feat(database): 新增命令行参数解析、空值检查和数据库配置管理功能

- 实现了 command 包用于控制台参数和选项解析
- 添加了 empty 包提供变量空值/零值检查功能
- 集成了 Viper 库进行配置文件管理和环境变量支持
- 提供了数据库配置初始化和默认值设置功能
- 实现了配置值获取、设置和结构化解析方法
- 添加了服务器、数据库和JWT配置的便捷获取方法
- 实现了完整的数据库接口定义和事务管理功能
main v1.0.2003
maguodong 2026-04-08 09:33:59 +08:00
parent 6dcd564206
commit af0a8f4043
132 changed files with 14331 additions and 10719 deletions

View File

@ -54,30 +54,29 @@ func init() {
}) })
} }
func GetConfigPath() string {
return configPath
}
// SetDefault 设置默认配置信息 // SetDefault 设置默认配置信息
func SetDefault() { func SetDefault() {
viper.Set("SERVER.addr", "127.0.0.1:8080") viper.Set("SERVER.addr", "127.0.0.1:8080")
viper.Set("SERVER.mode", "release") viper.Set("SERVER.mode", "release")
viper.Set("DATABASE.default.host", "127.0.0.1")
// 数据库配置 - 支持多种数据库类型 viper.Set("DATABASE.default.port", "3306")
viper.Set("DATABASE.type", "sqlite") viper.Set("DATABASE.default.user", "root")
viper.Set("DATABASE.dns", gfile.Join(gfile.Pwd(), "db", "database.db")) viper.Set("DATABASE.default.pass", "123456")
viper.Set("DATABASE.debug", true) viper.Set("DATABASE.default.name", "test")
viper.Set("DATABASE.default.type", "mysql")
// 数据库连接池配置 viper.Set("DATABASE.default.role", "master")
viper.Set("DATABASE.maxIdleConns", 10) // 最大空闲连接数 viper.Set("DATABASE.default.debug", false)
viper.Set("DATABASE.maxOpenConns", 100) // 最大打开连接数 viper.Set("DATABASE.default.prefix", "")
viper.Set("DATABASE.connMaxLifetime", 3600) // 连接最大生命周期(秒) viper.Set("DATABASE.default.dryRun", false)
viper.Set("DATABASE.default.charset", "utf8")
// 数据库主从配置(可选) viper.Set("DATABASE.default.timezone", "Local")
viper.Set("DATABASE.replicas", []string{}) // 从库列表 viper.Set("DATABASE.default.createdAt", "create_time")
viper.Set("DATABASE.readPolicy", "random") // 读负载均衡策略 viper.Set("DATABASE.default.updatedAt", "update_time")
viper.Set("DATABASE.default.timeMaintainDisabled", false)
// 时间配置 - 定义时间字段名称和格式
viper.Set("DATABASE.timeConfig.createdAt", "created_at")
viper.Set("DATABASE.timeConfig.updatedAt", "updated_at")
viper.Set("DATABASE.timeConfig.deletedAt", "deleted_at")
viper.Set("DATABASE.timeConfig.format", "2006-01-02 15:04:05")
// JWT 配置 // JWT 配置
viper.Set("JWT.secret", "SET-YOUR-SECRET") viper.Set("JWT.secret", "SET-YOUR-SECRET")

View File

@ -1,31 +0,0 @@
package base
import (
"github.com/gogf/gf/v2/os/gtime"
"gorm.io/gorm"
)
type IdModel struct {
Id int `json:"id" gorm:"column:id;type:int(11);common:id"`
}
type TimeModel struct {
CreateTime string `json:"create_time" gorm:"column:create_time;type:varchar(255);common:创建时间"`
UpdateTime string `json:"update_time" gorm:"column:update_time;type:varchar(255);common:更新时间"`
}
func (tm *TimeModel) BeforeCreate(scope *gorm.DB) error {
scope.Set("create_time", gtime.Datetime())
scope.Set("update_time", gtime.Datetime())
return nil
}
func (tm *TimeModel) BeforeUpdate(scope *gorm.DB) error {
scope.Set("update_time", gtime.Datetime())
return nil
}
//func (tm *TimeModel) AfterFind(scope *gorm.DB) error {
// tm.CreateTime = gtime.New(tm.CreateTime).String()
// tm.UpdateTime = gtime.New(tm.UpdateTime).String()
// return nil
//}

135
database/command/command.go Normal file
View File

@ -0,0 +1,135 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
//
// Package command provides console operations, like options/arguments reading.
package command
import (
"os"
"regexp"
"strings"
)
var (
defaultParsedArgs = make([]string, 0)
defaultParsedOptions = make(map[string]string)
argumentOptionRegex = regexp.MustCompile(`^\-{1,2}([\w\?\.\-]+)(=){0,1}(.*)$`)
)
// Init does custom initialization.
func Init(args ...string) {
if len(args) == 0 {
if len(defaultParsedArgs) == 0 && len(defaultParsedOptions) == 0 {
args = os.Args
} else {
return
}
} else {
defaultParsedArgs = make([]string, 0)
defaultParsedOptions = make(map[string]string)
}
// Parsing os.Args with default algorithm.
defaultParsedArgs, defaultParsedOptions = ParseUsingDefaultAlgorithm(args...)
}
// ParseUsingDefaultAlgorithm parses arguments using default algorithm.
func ParseUsingDefaultAlgorithm(args ...string) (parsedArgs []string, parsedOptions map[string]string) {
parsedArgs = make([]string, 0)
parsedOptions = make(map[string]string)
for i := 0; i < len(args); {
array := argumentOptionRegex.FindStringSubmatch(args[i])
if len(array) > 2 {
if array[2] == "=" {
parsedOptions[array[1]] = array[3]
} else if i < len(args)-1 {
if len(args[i+1]) > 0 && args[i+1][0] == '-' {
// Example: gf gen -d -n 1
parsedOptions[array[1]] = array[3]
} else {
// Example: gf gen -n 2
parsedOptions[array[1]] = args[i+1]
i += 2
continue
}
} else {
// Example: gf gen -h
parsedOptions[array[1]] = array[3]
}
} else {
parsedArgs = append(parsedArgs, args[i])
}
i++
}
return
}
// GetOpt returns the option value named `name`.
func GetOpt(name string, def ...string) string {
Init()
if v, ok := defaultParsedOptions[name]; ok {
return v
}
if len(def) > 0 {
return def[0]
}
return ""
}
// GetOptAll returns all parsed options.
func GetOptAll() map[string]string {
Init()
return defaultParsedOptions
}
// ContainsOpt checks whether option named `name` exist in the arguments.
func ContainsOpt(name string) bool {
Init()
_, ok := defaultParsedOptions[name]
return ok
}
// GetArg returns the argument at `index`.
func GetArg(index int, def ...string) string {
Init()
if index < len(defaultParsedArgs) {
return defaultParsedArgs[index]
}
if len(def) > 0 {
return def[0]
}
return ""
}
// GetArgAll returns all parsed arguments.
func GetArgAll() []string {
Init()
return defaultParsedArgs
}
// GetOptWithEnv returns the command line argument of the specified `key`.
// If the argument does not exist, then it returns the environment variable with specified `key`.
// It returns the default value `def` if none of them exists.
//
// Fetching Rules:
// 1. Command line arguments are in lowercase format, eg: gf.package.variable;
// 2. Environment arguments are in uppercase format, eg: GF_PACKAGE_VARIABLE
func GetOptWithEnv(key string, def ...string) string {
cmdKey := strings.ToLower(strings.ReplaceAll(key, "_", "."))
if ContainsOpt(cmdKey) {
return GetOpt(cmdKey)
} else {
envKey := strings.ToUpper(strings.ReplaceAll(key, ".", "_"))
if r, ok := os.LookupEnv(envKey); ok {
return r
} else {
if len(def) > 0 {
return def[0]
}
}
}
return ""
}

View File

@ -1,96 +0,0 @@
package database
import (
"database/sql"
"fmt"
"time"
"git.magicany.cc/black1552/gin-base/config"
"git.magicany.cc/black1552/gin-base/log"
"github.com/glebarez/sqlite"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
var (
Type gorm.Dialector
Db *gorm.DB
err error
sqlDb *sql.DB
dns = config.GetConfigValue("database.dns", gfile.Join(gfile.Pwd(), "db", "database.db"))
)
func init() {
if g.IsEmpty(dns) {
log.Error("gormDns 未配置", "请检查配置文件")
return
}
switch config.GetConfigValue("database.type", "sqlite").String() {
case "mysql":
log.Info("使用 mysql 数据库")
mysqlInit()
case "sqlite":
log.Info("使用 sqlite 数据库")
sqliteInit()
}
// 构建 GORM 配置
gormConfig := &gorm.Config{
SkipDefaultTransaction: true,
NowFunc: func() time.Time {
return time.Now().Local()
},
// 命名策略:保持与模型一致,避免字段/表名转换问题
NamingStrategy: schema.NamingStrategy{
SingularTable: true, // 表名禁用复数形式(例如 User 对应 user 表,而非 users
},
}
// 根据配置决定是否开启 GORM 查询日志
if config.GetConfigValue("database.debug", false).Bool() {
log.Info("已开启 GORM 查询日志")
gormConfig.Logger = logger.Default.LogMode(logger.Info)
} else {
gormConfig.Logger = logger.Default.LogMode(logger.Silent)
}
Db, err = gorm.Open(Type, gormConfig)
if err != nil {
log.Error("数据库连接失败: ", err)
return
}
sqlDb, err = Db.DB()
if err != nil {
log.Error("获取sqlDb失败", err)
return
}
if err = sqlDb.Ping(); err != nil {
log.Error("数据库未正常连接", err)
return
}
}
func mysqlInit() {
Type = mysql.New(mysql.Config{
DSN: dns.String(),
DefaultStringSize: 255, // string 类型字段的默认长度
DisableDatetimePrecision: true, // 禁用 datetime 精度MySQL 5.6 之前的数据库不支持
DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
SkipInitializeWithVersion: false, // 根据当前 MySQL 版本自动配置
})
}
func sqliteInit() {
if !gfile.Exists(dns.String()) {
_, err = gfile.Create(dns.String())
if err != nil {
log.Error("创建数据库文件失败: ", err)
return
}
}
Type = sqlite.Open(fmt.Sprintf("%s?cache=shared&mode=rwc&_busy_timeout=10000&_fk=1&_journal=WAL&_sync=FULL", dns.String()))
}

243
database/empty/empty.go Normal file
View File

@ -0,0 +1,243 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package empty provides functions for checking empty/nil variables.
package empty
import (
"reflect"
"time"
"git.magicany.cc/black1552/gin-base/database/reflection"
)
// iString is used for type assert api for String().
type iString interface {
String() string
}
// iInterfaces is used for type assert api for Interfaces.
type iInterfaces interface {
Interfaces() []any
}
// iMapStrAny is the interface support for converting struct parameter to map.
type iMapStrAny interface {
MapStrAny() map[string]any
}
type iTime interface {
Date() (year int, month time.Month, day int)
IsZero() bool
}
// IsEmpty checks whether given `value` empty.
// It returns true if `value` is in: 0, nil, false, "", len(slice/map/chan) == 0,
// or else it returns false.
//
// The parameter `traceSource` is used for tracing to the source variable if given `value` is type of pointer
// that also points to a pointer. It returns true if the source is empty when `traceSource` is true.
// Note that it might use reflect feature which affects performance a little.
func IsEmpty(value any, traceSource ...bool) bool {
if value == nil {
return true
}
// It firstly checks the variable as common types using assertion to enhance the performance,
// and then using reflection.
switch result := value.(type) {
case int:
return result == 0
case int8:
return result == 0
case int16:
return result == 0
case int32:
return result == 0
case int64:
return result == 0
case uint:
return result == 0
case uint8:
return result == 0
case uint16:
return result == 0
case uint32:
return result == 0
case uint64:
return result == 0
case float32:
return result == 0
case float64:
return result == 0
case bool:
return !result
case string:
return result == ""
case []byte:
return len(result) == 0
case []rune:
return len(result) == 0
case []int:
return len(result) == 0
case []string:
return len(result) == 0
case []float32:
return len(result) == 0
case []float64:
return len(result) == 0
case map[string]any:
return len(result) == 0
default:
// Finally, using reflect.
var rv reflect.Value
if v, ok := value.(reflect.Value); ok {
rv = v
} else {
rv = reflect.ValueOf(value)
if IsNil(rv) {
return true
}
// =========================
// Common interfaces checks.
// =========================
if f, ok := value.(iTime); ok {
if f == (*time.Time)(nil) {
return true
}
return f.IsZero()
}
if f, ok := value.(iString); ok {
if f == nil {
return true
}
return f.String() == ""
}
if f, ok := value.(iInterfaces); ok {
if f == nil {
return true
}
return len(f.Interfaces()) == 0
}
if f, ok := value.(iMapStrAny); ok {
if f == nil {
return true
}
return len(f.MapStrAny()) == 0
}
}
switch rv.Kind() {
case reflect.Bool:
return !rv.Bool()
case
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64:
return rv.Int() == 0
case
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Uintptr:
return rv.Uint() == 0
case
reflect.Float32,
reflect.Float64:
return rv.Float() == 0
case reflect.String:
return rv.Len() == 0
case reflect.Struct:
var fieldValueInterface any
for i := 0; i < rv.NumField(); i++ {
fieldValueInterface, _ = reflection.ValueToInterface(rv.Field(i))
if !IsEmpty(fieldValueInterface) {
return false
}
}
return true
case
reflect.Chan,
reflect.Map,
reflect.Slice,
reflect.Array:
return rv.Len() == 0
case reflect.Pointer:
if len(traceSource) > 0 && traceSource[0] {
return IsEmpty(rv.Elem())
}
return rv.IsNil()
case
reflect.Func,
reflect.Interface,
reflect.UnsafePointer:
return rv.IsNil()
case reflect.Invalid:
return true
default:
return false
}
}
}
// IsNil checks whether given `value` is nil, especially for any type value.
// Parameter `traceSource` is used for tracing to the source variable if given `value` is type of pointer
// that also points to a pointer. It returns nil if the source is nil when `traceSource` is true.
// Note that it might use reflect feature which affects performance a little.
func IsNil(value any, traceSource ...bool) bool {
if value == nil {
return true
}
var rv reflect.Value
if v, ok := value.(reflect.Value); ok {
rv = v
} else {
rv = reflect.ValueOf(value)
}
switch rv.Kind() {
case reflect.Chan,
reflect.Map,
reflect.Slice,
reflect.Func,
reflect.Interface,
reflect.UnsafePointer:
return !rv.IsValid() || rv.IsNil()
case reflect.Pointer:
if len(traceSource) > 0 && traceSource[0] {
for rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
if !rv.IsValid() {
return true
}
if rv.Kind() == reflect.Pointer {
return rv.IsNil()
}
} else {
return !rv.IsValid() || rv.IsNil()
}
default:
return false
}
return false
}

1175
database/gdb.go Normal file

File diff suppressed because it is too large Load Diff

82
database/gdb_converter.go Normal file
View File

@ -0,0 +1,82 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"reflect"
"git.magicany.cc/black1552/gin-base/database/json"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/util/gconv"
)
// iVal is used for type assert api for Val().
type iVal interface {
Val() any
}
var (
// converter is the internal type converter for gdb.
converter = gconv.NewConverter()
)
func init() {
converter.RegisterAnyConverterFunc(
sliceTypeConverterFunc,
reflect.TypeOf([]string{}),
reflect.TypeOf([]float32{}),
reflect.TypeOf([]float64{}),
reflect.TypeOf([]int{}),
reflect.TypeOf([]int32{}),
reflect.TypeOf([]int64{}),
reflect.TypeOf([]uint{}),
reflect.TypeOf([]uint32{}),
reflect.TypeOf([]uint64{}),
)
}
// GetConverter returns the internal type converter for gdb.
func GetConverter() gconv.Converter {
return converter
}
func sliceTypeConverterFunc(from any, to reflect.Value) (err error) {
v, ok := from.(iVal)
if !ok {
return nil
}
fromVal := v.Val()
switch x := fromVal.(type) {
case []byte:
dst := to.Addr().Interface()
err = json.Unmarshal(x, dst)
case string:
dst := to.Addr().Interface()
err = json.Unmarshal([]byte(x), dst)
default:
fromType := reflect.TypeOf(fromVal)
switch fromType.Kind() {
case reflect.Slice:
convertOption := gconv.ConvertOption{
SliceOption: gconv.SliceOption{ContinueOnError: true},
MapOption: gconv.MapOption{ContinueOnError: true},
StructOption: gconv.StructOption{ContinueOnError: true},
}
dv, err := converter.ConvertWithTypeName(fromVal, to.Type().String(), convertOption)
if err != nil {
return err
}
to.Set(reflect.ValueOf(dv))
default:
err = gerror.Newf(
`unsupported type converting from type "%T" to type "%T"`,
fromVal, to,
)
}
}
return err
}

841
database/gdb_core.go Normal file
View File

@ -0,0 +1,841 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
//
package database
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"git.magicany.cc/black1552/gin-base/database/intlog"
"git.magicany.cc/black1552/gin-base/database/reflection"
"git.magicany.cc/black1552/gin-base/database/utils"
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gcache"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
// GetCore returns the underlying *Core object.
func (c *Core) GetCore() *Core {
return c
}
// Ctx is a chaining function, which creates and returns a new DB that is a shallow copy
// of current DB object and with given context in it.
// Note that this returned DB object can be used only once, so do not assign it to
// a global or package variable for long using.
func (c *Core) Ctx(ctx context.Context) DB {
if ctx == nil {
return c.db
}
// It makes a shallow copy of current db and changes its context for next chaining operation.
var (
err error
newCore = &Core{}
configNode = c.db.GetConfig()
)
*newCore = *c
// It creates a new DB object(NOT NEW CONNECTION), which is commonly a wrapper for object `Core`.
newCore.db, err = driverMap[configNode.Type].New(newCore, configNode)
if err != nil {
// It is really a serious error here.
// Do not let it continue.
panic(err)
}
newCore.ctx = WithDB(ctx, newCore.db)
newCore.ctx = c.injectInternalCtxData(newCore.ctx)
return newCore.db
}
// GetCtx returns the context for current DB.
// It returns `context.Background()` is there's no context previously set.
func (c *Core) GetCtx() context.Context {
ctx := c.ctx
if ctx == nil {
ctx = context.TODO()
}
return c.injectInternalCtxData(ctx)
}
// GetCtxTimeout returns the context and cancel function for specified timeout type.
func (c *Core) GetCtxTimeout(ctx context.Context, timeoutType ctxTimeoutType) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = c.db.GetCtx()
} else {
ctx = context.WithValue(ctx, ctxKeyWrappedByGetCtxTimeout, nil)
}
var config = c.db.GetConfig()
switch timeoutType {
case ctxTimeoutTypeExec:
if c.db.GetConfig().ExecTimeout > 0 {
return context.WithTimeout(ctx, config.ExecTimeout)
}
case ctxTimeoutTypeQuery:
if c.db.GetConfig().QueryTimeout > 0 {
return context.WithTimeout(ctx, config.QueryTimeout)
}
case ctxTimeoutTypePrepare:
if c.db.GetConfig().PrepareTimeout > 0 {
return context.WithTimeout(ctx, config.PrepareTimeout)
}
case ctxTimeoutTypeTrans:
if c.db.GetConfig().TranTimeout > 0 {
return context.WithTimeout(ctx, config.TranTimeout)
}
default:
panic(gerror.NewCodef(gcode.CodeInvalidParameter, "invalid context timeout type: %d", timeoutType))
}
return ctx, func() {}
}
// Close closes the database and prevents new queries from starting.
// Close then waits for all queries that have started processing on the server
// to finish.
//
// It is rare to Close a DB, as the DB handle is meant to be
// long-lived and shared between many goroutines.
func (c *Core) Close(ctx context.Context) (err error) {
if err = c.cache.Close(ctx); err != nil {
return err
}
c.links.LockFunc(func(m map[ConfigNode]*sql.DB) {
for k, v := range m {
err = v.Close()
if err != nil {
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `db.Close failed`)
}
intlog.Printf(ctx, `close link: %s, err: %v`, gconv.String(k), err)
if err != nil {
return
}
delete(m, k)
}
})
return
}
// Master creates and returns a connection from master node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (c *Core) Master(schema ...string) (*sql.DB, error) {
var (
usedSchema = gutil.GetOrDefaultStr(c.schema, schema...)
charL, charR = c.db.GetChars()
)
return c.getSqlDb(true, gstr.Trim(usedSchema, charL+charR))
}
// Slave creates and returns a connection from slave node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (c *Core) Slave(schema ...string) (*sql.DB, error) {
var (
usedSchema = gutil.GetOrDefaultStr(c.schema, schema...)
charL, charR = c.db.GetChars()
)
return c.getSqlDb(false, gstr.Trim(usedSchema, charL+charR))
}
// GetAll queries and returns data records from database.
func (c *Core) GetAll(ctx context.Context, sql string, args ...any) (Result, error) {
return c.db.DoSelect(ctx, nil, sql, args...)
}
// DoSelect queries and returns data records from database.
func (c *Core) DoSelect(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) {
return c.db.DoQuery(ctx, link, sql, args...)
}
// GetOne queries and returns one record from database.
func (c *Core) GetOne(ctx context.Context, sql string, args ...any) (Record, error) {
list, err := c.db.GetAll(ctx, sql, args...)
if err != nil {
return nil, err
}
if len(list) > 0 {
return list[0], nil
}
return nil, nil
}
// GetArray queries and returns data values as slice from database.
// Note that if there are multiple columns in the result, it returns just one column values randomly.
func (c *Core) GetArray(ctx context.Context, sql string, args ...any) (Array, error) {
all, err := c.db.DoSelect(ctx, nil, sql, args...)
if err != nil {
return nil, err
}
return all.Array(), nil
}
// doGetStruct queries one record from database and converts it to given struct.
// The parameter `pointer` should be a pointer to struct.
func (c *Core) doGetStruct(ctx context.Context, pointer any, sql string, args ...any) error {
one, err := c.db.GetOne(ctx, sql, args...)
if err != nil {
return err
}
return one.Struct(pointer)
}
// doGetStructs queries records from database and converts them to given struct.
// The parameter `pointer` should be type of struct slice: []struct/[]*struct.
func (c *Core) doGetStructs(ctx context.Context, pointer any, sql string, args ...any) error {
all, err := c.db.GetAll(ctx, sql, args...)
if err != nil {
return err
}
return all.Structs(pointer)
}
// GetScan queries one or more records from database and converts them to given struct or
// struct array.
//
// If parameter `pointer` is type of struct pointer, it calls GetStruct internally for
// the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally
// for conversion.
func (c *Core) GetScan(ctx context.Context, pointer any, sql string, args ...any) error {
reflectInfo := reflection.OriginTypeAndKind(pointer)
if reflectInfo.InputKind != reflect.Pointer {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
"params should be type of pointer, but got: %v",
reflectInfo.InputKind,
)
}
switch reflectInfo.OriginKind {
case reflect.Array, reflect.Slice:
return c.db.GetCore().doGetStructs(ctx, pointer, sql, args...)
case reflect.Struct:
return c.db.GetCore().doGetStruct(ctx, pointer, sql, args...)
default:
}
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`in valid parameter type "%v", of which element type should be type of struct/slice`,
reflectInfo.InputType,
)
}
// GetValue queries and returns the field value from database.
// The sql should query only one field from database, or else it returns only one
// field of the result.
func (c *Core) GetValue(ctx context.Context, sql string, args ...any) (Value, error) {
one, err := c.db.GetOne(ctx, sql, args...)
if err != nil {
return gvar.New(nil), err
}
for _, v := range one {
return v, nil
}
return gvar.New(nil), nil
}
// GetCount queries and returns the count from database.
func (c *Core) GetCount(ctx context.Context, sql string, args ...any) (int, error) {
// If the query fields do not contain function "COUNT",
// it replaces the sql string and adds the "COUNT" function to the fields.
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
}
value, err := c.db.GetValue(ctx, sql, args...)
if err != nil {
return 0, err
}
return value.Int(), nil
}
// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement.
func (c *Core) Union(unions ...*Model) *Model {
var ctx = c.db.GetCtx()
return c.doUnion(ctx, unionTypeNormal, unions...)
}
// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement.
func (c *Core) UnionAll(unions ...*Model) *Model {
var ctx = c.db.GetCtx()
return c.doUnion(ctx, unionTypeAll, unions...)
}
func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Model {
var (
unionTypeStr string
composedSqlStr string
composedArgs = make([]any, 0)
)
if unionType == unionTypeAll {
unionTypeStr = "UNION ALL"
} else {
unionTypeStr = "UNION"
}
for _, v := range unions {
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false)
if composedSqlStr == "" {
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
} else {
composedSqlStr += fmt.Sprintf(` %s (%s)`, unionTypeStr, sqlWithHolder)
}
composedArgs = append(composedArgs, holderArgs...)
}
return c.db.Raw(composedSqlStr, composedArgs...)
}
// PingMaster pings the master node to check authentication or keeps the connection alive.
func (c *Core) PingMaster() error {
var ctx = c.db.GetCtx()
if master, err := c.db.Master(); err != nil {
return err
} else {
if err = master.PingContext(ctx); err != nil {
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `master.Ping failed`)
}
return err
}
}
// PingSlave pings the slave node to check authentication or keeps the connection alive.
func (c *Core) PingSlave() error {
var ctx = c.db.GetCtx()
if slave, err := c.db.Slave(); err != nil {
return err
} else {
if err = slave.PingContext(ctx); err != nil {
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `slave.Ping failed`)
}
return err
}
}
// Insert does "INSERT INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it returns error.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `batch` specifies the batch operation count when given data is slice.
func (c *Core) Insert(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Insert()
}
return c.Model(table).Ctx(ctx).Data(data).Insert()
}
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it ignores the inserting.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `batch` specifies the batch operation count when given data is slice.
func (c *Core) InsertIgnore(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).InsertIgnore()
}
return c.Model(table).Ctx(ctx).Data(data).InsertIgnore()
}
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
func (c *Core) InsertAndGetId(ctx context.Context, table string, data any, batch ...int) (int64, error) {
if len(batch) > 0 {
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).InsertAndGetId()
}
return c.Model(table).Ctx(ctx).Data(data).InsertAndGetId()
}
// Replace does "REPLACE INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it deletes the record
// and inserts a new one.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// If given data is type of slice, it then does batch replacing, and the optional parameter
// `batch` specifies the batch operation count.
func (c *Core) Replace(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Replace()
}
return c.Model(table).Ctx(ctx).Data(data).Replace()
}
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
// It updates the record if there's primary or unique index in the saving data,
// or else it inserts a new record into the table.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// If given data is type of slice, it then does batch saving, and the optional parameter
// `batch` specifies the batch operation count.
func (c *Core) Save(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Save()
}
return c.Model(table).Ctx(ctx).Data(data).Save()
}
func (c *Core) fieldsToSequence(ctx context.Context, table string, fields []string) ([]string, error) {
var (
fieldSet = gset.NewStrSetFrom(fields)
fieldsResultInSequence = make([]string, 0)
tableFields, err = c.db.TableFields(ctx, table)
)
if err != nil {
return nil, err
}
// Sort the fields in order.
var fieldsOfTableInSequence = make([]string, len(tableFields))
for _, field := range tableFields {
fieldsOfTableInSequence[field.Index] = field.Name
}
// Sort the input fields.
for _, fieldName := range fieldsOfTableInSequence {
if fieldSet.Contains(fieldName) {
fieldsResultInSequence = append(fieldsResultInSequence, fieldName)
}
}
return fieldsResultInSequence, nil
}
// DoInsert inserts or updates data for given table.
// This function is usually used for custom interface definition, you do not need call it manually.
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `option` values are as follows:
// InsertOptionDefault: just insert, if there's unique/primary key in the data, it returns error;
// InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
// InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one;
// InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting;
func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) {
var (
keys []string // Field names.
values []string // Value holder string array, like: (?,?,?)
params []any // Values that will be committed to underlying database driver.
onDuplicateStr string // onDuplicateStr is used in "ON DUPLICATE KEY UPDATE" statement.
)
// ============================================================================================
// Group the list by fields. Different fields to different list.
// It here uses ListMap to keep sequence for data inserting.
// ============================================================================================
var (
keyListMap = gmap.NewListMap()
tmpKeyListMap = make(map[string]List)
)
for _, item := range list {
mapLen := len(item)
if mapLen == 0 {
continue
}
tmpKeys := make([]string, 0, mapLen)
for k := range item {
tmpKeys = append(tmpKeys, k)
}
if mapLen > 1 {
sort.Strings(tmpKeys)
}
keys = tmpKeys // for fieldsToSequence
tmpKeysInSequenceStr := gstr.Join(tmpKeys, ",")
if tmpKeyListMapItem, ok := tmpKeyListMap[tmpKeysInSequenceStr]; ok {
tmpKeyListMap[tmpKeysInSequenceStr] = append(tmpKeyListMapItem, item)
} else {
tmpKeyListMap[tmpKeysInSequenceStr] = List{item}
}
}
for tmpKeysInSequenceStr, itemList := range tmpKeyListMap {
keyListMap.Set(tmpKeysInSequenceStr, itemList)
}
if keyListMap.Size() > 1 {
var (
tmpResult sql.Result
sqlResult SqlResult
rowsAffected int64
)
keyListMap.Iterator(func(key, value any) bool {
tmpResult, err = c.DoInsert(ctx, link, table, value.(List), option)
if err != nil {
return false
}
rowsAffected, err = tmpResult.RowsAffected()
if err != nil {
return false
}
sqlResult.Result = tmpResult
sqlResult.Affected += rowsAffected
return true
})
return &sqlResult, err
}
keys, err = c.fieldsToSequence(ctx, table, keys)
if err != nil {
return nil, err
}
if len(keys) == 0 {
return nil, gerror.NewCode(gcode.CodeInvalidParameter, "no valid data fields found in table")
}
// Prepare the batch result pointer.
var (
charL, charR = c.db.GetChars()
batchResult = new(SqlResult)
keysStr = charL + strings.Join(keys, charR+","+charL) + charR
operation = GetInsertOperationByOption(option.InsertOption)
)
// Upsert clause only takes effect on Save operation.
if option.InsertOption == InsertOptionSave {
onDuplicateStr, err = c.db.FormatUpsert(keys, list, option)
if err != nil {
return nil, err
}
}
var (
listLength = len(list)
valueHolders = make([]string, 0)
)
for i := 0; i < listLength; i++ {
values = values[:0]
// Note that the map type is unordered,
// so it should use slice+key to retrieve the value.
for _, k := range keys {
if s, ok := list[i][k].(Raw); ok {
values = append(values, gconv.String(s))
} else {
values = append(values, "?")
params = append(params, list[i][k])
}
}
valueHolders = append(valueHolders, "("+gstr.Join(values, ",")+")")
// Batch package checks: It meets the batch number, or it is the last element.
if len(valueHolders) == option.BatchCount || (i == listLength-1 && len(valueHolders) > 0) {
var (
stdSqlResult sql.Result
affectedRows int64
)
stdSqlResult, err = c.db.DoExec(ctx, link, fmt.Sprintf(
"%s INTO %s(%s) VALUES%s %s",
operation, c.QuotePrefixTableName(table), keysStr,
gstr.Join(valueHolders, ","),
onDuplicateStr,
), params...)
if err != nil {
return stdSqlResult, err
}
if affectedRows, err = stdSqlResult.RowsAffected(); err != nil {
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `sql.Result.RowsAffected failed`)
return stdSqlResult, err
} else {
batchResult.Result = stdSqlResult
batchResult.Affected += affectedRows
}
params = params[:0]
valueHolders = valueHolders[:0]
}
}
return batchResult, nil
}
// Update does "UPDATE ... " statement for the table.
//
// The parameter `data` can be type of string/map/gmap/struct/*struct, etc.
// Eg: "uid=10000", "uid", 10000, g.Map{"uid": 10000, "name":"john"}
//
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
// It is commonly used with parameter `args`.
// Eg:
// "uid=10000",
// "uid", 10000
// "money>? AND name like ?", 99999, "vip_%"
// "status IN (?)", g.Slice{1,2,3}
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}.
func (c *Core) Update(ctx context.Context, table string, data any, condition any, args ...any) (sql.Result, error) {
return c.Model(table).Ctx(ctx).Data(data).Where(condition, args...).Update()
}
// DoUpdate does "UPDATE ... " statement for the table.
// This function is usually used for custom interface definition, you do not need to call it manually.
func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data any, condition string, args ...any) (result sql.Result, err error) {
table = c.QuotePrefixTableName(table)
var (
rv = reflect.ValueOf(data)
kind = rv.Kind()
)
if kind == reflect.Pointer {
rv = rv.Elem()
kind = rv.Kind()
}
var (
params []any
updates string
)
switch kind {
case reflect.Map, reflect.Struct:
var (
fields []string
dataMap map[string]any
)
dataMap, err = c.ConvertDataForRecord(ctx, data, table)
if err != nil {
return nil, err
}
// Sort the data keys in sequence of table fields.
var (
dataKeys = make([]string, 0)
keysInSequence = make([]string, 0)
)
for k := range dataMap {
dataKeys = append(dataKeys, k)
}
keysInSequence, err = c.fieldsToSequence(ctx, table, dataKeys)
if err != nil {
return nil, err
}
for _, k := range keysInSequence {
v := dataMap[k]
switch v.(type) {
case Counter, *Counter:
var counter Counter
switch value := v.(type) {
case Counter:
counter = value
case *Counter:
counter = *value
}
if counter.Value == 0 {
continue
}
operator, columnVal := c.getCounterAlter(counter)
fields = append(fields, fmt.Sprintf("%s=%s%s?", c.QuoteWord(k), c.QuoteWord(counter.Field), operator))
params = append(params, columnVal)
default:
if s, ok := v.(Raw); ok {
fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s))
} else {
fields = append(fields, c.QuoteWord(k)+"=?")
params = append(params, v)
}
}
}
updates = strings.Join(fields, ",")
default:
updates = gconv.String(data)
}
if len(updates) == 0 {
return nil, gerror.NewCode(gcode.CodeMissingParameter, "data cannot be empty")
}
if len(params) > 0 {
args = append(params, args...)
}
// If no link passed, it then uses the master link.
if link == nil {
if link, err = c.MasterLink(); err != nil {
return nil, err
}
}
return c.db.DoExec(ctx, link, fmt.Sprintf(
"UPDATE %s SET %s%s",
table, updates, condition,
),
args...,
)
}
// Delete does "DELETE FROM ... " statement for the table.
//
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
// It is commonly used with parameter `args`.
// Eg:
// "uid=10000",
// "uid", 10000
// "money>? AND name like ?", 99999, "vip_%"
// "status IN (?)", g.Slice{1,2,3}
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}.
func (c *Core) Delete(ctx context.Context, table string, condition any, args ...any) (result sql.Result, err error) {
return c.Model(table).Ctx(ctx).Where(condition, args...).Delete()
}
// DoDelete does "DELETE FROM ... " statement for the table.
// This function is usually used for custom interface definition, you do not need call it manually.
func (c *Core) DoDelete(ctx context.Context, link Link, table string, condition string, args ...any) (result sql.Result, err error) {
if link == nil {
if link, err = c.MasterLink(); err != nil {
return nil, err
}
}
table = c.QuotePrefixTableName(table)
return c.db.DoExec(ctx, link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
}
// FilteredLink retrieves and returns filtered `linkInfo` that can be using for
// logging or tracing purpose.
func (c *Core) FilteredLink() string {
return fmt.Sprintf(
`%s@%s(%s:%s)/%s`,
c.config.User, c.config.Protocol, c.config.Host, c.config.Port, c.config.Name,
)
}
// MarshalJSON implements the interface MarshalJSON for json.Marshal.
// It just returns the pointer address.
//
// Note that this interface implements mainly for workaround for a json infinite loop bug
// of Golang version < v1.14.
func (c *Core) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`%+v`, c)), nil
}
// writeSqlToLogger outputs the Sql object to logger.
// It is enabled only if configuration "debug" is true.
func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) {
var transactionIdStr string
if sql.IsTransaction {
if v := ctx.Value(transactionIdForLoggerCtx); v != nil {
transactionIdStr = fmt.Sprintf(`[txid:%d] `, v.(uint64))
}
}
s := fmt.Sprintf(
"[%3d ms] [%s] [%s] [rows:%-3d] %s%s",
sql.End-sql.Start, sql.Group, sql.Schema, sql.RowsAffected, transactionIdStr, sql.Format,
)
if sql.Error != nil {
s += "\nError: " + sql.Error.Error()
c.logger.Error(ctx, s)
} else {
c.logger.Debug(ctx, s)
}
}
// HasTable determine whether the table name exists in the database.
func (c *Core) HasTable(name string) (bool, error) {
tables, err := c.GetTablesWithCache()
if err != nil {
return false, err
}
charL, charR := c.db.GetChars()
name = gstr.Trim(name, charL+charR)
for _, table := range tables {
if table == name {
return true, nil
}
}
return false, nil
}
// GetInnerMemCache retrieves and returns the inner memory cache object.
func (c *Core) GetInnerMemCache() *gcache.Cache {
return c.innerMemCache
}
func (c *Core) SetTableFields(ctx context.Context, table string, fields map[string]*TableField, schema ...string) error {
if table == "" {
return gerror.NewCode(gcode.CodeInvalidParameter, "table name cannot be empty")
}
charL, charR := c.db.GetChars()
table = gstr.Trim(table, charL+charR)
if gstr.Contains(table, " ") {
return gerror.NewCode(
gcode.CodeInvalidParameter,
"function TableFields supports only single table operations",
)
}
var (
innerMemCache = c.GetInnerMemCache()
// prefix:group@schema#table
cacheKey = genTableFieldsCacheKey(
c.db.GetGroup(),
gutil.GetOrDefaultStr(c.db.GetSchema(), schema...),
table,
)
)
return innerMemCache.Set(ctx, cacheKey, fields, gcache.DurationNoExpire)
}
// GetTablesWithCache retrieves and returns the table names of current database with cache.
func (c *Core) GetTablesWithCache() ([]string, error) {
var (
ctx = c.db.GetCtx()
cacheKey = genTableNamesCacheKey(c.db.GetGroup())
cacheDuration = gcache.DurationNoExpire
innerMemCache = c.GetInnerMemCache()
)
result, err := innerMemCache.GetOrSetFuncLock(
ctx, cacheKey,
func(ctx context.Context) (any, error) {
tableList, err := c.db.Tables(ctx)
if err != nil {
return nil, err
}
return tableList, nil
}, cacheDuration,
)
if err != nil {
return nil, err
}
return result.Strings(), nil
}
// IsSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time.
func (c *Core) IsSoftCreatedFieldName(fieldName string) bool {
if fieldName == "" {
return false
}
if config := c.db.GetConfig(); config.CreatedAt != "" {
if utils.EqualFoldWithoutChars(fieldName, config.CreatedAt) {
return true
}
return gstr.InArray(append([]string{config.CreatedAt}, createdFieldNames...), fieldName)
}
for _, v := range createdFieldNames {
if utils.EqualFoldWithoutChars(fieldName, v) {
return true
}
}
return false
}
// FormatSqlBeforeExecuting formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func (c *Core) FormatSqlBeforeExecuting(sql string, args []any) (newSql string, newArgs []any) {
return handleSliceAndStructArgsForSql(sql, args)
}
// getCounterAlter
func (c *Core) getCounterAlter(counter Counter) (operator string, columnVal float64) {
operator, columnVal = "+", counter.Value
if columnVal < 0 {
operator, columnVal = "-", -columnVal
}
return
}

485
database/gdb_core_config.go Normal file
View File

@ -0,0 +1,485 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"fmt"
"sync"
"time"
"git.magicany.cc/black1552/gin-base/log"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gcache"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// Config is the configuration management object.
type Config map[string]ConfigGroup
// ConfigGroup is a slice of configuration node for specified named group.
type ConfigGroup []ConfigNode
// ConfigNode is configuration for one node.
type ConfigNode struct {
// Host specifies the server address, can be either IP address or domain name
// Example: "127.0.0.1", "localhost"
Host string `json:"host"`
// Port specifies the server port number
// Default is typically "3306" for MySQL
Port string `json:"port"`
// User specifies the authentication username for database connection
User string `json:"user"`
// Pass specifies the authentication password for database connection
Pass string `json:"pass"`
// Name specifies the default database name to be used
Name string `json:"name"`
// Type specifies the database type
// Example: mysql, mariadb, sqlite, mssql, pgsql, oracle, clickhouse, dm.
Type string `json:"type"`
// Link provides custom connection string that combines all configuration in one string
// Optional field
Link string `json:"link"`
// Extra provides additional configuration options for third-party database drivers
// Optional field
Extra string `json:"extra"`
// Role specifies the node role in master-slave setup
// Optional field, defaults to "master"
// Available values: "master", "slave"
Role Role `json:"role"`
// Debug enables debug mode for logging and output
// Optional field
Debug bool `json:"debug"`
// Prefix specifies the table name prefix
// Optional field
Prefix string `json:"prefix"`
// DryRun enables simulation mode where SELECT statements are executed
// but INSERT/UPDATE/DELETE statements are not
// Optional field
DryRun bool `json:"dryRun"`
// Weight specifies the node weight for load balancing calculations
// Optional field, only effective in multi-node setups
Weight int `json:"weight"`
// Charset specifies the character set for database operations
// Optional field, defaults to "utf8"
Charset string `json:"charset"`
// Protocol specifies the network protocol for database connection
// Optional field, defaults to "tcp"
// See net.Dial for available network protocols
Protocol string `json:"protocol"`
// Timezone sets the time zone for timestamp interpretation and display
// Optional field
Timezone string `json:"timezone"`
// Namespace specifies the schema namespace for certain databases
// Optional field, e.g., in PostgreSQL, Name is the catalog and Namespace is the schema
Namespace string `json:"namespace"`
// MaxIdleConnCount specifies the maximum number of idle connections in the pool
// Optional field
MaxIdleConnCount int `json:"maxIdle"`
// MaxOpenConnCount specifies the maximum number of open connections in the pool
// Optional field
MaxOpenConnCount int `json:"maxOpen"`
// MaxConnLifeTime specifies the maximum lifetime of a connection
// Optional field
MaxConnLifeTime time.Duration `json:"maxLifeTime"`
// MaxIdleConnTime specifies the maximum idle time of a connection before being closed
// This is Go 1.15+ feature: sql.DB.SetConnMaxIdleTime
// Optional field
MaxIdleConnTime time.Duration `json:"maxIdleTime"`
// QueryTimeout specifies the maximum execution time for DQL operations
// Optional field
QueryTimeout time.Duration `json:"queryTimeout"`
// ExecTimeout specifies the maximum execution time for DML operations
// Optional field
ExecTimeout time.Duration `json:"execTimeout"`
// TranTimeout specifies the maximum execution time for a transaction block
// Optional field
TranTimeout time.Duration `json:"tranTimeout"`
// PrepareTimeout specifies the maximum execution time for prepare operations
// Optional field
PrepareTimeout time.Duration `json:"prepareTimeout"`
// CreatedAt specifies the field name for automatic timestamp on record creation
// Optional field
CreatedAt string `json:"createdAt"`
// UpdatedAt specifies the field name for automatic timestamp on record updates
// Optional field
UpdatedAt string `json:"updatedAt"`
// DeletedAt specifies the field name for automatic timestamp on record deletion
// Optional field
DeletedAt string `json:"deletedAt"`
// TimeMaintainDisabled controls whether automatic time maintenance is disabled
// Optional field
TimeMaintainDisabled bool `json:"timeMaintainDisabled"`
}
type Role string
const (
RoleMaster Role = "master"
RoleSlave Role = "slave"
)
const (
DefaultGroupName = "default" // Default group name.
)
// configs specifies internal used configuration object.
var configs struct {
sync.RWMutex
config Config // All configurations.
group string // Default configuration group.
}
func init() {
configs.config = make(Config)
configs.group = DefaultGroupName
}
// SetConfig sets the global configuration for package.
// It will overwrite the old configuration of package.
func SetConfig(config Config) error {
defer instances.Clear()
configs.Lock()
defer configs.Unlock()
for k, nodes := range config {
for i, node := range nodes {
parsedNode, err := parseConfigNode(node)
if err != nil {
return err
}
nodes[i] = parsedNode
}
config[k] = nodes
}
configs.config = config
return nil
}
// SetConfigGroup sets the configuration for given group.
func SetConfigGroup(group string, nodes ConfigGroup) error {
defer instances.Clear()
configs.Lock()
defer configs.Unlock()
for i, node := range nodes {
parsedNode, err := parseConfigNode(node)
if err != nil {
return err
}
nodes[i] = parsedNode
}
configs.config[group] = nodes
return nil
}
// AddConfigNode adds one node configuration to configuration of given group.
func AddConfigNode(group string, node ConfigNode) error {
defer instances.Clear()
configs.Lock()
defer configs.Unlock()
parsedNode, err := parseConfigNode(node)
if err != nil {
return err
}
configs.config[group] = append(configs.config[group], parsedNode)
return nil
}
// parseConfigNode parses `Link` configuration syntax.
func parseConfigNode(node ConfigNode) (ConfigNode, error) {
if node.Link != "" {
parsedLinkNode, err := parseConfigNodeLink(&node)
if err != nil {
return node, err
}
node = *parsedLinkNode
}
if node.Link != "" && node.Type == "" {
match, _ := gregex.MatchString(`([a-z]+):(.+)`, node.Link)
if len(match) == 3 {
node.Type = gstr.Trim(match[1])
node.Link = gstr.Trim(match[2])
}
}
return node, nil
}
// AddDefaultConfigNode adds one node configuration to configuration of default group.
func AddDefaultConfigNode(node ConfigNode) error {
return AddConfigNode(DefaultGroupName, node)
}
// AddDefaultConfigGroup adds multiple node configurations to configuration of default group.
//
// Deprecated: Use SetDefaultConfigGroup instead.
func AddDefaultConfigGroup(nodes ConfigGroup) error {
return SetConfigGroup(DefaultGroupName, nodes)
}
// SetDefaultConfigGroup sets multiple node configurations to configuration of default group.
func SetDefaultConfigGroup(nodes ConfigGroup) error {
return SetConfigGroup(DefaultGroupName, nodes)
}
// GetConfig retrieves and returns the configuration of given group.
//
// Deprecated: Use GetConfigGroup instead.
func GetConfig(group string) ConfigGroup {
configGroup, _ := GetConfigGroup(group)
return configGroup
}
// GetConfigGroup retrieves and returns the configuration of given group.
// It returns an error if the group does not exist, or an empty slice if the group exists but has no nodes.
func GetConfigGroup(group string) (ConfigGroup, error) {
configs.RLock()
defer configs.RUnlock()
configGroup, exists := configs.config[group]
if !exists {
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`configuration group "%s" not found`,
group,
)
}
return configGroup, nil
}
// GetAllConfig retrieves and returns all configurations.
func GetAllConfig() Config {
configs.RLock()
defer configs.RUnlock()
return configs.config
}
// SetDefaultGroup sets the group name for default configuration.
func SetDefaultGroup(name string) {
defer instances.Clear()
configs.Lock()
defer configs.Unlock()
configs.group = name
}
// GetDefaultGroup returns the { name of default configuration.
func GetDefaultGroup() string {
defer instances.Clear()
configs.RLock()
defer configs.RUnlock()
return configs.group
}
// IsConfigured checks and returns whether the database configured.
// It returns true if any configuration exists.
func IsConfigured() bool {
configs.RLock()
defer configs.RUnlock()
return len(configs.config) > 0
}
// SetLogger sets the logger for orm.
func (c *Core) SetLogger(logger log.ILogger) {
c.logger = logger
}
// GetLogger returns the (logger) of the orm.
func (c *Core) GetLogger() log.ILogger {
return c.logger
}
// SetMaxIdleConnCount sets the maximum number of connections in the idle
// connection pool.
//
// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
//
// If n <= 0, no idle connections are retained.
//
// The default max idle connections is currently 2. This may change in
// a future release.
func (c *Core) SetMaxIdleConnCount(n int) {
c.dynamicConfig.MaxIdleConnCount = n
}
// SetMaxOpenConnCount sets the maximum number of open connections to the database.
//
// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
// MaxIdleConns, then MaxIdleConns will be reduced to match the new
// MaxOpenConns limit.
//
// If n <= 0, then there is no limit on the number of open connections.
// The default is 0 (unlimited).
func (c *Core) SetMaxOpenConnCount(n int) {
c.dynamicConfig.MaxOpenConnCount = n
}
// SetMaxConnLifeTime sets the maximum amount of time a connection may be reused.
//
// Expired connections may be closed lazily before reuse.
//
// If d <= 0, connections are not closed due to a connection's age.
func (c *Core) SetMaxConnLifeTime(d time.Duration) {
c.dynamicConfig.MaxConnLifeTime = d
}
// SetMaxIdleConnTime sets the maximum amount of time a connection may be idle before being closed.
//
// Idle connections may be closed lazily before reuse.
//
// If d <= 0, connections are not closed due to a connection's idle time.
// This is Go 1.15+ feature: sql.DB.SetConnMaxIdleTime.
func (c *Core) SetMaxIdleConnTime(d time.Duration) {
c.dynamicConfig.MaxIdleConnTime = d
}
// GetConfig returns the current used node configuration.
func (c *Core) GetConfig() *ConfigNode {
var configNode = c.getConfigNodeFromCtx(c.db.GetCtx())
if configNode != nil {
// Note:
// It so here checks and returns the config from current DB,
// if different schemas between current DB and config.Name from context,
// for example, in nested transaction scenario, the context is passed all through the logic procedure,
// but the config.Name from context may be still the original one from the first transaction object.
if c.config.Name == configNode.Name {
return configNode
}
}
return c.config
}
// SetDebug enables/disables the debug mode.
func (c *Core) SetDebug(debug bool) {
c.debug.Set(debug)
}
// GetDebug returns the debug value.
func (c *Core) GetDebug() bool {
return c.debug.Val()
}
// GetCache returns the internal cache object.
func (c *Core) GetCache() *gcache.Cache {
return c.cache
}
// GetGroup returns the group string configured.
func (c *Core) GetGroup() string {
return c.group
}
// SetDryRun enables/disables the DryRun feature.
func (c *Core) SetDryRun(enabled bool) {
c.config.DryRun = enabled
}
// GetDryRun returns the DryRun value.
func (c *Core) GetDryRun() bool {
return c.config.DryRun || allDryRun
}
// GetPrefix returns the table prefix string configured.
func (c *Core) GetPrefix() string {
return c.config.Prefix
}
// GetSchema returns the schema configured.
func (c *Core) GetSchema() string {
schema := c.schema
if schema == "" {
schema = c.db.GetConfig().Name
}
return schema
}
func parseConfigNodeLink(node *ConfigNode) (*ConfigNode, error) {
var (
link = node.Link
match []string
)
if link != "" {
// To be compatible with old configuration,
// it checks and converts the link to new configuration.
if node.Type != "" && !gstr.HasPrefix(link, node.Type+":") {
link = fmt.Sprintf("%s:%s", node.Type, link)
}
match, _ = gregex.MatchString(linkPattern, link)
if len(match) <= 5 {
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid link configuration: %s, shuold be pattern like: %s`,
link, linkPatternDescription,
)
}
node.Type = match[1]
node.User = match[2]
node.Pass = match[3]
node.Protocol = match[4]
array := gstr.Split(match[5], ":")
if node.Protocol == "file" {
node.Name = match[5]
} else {
if len(array) == 2 {
// link with port.
node.Host = array[0]
node.Port = array[1]
} else {
// link without port.
node.Host = array[0]
}
node.Name = match[6]
}
if len(match) > 6 && match[7] != "" {
node.Extra = match[7]
}
}
if node.Extra != "" {
if m, _ := gstr.Parse(node.Extra); len(m) > 0 {
_ = gconv.Struct(m, &node)
}
}
// Default value checks.
if node.Charset == "" {
node.Charset = defaultCharset
}
if node.Protocol == "" {
node.Protocol = defaultProtocol
}
return node, nil
}

98
database/gdb_core_ctx.go Normal file
View File

@ -0,0 +1,98 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"sync"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gctx"
)
// internalCtxData stores data in ctx for internal usage purpose.
type internalCtxData struct {
sync.Mutex
// Used configuration node in current operation.
ConfigNode *ConfigNode
}
// column stores column data in ctx for internal usage purpose.
type internalColumnData struct {
// The first column in result response from database server.
// This attribute is used for Value/Count selection statement purpose,
// which is to avoid HOOK handler that might modify the result columns
// that can confuse the Value/Count selection statement logic.
FirstResultColumn string
}
const (
internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData"
internalColumnDataKeyInCtx gctx.StrKey = "InternalColumnData"
// `ignoreResultKeyInCtx` is a mark for some db drivers that do not support `RowsAffected` function,
// for example: `clickhouse`. The `clickhouse` does not support fetching insert/update results,
// but returns errors when execute `RowsAffected`. It here ignores the calling of `RowsAffected`
// to avoid triggering errors, rather than ignoring errors after they are triggered.
ignoreResultKeyInCtx gctx.StrKey = "IgnoreResult"
)
func (c *Core) injectInternalCtxData(ctx context.Context) context.Context {
// If the internal data is already injected, it does nothing.
if ctx.Value(internalCtxDataKeyInCtx) != nil {
return ctx
}
return context.WithValue(ctx, internalCtxDataKeyInCtx, &internalCtxData{
ConfigNode: c.config,
})
}
func (c *Core) setConfigNodeToCtx(ctx context.Context, node *ConfigNode) error {
value := ctx.Value(internalCtxDataKeyInCtx)
if value == nil {
return gerror.NewCode(gcode.CodeInternalError, `no internal data found in context`)
}
data := value.(*internalCtxData)
data.Lock()
defer data.Unlock()
data.ConfigNode = node
return nil
}
func (c *Core) getConfigNodeFromCtx(ctx context.Context) *ConfigNode {
if value := ctx.Value(internalCtxDataKeyInCtx); value != nil {
data := value.(*internalCtxData)
data.Lock()
defer data.Unlock()
return data.ConfigNode
}
return nil
}
func (c *Core) injectInternalColumn(ctx context.Context) context.Context {
return context.WithValue(ctx, internalColumnDataKeyInCtx, &internalColumnData{})
}
func (c *Core) getInternalColumnFromCtx(ctx context.Context) *internalColumnData {
if v := ctx.Value(internalColumnDataKeyInCtx); v != nil {
return v.(*internalColumnData)
}
return nil
}
func (c *Core) InjectIgnoreResult(ctx context.Context) context.Context {
if ctx.Value(ignoreResultKeyInCtx) != nil {
return ctx
}
return context.WithValue(ctx, ignoreResultKeyInCtx, true)
}
func (c *Core) GetIgnoreResultFromCtx(ctx context.Context) bool {
return ctx.Value(ignoreResultKeyInCtx) != nil
}

43
database/gdb_core_link.go Normal file
View File

@ -0,0 +1,43 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
)
// dbLink is used to implement interface Link for DB.
type dbLink struct {
*sql.DB // Underlying DB object.
isOnMaster bool // isOnMaster marks whether current link is operated on master node.
}
// txLink is used to implement interface Link for TX.
type txLink struct {
*sql.Tx
}
// IsTransaction returns if current Link is a transaction.
func (l *dbLink) IsTransaction() bool {
return false
}
// IsOnMaster checks and returns whether current link is operated on master node.
func (l *dbLink) IsOnMaster() bool {
return l.isOnMaster
}
// IsTransaction returns if current Link is a transaction.
func (l *txLink) IsTransaction() bool {
return true
}
// IsOnMaster checks and returns whether current link is operated on master node.
// Note that, transaction operation is always operated on master node.
func (l *txLink) IsOnMaster() bool {
return true
}

View File

@ -0,0 +1,45 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
//
package database
import (
"context"
"database/sql"
)
type localStatsItem struct {
node *ConfigNode
stats sql.DBStats
}
// Node returns the configuration node info.
func (item *localStatsItem) Node() ConfigNode {
return *item.node
}
// Stats returns the connection stat for current node.
func (item *localStatsItem) Stats() sql.DBStats {
return item.stats
}
// Stats retrieves and returns the pool stat for all nodes that have been established.
func (c *Core) Stats(ctx context.Context) []StatsItem {
var items = make([]StatsItem, 0)
c.links.Iterator(func(k ConfigNode, v *sql.DB) bool {
// Create a local copy of k to avoid loop variable address re-use issue
// In Go, loop variables are re-used with the same memory address across iterations,
// directly using &k would cause all localStatsItem instances to share the same address
node := k
items = append(items, &localStatsItem{
node: &node,
stats: v.Stats(),
})
return true
})
return items
}

View File

@ -0,0 +1,512 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"database/sql/driver"
"math/big"
"reflect"
"strings"
"time"
"git.magicany.cc/black1552/gin-base/database/intlog"
"git.magicany.cc/black1552/gin-base/database/json"
"github.com/gogf/gf/v2/encoding/gbinary"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
// GetFieldTypeStr retrieves and returns the field type string for certain field by name.
func (c *Core) GetFieldTypeStr(ctx context.Context, fieldName, table, schema string) string {
field := c.GetFieldType(ctx, fieldName, table, schema)
if field != nil {
// Kinds of data type examples:
// year(4)
// datetime
// varchar(64)
// bigint(20)
// int(10) unsigned
typeName := gstr.StrTillEx(field.Type, "(") // int(10) unsigned -> int
if typeName != "" {
typeName = gstr.Trim(typeName)
} else {
typeName = field.Type
}
return typeName
}
return ""
}
// GetFieldType retrieves and returns the field type object for certain field by name.
func (c *Core) GetFieldType(ctx context.Context, fieldName, table, schema string) *TableField {
fieldsMap, err := c.db.TableFields(ctx, table, schema)
if err != nil {
intlog.Errorf(
ctx,
`TableFields failed for table "%s", schema "%s": %+v`,
table, schema, err,
)
return nil
}
for tableFieldName, tableField := range fieldsMap {
if tableFieldName == fieldName {
return tableField
}
}
return nil
}
// ConvertDataForRecord is a very important function, which does converting for any data that
// will be inserted into table/collection as a record.
//
// The parameter `value` should be type of *map/map/*struct/struct.
// It supports embedded struct definition for struct.
func (c *Core) ConvertDataForRecord(ctx context.Context, value any, table string) (map[string]any, error) {
var (
err error
data = MapOrStructToMapDeep(value, true)
)
for fieldName, fieldValue := range data {
var fieldType = c.GetFieldTypeStr(ctx, fieldName, table, c.GetSchema())
data[fieldName], err = c.db.ConvertValueForField(
ctx,
fieldType,
fieldValue,
)
if err != nil {
return nil, gerror.Wrapf(err, `ConvertDataForRecord failed for value: %#v`, fieldValue)
}
}
return data, nil
}
// ConvertValueForField converts value to the type of the record field.
// The parameter `fieldType` is the target record field.
// The parameter `fieldValue` is the value that to be committed to record field.
func (c *Core) ConvertValueForField(ctx context.Context, fieldType string, fieldValue any) (any, error) {
var (
err error
convertedValue = fieldValue
)
switch fieldValue.(type) {
case time.Time, *time.Time, gtime.Time, *gtime.Time:
goto Default
}
// If `value` implements interface `driver.Valuer`, it then uses the interface for value converting.
if valuer, ok := fieldValue.(driver.Valuer); ok {
if convertedValue, err = valuer.Value(); err != nil {
return nil, err
}
return convertedValue, nil
}
Default:
// Default value converting.
var (
rvValue = reflect.ValueOf(fieldValue)
rvKind = rvValue.Kind()
)
for rvKind == reflect.Pointer {
rvValue = rvValue.Elem()
rvKind = rvValue.Kind()
}
switch rvKind {
case reflect.Invalid:
convertedValue = nil
case reflect.Slice, reflect.Array, reflect.Map:
// It should ignore the bytes type.
if _, ok := fieldValue.([]byte); !ok {
// Convert the value to JSON.
convertedValue, err = json.Marshal(fieldValue)
if err != nil {
return nil, err
}
}
case reflect.Struct:
switch r := fieldValue.(type) {
// If the time is zero, it then updates it to nil,
// which will insert/update the value to database as "null".
case time.Time:
if r.IsZero() {
convertedValue = nil
} else {
switch fieldType {
case fieldTypeYear:
convertedValue = r.Format("2006")
case fieldTypeDate:
convertedValue = r.Format("2006-01-02")
case fieldTypeTime:
convertedValue = r.Format("15:04:05")
default:
}
}
case *time.Time:
if r == nil {
// Nothing to do.
} else {
switch fieldType {
case fieldTypeYear:
convertedValue = r.Format("2006")
case fieldTypeDate:
convertedValue = r.Format("2006-01-02")
case fieldTypeTime:
convertedValue = r.Format("15:04:05")
default:
}
}
case gtime.Time:
if r.IsZero() {
convertedValue = nil
} else {
switch fieldType {
case fieldTypeYear:
convertedValue = r.Layout("2006")
case fieldTypeDate:
convertedValue = r.Layout("2006-01-02")
case fieldTypeTime:
convertedValue = r.Layout("15:04:05")
default:
convertedValue = r.Time
}
}
case *gtime.Time:
if r.IsZero() {
convertedValue = nil
} else {
switch fieldType {
case fieldTypeYear:
convertedValue = r.Layout("2006")
case fieldTypeDate:
convertedValue = r.Layout("2006-01-02")
case fieldTypeTime:
convertedValue = r.Layout("15:04:05")
default:
convertedValue = r.Time
}
}
case Counter, *Counter:
// Nothing to do.
default:
// If `value` implements interface iNil,
// check its IsNil() function, if got ture,
// which will insert/update the value to database as "null".
if v, ok := fieldValue.(iNil); ok && v.IsNil() {
convertedValue = nil
} else if s, ok := fieldValue.(iString); ok {
// Use string conversion in default.
convertedValue = s.String()
} else {
// Convert the value to JSON.
convertedValue, err = json.Marshal(fieldValue)
if err != nil {
return nil, err
}
}
}
default:
}
return convertedValue, nil
}
// GetFormattedDBTypeNameForField retrieves and returns the formatted database type name
// eg. `int(10) unsigned` -> `int`, `varchar(100)` -> `varchar`, etc.
func (c *Core) GetFormattedDBTypeNameForField(fieldType string) (typeName, typePattern string) {
match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType)
if len(match) == 3 {
typeName = gstr.Trim(match[1])
typePattern = gstr.Trim(match[2])
} else {
var array = gstr.SplitAndTrim(fieldType, " ")
if len(array) > 1 && gstr.Equal(array[0], "unsigned") {
typeName = array[1]
} else if len(array) > 0 {
typeName = array[0]
}
}
typeName = strings.ToLower(typeName)
return
}
// CheckLocalTypeForField checks and returns corresponding type for given db type.
// The `fieldType` is retrieved from ColumnTypes of db driver, example:
// UNSIGNED INT
func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ any) (LocalType, error) {
var (
typeName string
typePattern string
)
typeName, typePattern = c.GetFormattedDBTypeNameForField(fieldType)
switch typeName {
case
fieldTypeBinary,
fieldTypeVarbinary,
fieldTypeBlob,
fieldTypeTinyblob,
fieldTypeMediumblob,
fieldTypeLongblob:
return LocalTypeBytes, nil
case
fieldTypeInt,
fieldTypeTinyint,
fieldTypeSmallInt,
fieldTypeSmallint,
fieldTypeMediumInt,
fieldTypeMediumint,
fieldTypeSerial:
if gstr.ContainsI(fieldType, "unsigned") {
return LocalTypeUint, nil
}
return LocalTypeInt, nil
case
fieldTypeBigInt,
fieldTypeBigint,
fieldTypeBigserial:
if gstr.ContainsI(fieldType, "unsigned") {
return LocalTypeUint64, nil
}
return LocalTypeInt64, nil
case
fieldTypeInt128,
fieldTypeInt256,
fieldTypeUint128,
fieldTypeUint256:
return LocalTypeBigInt, nil
case
fieldTypeReal:
return LocalTypeFloat32, nil
case
fieldTypeDecimal,
fieldTypeMoney,
fieldTypeNumeric,
fieldTypeSmallmoney:
return LocalTypeString, nil
case
fieldTypeFloat,
fieldTypeDouble:
return LocalTypeFloat64, nil
case
fieldTypeBit:
// It is suggested using bit(1) as boolean.
if typePattern == "1" {
return LocalTypeBool, nil
}
if gstr.ContainsI(fieldType, "unsigned") {
return LocalTypeUint64Bytes, nil
}
return LocalTypeInt64Bytes, nil
case
fieldTypeBool:
return LocalTypeBool, nil
case
fieldTypeDate:
return LocalTypeDate, nil
case
fieldTypeTime:
return LocalTypeTime, nil
case
fieldTypeDatetime,
fieldTypeTimestamp,
fieldTypeTimestampz:
return LocalTypeDatetime, nil
case
fieldTypeJson:
return LocalTypeJson, nil
case
fieldTypeJsonb:
return LocalTypeJsonb, nil
default:
// Auto-detect field type, using key match.
switch {
case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"):
return LocalTypeString, nil
case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"):
return LocalTypeFloat64, nil
case strings.Contains(typeName, "bool"):
return LocalTypeBool, nil
case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"):
return LocalTypeBytes, nil
case strings.Contains(typeName, "int"):
if gstr.ContainsI(fieldType, "unsigned") {
return LocalTypeUint, nil
}
return LocalTypeInt, nil
case strings.Contains(typeName, "time"):
return LocalTypeDatetime, nil
case strings.Contains(typeName, "date"):
return LocalTypeDatetime, nil
default:
return LocalTypeString, nil
}
}
}
// ConvertValueForLocal converts value to local Golang type of value according field type name from database.
// The parameter `fieldType` is in lower case, like:
// `float(5,2)`, `unsigned double(5,2)`, `decimal(10,2)`, `char(45)`, `varchar(100)`, etc.
func (c *Core) ConvertValueForLocal(
ctx context.Context, fieldType string, fieldValue any,
) (any, error) {
// If there's no type retrieved, it returns the `fieldValue` directly
// to use its original data type, as `fieldValue` is type of any.
if fieldType == "" {
return fieldValue, nil
}
typeName, err := c.db.CheckLocalTypeForField(ctx, fieldType, fieldValue)
if err != nil {
return nil, err
}
switch typeName {
case LocalTypeBytes:
var typeNameStr = string(typeName)
if strings.Contains(typeNameStr, "binary") || strings.Contains(typeNameStr, "blob") {
return fieldValue, nil
}
return gconv.Bytes(fieldValue), nil
case LocalTypeInt:
return gconv.Int(gconv.String(fieldValue)), nil
case LocalTypeUint:
return gconv.Uint(gconv.String(fieldValue)), nil
case LocalTypeInt64:
return gconv.Int64(gconv.String(fieldValue)), nil
case LocalTypeUint64:
return gconv.Uint64(gconv.String(fieldValue)), nil
case LocalTypeInt64Bytes:
return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil
case LocalTypeUint64Bytes:
return gbinary.BeDecodeToUint64(gconv.Bytes(fieldValue)), nil
case LocalTypeBigInt:
switch v := fieldValue.(type) {
case big.Int:
return v.String(), nil
case *big.Int:
return v.String(), nil
default:
return gconv.String(fieldValue), nil
}
case LocalTypeFloat32:
return gconv.Float32(gconv.String(fieldValue)), nil
case LocalTypeFloat64:
return gconv.Float64(gconv.String(fieldValue)), nil
case LocalTypeBool:
s := gconv.String(fieldValue)
// mssql is true|false string.
if strings.EqualFold(s, "true") {
return 1, nil
}
if strings.EqualFold(s, "false") {
return 0, nil
}
return gconv.Bool(fieldValue), nil
case LocalTypeDate:
if t, ok := fieldValue.(time.Time); ok {
return gtime.NewFromTime(t).Format("Y-m-d"), nil
}
t, _ := gtime.StrToTime(gconv.String(fieldValue))
return t.Format("Y-m-d"), nil
case LocalTypeTime:
if t, ok := fieldValue.(time.Time); ok {
return gtime.NewFromTime(t).Format("H:i:s"), nil
}
t, _ := gtime.StrToTime(gconv.String(fieldValue))
return t.Format("H:i:s"), nil
case LocalTypeDatetime:
if t, ok := fieldValue.(time.Time); ok {
return gtime.NewFromTime(t), nil
}
t, _ := gtime.StrToTime(gconv.String(fieldValue))
return t, nil
default:
return gconv.String(fieldValue), nil
}
}
// mappingAndFilterData automatically mappings the map key to table field and removes
// all key-value pairs that are not the field of given table.
func (c *Core) mappingAndFilterData(ctx context.Context, schema, table string, data map[string]any, filter bool) (map[string]any, error) {
fieldsMap, err := c.db.TableFields(ctx, c.guessPrimaryTableName(table), schema)
if err != nil {
return nil, err
}
if len(fieldsMap) == 0 {
return nil, gerror.Newf(`The table %s may not exist, or the table contains no fields`, table)
}
fieldsKeyMap := make(map[string]any, len(fieldsMap))
for k := range fieldsMap {
fieldsKeyMap[k] = nil
}
// Automatic data key to table field name mapping.
var foundKey string
for dataKey, dataValue := range data {
if _, ok := fieldsKeyMap[dataKey]; !ok {
foundKey, _ = gutil.MapPossibleItemByKey(fieldsKeyMap, dataKey)
if foundKey != "" {
if _, ok = data[foundKey]; !ok {
data[foundKey] = dataValue
}
delete(data, dataKey)
}
}
}
// Data filtering.
// It deletes all key-value pairs that has incorrect field name.
if filter {
for dataKey := range data {
if _, ok := fieldsMap[dataKey]; !ok {
delete(data, dataKey)
}
}
if len(data) == 0 {
return nil, gerror.Newf(`input data match no fields in table %s`, table)
}
}
return data, nil
}

View File

@ -0,0 +1,84 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
//
package database
import (
"context"
"fmt"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.18.0"
"go.opentelemetry.io/otel/trace"
"github.com/gogf/gf/v2/net/gtrace"
)
const (
traceInstrumentName = "github.com/gogf/gf/v2/database/gdb"
traceAttrDbType = "db.type"
traceAttrDbHost = "db.host"
traceAttrDbPort = "db.port"
traceAttrDbName = "db.name"
traceAttrDbUser = "db.user"
traceAttrDbLink = "db.link"
traceAttrDbGroup = "db.group"
traceEventDbExecution = "db.execution"
traceEventDbExecutionCost = "db.execution.cost"
traceEventDbExecutionRows = "db.execution.rows"
traceEventDbExecutionTxID = "db.execution.txid"
traceEventDbExecutionType = "db.execution.type"
)
// addSqlToTracing adds sql information to tracer if it's enabled.
func (c *Core) traceSpanEnd(ctx context.Context, span trace.Span, sql *Sql) {
if gtrace.IsUsingDefaultProvider() || !gtrace.IsTracingInternal() {
return
}
if sql.Error != nil {
span.SetStatus(codes.Error, fmt.Sprintf(`%+v`, sql.Error))
}
labels := make([]attribute.KeyValue, 0)
labels = append(labels, gtrace.CommonLabels()...)
labels = append(labels,
attribute.String(traceAttrDbType, c.db.GetConfig().Type),
semconv.DBStatement(sql.Format),
)
if c.db.GetConfig().Host != "" {
labels = append(labels, attribute.String(traceAttrDbHost, c.db.GetConfig().Host))
}
if c.db.GetConfig().Port != "" {
labels = append(labels, attribute.String(traceAttrDbPort, c.db.GetConfig().Port))
}
if c.db.GetConfig().Name != "" {
labels = append(labels, attribute.String(traceAttrDbName, c.db.GetConfig().Name))
}
if c.db.GetConfig().User != "" {
labels = append(labels, attribute.String(traceAttrDbUser, c.db.GetConfig().User))
}
if filteredLink := c.db.GetCore().FilteredLink(); filteredLink != "" {
labels = append(labels, attribute.String(traceAttrDbLink, c.db.GetCore().FilteredLink()))
}
if group := c.db.GetGroup(); group != "" {
labels = append(labels, attribute.String(traceAttrDbGroup, group))
}
span.SetAttributes(labels...)
events := []attribute.KeyValue{
attribute.String(traceEventDbExecutionCost, fmt.Sprintf(`%d ms`, sql.End-sql.Start)),
attribute.String(traceEventDbExecutionRows, fmt.Sprintf(`%d`, sql.RowsAffected)),
}
if sql.IsTransaction {
if v := ctx.Value(transactionIdForLoggerCtx); v != nil {
events = append(events, attribute.String(
traceEventDbExecutionTxID, fmt.Sprintf(`%d`, v.(uint64)),
))
}
}
events = append(events, attribute.String(traceEventDbExecutionType, string(sql.Type)))
span.AddEvent(traceEventDbExecution, trace.WithAttributes(events...))
}

View File

@ -0,0 +1,295 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"database/sql"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
)
// Propagation defines transaction propagation behavior.
type Propagation string
const (
// PropagationNested starts a nested transaction if already in a transaction,
// or behaves like PropagationRequired if not in a transaction.
//
// It is the default behavior.
PropagationNested Propagation = "NESTED"
// PropagationRequired starts a new transaction if not in a transaction,
// or uses the existing transaction if already in a transaction.
PropagationRequired Propagation = "REQUIRED"
// PropagationSupports executes within the existing transaction if present,
// otherwise executes without transaction.
PropagationSupports Propagation = "SUPPORTS"
// PropagationRequiresNew starts a new transaction, and suspends the current transaction if one exists.
PropagationRequiresNew Propagation = "REQUIRES_NEW"
// PropagationNotSupported executes non-transactional, suspends any existing transaction.
PropagationNotSupported Propagation = "NOT_SUPPORTED"
// PropagationMandatory executes in a transaction, fails if no existing transaction.
PropagationMandatory Propagation = "MANDATORY"
// PropagationNever executes non-transactional, fails if in an existing transaction.
PropagationNever Propagation = "NEVER"
)
// TxOptions defines options for transaction control.
type TxOptions struct {
// Propagation specifies the propagation behavior.
Propagation Propagation
// Isolation is the transaction isolation level.
// If zero, the driver or database's default level is used.
Isolation sql.IsolationLevel
// ReadOnly is used to mark the transaction as read-only.
ReadOnly bool
}
// Context key types for transaction to avoid collisions
type transactionCtxKey string
const (
transactionPointerPrefix = "transaction"
contextTransactionKeyPrefix = "TransactionObjectForGroup_"
transactionIdForLoggerCtx transactionCtxKey = "TransactionId"
)
var transactionIdGenerator = gtype.NewUint64()
// DefaultTxOptions returns the default transaction options.
func DefaultTxOptions() TxOptions {
return TxOptions{
// Note the default propagation type is PropagationNested not PropagationRequired.
Propagation: PropagationNested,
}
}
// Begin starts and returns the transaction object.
// You should call Commit or Rollback functions of the transaction object
// if you no longer use the transaction. Commit or Rollback functions will also
// close the transaction automatically.
func (c *Core) Begin(ctx context.Context) (tx TX, err error) {
return c.BeginWithOptions(ctx, DefaultTxOptions())
}
// BeginWithOptions starts and returns the transaction object with given options.
// The options allow specifying the isolation level and read-only mode.
// You should call Commit or Rollback functions of the transaction object
// if you no longer use the transaction. Commit or Rollback functions will also
// close the transaction automatically.
func (c *Core) BeginWithOptions(ctx context.Context, opts TxOptions) (tx TX, err error) {
if ctx == nil {
ctx = c.db.GetCtx()
}
ctx = c.injectInternalCtxData(ctx)
return c.doBeginCtx(ctx, sql.TxOptions{
Isolation: opts.Isolation,
ReadOnly: opts.ReadOnly,
})
}
func (c *Core) doBeginCtx(ctx context.Context, opts sql.TxOptions) (TX, error) {
master, err := c.db.Master()
if err != nil {
return nil, err
}
var out DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Db: master,
Sql: "BEGIN",
Type: SqlTypeBegin,
TxOptions: opts,
IsTransaction: true,
})
return out.Tx, err
}
// Transaction wraps the transaction logic using function `f`.
// It rollbacks the transaction and returns the error from function `f` if
// it returns non-nil error. It commits the transaction and returns nil if
// function `f` returns nil.
//
// Note that, you should not Commit or Rollback the transaction in function `f`
// as it is automatically handled by this function.
func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) {
return c.TransactionWithOptions(ctx, DefaultTxOptions(), f)
}
// TransactionWithOptions wraps the transaction logic with propagation options using function `f`.
func (c *Core) TransactionWithOptions(
ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error,
) (err error) {
if ctx == nil {
ctx = c.db.GetCtx()
}
ctx = c.injectInternalCtxData(ctx)
// Check current transaction from context
var (
group = c.db.GetGroup()
currentTx = TXFromCtx(ctx, group)
)
switch opts.Propagation {
case PropagationRequired:
if currentTx != nil {
return f(ctx, currentTx)
}
return c.createNewTransaction(ctx, opts, f)
case PropagationSupports:
if currentTx == nil {
currentTx = c.newEmptyTX()
}
return f(ctx, currentTx)
case PropagationMandatory:
if currentTx == nil {
return gerror.NewCode(
gcode.CodeInvalidOperation,
"transaction propagation MANDATORY requires an existing transaction",
)
}
return f(ctx, currentTx)
case PropagationRequiresNew:
ctx = WithoutTX(ctx, group)
return c.createNewTransaction(ctx, opts, f)
case PropagationNotSupported:
ctx = WithoutTX(ctx, group)
return f(ctx, c.newEmptyTX())
case PropagationNever:
if currentTx != nil {
return gerror.NewCode(
gcode.CodeInvalidOperation,
"transaction propagation NEVER cannot run within an existing transaction",
)
}
ctx = WithoutTX(ctx, group)
return f(ctx, c.newEmptyTX())
case PropagationNested:
if currentTx != nil {
return currentTx.Transaction(ctx, f)
}
return c.createNewTransaction(ctx, opts, f)
default:
return gerror.NewCodef(
gcode.CodeInvalidParameter,
"unsupported propagation behavior: %s",
opts.Propagation,
)
}
}
// createNewTransaction handles creating and managing a new transaction
func (c *Core) createNewTransaction(
ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error,
) (err error) {
// Begin transaction with options
tx, err := c.doBeginCtx(ctx, sql.TxOptions{
Isolation: opts.Isolation,
ReadOnly: opts.ReadOnly,
})
if err != nil {
return err
}
// Inject transaction object into context
ctx = WithTX(tx.GetCtx(), tx)
err = callTxFunc(tx.Ctx(ctx), f)
return
}
func callTxFunc(tx TX, f func(ctx context.Context, tx TX) error) (err error) {
defer func() {
if err == nil {
if exception := recover(); exception != nil {
if v, ok := exception.(error); ok && gerror.HasStack(v) {
err = v
} else {
err = gerror.NewCodef(gcode.CodeInternalPanic, "%+v", exception)
}
}
}
if err != nil {
if e := tx.Rollback(); e != nil {
err = e
}
} else {
if e := tx.Commit(); e != nil {
err = e
}
}
}()
err = f(tx.GetCtx(), tx)
return
}
// WithTX injects given transaction object into context and returns a new context.
func WithTX(ctx context.Context, tx TX) context.Context {
if tx == nil {
return ctx
}
// Check repeat injection from given.
group := tx.GetDB().GetGroup()
if ctxTx := TXFromCtx(ctx, group); ctxTx != nil && ctxTx.GetDB().GetGroup() == group {
return ctx
}
dbCtx := tx.GetDB().GetCtx()
if ctxTx := TXFromCtx(dbCtx, group); ctxTx != nil && ctxTx.GetDB().GetGroup() == group {
return dbCtx
}
// Inject transaction object and id into context.
ctx = context.WithValue(ctx, transactionKeyForContext(group), tx)
ctx = context.WithValue(ctx, transactionIdForLoggerCtx, tx.GetCtx().Value(transactionIdForLoggerCtx))
return ctx
}
// WithoutTX removed transaction object from context and returns a new context.
func WithoutTX(ctx context.Context, group string) context.Context {
ctx = context.WithValue(ctx, transactionKeyForContext(group), nil)
ctx = context.WithValue(ctx, transactionIdForLoggerCtx, nil)
return ctx
}
// TXFromCtx retrieves and returns transaction object from context.
// It is usually used in nested transaction feature, and it returns nil if it is not set previously.
func TXFromCtx(ctx context.Context, group string) TX {
if ctx == nil {
return nil
}
v := ctx.Value(transactionKeyForContext(group))
if v != nil {
tx := v.(TX)
if tx.IsClosed() {
return nil
}
// no underlying sql tx.
if tx.GetSqlTX() == nil {
return nil
}
tx = tx.Ctx(ctx)
return tx
}
return nil
}
// transactionKeyForContext forms and returns a key for storing transaction object of certain database group into context.
func transactionKeyForContext(group string) transactionCtxKey {
return transactionCtxKey(contextTransactionKeyPrefix + group)
}

437
database/gdb_core_txcore.go Normal file
View File

@ -0,0 +1,437 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"database/sql"
"reflect"
"git.magicany.cc/black1552/gin-base/database/reflection"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/util/gconv"
)
// TXCore is the struct for transaction management.
type TXCore struct {
// db is the database management interface that implements the DB interface,
// providing access to database operations and configuration.
db DB
// tx is the underlying SQL transaction object from database/sql package,
// which manages the actual transaction operations.
tx *sql.Tx
// ctx is the context specific to this transaction,
// which can be used for timeout control and cancellation.
ctx context.Context
// master is the underlying master database connection pool,
// used for direct database operations when needed.
master *sql.DB
// transactionId is a unique identifier for this transaction instance,
// used for tracking and debugging purposes.
transactionId string
// transactionCount tracks the number of nested transaction begins,
// used for managing transaction nesting depth.
transactionCount int
// isClosed indicates whether this transaction has been finalized
// through either a commit or rollback operation.
isClosed bool
// cancelFunc is the context cancellation function associated with ctx,
// used to cancel the transaction context when needed.
cancelFunc context.CancelFunc
}
func (c *Core) newEmptyTX() TX {
return &TXCore{
db: c.db,
}
}
// transactionKeyForNestedPoint forms and returns the transaction key at current save point.
func (tx *TXCore) transactionKeyForNestedPoint() string {
return tx.db.GetCore().QuoteWord(
transactionPointerPrefix + gconv.String(tx.transactionCount),
)
}
// Ctx sets the context for current transaction.
func (tx *TXCore) Ctx(ctx context.Context) TX {
tx.ctx = ctx
if tx.ctx != nil {
tx.ctx = tx.db.GetCore().injectInternalCtxData(tx.ctx)
}
return tx
}
// GetCtx returns the context for current transaction.
func (tx *TXCore) GetCtx() context.Context {
return tx.ctx
}
// GetDB returns the DB for current transaction.
func (tx *TXCore) GetDB() DB {
return tx.db
}
// GetSqlTX returns the underlying transaction object for current transaction.
func (tx *TXCore) GetSqlTX() *sql.Tx {
return tx.tx
}
// Commit commits current transaction.
// Note that it releases previous saved transaction point if it's in a nested transaction procedure,
// or else it commits the hole transaction.
func (tx *TXCore) Commit() error {
if tx.transactionCount > 0 {
tx.transactionCount--
_, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKeyForNestedPoint())
return err
}
_, err := tx.db.DoCommit(tx.ctx, DoCommitInput{
Tx: tx.tx,
Sql: "COMMIT",
Type: SqlTypeTXCommit,
TxCancelFunc: tx.cancelFunc,
IsTransaction: true,
})
if err == nil {
tx.isClosed = true
}
return err
}
// Rollback aborts current transaction.
// Note that it aborts current transaction if it's in a nested transaction procedure,
// or else it aborts the hole transaction.
func (tx *TXCore) Rollback() error {
if tx.transactionCount > 0 {
tx.transactionCount--
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKeyForNestedPoint())
return err
}
_, err := tx.db.DoCommit(tx.ctx, DoCommitInput{
Tx: tx.tx,
Sql: "ROLLBACK",
Type: SqlTypeTXRollback,
TxCancelFunc: tx.cancelFunc,
IsTransaction: true,
})
if err == nil {
tx.isClosed = true
}
return err
}
// IsClosed checks and returns this transaction has already been committed or rolled back.
func (tx *TXCore) IsClosed() bool {
return tx.isClosed
}
// Begin starts a nested transaction procedure.
func (tx *TXCore) Begin() error {
_, err := tx.Exec("SAVEPOINT " + tx.transactionKeyForNestedPoint())
if err != nil {
return err
}
tx.transactionCount++
return nil
}
// SavePoint performs `SAVEPOINT xxx` SQL statement that saves transaction at current point.
// The parameter `point` specifies the point name that will be saved to server.
func (tx *TXCore) SavePoint(point string) error {
_, err := tx.Exec("SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
return err
}
// RollbackTo performs `ROLLBACK TO SAVEPOINT xxx` SQL statement that rollbacks to specified saved transaction.
// The parameter `point` specifies the point name that was saved previously.
func (tx *TXCore) RollbackTo(point string) error {
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
return err
}
// Transaction wraps the transaction logic using function `f`.
// It rollbacks the transaction and returns the error from function `f` if
// it returns non-nil error. It commits the transaction and returns nil if
// function `f` returns nil.
//
// Note that, you should not Commit or Rollback the transaction in function `f`
// as it is automatically handled by this function.
func (tx *TXCore) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) {
if ctx != nil {
tx.ctx = ctx
}
// Check transaction object from context.
if TXFromCtx(tx.ctx, tx.db.GetGroup()) == nil {
// Inject transaction object into context.
tx.ctx = WithTX(tx.ctx, tx)
}
if err = tx.Begin(); err != nil {
return err
}
err = callTxFunc(tx, f)
return
}
// TransactionWithOptions wraps the transaction logic with propagation options using function `f`.
func (tx *TXCore) TransactionWithOptions(
ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error,
) (err error) {
return tx.db.TransactionWithOptions(ctx, opts, f)
}
// Query does query operation on transaction.
// See Core.Query.
func (tx *TXCore) Query(sql string, args ...any) (result Result, err error) {
return tx.db.DoQuery(tx.ctx, &txLink{tx.tx}, sql, args...)
}
// Exec does none query operation on transaction.
// See Core.Exec.
func (tx *TXCore) Exec(sql string, args ...any) (sql.Result, error) {
return tx.db.DoExec(tx.ctx, &txLink{tx.tx}, sql, args...)
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
func (tx *TXCore) Prepare(sql string) (*Stmt, error) {
return tx.db.DoPrepare(tx.ctx, &txLink{tx.tx}, sql)
}
// GetAll queries and returns data records from database.
func (tx *TXCore) GetAll(sql string, args ...any) (Result, error) {
return tx.Query(sql, args...)
}
// GetOne queries and returns one record from database.
func (tx *TXCore) GetOne(sql string, args ...any) (Record, error) {
list, err := tx.GetAll(sql, args...)
if err != nil {
return nil, err
}
if len(list) > 0 {
return list[0], nil
}
return nil, nil
}
// GetStruct queries one record from database and converts it to given struct.
// The parameter `pointer` should be a pointer to struct.
func (tx *TXCore) GetStruct(obj any, sql string, args ...any) error {
one, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
return one.Struct(obj)
}
// GetStructs queries records from database and converts them to given struct.
// The parameter `pointer` should be type of struct slice: []struct/[]*struct.
func (tx *TXCore) GetStructs(objPointerSlice any, sql string, args ...any) error {
all, err := tx.GetAll(sql, args...)
if err != nil {
return err
}
return all.Structs(objPointerSlice)
}
// GetScan queries one or more records from database and converts them to given struct or
// struct array.
//
// If parameter `pointer` is type of struct pointer, it calls GetStruct internally for
// the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally
// for conversion.
func (tx *TXCore) GetScan(pointer any, sql string, args ...any) error {
reflectInfo := reflection.OriginTypeAndKind(pointer)
if reflectInfo.InputKind != reflect.Pointer {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
"params should be type of pointer, but got: %v",
reflectInfo.InputKind,
)
}
switch reflectInfo.OriginKind {
case reflect.Array, reflect.Slice:
return tx.GetStructs(pointer, sql, args...)
case reflect.Struct:
return tx.GetStruct(pointer, sql, args...)
default:
}
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`in valid parameter type "%v", of which element type should be type of struct/slice`,
reflectInfo.InputType,
)
}
// GetValue queries and returns the field value from database.
// The sql should query only one field from database, or else it returns only one
// field of the result.
func (tx *TXCore) GetValue(sql string, args ...any) (Value, error) {
one, err := tx.GetOne(sql, args...)
if err != nil {
return nil, err
}
for _, v := range one {
return v, nil
}
return nil, nil
}
// GetCount queries and returns the count from database.
func (tx *TXCore) GetCount(sql string, args ...any) (int64, error) {
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
}
value, err := tx.GetValue(sql, args...)
if err != nil {
return 0, err
}
return value.Int64(), nil
}
// Insert does "INSERT INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it returns error.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `batch` specifies the batch operation count when given data is slice.
func (tx *TXCore) Insert(table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Insert()
}
return tx.Model(table).Ctx(tx.ctx).Data(data).Insert()
}
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it ignores the inserting.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `batch` specifies the batch operation count when given data is slice.
func (tx *TXCore) InsertIgnore(table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertIgnore()
}
return tx.Model(table).Ctx(tx.ctx).Data(data).InsertIgnore()
}
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
func (tx *TXCore) InsertAndGetId(table string, data any, batch ...int) (int64, error) {
if len(batch) > 0 {
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertAndGetId()
}
return tx.Model(table).Ctx(tx.ctx).Data(data).InsertAndGetId()
}
// Replace does "REPLACE INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it deletes the record
// and inserts a new one.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// If given data is type of slice, it then does batch replacing, and the optional parameter
// `batch` specifies the batch operation count.
func (tx *TXCore) Replace(table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Replace()
}
return tx.Model(table).Ctx(tx.ctx).Data(data).Replace()
}
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
// It updates the record if there's primary or unique index in the saving data,
// or else it inserts a new record into the table.
//
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// If given data is type of slice, it then does batch saving, and the optional parameter
// `batch` specifies the batch operation count.
func (tx *TXCore) Save(table string, data any, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Save()
}
return tx.Model(table).Ctx(tx.ctx).Data(data).Save()
}
// Update does "UPDATE ... " statement for the table.
//
// The parameter `data` can be type of string/map/gmap/struct/*struct, etc.
// Eg: "uid=10000", "uid", 10000, g.Map{"uid": 10000, "name":"john"}
//
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
// It is commonly used with parameter `args`.
// Eg:
// "uid=10000",
// "uid", 10000
// "money>? AND name like ?", 99999, "vip_%"
// "status IN (?)", g.Slice{1,2,3}
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}.
func (tx *TXCore) Update(table string, data any, condition any, args ...any) (sql.Result, error) {
return tx.Model(table).Ctx(tx.ctx).Data(data).Where(condition, args...).Update()
}
// Delete does "DELETE FROM ... " statement for the table.
//
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
// It is commonly used with parameter `args`.
// Eg:
// "uid=10000",
// "uid", 10000
// "money>? AND name like ?", 99999, "vip_%"
// "status IN (?)", g.Slice{1,2,3}
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}.
func (tx *TXCore) Delete(table string, condition any, args ...any) (sql.Result, error) {
return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
}
// QueryContext implements interface function Link.QueryContext.
func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...any) (*sql.Rows, error) {
return tx.tx.QueryContext(ctx, sql, args...)
}
// ExecContext implements interface function Link.ExecContext.
func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...any) (sql.Result, error) {
return tx.tx.ExecContext(ctx, sql, args...)
}
// PrepareContext implements interface function Link.PrepareContext.
func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
return tx.tx.PrepareContext(ctx, sql)
}
// IsOnMaster implements interface function Link.IsOnMaster.
func (tx *TXCore) IsOnMaster() bool {
return true
}
// IsTransaction implements interface function Link.IsTransaction.
func (tx *TXCore) IsTransaction() bool {
return tx != nil
}

View File

@ -0,0 +1,533 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
//
package database
import (
"context"
"database/sql"
"fmt"
"reflect"
"git.magicany.cc/black1552/gin-base/database/intlog"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"github.com/gogf/gf/v2"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/guid"
)
// Query commits one query SQL to underlying driver and returns the execution result.
// It is most commonly used for data querying.
func (c *Core) Query(ctx context.Context, sql string, args ...any) (result Result, err error) {
return c.db.DoQuery(ctx, nil, sql, args...)
}
// DoQuery commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) {
// Transaction checks.
if link == nil {
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
// Firstly, check and retrieve transaction link from context.
link = &txLink{tx.GetSqlTX()}
} else if link, err = c.SlaveLink(); err != nil {
// Or else it creates one from master node.
return nil, err
}
} else if !link.IsTransaction() {
// If current link is not transaction link, it checks and retrieves transaction from context.
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
link = &txLink{tx.GetSqlTX()}
}
}
// Sql filtering.
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
}
// SQL format and retrieve.
if v := ctx.Value(ctxKeyCatchSQL); v != nil {
var (
manager = v.(*CatchSQLManager)
formattedSql = FormatSqlWithArgs(sql, args)
)
manager.SQLArray.Append(formattedSql)
if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
return nil, nil
}
}
// Link execution.
var out DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Link: link,
Sql: sql,
Args: args,
Stmt: nil,
Type: SqlTypeQueryContext,
IsTransaction: link.IsTransaction(),
})
if err != nil {
return nil, err
}
return out.Records, err
}
// Exec commits one query SQL to underlying driver and returns the execution result.
// It is most commonly used for data inserting and updating.
func (c *Core) Exec(ctx context.Context, sql string, args ...any) (result sql.Result, err error) {
return c.db.DoExec(ctx, nil, sql, args...)
}
// DoExec commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...any) (result sql.Result, err error) {
// Transaction checks.
if link == nil {
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
// Firstly, check and retrieve transaction link from context.
link = &txLink{tx.GetSqlTX()}
} else if link, err = c.MasterLink(); err != nil {
// Or else it creates one from master node.
return nil, err
}
} else if !link.IsTransaction() {
// If current link is not transaction link, it tries retrieving transaction object from context.
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
link = &txLink{tx.GetSqlTX()}
}
}
// SQL filtering.
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
}
// SQL format and retrieve.
if v := ctx.Value(ctxKeyCatchSQL); v != nil {
var (
manager = v.(*CatchSQLManager)
formattedSql = FormatSqlWithArgs(sql, args)
)
manager.SQLArray.Append(formattedSql)
if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
return new(SqlResult), nil
}
}
// Link execution.
var out DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Link: link,
Sql: sql,
Args: args,
Stmt: nil,
Type: SqlTypeExecContext,
IsTransaction: link.IsTransaction(),
})
if err != nil {
return nil, err
}
return out.Result, err
}
// DoFilter is a hook function, which filters the sql and its arguments before it's committed to underlying driver.
// The parameter `link` specifies the current database connection operation object. You can modify the sql
// string `sql` and its arguments `args` as you wish before they're committed to driver.
func (c *Core) DoFilter(
ctx context.Context, link Link, sql string, args []any,
) (newSql string, newArgs []any, err error) {
return sql, args, nil
}
// DoCommit commits current sql and arguments to underlying sql driver.
func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) {
var (
sqlTx *sql.Tx
sqlStmt *sql.Stmt
sqlRows *sql.Rows
sqlResult sql.Result
stmtSqlRows *sql.Rows
stmtSqlRow *sql.Row
rowsAffected int64
cancelFuncForTimeout context.CancelFunc
formattedSql = FormatSqlWithArgs(in.Sql, in.Args)
timestampMilli1 = gtime.TimestampMilli()
)
// Panic recovery to handle panics from underlying database drivers
defer func() {
if exception := recover(); exception != nil {
if err == nil {
if v, ok := exception.(error); ok && gerror.HasStack(v) {
err = v
} else {
err = gerror.WrapCodef(gcode.CodeDbOperationError, gerror.NewCodef(gcode.CodeInternalPanic, "%+v", exception), FormatSqlWithArgs(in.Sql, in.Args))
}
}
}
}()
// Trace span start.
tr := otel.GetTracerProvider().Tracer(traceInstrumentName, trace.WithInstrumentationVersion(gf.VERSION))
ctx, span := tr.Start(ctx, string(in.Type), trace.WithSpanKind(trace.SpanKindClient))
defer span.End()
// Execution by type.
switch in.Type {
case SqlTypeBegin:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeTrans)
formattedSql = fmt.Sprintf(
`%s (IosolationLevel: %s, ReadOnly: %t)`,
formattedSql, in.TxOptions.Isolation.String(), in.TxOptions.ReadOnly,
)
if sqlTx, err = in.Db.BeginTx(ctx, &in.TxOptions); err == nil {
tx := &TXCore{
db: c.db,
tx: sqlTx,
ctx: ctx,
master: in.Db,
transactionId: guid.S(),
cancelFunc: cancelFuncForTimeout,
}
tx.ctx = context.WithValue(ctx, transactionKeyForContext(tx.db.GetGroup()), tx)
tx.ctx = context.WithValue(tx.ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1))
out.Tx = tx
ctx = out.Tx.GetCtx()
}
out.RawResult = sqlTx
case SqlTypeTXCommit:
if in.TxCancelFunc != nil {
defer in.TxCancelFunc()
}
err = in.Tx.Commit()
case SqlTypeTXRollback:
if in.TxCancelFunc != nil {
defer in.TxCancelFunc()
}
err = in.Tx.Rollback()
case SqlTypeExecContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
defer cancelFuncForTimeout()
if c.db.GetDryRun() {
sqlResult = new(SqlResult)
} else {
sqlResult, err = in.Link.ExecContext(ctx, in.Sql, in.Args...)
}
out.RawResult = sqlResult
case SqlTypeQueryContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
defer cancelFuncForTimeout()
sqlRows, err = in.Link.QueryContext(ctx, in.Sql, in.Args...)
out.RawResult = sqlRows
case SqlTypePrepareContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypePrepare)
defer cancelFuncForTimeout()
sqlStmt, err = in.Link.PrepareContext(ctx, in.Sql)
out.RawResult = sqlStmt
case SqlTypeStmtExecContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
defer cancelFuncForTimeout()
if c.db.GetDryRun() {
sqlResult = new(SqlResult)
} else {
sqlResult, err = in.Stmt.ExecContext(ctx, in.Args...)
}
out.RawResult = sqlResult
case SqlTypeStmtQueryContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
defer cancelFuncForTimeout()
stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...)
out.RawResult = stmtSqlRows
case SqlTypeStmtQueryRowContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
defer cancelFuncForTimeout()
stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...)
out.RawResult = stmtSqlRow
default:
panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid SqlType "%s"`, in.Type))
}
// Result handling.
switch {
case sqlResult != nil && !c.GetIgnoreResultFromCtx(ctx):
rowsAffected, err = sqlResult.RowsAffected()
out.Result = sqlResult
case sqlRows != nil:
out.Records, err = c.RowsToResult(ctx, sqlRows)
rowsAffected = int64(len(out.Records))
case sqlStmt != nil:
out.Stmt = &Stmt{
Stmt: sqlStmt,
core: c,
link: in.Link,
sql: in.Sql,
}
}
var (
timestampMilli2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: in.Sql,
Type: in.Type,
Args: in.Args,
Format: formattedSql,
Error: err,
Start: timestampMilli1,
End: timestampMilli2,
Group: c.db.GetGroup(),
Schema: c.db.GetSchema(),
RowsAffected: rowsAffected,
IsTransaction: in.IsTransaction,
}
)
// Tracing.
c.traceSpanEnd(ctx, span, sqlObj)
// Logging.
if c.db.GetDebug() {
c.writeSqlToLogger(ctx, sqlObj)
}
if err != nil && err != sql.ErrNoRows {
err = gerror.WrapCode(
gcode.CodeDbOperationError,
err,
FormatSqlWithArgs(in.Sql, in.Args),
)
}
return out, err
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
//
// The parameter `execOnMaster` specifies whether executing the sql on master node,
// or else it executes the sql on slave node if master-slave configured.
func (c *Core) Prepare(ctx context.Context, sql string, execOnMaster ...bool) (*Stmt, error) {
var (
err error
link Link
)
if len(execOnMaster) > 0 && execOnMaster[0] {
if link, err = c.MasterLink(); err != nil {
return nil, err
}
} else {
if link, err = c.SlaveLink(); err != nil {
return nil, err
}
}
return c.db.DoPrepare(ctx, link, sql)
}
// DoPrepare calls prepare function on given link object and returns the statement object.
func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt, err error) {
// Transaction checks.
if link == nil {
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
// Firstly, check and retrieve transaction link from context.
link = &txLink{tx.GetSqlTX()}
} else {
// Or else it creates one from master node.
if link, err = c.MasterLink(); err != nil {
return nil, err
}
}
} else if !link.IsTransaction() {
// If current link is not transaction link, it checks and retrieves transaction from context.
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
link = &txLink{tx.GetSqlTX()}
}
}
if c.db.GetConfig().PrepareTimeout > 0 {
// DO NOT USE cancel function in prepare statement.
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, c.db.GetConfig().PrepareTimeout)
defer cancelFunc()
}
// Link execution.
var out DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Link: link,
Sql: sql,
Type: SqlTypePrepareContext,
IsTransaction: link.IsTransaction(),
})
if err != nil {
return nil, err
}
return out.Stmt, err
}
// FormatUpsert formats and returns SQL clause part for upsert statement.
// In default implements, this function performs upsert statement for MySQL like:
// `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...`
func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) {
var onDuplicateStr string
if option.OnDuplicateStr != "" {
onDuplicateStr = option.OnDuplicateStr
} else if len(option.OnDuplicateMap) > 0 {
for k, v := range option.OnDuplicateMap {
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
switch v.(type) {
case Raw, *Raw:
onDuplicateStr += fmt.Sprintf(
"%s=%s",
c.QuoteWord(k),
v,
)
case Counter, *Counter:
var counter Counter
switch value := v.(type) {
case Counter:
counter = value
case *Counter:
counter = *value
}
operator, columnVal := c.getCounterAlter(counter)
onDuplicateStr += fmt.Sprintf(
"%s=%s%s%s",
c.QuoteWord(k),
c.QuoteWord(counter.Field),
operator,
gconv.String(columnVal),
)
default:
onDuplicateStr += fmt.Sprintf(
"%s=VALUES(%s)",
c.QuoteWord(k),
c.QuoteWord(gconv.String(v)),
)
}
}
} else {
for _, column := range columns {
// If it's `SAVE` operation, do not automatically update the creating time.
if c.IsSoftCreatedFieldName(column) {
continue
}
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
onDuplicateStr += fmt.Sprintf(
"%s=VALUES(%s)",
c.QuoteWord(column),
c.QuoteWord(column),
)
}
}
return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil
}
// RowsToResult converts underlying data record type sql.Rows to Result type.
func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) {
if rows == nil {
return nil, nil
}
defer func() {
if err := rows.Close(); err != nil {
intlog.Errorf(ctx, `%+v`, err)
}
}()
if !rows.Next() {
return nil, nil
}
// Column names and types.
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
if len(columnTypes) > 0 {
if internalData := c.getInternalColumnFromCtx(ctx); internalData != nil {
internalData.FirstResultColumn = columnTypes[0].Name()
}
}
var (
values = make([]any, len(columnTypes))
result = make(Result, 0)
scanArgs = make([]any, len(values))
)
for i := range values {
scanArgs[i] = &values[i]
}
for {
if err = rows.Scan(scanArgs...); err != nil {
return result, err
}
record := Record{}
for i, value := range values {
if value == nil {
// DO NOT use `gvar.New(nil)` here as it creates an initialized object
// which will cause struct converting issue.
record[columnTypes[i].Name()] = nil
} else {
var (
convertedValue any
columnType = columnTypes[i]
)
if convertedValue, err = c.columnValueToLocalValue(ctx, value, columnType); err != nil {
return nil, err
}
record[columnTypes[i].Name()] = gvar.New(convertedValue)
}
}
result = append(result, record)
if !rows.Next() {
break
}
}
return result, nil
}
// OrderRandomFunction returns the SQL function for random ordering.
func (c *Core) OrderRandomFunction() string {
return "RAND()"
}
func (c *Core) columnValueToLocalValue(ctx context.Context, value any, columnType *sql.ColumnType) (any, error) {
var scanType = columnType.ScanType()
if scanType != nil {
// Common basic builtin types.
switch scanType.Kind() {
case
reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
return gconv.Convert(gconv.String(value), scanType.String()), nil
default:
}
}
// Other complex types, especially custom types.
return c.db.ConvertValueForLocal(ctx, columnType.DatabaseTypeName(), value)
}

View File

@ -0,0 +1,273 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
//
package database
import (
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gutil"
)
// GetDB returns the underlying DB.
func (c *Core) GetDB() DB {
return c.db
}
// GetLink creates and returns the underlying database link object with transaction checks.
// The parameter `master` specifies whether using the master node if master-slave configured.
func (c *Core) GetLink(ctx context.Context, master bool, schema string) (Link, error) {
tx := TXFromCtx(ctx, c.db.GetGroup())
if tx != nil {
return &txLink{tx.GetSqlTX()}, nil
}
if master {
link, err := c.db.GetCore().MasterLink(schema)
if err != nil {
return nil, err
}
return link, nil
}
link, err := c.db.GetCore().SlaveLink(schema)
if err != nil {
return nil, err
}
return link, nil
}
// MasterLink acts like function Master but with additional `schema` parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Master.
func (c *Core) MasterLink(schema ...string) (Link, error) {
db, err := c.db.Master(schema...)
if err != nil {
return nil, err
}
return &dbLink{
DB: db,
isOnMaster: true,
}, nil
}
// SlaveLink acts like function Slave but with additional `schema` parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Slave.
func (c *Core) SlaveLink(schema ...string) (Link, error) {
db, err := c.db.Slave(schema...)
if err != nil {
return nil, err
}
return &dbLink{
DB: db,
isOnMaster: false,
}, nil
}
// QuoteWord checks given string `s` a word,
// if true it quotes `s` with security chars of the database
// and returns the quoted string; or else it returns `s` without any change.
//
// The meaning of a `word` can be considered as a column name.
func (c *Core) QuoteWord(s string) string {
s = gstr.Trim(s)
if s == "" {
return s
}
charLeft, charRight := c.db.GetChars()
return doQuoteWord(s, charLeft, charRight)
}
// QuoteString quotes string with quote chars. Strings like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc".
//
// The meaning of a `string` can be considered as part of a statement string including columns.
func (c *Core) QuoteString(s string) string {
if !gregex.IsMatchString(regularFieldNameWithCommaRegPattern, s) {
return s
}
charLeft, charRight := c.db.GetChars()
return doQuoteString(s, charLeft, charRight)
}
// QuotePrefixTableName adds prefix string and quotes chars for the table.
// It handles table string like:
// "user", "user u",
// "user,user_detail",
// "user u, user_detail ut",
// "user as u, user_detail as ut".
//
// Note that, this will automatically checks the table prefix whether already added,
// if true it does nothing to the table name, or else adds the prefix to the table name.
func (c *Core) QuotePrefixTableName(table string) string {
charLeft, charRight := c.db.GetChars()
return doQuoteTableName(table, c.db.GetPrefix(), charLeft, charRight)
}
// GetChars returns the security char for current database.
// It does nothing in default.
func (c *Core) GetChars() (charLeft string, charRight string) {
return "", ""
}
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
func (c *Core) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
return
}
// TableFields retrieves and returns the fields' information of specified table of current
// schema.
//
// The parameter `link` is optional, if given nil it automatically retrieves a raw sql connection
// as its link to proceed necessary sql query.
//
// Note that it returns a map containing the field name and its corresponding fields.
// As a map is unsorted, the TableField struct has an "Index" field marks its sequence in
// the fields.
//
// It's using cache feature to enhance the performance, which is never expired util the
// process restarts.
func (c *Core) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) {
return
}
// ClearTableFields removes certain cached table fields of current configuration group.
func (c *Core) ClearTableFields(ctx context.Context, table string, schema ...string) (err error) {
tableFieldsCacheKey := genTableFieldsCacheKey(
c.db.GetGroup(),
gutil.GetOrDefaultStr(c.db.GetSchema(), schema...),
table,
)
_, err = c.innerMemCache.Remove(ctx, tableFieldsCacheKey)
return
}
// ClearTableFieldsAll removes all cached table fields of current configuration group.
func (c *Core) ClearTableFieldsAll(ctx context.Context) (err error) {
var (
keys, _ = c.innerMemCache.KeyStrings(ctx)
cachePrefix = cachePrefixTableFields
removedKeys = make([]any, 0)
)
for _, key := range keys {
if gstr.HasPrefix(key, cachePrefix) {
removedKeys = append(removedKeys, key)
}
}
if len(removedKeys) > 0 {
err = c.innerMemCache.Removes(ctx, removedKeys)
}
return
}
// ClearCache removes cached sql result of certain table.
func (c *Core) ClearCache(ctx context.Context, table string) (err error) {
var (
keys, _ = c.db.GetCache().KeyStrings(ctx)
cachePrefix = fmt.Sprintf(`%s%s@`, cachePrefixSelectCache, table)
removedKeys = make([]any, 0)
)
for _, key := range keys {
if gstr.HasPrefix(key, cachePrefix) {
removedKeys = append(removedKeys, key)
}
}
if len(removedKeys) > 0 {
err = c.db.GetCache().Removes(ctx, removedKeys)
}
return
}
// ClearCacheAll removes all cached sql result from cache
func (c *Core) ClearCacheAll(ctx context.Context) (err error) {
if err = c.db.GetCache().Clear(ctx); err != nil {
return err
}
if err = c.GetInnerMemCache().Clear(ctx); err != nil {
return err
}
return
}
// HasField determine whether the field exists in the table.
func (c *Core) HasField(ctx context.Context, table, field string, schema ...string) (bool, error) {
table = c.guessPrimaryTableName(table)
tableFields, err := c.db.TableFields(ctx, table, schema...)
if err != nil {
return false, err
}
if len(tableFields) == 0 {
return false, gerror.NewCodef(
gcode.CodeNotFound,
`empty table fields for table "%s"`, table,
)
}
fieldsArray := make([]string, len(tableFields))
for k, v := range tableFields {
fieldsArray[v.Index] = k
}
charLeft, charRight := c.db.GetChars()
field = gstr.Trim(field, charLeft+charRight)
for _, f := range fieldsArray {
if f == field {
return true, nil
}
}
return false, nil
}
// guessPrimaryTableName parses and returns the primary table name.
func (c *Core) guessPrimaryTableName(tableStr string) string {
if tableStr == "" {
return ""
}
var (
guessedTableName string
array1 = gstr.SplitAndTrim(tableStr, ",")
array2 = gstr.SplitAndTrim(array1[0], " ")
array3 = gstr.SplitAndTrim(array2[0], ".")
)
if len(array3) >= 2 {
guessedTableName = array3[1]
} else {
guessedTableName = array3[0]
}
charL, charR := c.db.GetChars()
if charL != "" || charR != "" {
guessedTableName = gstr.Trim(guessedTableName, charL+charR)
}
if !gregex.IsMatchString(regularFieldNameRegPattern, guessedTableName) {
return ""
}
return guessedTableName
}
// GetPrimaryKeys retrieves and returns the primary key field names of the specified table.
// This method extracts primary key information from TableFields.
// The parameter `schema` is optional, if not specified it uses the default schema.
func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) {
tableFields, err := c.db.TableFields(ctx, table, schema...)
if err != nil {
return nil, err
}
var primaryKeys []string
for _, field := range tableFields {
if strings.EqualFold(field.Key, "pri") {
primaryKeys = append(primaryKeys, field.Name)
}
}
return primaryKeys, nil
}

View File

@ -0,0 +1,7 @@
package database
type IDao interface {
DB() DB
TableName() string
Columns() any
}

162
database/gdb_database.go Normal file
View File

@ -0,0 +1,162 @@
package database
import (
"context"
"fmt"
"git.magicany.cc/black1552/gin-base/config"
"git.magicany.cc/black1552/gin-base/database/instance"
"git.magicany.cc/black1552/gin-base/database/intlog"
customLog "git.magicany.cc/black1552/gin-base/log"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
const (
frameCoreComponentNameDatabase = "core.component.database"
ConfigNodeNameDatabase = "database"
)
func Database(name ...string) gdb.DB {
var (
ctx = context.Background()
group = gdb.DefaultGroupName
)
if len(name) > 0 && name[0] != "" {
group = name[0]
}
instanceKey := fmt.Sprintf("%s.%s", frameCoreComponentNameDatabase, group)
db := instance.GetOrSetFuncLock(instanceKey, func() interface{} {
// It ignores returned error to avoid file no found error while it's not necessary.
var (
configMap map[string]interface{}
configNodeKey = ConfigNodeNameDatabase
)
// It firstly searches the configuration of the instance name.
if configData := config.GetAllConfig(); len(configData) > 0 {
if v, _ := gutil.MapPossibleItemByKey(configData, ConfigNodeNameDatabase); v != "" {
configNodeKey = v
}
}
if v := config.GetConfigValue(configNodeKey); !v.IsEmpty() {
configMap = v.Map()
}
// 检查配置是否存在
if len(configMap) == 0 {
// 从 config 包获取所有配置,检查是否包含数据库配置
allConfig := config.GetAllConfig()
if len(allConfig) == 0 {
panic(gerror.NewCodef(
gcode.CodeMissingConfiguration,
`database initialization failed: configuration file is empty or invalid`,
))
}
// 检查是否配置了数据库节点
if _, exists := allConfig["database"]; !exists {
// 尝试其他可能的键名
found := false
for key := range allConfig {
if key == "DATABASE" || key == "Database" {
found = true
break
}
}
if !found {
panic(gerror.NewCodef(
gcode.CodeMissingConfiguration,
`database initialization failed: configuration missing for database node "%s"`,
ConfigNodeNameDatabase,
))
}
}
}
if len(configMap) == 0 {
configMap = make(map[string]interface{})
}
// Parse `m` as map-slice and adds it to global configurations for package gdb.
for g, groupConfig := range configMap {
cg := ConfigGroup{}
switch value := groupConfig.(type) {
case []interface{}:
for _, v := range value {
if node := parseDBConfigNode(v); node != nil {
cg = append(cg, *node)
}
}
case map[string]interface{}:
if node := parseDBConfigNode(value); node != nil {
cg = append(cg, *node)
}
}
if len(cg) > 0 {
if GetConfig(group) == nil {
intlog.Printf(ctx, "add configuration for group: %s, %#v", g, cg)
SetConfigGroup(g, cg)
} else {
intlog.Printf(ctx, "ignore configuration as it already exists for group: %s, %#v", g, cg)
intlog.Printf(ctx, "%s, %#v", g, cg)
}
}
}
// Parse `m` as a single node configuration,
// which is the default group configuration.
if node := parseDBConfigNode(configMap); node != nil {
cg := ConfigGroup{}
if node.Link != "" || node.Host != "" {
cg = append(cg, *node)
}
if len(cg) > 0 {
if GetConfig(group) == nil {
intlog.Printf(ctx, "add configuration for group: %s, %#v", DefaultGroupName, cg)
SetConfigGroup(DefaultGroupName, cg)
} else {
intlog.Printf(
ctx,
"ignore configuration as it already exists for group: %s, %#v",
DefaultGroupName, cg,
)
intlog.Printf(ctx, "%s, %#v", DefaultGroupName, cg)
}
}
}
// Create a new ORM object with given configurations.
if db, err := NewByGroup(name...); err == nil {
// 自动初始化自定义日志器并设置到数据库
db.SetLogger(customLog.GetLogger())
return db
} else {
// If panics, often because it does not find its configuration for given group.
panic(err)
}
return nil
})
if db != nil {
return db.(gdb.DB)
}
return nil
}
func parseDBConfigNode(value interface{}) *ConfigNode {
nodeMap, ok := value.(map[string]interface{})
if !ok {
return nil
}
var (
node = &ConfigNode{}
err = gconv.Struct(nodeMap, node)
)
if err != nil {
panic(err)
}
// Find possible `Link` configuration content.
if _, v := gutil.MapPossibleItemByKey(nodeMap, "Link"); v != nil {
node.Link = gconv.String(v)
}
return node
}

View File

@ -0,0 +1,46 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
)
// DriverDefault is the default driver for mysql database, which does nothing.
type DriverDefault struct {
*Core
}
func init() {
if err := Register("default", &DriverDefault{}); err != nil {
panic(err)
}
}
// New creates and returns a database object for mysql.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverDefault) New(core *Core, node *ConfigNode) (DB, error) {
return &DriverDefault{
Core: core,
}, nil
}
// Open creates and returns an underlying sql.DB object for mysql.
// Note that it converts time.Time argument to local timezone in default.
func (d *DriverDefault) Open(config *ConfigNode) (db *sql.DB, err error) {
return
}
// PingMaster pings the master node to check authentication or keeps the connection alive.
func (d *DriverDefault) PingMaster() error {
return nil
}
// PingSlave pings the slave node to check authentication or keeps the connection alive.
func (d *DriverDefault) PingSlave() error {
return nil
}

View File

@ -0,0 +1,31 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// DriverWrapper is a driver wrapper for extending features with embedded driver.
type DriverWrapper struct {
driver Driver
}
// New creates and returns a database object for mysql.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverWrapper) New(core *Core, node *ConfigNode) (DB, error) {
db, err := d.driver.New(core, node)
if err != nil {
return nil, err
}
return &DriverWrapperDB{
DB: db,
}, nil
}
// newDriverWrapper creates and returns a driver wrapper.
func newDriverWrapper(driver Driver) Driver {
return &DriverWrapper{
driver: driver,
}
}

View File

@ -0,0 +1,131 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"database/sql"
"fmt"
"git.magicany.cc/black1552/gin-base/database/intlog"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gcache"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gutil"
)
// DriverWrapperDB is a DB wrapper for extending features with embedded DB.
type DriverWrapperDB struct {
DB
}
// Open creates and returns an underlying sql.DB object for pgsql.
// https://pkg.go.dev/github.com/lib/pq
func (d *DriverWrapperDB) Open(node *ConfigNode) (db *sql.DB, err error) {
var ctx = d.GetCtx()
intlog.PrintFunc(ctx, func() string {
return fmt.Sprintf(`open new connection:%s`, gjson.MustEncode(node))
})
return d.DB.Open(node)
}
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverWrapperDB) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
ctx = context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{})
return d.DB.Tables(ctx, schema...)
}
// TableFields retrieves and returns the fields' information of specified table of current
// schema.
//
// The parameter `link` is optional, if given nil it automatically retrieves a raw sql connection
// as its link to proceed necessary sql query.
//
// Note that it returns a map containing the field name and its corresponding fields.
// As a map is unsorted, the TableField struct has an "Index" field marks its sequence in
// the fields.
//
// It's using cache feature to enhance the performance, which is never expired util the
// process restarts.
func (d *DriverWrapperDB) TableFields(
ctx context.Context, table string, schema ...string,
) (fields map[string]*TableField, err error) {
if table == "" {
return nil, nil
}
charL, charR := d.GetChars()
table = gstr.Trim(table, charL+charR)
if gstr.Contains(table, " ") {
return nil, gerror.NewCode(
gcode.CodeInvalidParameter,
"function TableFields supports only single table operations",
)
}
var (
innerMemCache = d.GetCore().GetInnerMemCache()
// prefix:group@schema#table
cacheKey = genTableFieldsCacheKey(
d.GetGroup(),
gutil.GetOrDefaultStr(d.GetSchema(), schema...),
table,
)
cacheFunc = func(ctx context.Context) (any, error) {
return d.DB.TableFields(
context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{}),
table, schema...,
)
}
value *gvar.Var
)
value, err = innerMemCache.GetOrSetFuncLock(
ctx, cacheKey, cacheFunc, gcache.DurationNoExpire,
)
if err != nil {
return
}
if !value.IsNil() {
fields = value.Val().(map[string]*TableField)
}
return
}
// DoInsert inserts or updates data for given table.
// This function is usually used for custom interface definition, you do not need call it manually.
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter `option` values are as follows:
// InsertOptionDefault: just insert, if there's unique/primary key in the data, it returns error;
// InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
// InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one;
// InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting;
func (d *DriverWrapperDB) DoInsert(
ctx context.Context, link Link, table string, list List, option DoInsertOption,
) (result sql.Result, err error) {
if len(list) == 0 {
return nil, gerror.NewCodef(
gcode.CodeInvalidRequest,
`data list is empty for %s operation`,
GetInsertOperationByOption(option.InsertOption),
)
}
// Convert data type before commit it to underlying db driver.
for i, item := range list {
list[i], err = d.GetCore().ConvertDataForRecord(ctx, item, table)
if err != nil {
return nil, err
}
}
return d.DB.DoInsert(ctx, link, table, list, option)
}

1017
database/gdb_func.go Normal file

File diff suppressed because it is too large Load Diff

350
database/gdb_model.go Normal file
View File

@ -0,0 +1,350 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"fmt"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// Model is core struct implementing the DAO for ORM.
type Model struct {
db DB // Underlying DB interface.
tx TX // Underlying TX interface.
rawSql string // rawSql is the raw SQL string which marks a raw SQL based Model not a table based Model.
schema string // Custom database schema.
linkType int // Mark for operation on master or slave.
tablesInit string // Table names when model initialization.
tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud".
fields []any // Operation fields, multiple fields joined using char ','.
fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering.
withArray []any // Arguments for With feature.
withAll bool // Enable model association operations on all objects that have "with" tag in the struct.
extraArgs []any // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver.
whereBuilder *WhereBuilder // Condition builder for where operation.
groupBy string // Used for "group by" statement.
orderBy string // Used for "order by" statement.
having []any // Used for "having..." statement.
start int // Used for "select ... start, limit ..." statement.
limit int // Used for "select ... start, limit ..." statement.
option int // Option for extra operation features.
offset int // Offset statement for some databases grammar.
partition string // Partition table partition name.
data any // Data for operation, which can be type of map/[]map/struct/*struct/string, etc.
batch int // Batch number for batch Insert/Replace/Save operations.
filter bool // Filter data and where key-value pairs according to the fields of the table.
distinct string // Force the query to only return distinct results.
lockInfo string // Lock for update or in shared lock.
cacheEnabled bool // Enable sql result cache feature, which is mainly for indicating cache duration(especially 0) usage.
cacheOption CacheOption // Cache option for query statement.
pageCacheOption []CacheOption // Cache option for paging query statement.
hookHandler HookHandler // Hook functions for model hook feature.
unscoped bool // Disables soft deleting features when select/delete operations.
safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model.
onDuplicate any // onDuplicate is used for on Upsert clause.
onDuplicateEx any // onDuplicateEx is used for excluding some columns on Upsert clause.
onConflict any // onConflict is used for conflict keys on Upsert clause.
tableAliasMap map[string]string // Table alias to true table name, usually used in join statements.
softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model.
shardingConfig ShardingConfig // ShardingConfig for database/table sharding feature.
shardingValue any // Sharding value for sharding feature.
}
// ModelHandler is a function that handles given Model and returns a new Model that is custom modified.
type ModelHandler func(m *Model) *Model
// ChunkHandler is a function that is used in function Chunk, which handles given Result and error.
// It returns true if it wants to continue chunking, or else it returns false to stop chunking.
type ChunkHandler func(result Result, err error) bool
const (
linkTypeMaster = 1
linkTypeSlave = 2
defaultField = "*"
whereHolderOperatorWhere = 1
whereHolderOperatorAnd = 2
whereHolderOperatorOr = 3
whereHolderTypeDefault = "Default"
whereHolderTypeNoArgs = "NoArgs"
whereHolderTypeIn = "In"
)
// Model creates and returns a new ORM model from given schema.
// The parameter `tableNameQueryOrStruct` can be more than one table names, and also alias name, like:
// 1. Model names:
// db.Model("user")
// db.Model("user u")
// db.Model("user, user_detail")
// db.Model("user u, user_detail ud")
// 2. Model name with alias:
// db.Model("user", "u")
// 3. Model name with sub-query:
// db.Model("? AS a, ? AS b", subQuery1, subQuery2)
func (c *Core) Model(tableNameQueryOrStruct ...any) *Model {
var (
ctx = c.db.GetCtx()
tableStr string
tableName string
extraArgs []any
)
// Model creation with sub-query.
if len(tableNameQueryOrStruct) > 1 {
conditionStr := gconv.String(tableNameQueryOrStruct[0])
if gstr.Contains(conditionStr, "?") {
whereHolder := WhereHolder{
Where: conditionStr,
Args: tableNameQueryOrStruct[1:],
}
tableStr, extraArgs = formatWhereHolder(ctx, c.db, formatWhereHolderInput{
WhereHolder: whereHolder,
OmitNil: false,
OmitEmpty: false,
Schema: "",
Table: "",
})
}
}
// Normal model creation.
if tableStr == "" {
tableNames := make([]string, len(tableNameQueryOrStruct))
for k, v := range tableNameQueryOrStruct {
if s, ok := v.(string); ok {
tableNames[k] = s
} else if tableName = getTableNameFromOrmTag(v); tableName != "" {
tableNames[k] = tableName
}
}
if len(tableNames) > 1 {
tableStr = fmt.Sprintf(
`%s AS %s`, c.QuotePrefixTableName(tableNames[0]), c.QuoteWord(tableNames[1]),
)
} else if len(tableNames) == 1 {
tableStr = c.QuotePrefixTableName(tableNames[0])
}
}
m := &Model{
db: c.db,
schema: c.schema,
tablesInit: tableStr,
tables: tableStr,
start: -1,
offset: -1,
filter: true,
extraArgs: extraArgs,
tableAliasMap: make(map[string]string),
}
m.whereBuilder = m.Builder()
if defaultModelSafe {
m.safe = true
}
return m
}
// Raw creates and returns a model based on a raw sql not a table.
// Example:
//
// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result)
func (c *Core) Raw(rawSql string, args ...any) *Model {
model := c.Model()
model.rawSql = rawSql
model.extraArgs = args
return model
}
// Raw sets current model as a raw sql model.
// Example:
//
// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result)
//
// See Core.Raw.
func (m *Model) Raw(rawSql string, args ...any) *Model {
model := m.db.Raw(rawSql, args...)
model.db = m.db
model.tx = m.tx
return model
}
func (tx *TXCore) Raw(rawSql string, args ...any) *Model {
return tx.Model().Raw(rawSql, args...)
}
// With creates and returns an ORM model based on metadata of given object.
func (c *Core) With(objects ...any) *Model {
return c.db.Model().With(objects...)
}
// Partition sets Partition name.
// Example:
// dao.User.Ctx(ctx).Partition"p1","p2","p3").All()
func (m *Model) Partition(partitions ...string) *Model {
model := m.getModel()
model.partition = gstr.Join(partitions, ",")
return model
}
// Model acts like Core.Model except it operates on transaction.
// See Core.Model.
func (tx *TXCore) Model(tableNameQueryOrStruct ...any) *Model {
model := tx.db.Model(tableNameQueryOrStruct...)
model.db = tx.db
model.tx = tx
return model
}
// With acts like Core.With except it operates on transaction.
// See Core.With.
func (tx *TXCore) With(object any) *Model {
return tx.Model().With(object)
}
// Ctx sets the context for current operation.
func (m *Model) Ctx(ctx context.Context) *Model {
if ctx == nil {
return m
}
model := m.getModel()
model.db = model.db.Ctx(ctx)
if m.tx != nil {
model.tx = model.tx.Ctx(ctx)
}
return model
}
// GetCtx returns the context for current Model.
// It returns `context.Background()` is there's no context previously set.
func (m *Model) GetCtx() context.Context {
if m.tx != nil && m.tx.GetCtx() != nil {
return m.tx.GetCtx()
}
return m.db.GetCtx()
}
// As sets an alias name for current table.
func (m *Model) As(as string) *Model {
if m.tables != "" {
model := m.getModel()
split := " JOIN "
if gstr.ContainsI(model.tables, split) {
// For join table.
array := gstr.Split(model.tables, split)
array[len(array)-1], _ = gregex.ReplaceString(`(.+) ON`, fmt.Sprintf(`$1 AS %s ON`, as), array[len(array)-1])
model.tables = gstr.Join(array, split)
} else {
// For base table.
model.tables = gstr.TrimRight(model.tables) + " AS " + as
}
return model
}
return m
}
// DB sets/changes the db object for current operation.
func (m *Model) DB(db DB) *Model {
model := m.getModel()
model.db = db
return model
}
// TX sets/changes the transaction for current operation.
func (m *Model) TX(tx TX) *Model {
model := m.getModel()
model.db = tx.GetDB()
model.tx = tx
return model
}
// Schema sets the schema for current operation.
func (m *Model) Schema(schema string) *Model {
model := m.getModel()
model.schema = schema
return model
}
// Clone creates and returns a new model which is a Clone of current model.
// Note that it uses deep-copy for the Clone.
func (m *Model) Clone() *Model {
newModel := (*Model)(nil)
if m.tx != nil {
newModel = m.tx.Model(m.tablesInit)
} else {
newModel = m.db.Model(m.tablesInit)
}
// Basic attributes copy.
*newModel = *m
// WhereBuilder copy, note the attribute pointer.
newModel.whereBuilder = m.whereBuilder.Clone()
newModel.whereBuilder.model = newModel
// Shallow copy slice attributes.
if n := len(m.fields); n > 0 {
newModel.fields = make([]any, n)
copy(newModel.fields, m.fields)
}
if n := len(m.fieldsEx); n > 0 {
newModel.fieldsEx = make([]any, n)
copy(newModel.fieldsEx, m.fieldsEx)
}
if n := len(m.extraArgs); n > 0 {
newModel.extraArgs = make([]any, n)
copy(newModel.extraArgs, m.extraArgs)
}
if n := len(m.withArray); n > 0 {
newModel.withArray = make([]any, n)
copy(newModel.withArray, m.withArray)
}
if n := len(m.having); n > 0 {
newModel.having = make([]any, n)
copy(newModel.having, m.having)
}
return newModel
}
// Master marks the following operation on master node.
func (m *Model) Master() *Model {
model := m.getModel()
model.linkType = linkTypeMaster
return model
}
// Slave marks the following operation on slave node.
// Note that it makes sense only if there's any slave node configured.
func (m *Model) Slave() *Model {
model := m.getModel()
model.linkType = linkTypeSlave
return model
}
// Safe marks this model safe or unsafe. If safe is true, it clones and returns a new model object
// whenever the operation done, or else it changes the attribute of current model.
func (m *Model) Safe(safe ...bool) *Model {
if len(safe) > 0 {
m.safe = safe[0]
} else {
m.safe = true
}
return m
}
// Args sets custom arguments for model operation.
func (m *Model) Args(args ...any) *Model {
model := m.getModel()
model.extraArgs = append(model.extraArgs, args)
return model
}
// Handler calls each of `handlers` on current Model and returns a new Model.
// ModelHandler is a function that handles given Model and returns a new Model that is custom modified.
func (m *Model) Handler(handlers ...ModelHandler) *Model {
model := m.getModel()
for _, handler := range handlers {
model = handler(model)
}
return model
}

View File

@ -0,0 +1,124 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"fmt"
)
// WhereBuilder holds multiple where conditions in a group.
type WhereBuilder struct {
model *Model // A WhereBuilder should be bound to certain Model.
whereHolder []WhereHolder // Condition strings for where operation.
}
// WhereHolder is the holder for where condition preparing.
type WhereHolder struct {
Type string // Type of this holder.
Operator int // Operator for this holder.
Where any // Where parameter, which can commonly be type of string/map/struct.
Args []any // Arguments for where parameter.
Prefix string // Field prefix, eg: "user.", "order.".
}
// Builder creates and returns a WhereBuilder. Please note that the builder is chain-safe.
func (m *Model) Builder() *WhereBuilder {
b := &WhereBuilder{
model: m,
whereHolder: make([]WhereHolder, 0),
}
return b
}
// getBuilder creates and returns a cloned WhereBuilder of current WhereBuilder
func (b *WhereBuilder) getBuilder() *WhereBuilder {
return b.Clone()
}
// Clone clones and returns a WhereBuilder that is a copy of current one.
func (b *WhereBuilder) Clone() *WhereBuilder {
newBuilder := b.model.Builder()
newBuilder.whereHolder = make([]WhereHolder, len(b.whereHolder))
copy(newBuilder.whereHolder, b.whereHolder)
return newBuilder
}
// Build builds current WhereBuilder and returns the condition string and parameters.
func (b *WhereBuilder) Build() (conditionWhere string, conditionArgs []any) {
var (
ctx = b.model.GetCtx()
autoPrefix = b.model.getAutoPrefix()
tableForMappingAndFiltering = b.model.tables
)
if len(b.whereHolder) > 0 {
for _, holder := range b.whereHolder {
if holder.Prefix == "" {
holder.Prefix = autoPrefix
}
switch holder.Operator {
case whereHolderOperatorWhere, whereHolderOperatorAnd:
newWhere, newArgs := formatWhereHolder(ctx, b.model.db, formatWhereHolderInput{
WhereHolder: holder,
OmitNil: b.model.option&optionOmitNilWhere > 0,
OmitEmpty: b.model.option&optionOmitEmptyWhere > 0,
Schema: b.model.schema,
Table: tableForMappingAndFiltering,
})
if len(newWhere) > 0 {
if len(conditionWhere) == 0 {
conditionWhere = newWhere
} else if conditionWhere[0] == '(' {
conditionWhere = fmt.Sprintf(`%s AND (%s)`, conditionWhere, newWhere)
} else {
conditionWhere = fmt.Sprintf(`(%s) AND (%s)`, conditionWhere, newWhere)
}
conditionArgs = append(conditionArgs, newArgs...)
}
case whereHolderOperatorOr:
newWhere, newArgs := formatWhereHolder(ctx, b.model.db, formatWhereHolderInput{
WhereHolder: holder,
OmitNil: b.model.option&optionOmitNilWhere > 0,
OmitEmpty: b.model.option&optionOmitEmptyWhere > 0,
Schema: b.model.schema,
Table: tableForMappingAndFiltering,
})
if len(newWhere) > 0 {
if len(conditionWhere) == 0 {
conditionWhere = newWhere
} else if conditionWhere[0] == '(' {
conditionWhere = fmt.Sprintf(`%s OR (%s)`, conditionWhere, newWhere)
} else {
conditionWhere = fmt.Sprintf(`(%s) OR (%s)`, conditionWhere, newWhere)
}
conditionArgs = append(conditionArgs, newArgs...)
}
}
}
}
return
}
// convertWhereBuilder converts parameter `where` to condition string and parameters if `where` is also a WhereBuilder.
func (b *WhereBuilder) convertWhereBuilder(where any, args []any) (newWhere any, newArgs []any) {
var builder *WhereBuilder
switch v := where.(type) {
case WhereBuilder:
builder = &v
case *WhereBuilder:
builder = v
}
if builder != nil {
conditionWhere, conditionArgs := builder.Build()
if conditionWhere != "" && (len(b.whereHolder) == 0 || len(builder.whereHolder) > 1) {
conditionWhere = "(" + conditionWhere + ")"
}
return conditionWhere, conditionArgs
}
return where, args
}

View File

@ -0,0 +1,171 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"fmt"
"github.com/gogf/gf/v2/text/gstr"
)
// doWhereType sets the condition statement for the model. The parameter `where` can be type of
// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times,
// multiple conditions will be joined into where statement using "AND".
func (b *WhereBuilder) doWhereType(whereType string, where any, args ...any) *WhereBuilder {
where, args = b.convertWhereBuilder(where, args)
builder := b.getBuilder()
if builder.whereHolder == nil {
builder.whereHolder = make([]WhereHolder, 0)
}
if whereType == "" {
if len(args) == 0 {
whereType = whereHolderTypeNoArgs
} else {
whereType = whereHolderTypeDefault
}
}
builder.whereHolder = append(builder.whereHolder, WhereHolder{
Type: whereType,
Operator: whereHolderOperatorWhere,
Where: where,
Args: args,
})
return builder
}
// doWherefType builds condition string using fmt.Sprintf and arguments.
// Note that if the number of `args` is more than the placeholder in `format`,
// the extra `args` will be used as the where condition arguments of the Model.
func (b *WhereBuilder) doWherefType(t string, format string, args ...any) *WhereBuilder {
var (
placeHolderCount = gstr.Count(format, "?")
conditionStr = fmt.Sprintf(format, args[:len(args)-placeHolderCount]...)
)
return b.doWhereType(t, conditionStr, args[len(args)-placeHolderCount:]...)
}
// Where sets the condition statement for the builder. The parameter `where` can be type of
// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times,
// multiple conditions will be joined into where statement using "AND".
// Eg:
// Where("uid=10000")
// Where("uid", 10000)
// Where("money>? AND name like ?", 99999, "vip_%")
// Where("uid", 1).Where("name", "john")
// Where("status IN (?)", g.Slice{1,2,3})
// Where("age IN(?,?)", 18, 50)
// Where(User{ Id : 1, UserName : "john"}).
func (b *WhereBuilder) Where(where any, args ...any) *WhereBuilder {
return b.doWhereType(``, where, args...)
}
// Wheref builds condition string using fmt.Sprintf and arguments.
// Note that if the number of `args` is more than the placeholder in `format`,
// the extra `args` will be used as the where condition arguments of the Model.
// Eg:
// Wheref(`amount<? and status=%s`, "paid", 100) => WHERE `amount`<100 and status='paid'
// Wheref(`amount<%d and status=%s`, 100, "paid") => WHERE `amount`<100 and status='paid'
func (b *WhereBuilder) Wheref(format string, args ...any) *WhereBuilder {
return b.doWherefType(``, format, args...)
}
// WherePri does the same logic as Model.Where except that if the parameter `where`
// is a single condition like int/string/float/slice, it treats the condition as the primary
// key value. That is, if primary key is "id" and given `where` parameter as "123", the
// WherePri function treats the condition as "id=123", but Model.Where treats the condition
// as string "123".
func (b *WhereBuilder) WherePri(where any, args ...any) *WhereBuilder {
if len(args) > 0 {
return b.Where(where, args...)
}
newWhere := GetPrimaryKeyCondition(b.model.getPrimaryKey(), where)
return b.Where(newWhere[0], newWhere[1:]...)
}
// WhereLT builds `column < value` statement.
func (b *WhereBuilder) WhereLT(column string, value any) *WhereBuilder {
return b.Wheref(`%s < ?`, b.model.QuoteWord(column), value)
}
// WhereLTE builds `column <= value` statement.
func (b *WhereBuilder) WhereLTE(column string, value any) *WhereBuilder {
return b.Wheref(`%s <= ?`, b.model.QuoteWord(column), value)
}
// WhereGT builds `column > value` statement.
func (b *WhereBuilder) WhereGT(column string, value any) *WhereBuilder {
return b.Wheref(`%s > ?`, b.model.QuoteWord(column), value)
}
// WhereGTE builds `column >= value` statement.
func (b *WhereBuilder) WhereGTE(column string, value any) *WhereBuilder {
return b.Wheref(`%s >= ?`, b.model.QuoteWord(column), value)
}
// WhereBetween builds `column BETWEEN min AND max` statement.
func (b *WhereBuilder) WhereBetween(column string, min, max any) *WhereBuilder {
return b.Wheref(`%s BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
}
// WhereLike builds `column LIKE like` statement.
func (b *WhereBuilder) WhereLike(column string, like string) *WhereBuilder {
return b.Wheref(`%s LIKE ?`, b.model.QuoteWord(column), like)
}
// WhereIn builds `column IN (in)` statement.
func (b *WhereBuilder) WhereIn(column string, in any) *WhereBuilder {
return b.doWherefType(whereHolderTypeIn, `%s IN (?)`, b.model.QuoteWord(column), in)
}
// WhereNull builds `columns[0] IS NULL AND columns[1] IS NULL ...` statement.
func (b *WhereBuilder) WhereNull(columns ...string) *WhereBuilder {
builder := b
for _, column := range columns {
builder = builder.Wheref(`%s IS NULL`, b.model.QuoteWord(column))
}
return builder
}
// WhereNotBetween builds `column NOT BETWEEN min AND max` statement.
func (b *WhereBuilder) WhereNotBetween(column string, min, max any) *WhereBuilder {
return b.Wheref(`%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
}
// WhereNotLike builds `column NOT LIKE like` statement.
func (b *WhereBuilder) WhereNotLike(column string, like any) *WhereBuilder {
return b.Wheref(`%s NOT LIKE ?`, b.model.QuoteWord(column), like)
}
// WhereNot builds `column != value` statement.
func (b *WhereBuilder) WhereNot(column string, value any) *WhereBuilder {
return b.Wheref(`%s != ?`, b.model.QuoteWord(column), value)
}
// WhereNotIn builds `column NOT IN (in)` statement.
func (b *WhereBuilder) WhereNotIn(column string, in any) *WhereBuilder {
return b.doWherefType(whereHolderTypeIn, `%s NOT IN (?)`, b.model.QuoteWord(column), in)
}
// WhereNotNull builds `columns[0] IS NOT NULL AND columns[1] IS NOT NULL ...` statement.
func (b *WhereBuilder) WhereNotNull(columns ...string) *WhereBuilder {
builder := b
for _, column := range columns {
builder = builder.Wheref(`%s IS NOT NULL`, b.model.QuoteWord(column))
}
return builder
}
// WhereExists builds `EXISTS (subQuery)` statement.
func (b *WhereBuilder) WhereExists(subQuery *Model) *WhereBuilder {
return b.Wheref(`EXISTS (?)`, subQuery)
}
// WhereNotExists builds `NOT EXISTS (subQuery)` statement.
func (b *WhereBuilder) WhereNotExists(subQuery *Model) *WhereBuilder {
return b.Wheref(`NOT EXISTS (?)`, subQuery)
}

View File

@ -0,0 +1,101 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// WherePrefix performs as Where, but it adds prefix to each field in where statement.
// Eg:
// WherePrefix("order", "status", "paid") => WHERE `order`.`status`='paid'
// WherePrefix("order", struct{Status:"paid", "channel":"bank"}) => WHERE `order`.`status`='paid' AND `order`.`channel`='bank'
func (b *WhereBuilder) WherePrefix(prefix string, where any, args ...any) *WhereBuilder {
where, args = b.convertWhereBuilder(where, args)
builder := b.getBuilder()
if builder.whereHolder == nil {
builder.whereHolder = make([]WhereHolder, 0)
}
builder.whereHolder = append(builder.whereHolder, WhereHolder{
Type: whereHolderTypeDefault,
Operator: whereHolderOperatorWhere,
Where: where,
Args: args,
Prefix: prefix,
})
return builder
}
// WherePrefixLT builds `prefix.column < value` statement.
func (b *WhereBuilder) WherePrefixLT(prefix string, column string, value any) *WhereBuilder {
return b.Wheref(`%s.%s < ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WherePrefixLTE builds `prefix.column <= value` statement.
func (b *WhereBuilder) WherePrefixLTE(prefix string, column string, value any) *WhereBuilder {
return b.Wheref(`%s.%s <= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WherePrefixGT builds `prefix.column > value` statement.
func (b *WhereBuilder) WherePrefixGT(prefix string, column string, value any) *WhereBuilder {
return b.Wheref(`%s.%s > ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WherePrefixGTE builds `prefix.column >= value` statement.
func (b *WhereBuilder) WherePrefixGTE(prefix string, column string, value any) *WhereBuilder {
return b.Wheref(`%s.%s >= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WherePrefixBetween builds `prefix.column BETWEEN min AND max` statement.
func (b *WhereBuilder) WherePrefixBetween(prefix string, column string, min, max any) *WhereBuilder {
return b.Wheref(`%s.%s BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
}
// WherePrefixLike builds `prefix.column LIKE like` statement.
func (b *WhereBuilder) WherePrefixLike(prefix string, column string, like any) *WhereBuilder {
return b.Wheref(`%s.%s LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
}
// WherePrefixIn builds `prefix.column IN (in)` statement.
func (b *WhereBuilder) WherePrefixIn(prefix string, column string, in any) *WhereBuilder {
return b.doWherefType(whereHolderTypeIn, `%s.%s IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
}
// WherePrefixNull builds `prefix.columns[0] IS NULL AND prefix.columns[1] IS NULL ...` statement.
func (b *WhereBuilder) WherePrefixNull(prefix string, columns ...string) *WhereBuilder {
builder := b
for _, column := range columns {
builder = builder.Wheref(`%s.%s IS NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
}
return builder
}
// WherePrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement.
func (b *WhereBuilder) WherePrefixNotBetween(prefix string, column string, min, max any) *WhereBuilder {
return b.Wheref(`%s.%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
}
// WherePrefixNotLike builds `prefix.column NOT LIKE like` statement.
func (b *WhereBuilder) WherePrefixNotLike(prefix string, column string, like any) *WhereBuilder {
return b.Wheref(`%s.%s NOT LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
}
// WherePrefixNot builds `prefix.column != value` statement.
func (b *WhereBuilder) WherePrefixNot(prefix string, column string, value any) *WhereBuilder {
return b.Wheref(`%s.%s != ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WherePrefixNotIn builds `prefix.column NOT IN (in)` statement.
func (b *WhereBuilder) WherePrefixNotIn(prefix string, column string, in any) *WhereBuilder {
return b.doWherefType(whereHolderTypeIn, `%s.%s NOT IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
}
// WherePrefixNotNull builds `prefix.columns[0] IS NOT NULL AND prefix.columns[1] IS NOT NULL ...` statement.
func (b *WhereBuilder) WherePrefixNotNull(prefix string, columns ...string) *WhereBuilder {
builder := b
for _, column := range columns {
builder = builder.Wheref(`%s.%s IS NOT NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
}
return builder
}

View File

@ -0,0 +1,125 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"fmt"
"github.com/gogf/gf/v2/text/gstr"
)
// WhereOr adds "OR" condition to the where statement.
func (b *WhereBuilder) doWhereOrType(t string, where any, args ...any) *WhereBuilder {
where, args = b.convertWhereBuilder(where, args)
builder := b.getBuilder()
if builder.whereHolder == nil {
builder.whereHolder = make([]WhereHolder, 0)
}
builder.whereHolder = append(builder.whereHolder, WhereHolder{
Type: t,
Operator: whereHolderOperatorOr,
Where: where,
Args: args,
})
return builder
}
// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments.
func (b *WhereBuilder) doWhereOrfType(t string, format string, args ...any) *WhereBuilder {
var (
placeHolderCount = gstr.Count(format, "?")
conditionStr = fmt.Sprintf(format, args[:len(args)-placeHolderCount]...)
)
return b.doWhereOrType(t, conditionStr, args[len(args)-placeHolderCount:]...)
}
// WhereOr adds "OR" condition to the where statement.
func (b *WhereBuilder) WhereOr(where any, args ...any) *WhereBuilder {
return b.doWhereOrType(``, where, args...)
}
// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments.
// Eg:
// WhereOrf(`amount<? and status=%s`, "paid", 100) => WHERE xxx OR `amount`<100 and status='paid'
// WhereOrf(`amount<%d and status=%s`, 100, "paid") => WHERE xxx OR `amount`<100 and status='paid'
func (b *WhereBuilder) WhereOrf(format string, args ...any) *WhereBuilder {
return b.doWhereOrfType(``, format, args...)
}
// WhereOrNot builds `column != value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrNot(column string, value any) *WhereBuilder {
return b.WhereOrf(`%s != ?`, column, value)
}
// WhereOrLT builds `column < value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrLT(column string, value any) *WhereBuilder {
return b.WhereOrf(`%s < ?`, column, value)
}
// WhereOrLTE builds `column <= value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrLTE(column string, value any) *WhereBuilder {
return b.WhereOrf(`%s <= ?`, column, value)
}
// WhereOrGT builds `column > value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrGT(column string, value any) *WhereBuilder {
return b.WhereOrf(`%s > ?`, column, value)
}
// WhereOrGTE builds `column >= value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrGTE(column string, value any) *WhereBuilder {
return b.WhereOrf(`%s >= ?`, column, value)
}
// WhereOrBetween builds `column BETWEEN min AND max` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrBetween(column string, min, max any) *WhereBuilder {
return b.WhereOrf(`%s BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
}
// WhereOrLike builds `column LIKE 'like'` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrLike(column string, like any) *WhereBuilder {
return b.WhereOrf(`%s LIKE ?`, b.model.QuoteWord(column), like)
}
// WhereOrIn builds `column IN (in)` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrIn(column string, in any) *WhereBuilder {
return b.doWhereOrfType(whereHolderTypeIn, `%s IN (?)`, b.model.QuoteWord(column), in)
}
// WhereOrNull builds `columns[0] IS NULL OR columns[1] IS NULL ...` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrNull(columns ...string) *WhereBuilder {
var builder *WhereBuilder
for _, column := range columns {
builder = b.WhereOrf(`%s IS NULL`, b.model.QuoteWord(column))
}
return builder
}
// WhereOrNotBetween builds `column NOT BETWEEN min AND max` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrNotBetween(column string, min, max any) *WhereBuilder {
return b.WhereOrf(`%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
}
// WhereOrNotLike builds `column NOT LIKE like` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrNotLike(column string, like any) *WhereBuilder {
return b.WhereOrf(`%s NOT LIKE ?`, b.model.QuoteWord(column), like)
}
// WhereOrNotIn builds `column NOT IN (in)` statement.
func (b *WhereBuilder) WhereOrNotIn(column string, in any) *WhereBuilder {
return b.doWhereOrfType(whereHolderTypeIn, `%s NOT IN (?)`, b.model.QuoteWord(column), in)
}
// WhereOrNotNull builds `columns[0] IS NOT NULL OR columns[1] IS NOT NULL ...` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrNotNull(columns ...string) *WhereBuilder {
builder := b
for _, column := range columns {
builder = builder.WhereOrf(`%s IS NOT NULL`, b.model.QuoteWord(column))
}
return builder
}

View File

@ -0,0 +1,98 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// WhereOrPrefix performs as WhereOr, but it adds prefix to each field in where statement.
// Eg:
// WhereOrPrefix("order", "status", "paid") => WHERE xxx OR (`order`.`status`='paid')
// WhereOrPrefix("order", struct{Status:"paid", "channel":"bank"}) => WHERE xxx OR (`order`.`status`='paid' AND `order`.`channel`='bank')
func (b *WhereBuilder) WhereOrPrefix(prefix string, where any, args ...any) *WhereBuilder {
where, args = b.convertWhereBuilder(where, args)
builder := b.getBuilder()
builder.whereHolder = append(builder.whereHolder, WhereHolder{
Type: whereHolderTypeDefault,
Operator: whereHolderOperatorOr,
Where: where,
Args: args,
Prefix: prefix,
})
return builder
}
// WhereOrPrefixNot builds `prefix.column != value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixNot(prefix string, column string, value any) *WhereBuilder {
return b.WhereOrf(`%s.%s != ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WhereOrPrefixLT builds `prefix.column < value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixLT(prefix string, column string, value any) *WhereBuilder {
return b.WhereOrf(`%s.%s < ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WhereOrPrefixLTE builds `prefix.column <= value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixLTE(prefix string, column string, value any) *WhereBuilder {
return b.WhereOrf(`%s.%s <= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WhereOrPrefixGT builds `prefix.column > value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixGT(prefix string, column string, value any) *WhereBuilder {
return b.WhereOrf(`%s.%s > ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WhereOrPrefixGTE builds `prefix.column >= value` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixGTE(prefix string, column string, value any) *WhereBuilder {
return b.WhereOrf(`%s.%s >= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
}
// WhereOrPrefixBetween builds `prefix.column BETWEEN min AND max` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixBetween(prefix string, column string, min, max any) *WhereBuilder {
return b.WhereOrf(`%s.%s BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
}
// WhereOrPrefixLike builds `prefix.column LIKE 'like'` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixLike(prefix string, column string, like any) *WhereBuilder {
return b.WhereOrf(`%s.%s LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
}
// WhereOrPrefixIn builds `prefix.column IN (in)` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixIn(prefix string, column string, in any) *WhereBuilder {
return b.doWhereOrfType(whereHolderTypeIn, `%s.%s IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
}
// WhereOrPrefixNull builds `prefix.columns[0] IS NULL OR prefix.columns[1] IS NULL ...` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixNull(prefix string, columns ...string) *WhereBuilder {
builder := b
for _, column := range columns {
builder = builder.WhereOrf(`%s.%s IS NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
}
return builder
}
// WhereOrPrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixNotBetween(prefix string, column string, min, max any) *WhereBuilder {
return b.WhereOrf(`%s.%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
}
// WhereOrPrefixNotLike builds `prefix.column NOT LIKE 'like'` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixNotLike(prefix string, column string, like any) *WhereBuilder {
return b.WhereOrf(`%s.%s NOT LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
}
// WhereOrPrefixNotIn builds `prefix.column NOT IN (in)` statement.
func (b *WhereBuilder) WhereOrPrefixNotIn(prefix string, column string, in any) *WhereBuilder {
return b.doWhereOrfType(whereHolderTypeIn, `%s.%s NOT IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
}
// WhereOrPrefixNotNull builds `prefix.columns[0] IS NOT NULL OR prefix.columns[1] IS NOT NULL ...` statement in `OR` conditions.
func (b *WhereBuilder) WhereOrPrefixNotNull(prefix string, columns ...string) *WhereBuilder {
builder := b
for _, column := range columns {
builder = builder.WhereOrf(`%s.%s IS NOT NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
}
return builder
}

172
database/gdb_model_cache.go Normal file
View File

@ -0,0 +1,172 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"time"
"git.magicany.cc/black1552/gin-base/database/intlog"
)
// CacheOption is options for model cache control in query.
type CacheOption struct {
// Duration is the TTL for the cache.
// If the parameter `Duration` < 0, which means it clear the cache with given `Name`.
// If the parameter `Duration` = 0, which means it never expires.
// If the parameter `Duration` > 0, which means it expires after `Duration`.
Duration time.Duration
// Name is an optional unique name for the cache.
// The Name is used to bind a name to the cache, which means you can later control the cache
// like changing the `duration` or clearing the cache with specified Name.
Name string
// Force caches the query result whatever the result is nil or not.
// It is used to avoid Cache Penetration.
Force bool
}
// selectCacheItem is the cache item for SELECT statement result.
type selectCacheItem struct {
Result Result // Sql result of SELECT statement.
FirstResultColumn string // The first column name of result, for Value/Count functions.
}
// Cache sets the cache feature for the model. It caches the result of the sql, which means
// if there's another same sql request, it just reads and returns the result from cache, it
// but not committed and executed into the database.
//
// Note that, the cache feature is disabled if the model is performing select statement
// on a transaction.
func (m *Model) Cache(option CacheOption) *Model {
model := m.getModel()
model.cacheOption = option
model.cacheEnabled = true
return model
}
// PageCache sets the cache feature for pagination queries. It allows to configure
// separate cache options for count query and data query in pagination.
//
// Note that, the cache feature is disabled if the model is performing select statement
// on a transaction.
func (m *Model) PageCache(countOption CacheOption, dataOption CacheOption) *Model {
model := m.getModel()
model.pageCacheOption = []CacheOption{countOption, dataOption}
model.cacheEnabled = true
return model
}
// checkAndRemoveSelectCache checks and removes the cache in insert/update/delete statement if
// cache feature is enabled.
func (m *Model) checkAndRemoveSelectCache(ctx context.Context) {
if m.cacheEnabled && m.cacheOption.Duration < 0 && len(m.cacheOption.Name) > 0 {
var cacheKey = m.makeSelectCacheKey("")
if _, err := m.db.GetCache().Remove(ctx, cacheKey); err != nil {
intlog.Errorf(ctx, `%+v`, err)
}
}
}
func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args ...any) (result Result, err error) {
if !m.cacheEnabled || m.tx != nil {
return
}
var (
cacheItem *selectCacheItem
cacheKey = m.makeSelectCacheKey(sql, args...)
cacheObj = m.db.GetCache()
core = m.db.GetCore()
)
defer func() {
if cacheItem != nil {
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if cacheItem.FirstResultColumn != "" {
internalData.FirstResultColumn = cacheItem.FirstResultColumn
}
}
}
}()
if v, _ := cacheObj.Get(ctx, cacheKey); !v.IsNil() {
if err = v.Scan(&cacheItem); err != nil {
return nil, err
}
return cacheItem.Result, nil
}
return
}
func (m *Model) saveSelectResultToCache(
ctx context.Context, selectType SelectType, result Result, sql string, args ...any,
) (err error) {
if !m.cacheEnabled || m.tx != nil {
return
}
var (
cacheKey = m.makeSelectCacheKey(sql, args...)
cacheObj = m.db.GetCache()
)
if m.cacheOption.Duration < 0 {
if _, errCache := cacheObj.Remove(ctx, cacheKey); errCache != nil {
intlog.Errorf(ctx, `%+v`, errCache)
}
return
}
// Special handler for Value/Count operations result.
if len(result) > 0 {
var core = m.db.GetCore()
switch selectType {
case SelectTypeValue, SelectTypeArray, SelectTypeCount:
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if result[0][internalData.FirstResultColumn].IsEmpty() {
result = nil
}
}
default:
}
}
// In case of Cache Penetration.
if result != nil && result.IsEmpty() {
if m.cacheOption.Force {
result = Result{}
} else {
result = nil
}
}
var (
core = m.db.GetCore()
cacheItem = &selectCacheItem{
Result: result,
}
)
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
cacheItem.FirstResultColumn = internalData.FirstResultColumn
}
if errCache := cacheObj.Set(ctx, cacheKey, cacheItem, m.cacheOption.Duration); errCache != nil {
intlog.Errorf(ctx, `%+v`, errCache)
}
return
}
func (m *Model) makeSelectCacheKey(sql string, args ...any) string {
var (
table = m.db.GetCore().guessPrimaryTableName(m.tables)
group = m.db.GetGroup()
schema = m.db.GetSchema()
customName = m.cacheOption.Name
)
return genSelectCacheKey(
table,
group,
schema,
customName,
sql,
args...,
)
}

View File

@ -0,0 +1,87 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/internal/intlog"
"github.com/gogf/gf/v2/text/gstr"
)
// Delete does "DELETE FROM ... " statement for the model.
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) Delete(where ...any) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(where) > 0 {
return m.Where(where[0], where[1:]...).Delete()
}
defer func() {
if err == nil {
m.checkAndRemoveSelectCache(ctx)
}
}()
var (
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
conditionStr = conditionWhere + conditionExtra
fieldNameDelete, fieldTypeDelete = m.softTimeMaintainer().GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete)
)
if m.unscoped {
fieldNameDelete = ""
}
if !gstr.ContainsI(conditionStr, " WHERE ") || (fieldNameDelete != "" && !gstr.ContainsI(conditionStr, " AND ")) {
intlog.Printf(
ctx,
`sql condition string "%s" has no WHERE for DELETE operation, fieldNameDelete: %s`,
conditionStr, fieldNameDelete,
)
return nil, gerror.NewCode(
gcode.CodeMissingParameter,
"there should be WHERE condition statement for DELETE operation",
)
}
// Soft deleting.
if fieldNameDelete != "" {
dataHolder, dataValue := m.softTimeMaintainer().GetDeleteData(
ctx, "", fieldNameDelete, fieldTypeDelete,
)
in := &HookUpdateInput{
internalParamHookUpdate: internalParamHookUpdate{
internalParamHook: internalParamHook{
link: m.getLink(true),
},
handler: m.hookHandler.Update,
},
Model: m,
Table: m.tables,
Schema: m.schema,
Data: dataHolder,
Condition: conditionStr,
Args: append([]any{dataValue}, conditionArgs...),
}
return in.Next(ctx)
}
in := &HookDeleteInput{
internalParamHookDelete: internalParamHookDelete{
internalParamHook: internalParamHook{
link: m.getLink(true),
},
handler: m.hookHandler.Delete,
},
Model: m,
Table: m.tables,
Schema: m.schema,
Condition: conditionStr,
Args: conditionArgs,
}
return in.Next(ctx)
}

View File

@ -0,0 +1,283 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"fmt"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// Fields appends `fieldNamesOrMapStruct` to the operation fields of the model, multiple fields joined using char ','.
// The parameter `fieldNamesOrMapStruct` can be type of string/map/*map/struct/*struct.
//
// Example:
// Fields("id", "name", "age")
// Fields([]string{"id", "name", "age"})
// Fields(map[string]any{"id":1, "name":"john", "age":18})
// Fields(User{Id: 1, Name: "john", Age: 18}).
func (m *Model) Fields(fieldNamesOrMapStruct ...any) *Model {
length := len(fieldNamesOrMapStruct)
if length == 0 {
return m
}
fields := m.filterFieldsFrom(m.tablesInit, fieldNamesOrMapStruct...)
if len(fields) == 0 {
return m
}
model := m.getModel()
return model.appendToFields(fields...)
}
// FieldsPrefix performs as function Fields but add extra prefix for each field.
func (m *Model) FieldsPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...any) *Model {
fields := m.filterFieldsFrom(
m.getTableNameByPrefixOrAlias(prefixOrAlias),
fieldNamesOrMapStruct...,
)
if len(fields) == 0 {
return m
}
prefixOrAlias = m.QuoteWord(prefixOrAlias)
for i, field := range fields {
fields[i] = fmt.Sprintf("%s.%s", prefixOrAlias, m.QuoteWord(gconv.String(field)))
}
model := m.getModel()
return model.appendToFields(fields...)
}
// FieldsEx appends `fieldNamesOrMapStruct` to the excluded operation fields of the model,
// multiple fields joined using char ','.
// Note that this function supports only single table operations.
// The parameter `fieldNamesOrMapStruct` can be type of string/map/*map/struct/*struct.
//
// Example:
// FieldsEx("id", "name", "age")
// FieldsEx([]string{"id", "name", "age"})
// FieldsEx(map[string]any{"id":1, "name":"john", "age":18})
// FieldsEx(User{Id: 1, Name: "john", Age: 18}).
func (m *Model) FieldsEx(fieldNamesOrMapStruct ...any) *Model {
return m.doFieldsEx(m.tablesInit, fieldNamesOrMapStruct...)
}
func (m *Model) doFieldsEx(table string, fieldNamesOrMapStruct ...any) *Model {
length := len(fieldNamesOrMapStruct)
if length == 0 {
return m
}
fields := m.filterFieldsFrom(table, fieldNamesOrMapStruct...)
if len(fields) == 0 {
return m
}
model := m.getModel()
model.fieldsEx = append(model.fieldsEx, fields...)
return model
}
// FieldsExPrefix performs as function FieldsEx but add extra prefix for each field.
// Note that this function must be used together with FieldsPrefix, otherwise it will be invalid.
func (m *Model) FieldsExPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...any) *Model {
fields := m.filterFieldsFrom(
m.getTableNameByPrefixOrAlias(prefixOrAlias),
fieldNamesOrMapStruct...,
)
if len(fields) == 0 {
return m
}
prefixOrAlias = m.QuoteWord(prefixOrAlias)
for i, field := range fields {
fields[i] = fmt.Sprintf("%s.%s", prefixOrAlias, m.QuoteWord(gconv.String(field)))
}
model := m.getModel()
model.fieldsEx = append(model.fieldsEx, fields...)
return model
}
// FieldCount formats and appends commonly used field `COUNT(column)` to the select fields of model.
func (m *Model) FieldCount(column string, as ...string) *Model {
asStr := ""
if len(as) > 0 && as[0] != "" {
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
}
model := m.getModel()
return model.appendToFields(
fmt.Sprintf(`COUNT(%s)%s`, m.QuoteWord(column), asStr),
)
}
// FieldSum formats and appends commonly used field `SUM(column)` to the select fields of model.
func (m *Model) FieldSum(column string, as ...string) *Model {
asStr := ""
if len(as) > 0 && as[0] != "" {
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
}
model := m.getModel()
return model.appendToFields(
fmt.Sprintf(`SUM(%s)%s`, m.QuoteWord(column), asStr),
)
}
// FieldMin formats and appends commonly used field `MIN(column)` to the select fields of model.
func (m *Model) FieldMin(column string, as ...string) *Model {
asStr := ""
if len(as) > 0 && as[0] != "" {
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
}
model := m.getModel()
return model.appendToFields(
fmt.Sprintf(`MIN(%s)%s`, m.QuoteWord(column), asStr),
)
}
// FieldMax formats and appends commonly used field `MAX(column)` to the select fields of model.
func (m *Model) FieldMax(column string, as ...string) *Model {
asStr := ""
if len(as) > 0 && as[0] != "" {
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
}
model := m.getModel()
return model.appendToFields(
fmt.Sprintf(`MAX(%s)%s`, m.QuoteWord(column), asStr),
)
}
// FieldAvg formats and appends commonly used field `AVG(column)` to the select fields of model.
func (m *Model) FieldAvg(column string, as ...string) *Model {
asStr := ""
if len(as) > 0 && as[0] != "" {
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
}
model := m.getModel()
return model.appendToFields(
fmt.Sprintf(`AVG(%s)%s`, m.QuoteWord(column), asStr),
)
}
// GetFieldsStr retrieves and returns all fields from the table, joined with char ','.
// The optional parameter `prefix` specifies the prefix for each field, eg: GetFieldsStr("u.").
func (m *Model) GetFieldsStr(prefix ...string) string {
prefixStr := ""
if len(prefix) > 0 {
prefixStr = prefix[0]
}
tableFields, err := m.TableFields(m.tablesInit)
if err != nil {
panic(err)
}
if len(tableFields) == 0 {
panic(fmt.Sprintf(`empty table fields for table "%s"`, m.tables))
}
fieldsArray := make([]string, len(tableFields))
for k, v := range tableFields {
fieldsArray[v.Index] = k
}
newFields := ""
for _, k := range fieldsArray {
if len(newFields) > 0 {
newFields += ","
}
newFields += prefixStr + k
}
newFields = m.db.GetCore().QuoteString(newFields)
return newFields
}
// GetFieldsExStr retrieves and returns fields which are not in parameter `fields` from the table,
// joined with char ','.
// The parameter `fields` specifies the fields that are excluded.
// The optional parameter `prefix` specifies the prefix for each field, eg: FieldsExStr("id", "u.").
func (m *Model) GetFieldsExStr(fields string, prefix ...string) (string, error) {
prefixStr := ""
if len(prefix) > 0 {
prefixStr = prefix[0]
}
tableFields, err := m.TableFields(m.tablesInit)
if err != nil {
return "", err
}
if len(tableFields) == 0 {
return "", gerror.Newf(`empty table fields for table "%s"`, m.tables)
}
fieldsExSet := gset.NewStrSetFrom(gstr.SplitAndTrim(fields, ","))
fieldsArray := make([]string, len(tableFields))
for k, v := range tableFields {
fieldsArray[v.Index] = k
}
newFields := ""
for _, k := range fieldsArray {
if fieldsExSet.Contains(k) {
continue
}
if len(newFields) > 0 {
newFields += ","
}
newFields += prefixStr + k
}
newFields = m.db.GetCore().QuoteString(newFields)
return newFields, nil
}
// HasField determine whether the field exists in the table.
func (m *Model) HasField(field string) (bool, error) {
return m.db.GetCore().HasField(m.GetCtx(), m.tablesInit, field)
}
// getFieldsFrom retrieves, filters and returns fields name from table `table`.
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...any) []any {
length := len(fieldNamesOrMapStruct)
if length == 0 {
return nil
}
switch {
// String slice.
case length >= 2:
return m.mappingAndFilterToTableFields(
table, fieldNamesOrMapStruct, true,
)
// It needs type asserting.
case length == 1:
structOrMap := fieldNamesOrMapStruct[0]
switch r := structOrMap.(type) {
case string:
return m.mappingAndFilterToTableFields(table, []any{r}, false)
case []string:
return m.mappingAndFilterToTableFields(table, gconv.Interfaces(r), true)
case Raw, *Raw:
return []any{structOrMap}
default:
return m.mappingAndFilterToTableFields(table, getFieldsFromStructOrMap(structOrMap), true)
}
default:
return nil
}
}
func (m *Model) appendToFields(fields ...any) *Model {
if len(fields) == 0 {
return m
}
model := m.getModel()
model.fields = append(model.fields, fields...)
return model
}
func (m *Model) isFieldInFieldsEx(field string) bool {
for _, v := range m.fieldsEx {
if v == field {
return true
}
}
return false
}

309
database/gdb_model_hook.go Normal file
View File

@ -0,0 +1,309 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"database/sql"
"fmt"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
)
type (
HookFuncSelect func(ctx context.Context, in *HookSelectInput) (result Result, err error)
HookFuncInsert func(ctx context.Context, in *HookInsertInput) (result sql.Result, err error)
HookFuncUpdate func(ctx context.Context, in *HookUpdateInput) (result sql.Result, err error)
HookFuncDelete func(ctx context.Context, in *HookDeleteInput) (result sql.Result, err error)
)
// HookHandler manages all supported hook functions for Model.
type HookHandler struct {
Select HookFuncSelect
Insert HookFuncInsert
Update HookFuncUpdate
Delete HookFuncDelete
}
// internalParamHook manages all internal parameters for hook operations.
// The `internal` obviously means you cannot access these parameters outside this package.
type internalParamHook struct {
link Link // Connection object from third party sql driver.
handlerCalled bool // Simple mark for custom handler called, in case of recursive calling.
removedWhere bool // Removed mark for condition string that was removed `WHERE` prefix.
originalTableName *gvar.Var // The original table name.
originalSchemaName *gvar.Var // The original schema name.
}
type internalParamHookSelect struct {
internalParamHook
handler HookFuncSelect
}
type internalParamHookInsert struct {
internalParamHook
handler HookFuncInsert
}
type internalParamHookUpdate struct {
internalParamHook
handler HookFuncUpdate
}
type internalParamHookDelete struct {
internalParamHook
handler HookFuncDelete
}
// HookSelectInput holds the parameters for select hook operation.
// Note that, COUNT statement will also be hooked by this feature,
// which is usually not be interesting for upper business hook handler.
type HookSelectInput struct {
internalParamHookSelect
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Sql string // The sql string that to be committed.
Args []any // The arguments of sql.
SelectType SelectType // The type of this SELECT operation.
}
// HookInsertInput holds the parameters for insert hook operation.
type HookInsertInput struct {
internalParamHookInsert
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Data List // The data records list to be inserted/saved into table.
Option DoInsertOption // The extra option for data inserting.
}
// HookUpdateInput holds the parameters for update hook operation.
type HookUpdateInput struct {
internalParamHookUpdate
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Data any // Data can be type of: map[string]any/string. You can use type assertion on `Data`.
Condition string // The where condition string for updating.
Args []any // The arguments for sql place-holders.
}
// HookDeleteInput holds the parameters for delete hook operation.
type HookDeleteInput struct {
internalParamHookDelete
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Condition string // The where condition string for deleting.
Args []any // The arguments for sql place-holders.
}
const (
whereKeyInCondition = " WHERE "
)
// IsTransaction checks and returns whether current operation is during transaction.
func (h *internalParamHook) IsTransaction() bool {
return h.link.IsTransaction()
}
// Next calls the next hook handler.
func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) {
if h.originalTableName.IsNil() {
h.originalTableName = gvar.New(h.Table)
}
if h.originalSchemaName.IsNil() {
h.originalSchemaName = gvar.New(h.Schema)
}
// Sharding feature.
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
if err != nil {
return nil, err
}
h.Table, err = h.Model.getActualTable(ctx, h.Table)
if err != nil {
return nil, err
}
// Custom hook handler call.
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
return h.handler(ctx, h)
}
var toBeCommittedSql = h.Sql
// Table change.
if h.Table != h.originalTableName.String() {
toBeCommittedSql, err = gregex.ReplaceStringFuncMatch(
`(?i) FROM ([\S]+)`,
toBeCommittedSql,
func(match []string) string {
charL, charR := h.Model.db.GetChars()
return fmt.Sprintf(` FROM %s%s%s`, charL, h.Table, charR)
},
)
if err != nil {
return
}
}
// Schema change.
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
h.link, err = h.Model.db.GetCore().SlaveLink(h.Schema)
if err != nil {
return
}
h.Model.db.GetCore().schema = h.Schema
defer func() {
h.Model.db.GetCore().schema = h.originalSchemaName.String()
}()
}
return h.Model.db.DoSelect(ctx, h.link, toBeCommittedSql, h.Args...)
}
// Next calls the next hook handler.
func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) {
if h.originalTableName.IsNil() {
h.originalTableName = gvar.New(h.Table)
}
if h.originalSchemaName.IsNil() {
h.originalSchemaName = gvar.New(h.Schema)
}
// Sharding feature.
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
if err != nil {
return nil, err
}
h.Table, err = h.Model.getActualTable(ctx, h.Table)
if err != nil {
return nil, err
}
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
return h.handler(ctx, h)
}
// No need to handle table change.
// Schema change.
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
h.link, err = h.Model.db.GetCore().MasterLink(h.Schema)
if err != nil {
return
}
h.Model.db.GetCore().schema = h.Schema
defer func() {
h.Model.db.GetCore().schema = h.originalSchemaName.String()
}()
}
return h.Model.db.DoInsert(ctx, h.link, h.Table, h.Data, h.Option)
}
// Next calls the next hook handler.
func (h *HookUpdateInput) Next(ctx context.Context) (result sql.Result, err error) {
if h.originalTableName.IsNil() {
h.originalTableName = gvar.New(h.Table)
}
if h.originalSchemaName.IsNil() {
h.originalSchemaName = gvar.New(h.Schema)
}
// Sharding feature.
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
if err != nil {
return nil, err
}
h.Table, err = h.Model.getActualTable(ctx, h.Table)
if err != nil {
return nil, err
}
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
if gstr.HasPrefix(h.Condition, whereKeyInCondition) {
h.removedWhere = true
h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition)
}
return h.handler(ctx, h)
}
if h.removedWhere {
h.Condition = whereKeyInCondition + h.Condition
}
// No need to handle table change.
// Schema change.
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
h.link, err = h.Model.db.GetCore().MasterLink(h.Schema)
if err != nil {
return
}
h.Model.db.GetCore().schema = h.Schema
defer func() {
h.Model.db.GetCore().schema = h.originalSchemaName.String()
}()
}
return h.Model.db.DoUpdate(ctx, h.link, h.Table, h.Data, h.Condition, h.Args...)
}
// Next calls the next hook handler.
func (h *HookDeleteInput) Next(ctx context.Context) (result sql.Result, err error) {
if h.originalTableName.IsNil() {
h.originalTableName = gvar.New(h.Table)
}
if h.originalSchemaName.IsNil() {
h.originalSchemaName = gvar.New(h.Schema)
}
// Sharding feature.
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
if err != nil {
return nil, err
}
h.Table, err = h.Model.getActualTable(ctx, h.Table)
if err != nil {
return nil, err
}
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
if gstr.HasPrefix(h.Condition, whereKeyInCondition) {
h.removedWhere = true
h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition)
}
return h.handler(ctx, h)
}
if h.removedWhere {
h.Condition = whereKeyInCondition + h.Condition
}
// No need to handle table change.
// Schema change.
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
h.link, err = h.Model.db.GetCore().MasterLink(h.Schema)
if err != nil {
return
}
h.Model.db.GetCore().schema = h.Schema
defer func() {
h.Model.db.GetCore().schema = h.originalSchemaName.String()
}()
}
return h.Model.db.DoDelete(ctx, h.link, h.Table, h.Condition, h.Args...)
}
// Hook sets the hook functions for current model.
func (m *Model) Hook(hook HookHandler) *Model {
model := m.getModel()
model.hookHandler = hook
return model
}

View File

@ -0,0 +1,469 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"database/sql"
"reflect"
"git.magicany.cc/black1552/gin-base/database/empty"
"git.magicany.cc/black1552/gin-base/database/reflection"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
// Batch sets the batch operation number for the model.
func (m *Model) Batch(batch int) *Model {
model := m.getModel()
model.batch = batch
return model
}
// Data sets the operation data for the model.
// The parameter `data` can be type of string/map/gmap/slice/struct/*struct, etc.
// Note that, it uses shallow value copying for `data` if `data` is type of map/slice
// to avoid changing it inside function.
// Eg:
// Data("uid=10000")
// Data("uid", 10000)
// Data("uid=? AND name=?", 10000, "john")
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}).
func (m *Model) Data(data ...any) *Model {
var model = m.getModel()
if len(data) > 1 {
if s := gconv.String(data[0]); gstr.Contains(s, "?") {
model.data = s
model.extraArgs = data[1:]
} else {
newData := make(map[string]any)
for i := 0; i < len(data); i += 2 {
newData[gconv.String(data[i])] = data[i+1]
}
model.data = newData
}
} else if len(data) == 1 {
switch value := data[0].(type) {
case Result:
model.data = value.List()
case Record:
model.data = value.Map()
case List:
list := make(List, len(value))
for k, v := range value {
list[k] = gutil.MapCopy(v)
}
model.data = list
case Map:
model.data = gutil.MapCopy(value)
default:
reflectInfo := reflection.OriginValueAndKind(value)
switch reflectInfo.OriginKind {
case reflect.Slice, reflect.Array:
if reflectInfo.OriginValue.Len() > 0 {
// If the `data` parameter is a DO struct,
// it then adds `OmitNilData` option for this condition,
// which will filter all nil parameters in `data`.
if isDoStruct(reflectInfo.OriginValue.Index(0).Interface()) {
model = model.OmitNilData()
model.option |= optionOmitNilDataInternal
}
}
list := make(List, reflectInfo.OriginValue.Len())
for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
list[i] = anyValueToMapBeforeToRecord(reflectInfo.OriginValue.Index(i).Interface())
}
model.data = list
case reflect.Struct:
// If the `data` parameter is a DO struct,
// it then adds `OmitNilData` option for this condition,
// which will filter all nil parameters in `data`.
if isDoStruct(value) {
model = model.OmitNilData()
}
if v, ok := data[0].(iInterfaces); ok {
var (
array = v.Interfaces()
list = make(List, len(array))
)
for i := 0; i < len(array); i++ {
list[i] = anyValueToMapBeforeToRecord(array[i])
}
model.data = list
} else {
model.data = anyValueToMapBeforeToRecord(data[0])
}
case reflect.Map:
model.data = anyValueToMapBeforeToRecord(data[0])
default:
model.data = data[0]
}
}
}
return model
}
// OnConflict sets the primary key or index when columns conflicts occurs.
// It's not necessary for MySQL driver.
func (m *Model) OnConflict(onConflict ...any) *Model {
if len(onConflict) == 0 {
return m
}
model := m.getModel()
if len(onConflict) > 1 {
model.onConflict = onConflict
} else if len(onConflict) == 1 {
model.onConflict = onConflict[0]
}
return model
}
// OnDuplicate sets the operations when columns conflicts occurs.
// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
// The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice.
// Example:
//
// OnDuplicate("nickname, age")
// OnDuplicate("nickname", "age")
//
// OnDuplicate(g.Map{
// "nickname": gdb.Raw("CONCAT('name_', VALUES(`nickname`))"),
// })
//
// OnDuplicate(g.Map{
// "nickname": "passport",
// }).
func (m *Model) OnDuplicate(onDuplicate ...any) *Model {
if len(onDuplicate) == 0 {
return m
}
model := m.getModel()
if len(onDuplicate) > 1 {
model.onDuplicate = onDuplicate
} else if len(onDuplicate) == 1 {
model.onDuplicate = onDuplicate[0]
}
return model
}
// OnDuplicateEx sets the excluding columns for operations when columns conflict occurs.
// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
// The parameter `onDuplicateEx` can be type of string/map/slice.
// Example:
//
// OnDuplicateEx("passport, password")
// OnDuplicateEx("passport", "password")
//
// OnDuplicateEx(g.Map{
// "passport": "",
// "password": "",
// }).
func (m *Model) OnDuplicateEx(onDuplicateEx ...any) *Model {
if len(onDuplicateEx) == 0 {
return m
}
model := m.getModel()
if len(onDuplicateEx) > 1 {
model.onDuplicateEx = onDuplicateEx
} else if len(onDuplicateEx) == 1 {
model.onDuplicateEx = onDuplicateEx[0]
}
return model
}
// Insert does "INSERT INTO ..." statement for the model.
// The optional parameter `data` is the same as the parameter of Model.Data function,
// see Model.Data.
func (m *Model) Insert(data ...any) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).Insert()
}
return m.doInsertWithOption(ctx, InsertOptionDefault)
}
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
func (m *Model) InsertAndGetId(data ...any) (lastInsertId int64, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).InsertAndGetId()
}
result, err := m.doInsertWithOption(ctx, InsertOptionDefault)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the model.
// The optional parameter `data` is the same as the parameter of Model.Data function,
// see Model.Data.
func (m *Model) InsertIgnore(data ...any) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).InsertIgnore()
}
return m.doInsertWithOption(ctx, InsertOptionIgnore)
}
// Replace does "REPLACE INTO ..." statement for the model.
// The optional parameter `data` is the same as the parameter of Model.Data function,
// see Model.Data.
func (m *Model) Replace(data ...any) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).Replace()
}
return m.doInsertWithOption(ctx, InsertOptionReplace)
}
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model.
// The optional parameter `data` is the same as the parameter of Model.Data function,
// see Model.Data.
//
// It updates the record if there's primary or unique index in the saving data,
// or else it inserts a new record into the table.
func (m *Model) Save(data ...any) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).Save()
}
return m.doInsertWithOption(ctx, InsertOptionSave)
}
// doInsertWithOption inserts data with option parameter.
func (m *Model) doInsertWithOption(ctx context.Context, insertOption InsertOption) (result sql.Result, err error) {
defer func() {
if err == nil {
m.checkAndRemoveSelectCache(ctx)
}
}()
if m.data == nil {
return nil, gerror.NewCode(gcode.CodeMissingParameter, "inserting into table with empty data")
}
var (
list List
stm = m.softTimeMaintainer()
fieldNameCreate, fieldTypeCreate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldCreate)
fieldNameUpdate, fieldTypeUpdate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldUpdate)
fieldNameDelete, fieldTypeDelete = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete)
)
// m.data was already converted to type List/Map by function Data
newData, err := m.filterDataForInsertOrUpdate(m.data)
if err != nil {
return nil, err
}
// It converts any data to List type for inserting.
switch value := newData.(type) {
case List:
list = value
case Map:
list = List{value}
}
if len(list) < 1 {
return result, gerror.NewCode(gcode.CodeMissingParameter, "data list cannot be empty")
}
// Automatic handling for creating/updating time.
if fieldNameCreate != "" && m.isFieldInFieldsEx(fieldNameCreate) {
fieldNameCreate = ""
}
if fieldNameUpdate != "" && m.isFieldInFieldsEx(fieldNameUpdate) {
fieldNameUpdate = ""
}
var isSoftTimeFeatureEnabled = fieldNameCreate != "" || fieldNameUpdate != ""
if !m.unscoped && isSoftTimeFeatureEnabled {
for k, v := range list {
if fieldNameCreate != "" && empty.IsNil(v[fieldNameCreate]) {
fieldCreateValue := stm.GetFieldValue(ctx, fieldTypeCreate, false)
if fieldCreateValue != nil {
v[fieldNameCreate] = fieldCreateValue
}
}
if fieldNameUpdate != "" && empty.IsNil(v[fieldNameUpdate]) {
fieldUpdateValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false)
if fieldUpdateValue != nil {
v[fieldNameUpdate] = fieldUpdateValue
}
}
// for timestamp field that should initialize the delete_at field with value, for example 0.
if fieldNameDelete != "" && empty.IsNil(v[fieldNameDelete]) {
fieldDeleteValue := stm.GetFieldValue(ctx, fieldTypeDelete, true)
if fieldDeleteValue != nil {
v[fieldNameDelete] = fieldDeleteValue
}
}
list[k] = v
}
}
// Format DoInsertOption, especially for "ON DUPLICATE KEY UPDATE" statement.
columnNames := make([]string, 0, len(list[0]))
for k := range list[0] {
columnNames = append(columnNames, k)
}
doInsertOption, err := m.formatDoInsertOption(insertOption, columnNames)
if err != nil {
return result, err
}
in := &HookInsertInput{
internalParamHookInsert: internalParamHookInsert{
internalParamHook: internalParamHook{
link: m.getLink(true),
},
handler: m.hookHandler.Insert,
},
Model: m,
Table: m.tables,
Schema: m.schema,
Data: list,
Option: doInsertOption,
}
return in.Next(ctx)
}
func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []string) (option DoInsertOption, err error) {
option = DoInsertOption{
InsertOption: insertOption,
BatchCount: m.getBatch(),
}
if insertOption != InsertOptionSave {
return
}
onConflictKeys, err := m.formatOnConflictKeys(m.onConflict)
if err != nil {
return option, err
}
option.OnConflict = onConflictKeys
onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx)
if err != nil {
return option, err
}
onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys)
if m.onDuplicate != nil {
switch m.onDuplicate.(type) {
case Raw, *Raw:
option.OnDuplicateStr = gconv.String(m.onDuplicate)
default:
reflectInfo := reflection.OriginValueAndKind(m.onDuplicate)
switch reflectInfo.OriginKind {
case reflect.String:
option.OnDuplicateMap = make(map[string]any)
for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") {
if onDuplicateExKeySet.Contains(v) {
continue
}
option.OnDuplicateMap[v] = v
}
case reflect.Map:
option.OnDuplicateMap = make(map[string]any)
for k, v := range gconv.Map(m.onDuplicate) {
if onDuplicateExKeySet.Contains(k) {
continue
}
option.OnDuplicateMap[k] = v
}
case reflect.Slice, reflect.Array:
option.OnDuplicateMap = make(map[string]any)
for _, v := range gconv.Strings(m.onDuplicate) {
if onDuplicateExKeySet.Contains(v) {
continue
}
option.OnDuplicateMap[v] = v
}
default:
return option, gerror.NewCodef(
gcode.CodeInvalidParameter,
`unsupported OnDuplicate parameter type "%s"`,
reflect.TypeOf(m.onDuplicate),
)
}
}
} else if onDuplicateExKeySet.Size() > 0 {
option.OnDuplicateMap = make(map[string]any)
for _, v := range columnNames {
if onDuplicateExKeySet.Contains(v) {
continue
}
option.OnDuplicateMap[v] = v
}
}
return
}
func (m *Model) formatOnDuplicateExKeys(onDuplicateEx any) ([]string, error) {
if onDuplicateEx == nil {
return nil, nil
}
reflectInfo := reflection.OriginValueAndKind(onDuplicateEx)
switch reflectInfo.OriginKind {
case reflect.String:
return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
case reflect.Map:
return gutil.Keys(onDuplicateEx), nil
case reflect.Slice, reflect.Array:
return gconv.Strings(onDuplicateEx), nil
default:
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`unsupported OnDuplicateEx parameter type "%s"`,
reflect.TypeOf(onDuplicateEx),
)
}
}
func (m *Model) formatOnConflictKeys(onConflict any) ([]string, error) {
if onConflict == nil {
return nil, nil
}
reflectInfo := reflection.OriginValueAndKind(onConflict)
switch reflectInfo.OriginKind {
case reflect.String:
return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
case reflect.Slice, reflect.Array:
return gconv.Strings(onConflict), nil
default:
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`unsupported onConflict parameter type "%s"`,
reflect.TypeOf(onConflict),
)
}
}
func (m *Model) getBatch() int {
return m.batch
}

223
database/gdb_model_join.go Normal file
View File

@ -0,0 +1,223 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"fmt"
"github.com/gogf/gf/v2/text/gstr"
)
// LeftJoin does "LEFT JOIN ... ON ..." statement on the model.
// The parameter `table` can be joined table and its joined condition,
// and also with its alias name.
//
// Eg:
// Model("user").LeftJoin("user_detail", "user_detail.uid=user.uid")
// Model("user", "u").LeftJoin("user_detail", "ud", "ud.uid=u.uid")
// Model("user", "u").LeftJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid").
func (m *Model) LeftJoin(tableOrSubQueryAndJoinConditions ...string) *Model {
return m.doJoin(joinOperatorLeft, tableOrSubQueryAndJoinConditions...)
}
// RightJoin does "RIGHT JOIN ... ON ..." statement on the model.
// The parameter `table` can be joined table and its joined condition,
// and also with its alias name.
//
// Eg:
// Model("user").RightJoin("user_detail", "user_detail.uid=user.uid")
// Model("user", "u").RightJoin("user_detail", "ud", "ud.uid=u.uid")
// Model("user", "u").RightJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid").
func (m *Model) RightJoin(tableOrSubQueryAndJoinConditions ...string) *Model {
return m.doJoin(joinOperatorRight, tableOrSubQueryAndJoinConditions...)
}
// InnerJoin does "INNER JOIN ... ON ..." statement on the model.
// The parameter `table` can be joined table and its joined condition,
// and also with its alias name。
//
// Eg:
// Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid")
// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid")
// Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid").
func (m *Model) InnerJoin(tableOrSubQueryAndJoinConditions ...string) *Model {
return m.doJoin(joinOperatorInner, tableOrSubQueryAndJoinConditions...)
}
// LeftJoinOnField performs as LeftJoin, but it joins both tables with the `same field name`.
//
// Eg:
// Model("order").LeftJoinOnField("user", "user_id")
// Model("order").LeftJoinOnField("product", "product_id").
func (m *Model) LeftJoinOnField(table, field string) *Model {
return m.doJoin(joinOperatorLeft, table, fmt.Sprintf(
`%s.%s=%s.%s`,
m.tablesInit,
m.db.GetCore().QuoteWord(field),
m.db.GetCore().QuoteWord(table),
m.db.GetCore().QuoteWord(field),
))
}
// RightJoinOnField performs as RightJoin, but it joins both tables with the `same field name`.
//
// Eg:
// Model("order").InnerJoinOnField("user", "user_id")
// Model("order").InnerJoinOnField("product", "product_id").
func (m *Model) RightJoinOnField(table, field string) *Model {
return m.doJoin(joinOperatorRight, table, fmt.Sprintf(
`%s.%s=%s.%s`,
m.tablesInit,
m.db.GetCore().QuoteWord(field),
m.db.GetCore().QuoteWord(table),
m.db.GetCore().QuoteWord(field),
))
}
// InnerJoinOnField performs as InnerJoin, but it joins both tables with the `same field name`.
//
// Eg:
// Model("order").InnerJoinOnField("user", "user_id")
// Model("order").InnerJoinOnField("product", "product_id").
func (m *Model) InnerJoinOnField(table, field string) *Model {
return m.doJoin(joinOperatorInner, table, fmt.Sprintf(
`%s.%s=%s.%s`,
m.tablesInit,
m.db.GetCore().QuoteWord(field),
m.db.GetCore().QuoteWord(table),
m.db.GetCore().QuoteWord(field),
))
}
// LeftJoinOnFields performs as LeftJoin. It specifies different fields and comparison operator.
//
// Eg:
// Model("user").LeftJoinOnFields("order", "id", "=", "user_id")
// Model("user").LeftJoinOnFields("order", "id", ">", "user_id")
// Model("user").LeftJoinOnFields("order", "id", "<", "user_id")
func (m *Model) LeftJoinOnFields(table, firstField, operator, secondField string) *Model {
return m.doJoin(joinOperatorLeft, table, fmt.Sprintf(
`%s.%s %s %s.%s`,
m.tablesInit,
m.db.GetCore().QuoteWord(firstField),
operator,
m.db.GetCore().QuoteWord(table),
m.db.GetCore().QuoteWord(secondField),
))
}
// RightJoinOnFields performs as RightJoin. It specifies different fields and comparison operator.
//
// Eg:
// Model("user").RightJoinOnFields("order", "id", "=", "user_id")
// Model("user").RightJoinOnFields("order", "id", ">", "user_id")
// Model("user").RightJoinOnFields("order", "id", "<", "user_id")
func (m *Model) RightJoinOnFields(table, firstField, operator, secondField string) *Model {
return m.doJoin(joinOperatorRight, table, fmt.Sprintf(
`%s.%s %s %s.%s`,
m.tablesInit,
m.db.GetCore().QuoteWord(firstField),
operator,
m.db.GetCore().QuoteWord(table),
m.db.GetCore().QuoteWord(secondField),
))
}
// InnerJoinOnFields performs as InnerJoin. It specifies different fields and comparison operator.
//
// Eg:
// Model("user").InnerJoinOnFields("order", "id", "=", "user_id")
// Model("user").InnerJoinOnFields("order", "id", ">", "user_id")
// Model("user").InnerJoinOnFields("order", "id", "<", "user_id")
func (m *Model) InnerJoinOnFields(table, firstField, operator, secondField string) *Model {
return m.doJoin(joinOperatorInner, table, fmt.Sprintf(
`%s.%s %s %s.%s`,
m.tablesInit,
m.db.GetCore().QuoteWord(firstField),
operator,
m.db.GetCore().QuoteWord(table),
m.db.GetCore().QuoteWord(secondField),
))
}
// doJoin does "LEFT/RIGHT/INNER JOIN ... ON ..." statement on the model.
// The parameter `tableOrSubQueryAndJoinConditions` can be joined table and its joined condition,
// and also with its alias name.
//
// Eg:
// Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid")
// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid")
// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid>u.uid")
// Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid")
// Related issues:
// https://github.com/gogf/gf/issues/1024
func (m *Model) doJoin(operator joinOperator, tableOrSubQueryAndJoinConditions ...string) *Model {
var (
model = m.getModel()
joinStr = ""
table string
alias string
)
// Check the first parameter table or sub-query.
if len(tableOrSubQueryAndJoinConditions) > 0 {
if isSubQuery(tableOrSubQueryAndJoinConditions[0]) {
joinStr = gstr.Trim(tableOrSubQueryAndJoinConditions[0])
if joinStr[0] != '(' {
joinStr = "(" + joinStr + ")"
}
} else {
table = tableOrSubQueryAndJoinConditions[0]
joinStr = m.db.GetCore().QuotePrefixTableName(table)
}
}
// Generate join condition statement string.
conditionLength := len(tableOrSubQueryAndJoinConditions)
switch {
case conditionLength > 2:
alias = tableOrSubQueryAndJoinConditions[1]
model.tables += fmt.Sprintf(
" %s JOIN %s AS %s ON (%s)",
operator, joinStr,
m.db.GetCore().QuoteWord(alias),
tableOrSubQueryAndJoinConditions[2],
)
m.tableAliasMap[alias] = table
case conditionLength == 2:
model.tables += fmt.Sprintf(
" %s JOIN %s ON (%s)",
operator, joinStr, tableOrSubQueryAndJoinConditions[1],
)
case conditionLength == 1:
model.tables += fmt.Sprintf(
" %s JOIN %s", operator, joinStr,
)
}
return model
}
// getTableNameByPrefixOrAlias checks and returns the table name if `prefixOrAlias` is an alias of a table,
// it or else returns the `prefixOrAlias` directly.
func (m *Model) getTableNameByPrefixOrAlias(prefixOrAlias string) string {
value, ok := m.tableAliasMap[prefixOrAlias]
if ok {
return value
}
return prefixOrAlias
}
// isSubQuery checks and returns whether given string a sub-query sql string.
func isSubQuery(s string) bool {
s = gstr.TrimLeft(s, "()")
if p := gstr.Pos(s, " "); p != -1 {
if gstr.Equal(s[:p], "select") {
return true
}
}
return false
}

129
database/gdb_model_lock.go Normal file
View File

@ -0,0 +1,129 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// Lock clause constants for different databases.
// These constants provide type-safe and IDE-friendly access to various lock syntaxes.
const (
// Common lock clauses (supported by most databases)
LockForUpdate = "FOR UPDATE"
LockForUpdateSkipLocked = "FOR UPDATE SKIP LOCKED"
// MySQL lock clauses
LockInShareMode = "LOCK IN SHARE MODE" // MySQL legacy syntax
LockForShare = "FOR SHARE" // MySQL 8.0+ and PostgreSQL
LockForUpdateNowait = "FOR UPDATE NOWAIT" // MySQL 8.0+ and Oracle
// PostgreSQL specific lock clauses
LockForNoKeyUpdate = "FOR NO KEY UPDATE"
LockForKeyShare = "FOR KEY SHARE"
LockForShareNowait = "FOR SHARE NOWAIT"
LockForShareSkipLocked = "FOR SHARE SKIP LOCKED"
LockForNoKeyUpdateNowait = "FOR NO KEY UPDATE NOWAIT"
LockForNoKeyUpdateSkipLocked = "FOR NO KEY UPDATE SKIP LOCKED"
LockForKeyShareNowait = "FOR KEY SHARE NOWAIT"
LockForKeyShareSkipLocked = "FOR KEY SHARE SKIP LOCKED"
// Oracle specific lock clauses
LockForUpdateWait5 = "FOR UPDATE WAIT 5"
LockForUpdateWait10 = "FOR UPDATE WAIT 10"
LockForUpdateWait30 = "FOR UPDATE WAIT 30"
// SQL Server lock hints (use with WITH clause)
LockWithUpdLock = "WITH (UPDLOCK)"
LockWithHoldLock = "WITH (HOLDLOCK)"
LockWithXLock = "WITH (XLOCK)"
LockWithTabLock = "WITH (TABLOCK)"
LockWithNoLock = "WITH (NOLOCK)"
LockWithUpdLockHoldLock = "WITH (UPDLOCK, HOLDLOCK)"
)
// Lock sets a custom lock clause for the current operation.
// This is a generic method that allows you to specify any lock syntax supported by your database.
// You can use predefined constants or custom strings.
//
// Database-specific lock syntax support:
//
// PostgreSQL (most comprehensive):
// - "FOR UPDATE" - Exclusive lock, blocks all access
// - "FOR NO KEY UPDATE" - Weaker exclusive lock, doesn't block FOR KEY SHARE
// - "FOR SHARE" - Shared lock, allows reads but blocks writes
// - "FOR KEY SHARE" - Weakest lock, only locks key values
// - All above can be combined with:
// - "NOWAIT" - Return immediately if lock cannot be acquired
// - "SKIP LOCKED" - Skip locked rows instead of waiting
//
// MySQL:
// - "FOR UPDATE" - Exclusive lock (all versions)
// - "LOCK IN SHARE MODE" - Shared lock (legacy syntax)
// - "FOR SHARE" - Shared lock (MySQL 8.0+)
// - "FOR UPDATE NOWAIT" - MySQL 8.0+ only
// - "FOR UPDATE SKIP LOCKED" - MySQL 8.0+ only
//
// Oracle:
// - "FOR UPDATE" - Exclusive lock
// - "FOR UPDATE NOWAIT" - Exclusive lock, no wait
// - "FOR UPDATE SKIP LOCKED" - Exclusive lock, skip locked rows
// - "FOR UPDATE WAIT n" - Exclusive lock, wait n seconds
// - "FOR UPDATE OF column_list" - Lock specific columns
//
// SQL Server (uses WITH hints):
// - "WITH (UPDLOCK)" - Update lock
// - "WITH (HOLDLOCK)" - Hold lock until transaction end
// - "WITH (XLOCK)" - Exclusive lock
// - "WITH (TABLOCK)" - Table lock
// - "WITH (NOLOCK)" - No lock (dirty read)
// - "WITH (UPDLOCK, HOLDLOCK)" - Combined update and hold lock
//
// SQLite:
// - Limited locking support, database-level locks only
// - No row-level lock syntax supported
//
// Usage examples:
//
// db.Model("users").Lock("FOR UPDATE NOWAIT").Where("id", 1).One()
// db.Model("users").Lock("FOR SHARE SKIP LOCKED").Where("status", "active").All()
// db.Model("users").Lock("WITH (UPDLOCK)").Where("id", 1).One() // SQL Server
// db.Model("users").Lock("FOR UPDATE OF name, email").Where("id", 1).One() // Oracle
// db.Model("users").Lock("FOR UPDATE WAIT 15").Where("id", 1).One() // Oracle custom wait
//
// Or use predefined constants for better IDE support:
//
// db.Model("users").Lock(gdb.LockForUpdateNowait).Where("id", 1).One()
// db.Model("users").Lock(gdb.LockForShareSkipLocked).Where("status", "active").All()
func (m *Model) Lock(lockClause string) *Model {
model := m.getModel()
model.lockInfo = lockClause
return model
}
// LockUpdate sets the lock for update for current operation.
// This is equivalent to Lock("FOR UPDATE").
func (m *Model) LockUpdate() *Model {
model := m.getModel()
model.lockInfo = LockForUpdate
return model
}
// LockUpdateSkipLocked sets the lock for update with skip locked behavior for current operation.
// It skips the locked rows.
// This is equivalent to Lock("FOR UPDATE SKIP LOCKED").
// Note: Supported by PostgreSQL, Oracle, and MySQL 8.0+.
func (m *Model) LockUpdateSkipLocked() *Model {
model := m.getModel()
model.lockInfo = LockForUpdateSkipLocked
return model
}
// LockShared sets the lock in share mode for current operation.
// This is equivalent to Lock("LOCK IN SHARE MODE") for MySQL or Lock("FOR SHARE") for PostgreSQL.
// Note: For maximum compatibility, this uses MySQL's legacy syntax.
func (m *Model) LockShared() *Model {
model := m.getModel()
model.lockInfo = LockInShareMode
return model
}

View File

@ -0,0 +1,73 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
const (
optionOmitNil = optionOmitNilWhere | optionOmitNilData
optionOmitEmpty = optionOmitEmptyWhere | optionOmitEmptyData
optionOmitNilDataInternal = optionOmitNilData | optionOmitNilDataList // this option is used internally only for ForDao feature.
optionOmitEmptyWhere = 1 << iota // 8
optionOmitEmptyData // 16
optionOmitNilWhere // 32
optionOmitNilData // 64
optionOmitNilDataList // 128
)
// OmitEmpty sets optionOmitEmpty option for the model, which automatically filers
// the data and where parameters for `empty` values.
func (m *Model) OmitEmpty() *Model {
model := m.getModel()
model.option = model.option | optionOmitEmpty
return model
}
// OmitEmptyWhere sets optionOmitEmptyWhere option for the model, which automatically filers
// the Where/Having parameters for `empty` values.
//
// Eg:
//
// Where("id", []int{}).All() -> SELECT xxx FROM xxx WHERE 0=1
// Where("name", "").All() -> SELECT xxx FROM xxx WHERE `name`=''
// OmitEmpty().Where("id", []int{}).All() -> SELECT xxx FROM xxx
// OmitEmpty().("name", "").All() -> SELECT xxx FROM xxx.
func (m *Model) OmitEmptyWhere() *Model {
model := m.getModel()
model.option = model.option | optionOmitEmptyWhere
return model
}
// OmitEmptyData sets optionOmitEmptyData option for the model, which automatically filers
// the Data parameters for `empty` values.
func (m *Model) OmitEmptyData() *Model {
model := m.getModel()
model.option = model.option | optionOmitEmptyData
return model
}
// OmitNil sets optionOmitNil option for the model, which automatically filers
// the data and where parameters for `nil` values.
func (m *Model) OmitNil() *Model {
model := m.getModel()
model.option = model.option | optionOmitNil
return model
}
// OmitNilWhere sets optionOmitNilWhere option for the model, which automatically filers
// the Where/Having parameters for `nil` values.
func (m *Model) OmitNilWhere() *Model {
model := m.getModel()
model.option = model.option | optionOmitNilWhere
return model
}
// OmitNilData sets optionOmitNilData option for the model, which automatically filers
// the Data parameters for `nil` values.
func (m *Model) OmitNilData() *Model {
model := m.getModel()
model.option = model.option | optionOmitNilData
return model
}

View File

@ -0,0 +1,95 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"strings"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// Order sets the "ORDER BY" statement for the model.
//
// Example:
// Order("id desc")
// Order("id", "desc")
// Order("id desc,name asc")
// Order("id desc", "name asc")
// Order("id desc").Order("name asc")
// Order(gdb.Raw("field(id, 3,1,2)")).
func (m *Model) Order(orderBy ...any) *Model {
if len(orderBy) == 0 {
return m
}
var (
core = m.db.GetCore()
model = m.getModel()
)
for _, v := range orderBy {
if model.orderBy != "" {
model.orderBy += ","
}
switch v.(type) {
case Raw, *Raw:
model.orderBy += gconv.String(v)
default:
orderByStr := gconv.String(v)
if gstr.Contains(orderByStr, " ") {
model.orderBy += core.QuoteString(orderByStr)
} else {
if gstr.Equal(orderByStr, "ASC") || gstr.Equal(orderByStr, "DESC") {
model.orderBy = gstr.TrimRight(model.orderBy, ",")
model.orderBy += " " + orderByStr
} else {
model.orderBy += core.QuoteWord(orderByStr)
}
}
}
}
return model
}
// OrderAsc sets the "ORDER BY xxx ASC" statement for the model.
func (m *Model) OrderAsc(column string) *Model {
if len(column) == 0 {
return m
}
return m.Order(column + " ASC")
}
// OrderDesc sets the "ORDER BY xxx DESC" statement for the model.
func (m *Model) OrderDesc(column string) *Model {
if len(column) == 0 {
return m
}
return m.Order(column + " DESC")
}
// OrderRandom sets the "ORDER BY RANDOM()" statement for the model.
func (m *Model) OrderRandom() *Model {
model := m.getModel()
model.orderBy = m.db.OrderRandomFunction()
return model
}
// Group sets the "GROUP BY" statement for the model.
func (m *Model) Group(groupBy ...string) *Model {
if len(groupBy) == 0 {
return m
}
var (
core = m.db.GetCore()
model = m.getModel()
)
if model.groupBy != "" {
model.groupBy += ","
}
model.groupBy += core.QuoteString(strings.Join(groupBy, ","))
return model
}

View File

@ -0,0 +1,964 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"fmt"
"reflect"
"git.magicany.cc/black1552/gin-base/database/reflection"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// All does "SELECT FROM ..." statement for the model.
// It retrieves the records from table and returns the result as slice type.
// It returns nil if there's no record retrieved with the given conditions from table.
//
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) All(where ...any) (Result, error) {
var ctx = m.GetCtx()
return m.doGetAll(ctx, SelectTypeDefault, false, where...)
}
// AllAndCount retrieves all records and the total count of records from the model.
// If useFieldForCount is true, it will use the fields specified in the model for counting;
// otherwise, it will use a constant value of 1 for counting.
// It returns the result as a slice of records, the total count of records, and an error if any.
// The where parameter is an optional list of conditions to use when retrieving records.
//
// Example:
//
// var model Model
// var result Result
// var count int
// where := []any{"name = ?", "John"}
// result, count, err := model.AllAndCount(true)
// if err != nil {
// // Handle error.
// }
// fmt.Println(result, count)
func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount int, err error) {
// Clone the model for counting
countModel := m.Clone()
// Decide how to build the COUNT() expression:
// - If caller explicitly wants to use the single field expression for counting,
// honor it (e.g. Fields("DISTINCT col") with useFieldForCount = true).
// - Otherwise, clear fields to let Count() use its default COUNT(1),
// avoiding invalid COUNT(field1, field2, ...) with multiple fields,
// or incorrect COUNT(DISTINCT 1) when Distinct() is set.
if useFieldForCount && len(m.fields) == 1 {
countModel.fields = m.fields
} else {
countModel.fields = nil
}
if len(m.pageCacheOption) > 0 {
countModel = countModel.Cache(m.pageCacheOption[0])
}
// Get the total count of records
totalCount, err = countModel.Count()
if err != nil {
return
}
// If the total count is 0, there are no records to retrieve, so return early
if totalCount == 0 {
return
}
resultModel := m.Clone()
if len(m.pageCacheOption) > 1 {
resultModel = resultModel.Cache(m.pageCacheOption[1])
}
// Retrieve all records
result, err = resultModel.doGetAll(m.GetCtx(), SelectTypeDefault, false)
return
}
// Chunk iterates the query result with given `size` and `handler` function.
func (m *Model) Chunk(size int, handler ChunkHandler) {
page := m.start
if page <= 0 {
page = 1
}
model := m
for {
model = model.Page(page, size)
data, err := model.All()
if err != nil {
handler(nil, err)
break
}
if len(data) == 0 {
break
}
if !handler(data, err) {
break
}
if len(data) < size {
break
}
page++
}
}
// One retrieves one record from table and returns the result as map type.
// It returns nil if there's no record retrieved with the given conditions from table.
//
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) One(where ...any) (Record, error) {
var ctx = m.GetCtx()
if len(where) > 0 {
return m.Where(where[0], where[1:]...).One()
}
all, err := m.doGetAll(ctx, SelectTypeDefault, true)
if err != nil {
return nil, err
}
if len(all) > 0 {
return all[0], nil
}
return nil, nil
}
// Array queries and returns data values as slice from database.
// Note that if there are multiple columns in the result, it returns just one column values randomly.
//
// If the optional parameter `fieldsAndWhere` is given, the fieldsAndWhere[0] is the selected fields
// and fieldsAndWhere[1:] is treated as where condition fields.
// Also see Model.Fields and Model.Where functions.
func (m *Model) Array(fieldsAndWhere ...any) (Array, error) {
if len(fieldsAndWhere) > 0 {
if len(fieldsAndWhere) > 2 {
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Array()
} else if len(fieldsAndWhere) == 2 {
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1]).Array()
} else {
return m.Fields(gconv.String(fieldsAndWhere[0])).Array()
}
}
var (
field string
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
all, err := m.doGetAll(ctx, SelectTypeArray, false)
if err != nil {
return nil, err
}
if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return nil, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
field = internalData.FirstResultColumn
if field == "" {
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
field = recordFields[0]
} else {
// it returns error if there are multiple fields in the result record.
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
)
}
}
}
return all.Array(field), nil
}
// Struct retrieves one record from table and converts it into given struct.
// The parameter `pointer` should be type of *struct/**struct. If type **struct is given,
// it can create the struct internally during converting.
//
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
//
// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has
// default value and there's no record retrieved with the given conditions from table.
//
// Example:
// user := new(User)
// err := db.Model("user").Where("id", 1).Scan(user)
//
// user := (*User)(nil)
// err := db.Model("user").Where("id", 1).Scan(&user).
func (m *Model) doStruct(pointer any, where ...any) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(v.Interface())
} else {
model = m.Fields(pointer)
}
}
one, err := model.One(where...)
if err != nil {
return err
}
if err = one.Struct(pointer); err != nil {
return err
}
return model.doWithScanStruct(pointer)
}
// Structs retrieves records from table and converts them into given struct slice.
// The parameter `pointer` should be type of *[]struct/*[]*struct. It can create and fill the struct
// slice internally during converting.
//
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
//
// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has
// default value and there's no record retrieved with the given conditions from table.
//
// Example:
// users := ([]User)(nil)
// err := db.Model("user").Scan(&users)
//
// users := ([]*User)(nil)
// err := db.Model("user").Scan(&users).
func (m *Model) doStructs(pointer any, where ...any) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(
reflect.New(
v.Type().Elem(),
).Interface(),
)
} else {
model = m.Fields(
reflect.New(
reflect.ValueOf(pointer).Elem().Type().Elem(),
).Interface(),
)
}
}
all, err := model.All(where...)
if err != nil {
return err
}
if err = all.Structs(pointer); err != nil {
return err
}
return model.doWithScanStructs(pointer)
}
// Scan automatically calls Struct or Structs function according to the type of parameter `pointer`.
// It calls function doStruct if `pointer` is type of *struct/**struct.
// It calls function doStructs if `pointer` is type of *[]struct/*[]*struct.
//
// The optional parameter `where` is the same as the parameter of Model.Where function, see Model.Where.
//
// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has
// default value and there's no record retrieved with the given conditions from table.
//
// Example:
// user := new(User)
// err := db.Model("user").Where("id", 1).Scan(user)
//
// user := (*User)(nil)
// err := db.Model("user").Where("id", 1).Scan(&user)
//
// users := ([]User)(nil)
// err := db.Model("user").Scan(&users)
//
// users := ([]*User)(nil)
// err := db.Model("user").Scan(&users).
func (m *Model) Scan(pointer any, where ...any) error {
reflectInfo := reflection.OriginTypeAndKind(pointer)
if reflectInfo.InputKind != reflect.Pointer {
return gerror.NewCode(
gcode.CodeInvalidParameter,
`the parameter "pointer" for function Scan should type of pointer`,
)
}
switch reflectInfo.OriginKind {
case reflect.Slice, reflect.Array:
return m.doStructs(pointer, where...)
case reflect.Struct, reflect.Invalid:
return m.doStruct(pointer, where...)
default:
return gerror.NewCode(
gcode.CodeInvalidParameter,
`element of parameter "pointer" for function Scan should type of struct/*struct/[]struct/[]*struct`,
)
}
}
// ScanAndCount scans a single record or record array that matches the given conditions and counts the total number
// of records that match those conditions.
//
// If `useFieldForCount` is true, it will use the fields specified in the model for counting;
// The `pointer` parameter is a pointer to a struct that the scanned data will be stored in.
// The `totalCount` parameter is a pointer to an integer that will be set to the total number of records that match the given conditions.
// The where parameter is an optional list of conditions to use when retrieving records.
//
// Example:
//
// var count int
// user := new(User)
// err := db.Model("user").Where("id", 1).ScanAndCount(user,&count,true)
// fmt.Println(user, count)
//
// Example Join:
//
// type User struct {
// Id int
// Passport string
// Name string
// Age int
// }
// var users []User
// var count int
// db.Model(table).As("u1").
// LeftJoin(tableName2, "u2", "u2.id=u1.id").
// Fields("u1.passport,u1.id,u2.name,u2.age").
// Where("u1.id<2").
// ScanAndCount(&users, &count, false)
func (m *Model) ScanAndCount(pointer any, totalCount *int, useFieldForCount bool) (err error) {
// support Fields with *, example: .Fields("a.*, b.name"). Count sql is select count(1) from xxx
countModel := m.Clone()
// Decide how to build the COUNT() expression:
// - If caller explicitly wants to use the single field expression for counting,
// honor it (e.g. Fields("DISTINCT col") with useFieldForCount = true).
// - Otherwise, clear fields to let Count() use its default COUNT(1),
// avoiding invalid COUNT(field1, field2, ...) with multiple fields,
// or incorrect COUNT(DISTINCT 1) when Distinct() is set.
if useFieldForCount && len(m.fields) == 1 {
countModel.fields = m.fields
} else {
countModel.fields = nil
}
if len(m.pageCacheOption) > 0 {
countModel = countModel.Cache(m.pageCacheOption[0])
}
// Get the total count of records
*totalCount, err = countModel.Count()
if err != nil {
return err
}
// If the total count is 0, there are no records to retrieve, so return early
if *totalCount == 0 {
return
}
scanModel := m.Clone()
if len(m.pageCacheOption) > 1 {
scanModel = scanModel.Cache(m.pageCacheOption[1])
}
err = scanModel.Scan(pointer)
return
}
// ScanList converts `r` to struct slice which contains other complex struct attributes.
// Note that the parameter `listPointer` should be type of *[]struct/*[]*struct.
//
// See Result.ScanList.
func (m *Model) ScanList(structSlicePointer any, bindToAttrName string, relationAttrNameAndFields ...string) (err error) {
var result Result
out, err := checkGetSliceElementInfoForScanList(structSlicePointer, bindToAttrName)
if err != nil {
return err
}
if len(m.fields) > 0 || len(m.fieldsEx) != 0 {
// There are custom fields.
result, err = m.All()
} else {
// Filter fields using temporary created struct using reflect.New.
result, err = m.Fields(reflect.New(out.BindToAttrType).Interface()).All()
}
if err != nil {
return err
}
var (
relationAttrName string
relationFields string
)
switch len(relationAttrNameAndFields) {
case 2:
relationAttrName = relationAttrNameAndFields[0]
relationFields = relationAttrNameAndFields[1]
case 1:
relationFields = relationAttrNameAndFields[0]
}
return doScanList(doScanListInput{
Model: m,
Result: result,
StructSlicePointer: structSlicePointer,
StructSliceValue: out.SliceReflectValue,
BindToAttrName: bindToAttrName,
RelationAttrName: relationAttrName,
RelationFields: relationFields,
})
}
// Value retrieves a specified record value from table and returns the result as interface type.
// It returns nil if there's no record found with the given conditions from table.
//
// If the optional parameter `fieldsAndWhere` is given, the fieldsAndWhere[0] is the selected fields
// and fieldsAndWhere[1:] is treated as where condition fields.
// Also see Model.Fields and Model.Where functions.
func (m *Model) Value(fieldsAndWhere ...any) (Value, error) {
var (
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
if len(fieldsAndWhere) > 0 {
if len(fieldsAndWhere) > 2 {
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value()
} else if len(fieldsAndWhere) == 2 {
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1]).Value()
} else {
return m.Fields(gconv.String(fieldsAndWhere[0])).Value()
}
}
var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true)
all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...)
)
if err != nil {
return nil, err
}
if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return nil, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v, nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
for _, v := range all[0] {
return v, nil
}
}
// it returns error if there are multiple fields in the result record.
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Value" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
)
}
return nil, nil
}
func (m *Model) getRecordFields(record Record) []string {
if len(record) == 0 {
return nil
}
var fields = make([]string, 0)
for k := range record {
fields = append(fields, k)
}
return fields
}
// Count does "SELECT COUNT(x) FROM ..." statement for the model.
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) Count(where ...any) (int, error) {
var (
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
if len(where) > 0 {
return m.Where(where[0], where[1:]...).Count()
}
var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false)
all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...)
)
if err != nil {
return 0, err
}
if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return 0, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v.Int(), nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
for _, v := range all[0] {
return v.Int(), nil
}
}
// it returns error if there are multiple fields in the result record.
return 0, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Count" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
)
}
return 0, nil
}
// Exist does "SELECT 1 FROM ... LIMIT 1" statement for the model.
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) Exist(where ...any) (bool, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).Exist()
}
one, err := m.Fields(Raw("1")).One()
if err != nil {
return false, err
}
for _, val := range one {
if val.Bool() {
return true, nil
}
}
return false, nil
}
// CountColumn does "SELECT COUNT(x) FROM ..." statement for the model.
func (m *Model) CountColumn(column string) (int, error) {
if len(column) == 0 {
return 0, nil
}
return m.Fields(column).Count()
}
// Min does "SELECT MIN(x) FROM ..." statement for the model.
func (m *Model) Min(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`MIN(%s)`, m.QuoteWord(column))).Value()
if err != nil {
return 0, err
}
return value.Float64(), err
}
// Max does "SELECT MAX(x) FROM ..." statement for the model.
func (m *Model) Max(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`MAX(%s)`, m.QuoteWord(column))).Value()
if err != nil {
return 0, err
}
return value.Float64(), err
}
// Avg does "SELECT AVG(x) FROM ..." statement for the model.
func (m *Model) Avg(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`AVG(%s)`, m.QuoteWord(column))).Value()
if err != nil {
return 0, err
}
return value.Float64(), err
}
// Sum does "SELECT SUM(x) FROM ..." statement for the model.
func (m *Model) Sum(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`SUM(%s)`, m.QuoteWord(column))).Value()
if err != nil {
return 0, err
}
return value.Float64(), err
}
// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement for the model.
func (m *Model) Union(unions ...*Model) *Model {
return m.db.Union(unions...)
}
// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement for the model.
func (m *Model) UnionAll(unions ...*Model) *Model {
return m.db.UnionAll(unions...)
}
// Limit sets the "LIMIT" statement for the model.
// The parameter `limit` can be either one or two number, if passed two number is passed,
// it then sets "LIMIT limit[0],limit[1]" statement for the model, or else it sets "LIMIT limit[0]"
// statement.
// Note: Negative values are treated as zero.
func (m *Model) Limit(limit ...int) *Model {
model := m.getModel()
switch len(limit) {
case 1:
if limit[0] < 0 {
limit[0] = 0
}
model.limit = limit[0]
case 2:
if limit[0] < 0 {
limit[0] = 0
}
if limit[1] < 0 {
limit[1] = 0
}
model.start = limit[0]
model.limit = limit[1]
}
return model
}
// Offset sets the "OFFSET" statement for the model.
// It only makes sense for some databases like SQLServer, PostgreSQL, etc.
// Note: Negative values are treated as zero.
func (m *Model) Offset(offset int) *Model {
model := m.getModel()
if offset < 0 {
offset = 0
}
model.offset = offset
return model
}
// Distinct forces the query to only return distinct results.
func (m *Model) Distinct() *Model {
model := m.getModel()
model.distinct = "DISTINCT "
return model
}
// Page sets the paging number for the model.
// The parameter `page` is started from 1 for paging.
// Note that, it differs that the Limit function starts from 0 for "LIMIT" statement.
// Note: Negative limit values are treated as zero.
func (m *Model) Page(page, limit int) *Model {
model := m.getModel()
if page <= 0 {
page = 1
}
if limit < 0 {
limit = 0
}
model.start = (page - 1) * limit
model.limit = limit
return model
}
// Having sets the having statement for the model.
// The parameters of this function usage are as the same as function Where.
// See Where.
func (m *Model) Having(having any, args ...any) *Model {
model := m.getModel()
model.having = []any{
having, args,
}
return model
}
// doGetAll does "SELECT FROM ..." statement for the model.
// It retrieves the records from table and returns the result as slice type.
// It returns nil if there's no record retrieved with the given conditions from table.
//
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...any) (Result, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).All()
}
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1)
return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...)
}
// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(
ctx context.Context, selectType SelectType, sql string, args ...any,
) (result Result, err error) {
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
return
}
in := &HookSelectInput{
internalParamHookSelect: internalParamHookSelect{
internalParamHook: internalParamHook{
link: m.getLink(false),
},
handler: m.hookHandler.Select,
},
Model: m,
Table: m.tables,
Schema: m.schema,
Sql: sql,
Args: m.mergeArguments(args),
SelectType: selectType,
}
if result, err = in.Next(ctx); err != nil {
return
}
err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...)
return
}
func (m *Model) getFormattedSqlAndArgs(
ctx context.Context, selectType SelectType, limit1 bool,
) (sqlWithHolder string, holderArgs []any) {
switch selectType {
case SelectTypeCount:
queryFields := "COUNT(1)"
if len(m.fields) > 0 {
// DO NOT quote the m.fields here, in case of fields like:
// DISTINCT t.user_id uid
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.getFieldsAsStr())
}
// Raw SQL Model.
if m.rawSql != "" {
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true)
sqlWithHolder = fmt.Sprintf(
"SELECT %s FROM (%s%s) AS T",
queryFields, m.rawSql, conditionWhere+conditionExtra,
)
return sqlWithHolder, conditionArgs
}
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true)
sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra)
if len(m.groupBy) > 0 {
sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder)
}
return sqlWithHolder, conditionArgs
default:
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, limit1, false)
// Raw SQL Model, especially for UNION/UNION ALL featured SQL.
if m.rawSql != "" {
sqlWithHolder = fmt.Sprintf(
"%s%s",
m.rawSql,
conditionWhere+conditionExtra,
)
return sqlWithHolder, conditionArgs
}
// DO NOT quote the m.fields where, in case of fields like:
// DISTINCT t.user_id uid
sqlWithHolder = fmt.Sprintf(
"SELECT %s%s FROM %s%s",
m.distinct, m.getFieldsFiltered(), m.tables, conditionWhere+conditionExtra,
)
return sqlWithHolder, conditionArgs
}
}
func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []any) {
holder, args = m.getFormattedSqlAndArgs(
ctx, SelectTypeDefault, false,
)
args = m.mergeArguments(args)
return
}
func (m *Model) getAutoPrefix() string {
autoPrefix := ""
if gstr.Contains(m.tables, " JOIN ") {
autoPrefix = m.QuoteWord(
m.db.GetCore().guessPrimaryTableName(m.tablesInit),
)
}
return autoPrefix
}
func (m *Model) getFieldsAsStr() string {
var (
fieldsStr string
)
for _, v := range m.fields {
field := gconv.String(v)
switch {
case gstr.ContainsAny(field, "()"):
case gstr.ContainsAny(field, ". "):
default:
switch v.(type) {
case Raw, *Raw:
default:
field = m.QuoteWord(field)
}
}
if fieldsStr != "" {
fieldsStr += ","
}
fieldsStr += field
}
return fieldsStr
}
// getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will
// really be committed to underlying database driver.
func (m *Model) getFieldsFiltered() string {
if len(m.fieldsEx) == 0 && len(m.fields) == 0 {
return defaultField
}
if len(m.fieldsEx) == 0 && len(m.fields) > 0 {
return m.getFieldsAsStr()
}
var (
fieldsArray []string
fieldsExSet = gset.NewStrSetFrom(gconv.Strings(m.fieldsEx))
)
if len(m.fields) > 0 {
// Filter custom fields with fieldEx.
fieldsArray = make([]string, 0, 8)
for _, v := range m.fields {
field := gconv.String(v)
fieldsArray = append(fieldsArray, field[gstr.PosR(field, "-")+1:])
}
} else {
if gstr.Contains(m.tables, " ") {
panic("function FieldsEx supports only single table operations")
}
// Filter table fields with fieldEx.
tableFields, err := m.TableFields(m.tablesInit)
if err != nil {
panic(err)
}
if len(tableFields) == 0 {
panic(fmt.Sprintf(`empty table fields for table "%s"`, m.tables))
}
fieldsArray = make([]string, len(tableFields))
for k, v := range tableFields {
fieldsArray[v.Index] = k
}
}
newFields := ""
for _, k := range fieldsArray {
if fieldsExSet.Contains(k) {
continue
}
if len(newFields) > 0 {
newFields += ","
}
newFields += m.QuoteWord(k)
}
return newFields
}
// formatCondition formats where arguments of the model and returns a new condition sql and its arguments.
// Note that this function does not change any attribute value of the `m`.
//
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
func (m *Model) formatCondition(
ctx context.Context, limit1 bool, isCountStatement bool,
) (conditionWhere string, conditionExtra string, conditionArgs []any) {
var autoPrefix = m.getAutoPrefix()
// GROUP BY.
if m.groupBy != "" {
conditionExtra += " GROUP BY " + m.groupBy
}
// WHERE
conditionWhere, conditionArgs = m.whereBuilder.Build()
softDeletingCondition := m.softTimeMaintainer().GetDeleteCondition(ctx)
if m.rawSql != "" && conditionWhere != "" {
if gstr.ContainsI(m.rawSql, " WHERE ") {
conditionWhere = " AND " + conditionWhere
} else {
conditionWhere = " WHERE " + conditionWhere
}
} else if !m.unscoped && softDeletingCondition != "" {
if conditionWhere == "" {
conditionWhere = fmt.Sprintf(` WHERE %s`, softDeletingCondition)
} else {
conditionWhere = fmt.Sprintf(` WHERE (%s) AND %s`, conditionWhere, softDeletingCondition)
}
} else {
if conditionWhere != "" {
conditionWhere = " WHERE " + conditionWhere
}
}
// HAVING.
if len(m.having) > 0 {
havingHolder := WhereHolder{
Where: m.having[0],
Args: gconv.Interfaces(m.having[1]),
Prefix: autoPrefix,
}
havingStr, havingArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
WhereHolder: havingHolder,
OmitNil: m.option&optionOmitNilWhere > 0,
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
Schema: m.schema,
Table: m.tables,
})
if len(havingStr) > 0 {
conditionExtra += " HAVING " + havingStr
conditionArgs = append(conditionArgs, havingArgs...)
}
}
// ORDER BY.
if !isCountStatement { // The count statement of sqlserver cannot contain the order by statement
if m.orderBy != "" {
conditionExtra += " ORDER BY " + m.orderBy
}
}
// LIMIT.
if !isCountStatement {
if m.limit != 0 {
if m.start >= 0 {
conditionExtra += fmt.Sprintf(" LIMIT %d,%d", m.start, m.limit)
} else {
conditionExtra += fmt.Sprintf(" LIMIT %d", m.limit)
}
} else if limit1 {
conditionExtra += " LIMIT 1"
}
if m.offset >= 0 {
conditionExtra += fmt.Sprintf(" OFFSET %d", m.offset)
}
}
if m.lockInfo != "" {
conditionExtra += " " + m.lockInfo
}
return
}

View File

@ -0,0 +1,161 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"fmt"
"hash/fnv"
"reflect"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/util/gconv"
)
// ShardingConfig defines the configuration for database/table sharding.
type ShardingConfig struct {
// Table sharding configuration
Table ShardingTableConfig
// Schema sharding configuration
Schema ShardingSchemaConfig
}
// ShardingSchemaConfig defines the configuration for database sharding.
type ShardingSchemaConfig struct {
// Enable schema sharding
Enable bool
// Schema rule prefix, e.g., "db_"
Prefix string
// ShardingRule defines how to route data to different database nodes
Rule ShardingRule
}
// ShardingTableConfig defines the configuration for table sharding
type ShardingTableConfig struct {
// Enable table sharding
Enable bool
// Table rule prefix, e.g., "user_"
Prefix string
// ShardingRule defines how to route data to different tables
Rule ShardingRule
}
// ShardingRule defines the interface for sharding rules
type ShardingRule interface {
// SchemaName returns the target schema name based on sharding value.
SchemaName(ctx context.Context, config ShardingSchemaConfig, value any) (string, error)
// TableName returns the target table name based on sharding value.
TableName(ctx context.Context, config ShardingTableConfig, value any) (string, error)
}
// DefaultShardingRule implements a simple modulo-based sharding rule
type DefaultShardingRule struct {
// Number of schema count.
SchemaCount int
// Number of tables per schema.
TableCount int
}
// Sharding creates a sharding model with given sharding configuration.
func (m *Model) Sharding(config ShardingConfig) *Model {
model := m.getModel()
model.shardingConfig = config
return model
}
// ShardingValue sets the sharding value for routing
func (m *Model) ShardingValue(value any) *Model {
model := m.getModel()
model.shardingValue = value
return model
}
// getActualSchema returns the actual schema based on sharding configuration.
// TODO it does not support schemas in different database config node.
func (m *Model) getActualSchema(ctx context.Context, defaultSchema string) (string, error) {
if !m.shardingConfig.Schema.Enable {
return defaultSchema, nil
}
if m.shardingValue == nil {
return defaultSchema, gerror.NewCode(
gcode.CodeInvalidParameter, "sharding value is required when sharding feature enabled",
)
}
if m.shardingConfig.Schema.Rule == nil {
return defaultSchema, gerror.NewCode(
gcode.CodeInvalidParameter, "sharding rule is required when sharding feature enabled",
)
}
return m.shardingConfig.Schema.Rule.SchemaName(ctx, m.shardingConfig.Schema, m.shardingValue)
}
// getActualTable returns the actual table name based on sharding configuration
func (m *Model) getActualTable(ctx context.Context, defaultTable string) (string, error) {
if !m.shardingConfig.Table.Enable {
return defaultTable, nil
}
if m.shardingValue == nil {
return defaultTable, gerror.NewCode(
gcode.CodeInvalidParameter, "sharding value is required when sharding feature enabled",
)
}
if m.shardingConfig.Table.Rule == nil {
return defaultTable, gerror.NewCode(
gcode.CodeInvalidParameter, "sharding rule is required when sharding feature enabled",
)
}
return m.shardingConfig.Table.Rule.TableName(ctx, m.shardingConfig.Table, m.shardingValue)
}
// SchemaName implements the default database sharding strategy
func (r *DefaultShardingRule) SchemaName(ctx context.Context, config ShardingSchemaConfig, value any) (string, error) {
if r.SchemaCount == 0 {
return "", gerror.NewCode(
gcode.CodeInvalidParameter, "schema count should not be 0 using DefaultShardingRule when schema sharding enabled",
)
}
hashValue, err := getHashValue(value)
if err != nil {
return "", err
}
nodeIndex := hashValue % uint64(r.SchemaCount)
return fmt.Sprintf("%s%d", config.Prefix, nodeIndex), nil
}
// TableName implements the default table sharding strategy
func (r *DefaultShardingRule) TableName(ctx context.Context, config ShardingTableConfig, value any) (string, error) {
if r.TableCount == 0 {
return "", gerror.NewCode(
gcode.CodeInvalidParameter, "table count should not be 0 using DefaultShardingRule when table sharding enabled",
)
}
hashValue, err := getHashValue(value)
if err != nil {
return "", err
}
tableIndex := hashValue % uint64(r.TableCount)
return fmt.Sprintf("%s%d", config.Prefix, tableIndex), nil
}
// getHashValue converts sharding value to uint64 hash
func getHashValue(value any) (uint64, error) {
var rv = reflect.ValueOf(value)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
return gconv.Uint64(value), nil
default:
h := fnv.New64a()
_, err := h.Write(gconv.Bytes(value))
if err != nil {
return 0, gerror.WrapCode(gcode.CodeInternalError, err)
}
return h.Sum64(), nil
}
}

View File

@ -0,0 +1,384 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"fmt"
"strings"
"git.magicany.cc/black1552/gin-base/database/intlog"
"git.magicany.cc/black1552/gin-base/database/utils"
"github.com/gogf/gf/v2/container/garray"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gcache"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
)
// SoftTimeType custom defines the soft time field type.
type SoftTimeType int
const (
SoftTimeTypeAuto SoftTimeType = 0 // (Default)Auto detect the field type by table field type.
SoftTimeTypeTime SoftTimeType = 1 // Using datetime as the field value.
SoftTimeTypeTimestamp SoftTimeType = 2 // In unix seconds.
SoftTimeTypeTimestampMilli SoftTimeType = 3 // In unix milliseconds.
SoftTimeTypeTimestampMicro SoftTimeType = 4 // In unix microseconds.
SoftTimeTypeTimestampNano SoftTimeType = 5 // In unix nanoseconds.
)
// SoftTimeOption is the option to customize soft time feature for Model.
type SoftTimeOption struct {
SoftTimeType SoftTimeType // The value type for soft time field.
}
type softTimeMaintainer struct {
*Model
}
// SoftTimeFieldType represents different soft time field purposes.
type SoftTimeFieldType int
const (
SoftTimeFieldCreate SoftTimeFieldType = iota
SoftTimeFieldUpdate
SoftTimeFieldDelete
)
type iSoftTimeMaintainer interface {
// GetFieldInfo returns field name and type for specified field purpose.
GetFieldInfo(ctx context.Context, schema, table string, fieldPurpose SoftTimeFieldType) (fieldName string, localType LocalType)
// GetFieldValue generates value for create/update/delete operations.
GetFieldValue(ctx context.Context, localType LocalType, isDeleted bool) any
// GetDeleteCondition returns WHERE condition for soft delete query.
GetDeleteCondition(ctx context.Context) string
// GetDeleteData returns UPDATE statement data for soft delete.
GetDeleteData(ctx context.Context, prefix, fieldName string, localType LocalType) (holder string, value any)
}
// getSoftFieldNameAndTypeCacheItem is the internal struct for storing create/update/delete fields.
type getSoftFieldNameAndTypeCacheItem struct {
FieldName string
FieldType LocalType
}
var (
// Default field names of table for automatic-filled for record creating.
createdFieldNames = []string{"created_at", "create_at"}
// Default field names of table for automatic-filled for record updating.
updatedFieldNames = []string{"updated_at", "update_at"}
// Default field names of table for automatic-filled for record deleting.
deletedFieldNames = []string{"deleted_at", "delete_at"}
)
// SoftTime sets the SoftTimeOption to customize soft time feature for Model.
func (m *Model) SoftTime(option SoftTimeOption) *Model {
model := m.getModel()
model.softTimeOption = option
return model
}
// Unscoped disables the soft time feature for insert, update and delete operations.
func (m *Model) Unscoped() *Model {
model := m.getModel()
model.unscoped = true
return model
}
func (m *Model) softTimeMaintainer() iSoftTimeMaintainer {
return &softTimeMaintainer{
m,
}
}
// GetFieldInfo returns field name and type for specified field purpose.
// It checks the key with or without cases or chars '-'/'_'/'.'/' '.
func (m *softTimeMaintainer) GetFieldInfo(
ctx context.Context, schema, table string, fieldPurpose SoftTimeFieldType,
) (fieldName string, localType LocalType) {
// Check if feature is disabled
if m.db.GetConfig().TimeMaintainDisabled {
return "", LocalTypeUndefined
}
// Determine table name
tableName := table
if tableName == "" {
tableName = m.tablesInit
}
// Get config and field candidates
config := m.db.GetConfig()
var (
configField string
defaultFields []string
)
switch fieldPurpose {
case SoftTimeFieldCreate:
configField = config.CreatedAt
defaultFields = createdFieldNames
case SoftTimeFieldUpdate:
configField = config.UpdatedAt
defaultFields = updatedFieldNames
case SoftTimeFieldDelete:
configField = config.DeletedAt
defaultFields = deletedFieldNames
}
// Use config field if specified, otherwise use defaults
if configField != "" {
return m.getSoftFieldNameAndType(ctx, schema, tableName, []string{configField})
}
return m.getSoftFieldNameAndType(ctx, schema, tableName, defaultFields)
}
// getSoftFieldNameAndType retrieves and returns the field name of the table for possible key.
func (m *softTimeMaintainer) getSoftFieldNameAndType(
ctx context.Context, schema, table string, candidateFields []string,
) (fieldName string, fieldType LocalType) {
// Build cache key
cacheKey := genSoftTimeFieldNameTypeCacheKey(schema, table, candidateFields)
// Try to get from cache
cache := m.db.GetCore().GetInnerMemCache()
result, err := cache.GetOrSetFunc(ctx, cacheKey, func(ctx context.Context) (any, error) {
// Get table fields
fieldsMap, err := m.TableFields(table, schema)
if err != nil || len(fieldsMap) == 0 {
return nil, err
}
// Search for matching field
for _, field := range candidateFields {
if name := searchFieldNameFromMap(fieldsMap, field); name != "" {
fType, _ := m.db.CheckLocalTypeForField(ctx, fieldsMap[name].Type, nil)
return getSoftFieldNameAndTypeCacheItem{
FieldName: name,
FieldType: fType,
}, nil
}
}
return nil, nil
}, gcache.DurationNoExpire)
if err != nil || result == nil {
return "", LocalTypeUndefined
}
item := result.Val().(getSoftFieldNameAndTypeCacheItem)
return item.FieldName, item.FieldType
}
func searchFieldNameFromMap(fieldsMap map[string]*TableField, key string) string {
if len(fieldsMap) == 0 {
return ""
}
_, ok := fieldsMap[key]
if ok {
return key
}
key = utils.RemoveSymbols(key)
for k := range fieldsMap {
if strings.EqualFold(utils.RemoveSymbols(k), key) {
return k
}
}
return ""
}
// GetDeleteCondition returns WHERE condition for soft delete query.
// It supports multiple tables string like:
// "user u, user_detail ud"
// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid)"
// "user LEFT JOIN user_detail ON(user_detail.uid=user.uid)"
// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid) LEFT JOIN user_stats us ON(us.uid=u.uid)".
func (m *softTimeMaintainer) GetDeleteCondition(ctx context.Context) string {
if m.unscoped {
return ""
}
conditionArray := garray.NewStrArray()
if gstr.Contains(m.tables, " JOIN ") {
// Base table.
tableMatch, _ := gregex.MatchString(`(.+?) [A-Z]+ JOIN`, m.tables)
conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, tableMatch[1]))
// Multiple joined tables, exclude the sub query sql which contains char '(' and ')'.
tableMatches, _ := gregex.MatchAllString(`JOIN ([^()]+?) ON`, m.tables)
for _, match := range tableMatches {
conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, match[1]))
}
}
if conditionArray.Len() == 0 && gstr.Contains(m.tables, ",") {
// Multiple base tables.
for _, s := range gstr.SplitAndTrim(m.tables, ",") {
conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, s))
}
}
conditionArray.FilterEmpty()
if conditionArray.Len() > 0 {
return conditionArray.Join(" AND ")
}
// Only one table.
fieldName, fieldType := m.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete)
if fieldName != "" {
return m.buildDeleteCondition(ctx, "", fieldName, fieldType)
}
return ""
}
// getConditionOfTableStringForSoftDeleting does something as its name describes.
// Examples for `s`:
// - `test`.`demo` as b
// - `test`.`demo` b
// - `demo`
// - demo
func (m *softTimeMaintainer) getConditionOfTableStringForSoftDeleting(ctx context.Context, s string) string {
var (
table string
schema string
array1 = gstr.SplitAndTrim(s, " ")
array2 = gstr.SplitAndTrim(array1[0], ".")
)
if len(array2) >= 2 {
table = array2[1]
schema = array2[0]
} else {
table = array2[0]
}
fieldName, fieldType := m.GetFieldInfo(ctx, schema, table, SoftTimeFieldDelete)
if fieldName == "" {
return ""
}
if len(array1) >= 3 {
return m.buildDeleteCondition(ctx, array1[2], fieldName, fieldType)
}
if len(array1) >= 2 {
return m.buildDeleteCondition(ctx, array1[1], fieldName, fieldType)
}
return m.buildDeleteCondition(ctx, table, fieldName, fieldType)
}
// GetDeleteData returns UPDATE statement data for soft delete.
func (m *softTimeMaintainer) GetDeleteData(
ctx context.Context, prefix, fieldName string, fieldType LocalType,
) (holder string, value any) {
core := m.db.GetCore()
quotedName := core.QuoteWord(fieldName)
if prefix != "" {
quotedName = fmt.Sprintf(`%s.%s`, core.QuoteWord(prefix), quotedName)
}
holder = fmt.Sprintf(`%s=?`, quotedName)
value = m.GetFieldValue(ctx, fieldType, false)
return
}
// buildDeleteCondition builds WHERE condition for soft delete filtering.
func (m *softTimeMaintainer) buildDeleteCondition(
ctx context.Context, prefix, fieldName string, fieldType LocalType,
) string {
core := m.db.GetCore()
quotedName := core.QuoteWord(fieldName)
if prefix != "" {
quotedName = fmt.Sprintf(`%s.%s`, core.QuoteWord(prefix), quotedName)
}
switch m.softTimeOption.SoftTimeType {
case SoftTimeTypeAuto:
switch fieldType {
case LocalTypeDate, LocalTypeTime, LocalTypeDatetime:
return fmt.Sprintf(`%s IS NULL`, quotedName)
case LocalTypeInt, LocalTypeUint, LocalTypeInt64, LocalTypeUint64, LocalTypeBool:
return fmt.Sprintf(`%s=0`, quotedName)
default:
intlog.Errorf(ctx, `invalid field type "%s" for soft delete condition: prefix=%s, field=%s`, fieldType, prefix, fieldName)
return ""
}
case SoftTimeTypeTime:
return fmt.Sprintf(`%s IS NULL`, quotedName)
default:
return fmt.Sprintf(`%s=0`, quotedName)
}
}
// GetFieldValue generates value for create/update/delete operations.
func (m *softTimeMaintainer) GetFieldValue(
ctx context.Context, fieldType LocalType, isDeleted bool,
) any {
// For deleted field, return "empty" value
if isDeleted {
return m.getEmptyValue(fieldType)
}
// For create/update/delete, return current time value
switch m.softTimeOption.SoftTimeType {
case SoftTimeTypeAuto:
return m.getAutoValue(ctx, fieldType)
default:
switch fieldType {
case LocalTypeBool:
return 1
default:
return m.getTimestampValue()
}
}
}
// getTimestampValue returns timestamp value for soft time.
func (m *softTimeMaintainer) getTimestampValue() any {
switch m.softTimeOption.SoftTimeType {
case SoftTimeTypeTime:
return gtime.Now()
case SoftTimeTypeTimestamp:
return gtime.Timestamp()
case SoftTimeTypeTimestampMilli:
return gtime.TimestampMilli()
case SoftTimeTypeTimestampMicro:
return gtime.TimestampMicro()
case SoftTimeTypeTimestampNano:
return gtime.TimestampNano()
default:
panic(gerror.NewCodef(
gcode.CodeInternalPanic,
`unrecognized SoftTimeType "%d"`, m.softTimeOption.SoftTimeType,
))
}
}
// getEmptyValue returns "empty" value for deleted field.
func (m *softTimeMaintainer) getEmptyValue(fieldType LocalType) any {
switch fieldType {
case LocalTypeDate, LocalTypeTime, LocalTypeDatetime:
return nil
default:
return 0
}
}
// getAutoValue returns auto-detected value based on field type.
func (m *softTimeMaintainer) getAutoValue(ctx context.Context, fieldType LocalType) any {
switch fieldType {
case LocalTypeDate, LocalTypeTime, LocalTypeDatetime:
return gtime.Now()
case LocalTypeInt, LocalTypeUint, LocalTypeInt64, LocalTypeUint64:
return gtime.Timestamp()
case LocalTypeBool:
return 1
default:
intlog.Errorf(ctx, `invalid field type "%s" for soft time auto value`, fieldType)
return nil
}
}

View File

@ -0,0 +1,42 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
)
// Transaction wraps the transaction logic using function `f`.
// It rollbacks the transaction and returns the error from function `f` if
// it returns non-nil error. It commits the transaction and returns nil if
// function `f` returns nil.
//
// Note that, you should not Commit or Rollback the transaction in function `f`
// as it is automatically handled by this function.
func (m *Model) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) {
if ctx == nil {
ctx = m.GetCtx()
}
if m.tx != nil {
return m.tx.Transaction(ctx, f)
}
return m.db.Transaction(ctx, f)
}
// TransactionWithOptions executes transaction with options.
// The parameter `opts` specifies the transaction options.
// The parameter `f` specifies the function that will be called within the transaction.
// If f returns error, the transaction will be rolled back, or else the transaction will be committed.
func (m *Model) TransactionWithOptions(ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error) (err error) {
if ctx == nil {
ctx = m.GetCtx()
}
if m.tx != nil {
return m.tx.Transaction(ctx, f)
}
return m.db.TransactionWithOptions(ctx, opts, f)
}

View File

@ -0,0 +1,139 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
"fmt"
"reflect"
"git.magicany.cc/black1552/gin-base/database/empty"
"git.magicany.cc/black1552/gin-base/database/intlog"
"git.magicany.cc/black1552/gin-base/database/reflection"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// Update does "UPDATE ... " statement for the model.
//
// If the optional parameter `dataAndWhere` is given, the dataAndWhere[0] is the updated data field,
// and dataAndWhere[1:] is treated as where condition fields.
// Also see Model.Data and Model.Where functions.
func (m *Model) Update(dataAndWhere ...any) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(dataAndWhere) > 0 {
if len(dataAndWhere) > 2 {
return m.Data(dataAndWhere[0]).Where(dataAndWhere[1], dataAndWhere[2:]...).Update()
} else if len(dataAndWhere) == 2 {
return m.Data(dataAndWhere[0]).Where(dataAndWhere[1]).Update()
} else {
return m.Data(dataAndWhere[0]).Update()
}
}
defer func() {
if err == nil {
m.checkAndRemoveSelectCache(ctx)
}
}()
if m.data == nil {
return nil, gerror.NewCode(gcode.CodeMissingParameter, "updating table with empty data")
}
var (
newData any
stm = m.softTimeMaintainer()
reflectInfo = reflection.OriginTypeAndKind(m.data)
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
conditionStr = conditionWhere + conditionExtra
fieldNameUpdate, fieldTypeUpdate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldUpdate)
)
if fieldNameUpdate != "" && (m.unscoped || m.isFieldInFieldsEx(fieldNameUpdate)) {
fieldNameUpdate = ""
}
newData, err = m.filterDataForInsertOrUpdate(m.data)
if err != nil {
return nil, err
}
switch reflectInfo.OriginKind {
case reflect.Map, reflect.Struct:
var dataMap = anyValueToMapBeforeToRecord(newData)
// Automatically update the record updating time.
if fieldNameUpdate != "" && empty.IsNil(dataMap[fieldNameUpdate]) {
dataValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false)
dataMap[fieldNameUpdate] = dataValue
}
newData = dataMap
default:
var updateStr = gconv.String(newData)
// Automatically update the record updating time.
if fieldNameUpdate != "" && !gstr.Contains(updateStr, fieldNameUpdate) {
dataValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false)
updateStr += fmt.Sprintf(`,%s=?`, fieldNameUpdate)
conditionArgs = append([]any{dataValue}, conditionArgs...)
}
newData = updateStr
}
if !gstr.ContainsI(conditionStr, " WHERE ") {
intlog.Printf(
ctx,
`sql condition string "%s" has no WHERE for UPDATE operation, fieldNameUpdate: %s`,
conditionStr, fieldNameUpdate,
)
return nil, gerror.NewCode(
gcode.CodeMissingParameter,
"there should be WHERE condition statement for UPDATE operation",
)
}
in := &HookUpdateInput{
internalParamHookUpdate: internalParamHookUpdate{
internalParamHook: internalParamHook{
link: m.getLink(true),
},
handler: m.hookHandler.Update,
},
Model: m,
Table: m.tables,
Schema: m.schema,
Data: newData,
Condition: conditionStr,
Args: m.mergeArguments(conditionArgs),
}
return in.Next(ctx)
}
// UpdateAndGetAffected performs update statement and returns the affected rows number.
func (m *Model) UpdateAndGetAffected(dataAndWhere ...any) (affected int64, err error) {
result, err := m.Update(dataAndWhere...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// Increment increments a column's value by a given amount.
// The parameter `amount` can be type of float or integer.
func (m *Model) Increment(column string, amount any) (sql.Result, error) {
return m.getModel().Data(column, &Counter{
Field: column,
Value: gconv.Float64(amount),
}).Update()
}
// Decrement decrements a column's value by a given amount.
// The parameter `amount` can be type of float or integer.
func (m *Model) Decrement(column string, amount any) (sql.Result, error) {
return m.getModel().Data(column, &Counter{
Field: column,
Value: -gconv.Float64(amount),
}).Update()
}

View File

@ -0,0 +1,314 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"time"
"git.magicany.cc/black1552/gin-base/database/empty"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
// QuoteWord checks given string `s` a word,
// if true it quotes `s` with security chars of the database
// and returns the quoted string; or else it returns `s` without any change.
//
// The meaning of a `word` can be considered as a column name.
func (m *Model) QuoteWord(s string) string {
return m.db.GetCore().QuoteWord(s)
}
// TableFields retrieves and returns the fields' information of specified table of current
// schema.
//
// Also see DriverMysql.TableFields.
func (m *Model) TableFields(tableStr string, schema ...string) (fields map[string]*TableField, err error) {
var (
ctx = m.GetCtx()
usedTable = m.db.GetCore().guessPrimaryTableName(tableStr)
usedSchema = gutil.GetOrDefaultStr(m.schema, schema...)
)
// Sharding feature.
usedSchema, err = m.getActualSchema(ctx, usedSchema)
if err != nil {
return nil, err
}
usedTable, err = m.getActualTable(ctx, usedTable)
if err != nil {
return nil, err
}
return m.db.TableFields(ctx, usedTable, usedSchema)
}
// getModel creates and returns a cloned model of current model if `safe` is true, or else it returns
// the current model.
func (m *Model) getModel() *Model {
if !m.safe {
return m
} else {
return m.Clone()
}
}
// mappingAndFilterToTableFields mappings and changes given field name to really table field name.
// Eg:
// ID -> id
// NICK_Name -> nickname.
func (m *Model) mappingAndFilterToTableFields(table string, fields []any, filter bool) []any {
var fieldsTable = table
if fieldsTable != "" {
hasTable, _ := m.db.GetCore().HasTable(fieldsTable)
if !hasTable {
if fieldsTable != m.tablesInit {
// Table/alias unknown (e.g., FieldsPrefix called before LeftJoin), skip filtering.
return fields
}
// HasTable cache miss for main table, fallback to use main table for field mapping.
fieldsTable = m.tablesInit
}
}
if fieldsTable == "" {
fieldsTable = m.tablesInit
}
fieldsMap, _ := m.TableFields(fieldsTable)
if len(fieldsMap) == 0 {
return fields
}
var outputFieldsArray = make([]any, 0)
fieldsKeyMap := make(map[string]any, len(fieldsMap))
for k := range fieldsMap {
fieldsKeyMap[k] = nil
}
for _, field := range fields {
var (
fieldStr = gconv.String(field)
inputFieldsArray []string
)
// Skip empty string fields.
if fieldStr == "" {
continue
}
switch {
case gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, fieldStr):
inputFieldsArray = append(inputFieldsArray, fieldStr)
case gregex.IsMatchString(regularFieldNameWithCommaRegPattern, fieldStr):
inputFieldsArray = gstr.SplitAndTrim(fieldStr, ",")
default:
// Example:
// user.id, user.name
// replace(concat_ws(',',lpad(s.id, 6, '0'),s.name),',','') `code`
outputFieldsArray = append(outputFieldsArray, field)
continue
}
for _, inputField := range inputFieldsArray {
if !gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, inputField) {
outputFieldsArray = append(outputFieldsArray, inputField)
continue
}
if _, ok := fieldsKeyMap[inputField]; !ok {
// Example:
// id, name
if foundKey, _ := gutil.MapPossibleItemByKey(fieldsKeyMap, inputField); foundKey != "" {
outputFieldsArray = append(outputFieldsArray, foundKey)
} else if !filter {
outputFieldsArray = append(outputFieldsArray, inputField)
}
} else {
outputFieldsArray = append(outputFieldsArray, inputField)
}
}
}
return outputFieldsArray
}
// filterDataForInsertOrUpdate does filter feature with data for inserting/updating operations.
// Note that, it does not filter list item, which is also type of map, for "omit empty" feature.
func (m *Model) filterDataForInsertOrUpdate(data any) (any, error) {
var err error
switch value := data.(type) {
case List:
var omitEmpty bool
if m.option&optionOmitNilDataList > 0 {
omitEmpty = true
}
for k, item := range value {
value[k], err = m.doMappingAndFilterForInsertOrUpdateDataMap(item, omitEmpty)
if err != nil {
return nil, err
}
}
return value, nil
case Map:
return m.doMappingAndFilterForInsertOrUpdateDataMap(value, true)
default:
return data, nil
}
}
// doMappingAndFilterForInsertOrUpdateDataMap does the filter features for map.
// Note that, it does not filter list item, which is also type of map, for "omit empty" feature.
func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEmpty bool) (Map, error) {
var (
err error
ctx = m.GetCtx()
core = m.db.GetCore()
schema = m.schema
table = m.tablesInit
)
// Sharding feature.
schema, err = m.getActualSchema(ctx, schema)
if err != nil {
return nil, err
}
table, err = m.getActualTable(ctx, table)
if err != nil {
return nil, err
}
data, err = core.mappingAndFilterData(
ctx, schema, table, data, m.filter,
)
if err != nil {
return nil, err
}
// Remove key-value pairs of which the value is nil.
if allowOmitEmpty && m.option&optionOmitNilData > 0 {
tempMap := make(Map, len(data))
for k, v := range data {
if empty.IsNil(v) {
continue
}
tempMap[k] = v
}
data = tempMap
}
// Remove key-value pairs of which the value is empty.
if allowOmitEmpty && m.option&optionOmitEmptyData > 0 {
tempMap := make(Map, len(data))
for k, v := range data {
if empty.IsEmpty(v) {
continue
}
// Special type filtering.
switch r := v.(type) {
case time.Time:
if r.IsZero() {
continue
}
case *time.Time:
if r.IsZero() {
continue
}
case gtime.Time:
if r.IsZero() {
continue
}
case *gtime.Time:
if r.IsZero() {
continue
}
}
tempMap[k] = v
}
data = tempMap
}
if len(m.fields) > 0 {
// Keep specified fields.
var (
fieldSet = gset.NewStrSetFrom(gconv.Strings(m.fields))
charL, charR = m.db.GetChars()
chars = charL + charR
)
fieldSet.Walk(func(item string) string {
return gstr.Trim(item, chars)
})
for k := range data {
k = gstr.Trim(k, chars)
if !fieldSet.Contains(k) {
delete(data, k)
}
}
} else if len(m.fieldsEx) > 0 {
// Filter specified fields.
for _, v := range m.fieldsEx {
delete(data, gconv.String(v))
}
}
return data, nil
}
// getLink returns the underlying database link object with configured `linkType` attribute.
// The parameter `master` specifies whether using the master node if master-slave configured.
func (m *Model) getLink(master bool) Link {
if m.tx != nil {
if sqlTx := m.tx.GetSqlTX(); sqlTx != nil {
return &txLink{sqlTx}
}
}
linkType := m.linkType
if linkType == 0 {
if master {
linkType = linkTypeMaster
} else {
linkType = linkTypeSlave
}
}
switch linkType {
case linkTypeMaster:
link, err := m.db.GetCore().MasterLink(m.schema)
if err != nil {
panic(err)
}
return link
case linkTypeSlave:
link, err := m.db.GetCore().SlaveLink(m.schema)
if err != nil {
panic(err)
}
return link
}
return nil
}
// getPrimaryKey retrieves and returns the primary key name of the model table.
// It parses m.tables to retrieve the primary table name, supporting m.tables like:
// "user", "user u", "user as u, user_detail as ud".
func (m *Model) getPrimaryKey() string {
table := gstr.SplitAndTrim(m.tablesInit, " ")[0]
tableFields, err := m.TableFields(table)
if err != nil {
return ""
}
for name, field := range tableFields {
if gstr.ContainsI(field.Key, "pri") {
return name
}
}
return ""
}
// mergeArguments creates and returns new arguments by merging `m.extraArgs` and given `args`.
func (m *Model) mergeArguments(args []any) []any {
if len(m.extraArgs) > 0 {
newArgs := make([]any, len(m.extraArgs)+len(args))
copy(newArgs, m.extraArgs)
copy(newArgs[len(m.extraArgs):], args)
return newArgs
}
return args
}

129
database/gdb_model_where.go Normal file
View File

@ -0,0 +1,129 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://githum.com/gogf/gf.
package database
// callWhereBuilder creates and returns a new Model, and sets its WhereBuilder if current Model is safe.
// It sets the WhereBuilder and returns current Model directly if it is not safe.
func (m *Model) callWhereBuilder(builder *WhereBuilder) *Model {
model := m.getModel()
model.whereBuilder = builder
return model
}
// Where sets the condition statement for the builder. The parameter `where` can be type of
// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times,
// multiple conditions will be joined into where statement using "AND".
// See WhereBuilder.Where.
func (m *Model) Where(where any, args ...any) *Model {
return m.callWhereBuilder(m.whereBuilder.Where(where, args...))
}
// Wheref builds condition string using fmt.Sprintf and arguments.
// Note that if the number of `args` is more than the placeholder in `format`,
// the extra `args` will be used as the where condition arguments of the Model.
// See WhereBuilder.Wheref.
func (m *Model) Wheref(format string, args ...any) *Model {
return m.callWhereBuilder(m.whereBuilder.Wheref(format, args...))
}
// WherePri does the same logic as Model.Where except that if the parameter `where`
// is a single condition like int/string/float/slice, it treats the condition as the primary
// key value. That is, if primary key is "id" and given `where` parameter as "123", the
// WherePri function treats the condition as "id=123", but Model.Where treats the condition
// as string "123".
// See WhereBuilder.WherePri.
func (m *Model) WherePri(where any, args ...any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePri(where, args...))
}
// WhereLT builds `column < value` statement.
// See WhereBuilder.WhereLT.
func (m *Model) WhereLT(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereLT(column, value))
}
// WhereLTE builds `column <= value` statement.
// See WhereBuilder.WhereLTE.
func (m *Model) WhereLTE(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereLTE(column, value))
}
// WhereGT builds `column > value` statement.
// See WhereBuilder.WhereGT.
func (m *Model) WhereGT(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereGT(column, value))
}
// WhereGTE builds `column >= value` statement.
// See WhereBuilder.WhereGTE.
func (m *Model) WhereGTE(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereGTE(column, value))
}
// WhereBetween builds `column BETWEEN min AND max` statement.
// See WhereBuilder.WhereBetween.
func (m *Model) WhereBetween(column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereBetween(column, min, max))
}
// WhereLike builds `column LIKE like` statement.
// See WhereBuilder.WhereLike.
func (m *Model) WhereLike(column string, like string) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereLike(column, like))
}
// WhereIn builds `column IN (in)` statement.
// See WhereBuilder.WhereIn.
func (m *Model) WhereIn(column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereIn(column, in))
}
// WhereNull builds `columns[0] IS NULL AND columns[1] IS NULL ...` statement.
// See WhereBuilder.WhereNull.
func (m *Model) WhereNull(columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereNull(columns...))
}
// WhereNotBetween builds `column NOT BETWEEN min AND max` statement.
// See WhereBuilder.WhereNotBetween.
func (m *Model) WhereNotBetween(column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereNotBetween(column, min, max))
}
// WhereNotLike builds `column NOT LIKE like` statement.
// See WhereBuilder.WhereNotLike.
func (m *Model) WhereNotLike(column string, like any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereNotLike(column, like))
}
// WhereNot builds `column != value` statement.
// See WhereBuilder.WhereNot.
func (m *Model) WhereNot(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereNot(column, value))
}
// WhereNotIn builds `column NOT IN (in)` statement.
// See WhereBuilder.WhereNotIn.
func (m *Model) WhereNotIn(column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereNotIn(column, in))
}
// WhereNotNull builds `columns[0] IS NOT NULL AND columns[1] IS NOT NULL ...` statement.
// See WhereBuilder.WhereNotNull.
func (m *Model) WhereNotNull(columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereNotNull(columns...))
}
// WhereExists builds `EXISTS (subQuery)` statement.
func (m *Model) WhereExists(subQuery *Model) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereExists(subQuery))
}
// WhereNotExists builds `NOT EXISTS (subQuery)` statement.
func (m *Model) WhereNotExists(subQuery *Model) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereNotExists(subQuery))
}

View File

@ -0,0 +1,91 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// WherePrefix performs as Where, but it adds prefix to each field in where statement.
// See WhereBuilder.WherePrefix.
func (m *Model) WherePrefix(prefix string, where any, args ...any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefix(prefix, where, args...))
}
// WherePrefixLT builds `prefix.column < value` statement.
// See WhereBuilder.WherePrefixLT.
func (m *Model) WherePrefixLT(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixLT(prefix, column, value))
}
// WherePrefixLTE builds `prefix.column <= value` statement.
// See WhereBuilder.WherePrefixLTE.
func (m *Model) WherePrefixLTE(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixLTE(prefix, column, value))
}
// WherePrefixGT builds `prefix.column > value` statement.
// See WhereBuilder.WherePrefixGT.
func (m *Model) WherePrefixGT(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixGT(prefix, column, value))
}
// WherePrefixGTE builds `prefix.column >= value` statement.
// See WhereBuilder.WherePrefixGTE.
func (m *Model) WherePrefixGTE(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixGTE(prefix, column, value))
}
// WherePrefixBetween builds `prefix.column BETWEEN min AND max` statement.
// See WhereBuilder.WherePrefixBetween.
func (m *Model) WherePrefixBetween(prefix string, column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixBetween(prefix, column, min, max))
}
// WherePrefixLike builds `prefix.column LIKE like` statement.
// See WhereBuilder.WherePrefixLike.
func (m *Model) WherePrefixLike(prefix string, column string, like any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixLike(prefix, column, like))
}
// WherePrefixIn builds `prefix.column IN (in)` statement.
// See WhereBuilder.WherePrefixIn.
func (m *Model) WherePrefixIn(prefix string, column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixIn(prefix, column, in))
}
// WherePrefixNull builds `prefix.columns[0] IS NULL AND prefix.columns[1] IS NULL ...` statement.
// See WhereBuilder.WherePrefixNull.
func (m *Model) WherePrefixNull(prefix string, columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixNull(prefix, columns...))
}
// WherePrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement.
// See WhereBuilder.WherePrefixNotBetween.
func (m *Model) WherePrefixNotBetween(prefix string, column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotBetween(prefix, column, min, max))
}
// WherePrefixNotLike builds `prefix.column NOT LIKE like` statement.
// See WhereBuilder.WherePrefixNotLike.
func (m *Model) WherePrefixNotLike(prefix string, column string, like any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotLike(prefix, column, like))
}
// WherePrefixNot builds `prefix.column != value` statement.
// See WhereBuilder.WherePrefixNot.
func (m *Model) WherePrefixNot(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixNot(prefix, column, value))
}
// WherePrefixNotIn builds `prefix.column NOT IN (in)` statement.
// See WhereBuilder.WherePrefixNotIn.
func (m *Model) WherePrefixNotIn(prefix string, column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotIn(prefix, column, in))
}
// WherePrefixNotNull builds `prefix.columns[0] IS NOT NULL AND prefix.columns[1] IS NOT NULL ...` statement.
// See WhereBuilder.WherePrefixNotNull.
func (m *Model) WherePrefixNotNull(prefix string, columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotNull(prefix, columns...))
}

View File

@ -0,0 +1,97 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// WhereOr adds "OR" condition to the where statement.
// See WhereBuilder.WhereOr.
func (m *Model) WhereOr(where any, args ...any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOr(where, args...))
}
// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments.
// See WhereBuilder.WhereOrf.
func (m *Model) WhereOrf(format string, args ...any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrf(format, args...))
}
// WhereOrLT builds `column < value` statement in `OR` conditions.
// See WhereBuilder.WhereOrLT.
func (m *Model) WhereOrLT(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrLT(column, value))
}
// WhereOrLTE builds `column <= value` statement in `OR` conditions.
// See WhereBuilder.WhereOrLTE.
func (m *Model) WhereOrLTE(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrLTE(column, value))
}
// WhereOrGT builds `column > value` statement in `OR` conditions.
// See WhereBuilder.WhereOrGT.
func (m *Model) WhereOrGT(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrGT(column, value))
}
// WhereOrGTE builds `column >= value` statement in `OR` conditions.
// See WhereBuilder.WhereOrGTE.
func (m *Model) WhereOrGTE(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrGTE(column, value))
}
// WhereOrBetween builds `column BETWEEN min AND max` statement in `OR` conditions.
// See WhereBuilder.WhereOrBetween.
func (m *Model) WhereOrBetween(column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrBetween(column, min, max))
}
// WhereOrLike builds `column LIKE like` statement in `OR` conditions.
// See WhereBuilder.WhereOrLike.
func (m *Model) WhereOrLike(column string, like any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrLike(column, like))
}
// WhereOrIn builds `column IN (in)` statement in `OR` conditions.
// See WhereBuilder.WhereOrIn.
func (m *Model) WhereOrIn(column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrIn(column, in))
}
// WhereOrNull builds `columns[0] IS NULL OR columns[1] IS NULL ...` statement in `OR` conditions.
// See WhereBuilder.WhereOrNull.
func (m *Model) WhereOrNull(columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrNull(columns...))
}
// WhereOrNotBetween builds `column NOT BETWEEN min AND max` statement in `OR` conditions.
// See WhereBuilder.WhereOrNotBetween.
func (m *Model) WhereOrNotBetween(column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrNotBetween(column, min, max))
}
// WhereOrNotLike builds `column NOT LIKE 'like'` statement in `OR` conditions.
// See WhereBuilder.WhereOrNotLike.
func (m *Model) WhereOrNotLike(column string, like any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrNotLike(column, like))
}
// WhereOrNot builds `column != value` statement.
// See WhereBuilder.WhereOrNot.
func (m *Model) WhereOrNot(column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrNot(column, value))
}
// WhereOrNotIn builds `column NOT IN (in)` statement.
// See WhereBuilder.WhereOrNotIn.
func (m *Model) WhereOrNotIn(column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrNotIn(column, in))
}
// WhereOrNotNull builds `columns[0] IS NOT NULL OR columns[1] IS NOT NULL ...` statement in `OR` conditions.
// See WhereBuilder.WhereOrNotNull.
func (m *Model) WhereOrNotNull(columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrNotNull(columns...))
}

View File

@ -0,0 +1,91 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// WhereOrPrefix performs as WhereOr, but it adds prefix to each field in where statement.
// See WhereBuilder.WhereOrPrefix.
func (m *Model) WhereOrPrefix(prefix string, where any, args ...any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefix(prefix, where, args...))
}
// WhereOrPrefixLT builds `prefix.column < value` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixLT.
func (m *Model) WhereOrPrefixLT(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLT(prefix, column, value))
}
// WhereOrPrefixLTE builds `prefix.column <= value` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixLTE.
func (m *Model) WhereOrPrefixLTE(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLTE(prefix, column, value))
}
// WhereOrPrefixGT builds `prefix.column > value` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixGT.
func (m *Model) WhereOrPrefixGT(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixGT(prefix, column, value))
}
// WhereOrPrefixGTE builds `prefix.column >= value` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixGTE.
func (m *Model) WhereOrPrefixGTE(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixGTE(prefix, column, value))
}
// WhereOrPrefixBetween builds `prefix.column BETWEEN min AND max` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixBetween.
func (m *Model) WhereOrPrefixBetween(prefix string, column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixBetween(prefix, column, min, max))
}
// WhereOrPrefixLike builds `prefix.column LIKE like` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixLike.
func (m *Model) WhereOrPrefixLike(prefix string, column string, like any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLike(prefix, column, like))
}
// WhereOrPrefixIn builds `prefix.column IN (in)` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixIn.
func (m *Model) WhereOrPrefixIn(prefix string, column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixIn(prefix, column, in))
}
// WhereOrPrefixNull builds `prefix.columns[0] IS NULL OR prefix.columns[1] IS NULL ...` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixNull.
func (m *Model) WhereOrPrefixNull(prefix string, columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNull(prefix, columns...))
}
// WhereOrPrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixNotBetween.
func (m *Model) WhereOrPrefixNotBetween(prefix string, column string, min, max any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotBetween(prefix, column, min, max))
}
// WhereOrPrefixNotLike builds `prefix.column NOT LIKE like` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixNotLike.
func (m *Model) WhereOrPrefixNotLike(prefix string, column string, like any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotLike(prefix, column, like))
}
// WhereOrPrefixNotIn builds `prefix.column NOT IN (in)` statement.
// See WhereBuilder.WhereOrPrefixNotIn.
func (m *Model) WhereOrPrefixNotIn(prefix string, column string, in any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotIn(prefix, column, in))
}
// WhereOrPrefixNotNull builds `prefix.columns[0] IS NOT NULL OR prefix.columns[1] IS NOT NULL ...` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixNotNull.
func (m *Model) WhereOrPrefixNotNull(prefix string, columns ...string) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotNull(prefix, columns...))
}
// WhereOrPrefixNot builds `prefix.column != value` statement in `OR` conditions.
// See WhereBuilder.WhereOrPrefixNot.
func (m *Model) WhereOrPrefixNot(prefix string, column string, value any) *Model {
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNot(prefix, column, value))
}

349
database/gdb_model_with.go Normal file
View File

@ -0,0 +1,349 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
"reflect"
"git.magicany.cc/black1552/gin-base/database/utils"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gstructs"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gutil"
)
// With creates and returns an ORM model based on metadata of given object.
// It also enables model association operations feature on given `object`.
// It can be called multiple times to add one or more objects to model and enable
// their mode association operations feature.
// For example, if given struct definition:
//
// type User struct {
// gmeta.Meta `orm:"table:user"`
// Id int `json:"id"`
// Name string `json:"name"`
// UserDetail *UserDetail `orm:"with:uid=id"`
// UserScores []*UserScores `orm:"with:uid=id"`
// }
//
// We can enable model association operations on attribute `UserDetail` and `UserScores` by:
//
// db.With(User{}.UserDetail).With(User{}.UserScores).Scan(xxx)
//
// Or:
//
// db.With(UserDetail{}).With(UserScores{}).Scan(xxx)
//
// Or:
//
// db.With(UserDetail{}, UserScores{}).Scan(xxx)
func (m *Model) With(objects ...any) *Model {
model := m.getModel()
for _, object := range objects {
if m.tables == "" {
m.tablesInit = m.db.GetCore().QuotePrefixTableName(
getTableNameFromOrmTag(object),
)
m.tables = m.tablesInit
return model
}
model.withArray = append(model.withArray, object)
}
return model
}
// WithAll enables model association operations on all objects that have "with" tag in the struct.
func (m *Model) WithAll() *Model {
model := m.getModel()
model.withAll = true
return model
}
// doWithScanStruct handles model association operations feature for single struct.
func (m *Model) doWithScanStruct(pointer any) error {
if len(m.withArray) == 0 && !m.withAll {
return nil
}
var (
err error
allowedTypeStrArray = make([]string, 0)
)
currentStructFieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{
Pointer: pointer,
PriorityTagArray: nil,
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
})
if err != nil {
return err
}
// It checks the with array and automatically calls the ScanList to complete association querying.
if !m.withAll {
for _, field := range currentStructFieldMap {
for _, withItem := range m.withArray {
withItemReflectValueType, err := gstructs.StructType(withItem)
if err != nil {
return err
}
var (
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
withItemReflectValueTypeStr = gstr.TrimAll(withItemReflectValueType.String(), "*[]")
)
// It does select operation if the field type is in the specified "with" type array.
if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 {
allowedTypeStrArray = append(allowedTypeStrArray, fieldTypeStr)
}
}
}
}
for _, field := range currentStructFieldMap {
var (
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
parsedTagOutput = m.parseWithTagInFieldStruct(field)
)
if parsedTagOutput.With == "" {
continue
}
// It just handlers "with" type attribute struct, so it ignores other struct types.
if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) {
continue
}
array := gstr.SplitAndTrim(parsedTagOutput.With, "=")
if len(array) == 1 {
// It also supports using only one column name
// if both tables associates using the same column name.
array = append(array, parsedTagOutput.With)
}
var (
model *Model
fieldKeys []string
relatedSourceName = array[0]
relatedTargetName = array[1]
relatedTargetValue any
)
// Find the value of related attribute from `pointer`.
for attributeName, attributeValue := range currentStructFieldMap {
if utils.EqualFoldWithoutChars(attributeName, relatedTargetName) {
relatedTargetValue = attributeValue.Value.Interface()
break
}
}
if relatedTargetValue == nil {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`cannot find the target related value of name "%s" in with tag "%s" for attribute "%s.%s"`,
relatedTargetName, parsedTagOutput.With, reflect.TypeOf(pointer).Elem(), field.Name(),
)
}
bindToReflectValue := field.Value
if bindToReflectValue.Kind() != reflect.Pointer && bindToReflectValue.CanAddr() {
bindToReflectValue = bindToReflectValue.Addr()
}
if structFields, err := gstructs.Fields(gstructs.FieldsInput{
Pointer: field.Value,
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
}); err != nil {
return err
} else {
fieldKeys = make([]string, len(structFields))
for i, field := range structFields {
fieldKeys[i] = field.Name()
}
}
// Recursively with feature checks.
model = m.db.With(field.Value).Hook(m.hookHandler)
if m.withAll {
model = model.WithAll()
} else {
model = model.With(m.withArray...)
}
if parsedTagOutput.Where != "" {
model = model.Where(parsedTagOutput.Where)
}
if parsedTagOutput.Order != "" {
model = model.Order(parsedTagOutput.Order)
}
if parsedTagOutput.Unscoped == "true" {
model = model.Unscoped()
}
// With cache feature.
if m.cacheEnabled && m.cacheOption.Name == "" {
model = model.Cache(m.cacheOption)
}
err = model.Fields(fieldKeys).
Where(relatedSourceName, relatedTargetValue).
Scan(bindToReflectValue)
// It ignores sql.ErrNoRows in with feature.
if err != nil && err != sql.ErrNoRows {
return err
}
}
return nil
}
// doWithScanStructs handles model association operations feature for struct slice.
// Also see doWithScanStruct.
func (m *Model) doWithScanStructs(pointer any) error {
if len(m.withArray) == 0 && !m.withAll {
return nil
}
if v, ok := pointer.(reflect.Value); ok {
pointer = v.Interface()
}
var (
err error
allowedTypeStrArray = make([]string, 0)
)
currentStructFieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{
Pointer: pointer,
PriorityTagArray: nil,
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
})
if err != nil {
return err
}
// It checks the with array and automatically calls the ScanList to complete association querying.
if !m.withAll {
for _, field := range currentStructFieldMap {
for _, withItem := range m.withArray {
withItemReflectValueType, err := gstructs.StructType(withItem)
if err != nil {
return err
}
var (
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
withItemReflectValueTypeStr = gstr.TrimAll(withItemReflectValueType.String(), "*[]")
)
// It does select operation if the field type is in the specified with type array.
if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 {
allowedTypeStrArray = append(allowedTypeStrArray, fieldTypeStr)
}
}
}
}
for fieldName, field := range currentStructFieldMap {
var (
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
parsedTagOutput = m.parseWithTagInFieldStruct(field)
)
if parsedTagOutput.With == "" {
continue
}
if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) {
continue
}
array := gstr.SplitAndTrim(parsedTagOutput.With, "=")
if len(array) == 1 {
// It supports using only one column name
// if both tables associates using the same column name.
array = append(array, parsedTagOutput.With)
}
var (
model *Model
fieldKeys []string
relatedSourceName = array[0]
relatedTargetName = array[1]
relatedTargetValue any
)
// Find the value slice of related attribute from `pointer`.
for attributeName := range currentStructFieldMap {
if utils.EqualFoldWithoutChars(attributeName, relatedTargetName) {
relatedTargetValue = ListItemValuesUnique(pointer, attributeName)
break
}
}
if relatedTargetValue == nil {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`cannot find the related value for attribute name "%s" of with tag "%s"`,
relatedTargetName, parsedTagOutput.With,
)
}
// If related value is empty, it does nothing but just returns.
if gutil.IsEmpty(relatedTargetValue) {
return nil
}
if structFields, err := gstructs.Fields(gstructs.FieldsInput{
Pointer: field.Value,
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
}); err != nil {
return err
} else {
fieldKeys = make([]string, len(structFields))
for i, field := range structFields {
fieldKeys[i] = field.Name()
}
}
// Recursively with feature checks.
model = m.db.With(field.Value).Hook(m.hookHandler)
if m.withAll {
model = model.WithAll()
} else {
model = model.With(m.withArray...)
}
if parsedTagOutput.Where != "" {
model = model.Where(parsedTagOutput.Where)
}
if parsedTagOutput.Order != "" {
model = model.Order(parsedTagOutput.Order)
}
if parsedTagOutput.Unscoped == "true" {
model = model.Unscoped()
}
// With cache feature.
if m.cacheEnabled && m.cacheOption.Name == "" {
model = model.Cache(m.cacheOption)
}
err = model.Fields(fieldKeys).
Where(relatedSourceName, relatedTargetValue).
ScanList(pointer, fieldName, parsedTagOutput.With)
// It ignores sql.ErrNoRows in with feature.
if err != nil && err != sql.ErrNoRows {
return err
}
}
return nil
}
type parseWithTagInFieldStructOutput struct {
With string
Where string
Order string
Unscoped string
}
func (m *Model) parseWithTagInFieldStruct(field gstructs.Field) (output parseWithTagInFieldStructOutput) {
var (
ormTag = field.Tag(OrmTagForStruct)
data = make(map[string]string)
array []string
key string
)
for _, v := range gstr.SplitAndTrim(ormTag, ",") {
array = gstr.Split(v, ":")
if len(array) == 2 {
key = array[0]
data[key] = gstr.Trim(array[1])
} else {
if key == OrmTagForWithOrder {
// supporting multiple order fields
data[key] += "," + gstr.Trim(v)
} else {
data[key] += " " + gstr.Trim(v)
}
}
}
output.With = data[OrmTagForWith]
output.Where = data[OrmTagForWithWhere]
output.Order = data[OrmTagForWithOrder]
output.Unscoped = data[OrmTagForWithUnscoped]
return
}

67
database/gdb_result.go Normal file
View File

@ -0,0 +1,67 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
"github.com/gogf/gf/v2/errors/gerror"
)
// SqlResult is execution result for sql operations.
// It also supports batch operation result for rowsAffected.
type SqlResult struct {
Result sql.Result
Affected int64
}
// MustGetAffected returns the affected rows count, if any error occurs, it panics.
func (r *SqlResult) MustGetAffected() int64 {
rows, err := r.RowsAffected()
if err != nil {
err = gerror.Wrap(err, `sql.Result.RowsAffected failed`)
panic(err)
}
return rows
}
// MustGetInsertId returns the last insert id, if any error occurs, it panics.
func (r *SqlResult) MustGetInsertId() int64 {
id, err := r.LastInsertId()
if err != nil {
err = gerror.Wrap(err, `sql.Result.LastInsertId failed`)
panic(err)
}
return id
}
// RowsAffected returns the number of rows affected by an
// update, insert, or delete. Not every database or database
// driver may support this.
// Also, See sql.Result.
func (r *SqlResult) RowsAffected() (int64, error) {
if r.Affected > 0 {
return r.Affected, nil
}
if r.Result == nil {
return 0, nil
}
return r.Result.RowsAffected()
}
// LastInsertId returns the integer generated by the database
// in response to a command. Typically, this will be from an
// "auto increment" column when inserting a new row. Not all
// databases support this feature, and the syntax of such
// statements varies.
// Also, See sql.Result.
func (r *SqlResult) LastInsertId() (int64, error) {
if r.Result == nil {
return 0, nil
}
return r.Result.LastInsertId()
}

30
database/gdb_schema.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
// Schema is a schema object from which it can then create a Model.
type Schema struct {
DB
}
// Schema creates and returns a schema.
func (c *Core) Schema(schema string) *Schema {
// Do not change the schema of the original db,
// it here creates a new db and changes its schema.
db, err := NewByGroup(c.GetGroup())
if err != nil {
panic(err)
}
core := db.GetCore()
// Different schema share some same objects.
core.logger = c.logger
core.cache = c.cache
core.schema = schema
return &Schema{
DB: db,
}
}

118
database/gdb_statement.go Normal file
View File

@ -0,0 +1,118 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"database/sql"
)
// Stmt is a prepared statement.
// A Stmt is safe for concurrent use by multiple goroutines.
//
// If a Stmt is prepared on a Tx or Conn, it will be bound to a single
// underlying connection forever. If the Tx or Conn closes, the Stmt will
// become unusable and all operations will return an error.
// If a Stmt is prepared on a DB, it will remain usable for the lifetime of the
// DB. When the Stmt needs to execute on a new underlying connection, it will
// prepare itself on the new connection automatically.
type Stmt struct {
*sql.Stmt
core *Core
link Link
sql string
}
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
Stmt: s.Stmt,
Link: s.link,
Sql: s.sql,
Args: args,
Type: SqlTypeStmtExecContext,
IsTransaction: s.link.IsTransaction(),
})
return out.Result, err
}
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
Stmt: s.Stmt,
Link: s.link,
Sql: s.sql,
Args: args,
Type: SqlTypeStmtQueryContext,
IsTransaction: s.link.IsTransaction(),
})
if err != nil {
return nil, err
}
if out.RawResult != nil {
return out.RawResult.(*sql.Rows), err
}
return nil, nil
}
// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row {
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
Stmt: s.Stmt,
Link: s.link,
Sql: s.sql,
Args: args,
Type: SqlTypeStmtQueryContext,
IsTransaction: s.link.IsTransaction(),
})
if err != nil {
panic(err)
}
if out.RawResult != nil {
return out.RawResult.(*sql.Row)
}
return nil
}
// Exec executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) Exec(args ...any) (sql.Result, error) {
return s.ExecContext(context.Background(), args...)
}
// Query executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) Query(args ...any) (*sql.Rows, error) {
return s.QueryContext(context.Background(), args...)
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
//
// Example usage:
//
// var name string
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
func (s *Stmt) QueryRow(args ...any) *sql.Row {
return s.QueryRowContext(context.Background(), args...)
}
// Close closes the statement.
func (s *Stmt) Close() error {
return s.Stmt.Close()
}

View File

@ -0,0 +1,65 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
"git.magicany.cc/black1552/gin-base/database/empty"
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// Json converts `r` to JSON format content.
func (r Record) Json() string {
content, _ := gjson.New(r.Map()).ToJsonString()
return content
}
// Xml converts `r` to XML format content.
func (r Record) Xml(rootTag ...string) string {
content, _ := gjson.New(r.Map()).ToXmlString(rootTag...)
return content
}
// Map converts `r` to map[string]any.
func (r Record) Map() Map {
m := make(map[string]any)
for k, v := range r {
m[k] = v.Val()
}
return m
}
// GMap converts `r` to a gmap.
func (r Record) GMap() *gmap.StrAnyMap {
return gmap.NewStrAnyMapFrom(r.Map())
}
// Struct converts `r` to a struct.
// Note that the parameter `pointer` should be type of *struct/**struct.
//
// Note that it returns sql.ErrNoRows if `r` is empty.
func (r Record) Struct(pointer any) error {
// If the record is empty, it returns error.
if r.IsEmpty() {
if !empty.IsNil(pointer, true) {
return sql.ErrNoRows
}
return nil
}
return converter.Struct(r, pointer, gconv.StructOption{
PriorityTag: OrmTagForStruct,
ContinueOnError: true,
})
}
// IsEmpty checks and returns whether `r` is empty.
func (r Record) IsEmpty() bool {
return len(r) == 0
}

214
database/gdb_type_result.go Normal file
View File

@ -0,0 +1,214 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
"math"
"git.magicany.cc/black1552/gin-base/database/empty"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// IsEmpty checks and returns whether `r` is empty.
func (r Result) IsEmpty() bool {
return r == nil || r.Len() == 0
}
// Len returns the length of result list.
func (r Result) Len() int {
return len(r)
}
// Size is alias of function Len.
func (r Result) Size() int {
return r.Len()
}
// Chunk splits a Result into multiple Results,
// the size of each array is determined by `size`.
// The last chunk may contain less than size elements.
func (r Result) Chunk(size int) []Result {
if size < 1 {
return nil
}
length := len(r)
chunks := int(math.Ceil(float64(length) / float64(size)))
var n []Result
for i, end := 0, 0; chunks > 0; chunks-- {
end = (i + 1) * size
if end > length {
end = length
}
n = append(n, r[i*size:end])
i++
}
return n
}
// Json converts `r` to JSON format content.
func (r Result) Json() string {
content, _ := gjson.New(r.List()).ToJsonString()
return content
}
// Xml converts `r` to XML format content.
func (r Result) Xml(rootTag ...string) string {
content, _ := gjson.New(r.List()).ToXmlString(rootTag...)
return content
}
// List converts `r` to a List.
func (r Result) List() List {
list := make(List, len(r))
for k, v := range r {
list[k] = v.Map()
}
return list
}
// Array retrieves and returns specified column values as slice.
// The parameter `field` is optional is the column field is only one.
// The default `field` is the first field name of the first item in `Result` if parameter `field` is not given.
func (r Result) Array(field ...string) Array {
array := make(Array, len(r))
if len(r) == 0 {
return array
}
key := ""
if len(field) > 0 && field[0] != "" {
key = field[0]
} else {
for k := range r[0] {
key = k
break
}
}
for k, v := range r {
array[k] = v[key]
}
return array
}
// MapKeyValue converts `r` to a map[string]Value of which key is specified by `key`.
// Note that the item value may be type of slice.
func (r Result) MapKeyValue(key string) map[string]Value {
var (
s string
m = make(map[string]Value)
tempMap = make(map[string][]any)
hasMultiValues bool
)
for _, item := range r {
if k, ok := item[key]; ok {
s = k.String()
tempMap[s] = append(tempMap[s], item)
if len(tempMap[s]) > 1 {
hasMultiValues = true
}
}
}
for k, v := range tempMap {
if hasMultiValues {
m[k] = gvar.New(v)
} else {
m[k] = gvar.New(v[0])
}
}
return m
}
// MapKeyStr converts `r` to a map[string]Map of which key is specified by `key`.
func (r Result) MapKeyStr(key string) map[string]Map {
m := make(map[string]Map)
for _, item := range r {
if v, ok := item[key]; ok {
m[v.String()] = item.Map()
}
}
return m
}
// MapKeyInt converts `r` to a map[int]Map of which key is specified by `key`.
func (r Result) MapKeyInt(key string) map[int]Map {
m := make(map[int]Map)
for _, item := range r {
if v, ok := item[key]; ok {
m[v.Int()] = item.Map()
}
}
return m
}
// MapKeyUint converts `r` to a map[uint]Map of which key is specified by `key`.
func (r Result) MapKeyUint(key string) map[uint]Map {
m := make(map[uint]Map)
for _, item := range r {
if v, ok := item[key]; ok {
m[v.Uint()] = item.Map()
}
}
return m
}
// RecordKeyStr converts `r` to a map[string]Record of which key is specified by `key`.
func (r Result) RecordKeyStr(key string) map[string]Record {
m := make(map[string]Record)
for _, item := range r {
if v, ok := item[key]; ok {
m[v.String()] = item
}
}
return m
}
// RecordKeyInt converts `r` to a map[int]Record of which key is specified by `key`.
func (r Result) RecordKeyInt(key string) map[int]Record {
m := make(map[int]Record)
for _, item := range r {
if v, ok := item[key]; ok {
m[v.Int()] = item
}
}
return m
}
// RecordKeyUint converts `r` to a map[uint]Record of which key is specified by `key`.
func (r Result) RecordKeyUint(key string) map[uint]Record {
m := make(map[uint]Record)
for _, item := range r {
if v, ok := item[key]; ok {
m[v.Uint()] = item
}
}
return m
}
// Structs converts `r` to struct slice.
// Note that the parameter `pointer` should be type of *[]struct/*[]*struct.
func (r Result) Structs(pointer any) (err error) {
// If the result is empty and the target pointer is not empty, it returns error.
if r.IsEmpty() {
if !empty.IsEmpty(pointer, true) {
return sql.ErrNoRows
}
return nil
}
var (
sliceOption = gconv.SliceOption{ContinueOnError: true}
structOption = gconv.StructOption{
PriorityTag: OrmTagForStruct,
ContinueOnError: true,
}
)
return converter.Structs(r, pointer, gconv.StructsOption{
SliceOption: sliceOption,
StructOption: structOption,
})
}

View File

@ -0,0 +1,512 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"database/sql"
"reflect"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gstructs"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
// ScanList converts `r` to struct slice which contains other complex struct attributes.
// Note that the parameter `structSlicePointer` should be type of *[]struct/*[]*struct.
//
// Usage example 1: Normal attribute struct relation:
//
// type EntityUser struct {
// Uid int
// Name string
// }
//
// type EntityUserDetail struct {
// Uid int
// Address string
// }
//
// type EntityUserScores struct {
// Id int
// Uid int
// Score int
// Course string
// }
//
// type Entity struct {
// User *EntityUser
// UserDetail *EntityUserDetail
// UserScores []*EntityUserScores
// }
//
// var users []*Entity
// ScanList(&users, "User")
// ScanList(&users, "User", "uid")
// ScanList(&users, "UserDetail", "User", "uid:Uid")
// ScanList(&users, "UserScores", "User", "uid:Uid")
// ScanList(&users, "UserScores", "User", "uid")
//
// Usage example 2: Embedded attribute struct relation:
//
// type EntityUser struct {
// Uid int
// Name string
// }
//
// type EntityUserDetail struct {
// Uid int
// Address string
// }
//
// type EntityUserScores struct {
// Id int
// Uid int
// Score int
// }
//
// type Entity struct {
// EntityUser
// UserDetail EntityUserDetail
// UserScores []EntityUserScores
// }
//
// var users []*Entity
// ScanList(&users)
// ScanList(&users, "UserDetail", "uid")
// ScanList(&users, "UserScores", "uid")
//
// The parameters "User/UserDetail/UserScores" in the example codes specify the target attribute struct
// that current result will be bound to.
//
// The "uid" in the example codes is the table field name of the result, and the "Uid" is the relational
// struct attribute name - not the attribute name of the bound to target. In the example codes, it's attribute
// name "Uid" of "User" of entity "Entity". It automatically calculates the HasOne/HasMany relationship with
// given `relation` parameter.
//
// See the example or unit testing cases for clear understanding for this function.
func (r Result) ScanList(structSlicePointer any, bindToAttrName string, relationAttrNameAndFields ...string) (err error) {
out, err := checkGetSliceElementInfoForScanList(structSlicePointer, bindToAttrName)
if err != nil {
return err
}
var (
relationAttrName string
relationFields string
)
switch len(relationAttrNameAndFields) {
case 2:
relationAttrName = relationAttrNameAndFields[0]
relationFields = relationAttrNameAndFields[1]
case 1:
relationFields = relationAttrNameAndFields[0]
}
return doScanList(doScanListInput{
Model: nil,
Result: r,
StructSlicePointer: structSlicePointer,
StructSliceValue: out.SliceReflectValue,
BindToAttrName: bindToAttrName,
RelationAttrName: relationAttrName,
RelationFields: relationFields,
})
}
type checkGetSliceElementInfoForScanListOutput struct {
SliceReflectValue reflect.Value
BindToAttrType reflect.Type
}
func checkGetSliceElementInfoForScanList(structSlicePointer any, bindToAttrName string) (out *checkGetSliceElementInfoForScanListOutput, err error) {
// Necessary checks for parameters.
if structSlicePointer == nil {
return nil, gerror.NewCode(gcode.CodeInvalidParameter, `structSlicePointer cannot be nil`)
}
if bindToAttrName == "" {
return nil, gerror.NewCode(gcode.CodeInvalidParameter, `bindToAttrName should not be empty`)
}
var (
reflectType reflect.Type
reflectValue = reflect.ValueOf(structSlicePointer)
reflectKind = reflectValue.Kind()
)
if reflectKind == reflect.Interface {
reflectValue = reflectValue.Elem()
reflectKind = reflectValue.Kind()
}
if reflectKind != reflect.Pointer {
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
"structSlicePointer should be type of *[]struct/*[]*struct, but got: %s",
reflect.TypeOf(structSlicePointer).String(),
)
}
out = &checkGetSliceElementInfoForScanListOutput{
SliceReflectValue: reflectValue.Elem(),
}
// Find the element struct type of the slice.
reflectType = reflectValue.Type().Elem().Elem()
reflectKind = reflectType.Kind()
for reflectKind == reflect.Pointer {
reflectType = reflectType.Elem()
reflectKind = reflectType.Kind()
}
if reflectKind != reflect.Struct {
err = gerror.NewCodef(
gcode.CodeInvalidParameter,
"structSlicePointer should be type of *[]struct/*[]*struct, but got: %s",
reflect.TypeOf(structSlicePointer).String(),
)
return
}
// Find the target field by given name.
structField, ok := reflectType.FieldByName(bindToAttrName)
if !ok {
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`field "%s" not found in element of "%s"`,
bindToAttrName,
reflect.TypeOf(structSlicePointer).String(),
)
}
// Find the attribute struct type for ORM fields filtering.
reflectType = structField.Type
reflectKind = reflectType.Kind()
for reflectKind == reflect.Pointer {
reflectType = reflectType.Elem()
reflectKind = reflectType.Kind()
}
if reflectKind == reflect.Slice || reflectKind == reflect.Array {
reflectType = reflectType.Elem()
// reflectKind = reflectType.Kind()
}
out.BindToAttrType = reflectType
return
}
type doScanListInput struct {
Model *Model
Result Result
StructSlicePointer any
StructSliceValue reflect.Value
BindToAttrName string
RelationAttrName string
RelationFields string
}
// doScanList converts `result` to struct slice which contains other complex struct attributes recursively.
// The parameter `model` is used for recursively scanning purpose, which means, it can scan the attribute struct/structs recursively,
// but it needs the Model for database accessing.
// Note that the parameter `structSlicePointer` should be type of *[]struct/*[]*struct.
func doScanList(in doScanListInput) (err error) {
if in.Result.IsEmpty() {
return nil
}
if in.BindToAttrName == "" {
return gerror.NewCode(gcode.CodeInvalidParameter, `bindToAttrName should not be empty`)
}
length := len(in.Result)
if length == 0 {
// The pointed slice is not empty.
if in.StructSliceValue.Len() > 0 {
// It here checks if it has struct item, which is already initialized.
// It then returns error to warn the developer its empty and no conversion.
if v := in.StructSliceValue.Index(0); v.Kind() != reflect.Pointer {
return sql.ErrNoRows
}
}
// Do nothing for empty struct slice.
return nil
}
var (
arrayValue reflect.Value // Like: []*Entity
arrayItemType reflect.Type // Like: *Entity
reflectType = reflect.TypeOf(in.StructSlicePointer)
)
if in.StructSliceValue.Len() > 0 {
arrayValue = in.StructSliceValue
} else {
arrayValue = reflect.MakeSlice(reflectType.Elem(), length, length)
}
// Slice element item.
arrayItemType = arrayValue.Index(0).Type()
// Relation variables.
var (
relationDataMap map[string]Value
relationFromFieldName string // Eg: relationKV: id:uid -> id
relationBindToFieldName string // Eg: relationKV: id:uid -> uid
)
if len(in.RelationFields) > 0 {
// The relation key string of table field name and attribute name
// can be joined with char '=' or ':'.
array := gstr.SplitAndTrim(in.RelationFields, "=")
if len(array) == 1 {
// Compatible with old splitting char ':'.
array = gstr.SplitAndTrim(in.RelationFields, ":")
}
if len(array) == 1 {
// The relation names are the same.
array = []string{in.RelationFields, in.RelationFields}
}
if len(array) == 2 {
// Defined table field to relation attribute name.
// Like:
// uid:Uid
// uid:UserId
relationFromFieldName = array[0]
relationBindToFieldName = array[1]
if key, _ := gutil.MapPossibleItemByKey(in.Result[0].Map(), relationFromFieldName); key == "" {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`cannot find possible related table field name "%s" from given relation fields "%s"`,
relationFromFieldName,
in.RelationFields,
)
} else {
relationFromFieldName = key
}
} else {
return gerror.NewCode(
gcode.CodeInvalidParameter,
`parameter relationKV should be format of "ResultFieldName:BindToAttrName"`,
)
}
if relationFromFieldName != "" {
// Note that the value might be type of slice.
relationDataMap = in.Result.MapKeyValue(relationFromFieldName)
}
if len(relationDataMap) == 0 {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`cannot find the relation data map, maybe invalid relation fields given "%v"`,
in.RelationFields,
)
}
}
// Bind to target attribute.
var (
ok bool
bindToAttrValue reflect.Value
bindToAttrKind reflect.Kind
bindToAttrType reflect.Type
bindToAttrField reflect.StructField
)
if arrayItemType.Kind() == reflect.Pointer {
if bindToAttrField, ok = arrayItemType.Elem().FieldByName(in.BindToAttrName); !ok {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`,
in.BindToAttrName,
)
}
} else {
if bindToAttrField, ok = arrayItemType.FieldByName(in.BindToAttrName); !ok {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`,
in.BindToAttrName,
)
}
}
bindToAttrType = bindToAttrField.Type
bindToAttrKind = bindToAttrType.Kind()
// Bind to relation conditions.
var (
relationFromAttrValue reflect.Value
relationFromAttrField reflect.Value
relationBindToFieldNameChecked bool
)
for i := 0; i < arrayValue.Len(); i++ {
arrayElemValue := arrayValue.Index(i)
// The FieldByName should be called on non-pointer reflect.Value.
if arrayElemValue.Kind() == reflect.Pointer {
// Like: []*Entity
arrayElemValue = arrayElemValue.Elem()
if !arrayElemValue.IsValid() {
// The element is nil, then create one and set it to the slice.
// The "reflect.New(itemType.Elem())" creates a new element and returns the address of it.
// For example:
// reflect.New(itemType.Elem()) => *Entity
// reflect.New(itemType.Elem()).Elem() => Entity
arrayElemValue = reflect.New(arrayItemType.Elem()).Elem()
arrayValue.Index(i).Set(arrayElemValue.Addr())
}
// } else {
// Like: []Entity
}
bindToAttrValue = arrayElemValue.FieldByName(in.BindToAttrName)
if in.RelationAttrName != "" {
// Attribute value of current slice element.
relationFromAttrValue = arrayElemValue.FieldByName(in.RelationAttrName)
if relationFromAttrValue.Kind() == reflect.Pointer {
relationFromAttrValue = relationFromAttrValue.Elem()
}
} else {
// Current slice element.
relationFromAttrValue = arrayElemValue
}
if len(relationDataMap) > 0 && !relationFromAttrValue.IsValid() {
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
}
// Check and find possible bind to attribute name.
if in.RelationFields != "" && !relationBindToFieldNameChecked {
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
if !relationFromAttrField.IsValid() {
fieldMap, _ := gstructs.FieldMap(gstructs.FieldMapInput{
Pointer: relationFromAttrValue,
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
})
if key, _ := gutil.MapPossibleItemByKey(gconv.Map(fieldMap), relationBindToFieldName); key == "" {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`cannot find possible related attribute name "%s" from given relation fields "%s"`,
relationBindToFieldName,
in.RelationFields,
)
} else {
relationBindToFieldName = key
}
}
relationBindToFieldNameChecked = true
}
switch bindToAttrKind {
case reflect.Array, reflect.Slice:
if len(relationDataMap) > 0 {
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
if relationFromAttrField.IsValid() {
results := make(Result, 0)
for _, v := range relationDataMap[gconv.String(relationFromAttrField.Interface())].Slice() {
results = append(results, v.(Record))
}
if err = results.Structs(bindToAttrValue.Addr()); err != nil {
return err
}
// Recursively Scan.
if in.Model != nil {
if err = in.Model.doWithScanStructs(bindToAttrValue.Addr()); err != nil {
return nil
}
}
} else {
// Maybe the attribute does not exist yet.
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
}
} else {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
`relationKey should not be empty as field "%s" is slice`,
in.BindToAttrName,
)
}
case reflect.Pointer:
var element reflect.Value
if bindToAttrValue.IsNil() {
element = reflect.New(bindToAttrType.Elem()).Elem()
} else {
element = bindToAttrValue.Elem()
}
if len(relationDataMap) > 0 {
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
if relationFromAttrField.IsValid() {
v := relationDataMap[gconv.String(relationFromAttrField.Interface())]
if v == nil {
// There's no relational data.
continue
}
if v.IsSlice() {
if err = v.Slice()[0].(Record).Struct(element); err != nil {
return err
}
} else {
if err = v.Val().(Record).Struct(element); err != nil {
return err
}
}
} else {
// Maybe the attribute does not exist yet.
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
}
} else {
if i >= len(in.Result) {
// There's no relational data.
continue
}
v := in.Result[i]
if v == nil {
// There's no relational data.
continue
}
if err = v.Struct(element); err != nil {
return err
}
}
// Recursively Scan.
if in.Model != nil {
if err = in.Model.doWithScanStruct(element); err != nil {
return err
}
}
bindToAttrValue.Set(element.Addr())
case reflect.Struct:
if len(relationDataMap) > 0 {
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
if relationFromAttrField.IsValid() {
relationDataItem := relationDataMap[gconv.String(relationFromAttrField.Interface())]
if relationDataItem == nil {
// There's no relational data.
continue
}
if relationDataItem.IsSlice() {
if err = relationDataItem.Slice()[0].(Record).Struct(bindToAttrValue); err != nil {
return err
}
} else {
if err = relationDataItem.Val().(Record).Struct(bindToAttrValue); err != nil {
return err
}
}
} else {
// Maybe the attribute does not exist yet.
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
}
} else {
if i >= len(in.Result) {
// There's no relational data.
continue
}
relationDataItem := in.Result[i]
if relationDataItem == nil {
// There's no relational data.
continue
}
if err = relationDataItem.Struct(bindToAttrValue); err != nil {
return err
}
}
// Recursively Scan.
if in.Model != nil {
if err = in.Model.doWithScanStruct(bindToAttrValue); err != nil {
return err
}
}
default:
return gerror.NewCodef(gcode.CodeInvalidParameter, `unsupported attribute type: %s`, bindToAttrKind.String())
}
}
reflect.ValueOf(in.StructSlicePointer).Elem().Set(arrayValue)
return nil
}

View File

@ -1 +0,0 @@
package database

View File

@ -0,0 +1,79 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package instance provides instances management.
//
// Note that this package is not used for cache, as it has no cache expiration.
package instance
import (
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/encoding/ghash"
)
const (
groupNumber = 64
)
var (
groups = make([]*gmap.StrAnyMap, groupNumber)
)
func init() {
for i := 0; i < groupNumber; i++ {
groups[i] = gmap.NewStrAnyMap(true)
}
}
func getGroup(key string) *gmap.StrAnyMap {
return groups[int(ghash.DJB([]byte(key))%groupNumber)]
}
// Get returns the instance by given name.
func Get(name string) any {
return getGroup(name).Get(name)
}
// Set sets an instance to the instance manager with given name.
func Set(name string, instance any) {
getGroup(name).Set(name, instance)
}
// GetOrSet returns the instance by name,
// or set instance to the instance manager if it does not exist and returns this instance.
func GetOrSet(name string, instance any) any {
return getGroup(name).GetOrSet(name, instance)
}
// GetOrSetFunc returns the instance by name,
// or sets instance with returned value of callback function `f` if it does not exist
// and then returns this instance.
func GetOrSetFunc(name string, f func() any) any {
return getGroup(name).GetOrSetFunc(name, f)
}
// GetOrSetFuncLock returns the instance by name,
// or sets instance with returned value of callback function `f` if it does not exist
// and then returns this instance.
//
// GetOrSetFuncLock differs with GetOrSetFunc function is that it executes function `f`
// with mutex.Lock of the hash map.
func GetOrSetFuncLock(name string, f func() any) any {
return getGroup(name).GetOrSetFuncLock(name, f)
}
// SetIfNotExist sets `instance` to the map if the `name` does not exist, then returns true.
// It returns false if `name` exists, and `instance` would be ignored.
func SetIfNotExist(name string, instance any) bool {
return getGroup(name).SetIfNotExist(name, instance)
}
// Clear deletes all instances stored.
func Clear() {
for i := 0; i < groupNumber; i++ {
groups[i].Clear()
}
}

View File

@ -0,0 +1,44 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package instance_test
import (
"testing"
"github.com/gogf/gf/v2/internal/instance"
"github.com/gogf/gf/v2/test/gtest"
)
func Test_SetGet(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
instance.Set("test-user", 1)
t.Assert(instance.Get("test-user"), 1)
t.Assert(instance.Get("none-exists"), nil)
})
gtest.C(t, func(t *gtest.T) {
t.Assert(instance.GetOrSet("test-1", 1), 1)
t.Assert(instance.Get("test-1"), 1)
})
gtest.C(t, func(t *gtest.T) {
t.Assert(instance.GetOrSetFunc("test-2", func() any {
return 2
}), 2)
t.Assert(instance.Get("test-2"), 2)
})
gtest.C(t, func(t *gtest.T) {
t.Assert(instance.GetOrSetFuncLock("test-3", func() any {
return 3
}), 3)
t.Assert(instance.Get("test-3"), 3)
})
gtest.C(t, func(t *gtest.T) {
t.Assert(instance.SetIfNotExist("test-4", 4), true)
t.Assert(instance.Get("test-4"), 4)
t.Assert(instance.SetIfNotExist("test-4", 5), false)
t.Assert(instance.Get("test-4"), 4)
})
}

125
database/intlog/intlog.go Normal file
View File

@ -0,0 +1,125 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package intlog provides internal logging for GoFrame development usage only.
package intlog
import (
"bytes"
"context"
"fmt"
"path/filepath"
"time"
"git.magicany.cc/black1552/gin-base/database/utils"
"go.opentelemetry.io/otel/trace"
"github.com/gogf/gf/v2/debug/gdebug"
)
const (
stackFilterKey = "/internal/intlog"
)
// Print prints `v` with newline using fmt.Println.
// The parameter `v` can be multiple variables.
func Print(ctx context.Context, v ...any) {
if !utils.IsDebugEnabled() {
return
}
doPrint(ctx, fmt.Sprint(v...), false)
}
// Printf prints `v` with format `format` using fmt.Printf.
// The parameter `v` can be multiple variables.
func Printf(ctx context.Context, format string, v ...any) {
if !utils.IsDebugEnabled() {
return
}
doPrint(ctx, fmt.Sprintf(format, v...), false)
}
// Error prints `v` with newline using fmt.Println.
// The parameter `v` can be multiple variables.
func Error(ctx context.Context, v ...any) {
if !utils.IsDebugEnabled() {
return
}
doPrint(ctx, fmt.Sprint(v...), true)
}
// Errorf prints `v` with format `format` using fmt.Printf.
func Errorf(ctx context.Context, format string, v ...any) {
if !utils.IsDebugEnabled() {
return
}
doPrint(ctx, fmt.Sprintf(format, v...), true)
}
// PrintFunc prints the output from function `f`.
// It only calls function `f` if debug mode is enabled.
func PrintFunc(ctx context.Context, f func() string) {
if !utils.IsDebugEnabled() {
return
}
s := f()
if s == "" {
return
}
doPrint(ctx, s, false)
}
// ErrorFunc prints the output from function `f`.
// It only calls function `f` if debug mode is enabled.
func ErrorFunc(ctx context.Context, f func() string) {
if !utils.IsDebugEnabled() {
return
}
s := f()
if s == "" {
return
}
doPrint(ctx, s, true)
}
func doPrint(ctx context.Context, content string, stack bool) {
if !utils.IsDebugEnabled() {
return
}
buffer := bytes.NewBuffer(nil)
buffer.WriteString(time.Now().Format("2006-01-02 15:04:05.000"))
buffer.WriteString(" [INTE] ")
buffer.WriteString(file())
buffer.WriteString(" ")
if s := traceIDStr(ctx); s != "" {
buffer.WriteString(s + " ")
}
buffer.WriteString(content)
buffer.WriteString("\n")
if stack {
buffer.WriteString("Caller Stack:\n")
buffer.WriteString(gdebug.StackWithFilter([]string{stackFilterKey}))
}
fmt.Print(buffer.String())
}
// traceIDStr retrieves and returns the trace id string for logging output.
func traceIDStr(ctx context.Context) string {
if ctx == nil {
return ""
}
spanCtx := trace.SpanContextFromContext(ctx)
if traceID := spanCtx.TraceID(); traceID.IsValid() {
return "{" + traceID.String() + "}"
}
return ""
}
// file returns caller file name along with its line number.
func file() string {
_, p, l := gdebug.CallerWithFilter([]string{stackFilterKey})
return fmt.Sprintf(`%s:%d`, filepath.Base(p), l)
}

85
database/json/json.go Normal file
View File

@ -0,0 +1,85 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package json provides json operations wrapping ignoring stdlib or third-party lib json.
package json
import (
"bytes"
"encoding/json"
"io"
"github.com/gogf/gf/v2/errors/gerror"
)
// RawMessage is a raw encoded JSON value.
// It implements Marshaler and Unmarshaler and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage = json.RawMessage
// Marshal adapts to json/encoding Marshal API.
//
// Marshal returns the JSON encoding of v, adapts to json/encoding Marshal API
// Refer to https://godoc.org/encoding/json#Marshal for more information.
func Marshal(v any) (marshaledBytes []byte, err error) {
marshaledBytes, err = json.Marshal(v)
if err != nil {
err = gerror.Wrap(err, `json.Marshal failed`)
}
return
}
// MarshalIndent same as json.MarshalIndent.
func MarshalIndent(v any, prefix, indent string) (marshaledBytes []byte, err error) {
marshaledBytes, err = json.MarshalIndent(v, prefix, indent)
if err != nil {
err = gerror.Wrap(err, `json.MarshalIndent failed`)
}
return
}
// Unmarshal adapts to json/encoding Unmarshal API
//
// Unmarshal parses the JSON-encoded data and stores the result in the value pointed to by v.
// Refer to https://godoc.org/encoding/json#Unmarshal for more information.
func Unmarshal(data []byte, v any) (err error) {
err = json.Unmarshal(data, v)
if err != nil {
err = gerror.Wrap(err, `json.Unmarshal failed`)
}
return
}
// UnmarshalUseNumber decodes the json data bytes to target interface using number option.
func UnmarshalUseNumber(data []byte, v any) (err error) {
decoder := NewDecoder(bytes.NewReader(data))
decoder.UseNumber()
err = decoder.Decode(v)
if err != nil {
err = gerror.Wrap(err, `json.UnmarshalUseNumber failed`)
}
return
}
// NewEncoder same as json.NewEncoder
func NewEncoder(writer io.Writer) *json.Encoder {
return json.NewEncoder(writer)
}
// NewDecoder adapts to json/stream NewDecoder API.
//
// NewDecoder returns a new decoder that reads from r.
//
// Instead of a json/encoding Decoder, a Decoder is returned
// Refer to https://godoc.org/encoding/json#NewDecoder for more information.
func NewDecoder(reader io.Reader) *json.Decoder {
return json.NewDecoder(reader)
}
// Valid reports whether data is a valid JSON encoding.
func Valid(data []byte) bool {
return json.Valid(data)
}

View File

@ -1,43 +0,0 @@
package database
import (
"git.magicany.cc/black1552/gin-base/log"
"github.com/gogf/gf/v2/frame/g"
)
func SetAutoMigrate(models ...interface{}) {
if g.IsNil(Db) {
log.Error("数据库连接失败")
return
}
err := Db.AutoMigrate(models...)
if err != nil {
log.Error("数据库迁移失败", err)
}
}
func RenameColumn(dst interface{}, name, newName string) {
if Db.Migrator().HasColumn(dst, name) {
err := Db.Migrator().RenameColumn(dst, name, newName)
if err != nil {
log.Error("数据库修改字段失败", err)
return
}
} else {
log.Info("数据库字段不存在", name)
}
}
// DropColumn
// 删除字段
// 例DropColumn(&User{}, "Sex")
func DropColumn(dst interface{}, name string) {
if Db.Migrator().HasColumn(dst, name) {
err := Db.Migrator().DropColumn(dst, name)
if err != nil {
log.Error("数据库删除字段失败", err)
return
}
} else {
log.Info("数据库字段不存在", name)
}
}

View File

@ -0,0 +1,94 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package reflection provides some reflection functions for internal usage.
package reflection
import (
"reflect"
)
type OriginValueAndKindOutput struct {
InputValue reflect.Value
InputKind reflect.Kind
OriginValue reflect.Value
OriginKind reflect.Kind
}
// OriginValueAndKind retrieves and returns the original reflect value and kind.
func OriginValueAndKind(value any) (out OriginValueAndKindOutput) {
if v, ok := value.(reflect.Value); ok {
out.InputValue = v
} else {
out.InputValue = reflect.ValueOf(value)
}
out.InputKind = out.InputValue.Kind()
out.OriginValue = out.InputValue
out.OriginKind = out.InputKind
for out.OriginKind == reflect.Pointer {
out.OriginValue = out.OriginValue.Elem()
out.OriginKind = out.OriginValue.Kind()
}
return
}
type OriginTypeAndKindOutput struct {
InputType reflect.Type
InputKind reflect.Kind
OriginType reflect.Type
OriginKind reflect.Kind
}
// OriginTypeAndKind retrieves and returns the original reflect type and kind.
func OriginTypeAndKind(value any) (out OriginTypeAndKindOutput) {
if value == nil {
return
}
if reflectType, ok := value.(reflect.Type); ok {
out.InputType = reflectType
} else {
if reflectValue, ok := value.(reflect.Value); ok {
out.InputType = reflectValue.Type()
} else {
out.InputType = reflect.TypeOf(value)
}
}
out.InputKind = out.InputType.Kind()
out.OriginType = out.InputType
out.OriginKind = out.InputKind
for out.OriginKind == reflect.Pointer {
out.OriginType = out.OriginType.Elem()
out.OriginKind = out.OriginType.Kind()
}
return
}
// ValueToInterface converts reflect value to its interface type.
func ValueToInterface(v reflect.Value) (value any, ok bool) {
if v.IsValid() && v.CanInterface() {
return v.Interface(), true
}
switch v.Kind() {
case reflect.Bool:
return v.Bool(), true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int(), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint(), true
case reflect.Float32, reflect.Float64:
return v.Float(), true
case reflect.Complex64, reflect.Complex128:
return v.Complex(), true
case reflect.String:
return v.String(), true
case reflect.Pointer:
return ValueToInterface(v.Elem())
case reflect.Interface:
return ValueToInterface(v.Elem())
default:
return nil, false
}
}

8
database/utils/utils.go Normal file
View File

@ -0,0 +1,8 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package utils provides some utility functions for internal usage.
package utils

View File

@ -0,0 +1,26 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import "reflect"
// IsArray checks whether given value is array/slice.
// Note that it uses reflect internally implementing this feature.
func IsArray(value any) bool {
rv := reflect.ValueOf(value)
kind := rv.Kind()
if kind == reflect.Pointer {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
case reflect.Array, reflect.Slice:
return true
default:
return false
}
}

View File

@ -0,0 +1,38 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import "git.magicany.cc/black1552/gin-base/database/command"
const (
// Debug key for checking if in debug mode.
commandEnvKeyForDebugKey = "gf.debug"
)
// isDebugEnabled marks whether GoFrame debug mode is enabled.
var isDebugEnabled = false
func init() {
// Debugging configured.
value := command.GetOptWithEnv(commandEnvKeyForDebugKey)
if value == "" || value == "0" || value == "false" {
isDebugEnabled = false
} else {
isDebugEnabled = true
}
}
// IsDebugEnabled checks and returns whether debug mode is enabled.
// The debug mode is enabled when command argument "gf.debug" or environment "GF_DEBUG" is passed.
func IsDebugEnabled() bool {
return isDebugEnabled
}
// SetDebugEnabled enables/disables the internal debug info.
func SetDebugEnabled(enabled bool) {
isDebugEnabled = enabled
}

View File

@ -0,0 +1,48 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import (
"io"
)
// ReadCloser implements the io.ReadCloser interface
// which is used for reading request body content multiple times.
//
// Note that it cannot be closed.
type ReadCloser struct {
index int // Current read position.
content []byte // Content.
repeatable bool // Mark the content can be repeatable read.
}
// NewReadCloser creates and returns a RepeatReadCloser object.
func NewReadCloser(content []byte, repeatable bool) io.ReadCloser {
return &ReadCloser{
content: content,
repeatable: repeatable,
}
}
// Read implements the io.ReadCloser interface.
func (b *ReadCloser) Read(p []byte) (n int, err error) {
// Make it repeatable reading.
if b.index >= len(b.content) && b.repeatable {
b.index = 0
}
n = copy(p, b.content[b.index:])
b.index += n
if b.index >= len(b.content) {
return n, io.EOF
}
return n, nil
}
// Close implements the io.ReadCloser interface.
func (b *ReadCloser) Close() error {
return nil
}

102
database/utils/utils_is.go Normal file
View File

@ -0,0 +1,102 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import (
"reflect"
"git.magicany.cc/black1552/gin-base/database/empty"
)
// IsNil checks whether `value` is nil, especially for any type value.
func IsNil(value any) bool {
return empty.IsNil(value)
}
// IsEmpty checks whether `value` is empty.
func IsEmpty(value any) bool {
return empty.IsEmpty(value)
}
// IsInt checks whether `value` is type of int.
func IsInt(value any) bool {
switch value.(type) {
case int, *int, int8, *int8, int16, *int16, int32, *int32, int64, *int64:
return true
}
return false
}
// IsUint checks whether `value` is type of uint.
func IsUint(value any) bool {
switch value.(type) {
case uint, *uint, uint8, *uint8, uint16, *uint16, uint32, *uint32, uint64, *uint64:
return true
}
return false
}
// IsFloat checks whether `value` is type of float.
func IsFloat(value any) bool {
switch value.(type) {
case float32, *float32, float64, *float64:
return true
}
return false
}
// IsSlice checks whether `value` is type of slice.
func IsSlice(value any) bool {
var (
reflectValue = reflect.ValueOf(value)
reflectKind = reflectValue.Kind()
)
for reflectKind == reflect.Pointer {
reflectValue = reflectValue.Elem()
reflectKind = reflectValue.Kind()
}
switch reflectKind {
case reflect.Slice, reflect.Array:
return true
}
return false
}
// IsMap checks whether `value` is type of map.
func IsMap(value any) bool {
var (
reflectValue = reflect.ValueOf(value)
reflectKind = reflectValue.Kind()
)
for reflectKind == reflect.Pointer {
reflectValue = reflectValue.Elem()
reflectKind = reflectValue.Kind()
}
switch reflectKind {
case reflect.Map:
return true
}
return false
}
// IsStruct checks whether `value` is type of struct.
func IsStruct(value any) bool {
reflectType := reflect.TypeOf(value)
if reflectType == nil {
return false
}
reflectKind := reflectType.Kind()
for reflectKind == reflect.Pointer {
reflectType = reflectType.Elem()
reflectKind = reflectType.Kind()
}
switch reflectKind {
case reflect.Struct:
return true
}
return false
}

View File

@ -0,0 +1,37 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import "fmt"
// ListToMapByKey converts `list` to a map[string]any of which key is specified by `key`.
// Note that the item value may be type of slice.
func ListToMapByKey(list []map[string]any, key string) map[string]any {
var (
s = ""
m = make(map[string]any)
tempMap = make(map[string][]any)
hasMultiValues bool
)
for _, item := range list {
if k, ok := item[key]; ok {
s = fmt.Sprintf(`%v`, k)
tempMap[s] = append(tempMap[s], item)
if len(tempMap[s]) > 1 {
hasMultiValues = true
}
}
}
for k, v := range tempMap {
if hasMultiValues {
m[k] = v
} else {
m[k] = v[0]
}
}
return m
}

View File

@ -0,0 +1,37 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
// MapPossibleItemByKey tries to find the possible key-value pair for given key ignoring cases and symbols.
//
// Note that this function might be of low performance.
func MapPossibleItemByKey(data map[string]any, key string) (foundKey string, foundValue any) {
if len(data) == 0 {
return
}
if v, ok := data[key]; ok {
return key, v
}
// Loop checking.
for k, v := range data {
if EqualFoldWithoutChars(k, key) {
return k, v
}
}
return "", nil
}
// MapContainsPossibleKey checks if the given `key` is contained in given map `data`.
// It checks the key ignoring cases and symbols.
//
// Note that this function might be of low performance.
func MapContainsPossibleKey(data map[string]any, key string) bool {
if k, _ := MapPossibleItemByKey(data, key); k != "" {
return true
}
return false
}

View File

@ -0,0 +1,26 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import (
"reflect"
)
// CanCallIsNil Can reflect.Value call reflect.Value.IsNil.
// It can avoid reflect.Value.IsNil panics.
func CanCallIsNil(v any) bool {
rv, ok := v.(reflect.Value)
if !ok {
return false
}
switch rv.Kind() {
case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Slice, reflect.UnsafePointer:
return true
default:
return false
}
}

180
database/utils/utils_str.go Normal file
View File

@ -0,0 +1,180 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import (
"bytes"
"strings"
"unicode"
)
// DefaultTrimChars are the characters which are stripped by Trim* functions in default.
var DefaultTrimChars = string([]byte{
'\t', // Tab.
'\v', // Vertical tab.
'\n', // New line (line feed).
'\r', // Carriage return.
'\f', // New page.
' ', // Ordinary space.
0x00, // NUL-byte.
0x85, // Delete.
0xA0, // Non-breaking space.
})
// IsLetterUpper checks whether the given byte b is in upper case.
func IsLetterUpper(b byte) bool {
if b >= byte('A') && b <= byte('Z') {
return true
}
return false
}
// IsLetterLower checks whether the given byte b is in lower case.
func IsLetterLower(b byte) bool {
if b >= byte('a') && b <= byte('z') {
return true
}
return false
}
// IsLetter checks whether the given byte b is a letter.
func IsLetter(b byte) bool {
return IsLetterUpper(b) || IsLetterLower(b)
}
// IsNumeric checks whether the given string s is numeric.
// Note that float string like "123.456" is also numeric.
func IsNumeric(s string) bool {
var (
dotCount = 0
length = len(s)
)
if length == 0 {
return false
}
for i := 0; i < length; i++ {
if (s[i] == '-' || s[i] == '+') && i == 0 {
if length == 1 {
return false
}
continue
}
if s[i] == '.' {
dotCount++
if i > 0 && i < length-1 && s[i-1] >= '0' && s[i-1] <= '9' {
continue
} else {
return false
}
}
if s[i] < '0' || s[i] > '9' {
return false
}
}
return dotCount <= 1
}
// UcFirst returns a copy of the string s with the first letter mapped to its upper case.
func UcFirst(s string) string {
if len(s) == 0 {
return s
}
if IsLetterLower(s[0]) {
return string(s[0]-32) + s[1:]
}
return s
}
// ReplaceByMap returns a copy of `origin`,
// which is replaced by a map in unordered way, case-sensitively.
func ReplaceByMap(origin string, replaces map[string]string) string {
for k, v := range replaces {
origin = strings.ReplaceAll(origin, k, v)
}
return origin
}
// RemoveSymbols removes all symbols from string and lefts only numbers and letters.
func RemoveSymbols(s string) string {
b := make([]rune, 0, len(s))
for _, c := range s {
if c > 127 {
b = append(b, c)
} else if (c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') {
b = append(b, c)
}
}
return string(b)
}
// EqualFoldWithoutChars checks string `s1` and `s2` equal case-insensitively,
// with/without chars '-'/'_'/'.'/' '.
func EqualFoldWithoutChars(s1, s2 string) bool {
return strings.EqualFold(RemoveSymbols(s1), RemoveSymbols(s2))
}
// SplitAndTrim splits string `str` by a string `delimiter` to an array,
// and calls Trim to every element of this array. It ignores the elements
// which are empty after Trim.
func SplitAndTrim(str, delimiter string, characterMask ...string) []string {
array := make([]string, 0)
for _, v := range strings.Split(str, delimiter) {
v = Trim(v, characterMask...)
if v != "" {
array = append(array, v)
}
}
return array
}
// Trim strips whitespace (or other characters) from the beginning and end of a string.
// The optional parameter `characterMask` specifies the additional stripped characters.
func Trim(str string, characterMask ...string) string {
trimChars := DefaultTrimChars
if len(characterMask) > 0 {
trimChars += characterMask[0]
}
return strings.Trim(str, trimChars)
}
// FormatCmdKey formats string `s` as command key using uniformed format.
func FormatCmdKey(s string) string {
return strings.ToLower(strings.ReplaceAll(s, "_", "."))
}
// FormatEnvKey formats string `s` as environment key using uniformed format.
func FormatEnvKey(s string) string {
return strings.ToUpper(strings.ReplaceAll(s, ".", "_"))
}
// StripSlashes un-quotes a quoted string by AddSlashes.
func StripSlashes(str string) string {
var buf bytes.Buffer
l, skip := len(str), false
for i, char := range str {
if skip {
skip = false
} else if char == '\\' {
if i+1 < l && str[i+1] == '\\' {
skip = true
}
continue
}
buf.WriteRune(char)
}
return buf.String()
}
// IsASCII checks whether given string is ASCII characters.
func IsASCII(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > unicode.MaxASCII {
return false
}
}
return true
}

View File

@ -1,948 +0,0 @@
# Magic-ORM 自主 ORM 框架架构文档
## 📋 目录
- [概述](#概述)
- [核心特性](#核心特性)
- [架构设计](#架构设计)
- [技术栈](#技术栈)
- [核心接口设计](#核心接口设计)
- [快速开始](#快速开始)
- [详细功能说明](#详细功能说明)
- [最佳实践](#最佳实践)
---
## 概述
Magic-ORM 是一个完全自主研发的企业级 Go 语言 ORM 框架,不依赖任何第三方 ORM 库。框架基于 `database/sql` 标准库构建,提供了全自动化事务管理、面向接口设计、智能字段映射等高级特性。支持 MySQL、SQLite 等主流数据库,内置完整的迁移管理和可观测性支持,帮助开发者快速构建高质量的数据访问层。
**设计理念:**
- 零依赖:仅依赖 Go 标准库 `database/sql`
- 高性能:优化的查询执行器和连接池管理
- 易用性:简洁的 API 设计和智能默认行为
- 可扩展:面向接口的设计,支持自定义驱动扩展
- **内置驱动**:框架自带所有主流数据库驱动,无需额外安装
---
## 核心特性
- **全自动化嵌套事务支持**:无需手动管理事务传播行为
- **面向接口化设计**:核心功能均通过接口暴露,便于 Mock 与扩展
- **内置主流数据库驱动**:开箱即用,并支持自定义驱动扩展
- **统一配置组件**:与框架配置体系无缝集成
- **单例模式数据库对象**:同一分组配置仅初始化一次
- **双模式操作**:原生 SQL + ORM 链式操作
- **OpenTelemetry 可观测性**:完整支持 Tracing、Logging、Metrics
- **智能结果映射**`Scan` 自动识别 Map/Struct/Slice无需 `sql.ErrNoRows` 判空
- **全自动字段映射**:无需结构体标签,自动匹配数据库字段
- **参数智能过滤**:自动识别并过滤无效/空值字段
- **Model/DAO 代码生成器**:一键生成全量数据访问代码
- **高级特性**调试模式、DryRun、自定义 Handler、软删除、时间自动更新、模型关联、主从集群等
- **自动化数据库迁移**:支持自动迁移、增量迁移、回滚迁移等完整迁移管理
---
## 架构设计
### 整体架构图
```mermaid
graph TB
A[应用层] --> B[Magic-ORM 框架]
B --> C[配置中心]
B --> D[数据库连接池]
B --> E[事务管理器]
B --> F[迁移管理器]
C --> C1[统一配置组件]
C --> C2[环境配置]
D --> D1[MySQL 驱动]
D --> D2[SQLite 驱动]
D --> D3[自定义驱动]
E --> E1[自动嵌套事务]
E --> E2[事务传播控制]
F --> F1[自动迁移]
F --> F2[增量迁移]
F --> F3[回滚迁移]
B --> G[观测性组件]
G --> G1[Tracing]
G --> G2[Logging]
G --> G3[Metrics]
B --> H[工具组件]
H --> H1[字段映射器]
H --> H2[参数过滤器]
H --> H3[结果映射器]
H --> H4[代码生成器]
```
### 目录结构
```
magic-orm/
├── core/ # 核心实现
│ ├── database.go # 数据库连接管理
│ ├── transaction.go # 事务管理
│ ├── query.go # 查询构建器
│ └── mapper.go # 字段映射器
├── migrate/ # 迁移管理
│ └── migrator.go # 自动迁移实现
├── generator/ # 代码生成器
│ ├── model.go # Model 生成
│ └── dao.go # DAO 生成
├── tracing/ # OpenTelemetry 集成
│ └── tracer.go # 链路追踪
└── driver/ # 数据库驱动适配(已内置)
├── mysql.go # MySQL 驱动(内置)
├── sqlite.go # SQLite 驱动(内置)
├── postgres.go # PostgreSQL 驱动(内置)
├── sqlserver.go # SQL Server 驱动(内置)
├── oracle.go # Oracle 驱动(内置)
└── clickhouse.go # ClickHouse 驱动(内置)
```
### 核心组件说明
#### 1. 数据库连接管理 (`core/database.go`)
- **单例模式**:全局唯一的 `DB` 实例,确保资源高效利用
- **多数据库支持**:支持 MySQL、SQLite、PostgreSQL、SQL Server、Oracle、ClickHouse 等
- **驱动内置**:所有主流数据库驱动已预装在框架中
- **连接池优化**:内置 sql.DB 连接池管理
- **健康检查**:启动时自动执行 `Ping()` 验证连接
**核心配置项:**
```go
Config{
DriverName: "mysql", // 驱动名称
DataSource: "dns", // 数据源连接字符串
MaxIdleConns: 10, // 最大空闲连接数
MaxOpenConns: 100, // 最大打开连接数
Debug: true, // 调试模式
}
```
#### 2. 查询构建器 (`core/query.go`)
提供流畅的链式查询接口:
- **条件查询**: Where, Or, And
- **字段选择**: Select, Omit
- **排序分页**: Order, Limit, Offset
- **分组统计**: Group, Having, Count
- **连接查询**: Join, LeftJoin, RightJoin
- **预加载**: Preload
**示例:**
```go
var users []model.User
db.Model(&model.User{}).
Where("status = ?", 1).
Select("id", "username").
Order("id DESC").
Limit(10).
Find(&users)
```
#### 3. 事务管理器 (`core/transaction.go`)
提供完整的事务管理能力:
- **自动嵌套事务**: 自动管理事务传播
- **保存点支持**: 支持部分回滚
- **生命周期回调**: Before/After 钩子
#### 4. 字段映射器 (`core/mapper.go`)
智能字段映射系统:
- **驼峰转下划线**: UserName -> user_name
- **标签解析**: 支持 db, json 标签
- **类型转换**: Go 类型与数据库类型自动转换
- **零值过滤**: 自动过滤空值和零值
#### 5. 迁移管理 (`migrate/migrator.go`)
完整的数据库迁移方案:
- **自动迁移**: 根据模型自动创建/修改表结构
- **增量迁移**: 支持添加字段、索引等
- **回滚支持**: 支持迁移回滚
- **版本管理**: 迁移版本记录和管理
#### 6. 驱动管理器 (`driver/manager.go`)
统一的驱动管理和注册中心:
- **驱动注册**: 自动注册所有内置驱动
- **驱动选择**: 根据配置自动选择合适的驱动
- **驱动扩展**: 支持用户自定义驱动注册
- **版本检测**: 自动检测数据库版本并适配特性
```go
// 驱动管理器会自动处理
var supportedDrivers = map[string]driver.Driver{
"mysql": &MySQLDriver{},
"sqlite": &SQLiteDriver{},
"postgres": &PostgresDriver{},
"sqlserver": &SQLServerDriver{},
"oracle": &OracleDriver{},
"clickhouse": &ClickHouseDriver{},
}
```
---
## 技术栈
### 核心依赖
| 组件 | 版本 | 说明 |
|------|------|------|
| Go | 1.25+ | 编程语言 |
| database/sql | stdlib | Go 标准库 |
| driver-go | Latest | 数据库驱动接口规范 |
| OpenTelemetry | Latest | 可观测性框架 |
| **内置驱动集合** | Latest | **包含所有主流数据库驱动** |
### 支持的数据库驱动
框架已内置以下数据库驱动,**无需额外安装**
- **MySQL**: 内置驱动(基于 `github.com/go-sql-driver/mysql`
- **SQLite**: 内置驱动(基于 `github.com/mattn/go-sqlite3`
- **PostgreSQL**: 内置驱动(基于 `github.com/lib/pq`
- **SQL Server**: 内置驱动(基于 `github.com/denisenkom/go-mssqldb`
- **Oracle**: 内置驱动(基于 `github.com/godror/godror`
- **ClickHouse**: 内置驱动(基于 `github.com/ClickHouse/clickhouse-go`
- **自定义驱动**: 实现 `driver.Driver` 接口即可扩展
> 💡 **说明**:框架在编译时已将所有主流数据库驱动打包,用户只需引入 `magic-orm` 即可完成所有数据库操作,无需单独安装各数据库驱动。
---
## 核心接口设计
### 1. 数据库连接接口
```go
// IDatabase 数据库连接接口
type IDatabase interface {
// 基础操作
DB() *sql.DB
Close() error
Ping() error
// 事务管理
Begin() (ITx, error)
Transaction(fn func(ITx) error) error
// 查询构建器
Model(model interface{}) IQuery
Table(name string) IQuery
Query(result interface{}, query string, args ...interface{}) error
Exec(query string, args ...interface{}) (sql.Result, error)
// 迁移管理
Migrate(models ...interface{}) error
// 配置
SetDebug(bool)
SetMaxIdleConns(int)
SetMaxOpenConns(int)
SetConnMaxLifetime(time.Duration)
}
```
### 2. 事务接口
```go
// ITx 事务接口
type ITx interface {
// 基础操作
Commit() error
Rollback() error
// 查询操作
Model(model interface{}) IQuery
Table(name string) IQuery
Insert(model interface{}) (int64, error)
BatchInsert(models interface{}, batchSize int) error
Update(model interface{}, data map[string]interface{}) error
Delete(model interface{}) error
// 原生 SQL
Query(result interface{}, query string, args ...interface{}) error
Exec(query string, args ...interface{}) (sql.Result, error)
}
```
### 3. 查询构建器接口
```go
// IQuery 查询构建器接口
type IQuery interface {
// 条件查询
Where(query string, args ...interface{}) IQuery
Or(query string, args ...interface{}) IQuery
And(query string, args ...interface{}) IQuery
// 字段选择
Select(fields ...string) IQuery
Omit(fields ...string) IQuery
// 排序
Order(order string) IQuery
OrderBy(field string, direction string) IQuery
// 分页
Limit(limit int) IQuery
Offset(offset int) IQuery
Page(page, pageSize int) IQuery
// 分组
Group(group string) IQuery
Having(having string, args ...interface{}) IQuery
// 连接
Join(join string, args ...interface{}) IQuery
LeftJoin(table, on string) IQuery
RightJoin(table, on string) IQuery
InnerJoin(table, on string) IQuery
// 预加载
Preload(relation string, conditions ...interface{}) IQuery
// 执行查询
First(result interface{}) error
Find(result interface{}) error
Count(count *int64) IQuery
Exists() (bool, error)
// 更新和删除
Updates(data interface{}) error
UpdateColumn(column string, value interface{}) error
Delete() error
// 特殊模式
Unscoped() IQuery
DryRun() IQuery
Debug() IQuery
// 构建 SQL不执行
Build() (string, []interface{})
}
```
### 4. 模型接口
```go
// IModel 模型接口
type IModel interface {
// 表名映射
TableName() string
// 生命周期回调(可选)
BeforeCreate(tx ITx) error
AfterCreate(tx ITx) error
BeforeUpdate(tx ITx) error
AfterUpdate(tx ITx) error
BeforeDelete(tx ITx) error
AfterDelete(tx ITx) error
BeforeSave(tx ITx) error
AfterSave(tx ITx) error
}
```
### 5. 字段映射器接口
```go
// IFieldMapper 字段映射器接口
type IFieldMapper interface {
// 结构体字段转数据库列
StructToColumns(model interface{}) (map[string]interface{}, error)
// 数据库列转结构体字段
ColumnsToStruct(row *sql.Rows, model interface{}) error
// 获取表名
GetTableName(model interface{}) string
// 获取主键字段
GetPrimaryKey(model interface{}) string
// 获取字段信息
GetFields(model interface{}) []FieldInfo
}
// FieldInfo 字段信息
type FieldInfo struct {
Name string // 字段名
Column string // 列名
Type string // Go 类型
DbType string // 数据库类型
Tag string // 标签
IsPrimary bool // 是否主键
IsAuto bool // 是否自增
}
```
### 6. 迁移管理器接口
```go
// IMigrator 迁移管理器接口
type IMigrator interface {
// 自动迁移
AutoMigrate(models ...interface{}) error
// 表操作
CreateTable(model interface{}) error
DropTable(model interface{}) error
HasTable(model interface{}) (bool, error)
RenameTable(oldName, newName string) error
// 列操作
AddColumn(model interface{}, field string) error
DropColumn(model interface{}, field string) error
HasColumn(model interface{}, field string) (bool, error)
RenameColumn(model interface{}, oldField, newField string) error
// 索引操作
CreateIndex(model interface{}, field string) error
DropIndex(model interface{}, field string) error
HasIndex(model interface{}, field string) (bool, error)
}
```
### 7. 代码生成器接口
```go
// ICodeGenerator 代码生成器接口
type ICodeGenerator interface {
// 生成 Model 代码
GenerateModel(table string, outputDir string) error
// 生成 DAO 代码
GenerateDAO(table string, outputDir string) error
// 生成完整代码
GenerateAll(tables []string, outputDir string) error
// 从数据库读取表结构
InspectTable(tableName string) (*TableSchema, error)
}
// TableSchema 表结构信息
type TableSchema struct {
Name string
Columns []ColumnInfo
Indexes []IndexInfo
}
```
### 8. 配置结构
```go
// Config 数据库配置
type Config struct {
DriverName string // 驱动名称
DataSource string // 数据源连接字符串
MaxIdleConns int // 最大空闲连接数
MaxOpenConns int // 最大打开连接数
ConnMaxLifetime time.Duration // 连接最大生命周期
Debug bool // 调试模式
// 主从配置
Replicas []string // 从库列表
ReadPolicy ReadPolicy // 读负载均衡策略
// OpenTelemetry
EnableTracing bool
ServiceName string
}
// ReadPolicy 读负载均衡策略
type ReadPolicy int
const (
Random ReadPolicy = iota
RoundRobin
LeastConn
)
```
---
## 快速开始
### 1. 安装 Magic-ORM
```bash
# 仅需安装 magic-orm所有数据库驱动已内置
go get github.com/your-org/magic-orm
```
> ✅ **无需单独安装数据库驱动!** 所有驱动已包含在 magic-orm 中。
### 2. 配置数据库
在配置文件中设置数据库参数:
```yaml
database:
type: mysql # 或 sqlite, postgres
dns: "user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
debug: true
max_idle_conns: 10
max_open_conns: 100
```
### 3. 定义模型
```go
package model
import "time"
type User struct {
ID int64 `json:"id" db:"id"`
Username string `json:"username" db:"username"`
Password string `json:"-" db:"password"`
Email string `json:"email" db:"email"`
Status int `json:"status" db:"status"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// 表名映射
func (User) TableName() string {
return "user"
}
```
### 4. 初始化数据库连接
```go
package main
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
"your-project/orm"
)
func main() {
// 初始化数据库连接
db, err := orm.NewDatabase(&orm.Config{
DriverName: "mysql",
DataSource: "user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local",
MaxIdleConns: 10,
MaxOpenConns: 100,
Debug: true,
})
if err != nil {
panic(err)
}
defer db.Close()
// 执行迁移
orm.Migrate(db, &model.User{})
}
```
### 5. CRUD 操作
```go
// 创建
user := &model.User{Username: "admin", Password: "123456", Email: "admin@example.com"}
id, err := db.Insert(user)
// 查询单个
var user model.User
err := db.Model(&model.User{}).Where("id = ?", 1).First(&user)
// 查询多个
var users []model.User
err := db.Model(&model.User{}).Where("status = ?", 1).Order("id DESC").Find(&users)
// 更新
err := db.Model(&model.User{}).Where("id = ?", 1).Updates(map[string]interface{}{
"email": "new@example.com",
})
// 删除
err := db.Model(&model.User{}).Where("id = ?", 1).Delete()
// 原生 SQL
var results []model.User
err := db.Query(&results, "SELECT * FROM user WHERE status = ?", 1)
```
### 6. 事务操作
```go
// 自动嵌套事务
err := db.Transaction(func(tx *orm.Tx) error {
// 创建用户
user := &model.User{Username: "test", Email: "test@example.com"}
_, err := tx.Insert(user)
if err != nil {
return err
}
// 创建关联数据(自动加入同一事务)
profile := &model.Profile{UserID: user.ID, Avatar: "default.png"}
_, err = tx.Insert(profile)
if err != nil {
return err
}
return nil
})
```
---
## 详细功能说明
### 1. 全自动化嵌套事务
框架自动管理事务的传播行为,支持以下场景:
- **REQUIRED**: 如果当前存在事务,则加入该事务;否则创建新事务
- **REQUIRES_NEW**: 无论当前是否存在事务,都创建新事务
- **NESTED**: 在当前事务中创建嵌套事务(使用保存点)
**示例:**
```go
// 外层事务
db.Transaction(func(tx *orm.Tx) error {
// 内层自动加入同一事务
userService.CreateUser(tx, user)
orderService.CreateOrder(tx, order)
return nil
})
```
### 2. 智能结果映射
无需手动处理 `sql.ErrNoRows`,框架自动识别返回类型:
```go
// 自动识别 Struct
var user model.User
db.Model(&model.User{}).Where("id = ?", 1).First(&user) // 不存在时返回零值,不报错
// 自动识别 Slice
var users []model.User
db.Model(&model.User{}).Where("status = ?", 1).Find(&users) // 空结果返回空切片,而非 nil
// 自动识别 Map
var result map[string]interface{}
db.Table("user").Where("id = ?", 1).First(&result)
```
### 3. 全自动字段映射
无需结构体标签,框架自动匹配字段:
```go
// 驼峰命名自动转下划线
type UserInfo struct {
UserName string // 自动映射到 user_name 字段
UserAge int // 自动映射到 user_age 字段
CreatedAt string // 自动映射到 created_at 字段
}
```
### 4. 参数智能过滤
自动过滤零值和空指针:
```go
// 仅更新非零值字段
updateData := &model.User{
Username: "newname", // 会被更新
Email: "", // 空值,自动过滤
Status: 0, // 零值,自动过滤
}
db.Model(&user).Updates(updateData)
```
### 5. OpenTelemetry 可观测性
完整支持分布式追踪:
```go
// 自动注入 Span
ctx, span := otel.Tracer("gin-base").Start(context.Background(), "DB Query")
defer span.End()
// 自动记录 SQL 执行时间、错误信息等
db.WithContext(ctx).Find(&users)
```
### 6. 数据库迁移管理
#### 自动迁移
```go
database.SetAutoMigrate(&model.User{}, &model.Order{})
```
#### 增量迁移
```go
// 添加新字段
type UserV2 struct {
model.User
Phone string ` + "`" + `json:"phone" gorm:"column:phone;type:varchar(20)"` + "`" + `
}
database.SetAutoMigrate(&UserV2{})
```
#### 字段操作
```go
// 重命名字段
database.RenameColumn(&model.User{}, "UserName", "Nickname")
// 删除字段
database.DropColumn(&model.User{}, "OldField")
```
### 7. 高级特性
#### 软删除
```go
type User struct {
ID int64 `json:"id" db:"id"`
DeletedAt *time.Time `json:"deleted_at" db:"deleted_at"` // 软删除标记
}
// 自动过滤已删除记录
db.Model(&model.User{}).Find(&users) // WHERE deleted_at IS NULL
// 强制包含已删除记录
db.Unscoped().Model(&model.User{}).Find(&users)
```
#### 调试模式
```yaml
database:
debug: true # 输出所有 SQL 日志
```
#### DryRun 模式
```go
// 生成 SQL 但不执行
sql, args := db.Model(&model.User{}).DryRun().Insert(&user)
fmt.Println(sql, args)
```
#### 自定义 Handler
```go
// 注册回调函数
db.Callback().Before("insert").Register("custom_before_insert", func(ctx context.Context, db *orm.DB) error {
// 自定义逻辑
return nil
})
```
#### 主从集群
```go
// 配置读写分离
db, err := orm.NewDatabase(&orm.Config{
DriverName: "mysql",
DataSource: "master_dsn",
Replicas: []string{"slave1_dsn", "slave2_dsn"},
})
```
---
## 最佳实践
### 1. 模型设计规范
```go
// ✅ 推荐:使用 db 标签明确字段映射
type User struct {
ID int64 `json:"id" db:"id"`
Username string `json:"username" db:"username"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// ❌ 不推荐:缺少字段映射标签
type User struct {
Id int64 // 无法自动映射到 id 列
UserName string // 可能映射错误
CreatedAt time.Time // 时间格式可能不匹配
}
```
### 2. 事务使用规范
```go
// ✅ 推荐:使用闭包自动管理事务
err := db.Transaction(func(tx *orm.Tx) error {
// 业务逻辑
return nil
})
// ❌ 不推荐:手动管理事务
tx, err := db.Begin()
if err != nil {
panic(err)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
```
### 3. 查询优化
```go
// ✅ 推荐:使用 Select 指定字段
db.Model(&model.User{}).Select("id", "username").Find(&users)
// ✅ 推荐:使用 Index 加速查询
// 在数据库层面创建索引
// CREATE INDEX idx_username ON user(username);
// ✅ 推荐:批量操作
users := []model.User{{}, {}, {}}
db.BatchInsert(&users, 100) // 每批 100 条
// ❌ 避免N+1 查询问题
for _, user := range users {
db.Model(&model.Order{}).Where("user_id = ?", user.ID).Find(&orders) // 循环查询
}
// ✅ 使用 Join 或预加载
db.Query(&results, "SELECT u.*, o.* FROM user u LEFT JOIN orders o ON u.id = o.user_id")
```
### 4. 错误处理
```go
// ✅ 推荐:统一错误处理
if err := db.Insert(&user); err != nil {
log.Error("创建用户失败", "error", err)
return err
}
// ✅ 使用 errors 包判断特定错误
if errors.Is(err, sql.ErrNoRows) {
// 记录不存在
}
```
### 5. 性能优化
```go
// 连接池配置
sqlDB := db.DB()
sqlDB.SetMaxIdleConns(10) // 最大空闲连接数
sqlDB.SetMaxOpenConns(100) // 最大打开连接数
sqlDB.SetConnMaxLifetime(time.Hour) // 连接最大生命周期
// 使用 Scan 替代 Find 提升性能
type Result struct {
ID int64 `db:"id"`
Username string `db:"username"`
}
var results []Result
db.Model(&model.User{}).Select("id", "username").Scan(&results)
```
---
## 常见问题
### Q: 如何处理并发写入?
A: 使用事务 + 乐观锁:
```go
type Product struct {
ID int64 `db:"id"`
Version int `db:"version"` // 版本号
}
// 更新时检查版本号
rows, err := db.Exec(
"UPDATE product SET version = ?, stock = ? WHERE id = ? AND version = ?",
newVersion, newStock, id, oldVersion,
)
count, _ := rows.RowsAffected()
if count == 0 {
return errors.New("乐观锁冲突,数据已被其他事务修改")
}
```
### Q: 如何实现读写分离?
A: 配置主从数据库连接:
```go
db, err := orm.NewDatabase(&orm.Config{
DriverName: "mysql",
DataSource: "master_dsn",
Replicas: []string{"slave1_dsn", "slave2_dsn"},
ReadPolicy: orm.RoundRobin, // 负载均衡策略
})
```
### Q: 如何批量插入大量数据?
A: 使用 `BatchInsert`
```go
users := make([]model.User, 10000)
// ... 填充数据 ...
db.BatchInsert(&users, 1000) // 每批 1000 条,共 10 批
```
### Q: 如何实现字段自动映射?
A: 框架会自动将驼峰命名转换为下划线命名:
```go
type UserInfo struct {
UserName string `db:"user_name"` // 自动映射到 user_name 字段
UserAge int `db:"user_age"` // 自动映射到 user_age 字段
CreatedAt string `db:"created_at"` // 自动映射到 created_at 字段
}
```
### Q: 如何处理时间字段?
A: 使用 `time.Time` 类型,框架会自动处理时区转换:
```go
type Event struct {
ID int64 `db:"id"`
StartTime time.Time `db:"start_time"`
EndTime time.Time `db:"end_time"`
}
```
---
## 更新日志
- **v1.0.0**: 初始版本发布
- 完全自主研发,零依赖第三方 ORM
- 基于 database/sql 标准库
- 全自动化事务管理
- 智能字段映射
- OpenTelemetry 集成
- 支持 MySQL、SQLite、PostgreSQL
---
## 贡献指南
欢迎提交 Issue 和 Pull Request
---
## 许可证
MIT License

View File

@ -1,375 +0,0 @@
# Magic-ORM 功能完整性验证报告
## 📋 验证概述
本文档验证 Magic-ORM 框架相对于 README.md 中定义的核心特性的完整实现情况。
---
## ✅ 已完整实现的核心特性
### 1. **全自动化嵌套事务支持** ✅
- **文件**: `core/transaction.go`
- **实现内容**:
- `Transaction()` 方法自动管理事务提交/回滚
- 支持 panic 时自动回滚
- 事务中可执行 Insert、BatchInsert、Update、Delete、Query 等操作
- **测试状态**: ✅ 通过
### 2. **面向接口化设计** ✅
- **文件**: `core/interfaces.go`
- **实现接口**:
- `IDatabase` - 数据库连接接口
- `ITx` - 事务接口
- `IQuery` - 查询构建器接口
- `IModel` - 模型接口
- `IFieldMapper` - 字段映射器接口
- `IMigrator` - 迁移管理器接口
- `ICodeGenerator` - 代码生成器接口
- **测试状态**: ✅ 通过
### 3. **内置主流数据库驱动** ✅
- **文件**: `driver/manager.go`, `driver/sqlite.go`
- **实现内容**:
- DriverManager 单例模式管理所有驱动
- SQLite 驱动已实现
- 支持 MySQL/PostgreSQL/SQL Server/Oracle/ClickHouse框架已预留接口
- **测试状态**: ✅ 通过
### 4. **统一配置组件** ✅
- **文件**: `core/interfaces.go`
- **Config 结构**:
```go
type Config struct {
DriverName string
DataSource string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
Debug bool
Replicas []string
ReadPolicy ReadPolicy
EnableTracing bool
ServiceName string
}
```
- **测试状态**: ✅ 通过
### 5. **单例模式数据库对象** ✅
- **文件**: `driver/manager.go`
- **实现内容**:
- `GetDefaultManager()` 使用 sync.Once 确保单例
- 驱动管理器全局唯一实例
- **测试状态**: ✅ 通过
### 6. **双模式操作** ✅
- **文件**: `core/query.go`, `core/database.go`
- **支持模式**:
- ✅ ORM 链式操作:`db.Model(&User{}).Where("id = ?", 1).Find(&user)`
- ✅ 原生 SQL`db.Query(&users, "SELECT * FROM user")`
- **测试状态**: ✅ 通过
### 7. **OpenTelemetry 可观测性**
- **文件**: `tracing/tracer.go`
- **实现内容**:
- 自动追踪所有数据库操作
- 记录 SQL 语句、参数、执行时间、影响行数
- 支持分布式追踪上下文
- **测试状态**: ✅ 通过
### 8. **智能结果映射** ✅
- **文件**: `core/result_mapper.go`
- **实现内容**:
- `MapToSlice()` - 映射到 Slice
- `MapToStruct()` - 映射到 Struct
- `ScanAll()` - 自动识别目标类型
- 无需手动处理 `sql.ErrNoRows`
- **测试状态**: ✅ 通过
### 9. **全自动字段映射** ✅
- **文件**: `core/mapper.go`
- **实现内容**:
- 驼峰命名自动转下划线
- 解析 db/json 标签
- Go 类型与数据库类型自动转换
- 零值自动过滤
- **测试状态**: ✅ 通过
### 10. **参数智能过滤** ✅
- **文件**: `core/filter.go`
- **实现内容**:
- `FilterZeroValues()` - 过滤零值
- `FilterEmptyStrings()` - 过滤空字符串
- `FilterNilValues()` - 过滤 nil 值
- `IsValidValue()` - 检查值有效性
- **测试状态**: ✅ 通过
### 11. **Model/DAO 代码生成器**
- **文件**: `generator/generator.go`
- **实现内容**:
- `GenerateModel()` - 生成 Model 代码
- `GenerateDAO()` - 生成 DAO 代码
- `GenerateAll()` - 一次性生成完整代码
- 支持自定义列信息
- **测试结果**: ✅ 成功生成 `generated/user.go`
### 12. **高级特性** ✅
- **文件**: 多个核心文件
- **已实现**:
- ✅ 调试模式 (`Debug()`)
- ✅ DryRun 模式 (`DryRun()`)
- ✅ 软删除 (`core/soft_delete.go`)
- ✅ 模型关联 (`core/relation.go`)
- ✅ 主从集群读写分离 (`core/read_write.go`)
- ✅ 查询缓存 (`core/cache.go`)
- **测试状态**: ✅ 通过
### 13. **自动化数据库迁移** ✅
- **文件**: `core/migrator.go`
- **实现内容**:
- ✅ `AutoMigrate()` - 自动迁移
- ✅ `CreateTable()` / `DropTable()`
- ✅ `HasTable()` / `RenameTable()`
- ✅ `AddColumn()` / `DropColumn()`
- ✅ `CreateIndex()` / `DropIndex()`
- ✅ 完整的 DDL 操作支持
- **测试状态**: ✅ 通过
---
## 📊 查询构建器完整方法集
### ✅ 已实现的方法
| 方法 | 功能 | 状态 |
|------|------|------|
| `Where()` | 条件查询 | ✅ |
| `Or()` | OR 条件 | ✅ |
| `And()` | AND 条件 | ✅ |
| `Select()` | 选择字段 | ✅ |
| `Omit()` | 排除字段 | ✅ |
| `Order()` | 排序 | ✅ |
| `OrderBy()` | 指定字段排序 | ✅ |
| `Limit()` | 限制数量 | ✅ |
| `Offset()` | 偏移量 | ✅ |
| `Page()` | 分页查询 | ✅ |
| `Group()` | 分组 | ✅ |
| `Having()` | HAVING 条件 | ✅ |
| `Join()` | JOIN 连接 | ✅ |
| `LeftJoin()` | 左连接 | ✅ |
| `RightJoin()` | 右连接 | ✅ |
| `InnerJoin()` | 内连接 | ✅ |
| `Preload()` | 预加载关联 | ✅ (框架) |
| `First()` | 查询第一条 | ✅ |
| `Find()` | 查询多条 | ✅ |
| `Scan()` | 扫描到自定义结构 | ✅ |
| `Count()` | 统计数量 | ✅ |
| `Exists()` | 存在性检查 | ✅ |
| `Updates()` | 更新数据 | ✅ |
| `UpdateColumn()` | 更新单字段 | ✅ |
| `Delete()` | 删除数据 | ✅ |
| `Unscoped()` | 忽略软删除 | ✅ |
| `DryRun()` | 干跑模式 | ✅ |
| `Debug()` | 调试模式 | ✅ |
| `Build()` | 构建 SQL | ✅ |
| `BuildUpdate()` | 构建 UPDATE | ✅ |
| `BuildDelete()` | 构建 DELETE | ✅ |
---
## 🎯 事务接口完整实现
### ITx 接口方法
| 方法 | 功能 | 实现状态 |
|------|------|---------|
| `Commit()` | 提交事务 | ✅ |
| `Rollback()` | 回滚事务 | ✅ |
| `Model()` | 基于模型查询 | ✅ |
| `Table()` | 基于表名查询 | ✅ |
| `Insert()` | 插入数据 | ✅ (返回 LastInsertId) |
| `BatchInsert()` | 批量插入 | ✅ (支持分批处理) |
| `Update()` | 更新数据 | ✅ |
| `Delete()` | 删除数据 | ✅ |
| `Query()` | 原生 SQL 查询 | ✅ |
| `Exec()` | 原生 SQL 执行 | ✅ |
---
## 🔧 新增核心组件
### 1. ParamFilter (参数过滤器)
```go
// 位置core/filter.go
- FilterZeroValues() // 过滤零值
- FilterEmptyStrings() // 过滤空字符串
- FilterNilValues() // 过滤 nil 值
- IsValidValue() // 检查值有效性
```
### 2. ResultSetMapper (结果集映射器)
```go
// 位置core/result_mapper.go
- MapToSlice() // 映射到 Slice
- MapToStruct() // 映射到 Struct
- ScanAll() // 通用扫描方法
```
### 3. CodeGenerator (代码生成器)
```go
// 位置generator/generator.go
- GenerateModel() // 生成 Model
- GenerateDAO() // 生成 DAO
- GenerateAll() // 生成完整代码
```
### 4. QueryCache (查询缓存)
```go
// 位置core/cache.go
- Set() // 设置缓存
- Get() // 获取缓存
- Delete() // 删除缓存
- Clear() // 清空缓存
- GenerateCacheKey() // 生成缓存键
```
### 5. ReadWriteDB (读写分离)
```go
// 位置core/read_write.go
- GetMaster() // 获取主库(写)
- GetSlave() // 获取从库(读)
- AddSlave() // 添加从库
- RemoveSlave() // 移除从库
- selectLeastConn() // 最少连接选择
```
### 6. RelationLoader (关联加载器)
```go
// 位置core/relation.go
- Preload() // 预加载关联
- loadHasOne() // 加载一对一
- loadHasMany() // 加载一对多
- loadBelongsTo() // 加载多对一
- loadManyToMany() // 加载多对多
```
---
## 📈 测试覆盖率
### 测试文件
- ✅ `core_test.go` - 核心功能测试
- ✅ `features_test.go` - 高级功能测试
- ✅ `validation_test.go` - 完整性验证测试
- ✅ `main_test.go` - 演示测试
### 测试结果汇总
```
=== RUN TestFieldMapper
✓ 字段映射器测试通过
=== RUN TestQueryBuilder
✓ 查询构建器测试通过
=== RUN TestResultSetMapper
✓ 结果集映射器测试通过
=== RUN TestSoftDelete
✓ 软删除功能测试通过
=== RUN TestQueryCache
✓ 查询缓存测试通过
=== RUN TestReadWriteDB
✓ 读写分离代码结构测试通过
=== RUN TestRelationLoader
✓ 关联加载代码结构测试通过
=== RUN TestTracing
✓ 链路追踪代码结构测试通过
=== RUN TestParamFilter
✓ 参数过滤器测试通过
=== RUN TestCodeGenerator
✓ Model 已生成generated\user.go
✓ 代码生成器测试通过
=== RUN TestAllCoreFeatures
✓ 所有核心功能验证完成
```
---
## 🎉 总结
### 实现完成度
- **核心接口**: 100% (8/8)
- **查询构建器方法**: 100% (33/33)
- **事务方法**: 100% (10/10)
- **高级特性**: 100% (6/6)
- **工具组件**: 100% (4/4)
- **代码生成**: 100% (2/2)
### 项目文件统计
```
db/
├── core/ # 核心实现 (12 个文件)
│ ├── interfaces.go # 接口定义
│ ├── database.go # 数据库连接
│ ├── query.go # 查询构建器
│ ├── transaction.go # 事务管理
│ ├── mapper.go # 字段映射器
│ ├── migrator.go # 迁移管理器
│ ├── result_mapper.go # 结果集映射器 ✨
│ ├── soft_delete.go # 软删除 ✨
│ ├── relation.go # 关联加载 ✨
│ ├── cache.go # 查询缓存 ✨
│ ├── read_write.go # 读写分离 ✨
│ └── filter.go # 参数过滤器 ✨
├── driver/ # 驱动层 (2 个文件)
│ ├── manager.go
│ └── sqlite.go
├── generator/ # 代码生成器 (1 个文件) ✨
│ └── generator.go
├── tracing/ # 链路追踪 (1 个文件)
│ └── tracer.go
├── model/ # 示例模型 (1 个文件)
│ └── user.go
├── core_test.go # 核心测试
├── features_test.go # 功能测试
├── validation_test.go # 完整性验证 ✨
├── example.go # 使用示例
└── README.md # 架构文档
```
### 编译状态
```bash
✅ go build ./... # 编译成功
```
### 功能验证
```bash
✅ go test -v validation_test.go # 所有核心功能验证通过
✅ go test -v features_test.go # 高级功能测试通过
✅ go test -v core_test.go # 核心功能测试通过
```
---
## 🚀 结论
**Magic-ORM 框架已 100% 完整实现 README.md 中定义的所有核心特性!**
框架具备:
- ✅ 完整的 CRUD 操作能力
- ✅ 强大的事务管理
- ✅ 智能的字段和结果映射
- ✅ 灵活的查询构建
- ✅ 完善的迁移工具
- ✅ 高效的代码生成
- ✅ 企业级的高级特性
- ✅ 全面的可观测性支持
**所有功能均已编译通过并通过测试验证!** 🎉

View File

@ -1,348 +0,0 @@
# Magic-ORM 代码生成器 - 命令行工具
## 🚀 快速开始
### 1. 构建命令行工具
**Windows:**
```bash
build.bat
```
**Linux/Mac:**
```bash
chmod +x build.sh
./build.sh
```
或者手动构建:
```bash
cd db
go build -o ../bin/gendb ./cmd/gendb
```
### 2. 使用方法
#### 基础用法
```bash
# 生成单个表
gendb user
# 生成多个表
gendb user product order
# 指定输出目录
gendb -o ./models user product
```
#### 高级用法
```bash
# 自定义列定义
gendb user id:int64:primary username:string email:string created_at:time.Time
# 混合使用(自动推断 + 自定义)
gendb -o ./generated user username:string email:string product name:string price:float64
# 查看版本
gendb -v
# 查看帮助
gendb -h
```
## 📋 功能特性
**自动生成**: 根据表名自动推断常用字段
**批量生成**: 一次生成多个表的代码
**自定义列**: 支持手动指定列定义
**灵活输出**: 可指定输出目录
**智能推断**: 自动识别常见表结构
## 🎯 支持的类型
| 类型别名 | Go 类型 |
|---------|---------|
| int, integer, bigint | int64 |
| string, text, varchar | string |
| time, datetime | time.Time |
| bool, boolean | bool |
| float, double | float64 |
| decimal | string |
## 📝 列定义格式
```
字段名:类型 [:primary] [:nullable]
```
示例:
- `id:int64:primary` - 主键 ID
- `username:string` - 用户名字段
- `email:string:nullable` - 可为空的邮箱字段
- `created_at:time.Time` - 创建时间字段
## 🔧 预设表结构
工具内置了常见表的默认结构:
### user / users
- id (主键)
- username
- email (可空)
- password
- status
- created_at
- updated_at
### product / products
- id (主键)
- name
- price
- stock
- description (可空)
- created_at
### order / orders
- id (主键)
- order_no
- user_id
- total_amount
- status
- created_at
## 💡 使用示例
### 示例 1: 快速生成用户模块
```bash
gendb user
```
生成文件:
- `generated/user.go` - User Model
- `generated/user_dao.go` - User DAO
### 示例 2: 生成电商模块
```bash
gendb -o ./shop user product order
```
生成文件:
- `shop/user.go`
- `shop/user_dao.go`
- `shop/product.go`
- `shop/product_dao.go`
- `shop/order.go`
- `shop/order_dao.go`
### 示例 3: 完全自定义
```bash
gendb article \
id:int64:primary \
title:string \
content:string:nullable \
author_id:int64 \
view_count:int \
published:bool \
created_at:time.Time
```
## 📁 生成的代码结构
### Model (user.go)
```go
package model
import "time"
// User user 表模型
type User struct {
ID int64 `json:"id" db:"id"`
Username string `json:"username" db:"username"`
Email string `json:"email" db:"email"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// TableName 表名
func (User) TableName() string {
return "user"
}
```
### DAO (user_dao.go)
```go
package dao
import (
"context"
"git.magicany.cc/black1552/gin-base/db/core"
"git.magicany.cc/black1552/gin-base/db/model"
)
// UserDAO user 表数据访问对象
type UserDAO struct {
db *core.Database
}
// NewUserDAO 创建 UserDAO 实例
func NewUserDAO(db *core.Database) *UserDAO {
return &UserDAO{db: db}
}
// Create 创建记录
func (dao *UserDAO) Create(ctx context.Context, model *model.User) error {
_, err := dao.db.Model(model).Insert(model)
return err
}
// GetByID 根据 ID 查询
func (dao *UserDAO) GetByID(ctx context.Context, id int64) (*model.User, error) {
var result model.User
err := dao.db.Model(&model.User{}).Where("id = ?", id).First(&result)
if err != nil {
return nil, err
}
return &result, nil
}
// ... 更多 CRUD 方法
```
## 🛠️ 安装到 PATH
### Windows
1. 将 `bin` 目录添加到系统环境变量 PATH
2. 或者复制 `gendb.exe` 到任意 PATH 中的目录
```powershell
# 临时添加到当前会话
$env:PATH += ";$(pwd)\bin"
# 永久添加(需要管理员权限)
[Environment]::SetEnvironmentVariable(
"Path",
$env:Path + ";$(pwd)\bin",
[EnvironmentVariableTarget]::Machine
)
```
### Linux/Mac
```bash
# 临时添加到当前会话
export PATH=$PATH:$(pwd)/bin
# 永久添加(添加到 ~/.bashrc 或 ~/.zshrc
echo 'export PATH=$PATH:$(pwd)/bin' >> ~/.bashrc
source ~/.bashrc
# 或者复制到系统目录
sudo cp bin/gendb /usr/local/bin/
```
## ⚙️ 选项说明
| 选项 | 简写 | 说明 | 默认值 |
|------|------|------|--------|
| `-version` | `-v` | 显示版本号 | - |
| `-help` | `-h` | 显示帮助信息 | - |
| `-o` | - | 输出目录 | `./generated` |
## 🎨 最佳实践
### 1. 从数据库读取真实结构
```bash
# 先用 SQL 导出表结构
mysql -u root -p -e "DESCRIBE your_database.users;"
# 然后根据输出调整列定义
```
### 2. 批量生成项目所有表
```bash
# 一次性生成所有表
gendb user product order category tag article comment
```
### 3. 版本控制
```bash
# 将生成的代码纳入 Git 管理
git add generated/
git commit -m "feat: 生成基础 Model 和 DAO 代码"
```
### 4. 自定义扩展
生成的代码可以作为基础,手动添加:
- 业务逻辑方法
- 验证逻辑
- 关联查询
- 索引优化
## ⚠️ 注意事项
1. **生成的代码需审查**: 自动生成的代码可能不完全符合业务需求
2. **不要频繁覆盖**: 手动修改的代码可能会被覆盖
3. **类型映射**: 特殊类型可能需要手动调整
4. **关联关系**: 复杂的模型关联需手动实现
## 🐛 故障排除
### 问题 1: 找不到命令
```bash
# 确保已构建并添加到 PATH
gendb: command not found
# 解决:
./bin/gendb -h # 使用相对路径
```
### 问题 2: 生成失败
```bash
# 检查输出目录是否有写权限
# 检查表名是否合法
# 使用 -h 查看正确的语法
```
### 问题 3: 类型不匹配
```bash
# 手动指定正确的类型
gendb user price:float64 instead of price:int
```
## 📞 获取帮助
```bash
# 查看完整帮助
gendb -h
# 查看版本
gendb -v
```
## 🎉 开始使用
```bash
# 最简单的用法
gendb user
# 立即体验!
```
---
**Magic-ORM Code Generator** - 让代码生成如此简单!🚀

View File

@ -1,425 +0,0 @@
package main
import (
"flag"
"fmt"
"os"
"strings"
"git.magicany.cc/black1552/gin-base/db/config"
"git.magicany.cc/black1552/gin-base/db/generator"
"git.magicany.cc/black1552/gin-base/db/introspector"
)
// 设置 Windows 控制台编码为 UTF-8
func init() {
// 在 Windows 上设置控制台输出代码页为 UTF-8 (65001)
// 这样可以避免中文乱码问题
}
const version = "1.0.0"
func main() {
// 定义命令行参数
versionFlag := flag.Bool("version", false, "显示版本号")
vFlag := flag.Bool("v", false, "显示版本号(简写)")
helpFlag := flag.Bool("help", false, "显示帮助信息")
hFlag := flag.Bool("h", false, "显示帮助信息(简写)")
outputDir := flag.String("o", "./model", "输出目录")
allFlag := flag.Bool("all", false, "生成所有预设的表user, product, order")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, `Magic-ORM - Model DAO
:
gendb [] <> [...]
gendb [] -all
:
`)
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, `
:
# user
gendb user
#
gendb -o ./models user product
#
gendb user id:int64:primary username:string email:string created_at:time.Time
#
gendb user product order
# user, product, order
gendb -all
:
[:primary] [:nullable]
:
int64, string, time.Time, bool, float64, int
:
https://github.com/your-repo/magic-orm
`)
}
flag.Parse()
// 检查版本参数
if *versionFlag || *vFlag {
fmt.Printf("Magic-ORM Code Generator v%s\n", version)
return
}
// 检查帮助参数
if *helpFlag || *hFlag {
flag.Usage()
return
}
// 检查 -all 参数
if *allFlag {
generateAllTablesFromDB(*outputDir)
return
}
// 获取参数
args := flag.Args()
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "错误:请指定至少一个表名")
fmt.Fprintln(os.Stderr, "使用 'gendb -h' 查看帮助")
fmt.Fprintln(os.Stderr, "或者使用 'gendb -all' 生成所有预设表")
os.Exit(1)
}
tableNames := args
// 创建代码生成器
cg := generator.NewCodeGenerator(*outputDir)
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
fmt.Printf("[Output Directory: %s]\n", *outputDir)
fmt.Println()
// 处理每个表
for _, tableName := range tableNames {
// 跳过看起来像列定义的参数
if strings.Contains(tableName, ":") {
continue
}
fmt.Printf("[Generating table '%s'...]\n", tableName)
// 解析列定义(如果有提供)
columns := parseColumns(tableNames, tableName)
// 如果没有自定义列定义,使用默认列
if len(columns) == 0 {
columns = getDefaultColumns(tableName)
}
// 生成代码
err := cg.GenerateAll(tableName, columns)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
continue
}
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
}
fmt.Println()
fmt.Println("[Complete] Code generation finished!")
fmt.Printf("[Location] Files are in: %s directory\n", *outputDir)
}
// parseColumns 解析列定义
func parseColumns(args []string, currentTable string) []generator.ColumnInfo {
// 查找当前表的列定义
found := false
columnDefs := []string{}
for i, arg := range args {
if arg == currentTable && !found {
found = true
// 收集后续的列定义
for j := i + 1; j < len(args); j++ {
if strings.Contains(args[j], ":") {
columnDefs = append(columnDefs, args[j])
} else {
break // 遇到下一个表名
}
}
break
}
}
if len(columnDefs) == 0 {
return nil
}
columns := []generator.ColumnInfo{}
for _, def := range columnDefs {
parts := strings.Split(def, ":")
if len(parts) < 2 {
continue
}
colName := parts[0]
fieldType := parts[1]
isPrimary := false
isNullable := false
// 检查修饰符
for i := 2; i < len(parts); i++ {
switch strings.ToLower(parts[i]) {
case "primary":
isPrimary = true
case "nullable":
isNullable = true
}
}
// 转换为 Go 字段名(驼峰)
fieldName := toCamelCase(colName)
// 映射类型
goType := mapType(fieldType)
columns = append(columns, generator.ColumnInfo{
ColumnName: colName,
FieldName: fieldName,
FieldType: goType,
JSONName: colName,
IsPrimary: isPrimary,
IsNullable: isNullable,
})
}
return columns
}
// getDefaultColumns 获取默认的列定义(根据表名推断)
func getDefaultColumns(tableName string) []generator.ColumnInfo {
columns := []generator.ColumnInfo{
{
ColumnName: "id",
FieldName: "ID",
FieldType: "int64",
JSONName: "id",
IsPrimary: true,
},
}
// 根据表名添加常见字段
switch tableName {
case "user", "users":
columns = append(columns,
generator.ColumnInfo{ColumnName: "username", FieldName: "Username", FieldType: "string", JSONName: "username"},
generator.ColumnInfo{ColumnName: "email", FieldName: "Email", FieldType: "string", JSONName: "email", IsNullable: true},
generator.ColumnInfo{ColumnName: "password", FieldName: "Password", FieldType: "string", JSONName: "password"},
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
)
case "product", "products":
columns = append(columns,
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
generator.ColumnInfo{ColumnName: "price", FieldName: "Price", FieldType: "float64", JSONName: "price"},
generator.ColumnInfo{ColumnName: "stock", FieldName: "Stock", FieldType: "int", JSONName: "stock"},
generator.ColumnInfo{ColumnName: "description", FieldName: "Description", FieldType: "string", JSONName: "description", IsNullable: true},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
)
case "order", "orders":
columns = append(columns,
generator.ColumnInfo{ColumnName: "order_no", FieldName: "OrderNo", FieldType: "string", JSONName: "order_no"},
generator.ColumnInfo{ColumnName: "user_id", FieldName: "UserID", FieldType: "int64", JSONName: "user_id"},
generator.ColumnInfo{ColumnName: "total_amount", FieldName: "TotalAmount", FieldType: "float64", JSONName: "total_amount"},
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
)
default:
// 默认添加通用字段
columns = append(columns,
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
)
}
return columns
}
// mapType 将类型字符串映射到 Go 类型
func mapType(typeStr string) string {
typeMap := map[string]string{
"int": "int64",
"integer": "int64",
"bigint": "int64",
"string": "string",
"text": "string",
"varchar": "string",
"time.Time": "time.Time",
"time": "time.Time",
"datetime": "time.Time",
"bool": "bool",
"boolean": "bool",
"float": "float64",
"float64": "float64",
"double": "float64",
"decimal": "string",
}
if goType, ok := typeMap[strings.ToLower(typeStr)]; ok {
return goType
}
return "string" // 默认返回 string
}
// toCamelCase 转换为驼峰命名
func toCamelCase(str string) string {
parts := strings.Split(str, "_")
result := ""
for _, part := range parts {
if len(part) > 0 {
result += strings.ToUpper(string(part[0])) + part[1:]
}
}
return result
}
// generateAllTablesFromDB 从数据库读取所有表并生成代码
func generateAllTablesFromDB(outputDir string) {
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
fmt.Println()
// 1. 加载配置文件
fmt.Println("[Step 1] Loading configuration file...")
cfg, err := loadDatabaseConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to load config: %v\n", err)
os.Exit(1)
}
fmt.Printf("[Info] Database type: %s\n", cfg.Type)
fmt.Printf("[Info] Database name: %s\n", cfg.Name)
fmt.Println()
// 2. 连接数据库并获取所有表
fmt.Println("[Step 2] Connecting to database and fetching table structure...")
intro, err := introspector.NewIntrospector(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to connect to database: %v\n", err)
os.Exit(1)
}
defer intro.Close()
tableNames, err := intro.GetTableNames()
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to get table names: %v\n", err)
os.Exit(1)
}
fmt.Printf("[Info] Found %d tables\n", len(tableNames))
fmt.Println()
// 3. 创建代码生成器
cg := generator.NewCodeGenerator(outputDir)
// 4. 为每个表生成代码
for _, tableName := range tableNames {
fmt.Printf("[Generating] Table '%s'...\n", tableName)
// 获取表详细信息
tableInfo, err := intro.GetTableInfo(tableName)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to get table info: %v\n", err)
continue
}
// 转换为 generator.ColumnInfo
columns := make([]generator.ColumnInfo, len(tableInfo.Columns))
for i, col := range tableInfo.Columns {
columns[i] = generator.ColumnInfo{
ColumnName: col.ColumnName,
FieldName: col.FieldName,
FieldType: col.GoType,
JSONName: col.JSONName,
IsPrimary: col.IsPrimary,
IsNullable: col.IsNullable,
}
}
// 生成代码
err = cg.GenerateAll(tableName, columns)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
continue
}
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
}
fmt.Println()
fmt.Println("[Complete] Code generation finished!")
fmt.Printf("[Location] Files are in: %s directory\n", outputDir)
}
// loadDatabaseConfig 加载数据库配置
func loadDatabaseConfig() (*config.DatabaseConfig, error) {
// 自动查找配置文件
configPath, err := config.FindConfigFile("")
if err != nil {
return nil, fmt.Errorf("查找配置文件失败:%w", err)
}
fmt.Printf("[Info] Using config file: %s\n", configPath)
// 从文件加载配置
cfg, err := config.LoadFromFile(configPath)
if err != nil {
return nil, fmt.Errorf("加载配置文件失败:%w", err)
}
return &cfg.Database, nil
}
// generateAllTables 生成所有预设的表
func generateAllTables(outputDir string) {
fmt.Printf("🚀 Magic-ORM 代码生成器 v%s\n", version)
fmt.Printf("📁 输出目录:%s\n", outputDir)
fmt.Println()
// 预设的所有表
presetTables := []string{"user", "product", "order"}
// 创建代码生成器
cg := generator.NewCodeGenerator(outputDir)
// 处理每个表
for _, tableName := range presetTables {
fmt.Printf("📝 生成表 '%s' 的代码...\n", tableName)
// 使用默认列定义
columns := getDefaultColumns(tableName)
// 生成代码
err := cg.GenerateAll(tableName, columns)
if err != nil {
fmt.Fprintf(os.Stderr, "❌ 生成失败:%v\n", err)
continue
}
fmt.Printf("✅ 成功生成 %s.go 和 %s_dao.go\n", tableName, tableName)
}
fmt.Println()
fmt.Println("✨ 代码生成完成!")
fmt.Printf("📂 生成的文件在:%s 目录下\n", outputDir)
}

View File

@ -1,23 +0,0 @@
# Magic-ORM 数据库配置文件示例
# 请根据实际情况修改以下配置
database:
# 数据库地址MySQL/PostgreSQL 为 IP 或域名SQLite 可忽略)
host: "127.0.0.1"
# 数据库端口
port: "3306"
# 数据库用户名
user: "root"
# 数据库密码
pass: "your_password"
# 数据库名称
name: "your_database"
# 数据库类型支持mysql, postgres, sqlite
type: "mysql"
# 其他配置可以继续添加...

Binary file not shown.

View File

@ -1,142 +0,0 @@
package config
import (
"fmt"
"os"
"path/filepath"
"testing"
)
// TestAutoFindConfig 测试自动查找配置文件
func TestAutoFindConfig(t *testing.T) {
fmt.Println("\n=== 测试自动查找配置文件 ===")
// 创建临时目录结构
tempDir, err := os.MkdirTemp("", "config_test")
if err != nil {
t.Fatalf("创建临时目录失败:%v", err)
}
defer os.RemoveAll(tempDir)
// 创建子目录
subDir := filepath.Join(tempDir, "subdir")
if err := os.MkdirAll(subDir, 0755); err != nil {
t.Fatalf("创建子目录失败:%v", err)
}
// 在根目录创建配置文件
configContent := `database:
host: "127.0.0.1"
port: "3306"
user: "root"
pass: "test"
name: "testdb"
type: "mysql"
`
configFile := filepath.Join(tempDir, "config.yaml")
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
t.Fatalf("创建配置文件失败:%v", err)
}
// 测试 1从子目录查找应该能找到父目录的配置
foundPath, err := findConfigFile(subDir)
if err != nil {
t.Errorf("从子目录查找失败:%v", err)
} else {
fmt.Printf("✓ 从子目录找到配置文件:%s\n", foundPath)
}
// 测试 2从根目录查找
foundPath, err = findConfigFile(tempDir)
if err != nil {
t.Errorf("从根目录查找失败:%v", err)
} else {
fmt.Printf("✓ 从根目录找到配置文件:%s\n", foundPath)
}
// 测试 3测试不同格式的配置文件
formats := []string{"config.yaml", "config.yml", "config.toml", "config.json"}
for _, format := range formats {
testFile := filepath.Join(tempDir, format)
if err := os.WriteFile(testFile, []byte(configContent), 0644); err != nil {
continue
}
foundPath, err = findConfigFile(tempDir)
if err != nil {
t.Errorf("查找 %s 失败:%v", format, err)
} else {
fmt.Printf("✓ 支持格式 %s: %s\n", format, foundPath)
}
os.Remove(testFile)
}
fmt.Println("✓ 自动查找配置文件测试通过")
}
// TestAutoConnect 测试自动连接功能
func TestAutoConnect(t *testing.T) {
fmt.Println("\n=== 测试 AutoConnect 接口 ===")
// 创建临时配置文件
tempDir, err := os.MkdirTemp("", "autoconnect_test")
if err != nil {
t.Fatalf("创建临时目录失败:%v", err)
}
defer os.RemoveAll(tempDir)
configContent := `database:
host: "127.0.0.1"
port: "3306"
user: "root"
pass: "test"
name: ":memory:"
type: "sqlite"
`
configFile := filepath.Join(tempDir, "config.yaml")
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
t.Fatalf("创建配置文件失败:%v", err)
}
// 切换到临时目录
oldDir, _ := os.Getwd()
os.Chdir(tempDir)
defer os.Chdir(oldDir)
// 测试 AutoConnect
_, err = AutoConnect(false)
if err != nil {
t.Logf("自动连接失败(预期):%v", err)
fmt.Println("✓ AutoConnect 接口正常(需要真实数据库才能连接成功)")
} else {
fmt.Println("✓ AutoConnect 自动连接成功")
}
fmt.Println("✓ AutoConnect 测试完成")
}
// TestAllAutoFind 完整自动查找测试
func TestAllAutoFind(t *testing.T) {
fmt.Println("\n========================================")
fmt.Println(" 配置文件自动查找完整性测试")
fmt.Println("========================================")
TestAutoFindConfig(t)
TestAutoConnect(t)
fmt.Println("\n========================================")
fmt.Println(" 所有自动查找测试完成!")
fmt.Println("========================================")
fmt.Println()
fmt.Println("已实现的自动查找功能:")
fmt.Println(" ✓ 自动在当前目录查找配置文件")
fmt.Println(" ✓ 自动在上级目录查找(最多 3 层)")
fmt.Println(" ✓ 支持 yaml, yml, toml, ini, json 格式")
fmt.Println(" ✓ 支持 config.* 和 .config.* 命名")
fmt.Println(" ✓ 提供 AutoConnect() 一键连接")
fmt.Println(" ✓ 无需手动指定配置文件路径")
fmt.Println()
}

View File

@ -1,95 +0,0 @@
package config
import (
"fmt"
"git.magicany.cc/black1552/gin-base/db/core"
"gopkg.in/yaml.v3"
)
// NewDatabaseFromConfig 从配置文件创建数据库连接(已废弃,请使用 AutoConnect
// Deprecated: 使用 AutoConnect 代替
func NewDatabaseFromConfig(configPath string, debug bool) (*core.Database, error) {
return autoConnectWithConfig(configPath, debug)
}
// AutoConnect 自动查找配置文件并创建数据库连接
// 会在当前目录及上级目录中查找 config.yaml, config.toml, config.ini, config.json 等文件
func AutoConnect(debug bool) (*core.Database, error) {
// 自动查找配置文件
configPath, err := FindConfigFile("")
if err != nil {
return nil, fmt.Errorf("查找配置文件失败:%w", err)
}
return autoConnectWithConfig(configPath, debug)
}
// AutoConnectWithDir 在指定目录自动查找配置文件并创建数据库连接
func AutoConnectWithDir(dir string, debug bool) (*core.Database, error) {
configPath, err := FindConfigFile(dir)
if err != nil {
return nil, fmt.Errorf("查找配置文件失败:%w", err)
}
return autoConnectWithConfig(configPath, debug)
}
// autoConnectWithConfig 根据配置文件创建数据库连接(内部使用)
func autoConnectWithConfig(configPath string, debug bool) (*core.Database, error) {
// 从文件加载配置
configFile, err := LoadFromFile(configPath)
if err != nil {
return nil, fmt.Errorf("加载配置失败:%w", err)
}
// 构建核心数据库配置
dbConfig := &core.Config{
DriverName: configFile.Database.GetDriverName(),
DataSource: configFile.Database.BuildDSN(),
Debug: debug,
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 3600000000000, // 1 小时
TimeConfig: core.DefaultTimeConfig(), // 使用默认时间配置
}
// 创建数据库连接
db, err := core.NewDatabase(dbConfig)
if err != nil {
return nil, fmt.Errorf("创建数据库连接失败:%w", err)
}
return db, nil
}
// NewDatabaseFromYAML 从 YAML 内容创建数据库连接
func NewDatabaseFromYAML(yamlContent []byte, debug bool) (*core.Database, error) {
var configFile Config
if err := yaml.Unmarshal(yamlContent, &configFile); err != nil {
return nil, fmt.Errorf("解析 YAML 失败:%w", err)
}
if err := configFile.Validate(); err != nil {
return nil, fmt.Errorf("验证配置失败:%w", err)
}
// 构建核心数据库配置
dbConfig := &core.Config{
DriverName: configFile.Database.GetDriverName(),
DataSource: configFile.Database.BuildDSN(),
Debug: debug,
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 3600000000000, // 1 小时
TimeConfig: core.DefaultTimeConfig(),
}
// 创建数据库连接
db, err := core.NewDatabase(dbConfig)
if err != nil {
return nil, fmt.Errorf("创建数据库连接失败:%w", err)
}
return db, nil
}

View File

@ -1,165 +0,0 @@
package config
import (
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
// DatabaseConfig 数据库配置结构 - 对应配置文件中的 database 部分
type DatabaseConfig struct {
Host string `yaml:"host"` // 数据库地址
Port string `yaml:"port"` // 数据库端口
User string `yaml:"user"` // 用户名
Password string `yaml:"pass"` // 密码
Name string `yaml:"name"` // 数据库名称
Type string `yaml:"type"` // 数据库类型mysql, sqlite, postgres 等)
}
// Config 完整配置文件结构
type Config struct {
Database DatabaseConfig `yaml:"database"` // 数据库配置
}
// LoadFromFile 从 YAML 文件加载配置
func LoadFromFile(filePath string) (*Config, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败:%w", err)
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("解析配置文件失败:%w", err)
}
// 验证必填字段
if err := config.Validate(); err != nil {
return nil, err
}
return &config, nil
}
// Validate 验证配置
func (c *Config) Validate() error {
if c.Database.Type == "" {
return fmt.Errorf("数据库类型不能为空")
}
if c.Database.Type == "sqlite" {
// SQLite 只需要 Name作为文件路径
if c.Database.Name == "" {
return fmt.Errorf("SQLite 数据库名称不能为空")
}
} else {
// 其他数据库需要所有字段
if c.Database.Host == "" {
return fmt.Errorf("数据库地址不能为空")
}
if c.Database.Port == "" {
return fmt.Errorf("数据库端口不能为空")
}
if c.Database.User == "" {
return fmt.Errorf("数据库用户名不能为空")
}
if c.Database.Password == "" {
return fmt.Errorf("数据库密码不能为空")
}
if c.Database.Name == "" {
return fmt.Errorf("数据库名称不能为空")
}
}
return nil
}
// BuildDSN 根据配置构建数据源连接字符串DSN
func (c *DatabaseConfig) BuildDSN() string {
switch c.Type {
case "mysql":
return c.buildMySQLDSN()
case "postgres":
return c.buildPostgresDSN()
case "sqlite":
return c.buildSQLiteDSN()
default:
// 默认返回原始配置
return ""
}
}
// buildMySQLDSN 构建 MySQL DSN
func (c *DatabaseConfig) buildMySQLDSN() string {
// 格式user:pass@tcp(host:port)/dbname?charset=utf8mb4&parseTime=True&loc=Local
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
c.User,
c.Password,
c.Host,
c.Port,
c.Name,
)
return dsn
}
// buildPostgresDSN 构建 PostgreSQL DSN
func (c *DatabaseConfig) buildPostgresDSN() string {
// 格式host=localhost port=5432 user=user password=password dbname=db sslmode=disable
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
c.Host,
c.Port,
c.User,
c.Password,
c.Name,
)
return dsn
}
// buildSQLiteDSN 构建 SQLite DSN
func (c *DatabaseConfig) buildSQLiteDSN() string {
// SQLite 直接使用文件名作为 DSN
return c.Name
}
// GetDriverName 获取驱动名称
func (c *DatabaseConfig) GetDriverName() string {
return c.Type
}
// FindConfigFile 在项目目录下自动查找配置文件
// 支持 yaml, yml, toml, ini, json 等格式
// 只在当前目录查找,不越级查找
func FindConfigFile(searchDir string) (string, error) {
// 配置文件名优先级列表
configNames := []string{
"config.yaml", "config.yml",
"config.toml",
"config.ini",
"config.json",
".config.yaml", ".config.yml",
".config.toml",
".config.ini",
".config.json",
}
// 如果未指定搜索目录,使用当前目录
if searchDir == "" {
var err error
searchDir, err = os.Getwd()
if err != nil {
return "", fmt.Errorf("获取当前目录失败:%w", err)
}
}
// 只在当前目录下查找,不向上查找
for _, name := range configNames {
filePath := filepath.Join(searchDir, name)
if _, err := os.Stat(filePath); err == nil {
return filePath, nil
}
}
return "", fmt.Errorf("未找到配置文件(支持 yaml, yml, toml, ini, json 格式)")
}

View File

@ -1,194 +0,0 @@
package config
import (
"fmt"
"os"
"testing"
)
// TestLoadFromFile 测试从文件加载配置
func TestLoadFromFile(t *testing.T) {
fmt.Println("\n=== 测试从文件加载配置 ===")
// 创建临时配置文件
tempConfig := `database:
host: "127.0.0.1"
port: "3306"
user: "root"
pass: "test_password"
name: "test_db"
type: "mysql"
`
// 写入临时文件
tempFile := "test_config.yaml"
if err := os.WriteFile(tempFile, []byte(tempConfig), 0644); err != nil {
t.Fatalf("创建临时文件失败:%v", err)
}
defer os.Remove(tempFile) // 测试完成后删除
// 加载配置
config, err := LoadFromFile(tempFile)
if err != nil {
t.Fatalf("加载配置失败:%v", err)
}
// 验证配置
if config.Database.Host != "127.0.0.1" {
t.Errorf("期望 Host 为 127.0.0.1,实际为 %s", config.Database.Host)
}
if config.Database.Port != "3306" {
t.Errorf("期望 Port 为 3306实际为 %s", config.Database.Port)
}
if config.Database.User != "root" {
t.Errorf("期望 User 为 root实际为 %s", config.Database.User)
}
if config.Database.Password != "test_password" {
t.Errorf("期望 Password 为 test_password实际为 %s", config.Database.Password)
}
if config.Database.Name != "test_db" {
t.Errorf("期望 Name 为 test_db实际为 %s", config.Database.Name)
}
if config.Database.Type != "mysql" {
t.Errorf("期望 Type 为 mysql实际为 %s", config.Database.Type)
}
fmt.Printf("✓ 配置加载成功\n")
fmt.Printf(" Host: %s\n", config.Database.Host)
fmt.Printf(" Port: %s\n", config.Database.Port)
fmt.Printf(" User: %s\n", config.Database.User)
fmt.Printf(" Pass: %s\n", config.Database.Password)
fmt.Printf(" Name: %s\n", config.Database.Name)
fmt.Printf(" Type: %s\n", config.Database.Type)
}
// TestBuildDSN 测试 DSN 构建
func TestBuildDSN(t *testing.T) {
fmt.Println("\n=== 测试 DSN 构建 ===")
testCases := []struct {
name string
config DatabaseConfig
expected string
}{
{
name: "MySQL",
config: DatabaseConfig{
Host: "127.0.0.1",
Port: "3306",
User: "root",
Password: "password",
Name: "testdb",
Type: "mysql",
},
expected: "root:password@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local",
},
{
name: "PostgreSQL",
config: DatabaseConfig{
Host: "localhost",
Port: "5432",
User: "postgres",
Password: "secret",
Name: "mydb",
Type: "postgres",
},
expected: "host=localhost port=5432 user=postgres password=secret dbname=mydb sslmode=disable",
},
{
name: "SQLite",
config: DatabaseConfig{
Name: "./test.db",
Type: "sqlite",
},
expected: "./test.db",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dsn := tc.config.BuildDSN()
if dsn != tc.expected {
t.Errorf("期望 DSN 为 %s实际为 %s", tc.expected, dsn)
}
fmt.Printf("%s DSN: %s\n", tc.name, dsn)
})
}
fmt.Println("✓ DSN 构建测试通过")
}
// TestValidate 测试配置验证
func TestValidate(t *testing.T) {
fmt.Println("\n=== 测试配置验证 ===")
// 测试有效配置
validConfig := &Config{
Database: DatabaseConfig{
Host: "127.0.0.1",
Port: "3306",
User: "root",
Password: "pass",
Name: "db",
Type: "mysql",
},
}
if err := validConfig.Validate(); err != nil {
t.Errorf("有效配置验证失败:%v", err)
}
fmt.Println("✓ MySQL 配置验证通过")
// 测试 SQLite 配置
sqliteConfig := &Config{
Database: DatabaseConfig{
Name: "./test.db",
Type: "sqlite",
},
}
if err := sqliteConfig.Validate(); err != nil {
t.Errorf("SQLite 配置验证失败:%v", err)
}
fmt.Println("✓ SQLite 配置验证通过")
// 测试无效配置(缺少必填字段)
invalidConfig := &Config{
Database: DatabaseConfig{
Host: "127.0.0.1",
Type: "mysql",
// 缺少其他必填字段
},
}
if err := invalidConfig.Validate(); err == nil {
t.Error("无效配置应该验证失败")
} else {
fmt.Printf("✓ 无效配置正确拒绝:%v\n", err)
}
}
// TestAllConfigLoading 完整配置加载测试
func TestAllConfigLoading(t *testing.T) {
fmt.Println("\n========================================")
fmt.Println(" 数据库配置加载完整性测试")
fmt.Println("========================================")
TestLoadFromFile(t)
TestBuildDSN(t)
TestValidate(t)
fmt.Println("\n========================================")
fmt.Println(" 所有配置加载测试完成!")
fmt.Println("========================================")
fmt.Println()
fmt.Println("已实现的配置加载功能:")
fmt.Println(" ✓ 从 YAML 文件加载数据库配置")
fmt.Println(" ✓ 支持 host, port, user, pass, name, type 字段")
fmt.Println(" ✓ 自动验证配置完整性")
fmt.Println(" ✓ 自动构建 MySQL DSN")
fmt.Println(" ✓ 自动构建 PostgreSQL DSN")
fmt.Println(" ✓ 自动构建 SQLite DSN")
fmt.Println(" ✓ 支持多种数据库类型")
fmt.Println()
}

View File

@ -1,144 +0,0 @@
package config
import (
"fmt"
"os"
"path/filepath"
"testing"
)
// TestFindConfigOnlyCurrentDir 测试只在当前目录查找配置文件
func TestFindConfigOnlyCurrentDir(t *testing.T) {
fmt.Println("\n=== 测试只在当前目录查找配置文件 ===")
// 创建临时目录结构
tempDir, err := os.MkdirTemp("", "config_test")
if err != nil {
t.Fatalf("创建临时目录失败:%v", err)
}
defer os.RemoveAll(tempDir)
// 创建子目录
subDir := filepath.Join(tempDir, "subdir")
if err := os.MkdirAll(subDir, 0755); err != nil {
t.Fatalf("创建子目录失败:%v", err)
}
// 在根目录创建配置文件
configContent := `database:
host: "127.0.0.1"
port: "3306"
user: "root"
pass: "test"
name: "testdb"
type: "mysql"
`
configFile := filepath.Join(tempDir, "config.yaml")
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
t.Fatalf("创建配置文件失败:%v", err)
}
// 测试 1从根目录查找应该找到
foundPath, err := findConfigFile(tempDir)
if err != nil {
t.Errorf("从根目录查找失败:%v", err)
} else {
fmt.Printf("✓ 从根目录找到配置文件:%s\n", foundPath)
}
// 测试 2从子目录查找不应该找到父目录的配置
foundPath, err = findConfigFile(subDir)
if err == nil {
t.Errorf("从子目录查找应该失败(不越级查找),但找到了:%s", foundPath)
} else {
fmt.Printf("✓ 从子目录查找正确失败(不越级):%v\n", err)
}
// 测试 3在子目录创建配置文件应该找到
subConfigFile := filepath.Join(subDir, "config.yaml")
if err := os.WriteFile(subConfigFile, []byte(configContent), 0644); err != nil {
t.Fatalf("创建子目录配置文件失败:%v", err)
}
foundPath, err = findConfigFile(subDir)
if err != nil {
t.Errorf("从子目录查找失败:%v", err)
} else {
fmt.Printf("✓ 从子目录找到配置文件:%s\n", foundPath)
}
fmt.Println("✓ 只在当前目录查找测试通过")
}
// TestNoParentSearch 测试不向上查找
func TestNoParentSearch(t *testing.T) {
fmt.Println("\n=== 测试不向上层目录查找 ===")
// 创建临时目录结构
tempDir, err := os.MkdirTemp("", "no_parent_test")
if err != nil {
t.Fatalf("创建临时目录失败:%v", err)
}
defer os.RemoveAll(tempDir)
// 创建多级子目录
level1 := filepath.Join(tempDir, "level1")
level2 := filepath.Join(level1, "level2")
level3 := filepath.Join(level2, "level3")
if err := os.MkdirAll(level3, 0755); err != nil {
t.Fatalf("创建目录失败:%v", err)
}
// 只在根目录创建配置文件
configContent := `database:
host: "127.0.0.1"
port: "3306"
user: "root"
pass: "test"
name: "testdb"
type: "mysql"
`
configFile := filepath.Join(tempDir, "config.yaml")
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
t.Fatalf("创建配置文件失败:%v", err)
}
// 从各级子目录查找(都应该失败,因为不越级查找)
testDirs := []string{level1, level2, level3}
for _, dir := range testDirs {
_, err := findConfigFile(dir)
if err == nil {
t.Errorf("从 %s 查找应该失败(不越级查找)", dir)
} else {
fmt.Printf("✓ 从 %s 查找正确失败(不越级)\n", filepath.Base(dir))
}
}
fmt.Println("✓ 不向上层目录查找测试通过")
}
// TestAllNoParentSearch 完整的不越级查找测试
func TestAllNoParentSearch(t *testing.T) {
fmt.Println("\n========================================")
fmt.Println(" 不越级查找完整性测试")
fmt.Println("========================================")
TestFindConfigOnlyCurrentDir(t)
TestNoParentSearch(t)
fmt.Println("\n========================================")
fmt.Println(" 所有不越级查找测试完成!")
fmt.Println("========================================")
fmt.Println()
fmt.Println("已实现的不越级查找功能:")
fmt.Println(" ✓ 只在当前工作目录查找配置文件")
fmt.Println(" ✓ 不会向上层目录查找")
fmt.Println(" ✓ 支持 yaml, yml, toml, ini, json 格式")
fmt.Println(" ✓ 支持 config.* 和 .config.* 命名")
fmt.Println(" ✓ 提供 AutoConnect() 一键连接")
fmt.Println(" ✓ 无需手动指定配置文件路径")
fmt.Println()
}

View File

@ -1,102 +0,0 @@
package config
import (
"fmt"
"git.magicany.cc/black1552/gin-base/db/core"
"git.magicany.cc/black1552/gin-base/db/driver"
)
// 示例:在应用程序中使用纯自研驱动管理
func ExampleUsage() {
// 1. 首先导入你选择的第三方数据库驱动
// 注意:这些驱动需要在 main 包或启动包中导入
/*
import (
_ "github.com/mattn/go-sqlite3" // SQLite 驱动
_ "github.com/go-sql-driver/mysql" // MySQL 驱动
_ "github.com/lib/pq" // PostgreSQL 驱动
_ "github.com/denisenkom/go-mssqldb" // SQL Server 驱动
)
*/
// 2. 获取驱动管理器并注册你选择的驱动
manager := driver.GetDefaultManager()
// 注册 SQLite 驱动
sqliteDriver := driver.NewGenericDriver("sqlite3")
manager.Register("sqlite3", sqliteDriver)
manager.Register("sqlite", sqliteDriver) // 别名
// 注册 MySQL 驱动
mysqlDriver := driver.NewGenericDriver("mysql")
manager.Register("mysql", mysqlDriver)
// 注册 PostgreSQL 驱动
postgresDriver := driver.NewGenericDriver("postgres")
manager.Register("postgres", postgresDriver)
// 3. 加载配置文件
configFile, err := LoadFromFile("config.yaml")
if err != nil {
fmt.Printf("加载配置失败:%v\n", err)
return
}
// 4. 验证驱动是否已注册
err = manager.RegisterDriverByConfig(configFile.Database.Type)
if err != nil {
fmt.Printf("驱动未注册:%v\n", err)
return
}
// 5. 使用配置创建数据库连接
dbConfig := &core.Config{
DriverName: configFile.Database.GetDriverName(),
DataSource: configFile.Database.BuildDSN(),
Debug: true,
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 3600000000000, // 1小时
}
// 6. 创建数据库实例
db, err := core.NewDatabase(dbConfig)
if err != nil {
fmt.Printf("创建数据库连接失败:%v\n", err)
return
}
fmt.Printf("成功连接到 %s 数据库\n", configFile.Database.Type)
// 现在可以使用 db 进行数据库操作
_ = db
}
// AdvancedExample 高级使用示例
func AdvancedExample() {
manager := driver.GetDefaultManager()
// 根据环境变量或配置动态注册驱动
databaseType := "sqlite3" // 从配置获取
// 注册对应驱动
genericDriver := driver.NewGenericDriver(databaseType)
manager.Register(databaseType, genericDriver)
// 验证驱动
err := manager.RegisterDriverByConfig(databaseType)
if err != nil {
fmt.Printf("驱动注册问题:%v\n", err)
return
}
// 现在可以安全地打开数据库连接
db, err := manager.Open(databaseType, "./example.db")
if err != nil {
fmt.Printf("打开数据库失败:%v\n", err)
return
}
defer db.Close()
fmt.Println("数据库连接成功")
}

View File

@ -1,216 +0,0 @@
package main
import (
"encoding/json"
"fmt"
"testing"
"time"
"git.magicany.cc/black1552/gin-base/db/core"
"git.magicany.cc/black1552/gin-base/db/model"
)
// TestTimeConfig 测试时间配置
func TestTimeConfig(t *testing.T) {
fmt.Println("\n=== 测试时间配置 ===")
// 测试默认配置
defaultConfig := core.DefaultTimeConfig()
fmt.Printf("默认创建时间字段:%s\n", defaultConfig.GetCreatedAt())
fmt.Printf("默认更新时间字段:%s\n", defaultConfig.GetUpdatedAt())
fmt.Printf("默认删除时间字段:%s\n", defaultConfig.GetDeletedAt())
fmt.Printf("默认时间格式:%s\n", defaultConfig.GetFormat())
// 测试自定义配置
customConfig := &core.TimeConfig{
CreatedAt: "create_time",
UpdatedAt: "update_time",
DeletedAt: "delete_time",
Format: "2006-01-02 15:04:05",
}
customConfig.Validate()
fmt.Printf("\n自定义创建时间字段%s\n", customConfig.GetCreatedAt())
fmt.Printf("自定义更新时间字段:%s\n", customConfig.GetUpdatedAt())
fmt.Printf("自定义删除时间字段:%s\n", customConfig.GetDeletedAt())
fmt.Printf("自定义时间格式:%s\n", customConfig.GetFormat())
// 测试格式化
now := time.Now()
formatted := customConfig.FormatTime(now)
fmt.Printf("\n格式化时间%s -> %s\n", now.Format("2006-01-02 15:04:05"), formatted)
// 测试解析
parsed, err := customConfig.ParseTime(formatted)
if err != nil {
t.Errorf("解析时间失败:%v", err)
}
fmt.Printf("解析时间:%s -> %s\n", formatted, parsed.Format("2006-01-02 15:04:05"))
fmt.Println("✓ 时间配置测试通过")
}
// TestCustomTimeFields 测试自定义时间字段
func TestCustomTimeFields(t *testing.T) {
fmt.Println("\n=== 测试自定义时间字段模型 ===")
// 使用自定义字段的模型
type CustomModel struct {
ID int64 `json:"id" db:"id"`
Name string `json:"name" db:"name"`
CreateTime model.Time `json:"create_time" db:"create_time"` // 自定义创建时间字段
UpdateTime model.Time `json:"update_time" db:"update_time"` // 自定义更新时间字段
DeleteTime *model.Time `json:"delete_time,omitempty" db:"delete_time"` // 自定义删除时间字段
}
now := time.Now()
custom := &CustomModel{
ID: 1,
Name: "test",
CreateTime: model.Time{Time: now},
UpdateTime: model.Time{Time: now},
}
// 序列化为 JSON
jsonData, err := json.Marshal(custom)
if err != nil {
t.Errorf("JSON 序列化失败:%v", err)
}
fmt.Printf("原始时间:%s\n", now.Format("2006-01-02 15:04:05"))
fmt.Printf("JSON 输出:%s\n", string(jsonData))
// 验证时间格式
var result map[string]interface{}
if err := json.Unmarshal(jsonData, &result); err != nil {
t.Errorf("JSON 反序列化失败:%v", err)
}
createTime, ok := result["create_time"].(string)
if !ok {
t.Error("create_time 应该是字符串")
}
_, err = time.Parse("2006-01-02 15:04:05", createTime)
if err != nil {
t.Errorf("时间格式不正确:%v", err)
}
fmt.Println("✓ 自定义时间字段测试通过")
}
// TestDatabaseWithTimeConfig 测试数据库配置中的时间配置
func TestDatabaseWithTimeConfig(t *testing.T) {
fmt.Println("\n=== 测试数据库时间配置 ===")
// 创建带自定义时间配置的 Config
config := &core.Config{
DriverName: "sqlite",
DataSource: ":memory:",
Debug: true,
TimeConfig: &core.TimeConfig{
CreatedAt: "created_at",
UpdatedAt: "updated_at",
DeletedAt: "deleted_at",
Format: "2006-01-02 15:04:05",
},
}
fmt.Printf("配置中的创建时间字段:%s\n", config.TimeConfig.GetCreatedAt())
fmt.Printf("配置中的更新时间字段:%s\n", config.TimeConfig.GetUpdatedAt())
fmt.Printf("配置中的删除时间字段:%s\n", config.TimeConfig.GetDeletedAt())
fmt.Printf("配置中的时间格式:%s\n", config.TimeConfig.GetFormat())
// 注意:这里不实际创建数据库连接,仅测试配置
fmt.Println("\n数据库会使用该配置自动处理时间字段")
fmt.Println(" - Insert: 自动设置 created_at/updated_at 为当前时间")
fmt.Println(" - Update: 自动设置 updated_at 为当前时间")
fmt.Println(" - Delete: 软删除时设置 deleted_at 为当前时间")
fmt.Println(" - Read: 所有时间字段格式化为 YYYY-MM-DD HH:mm:ss")
fmt.Println("✓ 数据库时间配置测试通过")
}
// TestAllTimeFormats 测试所有时间格式
func TestAllTimeFormats(t *testing.T) {
fmt.Println("\n=== 测试所有支持的时间格式 ===")
testCases := []struct {
format string
timeStr string
}{
{"2006-01-02 15:04:05", "2026-04-02 22:09:09"},
{"2006/01/02 15:04:05", "2026/04/02 22:09:09"},
{"2006-01-02T15:04:05", "2026-04-02T22:09:09"},
{"2006-01-02", "2026-04-02"},
}
for _, tc := range testCases {
t.Run(tc.format, func(t *testing.T) {
parsed, err := time.Parse(tc.format, tc.timeStr)
if err != nil {
t.Logf("格式 %s 解析失败:%v", tc.format, err)
return
}
// 统一格式化为标准格式
formatted := parsed.Format("2006-01-02 15:04:05")
fmt.Printf("%s -> %s\n", tc.timeStr, formatted)
})
}
fmt.Println("✓ 所有时间格式测试通过")
}
// TestDateTimeType 测试 datetime 类型支持
func TestDateTimeType(t *testing.T) {
fmt.Println("\n=== 测试 DATETIME 类型支持 ===")
// Go 的 time.Time 会自动映射到数据库的 DATETIME 类型
now := time.Now()
// 在 SQLite 中DATETIME 存储为 TEXTISO8601 格式)
// 在 MySQL 中DATETIME 存储为 DATETIME 类型
// Go 的 database/sql 会自动处理类型转换
fmt.Printf("Go time.Time: %s\n", now.Format("2006-01-02 15:04:05"))
fmt.Printf("数据库 DATETIME: 自动映射(由驱动处理)\n")
fmt.Println(" - SQLite: TEXT (ISO8601)")
fmt.Println(" - MySQL: DATETIME")
fmt.Println(" - PostgreSQL: TIMESTAMP")
// model.Time 包装后仍然保持 time.Time 的特性
customTime := model.Time{Time: now}
fmt.Printf("model.Time: %s\n", customTime.String())
fmt.Println("✓ DATETIME 类型测试通过")
}
// TestCompleteTimeHandling 完整时间处理测试
func TestCompleteTimeHandling(t *testing.T) {
fmt.Println("\n========================================")
fmt.Println(" CRUD 操作时间配置完整性测试")
fmt.Println("========================================")
TestTimeConfig(t)
TestCustomTimeFields(t)
TestDatabaseWithTimeConfig(t)
TestAllTimeFormats(t)
TestDateTimeType(t)
fmt.Println("\n========================================")
fmt.Println(" 所有时间配置测试完成!")
fmt.Println("========================================")
fmt.Println()
fmt.Println("已实现的时间配置功能:")
fmt.Println(" ✓ 配置文件定义创建时间字段名")
fmt.Println(" ✓ 配置文件定义更新时间字段名")
fmt.Println(" ✓ 配置文件定义删除时间字段名")
fmt.Println(" ✓ 配置文件定义时间格式(默认年 - 月-日 时:分:秒)")
fmt.Println(" ✓ Insert: 自动设置配置的时间字段")
fmt.Println(" ✓ Update: 自动设置配置的更新时间字段")
fmt.Println(" ✓ Delete: 软删除使用配置的删除时间字段")
fmt.Println(" ✓ Read: 所有时间字段格式化为配置的格式")
fmt.Println(" ✓ 支持 DATETIME 类型自动映射")
fmt.Println()
}

View File

@ -1,148 +0,0 @@
package core
import (
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"sync"
"time"
)
// CacheItem 缓存项
type CacheItem struct {
Data interface{} // 缓存的数据
ExpiresAt time.Time // 过期时间
}
// QueryCache 查询缓存 - 提高重复查询的性能
type QueryCache struct {
mu sync.RWMutex // 读写锁
items map[string]*CacheItem // 缓存项
duration time.Duration // 默认缓存时长
}
// NewQueryCache 创建查询缓存实例
func NewQueryCache(duration time.Duration) *QueryCache {
cache := &QueryCache{
items: make(map[string]*CacheItem),
duration: duration,
}
// 启动清理协程
go cache.cleaner()
return cache
}
// Set 设置缓存
func (qc *QueryCache) Set(key string, data interface{}) {
qc.mu.Lock()
defer qc.mu.Unlock()
qc.items[key] = &CacheItem{
Data: data,
ExpiresAt: time.Now().Add(qc.duration),
}
}
// Get 获取缓存
func (qc *QueryCache) Get(key string) (interface{}, bool) {
qc.mu.RLock()
defer qc.mu.RUnlock()
item, exists := qc.items[key]
if !exists {
return nil, false
}
// 检查是否过期
if time.Now().After(item.ExpiresAt) {
return nil, false
}
return item.Data, true
}
// Delete 删除缓存
func (qc *QueryCache) Delete(key string) {
qc.mu.Lock()
defer qc.mu.Unlock()
delete(qc.items, key)
}
// Clear 清空所有缓存
func (qc *QueryCache) Clear() {
qc.mu.Lock()
defer qc.mu.Unlock()
qc.items = make(map[string]*CacheItem)
}
// cleaner 定期清理过期缓存
func (qc *QueryCache) cleaner() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
qc.cleanExpired()
}
}
// cleanExpired 清理过期的缓存项
func (qc *QueryCache) cleanExpired() {
qc.mu.Lock()
defer qc.mu.Unlock()
now := time.Now()
for key, item := range qc.items {
if now.After(item.ExpiresAt) {
delete(qc.items, key)
}
}
}
// GenerateCacheKey 生成缓存键
func GenerateCacheKey(sql string, args ...interface{}) string {
// 将 SQL 和参数组合成字符串
keyData := sql
for _, arg := range args {
keyData += fmt.Sprintf("%v", arg)
}
// 计算 MD5 哈希
hash := md5.Sum([]byte(keyData))
return hex.EncodeToString(hash[:])
}
// deepCopy 深拷贝数据(使用 JSON 序列化/反序列化)
func deepCopy(src, dst interface{}) error {
// 序列化为 JSON
data, err := json.Marshal(src)
if err != nil {
return fmt.Errorf("序列化失败:%w", err)
}
// 反序列化到目标
if err := json.Unmarshal(data, dst); err != nil {
return fmt.Errorf("反序列化失败:%w", err)
}
return nil
}
// WithCache 带缓存的查询装饰器
func (q *QueryBuilder) WithCache(cache *QueryCache) IQuery {
if cache == nil {
return q
}
// 设置缓存实例
q.cache = cache
q.useCache = true
// 生成缓存键(使用 SQL 和参数)
sqlStr, args := q.BuildSelect()
q.cacheKey = GenerateCacheKey(sqlStr, args...)
return q
}

View File

@ -1,124 +0,0 @@
package core
import (
"fmt"
"testing"
"time"
)
// TestWithCache 测试带缓存的查询
func TestWithCache(t *testing.T) {
fmt.Println("\n=== 测试带缓存的查询 ===")
// 创建缓存实例(缓存 5 分钟)
_ = NewQueryCache(5 * time.Minute)
// 注意:这个测试需要真实的数据库连接
// 以下是使用示例:
// 示例 1: 基本缓存查询
// var users []User
// err := db.Model(&User{}).
// Where("status = ?", "active").
// WithCache(cache).
// Find(&users)
// if err != nil {
// t.Fatal(err)
// }
// 示例 2: 第二次查询会命中缓存
// var users2 []User
// err = db.Model(&User{}).
// Where("status = ?", "active").
// WithCache(cache).
// Find(&users2)
// if err != nil {
// t.Fatal(err)
// }
fmt.Println("✓ WithCache 已实现")
fmt.Println("功能:")
fmt.Println(" - 缓存命中时直接返回数据,不执行 SQL")
fmt.Println(" - 缓存未命中时执行查询并自动缓存结果")
fmt.Println(" - 支持深拷贝,避免引用问题")
fmt.Println("✓ 测试通过")
}
// TestDeepCopy 测试深拷贝功能
func TestDeepCopy(t *testing.T) {
fmt.Println("\n=== 测试深拷贝功能 ===")
type TestData struct {
ID int `json:"id"`
Name string `json:"name"`
}
src := &TestData{ID: 1, Name: "test"}
dst := &TestData{}
err := deepCopy(src, dst)
if err != nil {
t.Fatal(err)
}
if dst.ID != src.ID || dst.Name != src.Name {
t.Errorf("深拷贝失败:期望 %+v, 得到 %+v", src, dst)
}
// 修改源数据,目标不应该受影响
src.Name = "modified"
if dst.Name == "modified" {
t.Error("深拷贝失败:目标受到了源数据修改的影响")
}
fmt.Println("✓ 深拷贝功能正常")
fmt.Println("✓ 测试通过")
}
// TestCacheKeyGeneration 测试缓存键生成
func TestCacheKeyGeneration(t *testing.T) {
fmt.Println("\n=== 测试缓存键生成 ===")
// 相同的 SQL 和参数应该生成相同的键
key1 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1)
key2 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1)
if key1 != key2 {
t.Errorf("缓存键不一致:%s vs %s", key1, key2)
}
// 不同的参数应该生成不同的键
key3 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 2)
if key1 == key3 {
t.Error("不同的参数应该生成不同的缓存键")
}
fmt.Println("✓ 缓存键生成正常")
fmt.Println("✓ 测试通过")
}
// ExampleWithCache 使用示例
func exampleWithCacheUsage() {
// 示例 1: 基本用法
// cache := NewQueryCache(5 * time.Minute)
// var results []map[string]interface{}
// err := db.Table("users").
// Where("status = ?", "active").
// WithCache(cache).
// Find(&results)
// 示例 2: 带条件的查询
// err := db.Model(&User{}).
// Select("id", "username", "email").
// Where("age > ?", 18).
// Order("created_at DESC").
// Limit(10).
// WithCache(cache).
// Find(&results)
// 示例 3: 清除缓存
// cache.Clear() // 清空所有缓存
// cache.Delete(key) // 删除指定缓存
fmt.Println("使用示例请查看测试代码")
}

View File

@ -1,75 +0,0 @@
package core
import (
"time"
)
// TimeConfig 时间配置 - 定义时间字段名称和格式
type TimeConfig struct {
CreatedAt string `json:"created_at" yaml:"created_at"` // 创建时间字段名
UpdatedAt string `json:"updated_at" yaml:"updated_at"` // 更新时间字段名
DeletedAt string `json:"deleted_at" yaml:"deleted_at"` // 删除时间字段名
Format string `json:"format" yaml:"format"` // 时间格式,默认 "2006-01-02 15:04:05"
}
// DefaultTimeConfig 获取默认时间配置
func DefaultTimeConfig() *TimeConfig {
return &TimeConfig{
CreatedAt: "created_at",
UpdatedAt: "updated_at",
DeletedAt: "deleted_at",
Format: "2006-01-02 15:04:05", // Go 的参考时间格式
}
}
// Validate 验证时间配置
func (tc *TimeConfig) Validate() {
if tc.CreatedAt == "" {
tc.CreatedAt = "created_at"
}
if tc.UpdatedAt == "" {
tc.UpdatedAt = "updated_at"
}
if tc.DeletedAt == "" {
tc.DeletedAt = "deleted_at"
}
if tc.Format == "" {
tc.Format = "2006-01-02 15:04:05"
}
}
// GetCreatedAt 获取创建时间字段名
func (tc *TimeConfig) GetCreatedAt() string {
tc.Validate()
return tc.CreatedAt
}
// GetUpdatedAt 获取更新时间字段名
func (tc *TimeConfig) GetUpdatedAt() string {
tc.Validate()
return tc.UpdatedAt
}
// GetDeletedAt 获取删除时间字段名
func (tc *TimeConfig) GetDeletedAt() string {
tc.Validate()
return tc.DeletedAt
}
// GetFormat 获取时间格式
func (tc *TimeConfig) GetFormat() string {
tc.Validate()
return tc.Format
}
// FormatTime 格式化时间为配置的格式
func (tc *TimeConfig) FormatTime(t time.Time) string {
tc.Validate()
return t.Format(tc.Format)
}
// ParseTime 解析时间字符串
func (tc *TimeConfig) ParseTime(timeStr string) (time.Time, error) {
tc.Validate()
return time.Parse(tc.Format, timeStr)
}

View File

@ -1,314 +0,0 @@
package core
import (
"context"
"reflect"
)
// DAO 数据访问对象基类 - 所有 DAO 都继承此结构
// 提供通用的 CRUD 操作方法,子类只需嵌入即可使用
type DAO struct {
db *Database // 数据库连接实例
modelType interface{} // 模型类型信息,用于 Columns 等方法
}
// NewDAO 创建 DAO 基类实例
// 自动使用全局默认 Database 实例
func NewDAO() *DAO {
return &DAO{
db: GetDefaultDatabase(),
}
}
// NewDAOWithModel 创建带模型类型的 DAO 基类实例
// 参数:
// - model: 模型实例(指针类型),用于获取表结构信息
//
// 自动使用全局默认 Database 实例
func NewDAOWithModel(model interface{}) *DAO {
return &DAO{
db: GetDefaultDatabase(),
modelType: model,
}
}
// Create 创建记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Create(ctx context.Context, model interface{}) error {
// 使用事务来插入数据
tx, err := dao.db.Begin()
if err != nil {
return err
}
_, err = tx.Insert(model)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
// GetByID 根据 ID 查询单条记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) GetByID(ctx context.Context, model interface{}, id int64) error {
return dao.db.Model(model).Where("id = ?", id).First(model)
}
// Update 更新记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Update(ctx context.Context, model interface{}, data map[string]interface{}) error {
pkValue := getFieldValue(model, "ID")
if pkValue == 0 {
return nil
}
return dao.db.Model(model).Where("id = ?", pkValue).Updates(data)
}
// Delete 删除记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Delete(ctx context.Context, model interface{}) error {
pkValue := getFieldValue(model, "ID")
if pkValue == 0 {
return nil
}
return dao.db.Model(model).Where("id = ?", pkValue).Delete()
}
// FindAll 查询所有记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) FindAll(ctx context.Context, model interface{}) error {
return dao.db.Model(model).Find(model)
}
// FindByPage 分页查询(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) FindByPage(ctx context.Context, model interface{}, page, pageSize int) error {
return dao.db.Model(model).Limit(pageSize).Offset((page - 1) * pageSize).Find(model)
}
// Count 统计记录数(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Count(ctx context.Context, model interface{}, where ...string) (int64, error) {
var count int64
query := dao.db.Model(model)
if len(where) > 0 {
query = query.Where(where[0])
}
// Count 是链式调用,需要调用 Find 来执行
err := query.Count(&count).Find(model)
if err != nil {
return 0, err
}
return count, nil
}
// Exists 检查记录是否存在(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) Exists(ctx context.Context, model interface{}) (bool, error) {
count, err := dao.Count(ctx, model)
if err != nil {
return false, err
}
return count > 0, nil
}
// First 查询第一条记录(通用方法)
// 自动使用 DAO 中已关联的 Database 实例
func (dao *DAO) First(ctx context.Context, model interface{}) error {
return dao.db.Model(model).First(model)
}
// Columns 获取表的所有列名
// 返回一个动态创建的结构体类型,所有字段都是 string 类型
// 用途:用于构建 UPDATE、INSERT 等操作时的列名映射
//
// 示例:
//
// type UserDAO struct {
// *core.DAO
// }
//
// func NewUserDAO() *UserDAO {
// return &UserDAO{
// DAO: core.NewDAOWithModel(&model.User{}),
// }
// }
//
// // 使用
// dao := NewUserDAO()
// cols := dao.Columns() // 返回 *struct{ID string; Username string; ...}
func (dao *DAO) Columns() interface{} {
// 检查是否有模型类型信息
if dao.modelType == nil {
return nil
}
// 获取模型类型
modelType := reflect.TypeOf(dao.modelType)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// 创建字段列表
fields := []reflect.StructField{}
// 遍历模型的所有字段
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// 跳过未导出的字段
if !field.IsExported() {
continue
}
// 获取 db 标签,如果没有则跳过
dbTag := field.Tag.Get("db")
if dbTag == "" || dbTag == "-" {
continue
}
// 创建新的结构体字段,类型为 string
newField := reflect.StructField{
Name: field.Name,
Type: reflect.TypeOf(""), // string 类型
Tag: reflect.StructTag(`json:"` + field.Tag.Get("json") + `" db:"` + dbTag + `"`),
}
fields = append(fields, newField)
}
// 动态创建结构体类型
columnsType := reflect.StructOf(fields)
// 创建该类型的指针并返回
return reflect.New(columnsType).Interface()
}
// getFieldValue 获取结构体字段值(辅助函数)
// 用于获取主键或其他字段的值,支持多种数据类型
// 参数:
// - model: 模型实例(可以是指针或值)
// - fieldName: 字段名(如 "ID", "UserId" 等)
//
// 返回:
// - int64: 字段值(如果是数字类型)或 0无法获取时
func getFieldValue(model interface{}, fieldName string) int64 {
// 检查空值
if model == nil {
return 0
}
// 获取反射对象
val := reflect.ValueOf(model)
// 如果是指针,解引用
if val.Kind() == reflect.Ptr {
if val.IsNil() {
return 0
}
val = val.Elem()
}
// 确保是结构体
if val.Kind() != reflect.Struct {
return 0
}
// 查找字段
field := val.FieldByName(fieldName)
if !field.IsValid() {
// 尝试查找常见的主键字段名变体
alternativeNames := []string{"Id", "id", "ID"}
for _, name := range alternativeNames {
if name != fieldName {
field = val.FieldByName(name)
if field.IsValid() {
fieldName = name
break
}
}
}
if !field.IsValid() {
return 0
}
}
// 检查字段是否可以访问
if !field.CanInterface() {
return 0
}
// 获取字段值并转换为 int64
fieldValue := field.Interface()
// 根据字段类型进行转换
switch v := fieldValue.(type) {
case int:
return int64(v)
case int8:
return int64(v)
case int16:
return int64(v)
case int32:
return int64(v)
case int64:
return v
case uint:
return int64(v)
case uint8:
return int64(v)
case uint16:
return int64(v)
case uint32:
return int64(v)
case uint64:
// 注意uint64 转 int64 可能溢出,但这里假设 ID 不会超过 int64 范围
return int64(v)
case float32:
return int64(v)
case float64:
return int64(v)
case string:
// 尝试将字符串解析为数字
// 注意:这里不导入 strconv简单处理返回 0
return 0
default:
// 其他类型(如 sql.NullInt64 等),尝试使用反射
return convertToInteger(field)
}
}
// convertToInteger 使用反射将字段值转换为 int64
func convertToInteger(field reflect.Value) int64 {
// 获取实际的值(如果是指针则解引用)
if field.Kind() == reflect.Ptr {
if field.IsNil() {
return 0
}
field = field.Elem()
}
// 根据 Kind 进行转换
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return int64(field.Uint())
case reflect.Float32, reflect.Float64:
return int64(field.Float())
case reflect.String:
// 字符串类型,尝试解析(简单实现,不处理错误)
return 0
default:
return 0
}
}

View File

@ -1,179 +0,0 @@
package core
import (
"database/sql"
"fmt"
"testing"
)
// TestGetFieldValue 测试获取字段值的基本功能
func TestGetFieldValue(t *testing.T) {
fmt.Println("\n=== 测试 getFieldValue 基本功能 ===")
tests := []struct {
name string
model interface{}
fieldName string
expected int64
}{
{"int 类型", &TestModelInt{ID: 123}, "ID", 123},
{"int64 类型", &TestModelInt64{ID: 456}, "ID", 456},
{"uint 类型", &TestModelUint{ID: 789}, "ID", 789},
{"float 类型", &TestModelFloat{ID: 999.5}, "ID", 999},
{"指针为 nil", (*TestModelInt)(nil), "ID", 0},
{"model 为 nil", nil, "ID", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getFieldValue(tt.model, tt.fieldName)
if result != tt.expected {
t.Errorf("期望 %d, 得到 %d", tt.expected, result)
}
})
}
fmt.Println("✓ 基本功能测试通过")
}
// TestGetFieldValueAlternativeNames 测试字段名变体查找
func TestGetFieldValueAlternativeNames(t *testing.T) {
fmt.Println("\n=== 测试字段名变体查找 ===")
// 测试 Id 字段(驼峰)
model1 := &TestModelId{Id: 111}
result1 := getFieldValue(model1, "ID")
if result1 != 111 {
t.Errorf("期望 111, 得到 %d", result1)
}
fmt.Println("✓ 字段名变体查找测试通过")
}
// TestGetFieldValueEdgeCases 测试边界情况
func TestGetFieldValueEdgeCases(t *testing.T) {
fmt.Println("\n=== 测试边界情况 ===")
// 测试非结构体类型
nonStruct := 123
result := getFieldValue(nonStruct, "ID")
if result != 0 {
t.Errorf("非结构体应该返回 0, 得到 %d", result)
}
// 测试不存在的字段(没有 ID/Id/id 等变体)
type ModelNoID struct {
Name string `json:"name" db:"name"`
}
model := &ModelNoID{Name: "test"}
result = getFieldValue(model, "NonExistentField")
if result != 0 {
t.Errorf("不存在的字段应该返回 0, 得到 %d", result)
}
fmt.Println("✓ 边界情况测试通过")
}
// TestGetFieldValueSpecialTypes 测试特殊类型
func TestGetFieldValueSpecialTypes(t *testing.T) {
fmt.Println("\n=== 测试特殊类型 ===")
// 注意sql.NullInt64 等数据库特殊类型目前不支持
// 如果需要支持,可以在 convertToInteger 中添加专门的处理逻辑
fmt.Println("✓ 特殊类型测试通过(当前版本不支持 sql.NullInt64")
}
// TestGetFieldValueInUpdate 测试在 Update 场景中的使用
func TestGetFieldValueInUpdate(t *testing.T) {
fmt.Println("\n=== 测试 Update 场景 ===")
user := &UserModel{
ID: 1,
Username: "test",
}
pkValue := getFieldValue(user, "ID")
if pkValue != 1 {
t.Errorf("期望主键值为 1, 得到 %d", pkValue)
}
// 测试主键为 0 的情况
user2 := &UserModel{
ID: 0,
Username: "test2",
}
pkValue2 := getFieldValue(user2, "ID")
if pkValue2 != 0 {
t.Errorf("期望主键值为 0, 得到 %d", pkValue2)
}
fmt.Println("✓ Update 场景测试通过")
}
// TestGetFieldValueLargeNumbers 测试大数字
func TestGetFieldValueLargeNumbers(t *testing.T) {
fmt.Println("\n=== 测试大数字 ===")
// 测试最大 int64
maxInt := int64(9223372036854775807)
model1 := &TestModelInt64{ID: maxInt}
result1 := getFieldValue(model1, "ID")
if result1 != maxInt {
t.Errorf("期望 %d, 得到 %d", maxInt, result1)
}
// 测试 uint64 转 int64
largeUint := uint64(18446744073709551615) // 这会导致溢出
model2 := &TestModelUint64{ID: largeUint}
result2 := getFieldValue(model2, "ID")
// 注意:这里会发生溢出,但这是预期的行为
if result2 == 0 {
t.Error("uint64 转换不应该返回 0")
}
fmt.Println("✓ 大数字测试通过")
}
// 测试模型定义
type TestModelInt struct {
ID int `json:"id" db:"id"`
}
type TestModelInt64 struct {
ID int64 `json:"id" db:"id"`
}
type TestModelUint struct {
ID uint `json:"id" db:"id"`
}
type TestModelUint64 struct {
ID uint64 `json:"id" db:"id"`
}
type TestModelFloat struct {
ID float64 `json:"id" db:"id"`
}
type TestModelId struct {
Id int64 `json:"id" db:"id"`
}
type TestModelid struct {
id int64 `json:"id" db:"id"`
}
type TestModelPrivate struct {
privateField int64
}
type TestModelNullInt struct {
ID sql.NullInt64 `json:"id" db:"id"`
}
type UserModel struct {
ID int64 `json:"id" db:"id"`
Username string `json:"username" db:"username"`
}

View File

@ -1,144 +0,0 @@
package core
import (
"fmt"
"os"
"path/filepath"
"git.magicany.cc/black1552/gin-base/db/driver"
)
// defaultDatabase 全局默认数据库连接实例
var defaultDatabase *Database
// NewDatabase 创建数据库连接 - 初始化数据库连接和相关组件
func NewDatabase(config *Config) (*Database, error) {
db := &Database{
config: config,
debug: config.Debug,
driverName: config.DriverName,
}
// 初始化时间配置
if config.TimeConfig == nil {
db.timeConfig = DefaultTimeConfig()
} else {
db.timeConfig = config.TimeConfig
db.timeConfig.Validate()
}
// 获取驱动管理器
dm := driver.GetDefaultManager()
// 打开数据库连接
sqlDB, err := dm.Open(config.DriverName, config.DataSource)
if err != nil {
return nil, fmt.Errorf("打开数据库失败:%w", err)
}
db.db = sqlDB
// 配置连接池参数
if config.MaxIdleConns > 0 {
db.db.SetMaxIdleConns(config.MaxIdleConns)
}
if config.MaxOpenConns > 0 {
db.db.SetMaxOpenConns(config.MaxOpenConns)
}
if config.ConnMaxLifetime > 0 {
db.db.SetConnMaxLifetime(config.ConnMaxLifetime)
}
// 测试数据库连接
if err := db.db.Ping(); err != nil {
return nil, fmt.Errorf("数据库连接测试失败:%w", err)
}
// 初始化组件
db.mapper = NewFieldMapper()
db.migrator = NewMigrator(db)
if config.Debug {
fmt.Println("[Magic-ORM] 数据库连接成功")
}
// 设置为全局默认实例
defaultDatabase = db
return db, nil
}
// AutoConnect 自动查找配置文件并创建数据库连接
// 会在当前目录及上级目录中查找 config.yaml, config.toml, config.ini, config.json 等文件
func AutoConnect(debug bool) (*Database, error) {
// 自动查找配置文件
configPath, err := findConfigFile("")
if err != nil {
return nil, fmt.Errorf("查找配置文件失败:%w", err)
}
// 从文件加载配置(使用 config 包)
return loadAndConnect(configPath, debug)
}
// Connect 从配置文件创建数据库连接(向后兼容)
// Deprecated: 使用 AutoConnect 代替
func Connect(configPath string, debug bool) (*Database, error) {
return loadAndConnect(configPath, debug)
}
// findConfigFile 在项目目录下自动查找配置文件
// 支持 yaml, yml, toml, ini, json 等格式
// 只在当前目录查找,不越级查找
func findConfigFile(searchDir string) (string, error) {
// 配置文件名优先级列表
configNames := []string{
"config.yaml", "config.yml",
"config.toml",
"config.ini",
"config.json",
".config.yaml", ".config.yml",
".config.toml",
".config.ini",
".config.json",
}
// 如果未指定搜索目录,使用当前目录
if searchDir == "" {
var err error
searchDir, err = os.Getwd()
if err != nil {
return "", fmt.Errorf("获取当前目录失败:%w", err)
}
}
// 只在当前目录下查找,不向上查找
for _, name := range configNames {
filePath := filepath.Join(searchDir, name)
if _, err := os.Stat(filePath); err == nil {
return filePath, nil
}
}
return "", fmt.Errorf("未找到配置文件(支持 yaml, yml, toml, ini, json 格式)")
}
// loadAndConnect 从配置文件加载并创建数据库连接
func loadAndConnect(configPath string, debug bool) (*Database, error) {
// 这里需要调用 config 包的 LoadFromFile
// 为了避免循环依赖,我们直接在 core 包中实现简单的 YAML 解析
// 或者通过接口传递配置
// 简单方案:返回错误,提示使用 config 包
return nil, fmt.Errorf("请使用 config.AutoConnect() 方法")
}
// GetDefaultDatabase 获取全局默认数据库实例
func GetDefaultDatabase() *Database {
return defaultDatabase
}
// SetDefaultDatabase 设置全局默认数据库实例
func SetDefaultDatabase(db *Database) {
defaultDatabase = db
}

View File

@ -1,94 +0,0 @@
package core
import (
"reflect"
"time"
)
// ParamFilter 参数过滤器 - 智能过滤零值和空值字段
type ParamFilter struct{}
// NewParamFilter 创建参数过滤器实例
func NewParamFilter() *ParamFilter {
return &ParamFilter{}
}
// FilterZeroValues 过滤零值和空值字段
func (pf *ParamFilter) FilterZeroValues(data map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for key, value := range data {
if !pf.isZeroValue(value) {
result[key] = value
}
}
return result
}
// FilterEmptyStrings 过滤空字符串
func (pf *ParamFilter) FilterEmptyStrings(data map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for key, value := range data {
if str, ok := value.(string); ok {
if str != "" {
result[key] = value
}
} else {
result[key] = value
}
}
return result
}
// FilterNilValues 过滤 nil 值
func (pf *ParamFilter) FilterNilValues(data map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for key, value := range data {
if value != nil {
result[key] = value
}
}
return result
}
// isZeroValue 检查是否是零值
func (pf *ParamFilter) isZeroValue(v interface{}) bool {
if v == nil {
return true
}
val := reflect.ValueOf(v)
switch val.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return val.Len() == 0
case reflect.Bool:
return !val.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return val.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return val.Uint() == 0
case reflect.Float32, reflect.Float64:
return val.Float() == 0
case reflect.Interface, reflect.Ptr:
return val.IsNil()
case reflect.Struct:
// 特殊处理 time.Time
if t, ok := v.(time.Time); ok {
return t.IsZero()
}
return false
}
return false
}
// IsValidValue 检查值是否有效(非零值、非空值)
func (pf *ParamFilter) IsValidValue(v interface{}) bool {
return !pf.isZeroValue(v)
}

View File

@ -1,254 +0,0 @@
package core
import (
"database/sql"
"time"
)
// IDatabase 数据库连接接口 - 提供所有数据库操作的顶层接口
type IDatabase interface {
// 基础操作
DB() *sql.DB // 返回底层的 sql.DB 对象
Close() error // 关闭数据库连接
Ping() error // 测试数据库连接是否正常
// 事务管理
Begin() (ITx, error) // 开始一个新事务
Transaction(fn func(ITx) error) error // 执行事务,自动提交或回滚
// 查询构建器
Model(model interface{}) IQuery // 基于模型创建查询
Table(name string) IQuery // 基于表名创建查询
Query(result interface{}, query string, args ...interface{}) error // 执行原生 SQL 查询
Exec(query string, args ...interface{}) (sql.Result, error) // 执行原生 SQL 并返回结果
// 迁移管理
Migrate(models ...interface{}) error // 执行数据库迁移
// 配置
SetDebug(bool) // 设置调试模式
SetMaxIdleConns(int) // 设置最大空闲连接数
SetMaxOpenConns(int) // 设置最大打开连接数
SetConnMaxLifetime(time.Duration) // 设置连接最大生命周期
}
// ITx 事务接口 - 提供事务操作的所有方法
type ITx interface {
// 基础操作
Commit() error // 提交事务
Rollback() error // 回滚事务
// 查询操作
Model(model interface{}) IQuery // 在事务中基于模型创建查询
Table(name string) IQuery // 在事务中基于表名创建查询
Insert(model interface{}) (int64, error) // 插入数据,返回插入的 ID
BatchInsert(models interface{}, batchSize int) error // 批量插入数据
Update(model interface{}, data map[string]interface{}) error // 更新数据
Delete(model interface{}) error // 删除数据
// 原生 SQL
Query(result interface{}, query string, args ...interface{}) error // 执行原生 SQL 查询
Exec(query string, args ...interface{}) (sql.Result, error) // 执行原生 SQL
}
// IQuery 查询构建器接口 - 提供流畅的链式查询构建能力
type IQuery interface {
// 条件查询
Where(query string, args ...interface{}) IQuery // 添加 WHERE 条件
Or(query string, args ...interface{}) IQuery // 添加 OR 条件
And(query string, args ...interface{}) IQuery // 添加 AND 条件
// 字段选择
Select(fields ...string) IQuery // 选择要查询的字段
Omit(fields ...string) IQuery // 排除指定的字段
// 排序
Order(order string) IQuery // 设置排序规则
OrderBy(field string, direction string) IQuery // 按指定字段和方向排序
// 分页
Limit(limit int) IQuery // 限制返回数量
Offset(offset int) IQuery // 设置偏移量
Page(page, pageSize int) IQuery // 分页查询
// 分组
Group(group string) IQuery // 设置分组字段
Having(having string, args ...interface{}) IQuery // 添加 HAVING 条件
// 连接
Join(join string, args ...interface{}) IQuery // 添加 JOIN 连接
LeftJoin(table, on string) IQuery // 左连接
RightJoin(table, on string) IQuery // 右连接
InnerJoin(table, on string) IQuery // 内连接
// 预加载
Preload(relation string, conditions ...interface{}) IQuery // 预加载关联数据
// 执行查询
First(result interface{}) error // 查询第一条记录
Find(result interface{}) error // 查询多条记录
Count(count *int64) IQuery // 统计记录数量
Exists() (bool, error) // 检查记录是否存在
// 更新和删除
Updates(data interface{}) error // 更新数据
UpdateColumn(column string, value interface{}) error // 更新单个字段
Delete() error // 删除数据
// 特殊模式
Unscoped() IQuery // 忽略软删除
DryRun() IQuery // 干跑模式,不执行只生成 SQL
Debug() IQuery // 调试模式,打印 SQL 日志
// 构建 SQL不执行
Build() (string, []interface{}) // 构建 SELECT SQL 语句
BuildUpdate(data interface{}) (string, []interface{}) // 构建 UPDATE SQL 语句
BuildDelete() (string, []interface{}) // 构建 DELETE SQL 语句
}
// IModel 模型接口 - 定义模型的基本行为和生命周期回调
type IModel interface {
// 表名映射
TableName() string // 返回模型对应的表名
// 生命周期回调(可选实现)
BeforeCreate(tx ITx) error // 创建前回调
AfterCreate(tx ITx) error // 创建后回调
BeforeUpdate(tx ITx) error // 更新前回调
AfterUpdate(tx ITx) error // 更新后回调
BeforeDelete(tx ITx) error // 删除前回调
AfterDelete(tx ITx) error // 删除后回调
BeforeSave(tx ITx) error // 保存前回调
AfterSave(tx ITx) error // 保存后回调
}
// IFieldMapper 字段映射器接口 - 处理 Go 结构体与数据库字段之间的映射
type IFieldMapper interface {
// 结构体字段转数据库列
StructToColumns(model interface{}) (map[string]interface{}, error) // 将结构体转换为键值对
// 数据库列转结构体字段
ColumnsToStruct(row *sql.Rows, model interface{}) error // 将查询结果映射到结构体
// 获取表名
GetTableName(model interface{}) string // 获取模型对应的表名
// 获取主键字段
GetPrimaryKey(model interface{}) string // 获取主键字段名
// 获取字段信息
GetFields(model interface{}) []FieldInfo // 获取所有字段信息
}
// FieldInfo 字段信息 - 描述数据库字段的详细信息
type FieldInfo struct {
Name string // 字段名Go 结构体字段名)
Column string // 列名(数据库中的实际列名)
Type string // Go 类型(如 string, int, time.Time 等)
DbType string // 数据库类型(如 VARCHAR, INT, DATETIME 等)
Tag string // 标签db 标签内容)
IsPrimary bool // 是否主键
IsAuto bool // 是否自增
}
// IMigrator 迁移管理器接口 - 提供数据库架构迁移的所有操作
type IMigrator interface {
// 自动迁移
AutoMigrate(models ...interface{}) error // 自动执行模型迁移
// 表操作
CreateTable(model interface{}) error // 创建表
DropTable(model interface{}) error // 删除表
HasTable(model interface{}) (bool, error) // 检查表是否存在
RenameTable(oldName, newName string) error // 重命名表
// 列操作
AddColumn(model interface{}, field string) error // 添加列
DropColumn(model interface{}, field string) error // 删除列
HasColumn(model interface{}, field string) (bool, error) // 检查列是否存在
RenameColumn(model interface{}, oldField, newField string) error // 重命名列
// 索引操作
CreateIndex(model interface{}, field string) error // 创建索引
DropIndex(model interface{}, field string) error // 删除索引
HasIndex(model interface{}, field string) (bool, error) // 检查索引是否存在
}
// ICodeGenerator 代码生成器接口 - 自动生成 Model 和 DAO 代码
type ICodeGenerator interface {
// 生成 Model 代码
GenerateModel(table string, outputDir string) error // 根据表生成 Model 文件
// 生成 DAO 代码
GenerateDAO(table string, outputDir string) error // 根据表生成 DAO 文件
// 生成完整代码
GenerateAll(tables []string, outputDir string) error // 批量生成所有代码
// 从数据库读取表结构
InspectTable(tableName string) (*TableSchema, error) // 检查表结构
}
// TableSchema 表结构信息 - 描述数据库表的完整结构
type TableSchema struct {
Name string // 表名
Columns []ColumnInfo // 列信息列表
Indexes []IndexInfo // 索引信息列表
}
// ColumnInfo 列信息 - 描述表中一个列的详细信息
type ColumnInfo struct {
Name string // 列名
Type string // 数据类型
Nullable bool // 是否允许为空
Default interface{} // 默认值
PrimaryKey bool // 是否主键
}
// IndexInfo 索引信息 - 描述表中一个索引的详细信息
type IndexInfo struct {
Name string // 索引名
Columns []string // 索引包含的列
Unique bool // 是否唯一索引
}
// ReadPolicy 读负载均衡策略 - 定义主从集群中读操作的分配策略
type ReadPolicy int
const (
Random ReadPolicy = iota // 随机选择一个从库
RoundRobin // 轮询方式选择从库
LeastConn // 选择连接数最少的从库
)
// Config 数据库配置 - 包含数据库连接的所有配置项
type Config struct {
DriverName string // 驱动名称(如 mysql, sqlite, postgres 等)
DataSource string // 数据源连接字符串DNS
MaxIdleConns int // 最大空闲连接数
MaxOpenConns int // 最大打开连接数
ConnMaxLifetime time.Duration // 连接最大生命周期
Debug bool // 调试模式(是否打印 SQL 日志)
// 主从配置
Replicas []string // 从库列表(用于读写分离)
ReadPolicy ReadPolicy // 读负载均衡策略
// OpenTelemetry 可观测性配置
EnableTracing bool // 是否启用链路追踪
ServiceName string // 服务名称(用于 Tracing
// 时间配置
TimeConfig *TimeConfig // 时间字段配置(字段名、格式等)
}
// Database 数据库实现 - IDatabase 接口的具体实现
type Database struct {
db *sql.DB // 底层数据库连接
config *Config // 数据库配置
debug bool // 调试模式开关
mapper IFieldMapper // 字段映射器实例
migrator IMigrator // 迁移管理器实例
driverName string // 驱动名称
timeConfig *TimeConfig // 时间配置
}

View File

@ -1,306 +0,0 @@
package core
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
)
// FieldMapper 字段映射器实现 - 使用反射处理 Go 结构体与数据库字段之间的映射
type FieldMapper struct{}
// NewFieldMapper 创建字段映射器实例
func NewFieldMapper() IFieldMapper {
return &FieldMapper{}
}
// StructToColumns 将结构体转换为键值对 - 用于 INSERT/UPDATE 操作
func (fm *FieldMapper) StructToColumns(model interface{}) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 获取反射对象
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil, errors.New("模型必须是结构体")
}
typ := val.Type()
// 遍历所有字段
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
value := val.Field(i)
// 跳过未导出的字段
if !field.IsExported() {
continue
}
// 获取 db 标签
dbTag := field.Tag.Get("db")
if dbTag == "" || dbTag == "-" {
continue // 跳过没有 db 标签或标签为 - 的字段
}
// 跳过零值(可选优化)
if fm.isZeroValue(value) {
continue
}
// 添加到结果 map
result[dbTag] = value.Interface()
}
return result, nil
}
// ColumnsToStruct 将查询结果映射到结构体 - 用于 SELECT 操作
func (fm *FieldMapper) ColumnsToStruct(rows *sql.Rows, model interface{}) error {
// 获取列信息
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("获取列信息失败:%w", err)
}
// 获取反射对象
val := reflect.ValueOf(model)
if val.Kind() != reflect.Ptr {
return errors.New("模型必须是指针类型")
}
elem := val.Elem()
if elem.Kind() != reflect.Struct {
return errors.New("模型必须是指向结构体的指针")
}
// 创建扫描目标
scanTargets := make([]interface{}, len(columns))
fieldMap := make(map[int]int) // column index -> field index
// 建立列名到结构体字段的映射
for i, col := range columns {
found := false
for j := 0; j < elem.NumField(); j++ {
field := elem.Type().Field(j)
dbTag := field.Tag.Get("db")
// 匹配列名和字段
if dbTag == col || strings.ToLower(dbTag) == strings.ToLower(col) ||
strings.ToLower(field.Name) == strings.ToLower(col) {
fieldMap[i] = j
found = true
break
}
}
// 如果没找到匹配字段,使用 interface{} 占位
if !found {
var dummy interface{}
scanTargets[i] = &dummy
}
}
// 为找到的字段创建扫描目标
for i := range columns {
if fieldIdx, ok := fieldMap[i]; ok {
field := elem.Field(fieldIdx)
if field.CanSet() {
scanTargets[i] = field.Addr().Interface()
} else {
var dummy interface{}
scanTargets[i] = &dummy
}
}
}
// 执行扫描
if err := rows.Scan(scanTargets...); err != nil {
return fmt.Errorf("扫描数据失败:%w", err)
}
return nil
}
// GetTableName 获取模型对应的表名
func (fm *FieldMapper) GetTableName(model interface{}) string {
// 检查是否实现了 TableName() 方法
type tabler interface {
TableName() string
}
if t, ok := model.(tabler); ok {
return t.TableName()
}
// 否则使用结构体名称
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
typ := val.Type()
return fm.toSnakeCase(typ.Name())
}
// GetPrimaryKey 获取主键字段名 - 默认为 "id"
func (fm *FieldMapper) GetPrimaryKey(model interface{}) string {
// 查找标记为主键的字段
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
// 检查是否是 ID 字段
fieldName := field.Name
if fieldName == "ID" || fieldName == "Id" || fieldName == "id" {
dbTag := field.Tag.Get("db")
if dbTag != "" && dbTag != "-" {
return dbTag
}
return "id"
}
// 检查是否有 primary 标签
if field.Tag.Get("primary") == "true" {
dbTag := field.Tag.Get("db")
if dbTag != "" {
return dbTag
}
}
}
return "id" // 默认返回 id
}
// GetFields 获取所有字段信息 - 用于生成 SQL 语句
func (fm *FieldMapper) GetFields(model interface{}) []FieldInfo {
var fields []FieldInfo
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
typ := val.Type()
// 遍历所有字段
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
// 跳过未导出的字段
if !field.IsExported() {
continue
}
// 获取 db 标签
dbTag := field.Tag.Get("db")
if dbTag == "" || dbTag == "-" {
continue
}
// 创建字段信息
info := FieldInfo{
Name: field.Name,
Column: dbTag,
Type: fm.getTypeName(field.Type),
DbType: fm.mapToDbType(field.Type),
Tag: dbTag,
}
// 检查是否是主键
if field.Tag.Get("primary") == "true" ||
field.Name == "ID" || field.Name == "Id" {
info.IsPrimary = true
}
// 检查是否是自增
if field.Tag.Get("auto") == "true" {
info.IsAuto = true
}
fields = append(fields, info)
}
return fields
}
// isZeroValue 检查是否是零值
func (fm *FieldMapper) isZeroValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
case reflect.Struct:
// 特殊处理 time.Time
if t, ok := v.Interface().(time.Time); ok {
return t.IsZero()
}
return false
}
return false
}
// getTypeName 获取类型的名称
func (fm *FieldMapper) getTypeName(t reflect.Type) string {
return t.String()
}
// mapToDbType 将 Go 类型映射到数据库类型
func (fm *FieldMapper) mapToDbType(t reflect.Type) string {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return "BIGINT"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "BIGINT UNSIGNED"
case reflect.Float32, reflect.Float64:
return "DECIMAL"
case reflect.Bool:
return "TINYINT"
case reflect.String:
return "VARCHAR(255)"
default:
// 特殊类型
if t.PkgPath() == "time" && t.Name() == "Time" {
return "DATETIME"
}
return "TEXT"
}
}
// toSnakeCase 将驼峰命名转换为下划线命名
func (fm *FieldMapper) toSnakeCase(str string) string {
var result strings.Builder
for i, r := range str {
if r >= 'A' && r <= 'Z' {
if i > 0 {
result.WriteRune('_')
}
result.WriteRune(r + 32) // 转换为小写
} else {
result.WriteRune(r)
}
}
return result.String()
}

View File

@ -1,292 +0,0 @@
package core
import (
"fmt"
"strings"
)
// Migrator 迁移管理器实现 - 处理数据库架构的自动迁移
type Migrator struct {
db *Database // 数据库连接实例
}
// NewMigrator 创建迁移管理器实例
func NewMigrator(db *Database) IMigrator {
return &Migrator{db: db}
}
// AutoMigrate 自动迁移 - 根据模型自动创建或更新数据库表结构
func (m *Migrator) AutoMigrate(models ...interface{}) error {
for _, model := range models {
if err := m.CreateTable(model); err != nil {
return fmt.Errorf("创建表失败:%w", err)
}
}
return nil
}
// CreateTable 创建表 - 根据模型创建数据库表
func (m *Migrator) CreateTable(model interface{}) error {
mapper := NewFieldMapper()
// 获取表名
tableName := mapper.GetTableName(model)
// 获取字段信息
fields := mapper.GetFields(model)
if len(fields) == 0 {
return fmt.Errorf("模型没有有效的字段")
}
// 生成 CREATE TABLE SQL
var sqlBuilder strings.Builder
sqlBuilder.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", tableName))
columnDefs := make([]string, 0)
for _, field := range fields {
colDef := fmt.Sprintf("%s %s", field.Column, field.DbType)
// 添加主键约束
if field.IsPrimary {
colDef += " PRIMARY KEY"
if field.IsAuto {
colDef += " AUTOINCREMENT"
}
}
// 添加 NOT NULL 约束(可选)
// colDef += " NOT NULL"
columnDefs = append(columnDefs, colDef)
}
sqlBuilder.WriteString(strings.Join(columnDefs, ", "))
sqlBuilder.WriteString(")")
createSQL := sqlBuilder.String()
if m.db.debug {
fmt.Printf("[Magic-ORM] CREATE TABLE SQL: %s\n", createSQL)
}
// 执行 SQL
_, err := m.db.db.Exec(createSQL)
if err != nil {
return fmt.Errorf("执行 CREATE TABLE 失败:%w", err)
}
return nil
}
// DropTable 删除表 - 删除指定的数据库表
func (m *Migrator) DropTable(model interface{}) error {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
if m.db.debug {
fmt.Printf("[Magic-ORM] DROP TABLE SQL: %s\n", dropSQL)
}
_, err := m.db.db.Exec(dropSQL)
if err != nil {
return fmt.Errorf("执行 DROP TABLE 失败:%w", err)
}
return nil
}
// HasTable 检查表是否存在 - 验证数据库中是否已存在指定表
func (m *Migrator) HasTable(model interface{}) (bool, error) {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
// SQLite 检查表是否存在的 SQL
checkSQL := `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?`
var count int
err := m.db.db.QueryRow(checkSQL, tableName).Scan(&count)
if err != nil {
return false, fmt.Errorf("检查表是否存在失败:%w", err)
}
return count > 0, nil
}
// RenameTable 重命名表 - 修改数据库表的名称
func (m *Migrator) RenameTable(oldName, newName string) error {
renameSQL := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName)
if m.db.debug {
fmt.Printf("[Magic-ORM] RENAME TABLE SQL: %s\n", renameSQL)
}
_, err := m.db.db.Exec(renameSQL)
if err != nil {
return fmt.Errorf("重命名表失败:%w", err)
}
return nil
}
// AddColumn 添加列 - 向表中添加新的字段
func (m *Migrator) AddColumn(model interface{}, field string) error {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
// 获取字段信息
fields := mapper.GetFields(model)
var targetField *FieldInfo
for _, f := range fields {
if f.Name == field || f.Column == field {
targetField = &f
break
}
}
if targetField == nil {
return fmt.Errorf("字段不存在:%s", field)
}
addSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s",
tableName, targetField.Column, targetField.DbType)
if m.db.debug {
fmt.Printf("[Magic-ORM] ADD COLUMN SQL: %s\n", addSQL)
}
_, err := m.db.db.Exec(addSQL)
if err != nil {
return fmt.Errorf("添加列失败:%w", err)
}
return nil
}
// DropColumn 删除列 - 从表中删除指定的字段
func (m *Migrator) DropColumn(model interface{}, field string) error {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
// SQLite 不直接支持 DROP COLUMN需要重建表
// 这里使用简化方案:创建新表 -> 复制数据 -> 删除旧表 -> 重命名
_ = tableName // 避免编译错误
return fmt.Errorf("SQLite 不支持直接删除列,需要手动重建表")
}
// HasColumn 检查列是否存在 - 验证表中是否已存在指定字段
func (m *Migrator) HasColumn(model interface{}, field string) (bool, error) {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
// SQLite 检查列是否存在的 SQL
checkSQL := `PRAGMA table_info(` + tableName + `)`
rows, err := m.db.db.Query(checkSQL)
if err != nil {
return false, fmt.Errorf("检查列失败:%w", err)
}
defer rows.Close()
for rows.Next() {
var cid int
var name string
var typ string
var notNull int
var dfltValue interface{}
var pk int
if err := rows.Scan(&cid, &name, &typ, &notNull, &dfltValue, &pk); err != nil {
return false, err
}
if name == field {
return true, nil
}
}
return false, nil
}
// RenameColumn 重命名列 - 修改表中字段的名称
func (m *Migrator) RenameColumn(model interface{}, oldField, newField string) error {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
// SQLite 3.25.0+ 支持 ALTER TABLE ... RENAME COLUMN
renameSQL := fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s",
tableName, oldField, newField)
if m.db.debug {
fmt.Printf("[Magic-ORM] RENAME COLUMN SQL: %s\n", renameSQL)
}
_, err := m.db.db.Exec(renameSQL)
if err != nil {
return fmt.Errorf("重命名列失败:%w", err)
}
return nil
}
// CreateIndex 创建索引 - 为表中的字段创建索引
func (m *Migrator) CreateIndex(model interface{}, field string) error {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
indexName := fmt.Sprintf("idx_%s_%s", tableName, field)
createSQL := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
indexName, tableName, field)
if m.db.debug {
fmt.Printf("[Magic-ORM] CREATE INDEX SQL: %s\n", createSQL)
}
_, err := m.db.db.Exec(createSQL)
if err != nil {
return fmt.Errorf("创建索引失败:%w", err)
}
return nil
}
// DropIndex 删除索引 - 删除表中的指定索引
func (m *Migrator) DropIndex(model interface{}, field string) error {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
indexName := fmt.Sprintf("idx_%s_%s", tableName, field)
dropSQL := fmt.Sprintf("DROP INDEX IF EXISTS %s", indexName)
if m.db.debug {
fmt.Printf("[Magic-ORM] DROP INDEX SQL: %s\n", dropSQL)
}
_, err := m.db.db.Exec(dropSQL)
if err != nil {
return fmt.Errorf("删除索引失败:%w", err)
}
return nil
}
// HasIndex 检查索引是否存在 - 验证表中是否已存在指定索引
func (m *Migrator) HasIndex(model interface{}, field string) (bool, error) {
mapper := NewFieldMapper()
tableName := mapper.GetTableName(model)
indexName := fmt.Sprintf("idx_%s_%s", tableName, field)
checkSQL := `SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND name=?`
var count int
err := m.db.db.QueryRow(checkSQL, indexName).Scan(&count)
if err != nil {
return false, fmt.Errorf("检查索引失败:%w", err)
}
return count > 0, nil
}

View File

@ -1,46 +0,0 @@
package core
import (
"fmt"
)
// ExampleQueryBuilder_Omit 演示 Omit 方法的使用
func ExampleQueryBuilder_Omit() {
// 定义用户模型
type User struct {
ID int64 `json:"id" db:"id"`
Name string `json:"name" db:"name"`
Email string `json:"email" db:"email"`
Password string `json:"password" db:"password"`
Status int `json:"status" db:"status"`
}
// 创建 Database 实例(示例中使用 nil实际使用需要正确初始化
db := &Database{}
// 示例 1: 排除敏感字段(如密码)
q1 := db.Model(&User{}).Omit("password")
sql1, _ := q1.(*QueryBuilder).BuildSelect()
fmt.Printf("排除密码:%s\n", sql1)
// 示例 2: 排除多个字段
q2 := db.Model(&User{}).Omit("password", "status")
sql2, _ := q2.(*QueryBuilder).BuildSelect()
fmt.Printf("排除多个字段:%s\n", sql2)
// 示例 3: 链式调用 Omit
q3 := db.Model(&User{}).Omit("password").Omit("status")
sql3, _ := q3.(*QueryBuilder).BuildSelect()
fmt.Printf("链式调用:%s\n", sql3)
// 示例 4: Select 优先于 Omit
q4 := db.Model(&User{}).Select("id", "name").Omit("password")
sql4, _ := q4.(*QueryBuilder).BuildSelect()
fmt.Printf("Select 优先:%s\n", sql4)
// 输出:
// 排除密码SELECT id, name, email, status FROM user_model
// 排除多个字段SELECT id, name, email FROM user_model
// 链式调用SELECT id, name, email FROM user_model
// Select 优先SELECT id, name FROM user_model
}

View File

@ -1,128 +0,0 @@
package core
import (
"fmt"
"testing"
)
// TestQueryBuilder_Omit 测试 Omit 方法
func TestQueryBuilder_Omit(t *testing.T) {
// 创建测试模型
type UserModel struct {
ID int64 `json:"id" db:"id"`
Name string `json:"name" db:"name"`
Email string `json:"email" db:"email"`
Password string `json:"password" db:"password"`
Status int `json:"status" db:"status"`
}
// 创建 Database 实例(使用 nil 连接,只测试 SQL 生成)
db := &Database{}
t.Run("排除单个字段", func(t *testing.T) {
qb := db.Model(&UserModel{}).Omit("password").(*QueryBuilder)
sql, args := qb.BuildSelect()
fmt.Printf("排除单个字段 SQL: %s\n", sql)
fmt.Printf("参数:%v\n", args)
// 验证 SQL 不包含 password 字段
if containsString(sql, "password") {
t.Errorf("SQL 不应该包含 password 字段:%s", sql)
}
// 验证 SQL 包含其他字段
expectedFields := []string{"id", "name", "email", "status"}
for _, field := range expectedFields {
if !containsString(sql, field) {
t.Errorf("SQL 应该包含字段 %s: %s", field, sql)
}
}
})
t.Run("排除多个字段", func(t *testing.T) {
qb := db.Model(&UserModel{}).Omit("password", "status").(*QueryBuilder)
sql, args := qb.BuildSelect()
fmt.Printf("排除多个字段 SQL: %s\n", sql)
fmt.Printf("参数:%v\n", args)
// 验证 SQL 不包含 password 和 status 字段
if containsString(sql, "password") {
t.Errorf("SQL 不应该包含 password 字段:%s", sql)
}
if containsString(sql, "status") {
t.Errorf("SQL 不应该包含 status 字段:%s", sql)
}
// 验证 SQL 包含其他字段
expectedFields := []string{"id", "name", "email"}
for _, field := range expectedFields {
if !containsString(sql, field) {
t.Errorf("SQL 应该包含字段 %s: %s", field, sql)
}
}
})
t.Run("Omit 与 Select 优先级 - Select 优先", func(t *testing.T) {
qb := db.Model(&UserModel{}).Select("id", "name").Omit("password").(*QueryBuilder)
sql, args := qb.BuildSelect()
fmt.Printf("Select 优先 SQL: %s\n", sql)
fmt.Printf("参数:%v\n", args)
// 当同时使用 Select 和 Omit 时Select 优先
expectedFields := []string{"id", "name"}
for _, field := range expectedFields {
if !containsString(sql, field) {
t.Errorf("SQL 应该包含字段 %s: %s", field, sql)
}
}
})
t.Run("链式调用 Omit", func(t *testing.T) {
qb := db.Model(&UserModel{}).Omit("password").Omit("status").(*QueryBuilder)
sql, args := qb.BuildSelect()
fmt.Printf("链式调用 Omit SQL: %s\n", sql)
fmt.Printf("参数:%v\n", args)
// 验证 SQL 不包含 password 和 status 字段
if containsString(sql, "password") {
t.Errorf("SQL 不应该包含 password 字段:%s", sql)
}
if containsString(sql, "status") {
t.Errorf("SQL 不应该包含 status 字段:%s", sql)
}
})
t.Run("不设置 Omit - 默认行为", func(t *testing.T) {
qb := db.Model(&UserModel{}).(*QueryBuilder)
sql, args := qb.BuildSelect()
fmt.Printf("默认行为 SQL: %s\n", sql)
fmt.Printf("参数:%v\n", args)
// 默认应该查询所有字段(使用 *
if sql != "SELECT * FROM user_model" {
t.Errorf("默认应该使用 SELECT *: %s", sql)
}
})
}
// containsString 检查字符串是否包含子串
func containsString(s, substr string) bool {
return len(s) >= len(substr) && (s == substr ||
s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
findSubstring(s, substr))
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@ -1,132 +0,0 @@
package core
import (
"fmt"
"testing"
"time"
)
// User 用户模型 - 用于测试
type User struct {
ID int64 `json:"id" db:"id"`
Username string `json:"username" db:"username"`
Email string `json:"email" db:"email"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
// 关联字段
Profile UserProfile `json:"profile" db:"-" gorm:"ForeignKey:UserID;References:ID"`
Orders []Order `json:"orders" db:"-" gorm:"ForeignKey:UserID;References:ID"`
}
// TableName 表名
func (User) TableName() string {
return "user"
}
// UserProfile 用户资料模型 - 一对一关联
type UserProfile struct {
ID int64 `json:"id" db:"id"`
UserID int64 `json:"user_id" db:"user_id"`
Bio string `json:"bio" db:"bio"`
Avatar string `json:"avatar" db:"avatar"`
}
// TableName 表名
func (UserProfile) TableName() string {
return "user_profile"
}
// Order 订单模型 - 一对多关联
type Order struct {
ID int64 `json:"id" db:"id"`
UserID int64 `json:"user_id" db:"user_id"`
OrderNo string `json:"order_no" db:"order_no"`
Amount float64 `json:"amount" db:"amount"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// TableName 表名
func (Order) TableName() string {
return "order"
}
// TestPreloadHasOne 测试一对一预加载
func TestPreloadHasOne(t *testing.T) {
fmt.Println("\n=== 测试一对一预加载 ===")
// 这里只是示例,实际使用需要数据库连接
// db := AutoConnect(true)
// var users []User
// err := db.Model(&User{}).Preload("Profile").Find(&users)
// if err != nil {
// t.Fatal(err)
// }
fmt.Println("一对一预加载结构已实现")
fmt.Println("✓ 测试通过")
}
// TestPreloadHasMany 测试一对多预加载
func TestPreloadHasMany(t *testing.T) {
fmt.Println("\n=== 测试一对多预加载 ===")
// 示例用法
// var users []User
// err := db.Model(&User{}).Preload("Orders").Find(&users)
// if err != nil {
// t.Fatal(err)
// }
fmt.Println("一对多预加载结构已实现")
fmt.Println("✓ 测试通过")
}
// TestPreloadBelongsTo 测试多对一预加载
func TestPreloadBelongsTo(t *testing.T) {
fmt.Println("\n=== 测试多对一预加载 ===")
// 示例用法
// var orders []Order
// err := db.Model(&Order{}).Preload("User").Find(&orders)
// if err != nil {
// t.Fatal(err)
// }
fmt.Println("多对一预加载结构已实现")
fmt.Println("✓ 测试通过")
}
// TestPreloadMultiple 测试多个预加载
func TestPreloadMultiple(t *testing.T) {
fmt.Println("\n=== 测试多个预加载 ===")
// 示例用法
// var users []User
// err := db.Model(&User{}).
// Preload("Profile").
// Preload("Orders").
// Find(&users)
// if err != nil {
// t.Fatal(err)
// }
fmt.Println("多个预加载已实现")
fmt.Println("✓ 测试通过")
}
// TestPreloadWithConditions 测试带条件的预加载
func TestPreloadWithConditions(t *testing.T) {
fmt.Println("\n=== 测试带条件的预加载 ===")
// 示例用法
// var users []User
// err := db.Model(&User{}).
// Preload("Orders", "amount > ?", 100).
// Find(&users)
// if err != nil {
// t.Fatal(err)
// }
fmt.Println("带条件的预加载已实现")
fmt.Println("✓ 测试通过")
}

View File

@ -1,740 +0,0 @@
package core
import (
"database/sql"
"fmt"
"strings"
"sync"
)
// QueryBuilder 查询构建器实现 - 提供流畅的链式查询构建能力
type QueryBuilder struct {
db *Database // 数据库连接实例
table string // 表名
model interface{} // 模型对象
whereSQL string // WHERE 条件 SQL
whereArgs []interface{} // WHERE 条件参数
selectCols []string // 选择的字段列表
omitCols []string // 排除的字段列表
orderSQL string // ORDER BY SQL
limit int // LIMIT 限制数量
offset int // OFFSET 偏移量
groupSQL string // GROUP BY SQL
havingSQL string // HAVING 条件 SQL
havingArgs []interface{} // HAVING 条件参数
joinSQL string // JOIN SQL
joinArgs []interface{} // JOIN 参数
debug bool // 调试模式开关
dryRun bool // 干跑模式开关
unscoped bool // 忽略软删除开关
tx *sql.Tx // 事务对象(如果在事务中)
// 预加载关联数据
preloadRelations map[string][]interface{} // 预加载的关联关系及条件
// 缓存相关
cache *QueryCache // 缓存实例
cacheKey string // 缓存键
useCache bool // 是否使用缓存
}
// 同步池优化 - 复用 slice 减少内存分配
var whereArgsPool = sync.Pool{
New: func() interface{} {
return make([]interface{}, 0, 10)
},
}
var joinArgsPool = sync.Pool{
New: func() interface{} {
return make([]interface{}, 0, 5)
},
}
// Model 基于模型创建查询
func (d *Database) Model(model interface{}) IQuery {
return &QueryBuilder{
db: d,
model: model,
preloadRelations: make(map[string][]interface{}),
}
}
// Table 基于表名创建查询
func (d *Database) Table(name string) IQuery {
return &QueryBuilder{
db: d,
table: name,
preloadRelations: make(map[string][]interface{}),
}
}
// Where 添加 WHERE 条件 - 性能优化版本
func (q *QueryBuilder) Where(query string, args ...interface{}) IQuery {
if q.whereSQL == "" {
q.whereSQL = query
} else {
// 使用 strings.Builder 优化字符串拼接
var builder strings.Builder
builder.Grow(len(q.whereSQL) + 5 + len(query)) // 预分配内存
builder.WriteString(q.whereSQL)
builder.WriteString(" AND ")
builder.WriteString(query)
q.whereSQL = builder.String()
}
q.whereArgs = append(q.whereArgs, args...)
return q
}
// Or 添加 OR 条件 - 性能优化版本
func (q *QueryBuilder) Or(query string, args ...interface{}) IQuery {
if q.whereSQL == "" {
q.whereSQL = query
} else {
// 使用 strings.Builder 优化字符串拼接
var builder strings.Builder
builder.Grow(len(q.whereSQL) + 10 + len(query)) // 预分配内存
builder.WriteString(" (")
builder.WriteString(q.whereSQL)
builder.WriteString(") OR ")
builder.WriteString(query)
q.whereSQL = builder.String()
}
q.whereArgs = append(q.whereArgs, args...)
return q
}
// And 添加 AND 条件(同 Where
func (q *QueryBuilder) And(query string, args ...interface{}) IQuery {
return q.Where(query, args...)
}
// Select 选择要查询的字段
func (q *QueryBuilder) Select(fields ...string) IQuery {
q.selectCols = fields
return q
}
// Omit 排除指定的字段
func (q *QueryBuilder) Omit(fields ...string) IQuery {
q.omitCols = append(q.omitCols, fields...)
return q
}
// Order 设置排序规则
func (q *QueryBuilder) Order(order string) IQuery {
q.orderSQL = order
return q
}
// OrderBy 按指定字段和方向排序
func (q *QueryBuilder) OrderBy(field string, direction string) IQuery {
q.orderSQL = field + " " + direction
return q
}
// Limit 限制返回数量
func (q *QueryBuilder) Limit(limit int) IQuery {
q.limit = limit
return q
}
// Offset 设置偏移量
func (q *QueryBuilder) Offset(offset int) IQuery {
q.offset = offset
return q
}
// Page 分页查询
func (q *QueryBuilder) Page(page, pageSize int) IQuery {
q.limit = pageSize
q.offset = (page - 1) * pageSize
return q
}
// Group 设置分组字段
func (q *QueryBuilder) Group(group string) IQuery {
q.groupSQL = group
return q
}
// Having 添加 HAVING 条件
func (q *QueryBuilder) Having(having string, args ...interface{}) IQuery {
q.havingSQL = having
q.havingArgs = args
return q
}
// Join 添加 JOIN 连接 - 性能优化版本
func (q *QueryBuilder) Join(join string, args ...interface{}) IQuery {
if q.joinSQL == "" {
q.joinSQL = join
} else {
// 使用 strings.Builder 优化字符串拼接
var builder strings.Builder
builder.Grow(len(q.joinSQL) + 1 + len(join)) // 预分配内存
builder.WriteString(q.joinSQL)
builder.WriteByte(' ')
builder.WriteString(join)
q.joinSQL = builder.String()
}
q.joinArgs = append(q.joinArgs, args...)
return q
}
// LeftJoin 左连接
func (q *QueryBuilder) LeftJoin(table, on string) IQuery {
return q.Join("LEFT JOIN " + table + " ON " + on)
}
// RightJoin 右连接
func (q *QueryBuilder) RightJoin(table, on string) IQuery {
return q.Join("RIGHT JOIN " + table + " ON " + on)
}
// InnerJoin 内连接
func (q *QueryBuilder) InnerJoin(table, on string) IQuery {
return q.Join("INNER JOIN " + table + " ON " + on)
}
// Preload 预加载关联数据
func (q *QueryBuilder) Preload(relation string, conditions ...interface{}) IQuery {
if q.preloadRelations == nil {
q.preloadRelations = make(map[string][]interface{})
}
// 将关联条件添加到预加载列表中
q.preloadRelations[relation] = conditions
return q
}
// First 查询第一条记录
func (q *QueryBuilder) First(result interface{}) error {
q.limit = 1
return q.Find(result)
}
// Find 查询多条记录
func (q *QueryBuilder) Find(result interface{}) error {
// 如果使用缓存,先检查缓存
if q.useCache && q.cache != nil && q.cacheKey != "" {
if cachedData, exists := q.cache.Get(q.cacheKey); exists {
// 缓存命中,将数据拷贝到结果对象
if err := deepCopy(cachedData, result); err != nil {
return fmt.Errorf("缓存数据拷贝失败:%w", err)
}
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] 缓存命中:%s\n", q.cacheKey)
}
return nil
}
}
// 缓存未命中,执行实际查询
sqlStr, args := q.BuildSelect()
// 调试模式打印 SQL
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式不执行 SQL
if q.dryRun {
return nil
}
var rows *sql.Rows
var err error
// 判断是否在事务中
if q.tx != nil {
rows, err = q.tx.Query(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
rows, err = q.db.db.Query(sqlStr, args...)
} else {
return fmt.Errorf("数据库连接未初始化")
}
if err != nil {
return fmt.Errorf("查询失败:%w", err)
}
defer rows.Close()
// 使用 ResultSetMapper 将查询结果映射到 result
mapper := NewResultSetMapper()
if err := mapper.ScanAll(rows, result); err != nil {
return fmt.Errorf("结果映射失败:%w", err)
}
// 执行预加载关联数据
if len(q.preloadRelations) > 0 {
if err := q.executePreload(result); err != nil {
return fmt.Errorf("预加载关联失败:%w", err)
}
}
// 将结果存入缓存(如果启用了缓存)
if q.useCache && q.cache != nil && q.cacheKey != "" {
q.cache.Set(q.cacheKey, result)
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] 缓存已设置:%s\n", q.cacheKey)
}
}
return nil
}
// Count 统计记录数量
func (q *QueryBuilder) Count(count *int64) IQuery {
// 构建 COUNT 查询
originalSelect := q.selectCols
q.selectCols = []string{"COUNT(*)"}
sqlStr, args := q.BuildSelect()
// 调试模式
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] COUNT SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式
if q.dryRun {
return q
}
var err error
if q.tx != nil {
err = q.tx.QueryRow(sqlStr, args...).Scan(count)
} else if q.db != nil && q.db.db != nil {
err = q.db.db.QueryRow(sqlStr, args...).Scan(count)
}
if err != nil {
fmt.Printf("[Magic-ORM] Count 错误:%v\n", err)
}
// 恢复原来的选择字段
q.selectCols = originalSelect
return q
}
// Exists 检查记录是否存在
func (q *QueryBuilder) Exists() (bool, error) {
// 使用 LIMIT 1 优化查询
originalLimit := q.limit
q.limit = 1
sqlStr, args := q.BuildSelect()
// 调试模式
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] EXISTS SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式
if q.dryRun {
return false, nil
}
var rows *sql.Rows
var err error
if q.tx != nil {
rows, err = q.tx.Query(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
rows, err = q.db.db.Query(sqlStr, args...)
} else {
return false, fmt.Errorf("数据库连接未初始化")
}
defer rows.Close()
if err != nil {
return false, err
}
// 检查是否有结果
exists := rows.Next()
// 恢复原来的 limit
q.limit = originalLimit
return exists, nil
}
// Updates 更新数据
func (q *QueryBuilder) Updates(data interface{}) error {
sqlStr, args := q.BuildUpdate(data)
// 调试模式打印 SQL
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] UPDATE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式不执行 SQL
if q.dryRun {
return nil
}
var err error
if q.tx != nil {
_, err = q.tx.Exec(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
_, err = q.db.db.Exec(sqlStr, args...)
} else {
return fmt.Errorf("数据库连接未初始化")
}
if err != nil {
return fmt.Errorf("更新失败:%w", err)
}
return nil
}
// UpdateColumn 更新单个字段
func (q *QueryBuilder) UpdateColumn(column string, value interface{}) error {
return q.Updates(map[string]interface{}{column: value})
}
// Delete 删除数据
func (q *QueryBuilder) Delete() error {
sqlStr, args := q.BuildDelete()
// 调试模式打印 SQL
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] DELETE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式不执行 SQL
if q.dryRun {
return nil
}
var err error
if q.tx != nil {
_, err = q.tx.Exec(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
_, err = q.db.db.Exec(sqlStr, args...)
} else {
return fmt.Errorf("数据库连接未初始化")
}
if err != nil {
return fmt.Errorf("删除失败:%w", err)
}
return nil
}
// Unscoped 忽略软删除限制
func (q *QueryBuilder) Unscoped() IQuery {
q.unscoped = true
return q
}
// DryRun 设置干跑模式(只生成 SQL 不执行)
func (q *QueryBuilder) DryRun() IQuery {
q.dryRun = true
return q
}
// Debug 设置调试模式(打印 SQL 日志)
func (q *QueryBuilder) Debug() IQuery {
q.debug = true
return q
}
// Build 构建 SELECT SQL 语句
func (q *QueryBuilder) Build() (string, []interface{}) {
return q.BuildSelect()
}
// BuildSelect 构建 SELECT SQL 语句
func (q *QueryBuilder) BuildSelect() (string, []interface{}) {
var builder strings.Builder
// SELECT 部分
builder.WriteString("SELECT ")
if len(q.selectCols) > 0 {
// 如果指定了选择字段,直接使用
builder.WriteString(strings.Join(q.selectCols, ", "))
} else if len(q.omitCols) > 0 {
// 如果没有指定 select 但设置了 omit需要从模型获取所有字段并排除 omit 的字段
fields := q.getAllFields()
if len(fields) > 0 {
builder.WriteString(strings.Join(fields, ", "))
} else {
// 无法获取字段信息,使用 *
builder.WriteString("*")
}
} else {
// 默认选择所有字段
builder.WriteString("*")
}
// FROM 部分
builder.WriteString(" FROM ")
if q.table != "" {
builder.WriteString(q.table)
} else if q.model != nil {
// 从模型获取表名
mapper := NewFieldMapper()
builder.WriteString(mapper.GetTableName(q.model))
} else {
builder.WriteString("unknown_table")
}
// JOIN 部分
if q.joinSQL != "" {
builder.WriteString(" ")
builder.WriteString(q.joinSQL)
}
// WHERE 部分
if q.whereSQL != "" {
builder.WriteString(" WHERE ")
builder.WriteString(q.whereSQL)
}
// GROUP BY 部分
if q.groupSQL != "" {
builder.WriteString(" GROUP BY ")
builder.WriteString(q.groupSQL)
}
// HAVING 部分
if q.havingSQL != "" {
builder.WriteString(" HAVING ")
builder.WriteString(q.havingSQL)
}
// ORDER BY 部分
if q.orderSQL != "" {
builder.WriteString(" ORDER BY ")
builder.WriteString(q.orderSQL)
}
// LIMIT 部分
if q.limit > 0 {
builder.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
}
// OFFSET 部分
if q.offset > 0 {
builder.WriteString(fmt.Sprintf(" OFFSET %d", q.offset))
}
// 合并参数
allArgs := make([]interface{}, 0)
allArgs = append(allArgs, q.joinArgs...)
allArgs = append(allArgs, q.whereArgs...)
allArgs = append(allArgs, q.havingArgs...)
return builder.String(), allArgs
}
// getAllFields 获取模型的所有字段(排除 omit 的字段)
func (q *QueryBuilder) getAllFields() []string {
var fields []string
// 如果有模型,从模型获取字段
if q.model != nil {
mapper := NewFieldMapper()
fieldInfos := mapper.GetFields(q.model)
// 创建 omit 字段的 map 用于快速查找
omitMap := make(map[string]bool)
for _, omitField := range q.omitCols {
// 同时存储原始形式和小写形式,支持不区分大小写的匹配
omitMap[omitField] = true
omitMap[strings.ToLower(omitField)] = true
}
// 遍历所有字段,排除 omit 的字段
for _, fieldInfo := range fieldInfos {
// 检查字段是否在 omit 列表中
if !omitMap[fieldInfo.Column] && !omitMap[strings.ToLower(fieldInfo.Column)] {
fields = append(fields, fieldInfo.Column)
}
}
} else if q.table != "" {
// 如果只有表名没有模型,从数据库元数据获取字段
columns, err := q.getTableColumns(q.table)
if err != nil {
// 如果获取失败,返回 nil 使用 SELECT *
return nil
}
// 创建 omit 字段的 map 用于快速查找
omitMap := make(map[string]bool)
for _, omitField := range q.omitCols {
omitMap[omitField] = true
omitMap[strings.ToLower(omitField)] = true
}
// 过滤掉 omit 的字段
for _, col := range columns {
if !omitMap[col] && !omitMap[strings.ToLower(col)] {
fields = append(fields, col)
}
}
}
return fields
}
// getTableColumns 从数据库元数据获取表的列名
func (q *QueryBuilder) getTableColumns(tableName string) ([]string, error) {
if q.db == nil || q.db.db == nil {
return nil, fmt.Errorf("数据库连接未初始化")
}
var query string
var args []interface{}
var rows *sql.Rows
var err error
// 根据不同数据库类型查询元数据
switch q.db.driverName {
case "mysql":
query = `
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION
`
args = []interface{}{tableName}
case "postgres":
query = `
SELECT column_name
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position
`
args = []interface{}{tableName}
case "sqlite", "sqlite3":
query = `PRAGMA table_info(?)`
args = []interface{}{tableName}
default:
// 未知数据库类型,返回空
return nil, fmt.Errorf("不支持的数据库类型:%s", q.db.driverName)
}
rows, err = q.db.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询表元数据失败:%w", err)
}
defer rows.Close()
var columns []string
for rows.Next() {
var columnName string
if q.db.driverName == "sqlite" || q.db.driverName == "sqlite3" {
// SQLite PRAGMA table_info 返回多列cid, name, type, notnull, dflt_value, pk
var cid int
var typ string
var notNull int
var dfltValue sql.NullString
var pk int
if err := rows.Scan(&cid, &columnName, &typ, &notNull, &dfltValue, &pk); err != nil {
return nil, err
}
} else {
if err := rows.Scan(&columnName); err != nil {
return nil, err
}
}
columns = append(columns, columnName)
}
if err := rows.Err(); err != nil {
return nil, err
}
return columns, nil
}
// executePreload 执行预加载关联数据
func (q *QueryBuilder) executePreload(models interface{}) error {
// 创建关联加载器
loader := NewRelationLoader(q.db)
// 遍历所有预加载的关联关系
for relation, conditions := range q.preloadRelations {
if err := loader.Preload(models, relation, conditions...); err != nil {
return err
}
}
return nil
}
// BuildUpdate 构建 UPDATE SQL 语句
func (q *QueryBuilder) BuildUpdate(data interface{}) (string, []interface{}) {
var builder strings.Builder
var args []interface{}
builder.WriteString("UPDATE ")
if q.table != "" {
builder.WriteString(q.table)
} else if q.model != nil {
mapper := NewFieldMapper()
builder.WriteString(mapper.GetTableName(q.model))
} else {
builder.WriteString("unknown_table")
}
builder.WriteString(" SET ")
// 根据 data 类型生成 SET 子句
switch v := data.(type) {
case map[string]interface{}:
// map 类型,生成 key=value 对
setParts := make([]string, 0, len(v))
for key, value := range v {
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
args = append(args, value)
}
builder.WriteString(strings.Join(setParts, ", "))
case string:
// string 类型,直接使用(注意:实际使用需要转义)
builder.WriteString(v)
default:
// 结构体类型,使用字段映射器
mapper := NewFieldMapper()
columns, err := mapper.StructToColumns(data)
if err == nil && len(columns) > 0 {
setParts := make([]string, 0, len(columns))
for key := range columns {
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
args = append(args, columns[key])
}
builder.WriteString(strings.Join(setParts, ", "))
}
}
// WHERE 部分
if q.whereSQL != "" {
builder.WriteString(" WHERE ")
builder.WriteString(q.whereSQL)
args = append(args, q.whereArgs...)
}
return builder.String(), args
}
// BuildDelete 构建 DELETE SQL 语句
func (q *QueryBuilder) BuildDelete() (string, []interface{}) {
var builder strings.Builder
builder.WriteString("DELETE FROM ")
if q.table != "" {
builder.WriteString(q.table)
} else if q.model != nil {
mapper := NewFieldMapper()
builder.WriteString(mapper.GetTableName(q.model))
} else {
builder.WriteString("unknown_table")
}
if q.whereSQL != "" {
builder.WriteString(" WHERE ")
builder.WriteString(q.whereSQL)
}
return builder.String(), q.whereArgs
}

View File

@ -1,124 +0,0 @@
package core
import (
"database/sql"
"sync"
"sync/atomic"
)
// ReadWriteDB 读写分离数据库连接
type ReadWriteDB struct {
master *sql.DB // 主库(写)
slaves []*sql.DB // 从库列表(读)
policy ReadPolicy // 读负载均衡策略
counter uint64 // 轮询计数器
mu sync.RWMutex // 读写锁
}
// NewReadWriteDB 创建读写分离数据库连接
func NewReadWriteDB(master *sql.DB, slaves []*sql.DB, policy ReadPolicy) *ReadWriteDB {
return &ReadWriteDB{
master: master,
slaves: slaves,
policy: policy,
}
}
// GetMaster 获取主库连接(用于写操作)
func (rw *ReadWriteDB) GetMaster() *sql.DB {
return rw.master
}
// GetSlave 获取从库连接(用于读操作)
func (rw *ReadWriteDB) GetSlave() *sql.DB {
rw.mu.RLock()
defer rw.mu.RUnlock()
if len(rw.slaves) == 0 {
// 没有从库,使用主库
return rw.master
}
switch rw.policy {
case Random:
// 随机选择一个从库
idx := int(atomic.LoadUint64(&rw.counter)) % len(rw.slaves)
return rw.slaves[idx]
case RoundRobin:
// 轮询选择从库
idx := int(atomic.AddUint64(&rw.counter, 1)) % len(rw.slaves)
return rw.slaves[idx]
case LeastConn:
// 选择连接数最少的从库(简化实现)
return rw.selectLeastConn()
default:
return rw.slaves[0]
}
}
// selectLeastConn 选择连接数最少的从库
func (rw *ReadWriteDB) selectLeastConn() *sql.DB {
if len(rw.slaves) == 0 {
return rw.master
}
minConn := -1
selected := rw.slaves[0]
for _, slave := range rw.slaves {
stats := slave.Stats()
openConnections := stats.OpenConnections
if minConn == -1 || openConnections < minConn {
minConn = openConnections
selected = slave
}
}
return selected
}
// AddSlave 添加从库
func (rw *ReadWriteDB) AddSlave(slave *sql.DB) {
rw.mu.Lock()
defer rw.mu.Unlock()
rw.slaves = append(rw.slaves, slave)
}
// RemoveSlave 移除从库
func (rw *ReadWriteDB) RemoveSlave(slave *sql.DB) {
rw.mu.Lock()
defer rw.mu.Unlock()
for i, s := range rw.slaves {
if s == slave {
rw.slaves = append(rw.slaves[:i], rw.slaves[i+1:]...)
break
}
}
}
// Close 关闭所有连接
func (rw *ReadWriteDB) Close() error {
rw.mu.Lock()
defer rw.mu.Unlock()
// 关闭主库
if rw.master != nil {
if err := rw.master.Close(); err != nil {
return err
}
}
// 关闭所有从库
for _, slave := range rw.slaves {
if err := slave.Close(); err != nil {
return err
}
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More