gin-base/db/core/dao.go

315 lines
7.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package core
import (
"context"
"reflect"
)
// DAO 数据访问对象基类 - 所有 DAO 都继承此结构
// 提供通用的 CRUD 操作方法,子类只需嵌入即可使用
type DAO struct {
db *Database // 数据库连接实例
modelType interface{} // 模型类型信息,用于 Columns 等方法
}
// NewDAO 创建 DAO 基类实例
// 自动使用全局默认 Database 实例
func NewDAO() *DAO {
return &DAO{
db: GetDefaultDatabase(),
}
}
// NewDAOWithModel 创建带模型类型的 DAO 基类实例
// 参数:
// - model: 模型实例(指针类型),用于获取表结构信息
//
// 自动使用全局默认 Database 实例
func NewDAOWithModel(model interface{}) *DAO {
return &DAO{
db: GetDefaultDatabase(),
modelType: model,
}
}
// Create 创建记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Create(ctx context.Context, model interface{}) error {
// 使用事务来插入数据
tx, err := dao.db.Begin()
if err != nil {
return err
}
_, err = tx.Insert(model)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
// GetByID 根据 ID 查询单条记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) GetByID(ctx context.Context, model interface{}, id int64) error {
return dao.db.Model(model).Where("id = ?", id).First(model)
}
// Update 更新记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Update(ctx context.Context, model interface{}, data map[string]interface{}) error {
pkValue := getFieldValue(model, "ID")
if pkValue == 0 {
return nil
}
return dao.db.Model(model).Where("id = ?", pkValue).Updates(data)
}
// Delete 删除记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Delete(ctx context.Context, model interface{}) error {
pkValue := getFieldValue(model, "ID")
if pkValue == 0 {
return nil
}
return dao.db.Model(model).Where("id = ?", pkValue).Delete()
}
// FindAll 查询所有记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) FindAll(ctx context.Context, model interface{}) error {
return dao.db.Model(model).Find(model)
}
// FindByPage 分页查询(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) FindByPage(ctx context.Context, model interface{}, page, pageSize int) error {
return dao.db.Model(model).Limit(pageSize).Offset((page - 1) * pageSize).Find(model)
}
// Count 统计记录数(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Count(ctx context.Context, model interface{}, where ...string) (int64, error) {
var count int64
query := dao.db.Model(model)
if len(where) > 0 {
query = query.Where(where[0])
}
// Count 是链式调用,需要调用 Find 来执行
err := query.Count(&count).Find(model)
if err != nil {
return 0, err
}
return count, nil
}
// Exists 检查记录是否存在(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Exists(ctx context.Context, model interface{}) (bool, error) {
count, err := dao.Count(ctx, model)
if err != nil {
return false, err
}
return count > 0, nil
}
// First 查询第一条记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) First(ctx context.Context, model interface{}) error {
return dao.db.Model(model).First(model)
}
// Columns 获取表的所有列名
// 返回一个动态创建的结构体类型,所有字段都是 string 类型
// 用途:用于构建 UPDATE、INSERT 等操作时的列名映射
//
// 示例:
//
// type UserDAO struct {
// *core.DAO
// }
//
// func NewUserDAO() *UserDAO {
// return &UserDAO{
// DAO: core.NewDAOWithModel(&model.User{}),
// }
// }
//
// // 使用
// dao := NewUserDAO()
// cols := dao.Columns() // 返回 *struct{ID string; Username string; ...}
func (dao *DAO) Columns() interface{} {
// 检查是否有模型类型信息
if dao.modelType == nil {
return nil
}
// 获取模型类型
modelType := reflect.TypeOf(dao.modelType)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// 创建字段列表
fields := []reflect.StructField{}
// 遍历模型的所有字段
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// 跳过未导出的字段
if !field.IsExported() {
continue
}
// 获取 db 标签,如果没有则跳过
dbTag := field.Tag.Get("db")
if dbTag == "" || dbTag == "-" {
continue
}
// 创建新的结构体字段,类型为 string
newField := reflect.StructField{
Name: field.Name,
Type: reflect.TypeOf(""), // string 类型
Tag: reflect.StructTag(`json:"` + field.Tag.Get("json") + `" db:"` + dbTag + `"`),
}
fields = append(fields, newField)
}
// 动态创建结构体类型
columnsType := reflect.StructOf(fields)
// 创建该类型的指针并返回
return reflect.New(columnsType).Interface()
}
// getFieldValue 获取结构体字段值(辅助函数)
// 用于获取主键或其他字段的值,支持多种数据类型
// 参数:
// - model: 模型实例(可以是指针或值)
// - fieldName: 字段名(如 "ID", "UserId" 等)
//
// 返回:
// - int64: 字段值(如果是数字类型)或 0无法获取时
func getFieldValue(model interface{}, fieldName string) int64 {
// 检查空值
if model == nil {
return 0
}
// 获取反射对象
val := reflect.ValueOf(model)
// 如果是指针,解引用
if val.Kind() == reflect.Ptr {
if val.IsNil() {
return 0
}
val = val.Elem()
}
// 确保是结构体
if val.Kind() != reflect.Struct {
return 0
}
// 查找字段
field := val.FieldByName(fieldName)
if !field.IsValid() {
// 尝试查找常见的主键字段名变体
alternativeNames := []string{"Id", "id", "ID"}
for _, name := range alternativeNames {
if name != fieldName {
field = val.FieldByName(name)
if field.IsValid() {
fieldName = name
break
}
}
}
if !field.IsValid() {
return 0
}
}
// 检查字段是否可以访问
if !field.CanInterface() {
return 0
}
// 获取字段值并转换为 int64
fieldValue := field.Interface()
// 根据字段类型进行转换
switch v := fieldValue.(type) {
case int:
return int64(v)
case int8:
return int64(v)
case int16:
return int64(v)
case int32:
return int64(v)
case int64:
return v
case uint:
return int64(v)
case uint8:
return int64(v)
case uint16:
return int64(v)
case uint32:
return int64(v)
case uint64:
// 注意uint64 转 int64 可能溢出,但这里假设 ID 不会超过 int64 范围
return int64(v)
case float32:
return int64(v)
case float64:
return int64(v)
case string:
// 尝试将字符串解析为数字
// 注意:这里不导入 strconv简单处理返回 0
return 0
default:
// 其他类型(如 sql.NullInt64 等),尝试使用反射
return convertToInteger(field)
}
}
// convertToInteger 使用反射将字段值转换为 int64
func convertToInteger(field reflect.Value) int64 {
// 获取实际的值(如果是指针则解引用)
if field.Kind() == reflect.Ptr {
if field.IsNil() {
return 0
}
field = field.Elem()
}
// 根据 Kind 进行转换
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return int64(field.Uint())
case reflect.Float32, reflect.Float64:
return int64(field.Float())
case reflect.String:
// 字符串类型,尝试解析(简单实现,不处理错误)
return 0
default:
return 0
}
}