368 lines
9.3 KiB
Go
368 lines
9.3 KiB
Go
package core
|
||
|
||
import (
|
||
"fmt"
|
||
"reflect"
|
||
"strings"
|
||
)
|
||
|
||
// RelationType 关联类型
|
||
type RelationType int
|
||
|
||
const (
|
||
HasOne RelationType = iota // 一对一
|
||
HasMany // 一对多
|
||
BelongsTo // 多对一
|
||
ManyToMany // 多对多
|
||
)
|
||
|
||
// RelationInfo 关联信息
|
||
type RelationInfo struct {
|
||
Type RelationType // 关联类型
|
||
Field string // 字段名
|
||
Model interface{} // 关联的模型
|
||
FK string // 外键
|
||
PK string // 主键
|
||
JoinTable string // 中间表(多对多)
|
||
JoinFK string // 中间表外键
|
||
JoinJoinFK string // 中间表关联外键
|
||
}
|
||
|
||
// RelationLoader 关联加载器 - 处理模型关联的预加载
|
||
type RelationLoader struct {
|
||
db *Database
|
||
}
|
||
|
||
// NewRelationLoader 创建关联加载器实例
|
||
func NewRelationLoader(db *Database) *RelationLoader {
|
||
return &RelationLoader{db: db}
|
||
}
|
||
|
||
// Preload 预加载关联数据
|
||
func (rl *RelationLoader) Preload(models interface{}, relation string, conditions ...interface{}) error {
|
||
// 获取反射对象
|
||
modelsVal := reflect.ValueOf(models)
|
||
if modelsVal.Kind() != reflect.Ptr {
|
||
return fmt.Errorf("models 必须是指针类型")
|
||
}
|
||
|
||
elem := modelsVal.Elem()
|
||
if elem.Kind() != reflect.Slice {
|
||
return fmt.Errorf("models 必须是指向 Slice 的指针")
|
||
}
|
||
|
||
if elem.Len() == 0 {
|
||
return nil // 空 Slice,无需加载
|
||
}
|
||
|
||
// 解析关联关系
|
||
relationInfo, err := rl.parseRelation(elem.Index(0).Interface(), relation)
|
||
if err != nil {
|
||
return fmt.Errorf("解析关联失败:%w", err)
|
||
}
|
||
|
||
// 根据关联类型加载数据
|
||
switch relationInfo.Type {
|
||
case HasOne:
|
||
return rl.loadHasOne(elem, relationInfo)
|
||
case HasMany:
|
||
return rl.loadHasMany(elem, relationInfo)
|
||
case BelongsTo:
|
||
return rl.loadBelongsTo(elem, relationInfo)
|
||
case ManyToMany:
|
||
return rl.loadManyToMany(elem, relationInfo)
|
||
default:
|
||
return fmt.Errorf("不支持的关联类型:%v", relationInfo.Type)
|
||
}
|
||
}
|
||
|
||
// parseRelation 解析关联关系
|
||
func (rl *RelationLoader) parseRelation(model interface{}, relation string) (*RelationInfo, error) {
|
||
// 从结构体字段中解析关联信息
|
||
structType := reflect.TypeOf(model)
|
||
if structType.Kind() == reflect.Ptr {
|
||
structType = structType.Elem()
|
||
}
|
||
|
||
// 查找对应的字段
|
||
var relationField reflect.StructField
|
||
var found bool
|
||
|
||
for i := 0; i < structType.NumField(); i++ {
|
||
field := structType.Field(i)
|
||
if field.Name == relation {
|
||
relationField = field
|
||
found = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !found {
|
||
return nil, fmt.Errorf("字段 %s 不存在", relation)
|
||
}
|
||
|
||
// 从 gorm 标签解析关联信息
|
||
gormTag := relationField.Tag.Get("gorm")
|
||
fkTag := relationField.Tag.Get("foreignkey")
|
||
referencesTag := relationField.Tag.Get("references")
|
||
joinTableTag := relationField.Tag.Get("many2many")
|
||
|
||
// 初始化关联信息
|
||
info := &RelationInfo{
|
||
Field: relation,
|
||
Model: reflect.New(relationField.Type).Interface(),
|
||
}
|
||
|
||
// 判断关联类型
|
||
if relationField.Type.Kind() == reflect.Slice {
|
||
// 一对多或多对多
|
||
if joinTableTag != "" {
|
||
// 多对多
|
||
info.Type = ManyToMany
|
||
info.JoinTable = joinTableTag
|
||
} else {
|
||
// 一对多
|
||
info.Type = HasMany
|
||
}
|
||
} else {
|
||
// 一对一或多对一
|
||
// 根据外键位置判断
|
||
if fkTag != "" || referencesTag != "" {
|
||
// 如果当前模型包含外键,则是多对一
|
||
info.Type = BelongsTo
|
||
} else {
|
||
// 否则是一对一
|
||
info.Type = HasOne
|
||
}
|
||
}
|
||
|
||
// 解析外键和主键
|
||
if gormTag != "" {
|
||
// 解析 GORM 风格的标签
|
||
parts := strings.Split(gormTag, ";")
|
||
for _, part := range parts {
|
||
kv := strings.Split(part, ":")
|
||
if len(kv) == 2 {
|
||
key := strings.TrimSpace(kv[0])
|
||
value := strings.TrimSpace(kv[1])
|
||
switch key {
|
||
case "ForeignKey":
|
||
info.FK = value
|
||
case "References":
|
||
info.PK = value
|
||
case "JoinTable":
|
||
info.JoinTable = value
|
||
case "JoinForeignKey":
|
||
info.JoinFK = value
|
||
case "JoinReferences":
|
||
info.JoinJoinFK = value
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 使用单独的标签
|
||
if fkTag != "" {
|
||
info.FK = fkTag
|
||
}
|
||
if referencesTag != "" {
|
||
info.PK = referencesTag
|
||
}
|
||
|
||
// 设置默认值
|
||
if info.FK == "" {
|
||
// 默认外键为当前模型名 + Id
|
||
modelName := structType.Name()
|
||
info.FK = modelName + "Id"
|
||
}
|
||
if info.PK == "" {
|
||
info.PK = "id"
|
||
}
|
||
|
||
return info, nil
|
||
}
|
||
|
||
// loadHasOne 加载一对一关联
|
||
func (rl *RelationLoader) loadHasOne(models reflect.Value, relation *RelationInfo) error {
|
||
// 收集所有主键值
|
||
pkValues := make([]interface{}, 0, models.Len())
|
||
for i := 0; i < models.Len(); i++ {
|
||
model := models.Index(i).Interface()
|
||
pk := rl.getFieldValue(model, "ID")
|
||
if pk != nil {
|
||
pkValues = append(pkValues, pk)
|
||
}
|
||
}
|
||
|
||
if len(pkValues) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 查询关联数据
|
||
query := rl.db.Model(relation.Model)
|
||
query.Where(fmt.Sprintf("%s IN (?)", relation.FK), pkValues)
|
||
|
||
// 执行查询并映射到模型
|
||
relatedData := reflect.New(reflect.SliceOf(reflect.TypeOf(relation.Model))).Interface()
|
||
if err := query.Find(relatedData); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 将关联数据映射到模型
|
||
relatedVal := reflect.ValueOf(relatedData)
|
||
if relatedVal.Kind() == reflect.Ptr {
|
||
relatedVal = relatedVal.Elem()
|
||
}
|
||
|
||
// 遍历所有模型,设置关联字段
|
||
for i := 0; i < models.Len(); i++ {
|
||
model := models.Index(i)
|
||
pk := rl.getFieldValue(model.Interface(), "ID")
|
||
|
||
// 查找对应的关联数据
|
||
for j := 0; j < relatedVal.Len(); j++ {
|
||
item := relatedVal.Index(j).Interface()
|
||
itemFK := rl.getFieldValue(item, relation.FK)
|
||
if itemFK != nil && fmt.Sprintf("%v", itemFK) == fmt.Sprintf("%v", pk) {
|
||
model.Elem().FieldByName(relation.Field).Set(reflect.ValueOf(item))
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// loadHasMany 加载一对多关联
|
||
func (rl *RelationLoader) loadHasMany(models reflect.Value, relation *RelationInfo) error {
|
||
// 一对多的逻辑与 HasOne 类似,但结果必须映射到 Slice
|
||
return rl.loadHasOne(models, relation)
|
||
}
|
||
|
||
// loadBelongsTo 加载多对一关联
|
||
func (rl *RelationLoader) loadBelongsTo(models reflect.Value, relation *RelationInfo) error {
|
||
// 收集所有外键值
|
||
fkValues := make([]interface{}, 0, models.Len())
|
||
for i := 0; i < models.Len(); i++ {
|
||
model := models.Index(i).Interface()
|
||
fk := rl.getFieldValue(model, relation.FK)
|
||
if fk != nil {
|
||
fkValues = append(fkValues, fk)
|
||
}
|
||
}
|
||
|
||
if len(fkValues) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 查询关联数据
|
||
query := rl.db.Model(relation.Model)
|
||
query.Where(fmt.Sprintf("%s IN (?)", relation.PK), fkValues)
|
||
|
||
// 执行查询
|
||
relatedData := reflect.New(reflect.SliceOf(reflect.TypeOf(relation.Model))).Interface()
|
||
if err := query.Find(relatedData); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 将关联数据映射到模型
|
||
relatedVal := reflect.ValueOf(relatedData)
|
||
if relatedVal.Kind() == reflect.Ptr {
|
||
relatedVal = relatedVal.Elem()
|
||
}
|
||
|
||
// 遍历所有模型,设置关联字段
|
||
for i := 0; i < models.Len(); i++ {
|
||
model := models.Index(i)
|
||
fk := rl.getFieldValue(model.Interface(), relation.FK)
|
||
|
||
// 查找对应的关联数据
|
||
for j := 0; j < relatedVal.Len(); j++ {
|
||
item := relatedVal.Index(j).Interface()
|
||
itemPK := rl.getFieldValue(item, relation.PK)
|
||
if itemPK != nil && fmt.Sprintf("%v", itemPK) == fmt.Sprintf("%v", fk) {
|
||
model.Elem().FieldByName(relation.Field).Set(reflect.ValueOf(item))
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// loadManyToMany 加载多对多关联
|
||
func (rl *RelationLoader) loadManyToMany(models reflect.Value, relation *RelationInfo) error {
|
||
// 多对多需要通过中间表查询
|
||
// SELECT * FROM table WHERE id IN (
|
||
// SELECT join_fk FROM join_table WHERE fk IN (pk_values)
|
||
// )
|
||
|
||
// 收集所有主键值
|
||
pkValues := make([]interface{}, 0, models.Len())
|
||
for i := 0; i < models.Len(); i++ {
|
||
model := models.Index(i).Interface()
|
||
pk := rl.getFieldValue(model, "ID")
|
||
if pk != nil {
|
||
pkValues = append(pkValues, pk)
|
||
}
|
||
}
|
||
|
||
if len(pkValues) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 检查中间表配置
|
||
if relation.JoinTable == "" || relation.JoinFK == "" || relation.JoinJoinFK == "" {
|
||
return fmt.Errorf("多对多关联需要配置中间表信息")
|
||
}
|
||
|
||
// 先从中间表获取关联关系
|
||
joinQuery := rl.db.Table(relation.JoinTable)
|
||
joinQuery.Where(fmt.Sprintf("%s IN (?)", relation.JoinFK), pkValues)
|
||
|
||
// 这里简化处理,实际应该查询中间表获取关联 ID 列表
|
||
// 然后查询关联模型
|
||
|
||
return fmt.Errorf("多对多关联实现中,请稍后使用")
|
||
}
|
||
|
||
// getFieldValue 获取字段的值
|
||
func (rl *RelationLoader) getFieldValue(model interface{}, fieldName string) interface{} {
|
||
val := reflect.ValueOf(model)
|
||
if val.Kind() == reflect.Ptr {
|
||
val = val.Elem()
|
||
}
|
||
|
||
field := val.FieldByName(fieldName)
|
||
if field.IsValid() && field.CanInterface() {
|
||
return field.Interface()
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getRelationTags 从结构体字段提取关联标签信息
|
||
func getRelationTags(structType reflect.Type, fieldName string) map[string]string {
|
||
tags := make(map[string]string)
|
||
|
||
for i := 0; i < structType.NumField(); i++ {
|
||
field := structType.Field(i)
|
||
if field.Name == fieldName {
|
||
gormTag := field.Tag.Get("gorm")
|
||
if gormTag != "" {
|
||
// 解析 GORM 风格的标签
|
||
parts := strings.Split(gormTag, ";")
|
||
for _, part := range parts {
|
||
kv := strings.Split(part, ":")
|
||
if len(kv) == 2 {
|
||
tags[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
|
||
}
|
||
}
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
return tags
|
||
}
|