gin-base/db/core/relation.go

368 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package core
import (
"fmt"
"reflect"
"strings"
)
// RelationType 关联类型
type RelationType int
const (
HasOne RelationType = iota // 一对一
HasMany // 一对多
BelongsTo // 多对一
ManyToMany // 多对多
)
// RelationInfo 关联信息
type RelationInfo struct {
Type RelationType // 关联类型
Field string // 字段名
Model interface{} // 关联的模型
FK string // 外键
PK string // 主键
JoinTable string // 中间表(多对多)
JoinFK string // 中间表外键
JoinJoinFK string // 中间表关联外键
}
// RelationLoader 关联加载器 - 处理模型关联的预加载
type RelationLoader struct {
db *Database
}
// NewRelationLoader 创建关联加载器实例
func NewRelationLoader(db *Database) *RelationLoader {
return &RelationLoader{db: db}
}
// Preload 预加载关联数据
func (rl *RelationLoader) Preload(models interface{}, relation string, conditions ...interface{}) error {
// 获取反射对象
modelsVal := reflect.ValueOf(models)
if modelsVal.Kind() != reflect.Ptr {
return fmt.Errorf("models 必须是指针类型")
}
elem := modelsVal.Elem()
if elem.Kind() != reflect.Slice {
return fmt.Errorf("models 必须是指向 Slice 的指针")
}
if elem.Len() == 0 {
return nil // 空 Slice无需加载
}
// 解析关联关系
relationInfo, err := rl.parseRelation(elem.Index(0).Interface(), relation)
if err != nil {
return fmt.Errorf("解析关联失败:%w", err)
}
// 根据关联类型加载数据
switch relationInfo.Type {
case HasOne:
return rl.loadHasOne(elem, relationInfo)
case HasMany:
return rl.loadHasMany(elem, relationInfo)
case BelongsTo:
return rl.loadBelongsTo(elem, relationInfo)
case ManyToMany:
return rl.loadManyToMany(elem, relationInfo)
default:
return fmt.Errorf("不支持的关联类型:%v", relationInfo.Type)
}
}
// parseRelation 解析关联关系
func (rl *RelationLoader) parseRelation(model interface{}, relation string) (*RelationInfo, error) {
// 从结构体字段中解析关联信息
structType := reflect.TypeOf(model)
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
// 查找对应的字段
var relationField reflect.StructField
var found bool
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if field.Name == relation {
relationField = field
found = true
break
}
}
if !found {
return nil, fmt.Errorf("字段 %s 不存在", relation)
}
// 从 gorm 标签解析关联信息
gormTag := relationField.Tag.Get("gorm")
fkTag := relationField.Tag.Get("foreignkey")
referencesTag := relationField.Tag.Get("references")
joinTableTag := relationField.Tag.Get("many2many")
// 初始化关联信息
info := &RelationInfo{
Field: relation,
Model: reflect.New(relationField.Type).Interface(),
}
// 判断关联类型
if relationField.Type.Kind() == reflect.Slice {
// 一对多或多对多
if joinTableTag != "" {
// 多对多
info.Type = ManyToMany
info.JoinTable = joinTableTag
} else {
// 一对多
info.Type = HasMany
}
} else {
// 一对一或多对一
// 根据外键位置判断
if fkTag != "" || referencesTag != "" {
// 如果当前模型包含外键,则是多对一
info.Type = BelongsTo
} else {
// 否则是一对一
info.Type = HasOne
}
}
// 解析外键和主键
if gormTag != "" {
// 解析 GORM 风格的标签
parts := strings.Split(gormTag, ";")
for _, part := range parts {
kv := strings.Split(part, ":")
if len(kv) == 2 {
key := strings.TrimSpace(kv[0])
value := strings.TrimSpace(kv[1])
switch key {
case "ForeignKey":
info.FK = value
case "References":
info.PK = value
case "JoinTable":
info.JoinTable = value
case "JoinForeignKey":
info.JoinFK = value
case "JoinReferences":
info.JoinJoinFK = value
}
}
}
}
// 使用单独的标签
if fkTag != "" {
info.FK = fkTag
}
if referencesTag != "" {
info.PK = referencesTag
}
// 设置默认值
if info.FK == "" {
// 默认外键为当前模型名 + Id
modelName := structType.Name()
info.FK = modelName + "Id"
}
if info.PK == "" {
info.PK = "id"
}
return info, nil
}
// loadHasOne 加载一对一关联
func (rl *RelationLoader) loadHasOne(models reflect.Value, relation *RelationInfo) error {
// 收集所有主键值
pkValues := make([]interface{}, 0, models.Len())
for i := 0; i < models.Len(); i++ {
model := models.Index(i).Interface()
pk := rl.getFieldValue(model, "ID")
if pk != nil {
pkValues = append(pkValues, pk)
}
}
if len(pkValues) == 0 {
return nil
}
// 查询关联数据
query := rl.db.Model(relation.Model)
query.Where(fmt.Sprintf("%s IN (?)", relation.FK), pkValues)
// 执行查询并映射到模型
relatedData := reflect.New(reflect.SliceOf(reflect.TypeOf(relation.Model))).Interface()
if err := query.Find(relatedData); err != nil {
return err
}
// 将关联数据映射到模型
relatedVal := reflect.ValueOf(relatedData)
if relatedVal.Kind() == reflect.Ptr {
relatedVal = relatedVal.Elem()
}
// 遍历所有模型,设置关联字段
for i := 0; i < models.Len(); i++ {
model := models.Index(i)
pk := rl.getFieldValue(model.Interface(), "ID")
// 查找对应的关联数据
for j := 0; j < relatedVal.Len(); j++ {
item := relatedVal.Index(j).Interface()
itemFK := rl.getFieldValue(item, relation.FK)
if itemFK != nil && fmt.Sprintf("%v", itemFK) == fmt.Sprintf("%v", pk) {
model.Elem().FieldByName(relation.Field).Set(reflect.ValueOf(item))
break
}
}
}
return nil
}
// loadHasMany 加载一对多关联
func (rl *RelationLoader) loadHasMany(models reflect.Value, relation *RelationInfo) error {
// 一对多的逻辑与 HasOne 类似,但结果必须映射到 Slice
return rl.loadHasOne(models, relation)
}
// loadBelongsTo 加载多对一关联
func (rl *RelationLoader) loadBelongsTo(models reflect.Value, relation *RelationInfo) error {
// 收集所有外键值
fkValues := make([]interface{}, 0, models.Len())
for i := 0; i < models.Len(); i++ {
model := models.Index(i).Interface()
fk := rl.getFieldValue(model, relation.FK)
if fk != nil {
fkValues = append(fkValues, fk)
}
}
if len(fkValues) == 0 {
return nil
}
// 查询关联数据
query := rl.db.Model(relation.Model)
query.Where(fmt.Sprintf("%s IN (?)", relation.PK), fkValues)
// 执行查询
relatedData := reflect.New(reflect.SliceOf(reflect.TypeOf(relation.Model))).Interface()
if err := query.Find(relatedData); err != nil {
return err
}
// 将关联数据映射到模型
relatedVal := reflect.ValueOf(relatedData)
if relatedVal.Kind() == reflect.Ptr {
relatedVal = relatedVal.Elem()
}
// 遍历所有模型,设置关联字段
for i := 0; i < models.Len(); i++ {
model := models.Index(i)
fk := rl.getFieldValue(model.Interface(), relation.FK)
// 查找对应的关联数据
for j := 0; j < relatedVal.Len(); j++ {
item := relatedVal.Index(j).Interface()
itemPK := rl.getFieldValue(item, relation.PK)
if itemPK != nil && fmt.Sprintf("%v", itemPK) == fmt.Sprintf("%v", fk) {
model.Elem().FieldByName(relation.Field).Set(reflect.ValueOf(item))
break
}
}
}
return nil
}
// loadManyToMany 加载多对多关联
func (rl *RelationLoader) loadManyToMany(models reflect.Value, relation *RelationInfo) error {
// 多对多需要通过中间表查询
// SELECT * FROM table WHERE id IN (
// SELECT join_fk FROM join_table WHERE fk IN (pk_values)
// )
// 收集所有主键值
pkValues := make([]interface{}, 0, models.Len())
for i := 0; i < models.Len(); i++ {
model := models.Index(i).Interface()
pk := rl.getFieldValue(model, "ID")
if pk != nil {
pkValues = append(pkValues, pk)
}
}
if len(pkValues) == 0 {
return nil
}
// 检查中间表配置
if relation.JoinTable == "" || relation.JoinFK == "" || relation.JoinJoinFK == "" {
return fmt.Errorf("多对多关联需要配置中间表信息")
}
// 先从中间表获取关联关系
joinQuery := rl.db.Table(relation.JoinTable)
joinQuery.Where(fmt.Sprintf("%s IN (?)", relation.JoinFK), pkValues)
// 这里简化处理,实际应该查询中间表获取关联 ID 列表
// 然后查询关联模型
return fmt.Errorf("多对多关联实现中,请稍后使用")
}
// getFieldValue 获取字段的值
func (rl *RelationLoader) getFieldValue(model interface{}, fieldName string) interface{} {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
field := val.FieldByName(fieldName)
if field.IsValid() && field.CanInterface() {
return field.Interface()
}
return nil
}
// getRelationTags 从结构体字段提取关联标签信息
func getRelationTags(structType reflect.Type, fieldName string) map[string]string {
tags := make(map[string]string)
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if field.Name == fieldName {
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
// 解析 GORM 风格的标签
parts := strings.Split(gormTag, ";")
for _, part := range parts {
kv := strings.Split(part, ":")
if len(kv) == 2 {
tags[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
}
}
break
}
}
return tags
}