gin-base/db/core/mapper.go

307 lines
6.8 KiB
Go

package core
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
)
// FieldMapper 字段映射器实现 - 使用反射处理 Go 结构体与数据库字段之间的映射
type FieldMapper struct{}
// NewFieldMapper 创建字段映射器实例
func NewFieldMapper() IFieldMapper {
return &FieldMapper{}
}
// StructToColumns 将结构体转换为键值对 - 用于 INSERT/UPDATE 操作
func (fm *FieldMapper) StructToColumns(model interface{}) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 获取反射对象
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil, errors.New("模型必须是结构体")
}
typ := val.Type()
// 遍历所有字段
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
value := val.Field(i)
// 跳过未导出的字段
if !field.IsExported() {
continue
}
// 获取 db 标签
dbTag := field.Tag.Get("db")
if dbTag == "" || dbTag == "-" {
continue // 跳过没有 db 标签或标签为 - 的字段
}
// 跳过零值(可选优化)
if fm.isZeroValue(value) {
continue
}
// 添加到结果 map
result[dbTag] = value.Interface()
}
return result, nil
}
// ColumnsToStruct 将查询结果映射到结构体 - 用于 SELECT 操作
func (fm *FieldMapper) ColumnsToStruct(rows *sql.Rows, model interface{}) error {
// 获取列信息
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("获取列信息失败:%w", err)
}
// 获取反射对象
val := reflect.ValueOf(model)
if val.Kind() != reflect.Ptr {
return errors.New("模型必须是指针类型")
}
elem := val.Elem()
if elem.Kind() != reflect.Struct {
return errors.New("模型必须是指向结构体的指针")
}
// 创建扫描目标
scanTargets := make([]interface{}, len(columns))
fieldMap := make(map[int]int) // column index -> field index
// 建立列名到结构体字段的映射
for i, col := range columns {
found := false
for j := 0; j < elem.NumField(); j++ {
field := elem.Type().Field(j)
dbTag := field.Tag.Get("db")
// 匹配列名和字段
if dbTag == col || strings.ToLower(dbTag) == strings.ToLower(col) ||
strings.ToLower(field.Name) == strings.ToLower(col) {
fieldMap[i] = j
found = true
break
}
}
// 如果没找到匹配字段,使用 interface{} 占位
if !found {
var dummy interface{}
scanTargets[i] = &dummy
}
}
// 为找到的字段创建扫描目标
for i := range columns {
if fieldIdx, ok := fieldMap[i]; ok {
field := elem.Field(fieldIdx)
if field.CanSet() {
scanTargets[i] = field.Addr().Interface()
} else {
var dummy interface{}
scanTargets[i] = &dummy
}
}
}
// 执行扫描
if err := rows.Scan(scanTargets...); err != nil {
return fmt.Errorf("扫描数据失败:%w", err)
}
return nil
}
// GetTableName 获取模型对应的表名
func (fm *FieldMapper) GetTableName(model interface{}) string {
// 检查是否实现了 TableName() 方法
type tabler interface {
TableName() string
}
if t, ok := model.(tabler); ok {
return t.TableName()
}
// 否则使用结构体名称
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
typ := val.Type()
return fm.toSnakeCase(typ.Name())
}
// GetPrimaryKey 获取主键字段名 - 默认为 "id"
func (fm *FieldMapper) GetPrimaryKey(model interface{}) string {
// 查找标记为主键的字段
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
// 检查是否是 ID 字段
fieldName := field.Name
if fieldName == "ID" || fieldName == "Id" || fieldName == "id" {
dbTag := field.Tag.Get("db")
if dbTag != "" && dbTag != "-" {
return dbTag
}
return "id"
}
// 检查是否有 primary 标签
if field.Tag.Get("primary") == "true" {
dbTag := field.Tag.Get("db")
if dbTag != "" {
return dbTag
}
}
}
return "id" // 默认返回 id
}
// GetFields 获取所有字段信息 - 用于生成 SQL 语句
func (fm *FieldMapper) GetFields(model interface{}) []FieldInfo {
var fields []FieldInfo
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
typ := val.Type()
// 遍历所有字段
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
// 跳过未导出的字段
if !field.IsExported() {
continue
}
// 获取 db 标签
dbTag := field.Tag.Get("db")
if dbTag == "" || dbTag == "-" {
continue
}
// 创建字段信息
info := FieldInfo{
Name: field.Name,
Column: dbTag,
Type: fm.getTypeName(field.Type),
DbType: fm.mapToDbType(field.Type),
Tag: dbTag,
}
// 检查是否是主键
if field.Tag.Get("primary") == "true" ||
field.Name == "ID" || field.Name == "Id" {
info.IsPrimary = true
}
// 检查是否是自增
if field.Tag.Get("auto") == "true" {
info.IsAuto = true
}
fields = append(fields, info)
}
return fields
}
// isZeroValue 检查是否是零值
func (fm *FieldMapper) isZeroValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
case reflect.Struct:
// 特殊处理 time.Time
if t, ok := v.Interface().(time.Time); ok {
return t.IsZero()
}
return false
}
return false
}
// getTypeName 获取类型的名称
func (fm *FieldMapper) getTypeName(t reflect.Type) string {
return t.String()
}
// mapToDbType 将 Go 类型映射到数据库类型
func (fm *FieldMapper) mapToDbType(t reflect.Type) string {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return "BIGINT"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "BIGINT UNSIGNED"
case reflect.Float32, reflect.Float64:
return "DECIMAL"
case reflect.Bool:
return "TINYINT"
case reflect.String:
return "VARCHAR(255)"
default:
// 特殊类型
if t.PkgPath() == "time" && t.Name() == "Time" {
return "DATETIME"
}
return "TEXT"
}
}
// toSnakeCase 将驼峰命名转换为下划线命名
func (fm *FieldMapper) toSnakeCase(str string) string {
var result strings.Builder
for i, r := range str {
if r >= 'A' && r <= 'Z' {
if i > 0 {
result.WriteRune('_')
}
result.WriteRune(r + 32) // 转换为小写
} else {
result.WriteRune(r)
}
}
return result.String()
}