307 lines
6.8 KiB
Go
307 lines
6.8 KiB
Go
package core
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// FieldMapper 字段映射器实现 - 使用反射处理 Go 结构体与数据库字段之间的映射
|
|
type FieldMapper struct{}
|
|
|
|
// NewFieldMapper 创建字段映射器实例
|
|
func NewFieldMapper() IFieldMapper {
|
|
return &FieldMapper{}
|
|
}
|
|
|
|
// StructToColumns 将结构体转换为键值对 - 用于 INSERT/UPDATE 操作
|
|
func (fm *FieldMapper) StructToColumns(model interface{}) (map[string]interface{}, error) {
|
|
result := make(map[string]interface{})
|
|
|
|
// 获取反射对象
|
|
val := reflect.ValueOf(model)
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
|
|
if val.Kind() != reflect.Struct {
|
|
return nil, errors.New("模型必须是结构体")
|
|
}
|
|
|
|
typ := val.Type()
|
|
|
|
// 遍历所有字段
|
|
for i := 0; i < val.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
value := val.Field(i)
|
|
|
|
// 跳过未导出的字段
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
|
|
// 获取 db 标签
|
|
dbTag := field.Tag.Get("db")
|
|
if dbTag == "" || dbTag == "-" {
|
|
continue // 跳过没有 db 标签或标签为 - 的字段
|
|
}
|
|
|
|
// 跳过零值(可选优化)
|
|
if fm.isZeroValue(value) {
|
|
continue
|
|
}
|
|
|
|
// 添加到结果 map
|
|
result[dbTag] = value.Interface()
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ColumnsToStruct 将查询结果映射到结构体 - 用于 SELECT 操作
|
|
func (fm *FieldMapper) ColumnsToStruct(rows *sql.Rows, model interface{}) error {
|
|
// 获取列信息
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return fmt.Errorf("获取列信息失败:%w", err)
|
|
}
|
|
|
|
// 获取反射对象
|
|
val := reflect.ValueOf(model)
|
|
if val.Kind() != reflect.Ptr {
|
|
return errors.New("模型必须是指针类型")
|
|
}
|
|
|
|
elem := val.Elem()
|
|
if elem.Kind() != reflect.Struct {
|
|
return errors.New("模型必须是指向结构体的指针")
|
|
}
|
|
|
|
// 创建扫描目标
|
|
scanTargets := make([]interface{}, len(columns))
|
|
fieldMap := make(map[int]int) // column index -> field index
|
|
|
|
// 建立列名到结构体字段的映射
|
|
for i, col := range columns {
|
|
found := false
|
|
for j := 0; j < elem.NumField(); j++ {
|
|
field := elem.Type().Field(j)
|
|
dbTag := field.Tag.Get("db")
|
|
|
|
// 匹配列名和字段
|
|
if dbTag == col || strings.ToLower(dbTag) == strings.ToLower(col) ||
|
|
strings.ToLower(field.Name) == strings.ToLower(col) {
|
|
fieldMap[i] = j
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// 如果没找到匹配字段,使用 interface{} 占位
|
|
if !found {
|
|
var dummy interface{}
|
|
scanTargets[i] = &dummy
|
|
}
|
|
}
|
|
|
|
// 为找到的字段创建扫描目标
|
|
for i := range columns {
|
|
if fieldIdx, ok := fieldMap[i]; ok {
|
|
field := elem.Field(fieldIdx)
|
|
if field.CanSet() {
|
|
scanTargets[i] = field.Addr().Interface()
|
|
} else {
|
|
var dummy interface{}
|
|
scanTargets[i] = &dummy
|
|
}
|
|
}
|
|
}
|
|
|
|
// 执行扫描
|
|
if err := rows.Scan(scanTargets...); err != nil {
|
|
return fmt.Errorf("扫描数据失败:%w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTableName 获取模型对应的表名
|
|
func (fm *FieldMapper) GetTableName(model interface{}) string {
|
|
// 检查是否实现了 TableName() 方法
|
|
type tabler interface {
|
|
TableName() string
|
|
}
|
|
|
|
if t, ok := model.(tabler); ok {
|
|
return t.TableName()
|
|
}
|
|
|
|
// 否则使用结构体名称
|
|
val := reflect.ValueOf(model)
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
|
|
typ := val.Type()
|
|
return fm.toSnakeCase(typ.Name())
|
|
}
|
|
|
|
// GetPrimaryKey 获取主键字段名 - 默认为 "id"
|
|
func (fm *FieldMapper) GetPrimaryKey(model interface{}) string {
|
|
// 查找标记为主键的字段
|
|
val := reflect.ValueOf(model)
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
|
|
typ := val.Type()
|
|
for i := 0; i < val.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
|
|
// 检查是否是 ID 字段
|
|
fieldName := field.Name
|
|
if fieldName == "ID" || fieldName == "Id" || fieldName == "id" {
|
|
dbTag := field.Tag.Get("db")
|
|
if dbTag != "" && dbTag != "-" {
|
|
return dbTag
|
|
}
|
|
return "id"
|
|
}
|
|
|
|
// 检查是否有 primary 标签
|
|
if field.Tag.Get("primary") == "true" {
|
|
dbTag := field.Tag.Get("db")
|
|
if dbTag != "" {
|
|
return dbTag
|
|
}
|
|
}
|
|
}
|
|
|
|
return "id" // 默认返回 id
|
|
}
|
|
|
|
// GetFields 获取所有字段信息 - 用于生成 SQL 语句
|
|
func (fm *FieldMapper) GetFields(model interface{}) []FieldInfo {
|
|
var fields []FieldInfo
|
|
|
|
val := reflect.ValueOf(model)
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
|
|
typ := val.Type()
|
|
|
|
// 遍历所有字段
|
|
for i := 0; i < val.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
|
|
// 跳过未导出的字段
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
|
|
// 获取 db 标签
|
|
dbTag := field.Tag.Get("db")
|
|
if dbTag == "" || dbTag == "-" {
|
|
continue
|
|
}
|
|
|
|
// 创建字段信息
|
|
info := FieldInfo{
|
|
Name: field.Name,
|
|
Column: dbTag,
|
|
Type: fm.getTypeName(field.Type),
|
|
DbType: fm.mapToDbType(field.Type),
|
|
Tag: dbTag,
|
|
}
|
|
|
|
// 检查是否是主键
|
|
if field.Tag.Get("primary") == "true" ||
|
|
field.Name == "ID" || field.Name == "Id" {
|
|
info.IsPrimary = true
|
|
}
|
|
|
|
// 检查是否是自增
|
|
if field.Tag.Get("auto") == "true" {
|
|
info.IsAuto = true
|
|
}
|
|
|
|
fields = append(fields, info)
|
|
}
|
|
|
|
return fields
|
|
}
|
|
|
|
// isZeroValue 检查是否是零值
|
|
func (fm *FieldMapper) isZeroValue(v reflect.Value) bool {
|
|
switch v.Kind() {
|
|
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
|
|
return v.Len() == 0
|
|
case reflect.Bool:
|
|
return !v.Bool()
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return v.Int() == 0
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
return v.Uint() == 0
|
|
case reflect.Float32, reflect.Float64:
|
|
return v.Float() == 0
|
|
case reflect.Interface, reflect.Ptr:
|
|
return v.IsNil()
|
|
case reflect.Struct:
|
|
// 特殊处理 time.Time
|
|
if t, ok := v.Interface().(time.Time); ok {
|
|
return t.IsZero()
|
|
}
|
|
return false
|
|
}
|
|
return false
|
|
}
|
|
|
|
// getTypeName 获取类型的名称
|
|
func (fm *FieldMapper) getTypeName(t reflect.Type) string {
|
|
return t.String()
|
|
}
|
|
|
|
// mapToDbType 将 Go 类型映射到数据库类型
|
|
func (fm *FieldMapper) mapToDbType(t reflect.Type) string {
|
|
switch t.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return "BIGINT"
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
return "BIGINT UNSIGNED"
|
|
case reflect.Float32, reflect.Float64:
|
|
return "DECIMAL"
|
|
case reflect.Bool:
|
|
return "TINYINT"
|
|
case reflect.String:
|
|
return "VARCHAR(255)"
|
|
default:
|
|
// 特殊类型
|
|
if t.PkgPath() == "time" && t.Name() == "Time" {
|
|
return "DATETIME"
|
|
}
|
|
return "TEXT"
|
|
}
|
|
}
|
|
|
|
// toSnakeCase 将驼峰命名转换为下划线命名
|
|
func (fm *FieldMapper) toSnakeCase(str string) string {
|
|
var result strings.Builder
|
|
|
|
for i, r := range str {
|
|
if r >= 'A' && r <= 'Z' {
|
|
if i > 0 {
|
|
result.WriteRune('_')
|
|
}
|
|
result.WriteRune(r + 32) // 转换为小写
|
|
} else {
|
|
result.WriteRune(r)
|
|
}
|
|
}
|
|
|
|
return result.String()
|
|
}
|