gin-base/db/core/relation.go

200 lines
5.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package core
import (
"fmt"
"reflect"
"strings"
)
// RelationType 关联类型
type RelationType int
const (
HasOne RelationType = iota // 一对一
HasMany // 一对多
BelongsTo // 多对一
ManyToMany // 多对多
)
// RelationInfo 关联信息
type RelationInfo struct {
Type RelationType // 关联类型
Field string // 字段名
Model interface{} // 关联的模型
FK string // 外键
PK string // 主键
JoinTable string // 中间表(多对多)
JoinFK string // 中间表外键
JoinJoinFK string // 中间表关联外键
}
// RelationLoader 关联加载器 - 处理模型关联的预加载
type RelationLoader struct {
db *Database
}
// NewRelationLoader 创建关联加载器实例
func NewRelationLoader(db *Database) *RelationLoader {
return &RelationLoader{db: db}
}
// Preload 预加载关联数据
func (rl *RelationLoader) Preload(models interface{}, relation string, conditions ...interface{}) error {
// 获取反射对象
modelsVal := reflect.ValueOf(models)
if modelsVal.Kind() != reflect.Ptr {
return fmt.Errorf("models 必须是指针类型")
}
elem := modelsVal.Elem()
if elem.Kind() != reflect.Slice {
return fmt.Errorf("models 必须是指向 Slice 的指针")
}
if elem.Len() == 0 {
return nil // 空 Slice无需加载
}
// 解析关联关系
relationInfo, err := rl.parseRelation(elem.Index(0).Interface(), relation)
if err != nil {
return fmt.Errorf("解析关联失败:%w", err)
}
// 根据关联类型加载数据
switch relationInfo.Type {
case HasOne:
return rl.loadHasOne(elem, relationInfo)
case HasMany:
return rl.loadHasMany(elem, relationInfo)
case BelongsTo:
return rl.loadBelongsTo(elem, relationInfo)
case ManyToMany:
return rl.loadManyToMany(elem, relationInfo)
default:
return fmt.Errorf("不支持的关联类型:%v", relationInfo.Type)
}
}
// parseRelation 解析关联关系
func (rl *RelationLoader) parseRelation(model interface{}, relation string) (*RelationInfo, error) {
// TODO: 从结构体标签中解析关联信息
// 示例:
// type Order struct {
// User User `gorm:"ForeignKey:user_id;References:id"`
// Items []Item `gorm:"ForeignKey:order_id;References:id"`
// }
// 这里提供简化的实现
return &RelationInfo{
Type: HasOne, // 默认假设为一对一
Field: relation,
}, nil
}
// loadHasOne 加载一对一关联
func (rl *RelationLoader) loadHasOne(models reflect.Value, relation *RelationInfo) error {
// 收集所有主键值
pkValues := make([]interface{}, 0, models.Len())
for i := 0; i < models.Len(); i++ {
model := models.Index(i).Interface()
pk := rl.getFieldValue(model, "ID")
if pk != nil {
pkValues = append(pkValues, pk)
}
}
if len(pkValues) == 0 {
return nil
}
// 查询关联数据
query := rl.db.Model(relation.Model)
query.Where(fmt.Sprintf("%s IN ?", relation.FK), pkValues)
// TODO: 执行查询并映射到模型
return nil
}
// loadHasMany 加载一对多关联
func (rl *RelationLoader) loadHasMany(models reflect.Value, relation *RelationInfo) error {
// 类似 HasOne但结果需要映射到 Slice
return rl.loadHasOne(models, relation)
}
// loadBelongsTo 加载多对一关联
func (rl *RelationLoader) loadBelongsTo(models reflect.Value, relation *RelationInfo) error {
// 收集所有外键值
fkValues := make([]interface{}, 0, models.Len())
for i := 0; i < models.Len(); i++ {
model := models.Index(i).Interface()
fk := rl.getFieldValue(model, relation.FK)
if fk != nil {
fkValues = append(fkValues, fk)
}
}
if len(fkValues) == 0 {
return nil
}
// 查询关联数据
query := rl.db.Model(relation.Model)
query.Where(fmt.Sprintf("id IN ?"), fkValues)
// TODO: 执行查询并映射到模型
return nil
}
// loadManyToMany 加载多对多关联
func (rl *RelationLoader) loadManyToMany(models reflect.Value, relation *RelationInfo) error {
// 多对多需要通过中间表查询
// SELECT * FROM table WHERE id IN (
// SELECT join_fk FROM join_table WHERE fk IN (pk_values)
// )
return fmt.Errorf("多对多关联暂未实现")
}
// getFieldValue 获取字段的值
func (rl *RelationLoader) getFieldValue(model interface{}, fieldName string) interface{} {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
field := val.FieldByName(fieldName)
if field.IsValid() && field.CanInterface() {
return field.Interface()
}
return nil
}
// getRelationTags 从结构体字段提取关联标签信息
func getRelationTags(structType reflect.Type, fieldName string) map[string]string {
tags := make(map[string]string)
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if field.Name == fieldName {
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
// 解析 GORM 风格的标签
parts := strings.Split(gormTag, ";")
for _, part := range parts {
kv := strings.Split(part, ":")
if len(kv) == 2 {
tags[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
}
}
break
}
}
return tags
}