diff --git a/config/fun.go b/config/fun.go index b324fa9..185e292 100644 --- a/config/fun.go +++ b/config/fun.go @@ -54,30 +54,29 @@ func init() { }) } +func GetConfigPath() string { + return configPath +} + // SetDefault 设置默认配置信息 func SetDefault() { viper.Set("SERVER.addr", "127.0.0.1:8080") viper.Set("SERVER.mode", "release") - - // 数据库配置 - 支持多种数据库类型 - viper.Set("DATABASE.type", "sqlite") - viper.Set("DATABASE.dns", gfile.Join(gfile.Pwd(), "db", "database.db")) - viper.Set("DATABASE.debug", true) - - // 数据库连接池配置 - viper.Set("DATABASE.maxIdleConns", 10) // 最大空闲连接数 - viper.Set("DATABASE.maxOpenConns", 100) // 最大打开连接数 - viper.Set("DATABASE.connMaxLifetime", 3600) // 连接最大生命周期(秒) - - // 数据库主从配置(可选) - viper.Set("DATABASE.replicas", []string{}) // 从库列表 - viper.Set("DATABASE.readPolicy", "random") // 读负载均衡策略 - - // 时间配置 - 定义时间字段名称和格式 - viper.Set("DATABASE.timeConfig.createdAt", "created_at") - viper.Set("DATABASE.timeConfig.updatedAt", "updated_at") - viper.Set("DATABASE.timeConfig.deletedAt", "deleted_at") - viper.Set("DATABASE.timeConfig.format", "2006-01-02 15:04:05") + viper.Set("DATABASE.default.host", "127.0.0.1") + viper.Set("DATABASE.default.port", "3306") + viper.Set("DATABASE.default.user", "root") + viper.Set("DATABASE.default.pass", "123456") + viper.Set("DATABASE.default.name", "test") + viper.Set("DATABASE.default.type", "mysql") + viper.Set("DATABASE.default.role", "master") + viper.Set("DATABASE.default.debug", false) + viper.Set("DATABASE.default.prefix", "") + viper.Set("DATABASE.default.dryRun", false) + viper.Set("DATABASE.default.charset", "utf8") + viper.Set("DATABASE.default.timezone", "Local") + viper.Set("DATABASE.default.createdAt", "create_time") + viper.Set("DATABASE.default.updatedAt", "update_time") + viper.Set("DATABASE.default.timeMaintainDisabled", false) // JWT 配置 viper.Set("JWT.secret", "SET-YOUR-SECRET") diff --git a/database/base/base.go b/database/base/base.go deleted file mode 100644 index e6de47c..0000000 --- a/database/base/base.go +++ /dev/null @@ -1,31 +0,0 @@ -package base - -import ( - "github.com/gogf/gf/v2/os/gtime" - "gorm.io/gorm" -) - -type IdModel struct { - Id int `json:"id" gorm:"column:id;type:int(11);common:id"` -} -type TimeModel struct { - CreateTime string `json:"create_time" gorm:"column:create_time;type:varchar(255);common:创建时间"` - UpdateTime string `json:"update_time" gorm:"column:update_time;type:varchar(255);common:更新时间"` -} - -func (tm *TimeModel) BeforeCreate(scope *gorm.DB) error { - scope.Set("create_time", gtime.Datetime()) - scope.Set("update_time", gtime.Datetime()) - return nil -} - -func (tm *TimeModel) BeforeUpdate(scope *gorm.DB) error { - scope.Set("update_time", gtime.Datetime()) - return nil -} - -//func (tm *TimeModel) AfterFind(scope *gorm.DB) error { -// tm.CreateTime = gtime.New(tm.CreateTime).String() -// tm.UpdateTime = gtime.New(tm.UpdateTime).String() -// return nil -//} diff --git a/database/command/command.go b/database/command/command.go new file mode 100644 index 0000000..7cac390 --- /dev/null +++ b/database/command/command.go @@ -0,0 +1,135 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. +// + +// Package command provides console operations, like options/arguments reading. +package command + +import ( + "os" + "regexp" + "strings" +) + +var ( + defaultParsedArgs = make([]string, 0) + defaultParsedOptions = make(map[string]string) + argumentOptionRegex = regexp.MustCompile(`^\-{1,2}([\w\?\.\-]+)(=){0,1}(.*)$`) +) + +// Init does custom initialization. +func Init(args ...string) { + if len(args) == 0 { + if len(defaultParsedArgs) == 0 && len(defaultParsedOptions) == 0 { + args = os.Args + } else { + return + } + } else { + defaultParsedArgs = make([]string, 0) + defaultParsedOptions = make(map[string]string) + } + // Parsing os.Args with default algorithm. + defaultParsedArgs, defaultParsedOptions = ParseUsingDefaultAlgorithm(args...) +} + +// ParseUsingDefaultAlgorithm parses arguments using default algorithm. +func ParseUsingDefaultAlgorithm(args ...string) (parsedArgs []string, parsedOptions map[string]string) { + parsedArgs = make([]string, 0) + parsedOptions = make(map[string]string) + for i := 0; i < len(args); { + array := argumentOptionRegex.FindStringSubmatch(args[i]) + if len(array) > 2 { + if array[2] == "=" { + parsedOptions[array[1]] = array[3] + } else if i < len(args)-1 { + if len(args[i+1]) > 0 && args[i+1][0] == '-' { + // Example: gf gen -d -n 1 + parsedOptions[array[1]] = array[3] + } else { + // Example: gf gen -n 2 + parsedOptions[array[1]] = args[i+1] + i += 2 + continue + } + } else { + // Example: gf gen -h + parsedOptions[array[1]] = array[3] + } + } else { + parsedArgs = append(parsedArgs, args[i]) + } + i++ + } + return +} + +// GetOpt returns the option value named `name`. +func GetOpt(name string, def ...string) string { + Init() + if v, ok := defaultParsedOptions[name]; ok { + return v + } + if len(def) > 0 { + return def[0] + } + return "" +} + +// GetOptAll returns all parsed options. +func GetOptAll() map[string]string { + Init() + return defaultParsedOptions +} + +// ContainsOpt checks whether option named `name` exist in the arguments. +func ContainsOpt(name string) bool { + Init() + _, ok := defaultParsedOptions[name] + return ok +} + +// GetArg returns the argument at `index`. +func GetArg(index int, def ...string) string { + Init() + if index < len(defaultParsedArgs) { + return defaultParsedArgs[index] + } + if len(def) > 0 { + return def[0] + } + return "" +} + +// GetArgAll returns all parsed arguments. +func GetArgAll() []string { + Init() + return defaultParsedArgs +} + +// GetOptWithEnv returns the command line argument of the specified `key`. +// If the argument does not exist, then it returns the environment variable with specified `key`. +// It returns the default value `def` if none of them exists. +// +// Fetching Rules: +// 1. Command line arguments are in lowercase format, eg: gf.package.variable; +// 2. Environment arguments are in uppercase format, eg: GF_PACKAGE_VARIABLE; +func GetOptWithEnv(key string, def ...string) string { + cmdKey := strings.ToLower(strings.ReplaceAll(key, "_", ".")) + if ContainsOpt(cmdKey) { + return GetOpt(cmdKey) + } else { + envKey := strings.ToUpper(strings.ReplaceAll(key, ".", "_")) + if r, ok := os.LookupEnv(envKey); ok { + return r + } else { + if len(def) > 0 { + return def[0] + } + } + } + return "" +} diff --git a/database/database.go b/database/database.go deleted file mode 100644 index fa4134a..0000000 --- a/database/database.go +++ /dev/null @@ -1,96 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "time" - - "git.magicany.cc/black1552/gin-base/config" - "git.magicany.cc/black1552/gin-base/log" - "github.com/glebarez/sqlite" - "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/os/gfile" - "gorm.io/driver/mysql" - "gorm.io/gorm" - "gorm.io/gorm/logger" - "gorm.io/gorm/schema" -) - -var ( - Type gorm.Dialector - Db *gorm.DB - err error - sqlDb *sql.DB - dns = config.GetConfigValue("database.dns", gfile.Join(gfile.Pwd(), "db", "database.db")) -) - -func init() { - if g.IsEmpty(dns) { - log.Error("gormDns 未配置", "请检查配置文件") - return - } - switch config.GetConfigValue("database.type", "sqlite").String() { - case "mysql": - log.Info("使用 mysql 数据库") - mysqlInit() - case "sqlite": - log.Info("使用 sqlite 数据库") - sqliteInit() - } - - // 构建 GORM 配置 - gormConfig := &gorm.Config{ - SkipDefaultTransaction: true, - NowFunc: func() time.Time { - return time.Now().Local() - }, - // 命名策略:保持与模型一致,避免字段/表名转换问题 - NamingStrategy: schema.NamingStrategy{ - SingularTable: true, // 表名禁用复数形式(例如 User 对应 user 表,而非 users) - }, - } - - // 根据配置决定是否开启 GORM 查询日志 - if config.GetConfigValue("database.debug", false).Bool() { - log.Info("已开启 GORM 查询日志") - gormConfig.Logger = logger.Default.LogMode(logger.Info) - } else { - gormConfig.Logger = logger.Default.LogMode(logger.Silent) - } - - Db, err = gorm.Open(Type, gormConfig) - if err != nil { - log.Error("数据库连接失败: ", err) - return - } - sqlDb, err = Db.DB() - if err != nil { - log.Error("获取sqlDb失败", err) - return - } - if err = sqlDb.Ping(); err != nil { - log.Error("数据库未正常连接", err) - return - } -} - -func mysqlInit() { - Type = mysql.New(mysql.Config{ - DSN: dns.String(), - DefaultStringSize: 255, // string 类型字段的默认长度 - DisableDatetimePrecision: true, // 禁用 datetime 精度,MySQL 5.6 之前的数据库不支持 - DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式,MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引 - SkipInitializeWithVersion: false, // 根据当前 MySQL 版本自动配置 - }) -} - -func sqliteInit() { - if !gfile.Exists(dns.String()) { - _, err = gfile.Create(dns.String()) - if err != nil { - log.Error("创建数据库文件失败: ", err) - return - } - } - Type = sqlite.Open(fmt.Sprintf("%s?cache=shared&mode=rwc&_busy_timeout=10000&_fk=1&_journal=WAL&_sync=FULL", dns.String())) -} diff --git a/database/empty/empty.go b/database/empty/empty.go new file mode 100644 index 0000000..8d93967 --- /dev/null +++ b/database/empty/empty.go @@ -0,0 +1,243 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// Package empty provides functions for checking empty/nil variables. +package empty + +import ( + "reflect" + "time" + + "git.magicany.cc/black1552/gin-base/database/reflection" +) + +// iString is used for type assert api for String(). +type iString interface { + String() string +} + +// iInterfaces is used for type assert api for Interfaces. +type iInterfaces interface { + Interfaces() []any +} + +// iMapStrAny is the interface support for converting struct parameter to map. +type iMapStrAny interface { + MapStrAny() map[string]any +} + +type iTime interface { + Date() (year int, month time.Month, day int) + IsZero() bool +} + +// IsEmpty checks whether given `value` empty. +// It returns true if `value` is in: 0, nil, false, "", len(slice/map/chan) == 0, +// or else it returns false. +// +// The parameter `traceSource` is used for tracing to the source variable if given `value` is type of pointer +// that also points to a pointer. It returns true if the source is empty when `traceSource` is true. +// Note that it might use reflect feature which affects performance a little. +func IsEmpty(value any, traceSource ...bool) bool { + if value == nil { + return true + } + // It firstly checks the variable as common types using assertion to enhance the performance, + // and then using reflection. + switch result := value.(type) { + case int: + return result == 0 + case int8: + return result == 0 + case int16: + return result == 0 + case int32: + return result == 0 + case int64: + return result == 0 + case uint: + return result == 0 + case uint8: + return result == 0 + case uint16: + return result == 0 + case uint32: + return result == 0 + case uint64: + return result == 0 + case float32: + return result == 0 + case float64: + return result == 0 + case bool: + return !result + case string: + return result == "" + case []byte: + return len(result) == 0 + case []rune: + return len(result) == 0 + case []int: + return len(result) == 0 + case []string: + return len(result) == 0 + case []float32: + return len(result) == 0 + case []float64: + return len(result) == 0 + case map[string]any: + return len(result) == 0 + + default: + // Finally, using reflect. + var rv reflect.Value + if v, ok := value.(reflect.Value); ok { + rv = v + } else { + rv = reflect.ValueOf(value) + if IsNil(rv) { + return true + } + + // ========================= + // Common interfaces checks. + // ========================= + if f, ok := value.(iTime); ok { + if f == (*time.Time)(nil) { + return true + } + return f.IsZero() + } + if f, ok := value.(iString); ok { + if f == nil { + return true + } + return f.String() == "" + } + if f, ok := value.(iInterfaces); ok { + if f == nil { + return true + } + return len(f.Interfaces()) == 0 + } + if f, ok := value.(iMapStrAny); ok { + if f == nil { + return true + } + return len(f.MapStrAny()) == 0 + } + } + + switch rv.Kind() { + case reflect.Bool: + return !rv.Bool() + + case + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64: + return rv.Int() == 0 + + case + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr: + return rv.Uint() == 0 + + case + reflect.Float32, + reflect.Float64: + return rv.Float() == 0 + + case reflect.String: + return rv.Len() == 0 + + case reflect.Struct: + var fieldValueInterface any + for i := 0; i < rv.NumField(); i++ { + fieldValueInterface, _ = reflection.ValueToInterface(rv.Field(i)) + if !IsEmpty(fieldValueInterface) { + return false + } + } + return true + + case + reflect.Chan, + reflect.Map, + reflect.Slice, + reflect.Array: + return rv.Len() == 0 + + case reflect.Pointer: + if len(traceSource) > 0 && traceSource[0] { + return IsEmpty(rv.Elem()) + } + return rv.IsNil() + + case + reflect.Func, + reflect.Interface, + reflect.UnsafePointer: + return rv.IsNil() + + case reflect.Invalid: + return true + + default: + return false + } + } +} + +// IsNil checks whether given `value` is nil, especially for any type value. +// Parameter `traceSource` is used for tracing to the source variable if given `value` is type of pointer +// that also points to a pointer. It returns nil if the source is nil when `traceSource` is true. +// Note that it might use reflect feature which affects performance a little. +func IsNil(value any, traceSource ...bool) bool { + if value == nil { + return true + } + var rv reflect.Value + if v, ok := value.(reflect.Value); ok { + rv = v + } else { + rv = reflect.ValueOf(value) + } + switch rv.Kind() { + case reflect.Chan, + reflect.Map, + reflect.Slice, + reflect.Func, + reflect.Interface, + reflect.UnsafePointer: + return !rv.IsValid() || rv.IsNil() + + case reflect.Pointer: + if len(traceSource) > 0 && traceSource[0] { + for rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + if !rv.IsValid() { + return true + } + if rv.Kind() == reflect.Pointer { + return rv.IsNil() + } + } else { + return !rv.IsValid() || rv.IsNil() + } + + default: + return false + } + return false +} diff --git a/database/gdb.go b/database/gdb.go new file mode 100644 index 0000000..13e87db --- /dev/null +++ b/database/gdb.go @@ -0,0 +1,1175 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// package database provides ORM features for popular relationship databases. + +// TODO use context.Context as required parameter for all DB operations. +package database + +import ( + "context" + "database/sql" + "time" + + "git.magicany.cc/black1552/gin-base/log" + "github.com/gogf/gf/v2/container/garray" + "github.com/gogf/gf/v2/container/gmap" + "github.com/gogf/gf/v2/container/gtype" + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gcache" + "github.com/gogf/gf/v2/os/gcmd" + "github.com/gogf/gf/v2/os/gctx" + "github.com/gogf/gf/v2/os/glog" + "github.com/gogf/gf/v2/util/grand" + "github.com/gogf/gf/v2/util/gutil" +) + +// DB defines the interfaces for ORM operations. +type DB interface { + // =========================================================================== + // Model creation. + // =========================================================================== + + // Model creates and returns a new ORM model from given schema. + // The parameter `table` can be more than one table names, and also alias name, like: + // 1. Model names: + // Model("user") + // Model("user u") + // Model("user, user_detail") + // Model("user u, user_detail ud") + // 2. Model name with alias: Model("user", "u") + // Also see Core.Model. + Model(tableNameOrStruct ...any) *Model + + // Raw creates and returns a model based on a raw sql not a table. + Raw(rawSql string, args ...any) *Model + + // Schema switches to a specified schema. + // Also see Core.Schema. + Schema(schema string) *Schema + + // With creates and returns an ORM model based on metadata of given object. + // Also see Core.With. + With(objects ...any) *Model + + // Open creates a raw connection object for database with given node configuration. + // Note that it is not recommended using the function manually. + Open(config *ConfigNode) (*sql.DB, error) + + // Ctx is a chaining function, which creates and returns a new DB that is a shallow copy + // of current DB object and with given context in it. + // Also see Core.Ctx. + Ctx(ctx context.Context) DB + + // Close closes the database and prevents new queries from starting. + // Close then waits for all queries that have started processing on the server + // to finish. + // + // It is rare to Close a DB, as the DB handle is meant to be + // long-lived and shared between many goroutines. + Close(ctx context.Context) error + + // =========================================================================== + // Query APIs. + // =========================================================================== + + // Query executes a SQL query that returns rows using given SQL and arguments. + // The args are for any placeholder parameters in the query. + Query(ctx context.Context, sql string, args ...any) (Result, error) + + // Exec executes a SQL query that doesn't return rows (e.g., INSERT, UPDATE, DELETE). + // It returns sql.Result for accessing LastInsertId or RowsAffected. + Exec(ctx context.Context, sql string, args ...any) (sql.Result, error) + + // Prepare creates a prepared statement for later queries or executions. + // The execOnMaster parameter determines whether the statement executes on master node. + Prepare(ctx context.Context, sql string, execOnMaster ...bool) (*Stmt, error) + + // =========================================================================== + // Common APIs for CRUD. + // =========================================================================== + + // Insert inserts one or multiple records into table. + // The data can be a map, struct, or slice of maps/structs. + // The optional batch parameter specifies the batch size for bulk inserts. + Insert(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) + + // InsertIgnore inserts records but ignores duplicate key errors. + // It works like Insert but adds IGNORE keyword to the SQL statement. + InsertIgnore(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) + + // InsertAndGetId inserts a record and returns the auto-generated ID. + // It's a convenience method combining Insert with LastInsertId. + InsertAndGetId(ctx context.Context, table string, data any, batch ...int) (int64, error) + + // Replace inserts or replaces records using REPLACE INTO syntax. + // Existing records with same unique key will be deleted and re-inserted. + Replace(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) + + // Save inserts or updates records using INSERT ... ON DUPLICATE KEY UPDATE syntax. + // It updates existing records instead of replacing them entirely. + Save(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) + + // Update updates records in table that match the condition. + // The data can be a map or struct containing the new values. + // The condition specifies the WHERE clause with optional placeholder args. + Update(ctx context.Context, table string, data any, condition any, args ...any) (sql.Result, error) + + // Delete deletes records from table that match the condition. + // The condition specifies the WHERE clause with optional placeholder args. + Delete(ctx context.Context, table string, condition any, args ...any) (sql.Result, error) + + // =========================================================================== + // Internal APIs for CRUD, which can be overwritten by custom CRUD implements. + // =========================================================================== + + // DoSelect executes a SELECT query using the given link and returns the result. + // This is an internal method that can be overridden by custom implementations. + DoSelect(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) + + // DoInsert performs the actual INSERT operation with given options. + // This is an internal method that can be overridden by custom implementations. + DoInsert(ctx context.Context, link Link, table string, data List, option DoInsertOption) (result sql.Result, err error) + + // DoUpdate performs the actual UPDATE operation. + // This is an internal method that can be overridden by custom implementations. + DoUpdate(ctx context.Context, link Link, table string, data any, condition string, args ...any) (result sql.Result, err error) + + // DoDelete performs the actual DELETE operation. + // This is an internal method that can be overridden by custom implementations. + DoDelete(ctx context.Context, link Link, table string, condition string, args ...any) (result sql.Result, err error) + + // DoQuery executes a query that returns rows. + // This is an internal method that can be overridden by custom implementations. + DoQuery(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) + + // DoExec executes a query that doesn't return rows. + // This is an internal method that can be overridden by custom implementations. + DoExec(ctx context.Context, link Link, sql string, args ...any) (result sql.Result, err error) + + // DoFilter processes and filters SQL and args before execution. + // This is an internal method that can be overridden to implement custom SQL filtering. + DoFilter(ctx context.Context, link Link, sql string, args []any) (newSql string, newArgs []any, err error) + + // DoCommit handles the actual commit operation for transactions. + // This is an internal method that can be overridden by custom implementations. + DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) + + // DoPrepare creates a prepared statement on the given link. + // This is an internal method that can be overridden by custom implementations. + DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) + + // =========================================================================== + // Query APIs for convenience purpose. + // =========================================================================== + + // GetAll executes a query and returns all rows as Result. + // It's a convenience wrapper around Query. + GetAll(ctx context.Context, sql string, args ...any) (Result, error) + + // GetOne executes a query and returns the first row as Record. + // It's useful when you expect only one row to be returned. + GetOne(ctx context.Context, sql string, args ...any) (Record, error) + + // GetValue executes a query and returns the first column of the first row. + // It's useful for queries like SELECT COUNT(*) or getting a single value. + GetValue(ctx context.Context, sql string, args ...any) (Value, error) + + // GetArray executes a query and returns the first column of all rows. + // It's useful for queries like SELECT id FROM table. + GetArray(ctx context.Context, sql string, args ...any) (Array, error) + + // GetCount executes a COUNT query and returns the result as an integer. + // It's a convenience method for counting rows. + GetCount(ctx context.Context, sql string, args ...any) (int, error) + + // GetScan executes a query and scans the result into the given object pointer. + // It automatically maps database columns to struct fields or slice elements. + GetScan(ctx context.Context, objPointer any, sql string, args ...any) error + + // Union combines multiple SELECT queries using UNION operator. + // It returns a new Model that represents the combined query. + Union(unions ...*Model) *Model + + // UnionAll combines multiple SELECT queries using UNION ALL operator. + // Unlike Union, it keeps duplicate rows in the result. + UnionAll(unions ...*Model) *Model + + // =========================================================================== + // Master/Slave specification support. + // =========================================================================== + + // Master returns a connection to the master database node. + // The optional schema parameter specifies which database schema to use. + Master(schema ...string) (*sql.DB, error) + + // Slave returns a connection to a slave database node. + // The optional schema parameter specifies which database schema to use. + Slave(schema ...string) (*sql.DB, error) + + // =========================================================================== + // Ping-Pong. + // =========================================================================== + + // PingMaster checks if the master database node is accessible. + // It returns an error if the connection fails. + PingMaster() error + + // PingSlave checks if any slave database node is accessible. + // It returns an error if no slave connections are available. + PingSlave() error + + // =========================================================================== + // Transaction. + // =========================================================================== + + // Begin starts a new transaction and returns a TX interface. + // The returned TX must be committed or rolled back to release resources. + Begin(ctx context.Context) (TX, error) + + // BeginWithOptions starts a new transaction with the given options and returns a TX interface. + // The options allow specifying isolation level and read-only mode. + // The returned TX must be committed or rolled back to release resources. + BeginWithOptions(ctx context.Context, opts TxOptions) (TX, error) + + // Transaction executes a function within a transaction. + // It automatically handles commit/rollback based on whether f returns an error. + Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) error + + // TransactionWithOptions executes a function within a transaction with specific options. + // It allows customizing transaction behavior like isolation level and timeout. + TransactionWithOptions(ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error) error + + // =========================================================================== + // Configuration methods. + // =========================================================================== + + // GetCache returns the cache instance used by this database. + // The cache is used for query results caching. + GetCache() *gcache.Cache + + // SetDebug enables or disables debug mode for SQL logging. + // When enabled, all SQL statements and their execution time are logged. + SetDebug(debug bool) + + // GetDebug returns whether debug mode is enabled. + GetDebug() bool + + // GetSchema returns the current database schema name. + GetSchema() string + + // GetPrefix returns the table name prefix used by this database. + GetPrefix() string + + // GetGroup returns the configuration group name of this database. + GetGroup() string + + // SetDryRun enables or disables dry-run mode. + // In dry-run mode, SQL statements are generated but not executed. + SetDryRun(enabled bool) + + // GetDryRun returns whether dry-run mode is enabled. + GetDryRun() bool + + // SetLogger sets a custom logger for database operations. + // The logger must implement log.ILogger interface. + SetLogger(logger log.ILogger) + + // GetLogger returns the current logger used by this database. + GetLogger() log.ILogger + + // GetConfig returns the configuration node used by this database. + GetConfig() *ConfigNode + + // SetMaxIdleConnCount sets the maximum number of idle connections in the pool. + SetMaxIdleConnCount(n int) + + // SetMaxOpenConnCount sets the maximum number of open connections to the database. + SetMaxOpenConnCount(n int) + + // SetMaxConnLifeTime sets the maximum amount of time a connection may be reused. + SetMaxConnLifeTime(d time.Duration) + + // SetMaxIdleConnTime sets the maximum amount of time a connection may be idle before being closed. + SetMaxIdleConnTime(d time.Duration) + + // =========================================================================== + // Utility methods. + // =========================================================================== + + // Stats returns statistics about the database connection pool. + // It includes information like the number of active and idle connections. + Stats(ctx context.Context) []StatsItem + + // GetCtx returns the context associated with this database instance. + GetCtx() context.Context + + // GetCore returns the underlying Core instance of this database. + GetCore() *Core + + // GetChars returns the left and right quote characters used for escaping identifiers. + // For example, in MySQL these are backticks: ` and `. + GetChars() (charLeft string, charRight string) + + // Tables returns a list of all table names in the specified schema. + // If no schema is specified, it uses the default schema. + Tables(ctx context.Context, schema ...string) (tables []string, err error) + + // TableFields returns detailed information about all fields in the specified table. + // The returned map keys are field names and values contain field metadata. + TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) + + // ConvertValueForField converts a value to the appropriate type for a database field. + // It handles type conversion from Go types to database-specific types. + ConvertValueForField(ctx context.Context, fieldType string, fieldValue any) (any, error) + + // ConvertValueForLocal converts a database value to the appropriate Go type. + // It handles type conversion from database-specific types to Go types. + ConvertValueForLocal(ctx context.Context, fieldType string, fieldValue any) (any, error) + + // GetFormattedDBTypeNameForField returns the formatted database type name and pattern for a field type. + GetFormattedDBTypeNameForField(fieldType string) (typeName, typePattern string) + + // CheckLocalTypeForField checks if a Go value is compatible with a database field type. + // It returns the appropriate LocalType and any conversion errors. + CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue any) (LocalType, error) + + // FormatUpsert formats an upsert (INSERT ... ON DUPLICATE KEY UPDATE) statement. + // It generates the appropriate SQL based on the columns, values, and options provided. + FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) + + // OrderRandomFunction returns the SQL function for random ordering. + // The implementation is database-specific (e.g., RAND() for MySQL). + OrderRandomFunction() string +} + +// TX defines the interfaces for ORM transaction operations. +type TX interface { + Link + + // Ctx binds a context to current transaction. + // The context is used for operations like timeout control. + Ctx(ctx context.Context) TX + + // Raw creates and returns a model based on a raw SQL. + // The rawSql can contain placeholders ? and corresponding args. + Raw(rawSql string, args ...any) *Model + + // Model creates and returns a Model from given table name/struct. + // The parameter can be table name as string, or struct/*struct type. + Model(tableNameQueryOrStruct ...any) *Model + + // With creates and returns a Model from given object. + // It automatically analyzes the object and generates corresponding SQL. + With(object any) *Model + + // =========================================================================== + // Nested transaction if necessary. + // =========================================================================== + + // Begin starts a nested transaction. + // It creates a new savepoint for current transaction. + Begin() error + + // Commit commits current transaction/savepoint. + // For nested transactions, it releases the current savepoint. + Commit() error + + // Rollback rolls back current transaction/savepoint. + // For nested transactions, it rolls back to the current savepoint. + Rollback() error + + // Transaction executes given function in a nested transaction. + // It automatically handles commit/rollback based on function's error return. + Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) + + // TransactionWithOptions executes given function in a nested transaction with options. + // It allows customizing transaction behavior like isolation level. + TransactionWithOptions(ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error) error + + // =========================================================================== + // Core method. + // =========================================================================== + + // Query executes a query that returns rows using given SQL and arguments. + // The args are for any placeholder parameters in the query. + Query(sql string, args ...any) (result Result, err error) + + // Exec executes a query that doesn't return rows. + // For example: INSERT, UPDATE, DELETE. + Exec(sql string, args ...any) (sql.Result, error) + + // Prepare creates a prepared statement for later queries or executions. + // Multiple queries or executions may be run concurrently from the statement. + Prepare(sql string) (*Stmt, error) + + // =========================================================================== + // Query. + // =========================================================================== + + // GetAll executes a query and returns all rows as Result. + // It's a convenient wrapper for Query. + GetAll(sql string, args ...any) (Result, error) + + // GetOne executes a query and returns the first row as Record. + // It's useful when you expect only one row to be returned. + GetOne(sql string, args ...any) (Record, error) + + // GetStruct executes a query and scans the result into given struct. + // The obj should be a pointer to struct. + GetStruct(obj any, sql string, args ...any) error + + // GetStructs executes a query and scans all results into given struct slice. + // The objPointerSlice should be a pointer to slice of struct. + GetStructs(objPointerSlice any, sql string, args ...any) error + + // GetScan executes a query and scans the result into given variables. + // The pointer can be type of struct/*struct/[]struct/[]*struct. + GetScan(pointer any, sql string, args ...any) error + + // GetValue executes a query and returns the first column of first row. + // It's useful for queries like SELECT COUNT(*). + GetValue(sql string, args ...any) (Value, error) + + // GetCount executes a query that should return a count value. + // It's a convenient wrapper for count queries. + GetCount(sql string, args ...any) (int64, error) + + // =========================================================================== + // CRUD. + // =========================================================================== + + // Insert inserts one or multiple records into table. + // The data can be map/struct/*struct/[]map/[]struct/[]*struct. + Insert(table string, data any, batch ...int) (sql.Result, error) + + // InsertIgnore inserts one or multiple records with IGNORE option. + // It ignores records that would cause duplicate key conflicts. + InsertIgnore(table string, data any, batch ...int) (sql.Result, error) + + // InsertAndGetId inserts one record and returns its id value. + // It's commonly used with auto-increment primary key. + InsertAndGetId(table string, data any, batch ...int) (int64, error) + + // Replace inserts or replaces records using REPLACE INTO syntax. + // Existing records with same unique key will be deleted and re-inserted. + Replace(table string, data any, batch ...int) (sql.Result, error) + + // Save inserts or updates records using INSERT ... ON DUPLICATE KEY UPDATE syntax. + // It updates existing records instead of replacing them entirely. + Save(table string, data any, batch ...int) (sql.Result, error) + + // Update updates records in table that match given condition. + // The data can be map/struct, and condition supports various formats. + Update(table string, data any, condition any, args ...any) (sql.Result, error) + + // Delete deletes records from table that match given condition. + // The condition supports various formats with optional arguments. + Delete(table string, condition any, args ...any) (sql.Result, error) + + // =========================================================================== + // Utility methods. + // =========================================================================== + + // GetCtx returns the context that is bound to current transaction. + GetCtx() context.Context + + // GetDB returns the underlying DB interface object. + GetDB() DB + + // GetSqlTX returns the underlying *sql.Tx object. + // Note: be very careful when using this method. + GetSqlTX() *sql.Tx + + // IsClosed checks if current transaction is closed. + // A transaction is closed after Commit or Rollback. + IsClosed() bool + + // =========================================================================== + // Save point feature. + // =========================================================================== + + // SavePoint creates a save point with given name. + // It's used in nested transactions to create rollback points. + SavePoint(point string) error + + // RollbackTo rolls back transaction to previously created save point. + // If the save point doesn't exist, it returns an error. + RollbackTo(point string) error +} + +// StatsItem defines the stats information for a configuration node. +type StatsItem interface { + // Node returns the configuration node info. + Node() ConfigNode + + // Stats returns the connection stat for current node. + Stats() sql.DBStats +} + +// Core is the base struct for database management. +type Core struct { + db DB // DB interface object. + ctx context.Context // Context for chaining operation only. Do not set a default value in Core initialization. + group string // Configuration group name. + schema string // Custom schema for this object. + debug *gtype.Bool // Enable debug mode for the database, which can be changed in runtime. + cache *gcache.Cache // Cache manager, SQL result cache only. + links *gmap.KVMap[ConfigNode, *sql.DB] // links caches all created links by node. + logger log.ILogger // Logger for logging functionality. + config *ConfigNode // Current config node. + localTypeMap *gmap.StrAnyMap // Local type map for database field type conversion. + dynamicConfig dynamicConfig // Dynamic configurations, which can be changed in runtime. + innerMemCache *gcache.Cache // Internal memory cache for storing temporary data. +} + +type dynamicConfig struct { + MaxIdleConnCount int + MaxOpenConnCount int + MaxConnLifeTime time.Duration + MaxIdleConnTime time.Duration +} + +// DoCommitInput is the input parameters for function DoCommit. +type DoCommitInput struct { + // Db is the underlying database connection object. + Db *sql.DB + + // Tx is the underlying transaction object. + Tx *sql.Tx + + // Stmt is the prepared statement object. + Stmt *sql.Stmt + + // Link is the common database function wrapper interface. + Link Link + + // Sql is the SQL string to be executed. + Sql string + + // Args is the arguments for SQL placeholders. + Args []any + + // Type indicates the type of SQL operation. + Type SqlType + + // TxOptions specifies the transaction options. + TxOptions sql.TxOptions + + // TxCancelFunc is the context cancel function for transaction. + TxCancelFunc context.CancelFunc + + // IsTransaction indicates whether current operation is in transaction. + IsTransaction bool +} + +// DoCommitOutput is the output parameters for function DoCommit. +type DoCommitOutput struct { + // Result is the result of exec statement. + Result sql.Result + + // Records is the result of query statement. + Records []Record + + // Stmt is the Statement object result for Prepare. + Stmt *Stmt + + // Tx is the transaction object result for Begin. + Tx TX + + // RawResult is the underlying result, which might be sql.Result/*sql.Rows/*sql.Row. + RawResult any +} + +// Driver is the interface for integrating sql drivers into package database. +type Driver interface { + // New creates and returns a database object for specified database server. + New(core *Core, node *ConfigNode) (DB, error) +} + +// Link is a common database function wrapper interface. +// Note that, any operation using `Link` will have no SQL logging. +type Link interface { + QueryContext(ctx context.Context, sql string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, sql string, args ...any) (sql.Result, error) + PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) + IsOnMaster() bool + IsTransaction() bool +} + +// Sql is the sql recording struct. +type Sql struct { + Sql string // SQL string(may contain reserved char '?'). + Type SqlType // SQL operation type. + Args []any // Arguments for this sql. + Format string // Formatted sql which contains arguments in the sql. + Error error // Execution result. + Start int64 // Start execution timestamp in milliseconds. + End int64 // End execution timestamp in milliseconds. + Group string // Group is the group name of the configuration that the sql is executed from. + Schema string // Schema is the schema name of the configuration that the sql is executed from. + IsTransaction bool // IsTransaction marks whether this sql is executed in transaction. + RowsAffected int64 // RowsAffected marks retrieved or affected number with current sql statement. +} + +// DoInsertOption is the input struct for function DoInsert. +type DoInsertOption struct { + // OnDuplicateStr is the custom string for `on duplicated` statement. + OnDuplicateStr string + + // OnDuplicateMap is the custom key-value map from `OnDuplicateEx` function for `on duplicated` statement. + OnDuplicateMap map[string]any + + // OnConflict is the custom conflict key of upsert clause, if the database needs it. + OnConflict []string + + // InsertOption is the insert operation in constant value. + InsertOption InsertOption + + // BatchCount is the batch count for batch inserting. + BatchCount int +} + +// TableField is the struct for table field. +type TableField struct { + // Index is for ordering purpose as map is unordered. + Index int + + // Name is the field name. + Name string + + // Type is the field type. Eg: 'int(10) unsigned', 'varchar(64)'. + Type string + + // Null is whether the field can be null or not. + Null bool + + // Key is the index information(empty if it's not an index). Eg: PRI, MUL. + Key string + + // Default is the default value for the field. + Default any + + // Extra is the extra information. Eg: auto_increment. + Extra string + + // Comment is the field comment. + Comment string +} + +// Counter is the type for update count. +type Counter struct { + // Field is the field name. + Field string + + // Value is the value. + Value float64 +} + +type ( + // Raw is a raw sql that will not be treated as argument but as a direct sql part. + Raw string + + // Value is the field value type. + Value = *gvar.Var + + // Array is the field value array type. + Array = gvar.Vars + + // Record is the row record of the table. + Record map[string]Value + + // Result is the row record array. + Result []Record + + // Map is alias of map[string]any, which is the most common usage map type. + Map = map[string]any + + // List is type of map array. + List = []Map +) + +type CatchSQLManager struct { + // SQLArray is the array of sql. + SQLArray *garray.StrArray + + // DoCommit marks it will be committed to underlying driver or not. + DoCommit bool +} + +const ( + defaultModelSafe = false + defaultCharset = `utf8` + defaultProtocol = `tcp` + unionTypeNormal = 0 + unionTypeAll = 1 + defaultMaxIdleConnCount = 10 // Max idle connection count in pool. + defaultMaxOpenConnCount = 0 // Max open connection count in pool. Default is no limit. + defaultMaxConnLifeTime = 30 * time.Second // Max lifetime for per connection in pool in seconds. + cachePrefixTableFields = `TableFields:` + cachePrefixSelectCache = `SelectCache:` + commandEnvKeyForDryRun = "gf.gdb.dryrun" + modelForDaoSuffix = `ForDao` + dbRoleSlave = `slave` + ctxKeyForDB gctx.StrKey = `CtxKeyForDB` + ctxKeyCatchSQL gctx.StrKey = `CtxKeyCatchSQL` + ctxKeyInternalProducedSQL gctx.StrKey = `CtxKeyInternalProducedSQL` + + linkPattern = `^(\w+):(.*?):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*?)$` + linkPatternDescription = `type:username:password@protocol(host:port)/dbname?param1=value1&...¶mN=valueN` +) + +// Context key types to avoid collisions +type ctxKey string + +const ( + ctxKeyWrappedByGetCtxTimeout ctxKey = "WrappedByGetCtxTimeout" +) + +type ctxTimeoutType int + +const ( + ctxTimeoutTypeExec ctxTimeoutType = iota + ctxTimeoutTypeQuery + ctxTimeoutTypePrepare + ctxTimeoutTypeTrans +) + +type SelectType int + +const ( + SelectTypeDefault SelectType = iota + SelectTypeCount + SelectTypeValue + SelectTypeArray +) + +type joinOperator string + +const ( + joinOperatorLeft joinOperator = "LEFT" + joinOperatorRight joinOperator = "RIGHT" + joinOperatorInner joinOperator = "INNER" +) + +type InsertOption int + +const ( + InsertOptionDefault InsertOption = iota + InsertOptionReplace + InsertOptionSave + InsertOptionIgnore +) + +const ( + InsertOperationInsert = "INSERT" + InsertOperationReplace = "REPLACE" + InsertOperationIgnore = "INSERT IGNORE" + InsertOnDuplicateKeyUpdate = "ON DUPLICATE KEY UPDATE" +) + +type SqlType string + +const ( + SqlTypeBegin SqlType = "DB.Begin" + SqlTypeTXCommit SqlType = "TX.Commit" + SqlTypeTXRollback SqlType = "TX.Rollback" + SqlTypeExecContext SqlType = "DB.ExecContext" + SqlTypeQueryContext SqlType = "DB.QueryContext" + SqlTypePrepareContext SqlType = "DB.PrepareContext" + SqlTypeStmtExecContext SqlType = "DB.Statement.ExecContext" + SqlTypeStmtQueryContext SqlType = "DB.Statement.QueryContext" + SqlTypeStmtQueryRowContext SqlType = "DB.Statement.QueryRowContext" +) + +// LocalType is a type that defines the local storage type of a field value. +// It is used to specify how the field value should be processed locally. +type LocalType string + +const ( + LocalTypeUndefined LocalType = "" + LocalTypeString LocalType = "string" + LocalTypeTime LocalType = "time" + LocalTypeDate LocalType = "date" + LocalTypeDatetime LocalType = "datetime" + LocalTypeInt LocalType = "int" + LocalTypeUint LocalType = "uint" + LocalTypeInt32 LocalType = "int32" + LocalTypeUint32 LocalType = "uint32" + LocalTypeInt64 LocalType = "int64" + LocalTypeUint64 LocalType = "uint64" + LocalTypeBigInt LocalType = "bigint" + LocalTypeIntSlice LocalType = "[]int" + LocalTypeUintSlice LocalType = "[]uint" + LocalTypeInt32Slice LocalType = "[]int32" + LocalTypeUint32Slice LocalType = "[]uint32" + LocalTypeInt64Slice LocalType = "[]int64" + LocalTypeUint64Slice LocalType = "[]uint64" + LocalTypeStringSlice LocalType = "[]string" + LocalTypeInt64Bytes LocalType = "int64-bytes" + LocalTypeUint64Bytes LocalType = "uint64-bytes" + LocalTypeFloat32 LocalType = "float32" + LocalTypeFloat64 LocalType = "float64" + LocalTypeFloat32Slice LocalType = "[]float32" + LocalTypeFloat64Slice LocalType = "[]float64" + LocalTypeBytes LocalType = "[]byte" + LocalTypeBytesSlice LocalType = "[][]byte" + LocalTypeBool LocalType = "bool" + LocalTypeBoolSlice LocalType = "[]bool" + LocalTypeJson LocalType = "json" + LocalTypeJsonb LocalType = "jsonb" + LocalTypeUUID LocalType = "uuid.UUID" + LocalTypeUUIDSlice LocalType = "[]uuid.UUID" +) + +const ( + fieldTypeBinary = "binary" + fieldTypeVarbinary = "varbinary" + fieldTypeBlob = "blob" + fieldTypeTinyblob = "tinyblob" + fieldTypeMediumblob = "mediumblob" + fieldTypeLongblob = "longblob" + fieldTypeInt = "int" + fieldTypeTinyint = "tinyint" + fieldTypeSmallInt = "small_int" + fieldTypeSmallint = "smallint" + fieldTypeMediumInt = "medium_int" + fieldTypeMediumint = "mediumint" + fieldTypeSerial = "serial" + fieldTypeBigInt = "big_int" + fieldTypeBigint = "bigint" + fieldTypeBigserial = "bigserial" + fieldTypeInt128 = "int128" + fieldTypeInt256 = "int256" + fieldTypeUint128 = "uint128" + fieldTypeUint256 = "uint256" + fieldTypeReal = "real" + fieldTypeFloat = "float" + fieldTypeDouble = "double" + fieldTypeDecimal = "decimal" + fieldTypeMoney = "money" + fieldTypeNumeric = "numeric" + fieldTypeSmallmoney = "smallmoney" + fieldTypeBool = "bool" + fieldTypeBit = "bit" + fieldTypeYear = "year" // YYYY + fieldTypeDate = "date" // YYYY-MM-DD + fieldTypeTime = "time" // HH:MM:SS + fieldTypeDatetime = "datetime" // YYYY-MM-DD HH:MM:SS + fieldTypeTimestamp = "timestamp" // YYYYMMDD HHMMSS + fieldTypeTimestampz = "timestamptz" + fieldTypeJson = "json" + fieldTypeJsonb = "jsonb" +) + +var ( + // checker is the checker function for instances map. + checker = func(v DB) bool { return v == nil } + // instances is the management map for instances. + instances = gmap.NewKVMapWithChecker[string, DB](checker, true) + + // driverMap manages all custom registered driver. + driverMap = map[string]Driver{} + + // lastOperatorRegPattern is the regular expression pattern for a string + // which has operator at its tail. + lastOperatorRegPattern = `[<>=]+\s*$` + + // regularFieldNameRegPattern is the regular expression pattern for a string + // which is a regular field name of table. + regularFieldNameRegPattern = `^[\w\.\-]+$` + + // regularFieldNameWithCommaRegPattern is the regular expression pattern for one or more strings + // which are regular field names of table, multiple field names joined with char ','. + regularFieldNameWithCommaRegPattern = `^[\w\.\-,\s]+$` + + // regularFieldNameWithoutDotRegPattern is similar to regularFieldNameRegPattern but not allows '.'. + // Note that, although some databases allow char '.' in the field name, but it here does not allow '.' + // in the field name as it conflicts with "db.table.field" pattern in SOME situations. + regularFieldNameWithoutDotRegPattern = `^[\w\-]+$` + + // allDryRun sets dry-run feature for all database connections. + // It is commonly used for command options for convenience. + allDryRun = false +) + +func init() { + // allDryRun is initialized from environment or command options. + allDryRun = gcmd.GetOptWithEnv(commandEnvKeyForDryRun, false).Bool() +} + +// Register registers custom database driver to gdb. +func Register(name string, driver Driver) error { + driverMap[name] = newDriverWrapper(driver) + return nil +} + +// New creates and returns an ORM object with given configuration node. +func New(node ConfigNode) (db DB, err error) { + return newDBByConfigNode(&node, "") +} + +// NewByGroup creates and returns an ORM object with global configurations. +// The parameter `name` specifies the configuration group name, +// which is DefaultGroupName in default. +func NewByGroup(group ...string) (db DB, err error) { + groupName := configs.group + if len(group) > 0 && group[0] != "" { + groupName = group[0] + } + configs.RLock() + defer configs.RUnlock() + + if len(configs.config) < 1 { + return nil, gerror.NewCode( + gcode.CodeInvalidConfiguration, + "database configuration is empty, please set the database configuration before using", + ) + } + if _, ok := configs.config[groupName]; ok { + var node *ConfigNode + if node, err = getConfigNodeByGroup(groupName, true); err == nil { + return newDBByConfigNode(node, groupName) + } + return nil, err + } + return nil, gerror.NewCodef( + gcode.CodeInvalidConfiguration, + `database configuration node "%s" is not found, did you misspell group name "%s" or miss the database configuration?`, + groupName, groupName, + ) +} + +// linksChecker is the checker function for links map. +var linksChecker = func(v *sql.DB) bool { return v == nil } + +// newDBByConfigNode creates and returns an ORM object with given configuration node and group name. +// +// Very Note: +// The parameter `node` is used for DB creation, not for underlying connection creation. +// So all db type configurations in the same group should be the same. +func newDBByConfigNode(node *ConfigNode, group string) (db DB, err error) { + if node.Link != "" { + node, err = parseConfigNodeLink(node) + if err != nil { + return + } + } + c := &Core{ + group: group, + debug: gtype.NewBool(), + cache: gcache.New(), + links: gmap.NewKVMapWithChecker[ConfigNode, *sql.DB](linksChecker, true), + logger: glog.New(), + config: node, + localTypeMap: gmap.NewStrAnyMap(true), + innerMemCache: gcache.New(), + dynamicConfig: dynamicConfig{ + MaxIdleConnCount: node.MaxIdleConnCount, + MaxOpenConnCount: node.MaxOpenConnCount, + MaxConnLifeTime: node.MaxConnLifeTime, + MaxIdleConnTime: node.MaxIdleConnTime, + }, + } + if v, ok := driverMap[node.Type]; ok { + if c.db, err = v.New(c, node); err != nil { + return nil, err + } + return c.db, nil + } + errorMsg := `cannot find database driver for specified database type "%s"` + errorMsg += `, did you misspell type name "%s" or forget importing the database driver? ` + errorMsg += `possible reference: https://github.com/gogf/gf/tree/master/contrib/drivers` + return nil, gerror.NewCodef(gcode.CodeInvalidConfiguration, errorMsg, node.Type, node.Type) +} + +// Instance returns an instance for DB operations. +// The parameter `name` specifies the configuration group name, +// which is DefaultGroupName in default. +func Instance(name ...string) (db DB, err error) { + group := configs.group + if len(name) > 0 && name[0] != "" { + group = name[0] + } + v := instances.GetOrSetFuncLock(group, func() DB { + db, err = NewByGroup(group) + return db + }) + if v != nil { + return v, nil + } + return nil, err +} + +// getConfigNodeByGroup calculates and returns a configuration node of given group. It +// calculates the value internally using weight algorithm for load balance. +// +// The returned node is a clone of configuration node, which is safe for later modification. +// +// The parameter `master` specifies whether retrieving a master node, or else a slave node +// if master-slave nodes are configured. +func getConfigNodeByGroup(group string, master bool) (*ConfigNode, error) { + if list, ok := configs.config[group]; ok { + // Separates master and slave configuration nodes array. + var ( + masterList = make(ConfigGroup, 0) + slaveList = make(ConfigGroup, 0) + ) + for i := 0; i < len(list); i++ { + if list[i].Role == dbRoleSlave { + slaveList = append(slaveList, list[i]) + } else { + masterList = append(masterList, list[i]) + } + } + if len(masterList) < 1 { + return nil, gerror.NewCode( + gcode.CodeInvalidConfiguration, + "at least one master node configuration's need to make sense", + ) + } + if len(slaveList) < 1 { + slaveList = masterList + } + if master { + return getConfigNodeByWeight(masterList), nil + } else { + return getConfigNodeByWeight(slaveList), nil + } + } + return nil, gerror.NewCodef( + gcode.CodeInvalidConfiguration, + "empty database configuration for item name '%s'", + group, + ) +} + +// getConfigNodeByWeight calculates the configuration weights and randomly returns a node. +// The returned node is a clone of configuration node, which is safe for later modification. +// +// Calculation algorithm brief: +// 1. If we have 2 nodes, and their weights are both 1, then the weight range is [0, 199]; +// 2. Node1 weight range is [0, 99], and node2 weight range is [100, 199], ratio is 1:1; +// 3. If the random number is 99, it then chooses and returns node1;. +func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode { + if len(cg) < 2 { + return &cg[0] + } + var total int + for i := 0; i < len(cg); i++ { + total += cg[i].Weight * 100 + } + // If total is 0 means all the nodes have no weight attribute configured. + // It then defaults each node's weight attribute to 1. + if total == 0 { + for i := 0; i < len(cg); i++ { + cg[i].Weight = 1 + total += cg[i].Weight * 100 + } + } + // Exclude the right border value. + var ( + minWeight = 0 + maxWeight = 0 + random = grand.N(0, total-1) + ) + for i := 0; i < len(cg); i++ { + maxWeight = minWeight + cg[i].Weight*100 + if random >= minWeight && random < maxWeight { + // ==================================================== + // Return a COPY of the ConfigNode. + // ==================================================== + node := ConfigNode{} + node = cg[i] + return &node + } + minWeight = maxWeight + } + return nil +} + +// getSqlDb retrieves and returns an underlying database connection object. +// The parameter `master` specifies whether retrieves master node connection if +// master-slave nodes are configured. +func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) { + var ( + node *ConfigNode + ctx = c.db.GetCtx() + ) + if c.group != "" { + // Load balance. + configs.RLock() + defer configs.RUnlock() + // Value COPY for node. + // The returned node is a clone of configuration node, which is safe for later modification. + node, err = getConfigNodeByGroup(c.group, master) + if err != nil { + return nil, err + } + } else { + // Value COPY for node. + n := *c.db.GetConfig() + node = &n + } + if node.Charset == "" { + node.Charset = defaultCharset + } + // Changes the schema. + nodeSchema := gutil.GetOrDefaultStr(c.schema, schema...) + if nodeSchema != "" { + node.Name = nodeSchema + } + // Update the configuration object in internal data. + if err = c.setConfigNodeToCtx(ctx, node); err != nil { + return + } + + // Cache the underlying connection pool object by node. + var ( + instanceCacheFunc = func() *sql.DB { + if sqlDb, err = c.db.Open(node); err != nil { + return nil + } + if sqlDb == nil { + return nil + } + if c.dynamicConfig.MaxIdleConnCount > 0 { + sqlDb.SetMaxIdleConns(c.dynamicConfig.MaxIdleConnCount) + } else { + sqlDb.SetMaxIdleConns(defaultMaxIdleConnCount) + } + if c.dynamicConfig.MaxOpenConnCount > 0 { + sqlDb.SetMaxOpenConns(c.dynamicConfig.MaxOpenConnCount) + } else { + sqlDb.SetMaxOpenConns(defaultMaxOpenConnCount) + } + if c.dynamicConfig.MaxConnLifeTime > 0 { + sqlDb.SetConnMaxLifetime(c.dynamicConfig.MaxConnLifeTime) + } else { + sqlDb.SetConnMaxLifetime(defaultMaxConnLifeTime) + } + if c.dynamicConfig.MaxIdleConnTime > 0 { + sqlDb.SetConnMaxIdleTime(c.dynamicConfig.MaxIdleConnTime) + } + return sqlDb + } + // it here uses NODE VALUE not pointer as the cache key, in case of oracle ORA-12516 error. + instanceValue = c.links.GetOrSetFuncLock(*node, instanceCacheFunc) + ) + if instanceValue != nil && sqlDb == nil { + // It reads from instance map. + sqlDb = instanceValue + } + if node.Debug { + c.db.SetDebug(node.Debug) + } + if node.DryRun { + c.db.SetDryRun(node.DryRun) + } + return +} diff --git a/database/gdb_converter.go b/database/gdb_converter.go new file mode 100644 index 0000000..129102b --- /dev/null +++ b/database/gdb_converter.go @@ -0,0 +1,82 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "reflect" + + "git.magicany.cc/black1552/gin-base/database/json" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/util/gconv" +) + +// iVal is used for type assert api for Val(). +type iVal interface { + Val() any +} + +var ( + // converter is the internal type converter for gdb. + converter = gconv.NewConverter() +) + +func init() { + converter.RegisterAnyConverterFunc( + sliceTypeConverterFunc, + reflect.TypeOf([]string{}), + reflect.TypeOf([]float32{}), + reflect.TypeOf([]float64{}), + reflect.TypeOf([]int{}), + reflect.TypeOf([]int32{}), + reflect.TypeOf([]int64{}), + reflect.TypeOf([]uint{}), + reflect.TypeOf([]uint32{}), + reflect.TypeOf([]uint64{}), + ) +} + +// GetConverter returns the internal type converter for gdb. +func GetConverter() gconv.Converter { + return converter +} + +func sliceTypeConverterFunc(from any, to reflect.Value) (err error) { + v, ok := from.(iVal) + if !ok { + return nil + } + fromVal := v.Val() + switch x := fromVal.(type) { + case []byte: + dst := to.Addr().Interface() + err = json.Unmarshal(x, dst) + case string: + dst := to.Addr().Interface() + err = json.Unmarshal([]byte(x), dst) + default: + fromType := reflect.TypeOf(fromVal) + switch fromType.Kind() { + case reflect.Slice: + convertOption := gconv.ConvertOption{ + SliceOption: gconv.SliceOption{ContinueOnError: true}, + MapOption: gconv.MapOption{ContinueOnError: true}, + StructOption: gconv.StructOption{ContinueOnError: true}, + } + dv, err := converter.ConvertWithTypeName(fromVal, to.Type().String(), convertOption) + if err != nil { + return err + } + to.Set(reflect.ValueOf(dv)) + default: + err = gerror.Newf( + `unsupported type converting from type "%T" to type "%T"`, + fromVal, to, + ) + } + } + return err +} diff --git a/database/gdb_core.go b/database/gdb_core.go new file mode 100644 index 0000000..83d35ea --- /dev/null +++ b/database/gdb_core.go @@ -0,0 +1,841 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. +// + +package database + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "sort" + "strings" + + "git.magicany.cc/black1552/gin-base/database/intlog" + "git.magicany.cc/black1552/gin-base/database/reflection" + "git.magicany.cc/black1552/gin-base/database/utils" + "github.com/gogf/gf/v2/container/gmap" + "github.com/gogf/gf/v2/container/gset" + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gcache" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/gutil" +) + +// GetCore returns the underlying *Core object. +func (c *Core) GetCore() *Core { + return c +} + +// Ctx is a chaining function, which creates and returns a new DB that is a shallow copy +// of current DB object and with given context in it. +// Note that this returned DB object can be used only once, so do not assign it to +// a global or package variable for long using. +func (c *Core) Ctx(ctx context.Context) DB { + if ctx == nil { + return c.db + } + // It makes a shallow copy of current db and changes its context for next chaining operation. + var ( + err error + newCore = &Core{} + configNode = c.db.GetConfig() + ) + *newCore = *c + // It creates a new DB object(NOT NEW CONNECTION), which is commonly a wrapper for object `Core`. + newCore.db, err = driverMap[configNode.Type].New(newCore, configNode) + if err != nil { + // It is really a serious error here. + // Do not let it continue. + panic(err) + } + newCore.ctx = WithDB(ctx, newCore.db) + newCore.ctx = c.injectInternalCtxData(newCore.ctx) + return newCore.db +} + +// GetCtx returns the context for current DB. +// It returns `context.Background()` is there's no context previously set. +func (c *Core) GetCtx() context.Context { + ctx := c.ctx + if ctx == nil { + ctx = context.TODO() + } + return c.injectInternalCtxData(ctx) +} + +// GetCtxTimeout returns the context and cancel function for specified timeout type. +func (c *Core) GetCtxTimeout(ctx context.Context, timeoutType ctxTimeoutType) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = c.db.GetCtx() + } else { + ctx = context.WithValue(ctx, ctxKeyWrappedByGetCtxTimeout, nil) + } + var config = c.db.GetConfig() + switch timeoutType { + case ctxTimeoutTypeExec: + if c.db.GetConfig().ExecTimeout > 0 { + return context.WithTimeout(ctx, config.ExecTimeout) + } + case ctxTimeoutTypeQuery: + if c.db.GetConfig().QueryTimeout > 0 { + return context.WithTimeout(ctx, config.QueryTimeout) + } + case ctxTimeoutTypePrepare: + if c.db.GetConfig().PrepareTimeout > 0 { + return context.WithTimeout(ctx, config.PrepareTimeout) + } + case ctxTimeoutTypeTrans: + if c.db.GetConfig().TranTimeout > 0 { + return context.WithTimeout(ctx, config.TranTimeout) + } + default: + panic(gerror.NewCodef(gcode.CodeInvalidParameter, "invalid context timeout type: %d", timeoutType)) + } + return ctx, func() {} +} + +// Close closes the database and prevents new queries from starting. +// Close then waits for all queries that have started processing on the server +// to finish. +// +// It is rare to Close a DB, as the DB handle is meant to be +// long-lived and shared between many goroutines. +func (c *Core) Close(ctx context.Context) (err error) { + if err = c.cache.Close(ctx); err != nil { + return err + } + c.links.LockFunc(func(m map[ConfigNode]*sql.DB) { + for k, v := range m { + err = v.Close() + if err != nil { + err = gerror.WrapCode(gcode.CodeDbOperationError, err, `db.Close failed`) + } + intlog.Printf(ctx, `close link: %s, err: %v`, gconv.String(k), err) + if err != nil { + return + } + delete(m, k) + } + }) + return +} + +// Master creates and returns a connection from master node if master-slave configured. +// It returns the default connection if master-slave not configured. +func (c *Core) Master(schema ...string) (*sql.DB, error) { + var ( + usedSchema = gutil.GetOrDefaultStr(c.schema, schema...) + charL, charR = c.db.GetChars() + ) + return c.getSqlDb(true, gstr.Trim(usedSchema, charL+charR)) +} + +// Slave creates and returns a connection from slave node if master-slave configured. +// It returns the default connection if master-slave not configured. +func (c *Core) Slave(schema ...string) (*sql.DB, error) { + var ( + usedSchema = gutil.GetOrDefaultStr(c.schema, schema...) + charL, charR = c.db.GetChars() + ) + return c.getSqlDb(false, gstr.Trim(usedSchema, charL+charR)) +} + +// GetAll queries and returns data records from database. +func (c *Core) GetAll(ctx context.Context, sql string, args ...any) (Result, error) { + return c.db.DoSelect(ctx, nil, sql, args...) +} + +// DoSelect queries and returns data records from database. +func (c *Core) DoSelect(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) { + return c.db.DoQuery(ctx, link, sql, args...) +} + +// GetOne queries and returns one record from database. +func (c *Core) GetOne(ctx context.Context, sql string, args ...any) (Record, error) { + list, err := c.db.GetAll(ctx, sql, args...) + if err != nil { + return nil, err + } + if len(list) > 0 { + return list[0], nil + } + return nil, nil +} + +// GetArray queries and returns data values as slice from database. +// Note that if there are multiple columns in the result, it returns just one column values randomly. +func (c *Core) GetArray(ctx context.Context, sql string, args ...any) (Array, error) { + all, err := c.db.DoSelect(ctx, nil, sql, args...) + if err != nil { + return nil, err + } + return all.Array(), nil +} + +// doGetStruct queries one record from database and converts it to given struct. +// The parameter `pointer` should be a pointer to struct. +func (c *Core) doGetStruct(ctx context.Context, pointer any, sql string, args ...any) error { + one, err := c.db.GetOne(ctx, sql, args...) + if err != nil { + return err + } + return one.Struct(pointer) +} + +// doGetStructs queries records from database and converts them to given struct. +// The parameter `pointer` should be type of struct slice: []struct/[]*struct. +func (c *Core) doGetStructs(ctx context.Context, pointer any, sql string, args ...any) error { + all, err := c.db.GetAll(ctx, sql, args...) + if err != nil { + return err + } + return all.Structs(pointer) +} + +// GetScan queries one or more records from database and converts them to given struct or +// struct array. +// +// If parameter `pointer` is type of struct pointer, it calls GetStruct internally for +// the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally +// for conversion. +func (c *Core) GetScan(ctx context.Context, pointer any, sql string, args ...any) error { + reflectInfo := reflection.OriginTypeAndKind(pointer) + if reflectInfo.InputKind != reflect.Pointer { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "params should be type of pointer, but got: %v", + reflectInfo.InputKind, + ) + } + switch reflectInfo.OriginKind { + case reflect.Array, reflect.Slice: + return c.db.GetCore().doGetStructs(ctx, pointer, sql, args...) + + case reflect.Struct: + return c.db.GetCore().doGetStruct(ctx, pointer, sql, args...) + + default: + } + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `in valid parameter type "%v", of which element type should be type of struct/slice`, + reflectInfo.InputType, + ) +} + +// GetValue queries and returns the field value from database. +// The sql should query only one field from database, or else it returns only one +// field of the result. +func (c *Core) GetValue(ctx context.Context, sql string, args ...any) (Value, error) { + one, err := c.db.GetOne(ctx, sql, args...) + if err != nil { + return gvar.New(nil), err + } + for _, v := range one { + return v, nil + } + return gvar.New(nil), nil +} + +// GetCount queries and returns the count from database. +func (c *Core) GetCount(ctx context.Context, sql string, args ...any) (int, error) { + // If the query fields do not contain function "COUNT", + // it replaces the sql string and adds the "COUNT" function to the fields. + if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) { + sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql) + } + value, err := c.db.GetValue(ctx, sql, args...) + if err != nil { + return 0, err + } + return value.Int(), nil +} + +// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement. +func (c *Core) Union(unions ...*Model) *Model { + var ctx = c.db.GetCtx() + return c.doUnion(ctx, unionTypeNormal, unions...) +} + +// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement. +func (c *Core) UnionAll(unions ...*Model) *Model { + var ctx = c.db.GetCtx() + return c.doUnion(ctx, unionTypeAll, unions...) +} + +func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Model { + var ( + unionTypeStr string + composedSqlStr string + composedArgs = make([]any, 0) + ) + if unionType == unionTypeAll { + unionTypeStr = "UNION ALL" + } else { + unionTypeStr = "UNION" + } + for _, v := range unions { + sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false) + if composedSqlStr == "" { + composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder) + } else { + composedSqlStr += fmt.Sprintf(` %s (%s)`, unionTypeStr, sqlWithHolder) + } + composedArgs = append(composedArgs, holderArgs...) + } + return c.db.Raw(composedSqlStr, composedArgs...) +} + +// PingMaster pings the master node to check authentication or keeps the connection alive. +func (c *Core) PingMaster() error { + var ctx = c.db.GetCtx() + if master, err := c.db.Master(); err != nil { + return err + } else { + if err = master.PingContext(ctx); err != nil { + err = gerror.WrapCode(gcode.CodeDbOperationError, err, `master.Ping failed`) + } + return err + } +} + +// PingSlave pings the slave node to check authentication or keeps the connection alive. +func (c *Core) PingSlave() error { + var ctx = c.db.GetCtx() + if slave, err := c.db.Slave(); err != nil { + return err + } else { + if err = slave.PingContext(ctx); err != nil { + err = gerror.WrapCode(gcode.CodeDbOperationError, err, `slave.Ping failed`) + } + return err + } +} + +// Insert does "INSERT INTO ..." statement for the table. +// If there's already one unique record of the data in the table, it returns error. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `batch` specifies the batch operation count when given data is slice. +func (c *Core) Insert(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Insert() + } + return c.Model(table).Ctx(ctx).Data(data).Insert() +} + +// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table. +// If there's already one unique record of the data in the table, it ignores the inserting. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `batch` specifies the batch operation count when given data is slice. +func (c *Core) InsertIgnore(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).InsertIgnore() + } + return c.Model(table).Ctx(ctx).Data(data).InsertIgnore() +} + +// InsertAndGetId performs action Insert and returns the last insert id that automatically generated. +func (c *Core) InsertAndGetId(ctx context.Context, table string, data any, batch ...int) (int64, error) { + if len(batch) > 0 { + return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).InsertAndGetId() + } + return c.Model(table).Ctx(ctx).Data(data).InsertAndGetId() +} + +// Replace does "REPLACE INTO ..." statement for the table. +// If there's already one unique record of the data in the table, it deletes the record +// and inserts a new one. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// If given data is type of slice, it then does batch replacing, and the optional parameter +// `batch` specifies the batch operation count. +func (c *Core) Replace(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Replace() + } + return c.Model(table).Ctx(ctx).Data(data).Replace() +} + +// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table. +// It updates the record if there's primary or unique index in the saving data, +// or else it inserts a new record into the table. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// If given data is type of slice, it then does batch saving, and the optional parameter +// `batch` specifies the batch operation count. +func (c *Core) Save(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Save() + } + return c.Model(table).Ctx(ctx).Data(data).Save() +} + +func (c *Core) fieldsToSequence(ctx context.Context, table string, fields []string) ([]string, error) { + var ( + fieldSet = gset.NewStrSetFrom(fields) + fieldsResultInSequence = make([]string, 0) + tableFields, err = c.db.TableFields(ctx, table) + ) + if err != nil { + return nil, err + } + // Sort the fields in order. + var fieldsOfTableInSequence = make([]string, len(tableFields)) + for _, field := range tableFields { + fieldsOfTableInSequence[field.Index] = field.Name + } + // Sort the input fields. + for _, fieldName := range fieldsOfTableInSequence { + if fieldSet.Contains(fieldName) { + fieldsResultInSequence = append(fieldsResultInSequence, fieldName) + } + } + return fieldsResultInSequence, nil +} + +// DoInsert inserts or updates data for given table. +// This function is usually used for custom interface definition, you do not need call it manually. +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `option` values are as follows: +// InsertOptionDefault: just insert, if there's unique/primary key in the data, it returns error; +// InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one; +// InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one; +// InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting; +func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) { + var ( + keys []string // Field names. + values []string // Value holder string array, like: (?,?,?) + params []any // Values that will be committed to underlying database driver. + onDuplicateStr string // onDuplicateStr is used in "ON DUPLICATE KEY UPDATE" statement. + ) + // ============================================================================================ + // Group the list by fields. Different fields to different list. + // It here uses ListMap to keep sequence for data inserting. + // ============================================================================================ + var ( + keyListMap = gmap.NewListMap() + tmpKeyListMap = make(map[string]List) + ) + for _, item := range list { + mapLen := len(item) + if mapLen == 0 { + continue + } + tmpKeys := make([]string, 0, mapLen) + for k := range item { + tmpKeys = append(tmpKeys, k) + } + if mapLen > 1 { + sort.Strings(tmpKeys) + } + keys = tmpKeys // for fieldsToSequence + + tmpKeysInSequenceStr := gstr.Join(tmpKeys, ",") + if tmpKeyListMapItem, ok := tmpKeyListMap[tmpKeysInSequenceStr]; ok { + tmpKeyListMap[tmpKeysInSequenceStr] = append(tmpKeyListMapItem, item) + } else { + tmpKeyListMap[tmpKeysInSequenceStr] = List{item} + } + } + for tmpKeysInSequenceStr, itemList := range tmpKeyListMap { + keyListMap.Set(tmpKeysInSequenceStr, itemList) + } + if keyListMap.Size() > 1 { + var ( + tmpResult sql.Result + sqlResult SqlResult + rowsAffected int64 + ) + keyListMap.Iterator(func(key, value any) bool { + tmpResult, err = c.DoInsert(ctx, link, table, value.(List), option) + if err != nil { + return false + } + rowsAffected, err = tmpResult.RowsAffected() + if err != nil { + return false + } + sqlResult.Result = tmpResult + sqlResult.Affected += rowsAffected + return true + }) + return &sqlResult, err + } + + keys, err = c.fieldsToSequence(ctx, table, keys) + if err != nil { + return nil, err + } + + if len(keys) == 0 { + return nil, gerror.NewCode(gcode.CodeInvalidParameter, "no valid data fields found in table") + } + + // Prepare the batch result pointer. + var ( + charL, charR = c.db.GetChars() + batchResult = new(SqlResult) + keysStr = charL + strings.Join(keys, charR+","+charL) + charR + operation = GetInsertOperationByOption(option.InsertOption) + ) + // Upsert clause only takes effect on Save operation. + if option.InsertOption == InsertOptionSave { + onDuplicateStr, err = c.db.FormatUpsert(keys, list, option) + if err != nil { + return nil, err + } + } + var ( + listLength = len(list) + valueHolders = make([]string, 0) + ) + for i := 0; i < listLength; i++ { + values = values[:0] + // Note that the map type is unordered, + // so it should use slice+key to retrieve the value. + for _, k := range keys { + if s, ok := list[i][k].(Raw); ok { + values = append(values, gconv.String(s)) + } else { + values = append(values, "?") + params = append(params, list[i][k]) + } + } + valueHolders = append(valueHolders, "("+gstr.Join(values, ",")+")") + // Batch package checks: It meets the batch number, or it is the last element. + if len(valueHolders) == option.BatchCount || (i == listLength-1 && len(valueHolders) > 0) { + var ( + stdSqlResult sql.Result + affectedRows int64 + ) + stdSqlResult, err = c.db.DoExec(ctx, link, fmt.Sprintf( + "%s INTO %s(%s) VALUES%s %s", + operation, c.QuotePrefixTableName(table), keysStr, + gstr.Join(valueHolders, ","), + onDuplicateStr, + ), params...) + if err != nil { + return stdSqlResult, err + } + if affectedRows, err = stdSqlResult.RowsAffected(); err != nil { + err = gerror.WrapCode(gcode.CodeDbOperationError, err, `sql.Result.RowsAffected failed`) + return stdSqlResult, err + } else { + batchResult.Result = stdSqlResult + batchResult.Affected += affectedRows + } + params = params[:0] + valueHolders = valueHolders[:0] + } + } + return batchResult, nil +} + +// Update does "UPDATE ... " statement for the table. +// +// The parameter `data` can be type of string/map/gmap/struct/*struct, etc. +// Eg: "uid=10000", "uid", 10000, g.Map{"uid": 10000, "name":"john"} +// +// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc. +// It is commonly used with parameter `args`. +// Eg: +// "uid=10000", +// "uid", 10000 +// "money>? AND name like ?", 99999, "vip_%" +// "status IN (?)", g.Slice{1,2,3} +// "age IN(?,?)", 18, 50 +// User{ Id : 1, UserName : "john"}. +func (c *Core) Update(ctx context.Context, table string, data any, condition any, args ...any) (sql.Result, error) { + return c.Model(table).Ctx(ctx).Data(data).Where(condition, args...).Update() +} + +// DoUpdate does "UPDATE ... " statement for the table. +// This function is usually used for custom interface definition, you do not need to call it manually. +func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data any, condition string, args ...any) (result sql.Result, err error) { + table = c.QuotePrefixTableName(table) + var ( + rv = reflect.ValueOf(data) + kind = rv.Kind() + ) + if kind == reflect.Pointer { + rv = rv.Elem() + kind = rv.Kind() + } + var ( + params []any + updates string + ) + switch kind { + case reflect.Map, reflect.Struct: + var ( + fields []string + dataMap map[string]any + ) + dataMap, err = c.ConvertDataForRecord(ctx, data, table) + if err != nil { + return nil, err + } + // Sort the data keys in sequence of table fields. + var ( + dataKeys = make([]string, 0) + keysInSequence = make([]string, 0) + ) + for k := range dataMap { + dataKeys = append(dataKeys, k) + } + keysInSequence, err = c.fieldsToSequence(ctx, table, dataKeys) + if err != nil { + return nil, err + } + for _, k := range keysInSequence { + v := dataMap[k] + switch v.(type) { + case Counter, *Counter: + var counter Counter + switch value := v.(type) { + case Counter: + counter = value + case *Counter: + counter = *value + } + if counter.Value == 0 { + continue + } + operator, columnVal := c.getCounterAlter(counter) + fields = append(fields, fmt.Sprintf("%s=%s%s?", c.QuoteWord(k), c.QuoteWord(counter.Field), operator)) + params = append(params, columnVal) + default: + if s, ok := v.(Raw); ok { + fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s)) + } else { + fields = append(fields, c.QuoteWord(k)+"=?") + params = append(params, v) + } + } + } + updates = strings.Join(fields, ",") + + default: + updates = gconv.String(data) + } + if len(updates) == 0 { + return nil, gerror.NewCode(gcode.CodeMissingParameter, "data cannot be empty") + } + if len(params) > 0 { + args = append(params, args...) + } + // If no link passed, it then uses the master link. + if link == nil { + if link, err = c.MasterLink(); err != nil { + return nil, err + } + } + return c.db.DoExec(ctx, link, fmt.Sprintf( + "UPDATE %s SET %s%s", + table, updates, condition, + ), + args..., + ) +} + +// Delete does "DELETE FROM ... " statement for the table. +// +// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc. +// It is commonly used with parameter `args`. +// Eg: +// "uid=10000", +// "uid", 10000 +// "money>? AND name like ?", 99999, "vip_%" +// "status IN (?)", g.Slice{1,2,3} +// "age IN(?,?)", 18, 50 +// User{ Id : 1, UserName : "john"}. +func (c *Core) Delete(ctx context.Context, table string, condition any, args ...any) (result sql.Result, err error) { + return c.Model(table).Ctx(ctx).Where(condition, args...).Delete() +} + +// DoDelete does "DELETE FROM ... " statement for the table. +// This function is usually used for custom interface definition, you do not need call it manually. +func (c *Core) DoDelete(ctx context.Context, link Link, table string, condition string, args ...any) (result sql.Result, err error) { + if link == nil { + if link, err = c.MasterLink(); err != nil { + return nil, err + } + } + table = c.QuotePrefixTableName(table) + return c.db.DoExec(ctx, link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...) +} + +// FilteredLink retrieves and returns filtered `linkInfo` that can be using for +// logging or tracing purpose. +func (c *Core) FilteredLink() string { + return fmt.Sprintf( + `%s@%s(%s:%s)/%s`, + c.config.User, c.config.Protocol, c.config.Host, c.config.Port, c.config.Name, + ) +} + +// MarshalJSON implements the interface MarshalJSON for json.Marshal. +// It just returns the pointer address. +// +// Note that this interface implements mainly for workaround for a json infinite loop bug +// of Golang version < v1.14. +func (c *Core) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`%+v`, c)), nil +} + +// writeSqlToLogger outputs the Sql object to logger. +// It is enabled only if configuration "debug" is true. +func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) { + var transactionIdStr string + if sql.IsTransaction { + if v := ctx.Value(transactionIdForLoggerCtx); v != nil { + transactionIdStr = fmt.Sprintf(`[txid:%d] `, v.(uint64)) + } + } + s := fmt.Sprintf( + "[%3d ms] [%s] [%s] [rows:%-3d] %s%s", + sql.End-sql.Start, sql.Group, sql.Schema, sql.RowsAffected, transactionIdStr, sql.Format, + ) + if sql.Error != nil { + s += "\nError: " + sql.Error.Error() + c.logger.Error(ctx, s) + } else { + c.logger.Debug(ctx, s) + } +} + +// HasTable determine whether the table name exists in the database. +func (c *Core) HasTable(name string) (bool, error) { + tables, err := c.GetTablesWithCache() + if err != nil { + return false, err + } + charL, charR := c.db.GetChars() + name = gstr.Trim(name, charL+charR) + for _, table := range tables { + if table == name { + return true, nil + } + } + return false, nil +} + +// GetInnerMemCache retrieves and returns the inner memory cache object. +func (c *Core) GetInnerMemCache() *gcache.Cache { + return c.innerMemCache +} + +func (c *Core) SetTableFields(ctx context.Context, table string, fields map[string]*TableField, schema ...string) error { + if table == "" { + return gerror.NewCode(gcode.CodeInvalidParameter, "table name cannot be empty") + } + charL, charR := c.db.GetChars() + table = gstr.Trim(table, charL+charR) + if gstr.Contains(table, " ") { + return gerror.NewCode( + gcode.CodeInvalidParameter, + "function TableFields supports only single table operations", + ) + } + var ( + innerMemCache = c.GetInnerMemCache() + // prefix:group@schema#table + cacheKey = genTableFieldsCacheKey( + c.db.GetGroup(), + gutil.GetOrDefaultStr(c.db.GetSchema(), schema...), + table, + ) + ) + return innerMemCache.Set(ctx, cacheKey, fields, gcache.DurationNoExpire) +} + +// GetTablesWithCache retrieves and returns the table names of current database with cache. +func (c *Core) GetTablesWithCache() ([]string, error) { + var ( + ctx = c.db.GetCtx() + cacheKey = genTableNamesCacheKey(c.db.GetGroup()) + cacheDuration = gcache.DurationNoExpire + innerMemCache = c.GetInnerMemCache() + ) + result, err := innerMemCache.GetOrSetFuncLock( + ctx, cacheKey, + func(ctx context.Context) (any, error) { + tableList, err := c.db.Tables(ctx) + if err != nil { + return nil, err + } + return tableList, nil + }, cacheDuration, + ) + if err != nil { + return nil, err + } + return result.Strings(), nil +} + +// IsSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time. +func (c *Core) IsSoftCreatedFieldName(fieldName string) bool { + if fieldName == "" { + return false + } + if config := c.db.GetConfig(); config.CreatedAt != "" { + if utils.EqualFoldWithoutChars(fieldName, config.CreatedAt) { + return true + } + return gstr.InArray(append([]string{config.CreatedAt}, createdFieldNames...), fieldName) + } + for _, v := range createdFieldNames { + if utils.EqualFoldWithoutChars(fieldName, v) { + return true + } + } + return false +} + +// FormatSqlBeforeExecuting formats the sql string and its arguments before executing. +// The internal handleArguments function might be called twice during the SQL procedure, +// but do not worry about it, it's safe and efficient. +func (c *Core) FormatSqlBeforeExecuting(sql string, args []any) (newSql string, newArgs []any) { + return handleSliceAndStructArgsForSql(sql, args) +} + +// getCounterAlter +func (c *Core) getCounterAlter(counter Counter) (operator string, columnVal float64) { + operator, columnVal = "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + return +} diff --git a/database/gdb_core_config.go b/database/gdb_core_config.go new file mode 100644 index 0000000..e1d01d0 --- /dev/null +++ b/database/gdb_core_config.go @@ -0,0 +1,485 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "fmt" + "sync" + "time" + + "git.magicany.cc/black1552/gin-base/log" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gcache" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// Config is the configuration management object. +type Config map[string]ConfigGroup + +// ConfigGroup is a slice of configuration node for specified named group. +type ConfigGroup []ConfigNode + +// ConfigNode is configuration for one node. +type ConfigNode struct { + // Host specifies the server address, can be either IP address or domain name + // Example: "127.0.0.1", "localhost" + Host string `json:"host"` + + // Port specifies the server port number + // Default is typically "3306" for MySQL + Port string `json:"port"` + + // User specifies the authentication username for database connection + User string `json:"user"` + + // Pass specifies the authentication password for database connection + Pass string `json:"pass"` + + // Name specifies the default database name to be used + Name string `json:"name"` + + // Type specifies the database type + // Example: mysql, mariadb, sqlite, mssql, pgsql, oracle, clickhouse, dm. + Type string `json:"type"` + + // Link provides custom connection string that combines all configuration in one string + // Optional field + Link string `json:"link"` + + // Extra provides additional configuration options for third-party database drivers + // Optional field + Extra string `json:"extra"` + + // Role specifies the node role in master-slave setup + // Optional field, defaults to "master" + // Available values: "master", "slave" + Role Role `json:"role"` + + // Debug enables debug mode for logging and output + // Optional field + Debug bool `json:"debug"` + + // Prefix specifies the table name prefix + // Optional field + Prefix string `json:"prefix"` + + // DryRun enables simulation mode where SELECT statements are executed + // but INSERT/UPDATE/DELETE statements are not + // Optional field + DryRun bool `json:"dryRun"` + + // Weight specifies the node weight for load balancing calculations + // Optional field, only effective in multi-node setups + Weight int `json:"weight"` + + // Charset specifies the character set for database operations + // Optional field, defaults to "utf8" + Charset string `json:"charset"` + + // Protocol specifies the network protocol for database connection + // Optional field, defaults to "tcp" + // See net.Dial for available network protocols + Protocol string `json:"protocol"` + + // Timezone sets the time zone for timestamp interpretation and display + // Optional field + Timezone string `json:"timezone"` + + // Namespace specifies the schema namespace for certain databases + // Optional field, e.g., in PostgreSQL, Name is the catalog and Namespace is the schema + Namespace string `json:"namespace"` + + // MaxIdleConnCount specifies the maximum number of idle connections in the pool + // Optional field + MaxIdleConnCount int `json:"maxIdle"` + + // MaxOpenConnCount specifies the maximum number of open connections in the pool + // Optional field + MaxOpenConnCount int `json:"maxOpen"` + + // MaxConnLifeTime specifies the maximum lifetime of a connection + // Optional field + MaxConnLifeTime time.Duration `json:"maxLifeTime"` + + // MaxIdleConnTime specifies the maximum idle time of a connection before being closed + // This is Go 1.15+ feature: sql.DB.SetConnMaxIdleTime + // Optional field + MaxIdleConnTime time.Duration `json:"maxIdleTime"` + + // QueryTimeout specifies the maximum execution time for DQL operations + // Optional field + QueryTimeout time.Duration `json:"queryTimeout"` + + // ExecTimeout specifies the maximum execution time for DML operations + // Optional field + ExecTimeout time.Duration `json:"execTimeout"` + + // TranTimeout specifies the maximum execution time for a transaction block + // Optional field + TranTimeout time.Duration `json:"tranTimeout"` + + // PrepareTimeout specifies the maximum execution time for prepare operations + // Optional field + PrepareTimeout time.Duration `json:"prepareTimeout"` + + // CreatedAt specifies the field name for automatic timestamp on record creation + // Optional field + CreatedAt string `json:"createdAt"` + + // UpdatedAt specifies the field name for automatic timestamp on record updates + // Optional field + UpdatedAt string `json:"updatedAt"` + + // DeletedAt specifies the field name for automatic timestamp on record deletion + // Optional field + DeletedAt string `json:"deletedAt"` + + // TimeMaintainDisabled controls whether automatic time maintenance is disabled + // Optional field + TimeMaintainDisabled bool `json:"timeMaintainDisabled"` +} + +type Role string + +const ( + RoleMaster Role = "master" + RoleSlave Role = "slave" +) + +const ( + DefaultGroupName = "default" // Default group name. +) + +// configs specifies internal used configuration object. +var configs struct { + sync.RWMutex + config Config // All configurations. + group string // Default configuration group. +} + +func init() { + configs.config = make(Config) + configs.group = DefaultGroupName +} + +// SetConfig sets the global configuration for package. +// It will overwrite the old configuration of package. +func SetConfig(config Config) error { + defer instances.Clear() + configs.Lock() + defer configs.Unlock() + + for k, nodes := range config { + for i, node := range nodes { + parsedNode, err := parseConfigNode(node) + if err != nil { + return err + } + nodes[i] = parsedNode + } + config[k] = nodes + } + configs.config = config + return nil +} + +// SetConfigGroup sets the configuration for given group. +func SetConfigGroup(group string, nodes ConfigGroup) error { + defer instances.Clear() + configs.Lock() + defer configs.Unlock() + + for i, node := range nodes { + parsedNode, err := parseConfigNode(node) + if err != nil { + return err + } + nodes[i] = parsedNode + } + configs.config[group] = nodes + return nil +} + +// AddConfigNode adds one node configuration to configuration of given group. +func AddConfigNode(group string, node ConfigNode) error { + defer instances.Clear() + configs.Lock() + defer configs.Unlock() + + parsedNode, err := parseConfigNode(node) + if err != nil { + return err + } + configs.config[group] = append(configs.config[group], parsedNode) + return nil +} + +// parseConfigNode parses `Link` configuration syntax. +func parseConfigNode(node ConfigNode) (ConfigNode, error) { + if node.Link != "" { + parsedLinkNode, err := parseConfigNodeLink(&node) + if err != nil { + return node, err + } + node = *parsedLinkNode + } + if node.Link != "" && node.Type == "" { + match, _ := gregex.MatchString(`([a-z]+):(.+)`, node.Link) + if len(match) == 3 { + node.Type = gstr.Trim(match[1]) + node.Link = gstr.Trim(match[2]) + } + } + return node, nil +} + +// AddDefaultConfigNode adds one node configuration to configuration of default group. +func AddDefaultConfigNode(node ConfigNode) error { + return AddConfigNode(DefaultGroupName, node) +} + +// AddDefaultConfigGroup adds multiple node configurations to configuration of default group. +// +// Deprecated: Use SetDefaultConfigGroup instead. +func AddDefaultConfigGroup(nodes ConfigGroup) error { + return SetConfigGroup(DefaultGroupName, nodes) +} + +// SetDefaultConfigGroup sets multiple node configurations to configuration of default group. +func SetDefaultConfigGroup(nodes ConfigGroup) error { + return SetConfigGroup(DefaultGroupName, nodes) +} + +// GetConfig retrieves and returns the configuration of given group. +// +// Deprecated: Use GetConfigGroup instead. +func GetConfig(group string) ConfigGroup { + configGroup, _ := GetConfigGroup(group) + return configGroup +} + +// GetConfigGroup retrieves and returns the configuration of given group. +// It returns an error if the group does not exist, or an empty slice if the group exists but has no nodes. +func GetConfigGroup(group string) (ConfigGroup, error) { + configs.RLock() + defer configs.RUnlock() + + configGroup, exists := configs.config[group] + if !exists { + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `configuration group "%s" not found`, + group, + ) + } + return configGroup, nil +} + +// GetAllConfig retrieves and returns all configurations. +func GetAllConfig() Config { + configs.RLock() + defer configs.RUnlock() + return configs.config +} + +// SetDefaultGroup sets the group name for default configuration. +func SetDefaultGroup(name string) { + defer instances.Clear() + configs.Lock() + defer configs.Unlock() + configs.group = name +} + +// GetDefaultGroup returns the { name of default configuration. +func GetDefaultGroup() string { + defer instances.Clear() + configs.RLock() + defer configs.RUnlock() + return configs.group +} + +// IsConfigured checks and returns whether the database configured. +// It returns true if any configuration exists. +func IsConfigured() bool { + configs.RLock() + defer configs.RUnlock() + return len(configs.config) > 0 +} + +// SetLogger sets the logger for orm. +func (c *Core) SetLogger(logger log.ILogger) { + c.logger = logger +} + +// GetLogger returns the (logger) of the orm. +func (c *Core) GetLogger() log.ILogger { + return c.logger +} + +// SetMaxIdleConnCount sets the maximum number of connections in the idle +// connection pool. +// +// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns, +// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit. +// +// If n <= 0, no idle connections are retained. +// +// The default max idle connections is currently 2. This may change in +// a future release. +func (c *Core) SetMaxIdleConnCount(n int) { + c.dynamicConfig.MaxIdleConnCount = n +} + +// SetMaxOpenConnCount sets the maximum number of open connections to the database. +// +// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than +// MaxIdleConns, then MaxIdleConns will be reduced to match the new +// MaxOpenConns limit. +// +// If n <= 0, then there is no limit on the number of open connections. +// The default is 0 (unlimited). +func (c *Core) SetMaxOpenConnCount(n int) { + c.dynamicConfig.MaxOpenConnCount = n +} + +// SetMaxConnLifeTime sets the maximum amount of time a connection may be reused. +// +// Expired connections may be closed lazily before reuse. +// +// If d <= 0, connections are not closed due to a connection's age. +func (c *Core) SetMaxConnLifeTime(d time.Duration) { + c.dynamicConfig.MaxConnLifeTime = d +} + +// SetMaxIdleConnTime sets the maximum amount of time a connection may be idle before being closed. +// +// Idle connections may be closed lazily before reuse. +// +// If d <= 0, connections are not closed due to a connection's idle time. +// This is Go 1.15+ feature: sql.DB.SetConnMaxIdleTime. +func (c *Core) SetMaxIdleConnTime(d time.Duration) { + c.dynamicConfig.MaxIdleConnTime = d +} + +// GetConfig returns the current used node configuration. +func (c *Core) GetConfig() *ConfigNode { + var configNode = c.getConfigNodeFromCtx(c.db.GetCtx()) + if configNode != nil { + // Note: + // It so here checks and returns the config from current DB, + // if different schemas between current DB and config.Name from context, + // for example, in nested transaction scenario, the context is passed all through the logic procedure, + // but the config.Name from context may be still the original one from the first transaction object. + if c.config.Name == configNode.Name { + return configNode + } + } + return c.config +} + +// SetDebug enables/disables the debug mode. +func (c *Core) SetDebug(debug bool) { + c.debug.Set(debug) +} + +// GetDebug returns the debug value. +func (c *Core) GetDebug() bool { + return c.debug.Val() +} + +// GetCache returns the internal cache object. +func (c *Core) GetCache() *gcache.Cache { + return c.cache +} + +// GetGroup returns the group string configured. +func (c *Core) GetGroup() string { + return c.group +} + +// SetDryRun enables/disables the DryRun feature. +func (c *Core) SetDryRun(enabled bool) { + c.config.DryRun = enabled +} + +// GetDryRun returns the DryRun value. +func (c *Core) GetDryRun() bool { + return c.config.DryRun || allDryRun +} + +// GetPrefix returns the table prefix string configured. +func (c *Core) GetPrefix() string { + return c.config.Prefix +} + +// GetSchema returns the schema configured. +func (c *Core) GetSchema() string { + schema := c.schema + if schema == "" { + schema = c.db.GetConfig().Name + } + return schema +} + +func parseConfigNodeLink(node *ConfigNode) (*ConfigNode, error) { + var ( + link = node.Link + match []string + ) + if link != "" { + // To be compatible with old configuration, + // it checks and converts the link to new configuration. + if node.Type != "" && !gstr.HasPrefix(link, node.Type+":") { + link = fmt.Sprintf("%s:%s", node.Type, link) + } + match, _ = gregex.MatchString(linkPattern, link) + if len(match) <= 5 { + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `invalid link configuration: %s, shuold be pattern like: %s`, + link, linkPatternDescription, + ) + } + node.Type = match[1] + node.User = match[2] + node.Pass = match[3] + node.Protocol = match[4] + array := gstr.Split(match[5], ":") + if node.Protocol == "file" { + node.Name = match[5] + } else { + if len(array) == 2 { + // link with port. + node.Host = array[0] + node.Port = array[1] + } else { + // link without port. + node.Host = array[0] + } + node.Name = match[6] + } + if len(match) > 6 && match[7] != "" { + node.Extra = match[7] + } + } + if node.Extra != "" { + if m, _ := gstr.Parse(node.Extra); len(m) > 0 { + _ = gconv.Struct(m, &node) + } + } + // Default value checks. + if node.Charset == "" { + node.Charset = defaultCharset + } + if node.Protocol == "" { + node.Protocol = defaultProtocol + } + return node, nil +} diff --git a/database/gdb_core_ctx.go b/database/gdb_core_ctx.go new file mode 100644 index 0000000..1f9464c --- /dev/null +++ b/database/gdb_core_ctx.go @@ -0,0 +1,98 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "sync" + + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gctx" +) + +// internalCtxData stores data in ctx for internal usage purpose. +type internalCtxData struct { + sync.Mutex + // Used configuration node in current operation. + ConfigNode *ConfigNode +} + +// column stores column data in ctx for internal usage purpose. +type internalColumnData struct { + // The first column in result response from database server. + // This attribute is used for Value/Count selection statement purpose, + // which is to avoid HOOK handler that might modify the result columns + // that can confuse the Value/Count selection statement logic. + FirstResultColumn string +} + +const ( + internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData" + internalColumnDataKeyInCtx gctx.StrKey = "InternalColumnData" + + // `ignoreResultKeyInCtx` is a mark for some db drivers that do not support `RowsAffected` function, + // for example: `clickhouse`. The `clickhouse` does not support fetching insert/update results, + // but returns errors when execute `RowsAffected`. It here ignores the calling of `RowsAffected` + // to avoid triggering errors, rather than ignoring errors after they are triggered. + ignoreResultKeyInCtx gctx.StrKey = "IgnoreResult" +) + +func (c *Core) injectInternalCtxData(ctx context.Context) context.Context { + // If the internal data is already injected, it does nothing. + if ctx.Value(internalCtxDataKeyInCtx) != nil { + return ctx + } + return context.WithValue(ctx, internalCtxDataKeyInCtx, &internalCtxData{ + ConfigNode: c.config, + }) +} + +func (c *Core) setConfigNodeToCtx(ctx context.Context, node *ConfigNode) error { + value := ctx.Value(internalCtxDataKeyInCtx) + if value == nil { + return gerror.NewCode(gcode.CodeInternalError, `no internal data found in context`) + } + + data := value.(*internalCtxData) + data.Lock() + defer data.Unlock() + data.ConfigNode = node + return nil +} + +func (c *Core) getConfigNodeFromCtx(ctx context.Context) *ConfigNode { + if value := ctx.Value(internalCtxDataKeyInCtx); value != nil { + data := value.(*internalCtxData) + data.Lock() + defer data.Unlock() + return data.ConfigNode + } + return nil +} + +func (c *Core) injectInternalColumn(ctx context.Context) context.Context { + return context.WithValue(ctx, internalColumnDataKeyInCtx, &internalColumnData{}) +} + +func (c *Core) getInternalColumnFromCtx(ctx context.Context) *internalColumnData { + if v := ctx.Value(internalColumnDataKeyInCtx); v != nil { + return v.(*internalColumnData) + } + return nil +} + +func (c *Core) InjectIgnoreResult(ctx context.Context) context.Context { + if ctx.Value(ignoreResultKeyInCtx) != nil { + return ctx + } + return context.WithValue(ctx, ignoreResultKeyInCtx, true) +} + +func (c *Core) GetIgnoreResultFromCtx(ctx context.Context) bool { + return ctx.Value(ignoreResultKeyInCtx) != nil +} diff --git a/database/gdb_core_link.go b/database/gdb_core_link.go new file mode 100644 index 0000000..e74b13a --- /dev/null +++ b/database/gdb_core_link.go @@ -0,0 +1,43 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" +) + +// dbLink is used to implement interface Link for DB. +type dbLink struct { + *sql.DB // Underlying DB object. + isOnMaster bool // isOnMaster marks whether current link is operated on master node. +} + +// txLink is used to implement interface Link for TX. +type txLink struct { + *sql.Tx +} + +// IsTransaction returns if current Link is a transaction. +func (l *dbLink) IsTransaction() bool { + return false +} + +// IsOnMaster checks and returns whether current link is operated on master node. +func (l *dbLink) IsOnMaster() bool { + return l.isOnMaster +} + +// IsTransaction returns if current Link is a transaction. +func (l *txLink) IsTransaction() bool { + return true +} + +// IsOnMaster checks and returns whether current link is operated on master node. +// Note that, transaction operation is always operated on master node. +func (l *txLink) IsOnMaster() bool { + return true +} diff --git a/database/gdb_core_stats.go b/database/gdb_core_stats.go new file mode 100644 index 0000000..7f7cd4f --- /dev/null +++ b/database/gdb_core_stats.go @@ -0,0 +1,45 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. +// + +package database + +import ( + "context" + "database/sql" +) + +type localStatsItem struct { + node *ConfigNode + stats sql.DBStats +} + +// Node returns the configuration node info. +func (item *localStatsItem) Node() ConfigNode { + return *item.node +} + +// Stats returns the connection stat for current node. +func (item *localStatsItem) Stats() sql.DBStats { + return item.stats +} + +// Stats retrieves and returns the pool stat for all nodes that have been established. +func (c *Core) Stats(ctx context.Context) []StatsItem { + var items = make([]StatsItem, 0) + c.links.Iterator(func(k ConfigNode, v *sql.DB) bool { + // Create a local copy of k to avoid loop variable address re-use issue + // In Go, loop variables are re-used with the same memory address across iterations, + // directly using &k would cause all localStatsItem instances to share the same address + node := k + items = append(items, &localStatsItem{ + node: &node, + stats: v.Stats(), + }) + return true + }) + return items +} diff --git a/database/gdb_core_structure.go b/database/gdb_core_structure.go new file mode 100644 index 0000000..e90d125 --- /dev/null +++ b/database/gdb_core_structure.go @@ -0,0 +1,512 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "database/sql/driver" + "math/big" + "reflect" + "strings" + "time" + + "git.magicany.cc/black1552/gin-base/database/intlog" + "git.magicany.cc/black1552/gin-base/database/json" + "github.com/gogf/gf/v2/encoding/gbinary" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/gutil" +) + +// GetFieldTypeStr retrieves and returns the field type string for certain field by name. +func (c *Core) GetFieldTypeStr(ctx context.Context, fieldName, table, schema string) string { + field := c.GetFieldType(ctx, fieldName, table, schema) + if field != nil { + // Kinds of data type examples: + // year(4) + // datetime + // varchar(64) + // bigint(20) + // int(10) unsigned + typeName := gstr.StrTillEx(field.Type, "(") // int(10) unsigned -> int + if typeName != "" { + typeName = gstr.Trim(typeName) + } else { + typeName = field.Type + } + return typeName + } + return "" +} + +// GetFieldType retrieves and returns the field type object for certain field by name. +func (c *Core) GetFieldType(ctx context.Context, fieldName, table, schema string) *TableField { + fieldsMap, err := c.db.TableFields(ctx, table, schema) + if err != nil { + intlog.Errorf( + ctx, + `TableFields failed for table "%s", schema "%s": %+v`, + table, schema, err, + ) + return nil + } + for tableFieldName, tableField := range fieldsMap { + if tableFieldName == fieldName { + return tableField + } + } + return nil +} + +// ConvertDataForRecord is a very important function, which does converting for any data that +// will be inserted into table/collection as a record. +// +// The parameter `value` should be type of *map/map/*struct/struct. +// It supports embedded struct definition for struct. +func (c *Core) ConvertDataForRecord(ctx context.Context, value any, table string) (map[string]any, error) { + var ( + err error + data = MapOrStructToMapDeep(value, true) + ) + for fieldName, fieldValue := range data { + var fieldType = c.GetFieldTypeStr(ctx, fieldName, table, c.GetSchema()) + data[fieldName], err = c.db.ConvertValueForField( + ctx, + fieldType, + fieldValue, + ) + if err != nil { + return nil, gerror.Wrapf(err, `ConvertDataForRecord failed for value: %#v`, fieldValue) + } + } + return data, nil +} + +// ConvertValueForField converts value to the type of the record field. +// The parameter `fieldType` is the target record field. +// The parameter `fieldValue` is the value that to be committed to record field. +func (c *Core) ConvertValueForField(ctx context.Context, fieldType string, fieldValue any) (any, error) { + var ( + err error + convertedValue = fieldValue + ) + switch fieldValue.(type) { + case time.Time, *time.Time, gtime.Time, *gtime.Time: + goto Default + } + // If `value` implements interface `driver.Valuer`, it then uses the interface for value converting. + if valuer, ok := fieldValue.(driver.Valuer); ok { + if convertedValue, err = valuer.Value(); err != nil { + return nil, err + } + return convertedValue, nil + } +Default: + // Default value converting. + var ( + rvValue = reflect.ValueOf(fieldValue) + rvKind = rvValue.Kind() + ) + for rvKind == reflect.Pointer { + rvValue = rvValue.Elem() + rvKind = rvValue.Kind() + } + switch rvKind { + case reflect.Invalid: + convertedValue = nil + + case reflect.Slice, reflect.Array, reflect.Map: + // It should ignore the bytes type. + if _, ok := fieldValue.([]byte); !ok { + // Convert the value to JSON. + convertedValue, err = json.Marshal(fieldValue) + if err != nil { + return nil, err + } + } + case reflect.Struct: + switch r := fieldValue.(type) { + // If the time is zero, it then updates it to nil, + // which will insert/update the value to database as "null". + case time.Time: + if r.IsZero() { + convertedValue = nil + } else { + switch fieldType { + case fieldTypeYear: + convertedValue = r.Format("2006") + case fieldTypeDate: + convertedValue = r.Format("2006-01-02") + case fieldTypeTime: + convertedValue = r.Format("15:04:05") + default: + } + } + + case *time.Time: + if r == nil { + // Nothing to do. + } else { + switch fieldType { + case fieldTypeYear: + convertedValue = r.Format("2006") + case fieldTypeDate: + convertedValue = r.Format("2006-01-02") + case fieldTypeTime: + convertedValue = r.Format("15:04:05") + default: + } + } + + case gtime.Time: + if r.IsZero() { + convertedValue = nil + } else { + switch fieldType { + case fieldTypeYear: + convertedValue = r.Layout("2006") + case fieldTypeDate: + convertedValue = r.Layout("2006-01-02") + case fieldTypeTime: + convertedValue = r.Layout("15:04:05") + default: + convertedValue = r.Time + } + } + + case *gtime.Time: + if r.IsZero() { + convertedValue = nil + } else { + switch fieldType { + case fieldTypeYear: + convertedValue = r.Layout("2006") + case fieldTypeDate: + convertedValue = r.Layout("2006-01-02") + case fieldTypeTime: + convertedValue = r.Layout("15:04:05") + default: + convertedValue = r.Time + } + } + + case Counter, *Counter: + // Nothing to do. + + default: + // If `value` implements interface iNil, + // check its IsNil() function, if got ture, + // which will insert/update the value to database as "null". + if v, ok := fieldValue.(iNil); ok && v.IsNil() { + convertedValue = nil + } else if s, ok := fieldValue.(iString); ok { + // Use string conversion in default. + convertedValue = s.String() + } else { + // Convert the value to JSON. + convertedValue, err = json.Marshal(fieldValue) + if err != nil { + return nil, err + } + } + } + default: + } + + return convertedValue, nil +} + +// GetFormattedDBTypeNameForField retrieves and returns the formatted database type name +// eg. `int(10) unsigned` -> `int`, `varchar(100)` -> `varchar`, etc. +func (c *Core) GetFormattedDBTypeNameForField(fieldType string) (typeName, typePattern string) { + match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType) + if len(match) == 3 { + typeName = gstr.Trim(match[1]) + typePattern = gstr.Trim(match[2]) + } else { + var array = gstr.SplitAndTrim(fieldType, " ") + if len(array) > 1 && gstr.Equal(array[0], "unsigned") { + typeName = array[1] + } else if len(array) > 0 { + typeName = array[0] + } + } + typeName = strings.ToLower(typeName) + return +} + +// CheckLocalTypeForField checks and returns corresponding type for given db type. +// The `fieldType` is retrieved from ColumnTypes of db driver, example: +// UNSIGNED INT +func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ any) (LocalType, error) { + var ( + typeName string + typePattern string + ) + typeName, typePattern = c.GetFormattedDBTypeNameForField(fieldType) + switch typeName { + case + fieldTypeBinary, + fieldTypeVarbinary, + fieldTypeBlob, + fieldTypeTinyblob, + fieldTypeMediumblob, + fieldTypeLongblob: + return LocalTypeBytes, nil + + case + fieldTypeInt, + fieldTypeTinyint, + fieldTypeSmallInt, + fieldTypeSmallint, + fieldTypeMediumInt, + fieldTypeMediumint, + fieldTypeSerial: + if gstr.ContainsI(fieldType, "unsigned") { + return LocalTypeUint, nil + } + return LocalTypeInt, nil + + case + fieldTypeBigInt, + fieldTypeBigint, + fieldTypeBigserial: + if gstr.ContainsI(fieldType, "unsigned") { + return LocalTypeUint64, nil + } + return LocalTypeInt64, nil + + case + fieldTypeInt128, + fieldTypeInt256, + fieldTypeUint128, + fieldTypeUint256: + return LocalTypeBigInt, nil + + case + fieldTypeReal: + return LocalTypeFloat32, nil + + case + fieldTypeDecimal, + fieldTypeMoney, + fieldTypeNumeric, + fieldTypeSmallmoney: + return LocalTypeString, nil + case + fieldTypeFloat, + fieldTypeDouble: + return LocalTypeFloat64, nil + + case + fieldTypeBit: + // It is suggested using bit(1) as boolean. + if typePattern == "1" { + return LocalTypeBool, nil + } + if gstr.ContainsI(fieldType, "unsigned") { + return LocalTypeUint64Bytes, nil + } + return LocalTypeInt64Bytes, nil + + case + fieldTypeBool: + return LocalTypeBool, nil + + case + fieldTypeDate: + return LocalTypeDate, nil + + case + fieldTypeTime: + return LocalTypeTime, nil + + case + fieldTypeDatetime, + fieldTypeTimestamp, + fieldTypeTimestampz: + return LocalTypeDatetime, nil + + case + fieldTypeJson: + return LocalTypeJson, nil + + case + fieldTypeJsonb: + return LocalTypeJsonb, nil + + default: + // Auto-detect field type, using key match. + switch { + case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"): + return LocalTypeString, nil + + case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"): + return LocalTypeFloat64, nil + + case strings.Contains(typeName, "bool"): + return LocalTypeBool, nil + + case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"): + return LocalTypeBytes, nil + + case strings.Contains(typeName, "int"): + if gstr.ContainsI(fieldType, "unsigned") { + return LocalTypeUint, nil + } + return LocalTypeInt, nil + + case strings.Contains(typeName, "time"): + return LocalTypeDatetime, nil + + case strings.Contains(typeName, "date"): + return LocalTypeDatetime, nil + + default: + return LocalTypeString, nil + } + } +} + +// ConvertValueForLocal converts value to local Golang type of value according field type name from database. +// The parameter `fieldType` is in lower case, like: +// `float(5,2)`, `unsigned double(5,2)`, `decimal(10,2)`, `char(45)`, `varchar(100)`, etc. +func (c *Core) ConvertValueForLocal( + ctx context.Context, fieldType string, fieldValue any, +) (any, error) { + // If there's no type retrieved, it returns the `fieldValue` directly + // to use its original data type, as `fieldValue` is type of any. + if fieldType == "" { + return fieldValue, nil + } + typeName, err := c.db.CheckLocalTypeForField(ctx, fieldType, fieldValue) + if err != nil { + return nil, err + } + switch typeName { + case LocalTypeBytes: + var typeNameStr = string(typeName) + if strings.Contains(typeNameStr, "binary") || strings.Contains(typeNameStr, "blob") { + return fieldValue, nil + } + return gconv.Bytes(fieldValue), nil + + case LocalTypeInt: + return gconv.Int(gconv.String(fieldValue)), nil + + case LocalTypeUint: + return gconv.Uint(gconv.String(fieldValue)), nil + + case LocalTypeInt64: + return gconv.Int64(gconv.String(fieldValue)), nil + + case LocalTypeUint64: + return gconv.Uint64(gconv.String(fieldValue)), nil + + case LocalTypeInt64Bytes: + return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil + + case LocalTypeUint64Bytes: + return gbinary.BeDecodeToUint64(gconv.Bytes(fieldValue)), nil + + case LocalTypeBigInt: + switch v := fieldValue.(type) { + case big.Int: + return v.String(), nil + case *big.Int: + return v.String(), nil + default: + return gconv.String(fieldValue), nil + } + + case LocalTypeFloat32: + return gconv.Float32(gconv.String(fieldValue)), nil + + case LocalTypeFloat64: + return gconv.Float64(gconv.String(fieldValue)), nil + + case LocalTypeBool: + s := gconv.String(fieldValue) + // mssql is true|false string. + if strings.EqualFold(s, "true") { + return 1, nil + } + if strings.EqualFold(s, "false") { + return 0, nil + } + return gconv.Bool(fieldValue), nil + + case LocalTypeDate: + if t, ok := fieldValue.(time.Time); ok { + return gtime.NewFromTime(t).Format("Y-m-d"), nil + } + t, _ := gtime.StrToTime(gconv.String(fieldValue)) + return t.Format("Y-m-d"), nil + + case LocalTypeTime: + if t, ok := fieldValue.(time.Time); ok { + return gtime.NewFromTime(t).Format("H:i:s"), nil + } + t, _ := gtime.StrToTime(gconv.String(fieldValue)) + return t.Format("H:i:s"), nil + + case LocalTypeDatetime: + if t, ok := fieldValue.(time.Time); ok { + return gtime.NewFromTime(t), nil + } + t, _ := gtime.StrToTime(gconv.String(fieldValue)) + return t, nil + + default: + return gconv.String(fieldValue), nil + } +} + +// mappingAndFilterData automatically mappings the map key to table field and removes +// all key-value pairs that are not the field of given table. +func (c *Core) mappingAndFilterData(ctx context.Context, schema, table string, data map[string]any, filter bool) (map[string]any, error) { + fieldsMap, err := c.db.TableFields(ctx, c.guessPrimaryTableName(table), schema) + if err != nil { + return nil, err + } + if len(fieldsMap) == 0 { + return nil, gerror.Newf(`The table %s may not exist, or the table contains no fields`, table) + } + fieldsKeyMap := make(map[string]any, len(fieldsMap)) + for k := range fieldsMap { + fieldsKeyMap[k] = nil + } + // Automatic data key to table field name mapping. + var foundKey string + for dataKey, dataValue := range data { + if _, ok := fieldsKeyMap[dataKey]; !ok { + foundKey, _ = gutil.MapPossibleItemByKey(fieldsKeyMap, dataKey) + if foundKey != "" { + if _, ok = data[foundKey]; !ok { + data[foundKey] = dataValue + } + delete(data, dataKey) + } + } + } + // Data filtering. + // It deletes all key-value pairs that has incorrect field name. + if filter { + for dataKey := range data { + if _, ok := fieldsMap[dataKey]; !ok { + delete(data, dataKey) + } + } + if len(data) == 0 { + return nil, gerror.Newf(`input data match no fields in table %s`, table) + } + } + return data, nil +} diff --git a/database/gdb_core_trace.go b/database/gdb_core_trace.go new file mode 100644 index 0000000..b4d440b --- /dev/null +++ b/database/gdb_core_trace.go @@ -0,0 +1,84 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. +// + +package database + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.18.0" + "go.opentelemetry.io/otel/trace" + + "github.com/gogf/gf/v2/net/gtrace" +) + +const ( + traceInstrumentName = "github.com/gogf/gf/v2/database/gdb" + traceAttrDbType = "db.type" + traceAttrDbHost = "db.host" + traceAttrDbPort = "db.port" + traceAttrDbName = "db.name" + traceAttrDbUser = "db.user" + traceAttrDbLink = "db.link" + traceAttrDbGroup = "db.group" + traceEventDbExecution = "db.execution" + traceEventDbExecutionCost = "db.execution.cost" + traceEventDbExecutionRows = "db.execution.rows" + traceEventDbExecutionTxID = "db.execution.txid" + traceEventDbExecutionType = "db.execution.type" +) + +// addSqlToTracing adds sql information to tracer if it's enabled. +func (c *Core) traceSpanEnd(ctx context.Context, span trace.Span, sql *Sql) { + if gtrace.IsUsingDefaultProvider() || !gtrace.IsTracingInternal() { + return + } + if sql.Error != nil { + span.SetStatus(codes.Error, fmt.Sprintf(`%+v`, sql.Error)) + } + labels := make([]attribute.KeyValue, 0) + labels = append(labels, gtrace.CommonLabels()...) + labels = append(labels, + attribute.String(traceAttrDbType, c.db.GetConfig().Type), + semconv.DBStatement(sql.Format), + ) + if c.db.GetConfig().Host != "" { + labels = append(labels, attribute.String(traceAttrDbHost, c.db.GetConfig().Host)) + } + if c.db.GetConfig().Port != "" { + labels = append(labels, attribute.String(traceAttrDbPort, c.db.GetConfig().Port)) + } + if c.db.GetConfig().Name != "" { + labels = append(labels, attribute.String(traceAttrDbName, c.db.GetConfig().Name)) + } + if c.db.GetConfig().User != "" { + labels = append(labels, attribute.String(traceAttrDbUser, c.db.GetConfig().User)) + } + if filteredLink := c.db.GetCore().FilteredLink(); filteredLink != "" { + labels = append(labels, attribute.String(traceAttrDbLink, c.db.GetCore().FilteredLink())) + } + if group := c.db.GetGroup(); group != "" { + labels = append(labels, attribute.String(traceAttrDbGroup, group)) + } + span.SetAttributes(labels...) + events := []attribute.KeyValue{ + attribute.String(traceEventDbExecutionCost, fmt.Sprintf(`%d ms`, sql.End-sql.Start)), + attribute.String(traceEventDbExecutionRows, fmt.Sprintf(`%d`, sql.RowsAffected)), + } + if sql.IsTransaction { + if v := ctx.Value(transactionIdForLoggerCtx); v != nil { + events = append(events, attribute.String( + traceEventDbExecutionTxID, fmt.Sprintf(`%d`, v.(uint64)), + )) + } + } + events = append(events, attribute.String(traceEventDbExecutionType, string(sql.Type))) + span.AddEvent(traceEventDbExecution, trace.WithAttributes(events...)) +} diff --git a/database/gdb_core_transaction.go b/database/gdb_core_transaction.go new file mode 100644 index 0000000..eddd413 --- /dev/null +++ b/database/gdb_core_transaction.go @@ -0,0 +1,295 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "database/sql" + + "github.com/gogf/gf/v2/container/gtype" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" +) + +// Propagation defines transaction propagation behavior. +type Propagation string + +const ( + // PropagationNested starts a nested transaction if already in a transaction, + // or behaves like PropagationRequired if not in a transaction. + // + // It is the default behavior. + PropagationNested Propagation = "NESTED" + + // PropagationRequired starts a new transaction if not in a transaction, + // or uses the existing transaction if already in a transaction. + PropagationRequired Propagation = "REQUIRED" + + // PropagationSupports executes within the existing transaction if present, + // otherwise executes without transaction. + PropagationSupports Propagation = "SUPPORTS" + + // PropagationRequiresNew starts a new transaction, and suspends the current transaction if one exists. + PropagationRequiresNew Propagation = "REQUIRES_NEW" + + // PropagationNotSupported executes non-transactional, suspends any existing transaction. + PropagationNotSupported Propagation = "NOT_SUPPORTED" + + // PropagationMandatory executes in a transaction, fails if no existing transaction. + PropagationMandatory Propagation = "MANDATORY" + + // PropagationNever executes non-transactional, fails if in an existing transaction. + PropagationNever Propagation = "NEVER" +) + +// TxOptions defines options for transaction control. +type TxOptions struct { + // Propagation specifies the propagation behavior. + Propagation Propagation + // Isolation is the transaction isolation level. + // If zero, the driver or database's default level is used. + Isolation sql.IsolationLevel + // ReadOnly is used to mark the transaction as read-only. + ReadOnly bool +} + +// Context key types for transaction to avoid collisions +type transactionCtxKey string + +const ( + transactionPointerPrefix = "transaction" + contextTransactionKeyPrefix = "TransactionObjectForGroup_" + transactionIdForLoggerCtx transactionCtxKey = "TransactionId" +) + +var transactionIdGenerator = gtype.NewUint64() + +// DefaultTxOptions returns the default transaction options. +func DefaultTxOptions() TxOptions { + return TxOptions{ + // Note the default propagation type is PropagationNested not PropagationRequired. + Propagation: PropagationNested, + } +} + +// Begin starts and returns the transaction object. +// You should call Commit or Rollback functions of the transaction object +// if you no longer use the transaction. Commit or Rollback functions will also +// close the transaction automatically. +func (c *Core) Begin(ctx context.Context) (tx TX, err error) { + return c.BeginWithOptions(ctx, DefaultTxOptions()) +} + +// BeginWithOptions starts and returns the transaction object with given options. +// The options allow specifying the isolation level and read-only mode. +// You should call Commit or Rollback functions of the transaction object +// if you no longer use the transaction. Commit or Rollback functions will also +// close the transaction automatically. +func (c *Core) BeginWithOptions(ctx context.Context, opts TxOptions) (tx TX, err error) { + if ctx == nil { + ctx = c.db.GetCtx() + } + ctx = c.injectInternalCtxData(ctx) + return c.doBeginCtx(ctx, sql.TxOptions{ + Isolation: opts.Isolation, + ReadOnly: opts.ReadOnly, + }) +} + +func (c *Core) doBeginCtx(ctx context.Context, opts sql.TxOptions) (TX, error) { + master, err := c.db.Master() + if err != nil { + return nil, err + } + var out DoCommitOutput + out, err = c.db.DoCommit(ctx, DoCommitInput{ + Db: master, + Sql: "BEGIN", + Type: SqlTypeBegin, + TxOptions: opts, + IsTransaction: true, + }) + return out.Tx, err +} + +// Transaction wraps the transaction logic using function `f`. +// It rollbacks the transaction and returns the error from function `f` if +// it returns non-nil error. It commits the transaction and returns nil if +// function `f` returns nil. +// +// Note that, you should not Commit or Rollback the transaction in function `f` +// as it is automatically handled by this function. +func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) { + return c.TransactionWithOptions(ctx, DefaultTxOptions(), f) +} + +// TransactionWithOptions wraps the transaction logic with propagation options using function `f`. +func (c *Core) TransactionWithOptions( + ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error, +) (err error) { + if ctx == nil { + ctx = c.db.GetCtx() + } + ctx = c.injectInternalCtxData(ctx) + + // Check current transaction from context + var ( + group = c.db.GetGroup() + currentTx = TXFromCtx(ctx, group) + ) + switch opts.Propagation { + case PropagationRequired: + if currentTx != nil { + return f(ctx, currentTx) + } + return c.createNewTransaction(ctx, opts, f) + + case PropagationSupports: + if currentTx == nil { + currentTx = c.newEmptyTX() + } + return f(ctx, currentTx) + + case PropagationMandatory: + if currentTx == nil { + return gerror.NewCode( + gcode.CodeInvalidOperation, + "transaction propagation MANDATORY requires an existing transaction", + ) + } + return f(ctx, currentTx) + + case PropagationRequiresNew: + ctx = WithoutTX(ctx, group) + return c.createNewTransaction(ctx, opts, f) + + case PropagationNotSupported: + ctx = WithoutTX(ctx, group) + return f(ctx, c.newEmptyTX()) + + case PropagationNever: + if currentTx != nil { + return gerror.NewCode( + gcode.CodeInvalidOperation, + "transaction propagation NEVER cannot run within an existing transaction", + ) + } + ctx = WithoutTX(ctx, group) + return f(ctx, c.newEmptyTX()) + + case PropagationNested: + if currentTx != nil { + return currentTx.Transaction(ctx, f) + } + return c.createNewTransaction(ctx, opts, f) + + default: + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "unsupported propagation behavior: %s", + opts.Propagation, + ) + } +} + +// createNewTransaction handles creating and managing a new transaction +func (c *Core) createNewTransaction( + ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error, +) (err error) { + // Begin transaction with options + tx, err := c.doBeginCtx(ctx, sql.TxOptions{ + Isolation: opts.Isolation, + ReadOnly: opts.ReadOnly, + }) + if err != nil { + return err + } + + // Inject transaction object into context + ctx = WithTX(tx.GetCtx(), tx) + err = callTxFunc(tx.Ctx(ctx), f) + return +} + +func callTxFunc(tx TX, f func(ctx context.Context, tx TX) error) (err error) { + defer func() { + if err == nil { + if exception := recover(); exception != nil { + if v, ok := exception.(error); ok && gerror.HasStack(v) { + err = v + } else { + err = gerror.NewCodef(gcode.CodeInternalPanic, "%+v", exception) + } + } + } + if err != nil { + if e := tx.Rollback(); e != nil { + err = e + } + } else { + if e := tx.Commit(); e != nil { + err = e + } + } + }() + err = f(tx.GetCtx(), tx) + return +} + +// WithTX injects given transaction object into context and returns a new context. +func WithTX(ctx context.Context, tx TX) context.Context { + if tx == nil { + return ctx + } + // Check repeat injection from given. + group := tx.GetDB().GetGroup() + if ctxTx := TXFromCtx(ctx, group); ctxTx != nil && ctxTx.GetDB().GetGroup() == group { + return ctx + } + dbCtx := tx.GetDB().GetCtx() + if ctxTx := TXFromCtx(dbCtx, group); ctxTx != nil && ctxTx.GetDB().GetGroup() == group { + return dbCtx + } + // Inject transaction object and id into context. + ctx = context.WithValue(ctx, transactionKeyForContext(group), tx) + ctx = context.WithValue(ctx, transactionIdForLoggerCtx, tx.GetCtx().Value(transactionIdForLoggerCtx)) + return ctx +} + +// WithoutTX removed transaction object from context and returns a new context. +func WithoutTX(ctx context.Context, group string) context.Context { + ctx = context.WithValue(ctx, transactionKeyForContext(group), nil) + ctx = context.WithValue(ctx, transactionIdForLoggerCtx, nil) + return ctx +} + +// TXFromCtx retrieves and returns transaction object from context. +// It is usually used in nested transaction feature, and it returns nil if it is not set previously. +func TXFromCtx(ctx context.Context, group string) TX { + if ctx == nil { + return nil + } + v := ctx.Value(transactionKeyForContext(group)) + if v != nil { + tx := v.(TX) + if tx.IsClosed() { + return nil + } + // no underlying sql tx. + if tx.GetSqlTX() == nil { + return nil + } + tx = tx.Ctx(ctx) + return tx + } + return nil +} + +// transactionKeyForContext forms and returns a key for storing transaction object of certain database group into context. +func transactionKeyForContext(group string) transactionCtxKey { + return transactionCtxKey(contextTransactionKeyPrefix + group) +} diff --git a/database/gdb_core_txcore.go b/database/gdb_core_txcore.go new file mode 100644 index 0000000..f6dc43c --- /dev/null +++ b/database/gdb_core_txcore.go @@ -0,0 +1,437 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "database/sql" + "reflect" + + "git.magicany.cc/black1552/gin-base/database/reflection" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/util/gconv" +) + +// TXCore is the struct for transaction management. +type TXCore struct { + // db is the database management interface that implements the DB interface, + // providing access to database operations and configuration. + db DB + // tx is the underlying SQL transaction object from database/sql package, + // which manages the actual transaction operations. + tx *sql.Tx + // ctx is the context specific to this transaction, + // which can be used for timeout control and cancellation. + ctx context.Context + // master is the underlying master database connection pool, + // used for direct database operations when needed. + master *sql.DB + // transactionId is a unique identifier for this transaction instance, + // used for tracking and debugging purposes. + transactionId string + // transactionCount tracks the number of nested transaction begins, + // used for managing transaction nesting depth. + transactionCount int + // isClosed indicates whether this transaction has been finalized + // through either a commit or rollback operation. + isClosed bool + // cancelFunc is the context cancellation function associated with ctx, + // used to cancel the transaction context when needed. + cancelFunc context.CancelFunc +} + +func (c *Core) newEmptyTX() TX { + return &TXCore{ + db: c.db, + } +} + +// transactionKeyForNestedPoint forms and returns the transaction key at current save point. +func (tx *TXCore) transactionKeyForNestedPoint() string { + return tx.db.GetCore().QuoteWord( + transactionPointerPrefix + gconv.String(tx.transactionCount), + ) +} + +// Ctx sets the context for current transaction. +func (tx *TXCore) Ctx(ctx context.Context) TX { + tx.ctx = ctx + if tx.ctx != nil { + tx.ctx = tx.db.GetCore().injectInternalCtxData(tx.ctx) + } + return tx +} + +// GetCtx returns the context for current transaction. +func (tx *TXCore) GetCtx() context.Context { + return tx.ctx +} + +// GetDB returns the DB for current transaction. +func (tx *TXCore) GetDB() DB { + return tx.db +} + +// GetSqlTX returns the underlying transaction object for current transaction. +func (tx *TXCore) GetSqlTX() *sql.Tx { + return tx.tx +} + +// Commit commits current transaction. +// Note that it releases previous saved transaction point if it's in a nested transaction procedure, +// or else it commits the hole transaction. +func (tx *TXCore) Commit() error { + if tx.transactionCount > 0 { + tx.transactionCount-- + _, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKeyForNestedPoint()) + return err + } + _, err := tx.db.DoCommit(tx.ctx, DoCommitInput{ + Tx: tx.tx, + Sql: "COMMIT", + Type: SqlTypeTXCommit, + TxCancelFunc: tx.cancelFunc, + IsTransaction: true, + }) + if err == nil { + tx.isClosed = true + } + return err +} + +// Rollback aborts current transaction. +// Note that it aborts current transaction if it's in a nested transaction procedure, +// or else it aborts the hole transaction. +func (tx *TXCore) Rollback() error { + if tx.transactionCount > 0 { + tx.transactionCount-- + _, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKeyForNestedPoint()) + return err + } + _, err := tx.db.DoCommit(tx.ctx, DoCommitInput{ + Tx: tx.tx, + Sql: "ROLLBACK", + Type: SqlTypeTXRollback, + TxCancelFunc: tx.cancelFunc, + IsTransaction: true, + }) + if err == nil { + tx.isClosed = true + } + return err +} + +// IsClosed checks and returns this transaction has already been committed or rolled back. +func (tx *TXCore) IsClosed() bool { + return tx.isClosed +} + +// Begin starts a nested transaction procedure. +func (tx *TXCore) Begin() error { + _, err := tx.Exec("SAVEPOINT " + tx.transactionKeyForNestedPoint()) + if err != nil { + return err + } + tx.transactionCount++ + return nil +} + +// SavePoint performs `SAVEPOINT xxx` SQL statement that saves transaction at current point. +// The parameter `point` specifies the point name that will be saved to server. +func (tx *TXCore) SavePoint(point string) error { + _, err := tx.Exec("SAVEPOINT " + tx.db.GetCore().QuoteWord(point)) + return err +} + +// RollbackTo performs `ROLLBACK TO SAVEPOINT xxx` SQL statement that rollbacks to specified saved transaction. +// The parameter `point` specifies the point name that was saved previously. +func (tx *TXCore) RollbackTo(point string) error { + _, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.GetCore().QuoteWord(point)) + return err +} + +// Transaction wraps the transaction logic using function `f`. +// It rollbacks the transaction and returns the error from function `f` if +// it returns non-nil error. It commits the transaction and returns nil if +// function `f` returns nil. +// +// Note that, you should not Commit or Rollback the transaction in function `f` +// as it is automatically handled by this function. +func (tx *TXCore) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) { + if ctx != nil { + tx.ctx = ctx + } + // Check transaction object from context. + if TXFromCtx(tx.ctx, tx.db.GetGroup()) == nil { + // Inject transaction object into context. + tx.ctx = WithTX(tx.ctx, tx) + } + if err = tx.Begin(); err != nil { + return err + } + err = callTxFunc(tx, f) + return +} + +// TransactionWithOptions wraps the transaction logic with propagation options using function `f`. +func (tx *TXCore) TransactionWithOptions( + ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error, +) (err error) { + return tx.db.TransactionWithOptions(ctx, opts, f) +} + +// Query does query operation on transaction. +// See Core.Query. +func (tx *TXCore) Query(sql string, args ...any) (result Result, err error) { + return tx.db.DoQuery(tx.ctx, &txLink{tx.tx}, sql, args...) +} + +// Exec does none query operation on transaction. +// See Core.Exec. +func (tx *TXCore) Exec(sql string, args ...any) (sql.Result, error) { + return tx.db.DoExec(tx.ctx, &txLink{tx.tx}, sql, args...) +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +func (tx *TXCore) Prepare(sql string) (*Stmt, error) { + return tx.db.DoPrepare(tx.ctx, &txLink{tx.tx}, sql) +} + +// GetAll queries and returns data records from database. +func (tx *TXCore) GetAll(sql string, args ...any) (Result, error) { + return tx.Query(sql, args...) +} + +// GetOne queries and returns one record from database. +func (tx *TXCore) GetOne(sql string, args ...any) (Record, error) { + list, err := tx.GetAll(sql, args...) + if err != nil { + return nil, err + } + if len(list) > 0 { + return list[0], nil + } + return nil, nil +} + +// GetStruct queries one record from database and converts it to given struct. +// The parameter `pointer` should be a pointer to struct. +func (tx *TXCore) GetStruct(obj any, sql string, args ...any) error { + one, err := tx.GetOne(sql, args...) + if err != nil { + return err + } + return one.Struct(obj) +} + +// GetStructs queries records from database and converts them to given struct. +// The parameter `pointer` should be type of struct slice: []struct/[]*struct. +func (tx *TXCore) GetStructs(objPointerSlice any, sql string, args ...any) error { + all, err := tx.GetAll(sql, args...) + if err != nil { + return err + } + return all.Structs(objPointerSlice) +} + +// GetScan queries one or more records from database and converts them to given struct or +// struct array. +// +// If parameter `pointer` is type of struct pointer, it calls GetStruct internally for +// the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally +// for conversion. +func (tx *TXCore) GetScan(pointer any, sql string, args ...any) error { + reflectInfo := reflection.OriginTypeAndKind(pointer) + if reflectInfo.InputKind != reflect.Pointer { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "params should be type of pointer, but got: %v", + reflectInfo.InputKind, + ) + } + switch reflectInfo.OriginKind { + case reflect.Array, reflect.Slice: + return tx.GetStructs(pointer, sql, args...) + + case reflect.Struct: + return tx.GetStruct(pointer, sql, args...) + + default: + } + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `in valid parameter type "%v", of which element type should be type of struct/slice`, + reflectInfo.InputType, + ) +} + +// GetValue queries and returns the field value from database. +// The sql should query only one field from database, or else it returns only one +// field of the result. +func (tx *TXCore) GetValue(sql string, args ...any) (Value, error) { + one, err := tx.GetOne(sql, args...) + if err != nil { + return nil, err + } + for _, v := range one { + return v, nil + } + return nil, nil +} + +// GetCount queries and returns the count from database. +func (tx *TXCore) GetCount(sql string, args ...any) (int64, error) { + if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) { + sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql) + } + value, err := tx.GetValue(sql, args...) + if err != nil { + return 0, err + } + return value.Int64(), nil +} + +// Insert does "INSERT INTO ..." statement for the table. +// If there's already one unique record of the data in the table, it returns error. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `batch` specifies the batch operation count when given data is slice. +func (tx *TXCore) Insert(table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Insert() + } + return tx.Model(table).Ctx(tx.ctx).Data(data).Insert() +} + +// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table. +// If there's already one unique record of the data in the table, it ignores the inserting. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `batch` specifies the batch operation count when given data is slice. +func (tx *TXCore) InsertIgnore(table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertIgnore() + } + return tx.Model(table).Ctx(tx.ctx).Data(data).InsertIgnore() +} + +// InsertAndGetId performs action Insert and returns the last insert id that automatically generated. +func (tx *TXCore) InsertAndGetId(table string, data any, batch ...int) (int64, error) { + if len(batch) > 0 { + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertAndGetId() + } + return tx.Model(table).Ctx(tx.ctx).Data(data).InsertAndGetId() +} + +// Replace does "REPLACE INTO ..." statement for the table. +// If there's already one unique record of the data in the table, it deletes the record +// and inserts a new one. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// If given data is type of slice, it then does batch replacing, and the optional parameter +// `batch` specifies the batch operation count. +func (tx *TXCore) Replace(table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Replace() + } + return tx.Model(table).Ctx(tx.ctx).Data(data).Replace() +} + +// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table. +// It updates the record if there's primary or unique index in the saving data, +// or else it inserts a new record into the table. +// +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// If given data is type of slice, it then does batch saving, and the optional parameter +// `batch` specifies the batch operation count. +func (tx *TXCore) Save(table string, data any, batch ...int) (sql.Result, error) { + if len(batch) > 0 { + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Save() + } + return tx.Model(table).Ctx(tx.ctx).Data(data).Save() +} + +// Update does "UPDATE ... " statement for the table. +// +// The parameter `data` can be type of string/map/gmap/struct/*struct, etc. +// Eg: "uid=10000", "uid", 10000, g.Map{"uid": 10000, "name":"john"} +// +// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc. +// It is commonly used with parameter `args`. +// Eg: +// "uid=10000", +// "uid", 10000 +// "money>? AND name like ?", 99999, "vip_%" +// "status IN (?)", g.Slice{1,2,3} +// "age IN(?,?)", 18, 50 +// User{ Id : 1, UserName : "john"}. +func (tx *TXCore) Update(table string, data any, condition any, args ...any) (sql.Result, error) { + return tx.Model(table).Ctx(tx.ctx).Data(data).Where(condition, args...).Update() +} + +// Delete does "DELETE FROM ... " statement for the table. +// +// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc. +// It is commonly used with parameter `args`. +// Eg: +// "uid=10000", +// "uid", 10000 +// "money>? AND name like ?", 99999, "vip_%" +// "status IN (?)", g.Slice{1,2,3} +// "age IN(?,?)", 18, 50 +// User{ Id : 1, UserName : "john"}. +func (tx *TXCore) Delete(table string, condition any, args ...any) (sql.Result, error) { + return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete() +} + +// QueryContext implements interface function Link.QueryContext. +func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...any) (*sql.Rows, error) { + return tx.tx.QueryContext(ctx, sql, args...) +} + +// ExecContext implements interface function Link.ExecContext. +func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...any) (sql.Result, error) { + return tx.tx.ExecContext(ctx, sql, args...) +} + +// PrepareContext implements interface function Link.PrepareContext. +func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) { + return tx.tx.PrepareContext(ctx, sql) +} + +// IsOnMaster implements interface function Link.IsOnMaster. +func (tx *TXCore) IsOnMaster() bool { + return true +} + +// IsTransaction implements interface function Link.IsTransaction. +func (tx *TXCore) IsTransaction() bool { + return tx != nil +} diff --git a/database/gdb_core_underlying.go b/database/gdb_core_underlying.go new file mode 100644 index 0000000..1753fbe --- /dev/null +++ b/database/gdb_core_underlying.go @@ -0,0 +1,533 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. +// + +package database + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "git.magicany.cc/black1552/gin-base/database/intlog" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + + "github.com/gogf/gf/v2" + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/guid" +) + +// Query commits one query SQL to underlying driver and returns the execution result. +// It is most commonly used for data querying. +func (c *Core) Query(ctx context.Context, sql string, args ...any) (result Result, err error) { + return c.db.DoQuery(ctx, nil, sql, args...) +} + +// DoQuery commits the sql string and its arguments to underlying driver +// through given link object and returns the execution result. +func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) { + // Transaction checks. + if link == nil { + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLink{tx.GetSqlTX()} + } else if link, err = c.SlaveLink(); err != nil { + // Or else it creates one from master node. + return nil, err + } + } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + link = &txLink{tx.GetSqlTX()} + } + } + + // Sql filtering. + sql, args = c.FormatSqlBeforeExecuting(sql, args) + sql, args, err = c.db.DoFilter(ctx, link, sql, args) + if err != nil { + return nil, err + } + // SQL format and retrieve. + if v := ctx.Value(ctxKeyCatchSQL); v != nil { + var ( + manager = v.(*CatchSQLManager) + formattedSql = FormatSqlWithArgs(sql, args) + ) + manager.SQLArray.Append(formattedSql) + if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil { + return nil, nil + } + } + // Link execution. + var out DoCommitOutput + out, err = c.db.DoCommit(ctx, DoCommitInput{ + Link: link, + Sql: sql, + Args: args, + Stmt: nil, + Type: SqlTypeQueryContext, + IsTransaction: link.IsTransaction(), + }) + if err != nil { + return nil, err + } + return out.Records, err +} + +// Exec commits one query SQL to underlying driver and returns the execution result. +// It is most commonly used for data inserting and updating. +func (c *Core) Exec(ctx context.Context, sql string, args ...any) (result sql.Result, err error) { + return c.db.DoExec(ctx, nil, sql, args...) +} + +// DoExec commits the sql string and its arguments to underlying driver +// through given link object and returns the execution result. +func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...any) (result sql.Result, err error) { + // Transaction checks. + if link == nil { + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLink{tx.GetSqlTX()} + } else if link, err = c.MasterLink(); err != nil { + // Or else it creates one from master node. + return nil, err + } + } else if !link.IsTransaction() { + // If current link is not transaction link, it tries retrieving transaction object from context. + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + link = &txLink{tx.GetSqlTX()} + } + } + + // SQL filtering. + sql, args = c.FormatSqlBeforeExecuting(sql, args) + sql, args, err = c.db.DoFilter(ctx, link, sql, args) + if err != nil { + return nil, err + } + // SQL format and retrieve. + if v := ctx.Value(ctxKeyCatchSQL); v != nil { + var ( + manager = v.(*CatchSQLManager) + formattedSql = FormatSqlWithArgs(sql, args) + ) + manager.SQLArray.Append(formattedSql) + if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil { + return new(SqlResult), nil + } + } + // Link execution. + var out DoCommitOutput + out, err = c.db.DoCommit(ctx, DoCommitInput{ + Link: link, + Sql: sql, + Args: args, + Stmt: nil, + Type: SqlTypeExecContext, + IsTransaction: link.IsTransaction(), + }) + if err != nil { + return nil, err + } + return out.Result, err +} + +// DoFilter is a hook function, which filters the sql and its arguments before it's committed to underlying driver. +// The parameter `link` specifies the current database connection operation object. You can modify the sql +// string `sql` and its arguments `args` as you wish before they're committed to driver. +func (c *Core) DoFilter( + ctx context.Context, link Link, sql string, args []any, +) (newSql string, newArgs []any, err error) { + return sql, args, nil +} + +// DoCommit commits current sql and arguments to underlying sql driver. +func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) { + var ( + sqlTx *sql.Tx + sqlStmt *sql.Stmt + sqlRows *sql.Rows + sqlResult sql.Result + stmtSqlRows *sql.Rows + stmtSqlRow *sql.Row + rowsAffected int64 + cancelFuncForTimeout context.CancelFunc + formattedSql = FormatSqlWithArgs(in.Sql, in.Args) + timestampMilli1 = gtime.TimestampMilli() + ) + + // Panic recovery to handle panics from underlying database drivers + defer func() { + if exception := recover(); exception != nil { + if err == nil { + if v, ok := exception.(error); ok && gerror.HasStack(v) { + err = v + } else { + err = gerror.WrapCodef(gcode.CodeDbOperationError, gerror.NewCodef(gcode.CodeInternalPanic, "%+v", exception), FormatSqlWithArgs(in.Sql, in.Args)) + } + } + } + }() + + // Trace span start. + tr := otel.GetTracerProvider().Tracer(traceInstrumentName, trace.WithInstrumentationVersion(gf.VERSION)) + ctx, span := tr.Start(ctx, string(in.Type), trace.WithSpanKind(trace.SpanKindClient)) + defer span.End() + + // Execution by type. + switch in.Type { + case SqlTypeBegin: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeTrans) + formattedSql = fmt.Sprintf( + `%s (IosolationLevel: %s, ReadOnly: %t)`, + formattedSql, in.TxOptions.Isolation.String(), in.TxOptions.ReadOnly, + ) + if sqlTx, err = in.Db.BeginTx(ctx, &in.TxOptions); err == nil { + tx := &TXCore{ + db: c.db, + tx: sqlTx, + ctx: ctx, + master: in.Db, + transactionId: guid.S(), + cancelFunc: cancelFuncForTimeout, + } + tx.ctx = context.WithValue(ctx, transactionKeyForContext(tx.db.GetGroup()), tx) + tx.ctx = context.WithValue(tx.ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)) + out.Tx = tx + ctx = out.Tx.GetCtx() + } + out.RawResult = sqlTx + + case SqlTypeTXCommit: + if in.TxCancelFunc != nil { + defer in.TxCancelFunc() + } + err = in.Tx.Commit() + + case SqlTypeTXRollback: + if in.TxCancelFunc != nil { + defer in.TxCancelFunc() + } + err = in.Tx.Rollback() + + case SqlTypeExecContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec) + defer cancelFuncForTimeout() + if c.db.GetDryRun() { + sqlResult = new(SqlResult) + } else { + sqlResult, err = in.Link.ExecContext(ctx, in.Sql, in.Args...) + } + out.RawResult = sqlResult + + case SqlTypeQueryContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery) + defer cancelFuncForTimeout() + sqlRows, err = in.Link.QueryContext(ctx, in.Sql, in.Args...) + out.RawResult = sqlRows + + case SqlTypePrepareContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypePrepare) + defer cancelFuncForTimeout() + sqlStmt, err = in.Link.PrepareContext(ctx, in.Sql) + out.RawResult = sqlStmt + + case SqlTypeStmtExecContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec) + defer cancelFuncForTimeout() + if c.db.GetDryRun() { + sqlResult = new(SqlResult) + } else { + sqlResult, err = in.Stmt.ExecContext(ctx, in.Args...) + } + out.RawResult = sqlResult + + case SqlTypeStmtQueryContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery) + defer cancelFuncForTimeout() + stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...) + out.RawResult = stmtSqlRows + + case SqlTypeStmtQueryRowContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery) + defer cancelFuncForTimeout() + stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...) + out.RawResult = stmtSqlRow + + default: + panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid SqlType "%s"`, in.Type)) + } + // Result handling. + switch { + case sqlResult != nil && !c.GetIgnoreResultFromCtx(ctx): + rowsAffected, err = sqlResult.RowsAffected() + out.Result = sqlResult + + case sqlRows != nil: + out.Records, err = c.RowsToResult(ctx, sqlRows) + rowsAffected = int64(len(out.Records)) + + case sqlStmt != nil: + out.Stmt = &Stmt{ + Stmt: sqlStmt, + core: c, + link: in.Link, + sql: in.Sql, + } + } + var ( + timestampMilli2 = gtime.TimestampMilli() + sqlObj = &Sql{ + Sql: in.Sql, + Type: in.Type, + Args: in.Args, + Format: formattedSql, + Error: err, + Start: timestampMilli1, + End: timestampMilli2, + Group: c.db.GetGroup(), + Schema: c.db.GetSchema(), + RowsAffected: rowsAffected, + IsTransaction: in.IsTransaction, + } + ) + + // Tracing. + c.traceSpanEnd(ctx, span, sqlObj) + + // Logging. + if c.db.GetDebug() { + c.writeSqlToLogger(ctx, sqlObj) + } + if err != nil && err != sql.ErrNoRows { + err = gerror.WrapCode( + gcode.CodeDbOperationError, + err, + FormatSqlWithArgs(in.Sql, in.Args), + ) + } + return out, err +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// The parameter `execOnMaster` specifies whether executing the sql on master node, +// or else it executes the sql on slave node if master-slave configured. +func (c *Core) Prepare(ctx context.Context, sql string, execOnMaster ...bool) (*Stmt, error) { + var ( + err error + link Link + ) + if len(execOnMaster) > 0 && execOnMaster[0] { + if link, err = c.MasterLink(); err != nil { + return nil, err + } + } else { + if link, err = c.SlaveLink(); err != nil { + return nil, err + } + } + return c.db.DoPrepare(ctx, link, sql) +} + +// DoPrepare calls prepare function on given link object and returns the statement object. +func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt, err error) { + // Transaction checks. + if link == nil { + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLink{tx.GetSqlTX()} + } else { + // Or else it creates one from master node. + if link, err = c.MasterLink(); err != nil { + return nil, err + } + } + } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + link = &txLink{tx.GetSqlTX()} + } + } + + if c.db.GetConfig().PrepareTimeout > 0 { + // DO NOT USE cancel function in prepare statement. + var cancelFunc context.CancelFunc + ctx, cancelFunc = context.WithTimeout(ctx, c.db.GetConfig().PrepareTimeout) + defer cancelFunc() + } + + // Link execution. + var out DoCommitOutput + out, err = c.db.DoCommit(ctx, DoCommitInput{ + Link: link, + Sql: sql, + Type: SqlTypePrepareContext, + IsTransaction: link.IsTransaction(), + }) + if err != nil { + return nil, err + } + return out.Stmt, err +} + +// FormatUpsert formats and returns SQL clause part for upsert statement. +// In default implements, this function performs upsert statement for MySQL like: +// `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...` +func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) { + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case Raw, *Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + c.QuoteWord(k), + v, + ) + case Counter, *Counter: + var counter Counter + switch value := v.(type) { + case Counter: + counter = value + case *Counter: + counter = *value + } + operator, columnVal := c.getCounterAlter(counter) + onDuplicateStr += fmt.Sprintf( + "%s=%s%s%s", + c.QuoteWord(k), + c.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(k), + c.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's `SAVE` operation, do not automatically update the creating time. + if c.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(column), + c.QuoteWord(column), + ) + } + } + + return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil +} + +// RowsToResult converts underlying data record type sql.Rows to Result type. +func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) { + if rows == nil { + return nil, nil + } + defer func() { + if err := rows.Close(); err != nil { + intlog.Errorf(ctx, `%+v`, err) + } + }() + if !rows.Next() { + return nil, nil + } + // Column names and types. + columnTypes, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + if len(columnTypes) > 0 { + if internalData := c.getInternalColumnFromCtx(ctx); internalData != nil { + internalData.FirstResultColumn = columnTypes[0].Name() + } + } + var ( + values = make([]any, len(columnTypes)) + result = make(Result, 0) + scanArgs = make([]any, len(values)) + ) + for i := range values { + scanArgs[i] = &values[i] + } + for { + if err = rows.Scan(scanArgs...); err != nil { + return result, err + } + record := Record{} + for i, value := range values { + if value == nil { + // DO NOT use `gvar.New(nil)` here as it creates an initialized object + // which will cause struct converting issue. + record[columnTypes[i].Name()] = nil + } else { + var ( + convertedValue any + columnType = columnTypes[i] + ) + if convertedValue, err = c.columnValueToLocalValue(ctx, value, columnType); err != nil { + return nil, err + } + record[columnTypes[i].Name()] = gvar.New(convertedValue) + } + } + result = append(result, record) + if !rows.Next() { + break + } + } + return result, nil +} + +// OrderRandomFunction returns the SQL function for random ordering. +func (c *Core) OrderRandomFunction() string { + return "RAND()" +} + +func (c *Core) columnValueToLocalValue(ctx context.Context, value any, columnType *sql.ColumnType) (any, error) { + var scanType = columnType.ScanType() + if scanType != nil { + // Common basic builtin types. + switch scanType.Kind() { + case + reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return gconv.Convert(gconv.String(value), scanType.String()), nil + default: + } + } + // Other complex types, especially custom types. + return c.db.ConvertValueForLocal(ctx, columnType.DatabaseTypeName(), value) +} diff --git a/database/gdb_core_utility.go b/database/gdb_core_utility.go new file mode 100644 index 0000000..c606954 --- /dev/null +++ b/database/gdb_core_utility.go @@ -0,0 +1,273 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. +// + +package database + +import ( + "context" + "fmt" + "strings" + + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gutil" +) + +// GetDB returns the underlying DB. +func (c *Core) GetDB() DB { + return c.db +} + +// GetLink creates and returns the underlying database link object with transaction checks. +// The parameter `master` specifies whether using the master node if master-slave configured. +func (c *Core) GetLink(ctx context.Context, master bool, schema string) (Link, error) { + tx := TXFromCtx(ctx, c.db.GetGroup()) + if tx != nil { + return &txLink{tx.GetSqlTX()}, nil + } + if master { + link, err := c.db.GetCore().MasterLink(schema) + if err != nil { + return nil, err + } + return link, nil + } + link, err := c.db.GetCore().SlaveLink(schema) + if err != nil { + return nil, err + } + return link, nil +} + +// MasterLink acts like function Master but with additional `schema` parameter specifying +// the schema for the connection. It is defined for internal usage. +// Also see Master. +func (c *Core) MasterLink(schema ...string) (Link, error) { + db, err := c.db.Master(schema...) + if err != nil { + return nil, err + } + return &dbLink{ + DB: db, + isOnMaster: true, + }, nil +} + +// SlaveLink acts like function Slave but with additional `schema` parameter specifying +// the schema for the connection. It is defined for internal usage. +// Also see Slave. +func (c *Core) SlaveLink(schema ...string) (Link, error) { + db, err := c.db.Slave(schema...) + if err != nil { + return nil, err + } + return &dbLink{ + DB: db, + isOnMaster: false, + }, nil +} + +// QuoteWord checks given string `s` a word, +// if true it quotes `s` with security chars of the database +// and returns the quoted string; or else it returns `s` without any change. +// +// The meaning of a `word` can be considered as a column name. +func (c *Core) QuoteWord(s string) string { + s = gstr.Trim(s) + if s == "" { + return s + } + charLeft, charRight := c.db.GetChars() + return doQuoteWord(s, charLeft, charRight) +} + +// QuoteString quotes string with quote chars. Strings like: +// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc". +// +// The meaning of a `string` can be considered as part of a statement string including columns. +func (c *Core) QuoteString(s string) string { + if !gregex.IsMatchString(regularFieldNameWithCommaRegPattern, s) { + return s + } + charLeft, charRight := c.db.GetChars() + return doQuoteString(s, charLeft, charRight) +} + +// QuotePrefixTableName adds prefix string and quotes chars for the table. +// It handles table string like: +// "user", "user u", +// "user,user_detail", +// "user u, user_detail ut", +// "user as u, user_detail as ut". +// +// Note that, this will automatically checks the table prefix whether already added, +// if true it does nothing to the table name, or else adds the prefix to the table name. +func (c *Core) QuotePrefixTableName(table string) string { + charLeft, charRight := c.db.GetChars() + return doQuoteTableName(table, c.db.GetPrefix(), charLeft, charRight) +} + +// GetChars returns the security char for current database. +// It does nothing in default. +func (c *Core) GetChars() (charLeft string, charRight string) { + return "", "" +} + +// Tables retrieves and returns the tables of current schema. +// It's mainly used in cli tool chain for automatically generating the models. +func (c *Core) Tables(ctx context.Context, schema ...string) (tables []string, err error) { + return +} + +// TableFields retrieves and returns the fields' information of specified table of current +// schema. +// +// The parameter `link` is optional, if given nil it automatically retrieves a raw sql connection +// as its link to proceed necessary sql query. +// +// Note that it returns a map containing the field name and its corresponding fields. +// As a map is unsorted, the TableField struct has an "Index" field marks its sequence in +// the fields. +// +// It's using cache feature to enhance the performance, which is never expired util the +// process restarts. +func (c *Core) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) { + return +} + +// ClearTableFields removes certain cached table fields of current configuration group. +func (c *Core) ClearTableFields(ctx context.Context, table string, schema ...string) (err error) { + tableFieldsCacheKey := genTableFieldsCacheKey( + c.db.GetGroup(), + gutil.GetOrDefaultStr(c.db.GetSchema(), schema...), + table, + ) + _, err = c.innerMemCache.Remove(ctx, tableFieldsCacheKey) + return +} + +// ClearTableFieldsAll removes all cached table fields of current configuration group. +func (c *Core) ClearTableFieldsAll(ctx context.Context) (err error) { + var ( + keys, _ = c.innerMemCache.KeyStrings(ctx) + cachePrefix = cachePrefixTableFields + removedKeys = make([]any, 0) + ) + for _, key := range keys { + if gstr.HasPrefix(key, cachePrefix) { + removedKeys = append(removedKeys, key) + } + } + + if len(removedKeys) > 0 { + err = c.innerMemCache.Removes(ctx, removedKeys) + } + return +} + +// ClearCache removes cached sql result of certain table. +func (c *Core) ClearCache(ctx context.Context, table string) (err error) { + var ( + keys, _ = c.db.GetCache().KeyStrings(ctx) + cachePrefix = fmt.Sprintf(`%s%s@`, cachePrefixSelectCache, table) + removedKeys = make([]any, 0) + ) + for _, key := range keys { + if gstr.HasPrefix(key, cachePrefix) { + removedKeys = append(removedKeys, key) + } + } + if len(removedKeys) > 0 { + err = c.db.GetCache().Removes(ctx, removedKeys) + } + return +} + +// ClearCacheAll removes all cached sql result from cache +func (c *Core) ClearCacheAll(ctx context.Context) (err error) { + if err = c.db.GetCache().Clear(ctx); err != nil { + return err + } + if err = c.GetInnerMemCache().Clear(ctx); err != nil { + return err + } + return +} + +// HasField determine whether the field exists in the table. +func (c *Core) HasField(ctx context.Context, table, field string, schema ...string) (bool, error) { + table = c.guessPrimaryTableName(table) + tableFields, err := c.db.TableFields(ctx, table, schema...) + if err != nil { + return false, err + } + if len(tableFields) == 0 { + return false, gerror.NewCodef( + gcode.CodeNotFound, + `empty table fields for table "%s"`, table, + ) + } + fieldsArray := make([]string, len(tableFields)) + for k, v := range tableFields { + fieldsArray[v.Index] = k + } + charLeft, charRight := c.db.GetChars() + field = gstr.Trim(field, charLeft+charRight) + for _, f := range fieldsArray { + if f == field { + return true, nil + } + } + return false, nil +} + +// guessPrimaryTableName parses and returns the primary table name. +func (c *Core) guessPrimaryTableName(tableStr string) string { + if tableStr == "" { + return "" + } + var ( + guessedTableName string + array1 = gstr.SplitAndTrim(tableStr, ",") + array2 = gstr.SplitAndTrim(array1[0], " ") + array3 = gstr.SplitAndTrim(array2[0], ".") + ) + if len(array3) >= 2 { + guessedTableName = array3[1] + } else { + guessedTableName = array3[0] + } + charL, charR := c.db.GetChars() + if charL != "" || charR != "" { + guessedTableName = gstr.Trim(guessedTableName, charL+charR) + } + if !gregex.IsMatchString(regularFieldNameRegPattern, guessedTableName) { + return "" + } + return guessedTableName +} + +// GetPrimaryKeys retrieves and returns the primary key field names of the specified table. +// This method extracts primary key information from TableFields. +// The parameter `schema` is optional, if not specified it uses the default schema. +func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) { + tableFields, err := c.db.TableFields(ctx, table, schema...) + if err != nil { + return nil, err + } + + var primaryKeys []string + for _, field := range tableFields { + if strings.EqualFold(field.Key, "pri") { + primaryKeys = append(primaryKeys, field.Name) + } + } + + return primaryKeys, nil +} diff --git a/database/gdb_dao_interface.go b/database/gdb_dao_interface.go new file mode 100644 index 0000000..2837efb --- /dev/null +++ b/database/gdb_dao_interface.go @@ -0,0 +1,7 @@ +package database + +type IDao interface { + DB() DB + TableName() string + Columns() any +} diff --git a/database/gdb_database.go b/database/gdb_database.go new file mode 100644 index 0000000..ee6e50c --- /dev/null +++ b/database/gdb_database.go @@ -0,0 +1,162 @@ +package database + +import ( + "context" + "fmt" + + "git.magicany.cc/black1552/gin-base/config" + "git.magicany.cc/black1552/gin-base/database/instance" + "git.magicany.cc/black1552/gin-base/database/intlog" + customLog "git.magicany.cc/black1552/gin-base/log" + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/gutil" +) + +const ( + frameCoreComponentNameDatabase = "core.component.database" + ConfigNodeNameDatabase = "database" +) + +func Database(name ...string) gdb.DB { + var ( + ctx = context.Background() + group = gdb.DefaultGroupName + ) + + if len(name) > 0 && name[0] != "" { + group = name[0] + } + instanceKey := fmt.Sprintf("%s.%s", frameCoreComponentNameDatabase, group) + db := instance.GetOrSetFuncLock(instanceKey, func() interface{} { + // It ignores returned error to avoid file no found error while it's not necessary. + var ( + configMap map[string]interface{} + configNodeKey = ConfigNodeNameDatabase + ) + // It firstly searches the configuration of the instance name. + if configData := config.GetAllConfig(); len(configData) > 0 { + if v, _ := gutil.MapPossibleItemByKey(configData, ConfigNodeNameDatabase); v != "" { + configNodeKey = v + } + } + if v := config.GetConfigValue(configNodeKey); !v.IsEmpty() { + configMap = v.Map() + } + // 检查配置是否存在 + if len(configMap) == 0 { + // 从 config 包获取所有配置,检查是否包含数据库配置 + allConfig := config.GetAllConfig() + if len(allConfig) == 0 { + panic(gerror.NewCodef( + gcode.CodeMissingConfiguration, + `database initialization failed: configuration file is empty or invalid`, + )) + } + // 检查是否配置了数据库节点 + if _, exists := allConfig["database"]; !exists { + // 尝试其他可能的键名 + found := false + for key := range allConfig { + if key == "DATABASE" || key == "Database" { + found = true + break + } + } + if !found { + panic(gerror.NewCodef( + gcode.CodeMissingConfiguration, + `database initialization failed: configuration missing for database node "%s"`, + ConfigNodeNameDatabase, + )) + } + } + } + + if len(configMap) == 0 { + configMap = make(map[string]interface{}) + } + // Parse `m` as map-slice and adds it to global configurations for package gdb. + for g, groupConfig := range configMap { + cg := ConfigGroup{} + switch value := groupConfig.(type) { + case []interface{}: + for _, v := range value { + if node := parseDBConfigNode(v); node != nil { + cg = append(cg, *node) + } + } + case map[string]interface{}: + if node := parseDBConfigNode(value); node != nil { + cg = append(cg, *node) + } + } + if len(cg) > 0 { + if GetConfig(group) == nil { + intlog.Printf(ctx, "add configuration for group: %s, %#v", g, cg) + SetConfigGroup(g, cg) + } else { + intlog.Printf(ctx, "ignore configuration as it already exists for group: %s, %#v", g, cg) + intlog.Printf(ctx, "%s, %#v", g, cg) + } + } + } + // Parse `m` as a single node configuration, + // which is the default group configuration. + if node := parseDBConfigNode(configMap); node != nil { + cg := ConfigGroup{} + if node.Link != "" || node.Host != "" { + cg = append(cg, *node) + } + if len(cg) > 0 { + if GetConfig(group) == nil { + intlog.Printf(ctx, "add configuration for group: %s, %#v", DefaultGroupName, cg) + SetConfigGroup(DefaultGroupName, cg) + } else { + intlog.Printf( + ctx, + "ignore configuration as it already exists for group: %s, %#v", + DefaultGroupName, cg, + ) + intlog.Printf(ctx, "%s, %#v", DefaultGroupName, cg) + } + } + } + + // Create a new ORM object with given configurations. + if db, err := NewByGroup(name...); err == nil { + // 自动初始化自定义日志器并设置到数据库 + db.SetLogger(customLog.GetLogger()) + return db + } else { + // If panics, often because it does not find its configuration for given group. + panic(err) + } + return nil + }) + if db != nil { + return db.(gdb.DB) + } + return nil +} + +func parseDBConfigNode(value interface{}) *ConfigNode { + nodeMap, ok := value.(map[string]interface{}) + if !ok { + return nil + } + var ( + node = &ConfigNode{} + err = gconv.Struct(nodeMap, node) + ) + if err != nil { + panic(err) + } + // Find possible `Link` configuration content. + if _, v := gutil.MapPossibleItemByKey(nodeMap, "Link"); v != nil { + node.Link = gconv.String(v) + } + return node +} diff --git a/database/gdb_driver_default.go b/database/gdb_driver_default.go new file mode 100644 index 0000000..124077d --- /dev/null +++ b/database/gdb_driver_default.go @@ -0,0 +1,46 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" +) + +// DriverDefault is the default driver for mysql database, which does nothing. +type DriverDefault struct { + *Core +} + +func init() { + if err := Register("default", &DriverDefault{}); err != nil { + panic(err) + } +} + +// New creates and returns a database object for mysql. +// It implements the interface of gdb.Driver for extra database driver installation. +func (d *DriverDefault) New(core *Core, node *ConfigNode) (DB, error) { + return &DriverDefault{ + Core: core, + }, nil +} + +// Open creates and returns an underlying sql.DB object for mysql. +// Note that it converts time.Time argument to local timezone in default. +func (d *DriverDefault) Open(config *ConfigNode) (db *sql.DB, err error) { + return +} + +// PingMaster pings the master node to check authentication or keeps the connection alive. +func (d *DriverDefault) PingMaster() error { + return nil +} + +// PingSlave pings the slave node to check authentication or keeps the connection alive. +func (d *DriverDefault) PingSlave() error { + return nil +} diff --git a/database/gdb_driver_wrapper.go b/database/gdb_driver_wrapper.go new file mode 100644 index 0000000..b4a1b76 --- /dev/null +++ b/database/gdb_driver_wrapper.go @@ -0,0 +1,31 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// DriverWrapper is a driver wrapper for extending features with embedded driver. +type DriverWrapper struct { + driver Driver +} + +// New creates and returns a database object for mysql. +// It implements the interface of gdb.Driver for extra database driver installation. +func (d *DriverWrapper) New(core *Core, node *ConfigNode) (DB, error) { + db, err := d.driver.New(core, node) + if err != nil { + return nil, err + } + return &DriverWrapperDB{ + DB: db, + }, nil +} + +// newDriverWrapper creates and returns a driver wrapper. +func newDriverWrapper(driver Driver) Driver { + return &DriverWrapper{ + driver: driver, + } +} diff --git a/database/gdb_driver_wrapper_db.go b/database/gdb_driver_wrapper_db.go new file mode 100644 index 0000000..4619d43 --- /dev/null +++ b/database/gdb_driver_wrapper_db.go @@ -0,0 +1,131 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "database/sql" + "fmt" + + "git.magicany.cc/black1552/gin-base/database/intlog" + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gcache" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gutil" +) + +// DriverWrapperDB is a DB wrapper for extending features with embedded DB. +type DriverWrapperDB struct { + DB +} + +// Open creates and returns an underlying sql.DB object for pgsql. +// https://pkg.go.dev/github.com/lib/pq +func (d *DriverWrapperDB) Open(node *ConfigNode) (db *sql.DB, err error) { + var ctx = d.GetCtx() + intlog.PrintFunc(ctx, func() string { + return fmt.Sprintf(`open new connection:%s`, gjson.MustEncode(node)) + }) + return d.DB.Open(node) +} + +// Tables retrieves and returns the tables of current schema. +// It's mainly used in cli tool chain for automatically generating the models. +func (d *DriverWrapperDB) Tables(ctx context.Context, schema ...string) (tables []string, err error) { + ctx = context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{}) + return d.DB.Tables(ctx, schema...) +} + +// TableFields retrieves and returns the fields' information of specified table of current +// schema. +// +// The parameter `link` is optional, if given nil it automatically retrieves a raw sql connection +// as its link to proceed necessary sql query. +// +// Note that it returns a map containing the field name and its corresponding fields. +// As a map is unsorted, the TableField struct has an "Index" field marks its sequence in +// the fields. +// +// It's using cache feature to enhance the performance, which is never expired util the +// process restarts. +func (d *DriverWrapperDB) TableFields( + ctx context.Context, table string, schema ...string, +) (fields map[string]*TableField, err error) { + if table == "" { + return nil, nil + } + charL, charR := d.GetChars() + table = gstr.Trim(table, charL+charR) + if gstr.Contains(table, " ") { + return nil, gerror.NewCode( + gcode.CodeInvalidParameter, + "function TableFields supports only single table operations", + ) + } + var ( + innerMemCache = d.GetCore().GetInnerMemCache() + // prefix:group@schema#table + cacheKey = genTableFieldsCacheKey( + d.GetGroup(), + gutil.GetOrDefaultStr(d.GetSchema(), schema...), + table, + ) + cacheFunc = func(ctx context.Context) (any, error) { + return d.DB.TableFields( + context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{}), + table, schema..., + ) + } + value *gvar.Var + ) + value, err = innerMemCache.GetOrSetFuncLock( + ctx, cacheKey, cacheFunc, gcache.DurationNoExpire, + ) + if err != nil { + return + } + if !value.IsNil() { + fields = value.Val().(map[string]*TableField) + } + return +} + +// DoInsert inserts or updates data for given table. +// This function is usually used for custom interface definition, you do not need call it manually. +// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// +// The parameter `option` values are as follows: +// InsertOptionDefault: just insert, if there's unique/primary key in the data, it returns error; +// InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one; +// InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one; +// InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting; +func (d *DriverWrapperDB) DoInsert( + ctx context.Context, link Link, table string, list List, option DoInsertOption, +) (result sql.Result, err error) { + if len(list) == 0 { + return nil, gerror.NewCodef( + gcode.CodeInvalidRequest, + `data list is empty for %s operation`, + GetInsertOperationByOption(option.InsertOption), + ) + } + + // Convert data type before commit it to underlying db driver. + for i, item := range list { + list[i], err = d.GetCore().ConvertDataForRecord(ctx, item, table) + if err != nil { + return nil, err + } + } + return d.DB.DoInsert(ctx, link, table, list, option) +} diff --git a/database/gdb_func.go b/database/gdb_func.go new file mode 100644 index 0000000..f5ef6f7 --- /dev/null +++ b/database/gdb_func.go @@ -0,0 +1,1017 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "bytes" + "context" + "fmt" + "reflect" + "regexp" + "strings" + "time" + + "git.magicany.cc/black1552/gin-base/database/empty" + "git.magicany.cc/black1552/gin-base/database/intlog" + "git.magicany.cc/black1552/gin-base/database/json" + "git.magicany.cc/black1552/gin-base/database/reflection" + "git.magicany.cc/black1552/gin-base/database/utils" + "github.com/gogf/gf/v2/container/garray" + "github.com/gogf/gf/v2/encoding/ghash" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/os/gstructs" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/gmeta" + "github.com/gogf/gf/v2/util/gtag" + "github.com/gogf/gf/v2/util/gutil" +) + +// iString is the type assert api for String. +type iString interface { + String() string +} + +// iIterator is the type assert api for Iterator. +type iIterator interface { + Iterator(f func(key, value any) bool) +} + +// iInterfaces is the type assert api for Interfaces. +type iInterfaces interface { + Interfaces() []any +} + +// iNil if the type assert api for IsNil. +type iNil interface { + IsNil() bool +} + +// iTableName is the interface for retrieving table name for struct. +type iTableName interface { + TableName() string +} + +const ( + OrmTagForStruct = "orm" + OrmTagForTable = "table" + OrmTagForWith = "with" + OrmTagForWithWhere = "where" + OrmTagForWithOrder = "order" + OrmTagForWithUnscoped = "unscoped" + OrmTagForDo = "do" +) + +var ( + // quoteWordReg is the regular expression object for a word check. + quoteWordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`) + + // structTagPriority tags for struct converting for orm field mapping. + structTagPriority = append([]string{OrmTagForStruct}, gtag.StructTagPriority...) +) + +// WithDB injects given db object into context and returns a new context. +func WithDB(ctx context.Context, db DB) context.Context { + if db == nil { + return ctx + } + dbCtx := db.GetCtx() + if ctxDb := DBFromCtx(dbCtx); ctxDb != nil { + return dbCtx + } + ctx = context.WithValue(ctx, ctxKeyForDB, db) + return ctx +} + +// DBFromCtx retrieves and returns DB object from context. +func DBFromCtx(ctx context.Context) DB { + if ctx == nil { + return nil + } + v := ctx.Value(ctxKeyForDB) + if v != nil { + return v.(DB) + } + return nil +} + +// ToSQL formats and returns the last one of sql statements in given closure function +// WITHOUT TRULY EXECUTING IT. +// Be caution that, all the following sql statements should use the context object passing by function `f`. +func ToSQL(ctx context.Context, f func(ctx context.Context) error) (sql string, err error) { + var manager = &CatchSQLManager{ + SQLArray: garray.NewStrArray(), + DoCommit: false, + } + ctx = context.WithValue(ctx, ctxKeyCatchSQL, manager) + err = f(ctx) + sql, _ = manager.SQLArray.PopRight() + return +} + +// CatchSQL catches and returns all sql statements that are EXECUTED in given closure function. +// Be caution that, all the following sql statements should use the context object passing by function `f`. +func CatchSQL(ctx context.Context, f func(ctx context.Context) error) (sqlArray []string, err error) { + var manager = &CatchSQLManager{ + SQLArray: garray.NewStrArray(), + DoCommit: true, + } + ctx = context.WithValue(ctx, ctxKeyCatchSQL, manager) + err = f(ctx) + return manager.SQLArray.Slice(), err +} + +// isDoStruct checks and returns whether given type is a DO struct. +func isDoStruct(object any) bool { + // It checks by struct name like "XxxForDao", to be compatible with old version. + // TODO remove this compatible codes in future. + reflectType := reflect.TypeOf(object) + if gstr.HasSuffix(reflectType.String(), modelForDaoSuffix) { + return true + } + // It checks by struct meta for DO struct in version. + if ormTag := gmeta.Get(object, OrmTagForStruct); !ormTag.IsEmpty() { + match, _ := gregex.MatchString( + fmt.Sprintf(`%s\s*:\s*([^,]+)`, OrmTagForDo), + ormTag.String(), + ) + if len(match) > 1 { + return gconv.Bool(match[1]) + } + } + return false +} + +// getTableNameFromOrmTag retrieves and returns the table name from struct object. +func getTableNameFromOrmTag(object any) string { + var tableName string + var actualObj = object + + if rv, ok := object.(reflect.Value); ok { + // Check if reflect.Value is valid + if rv.IsValid() && rv.CanInterface() { + actualObj = rv.Interface() + } else { + // If reflect.Value is invalid, we cannot proceed with interface checks + return "" + } + } + + // Check iTableName interface + if actualObj != nil { + if r, ok := actualObj.(iTableName); ok { + return r.TableName() + } + + // User meta data tag "orm". + if ormTag := gmeta.Get(actualObj, OrmTagForStruct); !ormTag.IsEmpty() { + match, _ := gregex.MatchString( + fmt.Sprintf(`%s\s*:\s*([^,]+)`, OrmTagForTable), + ormTag.String(), + ) + if len(match) > 1 { + tableName = match[1] + } + } + + // Use the struct name of snake case. + if tableName == "" { + if t, err := gstructs.StructType(actualObj); err != nil { + panic(err) + } else { + tableName = gstr.CaseSnakeFirstUpper( + gstr.StrEx(t.String(), "."), + ) + } + } + } + + return tableName +} + +// ListItemValues retrieves and returns the elements of all item struct/map with key `key`. +// Note that the parameter `list` should be type of slice which contains elements of map or struct, +// or else it returns an empty slice. +// +// The parameter `list` supports types like: +// []map[string]any +// []map[string]sub-map +// []struct +// []struct:sub-struct +// Note that the sub-map/sub-struct makes sense only if the optional parameter `subKey` is given. +// See gutil.ListItemValues. +func ListItemValues(list any, key any, subKey ...any) (values []any) { + return gutil.ListItemValues(list, key, subKey...) +} + +// ListItemValuesUnique retrieves and returns the unique elements of all struct/map with key `key`. +// Note that the parameter `list` should be type of slice which contains elements of map or struct, +// or else it returns an empty slice. +// See gutil.ListItemValuesUnique. +func ListItemValuesUnique(list any, key string, subKey ...any) []any { + return gutil.ListItemValuesUnique(list, key, subKey...) +} + +// GetInsertOperationByOption returns proper insert option with given parameter `option`. +func GetInsertOperationByOption(option InsertOption) string { + var operator string + switch option { + case InsertOptionReplace: + operator = InsertOperationReplace + case InsertOptionIgnore: + operator = InsertOperationIgnore + default: + operator = InsertOperationInsert + } + return operator +} + +func anyValueToMapBeforeToRecord(value any) map[string]any { + convertedMap := gconv.Map(value, gconv.MapOption{ + Tags: structTagPriority, + OmitEmpty: true, // To be compatible with old version from v2.6.0. + }) + if gutil.OriginValueAndKind(value).OriginKind != reflect.Struct { + return convertedMap + } + // It here converts all struct/map slice attributes to json string. + for k, v := range convertedMap { + originValueAndKind := gutil.OriginValueAndKind(v) + switch originValueAndKind.OriginKind { + // Check map item slice item. + case reflect.Array, reflect.Slice: + mapItemValue := originValueAndKind.OriginValue + if mapItemValue.Len() == 0 { + break + } + // Check slice item type struct/map type. + switch mapItemValue.Index(0).Kind() { + case reflect.Struct, reflect.Map: + mapItemJsonBytes, err := json.Marshal(v) + if err != nil { + // Do not eat any error. + intlog.Error(context.TODO(), err) + } + convertedMap[k] = mapItemJsonBytes + } + } + } + return convertedMap +} + +// MapOrStructToMapDeep converts `value` to map type recursively(if attribute struct is embedded). +// The parameter `value` should be type of *map/map/*struct/struct. +// It supports embedded struct definition for struct. +func MapOrStructToMapDeep(value any, omitempty bool) map[string]any { + m := gconv.Map(value, gconv.MapOption{ + Tags: structTagPriority, + OmitEmpty: omitempty, + }) + for k, v := range m { + switch v.(type) { + case time.Time, *time.Time, gtime.Time, *gtime.Time, gjson.Json, *gjson.Json: + m[k] = v + } + } + return m +} + +// doQuoteTableName adds prefix string and quote chars for table name. It handles table string like: +// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut", +// "user.user u", "`user`.`user` u". +// +// Note that, this will automatically check the table prefix whether already added, if true it does +// nothing to the table name, or else adds the prefix to the table name and returns new table name with prefix. +func doQuoteTableName(table, prefix, charLeft, charRight string) string { + var ( + index int + chars = charLeft + charRight + array1 = gstr.SplitAndTrim(table, ",") + ) + for k1, v1 := range array1 { + array2 := gstr.SplitAndTrim(v1, " ") + // Trim the security chars. + array2[0] = gstr.Trim(array2[0], chars) + // Check whether it has database name. + array3 := gstr.Split(gstr.Trim(array2[0]), ".") + for k, v := range array3 { + array3[k] = gstr.Trim(v, chars) + } + index = len(array3) - 1 + // If the table name already has the prefix, skips the prefix adding. + if len(array3[index]) <= len(prefix) || array3[index][:len(prefix)] != prefix { + array3[index] = prefix + array3[index] + } + array2[0] = gstr.Join(array3, ".") + // Add the security chars. + array2[0] = doQuoteString(array2[0], charLeft, charRight) + array1[k1] = gstr.Join(array2, " ") + } + return gstr.Join(array1, ",") +} + +// doQuoteWord checks given string `s` a word, if true quotes it with `charLeft` and `charRight` +// and returns the quoted string; or else returns `s` without any change. +func doQuoteWord(s, charLeft, charRight string) string { + if quoteWordReg.MatchString(s) && !gstr.ContainsAny(s, charLeft+charRight) { + return charLeft + s + charRight + } + return s +} + +// doQuoteString quotes string with quote chars. +// For example, if quote char is '`': +// "null" => "NULL" +// "user" => "`user`" +// "user u" => "`user` u" +// "user,user_detail" => "`user`,`user_detail`" +// "user u, user_detail ut" => "`user` u,`user_detail` ut" +// "user.user u, user.user_detail ut" => "`user`.`user` u,`user`.`user_detail` ut" +// "u.id, u.name, u.age" => "`u`.`id`,`u`.`name`,`u`.`age`" +// "u.id asc" => "`u`.`id` asc". +func doQuoteString(s, charLeft, charRight string) string { + array1 := gstr.SplitAndTrim(s, ",") + for k1, v1 := range array1 { + array2 := gstr.SplitAndTrim(v1, " ") + array3 := gstr.Split(gstr.Trim(array2[0]), ".") + if len(array3) == 1 { + if strings.EqualFold(array3[0], "NULL") { + array3[0] = doQuoteWord(array3[0], "", "") + } else { + array3[0] = doQuoteWord(array3[0], charLeft, charRight) + } + } else if len(array3) >= 2 { + array3[0] = doQuoteWord(array3[0], charLeft, charRight) + // Note: + // mysql: u.uid + // mssql double dots: Database..Table + array3[len(array3)-1] = doQuoteWord(array3[len(array3)-1], charLeft, charRight) + } + array2[0] = gstr.Join(array3, ".") + array1[k1] = gstr.Join(array2, " ") + } + return gstr.Join(array1, ",") +} + +func getFieldsFromStructOrMap(structOrMap any) (fields []any) { + fields = []any{} + if utils.IsStruct(structOrMap) { + structFields, _ := gstructs.Fields(gstructs.FieldsInput{ + Pointer: structOrMap, + RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, + }) + var ormTagValue string + for _, structField := range structFields { + ormTagValue = structField.Tag(OrmTagForStruct) + ormTagValue = gstr.Split(gstr.Trim(ormTagValue), ",")[0] + if ormTagValue != "" && gregex.IsMatchString(regularFieldNameRegPattern, ormTagValue) { + fields = append(fields, ormTagValue) + } else { + fields = append(fields, structField.Name()) + } + } + } else { + fields = gconv.Interfaces(gutil.Keys(structOrMap)) + } + return +} + +// GetPrimaryKeyCondition returns a new where condition by primary field name. +// The optional parameter `where` is like follows: +// 123 => primary=123 +// []int{1, 2, 3} => primary IN(1,2,3) +// "john" => primary='john' +// []string{"john", "smith"} => primary IN('john','smith') +// g.Map{"id": g.Slice{1,2,3}} => id IN(1,2,3) +// g.Map{"id": 1, "name": "john"} => id=1 AND name='john' +// etc. +// +// Note that it returns the given `where` parameter directly if the `primary` is empty +// or length of `where` > 1. +func GetPrimaryKeyCondition(primary string, where ...any) (newWhereCondition []any) { + if len(where) == 0 { + return nil + } + if primary == "" { + return where + } + if len(where) == 1 { + var ( + rv = reflect.ValueOf(where[0]) + kind = rv.Kind() + ) + if kind == reflect.Pointer { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + case reflect.Map, reflect.Struct: + // Ignore the parameter `primary`. + break + + default: + return []any{map[string]any{ + primary: where[0], + }} + } + } + return where +} + +type formatWhereHolderInput struct { + WhereHolder + OmitNil bool + OmitEmpty bool + Schema string + Table string // Table is used for fields mapping and filtering internally. +} + +func isKeyValueCanBeOmitEmpty(omitEmpty bool, whereType string, key, value any) bool { + if !omitEmpty { + return false + } + // Eg: + // Where("id", []int{}).All() -> SELECT xxx FROM xxx WHERE 0=1 + // Where("name", "").All() -> SELECT xxx FROM xxx WHERE `name`='' + // OmitEmpty().Where("id", []int{}).All() -> SELECT xxx FROM xxx + // OmitEmpty().Where("name", "").All() -> SELECT xxx FROM xxx + // OmitEmpty().Where("1").All() -> SELECT xxx FROM xxx WHERE 1 + switch whereType { + case whereHolderTypeNoArgs: + return false + + case whereHolderTypeIn: + return gutil.IsEmpty(value) + + default: + if gstr.Count(gconv.String(key), "?") == 0 && gutil.IsEmpty(value) { + return true + } + } + return false +} + +// formatWhereHolder formats where statement and its arguments for `Where` and `Having` statements. +func formatWhereHolder(ctx context.Context, db DB, in formatWhereHolderInput) (newWhere string, newArgs []any) { + var ( + buffer = bytes.NewBuffer(nil) + reflectInfo = reflection.OriginValueAndKind(in.Where) + ) + switch reflectInfo.OriginKind { + case reflect.Array, reflect.Slice: + newArgs = formatWhereInterfaces(db, gconv.Interfaces(in.Where), buffer, newArgs) + + case reflect.Map: + for key, value := range MapOrStructToMapDeep(in.Where, true) { + if in.OmitNil && empty.IsNil(value) { + continue + } + if in.OmitEmpty && empty.IsEmpty(value) { + continue + } + newArgs = formatWhereKeyValue(formatWhereKeyValueInput{ + Db: db, + Buffer: buffer, + Args: newArgs, + Key: key, + Value: value, + Prefix: in.Prefix, + Type: in.Type, + }) + } + + case reflect.Struct: + // If the `where` parameter is `DO` struct, it then adds `OmitNil` option for this condition, + // which will filter all nil parameters in `where`. + if isDoStruct(in.Where) { + in.OmitNil = true + } + // If `where` struct implements `iIterator` interface, + // it then uses its Iterate function to iterate its key-value pairs. + // For example, ListMap and TreeMap are ordered map, + // which implement `iIterator` interface and are index-friendly for where conditions. + if iterator, ok := in.Where.(iIterator); ok { + iterator.Iterator(func(key, value any) bool { + ketStr := gconv.String(key) + if in.OmitNil && empty.IsNil(value) { + return true + } + if in.OmitEmpty && empty.IsEmpty(value) { + return true + } + newArgs = formatWhereKeyValue(formatWhereKeyValueInput{ + Db: db, + Buffer: buffer, + Args: newArgs, + Key: ketStr, + Value: value, + OmitEmpty: in.OmitEmpty, + Prefix: in.Prefix, + Type: in.Type, + }) + return true + }) + break + } + // Automatically mapping and filtering the struct attribute. + var ( + reflectType = reflectInfo.OriginValue.Type() + structField reflect.StructField + data = MapOrStructToMapDeep(in.Where, true) + ) + // If `Prefix` is given, it checks and retrieves the table name. + if in.Prefix != "" { + hasTable, _ := db.GetCore().HasTable(in.Prefix) + if hasTable { + in.Table = in.Prefix + } else { + ormTagTableName := getTableNameFromOrmTag(in.Where) + if ormTagTableName != "" { + in.Table = ormTagTableName + } + } + } + // Mapping and filtering fields if `Table` is given. + if in.Table != "" { + data, _ = db.GetCore().mappingAndFilterData(ctx, in.Schema, in.Table, data, true) + } + // Put the struct attributes in sequence in Where statement. + var ormTagValue string + for i := 0; i < reflectType.NumField(); i++ { + structField = reflectType.Field(i) + // Use tag value from `orm` as field name if specified. + ormTagValue = structField.Tag.Get(OrmTagForStruct) + ormTagValue = gstr.Split(gstr.Trim(ormTagValue), ",")[0] + if ormTagValue == "" { + ormTagValue = structField.Name + } + foundKey, foundValue := gutil.MapPossibleItemByKey(data, ormTagValue) + if foundKey != "" { + if in.OmitNil && empty.IsNil(foundValue) { + continue + } + if in.OmitEmpty && empty.IsEmpty(foundValue) { + continue + } + newArgs = formatWhereKeyValue(formatWhereKeyValueInput{ + Db: db, + Buffer: buffer, + Args: newArgs, + Key: foundKey, + Value: foundValue, + OmitEmpty: in.OmitEmpty, + Prefix: in.Prefix, + Type: in.Type, + }) + } + } + + default: + // Where filter. + var omitEmptyCheckValue any + if len(in.Args) == 1 { + omitEmptyCheckValue = in.Args[0] + } else { + omitEmptyCheckValue = in.Args + } + if isKeyValueCanBeOmitEmpty(in.OmitEmpty, in.Type, in.Where, omitEmptyCheckValue) { + return + } + // Usually a string. + whereStr := gstr.Trim(gconv.String(in.Where)) + // Is `whereStr` a field name which composed as a key-value condition? + // Eg: + // Where("id", 1) + // Where("id", g.Slice{1,2,3}) + if gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, whereStr) && len(in.Args) == 1 { + newArgs = formatWhereKeyValue(formatWhereKeyValueInput{ + Db: db, + Buffer: buffer, + Args: newArgs, + Key: whereStr, + Value: in.Args[0], + OmitEmpty: in.OmitEmpty, + Prefix: in.Prefix, + Type: in.Type, + }) + in.Args = in.Args[:0] + break + } + // If the first part is column name, it automatically adds prefix to the column. + if in.Prefix != "" { + array := gstr.Split(whereStr, " ") + if ok, _ := db.GetCore().HasField(ctx, in.Table, array[0]); ok { + whereStr = in.Prefix + "." + whereStr + } + } + // Regular string and parameter place holder handling. + // Eg: + // Where("id in(?) and name=?", g.Slice{1,2,3}, "john") + for i := 0; i < len(in.Args); i++ { + // =============================================================== + // Sub query, which is always used along with a string condition. + // =============================================================== + if subModel, ok := in.Args[i].(*Model); ok { + index := -1 + whereStr = gstr.ReplaceFunc(whereStr, `?`, func(s string) string { + index++ + if i+len(newArgs) == index { + sqlWithHolder, holderArgs := subModel.getHolderAndArgsAsSubModel(ctx) + in.Args = gutil.SliceInsertAfter(in.Args, i, holderArgs...) + // Automatically adding the brackets. + return "(" + sqlWithHolder + ")" + } + return s + }) + in.Args = gutil.SliceDelete(in.Args, i) + continue + } + } + buffer.WriteString(whereStr) + } + + if buffer.Len() == 0 { + return "", in.Args + } + if len(in.Args) > 0 { + newArgs = append(newArgs, in.Args...) + } + newWhere = buffer.String() + if len(newArgs) > 0 { + if gstr.Pos(newWhere, "?") == -1 { + if gregex.IsMatchString(lastOperatorRegPattern, newWhere) { + // Eg: Where/And/Or("uid>=", 1) + newWhere += "?" + } else if gregex.IsMatchString(regularFieldNameRegPattern, newWhere) { + newWhere = db.GetCore().QuoteString(newWhere) + if len(newArgs) > 0 { + if utils.IsArray(newArgs[0]) { + // Eg: + // Where("id", []int{1,2,3}) + // Where("user.id", []int{1,2,3}) + newWhere += " IN (?)" + } else if empty.IsNil(newArgs[0]) { + // Eg: + // Where("id", nil) + // Where("user.id", nil) + newWhere += " IS NULL" + newArgs = nil + } else { + // Eg: + // Where/And/Or("uid", 1) + // Where/And/Or("user.uid", 1) + newWhere += "=?" + } + } + } + } + } + return handleSliceAndStructArgsForSql(newWhere, newArgs) +} + +// formatWhereInterfaces formats `where` as []any. +func formatWhereInterfaces(db DB, where []any, buffer *bytes.Buffer, newArgs []any) []any { + if len(where) == 0 { + return newArgs + } + if len(where)%2 != 0 { + buffer.WriteString(gstr.Join(gconv.Strings(where), "")) + return newArgs + } + var str string + for i := 0; i < len(where); i += 2 { + str = gconv.String(where[i]) + if buffer.Len() > 0 { + buffer.WriteString(" AND " + db.GetCore().QuoteWord(str) + "=?") + } else { + buffer.WriteString(db.GetCore().QuoteWord(str) + "=?") + } + if s, ok := where[i+1].(Raw); ok { + buffer.WriteString(gconv.String(s)) + } else { + newArgs = append(newArgs, where[i+1]) + } + } + return newArgs +} + +type formatWhereKeyValueInput struct { + Db DB // Db is the underlying DB object for current operation. + Buffer *bytes.Buffer // Buffer is the sql statement string without Args for current operation. + Args []any // Args is the full arguments of current operation. + Key string // The field name, eg: "id", "name", etc. + Value any // The field value, can be any types. + Type string // The value in Where type. + OmitEmpty bool // Ignores current condition key if `value` is empty. + Prefix string // Field prefix, eg: "user", "order", etc. +} + +// formatWhereKeyValue handles each key-value pair of the parameter map. +func formatWhereKeyValue(in formatWhereKeyValueInput) (newArgs []any) { + var ( + quotedKey = in.Db.GetCore().QuoteWord(in.Key) + holderCount = gstr.Count(quotedKey, "?") + ) + if isKeyValueCanBeOmitEmpty(in.OmitEmpty, in.Type, quotedKey, in.Value) { + return in.Args + } + if in.Prefix != "" && !gstr.Contains(quotedKey, ".") { + quotedKey = in.Prefix + "." + quotedKey + } + if in.Buffer.Len() > 0 { + in.Buffer.WriteString(" AND ") + } + // If the value is type of slice, and there's only one '?' holder in + // the key string, it automatically adds '?' holder chars according to its arguments count + // and converts it to "IN" statement. + var ( + reflectValue = reflect.ValueOf(in.Value) + reflectKind = reflectValue.Kind() + ) + // Check if the value implements iString interface (like uuid.UUID). + // These types should be treated as single values, not arrays. + if reflectKind == reflect.Array { + if v, ok := in.Value.(iString); ok { + in.Value = v.String() + reflectKind = reflect.String + } + } + switch reflectKind { + // Slice argument. + case reflect.Slice, reflect.Array: + if holderCount == 0 { + in.Buffer.WriteString(quotedKey + " IN(?)") + in.Args = append(in.Args, in.Value) + } else { + if holderCount != reflectValue.Len() { + in.Buffer.WriteString(quotedKey) + in.Args = append(in.Args, in.Value) + } else { + in.Buffer.WriteString(quotedKey) + in.Args = append(in.Args, gconv.Interfaces(in.Value)...) + } + } + + default: + if in.Value == nil || empty.IsNil(reflectValue) { + if gregex.IsMatchString(regularFieldNameRegPattern, in.Key) { + // The key is a single field name. + in.Buffer.WriteString(quotedKey + " IS NULL") + } else { + // The key may have operation chars. + in.Buffer.WriteString(quotedKey) + } + } else { + // It also supports "LIKE" statement, which we consider it an operator. + quotedKey = gstr.Trim(quotedKey) + if gstr.Pos(quotedKey, "?") == -1 { + like := " LIKE" + if len(quotedKey) > len(like) && gstr.Equal(quotedKey[len(quotedKey)-len(like):], like) { + // Eg: Where(g.Map{"name like": "john%"}) + in.Buffer.WriteString(quotedKey + " ?") + } else if gregex.IsMatchString(lastOperatorRegPattern, quotedKey) { + // Eg: Where(g.Map{"age > ": 16}) + in.Buffer.WriteString(quotedKey + " ?") + } else if gregex.IsMatchString(regularFieldNameRegPattern, in.Key) { + // The key is a regular field name. + in.Buffer.WriteString(quotedKey + "=?") + } else { + // The key is not a regular field name. + // Eg: Where(g.Map{"age > 16": nil}) + // Issue: https://github.com/gogf/gf/issues/765 + if empty.IsEmpty(in.Value) { + in.Buffer.WriteString(quotedKey) + break + } else { + in.Buffer.WriteString(quotedKey + "=?") + } + } + } else { + in.Buffer.WriteString(quotedKey) + } + in.Args = append(in.Args, in.Value) + } + } + return in.Args +} + +// handleSliceAndStructArgsForSql is an important function, which handles the sql and all its arguments +// before committing them to underlying driver. +func handleSliceAndStructArgsForSql(oldSql string, oldArgs []any) (newSql string, newArgs []any) { + newSql = oldSql + if len(oldArgs) == 0 { + return + } + // insertHolderCount is used to calculate the inserting position for the '?' holder. + insertHolderCount := 0 + // Handles the slice and struct type argument item. + for index, oldArg := range oldArgs { + argReflectInfo := reflection.OriginValueAndKind(oldArg) + switch argReflectInfo.OriginKind { + case reflect.Slice, reflect.Array: + // It does not split the type of []byte. + // Eg: table.Where("name = ?", []byte("john")) + if _, ok := oldArg.([]byte); ok { + newArgs = append(newArgs, oldArg) + continue + } + // It does not split types that implement fmt.Stringer interface (like uuid.UUID). + // These types should be converted to string instead of being expanded as arrays. + // Eg: table.Where("uuid = ?", uuid.UUID{...}) + if v, ok := oldArg.(iString); ok { + newArgs = append(newArgs, v.String()) + continue + } + var ( + valueHolderCount = gstr.Count(newSql, "?") + argSliceLength = argReflectInfo.OriginValue.Len() + ) + if argSliceLength == 0 { + // Empty slice argument, it converts the sql to a false sql. + // Example: + // Query("select * from xxx where id in(?)", g.Slice{}) -> select * from xxx where 0=1 + // Where("id in(?)", g.Slice{}) -> WHERE 0=1 + if gstr.Contains(newSql, "?") { + whereKeyWord := " WHERE " + if p := gstr.PosI(newSql, whereKeyWord); p == -1 { + return "0=1", []any{} + } else { + return gstr.SubStr(newSql, 0, p+len(whereKeyWord)) + "0=1", []any{} + } + } + } else { + // Example: + // Query("SELECT ?+?", g.Slice{1,2}) + // WHERE("id=?", g.Slice{1,2}) + for i := 0; i < argSliceLength; i++ { + newArgs = append(newArgs, argReflectInfo.OriginValue.Index(i).Interface()) + } + } + + // If the '?' holder count equals the length of the slice, + // it does not implement the arguments splitting logic. + // Eg: db.Query("SELECT ?+?", g.Slice{1, 2}) + if len(oldArgs) == 1 && valueHolderCount == argSliceLength { + break + } + + // counter is used to finding the inserting position for the '?' holder. + var ( + counter = 0 + replaced = false + ) + newSql = gstr.ReplaceFunc(newSql, `?`, func(s string) string { + if replaced { + return s + } + counter++ + if counter == index+insertHolderCount+1 { + replaced = true + insertHolderCount += argSliceLength - 1 + return "?" + strings.Repeat(",?", argSliceLength-1) + } + return s + }) + + // Special struct handling. + case reflect.Struct: + switch v := oldArg.(type) { + // The underlying driver supports time.Time/*time.Time types. + case time.Time, *time.Time: + newArgs = append(newArgs, oldArg) + continue + + case gtime.Time: + newArgs = append(newArgs, v.Time) + continue + + case *gtime.Time: + newArgs = append(newArgs, v.Time) + continue + + default: + // It converts the struct to string in default + // if it has implemented the String interface. + if v, ok := oldArg.(iString); ok { + newArgs = append(newArgs, v.String()) + continue + } + } + newArgs = append(newArgs, oldArg) + + default: + switch oldArg.(type) { + // Do not append Raw arg to args but directly into the sql. + case Raw, *Raw: + var counter = 0 + newSql = gstr.ReplaceFunc(newSql, `?`, func(s string) string { + counter++ + if counter == index+insertHolderCount+1 { + return gconv.String(oldArg) + } + return s + }) + continue + + default: + } + newArgs = append(newArgs, oldArg) + } + } + return +} + +// FormatSqlWithArgs binds the arguments to the sql string and returns a complete +// sql string, just for debugging. +func FormatSqlWithArgs(sql string, args []any) string { + index := -1 + newQuery, _ := gregex.ReplaceStringFunc( + `(\?|:v\d+|\$\d+|@p\d+)`, + sql, + func(s string) string { + index++ + if len(args) > index { + if args[index] == nil { + return "null" + } + // Parameters of type Raw do not require special treatment + if v, ok := args[index].(Raw); ok { + return gconv.String(v) + } + reflectInfo := reflection.OriginValueAndKind(args[index]) + if reflectInfo.OriginKind == reflect.Pointer && + (reflectInfo.OriginValue.IsNil() || !reflectInfo.OriginValue.IsValid()) { + return "null" + } + switch reflectInfo.OriginKind { + case reflect.String, reflect.Map, reflect.Slice, reflect.Array: + return `'` + gstr.QuoteMeta(gconv.String(args[index]), `'`) + `'` + + case reflect.Struct: + if t, ok := args[index].(time.Time); ok { + return `'` + t.Format(`2006-01-02 15:04:05`) + `'` + } + return `'` + gstr.QuoteMeta(gconv.String(args[index]), `'`) + `'` + } + return gconv.String(args[index]) + } + return s + }) + return newQuery +} + +// FormatMultiLineSqlToSingle formats sql template string into one line. +func FormatMultiLineSqlToSingle(sql string) (string, error) { + var err error + // format sql template string. + sql, err = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(sql)) + if err != nil { + return "", err + } + sql, err = gregex.ReplaceString(`\s{2,}`, " ", gstr.Trim(sql)) + if err != nil { + return "", err + } + return sql, nil +} + +// genTableFieldsCacheKey generates cache key for table fields. +func genTableFieldsCacheKey(group, schema, table string) string { + return fmt.Sprintf( + `%s%s@%s#%s`, + cachePrefixTableFields, + group, + schema, + table, + ) +} + +// genSelectCacheKey generates cache key for select. +func genSelectCacheKey(table, group, schema, name, sql string, args ...any) string { + if name == "" { + name = fmt.Sprintf( + `%s@%s#%s:%d`, + table, + group, + schema, + ghash.BKDR64([]byte(sql+", @PARAMS:"+gconv.String(args))), + ) + } + return fmt.Sprintf(`%s%s`, cachePrefixSelectCache, name) +} + +// genTableNamesCacheKey generates cache key for table names. +func genTableNamesCacheKey(group string) string { + return fmt.Sprintf(`Tables:%s`, group) +} + +// genSoftTimeFieldNameTypeCacheKey generates cache key for soft time field name and type. +func genSoftTimeFieldNameTypeCacheKey(schema, table string, candidateFields []string) string { + return fmt.Sprintf(`getSoftFieldNameAndType:%s#%s#%s`, schema, table, strings.Join(candidateFields, "_")) +} diff --git a/database/gdb_model.go b/database/gdb_model.go new file mode 100644 index 0000000..c0760ae --- /dev/null +++ b/database/gdb_model.go @@ -0,0 +1,350 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "fmt" + + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// Model is core struct implementing the DAO for ORM. +type Model struct { + db DB // Underlying DB interface. + tx TX // Underlying TX interface. + rawSql string // rawSql is the raw SQL string which marks a raw SQL based Model not a table based Model. + schema string // Custom database schema. + linkType int // Mark for operation on master or slave. + tablesInit string // Table names when model initialization. + tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud". + fields []any // Operation fields, multiple fields joined using char ','. + fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering. + withArray []any // Arguments for With feature. + withAll bool // Enable model association operations on all objects that have "with" tag in the struct. + extraArgs []any // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver. + whereBuilder *WhereBuilder // Condition builder for where operation. + groupBy string // Used for "group by" statement. + orderBy string // Used for "order by" statement. + having []any // Used for "having..." statement. + start int // Used for "select ... start, limit ..." statement. + limit int // Used for "select ... start, limit ..." statement. + option int // Option for extra operation features. + offset int // Offset statement for some databases grammar. + partition string // Partition table partition name. + data any // Data for operation, which can be type of map/[]map/struct/*struct/string, etc. + batch int // Batch number for batch Insert/Replace/Save operations. + filter bool // Filter data and where key-value pairs according to the fields of the table. + distinct string // Force the query to only return distinct results. + lockInfo string // Lock for update or in shared lock. + cacheEnabled bool // Enable sql result cache feature, which is mainly for indicating cache duration(especially 0) usage. + cacheOption CacheOption // Cache option for query statement. + pageCacheOption []CacheOption // Cache option for paging query statement. + hookHandler HookHandler // Hook functions for model hook feature. + unscoped bool // Disables soft deleting features when select/delete operations. + safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. + onDuplicate any // onDuplicate is used for on Upsert clause. + onDuplicateEx any // onDuplicateEx is used for excluding some columns on Upsert clause. + onConflict any // onConflict is used for conflict keys on Upsert clause. + tableAliasMap map[string]string // Table alias to true table name, usually used in join statements. + softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model. + shardingConfig ShardingConfig // ShardingConfig for database/table sharding feature. + shardingValue any // Sharding value for sharding feature. +} + +// ModelHandler is a function that handles given Model and returns a new Model that is custom modified. +type ModelHandler func(m *Model) *Model + +// ChunkHandler is a function that is used in function Chunk, which handles given Result and error. +// It returns true if it wants to continue chunking, or else it returns false to stop chunking. +type ChunkHandler func(result Result, err error) bool + +const ( + linkTypeMaster = 1 + linkTypeSlave = 2 + defaultField = "*" + whereHolderOperatorWhere = 1 + whereHolderOperatorAnd = 2 + whereHolderOperatorOr = 3 + whereHolderTypeDefault = "Default" + whereHolderTypeNoArgs = "NoArgs" + whereHolderTypeIn = "In" +) + +// Model creates and returns a new ORM model from given schema. +// The parameter `tableNameQueryOrStruct` can be more than one table names, and also alias name, like: +// 1. Model names: +// db.Model("user") +// db.Model("user u") +// db.Model("user, user_detail") +// db.Model("user u, user_detail ud") +// 2. Model name with alias: +// db.Model("user", "u") +// 3. Model name with sub-query: +// db.Model("? AS a, ? AS b", subQuery1, subQuery2) +func (c *Core) Model(tableNameQueryOrStruct ...any) *Model { + var ( + ctx = c.db.GetCtx() + tableStr string + tableName string + extraArgs []any + ) + // Model creation with sub-query. + if len(tableNameQueryOrStruct) > 1 { + conditionStr := gconv.String(tableNameQueryOrStruct[0]) + if gstr.Contains(conditionStr, "?") { + whereHolder := WhereHolder{ + Where: conditionStr, + Args: tableNameQueryOrStruct[1:], + } + tableStr, extraArgs = formatWhereHolder(ctx, c.db, formatWhereHolderInput{ + WhereHolder: whereHolder, + OmitNil: false, + OmitEmpty: false, + Schema: "", + Table: "", + }) + } + } + // Normal model creation. + if tableStr == "" { + tableNames := make([]string, len(tableNameQueryOrStruct)) + for k, v := range tableNameQueryOrStruct { + if s, ok := v.(string); ok { + tableNames[k] = s + } else if tableName = getTableNameFromOrmTag(v); tableName != "" { + tableNames[k] = tableName + } + } + if len(tableNames) > 1 { + tableStr = fmt.Sprintf( + `%s AS %s`, c.QuotePrefixTableName(tableNames[0]), c.QuoteWord(tableNames[1]), + ) + } else if len(tableNames) == 1 { + tableStr = c.QuotePrefixTableName(tableNames[0]) + } + } + m := &Model{ + db: c.db, + schema: c.schema, + tablesInit: tableStr, + tables: tableStr, + start: -1, + offset: -1, + filter: true, + extraArgs: extraArgs, + tableAliasMap: make(map[string]string), + } + m.whereBuilder = m.Builder() + if defaultModelSafe { + m.safe = true + } + return m +} + +// Raw creates and returns a model based on a raw sql not a table. +// Example: +// +// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result) +func (c *Core) Raw(rawSql string, args ...any) *Model { + model := c.Model() + model.rawSql = rawSql + model.extraArgs = args + return model +} + +// Raw sets current model as a raw sql model. +// Example: +// +// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result) +// +// See Core.Raw. +func (m *Model) Raw(rawSql string, args ...any) *Model { + model := m.db.Raw(rawSql, args...) + model.db = m.db + model.tx = m.tx + return model +} + +func (tx *TXCore) Raw(rawSql string, args ...any) *Model { + return tx.Model().Raw(rawSql, args...) +} + +// With creates and returns an ORM model based on metadata of given object. +func (c *Core) With(objects ...any) *Model { + return c.db.Model().With(objects...) +} + +// Partition sets Partition name. +// Example: +// dao.User.Ctx(ctx).Partition("p1","p2","p3").All() +func (m *Model) Partition(partitions ...string) *Model { + model := m.getModel() + model.partition = gstr.Join(partitions, ",") + return model +} + +// Model acts like Core.Model except it operates on transaction. +// See Core.Model. +func (tx *TXCore) Model(tableNameQueryOrStruct ...any) *Model { + model := tx.db.Model(tableNameQueryOrStruct...) + model.db = tx.db + model.tx = tx + return model +} + +// With acts like Core.With except it operates on transaction. +// See Core.With. +func (tx *TXCore) With(object any) *Model { + return tx.Model().With(object) +} + +// Ctx sets the context for current operation. +func (m *Model) Ctx(ctx context.Context) *Model { + if ctx == nil { + return m + } + model := m.getModel() + model.db = model.db.Ctx(ctx) + if m.tx != nil { + model.tx = model.tx.Ctx(ctx) + } + return model +} + +// GetCtx returns the context for current Model. +// It returns `context.Background()` is there's no context previously set. +func (m *Model) GetCtx() context.Context { + if m.tx != nil && m.tx.GetCtx() != nil { + return m.tx.GetCtx() + } + return m.db.GetCtx() +} + +// As sets an alias name for current table. +func (m *Model) As(as string) *Model { + if m.tables != "" { + model := m.getModel() + split := " JOIN " + if gstr.ContainsI(model.tables, split) { + // For join table. + array := gstr.Split(model.tables, split) + array[len(array)-1], _ = gregex.ReplaceString(`(.+) ON`, fmt.Sprintf(`$1 AS %s ON`, as), array[len(array)-1]) + model.tables = gstr.Join(array, split) + } else { + // For base table. + model.tables = gstr.TrimRight(model.tables) + " AS " + as + } + return model + } + return m +} + +// DB sets/changes the db object for current operation. +func (m *Model) DB(db DB) *Model { + model := m.getModel() + model.db = db + return model +} + +// TX sets/changes the transaction for current operation. +func (m *Model) TX(tx TX) *Model { + model := m.getModel() + model.db = tx.GetDB() + model.tx = tx + return model +} + +// Schema sets the schema for current operation. +func (m *Model) Schema(schema string) *Model { + model := m.getModel() + model.schema = schema + return model +} + +// Clone creates and returns a new model which is a Clone of current model. +// Note that it uses deep-copy for the Clone. +func (m *Model) Clone() *Model { + newModel := (*Model)(nil) + if m.tx != nil { + newModel = m.tx.Model(m.tablesInit) + } else { + newModel = m.db.Model(m.tablesInit) + } + // Basic attributes copy. + *newModel = *m + // WhereBuilder copy, note the attribute pointer. + newModel.whereBuilder = m.whereBuilder.Clone() + newModel.whereBuilder.model = newModel + // Shallow copy slice attributes. + if n := len(m.fields); n > 0 { + newModel.fields = make([]any, n) + copy(newModel.fields, m.fields) + } + if n := len(m.fieldsEx); n > 0 { + newModel.fieldsEx = make([]any, n) + copy(newModel.fieldsEx, m.fieldsEx) + } + if n := len(m.extraArgs); n > 0 { + newModel.extraArgs = make([]any, n) + copy(newModel.extraArgs, m.extraArgs) + } + if n := len(m.withArray); n > 0 { + newModel.withArray = make([]any, n) + copy(newModel.withArray, m.withArray) + } + if n := len(m.having); n > 0 { + newModel.having = make([]any, n) + copy(newModel.having, m.having) + } + return newModel +} + +// Master marks the following operation on master node. +func (m *Model) Master() *Model { + model := m.getModel() + model.linkType = linkTypeMaster + return model +} + +// Slave marks the following operation on slave node. +// Note that it makes sense only if there's any slave node configured. +func (m *Model) Slave() *Model { + model := m.getModel() + model.linkType = linkTypeSlave + return model +} + +// Safe marks this model safe or unsafe. If safe is true, it clones and returns a new model object +// whenever the operation done, or else it changes the attribute of current model. +func (m *Model) Safe(safe ...bool) *Model { + if len(safe) > 0 { + m.safe = safe[0] + } else { + m.safe = true + } + return m +} + +// Args sets custom arguments for model operation. +func (m *Model) Args(args ...any) *Model { + model := m.getModel() + model.extraArgs = append(model.extraArgs, args) + return model +} + +// Handler calls each of `handlers` on current Model and returns a new Model. +// ModelHandler is a function that handles given Model and returns a new Model that is custom modified. +func (m *Model) Handler(handlers ...ModelHandler) *Model { + model := m.getModel() + for _, handler := range handlers { + model = handler(model) + } + return model +} diff --git a/database/gdb_model_builder.go b/database/gdb_model_builder.go new file mode 100644 index 0000000..6298cca --- /dev/null +++ b/database/gdb_model_builder.go @@ -0,0 +1,124 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "fmt" +) + +// WhereBuilder holds multiple where conditions in a group. +type WhereBuilder struct { + model *Model // A WhereBuilder should be bound to certain Model. + whereHolder []WhereHolder // Condition strings for where operation. +} + +// WhereHolder is the holder for where condition preparing. +type WhereHolder struct { + Type string // Type of this holder. + Operator int // Operator for this holder. + Where any // Where parameter, which can commonly be type of string/map/struct. + Args []any // Arguments for where parameter. + Prefix string // Field prefix, eg: "user.", "order.". +} + +// Builder creates and returns a WhereBuilder. Please note that the builder is chain-safe. +func (m *Model) Builder() *WhereBuilder { + b := &WhereBuilder{ + model: m, + whereHolder: make([]WhereHolder, 0), + } + return b +} + +// getBuilder creates and returns a cloned WhereBuilder of current WhereBuilder +func (b *WhereBuilder) getBuilder() *WhereBuilder { + return b.Clone() +} + +// Clone clones and returns a WhereBuilder that is a copy of current one. +func (b *WhereBuilder) Clone() *WhereBuilder { + newBuilder := b.model.Builder() + newBuilder.whereHolder = make([]WhereHolder, len(b.whereHolder)) + copy(newBuilder.whereHolder, b.whereHolder) + return newBuilder +} + +// Build builds current WhereBuilder and returns the condition string and parameters. +func (b *WhereBuilder) Build() (conditionWhere string, conditionArgs []any) { + var ( + ctx = b.model.GetCtx() + autoPrefix = b.model.getAutoPrefix() + tableForMappingAndFiltering = b.model.tables + ) + if len(b.whereHolder) > 0 { + for _, holder := range b.whereHolder { + if holder.Prefix == "" { + holder.Prefix = autoPrefix + } + switch holder.Operator { + case whereHolderOperatorWhere, whereHolderOperatorAnd: + newWhere, newArgs := formatWhereHolder(ctx, b.model.db, formatWhereHolderInput{ + WhereHolder: holder, + OmitNil: b.model.option&optionOmitNilWhere > 0, + OmitEmpty: b.model.option&optionOmitEmptyWhere > 0, + Schema: b.model.schema, + Table: tableForMappingAndFiltering, + }) + if len(newWhere) > 0 { + if len(conditionWhere) == 0 { + conditionWhere = newWhere + } else if conditionWhere[0] == '(' { + conditionWhere = fmt.Sprintf(`%s AND (%s)`, conditionWhere, newWhere) + } else { + conditionWhere = fmt.Sprintf(`(%s) AND (%s)`, conditionWhere, newWhere) + } + conditionArgs = append(conditionArgs, newArgs...) + } + + case whereHolderOperatorOr: + newWhere, newArgs := formatWhereHolder(ctx, b.model.db, formatWhereHolderInput{ + WhereHolder: holder, + OmitNil: b.model.option&optionOmitNilWhere > 0, + OmitEmpty: b.model.option&optionOmitEmptyWhere > 0, + Schema: b.model.schema, + Table: tableForMappingAndFiltering, + }) + if len(newWhere) > 0 { + if len(conditionWhere) == 0 { + conditionWhere = newWhere + } else if conditionWhere[0] == '(' { + conditionWhere = fmt.Sprintf(`%s OR (%s)`, conditionWhere, newWhere) + } else { + conditionWhere = fmt.Sprintf(`(%s) OR (%s)`, conditionWhere, newWhere) + } + conditionArgs = append(conditionArgs, newArgs...) + } + } + } + } + return +} + +// convertWhereBuilder converts parameter `where` to condition string and parameters if `where` is also a WhereBuilder. +func (b *WhereBuilder) convertWhereBuilder(where any, args []any) (newWhere any, newArgs []any) { + var builder *WhereBuilder + switch v := where.(type) { + case WhereBuilder: + builder = &v + + case *WhereBuilder: + builder = v + } + if builder != nil { + conditionWhere, conditionArgs := builder.Build() + if conditionWhere != "" && (len(b.whereHolder) == 0 || len(builder.whereHolder) > 1) { + conditionWhere = "(" + conditionWhere + ")" + } + return conditionWhere, conditionArgs + } + return where, args +} diff --git a/database/gdb_model_builder_where.go b/database/gdb_model_builder_where.go new file mode 100644 index 0000000..5812041 --- /dev/null +++ b/database/gdb_model_builder_where.go @@ -0,0 +1,171 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "fmt" + + "github.com/gogf/gf/v2/text/gstr" +) + +// doWhereType sets the condition statement for the model. The parameter `where` can be type of +// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times, +// multiple conditions will be joined into where statement using "AND". +func (b *WhereBuilder) doWhereType(whereType string, where any, args ...any) *WhereBuilder { + where, args = b.convertWhereBuilder(where, args) + + builder := b.getBuilder() + if builder.whereHolder == nil { + builder.whereHolder = make([]WhereHolder, 0) + } + if whereType == "" { + if len(args) == 0 { + whereType = whereHolderTypeNoArgs + } else { + whereType = whereHolderTypeDefault + } + } + builder.whereHolder = append(builder.whereHolder, WhereHolder{ + Type: whereType, + Operator: whereHolderOperatorWhere, + Where: where, + Args: args, + }) + return builder +} + +// doWherefType builds condition string using fmt.Sprintf and arguments. +// Note that if the number of `args` is more than the placeholder in `format`, +// the extra `args` will be used as the where condition arguments of the Model. +func (b *WhereBuilder) doWherefType(t string, format string, args ...any) *WhereBuilder { + var ( + placeHolderCount = gstr.Count(format, "?") + conditionStr = fmt.Sprintf(format, args[:len(args)-placeHolderCount]...) + ) + return b.doWhereType(t, conditionStr, args[len(args)-placeHolderCount:]...) +} + +// Where sets the condition statement for the builder. The parameter `where` can be type of +// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times, +// multiple conditions will be joined into where statement using "AND". +// Eg: +// Where("uid=10000") +// Where("uid", 10000) +// Where("money>? AND name like ?", 99999, "vip_%") +// Where("uid", 1).Where("name", "john") +// Where("status IN (?)", g.Slice{1,2,3}) +// Where("age IN(?,?)", 18, 50) +// Where(User{ Id : 1, UserName : "john"}). +func (b *WhereBuilder) Where(where any, args ...any) *WhereBuilder { + return b.doWhereType(``, where, args...) +} + +// Wheref builds condition string using fmt.Sprintf and arguments. +// Note that if the number of `args` is more than the placeholder in `format`, +// the extra `args` will be used as the where condition arguments of the Model. +// Eg: +// Wheref(`amount and status=%s`, "paid", 100) => WHERE `amount`<100 and status='paid' +// Wheref(`amount<%d and status=%s`, 100, "paid") => WHERE `amount`<100 and status='paid' +func (b *WhereBuilder) Wheref(format string, args ...any) *WhereBuilder { + return b.doWherefType(``, format, args...) +} + +// WherePri does the same logic as Model.Where except that if the parameter `where` +// is a single condition like int/string/float/slice, it treats the condition as the primary +// key value. That is, if primary key is "id" and given `where` parameter as "123", the +// WherePri function treats the condition as "id=123", but Model.Where treats the condition +// as string "123". +func (b *WhereBuilder) WherePri(where any, args ...any) *WhereBuilder { + if len(args) > 0 { + return b.Where(where, args...) + } + newWhere := GetPrimaryKeyCondition(b.model.getPrimaryKey(), where) + return b.Where(newWhere[0], newWhere[1:]...) +} + +// WhereLT builds `column < value` statement. +func (b *WhereBuilder) WhereLT(column string, value any) *WhereBuilder { + return b.Wheref(`%s < ?`, b.model.QuoteWord(column), value) +} + +// WhereLTE builds `column <= value` statement. +func (b *WhereBuilder) WhereLTE(column string, value any) *WhereBuilder { + return b.Wheref(`%s <= ?`, b.model.QuoteWord(column), value) +} + +// WhereGT builds `column > value` statement. +func (b *WhereBuilder) WhereGT(column string, value any) *WhereBuilder { + return b.Wheref(`%s > ?`, b.model.QuoteWord(column), value) +} + +// WhereGTE builds `column >= value` statement. +func (b *WhereBuilder) WhereGTE(column string, value any) *WhereBuilder { + return b.Wheref(`%s >= ?`, b.model.QuoteWord(column), value) +} + +// WhereBetween builds `column BETWEEN min AND max` statement. +func (b *WhereBuilder) WhereBetween(column string, min, max any) *WhereBuilder { + return b.Wheref(`%s BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max) +} + +// WhereLike builds `column LIKE like` statement. +func (b *WhereBuilder) WhereLike(column string, like string) *WhereBuilder { + return b.Wheref(`%s LIKE ?`, b.model.QuoteWord(column), like) +} + +// WhereIn builds `column IN (in)` statement. +func (b *WhereBuilder) WhereIn(column string, in any) *WhereBuilder { + return b.doWherefType(whereHolderTypeIn, `%s IN (?)`, b.model.QuoteWord(column), in) +} + +// WhereNull builds `columns[0] IS NULL AND columns[1] IS NULL ...` statement. +func (b *WhereBuilder) WhereNull(columns ...string) *WhereBuilder { + builder := b + for _, column := range columns { + builder = builder.Wheref(`%s IS NULL`, b.model.QuoteWord(column)) + } + return builder +} + +// WhereNotBetween builds `column NOT BETWEEN min AND max` statement. +func (b *WhereBuilder) WhereNotBetween(column string, min, max any) *WhereBuilder { + return b.Wheref(`%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max) +} + +// WhereNotLike builds `column NOT LIKE like` statement. +func (b *WhereBuilder) WhereNotLike(column string, like any) *WhereBuilder { + return b.Wheref(`%s NOT LIKE ?`, b.model.QuoteWord(column), like) +} + +// WhereNot builds `column != value` statement. +func (b *WhereBuilder) WhereNot(column string, value any) *WhereBuilder { + return b.Wheref(`%s != ?`, b.model.QuoteWord(column), value) +} + +// WhereNotIn builds `column NOT IN (in)` statement. +func (b *WhereBuilder) WhereNotIn(column string, in any) *WhereBuilder { + return b.doWherefType(whereHolderTypeIn, `%s NOT IN (?)`, b.model.QuoteWord(column), in) +} + +// WhereNotNull builds `columns[0] IS NOT NULL AND columns[1] IS NOT NULL ...` statement. +func (b *WhereBuilder) WhereNotNull(columns ...string) *WhereBuilder { + builder := b + for _, column := range columns { + builder = builder.Wheref(`%s IS NOT NULL`, b.model.QuoteWord(column)) + } + return builder +} + +// WhereExists builds `EXISTS (subQuery)` statement. +func (b *WhereBuilder) WhereExists(subQuery *Model) *WhereBuilder { + return b.Wheref(`EXISTS (?)`, subQuery) +} + +// WhereNotExists builds `NOT EXISTS (subQuery)` statement. +func (b *WhereBuilder) WhereNotExists(subQuery *Model) *WhereBuilder { + return b.Wheref(`NOT EXISTS (?)`, subQuery) +} diff --git a/database/gdb_model_builder_where_prefix.go b/database/gdb_model_builder_where_prefix.go new file mode 100644 index 0000000..1314d2a --- /dev/null +++ b/database/gdb_model_builder_where_prefix.go @@ -0,0 +1,101 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// WherePrefix performs as Where, but it adds prefix to each field in where statement. +// Eg: +// WherePrefix("order", "status", "paid") => WHERE `order`.`status`='paid' +// WherePrefix("order", struct{Status:"paid", "channel":"bank"}) => WHERE `order`.`status`='paid' AND `order`.`channel`='bank' +func (b *WhereBuilder) WherePrefix(prefix string, where any, args ...any) *WhereBuilder { + where, args = b.convertWhereBuilder(where, args) + + builder := b.getBuilder() + if builder.whereHolder == nil { + builder.whereHolder = make([]WhereHolder, 0) + } + builder.whereHolder = append(builder.whereHolder, WhereHolder{ + Type: whereHolderTypeDefault, + Operator: whereHolderOperatorWhere, + Where: where, + Args: args, + Prefix: prefix, + }) + return builder +} + +// WherePrefixLT builds `prefix.column < value` statement. +func (b *WhereBuilder) WherePrefixLT(prefix string, column string, value any) *WhereBuilder { + return b.Wheref(`%s.%s < ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WherePrefixLTE builds `prefix.column <= value` statement. +func (b *WhereBuilder) WherePrefixLTE(prefix string, column string, value any) *WhereBuilder { + return b.Wheref(`%s.%s <= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WherePrefixGT builds `prefix.column > value` statement. +func (b *WhereBuilder) WherePrefixGT(prefix string, column string, value any) *WhereBuilder { + return b.Wheref(`%s.%s > ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WherePrefixGTE builds `prefix.column >= value` statement. +func (b *WhereBuilder) WherePrefixGTE(prefix string, column string, value any) *WhereBuilder { + return b.Wheref(`%s.%s >= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WherePrefixBetween builds `prefix.column BETWEEN min AND max` statement. +func (b *WhereBuilder) WherePrefixBetween(prefix string, column string, min, max any) *WhereBuilder { + return b.Wheref(`%s.%s BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max) +} + +// WherePrefixLike builds `prefix.column LIKE like` statement. +func (b *WhereBuilder) WherePrefixLike(prefix string, column string, like any) *WhereBuilder { + return b.Wheref(`%s.%s LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like) +} + +// WherePrefixIn builds `prefix.column IN (in)` statement. +func (b *WhereBuilder) WherePrefixIn(prefix string, column string, in any) *WhereBuilder { + return b.doWherefType(whereHolderTypeIn, `%s.%s IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in) +} + +// WherePrefixNull builds `prefix.columns[0] IS NULL AND prefix.columns[1] IS NULL ...` statement. +func (b *WhereBuilder) WherePrefixNull(prefix string, columns ...string) *WhereBuilder { + builder := b + for _, column := range columns { + builder = builder.Wheref(`%s.%s IS NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column)) + } + return builder +} + +// WherePrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement. +func (b *WhereBuilder) WherePrefixNotBetween(prefix string, column string, min, max any) *WhereBuilder { + return b.Wheref(`%s.%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max) +} + +// WherePrefixNotLike builds `prefix.column NOT LIKE like` statement. +func (b *WhereBuilder) WherePrefixNotLike(prefix string, column string, like any) *WhereBuilder { + return b.Wheref(`%s.%s NOT LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like) +} + +// WherePrefixNot builds `prefix.column != value` statement. +func (b *WhereBuilder) WherePrefixNot(prefix string, column string, value any) *WhereBuilder { + return b.Wheref(`%s.%s != ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WherePrefixNotIn builds `prefix.column NOT IN (in)` statement. +func (b *WhereBuilder) WherePrefixNotIn(prefix string, column string, in any) *WhereBuilder { + return b.doWherefType(whereHolderTypeIn, `%s.%s NOT IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in) +} + +// WherePrefixNotNull builds `prefix.columns[0] IS NOT NULL AND prefix.columns[1] IS NOT NULL ...` statement. +func (b *WhereBuilder) WherePrefixNotNull(prefix string, columns ...string) *WhereBuilder { + builder := b + for _, column := range columns { + builder = builder.Wheref(`%s.%s IS NOT NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column)) + } + return builder +} diff --git a/database/gdb_model_builder_whereor.go b/database/gdb_model_builder_whereor.go new file mode 100644 index 0000000..58a1804 --- /dev/null +++ b/database/gdb_model_builder_whereor.go @@ -0,0 +1,125 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "fmt" + + "github.com/gogf/gf/v2/text/gstr" +) + +// WhereOr adds "OR" condition to the where statement. +func (b *WhereBuilder) doWhereOrType(t string, where any, args ...any) *WhereBuilder { + where, args = b.convertWhereBuilder(where, args) + + builder := b.getBuilder() + if builder.whereHolder == nil { + builder.whereHolder = make([]WhereHolder, 0) + } + builder.whereHolder = append(builder.whereHolder, WhereHolder{ + Type: t, + Operator: whereHolderOperatorOr, + Where: where, + Args: args, + }) + return builder +} + +// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments. +func (b *WhereBuilder) doWhereOrfType(t string, format string, args ...any) *WhereBuilder { + var ( + placeHolderCount = gstr.Count(format, "?") + conditionStr = fmt.Sprintf(format, args[:len(args)-placeHolderCount]...) + ) + return b.doWhereOrType(t, conditionStr, args[len(args)-placeHolderCount:]...) +} + +// WhereOr adds "OR" condition to the where statement. +func (b *WhereBuilder) WhereOr(where any, args ...any) *WhereBuilder { + return b.doWhereOrType(``, where, args...) +} + +// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments. +// Eg: +// WhereOrf(`amount and status=%s`, "paid", 100) => WHERE xxx OR `amount`<100 and status='paid' +// WhereOrf(`amount<%d and status=%s`, 100, "paid") => WHERE xxx OR `amount`<100 and status='paid' +func (b *WhereBuilder) WhereOrf(format string, args ...any) *WhereBuilder { + return b.doWhereOrfType(``, format, args...) +} + +// WhereOrNot builds `column != value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrNot(column string, value any) *WhereBuilder { + return b.WhereOrf(`%s != ?`, column, value) +} + +// WhereOrLT builds `column < value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrLT(column string, value any) *WhereBuilder { + return b.WhereOrf(`%s < ?`, column, value) +} + +// WhereOrLTE builds `column <= value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrLTE(column string, value any) *WhereBuilder { + return b.WhereOrf(`%s <= ?`, column, value) +} + +// WhereOrGT builds `column > value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrGT(column string, value any) *WhereBuilder { + return b.WhereOrf(`%s > ?`, column, value) +} + +// WhereOrGTE builds `column >= value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrGTE(column string, value any) *WhereBuilder { + return b.WhereOrf(`%s >= ?`, column, value) +} + +// WhereOrBetween builds `column BETWEEN min AND max` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrBetween(column string, min, max any) *WhereBuilder { + return b.WhereOrf(`%s BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max) +} + +// WhereOrLike builds `column LIKE 'like'` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrLike(column string, like any) *WhereBuilder { + return b.WhereOrf(`%s LIKE ?`, b.model.QuoteWord(column), like) +} + +// WhereOrIn builds `column IN (in)` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrIn(column string, in any) *WhereBuilder { + return b.doWhereOrfType(whereHolderTypeIn, `%s IN (?)`, b.model.QuoteWord(column), in) +} + +// WhereOrNull builds `columns[0] IS NULL OR columns[1] IS NULL ...` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrNull(columns ...string) *WhereBuilder { + var builder *WhereBuilder + for _, column := range columns { + builder = b.WhereOrf(`%s IS NULL`, b.model.QuoteWord(column)) + } + return builder +} + +// WhereOrNotBetween builds `column NOT BETWEEN min AND max` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrNotBetween(column string, min, max any) *WhereBuilder { + return b.WhereOrf(`%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max) +} + +// WhereOrNotLike builds `column NOT LIKE like` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrNotLike(column string, like any) *WhereBuilder { + return b.WhereOrf(`%s NOT LIKE ?`, b.model.QuoteWord(column), like) +} + +// WhereOrNotIn builds `column NOT IN (in)` statement. +func (b *WhereBuilder) WhereOrNotIn(column string, in any) *WhereBuilder { + return b.doWhereOrfType(whereHolderTypeIn, `%s NOT IN (?)`, b.model.QuoteWord(column), in) +} + +// WhereOrNotNull builds `columns[0] IS NOT NULL OR columns[1] IS NOT NULL ...` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrNotNull(columns ...string) *WhereBuilder { + builder := b + for _, column := range columns { + builder = builder.WhereOrf(`%s IS NOT NULL`, b.model.QuoteWord(column)) + } + return builder +} diff --git a/database/gdb_model_builder_whereor_prefix.go b/database/gdb_model_builder_whereor_prefix.go new file mode 100644 index 0000000..8ed96a3 --- /dev/null +++ b/database/gdb_model_builder_whereor_prefix.go @@ -0,0 +1,98 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// WhereOrPrefix performs as WhereOr, but it adds prefix to each field in where statement. +// Eg: +// WhereOrPrefix("order", "status", "paid") => WHERE xxx OR (`order`.`status`='paid') +// WhereOrPrefix("order", struct{Status:"paid", "channel":"bank"}) => WHERE xxx OR (`order`.`status`='paid' AND `order`.`channel`='bank') +func (b *WhereBuilder) WhereOrPrefix(prefix string, where any, args ...any) *WhereBuilder { + where, args = b.convertWhereBuilder(where, args) + + builder := b.getBuilder() + builder.whereHolder = append(builder.whereHolder, WhereHolder{ + Type: whereHolderTypeDefault, + Operator: whereHolderOperatorOr, + Where: where, + Args: args, + Prefix: prefix, + }) + return builder +} + +// WhereOrPrefixNot builds `prefix.column != value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixNot(prefix string, column string, value any) *WhereBuilder { + return b.WhereOrf(`%s.%s != ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WhereOrPrefixLT builds `prefix.column < value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixLT(prefix string, column string, value any) *WhereBuilder { + return b.WhereOrf(`%s.%s < ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WhereOrPrefixLTE builds `prefix.column <= value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixLTE(prefix string, column string, value any) *WhereBuilder { + return b.WhereOrf(`%s.%s <= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WhereOrPrefixGT builds `prefix.column > value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixGT(prefix string, column string, value any) *WhereBuilder { + return b.WhereOrf(`%s.%s > ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WhereOrPrefixGTE builds `prefix.column >= value` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixGTE(prefix string, column string, value any) *WhereBuilder { + return b.WhereOrf(`%s.%s >= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value) +} + +// WhereOrPrefixBetween builds `prefix.column BETWEEN min AND max` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixBetween(prefix string, column string, min, max any) *WhereBuilder { + return b.WhereOrf(`%s.%s BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max) +} + +// WhereOrPrefixLike builds `prefix.column LIKE 'like'` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixLike(prefix string, column string, like any) *WhereBuilder { + return b.WhereOrf(`%s.%s LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like) +} + +// WhereOrPrefixIn builds `prefix.column IN (in)` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixIn(prefix string, column string, in any) *WhereBuilder { + return b.doWhereOrfType(whereHolderTypeIn, `%s.%s IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in) +} + +// WhereOrPrefixNull builds `prefix.columns[0] IS NULL OR prefix.columns[1] IS NULL ...` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixNull(prefix string, columns ...string) *WhereBuilder { + builder := b + for _, column := range columns { + builder = builder.WhereOrf(`%s.%s IS NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column)) + } + return builder +} + +// WhereOrPrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixNotBetween(prefix string, column string, min, max any) *WhereBuilder { + return b.WhereOrf(`%s.%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max) +} + +// WhereOrPrefixNotLike builds `prefix.column NOT LIKE 'like'` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixNotLike(prefix string, column string, like any) *WhereBuilder { + return b.WhereOrf(`%s.%s NOT LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like) +} + +// WhereOrPrefixNotIn builds `prefix.column NOT IN (in)` statement. +func (b *WhereBuilder) WhereOrPrefixNotIn(prefix string, column string, in any) *WhereBuilder { + return b.doWhereOrfType(whereHolderTypeIn, `%s.%s NOT IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in) +} + +// WhereOrPrefixNotNull builds `prefix.columns[0] IS NOT NULL OR prefix.columns[1] IS NOT NULL ...` statement in `OR` conditions. +func (b *WhereBuilder) WhereOrPrefixNotNull(prefix string, columns ...string) *WhereBuilder { + builder := b + for _, column := range columns { + builder = builder.WhereOrf(`%s.%s IS NOT NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column)) + } + return builder +} diff --git a/database/gdb_model_cache.go b/database/gdb_model_cache.go new file mode 100644 index 0000000..583d8ba --- /dev/null +++ b/database/gdb_model_cache.go @@ -0,0 +1,172 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "time" + + "git.magicany.cc/black1552/gin-base/database/intlog" +) + +// CacheOption is options for model cache control in query. +type CacheOption struct { + // Duration is the TTL for the cache. + // If the parameter `Duration` < 0, which means it clear the cache with given `Name`. + // If the parameter `Duration` = 0, which means it never expires. + // If the parameter `Duration` > 0, which means it expires after `Duration`. + Duration time.Duration + + // Name is an optional unique name for the cache. + // The Name is used to bind a name to the cache, which means you can later control the cache + // like changing the `duration` or clearing the cache with specified Name. + Name string + + // Force caches the query result whatever the result is nil or not. + // It is used to avoid Cache Penetration. + Force bool +} + +// selectCacheItem is the cache item for SELECT statement result. +type selectCacheItem struct { + Result Result // Sql result of SELECT statement. + FirstResultColumn string // The first column name of result, for Value/Count functions. +} + +// Cache sets the cache feature for the model. It caches the result of the sql, which means +// if there's another same sql request, it just reads and returns the result from cache, it +// but not committed and executed into the database. +// +// Note that, the cache feature is disabled if the model is performing select statement +// on a transaction. +func (m *Model) Cache(option CacheOption) *Model { + model := m.getModel() + model.cacheOption = option + model.cacheEnabled = true + return model +} + +// PageCache sets the cache feature for pagination queries. It allows to configure +// separate cache options for count query and data query in pagination. +// +// Note that, the cache feature is disabled if the model is performing select statement +// on a transaction. +func (m *Model) PageCache(countOption CacheOption, dataOption CacheOption) *Model { + model := m.getModel() + model.pageCacheOption = []CacheOption{countOption, dataOption} + model.cacheEnabled = true + return model +} + +// checkAndRemoveSelectCache checks and removes the cache in insert/update/delete statement if +// cache feature is enabled. +func (m *Model) checkAndRemoveSelectCache(ctx context.Context) { + if m.cacheEnabled && m.cacheOption.Duration < 0 && len(m.cacheOption.Name) > 0 { + var cacheKey = m.makeSelectCacheKey("") + if _, err := m.db.GetCache().Remove(ctx, cacheKey); err != nil { + intlog.Errorf(ctx, `%+v`, err) + } + } +} + +func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args ...any) (result Result, err error) { + if !m.cacheEnabled || m.tx != nil { + return + } + var ( + cacheItem *selectCacheItem + cacheKey = m.makeSelectCacheKey(sql, args...) + cacheObj = m.db.GetCache() + core = m.db.GetCore() + ) + defer func() { + if cacheItem != nil { + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { + if cacheItem.FirstResultColumn != "" { + internalData.FirstResultColumn = cacheItem.FirstResultColumn + } + } + } + }() + if v, _ := cacheObj.Get(ctx, cacheKey); !v.IsNil() { + if err = v.Scan(&cacheItem); err != nil { + return nil, err + } + return cacheItem.Result, nil + } + return +} + +func (m *Model) saveSelectResultToCache( + ctx context.Context, selectType SelectType, result Result, sql string, args ...any, +) (err error) { + if !m.cacheEnabled || m.tx != nil { + return + } + var ( + cacheKey = m.makeSelectCacheKey(sql, args...) + cacheObj = m.db.GetCache() + ) + if m.cacheOption.Duration < 0 { + if _, errCache := cacheObj.Remove(ctx, cacheKey); errCache != nil { + intlog.Errorf(ctx, `%+v`, errCache) + } + return + } + // Special handler for Value/Count operations result. + if len(result) > 0 { + var core = m.db.GetCore() + switch selectType { + case SelectTypeValue, SelectTypeArray, SelectTypeCount: + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { + if result[0][internalData.FirstResultColumn].IsEmpty() { + result = nil + } + } + default: + } + } + + // In case of Cache Penetration. + if result != nil && result.IsEmpty() { + if m.cacheOption.Force { + result = Result{} + } else { + result = nil + } + } + var ( + core = m.db.GetCore() + cacheItem = &selectCacheItem{ + Result: result, + } + ) + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { + cacheItem.FirstResultColumn = internalData.FirstResultColumn + } + if errCache := cacheObj.Set(ctx, cacheKey, cacheItem, m.cacheOption.Duration); errCache != nil { + intlog.Errorf(ctx, `%+v`, errCache) + } + return +} + +func (m *Model) makeSelectCacheKey(sql string, args ...any) string { + var ( + table = m.db.GetCore().guessPrimaryTableName(m.tables) + group = m.db.GetGroup() + schema = m.db.GetSchema() + customName = m.cacheOption.Name + ) + return genSelectCacheKey( + table, + group, + schema, + customName, + sql, + args..., + ) +} diff --git a/database/gdb_model_delete.go b/database/gdb_model_delete.go new file mode 100644 index 0000000..3bdb966 --- /dev/null +++ b/database/gdb_model_delete.go @@ -0,0 +1,87 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" + + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/internal/intlog" + "github.com/gogf/gf/v2/text/gstr" +) + +// Delete does "DELETE FROM ... " statement for the model. +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +func (m *Model) Delete(where ...any) (result sql.Result, err error) { + var ctx = m.GetCtx() + if len(where) > 0 { + return m.Where(where[0], where[1:]...).Delete() + } + defer func() { + if err == nil { + m.checkAndRemoveSelectCache(ctx) + } + }() + var ( + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false) + conditionStr = conditionWhere + conditionExtra + fieldNameDelete, fieldTypeDelete = m.softTimeMaintainer().GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete) + ) + if m.unscoped { + fieldNameDelete = "" + } + if !gstr.ContainsI(conditionStr, " WHERE ") || (fieldNameDelete != "" && !gstr.ContainsI(conditionStr, " AND ")) { + intlog.Printf( + ctx, + `sql condition string "%s" has no WHERE for DELETE operation, fieldNameDelete: %s`, + conditionStr, fieldNameDelete, + ) + return nil, gerror.NewCode( + gcode.CodeMissingParameter, + "there should be WHERE condition statement for DELETE operation", + ) + } + + // Soft deleting. + if fieldNameDelete != "" { + dataHolder, dataValue := m.softTimeMaintainer().GetDeleteData( + ctx, "", fieldNameDelete, fieldTypeDelete, + ) + in := &HookUpdateInput{ + internalParamHookUpdate: internalParamHookUpdate{ + internalParamHook: internalParamHook{ + link: m.getLink(true), + }, + handler: m.hookHandler.Update, + }, + Model: m, + Table: m.tables, + Schema: m.schema, + Data: dataHolder, + Condition: conditionStr, + Args: append([]any{dataValue}, conditionArgs...), + } + return in.Next(ctx) + } + + in := &HookDeleteInput{ + internalParamHookDelete: internalParamHookDelete{ + internalParamHook: internalParamHook{ + link: m.getLink(true), + }, + handler: m.hookHandler.Delete, + }, + Model: m, + Table: m.tables, + Schema: m.schema, + Condition: conditionStr, + Args: conditionArgs, + } + return in.Next(ctx) +} diff --git a/database/gdb_model_fields.go b/database/gdb_model_fields.go new file mode 100644 index 0000000..bf52d81 --- /dev/null +++ b/database/gdb_model_fields.go @@ -0,0 +1,283 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "fmt" + + "github.com/gogf/gf/v2/container/gset" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// Fields appends `fieldNamesOrMapStruct` to the operation fields of the model, multiple fields joined using char ','. +// The parameter `fieldNamesOrMapStruct` can be type of string/map/*map/struct/*struct. +// +// Example: +// Fields("id", "name", "age") +// Fields([]string{"id", "name", "age"}) +// Fields(map[string]any{"id":1, "name":"john", "age":18}) +// Fields(User{Id: 1, Name: "john", Age: 18}). +func (m *Model) Fields(fieldNamesOrMapStruct ...any) *Model { + length := len(fieldNamesOrMapStruct) + if length == 0 { + return m + } + fields := m.filterFieldsFrom(m.tablesInit, fieldNamesOrMapStruct...) + if len(fields) == 0 { + return m + } + model := m.getModel() + return model.appendToFields(fields...) +} + +// FieldsPrefix performs as function Fields but add extra prefix for each field. +func (m *Model) FieldsPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...any) *Model { + fields := m.filterFieldsFrom( + m.getTableNameByPrefixOrAlias(prefixOrAlias), + fieldNamesOrMapStruct..., + ) + if len(fields) == 0 { + return m + } + prefixOrAlias = m.QuoteWord(prefixOrAlias) + for i, field := range fields { + fields[i] = fmt.Sprintf("%s.%s", prefixOrAlias, m.QuoteWord(gconv.String(field))) + } + model := m.getModel() + return model.appendToFields(fields...) +} + +// FieldsEx appends `fieldNamesOrMapStruct` to the excluded operation fields of the model, +// multiple fields joined using char ','. +// Note that this function supports only single table operations. +// The parameter `fieldNamesOrMapStruct` can be type of string/map/*map/struct/*struct. +// +// Example: +// FieldsEx("id", "name", "age") +// FieldsEx([]string{"id", "name", "age"}) +// FieldsEx(map[string]any{"id":1, "name":"john", "age":18}) +// FieldsEx(User{Id: 1, Name: "john", Age: 18}). +func (m *Model) FieldsEx(fieldNamesOrMapStruct ...any) *Model { + return m.doFieldsEx(m.tablesInit, fieldNamesOrMapStruct...) +} + +func (m *Model) doFieldsEx(table string, fieldNamesOrMapStruct ...any) *Model { + length := len(fieldNamesOrMapStruct) + if length == 0 { + return m + } + fields := m.filterFieldsFrom(table, fieldNamesOrMapStruct...) + if len(fields) == 0 { + return m + } + model := m.getModel() + model.fieldsEx = append(model.fieldsEx, fields...) + return model +} + +// FieldsExPrefix performs as function FieldsEx but add extra prefix for each field. +// Note that this function must be used together with FieldsPrefix, otherwise it will be invalid. +func (m *Model) FieldsExPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...any) *Model { + fields := m.filterFieldsFrom( + m.getTableNameByPrefixOrAlias(prefixOrAlias), + fieldNamesOrMapStruct..., + ) + if len(fields) == 0 { + return m + } + prefixOrAlias = m.QuoteWord(prefixOrAlias) + for i, field := range fields { + fields[i] = fmt.Sprintf("%s.%s", prefixOrAlias, m.QuoteWord(gconv.String(field))) + } + model := m.getModel() + model.fieldsEx = append(model.fieldsEx, fields...) + return model +} + +// FieldCount formats and appends commonly used field `COUNT(column)` to the select fields of model. +func (m *Model) FieldCount(column string, as ...string) *Model { + asStr := "" + if len(as) > 0 && as[0] != "" { + asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0])) + } + model := m.getModel() + return model.appendToFields( + fmt.Sprintf(`COUNT(%s)%s`, m.QuoteWord(column), asStr), + ) +} + +// FieldSum formats and appends commonly used field `SUM(column)` to the select fields of model. +func (m *Model) FieldSum(column string, as ...string) *Model { + asStr := "" + if len(as) > 0 && as[0] != "" { + asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0])) + } + model := m.getModel() + return model.appendToFields( + fmt.Sprintf(`SUM(%s)%s`, m.QuoteWord(column), asStr), + ) +} + +// FieldMin formats and appends commonly used field `MIN(column)` to the select fields of model. +func (m *Model) FieldMin(column string, as ...string) *Model { + asStr := "" + if len(as) > 0 && as[0] != "" { + asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0])) + } + model := m.getModel() + return model.appendToFields( + fmt.Sprintf(`MIN(%s)%s`, m.QuoteWord(column), asStr), + ) +} + +// FieldMax formats and appends commonly used field `MAX(column)` to the select fields of model. +func (m *Model) FieldMax(column string, as ...string) *Model { + asStr := "" + if len(as) > 0 && as[0] != "" { + asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0])) + } + model := m.getModel() + return model.appendToFields( + fmt.Sprintf(`MAX(%s)%s`, m.QuoteWord(column), asStr), + ) +} + +// FieldAvg formats and appends commonly used field `AVG(column)` to the select fields of model. +func (m *Model) FieldAvg(column string, as ...string) *Model { + asStr := "" + if len(as) > 0 && as[0] != "" { + asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0])) + } + model := m.getModel() + return model.appendToFields( + fmt.Sprintf(`AVG(%s)%s`, m.QuoteWord(column), asStr), + ) +} + +// GetFieldsStr retrieves and returns all fields from the table, joined with char ','. +// The optional parameter `prefix` specifies the prefix for each field, eg: GetFieldsStr("u."). +func (m *Model) GetFieldsStr(prefix ...string) string { + prefixStr := "" + if len(prefix) > 0 { + prefixStr = prefix[0] + } + tableFields, err := m.TableFields(m.tablesInit) + if err != nil { + panic(err) + } + if len(tableFields) == 0 { + panic(fmt.Sprintf(`empty table fields for table "%s"`, m.tables)) + } + fieldsArray := make([]string, len(tableFields)) + for k, v := range tableFields { + fieldsArray[v.Index] = k + } + newFields := "" + for _, k := range fieldsArray { + if len(newFields) > 0 { + newFields += "," + } + newFields += prefixStr + k + } + newFields = m.db.GetCore().QuoteString(newFields) + return newFields +} + +// GetFieldsExStr retrieves and returns fields which are not in parameter `fields` from the table, +// joined with char ','. +// The parameter `fields` specifies the fields that are excluded. +// The optional parameter `prefix` specifies the prefix for each field, eg: FieldsExStr("id", "u."). +func (m *Model) GetFieldsExStr(fields string, prefix ...string) (string, error) { + prefixStr := "" + if len(prefix) > 0 { + prefixStr = prefix[0] + } + tableFields, err := m.TableFields(m.tablesInit) + if err != nil { + return "", err + } + if len(tableFields) == 0 { + return "", gerror.Newf(`empty table fields for table "%s"`, m.tables) + } + fieldsExSet := gset.NewStrSetFrom(gstr.SplitAndTrim(fields, ",")) + fieldsArray := make([]string, len(tableFields)) + for k, v := range tableFields { + fieldsArray[v.Index] = k + } + newFields := "" + for _, k := range fieldsArray { + if fieldsExSet.Contains(k) { + continue + } + if len(newFields) > 0 { + newFields += "," + } + newFields += prefixStr + k + } + newFields = m.db.GetCore().QuoteString(newFields) + return newFields, nil +} + +// HasField determine whether the field exists in the table. +func (m *Model) HasField(field string) (bool, error) { + return m.db.GetCore().HasField(m.GetCtx(), m.tablesInit, field) +} + +// getFieldsFrom retrieves, filters and returns fields name from table `table`. +func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...any) []any { + length := len(fieldNamesOrMapStruct) + if length == 0 { + return nil + } + switch { + // String slice. + case length >= 2: + return m.mappingAndFilterToTableFields( + table, fieldNamesOrMapStruct, true, + ) + + // It needs type asserting. + case length == 1: + structOrMap := fieldNamesOrMapStruct[0] + switch r := structOrMap.(type) { + case string: + return m.mappingAndFilterToTableFields(table, []any{r}, false) + + case []string: + return m.mappingAndFilterToTableFields(table, gconv.Interfaces(r), true) + + case Raw, *Raw: + return []any{structOrMap} + + default: + return m.mappingAndFilterToTableFields(table, getFieldsFromStructOrMap(structOrMap), true) + } + + default: + return nil + } +} + +func (m *Model) appendToFields(fields ...any) *Model { + if len(fields) == 0 { + return m + } + model := m.getModel() + model.fields = append(model.fields, fields...) + return model +} + +func (m *Model) isFieldInFieldsEx(field string) bool { + for _, v := range m.fieldsEx { + if v == field { + return true + } + } + return false +} diff --git a/database/gdb_model_hook.go b/database/gdb_model_hook.go new file mode 100644 index 0000000..3dc1b4a --- /dev/null +++ b/database/gdb_model_hook.go @@ -0,0 +1,309 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "database/sql" + "fmt" + + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" +) + +type ( + HookFuncSelect func(ctx context.Context, in *HookSelectInput) (result Result, err error) + HookFuncInsert func(ctx context.Context, in *HookInsertInput) (result sql.Result, err error) + HookFuncUpdate func(ctx context.Context, in *HookUpdateInput) (result sql.Result, err error) + HookFuncDelete func(ctx context.Context, in *HookDeleteInput) (result sql.Result, err error) +) + +// HookHandler manages all supported hook functions for Model. +type HookHandler struct { + Select HookFuncSelect + Insert HookFuncInsert + Update HookFuncUpdate + Delete HookFuncDelete +} + +// internalParamHook manages all internal parameters for hook operations. +// The `internal` obviously means you cannot access these parameters outside this package. +type internalParamHook struct { + link Link // Connection object from third party sql driver. + handlerCalled bool // Simple mark for custom handler called, in case of recursive calling. + removedWhere bool // Removed mark for condition string that was removed `WHERE` prefix. + originalTableName *gvar.Var // The original table name. + originalSchemaName *gvar.Var // The original schema name. +} + +type internalParamHookSelect struct { + internalParamHook + handler HookFuncSelect +} + +type internalParamHookInsert struct { + internalParamHook + handler HookFuncInsert +} + +type internalParamHookUpdate struct { + internalParamHook + handler HookFuncUpdate +} + +type internalParamHookDelete struct { + internalParamHook + handler HookFuncDelete +} + +// HookSelectInput holds the parameters for select hook operation. +// Note that, COUNT statement will also be hooked by this feature, +// which is usually not be interesting for upper business hook handler. +type HookSelectInput struct { + internalParamHookSelect + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Sql string // The sql string that to be committed. + Args []any // The arguments of sql. + SelectType SelectType // The type of this SELECT operation. +} + +// HookInsertInput holds the parameters for insert hook operation. +type HookInsertInput struct { + internalParamHookInsert + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Data List // The data records list to be inserted/saved into table. + Option DoInsertOption // The extra option for data inserting. +} + +// HookUpdateInput holds the parameters for update hook operation. +type HookUpdateInput struct { + internalParamHookUpdate + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Data any // Data can be type of: map[string]any/string. You can use type assertion on `Data`. + Condition string // The where condition string for updating. + Args []any // The arguments for sql place-holders. +} + +// HookDeleteInput holds the parameters for delete hook operation. +type HookDeleteInput struct { + internalParamHookDelete + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Condition string // The where condition string for deleting. + Args []any // The arguments for sql place-holders. +} + +const ( + whereKeyInCondition = " WHERE " +) + +// IsTransaction checks and returns whether current operation is during transaction. +func (h *internalParamHook) IsTransaction() bool { + return h.link.IsTransaction() +} + +// Next calls the next hook handler. +func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) { + if h.originalTableName.IsNil() { + h.originalTableName = gvar.New(h.Table) + } + if h.originalSchemaName.IsNil() { + h.originalSchemaName = gvar.New(h.Schema) + } + + // Sharding feature. + h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) + if err != nil { + return nil, err + } + h.Table, err = h.Model.getActualTable(ctx, h.Table) + if err != nil { + return nil, err + } + + // Custom hook handler call. + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + return h.handler(ctx, h) + } + var toBeCommittedSql = h.Sql + // Table change. + if h.Table != h.originalTableName.String() { + toBeCommittedSql, err = gregex.ReplaceStringFuncMatch( + `(?i) FROM ([\S]+)`, + toBeCommittedSql, + func(match []string) string { + charL, charR := h.Model.db.GetChars() + return fmt.Sprintf(` FROM %s%s%s`, charL, h.Table, charR) + }, + ) + if err != nil { + return + } + } + // Schema change. + if h.Schema != "" && h.Schema != h.originalSchemaName.String() { + h.link, err = h.Model.db.GetCore().SlaveLink(h.Schema) + if err != nil { + return + } + h.Model.db.GetCore().schema = h.Schema + defer func() { + h.Model.db.GetCore().schema = h.originalSchemaName.String() + }() + } + return h.Model.db.DoSelect(ctx, h.link, toBeCommittedSql, h.Args...) +} + +// Next calls the next hook handler. +func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) { + if h.originalTableName.IsNil() { + h.originalTableName = gvar.New(h.Table) + } + if h.originalSchemaName.IsNil() { + h.originalSchemaName = gvar.New(h.Schema) + } + + // Sharding feature. + h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) + if err != nil { + return nil, err + } + h.Table, err = h.Model.getActualTable(ctx, h.Table) + if err != nil { + return nil, err + } + + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + return h.handler(ctx, h) + } + + // No need to handle table change. + + // Schema change. + if h.Schema != "" && h.Schema != h.originalSchemaName.String() { + h.link, err = h.Model.db.GetCore().MasterLink(h.Schema) + if err != nil { + return + } + h.Model.db.GetCore().schema = h.Schema + defer func() { + h.Model.db.GetCore().schema = h.originalSchemaName.String() + }() + } + return h.Model.db.DoInsert(ctx, h.link, h.Table, h.Data, h.Option) +} + +// Next calls the next hook handler. +func (h *HookUpdateInput) Next(ctx context.Context) (result sql.Result, err error) { + if h.originalTableName.IsNil() { + h.originalTableName = gvar.New(h.Table) + } + if h.originalSchemaName.IsNil() { + h.originalSchemaName = gvar.New(h.Schema) + } + + // Sharding feature. + h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) + if err != nil { + return nil, err + } + h.Table, err = h.Model.getActualTable(ctx, h.Table) + if err != nil { + return nil, err + } + + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + if gstr.HasPrefix(h.Condition, whereKeyInCondition) { + h.removedWhere = true + h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition) + } + return h.handler(ctx, h) + } + if h.removedWhere { + h.Condition = whereKeyInCondition + h.Condition + } + + // No need to handle table change. + + // Schema change. + if h.Schema != "" && h.Schema != h.originalSchemaName.String() { + h.link, err = h.Model.db.GetCore().MasterLink(h.Schema) + if err != nil { + return + } + h.Model.db.GetCore().schema = h.Schema + defer func() { + h.Model.db.GetCore().schema = h.originalSchemaName.String() + }() + } + return h.Model.db.DoUpdate(ctx, h.link, h.Table, h.Data, h.Condition, h.Args...) +} + +// Next calls the next hook handler. +func (h *HookDeleteInput) Next(ctx context.Context) (result sql.Result, err error) { + if h.originalTableName.IsNil() { + h.originalTableName = gvar.New(h.Table) + } + if h.originalSchemaName.IsNil() { + h.originalSchemaName = gvar.New(h.Schema) + } + + // Sharding feature. + h.Schema, err = h.Model.getActualSchema(ctx, h.Schema) + if err != nil { + return nil, err + } + h.Table, err = h.Model.getActualTable(ctx, h.Table) + if err != nil { + return nil, err + } + + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + if gstr.HasPrefix(h.Condition, whereKeyInCondition) { + h.removedWhere = true + h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition) + } + return h.handler(ctx, h) + } + if h.removedWhere { + h.Condition = whereKeyInCondition + h.Condition + } + + // No need to handle table change. + + // Schema change. + if h.Schema != "" && h.Schema != h.originalSchemaName.String() { + h.link, err = h.Model.db.GetCore().MasterLink(h.Schema) + if err != nil { + return + } + h.Model.db.GetCore().schema = h.Schema + defer func() { + h.Model.db.GetCore().schema = h.originalSchemaName.String() + }() + } + return h.Model.db.DoDelete(ctx, h.link, h.Table, h.Condition, h.Args...) +} + +// Hook sets the hook functions for current model. +func (m *Model) Hook(hook HookHandler) *Model { + model := m.getModel() + model.hookHandler = hook + return model +} diff --git a/database/gdb_model_insert.go b/database/gdb_model_insert.go new file mode 100644 index 0000000..5a04187 --- /dev/null +++ b/database/gdb_model_insert.go @@ -0,0 +1,469 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "database/sql" + "reflect" + + "git.magicany.cc/black1552/gin-base/database/empty" + "git.magicany.cc/black1552/gin-base/database/reflection" + "github.com/gogf/gf/v2/container/gset" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/gutil" +) + +// Batch sets the batch operation number for the model. +func (m *Model) Batch(batch int) *Model { + model := m.getModel() + model.batch = batch + return model +} + +// Data sets the operation data for the model. +// The parameter `data` can be type of string/map/gmap/slice/struct/*struct, etc. +// Note that, it uses shallow value copying for `data` if `data` is type of map/slice +// to avoid changing it inside function. +// Eg: +// Data("uid=10000") +// Data("uid", 10000) +// Data("uid=? AND name=?", 10000, "john") +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}). +func (m *Model) Data(data ...any) *Model { + var model = m.getModel() + if len(data) > 1 { + if s := gconv.String(data[0]); gstr.Contains(s, "?") { + model.data = s + model.extraArgs = data[1:] + } else { + newData := make(map[string]any) + for i := 0; i < len(data); i += 2 { + newData[gconv.String(data[i])] = data[i+1] + } + model.data = newData + } + } else if len(data) == 1 { + switch value := data[0].(type) { + case Result: + model.data = value.List() + + case Record: + model.data = value.Map() + + case List: + list := make(List, len(value)) + for k, v := range value { + list[k] = gutil.MapCopy(v) + } + model.data = list + + case Map: + model.data = gutil.MapCopy(value) + + default: + reflectInfo := reflection.OriginValueAndKind(value) + switch reflectInfo.OriginKind { + case reflect.Slice, reflect.Array: + if reflectInfo.OriginValue.Len() > 0 { + // If the `data` parameter is a DO struct, + // it then adds `OmitNilData` option for this condition, + // which will filter all nil parameters in `data`. + if isDoStruct(reflectInfo.OriginValue.Index(0).Interface()) { + model = model.OmitNilData() + model.option |= optionOmitNilDataInternal + } + } + list := make(List, reflectInfo.OriginValue.Len()) + for i := 0; i < reflectInfo.OriginValue.Len(); i++ { + list[i] = anyValueToMapBeforeToRecord(reflectInfo.OriginValue.Index(i).Interface()) + } + model.data = list + + case reflect.Struct: + // If the `data` parameter is a DO struct, + // it then adds `OmitNilData` option for this condition, + // which will filter all nil parameters in `data`. + if isDoStruct(value) { + model = model.OmitNilData() + } + if v, ok := data[0].(iInterfaces); ok { + var ( + array = v.Interfaces() + list = make(List, len(array)) + ) + for i := 0; i < len(array); i++ { + list[i] = anyValueToMapBeforeToRecord(array[i]) + } + model.data = list + } else { + model.data = anyValueToMapBeforeToRecord(data[0]) + } + + case reflect.Map: + model.data = anyValueToMapBeforeToRecord(data[0]) + + default: + model.data = data[0] + } + } + } + return model +} + +// OnConflict sets the primary key or index when columns conflicts occurs. +// It's not necessary for MySQL driver. +func (m *Model) OnConflict(onConflict ...any) *Model { + if len(onConflict) == 0 { + return m + } + model := m.getModel() + if len(onConflict) > 1 { + model.onConflict = onConflict + } else if len(onConflict) == 1 { + model.onConflict = onConflict[0] + } + return model +} + +// OnDuplicate sets the operations when columns conflicts occurs. +// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement. +// The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice. +// Example: +// +// OnDuplicate("nickname, age") +// OnDuplicate("nickname", "age") +// +// OnDuplicate(g.Map{ +// "nickname": gdb.Raw("CONCAT('name_', VALUES(`nickname`))"), +// }) +// +// OnDuplicate(g.Map{ +// "nickname": "passport", +// }). +func (m *Model) OnDuplicate(onDuplicate ...any) *Model { + if len(onDuplicate) == 0 { + return m + } + model := m.getModel() + if len(onDuplicate) > 1 { + model.onDuplicate = onDuplicate + } else if len(onDuplicate) == 1 { + model.onDuplicate = onDuplicate[0] + } + return model +} + +// OnDuplicateEx sets the excluding columns for operations when columns conflict occurs. +// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement. +// The parameter `onDuplicateEx` can be type of string/map/slice. +// Example: +// +// OnDuplicateEx("passport, password") +// OnDuplicateEx("passport", "password") +// +// OnDuplicateEx(g.Map{ +// "passport": "", +// "password": "", +// }). +func (m *Model) OnDuplicateEx(onDuplicateEx ...any) *Model { + if len(onDuplicateEx) == 0 { + return m + } + model := m.getModel() + if len(onDuplicateEx) > 1 { + model.onDuplicateEx = onDuplicateEx + } else if len(onDuplicateEx) == 1 { + model.onDuplicateEx = onDuplicateEx[0] + } + return model +} + +// Insert does "INSERT INTO ..." statement for the model. +// The optional parameter `data` is the same as the parameter of Model.Data function, +// see Model.Data. +func (m *Model) Insert(data ...any) (result sql.Result, err error) { + var ctx = m.GetCtx() + if len(data) > 0 { + return m.Data(data...).Insert() + } + return m.doInsertWithOption(ctx, InsertOptionDefault) +} + +// InsertAndGetId performs action Insert and returns the last insert id that automatically generated. +func (m *Model) InsertAndGetId(data ...any) (lastInsertId int64, err error) { + var ctx = m.GetCtx() + if len(data) > 0 { + return m.Data(data...).InsertAndGetId() + } + result, err := m.doInsertWithOption(ctx, InsertOptionDefault) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +// InsertIgnore does "INSERT IGNORE INTO ..." statement for the model. +// The optional parameter `data` is the same as the parameter of Model.Data function, +// see Model.Data. +func (m *Model) InsertIgnore(data ...any) (result sql.Result, err error) { + var ctx = m.GetCtx() + if len(data) > 0 { + return m.Data(data...).InsertIgnore() + } + return m.doInsertWithOption(ctx, InsertOptionIgnore) +} + +// Replace does "REPLACE INTO ..." statement for the model. +// The optional parameter `data` is the same as the parameter of Model.Data function, +// see Model.Data. +func (m *Model) Replace(data ...any) (result sql.Result, err error) { + var ctx = m.GetCtx() + if len(data) > 0 { + return m.Data(data...).Replace() + } + return m.doInsertWithOption(ctx, InsertOptionReplace) +} + +// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model. +// The optional parameter `data` is the same as the parameter of Model.Data function, +// see Model.Data. +// +// It updates the record if there's primary or unique index in the saving data, +// or else it inserts a new record into the table. +func (m *Model) Save(data ...any) (result sql.Result, err error) { + var ctx = m.GetCtx() + if len(data) > 0 { + return m.Data(data...).Save() + } + return m.doInsertWithOption(ctx, InsertOptionSave) +} + +// doInsertWithOption inserts data with option parameter. +func (m *Model) doInsertWithOption(ctx context.Context, insertOption InsertOption) (result sql.Result, err error) { + defer func() { + if err == nil { + m.checkAndRemoveSelectCache(ctx) + } + }() + if m.data == nil { + return nil, gerror.NewCode(gcode.CodeMissingParameter, "inserting into table with empty data") + } + var ( + list List + stm = m.softTimeMaintainer() + fieldNameCreate, fieldTypeCreate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldCreate) + fieldNameUpdate, fieldTypeUpdate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldUpdate) + fieldNameDelete, fieldTypeDelete = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete) + ) + // m.data was already converted to type List/Map by function Data + newData, err := m.filterDataForInsertOrUpdate(m.data) + if err != nil { + return nil, err + } + // It converts any data to List type for inserting. + switch value := newData.(type) { + case List: + list = value + + case Map: + list = List{value} + } + + if len(list) < 1 { + return result, gerror.NewCode(gcode.CodeMissingParameter, "data list cannot be empty") + } + + // Automatic handling for creating/updating time. + if fieldNameCreate != "" && m.isFieldInFieldsEx(fieldNameCreate) { + fieldNameCreate = "" + } + if fieldNameUpdate != "" && m.isFieldInFieldsEx(fieldNameUpdate) { + fieldNameUpdate = "" + } + var isSoftTimeFeatureEnabled = fieldNameCreate != "" || fieldNameUpdate != "" + if !m.unscoped && isSoftTimeFeatureEnabled { + for k, v := range list { + if fieldNameCreate != "" && empty.IsNil(v[fieldNameCreate]) { + fieldCreateValue := stm.GetFieldValue(ctx, fieldTypeCreate, false) + if fieldCreateValue != nil { + v[fieldNameCreate] = fieldCreateValue + } + } + if fieldNameUpdate != "" && empty.IsNil(v[fieldNameUpdate]) { + fieldUpdateValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false) + if fieldUpdateValue != nil { + v[fieldNameUpdate] = fieldUpdateValue + } + } + // for timestamp field that should initialize the delete_at field with value, for example 0. + if fieldNameDelete != "" && empty.IsNil(v[fieldNameDelete]) { + fieldDeleteValue := stm.GetFieldValue(ctx, fieldTypeDelete, true) + if fieldDeleteValue != nil { + v[fieldNameDelete] = fieldDeleteValue + } + } + list[k] = v + } + } + // Format DoInsertOption, especially for "ON DUPLICATE KEY UPDATE" statement. + columnNames := make([]string, 0, len(list[0])) + for k := range list[0] { + columnNames = append(columnNames, k) + } + doInsertOption, err := m.formatDoInsertOption(insertOption, columnNames) + if err != nil { + return result, err + } + + in := &HookInsertInput{ + internalParamHookInsert: internalParamHookInsert{ + internalParamHook: internalParamHook{ + link: m.getLink(true), + }, + handler: m.hookHandler.Insert, + }, + Model: m, + Table: m.tables, + Schema: m.schema, + Data: list, + Option: doInsertOption, + } + return in.Next(ctx) +} + +func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []string) (option DoInsertOption, err error) { + option = DoInsertOption{ + InsertOption: insertOption, + BatchCount: m.getBatch(), + } + if insertOption != InsertOptionSave { + return + } + + onConflictKeys, err := m.formatOnConflictKeys(m.onConflict) + if err != nil { + return option, err + } + option.OnConflict = onConflictKeys + + onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx) + if err != nil { + return option, err + } + onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys) + if m.onDuplicate != nil { + switch m.onDuplicate.(type) { + case Raw, *Raw: + option.OnDuplicateStr = gconv.String(m.onDuplicate) + + default: + reflectInfo := reflection.OriginValueAndKind(m.onDuplicate) + switch reflectInfo.OriginKind { + case reflect.String: + option.OnDuplicateMap = make(map[string]any) + for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v + } + + case reflect.Map: + option.OnDuplicateMap = make(map[string]any) + for k, v := range gconv.Map(m.onDuplicate) { + if onDuplicateExKeySet.Contains(k) { + continue + } + option.OnDuplicateMap[k] = v + } + + case reflect.Slice, reflect.Array: + option.OnDuplicateMap = make(map[string]any) + for _, v := range gconv.Strings(m.onDuplicate) { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v + } + + default: + return option, gerror.NewCodef( + gcode.CodeInvalidParameter, + `unsupported OnDuplicate parameter type "%s"`, + reflect.TypeOf(m.onDuplicate), + ) + } + } + } else if onDuplicateExKeySet.Size() > 0 { + option.OnDuplicateMap = make(map[string]any) + for _, v := range columnNames { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v + } + } + return +} + +func (m *Model) formatOnDuplicateExKeys(onDuplicateEx any) ([]string, error) { + if onDuplicateEx == nil { + return nil, nil + } + + reflectInfo := reflection.OriginValueAndKind(onDuplicateEx) + switch reflectInfo.OriginKind { + case reflect.String: + return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil + + case reflect.Map: + return gutil.Keys(onDuplicateEx), nil + + case reflect.Slice, reflect.Array: + return gconv.Strings(onDuplicateEx), nil + + default: + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `unsupported OnDuplicateEx parameter type "%s"`, + reflect.TypeOf(onDuplicateEx), + ) + } +} + +func (m *Model) formatOnConflictKeys(onConflict any) ([]string, error) { + if onConflict == nil { + return nil, nil + } + + reflectInfo := reflection.OriginValueAndKind(onConflict) + switch reflectInfo.OriginKind { + case reflect.String: + return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil + + case reflect.Slice, reflect.Array: + return gconv.Strings(onConflict), nil + + default: + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `unsupported onConflict parameter type "%s"`, + reflect.TypeOf(onConflict), + ) + } +} + +func (m *Model) getBatch() int { + return m.batch +} diff --git a/database/gdb_model_join.go b/database/gdb_model_join.go new file mode 100644 index 0000000..fd301e9 --- /dev/null +++ b/database/gdb_model_join.go @@ -0,0 +1,223 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "fmt" + + "github.com/gogf/gf/v2/text/gstr" +) + +// LeftJoin does "LEFT JOIN ... ON ..." statement on the model. +// The parameter `table` can be joined table and its joined condition, +// and also with its alias name. +// +// Eg: +// Model("user").LeftJoin("user_detail", "user_detail.uid=user.uid") +// Model("user", "u").LeftJoin("user_detail", "ud", "ud.uid=u.uid") +// Model("user", "u").LeftJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid"). +func (m *Model) LeftJoin(tableOrSubQueryAndJoinConditions ...string) *Model { + return m.doJoin(joinOperatorLeft, tableOrSubQueryAndJoinConditions...) +} + +// RightJoin does "RIGHT JOIN ... ON ..." statement on the model. +// The parameter `table` can be joined table and its joined condition, +// and also with its alias name. +// +// Eg: +// Model("user").RightJoin("user_detail", "user_detail.uid=user.uid") +// Model("user", "u").RightJoin("user_detail", "ud", "ud.uid=u.uid") +// Model("user", "u").RightJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid"). +func (m *Model) RightJoin(tableOrSubQueryAndJoinConditions ...string) *Model { + return m.doJoin(joinOperatorRight, tableOrSubQueryAndJoinConditions...) +} + +// InnerJoin does "INNER JOIN ... ON ..." statement on the model. +// The parameter `table` can be joined table and its joined condition, +// and also with its alias name。 +// +// Eg: +// Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid") +// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid") +// Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid"). +func (m *Model) InnerJoin(tableOrSubQueryAndJoinConditions ...string) *Model { + return m.doJoin(joinOperatorInner, tableOrSubQueryAndJoinConditions...) +} + +// LeftJoinOnField performs as LeftJoin, but it joins both tables with the `same field name`. +// +// Eg: +// Model("order").LeftJoinOnField("user", "user_id") +// Model("order").LeftJoinOnField("product", "product_id"). +func (m *Model) LeftJoinOnField(table, field string) *Model { + return m.doJoin(joinOperatorLeft, table, fmt.Sprintf( + `%s.%s=%s.%s`, + m.tablesInit, + m.db.GetCore().QuoteWord(field), + m.db.GetCore().QuoteWord(table), + m.db.GetCore().QuoteWord(field), + )) +} + +// RightJoinOnField performs as RightJoin, but it joins both tables with the `same field name`. +// +// Eg: +// Model("order").InnerJoinOnField("user", "user_id") +// Model("order").InnerJoinOnField("product", "product_id"). +func (m *Model) RightJoinOnField(table, field string) *Model { + return m.doJoin(joinOperatorRight, table, fmt.Sprintf( + `%s.%s=%s.%s`, + m.tablesInit, + m.db.GetCore().QuoteWord(field), + m.db.GetCore().QuoteWord(table), + m.db.GetCore().QuoteWord(field), + )) +} + +// InnerJoinOnField performs as InnerJoin, but it joins both tables with the `same field name`. +// +// Eg: +// Model("order").InnerJoinOnField("user", "user_id") +// Model("order").InnerJoinOnField("product", "product_id"). +func (m *Model) InnerJoinOnField(table, field string) *Model { + return m.doJoin(joinOperatorInner, table, fmt.Sprintf( + `%s.%s=%s.%s`, + m.tablesInit, + m.db.GetCore().QuoteWord(field), + m.db.GetCore().QuoteWord(table), + m.db.GetCore().QuoteWord(field), + )) +} + +// LeftJoinOnFields performs as LeftJoin. It specifies different fields and comparison operator. +// +// Eg: +// Model("user").LeftJoinOnFields("order", "id", "=", "user_id") +// Model("user").LeftJoinOnFields("order", "id", ">", "user_id") +// Model("user").LeftJoinOnFields("order", "id", "<", "user_id") +func (m *Model) LeftJoinOnFields(table, firstField, operator, secondField string) *Model { + return m.doJoin(joinOperatorLeft, table, fmt.Sprintf( + `%s.%s %s %s.%s`, + m.tablesInit, + m.db.GetCore().QuoteWord(firstField), + operator, + m.db.GetCore().QuoteWord(table), + m.db.GetCore().QuoteWord(secondField), + )) +} + +// RightJoinOnFields performs as RightJoin. It specifies different fields and comparison operator. +// +// Eg: +// Model("user").RightJoinOnFields("order", "id", "=", "user_id") +// Model("user").RightJoinOnFields("order", "id", ">", "user_id") +// Model("user").RightJoinOnFields("order", "id", "<", "user_id") +func (m *Model) RightJoinOnFields(table, firstField, operator, secondField string) *Model { + return m.doJoin(joinOperatorRight, table, fmt.Sprintf( + `%s.%s %s %s.%s`, + m.tablesInit, + m.db.GetCore().QuoteWord(firstField), + operator, + m.db.GetCore().QuoteWord(table), + m.db.GetCore().QuoteWord(secondField), + )) +} + +// InnerJoinOnFields performs as InnerJoin. It specifies different fields and comparison operator. +// +// Eg: +// Model("user").InnerJoinOnFields("order", "id", "=", "user_id") +// Model("user").InnerJoinOnFields("order", "id", ">", "user_id") +// Model("user").InnerJoinOnFields("order", "id", "<", "user_id") +func (m *Model) InnerJoinOnFields(table, firstField, operator, secondField string) *Model { + return m.doJoin(joinOperatorInner, table, fmt.Sprintf( + `%s.%s %s %s.%s`, + m.tablesInit, + m.db.GetCore().QuoteWord(firstField), + operator, + m.db.GetCore().QuoteWord(table), + m.db.GetCore().QuoteWord(secondField), + )) +} + +// doJoin does "LEFT/RIGHT/INNER JOIN ... ON ..." statement on the model. +// The parameter `tableOrSubQueryAndJoinConditions` can be joined table and its joined condition, +// and also with its alias name. +// +// Eg: +// Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid") +// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid") +// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid>u.uid") +// Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid") +// Related issues: +// https://github.com/gogf/gf/issues/1024 +func (m *Model) doJoin(operator joinOperator, tableOrSubQueryAndJoinConditions ...string) *Model { + var ( + model = m.getModel() + joinStr = "" + table string + alias string + ) + // Check the first parameter table or sub-query. + if len(tableOrSubQueryAndJoinConditions) > 0 { + if isSubQuery(tableOrSubQueryAndJoinConditions[0]) { + joinStr = gstr.Trim(tableOrSubQueryAndJoinConditions[0]) + if joinStr[0] != '(' { + joinStr = "(" + joinStr + ")" + } + } else { + table = tableOrSubQueryAndJoinConditions[0] + joinStr = m.db.GetCore().QuotePrefixTableName(table) + } + } + // Generate join condition statement string. + conditionLength := len(tableOrSubQueryAndJoinConditions) + switch { + case conditionLength > 2: + alias = tableOrSubQueryAndJoinConditions[1] + model.tables += fmt.Sprintf( + " %s JOIN %s AS %s ON (%s)", + operator, joinStr, + m.db.GetCore().QuoteWord(alias), + tableOrSubQueryAndJoinConditions[2], + ) + m.tableAliasMap[alias] = table + + case conditionLength == 2: + model.tables += fmt.Sprintf( + " %s JOIN %s ON (%s)", + operator, joinStr, tableOrSubQueryAndJoinConditions[1], + ) + + case conditionLength == 1: + model.tables += fmt.Sprintf( + " %s JOIN %s", operator, joinStr, + ) + } + return model +} + +// getTableNameByPrefixOrAlias checks and returns the table name if `prefixOrAlias` is an alias of a table, +// it or else returns the `prefixOrAlias` directly. +func (m *Model) getTableNameByPrefixOrAlias(prefixOrAlias string) string { + value, ok := m.tableAliasMap[prefixOrAlias] + if ok { + return value + } + return prefixOrAlias +} + +// isSubQuery checks and returns whether given string a sub-query sql string. +func isSubQuery(s string) bool { + s = gstr.TrimLeft(s, "()") + if p := gstr.Pos(s, " "); p != -1 { + if gstr.Equal(s[:p], "select") { + return true + } + } + return false +} diff --git a/database/gdb_model_lock.go b/database/gdb_model_lock.go new file mode 100644 index 0000000..f65ae8c --- /dev/null +++ b/database/gdb_model_lock.go @@ -0,0 +1,129 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// Lock clause constants for different databases. +// These constants provide type-safe and IDE-friendly access to various lock syntaxes. +const ( + // Common lock clauses (supported by most databases) + LockForUpdate = "FOR UPDATE" + LockForUpdateSkipLocked = "FOR UPDATE SKIP LOCKED" + + // MySQL lock clauses + LockInShareMode = "LOCK IN SHARE MODE" // MySQL legacy syntax + LockForShare = "FOR SHARE" // MySQL 8.0+ and PostgreSQL + LockForUpdateNowait = "FOR UPDATE NOWAIT" // MySQL 8.0+ and Oracle + + // PostgreSQL specific lock clauses + LockForNoKeyUpdate = "FOR NO KEY UPDATE" + LockForKeyShare = "FOR KEY SHARE" + LockForShareNowait = "FOR SHARE NOWAIT" + LockForShareSkipLocked = "FOR SHARE SKIP LOCKED" + LockForNoKeyUpdateNowait = "FOR NO KEY UPDATE NOWAIT" + LockForNoKeyUpdateSkipLocked = "FOR NO KEY UPDATE SKIP LOCKED" + LockForKeyShareNowait = "FOR KEY SHARE NOWAIT" + LockForKeyShareSkipLocked = "FOR KEY SHARE SKIP LOCKED" + + // Oracle specific lock clauses + LockForUpdateWait5 = "FOR UPDATE WAIT 5" + LockForUpdateWait10 = "FOR UPDATE WAIT 10" + LockForUpdateWait30 = "FOR UPDATE WAIT 30" + + // SQL Server lock hints (use with WITH clause) + LockWithUpdLock = "WITH (UPDLOCK)" + LockWithHoldLock = "WITH (HOLDLOCK)" + LockWithXLock = "WITH (XLOCK)" + LockWithTabLock = "WITH (TABLOCK)" + LockWithNoLock = "WITH (NOLOCK)" + LockWithUpdLockHoldLock = "WITH (UPDLOCK, HOLDLOCK)" +) + +// Lock sets a custom lock clause for the current operation. +// This is a generic method that allows you to specify any lock syntax supported by your database. +// You can use predefined constants or custom strings. +// +// Database-specific lock syntax support: +// +// PostgreSQL (most comprehensive): +// - "FOR UPDATE" - Exclusive lock, blocks all access +// - "FOR NO KEY UPDATE" - Weaker exclusive lock, doesn't block FOR KEY SHARE +// - "FOR SHARE" - Shared lock, allows reads but blocks writes +// - "FOR KEY SHARE" - Weakest lock, only locks key values +// - All above can be combined with: +// - "NOWAIT" - Return immediately if lock cannot be acquired +// - "SKIP LOCKED" - Skip locked rows instead of waiting +// +// MySQL: +// - "FOR UPDATE" - Exclusive lock (all versions) +// - "LOCK IN SHARE MODE" - Shared lock (legacy syntax) +// - "FOR SHARE" - Shared lock (MySQL 8.0+) +// - "FOR UPDATE NOWAIT" - MySQL 8.0+ only +// - "FOR UPDATE SKIP LOCKED" - MySQL 8.0+ only +// +// Oracle: +// - "FOR UPDATE" - Exclusive lock +// - "FOR UPDATE NOWAIT" - Exclusive lock, no wait +// - "FOR UPDATE SKIP LOCKED" - Exclusive lock, skip locked rows +// - "FOR UPDATE WAIT n" - Exclusive lock, wait n seconds +// - "FOR UPDATE OF column_list" - Lock specific columns +// +// SQL Server (uses WITH hints): +// - "WITH (UPDLOCK)" - Update lock +// - "WITH (HOLDLOCK)" - Hold lock until transaction end +// - "WITH (XLOCK)" - Exclusive lock +// - "WITH (TABLOCK)" - Table lock +// - "WITH (NOLOCK)" - No lock (dirty read) +// - "WITH (UPDLOCK, HOLDLOCK)" - Combined update and hold lock +// +// SQLite: +// - Limited locking support, database-level locks only +// - No row-level lock syntax supported +// +// Usage examples: +// +// db.Model("users").Lock("FOR UPDATE NOWAIT").Where("id", 1).One() +// db.Model("users").Lock("FOR SHARE SKIP LOCKED").Where("status", "active").All() +// db.Model("users").Lock("WITH (UPDLOCK)").Where("id", 1).One() // SQL Server +// db.Model("users").Lock("FOR UPDATE OF name, email").Where("id", 1).One() // Oracle +// db.Model("users").Lock("FOR UPDATE WAIT 15").Where("id", 1).One() // Oracle custom wait +// +// Or use predefined constants for better IDE support: +// +// db.Model("users").Lock(gdb.LockForUpdateNowait).Where("id", 1).One() +// db.Model("users").Lock(gdb.LockForShareSkipLocked).Where("status", "active").All() +func (m *Model) Lock(lockClause string) *Model { + model := m.getModel() + model.lockInfo = lockClause + return model +} + +// LockUpdate sets the lock for update for current operation. +// This is equivalent to Lock("FOR UPDATE"). +func (m *Model) LockUpdate() *Model { + model := m.getModel() + model.lockInfo = LockForUpdate + return model +} + +// LockUpdateSkipLocked sets the lock for update with skip locked behavior for current operation. +// It skips the locked rows. +// This is equivalent to Lock("FOR UPDATE SKIP LOCKED"). +// Note: Supported by PostgreSQL, Oracle, and MySQL 8.0+. +func (m *Model) LockUpdateSkipLocked() *Model { + model := m.getModel() + model.lockInfo = LockForUpdateSkipLocked + return model +} + +// LockShared sets the lock in share mode for current operation. +// This is equivalent to Lock("LOCK IN SHARE MODE") for MySQL or Lock("FOR SHARE") for PostgreSQL. +// Note: For maximum compatibility, this uses MySQL's legacy syntax. +func (m *Model) LockShared() *Model { + model := m.getModel() + model.lockInfo = LockInShareMode + return model +} diff --git a/database/gdb_model_option.go b/database/gdb_model_option.go new file mode 100644 index 0000000..a0115c0 --- /dev/null +++ b/database/gdb_model_option.go @@ -0,0 +1,73 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +const ( + optionOmitNil = optionOmitNilWhere | optionOmitNilData + optionOmitEmpty = optionOmitEmptyWhere | optionOmitEmptyData + optionOmitNilDataInternal = optionOmitNilData | optionOmitNilDataList // this option is used internally only for ForDao feature. + optionOmitEmptyWhere = 1 << iota // 8 + optionOmitEmptyData // 16 + optionOmitNilWhere // 32 + optionOmitNilData // 64 + optionOmitNilDataList // 128 +) + +// OmitEmpty sets optionOmitEmpty option for the model, which automatically filers +// the data and where parameters for `empty` values. +func (m *Model) OmitEmpty() *Model { + model := m.getModel() + model.option = model.option | optionOmitEmpty + return model +} + +// OmitEmptyWhere sets optionOmitEmptyWhere option for the model, which automatically filers +// the Where/Having parameters for `empty` values. +// +// Eg: +// +// Where("id", []int{}).All() -> SELECT xxx FROM xxx WHERE 0=1 +// Where("name", "").All() -> SELECT xxx FROM xxx WHERE `name`='' +// OmitEmpty().Where("id", []int{}).All() -> SELECT xxx FROM xxx +// OmitEmpty().("name", "").All() -> SELECT xxx FROM xxx. +func (m *Model) OmitEmptyWhere() *Model { + model := m.getModel() + model.option = model.option | optionOmitEmptyWhere + return model +} + +// OmitEmptyData sets optionOmitEmptyData option for the model, which automatically filers +// the Data parameters for `empty` values. +func (m *Model) OmitEmptyData() *Model { + model := m.getModel() + model.option = model.option | optionOmitEmptyData + return model +} + +// OmitNil sets optionOmitNil option for the model, which automatically filers +// the data and where parameters for `nil` values. +func (m *Model) OmitNil() *Model { + model := m.getModel() + model.option = model.option | optionOmitNil + return model +} + +// OmitNilWhere sets optionOmitNilWhere option for the model, which automatically filers +// the Where/Having parameters for `nil` values. +func (m *Model) OmitNilWhere() *Model { + model := m.getModel() + model.option = model.option | optionOmitNilWhere + return model +} + +// OmitNilData sets optionOmitNilData option for the model, which automatically filers +// the Data parameters for `nil` values. +func (m *Model) OmitNilData() *Model { + model := m.getModel() + model.option = model.option | optionOmitNilData + return model +} diff --git a/database/gdb_model_order_group.go b/database/gdb_model_order_group.go new file mode 100644 index 0000000..4080c0c --- /dev/null +++ b/database/gdb_model_order_group.go @@ -0,0 +1,95 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "strings" + + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// Order sets the "ORDER BY" statement for the model. +// +// Example: +// Order("id desc") +// Order("id", "desc") +// Order("id desc,name asc") +// Order("id desc", "name asc") +// Order("id desc").Order("name asc") +// Order(gdb.Raw("field(id, 3,1,2)")). +func (m *Model) Order(orderBy ...any) *Model { + if len(orderBy) == 0 { + return m + } + var ( + core = m.db.GetCore() + model = m.getModel() + ) + for _, v := range orderBy { + if model.orderBy != "" { + model.orderBy += "," + } + switch v.(type) { + case Raw, *Raw: + model.orderBy += gconv.String(v) + default: + orderByStr := gconv.String(v) + if gstr.Contains(orderByStr, " ") { + model.orderBy += core.QuoteString(orderByStr) + } else { + if gstr.Equal(orderByStr, "ASC") || gstr.Equal(orderByStr, "DESC") { + model.orderBy = gstr.TrimRight(model.orderBy, ",") + model.orderBy += " " + orderByStr + } else { + model.orderBy += core.QuoteWord(orderByStr) + } + } + } + } + return model +} + +// OrderAsc sets the "ORDER BY xxx ASC" statement for the model. +func (m *Model) OrderAsc(column string) *Model { + if len(column) == 0 { + return m + } + return m.Order(column + " ASC") +} + +// OrderDesc sets the "ORDER BY xxx DESC" statement for the model. +func (m *Model) OrderDesc(column string) *Model { + if len(column) == 0 { + return m + } + return m.Order(column + " DESC") +} + +// OrderRandom sets the "ORDER BY RANDOM()" statement for the model. +func (m *Model) OrderRandom() *Model { + model := m.getModel() + model.orderBy = m.db.OrderRandomFunction() + return model +} + +// Group sets the "GROUP BY" statement for the model. +func (m *Model) Group(groupBy ...string) *Model { + if len(groupBy) == 0 { + return m + } + var ( + core = m.db.GetCore() + model = m.getModel() + ) + + if model.groupBy != "" { + model.groupBy += "," + } + model.groupBy += core.QuoteString(strings.Join(groupBy, ",")) + return model +} diff --git a/database/gdb_model_select.go b/database/gdb_model_select.go new file mode 100644 index 0000000..7eb46a0 --- /dev/null +++ b/database/gdb_model_select.go @@ -0,0 +1,964 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "fmt" + "reflect" + + "git.magicany.cc/black1552/gin-base/database/reflection" + "github.com/gogf/gf/v2/container/gset" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// All does "SELECT FROM ..." statement for the model. +// It retrieves the records from table and returns the result as slice type. +// It returns nil if there's no record retrieved with the given conditions from table. +// +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +func (m *Model) All(where ...any) (Result, error) { + var ctx = m.GetCtx() + return m.doGetAll(ctx, SelectTypeDefault, false, where...) +} + +// AllAndCount retrieves all records and the total count of records from the model. +// If useFieldForCount is true, it will use the fields specified in the model for counting; +// otherwise, it will use a constant value of 1 for counting. +// It returns the result as a slice of records, the total count of records, and an error if any. +// The where parameter is an optional list of conditions to use when retrieving records. +// +// Example: +// +// var model Model +// var result Result +// var count int +// where := []any{"name = ?", "John"} +// result, count, err := model.AllAndCount(true) +// if err != nil { +// // Handle error. +// } +// fmt.Println(result, count) +func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount int, err error) { + // Clone the model for counting + countModel := m.Clone() + + // Decide how to build the COUNT() expression: + // - If caller explicitly wants to use the single field expression for counting, + // honor it (e.g. Fields("DISTINCT col") with useFieldForCount = true). + // - Otherwise, clear fields to let Count() use its default COUNT(1), + // avoiding invalid COUNT(field1, field2, ...) with multiple fields, + // or incorrect COUNT(DISTINCT 1) when Distinct() is set. + if useFieldForCount && len(m.fields) == 1 { + countModel.fields = m.fields + } else { + countModel.fields = nil + } + if len(m.pageCacheOption) > 0 { + countModel = countModel.Cache(m.pageCacheOption[0]) + } + + // Get the total count of records + totalCount, err = countModel.Count() + if err != nil { + return + } + + // If the total count is 0, there are no records to retrieve, so return early + if totalCount == 0 { + return + } + + resultModel := m.Clone() + if len(m.pageCacheOption) > 1 { + resultModel = resultModel.Cache(m.pageCacheOption[1]) + } + + // Retrieve all records + result, err = resultModel.doGetAll(m.GetCtx(), SelectTypeDefault, false) + return +} + +// Chunk iterates the query result with given `size` and `handler` function. +func (m *Model) Chunk(size int, handler ChunkHandler) { + page := m.start + if page <= 0 { + page = 1 + } + model := m + for { + model = model.Page(page, size) + data, err := model.All() + if err != nil { + handler(nil, err) + break + } + if len(data) == 0 { + break + } + if !handler(data, err) { + break + } + if len(data) < size { + break + } + page++ + } +} + +// One retrieves one record from table and returns the result as map type. +// It returns nil if there's no record retrieved with the given conditions from table. +// +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +func (m *Model) One(where ...any) (Record, error) { + var ctx = m.GetCtx() + if len(where) > 0 { + return m.Where(where[0], where[1:]...).One() + } + all, err := m.doGetAll(ctx, SelectTypeDefault, true) + if err != nil { + return nil, err + } + if len(all) > 0 { + return all[0], nil + } + return nil, nil +} + +// Array queries and returns data values as slice from database. +// Note that if there are multiple columns in the result, it returns just one column values randomly. +// +// If the optional parameter `fieldsAndWhere` is given, the fieldsAndWhere[0] is the selected fields +// and fieldsAndWhere[1:] is treated as where condition fields. +// Also see Model.Fields and Model.Where functions. +func (m *Model) Array(fieldsAndWhere ...any) (Array, error) { + if len(fieldsAndWhere) > 0 { + if len(fieldsAndWhere) > 2 { + return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Array() + } else if len(fieldsAndWhere) == 2 { + return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1]).Array() + } else { + return m.Fields(gconv.String(fieldsAndWhere[0])).Array() + } + } + + var ( + field string + core = m.db.GetCore() + ctx = core.injectInternalColumn(m.GetCtx()) + ) + all, err := m.doGetAll(ctx, SelectTypeArray, false) + if err != nil { + return nil, err + } + if len(all) > 0 { + internalData := core.getInternalColumnFromCtx(ctx) + if internalData == nil { + return nil, gerror.NewCode( + gcode.CodeInternalError, + `query count error: the internal context data is missing. there's internal issue should be fixed`, + ) + } + // If FirstResultColumn present, it returns the value of the first record of the first field. + // It means it use no cache mechanism, while cache mechanism makes `internalData` missing. + field = internalData.FirstResultColumn + if field == "" { + // Fields number check. + var recordFields = m.getRecordFields(all[0]) + if len(recordFields) == 1 { + field = recordFields[0] + } else { + // it returns error if there are multiple fields in the result record. + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `invalid fields for "Array" operation, result fields number "%d"%s, but expect one`, + len(recordFields), + gjson.MustEncodeString(recordFields), + ) + } + } + } + return all.Array(field), nil +} + +// Struct retrieves one record from table and converts it into given struct. +// The parameter `pointer` should be type of *struct/**struct. If type **struct is given, +// it can create the struct internally during converting. +// +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +// +// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has +// default value and there's no record retrieved with the given conditions from table. +// +// Example: +// user := new(User) +// err := db.Model("user").Where("id", 1).Scan(user) +// +// user := (*User)(nil) +// err := db.Model("user").Where("id", 1).Scan(&user). +func (m *Model) doStruct(pointer any, where ...any) error { + model := m + // Auto selecting fields by struct attributes. + if len(model.fieldsEx) == 0 && len(model.fields) == 0 { + if v, ok := pointer.(reflect.Value); ok { + model = m.Fields(v.Interface()) + } else { + model = m.Fields(pointer) + } + } + one, err := model.One(where...) + if err != nil { + return err + } + if err = one.Struct(pointer); err != nil { + return err + } + return model.doWithScanStruct(pointer) +} + +// Structs retrieves records from table and converts them into given struct slice. +// The parameter `pointer` should be type of *[]struct/*[]*struct. It can create and fill the struct +// slice internally during converting. +// +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +// +// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has +// default value and there's no record retrieved with the given conditions from table. +// +// Example: +// users := ([]User)(nil) +// err := db.Model("user").Scan(&users) +// +// users := ([]*User)(nil) +// err := db.Model("user").Scan(&users). +func (m *Model) doStructs(pointer any, where ...any) error { + model := m + // Auto selecting fields by struct attributes. + if len(model.fieldsEx) == 0 && len(model.fields) == 0 { + if v, ok := pointer.(reflect.Value); ok { + model = m.Fields( + reflect.New( + v.Type().Elem(), + ).Interface(), + ) + } else { + model = m.Fields( + reflect.New( + reflect.ValueOf(pointer).Elem().Type().Elem(), + ).Interface(), + ) + } + } + all, err := model.All(where...) + if err != nil { + return err + } + if err = all.Structs(pointer); err != nil { + return err + } + return model.doWithScanStructs(pointer) +} + +// Scan automatically calls Struct or Structs function according to the type of parameter `pointer`. +// It calls function doStruct if `pointer` is type of *struct/**struct. +// It calls function doStructs if `pointer` is type of *[]struct/*[]*struct. +// +// The optional parameter `where` is the same as the parameter of Model.Where function, see Model.Where. +// +// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has +// default value and there's no record retrieved with the given conditions from table. +// +// Example: +// user := new(User) +// err := db.Model("user").Where("id", 1).Scan(user) +// +// user := (*User)(nil) +// err := db.Model("user").Where("id", 1).Scan(&user) +// +// users := ([]User)(nil) +// err := db.Model("user").Scan(&users) +// +// users := ([]*User)(nil) +// err := db.Model("user").Scan(&users). +func (m *Model) Scan(pointer any, where ...any) error { + reflectInfo := reflection.OriginTypeAndKind(pointer) + if reflectInfo.InputKind != reflect.Pointer { + return gerror.NewCode( + gcode.CodeInvalidParameter, + `the parameter "pointer" for function Scan should type of pointer`, + ) + } + switch reflectInfo.OriginKind { + case reflect.Slice, reflect.Array: + return m.doStructs(pointer, where...) + + case reflect.Struct, reflect.Invalid: + return m.doStruct(pointer, where...) + + default: + return gerror.NewCode( + gcode.CodeInvalidParameter, + `element of parameter "pointer" for function Scan should type of struct/*struct/[]struct/[]*struct`, + ) + } +} + +// ScanAndCount scans a single record or record array that matches the given conditions and counts the total number +// of records that match those conditions. +// +// If `useFieldForCount` is true, it will use the fields specified in the model for counting; +// The `pointer` parameter is a pointer to a struct that the scanned data will be stored in. +// The `totalCount` parameter is a pointer to an integer that will be set to the total number of records that match the given conditions. +// The where parameter is an optional list of conditions to use when retrieving records. +// +// Example: +// +// var count int +// user := new(User) +// err := db.Model("user").Where("id", 1).ScanAndCount(user,&count,true) +// fmt.Println(user, count) +// +// Example Join: +// +// type User struct { +// Id int +// Passport string +// Name string +// Age int +// } +// var users []User +// var count int +// db.Model(table).As("u1"). +// LeftJoin(tableName2, "u2", "u2.id=u1.id"). +// Fields("u1.passport,u1.id,u2.name,u2.age"). +// Where("u1.id<2"). +// ScanAndCount(&users, &count, false) +func (m *Model) ScanAndCount(pointer any, totalCount *int, useFieldForCount bool) (err error) { + // support Fields with *, example: .Fields("a.*, b.name"). Count sql is select count(1) from xxx + countModel := m.Clone() + // Decide how to build the COUNT() expression: + // - If caller explicitly wants to use the single field expression for counting, + // honor it (e.g. Fields("DISTINCT col") with useFieldForCount = true). + // - Otherwise, clear fields to let Count() use its default COUNT(1), + // avoiding invalid COUNT(field1, field2, ...) with multiple fields, + // or incorrect COUNT(DISTINCT 1) when Distinct() is set. + if useFieldForCount && len(m.fields) == 1 { + countModel.fields = m.fields + } else { + countModel.fields = nil + } + if len(m.pageCacheOption) > 0 { + countModel = countModel.Cache(m.pageCacheOption[0]) + } + // Get the total count of records + *totalCount, err = countModel.Count() + if err != nil { + return err + } + + // If the total count is 0, there are no records to retrieve, so return early + if *totalCount == 0 { + return + } + scanModel := m.Clone() + if len(m.pageCacheOption) > 1 { + scanModel = scanModel.Cache(m.pageCacheOption[1]) + } + err = scanModel.Scan(pointer) + return +} + +// ScanList converts `r` to struct slice which contains other complex struct attributes. +// Note that the parameter `listPointer` should be type of *[]struct/*[]*struct. +// +// See Result.ScanList. +func (m *Model) ScanList(structSlicePointer any, bindToAttrName string, relationAttrNameAndFields ...string) (err error) { + var result Result + out, err := checkGetSliceElementInfoForScanList(structSlicePointer, bindToAttrName) + if err != nil { + return err + } + if len(m.fields) > 0 || len(m.fieldsEx) != 0 { + // There are custom fields. + result, err = m.All() + } else { + // Filter fields using temporary created struct using reflect.New. + result, err = m.Fields(reflect.New(out.BindToAttrType).Interface()).All() + } + if err != nil { + return err + } + var ( + relationAttrName string + relationFields string + ) + switch len(relationAttrNameAndFields) { + case 2: + relationAttrName = relationAttrNameAndFields[0] + relationFields = relationAttrNameAndFields[1] + case 1: + relationFields = relationAttrNameAndFields[0] + } + return doScanList(doScanListInput{ + Model: m, + Result: result, + StructSlicePointer: structSlicePointer, + StructSliceValue: out.SliceReflectValue, + BindToAttrName: bindToAttrName, + RelationAttrName: relationAttrName, + RelationFields: relationFields, + }) +} + +// Value retrieves a specified record value from table and returns the result as interface type. +// It returns nil if there's no record found with the given conditions from table. +// +// If the optional parameter `fieldsAndWhere` is given, the fieldsAndWhere[0] is the selected fields +// and fieldsAndWhere[1:] is treated as where condition fields. +// Also see Model.Fields and Model.Where functions. +func (m *Model) Value(fieldsAndWhere ...any) (Value, error) { + var ( + core = m.db.GetCore() + ctx = core.injectInternalColumn(m.GetCtx()) + ) + if len(fieldsAndWhere) > 0 { + if len(fieldsAndWhere) > 2 { + return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value() + } else if len(fieldsAndWhere) == 2 { + return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1]).Value() + } else { + return m.Fields(gconv.String(fieldsAndWhere[0])).Value() + } + } + var ( + sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true) + all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...) + ) + if err != nil { + return nil, err + } + if len(all) > 0 { + internalData := core.getInternalColumnFromCtx(ctx) + if internalData == nil { + return nil, gerror.NewCode( + gcode.CodeInternalError, + `query count error: the internal context data is missing. there's internal issue should be fixed`, + ) + } + // If FirstResultColumn present, it returns the value of the first record of the first field. + // It means it use no cache mechanism, while cache mechanism makes `internalData` missing. + if v, ok := all[0][internalData.FirstResultColumn]; ok { + return v, nil + } + // Fields number check. + var recordFields = m.getRecordFields(all[0]) + if len(recordFields) == 1 { + for _, v := range all[0] { + return v, nil + } + } + // it returns error if there are multiple fields in the result record. + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `invalid fields for "Value" operation, result fields number "%d"%s, but expect one`, + len(recordFields), + gjson.MustEncodeString(recordFields), + ) + } + return nil, nil +} + +func (m *Model) getRecordFields(record Record) []string { + if len(record) == 0 { + return nil + } + var fields = make([]string, 0) + for k := range record { + fields = append(fields, k) + } + return fields +} + +// Count does "SELECT COUNT(x) FROM ..." statement for the model. +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +func (m *Model) Count(where ...any) (int, error) { + var ( + core = m.db.GetCore() + ctx = core.injectInternalColumn(m.GetCtx()) + ) + if len(where) > 0 { + return m.Where(where[0], where[1:]...).Count() + } + var ( + sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false) + all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...) + ) + if err != nil { + return 0, err + } + if len(all) > 0 { + internalData := core.getInternalColumnFromCtx(ctx) + if internalData == nil { + return 0, gerror.NewCode( + gcode.CodeInternalError, + `query count error: the internal context data is missing. there's internal issue should be fixed`, + ) + } + // If FirstResultColumn present, it returns the value of the first record of the first field. + // It means it use no cache mechanism, while cache mechanism makes `internalData` missing. + if v, ok := all[0][internalData.FirstResultColumn]; ok { + return v.Int(), nil + } + // Fields number check. + var recordFields = m.getRecordFields(all[0]) + if len(recordFields) == 1 { + for _, v := range all[0] { + return v.Int(), nil + } + } + // it returns error if there are multiple fields in the result record. + return 0, gerror.NewCodef( + gcode.CodeInvalidParameter, + `invalid fields for "Count" operation, result fields number "%d"%s, but expect one`, + len(recordFields), + gjson.MustEncodeString(recordFields), + ) + } + return 0, nil +} + +// Exist does "SELECT 1 FROM ... LIMIT 1" statement for the model. +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +func (m *Model) Exist(where ...any) (bool, error) { + if len(where) > 0 { + return m.Where(where[0], where[1:]...).Exist() + } + one, err := m.Fields(Raw("1")).One() + if err != nil { + return false, err + } + for _, val := range one { + if val.Bool() { + return true, nil + } + } + return false, nil +} + +// CountColumn does "SELECT COUNT(x) FROM ..." statement for the model. +func (m *Model) CountColumn(column string) (int, error) { + if len(column) == 0 { + return 0, nil + } + return m.Fields(column).Count() +} + +// Min does "SELECT MIN(x) FROM ..." statement for the model. +func (m *Model) Min(column string) (float64, error) { + if len(column) == 0 { + return 0, nil + } + value, err := m.Fields(fmt.Sprintf(`MIN(%s)`, m.QuoteWord(column))).Value() + if err != nil { + return 0, err + } + return value.Float64(), err +} + +// Max does "SELECT MAX(x) FROM ..." statement for the model. +func (m *Model) Max(column string) (float64, error) { + if len(column) == 0 { + return 0, nil + } + value, err := m.Fields(fmt.Sprintf(`MAX(%s)`, m.QuoteWord(column))).Value() + if err != nil { + return 0, err + } + return value.Float64(), err +} + +// Avg does "SELECT AVG(x) FROM ..." statement for the model. +func (m *Model) Avg(column string) (float64, error) { + if len(column) == 0 { + return 0, nil + } + value, err := m.Fields(fmt.Sprintf(`AVG(%s)`, m.QuoteWord(column))).Value() + if err != nil { + return 0, err + } + return value.Float64(), err +} + +// Sum does "SELECT SUM(x) FROM ..." statement for the model. +func (m *Model) Sum(column string) (float64, error) { + if len(column) == 0 { + return 0, nil + } + value, err := m.Fields(fmt.Sprintf(`SUM(%s)`, m.QuoteWord(column))).Value() + if err != nil { + return 0, err + } + return value.Float64(), err +} + +// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement for the model. +func (m *Model) Union(unions ...*Model) *Model { + return m.db.Union(unions...) +} + +// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement for the model. +func (m *Model) UnionAll(unions ...*Model) *Model { + return m.db.UnionAll(unions...) +} + +// Limit sets the "LIMIT" statement for the model. +// The parameter `limit` can be either one or two number, if passed two number is passed, +// it then sets "LIMIT limit[0],limit[1]" statement for the model, or else it sets "LIMIT limit[0]" +// statement. +// Note: Negative values are treated as zero. +func (m *Model) Limit(limit ...int) *Model { + model := m.getModel() + switch len(limit) { + case 1: + if limit[0] < 0 { + limit[0] = 0 + } + model.limit = limit[0] + case 2: + if limit[0] < 0 { + limit[0] = 0 + } + if limit[1] < 0 { + limit[1] = 0 + } + model.start = limit[0] + model.limit = limit[1] + } + return model +} + +// Offset sets the "OFFSET" statement for the model. +// It only makes sense for some databases like SQLServer, PostgreSQL, etc. +// Note: Negative values are treated as zero. +func (m *Model) Offset(offset int) *Model { + model := m.getModel() + if offset < 0 { + offset = 0 + } + model.offset = offset + return model +} + +// Distinct forces the query to only return distinct results. +func (m *Model) Distinct() *Model { + model := m.getModel() + model.distinct = "DISTINCT " + return model +} + +// Page sets the paging number for the model. +// The parameter `page` is started from 1 for paging. +// Note that, it differs that the Limit function starts from 0 for "LIMIT" statement. +// Note: Negative limit values are treated as zero. +func (m *Model) Page(page, limit int) *Model { + model := m.getModel() + if page <= 0 { + page = 1 + } + if limit < 0 { + limit = 0 + } + model.start = (page - 1) * limit + model.limit = limit + return model +} + +// Having sets the having statement for the model. +// The parameters of this function usage are as the same as function Where. +// See Where. +func (m *Model) Having(having any, args ...any) *Model { + model := m.getModel() + model.having = []any{ + having, args, + } + return model +} + +// doGetAll does "SELECT FROM ..." statement for the model. +// It retrieves the records from table and returns the result as slice type. +// It returns nil if there's no record retrieved with the given conditions from table. +// +// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set. +// The optional parameter `where` is the same as the parameter of Model.Where function, +// see Model.Where. +func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...any) (Result, error) { + if len(where) > 0 { + return m.Where(where[0], where[1:]...).All() + } + sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1) + return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...) +} + +// doGetAllBySql does the select statement on the database. +func (m *Model) doGetAllBySql( + ctx context.Context, selectType SelectType, sql string, args ...any, +) (result Result, err error) { + if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil { + return + } + + in := &HookSelectInput{ + internalParamHookSelect: internalParamHookSelect{ + internalParamHook: internalParamHook{ + link: m.getLink(false), + }, + handler: m.hookHandler.Select, + }, + Model: m, + Table: m.tables, + Schema: m.schema, + Sql: sql, + Args: m.mergeArguments(args), + SelectType: selectType, + } + if result, err = in.Next(ctx); err != nil { + return + } + + err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...) + return +} + +func (m *Model) getFormattedSqlAndArgs( + ctx context.Context, selectType SelectType, limit1 bool, +) (sqlWithHolder string, holderArgs []any) { + switch selectType { + case SelectTypeCount: + queryFields := "COUNT(1)" + if len(m.fields) > 0 { + // DO NOT quote the m.fields here, in case of fields like: + // DISTINCT t.user_id uid + queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.getFieldsAsStr()) + } + // Raw SQL Model. + if m.rawSql != "" { + conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true) + sqlWithHolder = fmt.Sprintf( + "SELECT %s FROM (%s%s) AS T", + queryFields, m.rawSql, conditionWhere+conditionExtra, + ) + return sqlWithHolder, conditionArgs + } + conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true) + sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra) + if len(m.groupBy) > 0 { + sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder) + } + return sqlWithHolder, conditionArgs + + default: + conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, limit1, false) + // Raw SQL Model, especially for UNION/UNION ALL featured SQL. + if m.rawSql != "" { + sqlWithHolder = fmt.Sprintf( + "%s%s", + m.rawSql, + conditionWhere+conditionExtra, + ) + return sqlWithHolder, conditionArgs + } + // DO NOT quote the m.fields where, in case of fields like: + // DISTINCT t.user_id uid + sqlWithHolder = fmt.Sprintf( + "SELECT %s%s FROM %s%s", + m.distinct, m.getFieldsFiltered(), m.tables, conditionWhere+conditionExtra, + ) + return sqlWithHolder, conditionArgs + } +} + +func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []any) { + holder, args = m.getFormattedSqlAndArgs( + ctx, SelectTypeDefault, false, + ) + args = m.mergeArguments(args) + return +} + +func (m *Model) getAutoPrefix() string { + autoPrefix := "" + if gstr.Contains(m.tables, " JOIN ") { + autoPrefix = m.QuoteWord( + m.db.GetCore().guessPrimaryTableName(m.tablesInit), + ) + } + return autoPrefix +} + +func (m *Model) getFieldsAsStr() string { + var ( + fieldsStr string + ) + for _, v := range m.fields { + field := gconv.String(v) + switch { + case gstr.ContainsAny(field, "()"): + case gstr.ContainsAny(field, ". "): + default: + switch v.(type) { + case Raw, *Raw: + default: + field = m.QuoteWord(field) + } + } + if fieldsStr != "" { + fieldsStr += "," + } + fieldsStr += field + } + return fieldsStr +} + +// getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will +// really be committed to underlying database driver. +func (m *Model) getFieldsFiltered() string { + if len(m.fieldsEx) == 0 && len(m.fields) == 0 { + return defaultField + } + if len(m.fieldsEx) == 0 && len(m.fields) > 0 { + return m.getFieldsAsStr() + } + var ( + fieldsArray []string + fieldsExSet = gset.NewStrSetFrom(gconv.Strings(m.fieldsEx)) + ) + if len(m.fields) > 0 { + // Filter custom fields with fieldEx. + fieldsArray = make([]string, 0, 8) + for _, v := range m.fields { + field := gconv.String(v) + fieldsArray = append(fieldsArray, field[gstr.PosR(field, "-")+1:]) + } + } else { + if gstr.Contains(m.tables, " ") { + panic("function FieldsEx supports only single table operations") + } + // Filter table fields with fieldEx. + tableFields, err := m.TableFields(m.tablesInit) + if err != nil { + panic(err) + } + if len(tableFields) == 0 { + panic(fmt.Sprintf(`empty table fields for table "%s"`, m.tables)) + } + fieldsArray = make([]string, len(tableFields)) + for k, v := range tableFields { + fieldsArray[v.Index] = k + } + } + newFields := "" + for _, k := range fieldsArray { + if fieldsExSet.Contains(k) { + continue + } + if len(newFields) > 0 { + newFields += "," + } + newFields += m.QuoteWord(k) + } + return newFields +} + +// formatCondition formats where arguments of the model and returns a new condition sql and its arguments. +// Note that this function does not change any attribute value of the `m`. +// +// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set. +func (m *Model) formatCondition( + ctx context.Context, limit1 bool, isCountStatement bool, +) (conditionWhere string, conditionExtra string, conditionArgs []any) { + var autoPrefix = m.getAutoPrefix() + // GROUP BY. + if m.groupBy != "" { + conditionExtra += " GROUP BY " + m.groupBy + } + // WHERE + conditionWhere, conditionArgs = m.whereBuilder.Build() + softDeletingCondition := m.softTimeMaintainer().GetDeleteCondition(ctx) + if m.rawSql != "" && conditionWhere != "" { + if gstr.ContainsI(m.rawSql, " WHERE ") { + conditionWhere = " AND " + conditionWhere + } else { + conditionWhere = " WHERE " + conditionWhere + } + } else if !m.unscoped && softDeletingCondition != "" { + if conditionWhere == "" { + conditionWhere = fmt.Sprintf(` WHERE %s`, softDeletingCondition) + } else { + conditionWhere = fmt.Sprintf(` WHERE (%s) AND %s`, conditionWhere, softDeletingCondition) + } + } else { + if conditionWhere != "" { + conditionWhere = " WHERE " + conditionWhere + } + } + // HAVING. + if len(m.having) > 0 { + havingHolder := WhereHolder{ + Where: m.having[0], + Args: gconv.Interfaces(m.having[1]), + Prefix: autoPrefix, + } + havingStr, havingArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{ + WhereHolder: havingHolder, + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + Schema: m.schema, + Table: m.tables, + }) + if len(havingStr) > 0 { + conditionExtra += " HAVING " + havingStr + conditionArgs = append(conditionArgs, havingArgs...) + } + } + // ORDER BY. + if !isCountStatement { // The count statement of sqlserver cannot contain the order by statement + if m.orderBy != "" { + conditionExtra += " ORDER BY " + m.orderBy + } + } + // LIMIT. + if !isCountStatement { + if m.limit != 0 { + if m.start >= 0 { + conditionExtra += fmt.Sprintf(" LIMIT %d,%d", m.start, m.limit) + } else { + conditionExtra += fmt.Sprintf(" LIMIT %d", m.limit) + } + } else if limit1 { + conditionExtra += " LIMIT 1" + } + + if m.offset >= 0 { + conditionExtra += fmt.Sprintf(" OFFSET %d", m.offset) + } + } + + if m.lockInfo != "" { + conditionExtra += " " + m.lockInfo + } + return +} diff --git a/database/gdb_model_sharding.go b/database/gdb_model_sharding.go new file mode 100644 index 0000000..eab1340 --- /dev/null +++ b/database/gdb_model_sharding.go @@ -0,0 +1,161 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "fmt" + "hash/fnv" + "reflect" + + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/util/gconv" +) + +// ShardingConfig defines the configuration for database/table sharding. +type ShardingConfig struct { + // Table sharding configuration + Table ShardingTableConfig + // Schema sharding configuration + Schema ShardingSchemaConfig +} + +// ShardingSchemaConfig defines the configuration for database sharding. +type ShardingSchemaConfig struct { + // Enable schema sharding + Enable bool + // Schema rule prefix, e.g., "db_" + Prefix string + // ShardingRule defines how to route data to different database nodes + Rule ShardingRule +} + +// ShardingTableConfig defines the configuration for table sharding +type ShardingTableConfig struct { + // Enable table sharding + Enable bool + // Table rule prefix, e.g., "user_" + Prefix string + // ShardingRule defines how to route data to different tables + Rule ShardingRule +} + +// ShardingRule defines the interface for sharding rules +type ShardingRule interface { + // SchemaName returns the target schema name based on sharding value. + SchemaName(ctx context.Context, config ShardingSchemaConfig, value any) (string, error) + // TableName returns the target table name based on sharding value. + TableName(ctx context.Context, config ShardingTableConfig, value any) (string, error) +} + +// DefaultShardingRule implements a simple modulo-based sharding rule +type DefaultShardingRule struct { + // Number of schema count. + SchemaCount int + // Number of tables per schema. + TableCount int +} + +// Sharding creates a sharding model with given sharding configuration. +func (m *Model) Sharding(config ShardingConfig) *Model { + model := m.getModel() + model.shardingConfig = config + return model +} + +// ShardingValue sets the sharding value for routing +func (m *Model) ShardingValue(value any) *Model { + model := m.getModel() + model.shardingValue = value + return model +} + +// getActualSchema returns the actual schema based on sharding configuration. +// TODO it does not support schemas in different database config node. +func (m *Model) getActualSchema(ctx context.Context, defaultSchema string) (string, error) { + if !m.shardingConfig.Schema.Enable { + return defaultSchema, nil + } + if m.shardingValue == nil { + return defaultSchema, gerror.NewCode( + gcode.CodeInvalidParameter, "sharding value is required when sharding feature enabled", + ) + } + if m.shardingConfig.Schema.Rule == nil { + return defaultSchema, gerror.NewCode( + gcode.CodeInvalidParameter, "sharding rule is required when sharding feature enabled", + ) + } + return m.shardingConfig.Schema.Rule.SchemaName(ctx, m.shardingConfig.Schema, m.shardingValue) +} + +// getActualTable returns the actual table name based on sharding configuration +func (m *Model) getActualTable(ctx context.Context, defaultTable string) (string, error) { + if !m.shardingConfig.Table.Enable { + return defaultTable, nil + } + if m.shardingValue == nil { + return defaultTable, gerror.NewCode( + gcode.CodeInvalidParameter, "sharding value is required when sharding feature enabled", + ) + } + if m.shardingConfig.Table.Rule == nil { + return defaultTable, gerror.NewCode( + gcode.CodeInvalidParameter, "sharding rule is required when sharding feature enabled", + ) + } + return m.shardingConfig.Table.Rule.TableName(ctx, m.shardingConfig.Table, m.shardingValue) +} + +// SchemaName implements the default database sharding strategy +func (r *DefaultShardingRule) SchemaName(ctx context.Context, config ShardingSchemaConfig, value any) (string, error) { + if r.SchemaCount == 0 { + return "", gerror.NewCode( + gcode.CodeInvalidParameter, "schema count should not be 0 using DefaultShardingRule when schema sharding enabled", + ) + } + hashValue, err := getHashValue(value) + if err != nil { + return "", err + } + nodeIndex := hashValue % uint64(r.SchemaCount) + return fmt.Sprintf("%s%d", config.Prefix, nodeIndex), nil +} + +// TableName implements the default table sharding strategy +func (r *DefaultShardingRule) TableName(ctx context.Context, config ShardingTableConfig, value any) (string, error) { + if r.TableCount == 0 { + return "", gerror.NewCode( + gcode.CodeInvalidParameter, "table count should not be 0 using DefaultShardingRule when table sharding enabled", + ) + } + hashValue, err := getHashValue(value) + if err != nil { + return "", err + } + tableIndex := hashValue % uint64(r.TableCount) + return fmt.Sprintf("%s%d", config.Prefix, tableIndex), nil +} + +// getHashValue converts sharding value to uint64 hash +func getHashValue(value any) (uint64, error) { + var rv = reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return gconv.Uint64(value), nil + default: + h := fnv.New64a() + _, err := h.Write(gconv.Bytes(value)) + if err != nil { + return 0, gerror.WrapCode(gcode.CodeInternalError, err) + } + return h.Sum64(), nil + } +} diff --git a/database/gdb_model_soft_time.go b/database/gdb_model_soft_time.go new file mode 100644 index 0000000..658f5ca --- /dev/null +++ b/database/gdb_model_soft_time.go @@ -0,0 +1,384 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "fmt" + "strings" + + "git.magicany.cc/black1552/gin-base/database/intlog" + "git.magicany.cc/black1552/gin-base/database/utils" + "github.com/gogf/gf/v2/container/garray" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gcache" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" +) + +// SoftTimeType custom defines the soft time field type. +type SoftTimeType int + +const ( + SoftTimeTypeAuto SoftTimeType = 0 // (Default)Auto detect the field type by table field type. + SoftTimeTypeTime SoftTimeType = 1 // Using datetime as the field value. + SoftTimeTypeTimestamp SoftTimeType = 2 // In unix seconds. + SoftTimeTypeTimestampMilli SoftTimeType = 3 // In unix milliseconds. + SoftTimeTypeTimestampMicro SoftTimeType = 4 // In unix microseconds. + SoftTimeTypeTimestampNano SoftTimeType = 5 // In unix nanoseconds. +) + +// SoftTimeOption is the option to customize soft time feature for Model. +type SoftTimeOption struct { + SoftTimeType SoftTimeType // The value type for soft time field. +} + +type softTimeMaintainer struct { + *Model +} + +// SoftTimeFieldType represents different soft time field purposes. +type SoftTimeFieldType int + +const ( + SoftTimeFieldCreate SoftTimeFieldType = iota + SoftTimeFieldUpdate + SoftTimeFieldDelete +) + +type iSoftTimeMaintainer interface { + // GetFieldInfo returns field name and type for specified field purpose. + GetFieldInfo(ctx context.Context, schema, table string, fieldPurpose SoftTimeFieldType) (fieldName string, localType LocalType) + + // GetFieldValue generates value for create/update/delete operations. + GetFieldValue(ctx context.Context, localType LocalType, isDeleted bool) any + + // GetDeleteCondition returns WHERE condition for soft delete query. + GetDeleteCondition(ctx context.Context) string + + // GetDeleteData returns UPDATE statement data for soft delete. + GetDeleteData(ctx context.Context, prefix, fieldName string, localType LocalType) (holder string, value any) +} + +// getSoftFieldNameAndTypeCacheItem is the internal struct for storing create/update/delete fields. +type getSoftFieldNameAndTypeCacheItem struct { + FieldName string + FieldType LocalType +} + +var ( + // Default field names of table for automatic-filled for record creating. + createdFieldNames = []string{"created_at", "create_at"} + // Default field names of table for automatic-filled for record updating. + updatedFieldNames = []string{"updated_at", "update_at"} + // Default field names of table for automatic-filled for record deleting. + deletedFieldNames = []string{"deleted_at", "delete_at"} +) + +// SoftTime sets the SoftTimeOption to customize soft time feature for Model. +func (m *Model) SoftTime(option SoftTimeOption) *Model { + model := m.getModel() + model.softTimeOption = option + return model +} + +// Unscoped disables the soft time feature for insert, update and delete operations. +func (m *Model) Unscoped() *Model { + model := m.getModel() + model.unscoped = true + return model +} + +func (m *Model) softTimeMaintainer() iSoftTimeMaintainer { + return &softTimeMaintainer{ + m, + } +} + +// GetFieldInfo returns field name and type for specified field purpose. +// It checks the key with or without cases or chars '-'/'_'/'.'/' '. +func (m *softTimeMaintainer) GetFieldInfo( + ctx context.Context, schema, table string, fieldPurpose SoftTimeFieldType, +) (fieldName string, localType LocalType) { + // Check if feature is disabled + if m.db.GetConfig().TimeMaintainDisabled { + return "", LocalTypeUndefined + } + + // Determine table name + tableName := table + if tableName == "" { + tableName = m.tablesInit + } + + // Get config and field candidates + config := m.db.GetConfig() + var ( + configField string + defaultFields []string + ) + + switch fieldPurpose { + case SoftTimeFieldCreate: + configField = config.CreatedAt + defaultFields = createdFieldNames + case SoftTimeFieldUpdate: + configField = config.UpdatedAt + defaultFields = updatedFieldNames + case SoftTimeFieldDelete: + configField = config.DeletedAt + defaultFields = deletedFieldNames + } + + // Use config field if specified, otherwise use defaults + if configField != "" { + return m.getSoftFieldNameAndType(ctx, schema, tableName, []string{configField}) + } + return m.getSoftFieldNameAndType(ctx, schema, tableName, defaultFields) +} + +// getSoftFieldNameAndType retrieves and returns the field name of the table for possible key. +func (m *softTimeMaintainer) getSoftFieldNameAndType( + ctx context.Context, schema, table string, candidateFields []string, +) (fieldName string, fieldType LocalType) { + // Build cache key + cacheKey := genSoftTimeFieldNameTypeCacheKey(schema, table, candidateFields) + + // Try to get from cache + cache := m.db.GetCore().GetInnerMemCache() + result, err := cache.GetOrSetFunc(ctx, cacheKey, func(ctx context.Context) (any, error) { + // Get table fields + fieldsMap, err := m.TableFields(table, schema) + if err != nil || len(fieldsMap) == 0 { + return nil, err + } + + // Search for matching field + for _, field := range candidateFields { + if name := searchFieldNameFromMap(fieldsMap, field); name != "" { + fType, _ := m.db.CheckLocalTypeForField(ctx, fieldsMap[name].Type, nil) + return getSoftFieldNameAndTypeCacheItem{ + FieldName: name, + FieldType: fType, + }, nil + } + } + return nil, nil + }, gcache.DurationNoExpire) + + if err != nil || result == nil { + return "", LocalTypeUndefined + } + + item := result.Val().(getSoftFieldNameAndTypeCacheItem) + return item.FieldName, item.FieldType +} + +func searchFieldNameFromMap(fieldsMap map[string]*TableField, key string) string { + if len(fieldsMap) == 0 { + return "" + } + _, ok := fieldsMap[key] + if ok { + return key + } + key = utils.RemoveSymbols(key) + for k := range fieldsMap { + if strings.EqualFold(utils.RemoveSymbols(k), key) { + return k + } + } + return "" +} + +// GetDeleteCondition returns WHERE condition for soft delete query. +// It supports multiple tables string like: +// "user u, user_detail ud" +// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid)" +// "user LEFT JOIN user_detail ON(user_detail.uid=user.uid)" +// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid) LEFT JOIN user_stats us ON(us.uid=u.uid)". +func (m *softTimeMaintainer) GetDeleteCondition(ctx context.Context) string { + if m.unscoped { + return "" + } + conditionArray := garray.NewStrArray() + if gstr.Contains(m.tables, " JOIN ") { + // Base table. + tableMatch, _ := gregex.MatchString(`(.+?) [A-Z]+ JOIN`, m.tables) + conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, tableMatch[1])) + // Multiple joined tables, exclude the sub query sql which contains char '(' and ')'. + tableMatches, _ := gregex.MatchAllString(`JOIN ([^()]+?) ON`, m.tables) + for _, match := range tableMatches { + conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, match[1])) + } + } + if conditionArray.Len() == 0 && gstr.Contains(m.tables, ",") { + // Multiple base tables. + for _, s := range gstr.SplitAndTrim(m.tables, ",") { + conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, s)) + } + } + conditionArray.FilterEmpty() + if conditionArray.Len() > 0 { + return conditionArray.Join(" AND ") + } + // Only one table. + fieldName, fieldType := m.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete) + if fieldName != "" { + return m.buildDeleteCondition(ctx, "", fieldName, fieldType) + } + return "" +} + +// getConditionOfTableStringForSoftDeleting does something as its name describes. +// Examples for `s`: +// - `test`.`demo` as b +// - `test`.`demo` b +// - `demo` +// - demo +func (m *softTimeMaintainer) getConditionOfTableStringForSoftDeleting(ctx context.Context, s string) string { + var ( + table string + schema string + array1 = gstr.SplitAndTrim(s, " ") + array2 = gstr.SplitAndTrim(array1[0], ".") + ) + if len(array2) >= 2 { + table = array2[1] + schema = array2[0] + } else { + table = array2[0] + } + fieldName, fieldType := m.GetFieldInfo(ctx, schema, table, SoftTimeFieldDelete) + if fieldName == "" { + return "" + } + if len(array1) >= 3 { + return m.buildDeleteCondition(ctx, array1[2], fieldName, fieldType) + } + if len(array1) >= 2 { + return m.buildDeleteCondition(ctx, array1[1], fieldName, fieldType) + } + return m.buildDeleteCondition(ctx, table, fieldName, fieldType) +} + +// GetDeleteData returns UPDATE statement data for soft delete. +func (m *softTimeMaintainer) GetDeleteData( + ctx context.Context, prefix, fieldName string, fieldType LocalType, +) (holder string, value any) { + core := m.db.GetCore() + quotedName := core.QuoteWord(fieldName) + + if prefix != "" { + quotedName = fmt.Sprintf(`%s.%s`, core.QuoteWord(prefix), quotedName) + } + + holder = fmt.Sprintf(`%s=?`, quotedName) + value = m.GetFieldValue(ctx, fieldType, false) + return +} + +// buildDeleteCondition builds WHERE condition for soft delete filtering. +func (m *softTimeMaintainer) buildDeleteCondition( + ctx context.Context, prefix, fieldName string, fieldType LocalType, +) string { + core := m.db.GetCore() + quotedName := core.QuoteWord(fieldName) + + if prefix != "" { + quotedName = fmt.Sprintf(`%s.%s`, core.QuoteWord(prefix), quotedName) + } + switch m.softTimeOption.SoftTimeType { + case SoftTimeTypeAuto: + switch fieldType { + case LocalTypeDate, LocalTypeTime, LocalTypeDatetime: + return fmt.Sprintf(`%s IS NULL`, quotedName) + case LocalTypeInt, LocalTypeUint, LocalTypeInt64, LocalTypeUint64, LocalTypeBool: + return fmt.Sprintf(`%s=0`, quotedName) + default: + intlog.Errorf(ctx, `invalid field type "%s" for soft delete condition: prefix=%s, field=%s`, fieldType, prefix, fieldName) + return "" + } + + case SoftTimeTypeTime: + return fmt.Sprintf(`%s IS NULL`, quotedName) + + default: + return fmt.Sprintf(`%s=0`, quotedName) + } +} + +// GetFieldValue generates value for create/update/delete operations. +func (m *softTimeMaintainer) GetFieldValue( + ctx context.Context, fieldType LocalType, isDeleted bool, +) any { + // For deleted field, return "empty" value + if isDeleted { + return m.getEmptyValue(fieldType) + } + + // For create/update/delete, return current time value + switch m.softTimeOption.SoftTimeType { + case SoftTimeTypeAuto: + return m.getAutoValue(ctx, fieldType) + default: + switch fieldType { + case LocalTypeBool: + return 1 + default: + return m.getTimestampValue() + } + } +} + +// getTimestampValue returns timestamp value for soft time. +func (m *softTimeMaintainer) getTimestampValue() any { + switch m.softTimeOption.SoftTimeType { + case SoftTimeTypeTime: + return gtime.Now() + case SoftTimeTypeTimestamp: + return gtime.Timestamp() + case SoftTimeTypeTimestampMilli: + return gtime.TimestampMilli() + case SoftTimeTypeTimestampMicro: + return gtime.TimestampMicro() + case SoftTimeTypeTimestampNano: + return gtime.TimestampNano() + default: + panic(gerror.NewCodef( + gcode.CodeInternalPanic, + `unrecognized SoftTimeType "%d"`, m.softTimeOption.SoftTimeType, + )) + } +} + +// getEmptyValue returns "empty" value for deleted field. +func (m *softTimeMaintainer) getEmptyValue(fieldType LocalType) any { + switch fieldType { + case LocalTypeDate, LocalTypeTime, LocalTypeDatetime: + return nil + default: + return 0 + } +} + +// getAutoValue returns auto-detected value based on field type. +func (m *softTimeMaintainer) getAutoValue(ctx context.Context, fieldType LocalType) any { + switch fieldType { + case LocalTypeDate, LocalTypeTime, LocalTypeDatetime: + return gtime.Now() + case LocalTypeInt, LocalTypeUint, LocalTypeInt64, LocalTypeUint64: + return gtime.Timestamp() + case LocalTypeBool: + return 1 + default: + intlog.Errorf(ctx, `invalid field type "%s" for soft time auto value`, fieldType) + return nil + } +} diff --git a/database/gdb_model_transaction.go b/database/gdb_model_transaction.go new file mode 100644 index 0000000..065bd2a --- /dev/null +++ b/database/gdb_model_transaction.go @@ -0,0 +1,42 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" +) + +// Transaction wraps the transaction logic using function `f`. +// It rollbacks the transaction and returns the error from function `f` if +// it returns non-nil error. It commits the transaction and returns nil if +// function `f` returns nil. +// +// Note that, you should not Commit or Rollback the transaction in function `f` +// as it is automatically handled by this function. +func (m *Model) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) { + if ctx == nil { + ctx = m.GetCtx() + } + if m.tx != nil { + return m.tx.Transaction(ctx, f) + } + return m.db.Transaction(ctx, f) +} + +// TransactionWithOptions executes transaction with options. +// The parameter `opts` specifies the transaction options. +// The parameter `f` specifies the function that will be called within the transaction. +// If f returns error, the transaction will be rolled back, or else the transaction will be committed. +func (m *Model) TransactionWithOptions(ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error) (err error) { + if ctx == nil { + ctx = m.GetCtx() + } + if m.tx != nil { + return m.tx.Transaction(ctx, f) + } + return m.db.TransactionWithOptions(ctx, opts, f) +} diff --git a/database/gdb_model_update.go b/database/gdb_model_update.go new file mode 100644 index 0000000..201ccb5 --- /dev/null +++ b/database/gdb_model_update.go @@ -0,0 +1,139 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" + "fmt" + "reflect" + + "git.magicany.cc/black1552/gin-base/database/empty" + "git.magicany.cc/black1552/gin-base/database/intlog" + "git.magicany.cc/black1552/gin-base/database/reflection" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// Update does "UPDATE ... " statement for the model. +// +// If the optional parameter `dataAndWhere` is given, the dataAndWhere[0] is the updated data field, +// and dataAndWhere[1:] is treated as where condition fields. +// Also see Model.Data and Model.Where functions. +func (m *Model) Update(dataAndWhere ...any) (result sql.Result, err error) { + var ctx = m.GetCtx() + if len(dataAndWhere) > 0 { + if len(dataAndWhere) > 2 { + return m.Data(dataAndWhere[0]).Where(dataAndWhere[1], dataAndWhere[2:]...).Update() + } else if len(dataAndWhere) == 2 { + return m.Data(dataAndWhere[0]).Where(dataAndWhere[1]).Update() + } else { + return m.Data(dataAndWhere[0]).Update() + } + } + defer func() { + if err == nil { + m.checkAndRemoveSelectCache(ctx) + } + }() + if m.data == nil { + return nil, gerror.NewCode(gcode.CodeMissingParameter, "updating table with empty data") + } + var ( + newData any + stm = m.softTimeMaintainer() + reflectInfo = reflection.OriginTypeAndKind(m.data) + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false) + conditionStr = conditionWhere + conditionExtra + fieldNameUpdate, fieldTypeUpdate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldUpdate) + ) + if fieldNameUpdate != "" && (m.unscoped || m.isFieldInFieldsEx(fieldNameUpdate)) { + fieldNameUpdate = "" + } + + newData, err = m.filterDataForInsertOrUpdate(m.data) + if err != nil { + return nil, err + } + + switch reflectInfo.OriginKind { + case reflect.Map, reflect.Struct: + var dataMap = anyValueToMapBeforeToRecord(newData) + // Automatically update the record updating time. + if fieldNameUpdate != "" && empty.IsNil(dataMap[fieldNameUpdate]) { + dataValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false) + dataMap[fieldNameUpdate] = dataValue + } + newData = dataMap + + default: + var updateStr = gconv.String(newData) + // Automatically update the record updating time. + if fieldNameUpdate != "" && !gstr.Contains(updateStr, fieldNameUpdate) { + dataValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false) + updateStr += fmt.Sprintf(`,%s=?`, fieldNameUpdate) + conditionArgs = append([]any{dataValue}, conditionArgs...) + } + newData = updateStr + } + + if !gstr.ContainsI(conditionStr, " WHERE ") { + intlog.Printf( + ctx, + `sql condition string "%s" has no WHERE for UPDATE operation, fieldNameUpdate: %s`, + conditionStr, fieldNameUpdate, + ) + return nil, gerror.NewCode( + gcode.CodeMissingParameter, + "there should be WHERE condition statement for UPDATE operation", + ) + } + + in := &HookUpdateInput{ + internalParamHookUpdate: internalParamHookUpdate{ + internalParamHook: internalParamHook{ + link: m.getLink(true), + }, + handler: m.hookHandler.Update, + }, + Model: m, + Table: m.tables, + Schema: m.schema, + Data: newData, + Condition: conditionStr, + Args: m.mergeArguments(conditionArgs), + } + return in.Next(ctx) +} + +// UpdateAndGetAffected performs update statement and returns the affected rows number. +func (m *Model) UpdateAndGetAffected(dataAndWhere ...any) (affected int64, err error) { + result, err := m.Update(dataAndWhere...) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +// Increment increments a column's value by a given amount. +// The parameter `amount` can be type of float or integer. +func (m *Model) Increment(column string, amount any) (sql.Result, error) { + return m.getModel().Data(column, &Counter{ + Field: column, + Value: gconv.Float64(amount), + }).Update() +} + +// Decrement decrements a column's value by a given amount. +// The parameter `amount` can be type of float or integer. +func (m *Model) Decrement(column string, amount any) (sql.Result, error) { + return m.getModel().Data(column, &Counter{ + Field: column, + Value: -gconv.Float64(amount), + }).Update() +} diff --git a/database/gdb_model_utility.go b/database/gdb_model_utility.go new file mode 100644 index 0000000..43c9593 --- /dev/null +++ b/database/gdb_model_utility.go @@ -0,0 +1,314 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "time" + + "git.magicany.cc/black1552/gin-base/database/empty" + "github.com/gogf/gf/v2/container/gset" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/gutil" +) + +// QuoteWord checks given string `s` a word, +// if true it quotes `s` with security chars of the database +// and returns the quoted string; or else it returns `s` without any change. +// +// The meaning of a `word` can be considered as a column name. +func (m *Model) QuoteWord(s string) string { + return m.db.GetCore().QuoteWord(s) +} + +// TableFields retrieves and returns the fields' information of specified table of current +// schema. +// +// Also see DriverMysql.TableFields. +func (m *Model) TableFields(tableStr string, schema ...string) (fields map[string]*TableField, err error) { + var ( + ctx = m.GetCtx() + usedTable = m.db.GetCore().guessPrimaryTableName(tableStr) + usedSchema = gutil.GetOrDefaultStr(m.schema, schema...) + ) + // Sharding feature. + usedSchema, err = m.getActualSchema(ctx, usedSchema) + if err != nil { + return nil, err + } + usedTable, err = m.getActualTable(ctx, usedTable) + if err != nil { + return nil, err + } + return m.db.TableFields(ctx, usedTable, usedSchema) +} + +// getModel creates and returns a cloned model of current model if `safe` is true, or else it returns +// the current model. +func (m *Model) getModel() *Model { + if !m.safe { + return m + } else { + return m.Clone() + } +} + +// mappingAndFilterToTableFields mappings and changes given field name to really table field name. +// Eg: +// ID -> id +// NICK_Name -> nickname. +func (m *Model) mappingAndFilterToTableFields(table string, fields []any, filter bool) []any { + var fieldsTable = table + if fieldsTable != "" { + hasTable, _ := m.db.GetCore().HasTable(fieldsTable) + if !hasTable { + if fieldsTable != m.tablesInit { + // Table/alias unknown (e.g., FieldsPrefix called before LeftJoin), skip filtering. + return fields + } + // HasTable cache miss for main table, fallback to use main table for field mapping. + fieldsTable = m.tablesInit + } + } + if fieldsTable == "" { + fieldsTable = m.tablesInit + } + + fieldsMap, _ := m.TableFields(fieldsTable) + if len(fieldsMap) == 0 { + return fields + } + var outputFieldsArray = make([]any, 0) + fieldsKeyMap := make(map[string]any, len(fieldsMap)) + for k := range fieldsMap { + fieldsKeyMap[k] = nil + } + for _, field := range fields { + var ( + fieldStr = gconv.String(field) + inputFieldsArray []string + ) + // Skip empty string fields. + if fieldStr == "" { + continue + } + switch { + case gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, fieldStr): + inputFieldsArray = append(inputFieldsArray, fieldStr) + + case gregex.IsMatchString(regularFieldNameWithCommaRegPattern, fieldStr): + inputFieldsArray = gstr.SplitAndTrim(fieldStr, ",") + + default: + // Example: + // user.id, user.name + // replace(concat_ws(',',lpad(s.id, 6, '0'),s.name),',','') `code` + outputFieldsArray = append(outputFieldsArray, field) + continue + } + for _, inputField := range inputFieldsArray { + if !gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, inputField) { + outputFieldsArray = append(outputFieldsArray, inputField) + continue + } + if _, ok := fieldsKeyMap[inputField]; !ok { + // Example: + // id, name + if foundKey, _ := gutil.MapPossibleItemByKey(fieldsKeyMap, inputField); foundKey != "" { + outputFieldsArray = append(outputFieldsArray, foundKey) + } else if !filter { + outputFieldsArray = append(outputFieldsArray, inputField) + } + } else { + outputFieldsArray = append(outputFieldsArray, inputField) + } + } + } + return outputFieldsArray +} + +// filterDataForInsertOrUpdate does filter feature with data for inserting/updating operations. +// Note that, it does not filter list item, which is also type of map, for "omit empty" feature. +func (m *Model) filterDataForInsertOrUpdate(data any) (any, error) { + var err error + switch value := data.(type) { + case List: + var omitEmpty bool + if m.option&optionOmitNilDataList > 0 { + omitEmpty = true + } + for k, item := range value { + value[k], err = m.doMappingAndFilterForInsertOrUpdateDataMap(item, omitEmpty) + if err != nil { + return nil, err + } + } + return value, nil + + case Map: + return m.doMappingAndFilterForInsertOrUpdateDataMap(value, true) + + default: + return data, nil + } +} + +// doMappingAndFilterForInsertOrUpdateDataMap does the filter features for map. +// Note that, it does not filter list item, which is also type of map, for "omit empty" feature. +func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEmpty bool) (Map, error) { + var ( + err error + ctx = m.GetCtx() + core = m.db.GetCore() + schema = m.schema + table = m.tablesInit + ) + // Sharding feature. + schema, err = m.getActualSchema(ctx, schema) + if err != nil { + return nil, err + } + table, err = m.getActualTable(ctx, table) + if err != nil { + return nil, err + } + data, err = core.mappingAndFilterData( + ctx, schema, table, data, m.filter, + ) + if err != nil { + return nil, err + } + // Remove key-value pairs of which the value is nil. + if allowOmitEmpty && m.option&optionOmitNilData > 0 { + tempMap := make(Map, len(data)) + for k, v := range data { + if empty.IsNil(v) { + continue + } + tempMap[k] = v + } + data = tempMap + } + + // Remove key-value pairs of which the value is empty. + if allowOmitEmpty && m.option&optionOmitEmptyData > 0 { + tempMap := make(Map, len(data)) + for k, v := range data { + if empty.IsEmpty(v) { + continue + } + // Special type filtering. + switch r := v.(type) { + case time.Time: + if r.IsZero() { + continue + } + case *time.Time: + if r.IsZero() { + continue + } + case gtime.Time: + if r.IsZero() { + continue + } + case *gtime.Time: + if r.IsZero() { + continue + } + } + tempMap[k] = v + } + data = tempMap + } + + if len(m.fields) > 0 { + // Keep specified fields. + var ( + fieldSet = gset.NewStrSetFrom(gconv.Strings(m.fields)) + charL, charR = m.db.GetChars() + chars = charL + charR + ) + fieldSet.Walk(func(item string) string { + return gstr.Trim(item, chars) + }) + for k := range data { + k = gstr.Trim(k, chars) + if !fieldSet.Contains(k) { + delete(data, k) + } + } + } else if len(m.fieldsEx) > 0 { + // Filter specified fields. + for _, v := range m.fieldsEx { + delete(data, gconv.String(v)) + } + } + return data, nil +} + +// getLink returns the underlying database link object with configured `linkType` attribute. +// The parameter `master` specifies whether using the master node if master-slave configured. +func (m *Model) getLink(master bool) Link { + if m.tx != nil { + if sqlTx := m.tx.GetSqlTX(); sqlTx != nil { + return &txLink{sqlTx} + } + } + linkType := m.linkType + if linkType == 0 { + if master { + linkType = linkTypeMaster + } else { + linkType = linkTypeSlave + } + } + switch linkType { + case linkTypeMaster: + link, err := m.db.GetCore().MasterLink(m.schema) + if err != nil { + panic(err) + } + return link + case linkTypeSlave: + link, err := m.db.GetCore().SlaveLink(m.schema) + if err != nil { + panic(err) + } + return link + } + return nil +} + +// getPrimaryKey retrieves and returns the primary key name of the model table. +// It parses m.tables to retrieve the primary table name, supporting m.tables like: +// "user", "user u", "user as u, user_detail as ud". +func (m *Model) getPrimaryKey() string { + table := gstr.SplitAndTrim(m.tablesInit, " ")[0] + tableFields, err := m.TableFields(table) + if err != nil { + return "" + } + for name, field := range tableFields { + if gstr.ContainsI(field.Key, "pri") { + return name + } + } + return "" +} + +// mergeArguments creates and returns new arguments by merging `m.extraArgs` and given `args`. +func (m *Model) mergeArguments(args []any) []any { + if len(m.extraArgs) > 0 { + newArgs := make([]any, len(m.extraArgs)+len(args)) + copy(newArgs, m.extraArgs) + copy(newArgs[len(m.extraArgs):], args) + return newArgs + } + return args +} diff --git a/database/gdb_model_where.go b/database/gdb_model_where.go new file mode 100644 index 0000000..ed8b48f --- /dev/null +++ b/database/gdb_model_where.go @@ -0,0 +1,129 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://githum.com/gogf/gf. + +package database + +// callWhereBuilder creates and returns a new Model, and sets its WhereBuilder if current Model is safe. +// It sets the WhereBuilder and returns current Model directly if it is not safe. +func (m *Model) callWhereBuilder(builder *WhereBuilder) *Model { + model := m.getModel() + model.whereBuilder = builder + return model +} + +// Where sets the condition statement for the builder. The parameter `where` can be type of +// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times, +// multiple conditions will be joined into where statement using "AND". +// See WhereBuilder.Where. +func (m *Model) Where(where any, args ...any) *Model { + return m.callWhereBuilder(m.whereBuilder.Where(where, args...)) +} + +// Wheref builds condition string using fmt.Sprintf and arguments. +// Note that if the number of `args` is more than the placeholder in `format`, +// the extra `args` will be used as the where condition arguments of the Model. +// See WhereBuilder.Wheref. +func (m *Model) Wheref(format string, args ...any) *Model { + return m.callWhereBuilder(m.whereBuilder.Wheref(format, args...)) +} + +// WherePri does the same logic as Model.Where except that if the parameter `where` +// is a single condition like int/string/float/slice, it treats the condition as the primary +// key value. That is, if primary key is "id" and given `where` parameter as "123", the +// WherePri function treats the condition as "id=123", but Model.Where treats the condition +// as string "123". +// See WhereBuilder.WherePri. +func (m *Model) WherePri(where any, args ...any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePri(where, args...)) +} + +// WhereLT builds `column < value` statement. +// See WhereBuilder.WhereLT. +func (m *Model) WhereLT(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereLT(column, value)) +} + +// WhereLTE builds `column <= value` statement. +// See WhereBuilder.WhereLTE. +func (m *Model) WhereLTE(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereLTE(column, value)) +} + +// WhereGT builds `column > value` statement. +// See WhereBuilder.WhereGT. +func (m *Model) WhereGT(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereGT(column, value)) +} + +// WhereGTE builds `column >= value` statement. +// See WhereBuilder.WhereGTE. +func (m *Model) WhereGTE(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereGTE(column, value)) +} + +// WhereBetween builds `column BETWEEN min AND max` statement. +// See WhereBuilder.WhereBetween. +func (m *Model) WhereBetween(column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereBetween(column, min, max)) +} + +// WhereLike builds `column LIKE like` statement. +// See WhereBuilder.WhereLike. +func (m *Model) WhereLike(column string, like string) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereLike(column, like)) +} + +// WhereIn builds `column IN (in)` statement. +// See WhereBuilder.WhereIn. +func (m *Model) WhereIn(column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereIn(column, in)) +} + +// WhereNull builds `columns[0] IS NULL AND columns[1] IS NULL ...` statement. +// See WhereBuilder.WhereNull. +func (m *Model) WhereNull(columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereNull(columns...)) +} + +// WhereNotBetween builds `column NOT BETWEEN min AND max` statement. +// See WhereBuilder.WhereNotBetween. +func (m *Model) WhereNotBetween(column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereNotBetween(column, min, max)) +} + +// WhereNotLike builds `column NOT LIKE like` statement. +// See WhereBuilder.WhereNotLike. +func (m *Model) WhereNotLike(column string, like any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereNotLike(column, like)) +} + +// WhereNot builds `column != value` statement. +// See WhereBuilder.WhereNot. +func (m *Model) WhereNot(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereNot(column, value)) +} + +// WhereNotIn builds `column NOT IN (in)` statement. +// See WhereBuilder.WhereNotIn. +func (m *Model) WhereNotIn(column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereNotIn(column, in)) +} + +// WhereNotNull builds `columns[0] IS NOT NULL AND columns[1] IS NOT NULL ...` statement. +// See WhereBuilder.WhereNotNull. +func (m *Model) WhereNotNull(columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereNotNull(columns...)) +} + +// WhereExists builds `EXISTS (subQuery)` statement. +func (m *Model) WhereExists(subQuery *Model) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereExists(subQuery)) +} + +// WhereNotExists builds `NOT EXISTS (subQuery)` statement. +func (m *Model) WhereNotExists(subQuery *Model) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereNotExists(subQuery)) +} diff --git a/database/gdb_model_where_prefix.go b/database/gdb_model_where_prefix.go new file mode 100644 index 0000000..3649925 --- /dev/null +++ b/database/gdb_model_where_prefix.go @@ -0,0 +1,91 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// WherePrefix performs as Where, but it adds prefix to each field in where statement. +// See WhereBuilder.WherePrefix. +func (m *Model) WherePrefix(prefix string, where any, args ...any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefix(prefix, where, args...)) +} + +// WherePrefixLT builds `prefix.column < value` statement. +// See WhereBuilder.WherePrefixLT. +func (m *Model) WherePrefixLT(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixLT(prefix, column, value)) +} + +// WherePrefixLTE builds `prefix.column <= value` statement. +// See WhereBuilder.WherePrefixLTE. +func (m *Model) WherePrefixLTE(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixLTE(prefix, column, value)) +} + +// WherePrefixGT builds `prefix.column > value` statement. +// See WhereBuilder.WherePrefixGT. +func (m *Model) WherePrefixGT(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixGT(prefix, column, value)) +} + +// WherePrefixGTE builds `prefix.column >= value` statement. +// See WhereBuilder.WherePrefixGTE. +func (m *Model) WherePrefixGTE(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixGTE(prefix, column, value)) +} + +// WherePrefixBetween builds `prefix.column BETWEEN min AND max` statement. +// See WhereBuilder.WherePrefixBetween. +func (m *Model) WherePrefixBetween(prefix string, column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixBetween(prefix, column, min, max)) +} + +// WherePrefixLike builds `prefix.column LIKE like` statement. +// See WhereBuilder.WherePrefixLike. +func (m *Model) WherePrefixLike(prefix string, column string, like any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixLike(prefix, column, like)) +} + +// WherePrefixIn builds `prefix.column IN (in)` statement. +// See WhereBuilder.WherePrefixIn. +func (m *Model) WherePrefixIn(prefix string, column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixIn(prefix, column, in)) +} + +// WherePrefixNull builds `prefix.columns[0] IS NULL AND prefix.columns[1] IS NULL ...` statement. +// See WhereBuilder.WherePrefixNull. +func (m *Model) WherePrefixNull(prefix string, columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixNull(prefix, columns...)) +} + +// WherePrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement. +// See WhereBuilder.WherePrefixNotBetween. +func (m *Model) WherePrefixNotBetween(prefix string, column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixNotBetween(prefix, column, min, max)) +} + +// WherePrefixNotLike builds `prefix.column NOT LIKE like` statement. +// See WhereBuilder.WherePrefixNotLike. +func (m *Model) WherePrefixNotLike(prefix string, column string, like any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixNotLike(prefix, column, like)) +} + +// WherePrefixNot builds `prefix.column != value` statement. +// See WhereBuilder.WherePrefixNot. +func (m *Model) WherePrefixNot(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixNot(prefix, column, value)) +} + +// WherePrefixNotIn builds `prefix.column NOT IN (in)` statement. +// See WhereBuilder.WherePrefixNotIn. +func (m *Model) WherePrefixNotIn(prefix string, column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixNotIn(prefix, column, in)) +} + +// WherePrefixNotNull builds `prefix.columns[0] IS NOT NULL AND prefix.columns[1] IS NOT NULL ...` statement. +// See WhereBuilder.WherePrefixNotNull. +func (m *Model) WherePrefixNotNull(prefix string, columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WherePrefixNotNull(prefix, columns...)) +} diff --git a/database/gdb_model_whereor.go b/database/gdb_model_whereor.go new file mode 100644 index 0000000..ebb6e9a --- /dev/null +++ b/database/gdb_model_whereor.go @@ -0,0 +1,97 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// WhereOr adds "OR" condition to the where statement. +// See WhereBuilder.WhereOr. +func (m *Model) WhereOr(where any, args ...any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOr(where, args...)) +} + +// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments. +// See WhereBuilder.WhereOrf. +func (m *Model) WhereOrf(format string, args ...any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrf(format, args...)) +} + +// WhereOrLT builds `column < value` statement in `OR` conditions. +// See WhereBuilder.WhereOrLT. +func (m *Model) WhereOrLT(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrLT(column, value)) +} + +// WhereOrLTE builds `column <= value` statement in `OR` conditions. +// See WhereBuilder.WhereOrLTE. +func (m *Model) WhereOrLTE(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrLTE(column, value)) +} + +// WhereOrGT builds `column > value` statement in `OR` conditions. +// See WhereBuilder.WhereOrGT. +func (m *Model) WhereOrGT(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrGT(column, value)) +} + +// WhereOrGTE builds `column >= value` statement in `OR` conditions. +// See WhereBuilder.WhereOrGTE. +func (m *Model) WhereOrGTE(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrGTE(column, value)) +} + +// WhereOrBetween builds `column BETWEEN min AND max` statement in `OR` conditions. +// See WhereBuilder.WhereOrBetween. +func (m *Model) WhereOrBetween(column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrBetween(column, min, max)) +} + +// WhereOrLike builds `column LIKE like` statement in `OR` conditions. +// See WhereBuilder.WhereOrLike. +func (m *Model) WhereOrLike(column string, like any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrLike(column, like)) +} + +// WhereOrIn builds `column IN (in)` statement in `OR` conditions. +// See WhereBuilder.WhereOrIn. +func (m *Model) WhereOrIn(column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrIn(column, in)) +} + +// WhereOrNull builds `columns[0] IS NULL OR columns[1] IS NULL ...` statement in `OR` conditions. +// See WhereBuilder.WhereOrNull. +func (m *Model) WhereOrNull(columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrNull(columns...)) +} + +// WhereOrNotBetween builds `column NOT BETWEEN min AND max` statement in `OR` conditions. +// See WhereBuilder.WhereOrNotBetween. +func (m *Model) WhereOrNotBetween(column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrNotBetween(column, min, max)) +} + +// WhereOrNotLike builds `column NOT LIKE 'like'` statement in `OR` conditions. +// See WhereBuilder.WhereOrNotLike. +func (m *Model) WhereOrNotLike(column string, like any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrNotLike(column, like)) +} + +// WhereOrNot builds `column != value` statement. +// See WhereBuilder.WhereOrNot. +func (m *Model) WhereOrNot(column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrNot(column, value)) +} + +// WhereOrNotIn builds `column NOT IN (in)` statement. +// See WhereBuilder.WhereOrNotIn. +func (m *Model) WhereOrNotIn(column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrNotIn(column, in)) +} + +// WhereOrNotNull builds `columns[0] IS NOT NULL OR columns[1] IS NOT NULL ...` statement in `OR` conditions. +// See WhereBuilder.WhereOrNotNull. +func (m *Model) WhereOrNotNull(columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrNotNull(columns...)) +} diff --git a/database/gdb_model_whereor_prefix.go b/database/gdb_model_whereor_prefix.go new file mode 100644 index 0000000..d603e34 --- /dev/null +++ b/database/gdb_model_whereor_prefix.go @@ -0,0 +1,91 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// WhereOrPrefix performs as WhereOr, but it adds prefix to each field in where statement. +// See WhereBuilder.WhereOrPrefix. +func (m *Model) WhereOrPrefix(prefix string, where any, args ...any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefix(prefix, where, args...)) +} + +// WhereOrPrefixLT builds `prefix.column < value` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixLT. +func (m *Model) WhereOrPrefixLT(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLT(prefix, column, value)) +} + +// WhereOrPrefixLTE builds `prefix.column <= value` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixLTE. +func (m *Model) WhereOrPrefixLTE(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLTE(prefix, column, value)) +} + +// WhereOrPrefixGT builds `prefix.column > value` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixGT. +func (m *Model) WhereOrPrefixGT(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixGT(prefix, column, value)) +} + +// WhereOrPrefixGTE builds `prefix.column >= value` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixGTE. +func (m *Model) WhereOrPrefixGTE(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixGTE(prefix, column, value)) +} + +// WhereOrPrefixBetween builds `prefix.column BETWEEN min AND max` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixBetween. +func (m *Model) WhereOrPrefixBetween(prefix string, column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixBetween(prefix, column, min, max)) +} + +// WhereOrPrefixLike builds `prefix.column LIKE like` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixLike. +func (m *Model) WhereOrPrefixLike(prefix string, column string, like any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLike(prefix, column, like)) +} + +// WhereOrPrefixIn builds `prefix.column IN (in)` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixIn. +func (m *Model) WhereOrPrefixIn(prefix string, column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixIn(prefix, column, in)) +} + +// WhereOrPrefixNull builds `prefix.columns[0] IS NULL OR prefix.columns[1] IS NULL ...` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixNull. +func (m *Model) WhereOrPrefixNull(prefix string, columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNull(prefix, columns...)) +} + +// WhereOrPrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixNotBetween. +func (m *Model) WhereOrPrefixNotBetween(prefix string, column string, min, max any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotBetween(prefix, column, min, max)) +} + +// WhereOrPrefixNotLike builds `prefix.column NOT LIKE like` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixNotLike. +func (m *Model) WhereOrPrefixNotLike(prefix string, column string, like any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotLike(prefix, column, like)) +} + +// WhereOrPrefixNotIn builds `prefix.column NOT IN (in)` statement. +// See WhereBuilder.WhereOrPrefixNotIn. +func (m *Model) WhereOrPrefixNotIn(prefix string, column string, in any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotIn(prefix, column, in)) +} + +// WhereOrPrefixNotNull builds `prefix.columns[0] IS NOT NULL OR prefix.columns[1] IS NOT NULL ...` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixNotNull. +func (m *Model) WhereOrPrefixNotNull(prefix string, columns ...string) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotNull(prefix, columns...)) +} + +// WhereOrPrefixNot builds `prefix.column != value` statement in `OR` conditions. +// See WhereBuilder.WhereOrPrefixNot. +func (m *Model) WhereOrPrefixNot(prefix string, column string, value any) *Model { + return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNot(prefix, column, value)) +} diff --git a/database/gdb_model_with.go b/database/gdb_model_with.go new file mode 100644 index 0000000..fd79c31 --- /dev/null +++ b/database/gdb_model_with.go @@ -0,0 +1,349 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" + "reflect" + + "git.magicany.cc/black1552/gin-base/database/utils" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gstructs" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gutil" +) + +// With creates and returns an ORM model based on metadata of given object. +// It also enables model association operations feature on given `object`. +// It can be called multiple times to add one or more objects to model and enable +// their mode association operations feature. +// For example, if given struct definition: +// +// type User struct { +// gmeta.Meta `orm:"table:user"` +// Id int `json:"id"` +// Name string `json:"name"` +// UserDetail *UserDetail `orm:"with:uid=id"` +// UserScores []*UserScores `orm:"with:uid=id"` +// } +// +// We can enable model association operations on attribute `UserDetail` and `UserScores` by: +// +// db.With(User{}.UserDetail).With(User{}.UserScores).Scan(xxx) +// +// Or: +// +// db.With(UserDetail{}).With(UserScores{}).Scan(xxx) +// +// Or: +// +// db.With(UserDetail{}, UserScores{}).Scan(xxx) +func (m *Model) With(objects ...any) *Model { + model := m.getModel() + for _, object := range objects { + if m.tables == "" { + m.tablesInit = m.db.GetCore().QuotePrefixTableName( + getTableNameFromOrmTag(object), + ) + m.tables = m.tablesInit + return model + } + model.withArray = append(model.withArray, object) + } + return model +} + +// WithAll enables model association operations on all objects that have "with" tag in the struct. +func (m *Model) WithAll() *Model { + model := m.getModel() + model.withAll = true + return model +} + +// doWithScanStruct handles model association operations feature for single struct. +func (m *Model) doWithScanStruct(pointer any) error { + if len(m.withArray) == 0 && !m.withAll { + return nil + } + var ( + err error + allowedTypeStrArray = make([]string, 0) + ) + currentStructFieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{ + Pointer: pointer, + PriorityTagArray: nil, + RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, + }) + if err != nil { + return err + } + // It checks the with array and automatically calls the ScanList to complete association querying. + if !m.withAll { + for _, field := range currentStructFieldMap { + for _, withItem := range m.withArray { + withItemReflectValueType, err := gstructs.StructType(withItem) + if err != nil { + return err + } + var ( + fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") + withItemReflectValueTypeStr = gstr.TrimAll(withItemReflectValueType.String(), "*[]") + ) + // It does select operation if the field type is in the specified "with" type array. + if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 { + allowedTypeStrArray = append(allowedTypeStrArray, fieldTypeStr) + } + } + } + } + for _, field := range currentStructFieldMap { + var ( + fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") + parsedTagOutput = m.parseWithTagInFieldStruct(field) + ) + if parsedTagOutput.With == "" { + continue + } + // It just handlers "with" type attribute struct, so it ignores other struct types. + if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) { + continue + } + array := gstr.SplitAndTrim(parsedTagOutput.With, "=") + if len(array) == 1 { + // It also supports using only one column name + // if both tables associates using the same column name. + array = append(array, parsedTagOutput.With) + } + var ( + model *Model + fieldKeys []string + relatedSourceName = array[0] + relatedTargetName = array[1] + relatedTargetValue any + ) + // Find the value of related attribute from `pointer`. + for attributeName, attributeValue := range currentStructFieldMap { + if utils.EqualFoldWithoutChars(attributeName, relatedTargetName) { + relatedTargetValue = attributeValue.Value.Interface() + break + } + } + if relatedTargetValue == nil { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `cannot find the target related value of name "%s" in with tag "%s" for attribute "%s.%s"`, + relatedTargetName, parsedTagOutput.With, reflect.TypeOf(pointer).Elem(), field.Name(), + ) + } + bindToReflectValue := field.Value + if bindToReflectValue.Kind() != reflect.Pointer && bindToReflectValue.CanAddr() { + bindToReflectValue = bindToReflectValue.Addr() + } + + if structFields, err := gstructs.Fields(gstructs.FieldsInput{ + Pointer: field.Value, + RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, + }); err != nil { + return err + } else { + fieldKeys = make([]string, len(structFields)) + for i, field := range structFields { + fieldKeys[i] = field.Name() + } + } + // Recursively with feature checks. + model = m.db.With(field.Value).Hook(m.hookHandler) + if m.withAll { + model = model.WithAll() + } else { + model = model.With(m.withArray...) + } + if parsedTagOutput.Where != "" { + model = model.Where(parsedTagOutput.Where) + } + if parsedTagOutput.Order != "" { + model = model.Order(parsedTagOutput.Order) + } + if parsedTagOutput.Unscoped == "true" { + model = model.Unscoped() + } + // With cache feature. + if m.cacheEnabled && m.cacheOption.Name == "" { + model = model.Cache(m.cacheOption) + } + err = model.Fields(fieldKeys). + Where(relatedSourceName, relatedTargetValue). + Scan(bindToReflectValue) + // It ignores sql.ErrNoRows in with feature. + if err != nil && err != sql.ErrNoRows { + return err + } + } + return nil +} + +// doWithScanStructs handles model association operations feature for struct slice. +// Also see doWithScanStruct. +func (m *Model) doWithScanStructs(pointer any) error { + if len(m.withArray) == 0 && !m.withAll { + return nil + } + if v, ok := pointer.(reflect.Value); ok { + pointer = v.Interface() + } + + var ( + err error + allowedTypeStrArray = make([]string, 0) + ) + currentStructFieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{ + Pointer: pointer, + PriorityTagArray: nil, + RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, + }) + if err != nil { + return err + } + // It checks the with array and automatically calls the ScanList to complete association querying. + if !m.withAll { + for _, field := range currentStructFieldMap { + for _, withItem := range m.withArray { + withItemReflectValueType, err := gstructs.StructType(withItem) + if err != nil { + return err + } + var ( + fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") + withItemReflectValueTypeStr = gstr.TrimAll(withItemReflectValueType.String(), "*[]") + ) + // It does select operation if the field type is in the specified with type array. + if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 { + allowedTypeStrArray = append(allowedTypeStrArray, fieldTypeStr) + } + } + } + } + + for fieldName, field := range currentStructFieldMap { + var ( + fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]") + parsedTagOutput = m.parseWithTagInFieldStruct(field) + ) + if parsedTagOutput.With == "" { + continue + } + if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) { + continue + } + array := gstr.SplitAndTrim(parsedTagOutput.With, "=") + if len(array) == 1 { + // It supports using only one column name + // if both tables associates using the same column name. + array = append(array, parsedTagOutput.With) + } + var ( + model *Model + fieldKeys []string + relatedSourceName = array[0] + relatedTargetName = array[1] + relatedTargetValue any + ) + // Find the value slice of related attribute from `pointer`. + for attributeName := range currentStructFieldMap { + if utils.EqualFoldWithoutChars(attributeName, relatedTargetName) { + relatedTargetValue = ListItemValuesUnique(pointer, attributeName) + break + } + } + if relatedTargetValue == nil { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `cannot find the related value for attribute name "%s" of with tag "%s"`, + relatedTargetName, parsedTagOutput.With, + ) + } + // If related value is empty, it does nothing but just returns. + if gutil.IsEmpty(relatedTargetValue) { + return nil + } + if structFields, err := gstructs.Fields(gstructs.FieldsInput{ + Pointer: field.Value, + RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, + }); err != nil { + return err + } else { + fieldKeys = make([]string, len(structFields)) + for i, field := range structFields { + fieldKeys[i] = field.Name() + } + } + // Recursively with feature checks. + model = m.db.With(field.Value).Hook(m.hookHandler) + if m.withAll { + model = model.WithAll() + } else { + model = model.With(m.withArray...) + } + if parsedTagOutput.Where != "" { + model = model.Where(parsedTagOutput.Where) + } + if parsedTagOutput.Order != "" { + model = model.Order(parsedTagOutput.Order) + } + if parsedTagOutput.Unscoped == "true" { + model = model.Unscoped() + } + // With cache feature. + if m.cacheEnabled && m.cacheOption.Name == "" { + model = model.Cache(m.cacheOption) + } + err = model.Fields(fieldKeys). + Where(relatedSourceName, relatedTargetValue). + ScanList(pointer, fieldName, parsedTagOutput.With) + // It ignores sql.ErrNoRows in with feature. + if err != nil && err != sql.ErrNoRows { + return err + } + } + return nil +} + +type parseWithTagInFieldStructOutput struct { + With string + Where string + Order string + Unscoped string +} + +func (m *Model) parseWithTagInFieldStruct(field gstructs.Field) (output parseWithTagInFieldStructOutput) { + var ( + ormTag = field.Tag(OrmTagForStruct) + data = make(map[string]string) + array []string + key string + ) + for _, v := range gstr.SplitAndTrim(ormTag, ",") { + array = gstr.Split(v, ":") + if len(array) == 2 { + key = array[0] + data[key] = gstr.Trim(array[1]) + } else { + if key == OrmTagForWithOrder { + // supporting multiple order fields + data[key] += "," + gstr.Trim(v) + } else { + data[key] += " " + gstr.Trim(v) + } + } + } + output.With = data[OrmTagForWith] + output.Where = data[OrmTagForWithWhere] + output.Order = data[OrmTagForWithOrder] + output.Unscoped = data[OrmTagForWithUnscoped] + return +} diff --git a/database/gdb_result.go b/database/gdb_result.go new file mode 100644 index 0000000..9025ddd --- /dev/null +++ b/database/gdb_result.go @@ -0,0 +1,67 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" + + "github.com/gogf/gf/v2/errors/gerror" +) + +// SqlResult is execution result for sql operations. +// It also supports batch operation result for rowsAffected. +type SqlResult struct { + Result sql.Result + Affected int64 +} + +// MustGetAffected returns the affected rows count, if any error occurs, it panics. +func (r *SqlResult) MustGetAffected() int64 { + rows, err := r.RowsAffected() + if err != nil { + err = gerror.Wrap(err, `sql.Result.RowsAffected failed`) + panic(err) + } + return rows +} + +// MustGetInsertId returns the last insert id, if any error occurs, it panics. +func (r *SqlResult) MustGetInsertId() int64 { + id, err := r.LastInsertId() + if err != nil { + err = gerror.Wrap(err, `sql.Result.LastInsertId failed`) + panic(err) + } + return id +} + +// RowsAffected returns the number of rows affected by an +// update, insert, or delete. Not every database or database +// driver may support this. +// Also, See sql.Result. +func (r *SqlResult) RowsAffected() (int64, error) { + if r.Affected > 0 { + return r.Affected, nil + } + if r.Result == nil { + return 0, nil + } + return r.Result.RowsAffected() +} + +// LastInsertId returns the integer generated by the database +// in response to a command. Typically, this will be from an +// "auto increment" column when inserting a new row. Not all +// databases support this feature, and the syntax of such +// statements varies. +// Also, See sql.Result. +func (r *SqlResult) LastInsertId() (int64, error) { + if r.Result == nil { + return 0, nil + } + return r.Result.LastInsertId() +} diff --git a/database/gdb_schema.go b/database/gdb_schema.go new file mode 100644 index 0000000..896d675 --- /dev/null +++ b/database/gdb_schema.go @@ -0,0 +1,30 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +// Schema is a schema object from which it can then create a Model. +type Schema struct { + DB +} + +// Schema creates and returns a schema. +func (c *Core) Schema(schema string) *Schema { + // Do not change the schema of the original db, + // it here creates a new db and changes its schema. + db, err := NewByGroup(c.GetGroup()) + if err != nil { + panic(err) + } + core := db.GetCore() + // Different schema share some same objects. + core.logger = c.logger + core.cache = c.cache + core.schema = schema + return &Schema{ + DB: db, + } +} diff --git a/database/gdb_statement.go b/database/gdb_statement.go new file mode 100644 index 0000000..01d2a1a --- /dev/null +++ b/database/gdb_statement.go @@ -0,0 +1,118 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "context" + "database/sql" +) + +// Stmt is a prepared statement. +// A Stmt is safe for concurrent use by multiple goroutines. +// +// If a Stmt is prepared on a Tx or Conn, it will be bound to a single +// underlying connection forever. If the Tx or Conn closes, the Stmt will +// become unusable and all operations will return an error. +// If a Stmt is prepared on a DB, it will remain usable for the lifetime of the +// DB. When the Stmt needs to execute on a new underlying connection, it will +// prepare itself on the new connection automatically. +type Stmt struct { + *sql.Stmt + core *Core + link Link + sql string +} + +// ExecContext executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { + out, err := s.core.db.DoCommit(ctx, DoCommitInput{ + Stmt: s.Stmt, + Link: s.link, + Sql: s.sql, + Args: args, + Type: SqlTypeStmtExecContext, + IsTransaction: s.link.IsTransaction(), + }) + return out.Result, err +} + +// QueryContext executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { + out, err := s.core.db.DoCommit(ctx, DoCommitInput{ + Stmt: s.Stmt, + Link: s.link, + Sql: s.sql, + Args: args, + Type: SqlTypeStmtQueryContext, + IsTransaction: s.link.IsTransaction(), + }) + if err != nil { + return nil, err + } + if out.RawResult != nil { + return out.RawResult.(*sql.Rows), err + } + return nil, nil +} + +// QueryRowContext executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row { + out, err := s.core.db.DoCommit(ctx, DoCommitInput{ + Stmt: s.Stmt, + Link: s.link, + Sql: s.sql, + Args: args, + Type: SqlTypeStmtQueryContext, + IsTransaction: s.link.IsTransaction(), + }) + if err != nil { + panic(err) + } + if out.RawResult != nil { + return out.RawResult.(*sql.Row) + } + return nil +} + +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) Exec(args ...any) (sql.Result, error) { + return s.ExecContext(context.Background(), args...) +} + +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) Query(args ...any) (*sql.Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +func (s *Stmt) QueryRow(args ...any) *sql.Row { + return s.QueryRowContext(context.Background(), args...) +} + +// Close closes the statement. +func (s *Stmt) Close() error { + return s.Stmt.Close() +} diff --git a/database/gdb_type_record.go b/database/gdb_type_record.go new file mode 100644 index 0000000..5f8447e --- /dev/null +++ b/database/gdb_type_record.go @@ -0,0 +1,65 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" + + "git.magicany.cc/black1552/gin-base/database/empty" + "github.com/gogf/gf/v2/container/gmap" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/util/gconv" +) + +// Json converts `r` to JSON format content. +func (r Record) Json() string { + content, _ := gjson.New(r.Map()).ToJsonString() + return content +} + +// Xml converts `r` to XML format content. +func (r Record) Xml(rootTag ...string) string { + content, _ := gjson.New(r.Map()).ToXmlString(rootTag...) + return content +} + +// Map converts `r` to map[string]any. +func (r Record) Map() Map { + m := make(map[string]any) + for k, v := range r { + m[k] = v.Val() + } + return m +} + +// GMap converts `r` to a gmap. +func (r Record) GMap() *gmap.StrAnyMap { + return gmap.NewStrAnyMapFrom(r.Map()) +} + +// Struct converts `r` to a struct. +// Note that the parameter `pointer` should be type of *struct/**struct. +// +// Note that it returns sql.ErrNoRows if `r` is empty. +func (r Record) Struct(pointer any) error { + // If the record is empty, it returns error. + if r.IsEmpty() { + if !empty.IsNil(pointer, true) { + return sql.ErrNoRows + } + return nil + } + return converter.Struct(r, pointer, gconv.StructOption{ + PriorityTag: OrmTagForStruct, + ContinueOnError: true, + }) +} + +// IsEmpty checks and returns whether `r` is empty. +func (r Record) IsEmpty() bool { + return len(r) == 0 +} diff --git a/database/gdb_type_result.go b/database/gdb_type_result.go new file mode 100644 index 0000000..3d766dc --- /dev/null +++ b/database/gdb_type_result.go @@ -0,0 +1,214 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" + "math" + + "git.magicany.cc/black1552/gin-base/database/empty" + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/util/gconv" +) + +// IsEmpty checks and returns whether `r` is empty. +func (r Result) IsEmpty() bool { + return r == nil || r.Len() == 0 +} + +// Len returns the length of result list. +func (r Result) Len() int { + return len(r) +} + +// Size is alias of function Len. +func (r Result) Size() int { + return r.Len() +} + +// Chunk splits a Result into multiple Results, +// the size of each array is determined by `size`. +// The last chunk may contain less than size elements. +func (r Result) Chunk(size int) []Result { + if size < 1 { + return nil + } + length := len(r) + chunks := int(math.Ceil(float64(length) / float64(size))) + var n []Result + for i, end := 0, 0; chunks > 0; chunks-- { + end = (i + 1) * size + if end > length { + end = length + } + n = append(n, r[i*size:end]) + i++ + } + return n +} + +// Json converts `r` to JSON format content. +func (r Result) Json() string { + content, _ := gjson.New(r.List()).ToJsonString() + return content +} + +// Xml converts `r` to XML format content. +func (r Result) Xml(rootTag ...string) string { + content, _ := gjson.New(r.List()).ToXmlString(rootTag...) + return content +} + +// List converts `r` to a List. +func (r Result) List() List { + list := make(List, len(r)) + for k, v := range r { + list[k] = v.Map() + } + return list +} + +// Array retrieves and returns specified column values as slice. +// The parameter `field` is optional is the column field is only one. +// The default `field` is the first field name of the first item in `Result` if parameter `field` is not given. +func (r Result) Array(field ...string) Array { + array := make(Array, len(r)) + if len(r) == 0 { + return array + } + key := "" + if len(field) > 0 && field[0] != "" { + key = field[0] + } else { + for k := range r[0] { + key = k + break + } + } + for k, v := range r { + array[k] = v[key] + } + return array +} + +// MapKeyValue converts `r` to a map[string]Value of which key is specified by `key`. +// Note that the item value may be type of slice. +func (r Result) MapKeyValue(key string) map[string]Value { + var ( + s string + m = make(map[string]Value) + tempMap = make(map[string][]any) + hasMultiValues bool + ) + for _, item := range r { + if k, ok := item[key]; ok { + s = k.String() + tempMap[s] = append(tempMap[s], item) + if len(tempMap[s]) > 1 { + hasMultiValues = true + } + } + } + for k, v := range tempMap { + if hasMultiValues { + m[k] = gvar.New(v) + } else { + m[k] = gvar.New(v[0]) + } + } + return m +} + +// MapKeyStr converts `r` to a map[string]Map of which key is specified by `key`. +func (r Result) MapKeyStr(key string) map[string]Map { + m := make(map[string]Map) + for _, item := range r { + if v, ok := item[key]; ok { + m[v.String()] = item.Map() + } + } + return m +} + +// MapKeyInt converts `r` to a map[int]Map of which key is specified by `key`. +func (r Result) MapKeyInt(key string) map[int]Map { + m := make(map[int]Map) + for _, item := range r { + if v, ok := item[key]; ok { + m[v.Int()] = item.Map() + } + } + return m +} + +// MapKeyUint converts `r` to a map[uint]Map of which key is specified by `key`. +func (r Result) MapKeyUint(key string) map[uint]Map { + m := make(map[uint]Map) + for _, item := range r { + if v, ok := item[key]; ok { + m[v.Uint()] = item.Map() + } + } + return m +} + +// RecordKeyStr converts `r` to a map[string]Record of which key is specified by `key`. +func (r Result) RecordKeyStr(key string) map[string]Record { + m := make(map[string]Record) + for _, item := range r { + if v, ok := item[key]; ok { + m[v.String()] = item + } + } + return m +} + +// RecordKeyInt converts `r` to a map[int]Record of which key is specified by `key`. +func (r Result) RecordKeyInt(key string) map[int]Record { + m := make(map[int]Record) + for _, item := range r { + if v, ok := item[key]; ok { + m[v.Int()] = item + } + } + return m +} + +// RecordKeyUint converts `r` to a map[uint]Record of which key is specified by `key`. +func (r Result) RecordKeyUint(key string) map[uint]Record { + m := make(map[uint]Record) + for _, item := range r { + if v, ok := item[key]; ok { + m[v.Uint()] = item + } + } + return m +} + +// Structs converts `r` to struct slice. +// Note that the parameter `pointer` should be type of *[]struct/*[]*struct. +func (r Result) Structs(pointer any) (err error) { + // If the result is empty and the target pointer is not empty, it returns error. + if r.IsEmpty() { + if !empty.IsEmpty(pointer, true) { + return sql.ErrNoRows + } + return nil + } + var ( + sliceOption = gconv.SliceOption{ContinueOnError: true} + structOption = gconv.StructOption{ + PriorityTag: OrmTagForStruct, + ContinueOnError: true, + } + ) + return converter.Structs(r, pointer, gconv.StructsOption{ + SliceOption: sliceOption, + StructOption: structOption, + }) +} diff --git a/database/gdb_type_result_scanlist.go b/database/gdb_type_result_scanlist.go new file mode 100644 index 0000000..a9f49b9 --- /dev/null +++ b/database/gdb_type_result_scanlist.go @@ -0,0 +1,512 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package database + +import ( + "database/sql" + "reflect" + + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gstructs" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/gutil" +) + +// ScanList converts `r` to struct slice which contains other complex struct attributes. +// Note that the parameter `structSlicePointer` should be type of *[]struct/*[]*struct. +// +// Usage example 1: Normal attribute struct relation: +// +// type EntityUser struct { +// Uid int +// Name string +// } +// +// type EntityUserDetail struct { +// Uid int +// Address string +// } +// +// type EntityUserScores struct { +// Id int +// Uid int +// Score int +// Course string +// } +// +// type Entity struct { +// User *EntityUser +// UserDetail *EntityUserDetail +// UserScores []*EntityUserScores +// } +// +// var users []*Entity +// ScanList(&users, "User") +// ScanList(&users, "User", "uid") +// ScanList(&users, "UserDetail", "User", "uid:Uid") +// ScanList(&users, "UserScores", "User", "uid:Uid") +// ScanList(&users, "UserScores", "User", "uid") +// +// Usage example 2: Embedded attribute struct relation: +// +// type EntityUser struct { +// Uid int +// Name string +// } +// +// type EntityUserDetail struct { +// Uid int +// Address string +// } +// +// type EntityUserScores struct { +// Id int +// Uid int +// Score int +// } +// +// type Entity struct { +// EntityUser +// UserDetail EntityUserDetail +// UserScores []EntityUserScores +// } +// +// var users []*Entity +// ScanList(&users) +// ScanList(&users, "UserDetail", "uid") +// ScanList(&users, "UserScores", "uid") +// +// The parameters "User/UserDetail/UserScores" in the example codes specify the target attribute struct +// that current result will be bound to. +// +// The "uid" in the example codes is the table field name of the result, and the "Uid" is the relational +// struct attribute name - not the attribute name of the bound to target. In the example codes, it's attribute +// name "Uid" of "User" of entity "Entity". It automatically calculates the HasOne/HasMany relationship with +// given `relation` parameter. +// +// See the example or unit testing cases for clear understanding for this function. +func (r Result) ScanList(structSlicePointer any, bindToAttrName string, relationAttrNameAndFields ...string) (err error) { + out, err := checkGetSliceElementInfoForScanList(structSlicePointer, bindToAttrName) + if err != nil { + return err + } + + var ( + relationAttrName string + relationFields string + ) + switch len(relationAttrNameAndFields) { + case 2: + relationAttrName = relationAttrNameAndFields[0] + relationFields = relationAttrNameAndFields[1] + case 1: + relationFields = relationAttrNameAndFields[0] + } + return doScanList(doScanListInput{ + Model: nil, + Result: r, + StructSlicePointer: structSlicePointer, + StructSliceValue: out.SliceReflectValue, + BindToAttrName: bindToAttrName, + RelationAttrName: relationAttrName, + RelationFields: relationFields, + }) +} + +type checkGetSliceElementInfoForScanListOutput struct { + SliceReflectValue reflect.Value + BindToAttrType reflect.Type +} + +func checkGetSliceElementInfoForScanList(structSlicePointer any, bindToAttrName string) (out *checkGetSliceElementInfoForScanListOutput, err error) { + // Necessary checks for parameters. + if structSlicePointer == nil { + return nil, gerror.NewCode(gcode.CodeInvalidParameter, `structSlicePointer cannot be nil`) + } + if bindToAttrName == "" { + return nil, gerror.NewCode(gcode.CodeInvalidParameter, `bindToAttrName should not be empty`) + } + var ( + reflectType reflect.Type + reflectValue = reflect.ValueOf(structSlicePointer) + reflectKind = reflectValue.Kind() + ) + if reflectKind == reflect.Interface { + reflectValue = reflectValue.Elem() + reflectKind = reflectValue.Kind() + } + if reflectKind != reflect.Pointer { + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + "structSlicePointer should be type of *[]struct/*[]*struct, but got: %s", + reflect.TypeOf(structSlicePointer).String(), + ) + } + out = &checkGetSliceElementInfoForScanListOutput{ + SliceReflectValue: reflectValue.Elem(), + } + // Find the element struct type of the slice. + reflectType = reflectValue.Type().Elem().Elem() + reflectKind = reflectType.Kind() + for reflectKind == reflect.Pointer { + reflectType = reflectType.Elem() + reflectKind = reflectType.Kind() + } + if reflectKind != reflect.Struct { + err = gerror.NewCodef( + gcode.CodeInvalidParameter, + "structSlicePointer should be type of *[]struct/*[]*struct, but got: %s", + reflect.TypeOf(structSlicePointer).String(), + ) + return + } + // Find the target field by given name. + structField, ok := reflectType.FieldByName(bindToAttrName) + if !ok { + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `field "%s" not found in element of "%s"`, + bindToAttrName, + reflect.TypeOf(structSlicePointer).String(), + ) + } + // Find the attribute struct type for ORM fields filtering. + reflectType = structField.Type + reflectKind = reflectType.Kind() + for reflectKind == reflect.Pointer { + reflectType = reflectType.Elem() + reflectKind = reflectType.Kind() + } + if reflectKind == reflect.Slice || reflectKind == reflect.Array { + reflectType = reflectType.Elem() + // reflectKind = reflectType.Kind() + } + out.BindToAttrType = reflectType + return +} + +type doScanListInput struct { + Model *Model + Result Result + StructSlicePointer any + StructSliceValue reflect.Value + BindToAttrName string + RelationAttrName string + RelationFields string +} + +// doScanList converts `result` to struct slice which contains other complex struct attributes recursively. +// The parameter `model` is used for recursively scanning purpose, which means, it can scan the attribute struct/structs recursively, +// but it needs the Model for database accessing. +// Note that the parameter `structSlicePointer` should be type of *[]struct/*[]*struct. +func doScanList(in doScanListInput) (err error) { + if in.Result.IsEmpty() { + return nil + } + if in.BindToAttrName == "" { + return gerror.NewCode(gcode.CodeInvalidParameter, `bindToAttrName should not be empty`) + } + + length := len(in.Result) + if length == 0 { + // The pointed slice is not empty. + if in.StructSliceValue.Len() > 0 { + // It here checks if it has struct item, which is already initialized. + // It then returns error to warn the developer its empty and no conversion. + if v := in.StructSliceValue.Index(0); v.Kind() != reflect.Pointer { + return sql.ErrNoRows + } + } + // Do nothing for empty struct slice. + return nil + } + var ( + arrayValue reflect.Value // Like: []*Entity + arrayItemType reflect.Type // Like: *Entity + reflectType = reflect.TypeOf(in.StructSlicePointer) + ) + if in.StructSliceValue.Len() > 0 { + arrayValue = in.StructSliceValue + } else { + arrayValue = reflect.MakeSlice(reflectType.Elem(), length, length) + } + + // Slice element item. + arrayItemType = arrayValue.Index(0).Type() + + // Relation variables. + var ( + relationDataMap map[string]Value + relationFromFieldName string // Eg: relationKV: id:uid -> id + relationBindToFieldName string // Eg: relationKV: id:uid -> uid + ) + if len(in.RelationFields) > 0 { + // The relation key string of table field name and attribute name + // can be joined with char '=' or ':'. + array := gstr.SplitAndTrim(in.RelationFields, "=") + if len(array) == 1 { + // Compatible with old splitting char ':'. + array = gstr.SplitAndTrim(in.RelationFields, ":") + } + if len(array) == 1 { + // The relation names are the same. + array = []string{in.RelationFields, in.RelationFields} + } + if len(array) == 2 { + // Defined table field to relation attribute name. + // Like: + // uid:Uid + // uid:UserId + relationFromFieldName = array[0] + relationBindToFieldName = array[1] + if key, _ := gutil.MapPossibleItemByKey(in.Result[0].Map(), relationFromFieldName); key == "" { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `cannot find possible related table field name "%s" from given relation fields "%s"`, + relationFromFieldName, + in.RelationFields, + ) + } else { + relationFromFieldName = key + } + } else { + return gerror.NewCode( + gcode.CodeInvalidParameter, + `parameter relationKV should be format of "ResultFieldName:BindToAttrName"`, + ) + } + if relationFromFieldName != "" { + // Note that the value might be type of slice. + relationDataMap = in.Result.MapKeyValue(relationFromFieldName) + } + if len(relationDataMap) == 0 { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `cannot find the relation data map, maybe invalid relation fields given "%v"`, + in.RelationFields, + ) + } + } + // Bind to target attribute. + var ( + ok bool + bindToAttrValue reflect.Value + bindToAttrKind reflect.Kind + bindToAttrType reflect.Type + bindToAttrField reflect.StructField + ) + if arrayItemType.Kind() == reflect.Pointer { + if bindToAttrField, ok = arrayItemType.Elem().FieldByName(in.BindToAttrName); !ok { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`, + in.BindToAttrName, + ) + } + } else { + if bindToAttrField, ok = arrayItemType.FieldByName(in.BindToAttrName); !ok { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`, + in.BindToAttrName, + ) + } + } + bindToAttrType = bindToAttrField.Type + bindToAttrKind = bindToAttrType.Kind() + + // Bind to relation conditions. + var ( + relationFromAttrValue reflect.Value + relationFromAttrField reflect.Value + relationBindToFieldNameChecked bool + ) + for i := 0; i < arrayValue.Len(); i++ { + arrayElemValue := arrayValue.Index(i) + // The FieldByName should be called on non-pointer reflect.Value. + if arrayElemValue.Kind() == reflect.Pointer { + // Like: []*Entity + arrayElemValue = arrayElemValue.Elem() + if !arrayElemValue.IsValid() { + // The element is nil, then create one and set it to the slice. + // The "reflect.New(itemType.Elem())" creates a new element and returns the address of it. + // For example: + // reflect.New(itemType.Elem()) => *Entity + // reflect.New(itemType.Elem()).Elem() => Entity + arrayElemValue = reflect.New(arrayItemType.Elem()).Elem() + arrayValue.Index(i).Set(arrayElemValue.Addr()) + } + // } else { + // Like: []Entity + } + bindToAttrValue = arrayElemValue.FieldByName(in.BindToAttrName) + if in.RelationAttrName != "" { + // Attribute value of current slice element. + relationFromAttrValue = arrayElemValue.FieldByName(in.RelationAttrName) + if relationFromAttrValue.Kind() == reflect.Pointer { + relationFromAttrValue = relationFromAttrValue.Elem() + } + } else { + // Current slice element. + relationFromAttrValue = arrayElemValue + } + if len(relationDataMap) > 0 && !relationFromAttrValue.IsValid() { + return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields) + } + // Check and find possible bind to attribute name. + if in.RelationFields != "" && !relationBindToFieldNameChecked { + relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName) + if !relationFromAttrField.IsValid() { + fieldMap, _ := gstructs.FieldMap(gstructs.FieldMapInput{ + Pointer: relationFromAttrValue, + RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag, + }) + if key, _ := gutil.MapPossibleItemByKey(gconv.Map(fieldMap), relationBindToFieldName); key == "" { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `cannot find possible related attribute name "%s" from given relation fields "%s"`, + relationBindToFieldName, + in.RelationFields, + ) + } else { + relationBindToFieldName = key + } + } + relationBindToFieldNameChecked = true + } + switch bindToAttrKind { + case reflect.Array, reflect.Slice: + if len(relationDataMap) > 0 { + relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName) + if relationFromAttrField.IsValid() { + results := make(Result, 0) + for _, v := range relationDataMap[gconv.String(relationFromAttrField.Interface())].Slice() { + results = append(results, v.(Record)) + } + if err = results.Structs(bindToAttrValue.Addr()); err != nil { + return err + } + // Recursively Scan. + if in.Model != nil { + if err = in.Model.doWithScanStructs(bindToAttrValue.Addr()); err != nil { + return nil + } + } + } else { + // Maybe the attribute does not exist yet. + return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields) + } + } else { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `relationKey should not be empty as field "%s" is slice`, + in.BindToAttrName, + ) + } + + case reflect.Pointer: + var element reflect.Value + if bindToAttrValue.IsNil() { + element = reflect.New(bindToAttrType.Elem()).Elem() + } else { + element = bindToAttrValue.Elem() + } + if len(relationDataMap) > 0 { + relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName) + if relationFromAttrField.IsValid() { + v := relationDataMap[gconv.String(relationFromAttrField.Interface())] + if v == nil { + // There's no relational data. + continue + } + if v.IsSlice() { + if err = v.Slice()[0].(Record).Struct(element); err != nil { + return err + } + } else { + if err = v.Val().(Record).Struct(element); err != nil { + return err + } + } + } else { + // Maybe the attribute does not exist yet. + return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields) + } + } else { + if i >= len(in.Result) { + // There's no relational data. + continue + } + v := in.Result[i] + if v == nil { + // There's no relational data. + continue + } + if err = v.Struct(element); err != nil { + return err + } + } + // Recursively Scan. + if in.Model != nil { + if err = in.Model.doWithScanStruct(element); err != nil { + return err + } + } + bindToAttrValue.Set(element.Addr()) + + case reflect.Struct: + if len(relationDataMap) > 0 { + relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName) + if relationFromAttrField.IsValid() { + relationDataItem := relationDataMap[gconv.String(relationFromAttrField.Interface())] + if relationDataItem == nil { + // There's no relational data. + continue + } + if relationDataItem.IsSlice() { + if err = relationDataItem.Slice()[0].(Record).Struct(bindToAttrValue); err != nil { + return err + } + } else { + if err = relationDataItem.Val().(Record).Struct(bindToAttrValue); err != nil { + return err + } + } + } else { + // Maybe the attribute does not exist yet. + return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields) + } + } else { + if i >= len(in.Result) { + // There's no relational data. + continue + } + relationDataItem := in.Result[i] + if relationDataItem == nil { + // There's no relational data. + continue + } + if err = relationDataItem.Struct(bindToAttrValue); err != nil { + return err + } + } + // Recursively Scan. + if in.Model != nil { + if err = in.Model.doWithScanStruct(bindToAttrValue); err != nil { + return err + } + } + + default: + return gerror.NewCodef(gcode.CodeInvalidParameter, `unsupported attribute type: %s`, bindToAttrKind.String()) + } + } + reflect.ValueOf(in.StructSlicePointer).Elem().Set(arrayValue) + return nil +} diff --git a/database/index.go b/database/index.go deleted file mode 100644 index 636bab8..0000000 --- a/database/index.go +++ /dev/null @@ -1 +0,0 @@ -package database diff --git a/database/instance/instance.go b/database/instance/instance.go new file mode 100644 index 0000000..6af0fa2 --- /dev/null +++ b/database/instance/instance.go @@ -0,0 +1,79 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// Package instance provides instances management. +// +// Note that this package is not used for cache, as it has no cache expiration. +package instance + +import ( + "github.com/gogf/gf/v2/container/gmap" + "github.com/gogf/gf/v2/encoding/ghash" +) + +const ( + groupNumber = 64 +) + +var ( + groups = make([]*gmap.StrAnyMap, groupNumber) +) + +func init() { + for i := 0; i < groupNumber; i++ { + groups[i] = gmap.NewStrAnyMap(true) + } +} + +func getGroup(key string) *gmap.StrAnyMap { + return groups[int(ghash.DJB([]byte(key))%groupNumber)] +} + +// Get returns the instance by given name. +func Get(name string) any { + return getGroup(name).Get(name) +} + +// Set sets an instance to the instance manager with given name. +func Set(name string, instance any) { + getGroup(name).Set(name, instance) +} + +// GetOrSet returns the instance by name, +// or set instance to the instance manager if it does not exist and returns this instance. +func GetOrSet(name string, instance any) any { + return getGroup(name).GetOrSet(name, instance) +} + +// GetOrSetFunc returns the instance by name, +// or sets instance with returned value of callback function `f` if it does not exist +// and then returns this instance. +func GetOrSetFunc(name string, f func() any) any { + return getGroup(name).GetOrSetFunc(name, f) +} + +// GetOrSetFuncLock returns the instance by name, +// or sets instance with returned value of callback function `f` if it does not exist +// and then returns this instance. +// +// GetOrSetFuncLock differs with GetOrSetFunc function is that it executes function `f` +// with mutex.Lock of the hash map. +func GetOrSetFuncLock(name string, f func() any) any { + return getGroup(name).GetOrSetFuncLock(name, f) +} + +// SetIfNotExist sets `instance` to the map if the `name` does not exist, then returns true. +// It returns false if `name` exists, and `instance` would be ignored. +func SetIfNotExist(name string, instance any) bool { + return getGroup(name).SetIfNotExist(name, instance) +} + +// Clear deletes all instances stored. +func Clear() { + for i := 0; i < groupNumber; i++ { + groups[i].Clear() + } +} diff --git a/database/instance/instance_test.go b/database/instance/instance_test.go new file mode 100644 index 0000000..c29cb92 --- /dev/null +++ b/database/instance/instance_test.go @@ -0,0 +1,44 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package instance_test + +import ( + "testing" + + "github.com/gogf/gf/v2/internal/instance" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_SetGet(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + instance.Set("test-user", 1) + t.Assert(instance.Get("test-user"), 1) + t.Assert(instance.Get("none-exists"), nil) + }) + gtest.C(t, func(t *gtest.T) { + t.Assert(instance.GetOrSet("test-1", 1), 1) + t.Assert(instance.Get("test-1"), 1) + }) + gtest.C(t, func(t *gtest.T) { + t.Assert(instance.GetOrSetFunc("test-2", func() any { + return 2 + }), 2) + t.Assert(instance.Get("test-2"), 2) + }) + gtest.C(t, func(t *gtest.T) { + t.Assert(instance.GetOrSetFuncLock("test-3", func() any { + return 3 + }), 3) + t.Assert(instance.Get("test-3"), 3) + }) + gtest.C(t, func(t *gtest.T) { + t.Assert(instance.SetIfNotExist("test-4", 4), true) + t.Assert(instance.Get("test-4"), 4) + t.Assert(instance.SetIfNotExist("test-4", 5), false) + t.Assert(instance.Get("test-4"), 4) + }) +} diff --git a/database/intlog/intlog.go b/database/intlog/intlog.go new file mode 100644 index 0000000..be8fda7 --- /dev/null +++ b/database/intlog/intlog.go @@ -0,0 +1,125 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// Package intlog provides internal logging for GoFrame development usage only. +package intlog + +import ( + "bytes" + "context" + "fmt" + "path/filepath" + "time" + + "git.magicany.cc/black1552/gin-base/database/utils" + "go.opentelemetry.io/otel/trace" + + "github.com/gogf/gf/v2/debug/gdebug" +) + +const ( + stackFilterKey = "/internal/intlog" +) + +// Print prints `v` with newline using fmt.Println. +// The parameter `v` can be multiple variables. +func Print(ctx context.Context, v ...any) { + if !utils.IsDebugEnabled() { + return + } + doPrint(ctx, fmt.Sprint(v...), false) +} + +// Printf prints `v` with format `format` using fmt.Printf. +// The parameter `v` can be multiple variables. +func Printf(ctx context.Context, format string, v ...any) { + if !utils.IsDebugEnabled() { + return + } + doPrint(ctx, fmt.Sprintf(format, v...), false) +} + +// Error prints `v` with newline using fmt.Println. +// The parameter `v` can be multiple variables. +func Error(ctx context.Context, v ...any) { + if !utils.IsDebugEnabled() { + return + } + doPrint(ctx, fmt.Sprint(v...), true) +} + +// Errorf prints `v` with format `format` using fmt.Printf. +func Errorf(ctx context.Context, format string, v ...any) { + if !utils.IsDebugEnabled() { + return + } + doPrint(ctx, fmt.Sprintf(format, v...), true) +} + +// PrintFunc prints the output from function `f`. +// It only calls function `f` if debug mode is enabled. +func PrintFunc(ctx context.Context, f func() string) { + if !utils.IsDebugEnabled() { + return + } + s := f() + if s == "" { + return + } + doPrint(ctx, s, false) +} + +// ErrorFunc prints the output from function `f`. +// It only calls function `f` if debug mode is enabled. +func ErrorFunc(ctx context.Context, f func() string) { + if !utils.IsDebugEnabled() { + return + } + s := f() + if s == "" { + return + } + doPrint(ctx, s, true) +} + +func doPrint(ctx context.Context, content string, stack bool) { + if !utils.IsDebugEnabled() { + return + } + buffer := bytes.NewBuffer(nil) + buffer.WriteString(time.Now().Format("2006-01-02 15:04:05.000")) + buffer.WriteString(" [INTE] ") + buffer.WriteString(file()) + buffer.WriteString(" ") + if s := traceIDStr(ctx); s != "" { + buffer.WriteString(s + " ") + } + buffer.WriteString(content) + buffer.WriteString("\n") + if stack { + buffer.WriteString("Caller Stack:\n") + buffer.WriteString(gdebug.StackWithFilter([]string{stackFilterKey})) + } + fmt.Print(buffer.String()) +} + +// traceIDStr retrieves and returns the trace id string for logging output. +func traceIDStr(ctx context.Context) string { + if ctx == nil { + return "" + } + spanCtx := trace.SpanContextFromContext(ctx) + if traceID := spanCtx.TraceID(); traceID.IsValid() { + return "{" + traceID.String() + "}" + } + return "" +} + +// file returns caller file name along with its line number. +func file() string { + _, p, l := gdebug.CallerWithFilter([]string{stackFilterKey}) + return fmt.Sprintf(`%s:%d`, filepath.Base(p), l) +} diff --git a/database/json/json.go b/database/json/json.go new file mode 100644 index 0000000..88c7050 --- /dev/null +++ b/database/json/json.go @@ -0,0 +1,85 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// Package json provides json operations wrapping ignoring stdlib or third-party lib json. +package json + +import ( + "bytes" + "encoding/json" + "io" + + "github.com/gogf/gf/v2/errors/gerror" +) + +// RawMessage is a raw encoded JSON value. +// It implements Marshaler and Unmarshaler and can +// be used to delay JSON decoding or precompute a JSON encoding. +type RawMessage = json.RawMessage + +// Marshal adapts to json/encoding Marshal API. +// +// Marshal returns the JSON encoding of v, adapts to json/encoding Marshal API +// Refer to https://godoc.org/encoding/json#Marshal for more information. +func Marshal(v any) (marshaledBytes []byte, err error) { + marshaledBytes, err = json.Marshal(v) + if err != nil { + err = gerror.Wrap(err, `json.Marshal failed`) + } + return +} + +// MarshalIndent same as json.MarshalIndent. +func MarshalIndent(v any, prefix, indent string) (marshaledBytes []byte, err error) { + marshaledBytes, err = json.MarshalIndent(v, prefix, indent) + if err != nil { + err = gerror.Wrap(err, `json.MarshalIndent failed`) + } + return +} + +// Unmarshal adapts to json/encoding Unmarshal API +// +// Unmarshal parses the JSON-encoded data and stores the result in the value pointed to by v. +// Refer to https://godoc.org/encoding/json#Unmarshal for more information. +func Unmarshal(data []byte, v any) (err error) { + err = json.Unmarshal(data, v) + if err != nil { + err = gerror.Wrap(err, `json.Unmarshal failed`) + } + return +} + +// UnmarshalUseNumber decodes the json data bytes to target interface using number option. +func UnmarshalUseNumber(data []byte, v any) (err error) { + decoder := NewDecoder(bytes.NewReader(data)) + decoder.UseNumber() + err = decoder.Decode(v) + if err != nil { + err = gerror.Wrap(err, `json.UnmarshalUseNumber failed`) + } + return +} + +// NewEncoder same as json.NewEncoder +func NewEncoder(writer io.Writer) *json.Encoder { + return json.NewEncoder(writer) +} + +// NewDecoder adapts to json/stream NewDecoder API. +// +// NewDecoder returns a new decoder that reads from r. +// +// Instead of a json/encoding Decoder, a Decoder is returned +// Refer to https://godoc.org/encoding/json#NewDecoder for more information. +func NewDecoder(reader io.Reader) *json.Decoder { + return json.NewDecoder(reader) +} + +// Valid reports whether data is a valid JSON encoding. +func Valid(data []byte) bool { + return json.Valid(data) +} diff --git a/database/migrate.go b/database/migrate.go deleted file mode 100644 index 3ce028b..0000000 --- a/database/migrate.go +++ /dev/null @@ -1,43 +0,0 @@ -package database - -import ( - "git.magicany.cc/black1552/gin-base/log" - "github.com/gogf/gf/v2/frame/g" -) - -func SetAutoMigrate(models ...interface{}) { - if g.IsNil(Db) { - log.Error("数据库连接失败") - return - } - err := Db.AutoMigrate(models...) - if err != nil { - log.Error("数据库迁移失败", err) - } -} -func RenameColumn(dst interface{}, name, newName string) { - if Db.Migrator().HasColumn(dst, name) { - err := Db.Migrator().RenameColumn(dst, name, newName) - if err != nil { - log.Error("数据库修改字段失败", err) - return - } - } else { - log.Info("数据库字段不存在", name) - } -} - -// DropColumn -// 删除字段 -// 例:DropColumn(&User{}, "Sex") -func DropColumn(dst interface{}, name string) { - if Db.Migrator().HasColumn(dst, name) { - err := Db.Migrator().DropColumn(dst, name) - if err != nil { - log.Error("数据库删除字段失败", err) - return - } - } else { - log.Info("数据库字段不存在", name) - } -} diff --git a/database/reflection/reflection.go b/database/reflection/reflection.go new file mode 100644 index 0000000..0fccd35 --- /dev/null +++ b/database/reflection/reflection.go @@ -0,0 +1,94 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// Package reflection provides some reflection functions for internal usage. +package reflection + +import ( + "reflect" +) + +type OriginValueAndKindOutput struct { + InputValue reflect.Value + InputKind reflect.Kind + OriginValue reflect.Value + OriginKind reflect.Kind +} + +// OriginValueAndKind retrieves and returns the original reflect value and kind. +func OriginValueAndKind(value any) (out OriginValueAndKindOutput) { + if v, ok := value.(reflect.Value); ok { + out.InputValue = v + } else { + out.InputValue = reflect.ValueOf(value) + } + out.InputKind = out.InputValue.Kind() + out.OriginValue = out.InputValue + out.OriginKind = out.InputKind + for out.OriginKind == reflect.Pointer { + out.OriginValue = out.OriginValue.Elem() + out.OriginKind = out.OriginValue.Kind() + } + return +} + +type OriginTypeAndKindOutput struct { + InputType reflect.Type + InputKind reflect.Kind + OriginType reflect.Type + OriginKind reflect.Kind +} + +// OriginTypeAndKind retrieves and returns the original reflect type and kind. +func OriginTypeAndKind(value any) (out OriginTypeAndKindOutput) { + if value == nil { + return + } + if reflectType, ok := value.(reflect.Type); ok { + out.InputType = reflectType + } else { + if reflectValue, ok := value.(reflect.Value); ok { + out.InputType = reflectValue.Type() + } else { + out.InputType = reflect.TypeOf(value) + } + } + out.InputKind = out.InputType.Kind() + out.OriginType = out.InputType + out.OriginKind = out.InputKind + for out.OriginKind == reflect.Pointer { + out.OriginType = out.OriginType.Elem() + out.OriginKind = out.OriginType.Kind() + } + return +} + +// ValueToInterface converts reflect value to its interface type. +func ValueToInterface(v reflect.Value) (value any, ok bool) { + if v.IsValid() && v.CanInterface() { + return v.Interface(), true + } + switch v.Kind() { + case reflect.Bool: + return v.Bool(), true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int(), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint(), true + case reflect.Float32, reflect.Float64: + return v.Float(), true + case reflect.Complex64, reflect.Complex128: + return v.Complex(), true + case reflect.String: + return v.String(), true + case reflect.Pointer: + return ValueToInterface(v.Elem()) + case reflect.Interface: + return ValueToInterface(v.Elem()) + default: + return nil, false + } +} diff --git a/database/utils/utils.go b/database/utils/utils.go new file mode 100644 index 0000000..414a90c --- /dev/null +++ b/database/utils/utils.go @@ -0,0 +1,8 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// Package utils provides some utility functions for internal usage. +package utils diff --git a/database/utils/utils_array.go b/database/utils/utils_array.go new file mode 100644 index 0000000..2b566e0 --- /dev/null +++ b/database/utils/utils_array.go @@ -0,0 +1,26 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import "reflect" + +// IsArray checks whether given value is array/slice. +// Note that it uses reflect internally implementing this feature. +func IsArray(value any) bool { + rv := reflect.ValueOf(value) + kind := rv.Kind() + if kind == reflect.Pointer { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + case reflect.Array, reflect.Slice: + return true + default: + return false + } +} diff --git a/database/utils/utils_debug.go b/database/utils/utils_debug.go new file mode 100644 index 0000000..4f0cefe --- /dev/null +++ b/database/utils/utils_debug.go @@ -0,0 +1,38 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import "git.magicany.cc/black1552/gin-base/database/command" + +const ( + // Debug key for checking if in debug mode. + commandEnvKeyForDebugKey = "gf.debug" +) + +// isDebugEnabled marks whether GoFrame debug mode is enabled. +var isDebugEnabled = false + +func init() { + // Debugging configured. + value := command.GetOptWithEnv(commandEnvKeyForDebugKey) + if value == "" || value == "0" || value == "false" { + isDebugEnabled = false + } else { + isDebugEnabled = true + } +} + +// IsDebugEnabled checks and returns whether debug mode is enabled. +// The debug mode is enabled when command argument "gf.debug" or environment "GF_DEBUG" is passed. +func IsDebugEnabled() bool { + return isDebugEnabled +} + +// SetDebugEnabled enables/disables the internal debug info. +func SetDebugEnabled(enabled bool) { + isDebugEnabled = enabled +} diff --git a/database/utils/utils_io.go b/database/utils/utils_io.go new file mode 100644 index 0000000..c291693 --- /dev/null +++ b/database/utils/utils_io.go @@ -0,0 +1,48 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import ( + "io" +) + +// ReadCloser implements the io.ReadCloser interface +// which is used for reading request body content multiple times. +// +// Note that it cannot be closed. +type ReadCloser struct { + index int // Current read position. + content []byte // Content. + repeatable bool // Mark the content can be repeatable read. +} + +// NewReadCloser creates and returns a RepeatReadCloser object. +func NewReadCloser(content []byte, repeatable bool) io.ReadCloser { + return &ReadCloser{ + content: content, + repeatable: repeatable, + } +} + +// Read implements the io.ReadCloser interface. +func (b *ReadCloser) Read(p []byte) (n int, err error) { + // Make it repeatable reading. + if b.index >= len(b.content) && b.repeatable { + b.index = 0 + } + n = copy(p, b.content[b.index:]) + b.index += n + if b.index >= len(b.content) { + return n, io.EOF + } + return n, nil +} + +// Close implements the io.ReadCloser interface. +func (b *ReadCloser) Close() error { + return nil +} diff --git a/database/utils/utils_is.go b/database/utils/utils_is.go new file mode 100644 index 0000000..e76fe6c --- /dev/null +++ b/database/utils/utils_is.go @@ -0,0 +1,102 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import ( + "reflect" + + "git.magicany.cc/black1552/gin-base/database/empty" +) + +// IsNil checks whether `value` is nil, especially for any type value. +func IsNil(value any) bool { + return empty.IsNil(value) +} + +// IsEmpty checks whether `value` is empty. +func IsEmpty(value any) bool { + return empty.IsEmpty(value) +} + +// IsInt checks whether `value` is type of int. +func IsInt(value any) bool { + switch value.(type) { + case int, *int, int8, *int8, int16, *int16, int32, *int32, int64, *int64: + return true + } + return false +} + +// IsUint checks whether `value` is type of uint. +func IsUint(value any) bool { + switch value.(type) { + case uint, *uint, uint8, *uint8, uint16, *uint16, uint32, *uint32, uint64, *uint64: + return true + } + return false +} + +// IsFloat checks whether `value` is type of float. +func IsFloat(value any) bool { + switch value.(type) { + case float32, *float32, float64, *float64: + return true + } + return false +} + +// IsSlice checks whether `value` is type of slice. +func IsSlice(value any) bool { + var ( + reflectValue = reflect.ValueOf(value) + reflectKind = reflectValue.Kind() + ) + for reflectKind == reflect.Pointer { + reflectValue = reflectValue.Elem() + reflectKind = reflectValue.Kind() + } + switch reflectKind { + case reflect.Slice, reflect.Array: + return true + } + return false +} + +// IsMap checks whether `value` is type of map. +func IsMap(value any) bool { + var ( + reflectValue = reflect.ValueOf(value) + reflectKind = reflectValue.Kind() + ) + for reflectKind == reflect.Pointer { + reflectValue = reflectValue.Elem() + reflectKind = reflectValue.Kind() + } + switch reflectKind { + case reflect.Map: + return true + } + return false +} + +// IsStruct checks whether `value` is type of struct. +func IsStruct(value any) bool { + reflectType := reflect.TypeOf(value) + if reflectType == nil { + return false + } + reflectKind := reflectType.Kind() + for reflectKind == reflect.Pointer { + reflectType = reflectType.Elem() + reflectKind = reflectType.Kind() + } + switch reflectKind { + case reflect.Struct: + return true + } + return false +} diff --git a/database/utils/utils_list.go b/database/utils/utils_list.go new file mode 100644 index 0000000..e780375 --- /dev/null +++ b/database/utils/utils_list.go @@ -0,0 +1,37 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import "fmt" + +// ListToMapByKey converts `list` to a map[string]any of which key is specified by `key`. +// Note that the item value may be type of slice. +func ListToMapByKey(list []map[string]any, key string) map[string]any { + var ( + s = "" + m = make(map[string]any) + tempMap = make(map[string][]any) + hasMultiValues bool + ) + for _, item := range list { + if k, ok := item[key]; ok { + s = fmt.Sprintf(`%v`, k) + tempMap[s] = append(tempMap[s], item) + if len(tempMap[s]) > 1 { + hasMultiValues = true + } + } + } + for k, v := range tempMap { + if hasMultiValues { + m[k] = v + } else { + m[k] = v[0] + } + } + return m +} diff --git a/database/utils/utils_map.go b/database/utils/utils_map.go new file mode 100644 index 0000000..e000ef5 --- /dev/null +++ b/database/utils/utils_map.go @@ -0,0 +1,37 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +// MapPossibleItemByKey tries to find the possible key-value pair for given key ignoring cases and symbols. +// +// Note that this function might be of low performance. +func MapPossibleItemByKey(data map[string]any, key string) (foundKey string, foundValue any) { + if len(data) == 0 { + return + } + if v, ok := data[key]; ok { + return key, v + } + // Loop checking. + for k, v := range data { + if EqualFoldWithoutChars(k, key) { + return k, v + } + } + return "", nil +} + +// MapContainsPossibleKey checks if the given `key` is contained in given map `data`. +// It checks the key ignoring cases and symbols. +// +// Note that this function might be of low performance. +func MapContainsPossibleKey(data map[string]any, key string) bool { + if k, _ := MapPossibleItemByKey(data, key); k != "" { + return true + } + return false +} diff --git a/database/utils/utils_reflect.go b/database/utils/utils_reflect.go new file mode 100644 index 0000000..7fc1a84 --- /dev/null +++ b/database/utils/utils_reflect.go @@ -0,0 +1,26 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import ( + "reflect" +) + +// CanCallIsNil Can reflect.Value call reflect.Value.IsNil. +// It can avoid reflect.Value.IsNil panics. +func CanCallIsNil(v any) bool { + rv, ok := v.(reflect.Value) + if !ok { + return false + } + switch rv.Kind() { + case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Slice, reflect.UnsafePointer: + return true + default: + return false + } +} diff --git a/database/utils/utils_str.go b/database/utils/utils_str.go new file mode 100644 index 0000000..6a8290f --- /dev/null +++ b/database/utils/utils_str.go @@ -0,0 +1,180 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import ( + "bytes" + "strings" + "unicode" +) + +// DefaultTrimChars are the characters which are stripped by Trim* functions in default. +var DefaultTrimChars = string([]byte{ + '\t', // Tab. + '\v', // Vertical tab. + '\n', // New line (line feed). + '\r', // Carriage return. + '\f', // New page. + ' ', // Ordinary space. + 0x00, // NUL-byte. + 0x85, // Delete. + 0xA0, // Non-breaking space. +}) + +// IsLetterUpper checks whether the given byte b is in upper case. +func IsLetterUpper(b byte) bool { + if b >= byte('A') && b <= byte('Z') { + return true + } + return false +} + +// IsLetterLower checks whether the given byte b is in lower case. +func IsLetterLower(b byte) bool { + if b >= byte('a') && b <= byte('z') { + return true + } + return false +} + +// IsLetter checks whether the given byte b is a letter. +func IsLetter(b byte) bool { + return IsLetterUpper(b) || IsLetterLower(b) +} + +// IsNumeric checks whether the given string s is numeric. +// Note that float string like "123.456" is also numeric. +func IsNumeric(s string) bool { + var ( + dotCount = 0 + length = len(s) + ) + if length == 0 { + return false + } + for i := 0; i < length; i++ { + if (s[i] == '-' || s[i] == '+') && i == 0 { + if length == 1 { + return false + } + continue + } + if s[i] == '.' { + dotCount++ + if i > 0 && i < length-1 && s[i-1] >= '0' && s[i-1] <= '9' { + continue + } else { + return false + } + } + if s[i] < '0' || s[i] > '9' { + return false + } + } + return dotCount <= 1 +} + +// UcFirst returns a copy of the string s with the first letter mapped to its upper case. +func UcFirst(s string) string { + if len(s) == 0 { + return s + } + if IsLetterLower(s[0]) { + return string(s[0]-32) + s[1:] + } + return s +} + +// ReplaceByMap returns a copy of `origin`, +// which is replaced by a map in unordered way, case-sensitively. +func ReplaceByMap(origin string, replaces map[string]string) string { + for k, v := range replaces { + origin = strings.ReplaceAll(origin, k, v) + } + return origin +} + +// RemoveSymbols removes all symbols from string and lefts only numbers and letters. +func RemoveSymbols(s string) string { + b := make([]rune, 0, len(s)) + for _, c := range s { + if c > 127 { + b = append(b, c) + } else if (c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') { + b = append(b, c) + } + } + return string(b) +} + +// EqualFoldWithoutChars checks string `s1` and `s2` equal case-insensitively, +// with/without chars '-'/'_'/'.'/' '. +func EqualFoldWithoutChars(s1, s2 string) bool { + return strings.EqualFold(RemoveSymbols(s1), RemoveSymbols(s2)) +} + +// SplitAndTrim splits string `str` by a string `delimiter` to an array, +// and calls Trim to every element of this array. It ignores the elements +// which are empty after Trim. +func SplitAndTrim(str, delimiter string, characterMask ...string) []string { + array := make([]string, 0) + for _, v := range strings.Split(str, delimiter) { + v = Trim(v, characterMask...) + if v != "" { + array = append(array, v) + } + } + return array +} + +// Trim strips whitespace (or other characters) from the beginning and end of a string. +// The optional parameter `characterMask` specifies the additional stripped characters. +func Trim(str string, characterMask ...string) string { + trimChars := DefaultTrimChars + if len(characterMask) > 0 { + trimChars += characterMask[0] + } + return strings.Trim(str, trimChars) +} + +// FormatCmdKey formats string `s` as command key using uniformed format. +func FormatCmdKey(s string) string { + return strings.ToLower(strings.ReplaceAll(s, "_", ".")) +} + +// FormatEnvKey formats string `s` as environment key using uniformed format. +func FormatEnvKey(s string) string { + return strings.ToUpper(strings.ReplaceAll(s, ".", "_")) +} + +// StripSlashes un-quotes a quoted string by AddSlashes. +func StripSlashes(str string) string { + var buf bytes.Buffer + l, skip := len(str), false + for i, char := range str { + if skip { + skip = false + } else if char == '\\' { + if i+1 < l && str[i+1] == '\\' { + skip = true + } + continue + } + buf.WriteRune(char) + } + return buf.String() +} + +// IsASCII checks whether given string is ASCII characters. +func IsASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true +} diff --git a/db/README.md b/db/README.md deleted file mode 100644 index a06d5bf..0000000 --- a/db/README.md +++ /dev/null @@ -1,948 +0,0 @@ -# Magic-ORM 自主 ORM 框架架构文档 - -## 📋 目录 - -- [概述](#概述) -- [核心特性](#核心特性) -- [架构设计](#架构设计) -- [技术栈](#技术栈) -- [核心接口设计](#核心接口设计) -- [快速开始](#快速开始) -- [详细功能说明](#详细功能说明) -- [最佳实践](#最佳实践) - ---- - -## 概述 - -Magic-ORM 是一个完全自主研发的企业级 Go 语言 ORM 框架,不依赖任何第三方 ORM 库。框架基于 `database/sql` 标准库构建,提供了全自动化事务管理、面向接口设计、智能字段映射等高级特性。支持 MySQL、SQLite 等主流数据库,内置完整的迁移管理和可观测性支持,帮助开发者快速构建高质量的数据访问层。 - -**设计理念:** -- 零依赖:仅依赖 Go 标准库 `database/sql` -- 高性能:优化的查询执行器和连接池管理 -- 易用性:简洁的 API 设计和智能默认行为 -- 可扩展:面向接口的设计,支持自定义驱动扩展 -- **内置驱动**:框架自带所有主流数据库驱动,无需额外安装 - ---- - -## 核心特性 - -- **全自动化嵌套事务支持**:无需手动管理事务传播行为 -- **面向接口化设计**:核心功能均通过接口暴露,便于 Mock 与扩展 -- **内置主流数据库驱动**:开箱即用,并支持自定义驱动扩展 -- **统一配置组件**:与框架配置体系无缝集成 -- **单例模式数据库对象**:同一分组配置仅初始化一次 -- **双模式操作**:原生 SQL + ORM 链式操作 -- **OpenTelemetry 可观测性**:完整支持 Tracing、Logging、Metrics -- **智能结果映射**:`Scan` 自动识别 Map/Struct/Slice,无需 `sql.ErrNoRows` 判空 -- **全自动字段映射**:无需结构体标签,自动匹配数据库字段 -- **参数智能过滤**:自动识别并过滤无效/空值字段 -- **Model/DAO 代码生成器**:一键生成全量数据访问代码 -- **高级特性**:调试模式、DryRun、自定义 Handler、软删除、时间自动更新、模型关联、主从集群等 -- **自动化数据库迁移**:支持自动迁移、增量迁移、回滚迁移等完整迁移管理 - ---- - -## 架构设计 - -### 整体架构图 - -```mermaid -graph TB - A[应用层] --> B[Magic-ORM 框架] - B --> C[配置中心] - B --> D[数据库连接池] - B --> E[事务管理器] - B --> F[迁移管理器] - - C --> C1[统一配置组件] - C --> C2[环境配置] - - D --> D1[MySQL 驱动] - D --> D2[SQLite 驱动] - D --> D3[自定义驱动] - - E --> E1[自动嵌套事务] - E --> E2[事务传播控制] - - F --> F1[自动迁移] - F --> F2[增量迁移] - F --> F3[回滚迁移] - - B --> G[观测性组件] - G --> G1[Tracing] - G --> G2[Logging] - G --> G3[Metrics] - - B --> H[工具组件] - H --> H1[字段映射器] - H --> H2[参数过滤器] - H --> H3[结果映射器] - H --> H4[代码生成器] -``` - -### 目录结构 - -``` -magic-orm/ -├── core/ # 核心实现 -│ ├── database.go # 数据库连接管理 -│ ├── transaction.go # 事务管理 -│ ├── query.go # 查询构建器 -│ └── mapper.go # 字段映射器 -├── migrate/ # 迁移管理 -│ └── migrator.go # 自动迁移实现 -├── generator/ # 代码生成器 -│ ├── model.go # Model 生成 -│ └── dao.go # DAO 生成 -├── tracing/ # OpenTelemetry 集成 -│ └── tracer.go # 链路追踪 -└── driver/ # 数据库驱动适配(已内置) - ├── mysql.go # MySQL 驱动(内置) - ├── sqlite.go # SQLite 驱动(内置) - ├── postgres.go # PostgreSQL 驱动(内置) - ├── sqlserver.go # SQL Server 驱动(内置) - ├── oracle.go # Oracle 驱动(内置) - └── clickhouse.go # ClickHouse 驱动(内置) -``` - -### 核心组件说明 - -#### 1. 数据库连接管理 (`core/database.go`) - -- **单例模式**:全局唯一的 `DB` 实例,确保资源高效利用 -- **多数据库支持**:支持 MySQL、SQLite、PostgreSQL、SQL Server、Oracle、ClickHouse 等 -- **驱动内置**:所有主流数据库驱动已预装在框架中 -- **连接池优化**:内置 sql.DB 连接池管理 -- **健康检查**:启动时自动执行 `Ping()` 验证连接 - -**核心配置项:** -```go -Config{ - DriverName: "mysql", // 驱动名称 - DataSource: "dns", // 数据源连接字符串 - MaxIdleConns: 10, // 最大空闲连接数 - MaxOpenConns: 100, // 最大打开连接数 - Debug: true, // 调试模式 -} -``` - -#### 2. 查询构建器 (`core/query.go`) - -提供流畅的链式查询接口: - -- **条件查询**: Where, Or, And -- **字段选择**: Select, Omit -- **排序分页**: Order, Limit, Offset -- **分组统计**: Group, Having, Count -- **连接查询**: Join, LeftJoin, RightJoin -- **预加载**: Preload - -**示例:** -```go -var users []model.User -db.Model(&model.User{}). - Where("status = ?", 1). - Select("id", "username"). - Order("id DESC"). - Limit(10). - Find(&users) -``` - -#### 3. 事务管理器 (`core/transaction.go`) - -提供完整的事务管理能力: - -- **自动嵌套事务**: 自动管理事务传播 -- **保存点支持**: 支持部分回滚 -- **生命周期回调**: Before/After 钩子 - -#### 4. 字段映射器 (`core/mapper.go`) - -智能字段映射系统: - -- **驼峰转下划线**: UserName -> user_name -- **标签解析**: 支持 db, json 标签 -- **类型转换**: Go 类型与数据库类型自动转换 -- **零值过滤**: 自动过滤空值和零值 - -#### 5. 迁移管理 (`migrate/migrator.go`) - -完整的数据库迁移方案: - -- **自动迁移**: 根据模型自动创建/修改表结构 -- **增量迁移**: 支持添加字段、索引等 -- **回滚支持**: 支持迁移回滚 -- **版本管理**: 迁移版本记录和管理 - -#### 6. 驱动管理器 (`driver/manager.go`) - -统一的驱动管理和注册中心: - -- **驱动注册**: 自动注册所有内置驱动 -- **驱动选择**: 根据配置自动选择合适的驱动 -- **驱动扩展**: 支持用户自定义驱动注册 -- **版本检测**: 自动检测数据库版本并适配特性 - -```go -// 驱动管理器会自动处理 -var supportedDrivers = map[string]driver.Driver{ - "mysql": &MySQLDriver{}, - "sqlite": &SQLiteDriver{}, - "postgres": &PostgresDriver{}, - "sqlserver": &SQLServerDriver{}, - "oracle": &OracleDriver{}, - "clickhouse": &ClickHouseDriver{}, -} -``` - ---- - -## 技术栈 - -### 核心依赖 - -| 组件 | 版本 | 说明 | -|------|------|------| -| Go | 1.25+ | 编程语言 | -| database/sql | stdlib | Go 标准库 | -| driver-go | Latest | 数据库驱动接口规范 | -| OpenTelemetry | Latest | 可观测性框架 | -| **内置驱动集合** | Latest | **包含所有主流数据库驱动** | - -### 支持的数据库驱动 - -框架已内置以下数据库驱动,**无需额外安装**: - -- **MySQL**: 内置驱动(基于 `github.com/go-sql-driver/mysql`) -- **SQLite**: 内置驱动(基于 `github.com/mattn/go-sqlite3`) -- **PostgreSQL**: 内置驱动(基于 `github.com/lib/pq`) -- **SQL Server**: 内置驱动(基于 `github.com/denisenkom/go-mssqldb`) -- **Oracle**: 内置驱动(基于 `github.com/godror/godror`) -- **ClickHouse**: 内置驱动(基于 `github.com/ClickHouse/clickhouse-go`) -- **自定义驱动**: 实现 `driver.Driver` 接口即可扩展 - -> 💡 **说明**:框架在编译时已将所有主流数据库驱动打包,用户只需引入 `magic-orm` 即可完成所有数据库操作,无需单独安装各数据库驱动。 - ---- - -## 核心接口设计 - -### 1. 数据库连接接口 - -```go -// IDatabase 数据库连接接口 -type IDatabase interface { - // 基础操作 - DB() *sql.DB - Close() error - Ping() error - - // 事务管理 - Begin() (ITx, error) - Transaction(fn func(ITx) error) error - - // 查询构建器 - Model(model interface{}) IQuery - Table(name string) IQuery - Query(result interface{}, query string, args ...interface{}) error - Exec(query string, args ...interface{}) (sql.Result, error) - - // 迁移管理 - Migrate(models ...interface{}) error - - // 配置 - SetDebug(bool) - SetMaxIdleConns(int) - SetMaxOpenConns(int) - SetConnMaxLifetime(time.Duration) -} -``` - -### 2. 事务接口 - -```go -// ITx 事务接口 -type ITx interface { - // 基础操作 - Commit() error - Rollback() error - - // 查询操作 - Model(model interface{}) IQuery - Table(name string) IQuery - Insert(model interface{}) (int64, error) - BatchInsert(models interface{}, batchSize int) error - Update(model interface{}, data map[string]interface{}) error - Delete(model interface{}) error - - // 原生 SQL - Query(result interface{}, query string, args ...interface{}) error - Exec(query string, args ...interface{}) (sql.Result, error) -} -``` - -### 3. 查询构建器接口 - -```go -// IQuery 查询构建器接口 -type IQuery interface { - // 条件查询 - Where(query string, args ...interface{}) IQuery - Or(query string, args ...interface{}) IQuery - And(query string, args ...interface{}) IQuery - - // 字段选择 - Select(fields ...string) IQuery - Omit(fields ...string) IQuery - - // 排序 - Order(order string) IQuery - OrderBy(field string, direction string) IQuery - - // 分页 - Limit(limit int) IQuery - Offset(offset int) IQuery - Page(page, pageSize int) IQuery - - // 分组 - Group(group string) IQuery - Having(having string, args ...interface{}) IQuery - - // 连接 - Join(join string, args ...interface{}) IQuery - LeftJoin(table, on string) IQuery - RightJoin(table, on string) IQuery - InnerJoin(table, on string) IQuery - - // 预加载 - Preload(relation string, conditions ...interface{}) IQuery - - // 执行查询 - First(result interface{}) error - Find(result interface{}) error - Count(count *int64) IQuery - Exists() (bool, error) - - // 更新和删除 - Updates(data interface{}) error - UpdateColumn(column string, value interface{}) error - Delete() error - - // 特殊模式 - Unscoped() IQuery - DryRun() IQuery - Debug() IQuery - - // 构建 SQL(不执行) - Build() (string, []interface{}) -} -``` - -### 4. 模型接口 - -```go -// IModel 模型接口 -type IModel interface { - // 表名映射 - TableName() string - - // 生命周期回调(可选) - BeforeCreate(tx ITx) error - AfterCreate(tx ITx) error - BeforeUpdate(tx ITx) error - AfterUpdate(tx ITx) error - BeforeDelete(tx ITx) error - AfterDelete(tx ITx) error - BeforeSave(tx ITx) error - AfterSave(tx ITx) error -} -``` - -### 5. 字段映射器接口 - -```go -// IFieldMapper 字段映射器接口 -type IFieldMapper interface { - // 结构体字段转数据库列 - StructToColumns(model interface{}) (map[string]interface{}, error) - - // 数据库列转结构体字段 - ColumnsToStruct(row *sql.Rows, model interface{}) error - - // 获取表名 - GetTableName(model interface{}) string - - // 获取主键字段 - GetPrimaryKey(model interface{}) string - - // 获取字段信息 - GetFields(model interface{}) []FieldInfo -} - -// FieldInfo 字段信息 -type FieldInfo struct { - Name string // 字段名 - Column string // 列名 - Type string // Go 类型 - DbType string // 数据库类型 - Tag string // 标签 - IsPrimary bool // 是否主键 - IsAuto bool // 是否自增 -} -``` - -### 6. 迁移管理器接口 - -```go -// IMigrator 迁移管理器接口 -type IMigrator interface { - // 自动迁移 - AutoMigrate(models ...interface{}) error - - // 表操作 - CreateTable(model interface{}) error - DropTable(model interface{}) error - HasTable(model interface{}) (bool, error) - RenameTable(oldName, newName string) error - - // 列操作 - AddColumn(model interface{}, field string) error - DropColumn(model interface{}, field string) error - HasColumn(model interface{}, field string) (bool, error) - RenameColumn(model interface{}, oldField, newField string) error - - // 索引操作 - CreateIndex(model interface{}, field string) error - DropIndex(model interface{}, field string) error - HasIndex(model interface{}, field string) (bool, error) -} -``` - -### 7. 代码生成器接口 - -```go -// ICodeGenerator 代码生成器接口 -type ICodeGenerator interface { - // 生成 Model 代码 - GenerateModel(table string, outputDir string) error - - // 生成 DAO 代码 - GenerateDAO(table string, outputDir string) error - - // 生成完整代码 - GenerateAll(tables []string, outputDir string) error - - // 从数据库读取表结构 - InspectTable(tableName string) (*TableSchema, error) -} - -// TableSchema 表结构信息 -type TableSchema struct { - Name string - Columns []ColumnInfo - Indexes []IndexInfo -} -``` - -### 8. 配置结构 - -```go -// Config 数据库配置 -type Config struct { - DriverName string // 驱动名称 - DataSource string // 数据源连接字符串 - MaxIdleConns int // 最大空闲连接数 - MaxOpenConns int // 最大打开连接数 - ConnMaxLifetime time.Duration // 连接最大生命周期 - Debug bool // 调试模式 - - // 主从配置 - Replicas []string // 从库列表 - ReadPolicy ReadPolicy // 读负载均衡策略 - - // OpenTelemetry - EnableTracing bool - ServiceName string -} - -// ReadPolicy 读负载均衡策略 -type ReadPolicy int - -const ( - Random ReadPolicy = iota - RoundRobin - LeastConn -) -``` - ---- - -## 快速开始 - -### 1. 安装 Magic-ORM - -```bash -# 仅需安装 magic-orm,所有数据库驱动已内置 -go get github.com/your-org/magic-orm -``` - -> ✅ **无需单独安装数据库驱动!** 所有驱动已包含在 magic-orm 中。 - -### 2. 配置数据库 - -在配置文件中设置数据库参数: - -```yaml -database: - type: mysql # 或 sqlite, postgres - dns: "user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local" - debug: true - max_idle_conns: 10 - max_open_conns: 100 -``` - -### 3. 定义模型 - -```go -package model - -import "time" - -type User struct { - ID int64 `json:"id" db:"id"` - Username string `json:"username" db:"username"` - Password string `json:"-" db:"password"` - Email string `json:"email" db:"email"` - Status int `json:"status" db:"status"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` -} - -// 表名映射 -func (User) TableName() string { - return "user" -} -``` - -### 4. 初始化数据库连接 - -```go -package main - -import ( - "database/sql" - _ "github.com/go-sql-driver/mysql" - "your-project/orm" -) - -func main() { - // 初始化数据库连接 - db, err := orm.NewDatabase(&orm.Config{ - DriverName: "mysql", - DataSource: "user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local", - MaxIdleConns: 10, - MaxOpenConns: 100, - Debug: true, - }) - if err != nil { - panic(err) - } - defer db.Close() - - // 执行迁移 - orm.Migrate(db, &model.User{}) -} -``` - -### 5. CRUD 操作 - -```go -// 创建 -user := &model.User{Username: "admin", Password: "123456", Email: "admin@example.com"} -id, err := db.Insert(user) - -// 查询单个 -var user model.User -err := db.Model(&model.User{}).Where("id = ?", 1).First(&user) - -// 查询多个 -var users []model.User -err := db.Model(&model.User{}).Where("status = ?", 1).Order("id DESC").Find(&users) - -// 更新 -err := db.Model(&model.User{}).Where("id = ?", 1).Updates(map[string]interface{}{ - "email": "new@example.com", -}) - -// 删除 -err := db.Model(&model.User{}).Where("id = ?", 1).Delete() - -// 原生 SQL -var results []model.User -err := db.Query(&results, "SELECT * FROM user WHERE status = ?", 1) -``` - -### 6. 事务操作 - -```go -// 自动嵌套事务 -err := db.Transaction(func(tx *orm.Tx) error { - // 创建用户 - user := &model.User{Username: "test", Email: "test@example.com"} - _, err := tx.Insert(user) - if err != nil { - return err - } - - // 创建关联数据(自动加入同一事务) - profile := &model.Profile{UserID: user.ID, Avatar: "default.png"} - _, err = tx.Insert(profile) - if err != nil { - return err - } - - return nil -}) -``` - ---- - -## 详细功能说明 - -### 1. 全自动化嵌套事务 - -框架自动管理事务的传播行为,支持以下场景: - -- **REQUIRED**: 如果当前存在事务,则加入该事务;否则创建新事务 -- **REQUIRES_NEW**: 无论当前是否存在事务,都创建新事务 -- **NESTED**: 在当前事务中创建嵌套事务(使用保存点) - -**示例:** -```go -// 外层事务 -db.Transaction(func(tx *orm.Tx) error { - // 内层自动加入同一事务 - userService.CreateUser(tx, user) - orderService.CreateOrder(tx, order) - return nil -}) -``` - -### 2. 智能结果映射 - -无需手动处理 `sql.ErrNoRows`,框架自动识别返回类型: - -```go -// 自动识别 Struct -var user model.User -db.Model(&model.User{}).Where("id = ?", 1).First(&user) // 不存在时返回零值,不报错 - -// 自动识别 Slice -var users []model.User -db.Model(&model.User{}).Where("status = ?", 1).Find(&users) // 空结果返回空切片,而非 nil - -// 自动识别 Map -var result map[string]interface{} -db.Table("user").Where("id = ?", 1).First(&result) -``` - -### 3. 全自动字段映射 - -无需结构体标签,框架自动匹配字段: - -```go -// 驼峰命名自动转下划线 -type UserInfo struct { - UserName string // 自动映射到 user_name 字段 - UserAge int // 自动映射到 user_age 字段 - CreatedAt string // 自动映射到 created_at 字段 -} -``` - -### 4. 参数智能过滤 - -自动过滤零值和空指针: - -```go -// 仅更新非零值字段 -updateData := &model.User{ - Username: "newname", // 会被更新 - Email: "", // 空值,自动过滤 - Status: 0, // 零值,自动过滤 -} -db.Model(&user).Updates(updateData) -``` - -### 5. OpenTelemetry 可观测性 - -完整支持分布式追踪: - -```go -// 自动注入 Span -ctx, span := otel.Tracer("gin-base").Start(context.Background(), "DB Query") -defer span.End() - -// 自动记录 SQL 执行时间、错误信息等 -db.WithContext(ctx).Find(&users) -``` - -### 6. 数据库迁移管理 - -#### 自动迁移 -```go -database.SetAutoMigrate(&model.User{}, &model.Order{}) -``` - -#### 增量迁移 -```go -// 添加新字段 -type UserV2 struct { - model.User - Phone string ` + "`" + `json:"phone" gorm:"column:phone;type:varchar(20)"` + "`" + ` -} -database.SetAutoMigrate(&UserV2{}) -``` - -#### 字段操作 -```go -// 重命名字段 -database.RenameColumn(&model.User{}, "UserName", "Nickname") - -// 删除字段 -database.DropColumn(&model.User{}, "OldField") -``` - -### 7. 高级特性 - -#### 软删除 -```go -type User struct { - ID int64 `json:"id" db:"id"` - DeletedAt *time.Time `json:"deleted_at" db:"deleted_at"` // 软删除标记 -} - -// 自动过滤已删除记录 -db.Model(&model.User{}).Find(&users) // WHERE deleted_at IS NULL - -// 强制包含已删除记录 -db.Unscoped().Model(&model.User{}).Find(&users) -``` - -#### 调试模式 -```yaml -database: - debug: true # 输出所有 SQL 日志 -``` - -#### DryRun 模式 -```go -// 生成 SQL 但不执行 -sql, args := db.Model(&model.User{}).DryRun().Insert(&user) -fmt.Println(sql, args) -``` - -#### 自定义 Handler -```go -// 注册回调函数 -db.Callback().Before("insert").Register("custom_before_insert", func(ctx context.Context, db *orm.DB) error { - // 自定义逻辑 - return nil -}) -``` - -#### 主从集群 -```go -// 配置读写分离 -db, err := orm.NewDatabase(&orm.Config{ - DriverName: "mysql", - DataSource: "master_dsn", - Replicas: []string{"slave1_dsn", "slave2_dsn"}, -}) -``` - ---- - -## 最佳实践 - -### 1. 模型设计规范 - -```go -// ✅ 推荐:使用 db 标签明确字段映射 -type User struct { - ID int64 `json:"id" db:"id"` - Username string `json:"username" db:"username"` - CreatedAt time.Time `json:"created_at" db:"created_at"` -} - -// ❌ 不推荐:缺少字段映射标签 -type User struct { - Id int64 // 无法自动映射到 id 列 - UserName string // 可能映射错误 - CreatedAt time.Time // 时间格式可能不匹配 -} -``` - -### 2. 事务使用规范 - -```go -// ✅ 推荐:使用闭包自动管理事务 -err := db.Transaction(func(tx *orm.Tx) error { - // 业务逻辑 - return nil -}) - -// ❌ 不推荐:手动管理事务 -tx, err := db.Begin() -if err != nil { - panic(err) -} -defer func() { - if r := recover(); r != nil { - tx.Rollback() - } -}() -``` - -### 3. 查询优化 - -```go -// ✅ 推荐:使用 Select 指定字段 -db.Model(&model.User{}).Select("id", "username").Find(&users) - -// ✅ 推荐:使用 Index 加速查询 -// 在数据库层面创建索引 -// CREATE INDEX idx_username ON user(username); - -// ✅ 推荐:批量操作 -users := []model.User{{}, {}, {}} -db.BatchInsert(&users, 100) // 每批 100 条 - -// ❌ 避免:N+1 查询问题 -for _, user := range users { - db.Model(&model.Order{}).Where("user_id = ?", user.ID).Find(&orders) // 循环查询 -} - -// ✅ 使用 Join 或预加载 -db.Query(&results, "SELECT u.*, o.* FROM user u LEFT JOIN orders o ON u.id = o.user_id") -``` - -### 4. 错误处理 - -```go -// ✅ 推荐:统一错误处理 -if err := db.Insert(&user); err != nil { - log.Error("创建用户失败", "error", err) - return err -} - -// ✅ 使用 errors 包判断特定错误 -if errors.Is(err, sql.ErrNoRows) { - // 记录不存在 -} -``` - -### 5. 性能优化 - -```go -// 连接池配置 -sqlDB := db.DB() -sqlDB.SetMaxIdleConns(10) // 最大空闲连接数 -sqlDB.SetMaxOpenConns(100) // 最大打开连接数 -sqlDB.SetConnMaxLifetime(time.Hour) // 连接最大生命周期 - -// 使用 Scan 替代 Find 提升性能 -type Result struct { - ID int64 `db:"id"` - Username string `db:"username"` -} -var results []Result -db.Model(&model.User{}).Select("id", "username").Scan(&results) -``` - ---- - -## 常见问题 - -### Q: 如何处理并发写入? -A: 使用事务 + 乐观锁: -```go -type Product struct { - ID int64 `db:"id"` - Version int `db:"version"` // 版本号 -} - -// 更新时检查版本号 -rows, err := db.Exec( - "UPDATE product SET version = ?, stock = ? WHERE id = ? AND version = ?", - newVersion, newStock, id, oldVersion, -) -count, _ := rows.RowsAffected() -if count == 0 { - return errors.New("乐观锁冲突,数据已被其他事务修改") -} -``` - -### Q: 如何实现读写分离? -A: 配置主从数据库连接: -```go -db, err := orm.NewDatabase(&orm.Config{ - DriverName: "mysql", - DataSource: "master_dsn", - Replicas: []string{"slave1_dsn", "slave2_dsn"}, - ReadPolicy: orm.RoundRobin, // 负载均衡策略 -}) -``` - -### Q: 如何批量插入大量数据? -A: 使用 `BatchInsert`: -```go -users := make([]model.User, 10000) -// ... 填充数据 ... -db.BatchInsert(&users, 1000) // 每批 1000 条,共 10 批 -``` - -### Q: 如何实现字段自动映射? -A: 框架会自动将驼峰命名转换为下划线命名: -```go -type UserInfo struct { - UserName string `db:"user_name"` // 自动映射到 user_name 字段 - UserAge int `db:"user_age"` // 自动映射到 user_age 字段 - CreatedAt string `db:"created_at"` // 自动映射到 created_at 字段 -} -``` - -### Q: 如何处理时间字段? -A: 使用 `time.Time` 类型,框架会自动处理时区转换: -```go -type Event struct { - ID int64 `db:"id"` - StartTime time.Time `db:"start_time"` - EndTime time.Time `db:"end_time"` -} -``` - ---- - -## 更新日志 - -- **v1.0.0**: 初始版本发布 - - 完全自主研发,零依赖第三方 ORM - - 基于 database/sql 标准库 - - 全自动化事务管理 - - 智能字段映射 - - OpenTelemetry 集成 - - 支持 MySQL、SQLite、PostgreSQL - ---- - -## 贡献指南 - -欢迎提交 Issue 和 Pull Request! - ---- - -## 许可证 - -MIT License \ No newline at end of file diff --git a/db/VALIDATION.md b/db/VALIDATION.md deleted file mode 100644 index 75c7827..0000000 --- a/db/VALIDATION.md +++ /dev/null @@ -1,375 +0,0 @@ -# Magic-ORM 功能完整性验证报告 - -## 📋 验证概述 - -本文档验证 Magic-ORM 框架相对于 README.md 中定义的核心特性的完整实现情况。 - ---- - -## ✅ 已完整实现的核心特性 - -### 1. **全自动化嵌套事务支持** ✅ -- **文件**: `core/transaction.go` -- **实现内容**: - - `Transaction()` 方法自动管理事务提交/回滚 - - 支持 panic 时自动回滚 - - 事务中可执行 Insert、BatchInsert、Update、Delete、Query 等操作 -- **测试状态**: ✅ 通过 - -### 2. **面向接口化设计** ✅ -- **文件**: `core/interfaces.go` -- **实现接口**: - - `IDatabase` - 数据库连接接口 - - `ITx` - 事务接口 - - `IQuery` - 查询构建器接口 - - `IModel` - 模型接口 - - `IFieldMapper` - 字段映射器接口 - - `IMigrator` - 迁移管理器接口 - - `ICodeGenerator` - 代码生成器接口 -- **测试状态**: ✅ 通过 - -### 3. **内置主流数据库驱动** ✅ -- **文件**: `driver/manager.go`, `driver/sqlite.go` -- **实现内容**: - - DriverManager 单例模式管理所有驱动 - - SQLite 驱动已实现 - - 支持 MySQL/PostgreSQL/SQL Server/Oracle/ClickHouse(框架已预留接口) -- **测试状态**: ✅ 通过 - -### 4. **统一配置组件** ✅ -- **文件**: `core/interfaces.go` -- **Config 结构**: - ```go - type Config struct { - DriverName string - DataSource string - MaxIdleConns int - MaxOpenConns int - ConnMaxLifetime time.Duration - Debug bool - Replicas []string - ReadPolicy ReadPolicy - EnableTracing bool - ServiceName string - } - ``` -- **测试状态**: ✅ 通过 - -### 5. **单例模式数据库对象** ✅ -- **文件**: `driver/manager.go` -- **实现内容**: - - `GetDefaultManager()` 使用 sync.Once 确保单例 - - 驱动管理器全局唯一实例 -- **测试状态**: ✅ 通过 - -### 6. **双模式操作** ✅ -- **文件**: `core/query.go`, `core/database.go` -- **支持模式**: - - ✅ ORM 链式操作:`db.Model(&User{}).Where("id = ?", 1).Find(&user)` - - ✅ 原生 SQL:`db.Query(&users, "SELECT * FROM user")` -- **测试状态**: ✅ 通过 - -### 7. **OpenTelemetry 可观测性** ✅ -- **文件**: `tracing/tracer.go` -- **实现内容**: - - 自动追踪所有数据库操作 - - 记录 SQL 语句、参数、执行时间、影响行数 - - 支持分布式追踪上下文 -- **测试状态**: ✅ 通过 - -### 8. **智能结果映射** ✅ -- **文件**: `core/result_mapper.go` -- **实现内容**: - - `MapToSlice()` - 映射到 Slice - - `MapToStruct()` - 映射到 Struct - - `ScanAll()` - 自动识别目标类型 - - 无需手动处理 `sql.ErrNoRows` -- **测试状态**: ✅ 通过 - -### 9. **全自动字段映射** ✅ -- **文件**: `core/mapper.go` -- **实现内容**: - - 驼峰命名自动转下划线 - - 解析 db/json 标签 - - Go 类型与数据库类型自动转换 - - 零值自动过滤 -- **测试状态**: ✅ 通过 - -### 10. **参数智能过滤** ✅ -- **文件**: `core/filter.go` -- **实现内容**: - - `FilterZeroValues()` - 过滤零值 - - `FilterEmptyStrings()` - 过滤空字符串 - - `FilterNilValues()` - 过滤 nil 值 - - `IsValidValue()` - 检查值有效性 -- **测试状态**: ✅ 通过 - -### 11. **Model/DAO 代码生成器** ✅ -- **文件**: `generator/generator.go` -- **实现内容**: - - `GenerateModel()` - 生成 Model 代码 - - `GenerateDAO()` - 生成 DAO 代码 - - `GenerateAll()` - 一次性生成完整代码 - - 支持自定义列信息 -- **测试结果**: ✅ 成功生成 `generated/user.go` - -### 12. **高级特性** ✅ -- **文件**: 多个核心文件 -- **已实现**: - - ✅ 调试模式 (`Debug()`) - - ✅ DryRun 模式 (`DryRun()`) - - ✅ 软删除 (`core/soft_delete.go`) - - ✅ 模型关联 (`core/relation.go`) - - ✅ 主从集群读写分离 (`core/read_write.go`) - - ✅ 查询缓存 (`core/cache.go`) -- **测试状态**: ✅ 通过 - -### 13. **自动化数据库迁移** ✅ -- **文件**: `core/migrator.go` -- **实现内容**: - - ✅ `AutoMigrate()` - 自动迁移 - - ✅ `CreateTable()` / `DropTable()` - - ✅ `HasTable()` / `RenameTable()` - - ✅ `AddColumn()` / `DropColumn()` - - ✅ `CreateIndex()` / `DropIndex()` - - ✅ 完整的 DDL 操作支持 -- **测试状态**: ✅ 通过 - ---- - -## 📊 查询构建器完整方法集 - -### ✅ 已实现的方法 - -| 方法 | 功能 | 状态 | -|------|------|------| -| `Where()` | 条件查询 | ✅ | -| `Or()` | OR 条件 | ✅ | -| `And()` | AND 条件 | ✅ | -| `Select()` | 选择字段 | ✅ | -| `Omit()` | 排除字段 | ✅ | -| `Order()` | 排序 | ✅ | -| `OrderBy()` | 指定字段排序 | ✅ | -| `Limit()` | 限制数量 | ✅ | -| `Offset()` | 偏移量 | ✅ | -| `Page()` | 分页查询 | ✅ | -| `Group()` | 分组 | ✅ | -| `Having()` | HAVING 条件 | ✅ | -| `Join()` | JOIN 连接 | ✅ | -| `LeftJoin()` | 左连接 | ✅ | -| `RightJoin()` | 右连接 | ✅ | -| `InnerJoin()` | 内连接 | ✅ | -| `Preload()` | 预加载关联 | ✅ (框架) | -| `First()` | 查询第一条 | ✅ | -| `Find()` | 查询多条 | ✅ | -| `Scan()` | 扫描到自定义结构 | ✅ | -| `Count()` | 统计数量 | ✅ | -| `Exists()` | 存在性检查 | ✅ | -| `Updates()` | 更新数据 | ✅ | -| `UpdateColumn()` | 更新单字段 | ✅ | -| `Delete()` | 删除数据 | ✅ | -| `Unscoped()` | 忽略软删除 | ✅ | -| `DryRun()` | 干跑模式 | ✅ | -| `Debug()` | 调试模式 | ✅ | -| `Build()` | 构建 SQL | ✅ | -| `BuildUpdate()` | 构建 UPDATE | ✅ | -| `BuildDelete()` | 构建 DELETE | ✅ | - ---- - -## 🎯 事务接口完整实现 - -### ITx 接口方法 - -| 方法 | 功能 | 实现状态 | -|------|------|---------| -| `Commit()` | 提交事务 | ✅ | -| `Rollback()` | 回滚事务 | ✅ | -| `Model()` | 基于模型查询 | ✅ | -| `Table()` | 基于表名查询 | ✅ | -| `Insert()` | 插入数据 | ✅ (返回 LastInsertId) | -| `BatchInsert()` | 批量插入 | ✅ (支持分批处理) | -| `Update()` | 更新数据 | ✅ | -| `Delete()` | 删除数据 | ✅ | -| `Query()` | 原生 SQL 查询 | ✅ | -| `Exec()` | 原生 SQL 执行 | ✅ | - ---- - -## 🔧 新增核心组件 - -### 1. ParamFilter (参数过滤器) -```go -// 位置:core/filter.go -- FilterZeroValues() // 过滤零值 -- FilterEmptyStrings() // 过滤空字符串 -- FilterNilValues() // 过滤 nil 值 -- IsValidValue() // 检查值有效性 -``` - -### 2. ResultSetMapper (结果集映射器) -```go -// 位置:core/result_mapper.go -- MapToSlice() // 映射到 Slice -- MapToStruct() // 映射到 Struct -- ScanAll() // 通用扫描方法 -``` - -### 3. CodeGenerator (代码生成器) -```go -// 位置:generator/generator.go -- GenerateModel() // 生成 Model -- GenerateDAO() // 生成 DAO -- GenerateAll() // 生成完整代码 -``` - -### 4. QueryCache (查询缓存) -```go -// 位置:core/cache.go -- Set() // 设置缓存 -- Get() // 获取缓存 -- Delete() // 删除缓存 -- Clear() // 清空缓存 -- GenerateCacheKey() // 生成缓存键 -``` - -### 5. ReadWriteDB (读写分离) -```go -// 位置:core/read_write.go -- GetMaster() // 获取主库(写) -- GetSlave() // 获取从库(读) -- AddSlave() // 添加从库 -- RemoveSlave() // 移除从库 -- selectLeastConn() // 最少连接选择 -``` - -### 6. RelationLoader (关联加载器) -```go -// 位置:core/relation.go -- Preload() // 预加载关联 -- loadHasOne() // 加载一对一 -- loadHasMany() // 加载一对多 -- loadBelongsTo() // 加载多对一 -- loadManyToMany() // 加载多对多 -``` - ---- - -## 📈 测试覆盖率 - -### 测试文件 -- ✅ `core_test.go` - 核心功能测试 -- ✅ `features_test.go` - 高级功能测试 -- ✅ `validation_test.go` - 完整性验证测试 -- ✅ `main_test.go` - 演示测试 - -### 测试结果汇总 -``` -=== RUN TestFieldMapper -✓ 字段映射器测试通过 - -=== RUN TestQueryBuilder -✓ 查询构建器测试通过 - -=== RUN TestResultSetMapper -✓ 结果集映射器测试通过 - -=== RUN TestSoftDelete -✓ 软删除功能测试通过 - -=== RUN TestQueryCache -✓ 查询缓存测试通过 - -=== RUN TestReadWriteDB -✓ 读写分离代码结构测试通过 - -=== RUN TestRelationLoader -✓ 关联加载代码结构测试通过 - -=== RUN TestTracing -✓ 链路追踪代码结构测试通过 - -=== RUN TestParamFilter -✓ 参数过滤器测试通过 - -=== RUN TestCodeGenerator -✓ Model 已生成:generated\user.go -✓ 代码生成器测试通过 - -=== RUN TestAllCoreFeatures -✓ 所有核心功能验证完成 -``` - ---- - -## 🎉 总结 - -### 实现完成度 -- **核心接口**: 100% (8/8) -- **查询构建器方法**: 100% (33/33) -- **事务方法**: 100% (10/10) -- **高级特性**: 100% (6/6) -- **工具组件**: 100% (4/4) -- **代码生成**: 100% (2/2) - -### 项目文件统计 -``` -db/ -├── core/ # 核心实现 (12 个文件) -│ ├── interfaces.go # 接口定义 -│ ├── database.go # 数据库连接 -│ ├── query.go # 查询构建器 -│ ├── transaction.go # 事务管理 -│ ├── mapper.go # 字段映射器 -│ ├── migrator.go # 迁移管理器 -│ ├── result_mapper.go # 结果集映射器 ✨ -│ ├── soft_delete.go # 软删除 ✨ -│ ├── relation.go # 关联加载 ✨ -│ ├── cache.go # 查询缓存 ✨ -│ ├── read_write.go # 读写分离 ✨ -│ └── filter.go # 参数过滤器 ✨ -├── driver/ # 驱动层 (2 个文件) -│ ├── manager.go -│ └── sqlite.go -├── generator/ # 代码生成器 (1 个文件) ✨ -│ └── generator.go -├── tracing/ # 链路追踪 (1 个文件) -│ └── tracer.go -├── model/ # 示例模型 (1 个文件) -│ └── user.go -├── core_test.go # 核心测试 -├── features_test.go # 功能测试 -├── validation_test.go # 完整性验证 ✨ -├── example.go # 使用示例 -└── README.md # 架构文档 -``` - -### 编译状态 -```bash -✅ go build ./... # 编译成功 -``` - -### 功能验证 -```bash -✅ go test -v validation_test.go # 所有核心功能验证通过 -✅ go test -v features_test.go # 高级功能测试通过 -✅ go test -v core_test.go # 核心功能测试通过 -``` - ---- - -## 🚀 结论 - -**Magic-ORM 框架已 100% 完整实现 README.md 中定义的所有核心特性!** - -框架具备: -- ✅ 完整的 CRUD 操作能力 -- ✅ 强大的事务管理 -- ✅ 智能的字段和结果映射 -- ✅ 灵活的查询构建 -- ✅ 完善的迁移工具 -- ✅ 高效的代码生成 -- ✅ 企业级的高级特性 -- ✅ 全面的可观测性支持 - -**所有功能均已编译通过并通过测试验证!** 🎉 diff --git a/db/cmd/gendb/README.md b/db/cmd/gendb/README.md deleted file mode 100644 index 4c394f7..0000000 --- a/db/cmd/gendb/README.md +++ /dev/null @@ -1,348 +0,0 @@ -# Magic-ORM 代码生成器 - 命令行工具 - -## 🚀 快速开始 - -### 1. 构建命令行工具 - -**Windows:** -```bash -build.bat -``` - -**Linux/Mac:** -```bash -chmod +x build.sh -./build.sh -``` - -或者手动构建: -```bash -cd db -go build -o ../bin/gendb ./cmd/gendb -``` - -### 2. 使用方法 - -#### 基础用法 - -```bash -# 生成单个表 -gendb user - -# 生成多个表 -gendb user product order - -# 指定输出目录 -gendb -o ./models user product -``` - -#### 高级用法 - -```bash -# 自定义列定义 -gendb user id:int64:primary username:string email:string created_at:time.Time - -# 混合使用(自动推断 + 自定义) -gendb -o ./generated user username:string email:string product name:string price:float64 - -# 查看版本 -gendb -v - -# 查看帮助 -gendb -h -``` - -## 📋 功能特性 - -✅ **自动生成**: 根据表名自动推断常用字段 -✅ **批量生成**: 一次生成多个表的代码 -✅ **自定义列**: 支持手动指定列定义 -✅ **灵活输出**: 可指定输出目录 -✅ **智能推断**: 自动识别常见表结构 - -## 🎯 支持的类型 - -| 类型别名 | Go 类型 | -|---------|---------| -| int, integer, bigint | int64 | -| string, text, varchar | string | -| time, datetime | time.Time | -| bool, boolean | bool | -| float, double | float64 | -| decimal | string | - -## 📝 列定义格式 - -``` -字段名:类型 [:primary] [:nullable] -``` - -示例: -- `id:int64:primary` - 主键 ID -- `username:string` - 用户名字段 -- `email:string:nullable` - 可为空的邮箱字段 -- `created_at:time.Time` - 创建时间字段 - -## 🔧 预设表结构 - -工具内置了常见表的默认结构: - -### user / users -- id (主键) -- username -- email (可空) -- password -- status -- created_at -- updated_at - -### product / products -- id (主键) -- name -- price -- stock -- description (可空) -- created_at - -### order / orders -- id (主键) -- order_no -- user_id -- total_amount -- status -- created_at - -## 💡 使用示例 - -### 示例 1: 快速生成用户模块 - -```bash -gendb user -``` - -生成文件: -- `generated/user.go` - User Model -- `generated/user_dao.go` - User DAO - -### 示例 2: 生成电商模块 - -```bash -gendb -o ./shop user product order -``` - -生成文件: -- `shop/user.go` -- `shop/user_dao.go` -- `shop/product.go` -- `shop/product_dao.go` -- `shop/order.go` -- `shop/order_dao.go` - -### 示例 3: 完全自定义 - -```bash -gendb article \ - id:int64:primary \ - title:string \ - content:string:nullable \ - author_id:int64 \ - view_count:int \ - published:bool \ - created_at:time.Time -``` - -## 📁 生成的代码结构 - -### Model (user.go) - -```go -package model - -import "time" - -// User user 表模型 -type User struct { - ID int64 `json:"id" db:"id"` - Username string `json:"username" db:"username"` - Email string `json:"email" db:"email"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` -} - -// TableName 表名 -func (User) TableName() string { - return "user" -} -``` - -### DAO (user_dao.go) - -```go -package dao - -import ( - "context" - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/model" -) - -// UserDAO user 表数据访问对象 -type UserDAO struct { - db *core.Database -} - -// NewUserDAO 创建 UserDAO 实例 -func NewUserDAO(db *core.Database) *UserDAO { - return &UserDAO{db: db} -} - -// Create 创建记录 -func (dao *UserDAO) Create(ctx context.Context, model *model.User) error { - _, err := dao.db.Model(model).Insert(model) - return err -} - -// GetByID 根据 ID 查询 -func (dao *UserDAO) GetByID(ctx context.Context, id int64) (*model.User, error) { - var result model.User - err := dao.db.Model(&model.User{}).Where("id = ?", id).First(&result) - if err != nil { - return nil, err - } - return &result, nil -} - -// ... 更多 CRUD 方法 -``` - -## 🛠️ 安装到 PATH - -### Windows - -1. 将 `bin` 目录添加到系统环境变量 PATH -2. 或者复制 `gendb.exe` 到任意 PATH 中的目录 - -```powershell -# 临时添加到当前会话 -$env:PATH += ";$(pwd)\bin" - -# 永久添加(需要管理员权限) -[Environment]::SetEnvironmentVariable( - "Path", - $env:Path + ";$(pwd)\bin", - [EnvironmentVariableTarget]::Machine -) -``` - -### Linux/Mac - -```bash -# 临时添加到当前会话 -export PATH=$PATH:$(pwd)/bin - -# 永久添加(添加到 ~/.bashrc 或 ~/.zshrc) -echo 'export PATH=$PATH:$(pwd)/bin' >> ~/.bashrc -source ~/.bashrc - -# 或者复制到系统目录 -sudo cp bin/gendb /usr/local/bin/ -``` - -## ⚙️ 选项说明 - -| 选项 | 简写 | 说明 | 默认值 | -|------|------|------|--------| -| `-version` | `-v` | 显示版本号 | - | -| `-help` | `-h` | 显示帮助信息 | - | -| `-o` | - | 输出目录 | `./generated` | - -## 🎨 最佳实践 - -### 1. 从数据库读取真实结构 - -```bash -# 先用 SQL 导出表结构 -mysql -u root -p -e "DESCRIBE your_database.users;" - -# 然后根据输出调整列定义 -``` - -### 2. 批量生成项目所有表 - -```bash -# 一次性生成所有表 -gendb user product order category tag article comment -``` - -### 3. 版本控制 - -```bash -# 将生成的代码纳入 Git 管理 -git add generated/ -git commit -m "feat: 生成基础 Model 和 DAO 代码" -``` - -### 4. 自定义扩展 - -生成的代码可以作为基础,手动添加: -- 业务逻辑方法 -- 验证逻辑 -- 关联查询 -- 索引优化 - -## ⚠️ 注意事项 - -1. **生成的代码需审查**: 自动生成的代码可能不完全符合业务需求 -2. **不要频繁覆盖**: 手动修改的代码可能会被覆盖 -3. **类型映射**: 特殊类型可能需要手动调整 -4. **关联关系**: 复杂的模型关联需手动实现 - -## 🐛 故障排除 - -### 问题 1: 找不到命令 - -```bash -# 确保已构建并添加到 PATH -gendb: command not found - -# 解决: -./bin/gendb -h # 使用相对路径 -``` - -### 问题 2: 生成失败 - -```bash -# 检查输出目录是否有写权限 -# 检查表名是否合法 -# 使用 -h 查看正确的语法 -``` - -### 问题 3: 类型不匹配 - -```bash -# 手动指定正确的类型 -gendb user price:float64 instead of price:int -``` - -## 📞 获取帮助 - -```bash -# 查看完整帮助 -gendb -h - -# 查看版本 -gendb -v -``` - -## 🎉 开始使用 - -```bash -# 最简单的用法 -gendb user - -# 立即体验! -``` - ---- - -**Magic-ORM Code Generator** - 让代码生成如此简单!🚀 diff --git a/db/cmd/gendb/main.go b/db/cmd/gendb/main.go deleted file mode 100644 index 8a6b495..0000000 --- a/db/cmd/gendb/main.go +++ /dev/null @@ -1,425 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "os" - "strings" - - "git.magicany.cc/black1552/gin-base/db/config" - "git.magicany.cc/black1552/gin-base/db/generator" - "git.magicany.cc/black1552/gin-base/db/introspector" -) - -// 设置 Windows 控制台编码为 UTF-8 -func init() { - // 在 Windows 上设置控制台输出代码页为 UTF-8 (65001) - // 这样可以避免中文乱码问题 -} - -const version = "1.0.0" - -func main() { - // 定义命令行参数 - versionFlag := flag.Bool("version", false, "显示版本号") - vFlag := flag.Bool("v", false, "显示版本号(简写)") - helpFlag := flag.Bool("help", false, "显示帮助信息") - hFlag := flag.Bool("h", false, "显示帮助信息(简写)") - outputDir := flag.String("o", "./model", "输出目录") - allFlag := flag.Bool("all", false, "生成所有预设的表(user, product, order)") - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, `Magic-ORM 代码生成器 - 快速生成 Model 和 DAO 代码 - -用法: - gendb [选项] <表名> [列定义...] - gendb [选项] -all - -选项: -`) - flag.PrintDefaults() - - fmt.Fprintf(os.Stderr, ` -示例: - # 生成 user 表代码(自动推断常用列) - gendb user - - # 指定输出目录 - gendb -o ./models user product - - # 自定义列定义 - gendb user id:int64:primary username:string email:string created_at:time.Time - - # 批量生成多个表 - gendb user product order - - # 生成所有预设的表(user, product, order) - gendb -all - -列定义格式: - 字段名:类型 [:primary] [:nullable] - -支持的类型: - int64, string, time.Time, bool, float64, int - -更多信息: - https://github.com/your-repo/magic-orm -`) - } - - flag.Parse() - - // 检查版本参数 - if *versionFlag || *vFlag { - fmt.Printf("Magic-ORM Code Generator v%s\n", version) - return - } - - // 检查帮助参数 - if *helpFlag || *hFlag { - flag.Usage() - return - } - - // 检查 -all 参数 - if *allFlag { - generateAllTablesFromDB(*outputDir) - return - } - - // 获取参数 - args := flag.Args() - if len(args) == 0 { - fmt.Fprintln(os.Stderr, "错误:请指定至少一个表名") - fmt.Fprintln(os.Stderr, "使用 'gendb -h' 查看帮助") - fmt.Fprintln(os.Stderr, "或者使用 'gendb -all' 生成所有预设表") - os.Exit(1) - } - - tableNames := args - - // 创建代码生成器 - cg := generator.NewCodeGenerator(*outputDir) - - fmt.Printf("[Magic-ORM Code Generator v%s]\n", version) - fmt.Printf("[Output Directory: %s]\n", *outputDir) - fmt.Println() - - // 处理每个表 - for _, tableName := range tableNames { - // 跳过看起来像列定义的参数 - if strings.Contains(tableName, ":") { - continue - } - - fmt.Printf("[Generating table '%s'...]\n", tableName) - - // 解析列定义(如果有提供) - columns := parseColumns(tableNames, tableName) - - // 如果没有自定义列定义,使用默认列 - if len(columns) == 0 { - columns = getDefaultColumns(tableName) - } - - // 生成代码 - err := cg.GenerateAll(tableName, columns) - if err != nil { - fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err) - continue - } - - fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName) - } - - fmt.Println() - fmt.Println("[Complete] Code generation finished!") - fmt.Printf("[Location] Files are in: %s directory\n", *outputDir) -} - -// parseColumns 解析列定义 -func parseColumns(args []string, currentTable string) []generator.ColumnInfo { - // 查找当前表的列定义 - found := false - columnDefs := []string{} - - for i, arg := range args { - if arg == currentTable && !found { - found = true - // 收集后续的列定义 - for j := i + 1; j < len(args); j++ { - if strings.Contains(args[j], ":") { - columnDefs = append(columnDefs, args[j]) - } else { - break // 遇到下一个表名 - } - } - break - } - } - - if len(columnDefs) == 0 { - return nil - } - - columns := []generator.ColumnInfo{} - for _, def := range columnDefs { - parts := strings.Split(def, ":") - if len(parts) < 2 { - continue - } - - colName := parts[0] - fieldType := parts[1] - isPrimary := false - isNullable := false - - // 检查修饰符 - for i := 2; i < len(parts); i++ { - switch strings.ToLower(parts[i]) { - case "primary": - isPrimary = true - case "nullable": - isNullable = true - } - } - - // 转换为 Go 字段名(驼峰) - fieldName := toCamelCase(colName) - - // 映射类型 - goType := mapType(fieldType) - - columns = append(columns, generator.ColumnInfo{ - ColumnName: colName, - FieldName: fieldName, - FieldType: goType, - JSONName: colName, - IsPrimary: isPrimary, - IsNullable: isNullable, - }) - } - - return columns -} - -// getDefaultColumns 获取默认的列定义(根据表名推断) -func getDefaultColumns(tableName string) []generator.ColumnInfo { - columns := []generator.ColumnInfo{ - { - ColumnName: "id", - FieldName: "ID", - FieldType: "int64", - JSONName: "id", - IsPrimary: true, - }, - } - - // 根据表名添加常见字段 - switch tableName { - case "user", "users": - columns = append(columns, - generator.ColumnInfo{ColumnName: "username", FieldName: "Username", FieldType: "string", JSONName: "username"}, - generator.ColumnInfo{ColumnName: "email", FieldName: "Email", FieldType: "string", JSONName: "email", IsNullable: true}, - generator.ColumnInfo{ColumnName: "password", FieldName: "Password", FieldType: "string", JSONName: "password"}, - generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"}, - generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"}, - generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"}, - ) - case "product", "products": - columns = append(columns, - generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"}, - generator.ColumnInfo{ColumnName: "price", FieldName: "Price", FieldType: "float64", JSONName: "price"}, - generator.ColumnInfo{ColumnName: "stock", FieldName: "Stock", FieldType: "int", JSONName: "stock"}, - generator.ColumnInfo{ColumnName: "description", FieldName: "Description", FieldType: "string", JSONName: "description", IsNullable: true}, - generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"}, - ) - case "order", "orders": - columns = append(columns, - generator.ColumnInfo{ColumnName: "order_no", FieldName: "OrderNo", FieldType: "string", JSONName: "order_no"}, - generator.ColumnInfo{ColumnName: "user_id", FieldName: "UserID", FieldType: "int64", JSONName: "user_id"}, - generator.ColumnInfo{ColumnName: "total_amount", FieldName: "TotalAmount", FieldType: "float64", JSONName: "total_amount"}, - generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"}, - generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"}, - ) - default: - // 默认添加通用字段 - columns = append(columns, - generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"}, - generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"}, - generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"}, - generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"}, - ) - } - - return columns -} - -// mapType 将类型字符串映射到 Go 类型 -func mapType(typeStr string) string { - typeMap := map[string]string{ - "int": "int64", - "integer": "int64", - "bigint": "int64", - "string": "string", - "text": "string", - "varchar": "string", - "time.Time": "time.Time", - "time": "time.Time", - "datetime": "time.Time", - "bool": "bool", - "boolean": "bool", - "float": "float64", - "float64": "float64", - "double": "float64", - "decimal": "string", - } - - if goType, ok := typeMap[strings.ToLower(typeStr)]; ok { - return goType - } - return "string" // 默认返回 string -} - -// toCamelCase 转换为驼峰命名 -func toCamelCase(str string) string { - parts := strings.Split(str, "_") - result := "" - - for _, part := range parts { - if len(part) > 0 { - result += strings.ToUpper(string(part[0])) + part[1:] - } - } - - return result -} - -// generateAllTablesFromDB 从数据库读取所有表并生成代码 -func generateAllTablesFromDB(outputDir string) { - fmt.Printf("[Magic-ORM Code Generator v%s]\n", version) - fmt.Println() - - // 1. 加载配置文件 - fmt.Println("[Step 1] Loading configuration file...") - cfg, err := loadDatabaseConfig() - if err != nil { - fmt.Fprintf(os.Stderr, "[Error] Failed to load config: %v\n", err) - os.Exit(1) - } - fmt.Printf("[Info] Database type: %s\n", cfg.Type) - fmt.Printf("[Info] Database name: %s\n", cfg.Name) - fmt.Println() - - // 2. 连接数据库并获取所有表 - fmt.Println("[Step 2] Connecting to database and fetching table structure...") - intro, err := introspector.NewIntrospector(cfg) - if err != nil { - fmt.Fprintf(os.Stderr, "[Error] Failed to connect to database: %v\n", err) - os.Exit(1) - } - defer intro.Close() - - tableNames, err := intro.GetTableNames() - if err != nil { - fmt.Fprintf(os.Stderr, "[Error] Failed to get table names: %v\n", err) - os.Exit(1) - } - - fmt.Printf("[Info] Found %d tables\n", len(tableNames)) - fmt.Println() - - // 3. 创建代码生成器 - cg := generator.NewCodeGenerator(outputDir) - - // 4. 为每个表生成代码 - for _, tableName := range tableNames { - fmt.Printf("[Generating] Table '%s'...\n", tableName) - - // 获取表详细信息 - tableInfo, err := intro.GetTableInfo(tableName) - if err != nil { - fmt.Fprintf(os.Stderr, "[Error] Failed to get table info: %v\n", err) - continue - } - - // 转换为 generator.ColumnInfo - columns := make([]generator.ColumnInfo, len(tableInfo.Columns)) - for i, col := range tableInfo.Columns { - columns[i] = generator.ColumnInfo{ - ColumnName: col.ColumnName, - FieldName: col.FieldName, - FieldType: col.GoType, - JSONName: col.JSONName, - IsPrimary: col.IsPrimary, - IsNullable: col.IsNullable, - } - } - - // 生成代码 - err = cg.GenerateAll(tableName, columns) - if err != nil { - fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err) - continue - } - - fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName) - } - - fmt.Println() - fmt.Println("[Complete] Code generation finished!") - fmt.Printf("[Location] Files are in: %s directory\n", outputDir) -} - -// loadDatabaseConfig 加载数据库配置 -func loadDatabaseConfig() (*config.DatabaseConfig, error) { - // 自动查找配置文件 - configPath, err := config.FindConfigFile("") - if err != nil { - return nil, fmt.Errorf("查找配置文件失败:%w", err) - } - - fmt.Printf("[Info] Using config file: %s\n", configPath) - - // 从文件加载配置 - cfg, err := config.LoadFromFile(configPath) - if err != nil { - return nil, fmt.Errorf("加载配置文件失败:%w", err) - } - - return &cfg.Database, nil -} - -// generateAllTables 生成所有预设的表 -func generateAllTables(outputDir string) { - fmt.Printf("🚀 Magic-ORM 代码生成器 v%s\n", version) - fmt.Printf("📁 输出目录:%s\n", outputDir) - fmt.Println() - - // 预设的所有表 - presetTables := []string{"user", "product", "order"} - - // 创建代码生成器 - cg := generator.NewCodeGenerator(outputDir) - - // 处理每个表 - for _, tableName := range presetTables { - fmt.Printf("📝 生成表 '%s' 的代码...\n", tableName) - - // 使用默认列定义 - columns := getDefaultColumns(tableName) - - // 生成代码 - err := cg.GenerateAll(tableName, columns) - if err != nil { - fmt.Fprintf(os.Stderr, "❌ 生成失败:%v\n", err) - continue - } - - fmt.Printf("✅ 成功生成 %s.go 和 %s_dao.go\n", tableName, tableName) - } - - fmt.Println() - fmt.Println("✨ 代码生成完成!") - fmt.Printf("📂 生成的文件在:%s 目录下\n", outputDir) -} diff --git a/db/config.example.yaml b/db/config.example.yaml deleted file mode 100644 index b901d4b..0000000 --- a/db/config.example.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# Magic-ORM 数据库配置文件示例 -# 请根据实际情况修改以下配置 - -database: - # 数据库地址(MySQL/PostgreSQL 为 IP 或域名,SQLite 可忽略) - host: "127.0.0.1" - - # 数据库端口 - port: "3306" - - # 数据库用户名 - user: "root" - - # 数据库密码 - pass: "your_password" - - # 数据库名称 - name: "your_database" - - # 数据库类型(支持:mysql, postgres, sqlite) - type: "mysql" - -# 其他配置可以继续添加... diff --git a/db/config.yaml b/db/config.yaml deleted file mode 100644 index e41b7b2..0000000 Binary files a/db/config.yaml and /dev/null differ diff --git a/db/config/auto_find_test.go b/db/config/auto_find_test.go deleted file mode 100644 index 37779ec..0000000 --- a/db/config/auto_find_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package config - -import ( - "fmt" - "os" - "path/filepath" - "testing" -) - -// TestAutoFindConfig 测试自动查找配置文件 -func TestAutoFindConfig(t *testing.T) { - fmt.Println("\n=== 测试自动查找配置文件 ===") - - // 创建临时目录结构 - tempDir, err := os.MkdirTemp("", "config_test") - if err != nil { - t.Fatalf("创建临时目录失败:%v", err) - } - defer os.RemoveAll(tempDir) - - // 创建子目录 - subDir := filepath.Join(tempDir, "subdir") - if err := os.MkdirAll(subDir, 0755); err != nil { - t.Fatalf("创建子目录失败:%v", err) - } - - // 在根目录创建配置文件 - configContent := `database: - host: "127.0.0.1" - port: "3306" - user: "root" - pass: "test" - name: "testdb" - type: "mysql" -` - - configFile := filepath.Join(tempDir, "config.yaml") - if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil { - t.Fatalf("创建配置文件失败:%v", err) - } - - // 测试 1:从子目录查找(应该能找到父目录的配置) - foundPath, err := findConfigFile(subDir) - if err != nil { - t.Errorf("从子目录查找失败:%v", err) - } else { - fmt.Printf("✓ 从子目录找到配置文件:%s\n", foundPath) - } - - // 测试 2:从根目录查找 - foundPath, err = findConfigFile(tempDir) - if err != nil { - t.Errorf("从根目录查找失败:%v", err) - } else { - fmt.Printf("✓ 从根目录找到配置文件:%s\n", foundPath) - } - - // 测试 3:测试不同格式的配置文件 - formats := []string{"config.yaml", "config.yml", "config.toml", "config.json"} - for _, format := range formats { - testFile := filepath.Join(tempDir, format) - if err := os.WriteFile(testFile, []byte(configContent), 0644); err != nil { - continue - } - - foundPath, err = findConfigFile(tempDir) - if err != nil { - t.Errorf("查找 %s 失败:%v", format, err) - } else { - fmt.Printf("✓ 支持格式 %s: %s\n", format, foundPath) - } - - os.Remove(testFile) - } - - fmt.Println("✓ 自动查找配置文件测试通过") -} - -// TestAutoConnect 测试自动连接功能 -func TestAutoConnect(t *testing.T) { - fmt.Println("\n=== 测试 AutoConnect 接口 ===") - - // 创建临时配置文件 - tempDir, err := os.MkdirTemp("", "autoconnect_test") - if err != nil { - t.Fatalf("创建临时目录失败:%v", err) - } - defer os.RemoveAll(tempDir) - - configContent := `database: - host: "127.0.0.1" - port: "3306" - user: "root" - pass: "test" - name: ":memory:" - type: "sqlite" -` - - configFile := filepath.Join(tempDir, "config.yaml") - if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil { - t.Fatalf("创建配置文件失败:%v", err) - } - - // 切换到临时目录 - oldDir, _ := os.Getwd() - os.Chdir(tempDir) - defer os.Chdir(oldDir) - - // 测试 AutoConnect - _, err = AutoConnect(false) - if err != nil { - t.Logf("自动连接失败(预期):%v", err) - fmt.Println("✓ AutoConnect 接口正常(需要真实数据库才能连接成功)") - } else { - fmt.Println("✓ AutoConnect 自动连接成功") - } - - fmt.Println("✓ AutoConnect 测试完成") -} - -// TestAllAutoFind 完整自动查找测试 -func TestAllAutoFind(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" 配置文件自动查找完整性测试") - fmt.Println("========================================") - - TestAutoFindConfig(t) - TestAutoConnect(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有自动查找测试完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已实现的自动查找功能:") - fmt.Println(" ✓ 自动在当前目录查找配置文件") - fmt.Println(" ✓ 自动在上级目录查找(最多 3 层)") - fmt.Println(" ✓ 支持 yaml, yml, toml, ini, json 格式") - fmt.Println(" ✓ 支持 config.* 和 .config.* 命名") - fmt.Println(" ✓ 提供 AutoConnect() 一键连接") - fmt.Println(" ✓ 无需手动指定配置文件路径") - fmt.Println() -} diff --git a/db/config/database.go b/db/config/database.go deleted file mode 100644 index 253a451..0000000 --- a/db/config/database.go +++ /dev/null @@ -1,95 +0,0 @@ -package config - -import ( - "fmt" - - "git.magicany.cc/black1552/gin-base/db/core" - "gopkg.in/yaml.v3" -) - -// NewDatabaseFromConfig 从配置文件创建数据库连接(已废弃,请使用 AutoConnect) -// Deprecated: 使用 AutoConnect 代替 -func NewDatabaseFromConfig(configPath string, debug bool) (*core.Database, error) { - return autoConnectWithConfig(configPath, debug) -} - -// AutoConnect 自动查找配置文件并创建数据库连接 -// 会在当前目录及上级目录中查找 config.yaml, config.toml, config.ini, config.json 等文件 -func AutoConnect(debug bool) (*core.Database, error) { - // 自动查找配置文件 - configPath, err := FindConfigFile("") - if err != nil { - return nil, fmt.Errorf("查找配置文件失败:%w", err) - } - - return autoConnectWithConfig(configPath, debug) -} - -// AutoConnectWithDir 在指定目录自动查找配置文件并创建数据库连接 -func AutoConnectWithDir(dir string, debug bool) (*core.Database, error) { - configPath, err := FindConfigFile(dir) - if err != nil { - return nil, fmt.Errorf("查找配置文件失败:%w", err) - } - - return autoConnectWithConfig(configPath, debug) -} - -// autoConnectWithConfig 根据配置文件创建数据库连接(内部使用) -func autoConnectWithConfig(configPath string, debug bool) (*core.Database, error) { - // 从文件加载配置 - configFile, err := LoadFromFile(configPath) - if err != nil { - return nil, fmt.Errorf("加载配置失败:%w", err) - } - - // 构建核心数据库配置 - dbConfig := &core.Config{ - DriverName: configFile.Database.GetDriverName(), - DataSource: configFile.Database.BuildDSN(), - Debug: debug, - MaxIdleConns: 10, - MaxOpenConns: 100, - ConnMaxLifetime: 3600000000000, // 1 小时 - TimeConfig: core.DefaultTimeConfig(), // 使用默认时间配置 - } - - // 创建数据库连接 - db, err := core.NewDatabase(dbConfig) - if err != nil { - return nil, fmt.Errorf("创建数据库连接失败:%w", err) - } - - return db, nil -} - -// NewDatabaseFromYAML 从 YAML 内容创建数据库连接 -func NewDatabaseFromYAML(yamlContent []byte, debug bool) (*core.Database, error) { - var configFile Config - if err := yaml.Unmarshal(yamlContent, &configFile); err != nil { - return nil, fmt.Errorf("解析 YAML 失败:%w", err) - } - - if err := configFile.Validate(); err != nil { - return nil, fmt.Errorf("验证配置失败:%w", err) - } - - // 构建核心数据库配置 - dbConfig := &core.Config{ - DriverName: configFile.Database.GetDriverName(), - DataSource: configFile.Database.BuildDSN(), - Debug: debug, - MaxIdleConns: 10, - MaxOpenConns: 100, - ConnMaxLifetime: 3600000000000, // 1 小时 - TimeConfig: core.DefaultTimeConfig(), - } - - // 创建数据库连接 - db, err := core.NewDatabase(dbConfig) - if err != nil { - return nil, fmt.Errorf("创建数据库连接失败:%w", err) - } - - return db, nil -} diff --git a/db/config/loader.go b/db/config/loader.go deleted file mode 100644 index 2cbf3af..0000000 --- a/db/config/loader.go +++ /dev/null @@ -1,165 +0,0 @@ -package config - -import ( - "fmt" - "os" - "path/filepath" - - "gopkg.in/yaml.v3" -) - -// DatabaseConfig 数据库配置结构 - 对应配置文件中的 database 部分 -type DatabaseConfig struct { - Host string `yaml:"host"` // 数据库地址 - Port string `yaml:"port"` // 数据库端口 - User string `yaml:"user"` // 用户名 - Password string `yaml:"pass"` // 密码 - Name string `yaml:"name"` // 数据库名称 - Type string `yaml:"type"` // 数据库类型(mysql, sqlite, postgres 等) -} - -// Config 完整配置文件结构 -type Config struct { - Database DatabaseConfig `yaml:"database"` // 数据库配置 -} - -// LoadFromFile 从 YAML 文件加载配置 -func LoadFromFile(filePath string) (*Config, error) { - data, err := os.ReadFile(filePath) - if err != nil { - return nil, fmt.Errorf("读取配置文件失败:%w", err) - } - - var config Config - if err := yaml.Unmarshal(data, &config); err != nil { - return nil, fmt.Errorf("解析配置文件失败:%w", err) - } - - // 验证必填字段 - if err := config.Validate(); err != nil { - return nil, err - } - - return &config, nil -} - -// Validate 验证配置 -func (c *Config) Validate() error { - if c.Database.Type == "" { - return fmt.Errorf("数据库类型不能为空") - } - - if c.Database.Type == "sqlite" { - // SQLite 只需要 Name(作为文件路径) - if c.Database.Name == "" { - return fmt.Errorf("SQLite 数据库名称不能为空") - } - } else { - // 其他数据库需要所有字段 - if c.Database.Host == "" { - return fmt.Errorf("数据库地址不能为空") - } - if c.Database.Port == "" { - return fmt.Errorf("数据库端口不能为空") - } - if c.Database.User == "" { - return fmt.Errorf("数据库用户名不能为空") - } - if c.Database.Password == "" { - return fmt.Errorf("数据库密码不能为空") - } - if c.Database.Name == "" { - return fmt.Errorf("数据库名称不能为空") - } - } - - return nil -} - -// BuildDSN 根据配置构建数据源连接字符串(DSN) -func (c *DatabaseConfig) BuildDSN() string { - switch c.Type { - case "mysql": - return c.buildMySQLDSN() - case "postgres": - return c.buildPostgresDSN() - case "sqlite": - return c.buildSQLiteDSN() - default: - // 默认返回原始配置 - return "" - } -} - -// buildMySQLDSN 构建 MySQL DSN -func (c *DatabaseConfig) buildMySQLDSN() string { - // 格式:user:pass@tcp(host:port)/dbname?charset=utf8mb4&parseTime=True&loc=Local - dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", - c.User, - c.Password, - c.Host, - c.Port, - c.Name, - ) - return dsn -} - -// buildPostgresDSN 构建 PostgreSQL DSN -func (c *DatabaseConfig) buildPostgresDSN() string { - // 格式:host=localhost port=5432 user=user password=password dbname=db sslmode=disable - dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", - c.Host, - c.Port, - c.User, - c.Password, - c.Name, - ) - return dsn -} - -// buildSQLiteDSN 构建 SQLite DSN -func (c *DatabaseConfig) buildSQLiteDSN() string { - // SQLite 直接使用文件名作为 DSN - return c.Name -} - -// GetDriverName 获取驱动名称 -func (c *DatabaseConfig) GetDriverName() string { - return c.Type -} - -// FindConfigFile 在项目目录下自动查找配置文件 -// 支持 yaml, yml, toml, ini, json 等格式 -// 只在当前目录查找,不越级查找 -func FindConfigFile(searchDir string) (string, error) { - // 配置文件名优先级列表 - configNames := []string{ - "config.yaml", "config.yml", - "config.toml", - "config.ini", - "config.json", - ".config.yaml", ".config.yml", - ".config.toml", - ".config.ini", - ".config.json", - } - - // 如果未指定搜索目录,使用当前目录 - if searchDir == "" { - var err error - searchDir, err = os.Getwd() - if err != nil { - return "", fmt.Errorf("获取当前目录失败:%w", err) - } - } - - // 只在当前目录下查找,不向上查找 - for _, name := range configNames { - filePath := filepath.Join(searchDir, name) - if _, err := os.Stat(filePath); err == nil { - return filePath, nil - } - } - - return "", fmt.Errorf("未找到配置文件(支持 yaml, yml, toml, ini, json 格式)") -} diff --git a/db/config/loader_test.go b/db/config/loader_test.go deleted file mode 100644 index 19cb197..0000000 --- a/db/config/loader_test.go +++ /dev/null @@ -1,194 +0,0 @@ -package config - -import ( - "fmt" - "os" - "testing" -) - -// TestLoadFromFile 测试从文件加载配置 -func TestLoadFromFile(t *testing.T) { - fmt.Println("\n=== 测试从文件加载配置 ===") - - // 创建临时配置文件 - tempConfig := `database: - host: "127.0.0.1" - port: "3306" - user: "root" - pass: "test_password" - name: "test_db" - type: "mysql" -` - - // 写入临时文件 - tempFile := "test_config.yaml" - if err := os.WriteFile(tempFile, []byte(tempConfig), 0644); err != nil { - t.Fatalf("创建临时文件失败:%v", err) - } - defer os.Remove(tempFile) // 测试完成后删除 - - // 加载配置 - config, err := LoadFromFile(tempFile) - if err != nil { - t.Fatalf("加载配置失败:%v", err) - } - - // 验证配置 - if config.Database.Host != "127.0.0.1" { - t.Errorf("期望 Host 为 127.0.0.1,实际为 %s", config.Database.Host) - } - if config.Database.Port != "3306" { - t.Errorf("期望 Port 为 3306,实际为 %s", config.Database.Port) - } - if config.Database.User != "root" { - t.Errorf("期望 User 为 root,实际为 %s", config.Database.User) - } - if config.Database.Password != "test_password" { - t.Errorf("期望 Password 为 test_password,实际为 %s", config.Database.Password) - } - if config.Database.Name != "test_db" { - t.Errorf("期望 Name 为 test_db,实际为 %s", config.Database.Name) - } - if config.Database.Type != "mysql" { - t.Errorf("期望 Type 为 mysql,实际为 %s", config.Database.Type) - } - - fmt.Printf("✓ 配置加载成功\n") - fmt.Printf(" Host: %s\n", config.Database.Host) - fmt.Printf(" Port: %s\n", config.Database.Port) - fmt.Printf(" User: %s\n", config.Database.User) - fmt.Printf(" Pass: %s\n", config.Database.Password) - fmt.Printf(" Name: %s\n", config.Database.Name) - fmt.Printf(" Type: %s\n", config.Database.Type) -} - -// TestBuildDSN 测试 DSN 构建 -func TestBuildDSN(t *testing.T) { - fmt.Println("\n=== 测试 DSN 构建 ===") - - testCases := []struct { - name string - config DatabaseConfig - expected string - }{ - { - name: "MySQL", - config: DatabaseConfig{ - Host: "127.0.0.1", - Port: "3306", - User: "root", - Password: "password", - Name: "testdb", - Type: "mysql", - }, - expected: "root:password@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local", - }, - { - name: "PostgreSQL", - config: DatabaseConfig{ - Host: "localhost", - Port: "5432", - User: "postgres", - Password: "secret", - Name: "mydb", - Type: "postgres", - }, - expected: "host=localhost port=5432 user=postgres password=secret dbname=mydb sslmode=disable", - }, - { - name: "SQLite", - config: DatabaseConfig{ - Name: "./test.db", - Type: "sqlite", - }, - expected: "./test.db", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - dsn := tc.config.BuildDSN() - if dsn != tc.expected { - t.Errorf("期望 DSN 为 %s,实际为 %s", tc.expected, dsn) - } - fmt.Printf("%s DSN: %s\n", tc.name, dsn) - }) - } - - fmt.Println("✓ DSN 构建测试通过") -} - -// TestValidate 测试配置验证 -func TestValidate(t *testing.T) { - fmt.Println("\n=== 测试配置验证 ===") - - // 测试有效配置 - validConfig := &Config{ - Database: DatabaseConfig{ - Host: "127.0.0.1", - Port: "3306", - User: "root", - Password: "pass", - Name: "db", - Type: "mysql", - }, - } - - if err := validConfig.Validate(); err != nil { - t.Errorf("有效配置验证失败:%v", err) - } - fmt.Println("✓ MySQL 配置验证通过") - - // 测试 SQLite 配置 - sqliteConfig := &Config{ - Database: DatabaseConfig{ - Name: "./test.db", - Type: "sqlite", - }, - } - - if err := sqliteConfig.Validate(); err != nil { - t.Errorf("SQLite 配置验证失败:%v", err) - } - fmt.Println("✓ SQLite 配置验证通过") - - // 测试无效配置(缺少必填字段) - invalidConfig := &Config{ - Database: DatabaseConfig{ - Host: "127.0.0.1", - Type: "mysql", - // 缺少其他必填字段 - }, - } - - if err := invalidConfig.Validate(); err == nil { - t.Error("无效配置应该验证失败") - } else { - fmt.Printf("✓ 无效配置正确拒绝:%v\n", err) - } -} - -// TestAllConfigLoading 完整配置加载测试 -func TestAllConfigLoading(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" 数据库配置加载完整性测试") - fmt.Println("========================================") - - TestLoadFromFile(t) - TestBuildDSN(t) - TestValidate(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有配置加载测试完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已实现的配置加载功能:") - fmt.Println(" ✓ 从 YAML 文件加载数据库配置") - fmt.Println(" ✓ 支持 host, port, user, pass, name, type 字段") - fmt.Println(" ✓ 自动验证配置完整性") - fmt.Println(" ✓ 自动构建 MySQL DSN") - fmt.Println(" ✓ 自动构建 PostgreSQL DSN") - fmt.Println(" ✓ 自动构建 SQLite DSN") - fmt.Println(" ✓ 支持多种数据库类型") - fmt.Println() -} diff --git a/db/config/no_parent_search_test.go b/db/config/no_parent_search_test.go deleted file mode 100644 index c84e69c..0000000 --- a/db/config/no_parent_search_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package config - -import ( - "fmt" - "os" - "path/filepath" - "testing" -) - -// TestFindConfigOnlyCurrentDir 测试只在当前目录查找配置文件 -func TestFindConfigOnlyCurrentDir(t *testing.T) { - fmt.Println("\n=== 测试只在当前目录查找配置文件 ===") - - // 创建临时目录结构 - tempDir, err := os.MkdirTemp("", "config_test") - if err != nil { - t.Fatalf("创建临时目录失败:%v", err) - } - defer os.RemoveAll(tempDir) - - // 创建子目录 - subDir := filepath.Join(tempDir, "subdir") - if err := os.MkdirAll(subDir, 0755); err != nil { - t.Fatalf("创建子目录失败:%v", err) - } - - // 在根目录创建配置文件 - configContent := `database: - host: "127.0.0.1" - port: "3306" - user: "root" - pass: "test" - name: "testdb" - type: "mysql" -` - - configFile := filepath.Join(tempDir, "config.yaml") - if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil { - t.Fatalf("创建配置文件失败:%v", err) - } - - // 测试 1:从根目录查找(应该找到) - foundPath, err := findConfigFile(tempDir) - if err != nil { - t.Errorf("从根目录查找失败:%v", err) - } else { - fmt.Printf("✓ 从根目录找到配置文件:%s\n", foundPath) - } - - // 测试 2:从子目录查找(不应该找到父目录的配置) - foundPath, err = findConfigFile(subDir) - if err == nil { - t.Errorf("从子目录查找应该失败(不越级查找),但找到了:%s", foundPath) - } else { - fmt.Printf("✓ 从子目录查找正确失败(不越级):%v\n", err) - } - - // 测试 3:在子目录创建配置文件(应该找到) - subConfigFile := filepath.Join(subDir, "config.yaml") - if err := os.WriteFile(subConfigFile, []byte(configContent), 0644); err != nil { - t.Fatalf("创建子目录配置文件失败:%v", err) - } - - foundPath, err = findConfigFile(subDir) - if err != nil { - t.Errorf("从子目录查找失败:%v", err) - } else { - fmt.Printf("✓ 从子目录找到配置文件:%s\n", foundPath) - } - - fmt.Println("✓ 只在当前目录查找测试通过") -} - -// TestNoParentSearch 测试不向上查找 -func TestNoParentSearch(t *testing.T) { - fmt.Println("\n=== 测试不向上层目录查找 ===") - - // 创建临时目录结构 - tempDir, err := os.MkdirTemp("", "no_parent_test") - if err != nil { - t.Fatalf("创建临时目录失败:%v", err) - } - defer os.RemoveAll(tempDir) - - // 创建多级子目录 - level1 := filepath.Join(tempDir, "level1") - level2 := filepath.Join(level1, "level2") - level3 := filepath.Join(level2, "level3") - - if err := os.MkdirAll(level3, 0755); err != nil { - t.Fatalf("创建目录失败:%v", err) - } - - // 只在根目录创建配置文件 - configContent := `database: - host: "127.0.0.1" - port: "3306" - user: "root" - pass: "test" - name: "testdb" - type: "mysql" -` - - configFile := filepath.Join(tempDir, "config.yaml") - if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil { - t.Fatalf("创建配置文件失败:%v", err) - } - - // 从各级子目录查找(都应该失败,因为不越级查找) - testDirs := []string{level1, level2, level3} - for _, dir := range testDirs { - _, err := findConfigFile(dir) - if err == nil { - t.Errorf("从 %s 查找应该失败(不越级查找)", dir) - } else { - fmt.Printf("✓ 从 %s 查找正确失败(不越级)\n", filepath.Base(dir)) - } - } - - fmt.Println("✓ 不向上层目录查找测试通过") -} - -// TestAllNoParentSearch 完整的不越级查找测试 -func TestAllNoParentSearch(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" 不越级查找完整性测试") - fmt.Println("========================================") - - TestFindConfigOnlyCurrentDir(t) - TestNoParentSearch(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有不越级查找测试完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已实现的不越级查找功能:") - fmt.Println(" ✓ 只在当前工作目录查找配置文件") - fmt.Println(" ✓ 不会向上层目录查找") - fmt.Println(" ✓ 支持 yaml, yml, toml, ini, json 格式") - fmt.Println(" ✓ 支持 config.* 和 .config.* 命名") - fmt.Println(" ✓ 提供 AutoConnect() 一键连接") - fmt.Println(" ✓ 无需手动指定配置文件路径") - fmt.Println() -} diff --git a/db/config/usage_example.go b/db/config/usage_example.go deleted file mode 100644 index f832148..0000000 --- a/db/config/usage_example.go +++ /dev/null @@ -1,102 +0,0 @@ -package config - -import ( - "fmt" - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/driver" -) - -// 示例:在应用程序中使用纯自研驱动管理 -func ExampleUsage() { - // 1. 首先导入你选择的第三方数据库驱动 - // 注意:这些驱动需要在 main 包或启动包中导入 - /* - import ( - _ "github.com/mattn/go-sqlite3" // SQLite 驱动 - _ "github.com/go-sql-driver/mysql" // MySQL 驱动 - _ "github.com/lib/pq" // PostgreSQL 驱动 - _ "github.com/denisenkom/go-mssqldb" // SQL Server 驱动 - ) - */ - - // 2. 获取驱动管理器并注册你选择的驱动 - manager := driver.GetDefaultManager() - - // 注册 SQLite 驱动 - sqliteDriver := driver.NewGenericDriver("sqlite3") - manager.Register("sqlite3", sqliteDriver) - manager.Register("sqlite", sqliteDriver) // 别名 - - // 注册 MySQL 驱动 - mysqlDriver := driver.NewGenericDriver("mysql") - manager.Register("mysql", mysqlDriver) - - // 注册 PostgreSQL 驱动 - postgresDriver := driver.NewGenericDriver("postgres") - manager.Register("postgres", postgresDriver) - - // 3. 加载配置文件 - configFile, err := LoadFromFile("config.yaml") - if err != nil { - fmt.Printf("加载配置失败:%v\n", err) - return - } - - // 4. 验证驱动是否已注册 - err = manager.RegisterDriverByConfig(configFile.Database.Type) - if err != nil { - fmt.Printf("驱动未注册:%v\n", err) - return - } - - // 5. 使用配置创建数据库连接 - dbConfig := &core.Config{ - DriverName: configFile.Database.GetDriverName(), - DataSource: configFile.Database.BuildDSN(), - Debug: true, - MaxIdleConns: 10, - MaxOpenConns: 100, - ConnMaxLifetime: 3600000000000, // 1小时 - } - - // 6. 创建数据库实例 - db, err := core.NewDatabase(dbConfig) - if err != nil { - fmt.Printf("创建数据库连接失败:%v\n", err) - return - } - - fmt.Printf("成功连接到 %s 数据库\n", configFile.Database.Type) - - // 现在可以使用 db 进行数据库操作 - _ = db -} - -// AdvancedExample 高级使用示例 -func AdvancedExample() { - manager := driver.GetDefaultManager() - - // 根据环境变量或配置动态注册驱动 - databaseType := "sqlite3" // 从配置获取 - - // 注册对应驱动 - genericDriver := driver.NewGenericDriver(databaseType) - manager.Register(databaseType, genericDriver) - - // 验证驱动 - err := manager.RegisterDriverByConfig(databaseType) - if err != nil { - fmt.Printf("驱动注册问题:%v\n", err) - return - } - - // 现在可以安全地打开数据库连接 - db, err := manager.Open(databaseType, "./example.db") - if err != nil { - fmt.Printf("打开数据库失败:%v\n", err) - return - } - defer db.Close() - - fmt.Println("数据库连接成功") -} diff --git a/db/config_time_test.go b/db/config_time_test.go deleted file mode 100644 index b4b6ceb..0000000 --- a/db/config_time_test.go +++ /dev/null @@ -1,216 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "testing" - "time" - - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/model" -) - -// TestTimeConfig 测试时间配置 -func TestTimeConfig(t *testing.T) { - fmt.Println("\n=== 测试时间配置 ===") - - // 测试默认配置 - defaultConfig := core.DefaultTimeConfig() - fmt.Printf("默认创建时间字段:%s\n", defaultConfig.GetCreatedAt()) - fmt.Printf("默认更新时间字段:%s\n", defaultConfig.GetUpdatedAt()) - fmt.Printf("默认删除时间字段:%s\n", defaultConfig.GetDeletedAt()) - fmt.Printf("默认时间格式:%s\n", defaultConfig.GetFormat()) - - // 测试自定义配置 - customConfig := &core.TimeConfig{ - CreatedAt: "create_time", - UpdatedAt: "update_time", - DeletedAt: "delete_time", - Format: "2006-01-02 15:04:05", - } - customConfig.Validate() - - fmt.Printf("\n自定义创建时间字段:%s\n", customConfig.GetCreatedAt()) - fmt.Printf("自定义更新时间字段:%s\n", customConfig.GetUpdatedAt()) - fmt.Printf("自定义删除时间字段:%s\n", customConfig.GetDeletedAt()) - fmt.Printf("自定义时间格式:%s\n", customConfig.GetFormat()) - - // 测试格式化 - now := time.Now() - formatted := customConfig.FormatTime(now) - fmt.Printf("\n格式化时间:%s -> %s\n", now.Format("2006-01-02 15:04:05"), formatted) - - // 测试解析 - parsed, err := customConfig.ParseTime(formatted) - if err != nil { - t.Errorf("解析时间失败:%v", err) - } - fmt.Printf("解析时间:%s -> %s\n", formatted, parsed.Format("2006-01-02 15:04:05")) - - fmt.Println("✓ 时间配置测试通过") -} - -// TestCustomTimeFields 测试自定义时间字段 -func TestCustomTimeFields(t *testing.T) { - fmt.Println("\n=== 测试自定义时间字段模型 ===") - - // 使用自定义字段的模型 - type CustomModel struct { - ID int64 `json:"id" db:"id"` - Name string `json:"name" db:"name"` - CreateTime model.Time `json:"create_time" db:"create_time"` // 自定义创建时间字段 - UpdateTime model.Time `json:"update_time" db:"update_time"` // 自定义更新时间字段 - DeleteTime *model.Time `json:"delete_time,omitempty" db:"delete_time"` // 自定义删除时间字段 - } - - now := time.Now() - custom := &CustomModel{ - ID: 1, - Name: "test", - CreateTime: model.Time{Time: now}, - UpdateTime: model.Time{Time: now}, - } - - // 序列化为 JSON - jsonData, err := json.Marshal(custom) - if err != nil { - t.Errorf("JSON 序列化失败:%v", err) - } - - fmt.Printf("原始时间:%s\n", now.Format("2006-01-02 15:04:05")) - fmt.Printf("JSON 输出:%s\n", string(jsonData)) - - // 验证时间格式 - var result map[string]interface{} - if err := json.Unmarshal(jsonData, &result); err != nil { - t.Errorf("JSON 反序列化失败:%v", err) - } - - createTime, ok := result["create_time"].(string) - if !ok { - t.Error("create_time 应该是字符串") - } - - _, err = time.Parse("2006-01-02 15:04:05", createTime) - if err != nil { - t.Errorf("时间格式不正确:%v", err) - } - - fmt.Println("✓ 自定义时间字段测试通过") -} - -// TestDatabaseWithTimeConfig 测试数据库配置中的时间配置 -func TestDatabaseWithTimeConfig(t *testing.T) { - fmt.Println("\n=== 测试数据库时间配置 ===") - - // 创建带自定义时间配置的 Config - config := &core.Config{ - DriverName: "sqlite", - DataSource: ":memory:", - Debug: true, - TimeConfig: &core.TimeConfig{ - CreatedAt: "created_at", - UpdatedAt: "updated_at", - DeletedAt: "deleted_at", - Format: "2006-01-02 15:04:05", - }, - } - - fmt.Printf("配置中的创建时间字段:%s\n", config.TimeConfig.GetCreatedAt()) - fmt.Printf("配置中的更新时间字段:%s\n", config.TimeConfig.GetUpdatedAt()) - fmt.Printf("配置中的删除时间字段:%s\n", config.TimeConfig.GetDeletedAt()) - fmt.Printf("配置中的时间格式:%s\n", config.TimeConfig.GetFormat()) - - // 注意:这里不实际创建数据库连接,仅测试配置 - fmt.Println("\n数据库会使用该配置自动处理时间字段:") - fmt.Println(" - Insert: 自动设置 created_at/updated_at 为当前时间") - fmt.Println(" - Update: 自动设置 updated_at 为当前时间") - fmt.Println(" - Delete: 软删除时设置 deleted_at 为当前时间") - fmt.Println(" - Read: 所有时间字段格式化为 YYYY-MM-DD HH:mm:ss") - - fmt.Println("✓ 数据库时间配置测试通过") -} - -// TestAllTimeFormats 测试所有时间格式 -func TestAllTimeFormats(t *testing.T) { - fmt.Println("\n=== 测试所有支持的时间格式 ===") - - testCases := []struct { - format string - timeStr string - }{ - {"2006-01-02 15:04:05", "2026-04-02 22:09:09"}, - {"2006/01/02 15:04:05", "2026/04/02 22:09:09"}, - {"2006-01-02T15:04:05", "2026-04-02T22:09:09"}, - {"2006-01-02", "2026-04-02"}, - } - - for _, tc := range testCases { - t.Run(tc.format, func(t *testing.T) { - parsed, err := time.Parse(tc.format, tc.timeStr) - if err != nil { - t.Logf("格式 %s 解析失败:%v", tc.format, err) - return - } - - // 统一格式化为标准格式 - formatted := parsed.Format("2006-01-02 15:04:05") - fmt.Printf("%s -> %s\n", tc.timeStr, formatted) - }) - } - - fmt.Println("✓ 所有时间格式测试通过") -} - -// TestDateTimeType 测试 datetime 类型支持 -func TestDateTimeType(t *testing.T) { - fmt.Println("\n=== 测试 DATETIME 类型支持 ===") - - // Go 的 time.Time 会自动映射到数据库的 DATETIME 类型 - now := time.Now() - - // 在 SQLite 中,DATETIME 存储为 TEXT(ISO8601 格式) - // 在 MySQL 中,DATETIME 存储为 DATETIME 类型 - // Go 的 database/sql 会自动处理类型转换 - - fmt.Printf("Go time.Time: %s\n", now.Format("2006-01-02 15:04:05")) - fmt.Printf("数据库 DATETIME: 自动映射(由驱动处理)\n") - fmt.Println(" - SQLite: TEXT (ISO8601)") - fmt.Println(" - MySQL: DATETIME") - fmt.Println(" - PostgreSQL: TIMESTAMP") - - // model.Time 包装后仍然保持 time.Time 的特性 - customTime := model.Time{Time: now} - fmt.Printf("model.Time: %s\n", customTime.String()) - - fmt.Println("✓ DATETIME 类型测试通过") -} - -// TestCompleteTimeHandling 完整时间处理测试 -func TestCompleteTimeHandling(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" CRUD 操作时间配置完整性测试") - fmt.Println("========================================") - - TestTimeConfig(t) - TestCustomTimeFields(t) - TestDatabaseWithTimeConfig(t) - TestAllTimeFormats(t) - TestDateTimeType(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有时间配置测试完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已实现的时间配置功能:") - fmt.Println(" ✓ 配置文件定义创建时间字段名") - fmt.Println(" ✓ 配置文件定义更新时间字段名") - fmt.Println(" ✓ 配置文件定义删除时间字段名") - fmt.Println(" ✓ 配置文件定义时间格式(默认年 - 月-日 时:分:秒)") - fmt.Println(" ✓ Insert: 自动设置配置的时间字段") - fmt.Println(" ✓ Update: 自动设置配置的更新时间字段") - fmt.Println(" ✓ Delete: 软删除使用配置的删除时间字段") - fmt.Println(" ✓ Read: 所有时间字段格式化为配置的格式") - fmt.Println(" ✓ 支持 DATETIME 类型自动映射") - fmt.Println() -} diff --git a/db/core/cache.go b/db/core/cache.go deleted file mode 100644 index 922e66b..0000000 --- a/db/core/cache.go +++ /dev/null @@ -1,148 +0,0 @@ -package core - -import ( - "crypto/md5" - "encoding/hex" - "encoding/json" - "fmt" - "sync" - "time" -) - -// CacheItem 缓存项 -type CacheItem struct { - Data interface{} // 缓存的数据 - ExpiresAt time.Time // 过期时间 -} - -// QueryCache 查询缓存 - 提高重复查询的性能 -type QueryCache struct { - mu sync.RWMutex // 读写锁 - items map[string]*CacheItem // 缓存项 - duration time.Duration // 默认缓存时长 -} - -// NewQueryCache 创建查询缓存实例 -func NewQueryCache(duration time.Duration) *QueryCache { - cache := &QueryCache{ - items: make(map[string]*CacheItem), - duration: duration, - } - - // 启动清理协程 - go cache.cleaner() - - return cache -} - -// Set 设置缓存 -func (qc *QueryCache) Set(key string, data interface{}) { - qc.mu.Lock() - defer qc.mu.Unlock() - - qc.items[key] = &CacheItem{ - Data: data, - ExpiresAt: time.Now().Add(qc.duration), - } -} - -// Get 获取缓存 -func (qc *QueryCache) Get(key string) (interface{}, bool) { - qc.mu.RLock() - defer qc.mu.RUnlock() - - item, exists := qc.items[key] - if !exists { - return nil, false - } - - // 检查是否过期 - if time.Now().After(item.ExpiresAt) { - return nil, false - } - - return item.Data, true -} - -// Delete 删除缓存 -func (qc *QueryCache) Delete(key string) { - qc.mu.Lock() - defer qc.mu.Unlock() - delete(qc.items, key) -} - -// Clear 清空所有缓存 -func (qc *QueryCache) Clear() { - qc.mu.Lock() - defer qc.mu.Unlock() - qc.items = make(map[string]*CacheItem) -} - -// cleaner 定期清理过期缓存 -func (qc *QueryCache) cleaner() { - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - - for range ticker.C { - qc.cleanExpired() - } -} - -// cleanExpired 清理过期的缓存项 -func (qc *QueryCache) cleanExpired() { - qc.mu.Lock() - defer qc.mu.Unlock() - - now := time.Now() - for key, item := range qc.items { - if now.After(item.ExpiresAt) { - delete(qc.items, key) - } - } -} - -// GenerateCacheKey 生成缓存键 -func GenerateCacheKey(sql string, args ...interface{}) string { - // 将 SQL 和参数组合成字符串 - keyData := sql - for _, arg := range args { - keyData += fmt.Sprintf("%v", arg) - } - - // 计算 MD5 哈希 - hash := md5.Sum([]byte(keyData)) - return hex.EncodeToString(hash[:]) -} - -// deepCopy 深拷贝数据(使用 JSON 序列化/反序列化) -func deepCopy(src, dst interface{}) error { - // 序列化为 JSON - data, err := json.Marshal(src) - if err != nil { - return fmt.Errorf("序列化失败:%w", err) - } - - // 反序列化到目标 - if err := json.Unmarshal(data, dst); err != nil { - return fmt.Errorf("反序列化失败:%w", err) - } - - return nil -} - -// WithCache 带缓存的查询装饰器 -func (q *QueryBuilder) WithCache(cache *QueryCache) IQuery { - if cache == nil { - return q - } - - // 设置缓存实例 - q.cache = cache - q.useCache = true - - // 生成缓存键(使用 SQL 和参数) - sqlStr, args := q.BuildSelect() - q.cacheKey = GenerateCacheKey(sqlStr, args...) - - return q -} diff --git a/db/core/cache_test.go b/db/core/cache_test.go deleted file mode 100644 index 3f9edf3..0000000 --- a/db/core/cache_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package core - -import ( - "fmt" - "testing" - "time" -) - -// TestWithCache 测试带缓存的查询 -func TestWithCache(t *testing.T) { - fmt.Println("\n=== 测试带缓存的查询 ===") - - // 创建缓存实例(缓存 5 分钟) - _ = NewQueryCache(5 * time.Minute) - - // 注意:这个测试需要真实的数据库连接 - // 以下是使用示例: - - // 示例 1: 基本缓存查询 - // var users []User - // err := db.Model(&User{}). - // Where("status = ?", "active"). - // WithCache(cache). - // Find(&users) - // if err != nil { - // t.Fatal(err) - // } - - // 示例 2: 第二次查询会命中缓存 - // var users2 []User - // err = db.Model(&User{}). - // Where("status = ?", "active"). - // WithCache(cache). - // Find(&users2) - // if err != nil { - // t.Fatal(err) - // } - - fmt.Println("✓ WithCache 已实现") - fmt.Println("功能:") - fmt.Println(" - 缓存命中时直接返回数据,不执行 SQL") - fmt.Println(" - 缓存未命中时执行查询并自动缓存结果") - fmt.Println(" - 支持深拷贝,避免引用问题") - fmt.Println("✓ 测试通过") -} - -// TestDeepCopy 测试深拷贝功能 -func TestDeepCopy(t *testing.T) { - fmt.Println("\n=== 测试深拷贝功能 ===") - - type TestData struct { - ID int `json:"id"` - Name string `json:"name"` - } - - src := &TestData{ID: 1, Name: "test"} - dst := &TestData{} - - err := deepCopy(src, dst) - if err != nil { - t.Fatal(err) - } - - if dst.ID != src.ID || dst.Name != src.Name { - t.Errorf("深拷贝失败:期望 %+v, 得到 %+v", src, dst) - } - - // 修改源数据,目标不应该受影响 - src.Name = "modified" - if dst.Name == "modified" { - t.Error("深拷贝失败:目标受到了源数据修改的影响") - } - - fmt.Println("✓ 深拷贝功能正常") - fmt.Println("✓ 测试通过") -} - -// TestCacheKeyGeneration 测试缓存键生成 -func TestCacheKeyGeneration(t *testing.T) { - fmt.Println("\n=== 测试缓存键生成 ===") - - // 相同的 SQL 和参数应该生成相同的键 - key1 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1) - key2 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1) - - if key1 != key2 { - t.Errorf("缓存键不一致:%s vs %s", key1, key2) - } - - // 不同的参数应该生成不同的键 - key3 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 2) - if key1 == key3 { - t.Error("不同的参数应该生成不同的缓存键") - } - - fmt.Println("✓ 缓存键生成正常") - fmt.Println("✓ 测试通过") -} - -// ExampleWithCache 使用示例 -func exampleWithCacheUsage() { - // 示例 1: 基本用法 - // cache := NewQueryCache(5 * time.Minute) - // var results []map[string]interface{} - // err := db.Table("users"). - // Where("status = ?", "active"). - // WithCache(cache). - // Find(&results) - - // 示例 2: 带条件的查询 - // err := db.Model(&User{}). - // Select("id", "username", "email"). - // Where("age > ?", 18). - // Order("created_at DESC"). - // Limit(10). - // WithCache(cache). - // Find(&results) - - // 示例 3: 清除缓存 - // cache.Clear() // 清空所有缓存 - // cache.Delete(key) // 删除指定缓存 - - fmt.Println("使用示例请查看测试代码") -} diff --git a/db/core/config.go b/db/core/config.go deleted file mode 100644 index 5ef486c..0000000 --- a/db/core/config.go +++ /dev/null @@ -1,75 +0,0 @@ -package core - -import ( - "time" -) - -// TimeConfig 时间配置 - 定义时间字段名称和格式 -type TimeConfig struct { - CreatedAt string `json:"created_at" yaml:"created_at"` // 创建时间字段名 - UpdatedAt string `json:"updated_at" yaml:"updated_at"` // 更新时间字段名 - DeletedAt string `json:"deleted_at" yaml:"deleted_at"` // 删除时间字段名 - Format string `json:"format" yaml:"format"` // 时间格式,默认 "2006-01-02 15:04:05" -} - -// DefaultTimeConfig 获取默认时间配置 -func DefaultTimeConfig() *TimeConfig { - return &TimeConfig{ - CreatedAt: "created_at", - UpdatedAt: "updated_at", - DeletedAt: "deleted_at", - Format: "2006-01-02 15:04:05", // Go 的参考时间格式 - } -} - -// Validate 验证时间配置 -func (tc *TimeConfig) Validate() { - if tc.CreatedAt == "" { - tc.CreatedAt = "created_at" - } - if tc.UpdatedAt == "" { - tc.UpdatedAt = "updated_at" - } - if tc.DeletedAt == "" { - tc.DeletedAt = "deleted_at" - } - if tc.Format == "" { - tc.Format = "2006-01-02 15:04:05" - } -} - -// GetCreatedAt 获取创建时间字段名 -func (tc *TimeConfig) GetCreatedAt() string { - tc.Validate() - return tc.CreatedAt -} - -// GetUpdatedAt 获取更新时间字段名 -func (tc *TimeConfig) GetUpdatedAt() string { - tc.Validate() - return tc.UpdatedAt -} - -// GetDeletedAt 获取删除时间字段名 -func (tc *TimeConfig) GetDeletedAt() string { - tc.Validate() - return tc.DeletedAt -} - -// GetFormat 获取时间格式 -func (tc *TimeConfig) GetFormat() string { - tc.Validate() - return tc.Format -} - -// FormatTime 格式化时间为配置的格式 -func (tc *TimeConfig) FormatTime(t time.Time) string { - tc.Validate() - return t.Format(tc.Format) -} - -// ParseTime 解析时间字符串 -func (tc *TimeConfig) ParseTime(timeStr string) (time.Time, error) { - tc.Validate() - return time.Parse(tc.Format, timeStr) -} diff --git a/db/core/dao.go b/db/core/dao.go deleted file mode 100644 index fccccd2..0000000 --- a/db/core/dao.go +++ /dev/null @@ -1,314 +0,0 @@ -package core - -import ( - "context" - "reflect" -) - -// DAO 数据访问对象基类 - 所有 DAO 都继承此结构 -// 提供通用的 CRUD 操作方法,子类只需嵌入即可使用 -type DAO struct { - db *Database // 数据库连接实例 - modelType interface{} // 模型类型信息,用于 Columns 等方法 -} - -// NewDAO 创建 DAO 基类实例 -// 自动使用全局默认 Database 实例 -func NewDAO() *DAO { - return &DAO{ - db: GetDefaultDatabase(), - } -} - -// NewDAOWithModel 创建带模型类型的 DAO 基类实例 -// 参数: -// - model: 模型实例(指针类型),用于获取表结构信息 -// -// 自动使用全局默认 Database 实例 -func NewDAOWithModel(model interface{}) *DAO { - return &DAO{ - db: GetDefaultDatabase(), - modelType: model, - } -} - -// Create 创建记录(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) Create(ctx context.Context, model interface{}) error { - // 使用事务来插入数据 - tx, err := dao.db.Begin() - if err != nil { - return err - } - - _, err = tx.Insert(model) - if err != nil { - tx.Rollback() - return err - } - - return tx.Commit() -} - -// GetByID 根据 ID 查询单条记录(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) GetByID(ctx context.Context, model interface{}, id int64) error { - return dao.db.Model(model).Where("id = ?", id).First(model) -} - -// Update 更新记录(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) Update(ctx context.Context, model interface{}, data map[string]interface{}) error { - pkValue := getFieldValue(model, "ID") - - if pkValue == 0 { - return nil - } - - return dao.db.Model(model).Where("id = ?", pkValue).Updates(data) -} - -// Delete 删除记录(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) Delete(ctx context.Context, model interface{}) error { - pkValue := getFieldValue(model, "ID") - - if pkValue == 0 { - return nil - } - - return dao.db.Model(model).Where("id = ?", pkValue).Delete() -} - -// FindAll 查询所有记录(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) FindAll(ctx context.Context, model interface{}) error { - return dao.db.Model(model).Find(model) -} - -// FindByPage 分页查询(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) FindByPage(ctx context.Context, model interface{}, page, pageSize int) error { - return dao.db.Model(model).Limit(pageSize).Offset((page - 1) * pageSize).Find(model) -} - -// Count 统计记录数(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) Count(ctx context.Context, model interface{}, where ...string) (int64, error) { - var count int64 - - query := dao.db.Model(model) - if len(where) > 0 { - query = query.Where(where[0]) - } - - // Count 是链式调用,需要调用 Find 来执行 - err := query.Count(&count).Find(model) - if err != nil { - return 0, err - } - return count, nil -} - -// Exists 检查记录是否存在(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) Exists(ctx context.Context, model interface{}) (bool, error) { - count, err := dao.Count(ctx, model) - if err != nil { - return false, err - } - return count > 0, nil -} - -// First 查询第一条记录(通用方法) -// 自动使用 DAO 中已关联的 Database 实例 -func (dao *DAO) First(ctx context.Context, model interface{}) error { - return dao.db.Model(model).First(model) -} - -// Columns 获取表的所有列名 -// 返回一个动态创建的结构体类型,所有字段都是 string 类型 -// 用途:用于构建 UPDATE、INSERT 等操作时的列名映射 -// -// 示例: -// -// type UserDAO struct { -// *core.DAO -// } -// -// func NewUserDAO() *UserDAO { -// return &UserDAO{ -// DAO: core.NewDAOWithModel(&model.User{}), -// } -// } -// -// // 使用 -// dao := NewUserDAO() -// cols := dao.Columns() // 返回 *struct{ID string; Username string; ...} -func (dao *DAO) Columns() interface{} { - // 检查是否有模型类型信息 - if dao.modelType == nil { - return nil - } - - // 获取模型类型 - modelType := reflect.TypeOf(dao.modelType) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - - // 创建字段列表 - fields := []reflect.StructField{} - - // 遍历模型的所有字段 - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - - // 跳过未导出的字段 - if !field.IsExported() { - continue - } - - // 获取 db 标签,如果没有则跳过 - dbTag := field.Tag.Get("db") - if dbTag == "" || dbTag == "-" { - continue - } - - // 创建新的结构体字段,类型为 string - newField := reflect.StructField{ - Name: field.Name, - Type: reflect.TypeOf(""), // string 类型 - Tag: reflect.StructTag(`json:"` + field.Tag.Get("json") + `" db:"` + dbTag + `"`), - } - - fields = append(fields, newField) - } - - // 动态创建结构体类型 - columnsType := reflect.StructOf(fields) - - // 创建该类型的指针并返回 - return reflect.New(columnsType).Interface() -} - -// getFieldValue 获取结构体字段值(辅助函数) -// 用于获取主键或其他字段的值,支持多种数据类型 -// 参数: -// - model: 模型实例(可以是指针或值) -// - fieldName: 字段名(如 "ID", "UserId" 等) -// -// 返回: -// - int64: 字段值(如果是数字类型)或 0(无法获取时) -func getFieldValue(model interface{}, fieldName string) int64 { - // 检查空值 - if model == nil { - return 0 - } - - // 获取反射对象 - val := reflect.ValueOf(model) - - // 如果是指针,解引用 - if val.Kind() == reflect.Ptr { - if val.IsNil() { - return 0 - } - val = val.Elem() - } - - // 确保是结构体 - if val.Kind() != reflect.Struct { - return 0 - } - - // 查找字段 - field := val.FieldByName(fieldName) - if !field.IsValid() { - // 尝试查找常见的主键字段名变体 - alternativeNames := []string{"Id", "id", "ID"} - for _, name := range alternativeNames { - if name != fieldName { - field = val.FieldByName(name) - if field.IsValid() { - fieldName = name - break - } - } - } - - if !field.IsValid() { - return 0 - } - } - - // 检查字段是否可以访问 - if !field.CanInterface() { - return 0 - } - - // 获取字段值并转换为 int64 - fieldValue := field.Interface() - - // 根据字段类型进行转换 - switch v := fieldValue.(type) { - case int: - return int64(v) - case int8: - return int64(v) - case int16: - return int64(v) - case int32: - return int64(v) - case int64: - return v - case uint: - return int64(v) - case uint8: - return int64(v) - case uint16: - return int64(v) - case uint32: - return int64(v) - case uint64: - // 注意:uint64 转 int64 可能溢出,但这里假设 ID 不会超过 int64 范围 - return int64(v) - case float32: - return int64(v) - case float64: - return int64(v) - case string: - // 尝试将字符串解析为数字 - // 注意:这里不导入 strconv,简单处理返回 0 - return 0 - default: - // 其他类型(如 sql.NullInt64 等),尝试使用反射 - return convertToInteger(field) - } -} - -// convertToInteger 使用反射将字段值转换为 int64 -func convertToInteger(field reflect.Value) int64 { - // 获取实际的值(如果是指针则解引用) - if field.Kind() == reflect.Ptr { - if field.IsNil() { - return 0 - } - field = field.Elem() - } - - // 根据 Kind 进行转换 - switch field.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return field.Int() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return int64(field.Uint()) - case reflect.Float32, reflect.Float64: - return int64(field.Float()) - case reflect.String: - // 字符串类型,尝试解析(简单实现,不处理错误) - return 0 - default: - return 0 - } -} diff --git a/db/core/dao_test.go b/db/core/dao_test.go deleted file mode 100644 index 606ad96..0000000 --- a/db/core/dao_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package core - -import ( - "database/sql" - "fmt" - "testing" -) - -// TestGetFieldValue 测试获取字段值的基本功能 -func TestGetFieldValue(t *testing.T) { - fmt.Println("\n=== 测试 getFieldValue 基本功能 ===") - - tests := []struct { - name string - model interface{} - fieldName string - expected int64 - }{ - {"int 类型", &TestModelInt{ID: 123}, "ID", 123}, - {"int64 类型", &TestModelInt64{ID: 456}, "ID", 456}, - {"uint 类型", &TestModelUint{ID: 789}, "ID", 789}, - {"float 类型", &TestModelFloat{ID: 999.5}, "ID", 999}, - {"指针为 nil", (*TestModelInt)(nil), "ID", 0}, - {"model 为 nil", nil, "ID", 0}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := getFieldValue(tt.model, tt.fieldName) - if result != tt.expected { - t.Errorf("期望 %d, 得到 %d", tt.expected, result) - } - }) - } - - fmt.Println("✓ 基本功能测试通过") -} - -// TestGetFieldValueAlternativeNames 测试字段名变体查找 -func TestGetFieldValueAlternativeNames(t *testing.T) { - fmt.Println("\n=== 测试字段名变体查找 ===") - - // 测试 Id 字段(驼峰) - model1 := &TestModelId{Id: 111} - result1 := getFieldValue(model1, "ID") - if result1 != 111 { - t.Errorf("期望 111, 得到 %d", result1) - } - - fmt.Println("✓ 字段名变体查找测试通过") -} - -// TestGetFieldValueEdgeCases 测试边界情况 -func TestGetFieldValueEdgeCases(t *testing.T) { - fmt.Println("\n=== 测试边界情况 ===") - - // 测试非结构体类型 - nonStruct := 123 - result := getFieldValue(nonStruct, "ID") - if result != 0 { - t.Errorf("非结构体应该返回 0, 得到 %d", result) - } - - // 测试不存在的字段(没有 ID/Id/id 等变体) - type ModelNoID struct { - Name string `json:"name" db:"name"` - } - model := &ModelNoID{Name: "test"} - result = getFieldValue(model, "NonExistentField") - if result != 0 { - t.Errorf("不存在的字段应该返回 0, 得到 %d", result) - } - - fmt.Println("✓ 边界情况测试通过") -} - -// TestGetFieldValueSpecialTypes 测试特殊类型 -func TestGetFieldValueSpecialTypes(t *testing.T) { - fmt.Println("\n=== 测试特殊类型 ===") - - // 注意:sql.NullInt64 等数据库特殊类型目前不支持 - // 如果需要支持,可以在 convertToInteger 中添加专门的处理逻辑 - - fmt.Println("✓ 特殊类型测试通过(当前版本不支持 sql.NullInt64)") -} - -// TestGetFieldValueInUpdate 测试在 Update 场景中的使用 -func TestGetFieldValueInUpdate(t *testing.T) { - fmt.Println("\n=== 测试 Update 场景 ===") - - user := &UserModel{ - ID: 1, - Username: "test", - } - - pkValue := getFieldValue(user, "ID") - if pkValue != 1 { - t.Errorf("期望主键值为 1, 得到 %d", pkValue) - } - - // 测试主键为 0 的情况 - user2 := &UserModel{ - ID: 0, - Username: "test2", - } - - pkValue2 := getFieldValue(user2, "ID") - if pkValue2 != 0 { - t.Errorf("期望主键值为 0, 得到 %d", pkValue2) - } - - fmt.Println("✓ Update 场景测试通过") -} - -// TestGetFieldValueLargeNumbers 测试大数字 -func TestGetFieldValueLargeNumbers(t *testing.T) { - fmt.Println("\n=== 测试大数字 ===") - - // 测试最大 int64 - maxInt := int64(9223372036854775807) - model1 := &TestModelInt64{ID: maxInt} - result1 := getFieldValue(model1, "ID") - if result1 != maxInt { - t.Errorf("期望 %d, 得到 %d", maxInt, result1) - } - - // 测试 uint64 转 int64 - largeUint := uint64(18446744073709551615) // 这会导致溢出 - model2 := &TestModelUint64{ID: largeUint} - result2 := getFieldValue(model2, "ID") - // 注意:这里会发生溢出,但这是预期的行为 - if result2 == 0 { - t.Error("uint64 转换不应该返回 0") - } - - fmt.Println("✓ 大数字测试通过") -} - -// 测试模型定义 -type TestModelInt struct { - ID int `json:"id" db:"id"` -} - -type TestModelInt64 struct { - ID int64 `json:"id" db:"id"` -} - -type TestModelUint struct { - ID uint `json:"id" db:"id"` -} - -type TestModelUint64 struct { - ID uint64 `json:"id" db:"id"` -} - -type TestModelFloat struct { - ID float64 `json:"id" db:"id"` -} - -type TestModelId struct { - Id int64 `json:"id" db:"id"` -} - -type TestModelid struct { - id int64 `json:"id" db:"id"` -} - -type TestModelPrivate struct { - privateField int64 -} - -type TestModelNullInt struct { - ID sql.NullInt64 `json:"id" db:"id"` -} - -type UserModel struct { - ID int64 `json:"id" db:"id"` - Username string `json:"username" db:"username"` -} diff --git a/db/core/database.go b/db/core/database.go deleted file mode 100644 index 37dfdb5..0000000 --- a/db/core/database.go +++ /dev/null @@ -1,144 +0,0 @@ -package core - -import ( - "fmt" - "os" - "path/filepath" - - "git.magicany.cc/black1552/gin-base/db/driver" -) - -// defaultDatabase 全局默认数据库连接实例 -var defaultDatabase *Database - -// NewDatabase 创建数据库连接 - 初始化数据库连接和相关组件 -func NewDatabase(config *Config) (*Database, error) { - db := &Database{ - config: config, - debug: config.Debug, - driverName: config.DriverName, - } - - // 初始化时间配置 - if config.TimeConfig == nil { - db.timeConfig = DefaultTimeConfig() - } else { - db.timeConfig = config.TimeConfig - db.timeConfig.Validate() - } - - // 获取驱动管理器 - dm := driver.GetDefaultManager() - - // 打开数据库连接 - sqlDB, err := dm.Open(config.DriverName, config.DataSource) - if err != nil { - return nil, fmt.Errorf("打开数据库失败:%w", err) - } - - db.db = sqlDB - - // 配置连接池参数 - if config.MaxIdleConns > 0 { - db.db.SetMaxIdleConns(config.MaxIdleConns) - } - if config.MaxOpenConns > 0 { - db.db.SetMaxOpenConns(config.MaxOpenConns) - } - if config.ConnMaxLifetime > 0 { - db.db.SetConnMaxLifetime(config.ConnMaxLifetime) - } - - // 测试数据库连接 - if err := db.db.Ping(); err != nil { - return nil, fmt.Errorf("数据库连接测试失败:%w", err) - } - - // 初始化组件 - db.mapper = NewFieldMapper() - db.migrator = NewMigrator(db) - - if config.Debug { - fmt.Println("[Magic-ORM] 数据库连接成功") - } - - // 设置为全局默认实例 - defaultDatabase = db - - return db, nil -} - -// AutoConnect 自动查找配置文件并创建数据库连接 -// 会在当前目录及上级目录中查找 config.yaml, config.toml, config.ini, config.json 等文件 -func AutoConnect(debug bool) (*Database, error) { - // 自动查找配置文件 - configPath, err := findConfigFile("") - if err != nil { - return nil, fmt.Errorf("查找配置文件失败:%w", err) - } - - // 从文件加载配置(使用 config 包) - return loadAndConnect(configPath, debug) -} - -// Connect 从配置文件创建数据库连接(向后兼容) -// Deprecated: 使用 AutoConnect 代替 -func Connect(configPath string, debug bool) (*Database, error) { - return loadAndConnect(configPath, debug) -} - -// findConfigFile 在项目目录下自动查找配置文件 -// 支持 yaml, yml, toml, ini, json 等格式 -// 只在当前目录查找,不越级查找 -func findConfigFile(searchDir string) (string, error) { - // 配置文件名优先级列表 - configNames := []string{ - "config.yaml", "config.yml", - "config.toml", - "config.ini", - "config.json", - ".config.yaml", ".config.yml", - ".config.toml", - ".config.ini", - ".config.json", - } - - // 如果未指定搜索目录,使用当前目录 - if searchDir == "" { - var err error - searchDir, err = os.Getwd() - if err != nil { - return "", fmt.Errorf("获取当前目录失败:%w", err) - } - } - - // 只在当前目录下查找,不向上查找 - for _, name := range configNames { - filePath := filepath.Join(searchDir, name) - if _, err := os.Stat(filePath); err == nil { - return filePath, nil - } - } - - return "", fmt.Errorf("未找到配置文件(支持 yaml, yml, toml, ini, json 格式)") -} - -// loadAndConnect 从配置文件加载并创建数据库连接 -func loadAndConnect(configPath string, debug bool) (*Database, error) { - // 这里需要调用 config 包的 LoadFromFile - // 为了避免循环依赖,我们直接在 core 包中实现简单的 YAML 解析 - // 或者通过接口传递配置 - - // 简单方案:返回错误,提示使用 config 包 - return nil, fmt.Errorf("请使用 config.AutoConnect() 方法") -} - -// GetDefaultDatabase 获取全局默认数据库实例 -func GetDefaultDatabase() *Database { - return defaultDatabase -} - -// SetDefaultDatabase 设置全局默认数据库实例 -func SetDefaultDatabase(db *Database) { - defaultDatabase = db -} diff --git a/db/core/filter.go b/db/core/filter.go deleted file mode 100644 index d6da604..0000000 --- a/db/core/filter.go +++ /dev/null @@ -1,94 +0,0 @@ -package core - -import ( - "reflect" - "time" -) - -// ParamFilter 参数过滤器 - 智能过滤零值和空值字段 -type ParamFilter struct{} - -// NewParamFilter 创建参数过滤器实例 -func NewParamFilter() *ParamFilter { - return &ParamFilter{} -} - -// FilterZeroValues 过滤零值和空值字段 -func (pf *ParamFilter) FilterZeroValues(data map[string]interface{}) map[string]interface{} { - result := make(map[string]interface{}) - - for key, value := range data { - if !pf.isZeroValue(value) { - result[key] = value - } - } - - return result -} - -// FilterEmptyStrings 过滤空字符串 -func (pf *ParamFilter) FilterEmptyStrings(data map[string]interface{}) map[string]interface{} { - result := make(map[string]interface{}) - - for key, value := range data { - if str, ok := value.(string); ok { - if str != "" { - result[key] = value - } - } else { - result[key] = value - } - } - - return result -} - -// FilterNilValues 过滤 nil 值 -func (pf *ParamFilter) FilterNilValues(data map[string]interface{}) map[string]interface{} { - result := make(map[string]interface{}) - - for key, value := range data { - if value != nil { - result[key] = value - } - } - - return result -} - -// isZeroValue 检查是否是零值 -func (pf *ParamFilter) isZeroValue(v interface{}) bool { - if v == nil { - return true - } - - val := reflect.ValueOf(v) - - switch val.Kind() { - case reflect.Array, reflect.Map, reflect.Slice, reflect.String: - return val.Len() == 0 - case reflect.Bool: - return !val.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return val.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return val.Uint() == 0 - case reflect.Float32, reflect.Float64: - return val.Float() == 0 - case reflect.Interface, reflect.Ptr: - return val.IsNil() - case reflect.Struct: - // 特殊处理 time.Time - if t, ok := v.(time.Time); ok { - return t.IsZero() - } - return false - } - - return false -} - -// IsValidValue 检查值是否有效(非零值、非空值) -func (pf *ParamFilter) IsValidValue(v interface{}) bool { - return !pf.isZeroValue(v) -} diff --git a/db/core/interfaces.go b/db/core/interfaces.go deleted file mode 100644 index 8fdcd10..0000000 --- a/db/core/interfaces.go +++ /dev/null @@ -1,254 +0,0 @@ -package core - -import ( - "database/sql" - "time" -) - -// IDatabase 数据库连接接口 - 提供所有数据库操作的顶层接口 -type IDatabase interface { - // 基础操作 - DB() *sql.DB // 返回底层的 sql.DB 对象 - Close() error // 关闭数据库连接 - Ping() error // 测试数据库连接是否正常 - - // 事务管理 - Begin() (ITx, error) // 开始一个新事务 - Transaction(fn func(ITx) error) error // 执行事务,自动提交或回滚 - - // 查询构建器 - Model(model interface{}) IQuery // 基于模型创建查询 - Table(name string) IQuery // 基于表名创建查询 - Query(result interface{}, query string, args ...interface{}) error // 执行原生 SQL 查询 - Exec(query string, args ...interface{}) (sql.Result, error) // 执行原生 SQL 并返回结果 - - // 迁移管理 - Migrate(models ...interface{}) error // 执行数据库迁移 - - // 配置 - SetDebug(bool) // 设置调试模式 - SetMaxIdleConns(int) // 设置最大空闲连接数 - SetMaxOpenConns(int) // 设置最大打开连接数 - SetConnMaxLifetime(time.Duration) // 设置连接最大生命周期 -} - -// ITx 事务接口 - 提供事务操作的所有方法 -type ITx interface { - // 基础操作 - Commit() error // 提交事务 - Rollback() error // 回滚事务 - - // 查询操作 - Model(model interface{}) IQuery // 在事务中基于模型创建查询 - Table(name string) IQuery // 在事务中基于表名创建查询 - Insert(model interface{}) (int64, error) // 插入数据,返回插入的 ID - BatchInsert(models interface{}, batchSize int) error // 批量插入数据 - Update(model interface{}, data map[string]interface{}) error // 更新数据 - Delete(model interface{}) error // 删除数据 - - // 原生 SQL - Query(result interface{}, query string, args ...interface{}) error // 执行原生 SQL 查询 - Exec(query string, args ...interface{}) (sql.Result, error) // 执行原生 SQL -} - -// IQuery 查询构建器接口 - 提供流畅的链式查询构建能力 -type IQuery interface { - // 条件查询 - Where(query string, args ...interface{}) IQuery // 添加 WHERE 条件 - Or(query string, args ...interface{}) IQuery // 添加 OR 条件 - And(query string, args ...interface{}) IQuery // 添加 AND 条件 - - // 字段选择 - Select(fields ...string) IQuery // 选择要查询的字段 - Omit(fields ...string) IQuery // 排除指定的字段 - - // 排序 - Order(order string) IQuery // 设置排序规则 - OrderBy(field string, direction string) IQuery // 按指定字段和方向排序 - - // 分页 - Limit(limit int) IQuery // 限制返回数量 - Offset(offset int) IQuery // 设置偏移量 - Page(page, pageSize int) IQuery // 分页查询 - - // 分组 - Group(group string) IQuery // 设置分组字段 - Having(having string, args ...interface{}) IQuery // 添加 HAVING 条件 - - // 连接 - Join(join string, args ...interface{}) IQuery // 添加 JOIN 连接 - LeftJoin(table, on string) IQuery // 左连接 - RightJoin(table, on string) IQuery // 右连接 - InnerJoin(table, on string) IQuery // 内连接 - - // 预加载 - Preload(relation string, conditions ...interface{}) IQuery // 预加载关联数据 - - // 执行查询 - First(result interface{}) error // 查询第一条记录 - Find(result interface{}) error // 查询多条记录 - Count(count *int64) IQuery // 统计记录数量 - Exists() (bool, error) // 检查记录是否存在 - - // 更新和删除 - Updates(data interface{}) error // 更新数据 - UpdateColumn(column string, value interface{}) error // 更新单个字段 - Delete() error // 删除数据 - - // 特殊模式 - Unscoped() IQuery // 忽略软删除 - DryRun() IQuery // 干跑模式,不执行只生成 SQL - Debug() IQuery // 调试模式,打印 SQL 日志 - - // 构建 SQL(不执行) - Build() (string, []interface{}) // 构建 SELECT SQL 语句 - BuildUpdate(data interface{}) (string, []interface{}) // 构建 UPDATE SQL 语句 - BuildDelete() (string, []interface{}) // 构建 DELETE SQL 语句 -} - -// IModel 模型接口 - 定义模型的基本行为和生命周期回调 -type IModel interface { - // 表名映射 - TableName() string // 返回模型对应的表名 - - // 生命周期回调(可选实现) - BeforeCreate(tx ITx) error // 创建前回调 - AfterCreate(tx ITx) error // 创建后回调 - BeforeUpdate(tx ITx) error // 更新前回调 - AfterUpdate(tx ITx) error // 更新后回调 - BeforeDelete(tx ITx) error // 删除前回调 - AfterDelete(tx ITx) error // 删除后回调 - BeforeSave(tx ITx) error // 保存前回调 - AfterSave(tx ITx) error // 保存后回调 -} - -// IFieldMapper 字段映射器接口 - 处理 Go 结构体与数据库字段之间的映射 -type IFieldMapper interface { - // 结构体字段转数据库列 - StructToColumns(model interface{}) (map[string]interface{}, error) // 将结构体转换为键值对 - - // 数据库列转结构体字段 - ColumnsToStruct(row *sql.Rows, model interface{}) error // 将查询结果映射到结构体 - - // 获取表名 - GetTableName(model interface{}) string // 获取模型对应的表名 - - // 获取主键字段 - GetPrimaryKey(model interface{}) string // 获取主键字段名 - - // 获取字段信息 - GetFields(model interface{}) []FieldInfo // 获取所有字段信息 -} - -// FieldInfo 字段信息 - 描述数据库字段的详细信息 -type FieldInfo struct { - Name string // 字段名(Go 结构体字段名) - Column string // 列名(数据库中的实际列名) - Type string // Go 类型(如 string, int, time.Time 等) - DbType string // 数据库类型(如 VARCHAR, INT, DATETIME 等) - Tag string // 标签(db 标签内容) - IsPrimary bool // 是否主键 - IsAuto bool // 是否自增 -} - -// IMigrator 迁移管理器接口 - 提供数据库架构迁移的所有操作 -type IMigrator interface { - // 自动迁移 - AutoMigrate(models ...interface{}) error // 自动执行模型迁移 - - // 表操作 - CreateTable(model interface{}) error // 创建表 - DropTable(model interface{}) error // 删除表 - HasTable(model interface{}) (bool, error) // 检查表是否存在 - RenameTable(oldName, newName string) error // 重命名表 - - // 列操作 - AddColumn(model interface{}, field string) error // 添加列 - DropColumn(model interface{}, field string) error // 删除列 - HasColumn(model interface{}, field string) (bool, error) // 检查列是否存在 - RenameColumn(model interface{}, oldField, newField string) error // 重命名列 - - // 索引操作 - CreateIndex(model interface{}, field string) error // 创建索引 - DropIndex(model interface{}, field string) error // 删除索引 - HasIndex(model interface{}, field string) (bool, error) // 检查索引是否存在 -} - -// ICodeGenerator 代码生成器接口 - 自动生成 Model 和 DAO 代码 -type ICodeGenerator interface { - // 生成 Model 代码 - GenerateModel(table string, outputDir string) error // 根据表生成 Model 文件 - - // 生成 DAO 代码 - GenerateDAO(table string, outputDir string) error // 根据表生成 DAO 文件 - - // 生成完整代码 - GenerateAll(tables []string, outputDir string) error // 批量生成所有代码 - - // 从数据库读取表结构 - InspectTable(tableName string) (*TableSchema, error) // 检查表结构 -} - -// TableSchema 表结构信息 - 描述数据库表的完整结构 -type TableSchema struct { - Name string // 表名 - Columns []ColumnInfo // 列信息列表 - Indexes []IndexInfo // 索引信息列表 -} - -// ColumnInfo 列信息 - 描述表中一个列的详细信息 -type ColumnInfo struct { - Name string // 列名 - Type string // 数据类型 - Nullable bool // 是否允许为空 - Default interface{} // 默认值 - PrimaryKey bool // 是否主键 -} - -// IndexInfo 索引信息 - 描述表中一个索引的详细信息 -type IndexInfo struct { - Name string // 索引名 - Columns []string // 索引包含的列 - Unique bool // 是否唯一索引 -} - -// ReadPolicy 读负载均衡策略 - 定义主从集群中读操作的分配策略 -type ReadPolicy int - -const ( - Random ReadPolicy = iota // 随机选择一个从库 - RoundRobin // 轮询方式选择从库 - LeastConn // 选择连接数最少的从库 -) - -// Config 数据库配置 - 包含数据库连接的所有配置项 -type Config struct { - DriverName string // 驱动名称(如 mysql, sqlite, postgres 等) - DataSource string // 数据源连接字符串(DNS) - MaxIdleConns int // 最大空闲连接数 - MaxOpenConns int // 最大打开连接数 - ConnMaxLifetime time.Duration // 连接最大生命周期 - Debug bool // 调试模式(是否打印 SQL 日志) - - // 主从配置 - Replicas []string // 从库列表(用于读写分离) - ReadPolicy ReadPolicy // 读负载均衡策略 - - // OpenTelemetry 可观测性配置 - EnableTracing bool // 是否启用链路追踪 - ServiceName string // 服务名称(用于 Tracing) - - // 时间配置 - TimeConfig *TimeConfig // 时间字段配置(字段名、格式等) -} - -// Database 数据库实现 - IDatabase 接口的具体实现 -type Database struct { - db *sql.DB // 底层数据库连接 - config *Config // 数据库配置 - debug bool // 调试模式开关 - mapper IFieldMapper // 字段映射器实例 - migrator IMigrator // 迁移管理器实例 - driverName string // 驱动名称 - timeConfig *TimeConfig // 时间配置 -} diff --git a/db/core/mapper.go b/db/core/mapper.go deleted file mode 100644 index f2b75ac..0000000 --- a/db/core/mapper.go +++ /dev/null @@ -1,306 +0,0 @@ -package core - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" -) - -// FieldMapper 字段映射器实现 - 使用反射处理 Go 结构体与数据库字段之间的映射 -type FieldMapper struct{} - -// NewFieldMapper 创建字段映射器实例 -func NewFieldMapper() IFieldMapper { - return &FieldMapper{} -} - -// StructToColumns 将结构体转换为键值对 - 用于 INSERT/UPDATE 操作 -func (fm *FieldMapper) StructToColumns(model interface{}) (map[string]interface{}, error) { - result := make(map[string]interface{}) - - // 获取反射对象 - val := reflect.ValueOf(model) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - if val.Kind() != reflect.Struct { - return nil, errors.New("模型必须是结构体") - } - - typ := val.Type() - - // 遍历所有字段 - for i := 0; i < val.NumField(); i++ { - field := typ.Field(i) - value := val.Field(i) - - // 跳过未导出的字段 - if !field.IsExported() { - continue - } - - // 获取 db 标签 - dbTag := field.Tag.Get("db") - if dbTag == "" || dbTag == "-" { - continue // 跳过没有 db 标签或标签为 - 的字段 - } - - // 跳过零值(可选优化) - if fm.isZeroValue(value) { - continue - } - - // 添加到结果 map - result[dbTag] = value.Interface() - } - - return result, nil -} - -// ColumnsToStruct 将查询结果映射到结构体 - 用于 SELECT 操作 -func (fm *FieldMapper) ColumnsToStruct(rows *sql.Rows, model interface{}) error { - // 获取列信息 - columns, err := rows.Columns() - if err != nil { - return fmt.Errorf("获取列信息失败:%w", err) - } - - // 获取反射对象 - val := reflect.ValueOf(model) - if val.Kind() != reflect.Ptr { - return errors.New("模型必须是指针类型") - } - - elem := val.Elem() - if elem.Kind() != reflect.Struct { - return errors.New("模型必须是指向结构体的指针") - } - - // 创建扫描目标 - scanTargets := make([]interface{}, len(columns)) - fieldMap := make(map[int]int) // column index -> field index - - // 建立列名到结构体字段的映射 - for i, col := range columns { - found := false - for j := 0; j < elem.NumField(); j++ { - field := elem.Type().Field(j) - dbTag := field.Tag.Get("db") - - // 匹配列名和字段 - if dbTag == col || strings.ToLower(dbTag) == strings.ToLower(col) || - strings.ToLower(field.Name) == strings.ToLower(col) { - fieldMap[i] = j - found = true - break - } - } - - // 如果没找到匹配字段,使用 interface{} 占位 - if !found { - var dummy interface{} - scanTargets[i] = &dummy - } - } - - // 为找到的字段创建扫描目标 - for i := range columns { - if fieldIdx, ok := fieldMap[i]; ok { - field := elem.Field(fieldIdx) - if field.CanSet() { - scanTargets[i] = field.Addr().Interface() - } else { - var dummy interface{} - scanTargets[i] = &dummy - } - } - } - - // 执行扫描 - if err := rows.Scan(scanTargets...); err != nil { - return fmt.Errorf("扫描数据失败:%w", err) - } - - return nil -} - -// GetTableName 获取模型对应的表名 -func (fm *FieldMapper) GetTableName(model interface{}) string { - // 检查是否实现了 TableName() 方法 - type tabler interface { - TableName() string - } - - if t, ok := model.(tabler); ok { - return t.TableName() - } - - // 否则使用结构体名称 - val := reflect.ValueOf(model) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - typ := val.Type() - return fm.toSnakeCase(typ.Name()) -} - -// GetPrimaryKey 获取主键字段名 - 默认为 "id" -func (fm *FieldMapper) GetPrimaryKey(model interface{}) string { - // 查找标记为主键的字段 - val := reflect.ValueOf(model) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - typ := val.Type() - for i := 0; i < val.NumField(); i++ { - field := typ.Field(i) - - // 检查是否是 ID 字段 - fieldName := field.Name - if fieldName == "ID" || fieldName == "Id" || fieldName == "id" { - dbTag := field.Tag.Get("db") - if dbTag != "" && dbTag != "-" { - return dbTag - } - return "id" - } - - // 检查是否有 primary 标签 - if field.Tag.Get("primary") == "true" { - dbTag := field.Tag.Get("db") - if dbTag != "" { - return dbTag - } - } - } - - return "id" // 默认返回 id -} - -// GetFields 获取所有字段信息 - 用于生成 SQL 语句 -func (fm *FieldMapper) GetFields(model interface{}) []FieldInfo { - var fields []FieldInfo - - val := reflect.ValueOf(model) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - typ := val.Type() - - // 遍历所有字段 - for i := 0; i < val.NumField(); i++ { - field := typ.Field(i) - - // 跳过未导出的字段 - if !field.IsExported() { - continue - } - - // 获取 db 标签 - dbTag := field.Tag.Get("db") - if dbTag == "" || dbTag == "-" { - continue - } - - // 创建字段信息 - info := FieldInfo{ - Name: field.Name, - Column: dbTag, - Type: fm.getTypeName(field.Type), - DbType: fm.mapToDbType(field.Type), - Tag: dbTag, - } - - // 检查是否是主键 - if field.Tag.Get("primary") == "true" || - field.Name == "ID" || field.Name == "Id" { - info.IsPrimary = true - } - - // 检查是否是自增 - if field.Tag.Get("auto") == "true" { - info.IsAuto = true - } - - fields = append(fields, info) - } - - return fields -} - -// isZeroValue 检查是否是零值 -func (fm *FieldMapper) isZeroValue(v reflect.Value) bool { - switch v.Kind() { - case reflect.Array, reflect.Map, reflect.Slice, reflect.String: - return v.Len() == 0 - case reflect.Bool: - return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0 - case reflect.Interface, reflect.Ptr: - return v.IsNil() - case reflect.Struct: - // 特殊处理 time.Time - if t, ok := v.Interface().(time.Time); ok { - return t.IsZero() - } - return false - } - return false -} - -// getTypeName 获取类型的名称 -func (fm *FieldMapper) getTypeName(t reflect.Type) string { - return t.String() -} - -// mapToDbType 将 Go 类型映射到数据库类型 -func (fm *FieldMapper) mapToDbType(t reflect.Type) string { - switch t.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return "BIGINT" - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return "BIGINT UNSIGNED" - case reflect.Float32, reflect.Float64: - return "DECIMAL" - case reflect.Bool: - return "TINYINT" - case reflect.String: - return "VARCHAR(255)" - default: - // 特殊类型 - if t.PkgPath() == "time" && t.Name() == "Time" { - return "DATETIME" - } - return "TEXT" - } -} - -// toSnakeCase 将驼峰命名转换为下划线命名 -func (fm *FieldMapper) toSnakeCase(str string) string { - var result strings.Builder - - for i, r := range str { - if r >= 'A' && r <= 'Z' { - if i > 0 { - result.WriteRune('_') - } - result.WriteRune(r + 32) // 转换为小写 - } else { - result.WriteRune(r) - } - } - - return result.String() -} diff --git a/db/core/migrator.go b/db/core/migrator.go deleted file mode 100644 index 106f736..0000000 --- a/db/core/migrator.go +++ /dev/null @@ -1,292 +0,0 @@ -package core - -import ( - "fmt" - "strings" -) - -// Migrator 迁移管理器实现 - 处理数据库架构的自动迁移 -type Migrator struct { - db *Database // 数据库连接实例 -} - -// NewMigrator 创建迁移管理器实例 -func NewMigrator(db *Database) IMigrator { - return &Migrator{db: db} -} - -// AutoMigrate 自动迁移 - 根据模型自动创建或更新数据库表结构 -func (m *Migrator) AutoMigrate(models ...interface{}) error { - for _, model := range models { - if err := m.CreateTable(model); err != nil { - return fmt.Errorf("创建表失败:%w", err) - } - } - return nil -} - -// CreateTable 创建表 - 根据模型创建数据库表 -func (m *Migrator) CreateTable(model interface{}) error { - mapper := NewFieldMapper() - - // 获取表名 - tableName := mapper.GetTableName(model) - - // 获取字段信息 - fields := mapper.GetFields(model) - if len(fields) == 0 { - return fmt.Errorf("模型没有有效的字段") - } - - // 生成 CREATE TABLE SQL - var sqlBuilder strings.Builder - sqlBuilder.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", tableName)) - - columnDefs := make([]string, 0) - for _, field := range fields { - colDef := fmt.Sprintf("%s %s", field.Column, field.DbType) - - // 添加主键约束 - if field.IsPrimary { - colDef += " PRIMARY KEY" - if field.IsAuto { - colDef += " AUTOINCREMENT" - } - } - - // 添加 NOT NULL 约束(可选) - // colDef += " NOT NULL" - - columnDefs = append(columnDefs, colDef) - } - - sqlBuilder.WriteString(strings.Join(columnDefs, ", ")) - sqlBuilder.WriteString(")") - - createSQL := sqlBuilder.String() - - if m.db.debug { - fmt.Printf("[Magic-ORM] CREATE TABLE SQL: %s\n", createSQL) - } - - // 执行 SQL - _, err := m.db.db.Exec(createSQL) - if err != nil { - return fmt.Errorf("执行 CREATE TABLE 失败:%w", err) - } - - return nil -} - -// DropTable 删除表 - 删除指定的数据库表 -func (m *Migrator) DropTable(model interface{}) error { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName) - - if m.db.debug { - fmt.Printf("[Magic-ORM] DROP TABLE SQL: %s\n", dropSQL) - } - - _, err := m.db.db.Exec(dropSQL) - if err != nil { - return fmt.Errorf("执行 DROP TABLE 失败:%w", err) - } - - return nil -} - -// HasTable 检查表是否存在 - 验证数据库中是否已存在指定表 -func (m *Migrator) HasTable(model interface{}) (bool, error) { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - // SQLite 检查表是否存在的 SQL - checkSQL := `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?` - - var count int - err := m.db.db.QueryRow(checkSQL, tableName).Scan(&count) - if err != nil { - return false, fmt.Errorf("检查表是否存在失败:%w", err) - } - - return count > 0, nil -} - -// RenameTable 重命名表 - 修改数据库表的名称 -func (m *Migrator) RenameTable(oldName, newName string) error { - renameSQL := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName) - - if m.db.debug { - fmt.Printf("[Magic-ORM] RENAME TABLE SQL: %s\n", renameSQL) - } - - _, err := m.db.db.Exec(renameSQL) - if err != nil { - return fmt.Errorf("重命名表失败:%w", err) - } - - return nil -} - -// AddColumn 添加列 - 向表中添加新的字段 -func (m *Migrator) AddColumn(model interface{}, field string) error { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - // 获取字段信息 - fields := mapper.GetFields(model) - var targetField *FieldInfo - - for _, f := range fields { - if f.Name == field || f.Column == field { - targetField = &f - break - } - } - - if targetField == nil { - return fmt.Errorf("字段不存在:%s", field) - } - - addSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", - tableName, targetField.Column, targetField.DbType) - - if m.db.debug { - fmt.Printf("[Magic-ORM] ADD COLUMN SQL: %s\n", addSQL) - } - - _, err := m.db.db.Exec(addSQL) - if err != nil { - return fmt.Errorf("添加列失败:%w", err) - } - - return nil -} - -// DropColumn 删除列 - 从表中删除指定的字段 -func (m *Migrator) DropColumn(model interface{}, field string) error { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - // SQLite 不直接支持 DROP COLUMN,需要重建表 - // 这里使用简化方案:创建新表 -> 复制数据 -> 删除旧表 -> 重命名 - - _ = tableName // 避免编译错误 - return fmt.Errorf("SQLite 不支持直接删除列,需要手动重建表") -} - -// HasColumn 检查列是否存在 - 验证表中是否已存在指定字段 -func (m *Migrator) HasColumn(model interface{}, field string) (bool, error) { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - // SQLite 检查列是否存在的 SQL - checkSQL := `PRAGMA table_info(` + tableName + `)` - - rows, err := m.db.db.Query(checkSQL) - if err != nil { - return false, fmt.Errorf("检查列失败:%w", err) - } - defer rows.Close() - - for rows.Next() { - var cid int - var name string - var typ string - var notNull int - var dfltValue interface{} - var pk int - - if err := rows.Scan(&cid, &name, &typ, ¬Null, &dfltValue, &pk); err != nil { - return false, err - } - - if name == field { - return true, nil - } - } - - return false, nil -} - -// RenameColumn 重命名列 - 修改表中字段的名称 -func (m *Migrator) RenameColumn(model interface{}, oldField, newField string) error { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - // SQLite 3.25.0+ 支持 ALTER TABLE ... RENAME COLUMN - renameSQL := fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", - tableName, oldField, newField) - - if m.db.debug { - fmt.Printf("[Magic-ORM] RENAME COLUMN SQL: %s\n", renameSQL) - } - - _, err := m.db.db.Exec(renameSQL) - if err != nil { - return fmt.Errorf("重命名列失败:%w", err) - } - - return nil -} - -// CreateIndex 创建索引 - 为表中的字段创建索引 -func (m *Migrator) CreateIndex(model interface{}, field string) error { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - indexName := fmt.Sprintf("idx_%s_%s", tableName, field) - createSQL := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)", - indexName, tableName, field) - - if m.db.debug { - fmt.Printf("[Magic-ORM] CREATE INDEX SQL: %s\n", createSQL) - } - - _, err := m.db.db.Exec(createSQL) - if err != nil { - return fmt.Errorf("创建索引失败:%w", err) - } - - return nil -} - -// DropIndex 删除索引 - 删除表中的指定索引 -func (m *Migrator) DropIndex(model interface{}, field string) error { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - indexName := fmt.Sprintf("idx_%s_%s", tableName, field) - dropSQL := fmt.Sprintf("DROP INDEX IF EXISTS %s", indexName) - - if m.db.debug { - fmt.Printf("[Magic-ORM] DROP INDEX SQL: %s\n", dropSQL) - } - - _, err := m.db.db.Exec(dropSQL) - if err != nil { - return fmt.Errorf("删除索引失败:%w", err) - } - - return nil -} - -// HasIndex 检查索引是否存在 - 验证表中是否已存在指定索引 -func (m *Migrator) HasIndex(model interface{}, field string) (bool, error) { - mapper := NewFieldMapper() - tableName := mapper.GetTableName(model) - - indexName := fmt.Sprintf("idx_%s_%s", tableName, field) - - checkSQL := `SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND name=?` - - var count int - err := m.db.db.QueryRow(checkSQL, indexName).Scan(&count) - if err != nil { - return false, fmt.Errorf("检查索引失败:%w", err) - } - - return count > 0, nil -} diff --git a/db/core/omit_example.go b/db/core/omit_example.go deleted file mode 100644 index ec741ee..0000000 --- a/db/core/omit_example.go +++ /dev/null @@ -1,46 +0,0 @@ -package core - -import ( - "fmt" -) - -// ExampleQueryBuilder_Omit 演示 Omit 方法的使用 -func ExampleQueryBuilder_Omit() { - // 定义用户模型 - type User struct { - ID int64 `json:"id" db:"id"` - Name string `json:"name" db:"name"` - Email string `json:"email" db:"email"` - Password string `json:"password" db:"password"` - Status int `json:"status" db:"status"` - } - - // 创建 Database 实例(示例中使用 nil,实际使用需要正确初始化) - db := &Database{} - - // 示例 1: 排除敏感字段(如密码) - q1 := db.Model(&User{}).Omit("password") - sql1, _ := q1.(*QueryBuilder).BuildSelect() - fmt.Printf("排除密码:%s\n", sql1) - - // 示例 2: 排除多个字段 - q2 := db.Model(&User{}).Omit("password", "status") - sql2, _ := q2.(*QueryBuilder).BuildSelect() - fmt.Printf("排除多个字段:%s\n", sql2) - - // 示例 3: 链式调用 Omit - q3 := db.Model(&User{}).Omit("password").Omit("status") - sql3, _ := q3.(*QueryBuilder).BuildSelect() - fmt.Printf("链式调用:%s\n", sql3) - - // 示例 4: Select 优先于 Omit - q4 := db.Model(&User{}).Select("id", "name").Omit("password") - sql4, _ := q4.(*QueryBuilder).BuildSelect() - fmt.Printf("Select 优先:%s\n", sql4) - - // 输出: - // 排除密码:SELECT id, name, email, status FROM user_model - // 排除多个字段:SELECT id, name, email FROM user_model - // 链式调用:SELECT id, name, email FROM user_model - // Select 优先:SELECT id, name FROM user_model -} diff --git a/db/core/omit_test.go b/db/core/omit_test.go deleted file mode 100644 index 078f7d9..0000000 --- a/db/core/omit_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package core - -import ( - "fmt" - "testing" -) - -// TestQueryBuilder_Omit 测试 Omit 方法 -func TestQueryBuilder_Omit(t *testing.T) { - // 创建测试模型 - type UserModel struct { - ID int64 `json:"id" db:"id"` - Name string `json:"name" db:"name"` - Email string `json:"email" db:"email"` - Password string `json:"password" db:"password"` - Status int `json:"status" db:"status"` - } - - // 创建 Database 实例(使用 nil 连接,只测试 SQL 生成) - db := &Database{} - - t.Run("排除单个字段", func(t *testing.T) { - qb := db.Model(&UserModel{}).Omit("password").(*QueryBuilder) - sql, args := qb.BuildSelect() - - fmt.Printf("排除单个字段 SQL: %s\n", sql) - fmt.Printf("参数:%v\n", args) - - // 验证 SQL 不包含 password 字段 - if containsString(sql, "password") { - t.Errorf("SQL 不应该包含 password 字段:%s", sql) - } - - // 验证 SQL 包含其他字段 - expectedFields := []string{"id", "name", "email", "status"} - for _, field := range expectedFields { - if !containsString(sql, field) { - t.Errorf("SQL 应该包含字段 %s: %s", field, sql) - } - } - }) - - t.Run("排除多个字段", func(t *testing.T) { - qb := db.Model(&UserModel{}).Omit("password", "status").(*QueryBuilder) - sql, args := qb.BuildSelect() - - fmt.Printf("排除多个字段 SQL: %s\n", sql) - fmt.Printf("参数:%v\n", args) - - // 验证 SQL 不包含 password 和 status 字段 - if containsString(sql, "password") { - t.Errorf("SQL 不应该包含 password 字段:%s", sql) - } - if containsString(sql, "status") { - t.Errorf("SQL 不应该包含 status 字段:%s", sql) - } - - // 验证 SQL 包含其他字段 - expectedFields := []string{"id", "name", "email"} - for _, field := range expectedFields { - if !containsString(sql, field) { - t.Errorf("SQL 应该包含字段 %s: %s", field, sql) - } - } - }) - - t.Run("Omit 与 Select 优先级 - Select 优先", func(t *testing.T) { - qb := db.Model(&UserModel{}).Select("id", "name").Omit("password").(*QueryBuilder) - sql, args := qb.BuildSelect() - - fmt.Printf("Select 优先 SQL: %s\n", sql) - fmt.Printf("参数:%v\n", args) - - // 当同时使用 Select 和 Omit 时,Select 优先 - expectedFields := []string{"id", "name"} - for _, field := range expectedFields { - if !containsString(sql, field) { - t.Errorf("SQL 应该包含字段 %s: %s", field, sql) - } - } - }) - - t.Run("链式调用 Omit", func(t *testing.T) { - qb := db.Model(&UserModel{}).Omit("password").Omit("status").(*QueryBuilder) - sql, args := qb.BuildSelect() - - fmt.Printf("链式调用 Omit SQL: %s\n", sql) - fmt.Printf("参数:%v\n", args) - - // 验证 SQL 不包含 password 和 status 字段 - if containsString(sql, "password") { - t.Errorf("SQL 不应该包含 password 字段:%s", sql) - } - if containsString(sql, "status") { - t.Errorf("SQL 不应该包含 status 字段:%s", sql) - } - }) - - t.Run("不设置 Omit - 默认行为", func(t *testing.T) { - qb := db.Model(&UserModel{}).(*QueryBuilder) - sql, args := qb.BuildSelect() - - fmt.Printf("默认行为 SQL: %s\n", sql) - fmt.Printf("参数:%v\n", args) - - // 默认应该查询所有字段(使用 *) - if sql != "SELECT * FROM user_model" { - t.Errorf("默认应该使用 SELECT *: %s", sql) - } - }) -} - -// containsString 检查字符串是否包含子串 -func containsString(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || - s[:len(substr)] == substr || - s[len(s)-len(substr):] == substr || - findSubstring(s, substr)) -} - -func findSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/db/core/preload_test.go b/db/core/preload_test.go deleted file mode 100644 index 8c5b5e4..0000000 --- a/db/core/preload_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package core - -import ( - "fmt" - "testing" - "time" -) - -// User 用户模型 - 用于测试 -type User struct { - ID int64 `json:"id" db:"id"` - Username string `json:"username" db:"username"` - Email string `json:"email" db:"email"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - - // 关联字段 - Profile UserProfile `json:"profile" db:"-" gorm:"ForeignKey:UserID;References:ID"` - Orders []Order `json:"orders" db:"-" gorm:"ForeignKey:UserID;References:ID"` -} - -// TableName 表名 -func (User) TableName() string { - return "user" -} - -// UserProfile 用户资料模型 - 一对一关联 -type UserProfile struct { - ID int64 `json:"id" db:"id"` - UserID int64 `json:"user_id" db:"user_id"` - Bio string `json:"bio" db:"bio"` - Avatar string `json:"avatar" db:"avatar"` -} - -// TableName 表名 -func (UserProfile) TableName() string { - return "user_profile" -} - -// Order 订单模型 - 一对多关联 -type Order struct { - ID int64 `json:"id" db:"id"` - UserID int64 `json:"user_id" db:"user_id"` - OrderNo string `json:"order_no" db:"order_no"` - Amount float64 `json:"amount" db:"amount"` - CreatedAt time.Time `json:"created_at" db:"created_at"` -} - -// TableName 表名 -func (Order) TableName() string { - return "order" -} - -// TestPreloadHasOne 测试一对一预加载 -func TestPreloadHasOne(t *testing.T) { - fmt.Println("\n=== 测试一对一预加载 ===") - - // 这里只是示例,实际使用需要数据库连接 - // db := AutoConnect(true) - // var users []User - // err := db.Model(&User{}).Preload("Profile").Find(&users) - // if err != nil { - // t.Fatal(err) - // } - - fmt.Println("一对一预加载结构已实现") - fmt.Println("✓ 测试通过") -} - -// TestPreloadHasMany 测试一对多预加载 -func TestPreloadHasMany(t *testing.T) { - fmt.Println("\n=== 测试一对多预加载 ===") - - // 示例用法 - // var users []User - // err := db.Model(&User{}).Preload("Orders").Find(&users) - // if err != nil { - // t.Fatal(err) - // } - - fmt.Println("一对多预加载结构已实现") - fmt.Println("✓ 测试通过") -} - -// TestPreloadBelongsTo 测试多对一预加载 -func TestPreloadBelongsTo(t *testing.T) { - fmt.Println("\n=== 测试多对一预加载 ===") - - // 示例用法 - // var orders []Order - // err := db.Model(&Order{}).Preload("User").Find(&orders) - // if err != nil { - // t.Fatal(err) - // } - - fmt.Println("多对一预加载结构已实现") - fmt.Println("✓ 测试通过") -} - -// TestPreloadMultiple 测试多个预加载 -func TestPreloadMultiple(t *testing.T) { - fmt.Println("\n=== 测试多个预加载 ===") - - // 示例用法 - // var users []User - // err := db.Model(&User{}). - // Preload("Profile"). - // Preload("Orders"). - // Find(&users) - // if err != nil { - // t.Fatal(err) - // } - - fmt.Println("多个预加载已实现") - fmt.Println("✓ 测试通过") -} - -// TestPreloadWithConditions 测试带条件的预加载 -func TestPreloadWithConditions(t *testing.T) { - fmt.Println("\n=== 测试带条件的预加载 ===") - - // 示例用法 - // var users []User - // err := db.Model(&User{}). - // Preload("Orders", "amount > ?", 100). - // Find(&users) - // if err != nil { - // t.Fatal(err) - // } - - fmt.Println("带条件的预加载已实现") - fmt.Println("✓ 测试通过") -} diff --git a/db/core/query.go b/db/core/query.go deleted file mode 100644 index 1cd5c6d..0000000 --- a/db/core/query.go +++ /dev/null @@ -1,740 +0,0 @@ -package core - -import ( - "database/sql" - "fmt" - "strings" - "sync" -) - -// QueryBuilder 查询构建器实现 - 提供流畅的链式查询构建能力 -type QueryBuilder struct { - db *Database // 数据库连接实例 - table string // 表名 - model interface{} // 模型对象 - whereSQL string // WHERE 条件 SQL - whereArgs []interface{} // WHERE 条件参数 - selectCols []string // 选择的字段列表 - omitCols []string // 排除的字段列表 - orderSQL string // ORDER BY SQL - limit int // LIMIT 限制数量 - offset int // OFFSET 偏移量 - groupSQL string // GROUP BY SQL - havingSQL string // HAVING 条件 SQL - havingArgs []interface{} // HAVING 条件参数 - joinSQL string // JOIN SQL - joinArgs []interface{} // JOIN 参数 - debug bool // 调试模式开关 - dryRun bool // 干跑模式开关 - unscoped bool // 忽略软删除开关 - tx *sql.Tx // 事务对象(如果在事务中) - // 预加载关联数据 - preloadRelations map[string][]interface{} // 预加载的关联关系及条件 - // 缓存相关 - cache *QueryCache // 缓存实例 - cacheKey string // 缓存键 - useCache bool // 是否使用缓存 -} - -// 同步池优化 - 复用 slice 减少内存分配 -var whereArgsPool = sync.Pool{ - New: func() interface{} { - return make([]interface{}, 0, 10) - }, -} - -var joinArgsPool = sync.Pool{ - New: func() interface{} { - return make([]interface{}, 0, 5) - }, -} - -// Model 基于模型创建查询 -func (d *Database) Model(model interface{}) IQuery { - return &QueryBuilder{ - db: d, - model: model, - preloadRelations: make(map[string][]interface{}), - } -} - -// Table 基于表名创建查询 -func (d *Database) Table(name string) IQuery { - return &QueryBuilder{ - db: d, - table: name, - preloadRelations: make(map[string][]interface{}), - } -} - -// Where 添加 WHERE 条件 - 性能优化版本 -func (q *QueryBuilder) Where(query string, args ...interface{}) IQuery { - if q.whereSQL == "" { - q.whereSQL = query - } else { - // 使用 strings.Builder 优化字符串拼接 - var builder strings.Builder - builder.Grow(len(q.whereSQL) + 5 + len(query)) // 预分配内存 - builder.WriteString(q.whereSQL) - builder.WriteString(" AND ") - builder.WriteString(query) - q.whereSQL = builder.String() - } - q.whereArgs = append(q.whereArgs, args...) - return q -} - -// Or 添加 OR 条件 - 性能优化版本 -func (q *QueryBuilder) Or(query string, args ...interface{}) IQuery { - if q.whereSQL == "" { - q.whereSQL = query - } else { - // 使用 strings.Builder 优化字符串拼接 - var builder strings.Builder - builder.Grow(len(q.whereSQL) + 10 + len(query)) // 预分配内存 - builder.WriteString(" (") - builder.WriteString(q.whereSQL) - builder.WriteString(") OR ") - builder.WriteString(query) - q.whereSQL = builder.String() - } - q.whereArgs = append(q.whereArgs, args...) - return q -} - -// And 添加 AND 条件(同 Where) -func (q *QueryBuilder) And(query string, args ...interface{}) IQuery { - return q.Where(query, args...) -} - -// Select 选择要查询的字段 -func (q *QueryBuilder) Select(fields ...string) IQuery { - q.selectCols = fields - return q -} - -// Omit 排除指定的字段 -func (q *QueryBuilder) Omit(fields ...string) IQuery { - q.omitCols = append(q.omitCols, fields...) - return q -} - -// Order 设置排序规则 -func (q *QueryBuilder) Order(order string) IQuery { - q.orderSQL = order - return q -} - -// OrderBy 按指定字段和方向排序 -func (q *QueryBuilder) OrderBy(field string, direction string) IQuery { - q.orderSQL = field + " " + direction - return q -} - -// Limit 限制返回数量 -func (q *QueryBuilder) Limit(limit int) IQuery { - q.limit = limit - return q -} - -// Offset 设置偏移量 -func (q *QueryBuilder) Offset(offset int) IQuery { - q.offset = offset - return q -} - -// Page 分页查询 -func (q *QueryBuilder) Page(page, pageSize int) IQuery { - q.limit = pageSize - q.offset = (page - 1) * pageSize - return q -} - -// Group 设置分组字段 -func (q *QueryBuilder) Group(group string) IQuery { - q.groupSQL = group - return q -} - -// Having 添加 HAVING 条件 -func (q *QueryBuilder) Having(having string, args ...interface{}) IQuery { - q.havingSQL = having - q.havingArgs = args - return q -} - -// Join 添加 JOIN 连接 - 性能优化版本 -func (q *QueryBuilder) Join(join string, args ...interface{}) IQuery { - if q.joinSQL == "" { - q.joinSQL = join - } else { - // 使用 strings.Builder 优化字符串拼接 - var builder strings.Builder - builder.Grow(len(q.joinSQL) + 1 + len(join)) // 预分配内存 - builder.WriteString(q.joinSQL) - builder.WriteByte(' ') - builder.WriteString(join) - q.joinSQL = builder.String() - } - q.joinArgs = append(q.joinArgs, args...) - return q -} - -// LeftJoin 左连接 -func (q *QueryBuilder) LeftJoin(table, on string) IQuery { - return q.Join("LEFT JOIN " + table + " ON " + on) -} - -// RightJoin 右连接 -func (q *QueryBuilder) RightJoin(table, on string) IQuery { - return q.Join("RIGHT JOIN " + table + " ON " + on) -} - -// InnerJoin 内连接 -func (q *QueryBuilder) InnerJoin(table, on string) IQuery { - return q.Join("INNER JOIN " + table + " ON " + on) -} - -// Preload 预加载关联数据 -func (q *QueryBuilder) Preload(relation string, conditions ...interface{}) IQuery { - if q.preloadRelations == nil { - q.preloadRelations = make(map[string][]interface{}) - } - // 将关联条件添加到预加载列表中 - q.preloadRelations[relation] = conditions - return q -} - -// First 查询第一条记录 -func (q *QueryBuilder) First(result interface{}) error { - q.limit = 1 - return q.Find(result) -} - -// Find 查询多条记录 -func (q *QueryBuilder) Find(result interface{}) error { - // 如果使用缓存,先检查缓存 - if q.useCache && q.cache != nil && q.cacheKey != "" { - if cachedData, exists := q.cache.Get(q.cacheKey); exists { - // 缓存命中,将数据拷贝到结果对象 - if err := deepCopy(cachedData, result); err != nil { - return fmt.Errorf("缓存数据拷贝失败:%w", err) - } - if q.debug || (q.db != nil && q.db.debug) { - fmt.Printf("[Magic-ORM] 缓存命中:%s\n", q.cacheKey) - } - return nil - } - } - - // 缓存未命中,执行实际查询 - sqlStr, args := q.BuildSelect() - - // 调试模式打印 SQL - if q.debug || (q.db != nil && q.db.debug) { - fmt.Printf("[Magic-ORM] SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args) - } - - // 干跑模式不执行 SQL - if q.dryRun { - return nil - } - - var rows *sql.Rows - var err error - - // 判断是否在事务中 - if q.tx != nil { - rows, err = q.tx.Query(sqlStr, args...) - } else if q.db != nil && q.db.db != nil { - rows, err = q.db.db.Query(sqlStr, args...) - } else { - return fmt.Errorf("数据库连接未初始化") - } - - if err != nil { - return fmt.Errorf("查询失败:%w", err) - } - defer rows.Close() - - // 使用 ResultSetMapper 将查询结果映射到 result - mapper := NewResultSetMapper() - if err := mapper.ScanAll(rows, result); err != nil { - return fmt.Errorf("结果映射失败:%w", err) - } - - // 执行预加载关联数据 - if len(q.preloadRelations) > 0 { - if err := q.executePreload(result); err != nil { - return fmt.Errorf("预加载关联失败:%w", err) - } - } - - // 将结果存入缓存(如果启用了缓存) - if q.useCache && q.cache != nil && q.cacheKey != "" { - q.cache.Set(q.cacheKey, result) - if q.debug || (q.db != nil && q.db.debug) { - fmt.Printf("[Magic-ORM] 缓存已设置:%s\n", q.cacheKey) - } - } - - return nil -} - -// Count 统计记录数量 -func (q *QueryBuilder) Count(count *int64) IQuery { - // 构建 COUNT 查询 - originalSelect := q.selectCols - q.selectCols = []string{"COUNT(*)"} - - sqlStr, args := q.BuildSelect() - - // 调试模式 - if q.debug || (q.db != nil && q.db.debug) { - fmt.Printf("[Magic-ORM] COUNT SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args) - } - - // 干跑模式 - if q.dryRun { - return q - } - - var err error - if q.tx != nil { - err = q.tx.QueryRow(sqlStr, args...).Scan(count) - } else if q.db != nil && q.db.db != nil { - err = q.db.db.QueryRow(sqlStr, args...).Scan(count) - } - - if err != nil { - fmt.Printf("[Magic-ORM] Count 错误:%v\n", err) - } - - // 恢复原来的选择字段 - q.selectCols = originalSelect - return q -} - -// Exists 检查记录是否存在 -func (q *QueryBuilder) Exists() (bool, error) { - // 使用 LIMIT 1 优化查询 - originalLimit := q.limit - q.limit = 1 - - sqlStr, args := q.BuildSelect() - - // 调试模式 - if q.debug || (q.db != nil && q.db.debug) { - fmt.Printf("[Magic-ORM] EXISTS SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args) - } - - // 干跑模式 - if q.dryRun { - return false, nil - } - - var rows *sql.Rows - var err error - - if q.tx != nil { - rows, err = q.tx.Query(sqlStr, args...) - } else if q.db != nil && q.db.db != nil { - rows, err = q.db.db.Query(sqlStr, args...) - } else { - return false, fmt.Errorf("数据库连接未初始化") - } - defer rows.Close() - - if err != nil { - return false, err - } - - // 检查是否有结果 - exists := rows.Next() - - // 恢复原来的 limit - q.limit = originalLimit - - return exists, nil -} - -// Updates 更新数据 -func (q *QueryBuilder) Updates(data interface{}) error { - sqlStr, args := q.BuildUpdate(data) - - // 调试模式打印 SQL - if q.debug || (q.db != nil && q.db.debug) { - fmt.Printf("[Magic-ORM] UPDATE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args) - } - - // 干跑模式不执行 SQL - if q.dryRun { - return nil - } - - var err error - if q.tx != nil { - _, err = q.tx.Exec(sqlStr, args...) - } else if q.db != nil && q.db.db != nil { - _, err = q.db.db.Exec(sqlStr, args...) - } else { - return fmt.Errorf("数据库连接未初始化") - } - - if err != nil { - return fmt.Errorf("更新失败:%w", err) - } - return nil -} - -// UpdateColumn 更新单个字段 -func (q *QueryBuilder) UpdateColumn(column string, value interface{}) error { - return q.Updates(map[string]interface{}{column: value}) -} - -// Delete 删除数据 -func (q *QueryBuilder) Delete() error { - sqlStr, args := q.BuildDelete() - - // 调试模式打印 SQL - if q.debug || (q.db != nil && q.db.debug) { - fmt.Printf("[Magic-ORM] DELETE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args) - } - - // 干跑模式不执行 SQL - if q.dryRun { - return nil - } - - var err error - if q.tx != nil { - _, err = q.tx.Exec(sqlStr, args...) - } else if q.db != nil && q.db.db != nil { - _, err = q.db.db.Exec(sqlStr, args...) - } else { - return fmt.Errorf("数据库连接未初始化") - } - - if err != nil { - return fmt.Errorf("删除失败:%w", err) - } - return nil -} - -// Unscoped 忽略软删除限制 -func (q *QueryBuilder) Unscoped() IQuery { - q.unscoped = true - return q -} - -// DryRun 设置干跑模式(只生成 SQL 不执行) -func (q *QueryBuilder) DryRun() IQuery { - q.dryRun = true - return q -} - -// Debug 设置调试模式(打印 SQL 日志) -func (q *QueryBuilder) Debug() IQuery { - q.debug = true - return q -} - -// Build 构建 SELECT SQL 语句 -func (q *QueryBuilder) Build() (string, []interface{}) { - return q.BuildSelect() -} - -// BuildSelect 构建 SELECT SQL 语句 -func (q *QueryBuilder) BuildSelect() (string, []interface{}) { - var builder strings.Builder - - // SELECT 部分 - builder.WriteString("SELECT ") - if len(q.selectCols) > 0 { - // 如果指定了选择字段,直接使用 - builder.WriteString(strings.Join(q.selectCols, ", ")) - } else if len(q.omitCols) > 0 { - // 如果没有指定 select 但设置了 omit,需要从模型获取所有字段并排除 omit 的字段 - fields := q.getAllFields() - if len(fields) > 0 { - builder.WriteString(strings.Join(fields, ", ")) - } else { - // 无法获取字段信息,使用 * - builder.WriteString("*") - } - } else { - // 默认选择所有字段 - builder.WriteString("*") - } - - // FROM 部分 - builder.WriteString(" FROM ") - if q.table != "" { - builder.WriteString(q.table) - } else if q.model != nil { - // 从模型获取表名 - mapper := NewFieldMapper() - builder.WriteString(mapper.GetTableName(q.model)) - } else { - builder.WriteString("unknown_table") - } - - // JOIN 部分 - if q.joinSQL != "" { - builder.WriteString(" ") - builder.WriteString(q.joinSQL) - } - - // WHERE 部分 - if q.whereSQL != "" { - builder.WriteString(" WHERE ") - builder.WriteString(q.whereSQL) - } - - // GROUP BY 部分 - if q.groupSQL != "" { - builder.WriteString(" GROUP BY ") - builder.WriteString(q.groupSQL) - } - - // HAVING 部分 - if q.havingSQL != "" { - builder.WriteString(" HAVING ") - builder.WriteString(q.havingSQL) - } - - // ORDER BY 部分 - if q.orderSQL != "" { - builder.WriteString(" ORDER BY ") - builder.WriteString(q.orderSQL) - } - - // LIMIT 部分 - if q.limit > 0 { - builder.WriteString(fmt.Sprintf(" LIMIT %d", q.limit)) - } - - // OFFSET 部分 - if q.offset > 0 { - builder.WriteString(fmt.Sprintf(" OFFSET %d", q.offset)) - } - - // 合并参数 - allArgs := make([]interface{}, 0) - allArgs = append(allArgs, q.joinArgs...) - allArgs = append(allArgs, q.whereArgs...) - allArgs = append(allArgs, q.havingArgs...) - - return builder.String(), allArgs -} - -// getAllFields 获取模型的所有字段(排除 omit 的字段) -func (q *QueryBuilder) getAllFields() []string { - var fields []string - - // 如果有模型,从模型获取字段 - if q.model != nil { - mapper := NewFieldMapper() - fieldInfos := mapper.GetFields(q.model) - - // 创建 omit 字段的 map 用于快速查找 - omitMap := make(map[string]bool) - for _, omitField := range q.omitCols { - // 同时存储原始形式和小写形式,支持不区分大小写的匹配 - omitMap[omitField] = true - omitMap[strings.ToLower(omitField)] = true - } - - // 遍历所有字段,排除 omit 的字段 - for _, fieldInfo := range fieldInfos { - // 检查字段是否在 omit 列表中 - if !omitMap[fieldInfo.Column] && !omitMap[strings.ToLower(fieldInfo.Column)] { - fields = append(fields, fieldInfo.Column) - } - } - } else if q.table != "" { - // 如果只有表名没有模型,从数据库元数据获取字段 - columns, err := q.getTableColumns(q.table) - if err != nil { - // 如果获取失败,返回 nil 使用 SELECT * - return nil - } - - // 创建 omit 字段的 map 用于快速查找 - omitMap := make(map[string]bool) - for _, omitField := range q.omitCols { - omitMap[omitField] = true - omitMap[strings.ToLower(omitField)] = true - } - - // 过滤掉 omit 的字段 - for _, col := range columns { - if !omitMap[col] && !omitMap[strings.ToLower(col)] { - fields = append(fields, col) - } - } - } - - return fields -} - -// getTableColumns 从数据库元数据获取表的列名 -func (q *QueryBuilder) getTableColumns(tableName string) ([]string, error) { - if q.db == nil || q.db.db == nil { - return nil, fmt.Errorf("数据库连接未初始化") - } - - var query string - var args []interface{} - var rows *sql.Rows - var err error - - // 根据不同数据库类型查询元数据 - switch q.db.driverName { - case "mysql": - query = ` - SELECT COLUMN_NAME - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ? - ORDER BY ORDINAL_POSITION - ` - args = []interface{}{tableName} - case "postgres": - query = ` - SELECT column_name - FROM information_schema.columns - WHERE table_schema = 'public' AND table_name = $1 - ORDER BY ordinal_position - ` - args = []interface{}{tableName} - case "sqlite", "sqlite3": - query = `PRAGMA table_info(?)` - args = []interface{}{tableName} - default: - // 未知数据库类型,返回空 - return nil, fmt.Errorf("不支持的数据库类型:%s", q.db.driverName) - } - - rows, err = q.db.db.Query(query, args...) - if err != nil { - return nil, fmt.Errorf("查询表元数据失败:%w", err) - } - defer rows.Close() - - var columns []string - for rows.Next() { - var columnName string - if q.db.driverName == "sqlite" || q.db.driverName == "sqlite3" { - // SQLite PRAGMA table_info 返回多列:cid, name, type, notnull, dflt_value, pk - var cid int - var typ string - var notNull int - var dfltValue sql.NullString - var pk int - if err := rows.Scan(&cid, &columnName, &typ, ¬Null, &dfltValue, &pk); err != nil { - return nil, err - } - } else { - if err := rows.Scan(&columnName); err != nil { - return nil, err - } - } - columns = append(columns, columnName) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return columns, nil -} - -// executePreload 执行预加载关联数据 -func (q *QueryBuilder) executePreload(models interface{}) error { - // 创建关联加载器 - loader := NewRelationLoader(q.db) - - // 遍历所有预加载的关联关系 - for relation, conditions := range q.preloadRelations { - if err := loader.Preload(models, relation, conditions...); err != nil { - return err - } - } - return nil -} - -// BuildUpdate 构建 UPDATE SQL 语句 -func (q *QueryBuilder) BuildUpdate(data interface{}) (string, []interface{}) { - var builder strings.Builder - var args []interface{} - - builder.WriteString("UPDATE ") - if q.table != "" { - builder.WriteString(q.table) - } else if q.model != nil { - mapper := NewFieldMapper() - builder.WriteString(mapper.GetTableName(q.model)) - } else { - builder.WriteString("unknown_table") - } - - builder.WriteString(" SET ") - - // 根据 data 类型生成 SET 子句 - switch v := data.(type) { - case map[string]interface{}: - // map 类型,生成 key=value 对 - setParts := make([]string, 0, len(v)) - for key, value := range v { - setParts = append(setParts, fmt.Sprintf("%s = ?", key)) - args = append(args, value) - } - builder.WriteString(strings.Join(setParts, ", ")) - case string: - // string 类型,直接使用(注意:实际使用需要转义) - builder.WriteString(v) - default: - // 结构体类型,使用字段映射器 - mapper := NewFieldMapper() - columns, err := mapper.StructToColumns(data) - if err == nil && len(columns) > 0 { - setParts := make([]string, 0, len(columns)) - for key := range columns { - setParts = append(setParts, fmt.Sprintf("%s = ?", key)) - args = append(args, columns[key]) - } - builder.WriteString(strings.Join(setParts, ", ")) - } - } - - // WHERE 部分 - if q.whereSQL != "" { - builder.WriteString(" WHERE ") - builder.WriteString(q.whereSQL) - args = append(args, q.whereArgs...) - } - - return builder.String(), args -} - -// BuildDelete 构建 DELETE SQL 语句 -func (q *QueryBuilder) BuildDelete() (string, []interface{}) { - var builder strings.Builder - - builder.WriteString("DELETE FROM ") - if q.table != "" { - builder.WriteString(q.table) - } else if q.model != nil { - mapper := NewFieldMapper() - builder.WriteString(mapper.GetTableName(q.model)) - } else { - builder.WriteString("unknown_table") - } - - if q.whereSQL != "" { - builder.WriteString(" WHERE ") - builder.WriteString(q.whereSQL) - } - - return builder.String(), q.whereArgs -} diff --git a/db/core/read_write.go b/db/core/read_write.go deleted file mode 100644 index 0496222..0000000 --- a/db/core/read_write.go +++ /dev/null @@ -1,124 +0,0 @@ -package core - -import ( - "database/sql" - "sync" - "sync/atomic" -) - -// ReadWriteDB 读写分离数据库连接 -type ReadWriteDB struct { - master *sql.DB // 主库(写) - slaves []*sql.DB // 从库列表(读) - policy ReadPolicy // 读负载均衡策略 - counter uint64 // 轮询计数器 - mu sync.RWMutex // 读写锁 -} - -// NewReadWriteDB 创建读写分离数据库连接 -func NewReadWriteDB(master *sql.DB, slaves []*sql.DB, policy ReadPolicy) *ReadWriteDB { - return &ReadWriteDB{ - master: master, - slaves: slaves, - policy: policy, - } -} - -// GetMaster 获取主库连接(用于写操作) -func (rw *ReadWriteDB) GetMaster() *sql.DB { - return rw.master -} - -// GetSlave 获取从库连接(用于读操作) -func (rw *ReadWriteDB) GetSlave() *sql.DB { - rw.mu.RLock() - defer rw.mu.RUnlock() - - if len(rw.slaves) == 0 { - // 没有从库,使用主库 - return rw.master - } - - switch rw.policy { - case Random: - // 随机选择一个从库 - idx := int(atomic.LoadUint64(&rw.counter)) % len(rw.slaves) - return rw.slaves[idx] - - case RoundRobin: - // 轮询选择从库 - idx := int(atomic.AddUint64(&rw.counter, 1)) % len(rw.slaves) - return rw.slaves[idx] - - case LeastConn: - // 选择连接数最少的从库(简化实现) - return rw.selectLeastConn() - - default: - return rw.slaves[0] - } -} - -// selectLeastConn 选择连接数最少的从库 -func (rw *ReadWriteDB) selectLeastConn() *sql.DB { - if len(rw.slaves) == 0 { - return rw.master - } - - minConn := -1 - selected := rw.slaves[0] - - for _, slave := range rw.slaves { - stats := slave.Stats() - openConnections := stats.OpenConnections - - if minConn == -1 || openConnections < minConn { - minConn = openConnections - selected = slave - } - } - - return selected -} - -// AddSlave 添加从库 -func (rw *ReadWriteDB) AddSlave(slave *sql.DB) { - rw.mu.Lock() - defer rw.mu.Unlock() - rw.slaves = append(rw.slaves, slave) -} - -// RemoveSlave 移除从库 -func (rw *ReadWriteDB) RemoveSlave(slave *sql.DB) { - rw.mu.Lock() - defer rw.mu.Unlock() - - for i, s := range rw.slaves { - if s == slave { - rw.slaves = append(rw.slaves[:i], rw.slaves[i+1:]...) - break - } - } -} - -// Close 关闭所有连接 -func (rw *ReadWriteDB) Close() error { - rw.mu.Lock() - defer rw.mu.Unlock() - - // 关闭主库 - if rw.master != nil { - if err := rw.master.Close(); err != nil { - return err - } - } - - // 关闭所有从库 - for _, slave := range rw.slaves { - if err := slave.Close(); err != nil { - return err - } - } - - return nil -} diff --git a/db/core/relation.go b/db/core/relation.go deleted file mode 100644 index 06eecfb..0000000 --- a/db/core/relation.go +++ /dev/null @@ -1,367 +0,0 @@ -package core - -import ( - "fmt" - "reflect" - "strings" -) - -// RelationType 关联类型 -type RelationType int - -const ( - HasOne RelationType = iota // 一对一 - HasMany // 一对多 - BelongsTo // 多对一 - ManyToMany // 多对多 -) - -// RelationInfo 关联信息 -type RelationInfo struct { - Type RelationType // 关联类型 - Field string // 字段名 - Model interface{} // 关联的模型 - FK string // 外键 - PK string // 主键 - JoinTable string // 中间表(多对多) - JoinFK string // 中间表外键 - JoinJoinFK string // 中间表关联外键 -} - -// RelationLoader 关联加载器 - 处理模型关联的预加载 -type RelationLoader struct { - db *Database -} - -// NewRelationLoader 创建关联加载器实例 -func NewRelationLoader(db *Database) *RelationLoader { - return &RelationLoader{db: db} -} - -// Preload 预加载关联数据 -func (rl *RelationLoader) Preload(models interface{}, relation string, conditions ...interface{}) error { - // 获取反射对象 - modelsVal := reflect.ValueOf(models) - if modelsVal.Kind() != reflect.Ptr { - return fmt.Errorf("models 必须是指针类型") - } - - elem := modelsVal.Elem() - if elem.Kind() != reflect.Slice { - return fmt.Errorf("models 必须是指向 Slice 的指针") - } - - if elem.Len() == 0 { - return nil // 空 Slice,无需加载 - } - - // 解析关联关系 - relationInfo, err := rl.parseRelation(elem.Index(0).Interface(), relation) - if err != nil { - return fmt.Errorf("解析关联失败:%w", err) - } - - // 根据关联类型加载数据 - switch relationInfo.Type { - case HasOne: - return rl.loadHasOne(elem, relationInfo) - case HasMany: - return rl.loadHasMany(elem, relationInfo) - case BelongsTo: - return rl.loadBelongsTo(elem, relationInfo) - case ManyToMany: - return rl.loadManyToMany(elem, relationInfo) - default: - return fmt.Errorf("不支持的关联类型:%v", relationInfo.Type) - } -} - -// parseRelation 解析关联关系 -func (rl *RelationLoader) parseRelation(model interface{}, relation string) (*RelationInfo, error) { - // 从结构体字段中解析关联信息 - structType := reflect.TypeOf(model) - if structType.Kind() == reflect.Ptr { - structType = structType.Elem() - } - - // 查找对应的字段 - var relationField reflect.StructField - var found bool - - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - if field.Name == relation { - relationField = field - found = true - break - } - } - - if !found { - return nil, fmt.Errorf("字段 %s 不存在", relation) - } - - // 从 gorm 标签解析关联信息 - gormTag := relationField.Tag.Get("gorm") - fkTag := relationField.Tag.Get("foreignkey") - referencesTag := relationField.Tag.Get("references") - joinTableTag := relationField.Tag.Get("many2many") - - // 初始化关联信息 - info := &RelationInfo{ - Field: relation, - Model: reflect.New(relationField.Type).Interface(), - } - - // 判断关联类型 - if relationField.Type.Kind() == reflect.Slice { - // 一对多或多对多 - if joinTableTag != "" { - // 多对多 - info.Type = ManyToMany - info.JoinTable = joinTableTag - } else { - // 一对多 - info.Type = HasMany - } - } else { - // 一对一或多对一 - // 根据外键位置判断 - if fkTag != "" || referencesTag != "" { - // 如果当前模型包含外键,则是多对一 - info.Type = BelongsTo - } else { - // 否则是一对一 - info.Type = HasOne - } - } - - // 解析外键和主键 - if gormTag != "" { - // 解析 GORM 风格的标签 - parts := strings.Split(gormTag, ";") - for _, part := range parts { - kv := strings.Split(part, ":") - if len(kv) == 2 { - key := strings.TrimSpace(kv[0]) - value := strings.TrimSpace(kv[1]) - switch key { - case "ForeignKey": - info.FK = value - case "References": - info.PK = value - case "JoinTable": - info.JoinTable = value - case "JoinForeignKey": - info.JoinFK = value - case "JoinReferences": - info.JoinJoinFK = value - } - } - } - } - - // 使用单独的标签 - if fkTag != "" { - info.FK = fkTag - } - if referencesTag != "" { - info.PK = referencesTag - } - - // 设置默认值 - if info.FK == "" { - // 默认外键为当前模型名 + Id - modelName := structType.Name() - info.FK = modelName + "Id" - } - if info.PK == "" { - info.PK = "id" - } - - return info, nil -} - -// loadHasOne 加载一对一关联 -func (rl *RelationLoader) loadHasOne(models reflect.Value, relation *RelationInfo) error { - // 收集所有主键值 - pkValues := make([]interface{}, 0, models.Len()) - for i := 0; i < models.Len(); i++ { - model := models.Index(i).Interface() - pk := rl.getFieldValue(model, "ID") - if pk != nil { - pkValues = append(pkValues, pk) - } - } - - if len(pkValues) == 0 { - return nil - } - - // 查询关联数据 - query := rl.db.Model(relation.Model) - query.Where(fmt.Sprintf("%s IN (?)", relation.FK), pkValues) - - // 执行查询并映射到模型 - relatedData := reflect.New(reflect.SliceOf(reflect.TypeOf(relation.Model))).Interface() - if err := query.Find(relatedData); err != nil { - return err - } - - // 将关联数据映射到模型 - relatedVal := reflect.ValueOf(relatedData) - if relatedVal.Kind() == reflect.Ptr { - relatedVal = relatedVal.Elem() - } - - // 遍历所有模型,设置关联字段 - for i := 0; i < models.Len(); i++ { - model := models.Index(i) - pk := rl.getFieldValue(model.Interface(), "ID") - - // 查找对应的关联数据 - for j := 0; j < relatedVal.Len(); j++ { - item := relatedVal.Index(j).Interface() - itemFK := rl.getFieldValue(item, relation.FK) - if itemFK != nil && fmt.Sprintf("%v", itemFK) == fmt.Sprintf("%v", pk) { - model.Elem().FieldByName(relation.Field).Set(reflect.ValueOf(item)) - break - } - } - } - - return nil -} - -// loadHasMany 加载一对多关联 -func (rl *RelationLoader) loadHasMany(models reflect.Value, relation *RelationInfo) error { - // 一对多的逻辑与 HasOne 类似,但结果必须映射到 Slice - return rl.loadHasOne(models, relation) -} - -// loadBelongsTo 加载多对一关联 -func (rl *RelationLoader) loadBelongsTo(models reflect.Value, relation *RelationInfo) error { - // 收集所有外键值 - fkValues := make([]interface{}, 0, models.Len()) - for i := 0; i < models.Len(); i++ { - model := models.Index(i).Interface() - fk := rl.getFieldValue(model, relation.FK) - if fk != nil { - fkValues = append(fkValues, fk) - } - } - - if len(fkValues) == 0 { - return nil - } - - // 查询关联数据 - query := rl.db.Model(relation.Model) - query.Where(fmt.Sprintf("%s IN (?)", relation.PK), fkValues) - - // 执行查询 - relatedData := reflect.New(reflect.SliceOf(reflect.TypeOf(relation.Model))).Interface() - if err := query.Find(relatedData); err != nil { - return err - } - - // 将关联数据映射到模型 - relatedVal := reflect.ValueOf(relatedData) - if relatedVal.Kind() == reflect.Ptr { - relatedVal = relatedVal.Elem() - } - - // 遍历所有模型,设置关联字段 - for i := 0; i < models.Len(); i++ { - model := models.Index(i) - fk := rl.getFieldValue(model.Interface(), relation.FK) - - // 查找对应的关联数据 - for j := 0; j < relatedVal.Len(); j++ { - item := relatedVal.Index(j).Interface() - itemPK := rl.getFieldValue(item, relation.PK) - if itemPK != nil && fmt.Sprintf("%v", itemPK) == fmt.Sprintf("%v", fk) { - model.Elem().FieldByName(relation.Field).Set(reflect.ValueOf(item)) - break - } - } - } - - return nil -} - -// loadManyToMany 加载多对多关联 -func (rl *RelationLoader) loadManyToMany(models reflect.Value, relation *RelationInfo) error { - // 多对多需要通过中间表查询 - // SELECT * FROM table WHERE id IN ( - // SELECT join_fk FROM join_table WHERE fk IN (pk_values) - // ) - - // 收集所有主键值 - pkValues := make([]interface{}, 0, models.Len()) - for i := 0; i < models.Len(); i++ { - model := models.Index(i).Interface() - pk := rl.getFieldValue(model, "ID") - if pk != nil { - pkValues = append(pkValues, pk) - } - } - - if len(pkValues) == 0 { - return nil - } - - // 检查中间表配置 - if relation.JoinTable == "" || relation.JoinFK == "" || relation.JoinJoinFK == "" { - return fmt.Errorf("多对多关联需要配置中间表信息") - } - - // 先从中间表获取关联关系 - joinQuery := rl.db.Table(relation.JoinTable) - joinQuery.Where(fmt.Sprintf("%s IN (?)", relation.JoinFK), pkValues) - - // 这里简化处理,实际应该查询中间表获取关联 ID 列表 - // 然后查询关联模型 - - return fmt.Errorf("多对多关联实现中,请稍后使用") -} - -// getFieldValue 获取字段的值 -func (rl *RelationLoader) getFieldValue(model interface{}, fieldName string) interface{} { - val := reflect.ValueOf(model) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - field := val.FieldByName(fieldName) - if field.IsValid() && field.CanInterface() { - return field.Interface() - } - - return nil -} - -// getRelationTags 从结构体字段提取关联标签信息 -func getRelationTags(structType reflect.Type, fieldName string) map[string]string { - tags := make(map[string]string) - - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - if field.Name == fieldName { - gormTag := field.Tag.Get("gorm") - if gormTag != "" { - // 解析 GORM 风格的标签 - parts := strings.Split(gormTag, ";") - for _, part := range parts { - kv := strings.Split(part, ":") - if len(kv) == 2 { - tags[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) - } - } - } - break - } - } - - return tags -} diff --git a/db/core/result_mapper.go b/db/core/result_mapper.go deleted file mode 100644 index a2601f5..0000000 --- a/db/core/result_mapper.go +++ /dev/null @@ -1,173 +0,0 @@ -package core - -import ( - "database/sql" - "fmt" - "reflect" -) - -// ResultSetMapper 结果集映射器 - 将查询结果映射到 Slice 或 Struct -type ResultSetMapper struct { - fieldMapper IFieldMapper -} - -// NewResultSetMapper 创建结果集映射器实例 -func NewResultSetMapper() *ResultSetMapper { - return &ResultSetMapper{ - fieldMapper: NewFieldMapper(), - } -} - -// MapToSlice 将查询结果映射到 Slice -func (rsm *ResultSetMapper) MapToSlice(rows *sql.Rows, result interface{}) error { - // 获取反射对象 - resultVal := reflect.ValueOf(result) - - // 必须是指针类型 - if resultVal.Kind() != reflect.Ptr { - return fmt.Errorf("result 必须是指针类型") - } - - elem := resultVal.Elem() - - // 必须是 Slice 类型 - if elem.Kind() != reflect.Slice { - return fmt.Errorf("result 必须是指向 Slice 的指针") - } - - // 获取 Slice 的元素类型 - sliceType := elem.Type().Elem() - var isPtr bool - if sliceType.Kind() == reflect.Ptr { - isPtr = true - sliceType = sliceType.Elem() - } - - if sliceType.Kind() != reflect.Struct { - return fmt.Errorf("Slice 的元素必须是结构体") - } - - // 获取列信息 - columns, err := rows.Columns() - if err != nil { - return fmt.Errorf("获取列信息失败:%w", err) - } - - // 建立列名到字段的映射 - fieldMap := make(map[string]int) - for i := 0; i < sliceType.NumField(); i++ { - field := sliceType.Field(i) - dbTag := field.Tag.Get("db") - - if dbTag != "" && dbTag != "-" { - // 使用 db 标签 - fieldMap[dbTag] = i - // 同时存储小写版本用于不区分大小写的匹配 - fieldMap[dbTag] = i - } else { - // 使用字段名的小写形式 - fieldMap[sliceType.Field(i).Name] = i - } - } - - // 循环读取每一行数据 - for rows.Next() { - // 创建新的结构体实例 - var item reflect.Value - if isPtr { - item = reflect.New(sliceType) - } else { - item = reflect.New(sliceType).Elem() - } - - // 创建扫描目标 - scanTargets := make([]interface{}, len(columns)) - - for i, col := range columns { - // 查找对应的字段 - var fieldIndex int - found := false - - // 尝试精确匹配 - if idx, ok := fieldMap[col]; ok { - fieldIndex = idx - found = true - } else { - // 尝试不区分大小写匹配 - colLower := col - for key, idx := range fieldMap { - if key == colLower { - fieldIndex = idx - found = true - break - } - } - } - - if found { - var field reflect.Value - if isPtr { - field = item.Elem().Field(fieldIndex) - } else { - field = item.Field(fieldIndex) - } - - if field.CanSet() { - scanTargets[i] = field.Addr().Interface() - } else { - // 字段不可设置,使用占位符 - var dummy interface{} - scanTargets[i] = &dummy - } - } else { - // 没有找到对应字段,使用占位符 - var dummy interface{} - scanTargets[i] = &dummy - } - } - - // 执行扫描 - if err := rows.Scan(scanTargets...); err != nil { - return fmt.Errorf("扫描数据失败:%w", err) - } - - // 处理时间字段格式化(目前保持原始 time.Time 值,由 JSON 序列化时格式化) - // Go 的 database/sql 会自动将数据库时间扫描到 time.Time 类型 - // 在 JSON 序列化时,model.Time 的 MarshalJSON 会格式化为指定格式 - - // 添加到 Slice - if isPtr { - elem.Set(reflect.Append(elem, item)) - } else { - elem.Set(reflect.Append(elem, item)) - } - } - - return nil -} - -// MapToStruct 将查询结果映射到单个 Struct -func (rsm *ResultSetMapper) MapToStruct(rows *sql.Rows, result interface{}) error { - // 使用 FieldMapper 的实现 - return rsm.fieldMapper.ColumnsToStruct(rows, result) -} - -// ScanAll 通用扫描方法,自动识别 Slice 或 Struct -func (rsm *ResultSetMapper) ScanAll(rows *sql.Rows, result interface{}) error { - val := reflect.ValueOf(result) - if val.Kind() != reflect.Ptr { - return fmt.Errorf("result 必须是指针类型") - } - - elem := val.Elem() - - // 判断是 Slice 还是 Struct - switch elem.Kind() { - case reflect.Slice: - return rsm.MapToSlice(rows, result) - case reflect.Struct: - return rsm.MapToStruct(rows, result) - default: - return fmt.Errorf("不支持的目标类型:%s", elem.Kind()) - } -} diff --git a/db/core/soft_delete.go b/db/core/soft_delete.go deleted file mode 100644 index 3471abd..0000000 --- a/db/core/soft_delete.go +++ /dev/null @@ -1,44 +0,0 @@ -package core - -import ( - "time" -) - -// SoftDelete 软删除模型 - 嵌入到需要软删除的模型中 -type SoftDelete struct { - DeletedAt *time.Time `json:"deleted_at" db:"deleted_at"` // 删除时间(为空表示未删除) -} - -// IsDeleted 检查是否已删除 -func (sd *SoftDelete) IsDeleted() bool { - return sd.DeletedAt != nil -} - -// Delete 标记为已删除 -func (sd *SoftDelete) Delete() { - now := time.Now() - sd.DeletedAt = &now -} - -// Restore 恢复(取消删除) -func (sd *SoftDelete) Restore() { - sd.DeletedAt = nil -} - -// ISoftDeleter 软删除接口 - 定义软删除相关方法 -type ISoftDeleter interface { - IsDeleted() bool - Delete() - Restore() -} - -// applySoftDelete 在查询中应用软删除过滤 -func applySoftDelete(q IQuery, unscoped bool) IQuery { - if unscoped { - // 忽略软删除,包含已删除的记录 - return q - } - - // 默认只查询未删除的记录 - return q.Where("deleted_at IS NULL") -} diff --git a/db/core/table_columns_test.go b/db/core/table_columns_test.go deleted file mode 100644 index 4a2eb81..0000000 --- a/db/core/table_columns_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package core - -import ( - "fmt" - "testing" -) - -// TestGetTableColumns 测试从数据库元数据获取字段 -func TestGetTableColumns(t *testing.T) { - fmt.Println("\n=== 测试获取表字段 ===") - - // 注意:这个测试需要真实的数据库连接 - // 以下是使用示例: - - // 1. 使用 Table() 方法时自动获取字段 - // db, err := AutoConnect(true) - // if err != nil { - // t.Fatal(err) - // } - // - // // 查询 user 表的所有字段 - // var users []User - // err = db.Table("user").Find(&users) - // if err != nil { - // t.Fatal(err) - // } - // - // // 排除某些字段 - // var users2 []User - // err = db.Table("user").Omit("password", "created_at").Find(&users2) - // if err != nil { - // t.Fatal(err) - // } - - fmt.Println("✓ getTableColumns 已实现") - fmt.Println("支持的数据库类型:MySQL, PostgreSQL, SQLite") - fmt.Println("✓ 测试通过") -} - -// TestGetAllFields 测试 getAllFields 方法 -func TestGetAllFields(t *testing.T) { - fmt.Println("\n=== 测试 getAllFields ===") - - // 场景 1: 有模型时,从模型获取字段 - // 场景 2: 只有表名时,从数据库元数据获取字段 - - fmt.Println("场景 1: 从模型获取字段 - 已实现") - fmt.Println("场景 2: 从数据库元数据获取字段 - 已实现") - fmt.Println("✓ 测试通过") -} - -// ExampleQueryBuilder_Table_getTableColumns 使用示例 -func exampleTableColumnsUsage() { - // 示例 1: 查询表的所有字段 - // var results []map[string]interface{} - // err := db.Table("users").Find(&results) - - // 示例 2: 排除某些字段 - // err := db.Table("users").Omit("password", "secret_key").Find(&results) - - // 示例 3: 选择特定字段 - // err := db.Table("users").Select("id", "username", "email").Find(&results) - - fmt.Println("使用示例请查看测试代码") -} diff --git a/db/core/transaction.go b/db/core/transaction.go deleted file mode 100644 index 63997cb..0000000 --- a/db/core/transaction.go +++ /dev/null @@ -1,447 +0,0 @@ -package core - -import ( - "database/sql" - "fmt" - "reflect" - "strings" - "sync" - "time" -) - -// Transaction 事务实现 - ITx 接口的具体实现 -type Transaction struct { - db *Database // 数据库连接 - tx *sql.Tx // 底层事务对象 - debug bool // 调试模式开关 -} - -// 同步池优化 - 复用 slice 减少内存分配 -var insertArgsPool = sync.Pool{ - New: func() interface{} { - return make([]interface{}, 0, 20) - }, -} - -var colNamesPool = sync.Pool{ - New: func() interface{} { - return make([]string, 0, 20) - }, -} - -// Begin 开始一个新事务 -func (d *Database) Begin() (ITx, error) { - if d.db == nil { - return nil, fmt.Errorf("数据库连接未初始化") - } - - tx, err := d.db.Begin() - if err != nil { - return nil, fmt.Errorf("开启事务失败:%w", err) - } - - return &Transaction{ - db: d, - tx: tx, - debug: d.debug, - }, nil -} - -// Transaction 执行事务 - 自动管理事务的提交和回滚 -func (d *Database) Transaction(fn func(ITx) error) error { - // 开启事务 - tx, err := d.Begin() - if err != nil { - return fmt.Errorf("开启事务失败:%w", err) - } - - defer func() { - // 如果有 panic,回滚事务 - if r := recover(); r != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - fmt.Printf("[Magic-ORM] 事务回滚失败:%v\n", rollbackErr) - } - panic(r) - } - }() - - // 执行用户提供的函数 - if err := fn(tx); err != nil { - // 如果出错,回滚事务 - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("事务执行失败且回滚也失败:%v, %w", rollbackErr, err) - } - return fmt.Errorf("事务执行失败:%w", err) - } - - // 提交事务 - if err := tx.Commit(); err != nil { - return fmt.Errorf("事务提交失败:%w", err) - } - - return nil -} - -// Commit 提交事务 -func (t *Transaction) Commit() error { - if t.tx == nil { - return fmt.Errorf("事务对象为空") - } - return t.tx.Commit() -} - -// Rollback 回滚事务 -func (t *Transaction) Rollback() error { - if t.tx == nil { - return fmt.Errorf("事务对象为空") - } - return t.tx.Rollback() -} - -// Model 在事务中基于模型创建查询 -func (t *Transaction) Model(model interface{}) IQuery { - return &QueryBuilder{ - db: t.db, - model: model, - tx: t.tx, // 使用事务对象 - debug: t.debug, - } -} - -// Table 在事务中基于表名创建查询 -func (t *Transaction) Table(name string) IQuery { - return &QueryBuilder{ - db: t.db, - table: name, - tx: t.tx, // 使用事务对象 - debug: t.debug, - } -} - -// Insert 插入数据到数据库 -func (t *Transaction) Insert(model interface{}) (int64, error) { - // 获取字段映射器 - mapper := NewFieldMapper() - - // 获取表名和字段信息 - tableName := mapper.GetTableName(model) - columns, err := mapper.StructToColumns(model) - if err != nil { - return 0, fmt.Errorf("获取字段信息失败:%w", err) - } - - if len(columns) == 0 { - return 0, fmt.Errorf("没有有效的字段") - } - - // 获取时间配置 - timeConfig := t.db.timeConfig - if timeConfig == nil { - timeConfig = DefaultTimeConfig() - } - - // 自动处理时间字段(使用配置的字段名) - now := time.Now() - for col, val := range columns { - // 检查是否是配置的时间字段 - if col == timeConfig.GetCreatedAt() || col == timeConfig.GetUpdatedAt() || col == timeConfig.GetDeletedAt() { - // 如果是零值时间,自动设置为当前时间 - if t.isZeroTimeValue(val) { - columns[col] = now - } - } - } - - // 生成 INSERT SQL - var sqlBuilder strings.Builder - sqlBuilder.Grow(128) // 预分配内存 - sqlBuilder.WriteString(fmt.Sprintf("INSERT INTO %s (", tableName)) - - // 列名 - 使用预分配内存 - colNames := colNamesPool.Get().([]string) - colNames = colNames[:0] // 重置长度但不释放内存 - placeholders := make([]string, 0, len(columns)) - args := insertArgsPool.Get().([]interface{}) - args = args[:0] // 重置长度但不释放内存 - defer func() { - colNamesPool.Put(colNames) - insertArgsPool.Put(args) - }() - - for col, val := range columns { - colNames = append(colNames, col) - placeholders = append(placeholders, "?") - args = append(args, val) - } - - sqlBuilder.WriteString(strings.Join(colNames, ", ")) - sqlBuilder.WriteString(") VALUES (") - sqlBuilder.WriteString(strings.Join(placeholders, ", ")) - sqlBuilder.WriteString(")") - - sqlStr := sqlBuilder.String() - - // 调试模式 - if t.debug { - fmt.Printf("[Magic-ORM] TX INSERT SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args) - } - - // 执行插入 - result, err := t.tx.Exec(sqlStr, args...) - if err != nil { - return 0, fmt.Errorf("插入失败:%w", err) - } - - // 获取插入的 ID - id, err := result.LastInsertId() - if err != nil { - return 0, fmt.Errorf("获取插入 ID 失败:%w", err) - } - - return id, nil -} - -// BatchInsert 批量插入数据 -func (t *Transaction) BatchInsert(models interface{}, batchSize int) error { - // 使用反射获取 Slice 数据 - modelsVal := reflect.ValueOf(models) - if modelsVal.Kind() != reflect.Ptr || modelsVal.Elem().Kind() != reflect.Slice { - return fmt.Errorf("models 必须是指向 Slice 的指针") - } - - sliceVal := modelsVal.Elem() - length := sliceVal.Len() - - if length == 0 { - return nil // 空 Slice,无需插入 - } - - // 分批处理 - for i := 0; i < length; i += batchSize { - end := i + batchSize - if end > length { - end = length - } - - // 处理当前批次 - for j := i; j < end; j++ { - model := sliceVal.Index(j).Interface() - _, err := t.Insert(model) - if err != nil { - return fmt.Errorf("批量插入第%d条记录失败:%w", j, err) - } - } - } - - return nil -} - -// isZeroTimeValue 检查是否是零值时间 -func (t *Transaction) isZeroTimeValue(val interface{}) bool { - if val == nil { - return true - } - - // 检查是否是 time.Time 类型 - if tm, ok := val.(time.Time); ok { - return tm.IsZero() || tm.UnixNano() == 0 - } - - // 使用反射检查 - v := reflect.ValueOf(val) - switch v.Kind() { - case reflect.Ptr: - return v.IsNil() - case reflect.Struct: - // 如果是 time.Time 结构 - if tm, ok := v.Interface().(time.Time); ok { - return tm.IsZero() || tm.UnixNano() == 0 - } - } - - return false -} - -// Update 更新数据 -func (t *Transaction) Update(model interface{}, data map[string]interface{}) error { - // 获取字段映射器 - mapper := NewFieldMapper() - - // 获取表名和主键 - tableName := mapper.GetTableName(model) - pk := mapper.GetPrimaryKey(model) - - // 获取时间配置 - timeConfig := t.db.timeConfig - if timeConfig == nil { - timeConfig = DefaultTimeConfig() - } - - // 自动处理 updated_at 时间字段(使用配置的字段名) - if data == nil { - data = make(map[string]interface{}) - } - data[timeConfig.GetUpdatedAt()] = time.Now() - - // 过滤零值 - pf := NewParamFilter() - data = pf.FilterZeroValues(data) - - if len(data) == 0 { - return fmt.Errorf("没有有效的更新字段") - } - - // 生成 UPDATE SQL - var sqlBuilder strings.Builder - sqlBuilder.Grow(128) // 预分配内存 - sqlBuilder.WriteString(fmt.Sprintf("UPDATE %s SET ", tableName)) - - setParts := make([]string, 0, len(data)) - args := insertArgsPool.Get().([]interface{}) - args = args[:0] // 重置长度但不释放内存 - defer func() { - insertArgsPool.Put(args) - }() - - for col, val := range data { - setParts = append(setParts, fmt.Sprintf("%s = ?", col)) - args = append(args, val) - } - - sqlBuilder.WriteString(strings.Join(setParts, ", ")) - sqlBuilder.WriteString(fmt.Sprintf(" WHERE %s = ?", pk)) - - // 获取主键值 - pkValue := reflect.ValueOf(model) - if pkValue.Kind() == reflect.Ptr { - pkValue = pkValue.Elem() - } - idField := pkValue.FieldByName("ID") - if idField.IsValid() { - args = append(args, idField.Interface()) - } else { - return fmt.Errorf("模型缺少 ID 字段") - } - - sqlStr := sqlBuilder.String() - - // 调试模式 - if t.debug { - fmt.Printf("[Magic-ORM] TX UPDATE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args) - } - - // 执行更新 - _, err := t.tx.Exec(sqlStr, args...) - if err != nil { - return fmt.Errorf("更新失败:%w", err) - } - - return nil -} - -// Delete 删除数据(支持软删除) -func (t *Transaction) Delete(model interface{}) error { - // 获取字段映射器 - mapper := NewFieldMapper() - - // 获取表名和主键 - tableName := mapper.GetTableName(model) - pk := mapper.GetPrimaryKey(model) - - // 获取时间配置 - timeConfig := t.db.timeConfig - if timeConfig == nil { - timeConfig = DefaultTimeConfig() - } - - // 检查是否支持软删除(是否有配置的 deleted_at 字段) - hasSoftDelete := false - pkValue := reflect.ValueOf(model) - if pkValue.Kind() == reflect.Ptr { - pkValue = pkValue.Elem() - } - - // 检查是否有 DeletedAt 字段(使用配置的字段名) - deletedAtField := pkValue.FieldByNameFunc(func(fieldName string) bool { - // 将字段名转换为数据库列名进行比较 - expectedCol := timeConfig.GetDeletedAt() - // 简单转换:下划线转驼峰 - return fieldName == "DeletedAt" || fieldName == expectedCol - }) - - if deletedAtField.IsValid() { - hasSoftDelete = true - } - - var sqlStr string - args := make([]interface{}, 0) - - if hasSoftDelete { - // 软删除:更新 deleted_at 为当前时间(使用配置的字段名) - sqlStr = fmt.Sprintf("UPDATE %s SET %s = ? WHERE %s = ?", tableName, timeConfig.GetDeletedAt(), pk) - args = append(args, time.Now()) - } else { - // 硬删除:直接 DELETE - sqlStr = fmt.Sprintf("DELETE FROM %s WHERE %s = ?", tableName, pk) - } - - // 获取主键值 - idField := pkValue.FieldByName("ID") - if idField.IsValid() { - args = append(args, idField.Interface()) - } else { - return fmt.Errorf("模型缺少 ID 字段") - } - - // 调试模式 - if t.debug { - deleteType := "硬删除" - if hasSoftDelete { - deleteType = "软删除" - } - fmt.Printf("[Magic-ORM] TX %s SQL: %s\n[Magic-ORM] Args: %v\n", deleteType, sqlStr, args) - } - - // 执行删除 - _, err := t.tx.Exec(sqlStr, args...) - if err != nil { - return fmt.Errorf("删除失败:%w", err) - } - - return nil -} - -// Query 在事务中执行原生 SQL 查询 -func (t *Transaction) Query(result interface{}, query string, args ...interface{}) error { - if t.debug { - fmt.Printf("[Magic-ORM] TX Query SQL: %s\n[Magic-ORM] Args: %v\n", query, args) - } - - rows, err := t.tx.Query(query, args...) - if err != nil { - return fmt.Errorf("事务查询失败:%w", err) - } - defer rows.Close() - - // 使用 ResultSetMapper 将查询结果映射到 result - mapper := NewResultSetMapper() - if err := mapper.ScanAll(rows, result); err != nil { - return fmt.Errorf("结果映射失败:%w", err) - } - - return nil -} - -// Exec 在事务中执行原生 SQL -func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) { - if t.debug { - fmt.Printf("[Magic-ORM] TX Exec SQL: %s\n[Magic-ORM] Args: %v\n", query, args) - } - - result, err := t.tx.Exec(query, args...) - if err != nil { - return nil, fmt.Errorf("事务执行失败:%w", err) - } - - return result, nil -} diff --git a/db/core/transaction_query_test.go b/db/core/transaction_query_test.go deleted file mode 100644 index e01498c..0000000 --- a/db/core/transaction_query_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package core - -import ( - "fmt" - "testing" - "time" -) - -// TestTransactionQuery 测试事务中的 Query 方法 -func TestTransactionQuery(t *testing.T) { - fmt.Println("\n=== 测试事务 Query 方法 ===") - - // 注意:这个测试需要真实的数据库连接 - // 以下是使用示例: - - // 示例 1: 基本用法 - // err := db.Transaction(func(tx ITx) error { - // var results []map[string]interface{} - // err := tx.Query(&results, "SELECT * FROM users WHERE status = ?", "active") - // if err != nil { - // return err - // } - // fmt.Printf("查询到 %d 条记录\n", len(results)) - // return nil - // }) - // if err != nil { - // t.Fatal(err) - // } - - // 示例 2: 查询到结构体 - // type User struct { - // ID int64 `json:"id" db:"id"` - // Username string `json:"username" db:"username"` - // Email string `json:"email" db:"email"` - // } - // var user User - // err := tx.Query(&user, "SELECT * FROM users WHERE id = ?", 1) - - // 示例 3: 结合事务的其他操作 - // err = db.Transaction(func(tx ITx) error { - // // 插入数据 - // user := &User{Username: "test", Email: "test@example.com"} - // _, err := tx.Insert(user) - // if err != nil { - // return err - // } - // - // // 查询验证 - // var inserted User - // err = tx.Query(&inserted, "SELECT * FROM users WHERE username = ?", "test") - // if err != nil { - // return err - // } - // - // return nil - // }) - - fmt.Println("✓ Transaction.Query 已实现") - fmt.Println("功能:") - fmt.Println(" - 支持查询到 Slice 类型") - fmt.Println(" - 支持查询到 Struct 类型") - fmt.Println(" - 自动映射查询结果") - fmt.Println(" - 在事务上下文中执行") - fmt.Println("✓ 测试通过") -} - -// TestTransactionQueryWithModel 测试事务中使用 Model 查询 -func TestTransactionQueryWithModel(t *testing.T) { - fmt.Println("\n=== 测试事务 Model 查询 ===") - - // 示例:使用 Model() 方法而不是原生 SQL - // err := db.Transaction(func(tx ITx) error { - // var users []User - // err := tx.Model(&User{}).Where("status = ?", "active").Find(&users) - // if err != nil { - // return err - // } - // return nil - // }) - - fmt.Println("✓ 事务 Model 查询功能正常") - fmt.Println("✓ 测试通过") -} - -// TestTransactionRollback 测试事务回滚时的查询 -func TestTransactionRollback(t *testing.T) { - fmt.Println("\n=== 测试事务回滚 ===") - - // 示例:测试回滚场景 - // shouldRollback := true - // err := db.Transaction(func(tx ITx) error { - // // 插入数据 - // user := &User{Username: "rollback_test", Email: "test@example.com"} - // _, err := tx.Insert(user) - // if err != nil { - // return err - // } - // - // // 查询验证 - // var count int64 - // tx.Model(&User{}).Where("username = ?", "rollback_test").Count(&count) - // fmt.Printf("插入后数量:%d\n", count) - // - // // 模拟错误,触发回滚 - // if shouldRollback { - // return fmt.Errorf("模拟错误") - // } - // return nil - // }) - // - // if err == nil { - // t.Error("期望返回错误") - // } - - fmt.Println("✓ 事务回滚机制正常") - fmt.Println("✓ 测试通过") -} - -// ExampleTransactionQuery 使用示例 -func exampleTransactionQueryUsage() { - // 示例 1: 基本查询 - // db.Transaction(func(tx ITx) error { - // var results []map[string]interface{} - // return tx.Query(&results, "SELECT * FROM users LIMIT 10") - // }) - - // 示例 2: 带参数查询 - // db.Transaction(func(tx ITx) error { - // var users []User - // return tx.Query(&users, "SELECT * FROM users WHERE age > ? ORDER BY created_at DESC", 18) - // }) - - // 示例 3: 复杂业务逻辑 - // db.Transaction(func(tx ITx) error { - // // 1. 查询用户 - // var user User - // if err := tx.Query(&user, "SELECT * FROM users WHERE id = ?", 1); err != nil { - // return err - // } - // - // // 2. 更新余额 - // _, err := tx.Exec("UPDATE accounts SET balance = balance - ? WHERE user_id = ?", 100, user.ID) - // if err != nil { - // return err - // } - // - // // 3. 记录交易日志 - // log := &TransactionLog{ - // UserID: user.ID, - // Amount: 100, - // Type: "debit", - // } - // _, err = tx.Insert(log) - // return err - // }) -} - -// TestTransactionQueryEdgeCases 测试边界情况 -func TestTransactionQueryEdgeCases(t *testing.T) { - fmt.Println("\n=== 测试边界情况 ===") - - // 测试 1: 空结果集 - // var emptyResults []User - // err := db.Transaction(func(tx ITx) error { - // return tx.Query(&emptyResults, "SELECT * FROM users WHERE id = -1") - // }) - // if err != nil { - // t.Errorf("空结果集不应该返回错误:%v", err) - // } - // if len(emptyResults) != 0 { - // t.Errorf("期望空结果集,得到 %d 条记录", len(emptyResults)) - // } - - // 测试 2: 单条结果 - // var singleUser User - // err := db.Transaction(func(tx ITx) error { - // return tx.Query(&singleUser, "SELECT * FROM users WHERE id = ?", 1) - // }) - - // 测试 3: 多条结果 - // var multipleUsers []User - // err := db.Transaction(func(tx ITx) error { - // return tx.Query(&multipleUsers, "SELECT * FROM users LIMIT 5") - // }) - - fmt.Println("✓ 边界情况处理正常") - fmt.Println("✓ 测试通过") -} - -// 测试模型定义 -type TestUser struct { - ID int64 `json:"id" db:"id"` - Username string `json:"username" db:"username"` - Email string `json:"email" db:"email"` - CreatedAt time.Time `json:"created_at" db:"created_at"` -} - -func (TestUser) TableName() string { - return "users" -} - -type TestTransactionLog struct { - ID int64 `json:"id" db:"id"` - UserID int64 `json:"user_id" db:"user_id"` - Amount float64 `json:"amount" db:"amount"` - Type string `json:"type" db:"type"` - CreatedAt time.Time `json:"created_at" db:"created_at"` -} - -func (TestTransactionLog) TableName() string { - return "transaction_logs" -} diff --git a/db/core_test.go b/db/core_test.go deleted file mode 100644 index 5c34cfb..0000000 --- a/db/core_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package main - -import ( - "fmt" - "testing" - - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/model" -) - -// TestFieldMapper 测试字段映射器 -func TestFieldMapper(t *testing.T) { - fmt.Println("\n=== 测试字段映射器 ===") - - mapper := core.NewFieldMapper() - user := &model.User{ - ID: 1, - Username: "test", - Email: "test@example.com", - Status: 1, - } - - // 测试获取表名 - tableName := mapper.GetTableName(user) - fmt.Printf("表名:%s\n", tableName) - if tableName != "user" { - t.Errorf("期望表名为 user,实际为 %s", tableName) - } - - // 测试获取主键 - pk := mapper.GetPrimaryKey(user) - fmt.Printf("主键:%s\n", pk) - if pk != "id" { - t.Errorf("期望主键为 id,实际为 %s", pk) - } - - // 测试获取字段信息 - fields := mapper.GetFields(user) - fmt.Printf("字段数量:%d\n", len(fields)) - for _, field := range fields { - fmt.Printf(" - %s (%s): %s [%s]\n", - field.Name, field.Column, field.Type, field.DbType) - } - - // 测试结构体转列 - columns, err := mapper.StructToColumns(user) - if err != nil { - t.Errorf("StructToColumns 失败:%v", err) - } - fmt.Printf("转换后的列:%+v\n", columns) - - fmt.Println("✓ 字段映射器测试通过") -} - -// TestQueryBuilder 测试查询构建器 -func TestQueryBuilder(t *testing.T) { - fmt.Println("\n=== 测试查询构建器 ===") - - db := &core.Database{} - - // 测试 SELECT 查询 - q1 := db.Table("user"). - Select("id", "username", "email"). - Where("status = ?", 1). - OrderBy("created_at", "DESC"). - Limit(10) - - sql1, args1 := q1.Build() - fmt.Printf("SELECT SQL: %s\n", sql1) - fmt.Printf("参数:%v\n", args1) - - // 测试 UPDATE - q2 := db.Table("user"). - Where("id = ?", 1) - - sql2, args2 := q2.BuildUpdate(map[string]interface{}{ - "email": "new@example.com", - "status": 1, - }) - fmt.Printf("UPDATE SQL: %s\n", sql2) - fmt.Printf("参数:%v\n", args2) - - // 测试 DELETE - q3 := db.Table("user").Where("status = ?", 0) - sql3, args3 := q3.BuildDelete() - fmt.Printf("DELETE SQL: %s\n", sql3) - fmt.Printf("参数:%v\n", args3) - - fmt.Println("✓ 查询构建器测试通过") -} - -// TestMigrator 测试迁移管理器 -func TestMigrator(t *testing.T) { - fmt.Println("\n=== 测试迁移管理器 ===") - - // 注意:由于还未建立真实数据库连接,这里仅测试 SQL 生成 - // 实际使用需要创建真实的数据库连接 - - fmt.Println("提示:迁移管理器需要真实数据库连接才能完整测试") - fmt.Println("✓ 迁移管理器代码结构测试通过") -} - -// TestTransaction 测试事务管理 -func TestTransaction(t *testing.T) { - fmt.Println("\n=== 测试事务管理 ===") - - // 测试事务流程(伪代码) - fmt.Println("事务流程:") - fmt.Println("1. db.Begin() - 开启事务") - fmt.Println("2. tx.Insert() - 执行插入") - fmt.Println("3. tx.Commit() - 提交事务") - fmt.Println("或 tx.Rollback() - 回滚事务") - - fmt.Println("✓ 事务管理代码结构测试通过") -} - -// TestDriverManager 测试驱动管理器 -func TestDriverManager(t *testing.T) { - fmt.Println("\n=== 测试驱动管理器 ===") - - // 驱动管理器已在 driver/manager.go 中实现 - fmt.Println("支持的驱动:") - fmt.Println(" - MySQL") - fmt.Println(" - SQLite") - fmt.Println(" - PostgreSQL") - fmt.Println(" - SQL Server") - fmt.Println(" - Oracle") - fmt.Println(" - ClickHouse") - - fmt.Println("✓ 驱动管理器测试通过") -} diff --git a/db/driver/clickhouse.go b/db/driver/clickhouse.go deleted file mode 100644 index aa83060..0000000 --- a/db/driver/clickhouse.go +++ /dev/null @@ -1,32 +0,0 @@ -package driver - -import ( - "database/sql" - "database/sql/driver" -) - -// ClickHouseDriver ClickHouse 数据库驱动实现 -type ClickHouseDriver struct { - driverName string // 驱动名称 -} - -// NewClickHouseDriver 创建 ClickHouse 驱动实例 -func NewClickHouseDriver(driverName string) *ClickHouseDriver { - if driverName == "" { - driverName = "clickhouse" - } - return &ClickHouseDriver{ - driverName: driverName, - } -} - -// Open 打开数据库连接 -func (d *ClickHouseDriver) Open(name string) (driver.Conn, error) { - // 作为包装器,实际的连接建立应该通过 sql.Open - return nil, nil -} - -// OpenDB 打开数据库连接(使用 sql.DB) -func (d *ClickHouseDriver) OpenDB(dataSourceName string) (*sql.DB, error) { - return sql.Open(d.driverName, dataSourceName) -} diff --git a/db/driver/driver_test.go b/db/driver/driver_test.go deleted file mode 100644 index e389f87..0000000 --- a/db/driver/driver_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package driver - -import ( - "fmt" - "testing" -) - -// TestDriverRegistration 测试驱动注册功能 -func TestDriverRegistration(t *testing.T) { - fmt.Println("\n=== 测试驱动注册功能 ===") - - // 获取默认驱动管理器 - manager := GetDefaultManager() - - // 在纯自研设计中,我们需要先手动注册驱动才能使用 - // 这里我们注册一个通用驱动作为示例(实际使用时需要先导入第三方驱动) - - // 测试列出所有驱动 - drivers := manager.ListDrivers() - fmt.Printf("✓ 已注册驱动列表:%v\n", drivers) - - fmt.Println("✓ 驱动注册测试通过") -} - -// TestRegisterDriverByConfig 测试根据配置注册驱动 -func TestRegisterDriverByConfig(t *testing.T) { - fmt.Println("\n=== 测试根据配置注册驱动 ===") - - manager := GetDefaultManager() - - // 测试不支持的数据库类型 - err := manager.RegisterDriverByConfig("unsupported") - if err == nil { - t.Error("不支持的数据库类型应该返回错误") - } else { - fmt.Printf("✓ 不支持的数据库类型返回错误:%v\n", err) - } - - // 测试已注册的驱动类型(应该返回提示信息,因为没有实际注册驱动) - err = manager.RegisterDriverByConfig("mysql") - if err != nil { - fmt.Printf("✓ MySQL 配置驱动返回提示信息:%v\n", err) - } else { - fmt.Println("✓ MySQL 配置驱动注册成功") - } - - fmt.Println("✓ 根据配置注册驱动测试通过") -} - -// TestMultipleRegistrations 测试重复注册 -func TestMultipleRegistrations(t *testing.T) { - fmt.Println("\n=== 测试重复注册 ===") - - manager := GetDefaultManager() - - // 在实际使用中,用户可以注册他们选择的驱动 - // 例如:注册一个通用驱动 - genericDriver := NewGenericDriver("sqlite3") - _ = manager.Register("sqlite3", genericDriver) - // 这里可能成功或失败,取决于是否已经注册了该驱动名 - - fmt.Println("✓ 重复注册测试通过") -} - -// TestDriverOpen 测试打开数据库连接 -func TestDriverOpen(t *testing.T) { - fmt.Println("\n=== 测试打开数据库连接 ===") - - // 在纯自研设计中,我们不直接打开连接,而是提供接口给使用者 - // 这里我们只是验证驱动结构的创建 - - // 创建一个通用驱动 - genericDriver := NewGenericDriver("sqlite3") - if genericDriver.driverName != "sqlite3" { - t.Errorf("期望驱动名为 sqlite3,实际为 %s", genericDriver.driverName) - } - - fmt.Println("✓ 打开数据库连接测试通过") -} - -// ExampleRegisterDriverByConfig 使用示例 -func exampleRegisterDriverByConfig() { - manager := GetDefaultManager() - - // 在实际应用中,用户需要先导入他们选择的数据库驱动 - // import _ "github.com/mattn/go-sqlite3" // SQLite 驱动 - // import _ "github.com/go-sql-driver/mysql" // MySQL 驱动 - - // 然后注册对应的驱动 - sqliteDriver := NewGenericDriver("sqlite3") - manager.Register("sqlite3", sqliteDriver) - - mysqlDriver := NewGenericDriver("mysql") - manager.Register("mysql", mysqlDriver) - - // 从配置文件读取数据库类型 - configType := "mysql" // 这通常来自配置文件 - - // 验证驱动是否已注册 - err := manager.RegisterDriverByConfig(configType) - if err != nil { - fmt.Printf("驱动未注册,请先注册:%v\n", err) - return - } - - fmt.Printf("成功验证 %s 驱动注册\n", configType) -} - -// ExampleUseWithConfig 使用配置的示例 -func exampleUseWithConfig() { - // 这是一个伪代码示例,展示如何与配置文件结合使用 - /* - // 用户需要先导入并注册他们选择的驱动 - import _ "github.com/mattn/go-sqlite3" - - manager := driver.GetDefaultManager() - - // 注册驱动 - manager.Register("sqlite3", &driver.GenericDriver{driverName: "sqlite3"}) - - // 加载配置 - config, err := config.LoadFromFile("config.yaml") - if err != nil { - log.Fatal("加载配置失败:", err) - } - - // 验证驱动注册 - err = manager.RegisterDriverByConfig(config.Database.Type) - if err != nil { - log.Fatal("驱动未注册:", err) - } - - // 打开数据库连接(使用标准库) - db, err := manager.Open(config.Database.GetDriverName(), config.Database.BuildDSN()) - if err != nil { - log.Fatal("打开数据库失败:", err) - } - - // 使用 db 进行数据库操作 - */ -} - -// TestDriverAvailability 测试驱动可用性检测 -func TestDriverAvailability(t *testing.T) { - fmt.Println("\n=== 测试驱动可用性检测 ===") - - manager := GetDefaultManager() - - // 测试未注册的驱动 - isAvailable := manager.isDriverAvailable("sqlite3") - fmt.Printf("✓ SQLite 驱动可用性:%v\n", isAvailable) - - isAvailable = manager.isDriverAvailable("mysql") - fmt.Printf("✓ MySQL 驱动可用性:%v\n", isAvailable) - - fmt.Println("✓ 驱动可用性检测测试通过") -} diff --git a/db/driver/manager.go b/db/driver/manager.go deleted file mode 100644 index 209f9da..0000000 --- a/db/driver/manager.go +++ /dev/null @@ -1,170 +0,0 @@ -package driver - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "sync" -) - -// IDriverManager 驱动管理器接口 - 统一管理所有数据库驱动的注册和使用 -type IDriverManager interface { - // 注册驱动 - Register(name string, d driver.Driver) error // 注册一个新的数据库驱动 - // 获取驱动 - GetDriver(name string) (driver.Driver, error) // 根据名称获取已注册的驱动 - // 列出所有驱动 - ListDrivers() []string // 列出所有已注册的驱动名称 - // 打开数据库连接 - Open(driverName, dataSource string) (*sql.DB, error) // 使用指定驱动打开数据库连接 -} - -// DriverManager 驱动管理器实现 - 管理所有数据库驱动的生命周期 -type DriverManager struct { - mu sync.RWMutex // 读写锁,保证并发安全 - drivers map[string]driver.Driver // 存储所有已注册的驱动 - sqlDBs map[string]*sql.DB // 存储已创建的数据库连接池 -} - -var defaultManager *DriverManager // 默认驱动管理器实例 -var once sync.Once // 单例模式同步控制 - -// GetDefaultManager 获取默认驱动管理器 - 使用单例模式确保全局唯一实例 -func GetDefaultManager() *DriverManager { - once.Do(func() { - defaultManager = &DriverManager{ - drivers: make(map[string]driver.Driver), // 初始化驱动映射表 - sqlDBs: make(map[string]*sql.DB), // 初始化连接池映射表 - } - // 注册所有内置驱动 - defaultManager.registerBuiltinDrivers() - }) - return defaultManager -} - -// registerBuiltinDrivers 注册所有内置驱动 - 自动注册框架自带的所有数据库驱动 -func (dm *DriverManager) registerBuiltinDrivers() { - // 注意:在这个纯自研 ORM 设计中,我们不自动注册任何具体的数据库驱动 - // 驱动由使用者在应用程序中注册,例如: - // - // import _ "github.com/mattn/go-sqlite3" // 注册 SQLite 驱动 - // import _ "github.com/go-sql-driver/mysql" // 注册 MySQL 驱动 - // - // 然后使用 dm.Register("sqlite3", &driver.OfficialDriver{"sqlite3"}) - - // 我们只提供一个机制,让使用者可以注册他们选择的驱动 - // 这样可以完全避免对特定第三方驱动的硬依赖 -} - -// isDriverAvailable 检查驱动是否可用(根据导入的包) -func (dm *DriverManager) isDriverAvailable(driverName string) bool { - // 检查指定名称的驱动是否已注册 - _, err := dm.GetDriver(driverName) - return err == nil -} - -// RegisterDriverByConfig 根据配置自动注册驱动 -// 在纯自研设计中,此方法提示用户手动注册驱动 -func (dm *DriverManager) RegisterDriverByConfig(configType string) error { - switch configType { - case "mysql", "postgres", "sqlite", "sqlite3", "sqlserver", "oracle", "clickhouse": - // 检查驱动是否已经注册 - if !dm.isDriverAvailable(configType) { - // 如果驱动未注册,返回指导信息 - return fmt.Errorf("驱动 '%s' 未注册。请在应用中导入并注册相应的驱动,例如: import _ \"github.com/mattn/go-sqlite3\"", configType) - } - default: - return fmt.Errorf("不支持的数据库类型:%s", configType) - } - return nil -} - -// Register 注册驱动 - 将新的数据库驱动注册到管理器中 -func (dm *DriverManager) Register(name string, d driver.Driver) error { - dm.mu.Lock() - defer dm.mu.Unlock() - - if _, exists := dm.drivers[name]; exists { - return nil // 已存在,不重复注册 - } - - dm.drivers[name] = d - return nil -} - -// GetDriver 获取驱动 - 根据驱动名称查找并返回已注册的驱动 -func (dm *DriverManager) GetDriver(name string) (driver.Driver, error) { - dm.mu.RLock() - defer dm.mu.RUnlock() - - d, exists := dm.drivers[name] - if !exists { - return nil, ErrDriverNotFound // 驱动未找到错误 - } - - return d, nil -} - -// ListDrivers 列出所有驱动 - 返回所有已注册的驱动名称列表 -func (dm *DriverManager) ListDrivers() []string { - dm.mu.RLock() - defer dm.mu.RUnlock() - - names := make([]string, 0, len(dm.drivers)) - for name := range dm.drivers { - names = append(names, name) - } - return names -} - -// Open 打开数据库连接 - 使用指定驱动和数据源创建数据库连接池 -func (dm *DriverManager) Open(driverName, dataSource string) (*sql.DB, error) { - dm.mu.Lock() - defer dm.mu.Unlock() - - // 检查是否已有连接(避免重复创建) - key := driverName + ":" + dataSource - if db, exists := dm.sqlDBs[key]; exists { - return db, nil - } - - // 获取驱动 - d, err := dm.GetDriver(driverName) - if err != nil { - return nil, err - } - - // 创建连接器(需要驱动实现 Connector 接口) - connector, ok := d.(driver.Connector) - if !ok { - return nil, ErrDriverNotConnector // 驱动不支持 Connector 接口 - } - - // 创建 sql.DB 连接池 - db := sql.OpenDB(connector) - dm.sqlDBs[key] = db - - return db, nil -} - -// Close 关闭指定连接 - 释放特定数据源的数据库连接池 -func (dm *DriverManager) Close(driverName, dataSource string) error { - dm.mu.Lock() - defer dm.mu.Unlock() - - key := driverName + ":" + dataSource - if db, exists := dm.sqlDBs[key]; exists { - if err := db.Close(); err != nil { - return err - } - delete(dm.sqlDBs, key) - } - return nil -} - -// 错误定义 - 定义驱动管理器相关的错误类型 -var ( - ErrDriverNotFound = errors.New("driver not found") // 驱动未找到错误 - ErrDriverNotConnector = errors.New("driver does not implement Connector interface") // 驱动不支持 Connector 接口错误 -) diff --git a/db/driver/mysql.go b/db/driver/mysql.go deleted file mode 100644 index c0d3f2f..0000000 --- a/db/driver/mysql.go +++ /dev/null @@ -1,32 +0,0 @@ -package driver - -import ( - "database/sql" - "database/sql/driver" -) - -// MySQLDriver MySQL 数据库驱动实现 -type MySQLDriver struct { - driverName string // 驱动名称 -} - -// NewMySQLDriver 创建 MySQL 驱动实例 -func NewMySQLDriver(driverName string) *MySQLDriver { - if driverName == "" { - driverName = "mysql" - } - return &MySQLDriver{ - driverName: driverName, - } -} - -// Open 打开数据库连接 -func (d *MySQLDriver) Open(name string) (driver.Conn, error) { - // 作为包装器,实际的连接建立应该通过 sql.Open - return nil, nil -} - -// OpenDB 打开数据库连接(使用 sql.DB) -func (d *MySQLDriver) OpenDB(dataSourceName string) (*sql.DB, error) { - return sql.Open(d.driverName, dataSourceName) -} diff --git a/db/driver/oracle.go b/db/driver/oracle.go deleted file mode 100644 index 6c251ea..0000000 --- a/db/driver/oracle.go +++ /dev/null @@ -1,32 +0,0 @@ -package driver - -import ( - "database/sql" - "database/sql/driver" -) - -// OracleDriver Oracle 数据库驱动实现 -type OracleDriver struct { - driverName string // 驱动名称 -} - -// NewOracleDriver 创建 Oracle 驱动实例 -func NewOracleDriver(driverName string) *OracleDriver { - if driverName == "" { - driverName = "oracle" - } - return &OracleDriver{ - driverName: driverName, - } -} - -// Open 打开数据库连接 -func (d *OracleDriver) Open(name string) (driver.Conn, error) { - // 作为包装器,实际的连接建立应该通过 sql.Open - return nil, nil -} - -// OpenDB 打开数据库连接(使用 sql.DB) -func (d *OracleDriver) OpenDB(dataSourceName string) (*sql.DB, error) { - return sql.Open(d.driverName, dataSourceName) -} diff --git a/db/driver/postgres.go b/db/driver/postgres.go deleted file mode 100644 index 4f1b9e3..0000000 --- a/db/driver/postgres.go +++ /dev/null @@ -1,32 +0,0 @@ -package driver - -import ( - "database/sql" - "database/sql/driver" -) - -// PostgresDriver PostgreSQL 数据库驱动实现 -type PostgresDriver struct { - driverName string // 驱动名称 -} - -// NewPostgresDriver 创建 PostgreSQL 驱动实例 -func NewPostgresDriver(driverName string) *PostgresDriver { - if driverName == "" { - driverName = "postgres" - } - return &PostgresDriver{ - driverName: driverName, - } -} - -// Open 打开数据库连接 -func (d *PostgresDriver) Open(name string) (driver.Conn, error) { - // 作为包装器,实际的连接建立应该通过 sql.Open - return nil, nil -} - -// OpenDB 打开数据库连接(使用 sql.DB) -func (d *PostgresDriver) OpenDB(dataSourceName string) (*sql.DB, error) { - return sql.Open(d.driverName, dataSourceName) -} diff --git a/db/driver/sqlite.go b/db/driver/sqlite.go deleted file mode 100644 index 011589c..0000000 --- a/db/driver/sqlite.go +++ /dev/null @@ -1,30 +0,0 @@ -package driver - -import ( - "database/sql" - "database/sql/driver" -) - -// GenericDriver 通用驱动包装器 - 用于包装任何实现了 driver.Driver 接口的驱动 -type GenericDriver struct { - driverName string // 驱动名称 -} - -// NewGenericDriver 创建通用驱动实例 -func NewGenericDriver(driverName string) *GenericDriver { - return &GenericDriver{ - driverName: driverName, - } -} - -// Open 打开数据库连接 -func (d *GenericDriver) Open(name string) (driver.Conn, error) { - // 由于我们只是包装器,实际的连接建立应该通过 sql.Open - // 这里返回错误,因为实际使用时应通过 sql.DB 进行操作 - return nil, nil -} - -// OpenDB 打开数据库连接(使用 sql.DB) -func (d *GenericDriver) OpenDB(dataSourceName string) (*sql.DB, error) { - return sql.Open(d.driverName, dataSourceName) -} diff --git a/db/driver/sqlserver.go b/db/driver/sqlserver.go deleted file mode 100644 index 5ed5a3a..0000000 --- a/db/driver/sqlserver.go +++ /dev/null @@ -1,32 +0,0 @@ -package driver - -import ( - "database/sql" - "database/sql/driver" -) - -// SQLServerDriver SQL Server 数据库驱动实现 -type SQLServerDriver struct { - driverName string // 驱动名称 -} - -// NewSQLServerDriver 创建 SQL Server 驱动实例 -func NewSQLServerDriver(driverName string) *SQLServerDriver { - if driverName == "" { - driverName = "sqlserver" - } - return &SQLServerDriver{ - driverName: driverName, - } -} - -// Open 打开数据库连接 -func (d *SQLServerDriver) Open(name string) (driver.Conn, error) { - // 作为包装器,实际的连接建立应该通过 sql.Open - return nil, nil -} - -// OpenDB 打开数据库连接(使用 sql.DB) -func (d *SQLServerDriver) OpenDB(dataSourceName string) (*sql.DB, error) { - return sql.Open(d.driverName, dataSourceName) -} diff --git a/db/features_test.go b/db/features_test.go deleted file mode 100644 index f34e5d9..0000000 --- a/db/features_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package main - -import ( - "fmt" - "testing" - "time" - - "git.magicany.cc/black1552/gin-base/db/core" -) - -// TestResultSetMapper 测试结果集映射器 -func TestResultSetMapper(t *testing.T) { - fmt.Println("\n=== 测试结果集映射器 ===") - - mapper := core.NewResultSetMapper() - _ = mapper // 避免编译警告 - fmt.Printf("结果集映射器已创建\n") - fmt.Printf("功能:自动识别 Slice/Struct 并映射查询结果\n") - fmt.Printf("✓ 结果集映射器测试通过\n") -} - -// TestSoftDelete 测试软删除功能 -func TestSoftDelete(t *testing.T) { - fmt.Println("\n=== 测试软删除功能 ===") - - sd := &core.SoftDelete{} - - // 初始状态 - if sd.IsDeleted() { - t.Error("初始状态不应被删除") - } - fmt.Println("初始状态:未删除") - - // 标记删除 - sd.Delete() - if !sd.IsDeleted() { - t.Error("应该被删除") - } - fmt.Println("删除后状态:已删除") - - // 恢复 - sd.Restore() - if sd.IsDeleted() { - t.Error("恢复后不应被删除") - } - fmt.Println("恢复后状态:未删除") - - fmt.Println("✓ 软删除功能测试通过") -} - -// TestQueryCache 测试查询缓存 -func TestQueryCache(t *testing.T) { - fmt.Println("\n=== 测试查询缓存 ===") - - cache := core.NewQueryCache(5 * time.Minute) - - // 设置缓存 - cache.Set("test_key", "test_value") - fmt.Println("设置缓存:test_key = test_value") - - // 获取缓存 - value, exists := cache.Get("test_key") - if !exists { - t.Error("缓存应该存在") - } - if value != "test_value" { - t.Errorf("期望 test_value,实际为 %v", value) - } - fmt.Printf("获取缓存:%v\n", value) - - // 删除缓存 - cache.Delete("test_key") - _, exists = cache.Get("test_key") - if exists { - t.Error("缓存应该已被删除") - } - fmt.Println("删除缓存成功") - - // 测试缓存键生成 - key1 := core.GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1) - key2 := core.GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1) - if key1 != key2 { - t.Error("相同 SQL 和参数应该生成相同的缓存键") - } - fmt.Printf("缓存键:%s\n", key1) - - fmt.Println("✓ 查询缓存测试通过") -} - -// TestReadWriteDB 测试读写分离 -func TestReadWriteDB(t *testing.T) { - fmt.Println("\n=== 测试读写分离 ===") - - // 注意:这里不创建真实的数据库连接,仅测试逻辑 - fmt.Println("读写分离功能:") - fmt.Println(" - 支持主从集群架构") - fmt.Println(" - 写操作使用主库") - fmt.Println(" - 读操作使用从库") - fmt.Println(" - 负载均衡策略:Random/RoundRobin/LeastConn") - fmt.Println("✓ 读写分离代码结构测试通过") -} - -// TestRelationLoader 测试关联加载 -func TestRelationLoader(t *testing.T) { - fmt.Println("\n=== 测试关联加载 ===") - - fmt.Println("支持的关联类型:") - fmt.Println(" - HasOne (一对一)") - fmt.Println(" - HasMany (一对多)") - fmt.Println(" - BelongsTo (多对一)") - fmt.Println(" - ManyToMany (多对多)") - fmt.Println("✓ 关联加载代码结构测试通过") -} - -// TestTracing 测试链路追踪 -func TestTracing(t *testing.T) { - fmt.Println("\n=== 测试链路追踪 ===") - - fmt.Println("OpenTelemetry 集成:") - fmt.Println(" - 自动追踪所有数据库操作") - fmt.Println(" - 记录 SQL 语句和参数") - fmt.Println(" - 记录执行时间和影响行数") - fmt.Println(" - 支持分布式追踪") - fmt.Println("✓ 链路追踪代码结构测试通过") -} - -// TestAllFeatures 综合测试所有新功能 -func TestAllFeatures(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" Magic-ORM 完整功能测试") - fmt.Println("========================================") - - TestResultSetMapper(t) - TestSoftDelete(t) - TestQueryCache(t) - TestReadWriteDB(t) - TestRelationLoader(t) - TestTracing(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有优化功能测试完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已实现的高级功能:") - fmt.Println(" ✓ 结果集自动映射到 Slice") - fmt.Println(" ✓ 软删除功能") - fmt.Println(" ✓ 查询缓存机制") - fmt.Println(" ✓ 主从集群读写分离") - fmt.Println(" ✓ 模型关联(HasOne/HasMany)") - fmt.Println(" ✓ OpenTelemetry 链路追踪") - fmt.Println() -} diff --git a/db/gendb.bat b/db/gendb.bat deleted file mode 100644 index c4980a0..0000000 --- a/db/gendb.bat +++ /dev/null @@ -1,14 +0,0 @@ -@echo off -chcp 65001 >nul -cls -echo. -echo ======================================== -echo Magic-ORM 代码生成器 -echo ======================================== -echo. - -go run ./cmd/gendb %* - -echo. -echo 按任意键退出... -pause >nul diff --git a/db/gendb.exe b/db/gendb.exe deleted file mode 100644 index 4bc87e8..0000000 Binary files a/db/gendb.exe and /dev/null differ diff --git a/db/generator/README.md b/db/generator/README.md deleted file mode 100644 index 95cc97c..0000000 --- a/db/generator/README.md +++ /dev/null @@ -1,307 +0,0 @@ -# Magic-ORM 代码生成器使用指南 - -## 📚 什么是代码生成器? - -代码生成器可以根据数据库表结构自动生成 Model 和 DAO 代码,大幅提高开发效率。 - -## 🚀 快速开始 - -### 1. 创建代码生成器 - -```go -package main - -import ( - "git.magicany.cc/black1552/gin-base/db/generator" -) - -// 创建代码生成器实例 -cg := generator.NewCodeGenerator("./generated") -``` - -### 2. 定义列信息 - -```go -columns := []generator.ColumnInfo{ - { - ColumnName: "id", // 数据库列名 - FieldName: "ID", // Go 字段名(驼峰) - FieldType: "int64", // Go 字段类型 - JSONName: "id", // JSON 标签名 - IsPrimary: true, // 是否主键 - IsNullable: false, // 是否可为空 - }, - { - ColumnName: "username", - FieldName: "Username", - FieldType: "string", - JSONName: "username", - IsPrimary: false, - IsNullable: false, - }, - { - ColumnName: "email", - FieldName: "Email", - FieldType: "string", - JSONName: "email", - IsPrimary: false, - IsNullable: true, - }, - { - ColumnName: "created_at", - FieldName: "CreatedAt", - FieldType: "time.Time", - JSONName: "created_at", - }, -} -``` - -### 3. 生成代码 - -#### 方式一:一键生成(推荐) - -```go -// 同时生成 Model 和 DAO -err := cg.GenerateAll("user", columns) -if err != nil { - panic(err) -} -``` - -#### 方式二:分别生成 - -```go -// 只生成 Model -err := cg.GenerateModel("user", columns) - -// 只生成 DAO -err := cg.GenerateDAO("user", "User") -``` - -## 📁 生成的文件结构 - -``` -generated/ -├── user.go # User Model -├── user_dao.go # User DAO -├── product.go # Product Model -└── product_dao.go # Product DAO -``` - -## 💡 使用生成的代码 - -### 导入包 - -```go -import ( - "context" - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/model" - "your-project/generated" -) -``` - -### 初始化数据库 - -```go -db, err := core.AutoConnect(false) -``` - -### 创建 DAO 实例 - -```go -userDAO := generated.NewUserDAO(db) -``` - -### CRUD 操作 - -```go -// 创建用户 -user := &model.User{ - Username: "john", - Email: "john@example.com", -} -err = userDAO.Create(context.Background(), user) - -// 查询用户 -user, err := userDAO.GetByID(context.Background(), 1) - -// 更新用户 -user.Email = "new@example.com" -err = userDAO.Update(context.Background(), user) - -// 删除用户 -err = userDAO.Delete(context.Background(), 1) - -// 分页查询 -users, err := userDAO.FindByPage(context.Background(), 1, 10) -``` - -## 🔧 API 参考 - -### NewCodeGenerator - -```go -func NewCodeGenerator(outputDir string) *CodeGenerator -``` - -创建代码生成器实例。 - -**参数:** -- `outputDir` - 输出目录路径 - -### GenerateModel - -```go -func (cg *CodeGenerator) GenerateModel(tableName string, columns []ColumnInfo) error -``` - -生成 Model 代码。 - -**参数:** -- `tableName` - 数据库表名 -- `columns` - 列信息数组 - -### GenerateDAO - -```go -func (cg *CodeGenerator) GenerateDAO(tableName string, modelName string) error -``` - -生成 DAO 代码。 - -**参数:** -- `tableName` - 数据库表名 -- `modelName` - Model 名称(驼峰) - -### GenerateAll - -```go -func (cg *CodeGenerator) GenerateAll(tableName string, columns []ColumnInfo) error -``` - -一键生成 Model + DAO(推荐)。 - -**参数:** -- `tableName` - 数据库表名 -- `columns` - 列信息数组 - -## 📋 ColumnInfo 结构 - -```go -type ColumnInfo struct { - ColumnName string // 数据库列名(下划线风格) - FieldName string // Go 字段名(驼峰风格) - FieldType string // Go 数据类型 - JSONName string // JSON 标签名 - IsPrimary bool // 是否主键 - IsNullable bool // 是否可为空 -} -``` - -## 🗺️ 类型映射表 - -| 数据库类型 | Go 类型 | -|-----------|---------| -| INT/BIGINT | int64 | -| VARCHAR/TEXT | string | -| DATETIME | time.Time | -| TIMESTAMP | time.Time | -| BOOLEAN | bool | -| FLOAT/DOUBLE | float64 | -| DECIMAL | string 或 float64 | - -## 🎯 最佳实践 - -### 1. 从数据库读取真实表结构 - -```go -// 示例:从 MySQL INFORMATION_SCHEMA 获取列信息 -query := ` - SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = 'your_database' - AND TABLE_NAME = 'your_table' -` -``` - -### 2. 批量生成所有表 - -```go -tables := []string{"users", "products", "orders"} - -for _, table := range tables { - columns := getColumnsFromDB(table) // 从数据库获取列信息 - err := cg.GenerateAll(table, columns) - if err != nil { - log.Printf("生成 %s 失败:%v", table, err) - } -} -``` - -### 3. 自定义模板 - -修改 `generator.go` 中的模板字符串,添加自定义方法: - -```go -tmpl := `package model - -// {{.ModelName}} {{.TableName}} 模型 -type {{.ModelName}} struct { -{{range .Columns}} - {{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.JSONName}}" db:"{{.ColumnName}}"` + "`" + ` -{{end}} -} - -// 自定义方法 -func (m *{{.ModelName}}) Validate() error { - // 验证逻辑 - return nil -} -` -``` - -### 4. 代码审查 - -- ✅ 检查生成的字段类型是否正确 -- ✅ 验证主键和索引设置 -- ✅ 添加业务逻辑方法 -- ✅ 补充注释和文档 - -### 5. 版本控制 - -```bash -# 将生成的代码纳入 Git 管理 -git add generated/ -git commit -m "feat: 生成用户和产品模块代码" -``` - -## ⚠️ 注意事项 - -1. **不要频繁覆盖**: 手动修改的代码可能会被覆盖 -2. **代码审查**: 生成的代码需要人工审查 -3. **类型映射**: 特殊类型可能需要手动调整 -4. **关联关系**: 复杂的模型关联需要手动实现 -5. **验证逻辑**: 业务验证逻辑需要手动添加 - -## 🎉 总结 - -✅ **优势:** -- 大幅提高开发效率 -- 代码规范统一 -- 减少重复劳动 -- 快速搭建项目骨架 - -✅ **适用场景:** -- 新项目快速启动 -- 数据库表结构变更 -- 批量生成基础代码 -- 原型开发 - -✅ **推荐用法:** -- 使用 `GenerateAll` 一键生成 -- 从真实数据库读取列信息 -- 定期重新生成保持同步 -- 配合版本控制管理代码 - -开始使用代码生成器,提升你的开发效率吧!🚀 diff --git a/db/generator/generator.go b/db/generator/generator.go deleted file mode 100644 index 51cfdbe..0000000 --- a/db/generator/generator.go +++ /dev/null @@ -1,197 +0,0 @@ -package generator - -import ( - "fmt" - "os" - "path/filepath" - "strings" - "text/template" -) - -// CodeGenerator 代码生成器 - 自动生成 Model 和 DAO 代码 -type CodeGenerator struct { - outputDir string // 输出目录 -} - -// NewCodeGenerator 创建代码生成器实例 -func NewCodeGenerator(outputDir string) *CodeGenerator { - return &CodeGenerator{ - outputDir: outputDir, - } -} - -// GenerateModel 生成 Model 代码 -func (cg *CodeGenerator) GenerateModel(tableName string, columns []ColumnInfo) error { - // 生成文件名 - fileName := strings.ToLower(tableName) + ".go" - filePath := filepath.Join(cg.outputDir, fileName) - - // 创建输出目录 - if err := os.MkdirAll(cg.outputDir, 0755); err != nil { - return fmt.Errorf("创建目录失败:%w", err) - } - - // 生成代码 - code := cg.generateModelCode(tableName, columns) - - // 写入文件 - if err := os.WriteFile(filePath, []byte(code), 0644); err != nil { - return fmt.Errorf("写入文件失败:%w", err) - } - - fmt.Printf("[Model] Generated: %s\n", filePath) - return nil -} - -// GenerateDAO 生成 DAO 代码 -func (cg *CodeGenerator) GenerateDAO(tableName string, modelName string) error { - fileName := strings.ToLower(tableName) + "_dao.go" - // DAO 文件输出到 dao 子目录 - daoDir := filepath.Join(cg.outputDir, "dao") - filePath := filepath.Join(daoDir, fileName) - - // 创建输出目录(包括 dao 子目录) - if err := os.MkdirAll(daoDir, 0755); err != nil { - return fmt.Errorf("创建目录失败:%w", err) - } - - // 生成代码 - code := cg.generateDAOCode(tableName, modelName) - - // 写入文件 - if err := os.WriteFile(filePath, []byte(code), 0644); err != nil { - return fmt.Errorf("写入文件失败:%w", err) - } - - fmt.Printf("[DAO] Generated: %s\n", filePath) - return nil -} - -// GenerateAll 生成完整代码(Model + DAO) -func (cg *CodeGenerator) GenerateAll(tableName string, columns []ColumnInfo) error { - modelName := cg.toCamelCase(tableName) - - // 生成 Model - if err := cg.GenerateModel(tableName, columns); err != nil { - return err - } - - // 生成 DAO - if err := cg.GenerateDAO(tableName, modelName); err != nil { - return err - } - - return nil -} - -// generateModelCode 生成 Model 代码 -func (cg *CodeGenerator) generateModelCode(tableName string, columns []ColumnInfo) string { - // 检查是否有时间字段 - hasTime := false - for _, col := range columns { - if col.FieldType == "time.Time" { - hasTime = true - break - } - } - - importStmt := "" - if hasTime { - importStmt = `import "time"` - } else { - importStmt = `` - } - - tmpl := `package model - -` + importStmt + ` -// {{.ModelName}} {{.TableName}} 模型 -type {{.ModelName}} struct { -{{range .Columns}} {{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.JSONName}}" db:"{{.ColumnName}}"` + "`" + ` -{{end}} -} - -// TableName 表名 -func ({{.ModelName}}) TableName() string { - return "{{.TableName}}" -} -` - - data := struct { - ModelName string - TableName string - Columns []ColumnInfo - }{ - ModelName: cg.toCamelCase(tableName), - TableName: tableName, - Columns: columns, - } - - return cg.executeTemplate(tmpl, data) -} - -// generateDAOCode 生成 DAO 代码 - 简化版本,只定义结构体并继承 Database -func (cg *CodeGenerator) generateDAOCode(tableName string, modelName string) string { - tmpl := `package dao - -import ( - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/model" -) - -var {{.ModelName}}Dao = &s{{.ModelName}}DAO{ - DAO: core.NewDAOWithModel(&model.{{.ModelName}}), -} - -// {{.ModelName}}DAO {{.TableName}} 数据访问对象 -// 嵌入 core.DAO,自动获得所有 CRUD 方法 -type s{{.ModelName}}DAO struct { - *core.DAO -}` - - data := struct { - ModelName string - TableName string - }{ - ModelName: cg.toCamelCase(tableName), - TableName: tableName, - } - - return cg.executeTemplate(tmpl, data) -} - -// executeTemplate 执行模板 -func (cg *CodeGenerator) executeTemplate(tmpl string, data interface{}) string { - t := template.Must(template.New("code").Parse(tmpl)) - - var buf strings.Builder - if err := t.Execute(&buf, data); err != nil { - return fmt.Sprintf("// 模板执行错误:%v", err) - } - - return buf.String() -} - -// toCamelCase 转换为驼峰命名 -func (cg *CodeGenerator) toCamelCase(str string) string { - parts := strings.Split(str, "_") - result := "" - - for _, part := range parts { - if len(part) > 0 { - result += strings.ToUpper(string(part[0])) + part[1:] - } - } - - return result -} - -// ColumnInfo 列信息 -type ColumnInfo struct { - ColumnName string // 列名 - FieldName string // 字段名(驼峰) - FieldType string // 字段类型 - JSONName string // JSON 标签名 - IsPrimary bool // 是否主键 - IsNullable bool // 是否可为空 -} diff --git a/db/go.mod b/db/go.mod deleted file mode 100644 index 5252178..0000000 --- a/db/go.mod +++ /dev/null @@ -1,18 +0,0 @@ -module git.magicany.cc/black1552/gin-base/db - -go 1.25 - -require ( - github.com/mattn/go-sqlite3 v1.14.17 - go.opentelemetry.io/otel v1.21.0 - go.opentelemetry.io/otel/trace v1.21.0 - gopkg.in/yaml.v3 v3.0.1 -) - -require ( - filippo.io/edwards25519 v1.1.0 // indirect - github.com/go-logr/logr v1.3.0 // indirect - github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-sql-driver/mysql v1.9.3 // indirect - go.opentelemetry.io/otel/metric v1.21.0 // indirect -) diff --git a/db/go.sum b/db/go.sum deleted file mode 100644 index 981f7b8..0000000 --- a/db/go.sum +++ /dev/null @@ -1,29 +0,0 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= -github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= -github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= -github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc= -go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo= -go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4= -go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM= -go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc= -go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/db/introspector/introspector.go b/db/introspector/introspector.go deleted file mode 100644 index 160e46f..0000000 --- a/db/introspector/introspector.go +++ /dev/null @@ -1,406 +0,0 @@ -package introspector - -import ( - "database/sql" - "fmt" - "strings" - - "git.magicany.cc/black1552/gin-base/db/config" - _ "github.com/go-sql-driver/mysql" -) - -// TableInfo 表信息 -type TableInfo struct { - TableName string // 表名 - Columns []ColumnInfo // 列信息 -} - -// ColumnInfo 列信息 -type ColumnInfo struct { - ColumnName string // 列名 - DataType string // 数据类型 - IsNullable bool // 是否可为空 - ColumnKey string // 键类型(PRI, MUL 等) - ColumnDefault string // 默认值 - Extra string // 额外信息(auto_increment 等) - GoType string // Go 类型 - FieldName string // Go 字段名(驼峰) - JSONName string // JSON 标签名 - IsPrimary bool // 是否主键 -} - -// Introspector 数据库结构检查器 -type Introspector struct { - db *sql.DB - config *config.DatabaseConfig -} - -// NewIntrospector 创建数据库结构检查器 -func NewIntrospector(cfg *config.DatabaseConfig) (*Introspector, error) { - dsn := cfg.BuildDSN() - db, err := sql.Open(cfg.GetDriverName(), dsn) - if err != nil { - return nil, fmt.Errorf("打开数据库连接失败:%w", err) - } - - // 测试连接 - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("连接数据库失败:%w", err) - } - - return &Introspector{ - db: db, - config: cfg, - }, nil -} - -// Close 关闭数据库连接 -func (i *Introspector) Close() error { - return i.db.Close() -} - -// GetTableNames 获取所有表名 -func (i *Introspector) GetTableNames() ([]string, error) { - switch i.config.Type { - case "mysql": - return i.getMySQLTableNames() - case "postgres": - return i.getPostgresTableNames() - case "sqlite": - return i.getSQLiteTableNames() - default: - return nil, fmt.Errorf("不支持的数据库类型:%s", i.config.Type) - } -} - -// getMySQLTableNames 获取 MySQL 所有表名 -func (i *Introspector) getMySQLTableNames() ([]string, error) { - query := ` - SELECT TABLE_NAME - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = ? - ORDER BY TABLE_NAME - ` - - rows, err := i.db.Query(query, i.config.Name) - if err != nil { - return nil, fmt.Errorf("查询表名失败:%w", err) - } - defer rows.Close() - - tableNames := []string{} - for rows.Next() { - var tableName string - if err := rows.Scan(&tableName); err != nil { - return nil, fmt.Errorf("扫描表名失败:%w", err) - } - tableNames = append(tableNames, tableName) - } - - return tableNames, nil -} - -// getPostgresTableNames 获取 PostgreSQL 所有表名 -func (i *Introspector) getPostgresTableNames() ([]string, error) { - query := ` - SELECT table_name - FROM information_schema.tables - WHERE table_schema = 'public' - ORDER BY table_name - ` - - rows, err := i.db.Query(query) - if err != nil { - return nil, fmt.Errorf("查询表名失败:%w", err) - } - defer rows.Close() - - tableNames := []string{} - for rows.Next() { - var tableName string - if err := rows.Scan(&tableName); err != nil { - return nil, fmt.Errorf("扫描表名失败:%w", err) - } - tableNames = append(tableNames, tableName) - } - - return tableNames, nil -} - -// getSQLiteTableNames 获取 SQLite 所有表名 -func (i *Introspector) getSQLiteTableNames() ([]string, error) { - query := `SELECT name FROM sqlite_master WHERE type='table' ORDER BY name` - - rows, err := i.db.Query(query) - if err != nil { - return nil, fmt.Errorf("查询表名失败:%w", err) - } - defer rows.Close() - - tableNames := []string{} - for rows.Next() { - var tableName string - if err := rows.Scan(&tableName); err != nil { - return nil, fmt.Errorf("扫描表名失败:%w", err) - } - // 跳过 SQLite 系统表 - if tableName != "sqlite_sequence" { - tableNames = append(tableNames, tableName) - } - } - - return tableNames, nil -} - -// GetTableInfo 获取表的详细信息 -func (i *Introspector) GetTableInfo(tableName string) (*TableInfo, error) { - switch i.config.Type { - case "mysql": - return i.getMySQLTableInfo(tableName) - case "postgres": - return i.getPostgresTableInfo(tableName) - case "sqlite": - return i.getSQLiteTableInfo(tableName) - default: - return nil, fmt.Errorf("不支持的数据库类型:%s", i.config.Type) - } -} - -// getMySQLTableInfo 获取 MySQL 表信息 -func (i *Introspector) getMySQLTableInfo(tableName string) (*TableInfo, error) { - query := ` - SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_DEFAULT, EXTRA - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? - ORDER BY ORDINAL_POSITION - ` - - rows, err := i.db.Query(query, i.config.Name, tableName) - if err != nil { - return nil, fmt.Errorf("查询列信息失败:%w", err) - } - defer rows.Close() - - columns := []ColumnInfo{} - for rows.Next() { - var col ColumnInfo - var isNullableStr string // MySQL 返回的是字符串 "YES"/"NO" - var columnDefault sql.NullString - - err := rows.Scan(&col.ColumnName, &col.DataType, &isNullableStr, &col.ColumnKey, &columnDefault, &col.Extra) - if err != nil { - return nil, fmt.Errorf("扫描列信息失败:%w", err) - } - - // 将字符串转换为布尔值 - col.IsNullable = isNullableStr == "YES" - - // 转换为 Go 类型 - col.GoType = mapMySQLTypeToGoType(col.DataType) - col.FieldName = toCamelCase(col.ColumnName) - col.JSONName = col.ColumnName - col.IsPrimary = col.ColumnKey == "PRI" - - columns = append(columns, col) - } - - return &TableInfo{ - TableName: tableName, - Columns: columns, - }, nil -} - -// getPostgresTableInfo 获取 PostgreSQL 表信息 -func (i *Introspector) getPostgresTableInfo(tableName string) (*TableInfo, error) { - query := ` - SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_name = $1 - ORDER BY ordinal_position - ` - - rows, err := i.db.Query(query, tableName) - if err != nil { - return nil, fmt.Errorf("查询列信息失败:%w", err) - } - defer rows.Close() - - columns := []ColumnInfo{} - for rows.Next() { - var col ColumnInfo - var columnDefault sql.NullString - err := rows.Scan(&col.ColumnName, &col.DataType, &col.IsNullable, &columnDefault) - if err != nil { - return nil, fmt.Errorf("扫描列信息失败:%w", err) - } - - // 转换为 Go 类型 - col.GoType = mapPostgresTypeToGoType(col.DataType) - col.FieldName = toCamelCase(col.ColumnName) - col.JSONName = col.ColumnName - col.IsPrimary = col.ColumnName == "id" - - columns = append(columns, col) - } - - return &TableInfo{ - TableName: tableName, - Columns: columns, - }, nil -} - -// getSQLiteTableInfo 获取 SQLite 表信息 -func (i *Introspector) getSQLiteTableInfo(tableName string) (*TableInfo, error) { - query := fmt.Sprintf("PRAGMA table_info(%s)", tableName) - - rows, err := i.db.Query(query) - if err != nil { - return nil, fmt.Errorf("查询列信息失败:%w", err) - } - defer rows.Close() - - columns := []ColumnInfo{} - for rows.Next() { - var col ColumnInfo - var notNull int - var pk int - var defaultValue sql.NullString - - err := rows.Scan(&col.ColumnName, &col.DataType, ¬Null, &defaultValue, &pk, &col.Extra) - if err != nil { - return nil, fmt.Errorf("扫描列信息失败:%w", err) - } - - col.IsNullable = notNull == 0 - col.IsPrimary = pk > 0 - - // 转换为 Go 类型 - col.GoType = mapSQLiteTypeToGoType(col.DataType) - col.FieldName = toCamelCase(col.ColumnName) - col.JSONName = col.ColumnName - - columns = append(columns, col) - } - - return &TableInfo{ - TableName: tableName, - Columns: columns, - }, nil -} - -// mapMySQLTypeToGoType 映射 MySQL 类型到 Go 类型 -func mapMySQLTypeToGoType(dbType string) string { - typeMap := map[string]string{ - "tinyint": "int64", - "smallint": "int64", - "mediumint": "int64", - "int": "int64", - "bigint": "int64", - "float": "float64", - "double": "float64", - "decimal": "string", - "date": "time.Time", - "datetime": "time.Time", - "timestamp": "time.Time", - "time": "string", - "char": "string", - "varchar": "string", - "text": "string", - "tinytext": "string", - "mediumtext": "string", - "longtext": "string", - "blob": "[]byte", - "tinyblob": "[]byte", - "mediumblob": "[]byte", - "longblob": "[]byte", - "boolean": "bool", - "json": "string", - } - - if goType, ok := typeMap[dbType]; ok { - return goType - } - return "string" -} - -// mapPostgresTypeToGoType 映射 PostgreSQL 类型到 Go 类型 -func mapPostgresTypeToGoType(dbType string) string { - typeMap := map[string]string{ - "smallint": "int64", - "integer": "int64", - "bigint": "int64", - "real": "float64", - "double": "float64", - "numeric": "string", - "decimal": "string", - "date": "time.Time", - "timestamp": "time.Time", - "timestamptz": "time.Time", - "time": "string", - "char": "string", - "varchar": "string", - "text": "string", - "bytea": "[]byte", - "boolean": "bool", - "json": "string", - "jsonb": "string", - } - - if goType, ok := typeMap[dbType]; ok { - return goType - } - return "string" -} - -// mapSQLiteTypeToGoType 映射 SQLite 类型到 Go 类型 -func mapSQLiteTypeToGoType(dbType string) string { - typeMap := map[string]string{ - "INTEGER": "int64", - "REAL": "float64", - "TEXT": "string", - "BLOB": "[]byte", - "NUMERIC": "string", - } - - if goType, ok := typeMap[dbType]; ok { - return goType - } - return "string" -} - -// toCamelCase 转换为驼峰命名 -func toCamelCase(str string) string { - parts := splitByUnderscore(str) - result := "" - - for _, part := range parts { - if len(part) > 0 { - result += strings.ToUpper(string(part[0])) + part[1:] - } - } - - return result -} - -// splitByUnderscore 按下划线分割字符串 -func splitByUnderscore(str string) []string { - result := []string{} - current := "" - - for _, ch := range str { - if ch == '_' { - if current != "" { - result = append(result, current) - current = "" - } - } else { - current += string(ch) - } - } - - if current != "" { - result = append(result, current) - } - - return result -} diff --git a/db/main_test.go b/db/main_test.go deleted file mode 100644 index 0c6139c..0000000 --- a/db/main_test.go +++ /dev/null @@ -1,283 +0,0 @@ -package main - -import ( - "fmt" - "testing" - "time" - - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/model" -) - -// TestMain 主测试函数 - 演示 Magic-ORM 的基本功能 -func TestMain(t *testing.T) { - fmt.Println("=== Magic-ORM 测试示例 ===") - fmt.Println() - - // 测试 1: 数据库连接配置 - testConfig() - - // 测试 2: 查询构建器 - testQueryBuilder() - - // 测试 3: 事务操作 - testTransaction() - - // 测试 4: 模型定义 - testModel() - - fmt.Println() - fmt.Println("=== 所有测试完成 ===") -} - -// testConfig 测试配置 -func testConfig() { - fmt.Println("[测试 1] 数据库配置") - - // 创建数据库配置(使用 SQLite 内存数据库进行测试) - config := &core.Config{ - DriverName: "sqlite", - DataSource: ":memory:", - MaxIdleConns: 10, - MaxOpenConns: 100, - Debug: true, - } - - fmt.Printf("配置信息:驱动=%s, 数据源=%s\n", config.DriverName, config.DataSource) - fmt.Println() -} - -// testQueryBuilder 测试查询构建器 -func testQueryBuilder() { - fmt.Println("[测试 2] 查询构建器") - - // 注意:由于还未实现完整的驱动,这里仅测试查询构建器的 SQL 生成功能 - - // 创建一个模拟的数据库实例(不需要真实连接) - db := &core.Database{} - - // 测试链式调用 - result := db.Table("user"). - Select("id", "username", "email"). - Where("status = ?", 1). - Where("age > ?", 18). - Order("created_at DESC"). - Limit(10). - Offset(0) - - sqlStr, args := result.Build() - fmt.Printf("生成的 SQL: %s\n", sqlStr) - fmt.Printf("参数:%v\n", args) - fmt.Println() - - // 测试 OR 条件 - db2 := &core.Database{} - q2 := db2.Table("user").Where("status = ?", 1) - sqlStr2, args2 := q2.Or("role = ?", "admin").Build() - fmt.Printf("OR 条件 SQL: %s\n", sqlStr2) - fmt.Printf("参数:%v\n", args2) - fmt.Println() - - // 测试 JOIN - db3 := &core.Database{} - sqlStr3, args3 := db3.Table("user"). - Select("u.id", "u.username", "o.amount"). - LeftJoin("order o", "u.id = o.user_id"). - Where("o.status = ?", 1). - Build() - fmt.Printf("JOIN SQL: %s\n", sqlStr3) - fmt.Printf("参数:%v\n", args3) - fmt.Println() -} - -// testTransaction 测试事务 -func testTransaction() { - fmt.Println("[测试 3] 事务操作") - - // 模拟事务流程 - fmt.Println("事务流程演示:") - fmt.Println("1. 开启事务") - fmt.Println("2. 执行插入操作") - fmt.Println("3. 执行更新操作") - fmt.Println("4. 提交事务") - fmt.Println() - - // 错误处理演示 - fmt.Println("错误处理:") - fmt.Println("- 如果任何步骤失败,自动回滚") - fmt.Println("- 如果发生 panic,自动回滚") - fmt.Println("- 成功后自动提交") - fmt.Println() -} - -// testModel 测试模型定义 -func testModel() { - fmt.Println("[测试 4] 模型定义") - - // 创建用户实例 - user := model.User{ - ID: 1, - Username: "test_user", - Password: "secret_password", - Email: "test@example.com", - Status: 1, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - fmt.Printf("用户模型:%+v\n", user) - fmt.Printf("表名:%s\n", user.TableName()) - fmt.Println() - - // 创建产品实例 - product := model.Product{ - ID: 1, - Name: "测试商品", - Price: 99.99, - Stock: 100, - Version: 1, - } - - fmt.Printf("产品模型:%+v\n", product) - fmt.Printf("表名:%s\n", product.TableName()) - fmt.Println() - - // 创建订单实例 - order := model.Order{ - ID: 1, - UserID: 1, - Amount: 199.99, - Status: 1, - CreatedAt: time.Now(), - } - - fmt.Printf("订单模型:%+v\n", order) - fmt.Printf("表名:%s\n", order.TableName()) - fmt.Println() -} - -// TestInsert 测试插入操作(示例代码) -func TestInsert(t *testing.T) { - fmt.Println("\n[插入操作示例]") - // 伪代码示例 - fmt.Println(` -// 创建用户 -user := &model.User{ - Username: "new_user", - Password: "password123", - Email: "new@example.com", - Status: 1, -} - -// 插入数据库 -id, err := db.Model(&model.User{}).Insert(user) -if err != nil { - log.Fatal(err) -} -fmt.Printf("插入成功,ID=%d\n", id) -`) -} - -// TestQuery 测试查询操作(示例代码) -func TestQuery(t *testing.T) { - fmt.Println("\n[查询操作示例]") - // 伪代码示例 - fmt.Println(` -// 查询单个用户 -var user model.User -err := db.Model(&model.User{}).Where("id = ?", 1).First(&user) -if err != nil { - log.Fatal(err) -} - -// 查询多个用户 -var users []model.User -err = db.Model(&model.User{}). - Where("status = ?", 1). - Order("id DESC"). - Limit(10). - Find(&users) -if err != nil { - log.Fatal(err) -} - -// 条件查询 -count := 0 -db.Model(&model.User{}). - Where("age > ?", 18). - And("status = ?", 1). - Count(&count) -`) -} - -// TestUpdate 测试更新操作(示例代码) -func TestUpdate(t *testing.T) { - fmt.Println("\n[更新操作示例]") - // 伪代码示例 - fmt.Println(` -// 更新单个字段 -err := db.Model(&model.User{}). - Where("id = ?", 1). - UpdateColumn("email", "new@example.com") - -// 更新多个字段 -err = db.Model(&model.User{}). - Where("id = ?", 1). - Updates(map[string]interface{}{ - "email": "new@example.com", - "status": 1, - }) -`) -} - -// TestDelete 测试删除操作(示例代码) -func TestDelete(t *testing.T) { - fmt.Println("\n[删除操作示例]") - // 伪代码示例 - fmt.Println(` -// 删除单个记录 -err := db.Model(&model.User{}).Where("id = ?", 1).Delete() - -// 批量删除 -err = db.Model(&model.User{}). - Where("status = ?", 0). - Delete() -`) -} - -// TestTransactionExample 事务操作完整示例 -func TestTransactionExample(t *testing.T) { - fmt.Println("\n[事务操作完整示例]") - // 伪代码示例 - fmt.Println(` -err := db.Transaction(func(tx core.ITx) error { - // 创建用户 - user := &model.User{ - Username: "tx_user", - Email: "tx@example.com", - } - _, err := tx.Insert(user) - if err != nil { - return err - } - - // 创建订单 - order := &model.Order{ - UserID: user.ID, - Amount: 99.99, - } - _, err = tx.Insert(order) - if err != nil { - return err - } - - // 所有操作成功,自动提交 - return nil -}) - -if err != nil { - // 任何操作失败,自动回滚 - log.Fatal("事务失败:", err) -} -`) -} diff --git a/db/perf_report.go b/db/perf_report.go deleted file mode 100644 index 6117efe..0000000 --- a/db/perf_report.go +++ /dev/null @@ -1,141 +0,0 @@ -package main - -import ( - "fmt" -) - -// Magic-ORM 性能优化报告 -func main() { - fmt.Println("\n========================================") - fmt.Println(" Magic-ORM 性能优化完成报告") - fmt.Println("========================================\n") - - fmt.Println("✅ 已完成的性能优化:") - fmt.Println() - - fmt.Println("1. 字符串拼接优化") - fmt.Println(" - Where/Or/Join 方法使用 strings.Builder") - fmt.Println(" - 预分配内存减少 GC 压力") - fmt.Println(" - 避免使用 + 操作符进行字符串连接") - fmt.Println() - - fmt.Println("2. 内存池优化 (sync.Pool)") - fmt.Println(" - whereArgsPool: 复用 WHERE 参数 slice") - fmt.Println(" - joinArgsPool: 复用 JOIN 参数 slice") - fmt.Println(" - insertArgsPool: 复用 INSERT 参数 slice") - fmt.Println(" - colNamesPool: 复用列名 slice") - fmt.Println() - - fmt.Println("3. 预分配内存优化") - fmt.Println(" - strings.Builder.Grow() 预分配缓冲区") - fmt.Println(" - slice 初始化时指定容量") - fmt.Println(" - 减少内存重新分配次数") - fmt.Println() - - fmt.Println("4. 事务处理优化") - fmt.Println(" - Insert 方法使用对象池") - fmt.Println(" - Update 方法复用参数 slice") - fmt.Println(" - 减少每次调用的内存分配") - fmt.Println() - - fmt.Println("========================================") - fmt.Println(" 优化技术细节") - fmt.Println("========================================\n") - - fmt.Println("📦 sync.Pool 使用示例:") - fmt.Println(` -var whereArgsPool = sync.Pool{ - New: func() interface{} { - return make([]interface{}, 0, 10) - }, -} - -// 使用时 -args := whereArgsPool.Get().([]interface{}) -args = args[:0] // 重置但不释放 -defer whereArgsPool.Put(args) // 放回池中 -`) - - fmt.Println("📝 strings.Builder 优化示例:") - fmt.Println(` -// 优化前 -q.whereSQL += " AND " + query - -// 优化后 -var builder strings.Builder -builder.Grow(len(q.whereSQL) + 5 + len(query)) -builder.WriteString(q.whereSQL) -builder.WriteString(" AND ") -builder.WriteString(query) -q.whereSQL = builder.String() -`) - - fmt.Println("💾 预分配内存示例:") - fmt.Println(` -// 优化前 -colNames := make([]string, 0, len(columns)) - -// 优化后 -colNames := colNamesPool.Get().([]string) -colNames = colNames[:0] -defer colNamesPool.Put(colNames) -`) - - fmt.Println("========================================") - fmt.Println(" 性能提升预期") - fmt.Println("========================================\n") - - fmt.Println("预计性能提升:") - fmt.Println(" ✓ 减少 30-50% 的内存分配") - fmt.Println(" ✓ 降低 20-40% 的 GC 压力") - fmt.Println(" ✓ 提升 15-30% 的吞吐量") - fmt.Println(" ✓ 减少 25-35% 的延迟") - fmt.Println() - - fmt.Println("适用场景:") - fmt.Println(" ✓ 高并发插入操作") - fmt.Println(" ✓ 批量数据处理") - fmt.Println(" ✓ 频繁查询场景") - fmt.Println(" ✓ 事务密集型应用") - fmt.Println() - - fmt.Println("最佳实践建议:") - fmt.Println(" 1. 批量操作使用 BatchInsert + 事务") - fmt.Println(" 2. 高频查询使用连接池配置") - fmt.Println(" 3. 大数据量考虑分页查询") - fmt.Println(" 4. 合理设置 maxOpenConns 和 maxIdleConns") - fmt.Println(" 5. 定期清理过期数据") - fmt.Println() - - fmt.Println("========================================") - fmt.Println(" 验证方式") - fmt.Println("========================================\n") - - fmt.Println("运行性能测试:") - fmt.Println(" go test -bench=. ./db/core/") - fmt.Println(" go test -benchmem ./db/core/") - fmt.Println() - - fmt.Println("查看内存分配:") - fmt.Println(" go test -allocs ./db/core/") - fmt.Println() - - fmt.Println("分析 CPU 性能:") - fmt.Println(" go test -cpuprofile=cpu.prof ./db/core/") - fmt.Println(" go tool pprof cpu.prof") - fmt.Println() - - fmt.Println("========================================") - fmt.Println(" 总结") - fmt.Println("========================================\n") - - fmt.Println("Magic-ORM 框架已完成全面的性能优化:") - fmt.Println(" ✅ 核心查询构建器优化") - fmt.Println(" ✅ 事务处理优化") - fmt.Println(" ✅ 内存管理优化") - fmt.Println(" ✅ 字符串处理优化") - fmt.Println(" ✅ 对象池复用机制") - fmt.Println() - fmt.Println("这些优化确保了 ORM 在高负载场景下的稳定性和性能表现!") - fmt.Println() -} diff --git a/db/read_time_test.go b/db/read_time_test.go deleted file mode 100644 index c0798a6..0000000 --- a/db/read_time_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "testing" - "time" - - "git.magicany.cc/black1552/gin-base/db/model" -) - -// TestTimeFormatting 测试时间格式化 -func TestTimeFormatting(t *testing.T) { - fmt.Println("\n=== 测试时间格式化 ===") - - // 创建带时间的模型 - now := time.Now() - user := &model.User{ - ID: 1, - Username: "test_user", - Email: "test@example.com", - Status: 1, - CreatedAt: model.Time{Time: now}, - UpdatedAt: model.Time{Time: now}, - } - - // 序列化为 JSON - jsonData, err := json.Marshal(user) - if err != nil { - t.Errorf("JSON 序列化失败:%v", err) - } - - fmt.Printf("原始时间:%s\n", now.Format("2006-01-02 15:04:05")) - fmt.Printf("JSON 输出:%s\n", string(jsonData)) - - // 验证时间格式 - var result map[string]interface{} - if err := json.Unmarshal(jsonData, &result); err != nil { - t.Errorf("JSON 反序列化失败:%v", err) - } - - createdAt, ok := result["created_at"].(string) - if !ok { - t.Error("created_at 应该是字符串") - } - - // 验证格式 - expectedFormat := "2006-01-02 15:04:05" - _, err = time.Parse(expectedFormat, createdAt) - if err != nil { - t.Errorf("时间格式不正确:%v", err) - } - - fmt.Printf("✓ 时间格式化测试通过\n") -} - -// TestTimeUnmarshal 测试时间反序列化 -func TestTimeUnmarshal(t *testing.T) { - fmt.Println("\n=== 测试时间反序列化 ===") - - // 测试不同时间格式 - testCases := []struct { - name string - jsonStr string - expected string - }{ - { - name: "标准格式", - jsonStr: `{"time":"2026-04-02 22:04:44"}`, - expected: "2026-04-02 22:04:44", - }, - { - name: "ISO8601 格式", - jsonStr: `{"time":"2026-04-02T22:04:44+08:00"}`, - expected: "2026-04-02 22:04:44", - }, - { - name: "日期格式", - jsonStr: `{"time":"2026-04-02"}`, - expected: "2026-04-02 00:00:00", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var result struct { - Time model.Time `json:"time"` - } - - if err := json.Unmarshal([]byte(tc.jsonStr), &result); err != nil { - t.Errorf("反序列化失败:%v", err) - } - - formatted := result.Time.String() - fmt.Printf("%s: %s -> %s\n", tc.name, tc.jsonStr, formatted) - }) - } - - fmt.Println("✓ 时间反序列化测试通过") -} - -// TestZeroTime 测试零值时间 -func TestZeroTime(t *testing.T) { - fmt.Println("\n=== 测试零值时间 ===") - - user := &model.User{ - ID: 1, - Username: "test", - CreatedAt: model.Time{}, // 零值 - UpdatedAt: model.Time{}, // 零值 - } - - jsonData, err := json.Marshal(user) - if err != nil { - t.Errorf("JSON 序列化失败:%v", err) - } - - fmt.Printf("零值时间 JSON: %s\n", string(jsonData)) - - // 零值应该序列化为 null - var result map[string]interface{} - json.Unmarshal(jsonData, &result) - - if result["created_at"] != nil { - t.Error("零值时间应该序列化为 null") - } - - fmt.Println("✓ 零值时间测试通过") -} - -// TestPointerTime 测试指针类型时间 -func TestPointerTime(t *testing.T) { - fmt.Println("\n=== 测试指针类型时间 ===") - - now := time.Now() - softUser := &model.SoftDeleteUser{ - ID: 1, - Username: "test", - DeletedAt: &model.Time{Time: now}, - } - - jsonData, err := json.Marshal(softUser) - if err != nil { - t.Errorf("JSON 序列化失败:%v", err) - } - - fmt.Printf("指针时间 JSON: %s\n", string(jsonData)) - - // 验证格式 - var result map[string]interface{} - json.Unmarshal(jsonData, &result) - - if deletedAt, ok := result["deleted_at"].(string); ok { - _, err := time.Parse("2006-01-02 15:04:05", deletedAt) - if err != nil { - t.Errorf("指针时间格式不正确:%v", err) - } - } - - fmt.Println("✓ 指针类型时间测试通过") -} - -// TestAllReadTimeFormatting 完整读取时间格式化测试 -func TestAllReadTimeFormatting(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" 读取操作时间格式化完整性测试") - fmt.Println("========================================") - - TestTimeFormatting(t) - TestTimeUnmarshal(t) - TestZeroTime(t) - TestPointerTime(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有读取时间格式化测试完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已实现的读取时间格式化功能:") - fmt.Println(" ✓ CreatedAt: 自动格式化为 YYYY-MM-DD HH:mm:ss") - fmt.Println(" ✓ UpdatedAt: 自动格式化为 YYYY-MM-DD HH:mm:ss") - fmt.Println(" ✓ DeletedAt: 自动格式化为 YYYY-MM-DD HH:mm:ss") - fmt.Println(" ✓ 支持多种时间格式反序列化") - fmt.Println(" ✓ 零值时间正确处理为 null") - fmt.Println(" ✓ 指针类型时间正确序列化") - fmt.Println() -} diff --git a/db/time_test.go b/db/time_test.go deleted file mode 100644 index 8389bc7..0000000 --- a/db/time_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package main - -import ( - "fmt" - "testing" - "time" - - "git.magicany.cc/black1552/gin-base/db/model" - "git.magicany.cc/black1552/gin-base/db/utils" -) - -// TestTimeUtils 测试时间工具 -func TestTimeUtils(t *testing.T) { - fmt.Println("\n=== 测试时间工具 ===") - - // 测试 Now() - nowStr := utils.Now() - fmt.Printf("当前时间:%s\n", nowStr) - - // 测试 FormatTime - nowTime := time.Now() - formatted := utils.FormatTime(nowTime) - fmt.Printf("格式化时间:%s\n", formatted) - - // 测试 ParseTime - parsed, err := utils.ParseTime(nowStr) - if err != nil { - t.Errorf("解析时间失败:%v", err) - } - fmt.Printf("解析时间:%v\n", parsed) - - // 测试 Timestamp - timestamp := utils.Timestamp() - fmt.Printf("时间戳:%d\n", timestamp) - - // 测试 FormatTimestamp - formattedTs := utils.FormatTimestamp(timestamp) - fmt.Printf("时间戳格式化:%s\n", formattedTs) - - // 测试 IsZeroTime - zeroTime := time.Time{} - if !utils.IsZeroTime(zeroTime) { - t.Error("零值时间检测失败") - } - fmt.Printf("零值时间检测:通过\n") - - // 测试 SafeTime - safe := utils.SafeTime(zeroTime) - fmt.Printf("安全时间(零值转当前):%s\n", utils.FormatTime(safe)) - - fmt.Println("✓ 时间工具测试通过") -} - -// TestInsertWithTime 测试 Insert 自动处理时间 -func TestInsertWithTime(t *testing.T) { - fmt.Println("\n=== 测试 Insert 自动处理时间 ===") - - // 创建带时间字段的模型 - user := &model.User{ - ID: 0, // 自增 ID - Username: "test_user", - Email: "test@example.com", - Status: 1, - CreatedAt: time.Time{}, // 零值时间,应该自动设置 - UpdatedAt: time.Time{}, // 零值时间,应该自动设置 - } - - fmt.Printf("插入前 CreatedAt: %v\n", user.CreatedAt) - fmt.Printf("插入前 UpdatedAt: %v\n", user.UpdatedAt) - - // 注意:这里不实际执行插入,仅测试逻辑 - fmt.Println("Insert 方法会自动检测并设置零值时间字段为当前时间") - fmt.Println(" - created_at: 零值时自动设置为 now()") - fmt.Println(" - updated_at: 零值时自动设置为 now()") - fmt.Println(" - deleted_at: 零值时自动设置为 now()") - - fmt.Println("✓ Insert 时间处理测试通过") -} - -// TestUpdateWithTime 测试 Update 自动处理时间 -func TestUpdateWithTime(t *testing.T) { - fmt.Println("\n=== 测试 Update 自动处理时间 ===") - - // Update 方法会自动添加 updated_at = now() - fmt.Println("Update 方法会自动设置 updated_at 为当前时间") - - data := map[string]interface{}{ - "username": "new_name", - "email": "new@example.com", - } - - fmt.Printf("原始数据:%v\n", data) - fmt.Println("Update 会自动添加:updated_at = time.Now()") - - fmt.Println("✓ Update 时间处理测试通过") -} - -// TestDeleteWithSoftDelete 测试软删除时间处理 -func TestDeleteWithSoftDelete(t *testing.T) { - fmt.Println("\n=== 测试软删除时间处理 ===") - - // 带软删除的模型 - now := time.Now() - user := &model.SoftDeleteUser{ - ID: 1, - Username: "test", - DeletedAt: &now, // 已设置删除时间 - } - - fmt.Printf("删除前 DeletedAt: %v\n", user.DeletedAt) - fmt.Println("Delete 方法会检测 DeletedAt 字段") - fmt.Println(" - 如果存在:执行软删除(UPDATE deleted_at = now())") - fmt.Println(" - 如果不存在:执行硬删除(DELETE)") - - fmt.Println("✓ 软删除时间处理测试通过") -} - -// TestTimeFormat 测试时间格式 -func TestTimeFormat(t *testing.T) { - fmt.Println("\n=== 测试时间格式 ===") - - // 默认时间格式 - expectedFormat := "2006-01-02 15:04:05" - nowStr := utils.Now() - - fmt.Printf("默认时间格式:%s\n", expectedFormat) - fmt.Printf("当前时间输出:%s\n", nowStr) - - // 验证格式 - _, err := time.Parse(expectedFormat, nowStr) - if err != nil { - t.Errorf("时间格式不正确:%v", err) - } - - fmt.Println("✓ 时间格式测试通过") -} - -// TestAllTimeHandling 完整时间处理测试 -func TestAllTimeHandling(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" CRUD 操作时间处理完整性测试") - fmt.Println("========================================") - - TestTimeUtils(t) - TestInsertWithTime(t) - TestUpdateWithTime(t) - TestDeleteWithSoftDelete(t) - TestTimeFormat(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有时间处理测试完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已实现的时间处理功能:") - fmt.Println(" ✓ Insert: 自动设置 created_at/updated_at") - fmt.Println(" ✓ Update: 自动设置 updated_at = now()") - fmt.Println(" ✓ Delete: 软删除自动设置 deleted_at = now()") - fmt.Println(" ✓ 默认时间格式:YYYY-MM-DD HH:mm:ss") - fmt.Println(" ✓ 零值时间自动转换为当前时间") - fmt.Println(" ✓ 时间工具函数齐全(Now/Parse/Format 等)") - fmt.Println() -} diff --git a/db/tracing/tracer.go b/db/tracing/tracer.go deleted file mode 100644 index d36e4ce..0000000 --- a/db/tracing/tracer.go +++ /dev/null @@ -1,181 +0,0 @@ -package tracing - -import ( - "context" - "database/sql" - "fmt" - "time" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -// Tracer 数据库操作追踪器 -type Tracer struct { - tracer trace.Tracer - config *TracerConfig -} - -// TracerConfig 追踪器配置 -type TracerConfig struct { - ServiceName string // 服务名称 - DBName string // 数据库名称 - DBSystem string // 数据库类型(mysql/postgresql/sqlite) -} - -// NewTracer 创建数据库追踪器 -func NewTracer(config *TracerConfig) *Tracer { - return &Tracer{ - tracer: otel.Tracer(config.ServiceName), - config: config, - } -} - -// TraceQuery 追踪查询操作 -func (t *Tracer) TraceQuery(ctx context.Context, query string, args []interface{}) (context.Context, error) { - // 创建 Span - spanName := fmt.Sprintf("DB Query: %s", t.getOperationName(query)) - ctx, span := t.tracer.Start(ctx, spanName, - trace.WithSpanKind(trace.SpanKindClient), - ) - defer span.End() - - // 设置属性 - span.SetAttributes( - attribute.String("db.system", t.config.DBSystem), - attribute.String("db.name", t.config.DBName), - attribute.String("db.statement", query), - attribute.StringSlice("db.args", t.argsToString(args)), - ) - - // 返回包含 Span 的 context - return ctx, nil -} - -// RecordError 记录错误 -func (t *Tracer) RecordError(ctx context.Context, err error) { - span := trace.SpanFromContext(ctx) - if span.IsRecording() { - span.RecordError(err) - } -} - -// RecordAffectedRows 记录影响的行数 -func (t *Tracer) RecordAffectedRows(ctx context.Context, rows int64) { - span := trace.SpanFromContext(ctx) - if span.IsRecording() { - span.SetAttributes(attribute.Int64("db.rows_affected", rows)) - } -} - -// getOperationName 从 SQL 获取操作名称 -func (t *Tracer) getOperationName(sql string) string { - if len(sql) < 6 { - return "UNKNOWN" - } - - prefix := sql[:6] - switch prefix { - case "SELECT": - return "SELECT" - case "INSERT": - return "INSERT" - case "UPDATE": - return "UPDATE" - case "DELETE": - return "DELETE" - default: - return "OTHER" - } -} - -// argsToString 将参数转换为字符串切片 -func (t *Tracer) argsToString(args []interface{}) []string { - result := make([]string, len(args)) - for i, arg := range args { - result[i] = fmt.Sprintf("%v", arg) - } - return result -} - -// WithTrace 在查询中启用追踪 -func WithTrace(ctx context.Context, db *sql.DB, query string, args ...interface{}) (*sql.Rows, error) { - // 获取追踪器(从全局或上下文中) - tracer := getTracerFromContext(ctx) - - if tracer != nil { - var err error - ctx, err = tracer.TraceQuery(ctx, query, args) - if err != nil { - return nil, err - } - - defer func(start time.Time) { - duration := time.Since(start) - span := trace.SpanFromContext(ctx) - if span.IsRecording() { - span.SetAttributes(attribute.Int64("db.duration_ms", duration.Milliseconds())) - } - }(time.Now()) - } - - // 执行实际查询 - return db.QueryContext(ctx, query, args...) -} - -// ExecWithTrace 在执行中启用追踪 -func ExecWithTrace(ctx context.Context, db *sql.DB, query string, args ...interface{}) (sql.Result, error) { - tracer := getTracerFromContext(ctx) - - if tracer != nil { - var err error - ctx, err = tracer.TraceQuery(ctx, query, args) - if err != nil { - return nil, err - } - - defer func(start time.Time) { - duration := time.Since(start) - span := trace.SpanFromContext(ctx) - if span.IsRecording() { - span.SetAttributes(attribute.Int64("db.duration_ms", duration.Milliseconds())) - } - }(time.Now()) - } - - // 执行实际操作 - result, err := db.ExecContext(ctx, query, args...) - if err != nil { - if tracer != nil { - tracer.RecordError(ctx, err) - } - return nil, err - } - - // 记录影响的行数 - if tracer != nil { - rows, _ := result.RowsAffected() - tracer.RecordAffectedRows(ctx, rows) - } - - return result, nil -} - -// contextKey 上下文键类型 -type contextKey string - -const tracerKey contextKey = "db_tracer" - -// ContextWithTracer 将追踪器存入上下文 -func ContextWithTracer(ctx context.Context, tracer *Tracer) context.Context { - return context.WithValue(ctx, tracerKey, tracer) -} - -// getTracerFromContext 从上下文获取追踪器 -func getTracerFromContext(ctx context.Context) *Tracer { - if tracer, ok := ctx.Value(tracerKey).(*Tracer); ok { - return tracer - } - return nil -} diff --git a/db/utils/time.go b/db/utils/time.go deleted file mode 100644 index 479c20a..0000000 --- a/db/utils/time.go +++ /dev/null @@ -1,54 +0,0 @@ -package utils - -import ( - "time" -) - -// TimeFormat 默认时间格式 -const TimeFormat = "2006-01-02 15:04:05" - -// FormatTime 格式化时间为默认格式 -func FormatTime(t time.Time) string { - return t.Format(TimeFormat) -} - -// ParseTime 解析时间字符串 -func ParseTime(timeStr string) (time.Time, error) { - return time.Parse(TimeFormat, timeStr) -} - -// Now 返回当前时间(默认格式) -func Now() string { - return time.Now().Format(TimeFormat) -} - -// Timestamp 返回当前时间戳 -func Timestamp() int64 { - return time.Now().Unix() -} - -// FormatTimestamp 格式化时间戳为默认格式 -func FormatTimestamp(timestamp int64) string { - return time.Unix(timestamp, 0).Format(TimeFormat) -} - -// IsZeroTime 检查是否是零值时间 -func IsZeroTime(t time.Time) bool { - return t.IsZero() || t.UnixNano() == 0 -} - -// SafeTime 安全获取时间,如果是零值则返回当前时间 -func SafeTime(t time.Time) time.Time { - if IsZeroTime(t) { - return time.Now() - } - return t -} - -// FormatToDefault 将任意时间格式化为默认格式 -func FormatToDefault(t time.Time) string { - if IsZeroTime(t) { - return "" - } - return FormatTime(t) -} diff --git a/db/validation_test.go b/db/validation_test.go deleted file mode 100644 index c553010..0000000 --- a/db/validation_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package main - -import ( - "fmt" - "testing" - - "git.magicany.cc/black1552/gin-base/db/core" - "git.magicany.cc/black1552/gin-base/db/generator" -) - -// TestParamFilter 测试参数过滤器 -func TestParamFilter(t *testing.T) { - fmt.Println("\n=== 测试参数过滤器 ===") - - pf := core.NewParamFilter() - - // 测试过滤零值 - data := map[string]interface{}{ - "name": "test", - "age": 0, // 零值 - "email": "", // 空值 - "status": 1, - "extra": nil, // nil 值 - } - - filtered := pf.FilterZeroValues(data) - fmt.Printf("原始数据:%v\n", data) - fmt.Printf("过滤零值后:%v\n", filtered) - - if len(filtered) != 2 { - t.Errorf("期望过滤后剩余 2 个字段,实际为%d", len(filtered)) - } - - fmt.Println("✓ 参数过滤器测试通过") -} - -// TestQueryBuilderMethods 测试查询构建器方法 -func TestQueryBuilderMethods(t *testing.T) { - fmt.Println("\n=== 测试查询构建器方法 ===") - - db := &core.Database{} - - // 测试 Omit - q1 := db.Table("user").Select("id", "username").Omit("password") - sql1, _ := q1.Build() - fmt.Printf("Omit SQL: %s\n", sql1) - - // 测试 Page - q2 := db.Table("user").Page(2, 20) - sql2, _ := q2.Build() - fmt.Printf("Page SQL: %s\n", sql2) - - // 测试 Count - var count int64 - q3 := db.Table("user").Where("status = ?", 1) - q3.Count(&count) - fmt.Printf("Count 方法已实现\n") - - // 测试 Exists - q4 := db.Table("user").Where("id = ?", 1) - exists, err := q4.Exists() - if err != nil { - t.Errorf("Exists 错误:%v", err) - } - fmt.Printf("Exists: %v\n", exists) - - fmt.Println("✓ 查询构建器方法测试通过") -} - -// TestCodeGenerator 测试代码生成器 -func TestCodeGenerator(t *testing.T) { - fmt.Println("\n=== 测试代码生成器 ===") - - cg := generator.NewCodeGenerator("./generated") - - columns := []generator.ColumnInfo{ - {ColumnName: "id", FieldName: "ID", FieldType: "int64", JSONName: "id", IsPrimary: true}, - {ColumnName: "username", FieldName: "Username", FieldType: "string", JSONName: "username"}, - {ColumnName: "email", FieldName: "Email", FieldType: "string", JSONName: "email"}, - {ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"}, - } - - // 测试 Model 生成 - modelCode := cg.GenerateModel("user", columns) - if modelCode != nil { - t.Logf("Model 生成结果:%v", modelCode) - } - - fmt.Println("✓ 代码生成器测试通过") -} - -// TestTransactionInsert 测试事务插入功能 -func TestTransactionInsert(t *testing.T) { - fmt.Println("\n=== 测试事务插入功能 ===") - - // 注意:这里不创建真实数据库连接,仅测试代码结构 - fmt.Println("事务 Insert 和 BatchInsert 方法已实现") - fmt.Println(" - Insert: 支持结构体插入,返回 LastInsertId") - fmt.Println(" - BatchInsert: 支持分批处理 Slice 数据") - - fmt.Println("✓ 事务插入功能测试通过") -} - -// TestAllCoreFeatures 核心功能完整性测试 -func TestAllCoreFeatures(t *testing.T) { - fmt.Println("\n========================================") - fmt.Println(" Magic-ORM 核心功能完整性验证") - fmt.Println("========================================") - - TestParamFilter(t) - TestQueryBuilderMethods(t) - TestCodeGenerator(t) - TestTransactionInsert(t) - - fmt.Println("\n========================================") - fmt.Println(" 所有核心功能验证完成!") - fmt.Println("========================================") - fmt.Println() - fmt.Println("已完整实现的核心功能:") - fmt.Println(" ✓ 参数智能过滤(零值/空值/nil)") - fmt.Println(" ✓ 查询构建器完整方法集") - fmt.Println(" ✓ Omit 排除字段") - fmt.Println(" ✓ Page 分页查询") - fmt.Println(" ✓ Count 统计") - fmt.Println(" ✓ Exists 存在性检查") - fmt.Println(" ✓ Scan 结果映射") - fmt.Println(" ✓ 事务 Insert 操作") - fmt.Println(" ✓ 事务 BatchInsert 批量插入") - fmt.Println(" ✓ 代码生成器(Model/DAO)") - fmt.Println() -} diff --git a/log/log.go b/log/log.go index 31e1104..a407b0f 100644 --- a/log/log.go +++ b/log/log.go @@ -9,6 +9,7 @@ import ( "regexp" "runtime/debug" "strings" + "sync" "time" "github.com/gogf/gf/v2/os/gfile" @@ -16,13 +17,13 @@ import ( "gopkg.in/natefinch/lumberjack.v2" ) -var ( +type sLogger struct { logPath string sysLog *log.Logger filePath string currentDate string // 当前日志文件对应的日期 fileLogger *lumberjack.Logger -) +} const ( Reset = "\033[0m" @@ -63,21 +64,26 @@ func (w *logWriter) Write(p []byte) (n int, err error) { return len(p), nil } +var ( + logs *sLogger + initMu sync.Mutex // 用于保护初始化过程的互斥锁 +) + // cleanOldLogs 删除指定天数之前的日志文件(包括主文件和备份文件) func cleanOldLogs(days int) { - if !gfile.Exists(logPath) { + if !gfile.Exists(logs.logPath) { return } // 获取所有日志文件 - files, err := gfile.DirNames(logPath) + files, err := gfile.DirNames(logs.logPath) if err != nil { return } now := time.Now() for _, file := range files { - path := filepath.Join(logPath, file) + path := filepath.Join(logs.logPath, file) if gfile.IsDir(path) { continue } @@ -128,12 +134,12 @@ func cleanOldLogs(days int) { // checkAndRotateLogFile 检查是否需要切换日志文件(跨天时) func checkAndRotateLogFile() { date := gtime.Date() - if currentDate != date { + if logs.currentDate != date { // 日期变化,需要重新初始化 - currentDate = date - filePath = gfile.Join(logPath, fmt.Sprintf("log-%s.log", currentDate)) - fileLogger = &lumberjack.Logger{ - Filename: filePath, + logs.currentDate = date + logs.filePath = gfile.Join(logs.logPath, fmt.Sprintf("log-%s.log", logs.currentDate)) + logs.fileLogger = &lumberjack.Logger{ + Filename: logs.filePath, MaxSize: 2, // 单个文件最大 10MB MaxBackups: 5, // 最多保留 5 个备份 MaxAge: 30, // 保留 30 天 @@ -142,9 +148,9 @@ func checkAndRotateLogFile() { // 创建新的 writer multiWriter := &logWriter{ console: os.Stdout, - file: fileLogger, + file: logs.fileLogger, } - sysLog = log.New(multiWriter, "", 0) + logs.sysLog = log.New(multiWriter, "", 0) // 清理 30 天前的旧日志 cleanOldLogs(30) @@ -152,49 +158,109 @@ func checkAndRotateLogFile() { } func Init() { - if sysLog != nil { - checkAndRotateLogFile() // 检查是否需要切换文件 + // 如果已经初始化,检查是否需要切换文件 + if logs != nil && logs.sysLog != nil { + checkAndRotateLogFile() return } - logPath = gfile.Join(gfile.Pwd(), "logs") - currentDate = gtime.Date() - filePath = gfile.Join(logPath, fmt.Sprintf("log-%s.log", currentDate)) - fileLogger = &lumberjack.Logger{ + + // 加锁确保线程安全 + initMu.Lock() + defer initMu.Unlock() + + // 双重检查,避免重复初始化 + if logs != nil && logs.sysLog != nil { + return + } + + // 初始化日志器 + currentDate := gtime.Date() + logPath := gfile.Join(gfile.Pwd(), "logs") + filePath := gfile.Join(logPath, fmt.Sprintf("log-%s.log", currentDate)) + + // 创建 lumberjack 日志文件处理器 + fileLogger := &lumberjack.Logger{ Filename: filePath, - MaxSize: 2, // 单个文件最大 10MB + MaxSize: 2, // 单个文件最大 2MB MaxBackups: 5, // 最多保留 5 个备份 MaxAge: 30, // 保留 30 天 - Compress: false, // 启用压缩 + Compress: false, // 不启用压缩 } + // 使用自定义 writer 实现控制台带颜色、文件无颜色的输出 multiWriter := &logWriter{ console: os.Stdout, file: fileLogger, } - sysLog = log.New(multiWriter, "", 0) + + logs = &sLogger{ + logPath: logPath, + sysLog: log.New(multiWriter, "", 0), + fileLogger: fileLogger, + filePath: filePath, + currentDate: currentDate, + } // 启动时清理 30 天前的旧日志 cleanOldLogs(30) } -func Info(v ...any) { +func (s *sLogger) Info(v ...any) { Init() - sysLog.SetPrefix(fmt.Sprintf("[%s] %s[INFO]%s ", time.Now().Format("2006-01-02 15:04:05"), Green, Reset)) - sysLog.Println(fmt.Sprint(v...)) + s.sysLog.SetPrefix(fmt.Sprintf("[%s] %s[INFO]%s ", time.Now().Format("2006-01-02 15:04:05"), Green, Reset)) + s.sysLog.Println(fmt.Sprint(v...)) +} +func (s *sLogger) Error(v ...any) { + Init() + s.sysLog.SetPrefix(fmt.Sprintf("[%s] %s[ERROR]%s ", time.Now().Format("2006-01-02 15:04:05"), Red, Reset)) + msg := fmt.Sprint(v...) + s.sysLog.Println(msg, strings.TrimSpace(string(debug.Stack()))) +} +func (s *sLogger) Warn(v ...any) { + Init() + s.sysLog.SetPrefix(fmt.Sprintf("[%s] %s[WARN]%s ", time.Now().Format("2006-01-02 15:04:05"), Yellow, Reset)) + s.sysLog.Println(fmt.Sprint(v...)) +} +func (s *sLogger) Debug(v ...any) { + Init() + s.sysLog.SetPrefix(fmt.Sprintf("[%s] %s[DEBUG]%s ", time.Now().Format("2006-01-02 15:04:05"), Blue, Reset)) + s.sysLog.Println(fmt.Sprint(v...)) +} + +func Info(v ...any) { + if logs == nil { + Init() + } + logs.Info(v...) } func Error(v ...any) { - Init() - sysLog.SetPrefix(fmt.Sprintf("[%s] %s[ERROR]%s ", time.Now().Format("2006-01-02 15:04:05"), Red, Reset)) - msg := fmt.Sprint(v...) - sysLog.Println(msg, strings.TrimSpace(string(debug.Stack()))) + if logs == nil { + Init() + } + logs.Error(v...) } func Warn(v ...any) { - Init() - sysLog.SetPrefix(fmt.Sprintf("[%s] %s[WARN]%s ", time.Now().Format("2006-01-02 15:04:05"), Yellow, Reset)) - sysLog.Println(fmt.Sprint(v...)) + if logs == nil { + Init() + } + logs.Warn(v...) } func Debug(v ...any) { - Init() - sysLog.SetPrefix(fmt.Sprintf("[%s] %s[DEBUG]%s ", time.Now().Format("2006-01-02 15:04:05"), Blue, Reset)) - sysLog.Println(fmt.Sprint(v...)) + if logs == nil { + Init() + } + logs.Debug(v...) +} + +// GetLogger 返回自定义日志器实例,实现 ILogger 接口 +func GetLogger() ILogger { + Init() + return logs +} + +type ILogger interface { + Info(v ...any) + Error(v ...any) + Warn(v ...any) + Debug(v ...any) } diff --git a/main.go b/main.go index ff43765..bfb29c1 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "git.magicany.cc/black1552/gin-base/config" + "git.magicany.cc/black1552/gin-base/database" "git.magicany.cc/black1552/gin-base/log" "git.magicany.cc/black1552/gin-base/server" ) @@ -9,6 +10,9 @@ import ( // TIP
To run your response, right-click the response and select Run.
Alternatively, click
// the