200 lines
5.1 KiB
Go
200 lines
5.1 KiB
Go
package core
|
||
|
||
import (
|
||
"fmt"
|
||
"reflect"
|
||
"strings"
|
||
)
|
||
|
||
// RelationType 关联类型
|
||
type RelationType int
|
||
|
||
const (
|
||
HasOne RelationType = iota // 一对一
|
||
HasMany // 一对多
|
||
BelongsTo // 多对一
|
||
ManyToMany // 多对多
|
||
)
|
||
|
||
// RelationInfo 关联信息
|
||
type RelationInfo struct {
|
||
Type RelationType // 关联类型
|
||
Field string // 字段名
|
||
Model interface{} // 关联的模型
|
||
FK string // 外键
|
||
PK string // 主键
|
||
JoinTable string // 中间表(多对多)
|
||
JoinFK string // 中间表外键
|
||
JoinJoinFK string // 中间表关联外键
|
||
}
|
||
|
||
// RelationLoader 关联加载器 - 处理模型关联的预加载
|
||
type RelationLoader struct {
|
||
db *Database
|
||
}
|
||
|
||
// NewRelationLoader 创建关联加载器实例
|
||
func NewRelationLoader(db *Database) *RelationLoader {
|
||
return &RelationLoader{db: db}
|
||
}
|
||
|
||
// Preload 预加载关联数据
|
||
func (rl *RelationLoader) Preload(models interface{}, relation string, conditions ...interface{}) error {
|
||
// 获取反射对象
|
||
modelsVal := reflect.ValueOf(models)
|
||
if modelsVal.Kind() != reflect.Ptr {
|
||
return fmt.Errorf("models 必须是指针类型")
|
||
}
|
||
|
||
elem := modelsVal.Elem()
|
||
if elem.Kind() != reflect.Slice {
|
||
return fmt.Errorf("models 必须是指向 Slice 的指针")
|
||
}
|
||
|
||
if elem.Len() == 0 {
|
||
return nil // 空 Slice,无需加载
|
||
}
|
||
|
||
// 解析关联关系
|
||
relationInfo, err := rl.parseRelation(elem.Index(0).Interface(), relation)
|
||
if err != nil {
|
||
return fmt.Errorf("解析关联失败:%w", err)
|
||
}
|
||
|
||
// 根据关联类型加载数据
|
||
switch relationInfo.Type {
|
||
case HasOne:
|
||
return rl.loadHasOne(elem, relationInfo)
|
||
case HasMany:
|
||
return rl.loadHasMany(elem, relationInfo)
|
||
case BelongsTo:
|
||
return rl.loadBelongsTo(elem, relationInfo)
|
||
case ManyToMany:
|
||
return rl.loadManyToMany(elem, relationInfo)
|
||
default:
|
||
return fmt.Errorf("不支持的关联类型:%v", relationInfo.Type)
|
||
}
|
||
}
|
||
|
||
// parseRelation 解析关联关系
|
||
func (rl *RelationLoader) parseRelation(model interface{}, relation string) (*RelationInfo, error) {
|
||
// TODO: 从结构体标签中解析关联信息
|
||
// 示例:
|
||
// type Order struct {
|
||
// User User `gorm:"ForeignKey:user_id;References:id"`
|
||
// Items []Item `gorm:"ForeignKey:order_id;References:id"`
|
||
// }
|
||
|
||
// 这里提供简化的实现
|
||
return &RelationInfo{
|
||
Type: HasOne, // 默认假设为一对一
|
||
Field: relation,
|
||
}, nil
|
||
}
|
||
|
||
// loadHasOne 加载一对一关联
|
||
func (rl *RelationLoader) loadHasOne(models reflect.Value, relation *RelationInfo) error {
|
||
// 收集所有主键值
|
||
pkValues := make([]interface{}, 0, models.Len())
|
||
for i := 0; i < models.Len(); i++ {
|
||
model := models.Index(i).Interface()
|
||
pk := rl.getFieldValue(model, "ID")
|
||
if pk != nil {
|
||
pkValues = append(pkValues, pk)
|
||
}
|
||
}
|
||
|
||
if len(pkValues) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 查询关联数据
|
||
query := rl.db.Model(relation.Model)
|
||
query.Where(fmt.Sprintf("%s IN ?", relation.FK), pkValues)
|
||
|
||
// TODO: 执行查询并映射到模型
|
||
|
||
return nil
|
||
}
|
||
|
||
// loadHasMany 加载一对多关联
|
||
func (rl *RelationLoader) loadHasMany(models reflect.Value, relation *RelationInfo) error {
|
||
// 类似 HasOne,但结果需要映射到 Slice
|
||
return rl.loadHasOne(models, relation)
|
||
}
|
||
|
||
// loadBelongsTo 加载多对一关联
|
||
func (rl *RelationLoader) loadBelongsTo(models reflect.Value, relation *RelationInfo) error {
|
||
// 收集所有外键值
|
||
fkValues := make([]interface{}, 0, models.Len())
|
||
for i := 0; i < models.Len(); i++ {
|
||
model := models.Index(i).Interface()
|
||
fk := rl.getFieldValue(model, relation.FK)
|
||
if fk != nil {
|
||
fkValues = append(fkValues, fk)
|
||
}
|
||
}
|
||
|
||
if len(fkValues) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 查询关联数据
|
||
query := rl.db.Model(relation.Model)
|
||
query.Where(fmt.Sprintf("id IN ?"), fkValues)
|
||
|
||
// TODO: 执行查询并映射到模型
|
||
|
||
return nil
|
||
}
|
||
|
||
// loadManyToMany 加载多对多关联
|
||
func (rl *RelationLoader) loadManyToMany(models reflect.Value, relation *RelationInfo) error {
|
||
// 多对多需要通过中间表查询
|
||
// SELECT * FROM table WHERE id IN (
|
||
// SELECT join_fk FROM join_table WHERE fk IN (pk_values)
|
||
// )
|
||
|
||
return fmt.Errorf("多对多关联暂未实现")
|
||
}
|
||
|
||
// getFieldValue 获取字段的值
|
||
func (rl *RelationLoader) getFieldValue(model interface{}, fieldName string) interface{} {
|
||
val := reflect.ValueOf(model)
|
||
if val.Kind() == reflect.Ptr {
|
||
val = val.Elem()
|
||
}
|
||
|
||
field := val.FieldByName(fieldName)
|
||
if field.IsValid() && field.CanInterface() {
|
||
return field.Interface()
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getRelationTags 从结构体字段提取关联标签信息
|
||
func getRelationTags(structType reflect.Type, fieldName string) map[string]string {
|
||
tags := make(map[string]string)
|
||
|
||
for i := 0; i < structType.NumField(); i++ {
|
||
field := structType.Field(i)
|
||
if field.Name == fieldName {
|
||
gormTag := field.Tag.Get("gorm")
|
||
if gormTag != "" {
|
||
// 解析 GORM 风格的标签
|
||
parts := strings.Split(gormTag, ";")
|
||
for _, part := range parts {
|
||
kv := strings.Split(part, ":")
|
||
if len(kv) == 2 {
|
||
tags[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
|
||
}
|
||
}
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
return tags
|
||
}
|