feat(database): 新增命令行参数解析、空值检查和数据库配置管理功能
- 实现了 command 包用于控制台参数和选项解析 - 添加了 empty 包提供变量空值/零值检查功能 - 集成了 Viper 库进行配置文件管理和环境变量支持 - 提供了数据库配置初始化和默认值设置功能 - 实现了配置值获取、设置和结构化解析方法 - 添加了服务器、数据库和JWT配置的便捷获取方法 - 实现了完整的数据库接口定义和事务管理功能main v1.0.2003
parent
6dcd564206
commit
af0a8f4043
|
|
@ -54,30 +54,29 @@ func init() {
|
|||
})
|
||||
}
|
||||
|
||||
func GetConfigPath() string {
|
||||
return configPath
|
||||
}
|
||||
|
||||
// SetDefault 设置默认配置信息
|
||||
func SetDefault() {
|
||||
viper.Set("SERVER.addr", "127.0.0.1:8080")
|
||||
viper.Set("SERVER.mode", "release")
|
||||
|
||||
// 数据库配置 - 支持多种数据库类型
|
||||
viper.Set("DATABASE.type", "sqlite")
|
||||
viper.Set("DATABASE.dns", gfile.Join(gfile.Pwd(), "db", "database.db"))
|
||||
viper.Set("DATABASE.debug", true)
|
||||
|
||||
// 数据库连接池配置
|
||||
viper.Set("DATABASE.maxIdleConns", 10) // 最大空闲连接数
|
||||
viper.Set("DATABASE.maxOpenConns", 100) // 最大打开连接数
|
||||
viper.Set("DATABASE.connMaxLifetime", 3600) // 连接最大生命周期(秒)
|
||||
|
||||
// 数据库主从配置(可选)
|
||||
viper.Set("DATABASE.replicas", []string{}) // 从库列表
|
||||
viper.Set("DATABASE.readPolicy", "random") // 读负载均衡策略
|
||||
|
||||
// 时间配置 - 定义时间字段名称和格式
|
||||
viper.Set("DATABASE.timeConfig.createdAt", "created_at")
|
||||
viper.Set("DATABASE.timeConfig.updatedAt", "updated_at")
|
||||
viper.Set("DATABASE.timeConfig.deletedAt", "deleted_at")
|
||||
viper.Set("DATABASE.timeConfig.format", "2006-01-02 15:04:05")
|
||||
viper.Set("DATABASE.default.host", "127.0.0.1")
|
||||
viper.Set("DATABASE.default.port", "3306")
|
||||
viper.Set("DATABASE.default.user", "root")
|
||||
viper.Set("DATABASE.default.pass", "123456")
|
||||
viper.Set("DATABASE.default.name", "test")
|
||||
viper.Set("DATABASE.default.type", "mysql")
|
||||
viper.Set("DATABASE.default.role", "master")
|
||||
viper.Set("DATABASE.default.debug", false)
|
||||
viper.Set("DATABASE.default.prefix", "")
|
||||
viper.Set("DATABASE.default.dryRun", false)
|
||||
viper.Set("DATABASE.default.charset", "utf8")
|
||||
viper.Set("DATABASE.default.timezone", "Local")
|
||||
viper.Set("DATABASE.default.createdAt", "create_time")
|
||||
viper.Set("DATABASE.default.updatedAt", "update_time")
|
||||
viper.Set("DATABASE.default.timeMaintainDisabled", false)
|
||||
|
||||
// JWT 配置
|
||||
viper.Set("JWT.secret", "SET-YOUR-SECRET")
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
package base
|
||||
|
||||
import (
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type IdModel struct {
|
||||
Id int `json:"id" gorm:"column:id;type:int(11);common:id"`
|
||||
}
|
||||
type TimeModel struct {
|
||||
CreateTime string `json:"create_time" gorm:"column:create_time;type:varchar(255);common:创建时间"`
|
||||
UpdateTime string `json:"update_time" gorm:"column:update_time;type:varchar(255);common:更新时间"`
|
||||
}
|
||||
|
||||
func (tm *TimeModel) BeforeCreate(scope *gorm.DB) error {
|
||||
scope.Set("create_time", gtime.Datetime())
|
||||
scope.Set("update_time", gtime.Datetime())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *TimeModel) BeforeUpdate(scope *gorm.DB) error {
|
||||
scope.Set("update_time", gtime.Datetime())
|
||||
return nil
|
||||
}
|
||||
|
||||
//func (tm *TimeModel) AfterFind(scope *gorm.DB) error {
|
||||
// tm.CreateTime = gtime.New(tm.CreateTime).String()
|
||||
// tm.UpdateTime = gtime.New(tm.UpdateTime).String()
|
||||
// return nil
|
||||
//}
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
|
||||
// Package command provides console operations, like options/arguments reading.
|
||||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultParsedArgs = make([]string, 0)
|
||||
defaultParsedOptions = make(map[string]string)
|
||||
argumentOptionRegex = regexp.MustCompile(`^\-{1,2}([\w\?\.\-]+)(=){0,1}(.*)$`)
|
||||
)
|
||||
|
||||
// Init does custom initialization.
|
||||
func Init(args ...string) {
|
||||
if len(args) == 0 {
|
||||
if len(defaultParsedArgs) == 0 && len(defaultParsedOptions) == 0 {
|
||||
args = os.Args
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
defaultParsedArgs = make([]string, 0)
|
||||
defaultParsedOptions = make(map[string]string)
|
||||
}
|
||||
// Parsing os.Args with default algorithm.
|
||||
defaultParsedArgs, defaultParsedOptions = ParseUsingDefaultAlgorithm(args...)
|
||||
}
|
||||
|
||||
// ParseUsingDefaultAlgorithm parses arguments using default algorithm.
|
||||
func ParseUsingDefaultAlgorithm(args ...string) (parsedArgs []string, parsedOptions map[string]string) {
|
||||
parsedArgs = make([]string, 0)
|
||||
parsedOptions = make(map[string]string)
|
||||
for i := 0; i < len(args); {
|
||||
array := argumentOptionRegex.FindStringSubmatch(args[i])
|
||||
if len(array) > 2 {
|
||||
if array[2] == "=" {
|
||||
parsedOptions[array[1]] = array[3]
|
||||
} else if i < len(args)-1 {
|
||||
if len(args[i+1]) > 0 && args[i+1][0] == '-' {
|
||||
// Example: gf gen -d -n 1
|
||||
parsedOptions[array[1]] = array[3]
|
||||
} else {
|
||||
// Example: gf gen -n 2
|
||||
parsedOptions[array[1]] = args[i+1]
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Example: gf gen -h
|
||||
parsedOptions[array[1]] = array[3]
|
||||
}
|
||||
} else {
|
||||
parsedArgs = append(parsedArgs, args[i])
|
||||
}
|
||||
i++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetOpt returns the option value named `name`.
|
||||
func GetOpt(name string, def ...string) string {
|
||||
Init()
|
||||
if v, ok := defaultParsedOptions[name]; ok {
|
||||
return v
|
||||
}
|
||||
if len(def) > 0 {
|
||||
return def[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetOptAll returns all parsed options.
|
||||
func GetOptAll() map[string]string {
|
||||
Init()
|
||||
return defaultParsedOptions
|
||||
}
|
||||
|
||||
// ContainsOpt checks whether option named `name` exist in the arguments.
|
||||
func ContainsOpt(name string) bool {
|
||||
Init()
|
||||
_, ok := defaultParsedOptions[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetArg returns the argument at `index`.
|
||||
func GetArg(index int, def ...string) string {
|
||||
Init()
|
||||
if index < len(defaultParsedArgs) {
|
||||
return defaultParsedArgs[index]
|
||||
}
|
||||
if len(def) > 0 {
|
||||
return def[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetArgAll returns all parsed arguments.
|
||||
func GetArgAll() []string {
|
||||
Init()
|
||||
return defaultParsedArgs
|
||||
}
|
||||
|
||||
// GetOptWithEnv returns the command line argument of the specified `key`.
|
||||
// If the argument does not exist, then it returns the environment variable with specified `key`.
|
||||
// It returns the default value `def` if none of them exists.
|
||||
//
|
||||
// Fetching Rules:
|
||||
// 1. Command line arguments are in lowercase format, eg: gf.package.variable;
|
||||
// 2. Environment arguments are in uppercase format, eg: GF_PACKAGE_VARIABLE;
|
||||
func GetOptWithEnv(key string, def ...string) string {
|
||||
cmdKey := strings.ToLower(strings.ReplaceAll(key, "_", "."))
|
||||
if ContainsOpt(cmdKey) {
|
||||
return GetOpt(cmdKey)
|
||||
} else {
|
||||
envKey := strings.ToUpper(strings.ReplaceAll(key, ".", "_"))
|
||||
if r, ok := os.LookupEnv(envKey); ok {
|
||||
return r
|
||||
} else {
|
||||
if len(def) > 0 {
|
||||
return def[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/config"
|
||||
"git.magicany.cc/black1552/gin-base/log"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gfile"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
Type gorm.Dialector
|
||||
Db *gorm.DB
|
||||
err error
|
||||
sqlDb *sql.DB
|
||||
dns = config.GetConfigValue("database.dns", gfile.Join(gfile.Pwd(), "db", "database.db"))
|
||||
)
|
||||
|
||||
func init() {
|
||||
if g.IsEmpty(dns) {
|
||||
log.Error("gormDns 未配置", "请检查配置文件")
|
||||
return
|
||||
}
|
||||
switch config.GetConfigValue("database.type", "sqlite").String() {
|
||||
case "mysql":
|
||||
log.Info("使用 mysql 数据库")
|
||||
mysqlInit()
|
||||
case "sqlite":
|
||||
log.Info("使用 sqlite 数据库")
|
||||
sqliteInit()
|
||||
}
|
||||
|
||||
// 构建 GORM 配置
|
||||
gormConfig := &gorm.Config{
|
||||
SkipDefaultTransaction: true,
|
||||
NowFunc: func() time.Time {
|
||||
return time.Now().Local()
|
||||
},
|
||||
// 命名策略:保持与模型一致,避免字段/表名转换问题
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
SingularTable: true, // 表名禁用复数形式(例如 User 对应 user 表,而非 users)
|
||||
},
|
||||
}
|
||||
|
||||
// 根据配置决定是否开启 GORM 查询日志
|
||||
if config.GetConfigValue("database.debug", false).Bool() {
|
||||
log.Info("已开启 GORM 查询日志")
|
||||
gormConfig.Logger = logger.Default.LogMode(logger.Info)
|
||||
} else {
|
||||
gormConfig.Logger = logger.Default.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
Db, err = gorm.Open(Type, gormConfig)
|
||||
if err != nil {
|
||||
log.Error("数据库连接失败: ", err)
|
||||
return
|
||||
}
|
||||
sqlDb, err = Db.DB()
|
||||
if err != nil {
|
||||
log.Error("获取sqlDb失败", err)
|
||||
return
|
||||
}
|
||||
if err = sqlDb.Ping(); err != nil {
|
||||
log.Error("数据库未正常连接", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func mysqlInit() {
|
||||
Type = mysql.New(mysql.Config{
|
||||
DSN: dns.String(),
|
||||
DefaultStringSize: 255, // string 类型字段的默认长度
|
||||
DisableDatetimePrecision: true, // 禁用 datetime 精度,MySQL 5.6 之前的数据库不支持
|
||||
DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式,MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
|
||||
SkipInitializeWithVersion: false, // 根据当前 MySQL 版本自动配置
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteInit() {
|
||||
if !gfile.Exists(dns.String()) {
|
||||
_, err = gfile.Create(dns.String())
|
||||
if err != nil {
|
||||
log.Error("创建数据库文件失败: ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
Type = sqlite.Open(fmt.Sprintf("%s?cache=shared&mode=rwc&_busy_timeout=10000&_fk=1&_journal=WAL&_sync=FULL", dns.String()))
|
||||
}
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package empty provides functions for checking empty/nil variables.
|
||||
package empty
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/reflection"
|
||||
)
|
||||
|
||||
// iString is used for type assert api for String().
|
||||
type iString interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
// iInterfaces is used for type assert api for Interfaces.
|
||||
type iInterfaces interface {
|
||||
Interfaces() []any
|
||||
}
|
||||
|
||||
// iMapStrAny is the interface support for converting struct parameter to map.
|
||||
type iMapStrAny interface {
|
||||
MapStrAny() map[string]any
|
||||
}
|
||||
|
||||
type iTime interface {
|
||||
Date() (year int, month time.Month, day int)
|
||||
IsZero() bool
|
||||
}
|
||||
|
||||
// IsEmpty checks whether given `value` empty.
|
||||
// It returns true if `value` is in: 0, nil, false, "", len(slice/map/chan) == 0,
|
||||
// or else it returns false.
|
||||
//
|
||||
// The parameter `traceSource` is used for tracing to the source variable if given `value` is type of pointer
|
||||
// that also points to a pointer. It returns true if the source is empty when `traceSource` is true.
|
||||
// Note that it might use reflect feature which affects performance a little.
|
||||
func IsEmpty(value any, traceSource ...bool) bool {
|
||||
if value == nil {
|
||||
return true
|
||||
}
|
||||
// It firstly checks the variable as common types using assertion to enhance the performance,
|
||||
// and then using reflection.
|
||||
switch result := value.(type) {
|
||||
case int:
|
||||
return result == 0
|
||||
case int8:
|
||||
return result == 0
|
||||
case int16:
|
||||
return result == 0
|
||||
case int32:
|
||||
return result == 0
|
||||
case int64:
|
||||
return result == 0
|
||||
case uint:
|
||||
return result == 0
|
||||
case uint8:
|
||||
return result == 0
|
||||
case uint16:
|
||||
return result == 0
|
||||
case uint32:
|
||||
return result == 0
|
||||
case uint64:
|
||||
return result == 0
|
||||
case float32:
|
||||
return result == 0
|
||||
case float64:
|
||||
return result == 0
|
||||
case bool:
|
||||
return !result
|
||||
case string:
|
||||
return result == ""
|
||||
case []byte:
|
||||
return len(result) == 0
|
||||
case []rune:
|
||||
return len(result) == 0
|
||||
case []int:
|
||||
return len(result) == 0
|
||||
case []string:
|
||||
return len(result) == 0
|
||||
case []float32:
|
||||
return len(result) == 0
|
||||
case []float64:
|
||||
return len(result) == 0
|
||||
case map[string]any:
|
||||
return len(result) == 0
|
||||
|
||||
default:
|
||||
// Finally, using reflect.
|
||||
var rv reflect.Value
|
||||
if v, ok := value.(reflect.Value); ok {
|
||||
rv = v
|
||||
} else {
|
||||
rv = reflect.ValueOf(value)
|
||||
if IsNil(rv) {
|
||||
return true
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Common interfaces checks.
|
||||
// =========================
|
||||
if f, ok := value.(iTime); ok {
|
||||
if f == (*time.Time)(nil) {
|
||||
return true
|
||||
}
|
||||
return f.IsZero()
|
||||
}
|
||||
if f, ok := value.(iString); ok {
|
||||
if f == nil {
|
||||
return true
|
||||
}
|
||||
return f.String() == ""
|
||||
}
|
||||
if f, ok := value.(iInterfaces); ok {
|
||||
if f == nil {
|
||||
return true
|
||||
}
|
||||
return len(f.Interfaces()) == 0
|
||||
}
|
||||
if f, ok := value.(iMapStrAny); ok {
|
||||
if f == nil {
|
||||
return true
|
||||
}
|
||||
return len(f.MapStrAny()) == 0
|
||||
}
|
||||
}
|
||||
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
return !rv.Bool()
|
||||
|
||||
case
|
||||
reflect.Int,
|
||||
reflect.Int8,
|
||||
reflect.Int16,
|
||||
reflect.Int32,
|
||||
reflect.Int64:
|
||||
return rv.Int() == 0
|
||||
|
||||
case
|
||||
reflect.Uint,
|
||||
reflect.Uint8,
|
||||
reflect.Uint16,
|
||||
reflect.Uint32,
|
||||
reflect.Uint64,
|
||||
reflect.Uintptr:
|
||||
return rv.Uint() == 0
|
||||
|
||||
case
|
||||
reflect.Float32,
|
||||
reflect.Float64:
|
||||
return rv.Float() == 0
|
||||
|
||||
case reflect.String:
|
||||
return rv.Len() == 0
|
||||
|
||||
case reflect.Struct:
|
||||
var fieldValueInterface any
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
fieldValueInterface, _ = reflection.ValueToInterface(rv.Field(i))
|
||||
if !IsEmpty(fieldValueInterface) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case
|
||||
reflect.Chan,
|
||||
reflect.Map,
|
||||
reflect.Slice,
|
||||
reflect.Array:
|
||||
return rv.Len() == 0
|
||||
|
||||
case reflect.Pointer:
|
||||
if len(traceSource) > 0 && traceSource[0] {
|
||||
return IsEmpty(rv.Elem())
|
||||
}
|
||||
return rv.IsNil()
|
||||
|
||||
case
|
||||
reflect.Func,
|
||||
reflect.Interface,
|
||||
reflect.UnsafePointer:
|
||||
return rv.IsNil()
|
||||
|
||||
case reflect.Invalid:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsNil checks whether given `value` is nil, especially for any type value.
|
||||
// Parameter `traceSource` is used for tracing to the source variable if given `value` is type of pointer
|
||||
// that also points to a pointer. It returns nil if the source is nil when `traceSource` is true.
|
||||
// Note that it might use reflect feature which affects performance a little.
|
||||
func IsNil(value any, traceSource ...bool) bool {
|
||||
if value == nil {
|
||||
return true
|
||||
}
|
||||
var rv reflect.Value
|
||||
if v, ok := value.(reflect.Value); ok {
|
||||
rv = v
|
||||
} else {
|
||||
rv = reflect.ValueOf(value)
|
||||
}
|
||||
switch rv.Kind() {
|
||||
case reflect.Chan,
|
||||
reflect.Map,
|
||||
reflect.Slice,
|
||||
reflect.Func,
|
||||
reflect.Interface,
|
||||
reflect.UnsafePointer:
|
||||
return !rv.IsValid() || rv.IsNil()
|
||||
|
||||
case reflect.Pointer:
|
||||
if len(traceSource) > 0 && traceSource[0] {
|
||||
for rv.Kind() == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
if !rv.IsValid() {
|
||||
return true
|
||||
}
|
||||
if rv.Kind() == reflect.Pointer {
|
||||
return rv.IsNil()
|
||||
}
|
||||
} else {
|
||||
return !rv.IsValid() || rv.IsNil()
|
||||
}
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/json"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// iVal is used for type assert api for Val().
|
||||
type iVal interface {
|
||||
Val() any
|
||||
}
|
||||
|
||||
var (
|
||||
// converter is the internal type converter for gdb.
|
||||
converter = gconv.NewConverter()
|
||||
)
|
||||
|
||||
func init() {
|
||||
converter.RegisterAnyConverterFunc(
|
||||
sliceTypeConverterFunc,
|
||||
reflect.TypeOf([]string{}),
|
||||
reflect.TypeOf([]float32{}),
|
||||
reflect.TypeOf([]float64{}),
|
||||
reflect.TypeOf([]int{}),
|
||||
reflect.TypeOf([]int32{}),
|
||||
reflect.TypeOf([]int64{}),
|
||||
reflect.TypeOf([]uint{}),
|
||||
reflect.TypeOf([]uint32{}),
|
||||
reflect.TypeOf([]uint64{}),
|
||||
)
|
||||
}
|
||||
|
||||
// GetConverter returns the internal type converter for gdb.
|
||||
func GetConverter() gconv.Converter {
|
||||
return converter
|
||||
}
|
||||
|
||||
func sliceTypeConverterFunc(from any, to reflect.Value) (err error) {
|
||||
v, ok := from.(iVal)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
fromVal := v.Val()
|
||||
switch x := fromVal.(type) {
|
||||
case []byte:
|
||||
dst := to.Addr().Interface()
|
||||
err = json.Unmarshal(x, dst)
|
||||
case string:
|
||||
dst := to.Addr().Interface()
|
||||
err = json.Unmarshal([]byte(x), dst)
|
||||
default:
|
||||
fromType := reflect.TypeOf(fromVal)
|
||||
switch fromType.Kind() {
|
||||
case reflect.Slice:
|
||||
convertOption := gconv.ConvertOption{
|
||||
SliceOption: gconv.SliceOption{ContinueOnError: true},
|
||||
MapOption: gconv.MapOption{ContinueOnError: true},
|
||||
StructOption: gconv.StructOption{ContinueOnError: true},
|
||||
}
|
||||
dv, err := converter.ConvertWithTypeName(fromVal, to.Type().String(), convertOption)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
to.Set(reflect.ValueOf(dv))
|
||||
default:
|
||||
err = gerror.Newf(
|
||||
`unsupported type converting from type "%T" to type "%T"`,
|
||||
fromVal, to,
|
||||
)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
@ -0,0 +1,841 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
"git.magicany.cc/black1552/gin-base/database/reflection"
|
||||
"git.magicany.cc/black1552/gin-base/database/utils"
|
||||
"github.com/gogf/gf/v2/container/gmap"
|
||||
"github.com/gogf/gf/v2/container/gset"
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gcache"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// GetCore returns the underlying *Core object.
|
||||
func (c *Core) GetCore() *Core {
|
||||
return c
|
||||
}
|
||||
|
||||
// Ctx is a chaining function, which creates and returns a new DB that is a shallow copy
|
||||
// of current DB object and with given context in it.
|
||||
// Note that this returned DB object can be used only once, so do not assign it to
|
||||
// a global or package variable for long using.
|
||||
func (c *Core) Ctx(ctx context.Context) DB {
|
||||
if ctx == nil {
|
||||
return c.db
|
||||
}
|
||||
// It makes a shallow copy of current db and changes its context for next chaining operation.
|
||||
var (
|
||||
err error
|
||||
newCore = &Core{}
|
||||
configNode = c.db.GetConfig()
|
||||
)
|
||||
*newCore = *c
|
||||
// It creates a new DB object(NOT NEW CONNECTION), which is commonly a wrapper for object `Core`.
|
||||
newCore.db, err = driverMap[configNode.Type].New(newCore, configNode)
|
||||
if err != nil {
|
||||
// It is really a serious error here.
|
||||
// Do not let it continue.
|
||||
panic(err)
|
||||
}
|
||||
newCore.ctx = WithDB(ctx, newCore.db)
|
||||
newCore.ctx = c.injectInternalCtxData(newCore.ctx)
|
||||
return newCore.db
|
||||
}
|
||||
|
||||
// GetCtx returns the context for current DB.
|
||||
// It returns `context.Background()` is there's no context previously set.
|
||||
func (c *Core) GetCtx() context.Context {
|
||||
ctx := c.ctx
|
||||
if ctx == nil {
|
||||
ctx = context.TODO()
|
||||
}
|
||||
return c.injectInternalCtxData(ctx)
|
||||
}
|
||||
|
||||
// GetCtxTimeout returns the context and cancel function for specified timeout type.
|
||||
func (c *Core) GetCtxTimeout(ctx context.Context, timeoutType ctxTimeoutType) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = c.db.GetCtx()
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, ctxKeyWrappedByGetCtxTimeout, nil)
|
||||
}
|
||||
var config = c.db.GetConfig()
|
||||
switch timeoutType {
|
||||
case ctxTimeoutTypeExec:
|
||||
if c.db.GetConfig().ExecTimeout > 0 {
|
||||
return context.WithTimeout(ctx, config.ExecTimeout)
|
||||
}
|
||||
case ctxTimeoutTypeQuery:
|
||||
if c.db.GetConfig().QueryTimeout > 0 {
|
||||
return context.WithTimeout(ctx, config.QueryTimeout)
|
||||
}
|
||||
case ctxTimeoutTypePrepare:
|
||||
if c.db.GetConfig().PrepareTimeout > 0 {
|
||||
return context.WithTimeout(ctx, config.PrepareTimeout)
|
||||
}
|
||||
case ctxTimeoutTypeTrans:
|
||||
if c.db.GetConfig().TranTimeout > 0 {
|
||||
return context.WithTimeout(ctx, config.TranTimeout)
|
||||
}
|
||||
default:
|
||||
panic(gerror.NewCodef(gcode.CodeInvalidParameter, "invalid context timeout type: %d", timeoutType))
|
||||
}
|
||||
return ctx, func() {}
|
||||
}
|
||||
|
||||
// Close closes the database and prevents new queries from starting.
|
||||
// Close then waits for all queries that have started processing on the server
|
||||
// to finish.
|
||||
//
|
||||
// It is rare to Close a DB, as the DB handle is meant to be
|
||||
// long-lived and shared between many goroutines.
|
||||
func (c *Core) Close(ctx context.Context) (err error) {
|
||||
if err = c.cache.Close(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
c.links.LockFunc(func(m map[ConfigNode]*sql.DB) {
|
||||
for k, v := range m {
|
||||
err = v.Close()
|
||||
if err != nil {
|
||||
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `db.Close failed`)
|
||||
}
|
||||
intlog.Printf(ctx, `close link: %s, err: %v`, gconv.String(k), err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
delete(m, k)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Master creates and returns a connection from master node if master-slave configured.
|
||||
// It returns the default connection if master-slave not configured.
|
||||
func (c *Core) Master(schema ...string) (*sql.DB, error) {
|
||||
var (
|
||||
usedSchema = gutil.GetOrDefaultStr(c.schema, schema...)
|
||||
charL, charR = c.db.GetChars()
|
||||
)
|
||||
return c.getSqlDb(true, gstr.Trim(usedSchema, charL+charR))
|
||||
}
|
||||
|
||||
// Slave creates and returns a connection from slave node if master-slave configured.
|
||||
// It returns the default connection if master-slave not configured.
|
||||
func (c *Core) Slave(schema ...string) (*sql.DB, error) {
|
||||
var (
|
||||
usedSchema = gutil.GetOrDefaultStr(c.schema, schema...)
|
||||
charL, charR = c.db.GetChars()
|
||||
)
|
||||
return c.getSqlDb(false, gstr.Trim(usedSchema, charL+charR))
|
||||
}
|
||||
|
||||
// GetAll queries and returns data records from database.
|
||||
func (c *Core) GetAll(ctx context.Context, sql string, args ...any) (Result, error) {
|
||||
return c.db.DoSelect(ctx, nil, sql, args...)
|
||||
}
|
||||
|
||||
// DoSelect queries and returns data records from database.
|
||||
func (c *Core) DoSelect(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) {
|
||||
return c.db.DoQuery(ctx, link, sql, args...)
|
||||
}
|
||||
|
||||
// GetOne queries and returns one record from database.
|
||||
func (c *Core) GetOne(ctx context.Context, sql string, args ...any) (Record, error) {
|
||||
list, err := c.db.GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) > 0 {
|
||||
return list[0], nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetArray queries and returns data values as slice from database.
|
||||
// Note that if there are multiple columns in the result, it returns just one column values randomly.
|
||||
func (c *Core) GetArray(ctx context.Context, sql string, args ...any) (Array, error) {
|
||||
all, err := c.db.DoSelect(ctx, nil, sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return all.Array(), nil
|
||||
}
|
||||
|
||||
// doGetStruct queries one record from database and converts it to given struct.
|
||||
// The parameter `pointer` should be a pointer to struct.
|
||||
func (c *Core) doGetStruct(ctx context.Context, pointer any, sql string, args ...any) error {
|
||||
one, err := c.db.GetOne(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return one.Struct(pointer)
|
||||
}
|
||||
|
||||
// doGetStructs queries records from database and converts them to given struct.
|
||||
// The parameter `pointer` should be type of struct slice: []struct/[]*struct.
|
||||
func (c *Core) doGetStructs(ctx context.Context, pointer any, sql string, args ...any) error {
|
||||
all, err := c.db.GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return all.Structs(pointer)
|
||||
}
|
||||
|
||||
// GetScan queries one or more records from database and converts them to given struct or
|
||||
// struct array.
|
||||
//
|
||||
// If parameter `pointer` is type of struct pointer, it calls GetStruct internally for
|
||||
// the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally
|
||||
// for conversion.
|
||||
func (c *Core) GetScan(ctx context.Context, pointer any, sql string, args ...any) error {
|
||||
reflectInfo := reflection.OriginTypeAndKind(pointer)
|
||||
if reflectInfo.InputKind != reflect.Pointer {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
"params should be type of pointer, but got: %v",
|
||||
reflectInfo.InputKind,
|
||||
)
|
||||
}
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return c.db.GetCore().doGetStructs(ctx, pointer, sql, args...)
|
||||
|
||||
case reflect.Struct:
|
||||
return c.db.GetCore().doGetStruct(ctx, pointer, sql, args...)
|
||||
|
||||
default:
|
||||
}
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`in valid parameter type "%v", of which element type should be type of struct/slice`,
|
||||
reflectInfo.InputType,
|
||||
)
|
||||
}
|
||||
|
||||
// GetValue queries and returns the field value from database.
|
||||
// The sql should query only one field from database, or else it returns only one
|
||||
// field of the result.
|
||||
func (c *Core) GetValue(ctx context.Context, sql string, args ...any) (Value, error) {
|
||||
one, err := c.db.GetOne(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return gvar.New(nil), err
|
||||
}
|
||||
for _, v := range one {
|
||||
return v, nil
|
||||
}
|
||||
return gvar.New(nil), nil
|
||||
}
|
||||
|
||||
// GetCount queries and returns the count from database.
|
||||
func (c *Core) GetCount(ctx context.Context, sql string, args ...any) (int, error) {
|
||||
// If the query fields do not contain function "COUNT",
|
||||
// it replaces the sql string and adds the "COUNT" function to the fields.
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
|
||||
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
|
||||
}
|
||||
value, err := c.db.GetValue(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value.Int(), nil
|
||||
}
|
||||
|
||||
// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement.
|
||||
func (c *Core) Union(unions ...*Model) *Model {
|
||||
var ctx = c.db.GetCtx()
|
||||
return c.doUnion(ctx, unionTypeNormal, unions...)
|
||||
}
|
||||
|
||||
// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement.
|
||||
func (c *Core) UnionAll(unions ...*Model) *Model {
|
||||
var ctx = c.db.GetCtx()
|
||||
return c.doUnion(ctx, unionTypeAll, unions...)
|
||||
}
|
||||
|
||||
func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Model {
|
||||
var (
|
||||
unionTypeStr string
|
||||
composedSqlStr string
|
||||
composedArgs = make([]any, 0)
|
||||
)
|
||||
if unionType == unionTypeAll {
|
||||
unionTypeStr = "UNION ALL"
|
||||
} else {
|
||||
unionTypeStr = "UNION"
|
||||
}
|
||||
for _, v := range unions {
|
||||
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false)
|
||||
if composedSqlStr == "" {
|
||||
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
|
||||
} else {
|
||||
composedSqlStr += fmt.Sprintf(` %s (%s)`, unionTypeStr, sqlWithHolder)
|
||||
}
|
||||
composedArgs = append(composedArgs, holderArgs...)
|
||||
}
|
||||
return c.db.Raw(composedSqlStr, composedArgs...)
|
||||
}
|
||||
|
||||
// PingMaster pings the master node to check authentication or keeps the connection alive.
|
||||
func (c *Core) PingMaster() error {
|
||||
var ctx = c.db.GetCtx()
|
||||
if master, err := c.db.Master(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err = master.PingContext(ctx); err != nil {
|
||||
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `master.Ping failed`)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// PingSlave pings the slave node to check authentication or keeps the connection alive.
|
||||
func (c *Core) PingSlave() error {
|
||||
var ctx = c.db.GetCtx()
|
||||
if slave, err := c.db.Slave(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err = slave.PingContext(ctx); err != nil {
|
||||
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `slave.Ping failed`)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Insert does "INSERT INTO ..." statement for the table.
|
||||
// If there's already one unique record of the data in the table, it returns error.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `batch` specifies the batch operation count when given data is slice.
|
||||
func (c *Core) Insert(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Insert()
|
||||
}
|
||||
return c.Model(table).Ctx(ctx).Data(data).Insert()
|
||||
}
|
||||
|
||||
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
|
||||
// If there's already one unique record of the data in the table, it ignores the inserting.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `batch` specifies the batch operation count when given data is slice.
|
||||
func (c *Core) InsertIgnore(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).InsertIgnore()
|
||||
}
|
||||
return c.Model(table).Ctx(ctx).Data(data).InsertIgnore()
|
||||
}
|
||||
|
||||
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
|
||||
func (c *Core) InsertAndGetId(ctx context.Context, table string, data any, batch ...int) (int64, error) {
|
||||
if len(batch) > 0 {
|
||||
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).InsertAndGetId()
|
||||
}
|
||||
return c.Model(table).Ctx(ctx).Data(data).InsertAndGetId()
|
||||
}
|
||||
|
||||
// Replace does "REPLACE INTO ..." statement for the table.
|
||||
// If there's already one unique record of the data in the table, it deletes the record
|
||||
// and inserts a new one.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// If given data is type of slice, it then does batch replacing, and the optional parameter
|
||||
// `batch` specifies the batch operation count.
|
||||
func (c *Core) Replace(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Replace()
|
||||
}
|
||||
return c.Model(table).Ctx(ctx).Data(data).Replace()
|
||||
}
|
||||
|
||||
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
|
||||
// It updates the record if there's primary or unique index in the saving data,
|
||||
// or else it inserts a new record into the table.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// If given data is type of slice, it then does batch saving, and the optional parameter
|
||||
// `batch` specifies the batch operation count.
|
||||
func (c *Core) Save(ctx context.Context, table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return c.Model(table).Ctx(ctx).Data(data).Batch(batch[0]).Save()
|
||||
}
|
||||
return c.Model(table).Ctx(ctx).Data(data).Save()
|
||||
}
|
||||
|
||||
func (c *Core) fieldsToSequence(ctx context.Context, table string, fields []string) ([]string, error) {
|
||||
var (
|
||||
fieldSet = gset.NewStrSetFrom(fields)
|
||||
fieldsResultInSequence = make([]string, 0)
|
||||
tableFields, err = c.db.TableFields(ctx, table)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Sort the fields in order.
|
||||
var fieldsOfTableInSequence = make([]string, len(tableFields))
|
||||
for _, field := range tableFields {
|
||||
fieldsOfTableInSequence[field.Index] = field.Name
|
||||
}
|
||||
// Sort the input fields.
|
||||
for _, fieldName := range fieldsOfTableInSequence {
|
||||
if fieldSet.Contains(fieldName) {
|
||||
fieldsResultInSequence = append(fieldsResultInSequence, fieldName)
|
||||
}
|
||||
}
|
||||
return fieldsResultInSequence, nil
|
||||
}
|
||||
|
||||
// DoInsert inserts or updates data for given table.
|
||||
// This function is usually used for custom interface definition, you do not need call it manually.
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `option` values are as follows:
|
||||
// InsertOptionDefault: just insert, if there's unique/primary key in the data, it returns error;
|
||||
// InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
|
||||
// InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one;
|
||||
// InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting;
|
||||
func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) {
|
||||
var (
|
||||
keys []string // Field names.
|
||||
values []string // Value holder string array, like: (?,?,?)
|
||||
params []any // Values that will be committed to underlying database driver.
|
||||
onDuplicateStr string // onDuplicateStr is used in "ON DUPLICATE KEY UPDATE" statement.
|
||||
)
|
||||
// ============================================================================================
|
||||
// Group the list by fields. Different fields to different list.
|
||||
// It here uses ListMap to keep sequence for data inserting.
|
||||
// ============================================================================================
|
||||
var (
|
||||
keyListMap = gmap.NewListMap()
|
||||
tmpKeyListMap = make(map[string]List)
|
||||
)
|
||||
for _, item := range list {
|
||||
mapLen := len(item)
|
||||
if mapLen == 0 {
|
||||
continue
|
||||
}
|
||||
tmpKeys := make([]string, 0, mapLen)
|
||||
for k := range item {
|
||||
tmpKeys = append(tmpKeys, k)
|
||||
}
|
||||
if mapLen > 1 {
|
||||
sort.Strings(tmpKeys)
|
||||
}
|
||||
keys = tmpKeys // for fieldsToSequence
|
||||
|
||||
tmpKeysInSequenceStr := gstr.Join(tmpKeys, ",")
|
||||
if tmpKeyListMapItem, ok := tmpKeyListMap[tmpKeysInSequenceStr]; ok {
|
||||
tmpKeyListMap[tmpKeysInSequenceStr] = append(tmpKeyListMapItem, item)
|
||||
} else {
|
||||
tmpKeyListMap[tmpKeysInSequenceStr] = List{item}
|
||||
}
|
||||
}
|
||||
for tmpKeysInSequenceStr, itemList := range tmpKeyListMap {
|
||||
keyListMap.Set(tmpKeysInSequenceStr, itemList)
|
||||
}
|
||||
if keyListMap.Size() > 1 {
|
||||
var (
|
||||
tmpResult sql.Result
|
||||
sqlResult SqlResult
|
||||
rowsAffected int64
|
||||
)
|
||||
keyListMap.Iterator(func(key, value any) bool {
|
||||
tmpResult, err = c.DoInsert(ctx, link, table, value.(List), option)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
rowsAffected, err = tmpResult.RowsAffected()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
sqlResult.Result = tmpResult
|
||||
sqlResult.Affected += rowsAffected
|
||||
return true
|
||||
})
|
||||
return &sqlResult, err
|
||||
}
|
||||
|
||||
keys, err = c.fieldsToSequence(ctx, table, keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, gerror.NewCode(gcode.CodeInvalidParameter, "no valid data fields found in table")
|
||||
}
|
||||
|
||||
// Prepare the batch result pointer.
|
||||
var (
|
||||
charL, charR = c.db.GetChars()
|
||||
batchResult = new(SqlResult)
|
||||
keysStr = charL + strings.Join(keys, charR+","+charL) + charR
|
||||
operation = GetInsertOperationByOption(option.InsertOption)
|
||||
)
|
||||
// Upsert clause only takes effect on Save operation.
|
||||
if option.InsertOption == InsertOptionSave {
|
||||
onDuplicateStr, err = c.db.FormatUpsert(keys, list, option)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var (
|
||||
listLength = len(list)
|
||||
valueHolders = make([]string, 0)
|
||||
)
|
||||
for i := 0; i < listLength; i++ {
|
||||
values = values[:0]
|
||||
// Note that the map type is unordered,
|
||||
// so it should use slice+key to retrieve the value.
|
||||
for _, k := range keys {
|
||||
if s, ok := list[i][k].(Raw); ok {
|
||||
values = append(values, gconv.String(s))
|
||||
} else {
|
||||
values = append(values, "?")
|
||||
params = append(params, list[i][k])
|
||||
}
|
||||
}
|
||||
valueHolders = append(valueHolders, "("+gstr.Join(values, ",")+")")
|
||||
// Batch package checks: It meets the batch number, or it is the last element.
|
||||
if len(valueHolders) == option.BatchCount || (i == listLength-1 && len(valueHolders) > 0) {
|
||||
var (
|
||||
stdSqlResult sql.Result
|
||||
affectedRows int64
|
||||
)
|
||||
stdSqlResult, err = c.db.DoExec(ctx, link, fmt.Sprintf(
|
||||
"%s INTO %s(%s) VALUES%s %s",
|
||||
operation, c.QuotePrefixTableName(table), keysStr,
|
||||
gstr.Join(valueHolders, ","),
|
||||
onDuplicateStr,
|
||||
), params...)
|
||||
if err != nil {
|
||||
return stdSqlResult, err
|
||||
}
|
||||
if affectedRows, err = stdSqlResult.RowsAffected(); err != nil {
|
||||
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `sql.Result.RowsAffected failed`)
|
||||
return stdSqlResult, err
|
||||
} else {
|
||||
batchResult.Result = stdSqlResult
|
||||
batchResult.Affected += affectedRows
|
||||
}
|
||||
params = params[:0]
|
||||
valueHolders = valueHolders[:0]
|
||||
}
|
||||
}
|
||||
return batchResult, nil
|
||||
}
|
||||
|
||||
// Update does "UPDATE ... " statement for the table.
|
||||
//
|
||||
// The parameter `data` can be type of string/map/gmap/struct/*struct, etc.
|
||||
// Eg: "uid=10000", "uid", 10000, g.Map{"uid": 10000, "name":"john"}
|
||||
//
|
||||
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
|
||||
// It is commonly used with parameter `args`.
|
||||
// Eg:
|
||||
// "uid=10000",
|
||||
// "uid", 10000
|
||||
// "money>? AND name like ?", 99999, "vip_%"
|
||||
// "status IN (?)", g.Slice{1,2,3}
|
||||
// "age IN(?,?)", 18, 50
|
||||
// User{ Id : 1, UserName : "john"}.
|
||||
func (c *Core) Update(ctx context.Context, table string, data any, condition any, args ...any) (sql.Result, error) {
|
||||
return c.Model(table).Ctx(ctx).Data(data).Where(condition, args...).Update()
|
||||
}
|
||||
|
||||
// DoUpdate does "UPDATE ... " statement for the table.
|
||||
// This function is usually used for custom interface definition, you do not need to call it manually.
|
||||
func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data any, condition string, args ...any) (result sql.Result, err error) {
|
||||
table = c.QuotePrefixTableName(table)
|
||||
var (
|
||||
rv = reflect.ValueOf(data)
|
||||
kind = rv.Kind()
|
||||
)
|
||||
if kind == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
var (
|
||||
params []any
|
||||
updates string
|
||||
)
|
||||
switch kind {
|
||||
case reflect.Map, reflect.Struct:
|
||||
var (
|
||||
fields []string
|
||||
dataMap map[string]any
|
||||
)
|
||||
dataMap, err = c.ConvertDataForRecord(ctx, data, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Sort the data keys in sequence of table fields.
|
||||
var (
|
||||
dataKeys = make([]string, 0)
|
||||
keysInSequence = make([]string, 0)
|
||||
)
|
||||
for k := range dataMap {
|
||||
dataKeys = append(dataKeys, k)
|
||||
}
|
||||
keysInSequence, err = c.fieldsToSequence(ctx, table, dataKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, k := range keysInSequence {
|
||||
v := dataMap[k]
|
||||
switch v.(type) {
|
||||
case Counter, *Counter:
|
||||
var counter Counter
|
||||
switch value := v.(type) {
|
||||
case Counter:
|
||||
counter = value
|
||||
case *Counter:
|
||||
counter = *value
|
||||
}
|
||||
if counter.Value == 0 {
|
||||
continue
|
||||
}
|
||||
operator, columnVal := c.getCounterAlter(counter)
|
||||
fields = append(fields, fmt.Sprintf("%s=%s%s?", c.QuoteWord(k), c.QuoteWord(counter.Field), operator))
|
||||
params = append(params, columnVal)
|
||||
default:
|
||||
if s, ok := v.(Raw); ok {
|
||||
fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s))
|
||||
} else {
|
||||
fields = append(fields, c.QuoteWord(k)+"=?")
|
||||
params = append(params, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
updates = strings.Join(fields, ",")
|
||||
|
||||
default:
|
||||
updates = gconv.String(data)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return nil, gerror.NewCode(gcode.CodeMissingParameter, "data cannot be empty")
|
||||
}
|
||||
if len(params) > 0 {
|
||||
args = append(params, args...)
|
||||
}
|
||||
// If no link passed, it then uses the master link.
|
||||
if link == nil {
|
||||
if link, err = c.MasterLink(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c.db.DoExec(ctx, link, fmt.Sprintf(
|
||||
"UPDATE %s SET %s%s",
|
||||
table, updates, condition,
|
||||
),
|
||||
args...,
|
||||
)
|
||||
}
|
||||
|
||||
// Delete does "DELETE FROM ... " statement for the table.
|
||||
//
|
||||
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
|
||||
// It is commonly used with parameter `args`.
|
||||
// Eg:
|
||||
// "uid=10000",
|
||||
// "uid", 10000
|
||||
// "money>? AND name like ?", 99999, "vip_%"
|
||||
// "status IN (?)", g.Slice{1,2,3}
|
||||
// "age IN(?,?)", 18, 50
|
||||
// User{ Id : 1, UserName : "john"}.
|
||||
func (c *Core) Delete(ctx context.Context, table string, condition any, args ...any) (result sql.Result, err error) {
|
||||
return c.Model(table).Ctx(ctx).Where(condition, args...).Delete()
|
||||
}
|
||||
|
||||
// DoDelete does "DELETE FROM ... " statement for the table.
|
||||
// This function is usually used for custom interface definition, you do not need call it manually.
|
||||
func (c *Core) DoDelete(ctx context.Context, link Link, table string, condition string, args ...any) (result sql.Result, err error) {
|
||||
if link == nil {
|
||||
if link, err = c.MasterLink(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
table = c.QuotePrefixTableName(table)
|
||||
return c.db.DoExec(ctx, link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
|
||||
}
|
||||
|
||||
// FilteredLink retrieves and returns filtered `linkInfo` that can be using for
|
||||
// logging or tracing purpose.
|
||||
func (c *Core) FilteredLink() string {
|
||||
return fmt.Sprintf(
|
||||
`%s@%s(%s:%s)/%s`,
|
||||
c.config.User, c.config.Protocol, c.config.Host, c.config.Port, c.config.Name,
|
||||
)
|
||||
}
|
||||
|
||||
// MarshalJSON implements the interface MarshalJSON for json.Marshal.
|
||||
// It just returns the pointer address.
|
||||
//
|
||||
// Note that this interface implements mainly for workaround for a json infinite loop bug
|
||||
// of Golang version < v1.14.
|
||||
func (c *Core) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf(`%+v`, c)), nil
|
||||
}
|
||||
|
||||
// writeSqlToLogger outputs the Sql object to logger.
|
||||
// It is enabled only if configuration "debug" is true.
|
||||
func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) {
|
||||
var transactionIdStr string
|
||||
if sql.IsTransaction {
|
||||
if v := ctx.Value(transactionIdForLoggerCtx); v != nil {
|
||||
transactionIdStr = fmt.Sprintf(`[txid:%d] `, v.(uint64))
|
||||
}
|
||||
}
|
||||
s := fmt.Sprintf(
|
||||
"[%3d ms] [%s] [%s] [rows:%-3d] %s%s",
|
||||
sql.End-sql.Start, sql.Group, sql.Schema, sql.RowsAffected, transactionIdStr, sql.Format,
|
||||
)
|
||||
if sql.Error != nil {
|
||||
s += "\nError: " + sql.Error.Error()
|
||||
c.logger.Error(ctx, s)
|
||||
} else {
|
||||
c.logger.Debug(ctx, s)
|
||||
}
|
||||
}
|
||||
|
||||
// HasTable determine whether the table name exists in the database.
|
||||
func (c *Core) HasTable(name string) (bool, error) {
|
||||
tables, err := c.GetTablesWithCache()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
charL, charR := c.db.GetChars()
|
||||
name = gstr.Trim(name, charL+charR)
|
||||
for _, table := range tables {
|
||||
if table == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetInnerMemCache retrieves and returns the inner memory cache object.
|
||||
func (c *Core) GetInnerMemCache() *gcache.Cache {
|
||||
return c.innerMemCache
|
||||
}
|
||||
|
||||
func (c *Core) SetTableFields(ctx context.Context, table string, fields map[string]*TableField, schema ...string) error {
|
||||
if table == "" {
|
||||
return gerror.NewCode(gcode.CodeInvalidParameter, "table name cannot be empty")
|
||||
}
|
||||
charL, charR := c.db.GetChars()
|
||||
table = gstr.Trim(table, charL+charR)
|
||||
if gstr.Contains(table, " ") {
|
||||
return gerror.NewCode(
|
||||
gcode.CodeInvalidParameter,
|
||||
"function TableFields supports only single table operations",
|
||||
)
|
||||
}
|
||||
var (
|
||||
innerMemCache = c.GetInnerMemCache()
|
||||
// prefix:group@schema#table
|
||||
cacheKey = genTableFieldsCacheKey(
|
||||
c.db.GetGroup(),
|
||||
gutil.GetOrDefaultStr(c.db.GetSchema(), schema...),
|
||||
table,
|
||||
)
|
||||
)
|
||||
return innerMemCache.Set(ctx, cacheKey, fields, gcache.DurationNoExpire)
|
||||
}
|
||||
|
||||
// GetTablesWithCache retrieves and returns the table names of current database with cache.
|
||||
func (c *Core) GetTablesWithCache() ([]string, error) {
|
||||
var (
|
||||
ctx = c.db.GetCtx()
|
||||
cacheKey = genTableNamesCacheKey(c.db.GetGroup())
|
||||
cacheDuration = gcache.DurationNoExpire
|
||||
innerMemCache = c.GetInnerMemCache()
|
||||
)
|
||||
result, err := innerMemCache.GetOrSetFuncLock(
|
||||
ctx, cacheKey,
|
||||
func(ctx context.Context) (any, error) {
|
||||
tableList, err := c.db.Tables(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tableList, nil
|
||||
}, cacheDuration,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Strings(), nil
|
||||
}
|
||||
|
||||
// IsSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time.
|
||||
func (c *Core) IsSoftCreatedFieldName(fieldName string) bool {
|
||||
if fieldName == "" {
|
||||
return false
|
||||
}
|
||||
if config := c.db.GetConfig(); config.CreatedAt != "" {
|
||||
if utils.EqualFoldWithoutChars(fieldName, config.CreatedAt) {
|
||||
return true
|
||||
}
|
||||
return gstr.InArray(append([]string{config.CreatedAt}, createdFieldNames...), fieldName)
|
||||
}
|
||||
for _, v := range createdFieldNames {
|
||||
if utils.EqualFoldWithoutChars(fieldName, v) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FormatSqlBeforeExecuting formats the sql string and its arguments before executing.
|
||||
// The internal handleArguments function might be called twice during the SQL procedure,
|
||||
// but do not worry about it, it's safe and efficient.
|
||||
func (c *Core) FormatSqlBeforeExecuting(sql string, args []any) (newSql string, newArgs []any) {
|
||||
return handleSliceAndStructArgsForSql(sql, args)
|
||||
}
|
||||
|
||||
// getCounterAlter
|
||||
func (c *Core) getCounterAlter(counter Counter) (operator string, columnVal float64) {
|
||||
operator, columnVal = "+", counter.Value
|
||||
if columnVal < 0 {
|
||||
operator, columnVal = "-", -columnVal
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,485 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/log"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gcache"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// Config is the configuration management object.
|
||||
type Config map[string]ConfigGroup
|
||||
|
||||
// ConfigGroup is a slice of configuration node for specified named group.
|
||||
type ConfigGroup []ConfigNode
|
||||
|
||||
// ConfigNode is configuration for one node.
|
||||
type ConfigNode struct {
|
||||
// Host specifies the server address, can be either IP address or domain name
|
||||
// Example: "127.0.0.1", "localhost"
|
||||
Host string `json:"host"`
|
||||
|
||||
// Port specifies the server port number
|
||||
// Default is typically "3306" for MySQL
|
||||
Port string `json:"port"`
|
||||
|
||||
// User specifies the authentication username for database connection
|
||||
User string `json:"user"`
|
||||
|
||||
// Pass specifies the authentication password for database connection
|
||||
Pass string `json:"pass"`
|
||||
|
||||
// Name specifies the default database name to be used
|
||||
Name string `json:"name"`
|
||||
|
||||
// Type specifies the database type
|
||||
// Example: mysql, mariadb, sqlite, mssql, pgsql, oracle, clickhouse, dm.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Link provides custom connection string that combines all configuration in one string
|
||||
// Optional field
|
||||
Link string `json:"link"`
|
||||
|
||||
// Extra provides additional configuration options for third-party database drivers
|
||||
// Optional field
|
||||
Extra string `json:"extra"`
|
||||
|
||||
// Role specifies the node role in master-slave setup
|
||||
// Optional field, defaults to "master"
|
||||
// Available values: "master", "slave"
|
||||
Role Role `json:"role"`
|
||||
|
||||
// Debug enables debug mode for logging and output
|
||||
// Optional field
|
||||
Debug bool `json:"debug"`
|
||||
|
||||
// Prefix specifies the table name prefix
|
||||
// Optional field
|
||||
Prefix string `json:"prefix"`
|
||||
|
||||
// DryRun enables simulation mode where SELECT statements are executed
|
||||
// but INSERT/UPDATE/DELETE statements are not
|
||||
// Optional field
|
||||
DryRun bool `json:"dryRun"`
|
||||
|
||||
// Weight specifies the node weight for load balancing calculations
|
||||
// Optional field, only effective in multi-node setups
|
||||
Weight int `json:"weight"`
|
||||
|
||||
// Charset specifies the character set for database operations
|
||||
// Optional field, defaults to "utf8"
|
||||
Charset string `json:"charset"`
|
||||
|
||||
// Protocol specifies the network protocol for database connection
|
||||
// Optional field, defaults to "tcp"
|
||||
// See net.Dial for available network protocols
|
||||
Protocol string `json:"protocol"`
|
||||
|
||||
// Timezone sets the time zone for timestamp interpretation and display
|
||||
// Optional field
|
||||
Timezone string `json:"timezone"`
|
||||
|
||||
// Namespace specifies the schema namespace for certain databases
|
||||
// Optional field, e.g., in PostgreSQL, Name is the catalog and Namespace is the schema
|
||||
Namespace string `json:"namespace"`
|
||||
|
||||
// MaxIdleConnCount specifies the maximum number of idle connections in the pool
|
||||
// Optional field
|
||||
MaxIdleConnCount int `json:"maxIdle"`
|
||||
|
||||
// MaxOpenConnCount specifies the maximum number of open connections in the pool
|
||||
// Optional field
|
||||
MaxOpenConnCount int `json:"maxOpen"`
|
||||
|
||||
// MaxConnLifeTime specifies the maximum lifetime of a connection
|
||||
// Optional field
|
||||
MaxConnLifeTime time.Duration `json:"maxLifeTime"`
|
||||
|
||||
// MaxIdleConnTime specifies the maximum idle time of a connection before being closed
|
||||
// This is Go 1.15+ feature: sql.DB.SetConnMaxIdleTime
|
||||
// Optional field
|
||||
MaxIdleConnTime time.Duration `json:"maxIdleTime"`
|
||||
|
||||
// QueryTimeout specifies the maximum execution time for DQL operations
|
||||
// Optional field
|
||||
QueryTimeout time.Duration `json:"queryTimeout"`
|
||||
|
||||
// ExecTimeout specifies the maximum execution time for DML operations
|
||||
// Optional field
|
||||
ExecTimeout time.Duration `json:"execTimeout"`
|
||||
|
||||
// TranTimeout specifies the maximum execution time for a transaction block
|
||||
// Optional field
|
||||
TranTimeout time.Duration `json:"tranTimeout"`
|
||||
|
||||
// PrepareTimeout specifies the maximum execution time for prepare operations
|
||||
// Optional field
|
||||
PrepareTimeout time.Duration `json:"prepareTimeout"`
|
||||
|
||||
// CreatedAt specifies the field name for automatic timestamp on record creation
|
||||
// Optional field
|
||||
CreatedAt string `json:"createdAt"`
|
||||
|
||||
// UpdatedAt specifies the field name for automatic timestamp on record updates
|
||||
// Optional field
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
|
||||
// DeletedAt specifies the field name for automatic timestamp on record deletion
|
||||
// Optional field
|
||||
DeletedAt string `json:"deletedAt"`
|
||||
|
||||
// TimeMaintainDisabled controls whether automatic time maintenance is disabled
|
||||
// Optional field
|
||||
TimeMaintainDisabled bool `json:"timeMaintainDisabled"`
|
||||
}
|
||||
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleMaster Role = "master"
|
||||
RoleSlave Role = "slave"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultGroupName = "default" // Default group name.
|
||||
)
|
||||
|
||||
// configs specifies internal used configuration object.
|
||||
var configs struct {
|
||||
sync.RWMutex
|
||||
config Config // All configurations.
|
||||
group string // Default configuration group.
|
||||
}
|
||||
|
||||
func init() {
|
||||
configs.config = make(Config)
|
||||
configs.group = DefaultGroupName
|
||||
}
|
||||
|
||||
// SetConfig sets the global configuration for package.
|
||||
// It will overwrite the old configuration of package.
|
||||
func SetConfig(config Config) error {
|
||||
defer instances.Clear()
|
||||
configs.Lock()
|
||||
defer configs.Unlock()
|
||||
|
||||
for k, nodes := range config {
|
||||
for i, node := range nodes {
|
||||
parsedNode, err := parseConfigNode(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nodes[i] = parsedNode
|
||||
}
|
||||
config[k] = nodes
|
||||
}
|
||||
configs.config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConfigGroup sets the configuration for given group.
|
||||
func SetConfigGroup(group string, nodes ConfigGroup) error {
|
||||
defer instances.Clear()
|
||||
configs.Lock()
|
||||
defer configs.Unlock()
|
||||
|
||||
for i, node := range nodes {
|
||||
parsedNode, err := parseConfigNode(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nodes[i] = parsedNode
|
||||
}
|
||||
configs.config[group] = nodes
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddConfigNode adds one node configuration to configuration of given group.
|
||||
func AddConfigNode(group string, node ConfigNode) error {
|
||||
defer instances.Clear()
|
||||
configs.Lock()
|
||||
defer configs.Unlock()
|
||||
|
||||
parsedNode, err := parseConfigNode(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configs.config[group] = append(configs.config[group], parsedNode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseConfigNode parses `Link` configuration syntax.
|
||||
func parseConfigNode(node ConfigNode) (ConfigNode, error) {
|
||||
if node.Link != "" {
|
||||
parsedLinkNode, err := parseConfigNodeLink(&node)
|
||||
if err != nil {
|
||||
return node, err
|
||||
}
|
||||
node = *parsedLinkNode
|
||||
}
|
||||
if node.Link != "" && node.Type == "" {
|
||||
match, _ := gregex.MatchString(`([a-z]+):(.+)`, node.Link)
|
||||
if len(match) == 3 {
|
||||
node.Type = gstr.Trim(match[1])
|
||||
node.Link = gstr.Trim(match[2])
|
||||
}
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// AddDefaultConfigNode adds one node configuration to configuration of default group.
|
||||
func AddDefaultConfigNode(node ConfigNode) error {
|
||||
return AddConfigNode(DefaultGroupName, node)
|
||||
}
|
||||
|
||||
// AddDefaultConfigGroup adds multiple node configurations to configuration of default group.
|
||||
//
|
||||
// Deprecated: Use SetDefaultConfigGroup instead.
|
||||
func AddDefaultConfigGroup(nodes ConfigGroup) error {
|
||||
return SetConfigGroup(DefaultGroupName, nodes)
|
||||
}
|
||||
|
||||
// SetDefaultConfigGroup sets multiple node configurations to configuration of default group.
|
||||
func SetDefaultConfigGroup(nodes ConfigGroup) error {
|
||||
return SetConfigGroup(DefaultGroupName, nodes)
|
||||
}
|
||||
|
||||
// GetConfig retrieves and returns the configuration of given group.
|
||||
//
|
||||
// Deprecated: Use GetConfigGroup instead.
|
||||
func GetConfig(group string) ConfigGroup {
|
||||
configGroup, _ := GetConfigGroup(group)
|
||||
return configGroup
|
||||
}
|
||||
|
||||
// GetConfigGroup retrieves and returns the configuration of given group.
|
||||
// It returns an error if the group does not exist, or an empty slice if the group exists but has no nodes.
|
||||
func GetConfigGroup(group string) (ConfigGroup, error) {
|
||||
configs.RLock()
|
||||
defer configs.RUnlock()
|
||||
|
||||
configGroup, exists := configs.config[group]
|
||||
if !exists {
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`configuration group "%s" not found`,
|
||||
group,
|
||||
)
|
||||
}
|
||||
return configGroup, nil
|
||||
}
|
||||
|
||||
// GetAllConfig retrieves and returns all configurations.
|
||||
func GetAllConfig() Config {
|
||||
configs.RLock()
|
||||
defer configs.RUnlock()
|
||||
return configs.config
|
||||
}
|
||||
|
||||
// SetDefaultGroup sets the group name for default configuration.
|
||||
func SetDefaultGroup(name string) {
|
||||
defer instances.Clear()
|
||||
configs.Lock()
|
||||
defer configs.Unlock()
|
||||
configs.group = name
|
||||
}
|
||||
|
||||
// GetDefaultGroup returns the { name of default configuration.
|
||||
func GetDefaultGroup() string {
|
||||
defer instances.Clear()
|
||||
configs.RLock()
|
||||
defer configs.RUnlock()
|
||||
return configs.group
|
||||
}
|
||||
|
||||
// IsConfigured checks and returns whether the database configured.
|
||||
// It returns true if any configuration exists.
|
||||
func IsConfigured() bool {
|
||||
configs.RLock()
|
||||
defer configs.RUnlock()
|
||||
return len(configs.config) > 0
|
||||
}
|
||||
|
||||
// SetLogger sets the logger for orm.
|
||||
func (c *Core) SetLogger(logger log.ILogger) {
|
||||
c.logger = logger
|
||||
}
|
||||
|
||||
// GetLogger returns the (logger) of the orm.
|
||||
func (c *Core) GetLogger() log.ILogger {
|
||||
return c.logger
|
||||
}
|
||||
|
||||
// SetMaxIdleConnCount sets the maximum number of connections in the idle
|
||||
// connection pool.
|
||||
//
|
||||
// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
|
||||
// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
|
||||
//
|
||||
// If n <= 0, no idle connections are retained.
|
||||
//
|
||||
// The default max idle connections is currently 2. This may change in
|
||||
// a future release.
|
||||
func (c *Core) SetMaxIdleConnCount(n int) {
|
||||
c.dynamicConfig.MaxIdleConnCount = n
|
||||
}
|
||||
|
||||
// SetMaxOpenConnCount sets the maximum number of open connections to the database.
|
||||
//
|
||||
// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
|
||||
// MaxIdleConns, then MaxIdleConns will be reduced to match the new
|
||||
// MaxOpenConns limit.
|
||||
//
|
||||
// If n <= 0, then there is no limit on the number of open connections.
|
||||
// The default is 0 (unlimited).
|
||||
func (c *Core) SetMaxOpenConnCount(n int) {
|
||||
c.dynamicConfig.MaxOpenConnCount = n
|
||||
}
|
||||
|
||||
// SetMaxConnLifeTime sets the maximum amount of time a connection may be reused.
|
||||
//
|
||||
// Expired connections may be closed lazily before reuse.
|
||||
//
|
||||
// If d <= 0, connections are not closed due to a connection's age.
|
||||
func (c *Core) SetMaxConnLifeTime(d time.Duration) {
|
||||
c.dynamicConfig.MaxConnLifeTime = d
|
||||
}
|
||||
|
||||
// SetMaxIdleConnTime sets the maximum amount of time a connection may be idle before being closed.
|
||||
//
|
||||
// Idle connections may be closed lazily before reuse.
|
||||
//
|
||||
// If d <= 0, connections are not closed due to a connection's idle time.
|
||||
// This is Go 1.15+ feature: sql.DB.SetConnMaxIdleTime.
|
||||
func (c *Core) SetMaxIdleConnTime(d time.Duration) {
|
||||
c.dynamicConfig.MaxIdleConnTime = d
|
||||
}
|
||||
|
||||
// GetConfig returns the current used node configuration.
|
||||
func (c *Core) GetConfig() *ConfigNode {
|
||||
var configNode = c.getConfigNodeFromCtx(c.db.GetCtx())
|
||||
if configNode != nil {
|
||||
// Note:
|
||||
// It so here checks and returns the config from current DB,
|
||||
// if different schemas between current DB and config.Name from context,
|
||||
// for example, in nested transaction scenario, the context is passed all through the logic procedure,
|
||||
// but the config.Name from context may be still the original one from the first transaction object.
|
||||
if c.config.Name == configNode.Name {
|
||||
return configNode
|
||||
}
|
||||
}
|
||||
return c.config
|
||||
}
|
||||
|
||||
// SetDebug enables/disables the debug mode.
|
||||
func (c *Core) SetDebug(debug bool) {
|
||||
c.debug.Set(debug)
|
||||
}
|
||||
|
||||
// GetDebug returns the debug value.
|
||||
func (c *Core) GetDebug() bool {
|
||||
return c.debug.Val()
|
||||
}
|
||||
|
||||
// GetCache returns the internal cache object.
|
||||
func (c *Core) GetCache() *gcache.Cache {
|
||||
return c.cache
|
||||
}
|
||||
|
||||
// GetGroup returns the group string configured.
|
||||
func (c *Core) GetGroup() string {
|
||||
return c.group
|
||||
}
|
||||
|
||||
// SetDryRun enables/disables the DryRun feature.
|
||||
func (c *Core) SetDryRun(enabled bool) {
|
||||
c.config.DryRun = enabled
|
||||
}
|
||||
|
||||
// GetDryRun returns the DryRun value.
|
||||
func (c *Core) GetDryRun() bool {
|
||||
return c.config.DryRun || allDryRun
|
||||
}
|
||||
|
||||
// GetPrefix returns the table prefix string configured.
|
||||
func (c *Core) GetPrefix() string {
|
||||
return c.config.Prefix
|
||||
}
|
||||
|
||||
// GetSchema returns the schema configured.
|
||||
func (c *Core) GetSchema() string {
|
||||
schema := c.schema
|
||||
if schema == "" {
|
||||
schema = c.db.GetConfig().Name
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func parseConfigNodeLink(node *ConfigNode) (*ConfigNode, error) {
|
||||
var (
|
||||
link = node.Link
|
||||
match []string
|
||||
)
|
||||
if link != "" {
|
||||
// To be compatible with old configuration,
|
||||
// it checks and converts the link to new configuration.
|
||||
if node.Type != "" && !gstr.HasPrefix(link, node.Type+":") {
|
||||
link = fmt.Sprintf("%s:%s", node.Type, link)
|
||||
}
|
||||
match, _ = gregex.MatchString(linkPattern, link)
|
||||
if len(match) <= 5 {
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid link configuration: %s, shuold be pattern like: %s`,
|
||||
link, linkPatternDescription,
|
||||
)
|
||||
}
|
||||
node.Type = match[1]
|
||||
node.User = match[2]
|
||||
node.Pass = match[3]
|
||||
node.Protocol = match[4]
|
||||
array := gstr.Split(match[5], ":")
|
||||
if node.Protocol == "file" {
|
||||
node.Name = match[5]
|
||||
} else {
|
||||
if len(array) == 2 {
|
||||
// link with port.
|
||||
node.Host = array[0]
|
||||
node.Port = array[1]
|
||||
} else {
|
||||
// link without port.
|
||||
node.Host = array[0]
|
||||
}
|
||||
node.Name = match[6]
|
||||
}
|
||||
if len(match) > 6 && match[7] != "" {
|
||||
node.Extra = match[7]
|
||||
}
|
||||
}
|
||||
if node.Extra != "" {
|
||||
if m, _ := gstr.Parse(node.Extra); len(m) > 0 {
|
||||
_ = gconv.Struct(m, &node)
|
||||
}
|
||||
}
|
||||
// Default value checks.
|
||||
if node.Charset == "" {
|
||||
node.Charset = defaultCharset
|
||||
}
|
||||
if node.Protocol == "" {
|
||||
node.Protocol = defaultProtocol
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gctx"
|
||||
)
|
||||
|
||||
// internalCtxData stores data in ctx for internal usage purpose.
|
||||
type internalCtxData struct {
|
||||
sync.Mutex
|
||||
// Used configuration node in current operation.
|
||||
ConfigNode *ConfigNode
|
||||
}
|
||||
|
||||
// column stores column data in ctx for internal usage purpose.
|
||||
type internalColumnData struct {
|
||||
// The first column in result response from database server.
|
||||
// This attribute is used for Value/Count selection statement purpose,
|
||||
// which is to avoid HOOK handler that might modify the result columns
|
||||
// that can confuse the Value/Count selection statement logic.
|
||||
FirstResultColumn string
|
||||
}
|
||||
|
||||
const (
|
||||
internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData"
|
||||
internalColumnDataKeyInCtx gctx.StrKey = "InternalColumnData"
|
||||
|
||||
// `ignoreResultKeyInCtx` is a mark for some db drivers that do not support `RowsAffected` function,
|
||||
// for example: `clickhouse`. The `clickhouse` does not support fetching insert/update results,
|
||||
// but returns errors when execute `RowsAffected`. It here ignores the calling of `RowsAffected`
|
||||
// to avoid triggering errors, rather than ignoring errors after they are triggered.
|
||||
ignoreResultKeyInCtx gctx.StrKey = "IgnoreResult"
|
||||
)
|
||||
|
||||
func (c *Core) injectInternalCtxData(ctx context.Context) context.Context {
|
||||
// If the internal data is already injected, it does nothing.
|
||||
if ctx.Value(internalCtxDataKeyInCtx) != nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, internalCtxDataKeyInCtx, &internalCtxData{
|
||||
ConfigNode: c.config,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Core) setConfigNodeToCtx(ctx context.Context, node *ConfigNode) error {
|
||||
value := ctx.Value(internalCtxDataKeyInCtx)
|
||||
if value == nil {
|
||||
return gerror.NewCode(gcode.CodeInternalError, `no internal data found in context`)
|
||||
}
|
||||
|
||||
data := value.(*internalCtxData)
|
||||
data.Lock()
|
||||
defer data.Unlock()
|
||||
data.ConfigNode = node
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) getConfigNodeFromCtx(ctx context.Context) *ConfigNode {
|
||||
if value := ctx.Value(internalCtxDataKeyInCtx); value != nil {
|
||||
data := value.(*internalCtxData)
|
||||
data.Lock()
|
||||
defer data.Unlock()
|
||||
return data.ConfigNode
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) injectInternalColumn(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, internalColumnDataKeyInCtx, &internalColumnData{})
|
||||
}
|
||||
|
||||
func (c *Core) getInternalColumnFromCtx(ctx context.Context) *internalColumnData {
|
||||
if v := ctx.Value(internalColumnDataKeyInCtx); v != nil {
|
||||
return v.(*internalColumnData)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) InjectIgnoreResult(ctx context.Context) context.Context {
|
||||
if ctx.Value(ignoreResultKeyInCtx) != nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, ignoreResultKeyInCtx, true)
|
||||
}
|
||||
|
||||
func (c *Core) GetIgnoreResultFromCtx(ctx context.Context) bool {
|
||||
return ctx.Value(ignoreResultKeyInCtx) != nil
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// dbLink is used to implement interface Link for DB.
|
||||
type dbLink struct {
|
||||
*sql.DB // Underlying DB object.
|
||||
isOnMaster bool // isOnMaster marks whether current link is operated on master node.
|
||||
}
|
||||
|
||||
// txLink is used to implement interface Link for TX.
|
||||
type txLink struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
// IsTransaction returns if current Link is a transaction.
|
||||
func (l *dbLink) IsTransaction() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOnMaster checks and returns whether current link is operated on master node.
|
||||
func (l *dbLink) IsOnMaster() bool {
|
||||
return l.isOnMaster
|
||||
}
|
||||
|
||||
// IsTransaction returns if current Link is a transaction.
|
||||
func (l *txLink) IsTransaction() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsOnMaster checks and returns whether current link is operated on master node.
|
||||
// Note that, transaction operation is always operated on master node.
|
||||
func (l *txLink) IsOnMaster() bool {
|
||||
return true
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type localStatsItem struct {
|
||||
node *ConfigNode
|
||||
stats sql.DBStats
|
||||
}
|
||||
|
||||
// Node returns the configuration node info.
|
||||
func (item *localStatsItem) Node() ConfigNode {
|
||||
return *item.node
|
||||
}
|
||||
|
||||
// Stats returns the connection stat for current node.
|
||||
func (item *localStatsItem) Stats() sql.DBStats {
|
||||
return item.stats
|
||||
}
|
||||
|
||||
// Stats retrieves and returns the pool stat for all nodes that have been established.
|
||||
func (c *Core) Stats(ctx context.Context) []StatsItem {
|
||||
var items = make([]StatsItem, 0)
|
||||
c.links.Iterator(func(k ConfigNode, v *sql.DB) bool {
|
||||
// Create a local copy of k to avoid loop variable address re-use issue
|
||||
// In Go, loop variables are re-used with the same memory address across iterations,
|
||||
// directly using &k would cause all localStatsItem instances to share the same address
|
||||
node := k
|
||||
items = append(items, &localStatsItem{
|
||||
node: &node,
|
||||
stats: v.Stats(),
|
||||
})
|
||||
return true
|
||||
})
|
||||
return items
|
||||
}
|
||||
|
|
@ -0,0 +1,512 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
"git.magicany.cc/black1552/gin-base/database/json"
|
||||
"github.com/gogf/gf/v2/encoding/gbinary"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// GetFieldTypeStr retrieves and returns the field type string for certain field by name.
|
||||
func (c *Core) GetFieldTypeStr(ctx context.Context, fieldName, table, schema string) string {
|
||||
field := c.GetFieldType(ctx, fieldName, table, schema)
|
||||
if field != nil {
|
||||
// Kinds of data type examples:
|
||||
// year(4)
|
||||
// datetime
|
||||
// varchar(64)
|
||||
// bigint(20)
|
||||
// int(10) unsigned
|
||||
typeName := gstr.StrTillEx(field.Type, "(") // int(10) unsigned -> int
|
||||
if typeName != "" {
|
||||
typeName = gstr.Trim(typeName)
|
||||
} else {
|
||||
typeName = field.Type
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetFieldType retrieves and returns the field type object for certain field by name.
|
||||
func (c *Core) GetFieldType(ctx context.Context, fieldName, table, schema string) *TableField {
|
||||
fieldsMap, err := c.db.TableFields(ctx, table, schema)
|
||||
if err != nil {
|
||||
intlog.Errorf(
|
||||
ctx,
|
||||
`TableFields failed for table "%s", schema "%s": %+v`,
|
||||
table, schema, err,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
for tableFieldName, tableField := range fieldsMap {
|
||||
if tableFieldName == fieldName {
|
||||
return tableField
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertDataForRecord is a very important function, which does converting for any data that
|
||||
// will be inserted into table/collection as a record.
|
||||
//
|
||||
// The parameter `value` should be type of *map/map/*struct/struct.
|
||||
// It supports embedded struct definition for struct.
|
||||
func (c *Core) ConvertDataForRecord(ctx context.Context, value any, table string) (map[string]any, error) {
|
||||
var (
|
||||
err error
|
||||
data = MapOrStructToMapDeep(value, true)
|
||||
)
|
||||
for fieldName, fieldValue := range data {
|
||||
var fieldType = c.GetFieldTypeStr(ctx, fieldName, table, c.GetSchema())
|
||||
data[fieldName], err = c.db.ConvertValueForField(
|
||||
ctx,
|
||||
fieldType,
|
||||
fieldValue,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, gerror.Wrapf(err, `ConvertDataForRecord failed for value: %#v`, fieldValue)
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ConvertValueForField converts value to the type of the record field.
|
||||
// The parameter `fieldType` is the target record field.
|
||||
// The parameter `fieldValue` is the value that to be committed to record field.
|
||||
func (c *Core) ConvertValueForField(ctx context.Context, fieldType string, fieldValue any) (any, error) {
|
||||
var (
|
||||
err error
|
||||
convertedValue = fieldValue
|
||||
)
|
||||
switch fieldValue.(type) {
|
||||
case time.Time, *time.Time, gtime.Time, *gtime.Time:
|
||||
goto Default
|
||||
}
|
||||
// If `value` implements interface `driver.Valuer`, it then uses the interface for value converting.
|
||||
if valuer, ok := fieldValue.(driver.Valuer); ok {
|
||||
if convertedValue, err = valuer.Value(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertedValue, nil
|
||||
}
|
||||
Default:
|
||||
// Default value converting.
|
||||
var (
|
||||
rvValue = reflect.ValueOf(fieldValue)
|
||||
rvKind = rvValue.Kind()
|
||||
)
|
||||
for rvKind == reflect.Pointer {
|
||||
rvValue = rvValue.Elem()
|
||||
rvKind = rvValue.Kind()
|
||||
}
|
||||
switch rvKind {
|
||||
case reflect.Invalid:
|
||||
convertedValue = nil
|
||||
|
||||
case reflect.Slice, reflect.Array, reflect.Map:
|
||||
// It should ignore the bytes type.
|
||||
if _, ok := fieldValue.([]byte); !ok {
|
||||
// Convert the value to JSON.
|
||||
convertedValue, err = json.Marshal(fieldValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
switch r := fieldValue.(type) {
|
||||
// If the time is zero, it then updates it to nil,
|
||||
// which will insert/update the value to database as "null".
|
||||
case time.Time:
|
||||
if r.IsZero() {
|
||||
convertedValue = nil
|
||||
} else {
|
||||
switch fieldType {
|
||||
case fieldTypeYear:
|
||||
convertedValue = r.Format("2006")
|
||||
case fieldTypeDate:
|
||||
convertedValue = r.Format("2006-01-02")
|
||||
case fieldTypeTime:
|
||||
convertedValue = r.Format("15:04:05")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case *time.Time:
|
||||
if r == nil {
|
||||
// Nothing to do.
|
||||
} else {
|
||||
switch fieldType {
|
||||
case fieldTypeYear:
|
||||
convertedValue = r.Format("2006")
|
||||
case fieldTypeDate:
|
||||
convertedValue = r.Format("2006-01-02")
|
||||
case fieldTypeTime:
|
||||
convertedValue = r.Format("15:04:05")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case gtime.Time:
|
||||
if r.IsZero() {
|
||||
convertedValue = nil
|
||||
} else {
|
||||
switch fieldType {
|
||||
case fieldTypeYear:
|
||||
convertedValue = r.Layout("2006")
|
||||
case fieldTypeDate:
|
||||
convertedValue = r.Layout("2006-01-02")
|
||||
case fieldTypeTime:
|
||||
convertedValue = r.Layout("15:04:05")
|
||||
default:
|
||||
convertedValue = r.Time
|
||||
}
|
||||
}
|
||||
|
||||
case *gtime.Time:
|
||||
if r.IsZero() {
|
||||
convertedValue = nil
|
||||
} else {
|
||||
switch fieldType {
|
||||
case fieldTypeYear:
|
||||
convertedValue = r.Layout("2006")
|
||||
case fieldTypeDate:
|
||||
convertedValue = r.Layout("2006-01-02")
|
||||
case fieldTypeTime:
|
||||
convertedValue = r.Layout("15:04:05")
|
||||
default:
|
||||
convertedValue = r.Time
|
||||
}
|
||||
}
|
||||
|
||||
case Counter, *Counter:
|
||||
// Nothing to do.
|
||||
|
||||
default:
|
||||
// If `value` implements interface iNil,
|
||||
// check its IsNil() function, if got ture,
|
||||
// which will insert/update the value to database as "null".
|
||||
if v, ok := fieldValue.(iNil); ok && v.IsNil() {
|
||||
convertedValue = nil
|
||||
} else if s, ok := fieldValue.(iString); ok {
|
||||
// Use string conversion in default.
|
||||
convertedValue = s.String()
|
||||
} else {
|
||||
// Convert the value to JSON.
|
||||
convertedValue, err = json.Marshal(fieldValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
return convertedValue, nil
|
||||
}
|
||||
|
||||
// GetFormattedDBTypeNameForField retrieves and returns the formatted database type name
|
||||
// eg. `int(10) unsigned` -> `int`, `varchar(100)` -> `varchar`, etc.
|
||||
func (c *Core) GetFormattedDBTypeNameForField(fieldType string) (typeName, typePattern string) {
|
||||
match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType)
|
||||
if len(match) == 3 {
|
||||
typeName = gstr.Trim(match[1])
|
||||
typePattern = gstr.Trim(match[2])
|
||||
} else {
|
||||
var array = gstr.SplitAndTrim(fieldType, " ")
|
||||
if len(array) > 1 && gstr.Equal(array[0], "unsigned") {
|
||||
typeName = array[1]
|
||||
} else if len(array) > 0 {
|
||||
typeName = array[0]
|
||||
}
|
||||
}
|
||||
typeName = strings.ToLower(typeName)
|
||||
return
|
||||
}
|
||||
|
||||
// CheckLocalTypeForField checks and returns corresponding type for given db type.
|
||||
// The `fieldType` is retrieved from ColumnTypes of db driver, example:
|
||||
// UNSIGNED INT
|
||||
func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ any) (LocalType, error) {
|
||||
var (
|
||||
typeName string
|
||||
typePattern string
|
||||
)
|
||||
typeName, typePattern = c.GetFormattedDBTypeNameForField(fieldType)
|
||||
switch typeName {
|
||||
case
|
||||
fieldTypeBinary,
|
||||
fieldTypeVarbinary,
|
||||
fieldTypeBlob,
|
||||
fieldTypeTinyblob,
|
||||
fieldTypeMediumblob,
|
||||
fieldTypeLongblob:
|
||||
return LocalTypeBytes, nil
|
||||
|
||||
case
|
||||
fieldTypeInt,
|
||||
fieldTypeTinyint,
|
||||
fieldTypeSmallInt,
|
||||
fieldTypeSmallint,
|
||||
fieldTypeMediumInt,
|
||||
fieldTypeMediumint,
|
||||
fieldTypeSerial:
|
||||
if gstr.ContainsI(fieldType, "unsigned") {
|
||||
return LocalTypeUint, nil
|
||||
}
|
||||
return LocalTypeInt, nil
|
||||
|
||||
case
|
||||
fieldTypeBigInt,
|
||||
fieldTypeBigint,
|
||||
fieldTypeBigserial:
|
||||
if gstr.ContainsI(fieldType, "unsigned") {
|
||||
return LocalTypeUint64, nil
|
||||
}
|
||||
return LocalTypeInt64, nil
|
||||
|
||||
case
|
||||
fieldTypeInt128,
|
||||
fieldTypeInt256,
|
||||
fieldTypeUint128,
|
||||
fieldTypeUint256:
|
||||
return LocalTypeBigInt, nil
|
||||
|
||||
case
|
||||
fieldTypeReal:
|
||||
return LocalTypeFloat32, nil
|
||||
|
||||
case
|
||||
fieldTypeDecimal,
|
||||
fieldTypeMoney,
|
||||
fieldTypeNumeric,
|
||||
fieldTypeSmallmoney:
|
||||
return LocalTypeString, nil
|
||||
case
|
||||
fieldTypeFloat,
|
||||
fieldTypeDouble:
|
||||
return LocalTypeFloat64, nil
|
||||
|
||||
case
|
||||
fieldTypeBit:
|
||||
// It is suggested using bit(1) as boolean.
|
||||
if typePattern == "1" {
|
||||
return LocalTypeBool, nil
|
||||
}
|
||||
if gstr.ContainsI(fieldType, "unsigned") {
|
||||
return LocalTypeUint64Bytes, nil
|
||||
}
|
||||
return LocalTypeInt64Bytes, nil
|
||||
|
||||
case
|
||||
fieldTypeBool:
|
||||
return LocalTypeBool, nil
|
||||
|
||||
case
|
||||
fieldTypeDate:
|
||||
return LocalTypeDate, nil
|
||||
|
||||
case
|
||||
fieldTypeTime:
|
||||
return LocalTypeTime, nil
|
||||
|
||||
case
|
||||
fieldTypeDatetime,
|
||||
fieldTypeTimestamp,
|
||||
fieldTypeTimestampz:
|
||||
return LocalTypeDatetime, nil
|
||||
|
||||
case
|
||||
fieldTypeJson:
|
||||
return LocalTypeJson, nil
|
||||
|
||||
case
|
||||
fieldTypeJsonb:
|
||||
return LocalTypeJsonb, nil
|
||||
|
||||
default:
|
||||
// Auto-detect field type, using key match.
|
||||
switch {
|
||||
case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"):
|
||||
return LocalTypeString, nil
|
||||
|
||||
case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"):
|
||||
return LocalTypeFloat64, nil
|
||||
|
||||
case strings.Contains(typeName, "bool"):
|
||||
return LocalTypeBool, nil
|
||||
|
||||
case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"):
|
||||
return LocalTypeBytes, nil
|
||||
|
||||
case strings.Contains(typeName, "int"):
|
||||
if gstr.ContainsI(fieldType, "unsigned") {
|
||||
return LocalTypeUint, nil
|
||||
}
|
||||
return LocalTypeInt, nil
|
||||
|
||||
case strings.Contains(typeName, "time"):
|
||||
return LocalTypeDatetime, nil
|
||||
|
||||
case strings.Contains(typeName, "date"):
|
||||
return LocalTypeDatetime, nil
|
||||
|
||||
default:
|
||||
return LocalTypeString, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertValueForLocal converts value to local Golang type of value according field type name from database.
|
||||
// The parameter `fieldType` is in lower case, like:
|
||||
// `float(5,2)`, `unsigned double(5,2)`, `decimal(10,2)`, `char(45)`, `varchar(100)`, etc.
|
||||
func (c *Core) ConvertValueForLocal(
|
||||
ctx context.Context, fieldType string, fieldValue any,
|
||||
) (any, error) {
|
||||
// If there's no type retrieved, it returns the `fieldValue` directly
|
||||
// to use its original data type, as `fieldValue` is type of any.
|
||||
if fieldType == "" {
|
||||
return fieldValue, nil
|
||||
}
|
||||
typeName, err := c.db.CheckLocalTypeForField(ctx, fieldType, fieldValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch typeName {
|
||||
case LocalTypeBytes:
|
||||
var typeNameStr = string(typeName)
|
||||
if strings.Contains(typeNameStr, "binary") || strings.Contains(typeNameStr, "blob") {
|
||||
return fieldValue, nil
|
||||
}
|
||||
return gconv.Bytes(fieldValue), nil
|
||||
|
||||
case LocalTypeInt:
|
||||
return gconv.Int(gconv.String(fieldValue)), nil
|
||||
|
||||
case LocalTypeUint:
|
||||
return gconv.Uint(gconv.String(fieldValue)), nil
|
||||
|
||||
case LocalTypeInt64:
|
||||
return gconv.Int64(gconv.String(fieldValue)), nil
|
||||
|
||||
case LocalTypeUint64:
|
||||
return gconv.Uint64(gconv.String(fieldValue)), nil
|
||||
|
||||
case LocalTypeInt64Bytes:
|
||||
return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil
|
||||
|
||||
case LocalTypeUint64Bytes:
|
||||
return gbinary.BeDecodeToUint64(gconv.Bytes(fieldValue)), nil
|
||||
|
||||
case LocalTypeBigInt:
|
||||
switch v := fieldValue.(type) {
|
||||
case big.Int:
|
||||
return v.String(), nil
|
||||
case *big.Int:
|
||||
return v.String(), nil
|
||||
default:
|
||||
return gconv.String(fieldValue), nil
|
||||
}
|
||||
|
||||
case LocalTypeFloat32:
|
||||
return gconv.Float32(gconv.String(fieldValue)), nil
|
||||
|
||||
case LocalTypeFloat64:
|
||||
return gconv.Float64(gconv.String(fieldValue)), nil
|
||||
|
||||
case LocalTypeBool:
|
||||
s := gconv.String(fieldValue)
|
||||
// mssql is true|false string.
|
||||
if strings.EqualFold(s, "true") {
|
||||
return 1, nil
|
||||
}
|
||||
if strings.EqualFold(s, "false") {
|
||||
return 0, nil
|
||||
}
|
||||
return gconv.Bool(fieldValue), nil
|
||||
|
||||
case LocalTypeDate:
|
||||
if t, ok := fieldValue.(time.Time); ok {
|
||||
return gtime.NewFromTime(t).Format("Y-m-d"), nil
|
||||
}
|
||||
t, _ := gtime.StrToTime(gconv.String(fieldValue))
|
||||
return t.Format("Y-m-d"), nil
|
||||
|
||||
case LocalTypeTime:
|
||||
if t, ok := fieldValue.(time.Time); ok {
|
||||
return gtime.NewFromTime(t).Format("H:i:s"), nil
|
||||
}
|
||||
t, _ := gtime.StrToTime(gconv.String(fieldValue))
|
||||
return t.Format("H:i:s"), nil
|
||||
|
||||
case LocalTypeDatetime:
|
||||
if t, ok := fieldValue.(time.Time); ok {
|
||||
return gtime.NewFromTime(t), nil
|
||||
}
|
||||
t, _ := gtime.StrToTime(gconv.String(fieldValue))
|
||||
return t, nil
|
||||
|
||||
default:
|
||||
return gconv.String(fieldValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
// mappingAndFilterData automatically mappings the map key to table field and removes
|
||||
// all key-value pairs that are not the field of given table.
|
||||
func (c *Core) mappingAndFilterData(ctx context.Context, schema, table string, data map[string]any, filter bool) (map[string]any, error) {
|
||||
fieldsMap, err := c.db.TableFields(ctx, c.guessPrimaryTableName(table), schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(fieldsMap) == 0 {
|
||||
return nil, gerror.Newf(`The table %s may not exist, or the table contains no fields`, table)
|
||||
}
|
||||
fieldsKeyMap := make(map[string]any, len(fieldsMap))
|
||||
for k := range fieldsMap {
|
||||
fieldsKeyMap[k] = nil
|
||||
}
|
||||
// Automatic data key to table field name mapping.
|
||||
var foundKey string
|
||||
for dataKey, dataValue := range data {
|
||||
if _, ok := fieldsKeyMap[dataKey]; !ok {
|
||||
foundKey, _ = gutil.MapPossibleItemByKey(fieldsKeyMap, dataKey)
|
||||
if foundKey != "" {
|
||||
if _, ok = data[foundKey]; !ok {
|
||||
data[foundKey] = dataValue
|
||||
}
|
||||
delete(data, dataKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Data filtering.
|
||||
// It deletes all key-value pairs that has incorrect field name.
|
||||
if filter {
|
||||
for dataKey := range data {
|
||||
if _, ok := fieldsMap[dataKey]; !ok {
|
||||
delete(data, dataKey)
|
||||
}
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, gerror.Newf(`input data match no fields in table %s`, table)
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.18.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/gogf/gf/v2/net/gtrace"
|
||||
)
|
||||
|
||||
const (
|
||||
traceInstrumentName = "github.com/gogf/gf/v2/database/gdb"
|
||||
traceAttrDbType = "db.type"
|
||||
traceAttrDbHost = "db.host"
|
||||
traceAttrDbPort = "db.port"
|
||||
traceAttrDbName = "db.name"
|
||||
traceAttrDbUser = "db.user"
|
||||
traceAttrDbLink = "db.link"
|
||||
traceAttrDbGroup = "db.group"
|
||||
traceEventDbExecution = "db.execution"
|
||||
traceEventDbExecutionCost = "db.execution.cost"
|
||||
traceEventDbExecutionRows = "db.execution.rows"
|
||||
traceEventDbExecutionTxID = "db.execution.txid"
|
||||
traceEventDbExecutionType = "db.execution.type"
|
||||
)
|
||||
|
||||
// addSqlToTracing adds sql information to tracer if it's enabled.
|
||||
func (c *Core) traceSpanEnd(ctx context.Context, span trace.Span, sql *Sql) {
|
||||
if gtrace.IsUsingDefaultProvider() || !gtrace.IsTracingInternal() {
|
||||
return
|
||||
}
|
||||
if sql.Error != nil {
|
||||
span.SetStatus(codes.Error, fmt.Sprintf(`%+v`, sql.Error))
|
||||
}
|
||||
labels := make([]attribute.KeyValue, 0)
|
||||
labels = append(labels, gtrace.CommonLabels()...)
|
||||
labels = append(labels,
|
||||
attribute.String(traceAttrDbType, c.db.GetConfig().Type),
|
||||
semconv.DBStatement(sql.Format),
|
||||
)
|
||||
if c.db.GetConfig().Host != "" {
|
||||
labels = append(labels, attribute.String(traceAttrDbHost, c.db.GetConfig().Host))
|
||||
}
|
||||
if c.db.GetConfig().Port != "" {
|
||||
labels = append(labels, attribute.String(traceAttrDbPort, c.db.GetConfig().Port))
|
||||
}
|
||||
if c.db.GetConfig().Name != "" {
|
||||
labels = append(labels, attribute.String(traceAttrDbName, c.db.GetConfig().Name))
|
||||
}
|
||||
if c.db.GetConfig().User != "" {
|
||||
labels = append(labels, attribute.String(traceAttrDbUser, c.db.GetConfig().User))
|
||||
}
|
||||
if filteredLink := c.db.GetCore().FilteredLink(); filteredLink != "" {
|
||||
labels = append(labels, attribute.String(traceAttrDbLink, c.db.GetCore().FilteredLink()))
|
||||
}
|
||||
if group := c.db.GetGroup(); group != "" {
|
||||
labels = append(labels, attribute.String(traceAttrDbGroup, group))
|
||||
}
|
||||
span.SetAttributes(labels...)
|
||||
events := []attribute.KeyValue{
|
||||
attribute.String(traceEventDbExecutionCost, fmt.Sprintf(`%d ms`, sql.End-sql.Start)),
|
||||
attribute.String(traceEventDbExecutionRows, fmt.Sprintf(`%d`, sql.RowsAffected)),
|
||||
}
|
||||
if sql.IsTransaction {
|
||||
if v := ctx.Value(transactionIdForLoggerCtx); v != nil {
|
||||
events = append(events, attribute.String(
|
||||
traceEventDbExecutionTxID, fmt.Sprintf(`%d`, v.(uint64)),
|
||||
))
|
||||
}
|
||||
}
|
||||
events = append(events, attribute.String(traceEventDbExecutionType, string(sql.Type)))
|
||||
span.AddEvent(traceEventDbExecution, trace.WithAttributes(events...))
|
||||
}
|
||||
|
|
@ -0,0 +1,295 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gtype"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
)
|
||||
|
||||
// Propagation defines transaction propagation behavior.
|
||||
type Propagation string
|
||||
|
||||
const (
|
||||
// PropagationNested starts a nested transaction if already in a transaction,
|
||||
// or behaves like PropagationRequired if not in a transaction.
|
||||
//
|
||||
// It is the default behavior.
|
||||
PropagationNested Propagation = "NESTED"
|
||||
|
||||
// PropagationRequired starts a new transaction if not in a transaction,
|
||||
// or uses the existing transaction if already in a transaction.
|
||||
PropagationRequired Propagation = "REQUIRED"
|
||||
|
||||
// PropagationSupports executes within the existing transaction if present,
|
||||
// otherwise executes without transaction.
|
||||
PropagationSupports Propagation = "SUPPORTS"
|
||||
|
||||
// PropagationRequiresNew starts a new transaction, and suspends the current transaction if one exists.
|
||||
PropagationRequiresNew Propagation = "REQUIRES_NEW"
|
||||
|
||||
// PropagationNotSupported executes non-transactional, suspends any existing transaction.
|
||||
PropagationNotSupported Propagation = "NOT_SUPPORTED"
|
||||
|
||||
// PropagationMandatory executes in a transaction, fails if no existing transaction.
|
||||
PropagationMandatory Propagation = "MANDATORY"
|
||||
|
||||
// PropagationNever executes non-transactional, fails if in an existing transaction.
|
||||
PropagationNever Propagation = "NEVER"
|
||||
)
|
||||
|
||||
// TxOptions defines options for transaction control.
|
||||
type TxOptions struct {
|
||||
// Propagation specifies the propagation behavior.
|
||||
Propagation Propagation
|
||||
// Isolation is the transaction isolation level.
|
||||
// If zero, the driver or database's default level is used.
|
||||
Isolation sql.IsolationLevel
|
||||
// ReadOnly is used to mark the transaction as read-only.
|
||||
ReadOnly bool
|
||||
}
|
||||
|
||||
// Context key types for transaction to avoid collisions
|
||||
type transactionCtxKey string
|
||||
|
||||
const (
|
||||
transactionPointerPrefix = "transaction"
|
||||
contextTransactionKeyPrefix = "TransactionObjectForGroup_"
|
||||
transactionIdForLoggerCtx transactionCtxKey = "TransactionId"
|
||||
)
|
||||
|
||||
var transactionIdGenerator = gtype.NewUint64()
|
||||
|
||||
// DefaultTxOptions returns the default transaction options.
|
||||
func DefaultTxOptions() TxOptions {
|
||||
return TxOptions{
|
||||
// Note the default propagation type is PropagationNested not PropagationRequired.
|
||||
Propagation: PropagationNested,
|
||||
}
|
||||
}
|
||||
|
||||
// Begin starts and returns the transaction object.
|
||||
// You should call Commit or Rollback functions of the transaction object
|
||||
// if you no longer use the transaction. Commit or Rollback functions will also
|
||||
// close the transaction automatically.
|
||||
func (c *Core) Begin(ctx context.Context) (tx TX, err error) {
|
||||
return c.BeginWithOptions(ctx, DefaultTxOptions())
|
||||
}
|
||||
|
||||
// BeginWithOptions starts and returns the transaction object with given options.
|
||||
// The options allow specifying the isolation level and read-only mode.
|
||||
// You should call Commit or Rollback functions of the transaction object
|
||||
// if you no longer use the transaction. Commit or Rollback functions will also
|
||||
// close the transaction automatically.
|
||||
func (c *Core) BeginWithOptions(ctx context.Context, opts TxOptions) (tx TX, err error) {
|
||||
if ctx == nil {
|
||||
ctx = c.db.GetCtx()
|
||||
}
|
||||
ctx = c.injectInternalCtxData(ctx)
|
||||
return c.doBeginCtx(ctx, sql.TxOptions{
|
||||
Isolation: opts.Isolation,
|
||||
ReadOnly: opts.ReadOnly,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Core) doBeginCtx(ctx context.Context, opts sql.TxOptions) (TX, error) {
|
||||
master, err := c.db.Master()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var out DoCommitOutput
|
||||
out, err = c.db.DoCommit(ctx, DoCommitInput{
|
||||
Db: master,
|
||||
Sql: "BEGIN",
|
||||
Type: SqlTypeBegin,
|
||||
TxOptions: opts,
|
||||
IsTransaction: true,
|
||||
})
|
||||
return out.Tx, err
|
||||
}
|
||||
|
||||
// Transaction wraps the transaction logic using function `f`.
|
||||
// It rollbacks the transaction and returns the error from function `f` if
|
||||
// it returns non-nil error. It commits the transaction and returns nil if
|
||||
// function `f` returns nil.
|
||||
//
|
||||
// Note that, you should not Commit or Rollback the transaction in function `f`
|
||||
// as it is automatically handled by this function.
|
||||
func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) {
|
||||
return c.TransactionWithOptions(ctx, DefaultTxOptions(), f)
|
||||
}
|
||||
|
||||
// TransactionWithOptions wraps the transaction logic with propagation options using function `f`.
|
||||
func (c *Core) TransactionWithOptions(
|
||||
ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error,
|
||||
) (err error) {
|
||||
if ctx == nil {
|
||||
ctx = c.db.GetCtx()
|
||||
}
|
||||
ctx = c.injectInternalCtxData(ctx)
|
||||
|
||||
// Check current transaction from context
|
||||
var (
|
||||
group = c.db.GetGroup()
|
||||
currentTx = TXFromCtx(ctx, group)
|
||||
)
|
||||
switch opts.Propagation {
|
||||
case PropagationRequired:
|
||||
if currentTx != nil {
|
||||
return f(ctx, currentTx)
|
||||
}
|
||||
return c.createNewTransaction(ctx, opts, f)
|
||||
|
||||
case PropagationSupports:
|
||||
if currentTx == nil {
|
||||
currentTx = c.newEmptyTX()
|
||||
}
|
||||
return f(ctx, currentTx)
|
||||
|
||||
case PropagationMandatory:
|
||||
if currentTx == nil {
|
||||
return gerror.NewCode(
|
||||
gcode.CodeInvalidOperation,
|
||||
"transaction propagation MANDATORY requires an existing transaction",
|
||||
)
|
||||
}
|
||||
return f(ctx, currentTx)
|
||||
|
||||
case PropagationRequiresNew:
|
||||
ctx = WithoutTX(ctx, group)
|
||||
return c.createNewTransaction(ctx, opts, f)
|
||||
|
||||
case PropagationNotSupported:
|
||||
ctx = WithoutTX(ctx, group)
|
||||
return f(ctx, c.newEmptyTX())
|
||||
|
||||
case PropagationNever:
|
||||
if currentTx != nil {
|
||||
return gerror.NewCode(
|
||||
gcode.CodeInvalidOperation,
|
||||
"transaction propagation NEVER cannot run within an existing transaction",
|
||||
)
|
||||
}
|
||||
ctx = WithoutTX(ctx, group)
|
||||
return f(ctx, c.newEmptyTX())
|
||||
|
||||
case PropagationNested:
|
||||
if currentTx != nil {
|
||||
return currentTx.Transaction(ctx, f)
|
||||
}
|
||||
return c.createNewTransaction(ctx, opts, f)
|
||||
|
||||
default:
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
"unsupported propagation behavior: %s",
|
||||
opts.Propagation,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// createNewTransaction handles creating and managing a new transaction
|
||||
func (c *Core) createNewTransaction(
|
||||
ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error,
|
||||
) (err error) {
|
||||
// Begin transaction with options
|
||||
tx, err := c.doBeginCtx(ctx, sql.TxOptions{
|
||||
Isolation: opts.Isolation,
|
||||
ReadOnly: opts.ReadOnly,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Inject transaction object into context
|
||||
ctx = WithTX(tx.GetCtx(), tx)
|
||||
err = callTxFunc(tx.Ctx(ctx), f)
|
||||
return
|
||||
}
|
||||
|
||||
func callTxFunc(tx TX, f func(ctx context.Context, tx TX) error) (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if exception := recover(); exception != nil {
|
||||
if v, ok := exception.(error); ok && gerror.HasStack(v) {
|
||||
err = v
|
||||
} else {
|
||||
err = gerror.NewCodef(gcode.CodeInternalPanic, "%+v", exception)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = e
|
||||
}
|
||||
} else {
|
||||
if e := tx.Commit(); e != nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
}()
|
||||
err = f(tx.GetCtx(), tx)
|
||||
return
|
||||
}
|
||||
|
||||
// WithTX injects given transaction object into context and returns a new context.
|
||||
func WithTX(ctx context.Context, tx TX) context.Context {
|
||||
if tx == nil {
|
||||
return ctx
|
||||
}
|
||||
// Check repeat injection from given.
|
||||
group := tx.GetDB().GetGroup()
|
||||
if ctxTx := TXFromCtx(ctx, group); ctxTx != nil && ctxTx.GetDB().GetGroup() == group {
|
||||
return ctx
|
||||
}
|
||||
dbCtx := tx.GetDB().GetCtx()
|
||||
if ctxTx := TXFromCtx(dbCtx, group); ctxTx != nil && ctxTx.GetDB().GetGroup() == group {
|
||||
return dbCtx
|
||||
}
|
||||
// Inject transaction object and id into context.
|
||||
ctx = context.WithValue(ctx, transactionKeyForContext(group), tx)
|
||||
ctx = context.WithValue(ctx, transactionIdForLoggerCtx, tx.GetCtx().Value(transactionIdForLoggerCtx))
|
||||
return ctx
|
||||
}
|
||||
|
||||
// WithoutTX removed transaction object from context and returns a new context.
|
||||
func WithoutTX(ctx context.Context, group string) context.Context {
|
||||
ctx = context.WithValue(ctx, transactionKeyForContext(group), nil)
|
||||
ctx = context.WithValue(ctx, transactionIdForLoggerCtx, nil)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// TXFromCtx retrieves and returns transaction object from context.
|
||||
// It is usually used in nested transaction feature, and it returns nil if it is not set previously.
|
||||
func TXFromCtx(ctx context.Context, group string) TX {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
v := ctx.Value(transactionKeyForContext(group))
|
||||
if v != nil {
|
||||
tx := v.(TX)
|
||||
if tx.IsClosed() {
|
||||
return nil
|
||||
}
|
||||
// no underlying sql tx.
|
||||
if tx.GetSqlTX() == nil {
|
||||
return nil
|
||||
}
|
||||
tx = tx.Ctx(ctx)
|
||||
return tx
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// transactionKeyForContext forms and returns a key for storing transaction object of certain database group into context.
|
||||
func transactionKeyForContext(group string) transactionCtxKey {
|
||||
return transactionCtxKey(contextTransactionKeyPrefix + group)
|
||||
}
|
||||
|
|
@ -0,0 +1,437 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/reflection"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// TXCore is the struct for transaction management.
|
||||
type TXCore struct {
|
||||
// db is the database management interface that implements the DB interface,
|
||||
// providing access to database operations and configuration.
|
||||
db DB
|
||||
// tx is the underlying SQL transaction object from database/sql package,
|
||||
// which manages the actual transaction operations.
|
||||
tx *sql.Tx
|
||||
// ctx is the context specific to this transaction,
|
||||
// which can be used for timeout control and cancellation.
|
||||
ctx context.Context
|
||||
// master is the underlying master database connection pool,
|
||||
// used for direct database operations when needed.
|
||||
master *sql.DB
|
||||
// transactionId is a unique identifier for this transaction instance,
|
||||
// used for tracking and debugging purposes.
|
||||
transactionId string
|
||||
// transactionCount tracks the number of nested transaction begins,
|
||||
// used for managing transaction nesting depth.
|
||||
transactionCount int
|
||||
// isClosed indicates whether this transaction has been finalized
|
||||
// through either a commit or rollback operation.
|
||||
isClosed bool
|
||||
// cancelFunc is the context cancellation function associated with ctx,
|
||||
// used to cancel the transaction context when needed.
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *Core) newEmptyTX() TX {
|
||||
return &TXCore{
|
||||
db: c.db,
|
||||
}
|
||||
}
|
||||
|
||||
// transactionKeyForNestedPoint forms and returns the transaction key at current save point.
|
||||
func (tx *TXCore) transactionKeyForNestedPoint() string {
|
||||
return tx.db.GetCore().QuoteWord(
|
||||
transactionPointerPrefix + gconv.String(tx.transactionCount),
|
||||
)
|
||||
}
|
||||
|
||||
// Ctx sets the context for current transaction.
|
||||
func (tx *TXCore) Ctx(ctx context.Context) TX {
|
||||
tx.ctx = ctx
|
||||
if tx.ctx != nil {
|
||||
tx.ctx = tx.db.GetCore().injectInternalCtxData(tx.ctx)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
// GetCtx returns the context for current transaction.
|
||||
func (tx *TXCore) GetCtx() context.Context {
|
||||
return tx.ctx
|
||||
}
|
||||
|
||||
// GetDB returns the DB for current transaction.
|
||||
func (tx *TXCore) GetDB() DB {
|
||||
return tx.db
|
||||
}
|
||||
|
||||
// GetSqlTX returns the underlying transaction object for current transaction.
|
||||
func (tx *TXCore) GetSqlTX() *sql.Tx {
|
||||
return tx.tx
|
||||
}
|
||||
|
||||
// Commit commits current transaction.
|
||||
// Note that it releases previous saved transaction point if it's in a nested transaction procedure,
|
||||
// or else it commits the hole transaction.
|
||||
func (tx *TXCore) Commit() error {
|
||||
if tx.transactionCount > 0 {
|
||||
tx.transactionCount--
|
||||
_, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKeyForNestedPoint())
|
||||
return err
|
||||
}
|
||||
_, err := tx.db.DoCommit(tx.ctx, DoCommitInput{
|
||||
Tx: tx.tx,
|
||||
Sql: "COMMIT",
|
||||
Type: SqlTypeTXCommit,
|
||||
TxCancelFunc: tx.cancelFunc,
|
||||
IsTransaction: true,
|
||||
})
|
||||
if err == nil {
|
||||
tx.isClosed = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Rollback aborts current transaction.
|
||||
// Note that it aborts current transaction if it's in a nested transaction procedure,
|
||||
// or else it aborts the hole transaction.
|
||||
func (tx *TXCore) Rollback() error {
|
||||
if tx.transactionCount > 0 {
|
||||
tx.transactionCount--
|
||||
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKeyForNestedPoint())
|
||||
return err
|
||||
}
|
||||
_, err := tx.db.DoCommit(tx.ctx, DoCommitInput{
|
||||
Tx: tx.tx,
|
||||
Sql: "ROLLBACK",
|
||||
Type: SqlTypeTXRollback,
|
||||
TxCancelFunc: tx.cancelFunc,
|
||||
IsTransaction: true,
|
||||
})
|
||||
if err == nil {
|
||||
tx.isClosed = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// IsClosed checks and returns this transaction has already been committed or rolled back.
|
||||
func (tx *TXCore) IsClosed() bool {
|
||||
return tx.isClosed
|
||||
}
|
||||
|
||||
// Begin starts a nested transaction procedure.
|
||||
func (tx *TXCore) Begin() error {
|
||||
_, err := tx.Exec("SAVEPOINT " + tx.transactionKeyForNestedPoint())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.transactionCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePoint performs `SAVEPOINT xxx` SQL statement that saves transaction at current point.
|
||||
// The parameter `point` specifies the point name that will be saved to server.
|
||||
func (tx *TXCore) SavePoint(point string) error {
|
||||
_, err := tx.Exec("SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
|
||||
return err
|
||||
}
|
||||
|
||||
// RollbackTo performs `ROLLBACK TO SAVEPOINT xxx` SQL statement that rollbacks to specified saved transaction.
|
||||
// The parameter `point` specifies the point name that was saved previously.
|
||||
func (tx *TXCore) RollbackTo(point string) error {
|
||||
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
|
||||
return err
|
||||
}
|
||||
|
||||
// Transaction wraps the transaction logic using function `f`.
|
||||
// It rollbacks the transaction and returns the error from function `f` if
|
||||
// it returns non-nil error. It commits the transaction and returns nil if
|
||||
// function `f` returns nil.
|
||||
//
|
||||
// Note that, you should not Commit or Rollback the transaction in function `f`
|
||||
// as it is automatically handled by this function.
|
||||
func (tx *TXCore) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) {
|
||||
if ctx != nil {
|
||||
tx.ctx = ctx
|
||||
}
|
||||
// Check transaction object from context.
|
||||
if TXFromCtx(tx.ctx, tx.db.GetGroup()) == nil {
|
||||
// Inject transaction object into context.
|
||||
tx.ctx = WithTX(tx.ctx, tx)
|
||||
}
|
||||
if err = tx.Begin(); err != nil {
|
||||
return err
|
||||
}
|
||||
err = callTxFunc(tx, f)
|
||||
return
|
||||
}
|
||||
|
||||
// TransactionWithOptions wraps the transaction logic with propagation options using function `f`.
|
||||
func (tx *TXCore) TransactionWithOptions(
|
||||
ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error,
|
||||
) (err error) {
|
||||
return tx.db.TransactionWithOptions(ctx, opts, f)
|
||||
}
|
||||
|
||||
// Query does query operation on transaction.
|
||||
// See Core.Query.
|
||||
func (tx *TXCore) Query(sql string, args ...any) (result Result, err error) {
|
||||
return tx.db.DoQuery(tx.ctx, &txLink{tx.tx}, sql, args...)
|
||||
}
|
||||
|
||||
// Exec does none query operation on transaction.
|
||||
// See Core.Exec.
|
||||
func (tx *TXCore) Exec(sql string, args ...any) (sql.Result, error) {
|
||||
return tx.db.DoExec(tx.ctx, &txLink{tx.tx}, sql, args...)
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the
|
||||
// returned statement.
|
||||
// The caller must call the statement's Close method
|
||||
// when the statement is no longer needed.
|
||||
func (tx *TXCore) Prepare(sql string) (*Stmt, error) {
|
||||
return tx.db.DoPrepare(tx.ctx, &txLink{tx.tx}, sql)
|
||||
}
|
||||
|
||||
// GetAll queries and returns data records from database.
|
||||
func (tx *TXCore) GetAll(sql string, args ...any) (Result, error) {
|
||||
return tx.Query(sql, args...)
|
||||
}
|
||||
|
||||
// GetOne queries and returns one record from database.
|
||||
func (tx *TXCore) GetOne(sql string, args ...any) (Record, error) {
|
||||
list, err := tx.GetAll(sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) > 0 {
|
||||
return list[0], nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetStruct queries one record from database and converts it to given struct.
|
||||
// The parameter `pointer` should be a pointer to struct.
|
||||
func (tx *TXCore) GetStruct(obj any, sql string, args ...any) error {
|
||||
one, err := tx.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return one.Struct(obj)
|
||||
}
|
||||
|
||||
// GetStructs queries records from database and converts them to given struct.
|
||||
// The parameter `pointer` should be type of struct slice: []struct/[]*struct.
|
||||
func (tx *TXCore) GetStructs(objPointerSlice any, sql string, args ...any) error {
|
||||
all, err := tx.GetAll(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return all.Structs(objPointerSlice)
|
||||
}
|
||||
|
||||
// GetScan queries one or more records from database and converts them to given struct or
|
||||
// struct array.
|
||||
//
|
||||
// If parameter `pointer` is type of struct pointer, it calls GetStruct internally for
|
||||
// the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally
|
||||
// for conversion.
|
||||
func (tx *TXCore) GetScan(pointer any, sql string, args ...any) error {
|
||||
reflectInfo := reflection.OriginTypeAndKind(pointer)
|
||||
if reflectInfo.InputKind != reflect.Pointer {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
"params should be type of pointer, but got: %v",
|
||||
reflectInfo.InputKind,
|
||||
)
|
||||
}
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return tx.GetStructs(pointer, sql, args...)
|
||||
|
||||
case reflect.Struct:
|
||||
return tx.GetStruct(pointer, sql, args...)
|
||||
|
||||
default:
|
||||
}
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`in valid parameter type "%v", of which element type should be type of struct/slice`,
|
||||
reflectInfo.InputType,
|
||||
)
|
||||
}
|
||||
|
||||
// GetValue queries and returns the field value from database.
|
||||
// The sql should query only one field from database, or else it returns only one
|
||||
// field of the result.
|
||||
func (tx *TXCore) GetValue(sql string, args ...any) (Value, error) {
|
||||
one, err := tx.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range one {
|
||||
return v, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetCount queries and returns the count from database.
|
||||
func (tx *TXCore) GetCount(sql string, args ...any) (int64, error) {
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
|
||||
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
|
||||
}
|
||||
value, err := tx.GetValue(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value.Int64(), nil
|
||||
}
|
||||
|
||||
// Insert does "INSERT INTO ..." statement for the table.
|
||||
// If there's already one unique record of the data in the table, it returns error.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `batch` specifies the batch operation count when given data is slice.
|
||||
func (tx *TXCore) Insert(table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Insert()
|
||||
}
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Insert()
|
||||
}
|
||||
|
||||
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
|
||||
// If there's already one unique record of the data in the table, it ignores the inserting.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `batch` specifies the batch operation count when given data is slice.
|
||||
func (tx *TXCore) InsertIgnore(table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertIgnore()
|
||||
}
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).InsertIgnore()
|
||||
}
|
||||
|
||||
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
|
||||
func (tx *TXCore) InsertAndGetId(table string, data any, batch ...int) (int64, error) {
|
||||
if len(batch) > 0 {
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertAndGetId()
|
||||
}
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).InsertAndGetId()
|
||||
}
|
||||
|
||||
// Replace does "REPLACE INTO ..." statement for the table.
|
||||
// If there's already one unique record of the data in the table, it deletes the record
|
||||
// and inserts a new one.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// If given data is type of slice, it then does batch replacing, and the optional parameter
|
||||
// `batch` specifies the batch operation count.
|
||||
func (tx *TXCore) Replace(table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Replace()
|
||||
}
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Replace()
|
||||
}
|
||||
|
||||
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
|
||||
// It updates the record if there's primary or unique index in the saving data,
|
||||
// or else it inserts a new record into the table.
|
||||
//
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// If given data is type of slice, it then does batch saving, and the optional parameter
|
||||
// `batch` specifies the batch operation count.
|
||||
func (tx *TXCore) Save(table string, data any, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Save()
|
||||
}
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Save()
|
||||
}
|
||||
|
||||
// Update does "UPDATE ... " statement for the table.
|
||||
//
|
||||
// The parameter `data` can be type of string/map/gmap/struct/*struct, etc.
|
||||
// Eg: "uid=10000", "uid", 10000, g.Map{"uid": 10000, "name":"john"}
|
||||
//
|
||||
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
|
||||
// It is commonly used with parameter `args`.
|
||||
// Eg:
|
||||
// "uid=10000",
|
||||
// "uid", 10000
|
||||
// "money>? AND name like ?", 99999, "vip_%"
|
||||
// "status IN (?)", g.Slice{1,2,3}
|
||||
// "age IN(?,?)", 18, 50
|
||||
// User{ Id : 1, UserName : "john"}.
|
||||
func (tx *TXCore) Update(table string, data any, condition any, args ...any) (sql.Result, error) {
|
||||
return tx.Model(table).Ctx(tx.ctx).Data(data).Where(condition, args...).Update()
|
||||
}
|
||||
|
||||
// Delete does "DELETE FROM ... " statement for the table.
|
||||
//
|
||||
// The parameter `condition` can be type of string/map/gmap/slice/struct/*struct, etc.
|
||||
// It is commonly used with parameter `args`.
|
||||
// Eg:
|
||||
// "uid=10000",
|
||||
// "uid", 10000
|
||||
// "money>? AND name like ?", 99999, "vip_%"
|
||||
// "status IN (?)", g.Slice{1,2,3}
|
||||
// "age IN(?,?)", 18, 50
|
||||
// User{ Id : 1, UserName : "john"}.
|
||||
func (tx *TXCore) Delete(table string, condition any, args ...any) (sql.Result, error) {
|
||||
return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
|
||||
}
|
||||
|
||||
// QueryContext implements interface function Link.QueryContext.
|
||||
func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...any) (*sql.Rows, error) {
|
||||
return tx.tx.QueryContext(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// ExecContext implements interface function Link.ExecContext.
|
||||
func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...any) (sql.Result, error) {
|
||||
return tx.tx.ExecContext(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// PrepareContext implements interface function Link.PrepareContext.
|
||||
func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
|
||||
return tx.tx.PrepareContext(ctx, sql)
|
||||
}
|
||||
|
||||
// IsOnMaster implements interface function Link.IsOnMaster.
|
||||
func (tx *TXCore) IsOnMaster() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsTransaction implements interface function Link.IsTransaction.
|
||||
func (tx *TXCore) IsTransaction() bool {
|
||||
return tx != nil
|
||||
}
|
||||
|
|
@ -0,0 +1,533 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/gogf/gf/v2"
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
// Query commits one query SQL to underlying driver and returns the execution result.
|
||||
// It is most commonly used for data querying.
|
||||
func (c *Core) Query(ctx context.Context, sql string, args ...any) (result Result, err error) {
|
||||
return c.db.DoQuery(ctx, nil, sql, args...)
|
||||
}
|
||||
|
||||
// DoQuery commits the sql string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...any) (result Result, err error) {
|
||||
// Transaction checks.
|
||||
if link == nil {
|
||||
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
|
||||
// Firstly, check and retrieve transaction link from context.
|
||||
link = &txLink{tx.GetSqlTX()}
|
||||
} else if link, err = c.SlaveLink(); err != nil {
|
||||
// Or else it creates one from master node.
|
||||
return nil, err
|
||||
}
|
||||
} else if !link.IsTransaction() {
|
||||
// If current link is not transaction link, it checks and retrieves transaction from context.
|
||||
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
|
||||
link = &txLink{tx.GetSqlTX()}
|
||||
}
|
||||
}
|
||||
|
||||
// Sql filtering.
|
||||
sql, args = c.FormatSqlBeforeExecuting(sql, args)
|
||||
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// SQL format and retrieve.
|
||||
if v := ctx.Value(ctxKeyCatchSQL); v != nil {
|
||||
var (
|
||||
manager = v.(*CatchSQLManager)
|
||||
formattedSql = FormatSqlWithArgs(sql, args)
|
||||
)
|
||||
manager.SQLArray.Append(formattedSql)
|
||||
if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
// Link execution.
|
||||
var out DoCommitOutput
|
||||
out, err = c.db.DoCommit(ctx, DoCommitInput{
|
||||
Link: link,
|
||||
Sql: sql,
|
||||
Args: args,
|
||||
Stmt: nil,
|
||||
Type: SqlTypeQueryContext,
|
||||
IsTransaction: link.IsTransaction(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out.Records, err
|
||||
}
|
||||
|
||||
// Exec commits one query SQL to underlying driver and returns the execution result.
|
||||
// It is most commonly used for data inserting and updating.
|
||||
func (c *Core) Exec(ctx context.Context, sql string, args ...any) (result sql.Result, err error) {
|
||||
return c.db.DoExec(ctx, nil, sql, args...)
|
||||
}
|
||||
|
||||
// DoExec commits the sql string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...any) (result sql.Result, err error) {
|
||||
// Transaction checks.
|
||||
if link == nil {
|
||||
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
|
||||
// Firstly, check and retrieve transaction link from context.
|
||||
link = &txLink{tx.GetSqlTX()}
|
||||
} else if link, err = c.MasterLink(); err != nil {
|
||||
// Or else it creates one from master node.
|
||||
return nil, err
|
||||
}
|
||||
} else if !link.IsTransaction() {
|
||||
// If current link is not transaction link, it tries retrieving transaction object from context.
|
||||
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
|
||||
link = &txLink{tx.GetSqlTX()}
|
||||
}
|
||||
}
|
||||
|
||||
// SQL filtering.
|
||||
sql, args = c.FormatSqlBeforeExecuting(sql, args)
|
||||
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// SQL format and retrieve.
|
||||
if v := ctx.Value(ctxKeyCatchSQL); v != nil {
|
||||
var (
|
||||
manager = v.(*CatchSQLManager)
|
||||
formattedSql = FormatSqlWithArgs(sql, args)
|
||||
)
|
||||
manager.SQLArray.Append(formattedSql)
|
||||
if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
|
||||
return new(SqlResult), nil
|
||||
}
|
||||
}
|
||||
// Link execution.
|
||||
var out DoCommitOutput
|
||||
out, err = c.db.DoCommit(ctx, DoCommitInput{
|
||||
Link: link,
|
||||
Sql: sql,
|
||||
Args: args,
|
||||
Stmt: nil,
|
||||
Type: SqlTypeExecContext,
|
||||
IsTransaction: link.IsTransaction(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out.Result, err
|
||||
}
|
||||
|
||||
// DoFilter is a hook function, which filters the sql and its arguments before it's committed to underlying driver.
|
||||
// The parameter `link` specifies the current database connection operation object. You can modify the sql
|
||||
// string `sql` and its arguments `args` as you wish before they're committed to driver.
|
||||
func (c *Core) DoFilter(
|
||||
ctx context.Context, link Link, sql string, args []any,
|
||||
) (newSql string, newArgs []any, err error) {
|
||||
return sql, args, nil
|
||||
}
|
||||
|
||||
// DoCommit commits current sql and arguments to underlying sql driver.
|
||||
func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) {
|
||||
var (
|
||||
sqlTx *sql.Tx
|
||||
sqlStmt *sql.Stmt
|
||||
sqlRows *sql.Rows
|
||||
sqlResult sql.Result
|
||||
stmtSqlRows *sql.Rows
|
||||
stmtSqlRow *sql.Row
|
||||
rowsAffected int64
|
||||
cancelFuncForTimeout context.CancelFunc
|
||||
formattedSql = FormatSqlWithArgs(in.Sql, in.Args)
|
||||
timestampMilli1 = gtime.TimestampMilli()
|
||||
)
|
||||
|
||||
// Panic recovery to handle panics from underlying database drivers
|
||||
defer func() {
|
||||
if exception := recover(); exception != nil {
|
||||
if err == nil {
|
||||
if v, ok := exception.(error); ok && gerror.HasStack(v) {
|
||||
err = v
|
||||
} else {
|
||||
err = gerror.WrapCodef(gcode.CodeDbOperationError, gerror.NewCodef(gcode.CodeInternalPanic, "%+v", exception), FormatSqlWithArgs(in.Sql, in.Args))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Trace span start.
|
||||
tr := otel.GetTracerProvider().Tracer(traceInstrumentName, trace.WithInstrumentationVersion(gf.VERSION))
|
||||
ctx, span := tr.Start(ctx, string(in.Type), trace.WithSpanKind(trace.SpanKindClient))
|
||||
defer span.End()
|
||||
|
||||
// Execution by type.
|
||||
switch in.Type {
|
||||
case SqlTypeBegin:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeTrans)
|
||||
formattedSql = fmt.Sprintf(
|
||||
`%s (IosolationLevel: %s, ReadOnly: %t)`,
|
||||
formattedSql, in.TxOptions.Isolation.String(), in.TxOptions.ReadOnly,
|
||||
)
|
||||
if sqlTx, err = in.Db.BeginTx(ctx, &in.TxOptions); err == nil {
|
||||
tx := &TXCore{
|
||||
db: c.db,
|
||||
tx: sqlTx,
|
||||
ctx: ctx,
|
||||
master: in.Db,
|
||||
transactionId: guid.S(),
|
||||
cancelFunc: cancelFuncForTimeout,
|
||||
}
|
||||
tx.ctx = context.WithValue(ctx, transactionKeyForContext(tx.db.GetGroup()), tx)
|
||||
tx.ctx = context.WithValue(tx.ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1))
|
||||
out.Tx = tx
|
||||
ctx = out.Tx.GetCtx()
|
||||
}
|
||||
out.RawResult = sqlTx
|
||||
|
||||
case SqlTypeTXCommit:
|
||||
if in.TxCancelFunc != nil {
|
||||
defer in.TxCancelFunc()
|
||||
}
|
||||
err = in.Tx.Commit()
|
||||
|
||||
case SqlTypeTXRollback:
|
||||
if in.TxCancelFunc != nil {
|
||||
defer in.TxCancelFunc()
|
||||
}
|
||||
err = in.Tx.Rollback()
|
||||
|
||||
case SqlTypeExecContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
|
||||
defer cancelFuncForTimeout()
|
||||
if c.db.GetDryRun() {
|
||||
sqlResult = new(SqlResult)
|
||||
} else {
|
||||
sqlResult, err = in.Link.ExecContext(ctx, in.Sql, in.Args...)
|
||||
}
|
||||
out.RawResult = sqlResult
|
||||
|
||||
case SqlTypeQueryContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
|
||||
defer cancelFuncForTimeout()
|
||||
sqlRows, err = in.Link.QueryContext(ctx, in.Sql, in.Args...)
|
||||
out.RawResult = sqlRows
|
||||
|
||||
case SqlTypePrepareContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypePrepare)
|
||||
defer cancelFuncForTimeout()
|
||||
sqlStmt, err = in.Link.PrepareContext(ctx, in.Sql)
|
||||
out.RawResult = sqlStmt
|
||||
|
||||
case SqlTypeStmtExecContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
|
||||
defer cancelFuncForTimeout()
|
||||
if c.db.GetDryRun() {
|
||||
sqlResult = new(SqlResult)
|
||||
} else {
|
||||
sqlResult, err = in.Stmt.ExecContext(ctx, in.Args...)
|
||||
}
|
||||
out.RawResult = sqlResult
|
||||
|
||||
case SqlTypeStmtQueryContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
|
||||
defer cancelFuncForTimeout()
|
||||
stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...)
|
||||
out.RawResult = stmtSqlRows
|
||||
|
||||
case SqlTypeStmtQueryRowContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
|
||||
defer cancelFuncForTimeout()
|
||||
stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...)
|
||||
out.RawResult = stmtSqlRow
|
||||
|
||||
default:
|
||||
panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid SqlType "%s"`, in.Type))
|
||||
}
|
||||
// Result handling.
|
||||
switch {
|
||||
case sqlResult != nil && !c.GetIgnoreResultFromCtx(ctx):
|
||||
rowsAffected, err = sqlResult.RowsAffected()
|
||||
out.Result = sqlResult
|
||||
|
||||
case sqlRows != nil:
|
||||
out.Records, err = c.RowsToResult(ctx, sqlRows)
|
||||
rowsAffected = int64(len(out.Records))
|
||||
|
||||
case sqlStmt != nil:
|
||||
out.Stmt = &Stmt{
|
||||
Stmt: sqlStmt,
|
||||
core: c,
|
||||
link: in.Link,
|
||||
sql: in.Sql,
|
||||
}
|
||||
}
|
||||
var (
|
||||
timestampMilli2 = gtime.TimestampMilli()
|
||||
sqlObj = &Sql{
|
||||
Sql: in.Sql,
|
||||
Type: in.Type,
|
||||
Args: in.Args,
|
||||
Format: formattedSql,
|
||||
Error: err,
|
||||
Start: timestampMilli1,
|
||||
End: timestampMilli2,
|
||||
Group: c.db.GetGroup(),
|
||||
Schema: c.db.GetSchema(),
|
||||
RowsAffected: rowsAffected,
|
||||
IsTransaction: in.IsTransaction,
|
||||
}
|
||||
)
|
||||
|
||||
// Tracing.
|
||||
c.traceSpanEnd(ctx, span, sqlObj)
|
||||
|
||||
// Logging.
|
||||
if c.db.GetDebug() {
|
||||
c.writeSqlToLogger(ctx, sqlObj)
|
||||
}
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
err = gerror.WrapCode(
|
||||
gcode.CodeDbOperationError,
|
||||
err,
|
||||
FormatSqlWithArgs(in.Sql, in.Args),
|
||||
)
|
||||
}
|
||||
return out, err
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the
|
||||
// returned statement.
|
||||
// The caller must call the statement's Close method
|
||||
// when the statement is no longer needed.
|
||||
//
|
||||
// The parameter `execOnMaster` specifies whether executing the sql on master node,
|
||||
// or else it executes the sql on slave node if master-slave configured.
|
||||
func (c *Core) Prepare(ctx context.Context, sql string, execOnMaster ...bool) (*Stmt, error) {
|
||||
var (
|
||||
err error
|
||||
link Link
|
||||
)
|
||||
if len(execOnMaster) > 0 && execOnMaster[0] {
|
||||
if link, err = c.MasterLink(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if link, err = c.SlaveLink(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c.db.DoPrepare(ctx, link, sql)
|
||||
}
|
||||
|
||||
// DoPrepare calls prepare function on given link object and returns the statement object.
|
||||
func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt, err error) {
|
||||
// Transaction checks.
|
||||
if link == nil {
|
||||
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
|
||||
// Firstly, check and retrieve transaction link from context.
|
||||
link = &txLink{tx.GetSqlTX()}
|
||||
} else {
|
||||
// Or else it creates one from master node.
|
||||
if link, err = c.MasterLink(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else if !link.IsTransaction() {
|
||||
// If current link is not transaction link, it checks and retrieves transaction from context.
|
||||
if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
|
||||
link = &txLink{tx.GetSqlTX()}
|
||||
}
|
||||
}
|
||||
|
||||
if c.db.GetConfig().PrepareTimeout > 0 {
|
||||
// DO NOT USE cancel function in prepare statement.
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx, cancelFunc = context.WithTimeout(ctx, c.db.GetConfig().PrepareTimeout)
|
||||
defer cancelFunc()
|
||||
}
|
||||
|
||||
// Link execution.
|
||||
var out DoCommitOutput
|
||||
out, err = c.db.DoCommit(ctx, DoCommitInput{
|
||||
Link: link,
|
||||
Sql: sql,
|
||||
Type: SqlTypePrepareContext,
|
||||
IsTransaction: link.IsTransaction(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out.Stmt, err
|
||||
}
|
||||
|
||||
// FormatUpsert formats and returns SQL clause part for upsert statement.
|
||||
// In default implements, this function performs upsert statement for MySQL like:
|
||||
// `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...`
|
||||
func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) {
|
||||
var onDuplicateStr string
|
||||
if option.OnDuplicateStr != "" {
|
||||
onDuplicateStr = option.OnDuplicateStr
|
||||
} else if len(option.OnDuplicateMap) > 0 {
|
||||
for k, v := range option.OnDuplicateMap {
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
switch v.(type) {
|
||||
case Raw, *Raw:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=%s",
|
||||
c.QuoteWord(k),
|
||||
v,
|
||||
)
|
||||
case Counter, *Counter:
|
||||
var counter Counter
|
||||
switch value := v.(type) {
|
||||
case Counter:
|
||||
counter = value
|
||||
case *Counter:
|
||||
counter = *value
|
||||
}
|
||||
operator, columnVal := c.getCounterAlter(counter)
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=%s%s%s",
|
||||
c.QuoteWord(k),
|
||||
c.QuoteWord(counter.Field),
|
||||
operator,
|
||||
gconv.String(columnVal),
|
||||
)
|
||||
default:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=VALUES(%s)",
|
||||
c.QuoteWord(k),
|
||||
c.QuoteWord(gconv.String(v)),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, column := range columns {
|
||||
// If it's `SAVE` operation, do not automatically update the creating time.
|
||||
if c.IsSoftCreatedFieldName(column) {
|
||||
continue
|
||||
}
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=VALUES(%s)",
|
||||
c.QuoteWord(column),
|
||||
c.QuoteWord(column),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil
|
||||
}
|
||||
|
||||
// RowsToResult converts underlying data record type sql.Rows to Result type.
|
||||
func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) {
|
||||
if rows == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
intlog.Errorf(ctx, `%+v`, err)
|
||||
}
|
||||
}()
|
||||
if !rows.Next() {
|
||||
return nil, nil
|
||||
}
|
||||
// Column names and types.
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(columnTypes) > 0 {
|
||||
if internalData := c.getInternalColumnFromCtx(ctx); internalData != nil {
|
||||
internalData.FirstResultColumn = columnTypes[0].Name()
|
||||
}
|
||||
}
|
||||
var (
|
||||
values = make([]any, len(columnTypes))
|
||||
result = make(Result, 0)
|
||||
scanArgs = make([]any, len(values))
|
||||
)
|
||||
for i := range values {
|
||||
scanArgs[i] = &values[i]
|
||||
}
|
||||
for {
|
||||
if err = rows.Scan(scanArgs...); err != nil {
|
||||
return result, err
|
||||
}
|
||||
record := Record{}
|
||||
for i, value := range values {
|
||||
if value == nil {
|
||||
// DO NOT use `gvar.New(nil)` here as it creates an initialized object
|
||||
// which will cause struct converting issue.
|
||||
record[columnTypes[i].Name()] = nil
|
||||
} else {
|
||||
var (
|
||||
convertedValue any
|
||||
columnType = columnTypes[i]
|
||||
)
|
||||
if convertedValue, err = c.columnValueToLocalValue(ctx, value, columnType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record[columnTypes[i].Name()] = gvar.New(convertedValue)
|
||||
}
|
||||
}
|
||||
result = append(result, record)
|
||||
if !rows.Next() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// OrderRandomFunction returns the SQL function for random ordering.
|
||||
func (c *Core) OrderRandomFunction() string {
|
||||
return "RAND()"
|
||||
}
|
||||
|
||||
func (c *Core) columnValueToLocalValue(ctx context.Context, value any, columnType *sql.ColumnType) (any, error) {
|
||||
var scanType = columnType.ScanType()
|
||||
if scanType != nil {
|
||||
// Common basic builtin types.
|
||||
switch scanType.Kind() {
|
||||
case
|
||||
reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64:
|
||||
return gconv.Convert(gconv.String(value), scanType.String()), nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
// Other complex types, especially custom types.
|
||||
return c.db.ConvertValueForLocal(ctx, columnType.DatabaseTypeName(), value)
|
||||
}
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// GetDB returns the underlying DB.
|
||||
func (c *Core) GetDB() DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
// GetLink creates and returns the underlying database link object with transaction checks.
|
||||
// The parameter `master` specifies whether using the master node if master-slave configured.
|
||||
func (c *Core) GetLink(ctx context.Context, master bool, schema string) (Link, error) {
|
||||
tx := TXFromCtx(ctx, c.db.GetGroup())
|
||||
if tx != nil {
|
||||
return &txLink{tx.GetSqlTX()}, nil
|
||||
}
|
||||
if master {
|
||||
link, err := c.db.GetCore().MasterLink(schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
link, err := c.db.GetCore().SlaveLink(schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
// MasterLink acts like function Master but with additional `schema` parameter specifying
|
||||
// the schema for the connection. It is defined for internal usage.
|
||||
// Also see Master.
|
||||
func (c *Core) MasterLink(schema ...string) (Link, error) {
|
||||
db, err := c.db.Master(schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dbLink{
|
||||
DB: db,
|
||||
isOnMaster: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SlaveLink acts like function Slave but with additional `schema` parameter specifying
|
||||
// the schema for the connection. It is defined for internal usage.
|
||||
// Also see Slave.
|
||||
func (c *Core) SlaveLink(schema ...string) (Link, error) {
|
||||
db, err := c.db.Slave(schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dbLink{
|
||||
DB: db,
|
||||
isOnMaster: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QuoteWord checks given string `s` a word,
|
||||
// if true it quotes `s` with security chars of the database
|
||||
// and returns the quoted string; or else it returns `s` without any change.
|
||||
//
|
||||
// The meaning of a `word` can be considered as a column name.
|
||||
func (c *Core) QuoteWord(s string) string {
|
||||
s = gstr.Trim(s)
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
charLeft, charRight := c.db.GetChars()
|
||||
return doQuoteWord(s, charLeft, charRight)
|
||||
}
|
||||
|
||||
// QuoteString quotes string with quote chars. Strings like:
|
||||
// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc".
|
||||
//
|
||||
// The meaning of a `string` can be considered as part of a statement string including columns.
|
||||
func (c *Core) QuoteString(s string) string {
|
||||
if !gregex.IsMatchString(regularFieldNameWithCommaRegPattern, s) {
|
||||
return s
|
||||
}
|
||||
charLeft, charRight := c.db.GetChars()
|
||||
return doQuoteString(s, charLeft, charRight)
|
||||
}
|
||||
|
||||
// QuotePrefixTableName adds prefix string and quotes chars for the table.
|
||||
// It handles table string like:
|
||||
// "user", "user u",
|
||||
// "user,user_detail",
|
||||
// "user u, user_detail ut",
|
||||
// "user as u, user_detail as ut".
|
||||
//
|
||||
// Note that, this will automatically checks the table prefix whether already added,
|
||||
// if true it does nothing to the table name, or else adds the prefix to the table name.
|
||||
func (c *Core) QuotePrefixTableName(table string) string {
|
||||
charLeft, charRight := c.db.GetChars()
|
||||
return doQuoteTableName(table, c.db.GetPrefix(), charLeft, charRight)
|
||||
}
|
||||
|
||||
// GetChars returns the security char for current database.
|
||||
// It does nothing in default.
|
||||
func (c *Core) GetChars() (charLeft string, charRight string) {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
func (c *Core) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields' information of specified table of current
|
||||
// schema.
|
||||
//
|
||||
// The parameter `link` is optional, if given nil it automatically retrieves a raw sql connection
|
||||
// as its link to proceed necessary sql query.
|
||||
//
|
||||
// Note that it returns a map containing the field name and its corresponding fields.
|
||||
// As a map is unsorted, the TableField struct has an "Index" field marks its sequence in
|
||||
// the fields.
|
||||
//
|
||||
// It's using cache feature to enhance the performance, which is never expired util the
|
||||
// process restarts.
|
||||
func (c *Core) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// ClearTableFields removes certain cached table fields of current configuration group.
|
||||
func (c *Core) ClearTableFields(ctx context.Context, table string, schema ...string) (err error) {
|
||||
tableFieldsCacheKey := genTableFieldsCacheKey(
|
||||
c.db.GetGroup(),
|
||||
gutil.GetOrDefaultStr(c.db.GetSchema(), schema...),
|
||||
table,
|
||||
)
|
||||
_, err = c.innerMemCache.Remove(ctx, tableFieldsCacheKey)
|
||||
return
|
||||
}
|
||||
|
||||
// ClearTableFieldsAll removes all cached table fields of current configuration group.
|
||||
func (c *Core) ClearTableFieldsAll(ctx context.Context) (err error) {
|
||||
var (
|
||||
keys, _ = c.innerMemCache.KeyStrings(ctx)
|
||||
cachePrefix = cachePrefixTableFields
|
||||
removedKeys = make([]any, 0)
|
||||
)
|
||||
for _, key := range keys {
|
||||
if gstr.HasPrefix(key, cachePrefix) {
|
||||
removedKeys = append(removedKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(removedKeys) > 0 {
|
||||
err = c.innerMemCache.Removes(ctx, removedKeys)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ClearCache removes cached sql result of certain table.
|
||||
func (c *Core) ClearCache(ctx context.Context, table string) (err error) {
|
||||
var (
|
||||
keys, _ = c.db.GetCache().KeyStrings(ctx)
|
||||
cachePrefix = fmt.Sprintf(`%s%s@`, cachePrefixSelectCache, table)
|
||||
removedKeys = make([]any, 0)
|
||||
)
|
||||
for _, key := range keys {
|
||||
if gstr.HasPrefix(key, cachePrefix) {
|
||||
removedKeys = append(removedKeys, key)
|
||||
}
|
||||
}
|
||||
if len(removedKeys) > 0 {
|
||||
err = c.db.GetCache().Removes(ctx, removedKeys)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ClearCacheAll removes all cached sql result from cache
|
||||
func (c *Core) ClearCacheAll(ctx context.Context) (err error) {
|
||||
if err = c.db.GetCache().Clear(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = c.GetInnerMemCache().Clear(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// HasField determine whether the field exists in the table.
|
||||
func (c *Core) HasField(ctx context.Context, table, field string, schema ...string) (bool, error) {
|
||||
table = c.guessPrimaryTableName(table)
|
||||
tableFields, err := c.db.TableFields(ctx, table, schema...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(tableFields) == 0 {
|
||||
return false, gerror.NewCodef(
|
||||
gcode.CodeNotFound,
|
||||
`empty table fields for table "%s"`, table,
|
||||
)
|
||||
}
|
||||
fieldsArray := make([]string, len(tableFields))
|
||||
for k, v := range tableFields {
|
||||
fieldsArray[v.Index] = k
|
||||
}
|
||||
charLeft, charRight := c.db.GetChars()
|
||||
field = gstr.Trim(field, charLeft+charRight)
|
||||
for _, f := range fieldsArray {
|
||||
if f == field {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// guessPrimaryTableName parses and returns the primary table name.
|
||||
func (c *Core) guessPrimaryTableName(tableStr string) string {
|
||||
if tableStr == "" {
|
||||
return ""
|
||||
}
|
||||
var (
|
||||
guessedTableName string
|
||||
array1 = gstr.SplitAndTrim(tableStr, ",")
|
||||
array2 = gstr.SplitAndTrim(array1[0], " ")
|
||||
array3 = gstr.SplitAndTrim(array2[0], ".")
|
||||
)
|
||||
if len(array3) >= 2 {
|
||||
guessedTableName = array3[1]
|
||||
} else {
|
||||
guessedTableName = array3[0]
|
||||
}
|
||||
charL, charR := c.db.GetChars()
|
||||
if charL != "" || charR != "" {
|
||||
guessedTableName = gstr.Trim(guessedTableName, charL+charR)
|
||||
}
|
||||
if !gregex.IsMatchString(regularFieldNameRegPattern, guessedTableName) {
|
||||
return ""
|
||||
}
|
||||
return guessedTableName
|
||||
}
|
||||
|
||||
// GetPrimaryKeys retrieves and returns the primary key field names of the specified table.
|
||||
// This method extracts primary key information from TableFields.
|
||||
// The parameter `schema` is optional, if not specified it uses the default schema.
|
||||
func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) {
|
||||
tableFields, err := c.db.TableFields(ctx, table, schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var primaryKeys []string
|
||||
for _, field := range tableFields {
|
||||
if strings.EqualFold(field.Key, "pri") {
|
||||
primaryKeys = append(primaryKeys, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return primaryKeys, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
package database
|
||||
|
||||
type IDao interface {
|
||||
DB() DB
|
||||
TableName() string
|
||||
Columns() any
|
||||
}
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/config"
|
||||
"git.magicany.cc/black1552/gin-base/database/instance"
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
customLog "git.magicany.cc/black1552/gin-base/log"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
const (
|
||||
frameCoreComponentNameDatabase = "core.component.database"
|
||||
ConfigNodeNameDatabase = "database"
|
||||
)
|
||||
|
||||
func Database(name ...string) gdb.DB {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
group = gdb.DefaultGroupName
|
||||
)
|
||||
|
||||
if len(name) > 0 && name[0] != "" {
|
||||
group = name[0]
|
||||
}
|
||||
instanceKey := fmt.Sprintf("%s.%s", frameCoreComponentNameDatabase, group)
|
||||
db := instance.GetOrSetFuncLock(instanceKey, func() interface{} {
|
||||
// It ignores returned error to avoid file no found error while it's not necessary.
|
||||
var (
|
||||
configMap map[string]interface{}
|
||||
configNodeKey = ConfigNodeNameDatabase
|
||||
)
|
||||
// It firstly searches the configuration of the instance name.
|
||||
if configData := config.GetAllConfig(); len(configData) > 0 {
|
||||
if v, _ := gutil.MapPossibleItemByKey(configData, ConfigNodeNameDatabase); v != "" {
|
||||
configNodeKey = v
|
||||
}
|
||||
}
|
||||
if v := config.GetConfigValue(configNodeKey); !v.IsEmpty() {
|
||||
configMap = v.Map()
|
||||
}
|
||||
// 检查配置是否存在
|
||||
if len(configMap) == 0 {
|
||||
// 从 config 包获取所有配置,检查是否包含数据库配置
|
||||
allConfig := config.GetAllConfig()
|
||||
if len(allConfig) == 0 {
|
||||
panic(gerror.NewCodef(
|
||||
gcode.CodeMissingConfiguration,
|
||||
`database initialization failed: configuration file is empty or invalid`,
|
||||
))
|
||||
}
|
||||
// 检查是否配置了数据库节点
|
||||
if _, exists := allConfig["database"]; !exists {
|
||||
// 尝试其他可能的键名
|
||||
found := false
|
||||
for key := range allConfig {
|
||||
if key == "DATABASE" || key == "Database" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
panic(gerror.NewCodef(
|
||||
gcode.CodeMissingConfiguration,
|
||||
`database initialization failed: configuration missing for database node "%s"`,
|
||||
ConfigNodeNameDatabase,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(configMap) == 0 {
|
||||
configMap = make(map[string]interface{})
|
||||
}
|
||||
// Parse `m` as map-slice and adds it to global configurations for package gdb.
|
||||
for g, groupConfig := range configMap {
|
||||
cg := ConfigGroup{}
|
||||
switch value := groupConfig.(type) {
|
||||
case []interface{}:
|
||||
for _, v := range value {
|
||||
if node := parseDBConfigNode(v); node != nil {
|
||||
cg = append(cg, *node)
|
||||
}
|
||||
}
|
||||
case map[string]interface{}:
|
||||
if node := parseDBConfigNode(value); node != nil {
|
||||
cg = append(cg, *node)
|
||||
}
|
||||
}
|
||||
if len(cg) > 0 {
|
||||
if GetConfig(group) == nil {
|
||||
intlog.Printf(ctx, "add configuration for group: %s, %#v", g, cg)
|
||||
SetConfigGroup(g, cg)
|
||||
} else {
|
||||
intlog.Printf(ctx, "ignore configuration as it already exists for group: %s, %#v", g, cg)
|
||||
intlog.Printf(ctx, "%s, %#v", g, cg)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Parse `m` as a single node configuration,
|
||||
// which is the default group configuration.
|
||||
if node := parseDBConfigNode(configMap); node != nil {
|
||||
cg := ConfigGroup{}
|
||||
if node.Link != "" || node.Host != "" {
|
||||
cg = append(cg, *node)
|
||||
}
|
||||
if len(cg) > 0 {
|
||||
if GetConfig(group) == nil {
|
||||
intlog.Printf(ctx, "add configuration for group: %s, %#v", DefaultGroupName, cg)
|
||||
SetConfigGroup(DefaultGroupName, cg)
|
||||
} else {
|
||||
intlog.Printf(
|
||||
ctx,
|
||||
"ignore configuration as it already exists for group: %s, %#v",
|
||||
DefaultGroupName, cg,
|
||||
)
|
||||
intlog.Printf(ctx, "%s, %#v", DefaultGroupName, cg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new ORM object with given configurations.
|
||||
if db, err := NewByGroup(name...); err == nil {
|
||||
// 自动初始化自定义日志器并设置到数据库
|
||||
db.SetLogger(customLog.GetLogger())
|
||||
return db
|
||||
} else {
|
||||
// If panics, often because it does not find its configuration for given group.
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if db != nil {
|
||||
return db.(gdb.DB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDBConfigNode(value interface{}) *ConfigNode {
|
||||
nodeMap, ok := value.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
node = &ConfigNode{}
|
||||
err = gconv.Struct(nodeMap, node)
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Find possible `Link` configuration content.
|
||||
if _, v := gutil.MapPossibleItemByKey(nodeMap, "Link"); v != nil {
|
||||
node.Link = gconv.String(v)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// DriverDefault is the default driver for mysql database, which does nothing.
|
||||
type DriverDefault struct {
|
||||
*Core
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := Register("default", &DriverDefault{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// New creates and returns a database object for mysql.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverDefault) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
return &DriverDefault{
|
||||
Core: core,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Open creates and returns an underlying sql.DB object for mysql.
|
||||
// Note that it converts time.Time argument to local timezone in default.
|
||||
func (d *DriverDefault) Open(config *ConfigNode) (db *sql.DB, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// PingMaster pings the master node to check authentication or keeps the connection alive.
|
||||
func (d *DriverDefault) PingMaster() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PingSlave pings the slave node to check authentication or keeps the connection alive.
|
||||
func (d *DriverDefault) PingSlave() error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// DriverWrapper is a driver wrapper for extending features with embedded driver.
|
||||
type DriverWrapper struct {
|
||||
driver Driver
|
||||
}
|
||||
|
||||
// New creates and returns a database object for mysql.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverWrapper) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
db, err := d.driver.New(core, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DriverWrapperDB{
|
||||
DB: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newDriverWrapper creates and returns a driver wrapper.
|
||||
func newDriverWrapper(driver Driver) Driver {
|
||||
return &DriverWrapper{
|
||||
driver: driver,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gcache"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// DriverWrapperDB is a DB wrapper for extending features with embedded DB.
|
||||
type DriverWrapperDB struct {
|
||||
DB
|
||||
}
|
||||
|
||||
// Open creates and returns an underlying sql.DB object for pgsql.
|
||||
// https://pkg.go.dev/github.com/lib/pq
|
||||
func (d *DriverWrapperDB) Open(node *ConfigNode) (db *sql.DB, err error) {
|
||||
var ctx = d.GetCtx()
|
||||
intlog.PrintFunc(ctx, func() string {
|
||||
return fmt.Sprintf(`open new connection:%s`, gjson.MustEncode(node))
|
||||
})
|
||||
return d.DB.Open(node)
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
func (d *DriverWrapperDB) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
|
||||
ctx = context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{})
|
||||
return d.DB.Tables(ctx, schema...)
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields' information of specified table of current
|
||||
// schema.
|
||||
//
|
||||
// The parameter `link` is optional, if given nil it automatically retrieves a raw sql connection
|
||||
// as its link to proceed necessary sql query.
|
||||
//
|
||||
// Note that it returns a map containing the field name and its corresponding fields.
|
||||
// As a map is unsorted, the TableField struct has an "Index" field marks its sequence in
|
||||
// the fields.
|
||||
//
|
||||
// It's using cache feature to enhance the performance, which is never expired util the
|
||||
// process restarts.
|
||||
func (d *DriverWrapperDB) TableFields(
|
||||
ctx context.Context, table string, schema ...string,
|
||||
) (fields map[string]*TableField, err error) {
|
||||
if table == "" {
|
||||
return nil, nil
|
||||
}
|
||||
charL, charR := d.GetChars()
|
||||
table = gstr.Trim(table, charL+charR)
|
||||
if gstr.Contains(table, " ") {
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeInvalidParameter,
|
||||
"function TableFields supports only single table operations",
|
||||
)
|
||||
}
|
||||
var (
|
||||
innerMemCache = d.GetCore().GetInnerMemCache()
|
||||
// prefix:group@schema#table
|
||||
cacheKey = genTableFieldsCacheKey(
|
||||
d.GetGroup(),
|
||||
gutil.GetOrDefaultStr(d.GetSchema(), schema...),
|
||||
table,
|
||||
)
|
||||
cacheFunc = func(ctx context.Context) (any, error) {
|
||||
return d.DB.TableFields(
|
||||
context.WithValue(ctx, ctxKeyInternalProducedSQL, struct{}{}),
|
||||
table, schema...,
|
||||
)
|
||||
}
|
||||
value *gvar.Var
|
||||
)
|
||||
value, err = innerMemCache.GetOrSetFuncLock(
|
||||
ctx, cacheKey, cacheFunc, gcache.DurationNoExpire,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !value.IsNil() {
|
||||
fields = value.Val().(map[string]*TableField)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DoInsert inserts or updates data for given table.
|
||||
// This function is usually used for custom interface definition, you do not need call it manually.
|
||||
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter `option` values are as follows:
|
||||
// InsertOptionDefault: just insert, if there's unique/primary key in the data, it returns error;
|
||||
// InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
|
||||
// InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one;
|
||||
// InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting;
|
||||
func (d *DriverWrapperDB) DoInsert(
|
||||
ctx context.Context, link Link, table string, list List, option DoInsertOption,
|
||||
) (result sql.Result, err error) {
|
||||
if len(list) == 0 {
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidRequest,
|
||||
`data list is empty for %s operation`,
|
||||
GetInsertOperationByOption(option.InsertOption),
|
||||
)
|
||||
}
|
||||
|
||||
// Convert data type before commit it to underlying db driver.
|
||||
for i, item := range list {
|
||||
list[i], err = d.GetCore().ConvertDataForRecord(ctx, item, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return d.DB.DoInsert(ctx, link, table, list, option)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,350 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// Model is core struct implementing the DAO for ORM.
|
||||
type Model struct {
|
||||
db DB // Underlying DB interface.
|
||||
tx TX // Underlying TX interface.
|
||||
rawSql string // rawSql is the raw SQL string which marks a raw SQL based Model not a table based Model.
|
||||
schema string // Custom database schema.
|
||||
linkType int // Mark for operation on master or slave.
|
||||
tablesInit string // Table names when model initialization.
|
||||
tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud".
|
||||
fields []any // Operation fields, multiple fields joined using char ','.
|
||||
fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering.
|
||||
withArray []any // Arguments for With feature.
|
||||
withAll bool // Enable model association operations on all objects that have "with" tag in the struct.
|
||||
extraArgs []any // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver.
|
||||
whereBuilder *WhereBuilder // Condition builder for where operation.
|
||||
groupBy string // Used for "group by" statement.
|
||||
orderBy string // Used for "order by" statement.
|
||||
having []any // Used for "having..." statement.
|
||||
start int // Used for "select ... start, limit ..." statement.
|
||||
limit int // Used for "select ... start, limit ..." statement.
|
||||
option int // Option for extra operation features.
|
||||
offset int // Offset statement for some databases grammar.
|
||||
partition string // Partition table partition name.
|
||||
data any // Data for operation, which can be type of map/[]map/struct/*struct/string, etc.
|
||||
batch int // Batch number for batch Insert/Replace/Save operations.
|
||||
filter bool // Filter data and where key-value pairs according to the fields of the table.
|
||||
distinct string // Force the query to only return distinct results.
|
||||
lockInfo string // Lock for update or in shared lock.
|
||||
cacheEnabled bool // Enable sql result cache feature, which is mainly for indicating cache duration(especially 0) usage.
|
||||
cacheOption CacheOption // Cache option for query statement.
|
||||
pageCacheOption []CacheOption // Cache option for paging query statement.
|
||||
hookHandler HookHandler // Hook functions for model hook feature.
|
||||
unscoped bool // Disables soft deleting features when select/delete operations.
|
||||
safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model.
|
||||
onDuplicate any // onDuplicate is used for on Upsert clause.
|
||||
onDuplicateEx any // onDuplicateEx is used for excluding some columns on Upsert clause.
|
||||
onConflict any // onConflict is used for conflict keys on Upsert clause.
|
||||
tableAliasMap map[string]string // Table alias to true table name, usually used in join statements.
|
||||
softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model.
|
||||
shardingConfig ShardingConfig // ShardingConfig for database/table sharding feature.
|
||||
shardingValue any // Sharding value for sharding feature.
|
||||
}
|
||||
|
||||
// ModelHandler is a function that handles given Model and returns a new Model that is custom modified.
|
||||
type ModelHandler func(m *Model) *Model
|
||||
|
||||
// ChunkHandler is a function that is used in function Chunk, which handles given Result and error.
|
||||
// It returns true if it wants to continue chunking, or else it returns false to stop chunking.
|
||||
type ChunkHandler func(result Result, err error) bool
|
||||
|
||||
const (
|
||||
linkTypeMaster = 1
|
||||
linkTypeSlave = 2
|
||||
defaultField = "*"
|
||||
whereHolderOperatorWhere = 1
|
||||
whereHolderOperatorAnd = 2
|
||||
whereHolderOperatorOr = 3
|
||||
whereHolderTypeDefault = "Default"
|
||||
whereHolderTypeNoArgs = "NoArgs"
|
||||
whereHolderTypeIn = "In"
|
||||
)
|
||||
|
||||
// Model creates and returns a new ORM model from given schema.
|
||||
// The parameter `tableNameQueryOrStruct` can be more than one table names, and also alias name, like:
|
||||
// 1. Model names:
|
||||
// db.Model("user")
|
||||
// db.Model("user u")
|
||||
// db.Model("user, user_detail")
|
||||
// db.Model("user u, user_detail ud")
|
||||
// 2. Model name with alias:
|
||||
// db.Model("user", "u")
|
||||
// 3. Model name with sub-query:
|
||||
// db.Model("? AS a, ? AS b", subQuery1, subQuery2)
|
||||
func (c *Core) Model(tableNameQueryOrStruct ...any) *Model {
|
||||
var (
|
||||
ctx = c.db.GetCtx()
|
||||
tableStr string
|
||||
tableName string
|
||||
extraArgs []any
|
||||
)
|
||||
// Model creation with sub-query.
|
||||
if len(tableNameQueryOrStruct) > 1 {
|
||||
conditionStr := gconv.String(tableNameQueryOrStruct[0])
|
||||
if gstr.Contains(conditionStr, "?") {
|
||||
whereHolder := WhereHolder{
|
||||
Where: conditionStr,
|
||||
Args: tableNameQueryOrStruct[1:],
|
||||
}
|
||||
tableStr, extraArgs = formatWhereHolder(ctx, c.db, formatWhereHolderInput{
|
||||
WhereHolder: whereHolder,
|
||||
OmitNil: false,
|
||||
OmitEmpty: false,
|
||||
Schema: "",
|
||||
Table: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
// Normal model creation.
|
||||
if tableStr == "" {
|
||||
tableNames := make([]string, len(tableNameQueryOrStruct))
|
||||
for k, v := range tableNameQueryOrStruct {
|
||||
if s, ok := v.(string); ok {
|
||||
tableNames[k] = s
|
||||
} else if tableName = getTableNameFromOrmTag(v); tableName != "" {
|
||||
tableNames[k] = tableName
|
||||
}
|
||||
}
|
||||
if len(tableNames) > 1 {
|
||||
tableStr = fmt.Sprintf(
|
||||
`%s AS %s`, c.QuotePrefixTableName(tableNames[0]), c.QuoteWord(tableNames[1]),
|
||||
)
|
||||
} else if len(tableNames) == 1 {
|
||||
tableStr = c.QuotePrefixTableName(tableNames[0])
|
||||
}
|
||||
}
|
||||
m := &Model{
|
||||
db: c.db,
|
||||
schema: c.schema,
|
||||
tablesInit: tableStr,
|
||||
tables: tableStr,
|
||||
start: -1,
|
||||
offset: -1,
|
||||
filter: true,
|
||||
extraArgs: extraArgs,
|
||||
tableAliasMap: make(map[string]string),
|
||||
}
|
||||
m.whereBuilder = m.Builder()
|
||||
if defaultModelSafe {
|
||||
m.safe = true
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Raw creates and returns a model based on a raw sql not a table.
|
||||
// Example:
|
||||
//
|
||||
// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result)
|
||||
func (c *Core) Raw(rawSql string, args ...any) *Model {
|
||||
model := c.Model()
|
||||
model.rawSql = rawSql
|
||||
model.extraArgs = args
|
||||
return model
|
||||
}
|
||||
|
||||
// Raw sets current model as a raw sql model.
|
||||
// Example:
|
||||
//
|
||||
// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result)
|
||||
//
|
||||
// See Core.Raw.
|
||||
func (m *Model) Raw(rawSql string, args ...any) *Model {
|
||||
model := m.db.Raw(rawSql, args...)
|
||||
model.db = m.db
|
||||
model.tx = m.tx
|
||||
return model
|
||||
}
|
||||
|
||||
func (tx *TXCore) Raw(rawSql string, args ...any) *Model {
|
||||
return tx.Model().Raw(rawSql, args...)
|
||||
}
|
||||
|
||||
// With creates and returns an ORM model based on metadata of given object.
|
||||
func (c *Core) With(objects ...any) *Model {
|
||||
return c.db.Model().With(objects...)
|
||||
}
|
||||
|
||||
// Partition sets Partition name.
|
||||
// Example:
|
||||
// dao.User.Ctx(ctx).Partition("p1","p2","p3").All()
|
||||
func (m *Model) Partition(partitions ...string) *Model {
|
||||
model := m.getModel()
|
||||
model.partition = gstr.Join(partitions, ",")
|
||||
return model
|
||||
}
|
||||
|
||||
// Model acts like Core.Model except it operates on transaction.
|
||||
// See Core.Model.
|
||||
func (tx *TXCore) Model(tableNameQueryOrStruct ...any) *Model {
|
||||
model := tx.db.Model(tableNameQueryOrStruct...)
|
||||
model.db = tx.db
|
||||
model.tx = tx
|
||||
return model
|
||||
}
|
||||
|
||||
// With acts like Core.With except it operates on transaction.
|
||||
// See Core.With.
|
||||
func (tx *TXCore) With(object any) *Model {
|
||||
return tx.Model().With(object)
|
||||
}
|
||||
|
||||
// Ctx sets the context for current operation.
|
||||
func (m *Model) Ctx(ctx context.Context) *Model {
|
||||
if ctx == nil {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
model.db = model.db.Ctx(ctx)
|
||||
if m.tx != nil {
|
||||
model.tx = model.tx.Ctx(ctx)
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// GetCtx returns the context for current Model.
|
||||
// It returns `context.Background()` is there's no context previously set.
|
||||
func (m *Model) GetCtx() context.Context {
|
||||
if m.tx != nil && m.tx.GetCtx() != nil {
|
||||
return m.tx.GetCtx()
|
||||
}
|
||||
return m.db.GetCtx()
|
||||
}
|
||||
|
||||
// As sets an alias name for current table.
|
||||
func (m *Model) As(as string) *Model {
|
||||
if m.tables != "" {
|
||||
model := m.getModel()
|
||||
split := " JOIN "
|
||||
if gstr.ContainsI(model.tables, split) {
|
||||
// For join table.
|
||||
array := gstr.Split(model.tables, split)
|
||||
array[len(array)-1], _ = gregex.ReplaceString(`(.+) ON`, fmt.Sprintf(`$1 AS %s ON`, as), array[len(array)-1])
|
||||
model.tables = gstr.Join(array, split)
|
||||
} else {
|
||||
// For base table.
|
||||
model.tables = gstr.TrimRight(model.tables) + " AS " + as
|
||||
}
|
||||
return model
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// DB sets/changes the db object for current operation.
|
||||
func (m *Model) DB(db DB) *Model {
|
||||
model := m.getModel()
|
||||
model.db = db
|
||||
return model
|
||||
}
|
||||
|
||||
// TX sets/changes the transaction for current operation.
|
||||
func (m *Model) TX(tx TX) *Model {
|
||||
model := m.getModel()
|
||||
model.db = tx.GetDB()
|
||||
model.tx = tx
|
||||
return model
|
||||
}
|
||||
|
||||
// Schema sets the schema for current operation.
|
||||
func (m *Model) Schema(schema string) *Model {
|
||||
model := m.getModel()
|
||||
model.schema = schema
|
||||
return model
|
||||
}
|
||||
|
||||
// Clone creates and returns a new model which is a Clone of current model.
|
||||
// Note that it uses deep-copy for the Clone.
|
||||
func (m *Model) Clone() *Model {
|
||||
newModel := (*Model)(nil)
|
||||
if m.tx != nil {
|
||||
newModel = m.tx.Model(m.tablesInit)
|
||||
} else {
|
||||
newModel = m.db.Model(m.tablesInit)
|
||||
}
|
||||
// Basic attributes copy.
|
||||
*newModel = *m
|
||||
// WhereBuilder copy, note the attribute pointer.
|
||||
newModel.whereBuilder = m.whereBuilder.Clone()
|
||||
newModel.whereBuilder.model = newModel
|
||||
// Shallow copy slice attributes.
|
||||
if n := len(m.fields); n > 0 {
|
||||
newModel.fields = make([]any, n)
|
||||
copy(newModel.fields, m.fields)
|
||||
}
|
||||
if n := len(m.fieldsEx); n > 0 {
|
||||
newModel.fieldsEx = make([]any, n)
|
||||
copy(newModel.fieldsEx, m.fieldsEx)
|
||||
}
|
||||
if n := len(m.extraArgs); n > 0 {
|
||||
newModel.extraArgs = make([]any, n)
|
||||
copy(newModel.extraArgs, m.extraArgs)
|
||||
}
|
||||
if n := len(m.withArray); n > 0 {
|
||||
newModel.withArray = make([]any, n)
|
||||
copy(newModel.withArray, m.withArray)
|
||||
}
|
||||
if n := len(m.having); n > 0 {
|
||||
newModel.having = make([]any, n)
|
||||
copy(newModel.having, m.having)
|
||||
}
|
||||
return newModel
|
||||
}
|
||||
|
||||
// Master marks the following operation on master node.
|
||||
func (m *Model) Master() *Model {
|
||||
model := m.getModel()
|
||||
model.linkType = linkTypeMaster
|
||||
return model
|
||||
}
|
||||
|
||||
// Slave marks the following operation on slave node.
|
||||
// Note that it makes sense only if there's any slave node configured.
|
||||
func (m *Model) Slave() *Model {
|
||||
model := m.getModel()
|
||||
model.linkType = linkTypeSlave
|
||||
return model
|
||||
}
|
||||
|
||||
// Safe marks this model safe or unsafe. If safe is true, it clones and returns a new model object
|
||||
// whenever the operation done, or else it changes the attribute of current model.
|
||||
func (m *Model) Safe(safe ...bool) *Model {
|
||||
if len(safe) > 0 {
|
||||
m.safe = safe[0]
|
||||
} else {
|
||||
m.safe = true
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Args sets custom arguments for model operation.
|
||||
func (m *Model) Args(args ...any) *Model {
|
||||
model := m.getModel()
|
||||
model.extraArgs = append(model.extraArgs, args)
|
||||
return model
|
||||
}
|
||||
|
||||
// Handler calls each of `handlers` on current Model and returns a new Model.
|
||||
// ModelHandler is a function that handles given Model and returns a new Model that is custom modified.
|
||||
func (m *Model) Handler(handlers ...ModelHandler) *Model {
|
||||
model := m.getModel()
|
||||
for _, handler := range handlers {
|
||||
model = handler(model)
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// WhereBuilder holds multiple where conditions in a group.
|
||||
type WhereBuilder struct {
|
||||
model *Model // A WhereBuilder should be bound to certain Model.
|
||||
whereHolder []WhereHolder // Condition strings for where operation.
|
||||
}
|
||||
|
||||
// WhereHolder is the holder for where condition preparing.
|
||||
type WhereHolder struct {
|
||||
Type string // Type of this holder.
|
||||
Operator int // Operator for this holder.
|
||||
Where any // Where parameter, which can commonly be type of string/map/struct.
|
||||
Args []any // Arguments for where parameter.
|
||||
Prefix string // Field prefix, eg: "user.", "order.".
|
||||
}
|
||||
|
||||
// Builder creates and returns a WhereBuilder. Please note that the builder is chain-safe.
|
||||
func (m *Model) Builder() *WhereBuilder {
|
||||
b := &WhereBuilder{
|
||||
model: m,
|
||||
whereHolder: make([]WhereHolder, 0),
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// getBuilder creates and returns a cloned WhereBuilder of current WhereBuilder
|
||||
func (b *WhereBuilder) getBuilder() *WhereBuilder {
|
||||
return b.Clone()
|
||||
}
|
||||
|
||||
// Clone clones and returns a WhereBuilder that is a copy of current one.
|
||||
func (b *WhereBuilder) Clone() *WhereBuilder {
|
||||
newBuilder := b.model.Builder()
|
||||
newBuilder.whereHolder = make([]WhereHolder, len(b.whereHolder))
|
||||
copy(newBuilder.whereHolder, b.whereHolder)
|
||||
return newBuilder
|
||||
}
|
||||
|
||||
// Build builds current WhereBuilder and returns the condition string and parameters.
|
||||
func (b *WhereBuilder) Build() (conditionWhere string, conditionArgs []any) {
|
||||
var (
|
||||
ctx = b.model.GetCtx()
|
||||
autoPrefix = b.model.getAutoPrefix()
|
||||
tableForMappingAndFiltering = b.model.tables
|
||||
)
|
||||
if len(b.whereHolder) > 0 {
|
||||
for _, holder := range b.whereHolder {
|
||||
if holder.Prefix == "" {
|
||||
holder.Prefix = autoPrefix
|
||||
}
|
||||
switch holder.Operator {
|
||||
case whereHolderOperatorWhere, whereHolderOperatorAnd:
|
||||
newWhere, newArgs := formatWhereHolder(ctx, b.model.db, formatWhereHolderInput{
|
||||
WhereHolder: holder,
|
||||
OmitNil: b.model.option&optionOmitNilWhere > 0,
|
||||
OmitEmpty: b.model.option&optionOmitEmptyWhere > 0,
|
||||
Schema: b.model.schema,
|
||||
Table: tableForMappingAndFiltering,
|
||||
})
|
||||
if len(newWhere) > 0 {
|
||||
if len(conditionWhere) == 0 {
|
||||
conditionWhere = newWhere
|
||||
} else if conditionWhere[0] == '(' {
|
||||
conditionWhere = fmt.Sprintf(`%s AND (%s)`, conditionWhere, newWhere)
|
||||
} else {
|
||||
conditionWhere = fmt.Sprintf(`(%s) AND (%s)`, conditionWhere, newWhere)
|
||||
}
|
||||
conditionArgs = append(conditionArgs, newArgs...)
|
||||
}
|
||||
|
||||
case whereHolderOperatorOr:
|
||||
newWhere, newArgs := formatWhereHolder(ctx, b.model.db, formatWhereHolderInput{
|
||||
WhereHolder: holder,
|
||||
OmitNil: b.model.option&optionOmitNilWhere > 0,
|
||||
OmitEmpty: b.model.option&optionOmitEmptyWhere > 0,
|
||||
Schema: b.model.schema,
|
||||
Table: tableForMappingAndFiltering,
|
||||
})
|
||||
if len(newWhere) > 0 {
|
||||
if len(conditionWhere) == 0 {
|
||||
conditionWhere = newWhere
|
||||
} else if conditionWhere[0] == '(' {
|
||||
conditionWhere = fmt.Sprintf(`%s OR (%s)`, conditionWhere, newWhere)
|
||||
} else {
|
||||
conditionWhere = fmt.Sprintf(`(%s) OR (%s)`, conditionWhere, newWhere)
|
||||
}
|
||||
conditionArgs = append(conditionArgs, newArgs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// convertWhereBuilder converts parameter `where` to condition string and parameters if `where` is also a WhereBuilder.
|
||||
func (b *WhereBuilder) convertWhereBuilder(where any, args []any) (newWhere any, newArgs []any) {
|
||||
var builder *WhereBuilder
|
||||
switch v := where.(type) {
|
||||
case WhereBuilder:
|
||||
builder = &v
|
||||
|
||||
case *WhereBuilder:
|
||||
builder = v
|
||||
}
|
||||
if builder != nil {
|
||||
conditionWhere, conditionArgs := builder.Build()
|
||||
if conditionWhere != "" && (len(b.whereHolder) == 0 || len(builder.whereHolder) > 1) {
|
||||
conditionWhere = "(" + conditionWhere + ")"
|
||||
}
|
||||
return conditionWhere, conditionArgs
|
||||
}
|
||||
return where, args
|
||||
}
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
)
|
||||
|
||||
// doWhereType sets the condition statement for the model. The parameter `where` can be type of
|
||||
// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times,
|
||||
// multiple conditions will be joined into where statement using "AND".
|
||||
func (b *WhereBuilder) doWhereType(whereType string, where any, args ...any) *WhereBuilder {
|
||||
where, args = b.convertWhereBuilder(where, args)
|
||||
|
||||
builder := b.getBuilder()
|
||||
if builder.whereHolder == nil {
|
||||
builder.whereHolder = make([]WhereHolder, 0)
|
||||
}
|
||||
if whereType == "" {
|
||||
if len(args) == 0 {
|
||||
whereType = whereHolderTypeNoArgs
|
||||
} else {
|
||||
whereType = whereHolderTypeDefault
|
||||
}
|
||||
}
|
||||
builder.whereHolder = append(builder.whereHolder, WhereHolder{
|
||||
Type: whereType,
|
||||
Operator: whereHolderOperatorWhere,
|
||||
Where: where,
|
||||
Args: args,
|
||||
})
|
||||
return builder
|
||||
}
|
||||
|
||||
// doWherefType builds condition string using fmt.Sprintf and arguments.
|
||||
// Note that if the number of `args` is more than the placeholder in `format`,
|
||||
// the extra `args` will be used as the where condition arguments of the Model.
|
||||
func (b *WhereBuilder) doWherefType(t string, format string, args ...any) *WhereBuilder {
|
||||
var (
|
||||
placeHolderCount = gstr.Count(format, "?")
|
||||
conditionStr = fmt.Sprintf(format, args[:len(args)-placeHolderCount]...)
|
||||
)
|
||||
return b.doWhereType(t, conditionStr, args[len(args)-placeHolderCount:]...)
|
||||
}
|
||||
|
||||
// Where sets the condition statement for the builder. The parameter `where` can be type of
|
||||
// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times,
|
||||
// multiple conditions will be joined into where statement using "AND".
|
||||
// Eg:
|
||||
// Where("uid=10000")
|
||||
// Where("uid", 10000)
|
||||
// Where("money>? AND name like ?", 99999, "vip_%")
|
||||
// Where("uid", 1).Where("name", "john")
|
||||
// Where("status IN (?)", g.Slice{1,2,3})
|
||||
// Where("age IN(?,?)", 18, 50)
|
||||
// Where(User{ Id : 1, UserName : "john"}).
|
||||
func (b *WhereBuilder) Where(where any, args ...any) *WhereBuilder {
|
||||
return b.doWhereType(``, where, args...)
|
||||
}
|
||||
|
||||
// Wheref builds condition string using fmt.Sprintf and arguments.
|
||||
// Note that if the number of `args` is more than the placeholder in `format`,
|
||||
// the extra `args` will be used as the where condition arguments of the Model.
|
||||
// Eg:
|
||||
// Wheref(`amount<? and status=%s`, "paid", 100) => WHERE `amount`<100 and status='paid'
|
||||
// Wheref(`amount<%d and status=%s`, 100, "paid") => WHERE `amount`<100 and status='paid'
|
||||
func (b *WhereBuilder) Wheref(format string, args ...any) *WhereBuilder {
|
||||
return b.doWherefType(``, format, args...)
|
||||
}
|
||||
|
||||
// WherePri does the same logic as Model.Where except that if the parameter `where`
|
||||
// is a single condition like int/string/float/slice, it treats the condition as the primary
|
||||
// key value. That is, if primary key is "id" and given `where` parameter as "123", the
|
||||
// WherePri function treats the condition as "id=123", but Model.Where treats the condition
|
||||
// as string "123".
|
||||
func (b *WhereBuilder) WherePri(where any, args ...any) *WhereBuilder {
|
||||
if len(args) > 0 {
|
||||
return b.Where(where, args...)
|
||||
}
|
||||
newWhere := GetPrimaryKeyCondition(b.model.getPrimaryKey(), where)
|
||||
return b.Where(newWhere[0], newWhere[1:]...)
|
||||
}
|
||||
|
||||
// WhereLT builds `column < value` statement.
|
||||
func (b *WhereBuilder) WhereLT(column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s < ?`, b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereLTE builds `column <= value` statement.
|
||||
func (b *WhereBuilder) WhereLTE(column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s <= ?`, b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereGT builds `column > value` statement.
|
||||
func (b *WhereBuilder) WhereGT(column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s > ?`, b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereGTE builds `column >= value` statement.
|
||||
func (b *WhereBuilder) WhereGTE(column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s >= ?`, b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereBetween builds `column BETWEEN min AND max` statement.
|
||||
func (b *WhereBuilder) WhereBetween(column string, min, max any) *WhereBuilder {
|
||||
return b.Wheref(`%s BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WhereLike builds `column LIKE like` statement.
|
||||
func (b *WhereBuilder) WhereLike(column string, like string) *WhereBuilder {
|
||||
return b.Wheref(`%s LIKE ?`, b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WhereIn builds `column IN (in)` statement.
|
||||
func (b *WhereBuilder) WhereIn(column string, in any) *WhereBuilder {
|
||||
return b.doWherefType(whereHolderTypeIn, `%s IN (?)`, b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WhereNull builds `columns[0] IS NULL AND columns[1] IS NULL ...` statement.
|
||||
func (b *WhereBuilder) WhereNull(columns ...string) *WhereBuilder {
|
||||
builder := b
|
||||
for _, column := range columns {
|
||||
builder = builder.Wheref(`%s IS NULL`, b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
||||
// WhereNotBetween builds `column NOT BETWEEN min AND max` statement.
|
||||
func (b *WhereBuilder) WhereNotBetween(column string, min, max any) *WhereBuilder {
|
||||
return b.Wheref(`%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WhereNotLike builds `column NOT LIKE like` statement.
|
||||
func (b *WhereBuilder) WhereNotLike(column string, like any) *WhereBuilder {
|
||||
return b.Wheref(`%s NOT LIKE ?`, b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WhereNot builds `column != value` statement.
|
||||
func (b *WhereBuilder) WhereNot(column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s != ?`, b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereNotIn builds `column NOT IN (in)` statement.
|
||||
func (b *WhereBuilder) WhereNotIn(column string, in any) *WhereBuilder {
|
||||
return b.doWherefType(whereHolderTypeIn, `%s NOT IN (?)`, b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WhereNotNull builds `columns[0] IS NOT NULL AND columns[1] IS NOT NULL ...` statement.
|
||||
func (b *WhereBuilder) WhereNotNull(columns ...string) *WhereBuilder {
|
||||
builder := b
|
||||
for _, column := range columns {
|
||||
builder = builder.Wheref(`%s IS NOT NULL`, b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
||||
// WhereExists builds `EXISTS (subQuery)` statement.
|
||||
func (b *WhereBuilder) WhereExists(subQuery *Model) *WhereBuilder {
|
||||
return b.Wheref(`EXISTS (?)`, subQuery)
|
||||
}
|
||||
|
||||
// WhereNotExists builds `NOT EXISTS (subQuery)` statement.
|
||||
func (b *WhereBuilder) WhereNotExists(subQuery *Model) *WhereBuilder {
|
||||
return b.Wheref(`NOT EXISTS (?)`, subQuery)
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// WherePrefix performs as Where, but it adds prefix to each field in where statement.
|
||||
// Eg:
|
||||
// WherePrefix("order", "status", "paid") => WHERE `order`.`status`='paid'
|
||||
// WherePrefix("order", struct{Status:"paid", "channel":"bank"}) => WHERE `order`.`status`='paid' AND `order`.`channel`='bank'
|
||||
func (b *WhereBuilder) WherePrefix(prefix string, where any, args ...any) *WhereBuilder {
|
||||
where, args = b.convertWhereBuilder(where, args)
|
||||
|
||||
builder := b.getBuilder()
|
||||
if builder.whereHolder == nil {
|
||||
builder.whereHolder = make([]WhereHolder, 0)
|
||||
}
|
||||
builder.whereHolder = append(builder.whereHolder, WhereHolder{
|
||||
Type: whereHolderTypeDefault,
|
||||
Operator: whereHolderOperatorWhere,
|
||||
Where: where,
|
||||
Args: args,
|
||||
Prefix: prefix,
|
||||
})
|
||||
return builder
|
||||
}
|
||||
|
||||
// WherePrefixLT builds `prefix.column < value` statement.
|
||||
func (b *WhereBuilder) WherePrefixLT(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s < ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WherePrefixLTE builds `prefix.column <= value` statement.
|
||||
func (b *WhereBuilder) WherePrefixLTE(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s <= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WherePrefixGT builds `prefix.column > value` statement.
|
||||
func (b *WhereBuilder) WherePrefixGT(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s > ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WherePrefixGTE builds `prefix.column >= value` statement.
|
||||
func (b *WhereBuilder) WherePrefixGTE(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s >= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WherePrefixBetween builds `prefix.column BETWEEN min AND max` statement.
|
||||
func (b *WhereBuilder) WherePrefixBetween(prefix string, column string, min, max any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WherePrefixLike builds `prefix.column LIKE like` statement.
|
||||
func (b *WhereBuilder) WherePrefixLike(prefix string, column string, like any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WherePrefixIn builds `prefix.column IN (in)` statement.
|
||||
func (b *WhereBuilder) WherePrefixIn(prefix string, column string, in any) *WhereBuilder {
|
||||
return b.doWherefType(whereHolderTypeIn, `%s.%s IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WherePrefixNull builds `prefix.columns[0] IS NULL AND prefix.columns[1] IS NULL ...` statement.
|
||||
func (b *WhereBuilder) WherePrefixNull(prefix string, columns ...string) *WhereBuilder {
|
||||
builder := b
|
||||
for _, column := range columns {
|
||||
builder = builder.Wheref(`%s.%s IS NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
||||
// WherePrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement.
|
||||
func (b *WhereBuilder) WherePrefixNotBetween(prefix string, column string, min, max any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WherePrefixNotLike builds `prefix.column NOT LIKE like` statement.
|
||||
func (b *WhereBuilder) WherePrefixNotLike(prefix string, column string, like any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s NOT LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WherePrefixNot builds `prefix.column != value` statement.
|
||||
func (b *WhereBuilder) WherePrefixNot(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.Wheref(`%s.%s != ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WherePrefixNotIn builds `prefix.column NOT IN (in)` statement.
|
||||
func (b *WhereBuilder) WherePrefixNotIn(prefix string, column string, in any) *WhereBuilder {
|
||||
return b.doWherefType(whereHolderTypeIn, `%s.%s NOT IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WherePrefixNotNull builds `prefix.columns[0] IS NOT NULL AND prefix.columns[1] IS NOT NULL ...` statement.
|
||||
func (b *WhereBuilder) WherePrefixNotNull(prefix string, columns ...string) *WhereBuilder {
|
||||
builder := b
|
||||
for _, column := range columns {
|
||||
builder = builder.Wheref(`%s.%s IS NOT NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
)
|
||||
|
||||
// WhereOr adds "OR" condition to the where statement.
|
||||
func (b *WhereBuilder) doWhereOrType(t string, where any, args ...any) *WhereBuilder {
|
||||
where, args = b.convertWhereBuilder(where, args)
|
||||
|
||||
builder := b.getBuilder()
|
||||
if builder.whereHolder == nil {
|
||||
builder.whereHolder = make([]WhereHolder, 0)
|
||||
}
|
||||
builder.whereHolder = append(builder.whereHolder, WhereHolder{
|
||||
Type: t,
|
||||
Operator: whereHolderOperatorOr,
|
||||
Where: where,
|
||||
Args: args,
|
||||
})
|
||||
return builder
|
||||
}
|
||||
|
||||
// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments.
|
||||
func (b *WhereBuilder) doWhereOrfType(t string, format string, args ...any) *WhereBuilder {
|
||||
var (
|
||||
placeHolderCount = gstr.Count(format, "?")
|
||||
conditionStr = fmt.Sprintf(format, args[:len(args)-placeHolderCount]...)
|
||||
)
|
||||
return b.doWhereOrType(t, conditionStr, args[len(args)-placeHolderCount:]...)
|
||||
}
|
||||
|
||||
// WhereOr adds "OR" condition to the where statement.
|
||||
func (b *WhereBuilder) WhereOr(where any, args ...any) *WhereBuilder {
|
||||
return b.doWhereOrType(``, where, args...)
|
||||
}
|
||||
|
||||
// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments.
|
||||
// Eg:
|
||||
// WhereOrf(`amount<? and status=%s`, "paid", 100) => WHERE xxx OR `amount`<100 and status='paid'
|
||||
// WhereOrf(`amount<%d and status=%s`, 100, "paid") => WHERE xxx OR `amount`<100 and status='paid'
|
||||
func (b *WhereBuilder) WhereOrf(format string, args ...any) *WhereBuilder {
|
||||
return b.doWhereOrfType(``, format, args...)
|
||||
}
|
||||
|
||||
// WhereOrNot builds `column != value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrNot(column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s != ?`, column, value)
|
||||
}
|
||||
|
||||
// WhereOrLT builds `column < value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrLT(column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s < ?`, column, value)
|
||||
}
|
||||
|
||||
// WhereOrLTE builds `column <= value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrLTE(column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s <= ?`, column, value)
|
||||
}
|
||||
|
||||
// WhereOrGT builds `column > value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrGT(column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s > ?`, column, value)
|
||||
}
|
||||
|
||||
// WhereOrGTE builds `column >= value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrGTE(column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s >= ?`, column, value)
|
||||
}
|
||||
|
||||
// WhereOrBetween builds `column BETWEEN min AND max` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrBetween(column string, min, max any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WhereOrLike builds `column LIKE 'like'` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrLike(column string, like any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s LIKE ?`, b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WhereOrIn builds `column IN (in)` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrIn(column string, in any) *WhereBuilder {
|
||||
return b.doWhereOrfType(whereHolderTypeIn, `%s IN (?)`, b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WhereOrNull builds `columns[0] IS NULL OR columns[1] IS NULL ...` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrNull(columns ...string) *WhereBuilder {
|
||||
var builder *WhereBuilder
|
||||
for _, column := range columns {
|
||||
builder = b.WhereOrf(`%s IS NULL`, b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
||||
// WhereOrNotBetween builds `column NOT BETWEEN min AND max` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrNotBetween(column string, min, max any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WhereOrNotLike builds `column NOT LIKE like` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrNotLike(column string, like any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s NOT LIKE ?`, b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WhereOrNotIn builds `column NOT IN (in)` statement.
|
||||
func (b *WhereBuilder) WhereOrNotIn(column string, in any) *WhereBuilder {
|
||||
return b.doWhereOrfType(whereHolderTypeIn, `%s NOT IN (?)`, b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WhereOrNotNull builds `columns[0] IS NOT NULL OR columns[1] IS NOT NULL ...` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrNotNull(columns ...string) *WhereBuilder {
|
||||
builder := b
|
||||
for _, column := range columns {
|
||||
builder = builder.WhereOrf(`%s IS NOT NULL`, b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// WhereOrPrefix performs as WhereOr, but it adds prefix to each field in where statement.
|
||||
// Eg:
|
||||
// WhereOrPrefix("order", "status", "paid") => WHERE xxx OR (`order`.`status`='paid')
|
||||
// WhereOrPrefix("order", struct{Status:"paid", "channel":"bank"}) => WHERE xxx OR (`order`.`status`='paid' AND `order`.`channel`='bank')
|
||||
func (b *WhereBuilder) WhereOrPrefix(prefix string, where any, args ...any) *WhereBuilder {
|
||||
where, args = b.convertWhereBuilder(where, args)
|
||||
|
||||
builder := b.getBuilder()
|
||||
builder.whereHolder = append(builder.whereHolder, WhereHolder{
|
||||
Type: whereHolderTypeDefault,
|
||||
Operator: whereHolderOperatorOr,
|
||||
Where: where,
|
||||
Args: args,
|
||||
Prefix: prefix,
|
||||
})
|
||||
return builder
|
||||
}
|
||||
|
||||
// WhereOrPrefixNot builds `prefix.column != value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixNot(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s != ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereOrPrefixLT builds `prefix.column < value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixLT(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s < ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereOrPrefixLTE builds `prefix.column <= value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixLTE(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s <= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereOrPrefixGT builds `prefix.column > value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixGT(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s > ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereOrPrefixGTE builds `prefix.column >= value` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixGTE(prefix string, column string, value any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s >= ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), value)
|
||||
}
|
||||
|
||||
// WhereOrPrefixBetween builds `prefix.column BETWEEN min AND max` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixBetween(prefix string, column string, min, max any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WhereOrPrefixLike builds `prefix.column LIKE 'like'` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixLike(prefix string, column string, like any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WhereOrPrefixIn builds `prefix.column IN (in)` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixIn(prefix string, column string, in any) *WhereBuilder {
|
||||
return b.doWhereOrfType(whereHolderTypeIn, `%s.%s IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WhereOrPrefixNull builds `prefix.columns[0] IS NULL OR prefix.columns[1] IS NULL ...` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixNull(prefix string, columns ...string) *WhereBuilder {
|
||||
builder := b
|
||||
for _, column := range columns {
|
||||
builder = builder.WhereOrf(`%s.%s IS NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixNotBetween(prefix string, column string, min, max any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s NOT BETWEEN ? AND ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), min, max)
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotLike builds `prefix.column NOT LIKE 'like'` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixNotLike(prefix string, column string, like any) *WhereBuilder {
|
||||
return b.WhereOrf(`%s.%s NOT LIKE ?`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), like)
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotIn builds `prefix.column NOT IN (in)` statement.
|
||||
func (b *WhereBuilder) WhereOrPrefixNotIn(prefix string, column string, in any) *WhereBuilder {
|
||||
return b.doWhereOrfType(whereHolderTypeIn, `%s.%s NOT IN (?)`, b.model.QuoteWord(prefix), b.model.QuoteWord(column), in)
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotNull builds `prefix.columns[0] IS NOT NULL OR prefix.columns[1] IS NOT NULL ...` statement in `OR` conditions.
|
||||
func (b *WhereBuilder) WhereOrPrefixNotNull(prefix string, columns ...string) *WhereBuilder {
|
||||
builder := b
|
||||
for _, column := range columns {
|
||||
builder = builder.WhereOrf(`%s.%s IS NOT NULL`, b.model.QuoteWord(prefix), b.model.QuoteWord(column))
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
)
|
||||
|
||||
// CacheOption is options for model cache control in query.
|
||||
type CacheOption struct {
|
||||
// Duration is the TTL for the cache.
|
||||
// If the parameter `Duration` < 0, which means it clear the cache with given `Name`.
|
||||
// If the parameter `Duration` = 0, which means it never expires.
|
||||
// If the parameter `Duration` > 0, which means it expires after `Duration`.
|
||||
Duration time.Duration
|
||||
|
||||
// Name is an optional unique name for the cache.
|
||||
// The Name is used to bind a name to the cache, which means you can later control the cache
|
||||
// like changing the `duration` or clearing the cache with specified Name.
|
||||
Name string
|
||||
|
||||
// Force caches the query result whatever the result is nil or not.
|
||||
// It is used to avoid Cache Penetration.
|
||||
Force bool
|
||||
}
|
||||
|
||||
// selectCacheItem is the cache item for SELECT statement result.
|
||||
type selectCacheItem struct {
|
||||
Result Result // Sql result of SELECT statement.
|
||||
FirstResultColumn string // The first column name of result, for Value/Count functions.
|
||||
}
|
||||
|
||||
// Cache sets the cache feature for the model. It caches the result of the sql, which means
|
||||
// if there's another same sql request, it just reads and returns the result from cache, it
|
||||
// but not committed and executed into the database.
|
||||
//
|
||||
// Note that, the cache feature is disabled if the model is performing select statement
|
||||
// on a transaction.
|
||||
func (m *Model) Cache(option CacheOption) *Model {
|
||||
model := m.getModel()
|
||||
model.cacheOption = option
|
||||
model.cacheEnabled = true
|
||||
return model
|
||||
}
|
||||
|
||||
// PageCache sets the cache feature for pagination queries. It allows to configure
|
||||
// separate cache options for count query and data query in pagination.
|
||||
//
|
||||
// Note that, the cache feature is disabled if the model is performing select statement
|
||||
// on a transaction.
|
||||
func (m *Model) PageCache(countOption CacheOption, dataOption CacheOption) *Model {
|
||||
model := m.getModel()
|
||||
model.pageCacheOption = []CacheOption{countOption, dataOption}
|
||||
model.cacheEnabled = true
|
||||
return model
|
||||
}
|
||||
|
||||
// checkAndRemoveSelectCache checks and removes the cache in insert/update/delete statement if
|
||||
// cache feature is enabled.
|
||||
func (m *Model) checkAndRemoveSelectCache(ctx context.Context) {
|
||||
if m.cacheEnabled && m.cacheOption.Duration < 0 && len(m.cacheOption.Name) > 0 {
|
||||
var cacheKey = m.makeSelectCacheKey("")
|
||||
if _, err := m.db.GetCache().Remove(ctx, cacheKey); err != nil {
|
||||
intlog.Errorf(ctx, `%+v`, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args ...any) (result Result, err error) {
|
||||
if !m.cacheEnabled || m.tx != nil {
|
||||
return
|
||||
}
|
||||
var (
|
||||
cacheItem *selectCacheItem
|
||||
cacheKey = m.makeSelectCacheKey(sql, args...)
|
||||
cacheObj = m.db.GetCache()
|
||||
core = m.db.GetCore()
|
||||
)
|
||||
defer func() {
|
||||
if cacheItem != nil {
|
||||
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
|
||||
if cacheItem.FirstResultColumn != "" {
|
||||
internalData.FirstResultColumn = cacheItem.FirstResultColumn
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
if v, _ := cacheObj.Get(ctx, cacheKey); !v.IsNil() {
|
||||
if err = v.Scan(&cacheItem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cacheItem.Result, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) saveSelectResultToCache(
|
||||
ctx context.Context, selectType SelectType, result Result, sql string, args ...any,
|
||||
) (err error) {
|
||||
if !m.cacheEnabled || m.tx != nil {
|
||||
return
|
||||
}
|
||||
var (
|
||||
cacheKey = m.makeSelectCacheKey(sql, args...)
|
||||
cacheObj = m.db.GetCache()
|
||||
)
|
||||
if m.cacheOption.Duration < 0 {
|
||||
if _, errCache := cacheObj.Remove(ctx, cacheKey); errCache != nil {
|
||||
intlog.Errorf(ctx, `%+v`, errCache)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Special handler for Value/Count operations result.
|
||||
if len(result) > 0 {
|
||||
var core = m.db.GetCore()
|
||||
switch selectType {
|
||||
case SelectTypeValue, SelectTypeArray, SelectTypeCount:
|
||||
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
|
||||
if result[0][internalData.FirstResultColumn].IsEmpty() {
|
||||
result = nil
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// In case of Cache Penetration.
|
||||
if result != nil && result.IsEmpty() {
|
||||
if m.cacheOption.Force {
|
||||
result = Result{}
|
||||
} else {
|
||||
result = nil
|
||||
}
|
||||
}
|
||||
var (
|
||||
core = m.db.GetCore()
|
||||
cacheItem = &selectCacheItem{
|
||||
Result: result,
|
||||
}
|
||||
)
|
||||
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
|
||||
cacheItem.FirstResultColumn = internalData.FirstResultColumn
|
||||
}
|
||||
if errCache := cacheObj.Set(ctx, cacheKey, cacheItem, m.cacheOption.Duration); errCache != nil {
|
||||
intlog.Errorf(ctx, `%+v`, errCache)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) makeSelectCacheKey(sql string, args ...any) string {
|
||||
var (
|
||||
table = m.db.GetCore().guessPrimaryTableName(m.tables)
|
||||
group = m.db.GetGroup()
|
||||
schema = m.db.GetSchema()
|
||||
customName = m.cacheOption.Name
|
||||
)
|
||||
return genSelectCacheKey(
|
||||
table,
|
||||
group,
|
||||
schema,
|
||||
customName,
|
||||
sql,
|
||||
args...,
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/internal/intlog"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
)
|
||||
|
||||
// Delete does "DELETE FROM ... " statement for the model.
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) Delete(where ...any) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).Delete()
|
||||
}
|
||||
defer func() {
|
||||
if err == nil {
|
||||
m.checkAndRemoveSelectCache(ctx)
|
||||
}
|
||||
}()
|
||||
var (
|
||||
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
|
||||
conditionStr = conditionWhere + conditionExtra
|
||||
fieldNameDelete, fieldTypeDelete = m.softTimeMaintainer().GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete)
|
||||
)
|
||||
if m.unscoped {
|
||||
fieldNameDelete = ""
|
||||
}
|
||||
if !gstr.ContainsI(conditionStr, " WHERE ") || (fieldNameDelete != "" && !gstr.ContainsI(conditionStr, " AND ")) {
|
||||
intlog.Printf(
|
||||
ctx,
|
||||
`sql condition string "%s" has no WHERE for DELETE operation, fieldNameDelete: %s`,
|
||||
conditionStr, fieldNameDelete,
|
||||
)
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeMissingParameter,
|
||||
"there should be WHERE condition statement for DELETE operation",
|
||||
)
|
||||
}
|
||||
|
||||
// Soft deleting.
|
||||
if fieldNameDelete != "" {
|
||||
dataHolder, dataValue := m.softTimeMaintainer().GetDeleteData(
|
||||
ctx, "", fieldNameDelete, fieldTypeDelete,
|
||||
)
|
||||
in := &HookUpdateInput{
|
||||
internalParamHookUpdate: internalParamHookUpdate{
|
||||
internalParamHook: internalParamHook{
|
||||
link: m.getLink(true),
|
||||
},
|
||||
handler: m.hookHandler.Update,
|
||||
},
|
||||
Model: m,
|
||||
Table: m.tables,
|
||||
Schema: m.schema,
|
||||
Data: dataHolder,
|
||||
Condition: conditionStr,
|
||||
Args: append([]any{dataValue}, conditionArgs...),
|
||||
}
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
||||
in := &HookDeleteInput{
|
||||
internalParamHookDelete: internalParamHookDelete{
|
||||
internalParamHook: internalParamHook{
|
||||
link: m.getLink(true),
|
||||
},
|
||||
handler: m.hookHandler.Delete,
|
||||
},
|
||||
Model: m,
|
||||
Table: m.tables,
|
||||
Schema: m.schema,
|
||||
Condition: conditionStr,
|
||||
Args: conditionArgs,
|
||||
}
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gset"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// Fields appends `fieldNamesOrMapStruct` to the operation fields of the model, multiple fields joined using char ','.
|
||||
// The parameter `fieldNamesOrMapStruct` can be type of string/map/*map/struct/*struct.
|
||||
//
|
||||
// Example:
|
||||
// Fields("id", "name", "age")
|
||||
// Fields([]string{"id", "name", "age"})
|
||||
// Fields(map[string]any{"id":1, "name":"john", "age":18})
|
||||
// Fields(User{Id: 1, Name: "john", Age: 18}).
|
||||
func (m *Model) Fields(fieldNamesOrMapStruct ...any) *Model {
|
||||
length := len(fieldNamesOrMapStruct)
|
||||
if length == 0 {
|
||||
return m
|
||||
}
|
||||
fields := m.filterFieldsFrom(m.tablesInit, fieldNamesOrMapStruct...)
|
||||
if len(fields) == 0 {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
return model.appendToFields(fields...)
|
||||
}
|
||||
|
||||
// FieldsPrefix performs as function Fields but add extra prefix for each field.
|
||||
func (m *Model) FieldsPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...any) *Model {
|
||||
fields := m.filterFieldsFrom(
|
||||
m.getTableNameByPrefixOrAlias(prefixOrAlias),
|
||||
fieldNamesOrMapStruct...,
|
||||
)
|
||||
if len(fields) == 0 {
|
||||
return m
|
||||
}
|
||||
prefixOrAlias = m.QuoteWord(prefixOrAlias)
|
||||
for i, field := range fields {
|
||||
fields[i] = fmt.Sprintf("%s.%s", prefixOrAlias, m.QuoteWord(gconv.String(field)))
|
||||
}
|
||||
model := m.getModel()
|
||||
return model.appendToFields(fields...)
|
||||
}
|
||||
|
||||
// FieldsEx appends `fieldNamesOrMapStruct` to the excluded operation fields of the model,
|
||||
// multiple fields joined using char ','.
|
||||
// Note that this function supports only single table operations.
|
||||
// The parameter `fieldNamesOrMapStruct` can be type of string/map/*map/struct/*struct.
|
||||
//
|
||||
// Example:
|
||||
// FieldsEx("id", "name", "age")
|
||||
// FieldsEx([]string{"id", "name", "age"})
|
||||
// FieldsEx(map[string]any{"id":1, "name":"john", "age":18})
|
||||
// FieldsEx(User{Id: 1, Name: "john", Age: 18}).
|
||||
func (m *Model) FieldsEx(fieldNamesOrMapStruct ...any) *Model {
|
||||
return m.doFieldsEx(m.tablesInit, fieldNamesOrMapStruct...)
|
||||
}
|
||||
|
||||
func (m *Model) doFieldsEx(table string, fieldNamesOrMapStruct ...any) *Model {
|
||||
length := len(fieldNamesOrMapStruct)
|
||||
if length == 0 {
|
||||
return m
|
||||
}
|
||||
fields := m.filterFieldsFrom(table, fieldNamesOrMapStruct...)
|
||||
if len(fields) == 0 {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
model.fieldsEx = append(model.fieldsEx, fields...)
|
||||
return model
|
||||
}
|
||||
|
||||
// FieldsExPrefix performs as function FieldsEx but add extra prefix for each field.
|
||||
// Note that this function must be used together with FieldsPrefix, otherwise it will be invalid.
|
||||
func (m *Model) FieldsExPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...any) *Model {
|
||||
fields := m.filterFieldsFrom(
|
||||
m.getTableNameByPrefixOrAlias(prefixOrAlias),
|
||||
fieldNamesOrMapStruct...,
|
||||
)
|
||||
if len(fields) == 0 {
|
||||
return m
|
||||
}
|
||||
prefixOrAlias = m.QuoteWord(prefixOrAlias)
|
||||
for i, field := range fields {
|
||||
fields[i] = fmt.Sprintf("%s.%s", prefixOrAlias, m.QuoteWord(gconv.String(field)))
|
||||
}
|
||||
model := m.getModel()
|
||||
model.fieldsEx = append(model.fieldsEx, fields...)
|
||||
return model
|
||||
}
|
||||
|
||||
// FieldCount formats and appends commonly used field `COUNT(column)` to the select fields of model.
|
||||
func (m *Model) FieldCount(column string, as ...string) *Model {
|
||||
asStr := ""
|
||||
if len(as) > 0 && as[0] != "" {
|
||||
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
|
||||
}
|
||||
model := m.getModel()
|
||||
return model.appendToFields(
|
||||
fmt.Sprintf(`COUNT(%s)%s`, m.QuoteWord(column), asStr),
|
||||
)
|
||||
}
|
||||
|
||||
// FieldSum formats and appends commonly used field `SUM(column)` to the select fields of model.
|
||||
func (m *Model) FieldSum(column string, as ...string) *Model {
|
||||
asStr := ""
|
||||
if len(as) > 0 && as[0] != "" {
|
||||
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
|
||||
}
|
||||
model := m.getModel()
|
||||
return model.appendToFields(
|
||||
fmt.Sprintf(`SUM(%s)%s`, m.QuoteWord(column), asStr),
|
||||
)
|
||||
}
|
||||
|
||||
// FieldMin formats and appends commonly used field `MIN(column)` to the select fields of model.
|
||||
func (m *Model) FieldMin(column string, as ...string) *Model {
|
||||
asStr := ""
|
||||
if len(as) > 0 && as[0] != "" {
|
||||
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
|
||||
}
|
||||
model := m.getModel()
|
||||
return model.appendToFields(
|
||||
fmt.Sprintf(`MIN(%s)%s`, m.QuoteWord(column), asStr),
|
||||
)
|
||||
}
|
||||
|
||||
// FieldMax formats and appends commonly used field `MAX(column)` to the select fields of model.
|
||||
func (m *Model) FieldMax(column string, as ...string) *Model {
|
||||
asStr := ""
|
||||
if len(as) > 0 && as[0] != "" {
|
||||
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
|
||||
}
|
||||
model := m.getModel()
|
||||
return model.appendToFields(
|
||||
fmt.Sprintf(`MAX(%s)%s`, m.QuoteWord(column), asStr),
|
||||
)
|
||||
}
|
||||
|
||||
// FieldAvg formats and appends commonly used field `AVG(column)` to the select fields of model.
|
||||
func (m *Model) FieldAvg(column string, as ...string) *Model {
|
||||
asStr := ""
|
||||
if len(as) > 0 && as[0] != "" {
|
||||
asStr = fmt.Sprintf(` AS %s`, m.QuoteWord(as[0]))
|
||||
}
|
||||
model := m.getModel()
|
||||
return model.appendToFields(
|
||||
fmt.Sprintf(`AVG(%s)%s`, m.QuoteWord(column), asStr),
|
||||
)
|
||||
}
|
||||
|
||||
// GetFieldsStr retrieves and returns all fields from the table, joined with char ','.
|
||||
// The optional parameter `prefix` specifies the prefix for each field, eg: GetFieldsStr("u.").
|
||||
func (m *Model) GetFieldsStr(prefix ...string) string {
|
||||
prefixStr := ""
|
||||
if len(prefix) > 0 {
|
||||
prefixStr = prefix[0]
|
||||
}
|
||||
tableFields, err := m.TableFields(m.tablesInit)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(tableFields) == 0 {
|
||||
panic(fmt.Sprintf(`empty table fields for table "%s"`, m.tables))
|
||||
}
|
||||
fieldsArray := make([]string, len(tableFields))
|
||||
for k, v := range tableFields {
|
||||
fieldsArray[v.Index] = k
|
||||
}
|
||||
newFields := ""
|
||||
for _, k := range fieldsArray {
|
||||
if len(newFields) > 0 {
|
||||
newFields += ","
|
||||
}
|
||||
newFields += prefixStr + k
|
||||
}
|
||||
newFields = m.db.GetCore().QuoteString(newFields)
|
||||
return newFields
|
||||
}
|
||||
|
||||
// GetFieldsExStr retrieves and returns fields which are not in parameter `fields` from the table,
|
||||
// joined with char ','.
|
||||
// The parameter `fields` specifies the fields that are excluded.
|
||||
// The optional parameter `prefix` specifies the prefix for each field, eg: FieldsExStr("id", "u.").
|
||||
func (m *Model) GetFieldsExStr(fields string, prefix ...string) (string, error) {
|
||||
prefixStr := ""
|
||||
if len(prefix) > 0 {
|
||||
prefixStr = prefix[0]
|
||||
}
|
||||
tableFields, err := m.TableFields(m.tablesInit)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(tableFields) == 0 {
|
||||
return "", gerror.Newf(`empty table fields for table "%s"`, m.tables)
|
||||
}
|
||||
fieldsExSet := gset.NewStrSetFrom(gstr.SplitAndTrim(fields, ","))
|
||||
fieldsArray := make([]string, len(tableFields))
|
||||
for k, v := range tableFields {
|
||||
fieldsArray[v.Index] = k
|
||||
}
|
||||
newFields := ""
|
||||
for _, k := range fieldsArray {
|
||||
if fieldsExSet.Contains(k) {
|
||||
continue
|
||||
}
|
||||
if len(newFields) > 0 {
|
||||
newFields += ","
|
||||
}
|
||||
newFields += prefixStr + k
|
||||
}
|
||||
newFields = m.db.GetCore().QuoteString(newFields)
|
||||
return newFields, nil
|
||||
}
|
||||
|
||||
// HasField determine whether the field exists in the table.
|
||||
func (m *Model) HasField(field string) (bool, error) {
|
||||
return m.db.GetCore().HasField(m.GetCtx(), m.tablesInit, field)
|
||||
}
|
||||
|
||||
// getFieldsFrom retrieves, filters and returns fields name from table `table`.
|
||||
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...any) []any {
|
||||
length := len(fieldNamesOrMapStruct)
|
||||
if length == 0 {
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
// String slice.
|
||||
case length >= 2:
|
||||
return m.mappingAndFilterToTableFields(
|
||||
table, fieldNamesOrMapStruct, true,
|
||||
)
|
||||
|
||||
// It needs type asserting.
|
||||
case length == 1:
|
||||
structOrMap := fieldNamesOrMapStruct[0]
|
||||
switch r := structOrMap.(type) {
|
||||
case string:
|
||||
return m.mappingAndFilterToTableFields(table, []any{r}, false)
|
||||
|
||||
case []string:
|
||||
return m.mappingAndFilterToTableFields(table, gconv.Interfaces(r), true)
|
||||
|
||||
case Raw, *Raw:
|
||||
return []any{structOrMap}
|
||||
|
||||
default:
|
||||
return m.mappingAndFilterToTableFields(table, getFieldsFromStructOrMap(structOrMap), true)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) appendToFields(fields ...any) *Model {
|
||||
if len(fields) == 0 {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
model.fields = append(model.fields, fields...)
|
||||
return model
|
||||
}
|
||||
|
||||
func (m *Model) isFieldInFieldsEx(field string) bool {
|
||||
for _, v := range m.fieldsEx {
|
||||
if v == field {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,309 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
)
|
||||
|
||||
type (
|
||||
HookFuncSelect func(ctx context.Context, in *HookSelectInput) (result Result, err error)
|
||||
HookFuncInsert func(ctx context.Context, in *HookInsertInput) (result sql.Result, err error)
|
||||
HookFuncUpdate func(ctx context.Context, in *HookUpdateInput) (result sql.Result, err error)
|
||||
HookFuncDelete func(ctx context.Context, in *HookDeleteInput) (result sql.Result, err error)
|
||||
)
|
||||
|
||||
// HookHandler manages all supported hook functions for Model.
|
||||
type HookHandler struct {
|
||||
Select HookFuncSelect
|
||||
Insert HookFuncInsert
|
||||
Update HookFuncUpdate
|
||||
Delete HookFuncDelete
|
||||
}
|
||||
|
||||
// internalParamHook manages all internal parameters for hook operations.
|
||||
// The `internal` obviously means you cannot access these parameters outside this package.
|
||||
type internalParamHook struct {
|
||||
link Link // Connection object from third party sql driver.
|
||||
handlerCalled bool // Simple mark for custom handler called, in case of recursive calling.
|
||||
removedWhere bool // Removed mark for condition string that was removed `WHERE` prefix.
|
||||
originalTableName *gvar.Var // The original table name.
|
||||
originalSchemaName *gvar.Var // The original schema name.
|
||||
}
|
||||
|
||||
type internalParamHookSelect struct {
|
||||
internalParamHook
|
||||
handler HookFuncSelect
|
||||
}
|
||||
|
||||
type internalParamHookInsert struct {
|
||||
internalParamHook
|
||||
handler HookFuncInsert
|
||||
}
|
||||
|
||||
type internalParamHookUpdate struct {
|
||||
internalParamHook
|
||||
handler HookFuncUpdate
|
||||
}
|
||||
|
||||
type internalParamHookDelete struct {
|
||||
internalParamHook
|
||||
handler HookFuncDelete
|
||||
}
|
||||
|
||||
// HookSelectInput holds the parameters for select hook operation.
|
||||
// Note that, COUNT statement will also be hooked by this feature,
|
||||
// which is usually not be interesting for upper business hook handler.
|
||||
type HookSelectInput struct {
|
||||
internalParamHookSelect
|
||||
Model *Model // Current operation Model.
|
||||
Table string // The table name that to be used. Update this attribute to change target table name.
|
||||
Schema string // The schema name that to be used. Update this attribute to change target schema name.
|
||||
Sql string // The sql string that to be committed.
|
||||
Args []any // The arguments of sql.
|
||||
SelectType SelectType // The type of this SELECT operation.
|
||||
}
|
||||
|
||||
// HookInsertInput holds the parameters for insert hook operation.
|
||||
type HookInsertInput struct {
|
||||
internalParamHookInsert
|
||||
Model *Model // Current operation Model.
|
||||
Table string // The table name that to be used. Update this attribute to change target table name.
|
||||
Schema string // The schema name that to be used. Update this attribute to change target schema name.
|
||||
Data List // The data records list to be inserted/saved into table.
|
||||
Option DoInsertOption // The extra option for data inserting.
|
||||
}
|
||||
|
||||
// HookUpdateInput holds the parameters for update hook operation.
|
||||
type HookUpdateInput struct {
|
||||
internalParamHookUpdate
|
||||
Model *Model // Current operation Model.
|
||||
Table string // The table name that to be used. Update this attribute to change target table name.
|
||||
Schema string // The schema name that to be used. Update this attribute to change target schema name.
|
||||
Data any // Data can be type of: map[string]any/string. You can use type assertion on `Data`.
|
||||
Condition string // The where condition string for updating.
|
||||
Args []any // The arguments for sql place-holders.
|
||||
}
|
||||
|
||||
// HookDeleteInput holds the parameters for delete hook operation.
|
||||
type HookDeleteInput struct {
|
||||
internalParamHookDelete
|
||||
Model *Model // Current operation Model.
|
||||
Table string // The table name that to be used. Update this attribute to change target table name.
|
||||
Schema string // The schema name that to be used. Update this attribute to change target schema name.
|
||||
Condition string // The where condition string for deleting.
|
||||
Args []any // The arguments for sql place-holders.
|
||||
}
|
||||
|
||||
const (
|
||||
whereKeyInCondition = " WHERE "
|
||||
)
|
||||
|
||||
// IsTransaction checks and returns whether current operation is during transaction.
|
||||
func (h *internalParamHook) IsTransaction() bool {
|
||||
return h.link.IsTransaction()
|
||||
}
|
||||
|
||||
// Next calls the next hook handler.
|
||||
func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) {
|
||||
if h.originalTableName.IsNil() {
|
||||
h.originalTableName = gvar.New(h.Table)
|
||||
}
|
||||
if h.originalSchemaName.IsNil() {
|
||||
h.originalSchemaName = gvar.New(h.Schema)
|
||||
}
|
||||
|
||||
// Sharding feature.
|
||||
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.Table, err = h.Model.getActualTable(ctx, h.Table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Custom hook handler call.
|
||||
if h.handler != nil && !h.handlerCalled {
|
||||
h.handlerCalled = true
|
||||
return h.handler(ctx, h)
|
||||
}
|
||||
var toBeCommittedSql = h.Sql
|
||||
// Table change.
|
||||
if h.Table != h.originalTableName.String() {
|
||||
toBeCommittedSql, err = gregex.ReplaceStringFuncMatch(
|
||||
`(?i) FROM ([\S]+)`,
|
||||
toBeCommittedSql,
|
||||
func(match []string) string {
|
||||
charL, charR := h.Model.db.GetChars()
|
||||
return fmt.Sprintf(` FROM %s%s%s`, charL, h.Table, charR)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Schema change.
|
||||
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
|
||||
h.link, err = h.Model.db.GetCore().SlaveLink(h.Schema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.Model.db.GetCore().schema = h.Schema
|
||||
defer func() {
|
||||
h.Model.db.GetCore().schema = h.originalSchemaName.String()
|
||||
}()
|
||||
}
|
||||
return h.Model.db.DoSelect(ctx, h.link, toBeCommittedSql, h.Args...)
|
||||
}
|
||||
|
||||
// Next calls the next hook handler.
|
||||
func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) {
|
||||
if h.originalTableName.IsNil() {
|
||||
h.originalTableName = gvar.New(h.Table)
|
||||
}
|
||||
if h.originalSchemaName.IsNil() {
|
||||
h.originalSchemaName = gvar.New(h.Schema)
|
||||
}
|
||||
|
||||
// Sharding feature.
|
||||
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.Table, err = h.Model.getActualTable(ctx, h.Table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.handler != nil && !h.handlerCalled {
|
||||
h.handlerCalled = true
|
||||
return h.handler(ctx, h)
|
||||
}
|
||||
|
||||
// No need to handle table change.
|
||||
|
||||
// Schema change.
|
||||
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
|
||||
h.link, err = h.Model.db.GetCore().MasterLink(h.Schema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.Model.db.GetCore().schema = h.Schema
|
||||
defer func() {
|
||||
h.Model.db.GetCore().schema = h.originalSchemaName.String()
|
||||
}()
|
||||
}
|
||||
return h.Model.db.DoInsert(ctx, h.link, h.Table, h.Data, h.Option)
|
||||
}
|
||||
|
||||
// Next calls the next hook handler.
|
||||
func (h *HookUpdateInput) Next(ctx context.Context) (result sql.Result, err error) {
|
||||
if h.originalTableName.IsNil() {
|
||||
h.originalTableName = gvar.New(h.Table)
|
||||
}
|
||||
if h.originalSchemaName.IsNil() {
|
||||
h.originalSchemaName = gvar.New(h.Schema)
|
||||
}
|
||||
|
||||
// Sharding feature.
|
||||
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.Table, err = h.Model.getActualTable(ctx, h.Table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.handler != nil && !h.handlerCalled {
|
||||
h.handlerCalled = true
|
||||
if gstr.HasPrefix(h.Condition, whereKeyInCondition) {
|
||||
h.removedWhere = true
|
||||
h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition)
|
||||
}
|
||||
return h.handler(ctx, h)
|
||||
}
|
||||
if h.removedWhere {
|
||||
h.Condition = whereKeyInCondition + h.Condition
|
||||
}
|
||||
|
||||
// No need to handle table change.
|
||||
|
||||
// Schema change.
|
||||
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
|
||||
h.link, err = h.Model.db.GetCore().MasterLink(h.Schema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.Model.db.GetCore().schema = h.Schema
|
||||
defer func() {
|
||||
h.Model.db.GetCore().schema = h.originalSchemaName.String()
|
||||
}()
|
||||
}
|
||||
return h.Model.db.DoUpdate(ctx, h.link, h.Table, h.Data, h.Condition, h.Args...)
|
||||
}
|
||||
|
||||
// Next calls the next hook handler.
|
||||
func (h *HookDeleteInput) Next(ctx context.Context) (result sql.Result, err error) {
|
||||
if h.originalTableName.IsNil() {
|
||||
h.originalTableName = gvar.New(h.Table)
|
||||
}
|
||||
if h.originalSchemaName.IsNil() {
|
||||
h.originalSchemaName = gvar.New(h.Schema)
|
||||
}
|
||||
|
||||
// Sharding feature.
|
||||
h.Schema, err = h.Model.getActualSchema(ctx, h.Schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.Table, err = h.Model.getActualTable(ctx, h.Table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.handler != nil && !h.handlerCalled {
|
||||
h.handlerCalled = true
|
||||
if gstr.HasPrefix(h.Condition, whereKeyInCondition) {
|
||||
h.removedWhere = true
|
||||
h.Condition = gstr.TrimLeftStr(h.Condition, whereKeyInCondition)
|
||||
}
|
||||
return h.handler(ctx, h)
|
||||
}
|
||||
if h.removedWhere {
|
||||
h.Condition = whereKeyInCondition + h.Condition
|
||||
}
|
||||
|
||||
// No need to handle table change.
|
||||
|
||||
// Schema change.
|
||||
if h.Schema != "" && h.Schema != h.originalSchemaName.String() {
|
||||
h.link, err = h.Model.db.GetCore().MasterLink(h.Schema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.Model.db.GetCore().schema = h.Schema
|
||||
defer func() {
|
||||
h.Model.db.GetCore().schema = h.originalSchemaName.String()
|
||||
}()
|
||||
}
|
||||
return h.Model.db.DoDelete(ctx, h.link, h.Table, h.Condition, h.Args...)
|
||||
}
|
||||
|
||||
// Hook sets the hook functions for current model.
|
||||
func (m *Model) Hook(hook HookHandler) *Model {
|
||||
model := m.getModel()
|
||||
model.hookHandler = hook
|
||||
return model
|
||||
}
|
||||
|
|
@ -0,0 +1,469 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/empty"
|
||||
"git.magicany.cc/black1552/gin-base/database/reflection"
|
||||
"github.com/gogf/gf/v2/container/gset"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// Batch sets the batch operation number for the model.
|
||||
func (m *Model) Batch(batch int) *Model {
|
||||
model := m.getModel()
|
||||
model.batch = batch
|
||||
return model
|
||||
}
|
||||
|
||||
// Data sets the operation data for the model.
|
||||
// The parameter `data` can be type of string/map/gmap/slice/struct/*struct, etc.
|
||||
// Note that, it uses shallow value copying for `data` if `data` is type of map/slice
|
||||
// to avoid changing it inside function.
|
||||
// Eg:
|
||||
// Data("uid=10000")
|
||||
// Data("uid", 10000)
|
||||
// Data("uid=? AND name=?", 10000, "john")
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}).
|
||||
func (m *Model) Data(data ...any) *Model {
|
||||
var model = m.getModel()
|
||||
if len(data) > 1 {
|
||||
if s := gconv.String(data[0]); gstr.Contains(s, "?") {
|
||||
model.data = s
|
||||
model.extraArgs = data[1:]
|
||||
} else {
|
||||
newData := make(map[string]any)
|
||||
for i := 0; i < len(data); i += 2 {
|
||||
newData[gconv.String(data[i])] = data[i+1]
|
||||
}
|
||||
model.data = newData
|
||||
}
|
||||
} else if len(data) == 1 {
|
||||
switch value := data[0].(type) {
|
||||
case Result:
|
||||
model.data = value.List()
|
||||
|
||||
case Record:
|
||||
model.data = value.Map()
|
||||
|
||||
case List:
|
||||
list := make(List, len(value))
|
||||
for k, v := range value {
|
||||
list[k] = gutil.MapCopy(v)
|
||||
}
|
||||
model.data = list
|
||||
|
||||
case Map:
|
||||
model.data = gutil.MapCopy(value)
|
||||
|
||||
default:
|
||||
reflectInfo := reflection.OriginValueAndKind(value)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if reflectInfo.OriginValue.Len() > 0 {
|
||||
// If the `data` parameter is a DO struct,
|
||||
// it then adds `OmitNilData` option for this condition,
|
||||
// which will filter all nil parameters in `data`.
|
||||
if isDoStruct(reflectInfo.OriginValue.Index(0).Interface()) {
|
||||
model = model.OmitNilData()
|
||||
model.option |= optionOmitNilDataInternal
|
||||
}
|
||||
}
|
||||
list := make(List, reflectInfo.OriginValue.Len())
|
||||
for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
|
||||
list[i] = anyValueToMapBeforeToRecord(reflectInfo.OriginValue.Index(i).Interface())
|
||||
}
|
||||
model.data = list
|
||||
|
||||
case reflect.Struct:
|
||||
// If the `data` parameter is a DO struct,
|
||||
// it then adds `OmitNilData` option for this condition,
|
||||
// which will filter all nil parameters in `data`.
|
||||
if isDoStruct(value) {
|
||||
model = model.OmitNilData()
|
||||
}
|
||||
if v, ok := data[0].(iInterfaces); ok {
|
||||
var (
|
||||
array = v.Interfaces()
|
||||
list = make(List, len(array))
|
||||
)
|
||||
for i := 0; i < len(array); i++ {
|
||||
list[i] = anyValueToMapBeforeToRecord(array[i])
|
||||
}
|
||||
model.data = list
|
||||
} else {
|
||||
model.data = anyValueToMapBeforeToRecord(data[0])
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
model.data = anyValueToMapBeforeToRecord(data[0])
|
||||
|
||||
default:
|
||||
model.data = data[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// OnConflict sets the primary key or index when columns conflicts occurs.
|
||||
// It's not necessary for MySQL driver.
|
||||
func (m *Model) OnConflict(onConflict ...any) *Model {
|
||||
if len(onConflict) == 0 {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
if len(onConflict) > 1 {
|
||||
model.onConflict = onConflict
|
||||
} else if len(onConflict) == 1 {
|
||||
model.onConflict = onConflict[0]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// OnDuplicate sets the operations when columns conflicts occurs.
|
||||
// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
|
||||
// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
|
||||
// The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice.
|
||||
// Example:
|
||||
//
|
||||
// OnDuplicate("nickname, age")
|
||||
// OnDuplicate("nickname", "age")
|
||||
//
|
||||
// OnDuplicate(g.Map{
|
||||
// "nickname": gdb.Raw("CONCAT('name_', VALUES(`nickname`))"),
|
||||
// })
|
||||
//
|
||||
// OnDuplicate(g.Map{
|
||||
// "nickname": "passport",
|
||||
// }).
|
||||
func (m *Model) OnDuplicate(onDuplicate ...any) *Model {
|
||||
if len(onDuplicate) == 0 {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
if len(onDuplicate) > 1 {
|
||||
model.onDuplicate = onDuplicate
|
||||
} else if len(onDuplicate) == 1 {
|
||||
model.onDuplicate = onDuplicate[0]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// OnDuplicateEx sets the excluding columns for operations when columns conflict occurs.
|
||||
// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
|
||||
// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
|
||||
// The parameter `onDuplicateEx` can be type of string/map/slice.
|
||||
// Example:
|
||||
//
|
||||
// OnDuplicateEx("passport, password")
|
||||
// OnDuplicateEx("passport", "password")
|
||||
//
|
||||
// OnDuplicateEx(g.Map{
|
||||
// "passport": "",
|
||||
// "password": "",
|
||||
// }).
|
||||
func (m *Model) OnDuplicateEx(onDuplicateEx ...any) *Model {
|
||||
if len(onDuplicateEx) == 0 {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
if len(onDuplicateEx) > 1 {
|
||||
model.onDuplicateEx = onDuplicateEx
|
||||
} else if len(onDuplicateEx) == 1 {
|
||||
model.onDuplicateEx = onDuplicateEx[0]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// Insert does "INSERT INTO ..." statement for the model.
|
||||
// The optional parameter `data` is the same as the parameter of Model.Data function,
|
||||
// see Model.Data.
|
||||
func (m *Model) Insert(data ...any) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).Insert()
|
||||
}
|
||||
return m.doInsertWithOption(ctx, InsertOptionDefault)
|
||||
}
|
||||
|
||||
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
|
||||
func (m *Model) InsertAndGetId(data ...any) (lastInsertId int64, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).InsertAndGetId()
|
||||
}
|
||||
result, err := m.doInsertWithOption(ctx, InsertOptionDefault)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the model.
|
||||
// The optional parameter `data` is the same as the parameter of Model.Data function,
|
||||
// see Model.Data.
|
||||
func (m *Model) InsertIgnore(data ...any) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).InsertIgnore()
|
||||
}
|
||||
return m.doInsertWithOption(ctx, InsertOptionIgnore)
|
||||
}
|
||||
|
||||
// Replace does "REPLACE INTO ..." statement for the model.
|
||||
// The optional parameter `data` is the same as the parameter of Model.Data function,
|
||||
// see Model.Data.
|
||||
func (m *Model) Replace(data ...any) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).Replace()
|
||||
}
|
||||
return m.doInsertWithOption(ctx, InsertOptionReplace)
|
||||
}
|
||||
|
||||
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model.
|
||||
// The optional parameter `data` is the same as the parameter of Model.Data function,
|
||||
// see Model.Data.
|
||||
//
|
||||
// It updates the record if there's primary or unique index in the saving data,
|
||||
// or else it inserts a new record into the table.
|
||||
func (m *Model) Save(data ...any) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).Save()
|
||||
}
|
||||
return m.doInsertWithOption(ctx, InsertOptionSave)
|
||||
}
|
||||
|
||||
// doInsertWithOption inserts data with option parameter.
|
||||
func (m *Model) doInsertWithOption(ctx context.Context, insertOption InsertOption) (result sql.Result, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
m.checkAndRemoveSelectCache(ctx)
|
||||
}
|
||||
}()
|
||||
if m.data == nil {
|
||||
return nil, gerror.NewCode(gcode.CodeMissingParameter, "inserting into table with empty data")
|
||||
}
|
||||
var (
|
||||
list List
|
||||
stm = m.softTimeMaintainer()
|
||||
fieldNameCreate, fieldTypeCreate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldCreate)
|
||||
fieldNameUpdate, fieldTypeUpdate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldUpdate)
|
||||
fieldNameDelete, fieldTypeDelete = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete)
|
||||
)
|
||||
// m.data was already converted to type List/Map by function Data
|
||||
newData, err := m.filterDataForInsertOrUpdate(m.data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// It converts any data to List type for inserting.
|
||||
switch value := newData.(type) {
|
||||
case List:
|
||||
list = value
|
||||
|
||||
case Map:
|
||||
list = List{value}
|
||||
}
|
||||
|
||||
if len(list) < 1 {
|
||||
return result, gerror.NewCode(gcode.CodeMissingParameter, "data list cannot be empty")
|
||||
}
|
||||
|
||||
// Automatic handling for creating/updating time.
|
||||
if fieldNameCreate != "" && m.isFieldInFieldsEx(fieldNameCreate) {
|
||||
fieldNameCreate = ""
|
||||
}
|
||||
if fieldNameUpdate != "" && m.isFieldInFieldsEx(fieldNameUpdate) {
|
||||
fieldNameUpdate = ""
|
||||
}
|
||||
var isSoftTimeFeatureEnabled = fieldNameCreate != "" || fieldNameUpdate != ""
|
||||
if !m.unscoped && isSoftTimeFeatureEnabled {
|
||||
for k, v := range list {
|
||||
if fieldNameCreate != "" && empty.IsNil(v[fieldNameCreate]) {
|
||||
fieldCreateValue := stm.GetFieldValue(ctx, fieldTypeCreate, false)
|
||||
if fieldCreateValue != nil {
|
||||
v[fieldNameCreate] = fieldCreateValue
|
||||
}
|
||||
}
|
||||
if fieldNameUpdate != "" && empty.IsNil(v[fieldNameUpdate]) {
|
||||
fieldUpdateValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false)
|
||||
if fieldUpdateValue != nil {
|
||||
v[fieldNameUpdate] = fieldUpdateValue
|
||||
}
|
||||
}
|
||||
// for timestamp field that should initialize the delete_at field with value, for example 0.
|
||||
if fieldNameDelete != "" && empty.IsNil(v[fieldNameDelete]) {
|
||||
fieldDeleteValue := stm.GetFieldValue(ctx, fieldTypeDelete, true)
|
||||
if fieldDeleteValue != nil {
|
||||
v[fieldNameDelete] = fieldDeleteValue
|
||||
}
|
||||
}
|
||||
list[k] = v
|
||||
}
|
||||
}
|
||||
// Format DoInsertOption, especially for "ON DUPLICATE KEY UPDATE" statement.
|
||||
columnNames := make([]string, 0, len(list[0]))
|
||||
for k := range list[0] {
|
||||
columnNames = append(columnNames, k)
|
||||
}
|
||||
doInsertOption, err := m.formatDoInsertOption(insertOption, columnNames)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
in := &HookInsertInput{
|
||||
internalParamHookInsert: internalParamHookInsert{
|
||||
internalParamHook: internalParamHook{
|
||||
link: m.getLink(true),
|
||||
},
|
||||
handler: m.hookHandler.Insert,
|
||||
},
|
||||
Model: m,
|
||||
Table: m.tables,
|
||||
Schema: m.schema,
|
||||
Data: list,
|
||||
Option: doInsertOption,
|
||||
}
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
||||
func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []string) (option DoInsertOption, err error) {
|
||||
option = DoInsertOption{
|
||||
InsertOption: insertOption,
|
||||
BatchCount: m.getBatch(),
|
||||
}
|
||||
if insertOption != InsertOptionSave {
|
||||
return
|
||||
}
|
||||
|
||||
onConflictKeys, err := m.formatOnConflictKeys(m.onConflict)
|
||||
if err != nil {
|
||||
return option, err
|
||||
}
|
||||
option.OnConflict = onConflictKeys
|
||||
|
||||
onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx)
|
||||
if err != nil {
|
||||
return option, err
|
||||
}
|
||||
onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys)
|
||||
if m.onDuplicate != nil {
|
||||
switch m.onDuplicate.(type) {
|
||||
case Raw, *Raw:
|
||||
option.OnDuplicateStr = gconv.String(m.onDuplicate)
|
||||
|
||||
default:
|
||||
reflectInfo := reflection.OriginValueAndKind(m.onDuplicate)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.String:
|
||||
option.OnDuplicateMap = make(map[string]any)
|
||||
for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
option.OnDuplicateMap = make(map[string]any)
|
||||
for k, v := range gconv.Map(m.onDuplicate) {
|
||||
if onDuplicateExKeySet.Contains(k) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[k] = v
|
||||
}
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
option.OnDuplicateMap = make(map[string]any)
|
||||
for _, v := range gconv.Strings(m.onDuplicate) {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
|
||||
default:
|
||||
return option, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`unsupported OnDuplicate parameter type "%s"`,
|
||||
reflect.TypeOf(m.onDuplicate),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if onDuplicateExKeySet.Size() > 0 {
|
||||
option.OnDuplicateMap = make(map[string]any)
|
||||
for _, v := range columnNames {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) formatOnDuplicateExKeys(onDuplicateEx any) ([]string, error) {
|
||||
if onDuplicateEx == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reflectInfo := reflection.OriginValueAndKind(onDuplicateEx)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.String:
|
||||
return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
|
||||
|
||||
case reflect.Map:
|
||||
return gutil.Keys(onDuplicateEx), nil
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
return gconv.Strings(onDuplicateEx), nil
|
||||
|
||||
default:
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`unsupported OnDuplicateEx parameter type "%s"`,
|
||||
reflect.TypeOf(onDuplicateEx),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) formatOnConflictKeys(onConflict any) ([]string, error) {
|
||||
if onConflict == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reflectInfo := reflection.OriginValueAndKind(onConflict)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.String:
|
||||
return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
return gconv.Strings(onConflict), nil
|
||||
|
||||
default:
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`unsupported onConflict parameter type "%s"`,
|
||||
reflect.TypeOf(onConflict),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) getBatch() int {
|
||||
return m.batch
|
||||
}
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
)
|
||||
|
||||
// LeftJoin does "LEFT JOIN ... ON ..." statement on the model.
|
||||
// The parameter `table` can be joined table and its joined condition,
|
||||
// and also with its alias name.
|
||||
//
|
||||
// Eg:
|
||||
// Model("user").LeftJoin("user_detail", "user_detail.uid=user.uid")
|
||||
// Model("user", "u").LeftJoin("user_detail", "ud", "ud.uid=u.uid")
|
||||
// Model("user", "u").LeftJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid").
|
||||
func (m *Model) LeftJoin(tableOrSubQueryAndJoinConditions ...string) *Model {
|
||||
return m.doJoin(joinOperatorLeft, tableOrSubQueryAndJoinConditions...)
|
||||
}
|
||||
|
||||
// RightJoin does "RIGHT JOIN ... ON ..." statement on the model.
|
||||
// The parameter `table` can be joined table and its joined condition,
|
||||
// and also with its alias name.
|
||||
//
|
||||
// Eg:
|
||||
// Model("user").RightJoin("user_detail", "user_detail.uid=user.uid")
|
||||
// Model("user", "u").RightJoin("user_detail", "ud", "ud.uid=u.uid")
|
||||
// Model("user", "u").RightJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid").
|
||||
func (m *Model) RightJoin(tableOrSubQueryAndJoinConditions ...string) *Model {
|
||||
return m.doJoin(joinOperatorRight, tableOrSubQueryAndJoinConditions...)
|
||||
}
|
||||
|
||||
// InnerJoin does "INNER JOIN ... ON ..." statement on the model.
|
||||
// The parameter `table` can be joined table and its joined condition,
|
||||
// and also with its alias name。
|
||||
//
|
||||
// Eg:
|
||||
// Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid")
|
||||
// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid")
|
||||
// Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid").
|
||||
func (m *Model) InnerJoin(tableOrSubQueryAndJoinConditions ...string) *Model {
|
||||
return m.doJoin(joinOperatorInner, tableOrSubQueryAndJoinConditions...)
|
||||
}
|
||||
|
||||
// LeftJoinOnField performs as LeftJoin, but it joins both tables with the `same field name`.
|
||||
//
|
||||
// Eg:
|
||||
// Model("order").LeftJoinOnField("user", "user_id")
|
||||
// Model("order").LeftJoinOnField("product", "product_id").
|
||||
func (m *Model) LeftJoinOnField(table, field string) *Model {
|
||||
return m.doJoin(joinOperatorLeft, table, fmt.Sprintf(
|
||||
`%s.%s=%s.%s`,
|
||||
m.tablesInit,
|
||||
m.db.GetCore().QuoteWord(field),
|
||||
m.db.GetCore().QuoteWord(table),
|
||||
m.db.GetCore().QuoteWord(field),
|
||||
))
|
||||
}
|
||||
|
||||
// RightJoinOnField performs as RightJoin, but it joins both tables with the `same field name`.
|
||||
//
|
||||
// Eg:
|
||||
// Model("order").InnerJoinOnField("user", "user_id")
|
||||
// Model("order").InnerJoinOnField("product", "product_id").
|
||||
func (m *Model) RightJoinOnField(table, field string) *Model {
|
||||
return m.doJoin(joinOperatorRight, table, fmt.Sprintf(
|
||||
`%s.%s=%s.%s`,
|
||||
m.tablesInit,
|
||||
m.db.GetCore().QuoteWord(field),
|
||||
m.db.GetCore().QuoteWord(table),
|
||||
m.db.GetCore().QuoteWord(field),
|
||||
))
|
||||
}
|
||||
|
||||
// InnerJoinOnField performs as InnerJoin, but it joins both tables with the `same field name`.
|
||||
//
|
||||
// Eg:
|
||||
// Model("order").InnerJoinOnField("user", "user_id")
|
||||
// Model("order").InnerJoinOnField("product", "product_id").
|
||||
func (m *Model) InnerJoinOnField(table, field string) *Model {
|
||||
return m.doJoin(joinOperatorInner, table, fmt.Sprintf(
|
||||
`%s.%s=%s.%s`,
|
||||
m.tablesInit,
|
||||
m.db.GetCore().QuoteWord(field),
|
||||
m.db.GetCore().QuoteWord(table),
|
||||
m.db.GetCore().QuoteWord(field),
|
||||
))
|
||||
}
|
||||
|
||||
// LeftJoinOnFields performs as LeftJoin. It specifies different fields and comparison operator.
|
||||
//
|
||||
// Eg:
|
||||
// Model("user").LeftJoinOnFields("order", "id", "=", "user_id")
|
||||
// Model("user").LeftJoinOnFields("order", "id", ">", "user_id")
|
||||
// Model("user").LeftJoinOnFields("order", "id", "<", "user_id")
|
||||
func (m *Model) LeftJoinOnFields(table, firstField, operator, secondField string) *Model {
|
||||
return m.doJoin(joinOperatorLeft, table, fmt.Sprintf(
|
||||
`%s.%s %s %s.%s`,
|
||||
m.tablesInit,
|
||||
m.db.GetCore().QuoteWord(firstField),
|
||||
operator,
|
||||
m.db.GetCore().QuoteWord(table),
|
||||
m.db.GetCore().QuoteWord(secondField),
|
||||
))
|
||||
}
|
||||
|
||||
// RightJoinOnFields performs as RightJoin. It specifies different fields and comparison operator.
|
||||
//
|
||||
// Eg:
|
||||
// Model("user").RightJoinOnFields("order", "id", "=", "user_id")
|
||||
// Model("user").RightJoinOnFields("order", "id", ">", "user_id")
|
||||
// Model("user").RightJoinOnFields("order", "id", "<", "user_id")
|
||||
func (m *Model) RightJoinOnFields(table, firstField, operator, secondField string) *Model {
|
||||
return m.doJoin(joinOperatorRight, table, fmt.Sprintf(
|
||||
`%s.%s %s %s.%s`,
|
||||
m.tablesInit,
|
||||
m.db.GetCore().QuoteWord(firstField),
|
||||
operator,
|
||||
m.db.GetCore().QuoteWord(table),
|
||||
m.db.GetCore().QuoteWord(secondField),
|
||||
))
|
||||
}
|
||||
|
||||
// InnerJoinOnFields performs as InnerJoin. It specifies different fields and comparison operator.
|
||||
//
|
||||
// Eg:
|
||||
// Model("user").InnerJoinOnFields("order", "id", "=", "user_id")
|
||||
// Model("user").InnerJoinOnFields("order", "id", ">", "user_id")
|
||||
// Model("user").InnerJoinOnFields("order", "id", "<", "user_id")
|
||||
func (m *Model) InnerJoinOnFields(table, firstField, operator, secondField string) *Model {
|
||||
return m.doJoin(joinOperatorInner, table, fmt.Sprintf(
|
||||
`%s.%s %s %s.%s`,
|
||||
m.tablesInit,
|
||||
m.db.GetCore().QuoteWord(firstField),
|
||||
operator,
|
||||
m.db.GetCore().QuoteWord(table),
|
||||
m.db.GetCore().QuoteWord(secondField),
|
||||
))
|
||||
}
|
||||
|
||||
// doJoin does "LEFT/RIGHT/INNER JOIN ... ON ..." statement on the model.
|
||||
// The parameter `tableOrSubQueryAndJoinConditions` can be joined table and its joined condition,
|
||||
// and also with its alias name.
|
||||
//
|
||||
// Eg:
|
||||
// Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid")
|
||||
// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid")
|
||||
// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid>u.uid")
|
||||
// Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid")
|
||||
// Related issues:
|
||||
// https://github.com/gogf/gf/issues/1024
|
||||
func (m *Model) doJoin(operator joinOperator, tableOrSubQueryAndJoinConditions ...string) *Model {
|
||||
var (
|
||||
model = m.getModel()
|
||||
joinStr = ""
|
||||
table string
|
||||
alias string
|
||||
)
|
||||
// Check the first parameter table or sub-query.
|
||||
if len(tableOrSubQueryAndJoinConditions) > 0 {
|
||||
if isSubQuery(tableOrSubQueryAndJoinConditions[0]) {
|
||||
joinStr = gstr.Trim(tableOrSubQueryAndJoinConditions[0])
|
||||
if joinStr[0] != '(' {
|
||||
joinStr = "(" + joinStr + ")"
|
||||
}
|
||||
} else {
|
||||
table = tableOrSubQueryAndJoinConditions[0]
|
||||
joinStr = m.db.GetCore().QuotePrefixTableName(table)
|
||||
}
|
||||
}
|
||||
// Generate join condition statement string.
|
||||
conditionLength := len(tableOrSubQueryAndJoinConditions)
|
||||
switch {
|
||||
case conditionLength > 2:
|
||||
alias = tableOrSubQueryAndJoinConditions[1]
|
||||
model.tables += fmt.Sprintf(
|
||||
" %s JOIN %s AS %s ON (%s)",
|
||||
operator, joinStr,
|
||||
m.db.GetCore().QuoteWord(alias),
|
||||
tableOrSubQueryAndJoinConditions[2],
|
||||
)
|
||||
m.tableAliasMap[alias] = table
|
||||
|
||||
case conditionLength == 2:
|
||||
model.tables += fmt.Sprintf(
|
||||
" %s JOIN %s ON (%s)",
|
||||
operator, joinStr, tableOrSubQueryAndJoinConditions[1],
|
||||
)
|
||||
|
||||
case conditionLength == 1:
|
||||
model.tables += fmt.Sprintf(
|
||||
" %s JOIN %s", operator, joinStr,
|
||||
)
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// getTableNameByPrefixOrAlias checks and returns the table name if `prefixOrAlias` is an alias of a table,
|
||||
// it or else returns the `prefixOrAlias` directly.
|
||||
func (m *Model) getTableNameByPrefixOrAlias(prefixOrAlias string) string {
|
||||
value, ok := m.tableAliasMap[prefixOrAlias]
|
||||
if ok {
|
||||
return value
|
||||
}
|
||||
return prefixOrAlias
|
||||
}
|
||||
|
||||
// isSubQuery checks and returns whether given string a sub-query sql string.
|
||||
func isSubQuery(s string) bool {
|
||||
s = gstr.TrimLeft(s, "()")
|
||||
if p := gstr.Pos(s, " "); p != -1 {
|
||||
if gstr.Equal(s[:p], "select") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// Lock clause constants for different databases.
|
||||
// These constants provide type-safe and IDE-friendly access to various lock syntaxes.
|
||||
const (
|
||||
// Common lock clauses (supported by most databases)
|
||||
LockForUpdate = "FOR UPDATE"
|
||||
LockForUpdateSkipLocked = "FOR UPDATE SKIP LOCKED"
|
||||
|
||||
// MySQL lock clauses
|
||||
LockInShareMode = "LOCK IN SHARE MODE" // MySQL legacy syntax
|
||||
LockForShare = "FOR SHARE" // MySQL 8.0+ and PostgreSQL
|
||||
LockForUpdateNowait = "FOR UPDATE NOWAIT" // MySQL 8.0+ and Oracle
|
||||
|
||||
// PostgreSQL specific lock clauses
|
||||
LockForNoKeyUpdate = "FOR NO KEY UPDATE"
|
||||
LockForKeyShare = "FOR KEY SHARE"
|
||||
LockForShareNowait = "FOR SHARE NOWAIT"
|
||||
LockForShareSkipLocked = "FOR SHARE SKIP LOCKED"
|
||||
LockForNoKeyUpdateNowait = "FOR NO KEY UPDATE NOWAIT"
|
||||
LockForNoKeyUpdateSkipLocked = "FOR NO KEY UPDATE SKIP LOCKED"
|
||||
LockForKeyShareNowait = "FOR KEY SHARE NOWAIT"
|
||||
LockForKeyShareSkipLocked = "FOR KEY SHARE SKIP LOCKED"
|
||||
|
||||
// Oracle specific lock clauses
|
||||
LockForUpdateWait5 = "FOR UPDATE WAIT 5"
|
||||
LockForUpdateWait10 = "FOR UPDATE WAIT 10"
|
||||
LockForUpdateWait30 = "FOR UPDATE WAIT 30"
|
||||
|
||||
// SQL Server lock hints (use with WITH clause)
|
||||
LockWithUpdLock = "WITH (UPDLOCK)"
|
||||
LockWithHoldLock = "WITH (HOLDLOCK)"
|
||||
LockWithXLock = "WITH (XLOCK)"
|
||||
LockWithTabLock = "WITH (TABLOCK)"
|
||||
LockWithNoLock = "WITH (NOLOCK)"
|
||||
LockWithUpdLockHoldLock = "WITH (UPDLOCK, HOLDLOCK)"
|
||||
)
|
||||
|
||||
// Lock sets a custom lock clause for the current operation.
|
||||
// This is a generic method that allows you to specify any lock syntax supported by your database.
|
||||
// You can use predefined constants or custom strings.
|
||||
//
|
||||
// Database-specific lock syntax support:
|
||||
//
|
||||
// PostgreSQL (most comprehensive):
|
||||
// - "FOR UPDATE" - Exclusive lock, blocks all access
|
||||
// - "FOR NO KEY UPDATE" - Weaker exclusive lock, doesn't block FOR KEY SHARE
|
||||
// - "FOR SHARE" - Shared lock, allows reads but blocks writes
|
||||
// - "FOR KEY SHARE" - Weakest lock, only locks key values
|
||||
// - All above can be combined with:
|
||||
// - "NOWAIT" - Return immediately if lock cannot be acquired
|
||||
// - "SKIP LOCKED" - Skip locked rows instead of waiting
|
||||
//
|
||||
// MySQL:
|
||||
// - "FOR UPDATE" - Exclusive lock (all versions)
|
||||
// - "LOCK IN SHARE MODE" - Shared lock (legacy syntax)
|
||||
// - "FOR SHARE" - Shared lock (MySQL 8.0+)
|
||||
// - "FOR UPDATE NOWAIT" - MySQL 8.0+ only
|
||||
// - "FOR UPDATE SKIP LOCKED" - MySQL 8.0+ only
|
||||
//
|
||||
// Oracle:
|
||||
// - "FOR UPDATE" - Exclusive lock
|
||||
// - "FOR UPDATE NOWAIT" - Exclusive lock, no wait
|
||||
// - "FOR UPDATE SKIP LOCKED" - Exclusive lock, skip locked rows
|
||||
// - "FOR UPDATE WAIT n" - Exclusive lock, wait n seconds
|
||||
// - "FOR UPDATE OF column_list" - Lock specific columns
|
||||
//
|
||||
// SQL Server (uses WITH hints):
|
||||
// - "WITH (UPDLOCK)" - Update lock
|
||||
// - "WITH (HOLDLOCK)" - Hold lock until transaction end
|
||||
// - "WITH (XLOCK)" - Exclusive lock
|
||||
// - "WITH (TABLOCK)" - Table lock
|
||||
// - "WITH (NOLOCK)" - No lock (dirty read)
|
||||
// - "WITH (UPDLOCK, HOLDLOCK)" - Combined update and hold lock
|
||||
//
|
||||
// SQLite:
|
||||
// - Limited locking support, database-level locks only
|
||||
// - No row-level lock syntax supported
|
||||
//
|
||||
// Usage examples:
|
||||
//
|
||||
// db.Model("users").Lock("FOR UPDATE NOWAIT").Where("id", 1).One()
|
||||
// db.Model("users").Lock("FOR SHARE SKIP LOCKED").Where("status", "active").All()
|
||||
// db.Model("users").Lock("WITH (UPDLOCK)").Where("id", 1).One() // SQL Server
|
||||
// db.Model("users").Lock("FOR UPDATE OF name, email").Where("id", 1).One() // Oracle
|
||||
// db.Model("users").Lock("FOR UPDATE WAIT 15").Where("id", 1).One() // Oracle custom wait
|
||||
//
|
||||
// Or use predefined constants for better IDE support:
|
||||
//
|
||||
// db.Model("users").Lock(gdb.LockForUpdateNowait).Where("id", 1).One()
|
||||
// db.Model("users").Lock(gdb.LockForShareSkipLocked).Where("status", "active").All()
|
||||
func (m *Model) Lock(lockClause string) *Model {
|
||||
model := m.getModel()
|
||||
model.lockInfo = lockClause
|
||||
return model
|
||||
}
|
||||
|
||||
// LockUpdate sets the lock for update for current operation.
|
||||
// This is equivalent to Lock("FOR UPDATE").
|
||||
func (m *Model) LockUpdate() *Model {
|
||||
model := m.getModel()
|
||||
model.lockInfo = LockForUpdate
|
||||
return model
|
||||
}
|
||||
|
||||
// LockUpdateSkipLocked sets the lock for update with skip locked behavior for current operation.
|
||||
// It skips the locked rows.
|
||||
// This is equivalent to Lock("FOR UPDATE SKIP LOCKED").
|
||||
// Note: Supported by PostgreSQL, Oracle, and MySQL 8.0+.
|
||||
func (m *Model) LockUpdateSkipLocked() *Model {
|
||||
model := m.getModel()
|
||||
model.lockInfo = LockForUpdateSkipLocked
|
||||
return model
|
||||
}
|
||||
|
||||
// LockShared sets the lock in share mode for current operation.
|
||||
// This is equivalent to Lock("LOCK IN SHARE MODE") for MySQL or Lock("FOR SHARE") for PostgreSQL.
|
||||
// Note: For maximum compatibility, this uses MySQL's legacy syntax.
|
||||
func (m *Model) LockShared() *Model {
|
||||
model := m.getModel()
|
||||
model.lockInfo = LockInShareMode
|
||||
return model
|
||||
}
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
const (
|
||||
optionOmitNil = optionOmitNilWhere | optionOmitNilData
|
||||
optionOmitEmpty = optionOmitEmptyWhere | optionOmitEmptyData
|
||||
optionOmitNilDataInternal = optionOmitNilData | optionOmitNilDataList // this option is used internally only for ForDao feature.
|
||||
optionOmitEmptyWhere = 1 << iota // 8
|
||||
optionOmitEmptyData // 16
|
||||
optionOmitNilWhere // 32
|
||||
optionOmitNilData // 64
|
||||
optionOmitNilDataList // 128
|
||||
)
|
||||
|
||||
// OmitEmpty sets optionOmitEmpty option for the model, which automatically filers
|
||||
// the data and where parameters for `empty` values.
|
||||
func (m *Model) OmitEmpty() *Model {
|
||||
model := m.getModel()
|
||||
model.option = model.option | optionOmitEmpty
|
||||
return model
|
||||
}
|
||||
|
||||
// OmitEmptyWhere sets optionOmitEmptyWhere option for the model, which automatically filers
|
||||
// the Where/Having parameters for `empty` values.
|
||||
//
|
||||
// Eg:
|
||||
//
|
||||
// Where("id", []int{}).All() -> SELECT xxx FROM xxx WHERE 0=1
|
||||
// Where("name", "").All() -> SELECT xxx FROM xxx WHERE `name`=''
|
||||
// OmitEmpty().Where("id", []int{}).All() -> SELECT xxx FROM xxx
|
||||
// OmitEmpty().("name", "").All() -> SELECT xxx FROM xxx.
|
||||
func (m *Model) OmitEmptyWhere() *Model {
|
||||
model := m.getModel()
|
||||
model.option = model.option | optionOmitEmptyWhere
|
||||
return model
|
||||
}
|
||||
|
||||
// OmitEmptyData sets optionOmitEmptyData option for the model, which automatically filers
|
||||
// the Data parameters for `empty` values.
|
||||
func (m *Model) OmitEmptyData() *Model {
|
||||
model := m.getModel()
|
||||
model.option = model.option | optionOmitEmptyData
|
||||
return model
|
||||
}
|
||||
|
||||
// OmitNil sets optionOmitNil option for the model, which automatically filers
|
||||
// the data and where parameters for `nil` values.
|
||||
func (m *Model) OmitNil() *Model {
|
||||
model := m.getModel()
|
||||
model.option = model.option | optionOmitNil
|
||||
return model
|
||||
}
|
||||
|
||||
// OmitNilWhere sets optionOmitNilWhere option for the model, which automatically filers
|
||||
// the Where/Having parameters for `nil` values.
|
||||
func (m *Model) OmitNilWhere() *Model {
|
||||
model := m.getModel()
|
||||
model.option = model.option | optionOmitNilWhere
|
||||
return model
|
||||
}
|
||||
|
||||
// OmitNilData sets optionOmitNilData option for the model, which automatically filers
|
||||
// the Data parameters for `nil` values.
|
||||
func (m *Model) OmitNilData() *Model {
|
||||
model := m.getModel()
|
||||
model.option = model.option | optionOmitNilData
|
||||
return model
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// Order sets the "ORDER BY" statement for the model.
|
||||
//
|
||||
// Example:
|
||||
// Order("id desc")
|
||||
// Order("id", "desc")
|
||||
// Order("id desc,name asc")
|
||||
// Order("id desc", "name asc")
|
||||
// Order("id desc").Order("name asc")
|
||||
// Order(gdb.Raw("field(id, 3,1,2)")).
|
||||
func (m *Model) Order(orderBy ...any) *Model {
|
||||
if len(orderBy) == 0 {
|
||||
return m
|
||||
}
|
||||
var (
|
||||
core = m.db.GetCore()
|
||||
model = m.getModel()
|
||||
)
|
||||
for _, v := range orderBy {
|
||||
if model.orderBy != "" {
|
||||
model.orderBy += ","
|
||||
}
|
||||
switch v.(type) {
|
||||
case Raw, *Raw:
|
||||
model.orderBy += gconv.String(v)
|
||||
default:
|
||||
orderByStr := gconv.String(v)
|
||||
if gstr.Contains(orderByStr, " ") {
|
||||
model.orderBy += core.QuoteString(orderByStr)
|
||||
} else {
|
||||
if gstr.Equal(orderByStr, "ASC") || gstr.Equal(orderByStr, "DESC") {
|
||||
model.orderBy = gstr.TrimRight(model.orderBy, ",")
|
||||
model.orderBy += " " + orderByStr
|
||||
} else {
|
||||
model.orderBy += core.QuoteWord(orderByStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// OrderAsc sets the "ORDER BY xxx ASC" statement for the model.
|
||||
func (m *Model) OrderAsc(column string) *Model {
|
||||
if len(column) == 0 {
|
||||
return m
|
||||
}
|
||||
return m.Order(column + " ASC")
|
||||
}
|
||||
|
||||
// OrderDesc sets the "ORDER BY xxx DESC" statement for the model.
|
||||
func (m *Model) OrderDesc(column string) *Model {
|
||||
if len(column) == 0 {
|
||||
return m
|
||||
}
|
||||
return m.Order(column + " DESC")
|
||||
}
|
||||
|
||||
// OrderRandom sets the "ORDER BY RANDOM()" statement for the model.
|
||||
func (m *Model) OrderRandom() *Model {
|
||||
model := m.getModel()
|
||||
model.orderBy = m.db.OrderRandomFunction()
|
||||
return model
|
||||
}
|
||||
|
||||
// Group sets the "GROUP BY" statement for the model.
|
||||
func (m *Model) Group(groupBy ...string) *Model {
|
||||
if len(groupBy) == 0 {
|
||||
return m
|
||||
}
|
||||
var (
|
||||
core = m.db.GetCore()
|
||||
model = m.getModel()
|
||||
)
|
||||
|
||||
if model.groupBy != "" {
|
||||
model.groupBy += ","
|
||||
}
|
||||
model.groupBy += core.QuoteString(strings.Join(groupBy, ","))
|
||||
return model
|
||||
}
|
||||
|
|
@ -0,0 +1,964 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/reflection"
|
||||
"github.com/gogf/gf/v2/container/gset"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// All does "SELECT FROM ..." statement for the model.
|
||||
// It retrieves the records from table and returns the result as slice type.
|
||||
// It returns nil if there's no record retrieved with the given conditions from table.
|
||||
//
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) All(where ...any) (Result, error) {
|
||||
var ctx = m.GetCtx()
|
||||
return m.doGetAll(ctx, SelectTypeDefault, false, where...)
|
||||
}
|
||||
|
||||
// AllAndCount retrieves all records and the total count of records from the model.
|
||||
// If useFieldForCount is true, it will use the fields specified in the model for counting;
|
||||
// otherwise, it will use a constant value of 1 for counting.
|
||||
// It returns the result as a slice of records, the total count of records, and an error if any.
|
||||
// The where parameter is an optional list of conditions to use when retrieving records.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var model Model
|
||||
// var result Result
|
||||
// var count int
|
||||
// where := []any{"name = ?", "John"}
|
||||
// result, count, err := model.AllAndCount(true)
|
||||
// if err != nil {
|
||||
// // Handle error.
|
||||
// }
|
||||
// fmt.Println(result, count)
|
||||
func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount int, err error) {
|
||||
// Clone the model for counting
|
||||
countModel := m.Clone()
|
||||
|
||||
// Decide how to build the COUNT() expression:
|
||||
// - If caller explicitly wants to use the single field expression for counting,
|
||||
// honor it (e.g. Fields("DISTINCT col") with useFieldForCount = true).
|
||||
// - Otherwise, clear fields to let Count() use its default COUNT(1),
|
||||
// avoiding invalid COUNT(field1, field2, ...) with multiple fields,
|
||||
// or incorrect COUNT(DISTINCT 1) when Distinct() is set.
|
||||
if useFieldForCount && len(m.fields) == 1 {
|
||||
countModel.fields = m.fields
|
||||
} else {
|
||||
countModel.fields = nil
|
||||
}
|
||||
if len(m.pageCacheOption) > 0 {
|
||||
countModel = countModel.Cache(m.pageCacheOption[0])
|
||||
}
|
||||
|
||||
// Get the total count of records
|
||||
totalCount, err = countModel.Count()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// If the total count is 0, there are no records to retrieve, so return early
|
||||
if totalCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
resultModel := m.Clone()
|
||||
if len(m.pageCacheOption) > 1 {
|
||||
resultModel = resultModel.Cache(m.pageCacheOption[1])
|
||||
}
|
||||
|
||||
// Retrieve all records
|
||||
result, err = resultModel.doGetAll(m.GetCtx(), SelectTypeDefault, false)
|
||||
return
|
||||
}
|
||||
|
||||
// Chunk iterates the query result with given `size` and `handler` function.
|
||||
func (m *Model) Chunk(size int, handler ChunkHandler) {
|
||||
page := m.start
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
model := m
|
||||
for {
|
||||
model = model.Page(page, size)
|
||||
data, err := model.All()
|
||||
if err != nil {
|
||||
handler(nil, err)
|
||||
break
|
||||
}
|
||||
if len(data) == 0 {
|
||||
break
|
||||
}
|
||||
if !handler(data, err) {
|
||||
break
|
||||
}
|
||||
if len(data) < size {
|
||||
break
|
||||
}
|
||||
page++
|
||||
}
|
||||
}
|
||||
|
||||
// One retrieves one record from table and returns the result as map type.
|
||||
// It returns nil if there's no record retrieved with the given conditions from table.
|
||||
//
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) One(where ...any) (Record, error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).One()
|
||||
}
|
||||
all, err := m.doGetAll(ctx, SelectTypeDefault, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(all) > 0 {
|
||||
return all[0], nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Array queries and returns data values as slice from database.
|
||||
// Note that if there are multiple columns in the result, it returns just one column values randomly.
|
||||
//
|
||||
// If the optional parameter `fieldsAndWhere` is given, the fieldsAndWhere[0] is the selected fields
|
||||
// and fieldsAndWhere[1:] is treated as where condition fields.
|
||||
// Also see Model.Fields and Model.Where functions.
|
||||
func (m *Model) Array(fieldsAndWhere ...any) (Array, error) {
|
||||
if len(fieldsAndWhere) > 0 {
|
||||
if len(fieldsAndWhere) > 2 {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Array()
|
||||
} else if len(fieldsAndWhere) == 2 {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1]).Array()
|
||||
} else {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Array()
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
field string
|
||||
core = m.db.GetCore()
|
||||
ctx = core.injectInternalColumn(m.GetCtx())
|
||||
)
|
||||
all, err := m.doGetAll(ctx, SelectTypeArray, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(all) > 0 {
|
||||
internalData := core.getInternalColumnFromCtx(ctx)
|
||||
if internalData == nil {
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeInternalError,
|
||||
`query count error: the internal context data is missing. there's internal issue should be fixed`,
|
||||
)
|
||||
}
|
||||
// If FirstResultColumn present, it returns the value of the first record of the first field.
|
||||
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
|
||||
field = internalData.FirstResultColumn
|
||||
if field == "" {
|
||||
// Fields number check.
|
||||
var recordFields = m.getRecordFields(all[0])
|
||||
if len(recordFields) == 1 {
|
||||
field = recordFields[0]
|
||||
} else {
|
||||
// it returns error if there are multiple fields in the result record.
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
|
||||
len(recordFields),
|
||||
gjson.MustEncodeString(recordFields),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return all.Array(field), nil
|
||||
}
|
||||
|
||||
// Struct retrieves one record from table and converts it into given struct.
|
||||
// The parameter `pointer` should be type of *struct/**struct. If type **struct is given,
|
||||
// it can create the struct internally during converting.
|
||||
//
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
//
|
||||
// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has
|
||||
// default value and there's no record retrieved with the given conditions from table.
|
||||
//
|
||||
// Example:
|
||||
// user := new(User)
|
||||
// err := db.Model("user").Where("id", 1).Scan(user)
|
||||
//
|
||||
// user := (*User)(nil)
|
||||
// err := db.Model("user").Where("id", 1).Scan(&user).
|
||||
func (m *Model) doStruct(pointer any, where ...any) error {
|
||||
model := m
|
||||
// Auto selecting fields by struct attributes.
|
||||
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
|
||||
if v, ok := pointer.(reflect.Value); ok {
|
||||
model = m.Fields(v.Interface())
|
||||
} else {
|
||||
model = m.Fields(pointer)
|
||||
}
|
||||
}
|
||||
one, err := model.One(where...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = one.Struct(pointer); err != nil {
|
||||
return err
|
||||
}
|
||||
return model.doWithScanStruct(pointer)
|
||||
}
|
||||
|
||||
// Structs retrieves records from table and converts them into given struct slice.
|
||||
// The parameter `pointer` should be type of *[]struct/*[]*struct. It can create and fill the struct
|
||||
// slice internally during converting.
|
||||
//
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
//
|
||||
// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has
|
||||
// default value and there's no record retrieved with the given conditions from table.
|
||||
//
|
||||
// Example:
|
||||
// users := ([]User)(nil)
|
||||
// err := db.Model("user").Scan(&users)
|
||||
//
|
||||
// users := ([]*User)(nil)
|
||||
// err := db.Model("user").Scan(&users).
|
||||
func (m *Model) doStructs(pointer any, where ...any) error {
|
||||
model := m
|
||||
// Auto selecting fields by struct attributes.
|
||||
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
|
||||
if v, ok := pointer.(reflect.Value); ok {
|
||||
model = m.Fields(
|
||||
reflect.New(
|
||||
v.Type().Elem(),
|
||||
).Interface(),
|
||||
)
|
||||
} else {
|
||||
model = m.Fields(
|
||||
reflect.New(
|
||||
reflect.ValueOf(pointer).Elem().Type().Elem(),
|
||||
).Interface(),
|
||||
)
|
||||
}
|
||||
}
|
||||
all, err := model.All(where...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = all.Structs(pointer); err != nil {
|
||||
return err
|
||||
}
|
||||
return model.doWithScanStructs(pointer)
|
||||
}
|
||||
|
||||
// Scan automatically calls Struct or Structs function according to the type of parameter `pointer`.
|
||||
// It calls function doStruct if `pointer` is type of *struct/**struct.
|
||||
// It calls function doStructs if `pointer` is type of *[]struct/*[]*struct.
|
||||
//
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function, see Model.Where.
|
||||
//
|
||||
// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has
|
||||
// default value and there's no record retrieved with the given conditions from table.
|
||||
//
|
||||
// Example:
|
||||
// user := new(User)
|
||||
// err := db.Model("user").Where("id", 1).Scan(user)
|
||||
//
|
||||
// user := (*User)(nil)
|
||||
// err := db.Model("user").Where("id", 1).Scan(&user)
|
||||
//
|
||||
// users := ([]User)(nil)
|
||||
// err := db.Model("user").Scan(&users)
|
||||
//
|
||||
// users := ([]*User)(nil)
|
||||
// err := db.Model("user").Scan(&users).
|
||||
func (m *Model) Scan(pointer any, where ...any) error {
|
||||
reflectInfo := reflection.OriginTypeAndKind(pointer)
|
||||
if reflectInfo.InputKind != reflect.Pointer {
|
||||
return gerror.NewCode(
|
||||
gcode.CodeInvalidParameter,
|
||||
`the parameter "pointer" for function Scan should type of pointer`,
|
||||
)
|
||||
}
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.Slice, reflect.Array:
|
||||
return m.doStructs(pointer, where...)
|
||||
|
||||
case reflect.Struct, reflect.Invalid:
|
||||
return m.doStruct(pointer, where...)
|
||||
|
||||
default:
|
||||
return gerror.NewCode(
|
||||
gcode.CodeInvalidParameter,
|
||||
`element of parameter "pointer" for function Scan should type of struct/*struct/[]struct/[]*struct`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ScanAndCount scans a single record or record array that matches the given conditions and counts the total number
|
||||
// of records that match those conditions.
|
||||
//
|
||||
// If `useFieldForCount` is true, it will use the fields specified in the model for counting;
|
||||
// The `pointer` parameter is a pointer to a struct that the scanned data will be stored in.
|
||||
// The `totalCount` parameter is a pointer to an integer that will be set to the total number of records that match the given conditions.
|
||||
// The where parameter is an optional list of conditions to use when retrieving records.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var count int
|
||||
// user := new(User)
|
||||
// err := db.Model("user").Where("id", 1).ScanAndCount(user,&count,true)
|
||||
// fmt.Println(user, count)
|
||||
//
|
||||
// Example Join:
|
||||
//
|
||||
// type User struct {
|
||||
// Id int
|
||||
// Passport string
|
||||
// Name string
|
||||
// Age int
|
||||
// }
|
||||
// var users []User
|
||||
// var count int
|
||||
// db.Model(table).As("u1").
|
||||
// LeftJoin(tableName2, "u2", "u2.id=u1.id").
|
||||
// Fields("u1.passport,u1.id,u2.name,u2.age").
|
||||
// Where("u1.id<2").
|
||||
// ScanAndCount(&users, &count, false)
|
||||
func (m *Model) ScanAndCount(pointer any, totalCount *int, useFieldForCount bool) (err error) {
|
||||
// support Fields with *, example: .Fields("a.*, b.name"). Count sql is select count(1) from xxx
|
||||
countModel := m.Clone()
|
||||
// Decide how to build the COUNT() expression:
|
||||
// - If caller explicitly wants to use the single field expression for counting,
|
||||
// honor it (e.g. Fields("DISTINCT col") with useFieldForCount = true).
|
||||
// - Otherwise, clear fields to let Count() use its default COUNT(1),
|
||||
// avoiding invalid COUNT(field1, field2, ...) with multiple fields,
|
||||
// or incorrect COUNT(DISTINCT 1) when Distinct() is set.
|
||||
if useFieldForCount && len(m.fields) == 1 {
|
||||
countModel.fields = m.fields
|
||||
} else {
|
||||
countModel.fields = nil
|
||||
}
|
||||
if len(m.pageCacheOption) > 0 {
|
||||
countModel = countModel.Cache(m.pageCacheOption[0])
|
||||
}
|
||||
// Get the total count of records
|
||||
*totalCount, err = countModel.Count()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the total count is 0, there are no records to retrieve, so return early
|
||||
if *totalCount == 0 {
|
||||
return
|
||||
}
|
||||
scanModel := m.Clone()
|
||||
if len(m.pageCacheOption) > 1 {
|
||||
scanModel = scanModel.Cache(m.pageCacheOption[1])
|
||||
}
|
||||
err = scanModel.Scan(pointer)
|
||||
return
|
||||
}
|
||||
|
||||
// ScanList converts `r` to struct slice which contains other complex struct attributes.
|
||||
// Note that the parameter `listPointer` should be type of *[]struct/*[]*struct.
|
||||
//
|
||||
// See Result.ScanList.
|
||||
func (m *Model) ScanList(structSlicePointer any, bindToAttrName string, relationAttrNameAndFields ...string) (err error) {
|
||||
var result Result
|
||||
out, err := checkGetSliceElementInfoForScanList(structSlicePointer, bindToAttrName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(m.fields) > 0 || len(m.fieldsEx) != 0 {
|
||||
// There are custom fields.
|
||||
result, err = m.All()
|
||||
} else {
|
||||
// Filter fields using temporary created struct using reflect.New.
|
||||
result, err = m.Fields(reflect.New(out.BindToAttrType).Interface()).All()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
relationAttrName string
|
||||
relationFields string
|
||||
)
|
||||
switch len(relationAttrNameAndFields) {
|
||||
case 2:
|
||||
relationAttrName = relationAttrNameAndFields[0]
|
||||
relationFields = relationAttrNameAndFields[1]
|
||||
case 1:
|
||||
relationFields = relationAttrNameAndFields[0]
|
||||
}
|
||||
return doScanList(doScanListInput{
|
||||
Model: m,
|
||||
Result: result,
|
||||
StructSlicePointer: structSlicePointer,
|
||||
StructSliceValue: out.SliceReflectValue,
|
||||
BindToAttrName: bindToAttrName,
|
||||
RelationAttrName: relationAttrName,
|
||||
RelationFields: relationFields,
|
||||
})
|
||||
}
|
||||
|
||||
// Value retrieves a specified record value from table and returns the result as interface type.
|
||||
// It returns nil if there's no record found with the given conditions from table.
|
||||
//
|
||||
// If the optional parameter `fieldsAndWhere` is given, the fieldsAndWhere[0] is the selected fields
|
||||
// and fieldsAndWhere[1:] is treated as where condition fields.
|
||||
// Also see Model.Fields and Model.Where functions.
|
||||
func (m *Model) Value(fieldsAndWhere ...any) (Value, error) {
|
||||
var (
|
||||
core = m.db.GetCore()
|
||||
ctx = core.injectInternalColumn(m.GetCtx())
|
||||
)
|
||||
if len(fieldsAndWhere) > 0 {
|
||||
if len(fieldsAndWhere) > 2 {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value()
|
||||
} else if len(fieldsAndWhere) == 2 {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1]).Value()
|
||||
} else {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Value()
|
||||
}
|
||||
}
|
||||
var (
|
||||
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true)
|
||||
all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(all) > 0 {
|
||||
internalData := core.getInternalColumnFromCtx(ctx)
|
||||
if internalData == nil {
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeInternalError,
|
||||
`query count error: the internal context data is missing. there's internal issue should be fixed`,
|
||||
)
|
||||
}
|
||||
// If FirstResultColumn present, it returns the value of the first record of the first field.
|
||||
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
|
||||
if v, ok := all[0][internalData.FirstResultColumn]; ok {
|
||||
return v, nil
|
||||
}
|
||||
// Fields number check.
|
||||
var recordFields = m.getRecordFields(all[0])
|
||||
if len(recordFields) == 1 {
|
||||
for _, v := range all[0] {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
// it returns error if there are multiple fields in the result record.
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid fields for "Value" operation, result fields number "%d"%s, but expect one`,
|
||||
len(recordFields),
|
||||
gjson.MustEncodeString(recordFields),
|
||||
)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *Model) getRecordFields(record Record) []string {
|
||||
if len(record) == 0 {
|
||||
return nil
|
||||
}
|
||||
var fields = make([]string, 0)
|
||||
for k := range record {
|
||||
fields = append(fields, k)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// Count does "SELECT COUNT(x) FROM ..." statement for the model.
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) Count(where ...any) (int, error) {
|
||||
var (
|
||||
core = m.db.GetCore()
|
||||
ctx = core.injectInternalColumn(m.GetCtx())
|
||||
)
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).Count()
|
||||
}
|
||||
var (
|
||||
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false)
|
||||
all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...)
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(all) > 0 {
|
||||
internalData := core.getInternalColumnFromCtx(ctx)
|
||||
if internalData == nil {
|
||||
return 0, gerror.NewCode(
|
||||
gcode.CodeInternalError,
|
||||
`query count error: the internal context data is missing. there's internal issue should be fixed`,
|
||||
)
|
||||
}
|
||||
// If FirstResultColumn present, it returns the value of the first record of the first field.
|
||||
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
|
||||
if v, ok := all[0][internalData.FirstResultColumn]; ok {
|
||||
return v.Int(), nil
|
||||
}
|
||||
// Fields number check.
|
||||
var recordFields = m.getRecordFields(all[0])
|
||||
if len(recordFields) == 1 {
|
||||
for _, v := range all[0] {
|
||||
return v.Int(), nil
|
||||
}
|
||||
}
|
||||
// it returns error if there are multiple fields in the result record.
|
||||
return 0, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid fields for "Count" operation, result fields number "%d"%s, but expect one`,
|
||||
len(recordFields),
|
||||
gjson.MustEncodeString(recordFields),
|
||||
)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Exist does "SELECT 1 FROM ... LIMIT 1" statement for the model.
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) Exist(where ...any) (bool, error) {
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).Exist()
|
||||
}
|
||||
one, err := m.Fields(Raw("1")).One()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, val := range one {
|
||||
if val.Bool() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// CountColumn does "SELECT COUNT(x) FROM ..." statement for the model.
|
||||
func (m *Model) CountColumn(column string) (int, error) {
|
||||
if len(column) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return m.Fields(column).Count()
|
||||
}
|
||||
|
||||
// Min does "SELECT MIN(x) FROM ..." statement for the model.
|
||||
func (m *Model) Min(column string) (float64, error) {
|
||||
if len(column) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
value, err := m.Fields(fmt.Sprintf(`MIN(%s)`, m.QuoteWord(column))).Value()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value.Float64(), err
|
||||
}
|
||||
|
||||
// Max does "SELECT MAX(x) FROM ..." statement for the model.
|
||||
func (m *Model) Max(column string) (float64, error) {
|
||||
if len(column) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
value, err := m.Fields(fmt.Sprintf(`MAX(%s)`, m.QuoteWord(column))).Value()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value.Float64(), err
|
||||
}
|
||||
|
||||
// Avg does "SELECT AVG(x) FROM ..." statement for the model.
|
||||
func (m *Model) Avg(column string) (float64, error) {
|
||||
if len(column) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
value, err := m.Fields(fmt.Sprintf(`AVG(%s)`, m.QuoteWord(column))).Value()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value.Float64(), err
|
||||
}
|
||||
|
||||
// Sum does "SELECT SUM(x) FROM ..." statement for the model.
|
||||
func (m *Model) Sum(column string) (float64, error) {
|
||||
if len(column) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
value, err := m.Fields(fmt.Sprintf(`SUM(%s)`, m.QuoteWord(column))).Value()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return value.Float64(), err
|
||||
}
|
||||
|
||||
// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement for the model.
|
||||
func (m *Model) Union(unions ...*Model) *Model {
|
||||
return m.db.Union(unions...)
|
||||
}
|
||||
|
||||
// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement for the model.
|
||||
func (m *Model) UnionAll(unions ...*Model) *Model {
|
||||
return m.db.UnionAll(unions...)
|
||||
}
|
||||
|
||||
// Limit sets the "LIMIT" statement for the model.
|
||||
// The parameter `limit` can be either one or two number, if passed two number is passed,
|
||||
// it then sets "LIMIT limit[0],limit[1]" statement for the model, or else it sets "LIMIT limit[0]"
|
||||
// statement.
|
||||
// Note: Negative values are treated as zero.
|
||||
func (m *Model) Limit(limit ...int) *Model {
|
||||
model := m.getModel()
|
||||
switch len(limit) {
|
||||
case 1:
|
||||
if limit[0] < 0 {
|
||||
limit[0] = 0
|
||||
}
|
||||
model.limit = limit[0]
|
||||
case 2:
|
||||
if limit[0] < 0 {
|
||||
limit[0] = 0
|
||||
}
|
||||
if limit[1] < 0 {
|
||||
limit[1] = 0
|
||||
}
|
||||
model.start = limit[0]
|
||||
model.limit = limit[1]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// Offset sets the "OFFSET" statement for the model.
|
||||
// It only makes sense for some databases like SQLServer, PostgreSQL, etc.
|
||||
// Note: Negative values are treated as zero.
|
||||
func (m *Model) Offset(offset int) *Model {
|
||||
model := m.getModel()
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
model.offset = offset
|
||||
return model
|
||||
}
|
||||
|
||||
// Distinct forces the query to only return distinct results.
|
||||
func (m *Model) Distinct() *Model {
|
||||
model := m.getModel()
|
||||
model.distinct = "DISTINCT "
|
||||
return model
|
||||
}
|
||||
|
||||
// Page sets the paging number for the model.
|
||||
// The parameter `page` is started from 1 for paging.
|
||||
// Note that, it differs that the Limit function starts from 0 for "LIMIT" statement.
|
||||
// Note: Negative limit values are treated as zero.
|
||||
func (m *Model) Page(page, limit int) *Model {
|
||||
model := m.getModel()
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
model.start = (page - 1) * limit
|
||||
model.limit = limit
|
||||
return model
|
||||
}
|
||||
|
||||
// Having sets the having statement for the model.
|
||||
// The parameters of this function usage are as the same as function Where.
|
||||
// See Where.
|
||||
func (m *Model) Having(having any, args ...any) *Model {
|
||||
model := m.getModel()
|
||||
model.having = []any{
|
||||
having, args,
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// doGetAll does "SELECT FROM ..." statement for the model.
|
||||
// It retrieves the records from table and returns the result as slice type.
|
||||
// It returns nil if there's no record retrieved with the given conditions from table.
|
||||
//
|
||||
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...any) (Result, error) {
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).All()
|
||||
}
|
||||
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1)
|
||||
return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...)
|
||||
}
|
||||
|
||||
// doGetAllBySql does the select statement on the database.
|
||||
func (m *Model) doGetAllBySql(
|
||||
ctx context.Context, selectType SelectType, sql string, args ...any,
|
||||
) (result Result, err error) {
|
||||
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
|
||||
return
|
||||
}
|
||||
|
||||
in := &HookSelectInput{
|
||||
internalParamHookSelect: internalParamHookSelect{
|
||||
internalParamHook: internalParamHook{
|
||||
link: m.getLink(false),
|
||||
},
|
||||
handler: m.hookHandler.Select,
|
||||
},
|
||||
Model: m,
|
||||
Table: m.tables,
|
||||
Schema: m.schema,
|
||||
Sql: sql,
|
||||
Args: m.mergeArguments(args),
|
||||
SelectType: selectType,
|
||||
}
|
||||
if result, err = in.Next(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) getFormattedSqlAndArgs(
|
||||
ctx context.Context, selectType SelectType, limit1 bool,
|
||||
) (sqlWithHolder string, holderArgs []any) {
|
||||
switch selectType {
|
||||
case SelectTypeCount:
|
||||
queryFields := "COUNT(1)"
|
||||
if len(m.fields) > 0 {
|
||||
// DO NOT quote the m.fields here, in case of fields like:
|
||||
// DISTINCT t.user_id uid
|
||||
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.getFieldsAsStr())
|
||||
}
|
||||
// Raw SQL Model.
|
||||
if m.rawSql != "" {
|
||||
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true)
|
||||
sqlWithHolder = fmt.Sprintf(
|
||||
"SELECT %s FROM (%s%s) AS T",
|
||||
queryFields, m.rawSql, conditionWhere+conditionExtra,
|
||||
)
|
||||
return sqlWithHolder, conditionArgs
|
||||
}
|
||||
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true)
|
||||
sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra)
|
||||
if len(m.groupBy) > 0 {
|
||||
sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder)
|
||||
}
|
||||
return sqlWithHolder, conditionArgs
|
||||
|
||||
default:
|
||||
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, limit1, false)
|
||||
// Raw SQL Model, especially for UNION/UNION ALL featured SQL.
|
||||
if m.rawSql != "" {
|
||||
sqlWithHolder = fmt.Sprintf(
|
||||
"%s%s",
|
||||
m.rawSql,
|
||||
conditionWhere+conditionExtra,
|
||||
)
|
||||
return sqlWithHolder, conditionArgs
|
||||
}
|
||||
// DO NOT quote the m.fields where, in case of fields like:
|
||||
// DISTINCT t.user_id uid
|
||||
sqlWithHolder = fmt.Sprintf(
|
||||
"SELECT %s%s FROM %s%s",
|
||||
m.distinct, m.getFieldsFiltered(), m.tables, conditionWhere+conditionExtra,
|
||||
)
|
||||
return sqlWithHolder, conditionArgs
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []any) {
|
||||
holder, args = m.getFormattedSqlAndArgs(
|
||||
ctx, SelectTypeDefault, false,
|
||||
)
|
||||
args = m.mergeArguments(args)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) getAutoPrefix() string {
|
||||
autoPrefix := ""
|
||||
if gstr.Contains(m.tables, " JOIN ") {
|
||||
autoPrefix = m.QuoteWord(
|
||||
m.db.GetCore().guessPrimaryTableName(m.tablesInit),
|
||||
)
|
||||
}
|
||||
return autoPrefix
|
||||
}
|
||||
|
||||
func (m *Model) getFieldsAsStr() string {
|
||||
var (
|
||||
fieldsStr string
|
||||
)
|
||||
for _, v := range m.fields {
|
||||
field := gconv.String(v)
|
||||
switch {
|
||||
case gstr.ContainsAny(field, "()"):
|
||||
case gstr.ContainsAny(field, ". "):
|
||||
default:
|
||||
switch v.(type) {
|
||||
case Raw, *Raw:
|
||||
default:
|
||||
field = m.QuoteWord(field)
|
||||
}
|
||||
}
|
||||
if fieldsStr != "" {
|
||||
fieldsStr += ","
|
||||
}
|
||||
fieldsStr += field
|
||||
}
|
||||
return fieldsStr
|
||||
}
|
||||
|
||||
// getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will
|
||||
// really be committed to underlying database driver.
|
||||
func (m *Model) getFieldsFiltered() string {
|
||||
if len(m.fieldsEx) == 0 && len(m.fields) == 0 {
|
||||
return defaultField
|
||||
}
|
||||
if len(m.fieldsEx) == 0 && len(m.fields) > 0 {
|
||||
return m.getFieldsAsStr()
|
||||
}
|
||||
var (
|
||||
fieldsArray []string
|
||||
fieldsExSet = gset.NewStrSetFrom(gconv.Strings(m.fieldsEx))
|
||||
)
|
||||
if len(m.fields) > 0 {
|
||||
// Filter custom fields with fieldEx.
|
||||
fieldsArray = make([]string, 0, 8)
|
||||
for _, v := range m.fields {
|
||||
field := gconv.String(v)
|
||||
fieldsArray = append(fieldsArray, field[gstr.PosR(field, "-")+1:])
|
||||
}
|
||||
} else {
|
||||
if gstr.Contains(m.tables, " ") {
|
||||
panic("function FieldsEx supports only single table operations")
|
||||
}
|
||||
// Filter table fields with fieldEx.
|
||||
tableFields, err := m.TableFields(m.tablesInit)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(tableFields) == 0 {
|
||||
panic(fmt.Sprintf(`empty table fields for table "%s"`, m.tables))
|
||||
}
|
||||
fieldsArray = make([]string, len(tableFields))
|
||||
for k, v := range tableFields {
|
||||
fieldsArray[v.Index] = k
|
||||
}
|
||||
}
|
||||
newFields := ""
|
||||
for _, k := range fieldsArray {
|
||||
if fieldsExSet.Contains(k) {
|
||||
continue
|
||||
}
|
||||
if len(newFields) > 0 {
|
||||
newFields += ","
|
||||
}
|
||||
newFields += m.QuoteWord(k)
|
||||
}
|
||||
return newFields
|
||||
}
|
||||
|
||||
// formatCondition formats where arguments of the model and returns a new condition sql and its arguments.
|
||||
// Note that this function does not change any attribute value of the `m`.
|
||||
//
|
||||
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
|
||||
func (m *Model) formatCondition(
|
||||
ctx context.Context, limit1 bool, isCountStatement bool,
|
||||
) (conditionWhere string, conditionExtra string, conditionArgs []any) {
|
||||
var autoPrefix = m.getAutoPrefix()
|
||||
// GROUP BY.
|
||||
if m.groupBy != "" {
|
||||
conditionExtra += " GROUP BY " + m.groupBy
|
||||
}
|
||||
// WHERE
|
||||
conditionWhere, conditionArgs = m.whereBuilder.Build()
|
||||
softDeletingCondition := m.softTimeMaintainer().GetDeleteCondition(ctx)
|
||||
if m.rawSql != "" && conditionWhere != "" {
|
||||
if gstr.ContainsI(m.rawSql, " WHERE ") {
|
||||
conditionWhere = " AND " + conditionWhere
|
||||
} else {
|
||||
conditionWhere = " WHERE " + conditionWhere
|
||||
}
|
||||
} else if !m.unscoped && softDeletingCondition != "" {
|
||||
if conditionWhere == "" {
|
||||
conditionWhere = fmt.Sprintf(` WHERE %s`, softDeletingCondition)
|
||||
} else {
|
||||
conditionWhere = fmt.Sprintf(` WHERE (%s) AND %s`, conditionWhere, softDeletingCondition)
|
||||
}
|
||||
} else {
|
||||
if conditionWhere != "" {
|
||||
conditionWhere = " WHERE " + conditionWhere
|
||||
}
|
||||
}
|
||||
// HAVING.
|
||||
if len(m.having) > 0 {
|
||||
havingHolder := WhereHolder{
|
||||
Where: m.having[0],
|
||||
Args: gconv.Interfaces(m.having[1]),
|
||||
Prefix: autoPrefix,
|
||||
}
|
||||
havingStr, havingArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
|
||||
WhereHolder: havingHolder,
|
||||
OmitNil: m.option&optionOmitNilWhere > 0,
|
||||
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
|
||||
Schema: m.schema,
|
||||
Table: m.tables,
|
||||
})
|
||||
if len(havingStr) > 0 {
|
||||
conditionExtra += " HAVING " + havingStr
|
||||
conditionArgs = append(conditionArgs, havingArgs...)
|
||||
}
|
||||
}
|
||||
// ORDER BY.
|
||||
if !isCountStatement { // The count statement of sqlserver cannot contain the order by statement
|
||||
if m.orderBy != "" {
|
||||
conditionExtra += " ORDER BY " + m.orderBy
|
||||
}
|
||||
}
|
||||
// LIMIT.
|
||||
if !isCountStatement {
|
||||
if m.limit != 0 {
|
||||
if m.start >= 0 {
|
||||
conditionExtra += fmt.Sprintf(" LIMIT %d,%d", m.start, m.limit)
|
||||
} else {
|
||||
conditionExtra += fmt.Sprintf(" LIMIT %d", m.limit)
|
||||
}
|
||||
} else if limit1 {
|
||||
conditionExtra += " LIMIT 1"
|
||||
}
|
||||
|
||||
if m.offset >= 0 {
|
||||
conditionExtra += fmt.Sprintf(" OFFSET %d", m.offset)
|
||||
}
|
||||
}
|
||||
|
||||
if m.lockInfo != "" {
|
||||
conditionExtra += " " + m.lockInfo
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ShardingConfig defines the configuration for database/table sharding.
|
||||
type ShardingConfig struct {
|
||||
// Table sharding configuration
|
||||
Table ShardingTableConfig
|
||||
// Schema sharding configuration
|
||||
Schema ShardingSchemaConfig
|
||||
}
|
||||
|
||||
// ShardingSchemaConfig defines the configuration for database sharding.
|
||||
type ShardingSchemaConfig struct {
|
||||
// Enable schema sharding
|
||||
Enable bool
|
||||
// Schema rule prefix, e.g., "db_"
|
||||
Prefix string
|
||||
// ShardingRule defines how to route data to different database nodes
|
||||
Rule ShardingRule
|
||||
}
|
||||
|
||||
// ShardingTableConfig defines the configuration for table sharding
|
||||
type ShardingTableConfig struct {
|
||||
// Enable table sharding
|
||||
Enable bool
|
||||
// Table rule prefix, e.g., "user_"
|
||||
Prefix string
|
||||
// ShardingRule defines how to route data to different tables
|
||||
Rule ShardingRule
|
||||
}
|
||||
|
||||
// ShardingRule defines the interface for sharding rules
|
||||
type ShardingRule interface {
|
||||
// SchemaName returns the target schema name based on sharding value.
|
||||
SchemaName(ctx context.Context, config ShardingSchemaConfig, value any) (string, error)
|
||||
// TableName returns the target table name based on sharding value.
|
||||
TableName(ctx context.Context, config ShardingTableConfig, value any) (string, error)
|
||||
}
|
||||
|
||||
// DefaultShardingRule implements a simple modulo-based sharding rule
|
||||
type DefaultShardingRule struct {
|
||||
// Number of schema count.
|
||||
SchemaCount int
|
||||
// Number of tables per schema.
|
||||
TableCount int
|
||||
}
|
||||
|
||||
// Sharding creates a sharding model with given sharding configuration.
|
||||
func (m *Model) Sharding(config ShardingConfig) *Model {
|
||||
model := m.getModel()
|
||||
model.shardingConfig = config
|
||||
return model
|
||||
}
|
||||
|
||||
// ShardingValue sets the sharding value for routing
|
||||
func (m *Model) ShardingValue(value any) *Model {
|
||||
model := m.getModel()
|
||||
model.shardingValue = value
|
||||
return model
|
||||
}
|
||||
|
||||
// getActualSchema returns the actual schema based on sharding configuration.
|
||||
// TODO it does not support schemas in different database config node.
|
||||
func (m *Model) getActualSchema(ctx context.Context, defaultSchema string) (string, error) {
|
||||
if !m.shardingConfig.Schema.Enable {
|
||||
return defaultSchema, nil
|
||||
}
|
||||
if m.shardingValue == nil {
|
||||
return defaultSchema, gerror.NewCode(
|
||||
gcode.CodeInvalidParameter, "sharding value is required when sharding feature enabled",
|
||||
)
|
||||
}
|
||||
if m.shardingConfig.Schema.Rule == nil {
|
||||
return defaultSchema, gerror.NewCode(
|
||||
gcode.CodeInvalidParameter, "sharding rule is required when sharding feature enabled",
|
||||
)
|
||||
}
|
||||
return m.shardingConfig.Schema.Rule.SchemaName(ctx, m.shardingConfig.Schema, m.shardingValue)
|
||||
}
|
||||
|
||||
// getActualTable returns the actual table name based on sharding configuration
|
||||
func (m *Model) getActualTable(ctx context.Context, defaultTable string) (string, error) {
|
||||
if !m.shardingConfig.Table.Enable {
|
||||
return defaultTable, nil
|
||||
}
|
||||
if m.shardingValue == nil {
|
||||
return defaultTable, gerror.NewCode(
|
||||
gcode.CodeInvalidParameter, "sharding value is required when sharding feature enabled",
|
||||
)
|
||||
}
|
||||
if m.shardingConfig.Table.Rule == nil {
|
||||
return defaultTable, gerror.NewCode(
|
||||
gcode.CodeInvalidParameter, "sharding rule is required when sharding feature enabled",
|
||||
)
|
||||
}
|
||||
return m.shardingConfig.Table.Rule.TableName(ctx, m.shardingConfig.Table, m.shardingValue)
|
||||
}
|
||||
|
||||
// SchemaName implements the default database sharding strategy
|
||||
func (r *DefaultShardingRule) SchemaName(ctx context.Context, config ShardingSchemaConfig, value any) (string, error) {
|
||||
if r.SchemaCount == 0 {
|
||||
return "", gerror.NewCode(
|
||||
gcode.CodeInvalidParameter, "schema count should not be 0 using DefaultShardingRule when schema sharding enabled",
|
||||
)
|
||||
}
|
||||
hashValue, err := getHashValue(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nodeIndex := hashValue % uint64(r.SchemaCount)
|
||||
return fmt.Sprintf("%s%d", config.Prefix, nodeIndex), nil
|
||||
}
|
||||
|
||||
// TableName implements the default table sharding strategy
|
||||
func (r *DefaultShardingRule) TableName(ctx context.Context, config ShardingTableConfig, value any) (string, error) {
|
||||
if r.TableCount == 0 {
|
||||
return "", gerror.NewCode(
|
||||
gcode.CodeInvalidParameter, "table count should not be 0 using DefaultShardingRule when table sharding enabled",
|
||||
)
|
||||
}
|
||||
hashValue, err := getHashValue(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tableIndex := hashValue % uint64(r.TableCount)
|
||||
return fmt.Sprintf("%s%d", config.Prefix, tableIndex), nil
|
||||
}
|
||||
|
||||
// getHashValue converts sharding value to uint64 hash
|
||||
func getHashValue(value any) (uint64, error) {
|
||||
var rv = reflect.ValueOf(value)
|
||||
switch rv.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64:
|
||||
return gconv.Uint64(value), nil
|
||||
default:
|
||||
h := fnv.New64a()
|
||||
_, err := h.Write(gconv.Bytes(value))
|
||||
if err != nil {
|
||||
return 0, gerror.WrapCode(gcode.CodeInternalError, err)
|
||||
}
|
||||
return h.Sum64(), nil
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,384 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
"git.magicany.cc/black1552/gin-base/database/utils"
|
||||
"github.com/gogf/gf/v2/container/garray"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gcache"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
)
|
||||
|
||||
// SoftTimeType custom defines the soft time field type.
|
||||
type SoftTimeType int
|
||||
|
||||
const (
|
||||
SoftTimeTypeAuto SoftTimeType = 0 // (Default)Auto detect the field type by table field type.
|
||||
SoftTimeTypeTime SoftTimeType = 1 // Using datetime as the field value.
|
||||
SoftTimeTypeTimestamp SoftTimeType = 2 // In unix seconds.
|
||||
SoftTimeTypeTimestampMilli SoftTimeType = 3 // In unix milliseconds.
|
||||
SoftTimeTypeTimestampMicro SoftTimeType = 4 // In unix microseconds.
|
||||
SoftTimeTypeTimestampNano SoftTimeType = 5 // In unix nanoseconds.
|
||||
)
|
||||
|
||||
// SoftTimeOption is the option to customize soft time feature for Model.
|
||||
type SoftTimeOption struct {
|
||||
SoftTimeType SoftTimeType // The value type for soft time field.
|
||||
}
|
||||
|
||||
type softTimeMaintainer struct {
|
||||
*Model
|
||||
}
|
||||
|
||||
// SoftTimeFieldType represents different soft time field purposes.
|
||||
type SoftTimeFieldType int
|
||||
|
||||
const (
|
||||
SoftTimeFieldCreate SoftTimeFieldType = iota
|
||||
SoftTimeFieldUpdate
|
||||
SoftTimeFieldDelete
|
||||
)
|
||||
|
||||
type iSoftTimeMaintainer interface {
|
||||
// GetFieldInfo returns field name and type for specified field purpose.
|
||||
GetFieldInfo(ctx context.Context, schema, table string, fieldPurpose SoftTimeFieldType) (fieldName string, localType LocalType)
|
||||
|
||||
// GetFieldValue generates value for create/update/delete operations.
|
||||
GetFieldValue(ctx context.Context, localType LocalType, isDeleted bool) any
|
||||
|
||||
// GetDeleteCondition returns WHERE condition for soft delete query.
|
||||
GetDeleteCondition(ctx context.Context) string
|
||||
|
||||
// GetDeleteData returns UPDATE statement data for soft delete.
|
||||
GetDeleteData(ctx context.Context, prefix, fieldName string, localType LocalType) (holder string, value any)
|
||||
}
|
||||
|
||||
// getSoftFieldNameAndTypeCacheItem is the internal struct for storing create/update/delete fields.
|
||||
type getSoftFieldNameAndTypeCacheItem struct {
|
||||
FieldName string
|
||||
FieldType LocalType
|
||||
}
|
||||
|
||||
var (
|
||||
// Default field names of table for automatic-filled for record creating.
|
||||
createdFieldNames = []string{"created_at", "create_at"}
|
||||
// Default field names of table for automatic-filled for record updating.
|
||||
updatedFieldNames = []string{"updated_at", "update_at"}
|
||||
// Default field names of table for automatic-filled for record deleting.
|
||||
deletedFieldNames = []string{"deleted_at", "delete_at"}
|
||||
)
|
||||
|
||||
// SoftTime sets the SoftTimeOption to customize soft time feature for Model.
|
||||
func (m *Model) SoftTime(option SoftTimeOption) *Model {
|
||||
model := m.getModel()
|
||||
model.softTimeOption = option
|
||||
return model
|
||||
}
|
||||
|
||||
// Unscoped disables the soft time feature for insert, update and delete operations.
|
||||
func (m *Model) Unscoped() *Model {
|
||||
model := m.getModel()
|
||||
model.unscoped = true
|
||||
return model
|
||||
}
|
||||
|
||||
func (m *Model) softTimeMaintainer() iSoftTimeMaintainer {
|
||||
return &softTimeMaintainer{
|
||||
m,
|
||||
}
|
||||
}
|
||||
|
||||
// GetFieldInfo returns field name and type for specified field purpose.
|
||||
// It checks the key with or without cases or chars '-'/'_'/'.'/' '.
|
||||
func (m *softTimeMaintainer) GetFieldInfo(
|
||||
ctx context.Context, schema, table string, fieldPurpose SoftTimeFieldType,
|
||||
) (fieldName string, localType LocalType) {
|
||||
// Check if feature is disabled
|
||||
if m.db.GetConfig().TimeMaintainDisabled {
|
||||
return "", LocalTypeUndefined
|
||||
}
|
||||
|
||||
// Determine table name
|
||||
tableName := table
|
||||
if tableName == "" {
|
||||
tableName = m.tablesInit
|
||||
}
|
||||
|
||||
// Get config and field candidates
|
||||
config := m.db.GetConfig()
|
||||
var (
|
||||
configField string
|
||||
defaultFields []string
|
||||
)
|
||||
|
||||
switch fieldPurpose {
|
||||
case SoftTimeFieldCreate:
|
||||
configField = config.CreatedAt
|
||||
defaultFields = createdFieldNames
|
||||
case SoftTimeFieldUpdate:
|
||||
configField = config.UpdatedAt
|
||||
defaultFields = updatedFieldNames
|
||||
case SoftTimeFieldDelete:
|
||||
configField = config.DeletedAt
|
||||
defaultFields = deletedFieldNames
|
||||
}
|
||||
|
||||
// Use config field if specified, otherwise use defaults
|
||||
if configField != "" {
|
||||
return m.getSoftFieldNameAndType(ctx, schema, tableName, []string{configField})
|
||||
}
|
||||
return m.getSoftFieldNameAndType(ctx, schema, tableName, defaultFields)
|
||||
}
|
||||
|
||||
// getSoftFieldNameAndType retrieves and returns the field name of the table for possible key.
|
||||
func (m *softTimeMaintainer) getSoftFieldNameAndType(
|
||||
ctx context.Context, schema, table string, candidateFields []string,
|
||||
) (fieldName string, fieldType LocalType) {
|
||||
// Build cache key
|
||||
cacheKey := genSoftTimeFieldNameTypeCacheKey(schema, table, candidateFields)
|
||||
|
||||
// Try to get from cache
|
||||
cache := m.db.GetCore().GetInnerMemCache()
|
||||
result, err := cache.GetOrSetFunc(ctx, cacheKey, func(ctx context.Context) (any, error) {
|
||||
// Get table fields
|
||||
fieldsMap, err := m.TableFields(table, schema)
|
||||
if err != nil || len(fieldsMap) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Search for matching field
|
||||
for _, field := range candidateFields {
|
||||
if name := searchFieldNameFromMap(fieldsMap, field); name != "" {
|
||||
fType, _ := m.db.CheckLocalTypeForField(ctx, fieldsMap[name].Type, nil)
|
||||
return getSoftFieldNameAndTypeCacheItem{
|
||||
FieldName: name,
|
||||
FieldType: fType,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}, gcache.DurationNoExpire)
|
||||
|
||||
if err != nil || result == nil {
|
||||
return "", LocalTypeUndefined
|
||||
}
|
||||
|
||||
item := result.Val().(getSoftFieldNameAndTypeCacheItem)
|
||||
return item.FieldName, item.FieldType
|
||||
}
|
||||
|
||||
func searchFieldNameFromMap(fieldsMap map[string]*TableField, key string) string {
|
||||
if len(fieldsMap) == 0 {
|
||||
return ""
|
||||
}
|
||||
_, ok := fieldsMap[key]
|
||||
if ok {
|
||||
return key
|
||||
}
|
||||
key = utils.RemoveSymbols(key)
|
||||
for k := range fieldsMap {
|
||||
if strings.EqualFold(utils.RemoveSymbols(k), key) {
|
||||
return k
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetDeleteCondition returns WHERE condition for soft delete query.
|
||||
// It supports multiple tables string like:
|
||||
// "user u, user_detail ud"
|
||||
// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid)"
|
||||
// "user LEFT JOIN user_detail ON(user_detail.uid=user.uid)"
|
||||
// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid) LEFT JOIN user_stats us ON(us.uid=u.uid)".
|
||||
func (m *softTimeMaintainer) GetDeleteCondition(ctx context.Context) string {
|
||||
if m.unscoped {
|
||||
return ""
|
||||
}
|
||||
conditionArray := garray.NewStrArray()
|
||||
if gstr.Contains(m.tables, " JOIN ") {
|
||||
// Base table.
|
||||
tableMatch, _ := gregex.MatchString(`(.+?) [A-Z]+ JOIN`, m.tables)
|
||||
conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, tableMatch[1]))
|
||||
// Multiple joined tables, exclude the sub query sql which contains char '(' and ')'.
|
||||
tableMatches, _ := gregex.MatchAllString(`JOIN ([^()]+?) ON`, m.tables)
|
||||
for _, match := range tableMatches {
|
||||
conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, match[1]))
|
||||
}
|
||||
}
|
||||
if conditionArray.Len() == 0 && gstr.Contains(m.tables, ",") {
|
||||
// Multiple base tables.
|
||||
for _, s := range gstr.SplitAndTrim(m.tables, ",") {
|
||||
conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(ctx, s))
|
||||
}
|
||||
}
|
||||
conditionArray.FilterEmpty()
|
||||
if conditionArray.Len() > 0 {
|
||||
return conditionArray.Join(" AND ")
|
||||
}
|
||||
// Only one table.
|
||||
fieldName, fieldType := m.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldDelete)
|
||||
if fieldName != "" {
|
||||
return m.buildDeleteCondition(ctx, "", fieldName, fieldType)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getConditionOfTableStringForSoftDeleting does something as its name describes.
|
||||
// Examples for `s`:
|
||||
// - `test`.`demo` as b
|
||||
// - `test`.`demo` b
|
||||
// - `demo`
|
||||
// - demo
|
||||
func (m *softTimeMaintainer) getConditionOfTableStringForSoftDeleting(ctx context.Context, s string) string {
|
||||
var (
|
||||
table string
|
||||
schema string
|
||||
array1 = gstr.SplitAndTrim(s, " ")
|
||||
array2 = gstr.SplitAndTrim(array1[0], ".")
|
||||
)
|
||||
if len(array2) >= 2 {
|
||||
table = array2[1]
|
||||
schema = array2[0]
|
||||
} else {
|
||||
table = array2[0]
|
||||
}
|
||||
fieldName, fieldType := m.GetFieldInfo(ctx, schema, table, SoftTimeFieldDelete)
|
||||
if fieldName == "" {
|
||||
return ""
|
||||
}
|
||||
if len(array1) >= 3 {
|
||||
return m.buildDeleteCondition(ctx, array1[2], fieldName, fieldType)
|
||||
}
|
||||
if len(array1) >= 2 {
|
||||
return m.buildDeleteCondition(ctx, array1[1], fieldName, fieldType)
|
||||
}
|
||||
return m.buildDeleteCondition(ctx, table, fieldName, fieldType)
|
||||
}
|
||||
|
||||
// GetDeleteData returns UPDATE statement data for soft delete.
|
||||
func (m *softTimeMaintainer) GetDeleteData(
|
||||
ctx context.Context, prefix, fieldName string, fieldType LocalType,
|
||||
) (holder string, value any) {
|
||||
core := m.db.GetCore()
|
||||
quotedName := core.QuoteWord(fieldName)
|
||||
|
||||
if prefix != "" {
|
||||
quotedName = fmt.Sprintf(`%s.%s`, core.QuoteWord(prefix), quotedName)
|
||||
}
|
||||
|
||||
holder = fmt.Sprintf(`%s=?`, quotedName)
|
||||
value = m.GetFieldValue(ctx, fieldType, false)
|
||||
return
|
||||
}
|
||||
|
||||
// buildDeleteCondition builds WHERE condition for soft delete filtering.
|
||||
func (m *softTimeMaintainer) buildDeleteCondition(
|
||||
ctx context.Context, prefix, fieldName string, fieldType LocalType,
|
||||
) string {
|
||||
core := m.db.GetCore()
|
||||
quotedName := core.QuoteWord(fieldName)
|
||||
|
||||
if prefix != "" {
|
||||
quotedName = fmt.Sprintf(`%s.%s`, core.QuoteWord(prefix), quotedName)
|
||||
}
|
||||
switch m.softTimeOption.SoftTimeType {
|
||||
case SoftTimeTypeAuto:
|
||||
switch fieldType {
|
||||
case LocalTypeDate, LocalTypeTime, LocalTypeDatetime:
|
||||
return fmt.Sprintf(`%s IS NULL`, quotedName)
|
||||
case LocalTypeInt, LocalTypeUint, LocalTypeInt64, LocalTypeUint64, LocalTypeBool:
|
||||
return fmt.Sprintf(`%s=0`, quotedName)
|
||||
default:
|
||||
intlog.Errorf(ctx, `invalid field type "%s" for soft delete condition: prefix=%s, field=%s`, fieldType, prefix, fieldName)
|
||||
return ""
|
||||
}
|
||||
|
||||
case SoftTimeTypeTime:
|
||||
return fmt.Sprintf(`%s IS NULL`, quotedName)
|
||||
|
||||
default:
|
||||
return fmt.Sprintf(`%s=0`, quotedName)
|
||||
}
|
||||
}
|
||||
|
||||
// GetFieldValue generates value for create/update/delete operations.
|
||||
func (m *softTimeMaintainer) GetFieldValue(
|
||||
ctx context.Context, fieldType LocalType, isDeleted bool,
|
||||
) any {
|
||||
// For deleted field, return "empty" value
|
||||
if isDeleted {
|
||||
return m.getEmptyValue(fieldType)
|
||||
}
|
||||
|
||||
// For create/update/delete, return current time value
|
||||
switch m.softTimeOption.SoftTimeType {
|
||||
case SoftTimeTypeAuto:
|
||||
return m.getAutoValue(ctx, fieldType)
|
||||
default:
|
||||
switch fieldType {
|
||||
case LocalTypeBool:
|
||||
return 1
|
||||
default:
|
||||
return m.getTimestampValue()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getTimestampValue returns timestamp value for soft time.
|
||||
func (m *softTimeMaintainer) getTimestampValue() any {
|
||||
switch m.softTimeOption.SoftTimeType {
|
||||
case SoftTimeTypeTime:
|
||||
return gtime.Now()
|
||||
case SoftTimeTypeTimestamp:
|
||||
return gtime.Timestamp()
|
||||
case SoftTimeTypeTimestampMilli:
|
||||
return gtime.TimestampMilli()
|
||||
case SoftTimeTypeTimestampMicro:
|
||||
return gtime.TimestampMicro()
|
||||
case SoftTimeTypeTimestampNano:
|
||||
return gtime.TimestampNano()
|
||||
default:
|
||||
panic(gerror.NewCodef(
|
||||
gcode.CodeInternalPanic,
|
||||
`unrecognized SoftTimeType "%d"`, m.softTimeOption.SoftTimeType,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// getEmptyValue returns "empty" value for deleted field.
|
||||
func (m *softTimeMaintainer) getEmptyValue(fieldType LocalType) any {
|
||||
switch fieldType {
|
||||
case LocalTypeDate, LocalTypeTime, LocalTypeDatetime:
|
||||
return nil
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// getAutoValue returns auto-detected value based on field type.
|
||||
func (m *softTimeMaintainer) getAutoValue(ctx context.Context, fieldType LocalType) any {
|
||||
switch fieldType {
|
||||
case LocalTypeDate, LocalTypeTime, LocalTypeDatetime:
|
||||
return gtime.Now()
|
||||
case LocalTypeInt, LocalTypeUint, LocalTypeInt64, LocalTypeUint64:
|
||||
return gtime.Timestamp()
|
||||
case LocalTypeBool:
|
||||
return 1
|
||||
default:
|
||||
intlog.Errorf(ctx, `invalid field type "%s" for soft time auto value`, fieldType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Transaction wraps the transaction logic using function `f`.
|
||||
// It rollbacks the transaction and returns the error from function `f` if
|
||||
// it returns non-nil error. It commits the transaction and returns nil if
|
||||
// function `f` returns nil.
|
||||
//
|
||||
// Note that, you should not Commit or Rollback the transaction in function `f`
|
||||
// as it is automatically handled by this function.
|
||||
func (m *Model) Transaction(ctx context.Context, f func(ctx context.Context, tx TX) error) (err error) {
|
||||
if ctx == nil {
|
||||
ctx = m.GetCtx()
|
||||
}
|
||||
if m.tx != nil {
|
||||
return m.tx.Transaction(ctx, f)
|
||||
}
|
||||
return m.db.Transaction(ctx, f)
|
||||
}
|
||||
|
||||
// TransactionWithOptions executes transaction with options.
|
||||
// The parameter `opts` specifies the transaction options.
|
||||
// The parameter `f` specifies the function that will be called within the transaction.
|
||||
// If f returns error, the transaction will be rolled back, or else the transaction will be committed.
|
||||
func (m *Model) TransactionWithOptions(ctx context.Context, opts TxOptions, f func(ctx context.Context, tx TX) error) (err error) {
|
||||
if ctx == nil {
|
||||
ctx = m.GetCtx()
|
||||
}
|
||||
if m.tx != nil {
|
||||
return m.tx.Transaction(ctx, f)
|
||||
}
|
||||
return m.db.TransactionWithOptions(ctx, opts, f)
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/empty"
|
||||
"git.magicany.cc/black1552/gin-base/database/intlog"
|
||||
"git.magicany.cc/black1552/gin-base/database/reflection"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// Update does "UPDATE ... " statement for the model.
|
||||
//
|
||||
// If the optional parameter `dataAndWhere` is given, the dataAndWhere[0] is the updated data field,
|
||||
// and dataAndWhere[1:] is treated as where condition fields.
|
||||
// Also see Model.Data and Model.Where functions.
|
||||
func (m *Model) Update(dataAndWhere ...any) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(dataAndWhere) > 0 {
|
||||
if len(dataAndWhere) > 2 {
|
||||
return m.Data(dataAndWhere[0]).Where(dataAndWhere[1], dataAndWhere[2:]...).Update()
|
||||
} else if len(dataAndWhere) == 2 {
|
||||
return m.Data(dataAndWhere[0]).Where(dataAndWhere[1]).Update()
|
||||
} else {
|
||||
return m.Data(dataAndWhere[0]).Update()
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if err == nil {
|
||||
m.checkAndRemoveSelectCache(ctx)
|
||||
}
|
||||
}()
|
||||
if m.data == nil {
|
||||
return nil, gerror.NewCode(gcode.CodeMissingParameter, "updating table with empty data")
|
||||
}
|
||||
var (
|
||||
newData any
|
||||
stm = m.softTimeMaintainer()
|
||||
reflectInfo = reflection.OriginTypeAndKind(m.data)
|
||||
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
|
||||
conditionStr = conditionWhere + conditionExtra
|
||||
fieldNameUpdate, fieldTypeUpdate = stm.GetFieldInfo(ctx, "", m.tablesInit, SoftTimeFieldUpdate)
|
||||
)
|
||||
if fieldNameUpdate != "" && (m.unscoped || m.isFieldInFieldsEx(fieldNameUpdate)) {
|
||||
fieldNameUpdate = ""
|
||||
}
|
||||
|
||||
newData, err = m.filterDataForInsertOrUpdate(m.data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.Map, reflect.Struct:
|
||||
var dataMap = anyValueToMapBeforeToRecord(newData)
|
||||
// Automatically update the record updating time.
|
||||
if fieldNameUpdate != "" && empty.IsNil(dataMap[fieldNameUpdate]) {
|
||||
dataValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false)
|
||||
dataMap[fieldNameUpdate] = dataValue
|
||||
}
|
||||
newData = dataMap
|
||||
|
||||
default:
|
||||
var updateStr = gconv.String(newData)
|
||||
// Automatically update the record updating time.
|
||||
if fieldNameUpdate != "" && !gstr.Contains(updateStr, fieldNameUpdate) {
|
||||
dataValue := stm.GetFieldValue(ctx, fieldTypeUpdate, false)
|
||||
updateStr += fmt.Sprintf(`,%s=?`, fieldNameUpdate)
|
||||
conditionArgs = append([]any{dataValue}, conditionArgs...)
|
||||
}
|
||||
newData = updateStr
|
||||
}
|
||||
|
||||
if !gstr.ContainsI(conditionStr, " WHERE ") {
|
||||
intlog.Printf(
|
||||
ctx,
|
||||
`sql condition string "%s" has no WHERE for UPDATE operation, fieldNameUpdate: %s`,
|
||||
conditionStr, fieldNameUpdate,
|
||||
)
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeMissingParameter,
|
||||
"there should be WHERE condition statement for UPDATE operation",
|
||||
)
|
||||
}
|
||||
|
||||
in := &HookUpdateInput{
|
||||
internalParamHookUpdate: internalParamHookUpdate{
|
||||
internalParamHook: internalParamHook{
|
||||
link: m.getLink(true),
|
||||
},
|
||||
handler: m.hookHandler.Update,
|
||||
},
|
||||
Model: m,
|
||||
Table: m.tables,
|
||||
Schema: m.schema,
|
||||
Data: newData,
|
||||
Condition: conditionStr,
|
||||
Args: m.mergeArguments(conditionArgs),
|
||||
}
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
||||
// UpdateAndGetAffected performs update statement and returns the affected rows number.
|
||||
func (m *Model) UpdateAndGetAffected(dataAndWhere ...any) (affected int64, err error) {
|
||||
result, err := m.Update(dataAndWhere...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// Increment increments a column's value by a given amount.
|
||||
// The parameter `amount` can be type of float or integer.
|
||||
func (m *Model) Increment(column string, amount any) (sql.Result, error) {
|
||||
return m.getModel().Data(column, &Counter{
|
||||
Field: column,
|
||||
Value: gconv.Float64(amount),
|
||||
}).Update()
|
||||
}
|
||||
|
||||
// Decrement decrements a column's value by a given amount.
|
||||
// The parameter `amount` can be type of float or integer.
|
||||
func (m *Model) Decrement(column string, amount any) (sql.Result, error) {
|
||||
return m.getModel().Data(column, &Counter{
|
||||
Field: column,
|
||||
Value: -gconv.Float64(amount),
|
||||
}).Update()
|
||||
}
|
||||
|
|
@ -0,0 +1,314 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/empty"
|
||||
"github.com/gogf/gf/v2/container/gset"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// QuoteWord checks given string `s` a word,
|
||||
// if true it quotes `s` with security chars of the database
|
||||
// and returns the quoted string; or else it returns `s` without any change.
|
||||
//
|
||||
// The meaning of a `word` can be considered as a column name.
|
||||
func (m *Model) QuoteWord(s string) string {
|
||||
return m.db.GetCore().QuoteWord(s)
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields' information of specified table of current
|
||||
// schema.
|
||||
//
|
||||
// Also see DriverMysql.TableFields.
|
||||
func (m *Model) TableFields(tableStr string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
var (
|
||||
ctx = m.GetCtx()
|
||||
usedTable = m.db.GetCore().guessPrimaryTableName(tableStr)
|
||||
usedSchema = gutil.GetOrDefaultStr(m.schema, schema...)
|
||||
)
|
||||
// Sharding feature.
|
||||
usedSchema, err = m.getActualSchema(ctx, usedSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usedTable, err = m.getActualTable(ctx, usedTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.db.TableFields(ctx, usedTable, usedSchema)
|
||||
}
|
||||
|
||||
// getModel creates and returns a cloned model of current model if `safe` is true, or else it returns
|
||||
// the current model.
|
||||
func (m *Model) getModel() *Model {
|
||||
if !m.safe {
|
||||
return m
|
||||
} else {
|
||||
return m.Clone()
|
||||
}
|
||||
}
|
||||
|
||||
// mappingAndFilterToTableFields mappings and changes given field name to really table field name.
|
||||
// Eg:
|
||||
// ID -> id
|
||||
// NICK_Name -> nickname.
|
||||
func (m *Model) mappingAndFilterToTableFields(table string, fields []any, filter bool) []any {
|
||||
var fieldsTable = table
|
||||
if fieldsTable != "" {
|
||||
hasTable, _ := m.db.GetCore().HasTable(fieldsTable)
|
||||
if !hasTable {
|
||||
if fieldsTable != m.tablesInit {
|
||||
// Table/alias unknown (e.g., FieldsPrefix called before LeftJoin), skip filtering.
|
||||
return fields
|
||||
}
|
||||
// HasTable cache miss for main table, fallback to use main table for field mapping.
|
||||
fieldsTable = m.tablesInit
|
||||
}
|
||||
}
|
||||
if fieldsTable == "" {
|
||||
fieldsTable = m.tablesInit
|
||||
}
|
||||
|
||||
fieldsMap, _ := m.TableFields(fieldsTable)
|
||||
if len(fieldsMap) == 0 {
|
||||
return fields
|
||||
}
|
||||
var outputFieldsArray = make([]any, 0)
|
||||
fieldsKeyMap := make(map[string]any, len(fieldsMap))
|
||||
for k := range fieldsMap {
|
||||
fieldsKeyMap[k] = nil
|
||||
}
|
||||
for _, field := range fields {
|
||||
var (
|
||||
fieldStr = gconv.String(field)
|
||||
inputFieldsArray []string
|
||||
)
|
||||
// Skip empty string fields.
|
||||
if fieldStr == "" {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, fieldStr):
|
||||
inputFieldsArray = append(inputFieldsArray, fieldStr)
|
||||
|
||||
case gregex.IsMatchString(regularFieldNameWithCommaRegPattern, fieldStr):
|
||||
inputFieldsArray = gstr.SplitAndTrim(fieldStr, ",")
|
||||
|
||||
default:
|
||||
// Example:
|
||||
// user.id, user.name
|
||||
// replace(concat_ws(',',lpad(s.id, 6, '0'),s.name),',','') `code`
|
||||
outputFieldsArray = append(outputFieldsArray, field)
|
||||
continue
|
||||
}
|
||||
for _, inputField := range inputFieldsArray {
|
||||
if !gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, inputField) {
|
||||
outputFieldsArray = append(outputFieldsArray, inputField)
|
||||
continue
|
||||
}
|
||||
if _, ok := fieldsKeyMap[inputField]; !ok {
|
||||
// Example:
|
||||
// id, name
|
||||
if foundKey, _ := gutil.MapPossibleItemByKey(fieldsKeyMap, inputField); foundKey != "" {
|
||||
outputFieldsArray = append(outputFieldsArray, foundKey)
|
||||
} else if !filter {
|
||||
outputFieldsArray = append(outputFieldsArray, inputField)
|
||||
}
|
||||
} else {
|
||||
outputFieldsArray = append(outputFieldsArray, inputField)
|
||||
}
|
||||
}
|
||||
}
|
||||
return outputFieldsArray
|
||||
}
|
||||
|
||||
// filterDataForInsertOrUpdate does filter feature with data for inserting/updating operations.
|
||||
// Note that, it does not filter list item, which is also type of map, for "omit empty" feature.
|
||||
func (m *Model) filterDataForInsertOrUpdate(data any) (any, error) {
|
||||
var err error
|
||||
switch value := data.(type) {
|
||||
case List:
|
||||
var omitEmpty bool
|
||||
if m.option&optionOmitNilDataList > 0 {
|
||||
omitEmpty = true
|
||||
}
|
||||
for k, item := range value {
|
||||
value[k], err = m.doMappingAndFilterForInsertOrUpdateDataMap(item, omitEmpty)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return value, nil
|
||||
|
||||
case Map:
|
||||
return m.doMappingAndFilterForInsertOrUpdateDataMap(value, true)
|
||||
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// doMappingAndFilterForInsertOrUpdateDataMap does the filter features for map.
|
||||
// Note that, it does not filter list item, which is also type of map, for "omit empty" feature.
|
||||
func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEmpty bool) (Map, error) {
|
||||
var (
|
||||
err error
|
||||
ctx = m.GetCtx()
|
||||
core = m.db.GetCore()
|
||||
schema = m.schema
|
||||
table = m.tablesInit
|
||||
)
|
||||
// Sharding feature.
|
||||
schema, err = m.getActualSchema(ctx, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
table, err = m.getActualTable(ctx, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err = core.mappingAndFilterData(
|
||||
ctx, schema, table, data, m.filter,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Remove key-value pairs of which the value is nil.
|
||||
if allowOmitEmpty && m.option&optionOmitNilData > 0 {
|
||||
tempMap := make(Map, len(data))
|
||||
for k, v := range data {
|
||||
if empty.IsNil(v) {
|
||||
continue
|
||||
}
|
||||
tempMap[k] = v
|
||||
}
|
||||
data = tempMap
|
||||
}
|
||||
|
||||
// Remove key-value pairs of which the value is empty.
|
||||
if allowOmitEmpty && m.option&optionOmitEmptyData > 0 {
|
||||
tempMap := make(Map, len(data))
|
||||
for k, v := range data {
|
||||
if empty.IsEmpty(v) {
|
||||
continue
|
||||
}
|
||||
// Special type filtering.
|
||||
switch r := v.(type) {
|
||||
case time.Time:
|
||||
if r.IsZero() {
|
||||
continue
|
||||
}
|
||||
case *time.Time:
|
||||
if r.IsZero() {
|
||||
continue
|
||||
}
|
||||
case gtime.Time:
|
||||
if r.IsZero() {
|
||||
continue
|
||||
}
|
||||
case *gtime.Time:
|
||||
if r.IsZero() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
tempMap[k] = v
|
||||
}
|
||||
data = tempMap
|
||||
}
|
||||
|
||||
if len(m.fields) > 0 {
|
||||
// Keep specified fields.
|
||||
var (
|
||||
fieldSet = gset.NewStrSetFrom(gconv.Strings(m.fields))
|
||||
charL, charR = m.db.GetChars()
|
||||
chars = charL + charR
|
||||
)
|
||||
fieldSet.Walk(func(item string) string {
|
||||
return gstr.Trim(item, chars)
|
||||
})
|
||||
for k := range data {
|
||||
k = gstr.Trim(k, chars)
|
||||
if !fieldSet.Contains(k) {
|
||||
delete(data, k)
|
||||
}
|
||||
}
|
||||
} else if len(m.fieldsEx) > 0 {
|
||||
// Filter specified fields.
|
||||
for _, v := range m.fieldsEx {
|
||||
delete(data, gconv.String(v))
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// getLink returns the underlying database link object with configured `linkType` attribute.
|
||||
// The parameter `master` specifies whether using the master node if master-slave configured.
|
||||
func (m *Model) getLink(master bool) Link {
|
||||
if m.tx != nil {
|
||||
if sqlTx := m.tx.GetSqlTX(); sqlTx != nil {
|
||||
return &txLink{sqlTx}
|
||||
}
|
||||
}
|
||||
linkType := m.linkType
|
||||
if linkType == 0 {
|
||||
if master {
|
||||
linkType = linkTypeMaster
|
||||
} else {
|
||||
linkType = linkTypeSlave
|
||||
}
|
||||
}
|
||||
switch linkType {
|
||||
case linkTypeMaster:
|
||||
link, err := m.db.GetCore().MasterLink(m.schema)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return link
|
||||
case linkTypeSlave:
|
||||
link, err := m.db.GetCore().SlaveLink(m.schema)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return link
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPrimaryKey retrieves and returns the primary key name of the model table.
|
||||
// It parses m.tables to retrieve the primary table name, supporting m.tables like:
|
||||
// "user", "user u", "user as u, user_detail as ud".
|
||||
func (m *Model) getPrimaryKey() string {
|
||||
table := gstr.SplitAndTrim(m.tablesInit, " ")[0]
|
||||
tableFields, err := m.TableFields(table)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for name, field := range tableFields {
|
||||
if gstr.ContainsI(field.Key, "pri") {
|
||||
return name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// mergeArguments creates and returns new arguments by merging `m.extraArgs` and given `args`.
|
||||
func (m *Model) mergeArguments(args []any) []any {
|
||||
if len(m.extraArgs) > 0 {
|
||||
newArgs := make([]any, len(m.extraArgs)+len(args))
|
||||
copy(newArgs, m.extraArgs)
|
||||
copy(newArgs[len(m.extraArgs):], args)
|
||||
return newArgs
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://githum.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// callWhereBuilder creates and returns a new Model, and sets its WhereBuilder if current Model is safe.
|
||||
// It sets the WhereBuilder and returns current Model directly if it is not safe.
|
||||
func (m *Model) callWhereBuilder(builder *WhereBuilder) *Model {
|
||||
model := m.getModel()
|
||||
model.whereBuilder = builder
|
||||
return model
|
||||
}
|
||||
|
||||
// Where sets the condition statement for the builder. The parameter `where` can be type of
|
||||
// string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times,
|
||||
// multiple conditions will be joined into where statement using "AND".
|
||||
// See WhereBuilder.Where.
|
||||
func (m *Model) Where(where any, args ...any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.Where(where, args...))
|
||||
}
|
||||
|
||||
// Wheref builds condition string using fmt.Sprintf and arguments.
|
||||
// Note that if the number of `args` is more than the placeholder in `format`,
|
||||
// the extra `args` will be used as the where condition arguments of the Model.
|
||||
// See WhereBuilder.Wheref.
|
||||
func (m *Model) Wheref(format string, args ...any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.Wheref(format, args...))
|
||||
}
|
||||
|
||||
// WherePri does the same logic as Model.Where except that if the parameter `where`
|
||||
// is a single condition like int/string/float/slice, it treats the condition as the primary
|
||||
// key value. That is, if primary key is "id" and given `where` parameter as "123", the
|
||||
// WherePri function treats the condition as "id=123", but Model.Where treats the condition
|
||||
// as string "123".
|
||||
// See WhereBuilder.WherePri.
|
||||
func (m *Model) WherePri(where any, args ...any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePri(where, args...))
|
||||
}
|
||||
|
||||
// WhereLT builds `column < value` statement.
|
||||
// See WhereBuilder.WhereLT.
|
||||
func (m *Model) WhereLT(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereLT(column, value))
|
||||
}
|
||||
|
||||
// WhereLTE builds `column <= value` statement.
|
||||
// See WhereBuilder.WhereLTE.
|
||||
func (m *Model) WhereLTE(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereLTE(column, value))
|
||||
}
|
||||
|
||||
// WhereGT builds `column > value` statement.
|
||||
// See WhereBuilder.WhereGT.
|
||||
func (m *Model) WhereGT(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereGT(column, value))
|
||||
}
|
||||
|
||||
// WhereGTE builds `column >= value` statement.
|
||||
// See WhereBuilder.WhereGTE.
|
||||
func (m *Model) WhereGTE(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereGTE(column, value))
|
||||
}
|
||||
|
||||
// WhereBetween builds `column BETWEEN min AND max` statement.
|
||||
// See WhereBuilder.WhereBetween.
|
||||
func (m *Model) WhereBetween(column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereBetween(column, min, max))
|
||||
}
|
||||
|
||||
// WhereLike builds `column LIKE like` statement.
|
||||
// See WhereBuilder.WhereLike.
|
||||
func (m *Model) WhereLike(column string, like string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereLike(column, like))
|
||||
}
|
||||
|
||||
// WhereIn builds `column IN (in)` statement.
|
||||
// See WhereBuilder.WhereIn.
|
||||
func (m *Model) WhereIn(column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereIn(column, in))
|
||||
}
|
||||
|
||||
// WhereNull builds `columns[0] IS NULL AND columns[1] IS NULL ...` statement.
|
||||
// See WhereBuilder.WhereNull.
|
||||
func (m *Model) WhereNull(columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereNull(columns...))
|
||||
}
|
||||
|
||||
// WhereNotBetween builds `column NOT BETWEEN min AND max` statement.
|
||||
// See WhereBuilder.WhereNotBetween.
|
||||
func (m *Model) WhereNotBetween(column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereNotBetween(column, min, max))
|
||||
}
|
||||
|
||||
// WhereNotLike builds `column NOT LIKE like` statement.
|
||||
// See WhereBuilder.WhereNotLike.
|
||||
func (m *Model) WhereNotLike(column string, like any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereNotLike(column, like))
|
||||
}
|
||||
|
||||
// WhereNot builds `column != value` statement.
|
||||
// See WhereBuilder.WhereNot.
|
||||
func (m *Model) WhereNot(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereNot(column, value))
|
||||
}
|
||||
|
||||
// WhereNotIn builds `column NOT IN (in)` statement.
|
||||
// See WhereBuilder.WhereNotIn.
|
||||
func (m *Model) WhereNotIn(column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereNotIn(column, in))
|
||||
}
|
||||
|
||||
// WhereNotNull builds `columns[0] IS NOT NULL AND columns[1] IS NOT NULL ...` statement.
|
||||
// See WhereBuilder.WhereNotNull.
|
||||
func (m *Model) WhereNotNull(columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereNotNull(columns...))
|
||||
}
|
||||
|
||||
// WhereExists builds `EXISTS (subQuery)` statement.
|
||||
func (m *Model) WhereExists(subQuery *Model) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereExists(subQuery))
|
||||
}
|
||||
|
||||
// WhereNotExists builds `NOT EXISTS (subQuery)` statement.
|
||||
func (m *Model) WhereNotExists(subQuery *Model) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereNotExists(subQuery))
|
||||
}
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// WherePrefix performs as Where, but it adds prefix to each field in where statement.
|
||||
// See WhereBuilder.WherePrefix.
|
||||
func (m *Model) WherePrefix(prefix string, where any, args ...any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefix(prefix, where, args...))
|
||||
}
|
||||
|
||||
// WherePrefixLT builds `prefix.column < value` statement.
|
||||
// See WhereBuilder.WherePrefixLT.
|
||||
func (m *Model) WherePrefixLT(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixLT(prefix, column, value))
|
||||
}
|
||||
|
||||
// WherePrefixLTE builds `prefix.column <= value` statement.
|
||||
// See WhereBuilder.WherePrefixLTE.
|
||||
func (m *Model) WherePrefixLTE(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixLTE(prefix, column, value))
|
||||
}
|
||||
|
||||
// WherePrefixGT builds `prefix.column > value` statement.
|
||||
// See WhereBuilder.WherePrefixGT.
|
||||
func (m *Model) WherePrefixGT(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixGT(prefix, column, value))
|
||||
}
|
||||
|
||||
// WherePrefixGTE builds `prefix.column >= value` statement.
|
||||
// See WhereBuilder.WherePrefixGTE.
|
||||
func (m *Model) WherePrefixGTE(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixGTE(prefix, column, value))
|
||||
}
|
||||
|
||||
// WherePrefixBetween builds `prefix.column BETWEEN min AND max` statement.
|
||||
// See WhereBuilder.WherePrefixBetween.
|
||||
func (m *Model) WherePrefixBetween(prefix string, column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixBetween(prefix, column, min, max))
|
||||
}
|
||||
|
||||
// WherePrefixLike builds `prefix.column LIKE like` statement.
|
||||
// See WhereBuilder.WherePrefixLike.
|
||||
func (m *Model) WherePrefixLike(prefix string, column string, like any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixLike(prefix, column, like))
|
||||
}
|
||||
|
||||
// WherePrefixIn builds `prefix.column IN (in)` statement.
|
||||
// See WhereBuilder.WherePrefixIn.
|
||||
func (m *Model) WherePrefixIn(prefix string, column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixIn(prefix, column, in))
|
||||
}
|
||||
|
||||
// WherePrefixNull builds `prefix.columns[0] IS NULL AND prefix.columns[1] IS NULL ...` statement.
|
||||
// See WhereBuilder.WherePrefixNull.
|
||||
func (m *Model) WherePrefixNull(prefix string, columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixNull(prefix, columns...))
|
||||
}
|
||||
|
||||
// WherePrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement.
|
||||
// See WhereBuilder.WherePrefixNotBetween.
|
||||
func (m *Model) WherePrefixNotBetween(prefix string, column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotBetween(prefix, column, min, max))
|
||||
}
|
||||
|
||||
// WherePrefixNotLike builds `prefix.column NOT LIKE like` statement.
|
||||
// See WhereBuilder.WherePrefixNotLike.
|
||||
func (m *Model) WherePrefixNotLike(prefix string, column string, like any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotLike(prefix, column, like))
|
||||
}
|
||||
|
||||
// WherePrefixNot builds `prefix.column != value` statement.
|
||||
// See WhereBuilder.WherePrefixNot.
|
||||
func (m *Model) WherePrefixNot(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixNot(prefix, column, value))
|
||||
}
|
||||
|
||||
// WherePrefixNotIn builds `prefix.column NOT IN (in)` statement.
|
||||
// See WhereBuilder.WherePrefixNotIn.
|
||||
func (m *Model) WherePrefixNotIn(prefix string, column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotIn(prefix, column, in))
|
||||
}
|
||||
|
||||
// WherePrefixNotNull builds `prefix.columns[0] IS NOT NULL AND prefix.columns[1] IS NOT NULL ...` statement.
|
||||
// See WhereBuilder.WherePrefixNotNull.
|
||||
func (m *Model) WherePrefixNotNull(prefix string, columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WherePrefixNotNull(prefix, columns...))
|
||||
}
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// WhereOr adds "OR" condition to the where statement.
|
||||
// See WhereBuilder.WhereOr.
|
||||
func (m *Model) WhereOr(where any, args ...any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOr(where, args...))
|
||||
}
|
||||
|
||||
// WhereOrf builds `OR` condition string using fmt.Sprintf and arguments.
|
||||
// See WhereBuilder.WhereOrf.
|
||||
func (m *Model) WhereOrf(format string, args ...any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrf(format, args...))
|
||||
}
|
||||
|
||||
// WhereOrLT builds `column < value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrLT.
|
||||
func (m *Model) WhereOrLT(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrLT(column, value))
|
||||
}
|
||||
|
||||
// WhereOrLTE builds `column <= value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrLTE.
|
||||
func (m *Model) WhereOrLTE(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrLTE(column, value))
|
||||
}
|
||||
|
||||
// WhereOrGT builds `column > value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrGT.
|
||||
func (m *Model) WhereOrGT(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrGT(column, value))
|
||||
}
|
||||
|
||||
// WhereOrGTE builds `column >= value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrGTE.
|
||||
func (m *Model) WhereOrGTE(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrGTE(column, value))
|
||||
}
|
||||
|
||||
// WhereOrBetween builds `column BETWEEN min AND max` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrBetween.
|
||||
func (m *Model) WhereOrBetween(column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrBetween(column, min, max))
|
||||
}
|
||||
|
||||
// WhereOrLike builds `column LIKE like` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrLike.
|
||||
func (m *Model) WhereOrLike(column string, like any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrLike(column, like))
|
||||
}
|
||||
|
||||
// WhereOrIn builds `column IN (in)` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrIn.
|
||||
func (m *Model) WhereOrIn(column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrIn(column, in))
|
||||
}
|
||||
|
||||
// WhereOrNull builds `columns[0] IS NULL OR columns[1] IS NULL ...` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrNull.
|
||||
func (m *Model) WhereOrNull(columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrNull(columns...))
|
||||
}
|
||||
|
||||
// WhereOrNotBetween builds `column NOT BETWEEN min AND max` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrNotBetween.
|
||||
func (m *Model) WhereOrNotBetween(column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrNotBetween(column, min, max))
|
||||
}
|
||||
|
||||
// WhereOrNotLike builds `column NOT LIKE 'like'` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrNotLike.
|
||||
func (m *Model) WhereOrNotLike(column string, like any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrNotLike(column, like))
|
||||
}
|
||||
|
||||
// WhereOrNot builds `column != value` statement.
|
||||
// See WhereBuilder.WhereOrNot.
|
||||
func (m *Model) WhereOrNot(column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrNot(column, value))
|
||||
}
|
||||
|
||||
// WhereOrNotIn builds `column NOT IN (in)` statement.
|
||||
// See WhereBuilder.WhereOrNotIn.
|
||||
func (m *Model) WhereOrNotIn(column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrNotIn(column, in))
|
||||
}
|
||||
|
||||
// WhereOrNotNull builds `columns[0] IS NOT NULL OR columns[1] IS NOT NULL ...` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrNotNull.
|
||||
func (m *Model) WhereOrNotNull(columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrNotNull(columns...))
|
||||
}
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// WhereOrPrefix performs as WhereOr, but it adds prefix to each field in where statement.
|
||||
// See WhereBuilder.WhereOrPrefix.
|
||||
func (m *Model) WhereOrPrefix(prefix string, where any, args ...any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefix(prefix, where, args...))
|
||||
}
|
||||
|
||||
// WhereOrPrefixLT builds `prefix.column < value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixLT.
|
||||
func (m *Model) WhereOrPrefixLT(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLT(prefix, column, value))
|
||||
}
|
||||
|
||||
// WhereOrPrefixLTE builds `prefix.column <= value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixLTE.
|
||||
func (m *Model) WhereOrPrefixLTE(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLTE(prefix, column, value))
|
||||
}
|
||||
|
||||
// WhereOrPrefixGT builds `prefix.column > value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixGT.
|
||||
func (m *Model) WhereOrPrefixGT(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixGT(prefix, column, value))
|
||||
}
|
||||
|
||||
// WhereOrPrefixGTE builds `prefix.column >= value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixGTE.
|
||||
func (m *Model) WhereOrPrefixGTE(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixGTE(prefix, column, value))
|
||||
}
|
||||
|
||||
// WhereOrPrefixBetween builds `prefix.column BETWEEN min AND max` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixBetween.
|
||||
func (m *Model) WhereOrPrefixBetween(prefix string, column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixBetween(prefix, column, min, max))
|
||||
}
|
||||
|
||||
// WhereOrPrefixLike builds `prefix.column LIKE like` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixLike.
|
||||
func (m *Model) WhereOrPrefixLike(prefix string, column string, like any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixLike(prefix, column, like))
|
||||
}
|
||||
|
||||
// WhereOrPrefixIn builds `prefix.column IN (in)` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixIn.
|
||||
func (m *Model) WhereOrPrefixIn(prefix string, column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixIn(prefix, column, in))
|
||||
}
|
||||
|
||||
// WhereOrPrefixNull builds `prefix.columns[0] IS NULL OR prefix.columns[1] IS NULL ...` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixNull.
|
||||
func (m *Model) WhereOrPrefixNull(prefix string, columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNull(prefix, columns...))
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotBetween builds `prefix.column NOT BETWEEN min AND max` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixNotBetween.
|
||||
func (m *Model) WhereOrPrefixNotBetween(prefix string, column string, min, max any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotBetween(prefix, column, min, max))
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotLike builds `prefix.column NOT LIKE like` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixNotLike.
|
||||
func (m *Model) WhereOrPrefixNotLike(prefix string, column string, like any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotLike(prefix, column, like))
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotIn builds `prefix.column NOT IN (in)` statement.
|
||||
// See WhereBuilder.WhereOrPrefixNotIn.
|
||||
func (m *Model) WhereOrPrefixNotIn(prefix string, column string, in any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotIn(prefix, column, in))
|
||||
}
|
||||
|
||||
// WhereOrPrefixNotNull builds `prefix.columns[0] IS NOT NULL OR prefix.columns[1] IS NOT NULL ...` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixNotNull.
|
||||
func (m *Model) WhereOrPrefixNotNull(prefix string, columns ...string) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNotNull(prefix, columns...))
|
||||
}
|
||||
|
||||
// WhereOrPrefixNot builds `prefix.column != value` statement in `OR` conditions.
|
||||
// See WhereBuilder.WhereOrPrefixNot.
|
||||
func (m *Model) WhereOrPrefixNot(prefix string, column string, value any) *Model {
|
||||
return m.callWhereBuilder(m.whereBuilder.WhereOrPrefixNot(prefix, column, value))
|
||||
}
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/utils"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gstructs"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// With creates and returns an ORM model based on metadata of given object.
|
||||
// It also enables model association operations feature on given `object`.
|
||||
// It can be called multiple times to add one or more objects to model and enable
|
||||
// their mode association operations feature.
|
||||
// For example, if given struct definition:
|
||||
//
|
||||
// type User struct {
|
||||
// gmeta.Meta `orm:"table:user"`
|
||||
// Id int `json:"id"`
|
||||
// Name string `json:"name"`
|
||||
// UserDetail *UserDetail `orm:"with:uid=id"`
|
||||
// UserScores []*UserScores `orm:"with:uid=id"`
|
||||
// }
|
||||
//
|
||||
// We can enable model association operations on attribute `UserDetail` and `UserScores` by:
|
||||
//
|
||||
// db.With(User{}.UserDetail).With(User{}.UserScores).Scan(xxx)
|
||||
//
|
||||
// Or:
|
||||
//
|
||||
// db.With(UserDetail{}).With(UserScores{}).Scan(xxx)
|
||||
//
|
||||
// Or:
|
||||
//
|
||||
// db.With(UserDetail{}, UserScores{}).Scan(xxx)
|
||||
func (m *Model) With(objects ...any) *Model {
|
||||
model := m.getModel()
|
||||
for _, object := range objects {
|
||||
if m.tables == "" {
|
||||
m.tablesInit = m.db.GetCore().QuotePrefixTableName(
|
||||
getTableNameFromOrmTag(object),
|
||||
)
|
||||
m.tables = m.tablesInit
|
||||
return model
|
||||
}
|
||||
model.withArray = append(model.withArray, object)
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// WithAll enables model association operations on all objects that have "with" tag in the struct.
|
||||
func (m *Model) WithAll() *Model {
|
||||
model := m.getModel()
|
||||
model.withAll = true
|
||||
return model
|
||||
}
|
||||
|
||||
// doWithScanStruct handles model association operations feature for single struct.
|
||||
func (m *Model) doWithScanStruct(pointer any) error {
|
||||
if len(m.withArray) == 0 && !m.withAll {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
err error
|
||||
allowedTypeStrArray = make([]string, 0)
|
||||
)
|
||||
currentStructFieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{
|
||||
Pointer: pointer,
|
||||
PriorityTagArray: nil,
|
||||
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// It checks the with array and automatically calls the ScanList to complete association querying.
|
||||
if !m.withAll {
|
||||
for _, field := range currentStructFieldMap {
|
||||
for _, withItem := range m.withArray {
|
||||
withItemReflectValueType, err := gstructs.StructType(withItem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
|
||||
withItemReflectValueTypeStr = gstr.TrimAll(withItemReflectValueType.String(), "*[]")
|
||||
)
|
||||
// It does select operation if the field type is in the specified "with" type array.
|
||||
if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 {
|
||||
allowedTypeStrArray = append(allowedTypeStrArray, fieldTypeStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, field := range currentStructFieldMap {
|
||||
var (
|
||||
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
|
||||
parsedTagOutput = m.parseWithTagInFieldStruct(field)
|
||||
)
|
||||
if parsedTagOutput.With == "" {
|
||||
continue
|
||||
}
|
||||
// It just handlers "with" type attribute struct, so it ignores other struct types.
|
||||
if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) {
|
||||
continue
|
||||
}
|
||||
array := gstr.SplitAndTrim(parsedTagOutput.With, "=")
|
||||
if len(array) == 1 {
|
||||
// It also supports using only one column name
|
||||
// if both tables associates using the same column name.
|
||||
array = append(array, parsedTagOutput.With)
|
||||
}
|
||||
var (
|
||||
model *Model
|
||||
fieldKeys []string
|
||||
relatedSourceName = array[0]
|
||||
relatedTargetName = array[1]
|
||||
relatedTargetValue any
|
||||
)
|
||||
// Find the value of related attribute from `pointer`.
|
||||
for attributeName, attributeValue := range currentStructFieldMap {
|
||||
if utils.EqualFoldWithoutChars(attributeName, relatedTargetName) {
|
||||
relatedTargetValue = attributeValue.Value.Interface()
|
||||
break
|
||||
}
|
||||
}
|
||||
if relatedTargetValue == nil {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`cannot find the target related value of name "%s" in with tag "%s" for attribute "%s.%s"`,
|
||||
relatedTargetName, parsedTagOutput.With, reflect.TypeOf(pointer).Elem(), field.Name(),
|
||||
)
|
||||
}
|
||||
bindToReflectValue := field.Value
|
||||
if bindToReflectValue.Kind() != reflect.Pointer && bindToReflectValue.CanAddr() {
|
||||
bindToReflectValue = bindToReflectValue.Addr()
|
||||
}
|
||||
|
||||
if structFields, err := gstructs.Fields(gstructs.FieldsInput{
|
||||
Pointer: field.Value,
|
||||
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
|
||||
}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
fieldKeys = make([]string, len(structFields))
|
||||
for i, field := range structFields {
|
||||
fieldKeys[i] = field.Name()
|
||||
}
|
||||
}
|
||||
// Recursively with feature checks.
|
||||
model = m.db.With(field.Value).Hook(m.hookHandler)
|
||||
if m.withAll {
|
||||
model = model.WithAll()
|
||||
} else {
|
||||
model = model.With(m.withArray...)
|
||||
}
|
||||
if parsedTagOutput.Where != "" {
|
||||
model = model.Where(parsedTagOutput.Where)
|
||||
}
|
||||
if parsedTagOutput.Order != "" {
|
||||
model = model.Order(parsedTagOutput.Order)
|
||||
}
|
||||
if parsedTagOutput.Unscoped == "true" {
|
||||
model = model.Unscoped()
|
||||
}
|
||||
// With cache feature.
|
||||
if m.cacheEnabled && m.cacheOption.Name == "" {
|
||||
model = model.Cache(m.cacheOption)
|
||||
}
|
||||
err = model.Fields(fieldKeys).
|
||||
Where(relatedSourceName, relatedTargetValue).
|
||||
Scan(bindToReflectValue)
|
||||
// It ignores sql.ErrNoRows in with feature.
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// doWithScanStructs handles model association operations feature for struct slice.
|
||||
// Also see doWithScanStruct.
|
||||
func (m *Model) doWithScanStructs(pointer any) error {
|
||||
if len(m.withArray) == 0 && !m.withAll {
|
||||
return nil
|
||||
}
|
||||
if v, ok := pointer.(reflect.Value); ok {
|
||||
pointer = v.Interface()
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
allowedTypeStrArray = make([]string, 0)
|
||||
)
|
||||
currentStructFieldMap, err := gstructs.FieldMap(gstructs.FieldMapInput{
|
||||
Pointer: pointer,
|
||||
PriorityTagArray: nil,
|
||||
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// It checks the with array and automatically calls the ScanList to complete association querying.
|
||||
if !m.withAll {
|
||||
for _, field := range currentStructFieldMap {
|
||||
for _, withItem := range m.withArray {
|
||||
withItemReflectValueType, err := gstructs.StructType(withItem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
|
||||
withItemReflectValueTypeStr = gstr.TrimAll(withItemReflectValueType.String(), "*[]")
|
||||
)
|
||||
// It does select operation if the field type is in the specified with type array.
|
||||
if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 {
|
||||
allowedTypeStrArray = append(allowedTypeStrArray, fieldTypeStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for fieldName, field := range currentStructFieldMap {
|
||||
var (
|
||||
fieldTypeStr = gstr.TrimAll(field.Type().String(), "*[]")
|
||||
parsedTagOutput = m.parseWithTagInFieldStruct(field)
|
||||
)
|
||||
if parsedTagOutput.With == "" {
|
||||
continue
|
||||
}
|
||||
if !m.withAll && !gstr.InArray(allowedTypeStrArray, fieldTypeStr) {
|
||||
continue
|
||||
}
|
||||
array := gstr.SplitAndTrim(parsedTagOutput.With, "=")
|
||||
if len(array) == 1 {
|
||||
// It supports using only one column name
|
||||
// if both tables associates using the same column name.
|
||||
array = append(array, parsedTagOutput.With)
|
||||
}
|
||||
var (
|
||||
model *Model
|
||||
fieldKeys []string
|
||||
relatedSourceName = array[0]
|
||||
relatedTargetName = array[1]
|
||||
relatedTargetValue any
|
||||
)
|
||||
// Find the value slice of related attribute from `pointer`.
|
||||
for attributeName := range currentStructFieldMap {
|
||||
if utils.EqualFoldWithoutChars(attributeName, relatedTargetName) {
|
||||
relatedTargetValue = ListItemValuesUnique(pointer, attributeName)
|
||||
break
|
||||
}
|
||||
}
|
||||
if relatedTargetValue == nil {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`cannot find the related value for attribute name "%s" of with tag "%s"`,
|
||||
relatedTargetName, parsedTagOutput.With,
|
||||
)
|
||||
}
|
||||
// If related value is empty, it does nothing but just returns.
|
||||
if gutil.IsEmpty(relatedTargetValue) {
|
||||
return nil
|
||||
}
|
||||
if structFields, err := gstructs.Fields(gstructs.FieldsInput{
|
||||
Pointer: field.Value,
|
||||
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
|
||||
}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
fieldKeys = make([]string, len(structFields))
|
||||
for i, field := range structFields {
|
||||
fieldKeys[i] = field.Name()
|
||||
}
|
||||
}
|
||||
// Recursively with feature checks.
|
||||
model = m.db.With(field.Value).Hook(m.hookHandler)
|
||||
if m.withAll {
|
||||
model = model.WithAll()
|
||||
} else {
|
||||
model = model.With(m.withArray...)
|
||||
}
|
||||
if parsedTagOutput.Where != "" {
|
||||
model = model.Where(parsedTagOutput.Where)
|
||||
}
|
||||
if parsedTagOutput.Order != "" {
|
||||
model = model.Order(parsedTagOutput.Order)
|
||||
}
|
||||
if parsedTagOutput.Unscoped == "true" {
|
||||
model = model.Unscoped()
|
||||
}
|
||||
// With cache feature.
|
||||
if m.cacheEnabled && m.cacheOption.Name == "" {
|
||||
model = model.Cache(m.cacheOption)
|
||||
}
|
||||
err = model.Fields(fieldKeys).
|
||||
Where(relatedSourceName, relatedTargetValue).
|
||||
ScanList(pointer, fieldName, parsedTagOutput.With)
|
||||
// It ignores sql.ErrNoRows in with feature.
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type parseWithTagInFieldStructOutput struct {
|
||||
With string
|
||||
Where string
|
||||
Order string
|
||||
Unscoped string
|
||||
}
|
||||
|
||||
func (m *Model) parseWithTagInFieldStruct(field gstructs.Field) (output parseWithTagInFieldStructOutput) {
|
||||
var (
|
||||
ormTag = field.Tag(OrmTagForStruct)
|
||||
data = make(map[string]string)
|
||||
array []string
|
||||
key string
|
||||
)
|
||||
for _, v := range gstr.SplitAndTrim(ormTag, ",") {
|
||||
array = gstr.Split(v, ":")
|
||||
if len(array) == 2 {
|
||||
key = array[0]
|
||||
data[key] = gstr.Trim(array[1])
|
||||
} else {
|
||||
if key == OrmTagForWithOrder {
|
||||
// supporting multiple order fields
|
||||
data[key] += "," + gstr.Trim(v)
|
||||
} else {
|
||||
data[key] += " " + gstr.Trim(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
output.With = data[OrmTagForWith]
|
||||
output.Where = data[OrmTagForWithWhere]
|
||||
output.Order = data[OrmTagForWithOrder]
|
||||
output.Unscoped = data[OrmTagForWithUnscoped]
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
)
|
||||
|
||||
// SqlResult is execution result for sql operations.
|
||||
// It also supports batch operation result for rowsAffected.
|
||||
type SqlResult struct {
|
||||
Result sql.Result
|
||||
Affected int64
|
||||
}
|
||||
|
||||
// MustGetAffected returns the affected rows count, if any error occurs, it panics.
|
||||
func (r *SqlResult) MustGetAffected() int64 {
|
||||
rows, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
err = gerror.Wrap(err, `sql.Result.RowsAffected failed`)
|
||||
panic(err)
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
// MustGetInsertId returns the last insert id, if any error occurs, it panics.
|
||||
func (r *SqlResult) MustGetInsertId() int64 {
|
||||
id, err := r.LastInsertId()
|
||||
if err != nil {
|
||||
err = gerror.Wrap(err, `sql.Result.LastInsertId failed`)
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// RowsAffected returns the number of rows affected by an
|
||||
// update, insert, or delete. Not every database or database
|
||||
// driver may support this.
|
||||
// Also, See sql.Result.
|
||||
func (r *SqlResult) RowsAffected() (int64, error) {
|
||||
if r.Affected > 0 {
|
||||
return r.Affected, nil
|
||||
}
|
||||
if r.Result == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return r.Result.RowsAffected()
|
||||
}
|
||||
|
||||
// LastInsertId returns the integer generated by the database
|
||||
// in response to a command. Typically, this will be from an
|
||||
// "auto increment" column when inserting a new row. Not all
|
||||
// databases support this feature, and the syntax of such
|
||||
// statements varies.
|
||||
// Also, See sql.Result.
|
||||
func (r *SqlResult) LastInsertId() (int64, error) {
|
||||
if r.Result == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return r.Result.LastInsertId()
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
// Schema is a schema object from which it can then create a Model.
|
||||
type Schema struct {
|
||||
DB
|
||||
}
|
||||
|
||||
// Schema creates and returns a schema.
|
||||
func (c *Core) Schema(schema string) *Schema {
|
||||
// Do not change the schema of the original db,
|
||||
// it here creates a new db and changes its schema.
|
||||
db, err := NewByGroup(c.GetGroup())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
core := db.GetCore()
|
||||
// Different schema share some same objects.
|
||||
core.logger = c.logger
|
||||
core.cache = c.cache
|
||||
core.schema = schema
|
||||
return &Schema{
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// Stmt is a prepared statement.
|
||||
// A Stmt is safe for concurrent use by multiple goroutines.
|
||||
//
|
||||
// If a Stmt is prepared on a Tx or Conn, it will be bound to a single
|
||||
// underlying connection forever. If the Tx or Conn closes, the Stmt will
|
||||
// become unusable and all operations will return an error.
|
||||
// If a Stmt is prepared on a DB, it will remain usable for the lifetime of the
|
||||
// DB. When the Stmt needs to execute on a new underlying connection, it will
|
||||
// prepare itself on the new connection automatically.
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
core *Core
|
||||
link Link
|
||||
sql string
|
||||
}
|
||||
|
||||
// ExecContext executes a prepared statement with the given arguments and
|
||||
// returns a Result summarizing the effect of the statement.
|
||||
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
|
||||
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
|
||||
Stmt: s.Stmt,
|
||||
Link: s.link,
|
||||
Sql: s.sql,
|
||||
Args: args,
|
||||
Type: SqlTypeStmtExecContext,
|
||||
IsTransaction: s.link.IsTransaction(),
|
||||
})
|
||||
return out.Result, err
|
||||
}
|
||||
|
||||
// QueryContext executes a prepared query statement with the given arguments
|
||||
// and returns the query results as a *Rows.
|
||||
func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
|
||||
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
|
||||
Stmt: s.Stmt,
|
||||
Link: s.link,
|
||||
Sql: s.sql,
|
||||
Args: args,
|
||||
Type: SqlTypeStmtQueryContext,
|
||||
IsTransaction: s.link.IsTransaction(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if out.RawResult != nil {
|
||||
return out.RawResult.(*sql.Rows), err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// QueryRowContext executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards
|
||||
// the rest.
|
||||
func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row {
|
||||
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
|
||||
Stmt: s.Stmt,
|
||||
Link: s.link,
|
||||
Sql: s.sql,
|
||||
Args: args,
|
||||
Type: SqlTypeStmtQueryContext,
|
||||
IsTransaction: s.link.IsTransaction(),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if out.RawResult != nil {
|
||||
return out.RawResult.(*sql.Row)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec executes a prepared statement with the given arguments and
|
||||
// returns a Result summarizing the effect of the statement.
|
||||
func (s *Stmt) Exec(args ...any) (sql.Result, error) {
|
||||
return s.ExecContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
// Query executes a prepared query statement with the given arguments
|
||||
// and returns the query results as a *Rows.
|
||||
func (s *Stmt) Query(args ...any) (*sql.Rows, error) {
|
||||
return s.QueryContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
// QueryRow executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards
|
||||
// the rest.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// var name string
|
||||
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
|
||||
func (s *Stmt) QueryRow(args ...any) *sql.Row {
|
||||
return s.QueryRowContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
// Close closes the statement.
|
||||
func (s *Stmt) Close() error {
|
||||
return s.Stmt.Close()
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/empty"
|
||||
"github.com/gogf/gf/v2/container/gmap"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// Json converts `r` to JSON format content.
|
||||
func (r Record) Json() string {
|
||||
content, _ := gjson.New(r.Map()).ToJsonString()
|
||||
return content
|
||||
}
|
||||
|
||||
// Xml converts `r` to XML format content.
|
||||
func (r Record) Xml(rootTag ...string) string {
|
||||
content, _ := gjson.New(r.Map()).ToXmlString(rootTag...)
|
||||
return content
|
||||
}
|
||||
|
||||
// Map converts `r` to map[string]any.
|
||||
func (r Record) Map() Map {
|
||||
m := make(map[string]any)
|
||||
for k, v := range r {
|
||||
m[k] = v.Val()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GMap converts `r` to a gmap.
|
||||
func (r Record) GMap() *gmap.StrAnyMap {
|
||||
return gmap.NewStrAnyMapFrom(r.Map())
|
||||
}
|
||||
|
||||
// Struct converts `r` to a struct.
|
||||
// Note that the parameter `pointer` should be type of *struct/**struct.
|
||||
//
|
||||
// Note that it returns sql.ErrNoRows if `r` is empty.
|
||||
func (r Record) Struct(pointer any) error {
|
||||
// If the record is empty, it returns error.
|
||||
if r.IsEmpty() {
|
||||
if !empty.IsNil(pointer, true) {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return converter.Struct(r, pointer, gconv.StructOption{
|
||||
PriorityTag: OrmTagForStruct,
|
||||
ContinueOnError: true,
|
||||
})
|
||||
}
|
||||
|
||||
// IsEmpty checks and returns whether `r` is empty.
|
||||
func (r Record) IsEmpty() bool {
|
||||
return len(r) == 0
|
||||
}
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"math"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/empty"
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// IsEmpty checks and returns whether `r` is empty.
|
||||
func (r Result) IsEmpty() bool {
|
||||
return r == nil || r.Len() == 0
|
||||
}
|
||||
|
||||
// Len returns the length of result list.
|
||||
func (r Result) Len() int {
|
||||
return len(r)
|
||||
}
|
||||
|
||||
// Size is alias of function Len.
|
||||
func (r Result) Size() int {
|
||||
return r.Len()
|
||||
}
|
||||
|
||||
// Chunk splits a Result into multiple Results,
|
||||
// the size of each array is determined by `size`.
|
||||
// The last chunk may contain less than size elements.
|
||||
func (r Result) Chunk(size int) []Result {
|
||||
if size < 1 {
|
||||
return nil
|
||||
}
|
||||
length := len(r)
|
||||
chunks := int(math.Ceil(float64(length) / float64(size)))
|
||||
var n []Result
|
||||
for i, end := 0, 0; chunks > 0; chunks-- {
|
||||
end = (i + 1) * size
|
||||
if end > length {
|
||||
end = length
|
||||
}
|
||||
n = append(n, r[i*size:end])
|
||||
i++
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Json converts `r` to JSON format content.
|
||||
func (r Result) Json() string {
|
||||
content, _ := gjson.New(r.List()).ToJsonString()
|
||||
return content
|
||||
}
|
||||
|
||||
// Xml converts `r` to XML format content.
|
||||
func (r Result) Xml(rootTag ...string) string {
|
||||
content, _ := gjson.New(r.List()).ToXmlString(rootTag...)
|
||||
return content
|
||||
}
|
||||
|
||||
// List converts `r` to a List.
|
||||
func (r Result) List() List {
|
||||
list := make(List, len(r))
|
||||
for k, v := range r {
|
||||
list[k] = v.Map()
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// Array retrieves and returns specified column values as slice.
|
||||
// The parameter `field` is optional is the column field is only one.
|
||||
// The default `field` is the first field name of the first item in `Result` if parameter `field` is not given.
|
||||
func (r Result) Array(field ...string) Array {
|
||||
array := make(Array, len(r))
|
||||
if len(r) == 0 {
|
||||
return array
|
||||
}
|
||||
key := ""
|
||||
if len(field) > 0 && field[0] != "" {
|
||||
key = field[0]
|
||||
} else {
|
||||
for k := range r[0] {
|
||||
key = k
|
||||
break
|
||||
}
|
||||
}
|
||||
for k, v := range r {
|
||||
array[k] = v[key]
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
// MapKeyValue converts `r` to a map[string]Value of which key is specified by `key`.
|
||||
// Note that the item value may be type of slice.
|
||||
func (r Result) MapKeyValue(key string) map[string]Value {
|
||||
var (
|
||||
s string
|
||||
m = make(map[string]Value)
|
||||
tempMap = make(map[string][]any)
|
||||
hasMultiValues bool
|
||||
)
|
||||
for _, item := range r {
|
||||
if k, ok := item[key]; ok {
|
||||
s = k.String()
|
||||
tempMap[s] = append(tempMap[s], item)
|
||||
if len(tempMap[s]) > 1 {
|
||||
hasMultiValues = true
|
||||
}
|
||||
}
|
||||
}
|
||||
for k, v := range tempMap {
|
||||
if hasMultiValues {
|
||||
m[k] = gvar.New(v)
|
||||
} else {
|
||||
m[k] = gvar.New(v[0])
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// MapKeyStr converts `r` to a map[string]Map of which key is specified by `key`.
|
||||
func (r Result) MapKeyStr(key string) map[string]Map {
|
||||
m := make(map[string]Map)
|
||||
for _, item := range r {
|
||||
if v, ok := item[key]; ok {
|
||||
m[v.String()] = item.Map()
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// MapKeyInt converts `r` to a map[int]Map of which key is specified by `key`.
|
||||
func (r Result) MapKeyInt(key string) map[int]Map {
|
||||
m := make(map[int]Map)
|
||||
for _, item := range r {
|
||||
if v, ok := item[key]; ok {
|
||||
m[v.Int()] = item.Map()
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// MapKeyUint converts `r` to a map[uint]Map of which key is specified by `key`.
|
||||
func (r Result) MapKeyUint(key string) map[uint]Map {
|
||||
m := make(map[uint]Map)
|
||||
for _, item := range r {
|
||||
if v, ok := item[key]; ok {
|
||||
m[v.Uint()] = item.Map()
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordKeyStr converts `r` to a map[string]Record of which key is specified by `key`.
|
||||
func (r Result) RecordKeyStr(key string) map[string]Record {
|
||||
m := make(map[string]Record)
|
||||
for _, item := range r {
|
||||
if v, ok := item[key]; ok {
|
||||
m[v.String()] = item
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordKeyInt converts `r` to a map[int]Record of which key is specified by `key`.
|
||||
func (r Result) RecordKeyInt(key string) map[int]Record {
|
||||
m := make(map[int]Record)
|
||||
for _, item := range r {
|
||||
if v, ok := item[key]; ok {
|
||||
m[v.Int()] = item
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordKeyUint converts `r` to a map[uint]Record of which key is specified by `key`.
|
||||
func (r Result) RecordKeyUint(key string) map[uint]Record {
|
||||
m := make(map[uint]Record)
|
||||
for _, item := range r {
|
||||
if v, ok := item[key]; ok {
|
||||
m[v.Uint()] = item
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Structs converts `r` to struct slice.
|
||||
// Note that the parameter `pointer` should be type of *[]struct/*[]*struct.
|
||||
func (r Result) Structs(pointer any) (err error) {
|
||||
// If the result is empty and the target pointer is not empty, it returns error.
|
||||
if r.IsEmpty() {
|
||||
if !empty.IsEmpty(pointer, true) {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
sliceOption = gconv.SliceOption{ContinueOnError: true}
|
||||
structOption = gconv.StructOption{
|
||||
PriorityTag: OrmTagForStruct,
|
||||
ContinueOnError: true,
|
||||
}
|
||||
)
|
||||
return converter.Structs(r, pointer, gconv.StructsOption{
|
||||
SliceOption: sliceOption,
|
||||
StructOption: structOption,
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,512 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gstructs"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/gutil"
|
||||
)
|
||||
|
||||
// ScanList converts `r` to struct slice which contains other complex struct attributes.
|
||||
// Note that the parameter `structSlicePointer` should be type of *[]struct/*[]*struct.
|
||||
//
|
||||
// Usage example 1: Normal attribute struct relation:
|
||||
//
|
||||
// type EntityUser struct {
|
||||
// Uid int
|
||||
// Name string
|
||||
// }
|
||||
//
|
||||
// type EntityUserDetail struct {
|
||||
// Uid int
|
||||
// Address string
|
||||
// }
|
||||
//
|
||||
// type EntityUserScores struct {
|
||||
// Id int
|
||||
// Uid int
|
||||
// Score int
|
||||
// Course string
|
||||
// }
|
||||
//
|
||||
// type Entity struct {
|
||||
// User *EntityUser
|
||||
// UserDetail *EntityUserDetail
|
||||
// UserScores []*EntityUserScores
|
||||
// }
|
||||
//
|
||||
// var users []*Entity
|
||||
// ScanList(&users, "User")
|
||||
// ScanList(&users, "User", "uid")
|
||||
// ScanList(&users, "UserDetail", "User", "uid:Uid")
|
||||
// ScanList(&users, "UserScores", "User", "uid:Uid")
|
||||
// ScanList(&users, "UserScores", "User", "uid")
|
||||
//
|
||||
// Usage example 2: Embedded attribute struct relation:
|
||||
//
|
||||
// type EntityUser struct {
|
||||
// Uid int
|
||||
// Name string
|
||||
// }
|
||||
//
|
||||
// type EntityUserDetail struct {
|
||||
// Uid int
|
||||
// Address string
|
||||
// }
|
||||
//
|
||||
// type EntityUserScores struct {
|
||||
// Id int
|
||||
// Uid int
|
||||
// Score int
|
||||
// }
|
||||
//
|
||||
// type Entity struct {
|
||||
// EntityUser
|
||||
// UserDetail EntityUserDetail
|
||||
// UserScores []EntityUserScores
|
||||
// }
|
||||
//
|
||||
// var users []*Entity
|
||||
// ScanList(&users)
|
||||
// ScanList(&users, "UserDetail", "uid")
|
||||
// ScanList(&users, "UserScores", "uid")
|
||||
//
|
||||
// The parameters "User/UserDetail/UserScores" in the example codes specify the target attribute struct
|
||||
// that current result will be bound to.
|
||||
//
|
||||
// The "uid" in the example codes is the table field name of the result, and the "Uid" is the relational
|
||||
// struct attribute name - not the attribute name of the bound to target. In the example codes, it's attribute
|
||||
// name "Uid" of "User" of entity "Entity". It automatically calculates the HasOne/HasMany relationship with
|
||||
// given `relation` parameter.
|
||||
//
|
||||
// See the example or unit testing cases for clear understanding for this function.
|
||||
func (r Result) ScanList(structSlicePointer any, bindToAttrName string, relationAttrNameAndFields ...string) (err error) {
|
||||
out, err := checkGetSliceElementInfoForScanList(structSlicePointer, bindToAttrName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
relationAttrName string
|
||||
relationFields string
|
||||
)
|
||||
switch len(relationAttrNameAndFields) {
|
||||
case 2:
|
||||
relationAttrName = relationAttrNameAndFields[0]
|
||||
relationFields = relationAttrNameAndFields[1]
|
||||
case 1:
|
||||
relationFields = relationAttrNameAndFields[0]
|
||||
}
|
||||
return doScanList(doScanListInput{
|
||||
Model: nil,
|
||||
Result: r,
|
||||
StructSlicePointer: structSlicePointer,
|
||||
StructSliceValue: out.SliceReflectValue,
|
||||
BindToAttrName: bindToAttrName,
|
||||
RelationAttrName: relationAttrName,
|
||||
RelationFields: relationFields,
|
||||
})
|
||||
}
|
||||
|
||||
type checkGetSliceElementInfoForScanListOutput struct {
|
||||
SliceReflectValue reflect.Value
|
||||
BindToAttrType reflect.Type
|
||||
}
|
||||
|
||||
func checkGetSliceElementInfoForScanList(structSlicePointer any, bindToAttrName string) (out *checkGetSliceElementInfoForScanListOutput, err error) {
|
||||
// Necessary checks for parameters.
|
||||
if structSlicePointer == nil {
|
||||
return nil, gerror.NewCode(gcode.CodeInvalidParameter, `structSlicePointer cannot be nil`)
|
||||
}
|
||||
if bindToAttrName == "" {
|
||||
return nil, gerror.NewCode(gcode.CodeInvalidParameter, `bindToAttrName should not be empty`)
|
||||
}
|
||||
var (
|
||||
reflectType reflect.Type
|
||||
reflectValue = reflect.ValueOf(structSlicePointer)
|
||||
reflectKind = reflectValue.Kind()
|
||||
)
|
||||
if reflectKind == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
reflectKind = reflectValue.Kind()
|
||||
}
|
||||
if reflectKind != reflect.Pointer {
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
"structSlicePointer should be type of *[]struct/*[]*struct, but got: %s",
|
||||
reflect.TypeOf(structSlicePointer).String(),
|
||||
)
|
||||
}
|
||||
out = &checkGetSliceElementInfoForScanListOutput{
|
||||
SliceReflectValue: reflectValue.Elem(),
|
||||
}
|
||||
// Find the element struct type of the slice.
|
||||
reflectType = reflectValue.Type().Elem().Elem()
|
||||
reflectKind = reflectType.Kind()
|
||||
for reflectKind == reflect.Pointer {
|
||||
reflectType = reflectType.Elem()
|
||||
reflectKind = reflectType.Kind()
|
||||
}
|
||||
if reflectKind != reflect.Struct {
|
||||
err = gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
"structSlicePointer should be type of *[]struct/*[]*struct, but got: %s",
|
||||
reflect.TypeOf(structSlicePointer).String(),
|
||||
)
|
||||
return
|
||||
}
|
||||
// Find the target field by given name.
|
||||
structField, ok := reflectType.FieldByName(bindToAttrName)
|
||||
if !ok {
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`field "%s" not found in element of "%s"`,
|
||||
bindToAttrName,
|
||||
reflect.TypeOf(structSlicePointer).String(),
|
||||
)
|
||||
}
|
||||
// Find the attribute struct type for ORM fields filtering.
|
||||
reflectType = structField.Type
|
||||
reflectKind = reflectType.Kind()
|
||||
for reflectKind == reflect.Pointer {
|
||||
reflectType = reflectType.Elem()
|
||||
reflectKind = reflectType.Kind()
|
||||
}
|
||||
if reflectKind == reflect.Slice || reflectKind == reflect.Array {
|
||||
reflectType = reflectType.Elem()
|
||||
// reflectKind = reflectType.Kind()
|
||||
}
|
||||
out.BindToAttrType = reflectType
|
||||
return
|
||||
}
|
||||
|
||||
type doScanListInput struct {
|
||||
Model *Model
|
||||
Result Result
|
||||
StructSlicePointer any
|
||||
StructSliceValue reflect.Value
|
||||
BindToAttrName string
|
||||
RelationAttrName string
|
||||
RelationFields string
|
||||
}
|
||||
|
||||
// doScanList converts `result` to struct slice which contains other complex struct attributes recursively.
|
||||
// The parameter `model` is used for recursively scanning purpose, which means, it can scan the attribute struct/structs recursively,
|
||||
// but it needs the Model for database accessing.
|
||||
// Note that the parameter `structSlicePointer` should be type of *[]struct/*[]*struct.
|
||||
func doScanList(in doScanListInput) (err error) {
|
||||
if in.Result.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
if in.BindToAttrName == "" {
|
||||
return gerror.NewCode(gcode.CodeInvalidParameter, `bindToAttrName should not be empty`)
|
||||
}
|
||||
|
||||
length := len(in.Result)
|
||||
if length == 0 {
|
||||
// The pointed slice is not empty.
|
||||
if in.StructSliceValue.Len() > 0 {
|
||||
// It here checks if it has struct item, which is already initialized.
|
||||
// It then returns error to warn the developer its empty and no conversion.
|
||||
if v := in.StructSliceValue.Index(0); v.Kind() != reflect.Pointer {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
}
|
||||
// Do nothing for empty struct slice.
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
arrayValue reflect.Value // Like: []*Entity
|
||||
arrayItemType reflect.Type // Like: *Entity
|
||||
reflectType = reflect.TypeOf(in.StructSlicePointer)
|
||||
)
|
||||
if in.StructSliceValue.Len() > 0 {
|
||||
arrayValue = in.StructSliceValue
|
||||
} else {
|
||||
arrayValue = reflect.MakeSlice(reflectType.Elem(), length, length)
|
||||
}
|
||||
|
||||
// Slice element item.
|
||||
arrayItemType = arrayValue.Index(0).Type()
|
||||
|
||||
// Relation variables.
|
||||
var (
|
||||
relationDataMap map[string]Value
|
||||
relationFromFieldName string // Eg: relationKV: id:uid -> id
|
||||
relationBindToFieldName string // Eg: relationKV: id:uid -> uid
|
||||
)
|
||||
if len(in.RelationFields) > 0 {
|
||||
// The relation key string of table field name and attribute name
|
||||
// can be joined with char '=' or ':'.
|
||||
array := gstr.SplitAndTrim(in.RelationFields, "=")
|
||||
if len(array) == 1 {
|
||||
// Compatible with old splitting char ':'.
|
||||
array = gstr.SplitAndTrim(in.RelationFields, ":")
|
||||
}
|
||||
if len(array) == 1 {
|
||||
// The relation names are the same.
|
||||
array = []string{in.RelationFields, in.RelationFields}
|
||||
}
|
||||
if len(array) == 2 {
|
||||
// Defined table field to relation attribute name.
|
||||
// Like:
|
||||
// uid:Uid
|
||||
// uid:UserId
|
||||
relationFromFieldName = array[0]
|
||||
relationBindToFieldName = array[1]
|
||||
if key, _ := gutil.MapPossibleItemByKey(in.Result[0].Map(), relationFromFieldName); key == "" {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`cannot find possible related table field name "%s" from given relation fields "%s"`,
|
||||
relationFromFieldName,
|
||||
in.RelationFields,
|
||||
)
|
||||
} else {
|
||||
relationFromFieldName = key
|
||||
}
|
||||
} else {
|
||||
return gerror.NewCode(
|
||||
gcode.CodeInvalidParameter,
|
||||
`parameter relationKV should be format of "ResultFieldName:BindToAttrName"`,
|
||||
)
|
||||
}
|
||||
if relationFromFieldName != "" {
|
||||
// Note that the value might be type of slice.
|
||||
relationDataMap = in.Result.MapKeyValue(relationFromFieldName)
|
||||
}
|
||||
if len(relationDataMap) == 0 {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`cannot find the relation data map, maybe invalid relation fields given "%v"`,
|
||||
in.RelationFields,
|
||||
)
|
||||
}
|
||||
}
|
||||
// Bind to target attribute.
|
||||
var (
|
||||
ok bool
|
||||
bindToAttrValue reflect.Value
|
||||
bindToAttrKind reflect.Kind
|
||||
bindToAttrType reflect.Type
|
||||
bindToAttrField reflect.StructField
|
||||
)
|
||||
if arrayItemType.Kind() == reflect.Pointer {
|
||||
if bindToAttrField, ok = arrayItemType.Elem().FieldByName(in.BindToAttrName); !ok {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`,
|
||||
in.BindToAttrName,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if bindToAttrField, ok = arrayItemType.FieldByName(in.BindToAttrName); !ok {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`,
|
||||
in.BindToAttrName,
|
||||
)
|
||||
}
|
||||
}
|
||||
bindToAttrType = bindToAttrField.Type
|
||||
bindToAttrKind = bindToAttrType.Kind()
|
||||
|
||||
// Bind to relation conditions.
|
||||
var (
|
||||
relationFromAttrValue reflect.Value
|
||||
relationFromAttrField reflect.Value
|
||||
relationBindToFieldNameChecked bool
|
||||
)
|
||||
for i := 0; i < arrayValue.Len(); i++ {
|
||||
arrayElemValue := arrayValue.Index(i)
|
||||
// The FieldByName should be called on non-pointer reflect.Value.
|
||||
if arrayElemValue.Kind() == reflect.Pointer {
|
||||
// Like: []*Entity
|
||||
arrayElemValue = arrayElemValue.Elem()
|
||||
if !arrayElemValue.IsValid() {
|
||||
// The element is nil, then create one and set it to the slice.
|
||||
// The "reflect.New(itemType.Elem())" creates a new element and returns the address of it.
|
||||
// For example:
|
||||
// reflect.New(itemType.Elem()) => *Entity
|
||||
// reflect.New(itemType.Elem()).Elem() => Entity
|
||||
arrayElemValue = reflect.New(arrayItemType.Elem()).Elem()
|
||||
arrayValue.Index(i).Set(arrayElemValue.Addr())
|
||||
}
|
||||
// } else {
|
||||
// Like: []Entity
|
||||
}
|
||||
bindToAttrValue = arrayElemValue.FieldByName(in.BindToAttrName)
|
||||
if in.RelationAttrName != "" {
|
||||
// Attribute value of current slice element.
|
||||
relationFromAttrValue = arrayElemValue.FieldByName(in.RelationAttrName)
|
||||
if relationFromAttrValue.Kind() == reflect.Pointer {
|
||||
relationFromAttrValue = relationFromAttrValue.Elem()
|
||||
}
|
||||
} else {
|
||||
// Current slice element.
|
||||
relationFromAttrValue = arrayElemValue
|
||||
}
|
||||
if len(relationDataMap) > 0 && !relationFromAttrValue.IsValid() {
|
||||
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
|
||||
}
|
||||
// Check and find possible bind to attribute name.
|
||||
if in.RelationFields != "" && !relationBindToFieldNameChecked {
|
||||
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
|
||||
if !relationFromAttrField.IsValid() {
|
||||
fieldMap, _ := gstructs.FieldMap(gstructs.FieldMapInput{
|
||||
Pointer: relationFromAttrValue,
|
||||
RecursiveOption: gstructs.RecursiveOptionEmbeddedNoTag,
|
||||
})
|
||||
if key, _ := gutil.MapPossibleItemByKey(gconv.Map(fieldMap), relationBindToFieldName); key == "" {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`cannot find possible related attribute name "%s" from given relation fields "%s"`,
|
||||
relationBindToFieldName,
|
||||
in.RelationFields,
|
||||
)
|
||||
} else {
|
||||
relationBindToFieldName = key
|
||||
}
|
||||
}
|
||||
relationBindToFieldNameChecked = true
|
||||
}
|
||||
switch bindToAttrKind {
|
||||
case reflect.Array, reflect.Slice:
|
||||
if len(relationDataMap) > 0 {
|
||||
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
|
||||
if relationFromAttrField.IsValid() {
|
||||
results := make(Result, 0)
|
||||
for _, v := range relationDataMap[gconv.String(relationFromAttrField.Interface())].Slice() {
|
||||
results = append(results, v.(Record))
|
||||
}
|
||||
if err = results.Structs(bindToAttrValue.Addr()); err != nil {
|
||||
return err
|
||||
}
|
||||
// Recursively Scan.
|
||||
if in.Model != nil {
|
||||
if err = in.Model.doWithScanStructs(bindToAttrValue.Addr()); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Maybe the attribute does not exist yet.
|
||||
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
|
||||
}
|
||||
} else {
|
||||
return gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`relationKey should not be empty as field "%s" is slice`,
|
||||
in.BindToAttrName,
|
||||
)
|
||||
}
|
||||
|
||||
case reflect.Pointer:
|
||||
var element reflect.Value
|
||||
if bindToAttrValue.IsNil() {
|
||||
element = reflect.New(bindToAttrType.Elem()).Elem()
|
||||
} else {
|
||||
element = bindToAttrValue.Elem()
|
||||
}
|
||||
if len(relationDataMap) > 0 {
|
||||
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
|
||||
if relationFromAttrField.IsValid() {
|
||||
v := relationDataMap[gconv.String(relationFromAttrField.Interface())]
|
||||
if v == nil {
|
||||
// There's no relational data.
|
||||
continue
|
||||
}
|
||||
if v.IsSlice() {
|
||||
if err = v.Slice()[0].(Record).Struct(element); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err = v.Val().(Record).Struct(element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Maybe the attribute does not exist yet.
|
||||
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
|
||||
}
|
||||
} else {
|
||||
if i >= len(in.Result) {
|
||||
// There's no relational data.
|
||||
continue
|
||||
}
|
||||
v := in.Result[i]
|
||||
if v == nil {
|
||||
// There's no relational data.
|
||||
continue
|
||||
}
|
||||
if err = v.Struct(element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Recursively Scan.
|
||||
if in.Model != nil {
|
||||
if err = in.Model.doWithScanStruct(element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
bindToAttrValue.Set(element.Addr())
|
||||
|
||||
case reflect.Struct:
|
||||
if len(relationDataMap) > 0 {
|
||||
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToFieldName)
|
||||
if relationFromAttrField.IsValid() {
|
||||
relationDataItem := relationDataMap[gconv.String(relationFromAttrField.Interface())]
|
||||
if relationDataItem == nil {
|
||||
// There's no relational data.
|
||||
continue
|
||||
}
|
||||
if relationDataItem.IsSlice() {
|
||||
if err = relationDataItem.Slice()[0].(Record).Struct(bindToAttrValue); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err = relationDataItem.Val().(Record).Struct(bindToAttrValue); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Maybe the attribute does not exist yet.
|
||||
return gerror.NewCodef(gcode.CodeInvalidParameter, `invalid relation fields specified: "%v"`, in.RelationFields)
|
||||
}
|
||||
} else {
|
||||
if i >= len(in.Result) {
|
||||
// There's no relational data.
|
||||
continue
|
||||
}
|
||||
relationDataItem := in.Result[i]
|
||||
if relationDataItem == nil {
|
||||
// There's no relational data.
|
||||
continue
|
||||
}
|
||||
if err = relationDataItem.Struct(bindToAttrValue); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Recursively Scan.
|
||||
if in.Model != nil {
|
||||
if err = in.Model.doWithScanStruct(bindToAttrValue); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return gerror.NewCodef(gcode.CodeInvalidParameter, `unsupported attribute type: %s`, bindToAttrKind.String())
|
||||
}
|
||||
}
|
||||
reflect.ValueOf(in.StructSlicePointer).Elem().Set(arrayValue)
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
package database
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package instance provides instances management.
|
||||
//
|
||||
// Note that this package is not used for cache, as it has no cache expiration.
|
||||
package instance
|
||||
|
||||
import (
|
||||
"github.com/gogf/gf/v2/container/gmap"
|
||||
"github.com/gogf/gf/v2/encoding/ghash"
|
||||
)
|
||||
|
||||
const (
|
||||
groupNumber = 64
|
||||
)
|
||||
|
||||
var (
|
||||
groups = make([]*gmap.StrAnyMap, groupNumber)
|
||||
)
|
||||
|
||||
func init() {
|
||||
for i := 0; i < groupNumber; i++ {
|
||||
groups[i] = gmap.NewStrAnyMap(true)
|
||||
}
|
||||
}
|
||||
|
||||
func getGroup(key string) *gmap.StrAnyMap {
|
||||
return groups[int(ghash.DJB([]byte(key))%groupNumber)]
|
||||
}
|
||||
|
||||
// Get returns the instance by given name.
|
||||
func Get(name string) any {
|
||||
return getGroup(name).Get(name)
|
||||
}
|
||||
|
||||
// Set sets an instance to the instance manager with given name.
|
||||
func Set(name string, instance any) {
|
||||
getGroup(name).Set(name, instance)
|
||||
}
|
||||
|
||||
// GetOrSet returns the instance by name,
|
||||
// or set instance to the instance manager if it does not exist and returns this instance.
|
||||
func GetOrSet(name string, instance any) any {
|
||||
return getGroup(name).GetOrSet(name, instance)
|
||||
}
|
||||
|
||||
// GetOrSetFunc returns the instance by name,
|
||||
// or sets instance with returned value of callback function `f` if it does not exist
|
||||
// and then returns this instance.
|
||||
func GetOrSetFunc(name string, f func() any) any {
|
||||
return getGroup(name).GetOrSetFunc(name, f)
|
||||
}
|
||||
|
||||
// GetOrSetFuncLock returns the instance by name,
|
||||
// or sets instance with returned value of callback function `f` if it does not exist
|
||||
// and then returns this instance.
|
||||
//
|
||||
// GetOrSetFuncLock differs with GetOrSetFunc function is that it executes function `f`
|
||||
// with mutex.Lock of the hash map.
|
||||
func GetOrSetFuncLock(name string, f func() any) any {
|
||||
return getGroup(name).GetOrSetFuncLock(name, f)
|
||||
}
|
||||
|
||||
// SetIfNotExist sets `instance` to the map if the `name` does not exist, then returns true.
|
||||
// It returns false if `name` exists, and `instance` would be ignored.
|
||||
func SetIfNotExist(name string, instance any) bool {
|
||||
return getGroup(name).SetIfNotExist(name, instance)
|
||||
}
|
||||
|
||||
// Clear deletes all instances stored.
|
||||
func Clear() {
|
||||
for i := 0; i < groupNumber; i++ {
|
||||
groups[i].Clear()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package instance_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gogf/gf/v2/internal/instance"
|
||||
"github.com/gogf/gf/v2/test/gtest"
|
||||
)
|
||||
|
||||
func Test_SetGet(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
instance.Set("test-user", 1)
|
||||
t.Assert(instance.Get("test-user"), 1)
|
||||
t.Assert(instance.Get("none-exists"), nil)
|
||||
})
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
t.Assert(instance.GetOrSet("test-1", 1), 1)
|
||||
t.Assert(instance.Get("test-1"), 1)
|
||||
})
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
t.Assert(instance.GetOrSetFunc("test-2", func() any {
|
||||
return 2
|
||||
}), 2)
|
||||
t.Assert(instance.Get("test-2"), 2)
|
||||
})
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
t.Assert(instance.GetOrSetFuncLock("test-3", func() any {
|
||||
return 3
|
||||
}), 3)
|
||||
t.Assert(instance.Get("test-3"), 3)
|
||||
})
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
t.Assert(instance.SetIfNotExist("test-4", 4), true)
|
||||
t.Assert(instance.Get("test-4"), 4)
|
||||
t.Assert(instance.SetIfNotExist("test-4", 5), false)
|
||||
t.Assert(instance.Get("test-4"), 4)
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package intlog provides internal logging for GoFrame development usage only.
|
||||
package intlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/utils"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/gogf/gf/v2/debug/gdebug"
|
||||
)
|
||||
|
||||
const (
|
||||
stackFilterKey = "/internal/intlog"
|
||||
)
|
||||
|
||||
// Print prints `v` with newline using fmt.Println.
|
||||
// The parameter `v` can be multiple variables.
|
||||
func Print(ctx context.Context, v ...any) {
|
||||
if !utils.IsDebugEnabled() {
|
||||
return
|
||||
}
|
||||
doPrint(ctx, fmt.Sprint(v...), false)
|
||||
}
|
||||
|
||||
// Printf prints `v` with format `format` using fmt.Printf.
|
||||
// The parameter `v` can be multiple variables.
|
||||
func Printf(ctx context.Context, format string, v ...any) {
|
||||
if !utils.IsDebugEnabled() {
|
||||
return
|
||||
}
|
||||
doPrint(ctx, fmt.Sprintf(format, v...), false)
|
||||
}
|
||||
|
||||
// Error prints `v` with newline using fmt.Println.
|
||||
// The parameter `v` can be multiple variables.
|
||||
func Error(ctx context.Context, v ...any) {
|
||||
if !utils.IsDebugEnabled() {
|
||||
return
|
||||
}
|
||||
doPrint(ctx, fmt.Sprint(v...), true)
|
||||
}
|
||||
|
||||
// Errorf prints `v` with format `format` using fmt.Printf.
|
||||
func Errorf(ctx context.Context, format string, v ...any) {
|
||||
if !utils.IsDebugEnabled() {
|
||||
return
|
||||
}
|
||||
doPrint(ctx, fmt.Sprintf(format, v...), true)
|
||||
}
|
||||
|
||||
// PrintFunc prints the output from function `f`.
|
||||
// It only calls function `f` if debug mode is enabled.
|
||||
func PrintFunc(ctx context.Context, f func() string) {
|
||||
if !utils.IsDebugEnabled() {
|
||||
return
|
||||
}
|
||||
s := f()
|
||||
if s == "" {
|
||||
return
|
||||
}
|
||||
doPrint(ctx, s, false)
|
||||
}
|
||||
|
||||
// ErrorFunc prints the output from function `f`.
|
||||
// It only calls function `f` if debug mode is enabled.
|
||||
func ErrorFunc(ctx context.Context, f func() string) {
|
||||
if !utils.IsDebugEnabled() {
|
||||
return
|
||||
}
|
||||
s := f()
|
||||
if s == "" {
|
||||
return
|
||||
}
|
||||
doPrint(ctx, s, true)
|
||||
}
|
||||
|
||||
func doPrint(ctx context.Context, content string, stack bool) {
|
||||
if !utils.IsDebugEnabled() {
|
||||
return
|
||||
}
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
buffer.WriteString(time.Now().Format("2006-01-02 15:04:05.000"))
|
||||
buffer.WriteString(" [INTE] ")
|
||||
buffer.WriteString(file())
|
||||
buffer.WriteString(" ")
|
||||
if s := traceIDStr(ctx); s != "" {
|
||||
buffer.WriteString(s + " ")
|
||||
}
|
||||
buffer.WriteString(content)
|
||||
buffer.WriteString("\n")
|
||||
if stack {
|
||||
buffer.WriteString("Caller Stack:\n")
|
||||
buffer.WriteString(gdebug.StackWithFilter([]string{stackFilterKey}))
|
||||
}
|
||||
fmt.Print(buffer.String())
|
||||
}
|
||||
|
||||
// traceIDStr retrieves and returns the trace id string for logging output.
|
||||
func traceIDStr(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
spanCtx := trace.SpanContextFromContext(ctx)
|
||||
if traceID := spanCtx.TraceID(); traceID.IsValid() {
|
||||
return "{" + traceID.String() + "}"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// file returns caller file name along with its line number.
|
||||
func file() string {
|
||||
_, p, l := gdebug.CallerWithFilter([]string{stackFilterKey})
|
||||
return fmt.Sprintf(`%s:%d`, filepath.Base(p), l)
|
||||
}
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package json provides json operations wrapping ignoring stdlib or third-party lib json.
|
||||
package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
)
|
||||
|
||||
// RawMessage is a raw encoded JSON value.
|
||||
// It implements Marshaler and Unmarshaler and can
|
||||
// be used to delay JSON decoding or precompute a JSON encoding.
|
||||
type RawMessage = json.RawMessage
|
||||
|
||||
// Marshal adapts to json/encoding Marshal API.
|
||||
//
|
||||
// Marshal returns the JSON encoding of v, adapts to json/encoding Marshal API
|
||||
// Refer to https://godoc.org/encoding/json#Marshal for more information.
|
||||
func Marshal(v any) (marshaledBytes []byte, err error) {
|
||||
marshaledBytes, err = json.Marshal(v)
|
||||
if err != nil {
|
||||
err = gerror.Wrap(err, `json.Marshal failed`)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalIndent same as json.MarshalIndent.
|
||||
func MarshalIndent(v any, prefix, indent string) (marshaledBytes []byte, err error) {
|
||||
marshaledBytes, err = json.MarshalIndent(v, prefix, indent)
|
||||
if err != nil {
|
||||
err = gerror.Wrap(err, `json.MarshalIndent failed`)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Unmarshal adapts to json/encoding Unmarshal API
|
||||
//
|
||||
// Unmarshal parses the JSON-encoded data and stores the result in the value pointed to by v.
|
||||
// Refer to https://godoc.org/encoding/json#Unmarshal for more information.
|
||||
func Unmarshal(data []byte, v any) (err error) {
|
||||
err = json.Unmarshal(data, v)
|
||||
if err != nil {
|
||||
err = gerror.Wrap(err, `json.Unmarshal failed`)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalUseNumber decodes the json data bytes to target interface using number option.
|
||||
func UnmarshalUseNumber(data []byte, v any) (err error) {
|
||||
decoder := NewDecoder(bytes.NewReader(data))
|
||||
decoder.UseNumber()
|
||||
err = decoder.Decode(v)
|
||||
if err != nil {
|
||||
err = gerror.Wrap(err, `json.UnmarshalUseNumber failed`)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NewEncoder same as json.NewEncoder
|
||||
func NewEncoder(writer io.Writer) *json.Encoder {
|
||||
return json.NewEncoder(writer)
|
||||
}
|
||||
|
||||
// NewDecoder adapts to json/stream NewDecoder API.
|
||||
//
|
||||
// NewDecoder returns a new decoder that reads from r.
|
||||
//
|
||||
// Instead of a json/encoding Decoder, a Decoder is returned
|
||||
// Refer to https://godoc.org/encoding/json#NewDecoder for more information.
|
||||
func NewDecoder(reader io.Reader) *json.Decoder {
|
||||
return json.NewDecoder(reader)
|
||||
}
|
||||
|
||||
// Valid reports whether data is a valid JSON encoding.
|
||||
func Valid(data []byte) bool {
|
||||
return json.Valid(data)
|
||||
}
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"git.magicany.cc/black1552/gin-base/log"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
func SetAutoMigrate(models ...interface{}) {
|
||||
if g.IsNil(Db) {
|
||||
log.Error("数据库连接失败")
|
||||
return
|
||||
}
|
||||
err := Db.AutoMigrate(models...)
|
||||
if err != nil {
|
||||
log.Error("数据库迁移失败", err)
|
||||
}
|
||||
}
|
||||
func RenameColumn(dst interface{}, name, newName string) {
|
||||
if Db.Migrator().HasColumn(dst, name) {
|
||||
err := Db.Migrator().RenameColumn(dst, name, newName)
|
||||
if err != nil {
|
||||
log.Error("数据库修改字段失败", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Info("数据库字段不存在", name)
|
||||
}
|
||||
}
|
||||
|
||||
// DropColumn
|
||||
// 删除字段
|
||||
// 例:DropColumn(&User{}, "Sex")
|
||||
func DropColumn(dst interface{}, name string) {
|
||||
if Db.Migrator().HasColumn(dst, name) {
|
||||
err := Db.Migrator().DropColumn(dst, name)
|
||||
if err != nil {
|
||||
log.Error("数据库删除字段失败", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Info("数据库字段不存在", name)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package reflection provides some reflection functions for internal usage.
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type OriginValueAndKindOutput struct {
|
||||
InputValue reflect.Value
|
||||
InputKind reflect.Kind
|
||||
OriginValue reflect.Value
|
||||
OriginKind reflect.Kind
|
||||
}
|
||||
|
||||
// OriginValueAndKind retrieves and returns the original reflect value and kind.
|
||||
func OriginValueAndKind(value any) (out OriginValueAndKindOutput) {
|
||||
if v, ok := value.(reflect.Value); ok {
|
||||
out.InputValue = v
|
||||
} else {
|
||||
out.InputValue = reflect.ValueOf(value)
|
||||
}
|
||||
out.InputKind = out.InputValue.Kind()
|
||||
out.OriginValue = out.InputValue
|
||||
out.OriginKind = out.InputKind
|
||||
for out.OriginKind == reflect.Pointer {
|
||||
out.OriginValue = out.OriginValue.Elem()
|
||||
out.OriginKind = out.OriginValue.Kind()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type OriginTypeAndKindOutput struct {
|
||||
InputType reflect.Type
|
||||
InputKind reflect.Kind
|
||||
OriginType reflect.Type
|
||||
OriginKind reflect.Kind
|
||||
}
|
||||
|
||||
// OriginTypeAndKind retrieves and returns the original reflect type and kind.
|
||||
func OriginTypeAndKind(value any) (out OriginTypeAndKindOutput) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
if reflectType, ok := value.(reflect.Type); ok {
|
||||
out.InputType = reflectType
|
||||
} else {
|
||||
if reflectValue, ok := value.(reflect.Value); ok {
|
||||
out.InputType = reflectValue.Type()
|
||||
} else {
|
||||
out.InputType = reflect.TypeOf(value)
|
||||
}
|
||||
}
|
||||
out.InputKind = out.InputType.Kind()
|
||||
out.OriginType = out.InputType
|
||||
out.OriginKind = out.InputKind
|
||||
for out.OriginKind == reflect.Pointer {
|
||||
out.OriginType = out.OriginType.Elem()
|
||||
out.OriginKind = out.OriginType.Kind()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ValueToInterface converts reflect value to its interface type.
|
||||
func ValueToInterface(v reflect.Value) (value any, ok bool) {
|
||||
if v.IsValid() && v.CanInterface() {
|
||||
return v.Interface(), true
|
||||
}
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
return v.Bool(), true
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int(), true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint(), true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float(), true
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
return v.Complex(), true
|
||||
case reflect.String:
|
||||
return v.String(), true
|
||||
case reflect.Pointer:
|
||||
return ValueToInterface(v.Elem())
|
||||
case reflect.Interface:
|
||||
return ValueToInterface(v.Elem())
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package utils provides some utility functions for internal usage.
|
||||
package utils
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
import "reflect"
|
||||
|
||||
// IsArray checks whether given value is array/slice.
|
||||
// Note that it uses reflect internally implementing this feature.
|
||||
func IsArray(value any) bool {
|
||||
rv := reflect.ValueOf(value)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
import "git.magicany.cc/black1552/gin-base/database/command"
|
||||
|
||||
const (
|
||||
// Debug key for checking if in debug mode.
|
||||
commandEnvKeyForDebugKey = "gf.debug"
|
||||
)
|
||||
|
||||
// isDebugEnabled marks whether GoFrame debug mode is enabled.
|
||||
var isDebugEnabled = false
|
||||
|
||||
func init() {
|
||||
// Debugging configured.
|
||||
value := command.GetOptWithEnv(commandEnvKeyForDebugKey)
|
||||
if value == "" || value == "0" || value == "false" {
|
||||
isDebugEnabled = false
|
||||
} else {
|
||||
isDebugEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// IsDebugEnabled checks and returns whether debug mode is enabled.
|
||||
// The debug mode is enabled when command argument "gf.debug" or environment "GF_DEBUG" is passed.
|
||||
func IsDebugEnabled() bool {
|
||||
return isDebugEnabled
|
||||
}
|
||||
|
||||
// SetDebugEnabled enables/disables the internal debug info.
|
||||
func SetDebugEnabled(enabled bool) {
|
||||
isDebugEnabled = enabled
|
||||
}
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// ReadCloser implements the io.ReadCloser interface
|
||||
// which is used for reading request body content multiple times.
|
||||
//
|
||||
// Note that it cannot be closed.
|
||||
type ReadCloser struct {
|
||||
index int // Current read position.
|
||||
content []byte // Content.
|
||||
repeatable bool // Mark the content can be repeatable read.
|
||||
}
|
||||
|
||||
// NewReadCloser creates and returns a RepeatReadCloser object.
|
||||
func NewReadCloser(content []byte, repeatable bool) io.ReadCloser {
|
||||
return &ReadCloser{
|
||||
content: content,
|
||||
repeatable: repeatable,
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the io.ReadCloser interface.
|
||||
func (b *ReadCloser) Read(p []byte) (n int, err error) {
|
||||
// Make it repeatable reading.
|
||||
if b.index >= len(b.content) && b.repeatable {
|
||||
b.index = 0
|
||||
}
|
||||
n = copy(p, b.content[b.index:])
|
||||
b.index += n
|
||||
if b.index >= len(b.content) {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close implements the io.ReadCloser interface.
|
||||
func (b *ReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database/empty"
|
||||
)
|
||||
|
||||
// IsNil checks whether `value` is nil, especially for any type value.
|
||||
func IsNil(value any) bool {
|
||||
return empty.IsNil(value)
|
||||
}
|
||||
|
||||
// IsEmpty checks whether `value` is empty.
|
||||
func IsEmpty(value any) bool {
|
||||
return empty.IsEmpty(value)
|
||||
}
|
||||
|
||||
// IsInt checks whether `value` is type of int.
|
||||
func IsInt(value any) bool {
|
||||
switch value.(type) {
|
||||
case int, *int, int8, *int8, int16, *int16, int32, *int32, int64, *int64:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsUint checks whether `value` is type of uint.
|
||||
func IsUint(value any) bool {
|
||||
switch value.(type) {
|
||||
case uint, *uint, uint8, *uint8, uint16, *uint16, uint32, *uint32, uint64, *uint64:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsFloat checks whether `value` is type of float.
|
||||
func IsFloat(value any) bool {
|
||||
switch value.(type) {
|
||||
case float32, *float32, float64, *float64:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSlice checks whether `value` is type of slice.
|
||||
func IsSlice(value any) bool {
|
||||
var (
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
reflectKind = reflectValue.Kind()
|
||||
)
|
||||
for reflectKind == reflect.Pointer {
|
||||
reflectValue = reflectValue.Elem()
|
||||
reflectKind = reflectValue.Kind()
|
||||
}
|
||||
switch reflectKind {
|
||||
case reflect.Slice, reflect.Array:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsMap checks whether `value` is type of map.
|
||||
func IsMap(value any) bool {
|
||||
var (
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
reflectKind = reflectValue.Kind()
|
||||
)
|
||||
for reflectKind == reflect.Pointer {
|
||||
reflectValue = reflectValue.Elem()
|
||||
reflectKind = reflectValue.Kind()
|
||||
}
|
||||
switch reflectKind {
|
||||
case reflect.Map:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsStruct checks whether `value` is type of struct.
|
||||
func IsStruct(value any) bool {
|
||||
reflectType := reflect.TypeOf(value)
|
||||
if reflectType == nil {
|
||||
return false
|
||||
}
|
||||
reflectKind := reflectType.Kind()
|
||||
for reflectKind == reflect.Pointer {
|
||||
reflectType = reflectType.Elem()
|
||||
reflectKind = reflectType.Kind()
|
||||
}
|
||||
switch reflectKind {
|
||||
case reflect.Struct:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ListToMapByKey converts `list` to a map[string]any of which key is specified by `key`.
|
||||
// Note that the item value may be type of slice.
|
||||
func ListToMapByKey(list []map[string]any, key string) map[string]any {
|
||||
var (
|
||||
s = ""
|
||||
m = make(map[string]any)
|
||||
tempMap = make(map[string][]any)
|
||||
hasMultiValues bool
|
||||
)
|
||||
for _, item := range list {
|
||||
if k, ok := item[key]; ok {
|
||||
s = fmt.Sprintf(`%v`, k)
|
||||
tempMap[s] = append(tempMap[s], item)
|
||||
if len(tempMap[s]) > 1 {
|
||||
hasMultiValues = true
|
||||
}
|
||||
}
|
||||
}
|
||||
for k, v := range tempMap {
|
||||
if hasMultiValues {
|
||||
m[k] = v
|
||||
} else {
|
||||
m[k] = v[0]
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
// MapPossibleItemByKey tries to find the possible key-value pair for given key ignoring cases and symbols.
|
||||
//
|
||||
// Note that this function might be of low performance.
|
||||
func MapPossibleItemByKey(data map[string]any, key string) (foundKey string, foundValue any) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
if v, ok := data[key]; ok {
|
||||
return key, v
|
||||
}
|
||||
// Loop checking.
|
||||
for k, v := range data {
|
||||
if EqualFoldWithoutChars(k, key) {
|
||||
return k, v
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// MapContainsPossibleKey checks if the given `key` is contained in given map `data`.
|
||||
// It checks the key ignoring cases and symbols.
|
||||
//
|
||||
// Note that this function might be of low performance.
|
||||
func MapContainsPossibleKey(data map[string]any, key string) bool {
|
||||
if k, _ := MapPossibleItemByKey(data, key); k != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// CanCallIsNil Can reflect.Value call reflect.Value.IsNil.
|
||||
// It can avoid reflect.Value.IsNil panics.
|
||||
func CanCallIsNil(v any) bool {
|
||||
rv, ok := v.(reflect.Value)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch rv.Kind() {
|
||||
case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Slice, reflect.UnsafePointer:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// DefaultTrimChars are the characters which are stripped by Trim* functions in default.
|
||||
var DefaultTrimChars = string([]byte{
|
||||
'\t', // Tab.
|
||||
'\v', // Vertical tab.
|
||||
'\n', // New line (line feed).
|
||||
'\r', // Carriage return.
|
||||
'\f', // New page.
|
||||
' ', // Ordinary space.
|
||||
0x00, // NUL-byte.
|
||||
0x85, // Delete.
|
||||
0xA0, // Non-breaking space.
|
||||
})
|
||||
|
||||
// IsLetterUpper checks whether the given byte b is in upper case.
|
||||
func IsLetterUpper(b byte) bool {
|
||||
if b >= byte('A') && b <= byte('Z') {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsLetterLower checks whether the given byte b is in lower case.
|
||||
func IsLetterLower(b byte) bool {
|
||||
if b >= byte('a') && b <= byte('z') {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsLetter checks whether the given byte b is a letter.
|
||||
func IsLetter(b byte) bool {
|
||||
return IsLetterUpper(b) || IsLetterLower(b)
|
||||
}
|
||||
|
||||
// IsNumeric checks whether the given string s is numeric.
|
||||
// Note that float string like "123.456" is also numeric.
|
||||
func IsNumeric(s string) bool {
|
||||
var (
|
||||
dotCount = 0
|
||||
length = len(s)
|
||||
)
|
||||
if length == 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < length; i++ {
|
||||
if (s[i] == '-' || s[i] == '+') && i == 0 {
|
||||
if length == 1 {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if s[i] == '.' {
|
||||
dotCount++
|
||||
if i > 0 && i < length-1 && s[i-1] >= '0' && s[i-1] <= '9' {
|
||||
continue
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if s[i] < '0' || s[i] > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return dotCount <= 1
|
||||
}
|
||||
|
||||
// UcFirst returns a copy of the string s with the first letter mapped to its upper case.
|
||||
func UcFirst(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
if IsLetterLower(s[0]) {
|
||||
return string(s[0]-32) + s[1:]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ReplaceByMap returns a copy of `origin`,
|
||||
// which is replaced by a map in unordered way, case-sensitively.
|
||||
func ReplaceByMap(origin string, replaces map[string]string) string {
|
||||
for k, v := range replaces {
|
||||
origin = strings.ReplaceAll(origin, k, v)
|
||||
}
|
||||
return origin
|
||||
}
|
||||
|
||||
// RemoveSymbols removes all symbols from string and lefts only numbers and letters.
|
||||
func RemoveSymbols(s string) string {
|
||||
b := make([]rune, 0, len(s))
|
||||
for _, c := range s {
|
||||
if c > 127 {
|
||||
b = append(b, c)
|
||||
} else if (c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') {
|
||||
b = append(b, c)
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// EqualFoldWithoutChars checks string `s1` and `s2` equal case-insensitively,
|
||||
// with/without chars '-'/'_'/'.'/' '.
|
||||
func EqualFoldWithoutChars(s1, s2 string) bool {
|
||||
return strings.EqualFold(RemoveSymbols(s1), RemoveSymbols(s2))
|
||||
}
|
||||
|
||||
// SplitAndTrim splits string `str` by a string `delimiter` to an array,
|
||||
// and calls Trim to every element of this array. It ignores the elements
|
||||
// which are empty after Trim.
|
||||
func SplitAndTrim(str, delimiter string, characterMask ...string) []string {
|
||||
array := make([]string, 0)
|
||||
for _, v := range strings.Split(str, delimiter) {
|
||||
v = Trim(v, characterMask...)
|
||||
if v != "" {
|
||||
array = append(array, v)
|
||||
}
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
// Trim strips whitespace (or other characters) from the beginning and end of a string.
|
||||
// The optional parameter `characterMask` specifies the additional stripped characters.
|
||||
func Trim(str string, characterMask ...string) string {
|
||||
trimChars := DefaultTrimChars
|
||||
if len(characterMask) > 0 {
|
||||
trimChars += characterMask[0]
|
||||
}
|
||||
return strings.Trim(str, trimChars)
|
||||
}
|
||||
|
||||
// FormatCmdKey formats string `s` as command key using uniformed format.
|
||||
func FormatCmdKey(s string) string {
|
||||
return strings.ToLower(strings.ReplaceAll(s, "_", "."))
|
||||
}
|
||||
|
||||
// FormatEnvKey formats string `s` as environment key using uniformed format.
|
||||
func FormatEnvKey(s string) string {
|
||||
return strings.ToUpper(strings.ReplaceAll(s, ".", "_"))
|
||||
}
|
||||
|
||||
// StripSlashes un-quotes a quoted string by AddSlashes.
|
||||
func StripSlashes(str string) string {
|
||||
var buf bytes.Buffer
|
||||
l, skip := len(str), false
|
||||
for i, char := range str {
|
||||
if skip {
|
||||
skip = false
|
||||
} else if char == '\\' {
|
||||
if i+1 < l && str[i+1] == '\\' {
|
||||
skip = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
buf.WriteRune(char)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// IsASCII checks whether given string is ASCII characters.
|
||||
func IsASCII(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] > unicode.MaxASCII {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
948
db/README.md
948
db/README.md
|
|
@ -1,948 +0,0 @@
|
|||
# Magic-ORM 自主 ORM 框架架构文档
|
||||
|
||||
## 📋 目录
|
||||
|
||||
- [概述](#概述)
|
||||
- [核心特性](#核心特性)
|
||||
- [架构设计](#架构设计)
|
||||
- [技术栈](#技术栈)
|
||||
- [核心接口设计](#核心接口设计)
|
||||
- [快速开始](#快速开始)
|
||||
- [详细功能说明](#详细功能说明)
|
||||
- [最佳实践](#最佳实践)
|
||||
|
||||
---
|
||||
|
||||
## 概述
|
||||
|
||||
Magic-ORM 是一个完全自主研发的企业级 Go 语言 ORM 框架,不依赖任何第三方 ORM 库。框架基于 `database/sql` 标准库构建,提供了全自动化事务管理、面向接口设计、智能字段映射等高级特性。支持 MySQL、SQLite 等主流数据库,内置完整的迁移管理和可观测性支持,帮助开发者快速构建高质量的数据访问层。
|
||||
|
||||
**设计理念:**
|
||||
- 零依赖:仅依赖 Go 标准库 `database/sql`
|
||||
- 高性能:优化的查询执行器和连接池管理
|
||||
- 易用性:简洁的 API 设计和智能默认行为
|
||||
- 可扩展:面向接口的设计,支持自定义驱动扩展
|
||||
- **内置驱动**:框架自带所有主流数据库驱动,无需额外安装
|
||||
|
||||
---
|
||||
|
||||
## 核心特性
|
||||
|
||||
- **全自动化嵌套事务支持**:无需手动管理事务传播行为
|
||||
- **面向接口化设计**:核心功能均通过接口暴露,便于 Mock 与扩展
|
||||
- **内置主流数据库驱动**:开箱即用,并支持自定义驱动扩展
|
||||
- **统一配置组件**:与框架配置体系无缝集成
|
||||
- **单例模式数据库对象**:同一分组配置仅初始化一次
|
||||
- **双模式操作**:原生 SQL + ORM 链式操作
|
||||
- **OpenTelemetry 可观测性**:完整支持 Tracing、Logging、Metrics
|
||||
- **智能结果映射**:`Scan` 自动识别 Map/Struct/Slice,无需 `sql.ErrNoRows` 判空
|
||||
- **全自动字段映射**:无需结构体标签,自动匹配数据库字段
|
||||
- **参数智能过滤**:自动识别并过滤无效/空值字段
|
||||
- **Model/DAO 代码生成器**:一键生成全量数据访问代码
|
||||
- **高级特性**:调试模式、DryRun、自定义 Handler、软删除、时间自动更新、模型关联、主从集群等
|
||||
- **自动化数据库迁移**:支持自动迁移、增量迁移、回滚迁移等完整迁移管理
|
||||
|
||||
---
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
A[应用层] --> B[Magic-ORM 框架]
|
||||
B --> C[配置中心]
|
||||
B --> D[数据库连接池]
|
||||
B --> E[事务管理器]
|
||||
B --> F[迁移管理器]
|
||||
|
||||
C --> C1[统一配置组件]
|
||||
C --> C2[环境配置]
|
||||
|
||||
D --> D1[MySQL 驱动]
|
||||
D --> D2[SQLite 驱动]
|
||||
D --> D3[自定义驱动]
|
||||
|
||||
E --> E1[自动嵌套事务]
|
||||
E --> E2[事务传播控制]
|
||||
|
||||
F --> F1[自动迁移]
|
||||
F --> F2[增量迁移]
|
||||
F --> F3[回滚迁移]
|
||||
|
||||
B --> G[观测性组件]
|
||||
G --> G1[Tracing]
|
||||
G --> G2[Logging]
|
||||
G --> G3[Metrics]
|
||||
|
||||
B --> H[工具组件]
|
||||
H --> H1[字段映射器]
|
||||
H --> H2[参数过滤器]
|
||||
H --> H3[结果映射器]
|
||||
H --> H4[代码生成器]
|
||||
```
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
magic-orm/
|
||||
├── core/ # 核心实现
|
||||
│ ├── database.go # 数据库连接管理
|
||||
│ ├── transaction.go # 事务管理
|
||||
│ ├── query.go # 查询构建器
|
||||
│ └── mapper.go # 字段映射器
|
||||
├── migrate/ # 迁移管理
|
||||
│ └── migrator.go # 自动迁移实现
|
||||
├── generator/ # 代码生成器
|
||||
│ ├── model.go # Model 生成
|
||||
│ └── dao.go # DAO 生成
|
||||
├── tracing/ # OpenTelemetry 集成
|
||||
│ └── tracer.go # 链路追踪
|
||||
└── driver/ # 数据库驱动适配(已内置)
|
||||
├── mysql.go # MySQL 驱动(内置)
|
||||
├── sqlite.go # SQLite 驱动(内置)
|
||||
├── postgres.go # PostgreSQL 驱动(内置)
|
||||
├── sqlserver.go # SQL Server 驱动(内置)
|
||||
├── oracle.go # Oracle 驱动(内置)
|
||||
└── clickhouse.go # ClickHouse 驱动(内置)
|
||||
```
|
||||
|
||||
### 核心组件说明
|
||||
|
||||
#### 1. 数据库连接管理 (`core/database.go`)
|
||||
|
||||
- **单例模式**:全局唯一的 `DB` 实例,确保资源高效利用
|
||||
- **多数据库支持**:支持 MySQL、SQLite、PostgreSQL、SQL Server、Oracle、ClickHouse 等
|
||||
- **驱动内置**:所有主流数据库驱动已预装在框架中
|
||||
- **连接池优化**:内置 sql.DB 连接池管理
|
||||
- **健康检查**:启动时自动执行 `Ping()` 验证连接
|
||||
|
||||
**核心配置项:**
|
||||
```go
|
||||
Config{
|
||||
DriverName: "mysql", // 驱动名称
|
||||
DataSource: "dns", // 数据源连接字符串
|
||||
MaxIdleConns: 10, // 最大空闲连接数
|
||||
MaxOpenConns: 100, // 最大打开连接数
|
||||
Debug: true, // 调试模式
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. 查询构建器 (`core/query.go`)
|
||||
|
||||
提供流畅的链式查询接口:
|
||||
|
||||
- **条件查询**: Where, Or, And
|
||||
- **字段选择**: Select, Omit
|
||||
- **排序分页**: Order, Limit, Offset
|
||||
- **分组统计**: Group, Having, Count
|
||||
- **连接查询**: Join, LeftJoin, RightJoin
|
||||
- **预加载**: Preload
|
||||
|
||||
**示例:**
|
||||
```go
|
||||
var users []model.User
|
||||
db.Model(&model.User{}).
|
||||
Where("status = ?", 1).
|
||||
Select("id", "username").
|
||||
Order("id DESC").
|
||||
Limit(10).
|
||||
Find(&users)
|
||||
```
|
||||
|
||||
#### 3. 事务管理器 (`core/transaction.go`)
|
||||
|
||||
提供完整的事务管理能力:
|
||||
|
||||
- **自动嵌套事务**: 自动管理事务传播
|
||||
- **保存点支持**: 支持部分回滚
|
||||
- **生命周期回调**: Before/After 钩子
|
||||
|
||||
#### 4. 字段映射器 (`core/mapper.go`)
|
||||
|
||||
智能字段映射系统:
|
||||
|
||||
- **驼峰转下划线**: UserName -> user_name
|
||||
- **标签解析**: 支持 db, json 标签
|
||||
- **类型转换**: Go 类型与数据库类型自动转换
|
||||
- **零值过滤**: 自动过滤空值和零值
|
||||
|
||||
#### 5. 迁移管理 (`migrate/migrator.go`)
|
||||
|
||||
完整的数据库迁移方案:
|
||||
|
||||
- **自动迁移**: 根据模型自动创建/修改表结构
|
||||
- **增量迁移**: 支持添加字段、索引等
|
||||
- **回滚支持**: 支持迁移回滚
|
||||
- **版本管理**: 迁移版本记录和管理
|
||||
|
||||
#### 6. 驱动管理器 (`driver/manager.go`)
|
||||
|
||||
统一的驱动管理和注册中心:
|
||||
|
||||
- **驱动注册**: 自动注册所有内置驱动
|
||||
- **驱动选择**: 根据配置自动选择合适的驱动
|
||||
- **驱动扩展**: 支持用户自定义驱动注册
|
||||
- **版本检测**: 自动检测数据库版本并适配特性
|
||||
|
||||
```go
|
||||
// 驱动管理器会自动处理
|
||||
var supportedDrivers = map[string]driver.Driver{
|
||||
"mysql": &MySQLDriver{},
|
||||
"sqlite": &SQLiteDriver{},
|
||||
"postgres": &PostgresDriver{},
|
||||
"sqlserver": &SQLServerDriver{},
|
||||
"oracle": &OracleDriver{},
|
||||
"clickhouse": &ClickHouseDriver{},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 技术栈
|
||||
|
||||
### 核心依赖
|
||||
|
||||
| 组件 | 版本 | 说明 |
|
||||
|------|------|------|
|
||||
| Go | 1.25+ | 编程语言 |
|
||||
| database/sql | stdlib | Go 标准库 |
|
||||
| driver-go | Latest | 数据库驱动接口规范 |
|
||||
| OpenTelemetry | Latest | 可观测性框架 |
|
||||
| **内置驱动集合** | Latest | **包含所有主流数据库驱动** |
|
||||
|
||||
### 支持的数据库驱动
|
||||
|
||||
框架已内置以下数据库驱动,**无需额外安装**:
|
||||
|
||||
- **MySQL**: 内置驱动(基于 `github.com/go-sql-driver/mysql`)
|
||||
- **SQLite**: 内置驱动(基于 `github.com/mattn/go-sqlite3`)
|
||||
- **PostgreSQL**: 内置驱动(基于 `github.com/lib/pq`)
|
||||
- **SQL Server**: 内置驱动(基于 `github.com/denisenkom/go-mssqldb`)
|
||||
- **Oracle**: 内置驱动(基于 `github.com/godror/godror`)
|
||||
- **ClickHouse**: 内置驱动(基于 `github.com/ClickHouse/clickhouse-go`)
|
||||
- **自定义驱动**: 实现 `driver.Driver` 接口即可扩展
|
||||
|
||||
> 💡 **说明**:框架在编译时已将所有主流数据库驱动打包,用户只需引入 `magic-orm` 即可完成所有数据库操作,无需单独安装各数据库驱动。
|
||||
|
||||
---
|
||||
|
||||
## 核心接口设计
|
||||
|
||||
### 1. 数据库连接接口
|
||||
|
||||
```go
|
||||
// IDatabase 数据库连接接口
|
||||
type IDatabase interface {
|
||||
// 基础操作
|
||||
DB() *sql.DB
|
||||
Close() error
|
||||
Ping() error
|
||||
|
||||
// 事务管理
|
||||
Begin() (ITx, error)
|
||||
Transaction(fn func(ITx) error) error
|
||||
|
||||
// 查询构建器
|
||||
Model(model interface{}) IQuery
|
||||
Table(name string) IQuery
|
||||
Query(result interface{}, query string, args ...interface{}) error
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
|
||||
// 迁移管理
|
||||
Migrate(models ...interface{}) error
|
||||
|
||||
// 配置
|
||||
SetDebug(bool)
|
||||
SetMaxIdleConns(int)
|
||||
SetMaxOpenConns(int)
|
||||
SetConnMaxLifetime(time.Duration)
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 事务接口
|
||||
|
||||
```go
|
||||
// ITx 事务接口
|
||||
type ITx interface {
|
||||
// 基础操作
|
||||
Commit() error
|
||||
Rollback() error
|
||||
|
||||
// 查询操作
|
||||
Model(model interface{}) IQuery
|
||||
Table(name string) IQuery
|
||||
Insert(model interface{}) (int64, error)
|
||||
BatchInsert(models interface{}, batchSize int) error
|
||||
Update(model interface{}, data map[string]interface{}) error
|
||||
Delete(model interface{}) error
|
||||
|
||||
// 原生 SQL
|
||||
Query(result interface{}, query string, args ...interface{}) error
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 查询构建器接口
|
||||
|
||||
```go
|
||||
// IQuery 查询构建器接口
|
||||
type IQuery interface {
|
||||
// 条件查询
|
||||
Where(query string, args ...interface{}) IQuery
|
||||
Or(query string, args ...interface{}) IQuery
|
||||
And(query string, args ...interface{}) IQuery
|
||||
|
||||
// 字段选择
|
||||
Select(fields ...string) IQuery
|
||||
Omit(fields ...string) IQuery
|
||||
|
||||
// 排序
|
||||
Order(order string) IQuery
|
||||
OrderBy(field string, direction string) IQuery
|
||||
|
||||
// 分页
|
||||
Limit(limit int) IQuery
|
||||
Offset(offset int) IQuery
|
||||
Page(page, pageSize int) IQuery
|
||||
|
||||
// 分组
|
||||
Group(group string) IQuery
|
||||
Having(having string, args ...interface{}) IQuery
|
||||
|
||||
// 连接
|
||||
Join(join string, args ...interface{}) IQuery
|
||||
LeftJoin(table, on string) IQuery
|
||||
RightJoin(table, on string) IQuery
|
||||
InnerJoin(table, on string) IQuery
|
||||
|
||||
// 预加载
|
||||
Preload(relation string, conditions ...interface{}) IQuery
|
||||
|
||||
// 执行查询
|
||||
First(result interface{}) error
|
||||
Find(result interface{}) error
|
||||
Count(count *int64) IQuery
|
||||
Exists() (bool, error)
|
||||
|
||||
// 更新和删除
|
||||
Updates(data interface{}) error
|
||||
UpdateColumn(column string, value interface{}) error
|
||||
Delete() error
|
||||
|
||||
// 特殊模式
|
||||
Unscoped() IQuery
|
||||
DryRun() IQuery
|
||||
Debug() IQuery
|
||||
|
||||
// 构建 SQL(不执行)
|
||||
Build() (string, []interface{})
|
||||
}
|
||||
```
|
||||
|
||||
### 4. 模型接口
|
||||
|
||||
```go
|
||||
// IModel 模型接口
|
||||
type IModel interface {
|
||||
// 表名映射
|
||||
TableName() string
|
||||
|
||||
// 生命周期回调(可选)
|
||||
BeforeCreate(tx ITx) error
|
||||
AfterCreate(tx ITx) error
|
||||
BeforeUpdate(tx ITx) error
|
||||
AfterUpdate(tx ITx) error
|
||||
BeforeDelete(tx ITx) error
|
||||
AfterDelete(tx ITx) error
|
||||
BeforeSave(tx ITx) error
|
||||
AfterSave(tx ITx) error
|
||||
}
|
||||
```
|
||||
|
||||
### 5. 字段映射器接口
|
||||
|
||||
```go
|
||||
// IFieldMapper 字段映射器接口
|
||||
type IFieldMapper interface {
|
||||
// 结构体字段转数据库列
|
||||
StructToColumns(model interface{}) (map[string]interface{}, error)
|
||||
|
||||
// 数据库列转结构体字段
|
||||
ColumnsToStruct(row *sql.Rows, model interface{}) error
|
||||
|
||||
// 获取表名
|
||||
GetTableName(model interface{}) string
|
||||
|
||||
// 获取主键字段
|
||||
GetPrimaryKey(model interface{}) string
|
||||
|
||||
// 获取字段信息
|
||||
GetFields(model interface{}) []FieldInfo
|
||||
}
|
||||
|
||||
// FieldInfo 字段信息
|
||||
type FieldInfo struct {
|
||||
Name string // 字段名
|
||||
Column string // 列名
|
||||
Type string // Go 类型
|
||||
DbType string // 数据库类型
|
||||
Tag string // 标签
|
||||
IsPrimary bool // 是否主键
|
||||
IsAuto bool // 是否自增
|
||||
}
|
||||
```
|
||||
|
||||
### 6. 迁移管理器接口
|
||||
|
||||
```go
|
||||
// IMigrator 迁移管理器接口
|
||||
type IMigrator interface {
|
||||
// 自动迁移
|
||||
AutoMigrate(models ...interface{}) error
|
||||
|
||||
// 表操作
|
||||
CreateTable(model interface{}) error
|
||||
DropTable(model interface{}) error
|
||||
HasTable(model interface{}) (bool, error)
|
||||
RenameTable(oldName, newName string) error
|
||||
|
||||
// 列操作
|
||||
AddColumn(model interface{}, field string) error
|
||||
DropColumn(model interface{}, field string) error
|
||||
HasColumn(model interface{}, field string) (bool, error)
|
||||
RenameColumn(model interface{}, oldField, newField string) error
|
||||
|
||||
// 索引操作
|
||||
CreateIndex(model interface{}, field string) error
|
||||
DropIndex(model interface{}, field string) error
|
||||
HasIndex(model interface{}, field string) (bool, error)
|
||||
}
|
||||
```
|
||||
|
||||
### 7. 代码生成器接口
|
||||
|
||||
```go
|
||||
// ICodeGenerator 代码生成器接口
|
||||
type ICodeGenerator interface {
|
||||
// 生成 Model 代码
|
||||
GenerateModel(table string, outputDir string) error
|
||||
|
||||
// 生成 DAO 代码
|
||||
GenerateDAO(table string, outputDir string) error
|
||||
|
||||
// 生成完整代码
|
||||
GenerateAll(tables []string, outputDir string) error
|
||||
|
||||
// 从数据库读取表结构
|
||||
InspectTable(tableName string) (*TableSchema, error)
|
||||
}
|
||||
|
||||
// TableSchema 表结构信息
|
||||
type TableSchema struct {
|
||||
Name string
|
||||
Columns []ColumnInfo
|
||||
Indexes []IndexInfo
|
||||
}
|
||||
```
|
||||
|
||||
### 8. 配置结构
|
||||
|
||||
```go
|
||||
// Config 数据库配置
|
||||
type Config struct {
|
||||
DriverName string // 驱动名称
|
||||
DataSource string // 数据源连接字符串
|
||||
MaxIdleConns int // 最大空闲连接数
|
||||
MaxOpenConns int // 最大打开连接数
|
||||
ConnMaxLifetime time.Duration // 连接最大生命周期
|
||||
Debug bool // 调试模式
|
||||
|
||||
// 主从配置
|
||||
Replicas []string // 从库列表
|
||||
ReadPolicy ReadPolicy // 读负载均衡策略
|
||||
|
||||
// OpenTelemetry
|
||||
EnableTracing bool
|
||||
ServiceName string
|
||||
}
|
||||
|
||||
// ReadPolicy 读负载均衡策略
|
||||
type ReadPolicy int
|
||||
|
||||
const (
|
||||
Random ReadPolicy = iota
|
||||
RoundRobin
|
||||
LeastConn
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装 Magic-ORM
|
||||
|
||||
```bash
|
||||
# 仅需安装 magic-orm,所有数据库驱动已内置
|
||||
go get github.com/your-org/magic-orm
|
||||
```
|
||||
|
||||
> ✅ **无需单独安装数据库驱动!** 所有驱动已包含在 magic-orm 中。
|
||||
|
||||
### 2. 配置数据库
|
||||
|
||||
在配置文件中设置数据库参数:
|
||||
|
||||
```yaml
|
||||
database:
|
||||
type: mysql # 或 sqlite, postgres
|
||||
dns: "user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
debug: true
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
```
|
||||
|
||||
### 3. 定义模型
|
||||
|
||||
```go
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Username string `json:"username" db:"username"`
|
||||
Password string `json:"-" db:"password"`
|
||||
Email string `json:"email" db:"email"`
|
||||
Status int `json:"status" db:"status"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
}
|
||||
|
||||
// 表名映射
|
||||
func (User) TableName() string {
|
||||
return "user"
|
||||
}
|
||||
```
|
||||
|
||||
### 4. 初始化数据库连接
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"your-project/orm"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 初始化数据库连接
|
||||
db, err := orm.NewDatabase(&orm.Config{
|
||||
DriverName: "mysql",
|
||||
DataSource: "user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
Debug: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 执行迁移
|
||||
orm.Migrate(db, &model.User{})
|
||||
}
|
||||
```
|
||||
|
||||
### 5. CRUD 操作
|
||||
|
||||
```go
|
||||
// 创建
|
||||
user := &model.User{Username: "admin", Password: "123456", Email: "admin@example.com"}
|
||||
id, err := db.Insert(user)
|
||||
|
||||
// 查询单个
|
||||
var user model.User
|
||||
err := db.Model(&model.User{}).Where("id = ?", 1).First(&user)
|
||||
|
||||
// 查询多个
|
||||
var users []model.User
|
||||
err := db.Model(&model.User{}).Where("status = ?", 1).Order("id DESC").Find(&users)
|
||||
|
||||
// 更新
|
||||
err := db.Model(&model.User{}).Where("id = ?", 1).Updates(map[string]interface{}{
|
||||
"email": "new@example.com",
|
||||
})
|
||||
|
||||
// 删除
|
||||
err := db.Model(&model.User{}).Where("id = ?", 1).Delete()
|
||||
|
||||
// 原生 SQL
|
||||
var results []model.User
|
||||
err := db.Query(&results, "SELECT * FROM user WHERE status = ?", 1)
|
||||
```
|
||||
|
||||
### 6. 事务操作
|
||||
|
||||
```go
|
||||
// 自动嵌套事务
|
||||
err := db.Transaction(func(tx *orm.Tx) error {
|
||||
// 创建用户
|
||||
user := &model.User{Username: "test", Email: "test@example.com"}
|
||||
_, err := tx.Insert(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建关联数据(自动加入同一事务)
|
||||
profile := &model.Profile{UserID: user.ID, Avatar: "default.png"}
|
||||
_, err = tx.Insert(profile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 详细功能说明
|
||||
|
||||
### 1. 全自动化嵌套事务
|
||||
|
||||
框架自动管理事务的传播行为,支持以下场景:
|
||||
|
||||
- **REQUIRED**: 如果当前存在事务,则加入该事务;否则创建新事务
|
||||
- **REQUIRES_NEW**: 无论当前是否存在事务,都创建新事务
|
||||
- **NESTED**: 在当前事务中创建嵌套事务(使用保存点)
|
||||
|
||||
**示例:**
|
||||
```go
|
||||
// 外层事务
|
||||
db.Transaction(func(tx *orm.Tx) error {
|
||||
// 内层自动加入同一事务
|
||||
userService.CreateUser(tx, user)
|
||||
orderService.CreateOrder(tx, order)
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### 2. 智能结果映射
|
||||
|
||||
无需手动处理 `sql.ErrNoRows`,框架自动识别返回类型:
|
||||
|
||||
```go
|
||||
// 自动识别 Struct
|
||||
var user model.User
|
||||
db.Model(&model.User{}).Where("id = ?", 1).First(&user) // 不存在时返回零值,不报错
|
||||
|
||||
// 自动识别 Slice
|
||||
var users []model.User
|
||||
db.Model(&model.User{}).Where("status = ?", 1).Find(&users) // 空结果返回空切片,而非 nil
|
||||
|
||||
// 自动识别 Map
|
||||
var result map[string]interface{}
|
||||
db.Table("user").Where("id = ?", 1).First(&result)
|
||||
```
|
||||
|
||||
### 3. 全自动字段映射
|
||||
|
||||
无需结构体标签,框架自动匹配字段:
|
||||
|
||||
```go
|
||||
// 驼峰命名自动转下划线
|
||||
type UserInfo struct {
|
||||
UserName string // 自动映射到 user_name 字段
|
||||
UserAge int // 自动映射到 user_age 字段
|
||||
CreatedAt string // 自动映射到 created_at 字段
|
||||
}
|
||||
```
|
||||
|
||||
### 4. 参数智能过滤
|
||||
|
||||
自动过滤零值和空指针:
|
||||
|
||||
```go
|
||||
// 仅更新非零值字段
|
||||
updateData := &model.User{
|
||||
Username: "newname", // 会被更新
|
||||
Email: "", // 空值,自动过滤
|
||||
Status: 0, // 零值,自动过滤
|
||||
}
|
||||
db.Model(&user).Updates(updateData)
|
||||
```
|
||||
|
||||
### 5. OpenTelemetry 可观测性
|
||||
|
||||
完整支持分布式追踪:
|
||||
|
||||
```go
|
||||
// 自动注入 Span
|
||||
ctx, span := otel.Tracer("gin-base").Start(context.Background(), "DB Query")
|
||||
defer span.End()
|
||||
|
||||
// 自动记录 SQL 执行时间、错误信息等
|
||||
db.WithContext(ctx).Find(&users)
|
||||
```
|
||||
|
||||
### 6. 数据库迁移管理
|
||||
|
||||
#### 自动迁移
|
||||
```go
|
||||
database.SetAutoMigrate(&model.User{}, &model.Order{})
|
||||
```
|
||||
|
||||
#### 增量迁移
|
||||
```go
|
||||
// 添加新字段
|
||||
type UserV2 struct {
|
||||
model.User
|
||||
Phone string ` + "`" + `json:"phone" gorm:"column:phone;type:varchar(20)"` + "`" + `
|
||||
}
|
||||
database.SetAutoMigrate(&UserV2{})
|
||||
```
|
||||
|
||||
#### 字段操作
|
||||
```go
|
||||
// 重命名字段
|
||||
database.RenameColumn(&model.User{}, "UserName", "Nickname")
|
||||
|
||||
// 删除字段
|
||||
database.DropColumn(&model.User{}, "OldField")
|
||||
```
|
||||
|
||||
### 7. 高级特性
|
||||
|
||||
#### 软删除
|
||||
```go
|
||||
type User struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
DeletedAt *time.Time `json:"deleted_at" db:"deleted_at"` // 软删除标记
|
||||
}
|
||||
|
||||
// 自动过滤已删除记录
|
||||
db.Model(&model.User{}).Find(&users) // WHERE deleted_at IS NULL
|
||||
|
||||
// 强制包含已删除记录
|
||||
db.Unscoped().Model(&model.User{}).Find(&users)
|
||||
```
|
||||
|
||||
#### 调试模式
|
||||
```yaml
|
||||
database:
|
||||
debug: true # 输出所有 SQL 日志
|
||||
```
|
||||
|
||||
#### DryRun 模式
|
||||
```go
|
||||
// 生成 SQL 但不执行
|
||||
sql, args := db.Model(&model.User{}).DryRun().Insert(&user)
|
||||
fmt.Println(sql, args)
|
||||
```
|
||||
|
||||
#### 自定义 Handler
|
||||
```go
|
||||
// 注册回调函数
|
||||
db.Callback().Before("insert").Register("custom_before_insert", func(ctx context.Context, db *orm.DB) error {
|
||||
// 自定义逻辑
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
#### 主从集群
|
||||
```go
|
||||
// 配置读写分离
|
||||
db, err := orm.NewDatabase(&orm.Config{
|
||||
DriverName: "mysql",
|
||||
DataSource: "master_dsn",
|
||||
Replicas: []string{"slave1_dsn", "slave2_dsn"},
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 模型设计规范
|
||||
|
||||
```go
|
||||
// ✅ 推荐:使用 db 标签明确字段映射
|
||||
type User struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Username string `json:"username" db:"username"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
}
|
||||
|
||||
// ❌ 不推荐:缺少字段映射标签
|
||||
type User struct {
|
||||
Id int64 // 无法自动映射到 id 列
|
||||
UserName string // 可能映射错误
|
||||
CreatedAt time.Time // 时间格式可能不匹配
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 事务使用规范
|
||||
|
||||
```go
|
||||
// ✅ 推荐:使用闭包自动管理事务
|
||||
err := db.Transaction(func(tx *orm.Tx) error {
|
||||
// 业务逻辑
|
||||
return nil
|
||||
})
|
||||
|
||||
// ❌ 不推荐:手动管理事务
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
```
|
||||
|
||||
### 3. 查询优化
|
||||
|
||||
```go
|
||||
// ✅ 推荐:使用 Select 指定字段
|
||||
db.Model(&model.User{}).Select("id", "username").Find(&users)
|
||||
|
||||
// ✅ 推荐:使用 Index 加速查询
|
||||
// 在数据库层面创建索引
|
||||
// CREATE INDEX idx_username ON user(username);
|
||||
|
||||
// ✅ 推荐:批量操作
|
||||
users := []model.User{{}, {}, {}}
|
||||
db.BatchInsert(&users, 100) // 每批 100 条
|
||||
|
||||
// ❌ 避免:N+1 查询问题
|
||||
for _, user := range users {
|
||||
db.Model(&model.Order{}).Where("user_id = ?", user.ID).Find(&orders) // 循环查询
|
||||
}
|
||||
|
||||
// ✅ 使用 Join 或预加载
|
||||
db.Query(&results, "SELECT u.*, o.* FROM user u LEFT JOIN orders o ON u.id = o.user_id")
|
||||
```
|
||||
|
||||
### 4. 错误处理
|
||||
|
||||
```go
|
||||
// ✅ 推荐:统一错误处理
|
||||
if err := db.Insert(&user); err != nil {
|
||||
log.Error("创建用户失败", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// ✅ 使用 errors 包判断特定错误
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// 记录不存在
|
||||
}
|
||||
```
|
||||
|
||||
### 5. 性能优化
|
||||
|
||||
```go
|
||||
// 连接池配置
|
||||
sqlDB := db.DB()
|
||||
sqlDB.SetMaxIdleConns(10) // 最大空闲连接数
|
||||
sqlDB.SetMaxOpenConns(100) // 最大打开连接数
|
||||
sqlDB.SetConnMaxLifetime(time.Hour) // 连接最大生命周期
|
||||
|
||||
// 使用 Scan 替代 Find 提升性能
|
||||
type Result struct {
|
||||
ID int64 `db:"id"`
|
||||
Username string `db:"username"`
|
||||
}
|
||||
var results []Result
|
||||
db.Model(&model.User{}).Select("id", "username").Scan(&results)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 如何处理并发写入?
|
||||
A: 使用事务 + 乐观锁:
|
||||
```go
|
||||
type Product struct {
|
||||
ID int64 `db:"id"`
|
||||
Version int `db:"version"` // 版本号
|
||||
}
|
||||
|
||||
// 更新时检查版本号
|
||||
rows, err := db.Exec(
|
||||
"UPDATE product SET version = ?, stock = ? WHERE id = ? AND version = ?",
|
||||
newVersion, newStock, id, oldVersion,
|
||||
)
|
||||
count, _ := rows.RowsAffected()
|
||||
if count == 0 {
|
||||
return errors.New("乐观锁冲突,数据已被其他事务修改")
|
||||
}
|
||||
```
|
||||
|
||||
### Q: 如何实现读写分离?
|
||||
A: 配置主从数据库连接:
|
||||
```go
|
||||
db, err := orm.NewDatabase(&orm.Config{
|
||||
DriverName: "mysql",
|
||||
DataSource: "master_dsn",
|
||||
Replicas: []string{"slave1_dsn", "slave2_dsn"},
|
||||
ReadPolicy: orm.RoundRobin, // 负载均衡策略
|
||||
})
|
||||
```
|
||||
|
||||
### Q: 如何批量插入大量数据?
|
||||
A: 使用 `BatchInsert`:
|
||||
```go
|
||||
users := make([]model.User, 10000)
|
||||
// ... 填充数据 ...
|
||||
db.BatchInsert(&users, 1000) // 每批 1000 条,共 10 批
|
||||
```
|
||||
|
||||
### Q: 如何实现字段自动映射?
|
||||
A: 框架会自动将驼峰命名转换为下划线命名:
|
||||
```go
|
||||
type UserInfo struct {
|
||||
UserName string `db:"user_name"` // 自动映射到 user_name 字段
|
||||
UserAge int `db:"user_age"` // 自动映射到 user_age 字段
|
||||
CreatedAt string `db:"created_at"` // 自动映射到 created_at 字段
|
||||
}
|
||||
```
|
||||
|
||||
### Q: 如何处理时间字段?
|
||||
A: 使用 `time.Time` 类型,框架会自动处理时区转换:
|
||||
```go
|
||||
type Event struct {
|
||||
ID int64 `db:"id"`
|
||||
StartTime time.Time `db:"start_time"`
|
||||
EndTime time.Time `db:"end_time"`
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 更新日志
|
||||
|
||||
- **v1.0.0**: 初始版本发布
|
||||
- 完全自主研发,零依赖第三方 ORM
|
||||
- 基于 database/sql 标准库
|
||||
- 全自动化事务管理
|
||||
- 智能字段映射
|
||||
- OpenTelemetry 集成
|
||||
- 支持 MySQL、SQLite、PostgreSQL
|
||||
|
||||
---
|
||||
|
||||
## 贡献指南
|
||||
|
||||
欢迎提交 Issue 和 Pull Request!
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
375
db/VALIDATION.md
375
db/VALIDATION.md
|
|
@ -1,375 +0,0 @@
|
|||
# Magic-ORM 功能完整性验证报告
|
||||
|
||||
## 📋 验证概述
|
||||
|
||||
本文档验证 Magic-ORM 框架相对于 README.md 中定义的核心特性的完整实现情况。
|
||||
|
||||
---
|
||||
|
||||
## ✅ 已完整实现的核心特性
|
||||
|
||||
### 1. **全自动化嵌套事务支持** ✅
|
||||
- **文件**: `core/transaction.go`
|
||||
- **实现内容**:
|
||||
- `Transaction()` 方法自动管理事务提交/回滚
|
||||
- 支持 panic 时自动回滚
|
||||
- 事务中可执行 Insert、BatchInsert、Update、Delete、Query 等操作
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 2. **面向接口化设计** ✅
|
||||
- **文件**: `core/interfaces.go`
|
||||
- **实现接口**:
|
||||
- `IDatabase` - 数据库连接接口
|
||||
- `ITx` - 事务接口
|
||||
- `IQuery` - 查询构建器接口
|
||||
- `IModel` - 模型接口
|
||||
- `IFieldMapper` - 字段映射器接口
|
||||
- `IMigrator` - 迁移管理器接口
|
||||
- `ICodeGenerator` - 代码生成器接口
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 3. **内置主流数据库驱动** ✅
|
||||
- **文件**: `driver/manager.go`, `driver/sqlite.go`
|
||||
- **实现内容**:
|
||||
- DriverManager 单例模式管理所有驱动
|
||||
- SQLite 驱动已实现
|
||||
- 支持 MySQL/PostgreSQL/SQL Server/Oracle/ClickHouse(框架已预留接口)
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 4. **统一配置组件** ✅
|
||||
- **文件**: `core/interfaces.go`
|
||||
- **Config 结构**:
|
||||
```go
|
||||
type Config struct {
|
||||
DriverName string
|
||||
DataSource string
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
Debug bool
|
||||
Replicas []string
|
||||
ReadPolicy ReadPolicy
|
||||
EnableTracing bool
|
||||
ServiceName string
|
||||
}
|
||||
```
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 5. **单例模式数据库对象** ✅
|
||||
- **文件**: `driver/manager.go`
|
||||
- **实现内容**:
|
||||
- `GetDefaultManager()` 使用 sync.Once 确保单例
|
||||
- 驱动管理器全局唯一实例
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 6. **双模式操作** ✅
|
||||
- **文件**: `core/query.go`, `core/database.go`
|
||||
- **支持模式**:
|
||||
- ✅ ORM 链式操作:`db.Model(&User{}).Where("id = ?", 1).Find(&user)`
|
||||
- ✅ 原生 SQL:`db.Query(&users, "SELECT * FROM user")`
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 7. **OpenTelemetry 可观测性** ✅
|
||||
- **文件**: `tracing/tracer.go`
|
||||
- **实现内容**:
|
||||
- 自动追踪所有数据库操作
|
||||
- 记录 SQL 语句、参数、执行时间、影响行数
|
||||
- 支持分布式追踪上下文
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 8. **智能结果映射** ✅
|
||||
- **文件**: `core/result_mapper.go`
|
||||
- **实现内容**:
|
||||
- `MapToSlice()` - 映射到 Slice
|
||||
- `MapToStruct()` - 映射到 Struct
|
||||
- `ScanAll()` - 自动识别目标类型
|
||||
- 无需手动处理 `sql.ErrNoRows`
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 9. **全自动字段映射** ✅
|
||||
- **文件**: `core/mapper.go`
|
||||
- **实现内容**:
|
||||
- 驼峰命名自动转下划线
|
||||
- 解析 db/json 标签
|
||||
- Go 类型与数据库类型自动转换
|
||||
- 零值自动过滤
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 10. **参数智能过滤** ✅
|
||||
- **文件**: `core/filter.go`
|
||||
- **实现内容**:
|
||||
- `FilterZeroValues()` - 过滤零值
|
||||
- `FilterEmptyStrings()` - 过滤空字符串
|
||||
- `FilterNilValues()` - 过滤 nil 值
|
||||
- `IsValidValue()` - 检查值有效性
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 11. **Model/DAO 代码生成器** ✅
|
||||
- **文件**: `generator/generator.go`
|
||||
- **实现内容**:
|
||||
- `GenerateModel()` - 生成 Model 代码
|
||||
- `GenerateDAO()` - 生成 DAO 代码
|
||||
- `GenerateAll()` - 一次性生成完整代码
|
||||
- 支持自定义列信息
|
||||
- **测试结果**: ✅ 成功生成 `generated/user.go`
|
||||
|
||||
### 12. **高级特性** ✅
|
||||
- **文件**: 多个核心文件
|
||||
- **已实现**:
|
||||
- ✅ 调试模式 (`Debug()`)
|
||||
- ✅ DryRun 模式 (`DryRun()`)
|
||||
- ✅ 软删除 (`core/soft_delete.go`)
|
||||
- ✅ 模型关联 (`core/relation.go`)
|
||||
- ✅ 主从集群读写分离 (`core/read_write.go`)
|
||||
- ✅ 查询缓存 (`core/cache.go`)
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
### 13. **自动化数据库迁移** ✅
|
||||
- **文件**: `core/migrator.go`
|
||||
- **实现内容**:
|
||||
- ✅ `AutoMigrate()` - 自动迁移
|
||||
- ✅ `CreateTable()` / `DropTable()`
|
||||
- ✅ `HasTable()` / `RenameTable()`
|
||||
- ✅ `AddColumn()` / `DropColumn()`
|
||||
- ✅ `CreateIndex()` / `DropIndex()`
|
||||
- ✅ 完整的 DDL 操作支持
|
||||
- **测试状态**: ✅ 通过
|
||||
|
||||
---
|
||||
|
||||
## 📊 查询构建器完整方法集
|
||||
|
||||
### ✅ 已实现的方法
|
||||
|
||||
| 方法 | 功能 | 状态 |
|
||||
|------|------|------|
|
||||
| `Where()` | 条件查询 | ✅ |
|
||||
| `Or()` | OR 条件 | ✅ |
|
||||
| `And()` | AND 条件 | ✅ |
|
||||
| `Select()` | 选择字段 | ✅ |
|
||||
| `Omit()` | 排除字段 | ✅ |
|
||||
| `Order()` | 排序 | ✅ |
|
||||
| `OrderBy()` | 指定字段排序 | ✅ |
|
||||
| `Limit()` | 限制数量 | ✅ |
|
||||
| `Offset()` | 偏移量 | ✅ |
|
||||
| `Page()` | 分页查询 | ✅ |
|
||||
| `Group()` | 分组 | ✅ |
|
||||
| `Having()` | HAVING 条件 | ✅ |
|
||||
| `Join()` | JOIN 连接 | ✅ |
|
||||
| `LeftJoin()` | 左连接 | ✅ |
|
||||
| `RightJoin()` | 右连接 | ✅ |
|
||||
| `InnerJoin()` | 内连接 | ✅ |
|
||||
| `Preload()` | 预加载关联 | ✅ (框架) |
|
||||
| `First()` | 查询第一条 | ✅ |
|
||||
| `Find()` | 查询多条 | ✅ |
|
||||
| `Scan()` | 扫描到自定义结构 | ✅ |
|
||||
| `Count()` | 统计数量 | ✅ |
|
||||
| `Exists()` | 存在性检查 | ✅ |
|
||||
| `Updates()` | 更新数据 | ✅ |
|
||||
| `UpdateColumn()` | 更新单字段 | ✅ |
|
||||
| `Delete()` | 删除数据 | ✅ |
|
||||
| `Unscoped()` | 忽略软删除 | ✅ |
|
||||
| `DryRun()` | 干跑模式 | ✅ |
|
||||
| `Debug()` | 调试模式 | ✅ |
|
||||
| `Build()` | 构建 SQL | ✅ |
|
||||
| `BuildUpdate()` | 构建 UPDATE | ✅ |
|
||||
| `BuildDelete()` | 构建 DELETE | ✅ |
|
||||
|
||||
---
|
||||
|
||||
## 🎯 事务接口完整实现
|
||||
|
||||
### ITx 接口方法
|
||||
|
||||
| 方法 | 功能 | 实现状态 |
|
||||
|------|------|---------|
|
||||
| `Commit()` | 提交事务 | ✅ |
|
||||
| `Rollback()` | 回滚事务 | ✅ |
|
||||
| `Model()` | 基于模型查询 | ✅ |
|
||||
| `Table()` | 基于表名查询 | ✅ |
|
||||
| `Insert()` | 插入数据 | ✅ (返回 LastInsertId) |
|
||||
| `BatchInsert()` | 批量插入 | ✅ (支持分批处理) |
|
||||
| `Update()` | 更新数据 | ✅ |
|
||||
| `Delete()` | 删除数据 | ✅ |
|
||||
| `Query()` | 原生 SQL 查询 | ✅ |
|
||||
| `Exec()` | 原生 SQL 执行 | ✅ |
|
||||
|
||||
---
|
||||
|
||||
## 🔧 新增核心组件
|
||||
|
||||
### 1. ParamFilter (参数过滤器)
|
||||
```go
|
||||
// 位置:core/filter.go
|
||||
- FilterZeroValues() // 过滤零值
|
||||
- FilterEmptyStrings() // 过滤空字符串
|
||||
- FilterNilValues() // 过滤 nil 值
|
||||
- IsValidValue() // 检查值有效性
|
||||
```
|
||||
|
||||
### 2. ResultSetMapper (结果集映射器)
|
||||
```go
|
||||
// 位置:core/result_mapper.go
|
||||
- MapToSlice() // 映射到 Slice
|
||||
- MapToStruct() // 映射到 Struct
|
||||
- ScanAll() // 通用扫描方法
|
||||
```
|
||||
|
||||
### 3. CodeGenerator (代码生成器)
|
||||
```go
|
||||
// 位置:generator/generator.go
|
||||
- GenerateModel() // 生成 Model
|
||||
- GenerateDAO() // 生成 DAO
|
||||
- GenerateAll() // 生成完整代码
|
||||
```
|
||||
|
||||
### 4. QueryCache (查询缓存)
|
||||
```go
|
||||
// 位置:core/cache.go
|
||||
- Set() // 设置缓存
|
||||
- Get() // 获取缓存
|
||||
- Delete() // 删除缓存
|
||||
- Clear() // 清空缓存
|
||||
- GenerateCacheKey() // 生成缓存键
|
||||
```
|
||||
|
||||
### 5. ReadWriteDB (读写分离)
|
||||
```go
|
||||
// 位置:core/read_write.go
|
||||
- GetMaster() // 获取主库(写)
|
||||
- GetSlave() // 获取从库(读)
|
||||
- AddSlave() // 添加从库
|
||||
- RemoveSlave() // 移除从库
|
||||
- selectLeastConn() // 最少连接选择
|
||||
```
|
||||
|
||||
### 6. RelationLoader (关联加载器)
|
||||
```go
|
||||
// 位置:core/relation.go
|
||||
- Preload() // 预加载关联
|
||||
- loadHasOne() // 加载一对一
|
||||
- loadHasMany() // 加载一对多
|
||||
- loadBelongsTo() // 加载多对一
|
||||
- loadManyToMany() // 加载多对多
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📈 测试覆盖率
|
||||
|
||||
### 测试文件
|
||||
- ✅ `core_test.go` - 核心功能测试
|
||||
- ✅ `features_test.go` - 高级功能测试
|
||||
- ✅ `validation_test.go` - 完整性验证测试
|
||||
- ✅ `main_test.go` - 演示测试
|
||||
|
||||
### 测试结果汇总
|
||||
```
|
||||
=== RUN TestFieldMapper
|
||||
✓ 字段映射器测试通过
|
||||
|
||||
=== RUN TestQueryBuilder
|
||||
✓ 查询构建器测试通过
|
||||
|
||||
=== RUN TestResultSetMapper
|
||||
✓ 结果集映射器测试通过
|
||||
|
||||
=== RUN TestSoftDelete
|
||||
✓ 软删除功能测试通过
|
||||
|
||||
=== RUN TestQueryCache
|
||||
✓ 查询缓存测试通过
|
||||
|
||||
=== RUN TestReadWriteDB
|
||||
✓ 读写分离代码结构测试通过
|
||||
|
||||
=== RUN TestRelationLoader
|
||||
✓ 关联加载代码结构测试通过
|
||||
|
||||
=== RUN TestTracing
|
||||
✓ 链路追踪代码结构测试通过
|
||||
|
||||
=== RUN TestParamFilter
|
||||
✓ 参数过滤器测试通过
|
||||
|
||||
=== RUN TestCodeGenerator
|
||||
✓ Model 已生成:generated\user.go
|
||||
✓ 代码生成器测试通过
|
||||
|
||||
=== RUN TestAllCoreFeatures
|
||||
✓ 所有核心功能验证完成
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎉 总结
|
||||
|
||||
### 实现完成度
|
||||
- **核心接口**: 100% (8/8)
|
||||
- **查询构建器方法**: 100% (33/33)
|
||||
- **事务方法**: 100% (10/10)
|
||||
- **高级特性**: 100% (6/6)
|
||||
- **工具组件**: 100% (4/4)
|
||||
- **代码生成**: 100% (2/2)
|
||||
|
||||
### 项目文件统计
|
||||
```
|
||||
db/
|
||||
├── core/ # 核心实现 (12 个文件)
|
||||
│ ├── interfaces.go # 接口定义
|
||||
│ ├── database.go # 数据库连接
|
||||
│ ├── query.go # 查询构建器
|
||||
│ ├── transaction.go # 事务管理
|
||||
│ ├── mapper.go # 字段映射器
|
||||
│ ├── migrator.go # 迁移管理器
|
||||
│ ├── result_mapper.go # 结果集映射器 ✨
|
||||
│ ├── soft_delete.go # 软删除 ✨
|
||||
│ ├── relation.go # 关联加载 ✨
|
||||
│ ├── cache.go # 查询缓存 ✨
|
||||
│ ├── read_write.go # 读写分离 ✨
|
||||
│ └── filter.go # 参数过滤器 ✨
|
||||
├── driver/ # 驱动层 (2 个文件)
|
||||
│ ├── manager.go
|
||||
│ └── sqlite.go
|
||||
├── generator/ # 代码生成器 (1 个文件) ✨
|
||||
│ └── generator.go
|
||||
├── tracing/ # 链路追踪 (1 个文件)
|
||||
│ └── tracer.go
|
||||
├── model/ # 示例模型 (1 个文件)
|
||||
│ └── user.go
|
||||
├── core_test.go # 核心测试
|
||||
├── features_test.go # 功能测试
|
||||
├── validation_test.go # 完整性验证 ✨
|
||||
├── example.go # 使用示例
|
||||
└── README.md # 架构文档
|
||||
```
|
||||
|
||||
### 编译状态
|
||||
```bash
|
||||
✅ go build ./... # 编译成功
|
||||
```
|
||||
|
||||
### 功能验证
|
||||
```bash
|
||||
✅ go test -v validation_test.go # 所有核心功能验证通过
|
||||
✅ go test -v features_test.go # 高级功能测试通过
|
||||
✅ go test -v core_test.go # 核心功能测试通过
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 结论
|
||||
|
||||
**Magic-ORM 框架已 100% 完整实现 README.md 中定义的所有核心特性!**
|
||||
|
||||
框架具备:
|
||||
- ✅ 完整的 CRUD 操作能力
|
||||
- ✅ 强大的事务管理
|
||||
- ✅ 智能的字段和结果映射
|
||||
- ✅ 灵活的查询构建
|
||||
- ✅ 完善的迁移工具
|
||||
- ✅ 高效的代码生成
|
||||
- ✅ 企业级的高级特性
|
||||
- ✅ 全面的可观测性支持
|
||||
|
||||
**所有功能均已编译通过并通过测试验证!** 🎉
|
||||
|
|
@ -1,348 +0,0 @@
|
|||
# Magic-ORM 代码生成器 - 命令行工具
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 构建命令行工具
|
||||
|
||||
**Windows:**
|
||||
```bash
|
||||
build.bat
|
||||
```
|
||||
|
||||
**Linux/Mac:**
|
||||
```bash
|
||||
chmod +x build.sh
|
||||
./build.sh
|
||||
```
|
||||
|
||||
或者手动构建:
|
||||
```bash
|
||||
cd db
|
||||
go build -o ../bin/gendb ./cmd/gendb
|
||||
```
|
||||
|
||||
### 2. 使用方法
|
||||
|
||||
#### 基础用法
|
||||
|
||||
```bash
|
||||
# 生成单个表
|
||||
gendb user
|
||||
|
||||
# 生成多个表
|
||||
gendb user product order
|
||||
|
||||
# 指定输出目录
|
||||
gendb -o ./models user product
|
||||
```
|
||||
|
||||
#### 高级用法
|
||||
|
||||
```bash
|
||||
# 自定义列定义
|
||||
gendb user id:int64:primary username:string email:string created_at:time.Time
|
||||
|
||||
# 混合使用(自动推断 + 自定义)
|
||||
gendb -o ./generated user username:string email:string product name:string price:float64
|
||||
|
||||
# 查看版本
|
||||
gendb -v
|
||||
|
||||
# 查看帮助
|
||||
gendb -h
|
||||
```
|
||||
|
||||
## 📋 功能特性
|
||||
|
||||
✅ **自动生成**: 根据表名自动推断常用字段
|
||||
✅ **批量生成**: 一次生成多个表的代码
|
||||
✅ **自定义列**: 支持手动指定列定义
|
||||
✅ **灵活输出**: 可指定输出目录
|
||||
✅ **智能推断**: 自动识别常见表结构
|
||||
|
||||
## 🎯 支持的类型
|
||||
|
||||
| 类型别名 | Go 类型 |
|
||||
|---------|---------|
|
||||
| int, integer, bigint | int64 |
|
||||
| string, text, varchar | string |
|
||||
| time, datetime | time.Time |
|
||||
| bool, boolean | bool |
|
||||
| float, double | float64 |
|
||||
| decimal | string |
|
||||
|
||||
## 📝 列定义格式
|
||||
|
||||
```
|
||||
字段名:类型 [:primary] [:nullable]
|
||||
```
|
||||
|
||||
示例:
|
||||
- `id:int64:primary` - 主键 ID
|
||||
- `username:string` - 用户名字段
|
||||
- `email:string:nullable` - 可为空的邮箱字段
|
||||
- `created_at:time.Time` - 创建时间字段
|
||||
|
||||
## 🔧 预设表结构
|
||||
|
||||
工具内置了常见表的默认结构:
|
||||
|
||||
### user / users
|
||||
- id (主键)
|
||||
- username
|
||||
- email (可空)
|
||||
- password
|
||||
- status
|
||||
- created_at
|
||||
- updated_at
|
||||
|
||||
### product / products
|
||||
- id (主键)
|
||||
- name
|
||||
- price
|
||||
- stock
|
||||
- description (可空)
|
||||
- created_at
|
||||
|
||||
### order / orders
|
||||
- id (主键)
|
||||
- order_no
|
||||
- user_id
|
||||
- total_amount
|
||||
- status
|
||||
- created_at
|
||||
|
||||
## 💡 使用示例
|
||||
|
||||
### 示例 1: 快速生成用户模块
|
||||
|
||||
```bash
|
||||
gendb user
|
||||
```
|
||||
|
||||
生成文件:
|
||||
- `generated/user.go` - User Model
|
||||
- `generated/user_dao.go` - User DAO
|
||||
|
||||
### 示例 2: 生成电商模块
|
||||
|
||||
```bash
|
||||
gendb -o ./shop user product order
|
||||
```
|
||||
|
||||
生成文件:
|
||||
- `shop/user.go`
|
||||
- `shop/user_dao.go`
|
||||
- `shop/product.go`
|
||||
- `shop/product_dao.go`
|
||||
- `shop/order.go`
|
||||
- `shop/order_dao.go`
|
||||
|
||||
### 示例 3: 完全自定义
|
||||
|
||||
```bash
|
||||
gendb article \
|
||||
id:int64:primary \
|
||||
title:string \
|
||||
content:string:nullable \
|
||||
author_id:int64 \
|
||||
view_count:int \
|
||||
published:bool \
|
||||
created_at:time.Time
|
||||
```
|
||||
|
||||
## 📁 生成的代码结构
|
||||
|
||||
### Model (user.go)
|
||||
|
||||
```go
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// User user 表模型
|
||||
type User struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Username string `json:"username" db:"username"`
|
||||
Email string `json:"email" db:"email"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (User) TableName() string {
|
||||
return "user"
|
||||
}
|
||||
```
|
||||
|
||||
### DAO (user_dao.go)
|
||||
|
||||
```go
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"git.magicany.cc/black1552/gin-base/db/core"
|
||||
"git.magicany.cc/black1552/gin-base/db/model"
|
||||
)
|
||||
|
||||
// UserDAO user 表数据访问对象
|
||||
type UserDAO struct {
|
||||
db *core.Database
|
||||
}
|
||||
|
||||
// NewUserDAO 创建 UserDAO 实例
|
||||
func NewUserDAO(db *core.Database) *UserDAO {
|
||||
return &UserDAO{db: db}
|
||||
}
|
||||
|
||||
// Create 创建记录
|
||||
func (dao *UserDAO) Create(ctx context.Context, model *model.User) error {
|
||||
_, err := dao.db.Model(model).Insert(model)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 查询
|
||||
func (dao *UserDAO) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
var result model.User
|
||||
err := dao.db.Model(&model.User{}).Where("id = ?", id).First(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ... 更多 CRUD 方法
|
||||
```
|
||||
|
||||
## 🛠️ 安装到 PATH
|
||||
|
||||
### Windows
|
||||
|
||||
1. 将 `bin` 目录添加到系统环境变量 PATH
|
||||
2. 或者复制 `gendb.exe` 到任意 PATH 中的目录
|
||||
|
||||
```powershell
|
||||
# 临时添加到当前会话
|
||||
$env:PATH += ";$(pwd)\bin"
|
||||
|
||||
# 永久添加(需要管理员权限)
|
||||
[Environment]::SetEnvironmentVariable(
|
||||
"Path",
|
||||
$env:Path + ";$(pwd)\bin",
|
||||
[EnvironmentVariableTarget]::Machine
|
||||
)
|
||||
```
|
||||
|
||||
### Linux/Mac
|
||||
|
||||
```bash
|
||||
# 临时添加到当前会话
|
||||
export PATH=$PATH:$(pwd)/bin
|
||||
|
||||
# 永久添加(添加到 ~/.bashrc 或 ~/.zshrc)
|
||||
echo 'export PATH=$PATH:$(pwd)/bin' >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
|
||||
# 或者复制到系统目录
|
||||
sudo cp bin/gendb /usr/local/bin/
|
||||
```
|
||||
|
||||
## ⚙️ 选项说明
|
||||
|
||||
| 选项 | 简写 | 说明 | 默认值 |
|
||||
|------|------|------|--------|
|
||||
| `-version` | `-v` | 显示版本号 | - |
|
||||
| `-help` | `-h` | 显示帮助信息 | - |
|
||||
| `-o` | - | 输出目录 | `./generated` |
|
||||
|
||||
## 🎨 最佳实践
|
||||
|
||||
### 1. 从数据库读取真实结构
|
||||
|
||||
```bash
|
||||
# 先用 SQL 导出表结构
|
||||
mysql -u root -p -e "DESCRIBE your_database.users;"
|
||||
|
||||
# 然后根据输出调整列定义
|
||||
```
|
||||
|
||||
### 2. 批量生成项目所有表
|
||||
|
||||
```bash
|
||||
# 一次性生成所有表
|
||||
gendb user product order category tag article comment
|
||||
```
|
||||
|
||||
### 3. 版本控制
|
||||
|
||||
```bash
|
||||
# 将生成的代码纳入 Git 管理
|
||||
git add generated/
|
||||
git commit -m "feat: 生成基础 Model 和 DAO 代码"
|
||||
```
|
||||
|
||||
### 4. 自定义扩展
|
||||
|
||||
生成的代码可以作为基础,手动添加:
|
||||
- 业务逻辑方法
|
||||
- 验证逻辑
|
||||
- 关联查询
|
||||
- 索引优化
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
1. **生成的代码需审查**: 自动生成的代码可能不完全符合业务需求
|
||||
2. **不要频繁覆盖**: 手动修改的代码可能会被覆盖
|
||||
3. **类型映射**: 特殊类型可能需要手动调整
|
||||
4. **关联关系**: 复杂的模型关联需手动实现
|
||||
|
||||
## 🐛 故障排除
|
||||
|
||||
### 问题 1: 找不到命令
|
||||
|
||||
```bash
|
||||
# 确保已构建并添加到 PATH
|
||||
gendb: command not found
|
||||
|
||||
# 解决:
|
||||
./bin/gendb -h # 使用相对路径
|
||||
```
|
||||
|
||||
### 问题 2: 生成失败
|
||||
|
||||
```bash
|
||||
# 检查输出目录是否有写权限
|
||||
# 检查表名是否合法
|
||||
# 使用 -h 查看正确的语法
|
||||
```
|
||||
|
||||
### 问题 3: 类型不匹配
|
||||
|
||||
```bash
|
||||
# 手动指定正确的类型
|
||||
gendb user price:float64 instead of price:int
|
||||
```
|
||||
|
||||
## 📞 获取帮助
|
||||
|
||||
```bash
|
||||
# 查看完整帮助
|
||||
gendb -h
|
||||
|
||||
# 查看版本
|
||||
gendb -v
|
||||
```
|
||||
|
||||
## 🎉 开始使用
|
||||
|
||||
```bash
|
||||
# 最简单的用法
|
||||
gendb user
|
||||
|
||||
# 立即体验!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**Magic-ORM Code Generator** - 让代码生成如此简单!🚀
|
||||
|
|
@ -1,425 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/db/config"
|
||||
"git.magicany.cc/black1552/gin-base/db/generator"
|
||||
"git.magicany.cc/black1552/gin-base/db/introspector"
|
||||
)
|
||||
|
||||
// 设置 Windows 控制台编码为 UTF-8
|
||||
func init() {
|
||||
// 在 Windows 上设置控制台输出代码页为 UTF-8 (65001)
|
||||
// 这样可以避免中文乱码问题
|
||||
}
|
||||
|
||||
const version = "1.0.0"
|
||||
|
||||
func main() {
|
||||
// 定义命令行参数
|
||||
versionFlag := flag.Bool("version", false, "显示版本号")
|
||||
vFlag := flag.Bool("v", false, "显示版本号(简写)")
|
||||
helpFlag := flag.Bool("help", false, "显示帮助信息")
|
||||
hFlag := flag.Bool("h", false, "显示帮助信息(简写)")
|
||||
outputDir := flag.String("o", "./model", "输出目录")
|
||||
allFlag := flag.Bool("all", false, "生成所有预设的表(user, product, order)")
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, `Magic-ORM 代码生成器 - 快速生成 Model 和 DAO 代码
|
||||
|
||||
用法:
|
||||
gendb [选项] <表名> [列定义...]
|
||||
gendb [选项] -all
|
||||
|
||||
选项:
|
||||
`)
|
||||
flag.PrintDefaults()
|
||||
|
||||
fmt.Fprintf(os.Stderr, `
|
||||
示例:
|
||||
# 生成 user 表代码(自动推断常用列)
|
||||
gendb user
|
||||
|
||||
# 指定输出目录
|
||||
gendb -o ./models user product
|
||||
|
||||
# 自定义列定义
|
||||
gendb user id:int64:primary username:string email:string created_at:time.Time
|
||||
|
||||
# 批量生成多个表
|
||||
gendb user product order
|
||||
|
||||
# 生成所有预设的表(user, product, order)
|
||||
gendb -all
|
||||
|
||||
列定义格式:
|
||||
字段名:类型 [:primary] [:nullable]
|
||||
|
||||
支持的类型:
|
||||
int64, string, time.Time, bool, float64, int
|
||||
|
||||
更多信息:
|
||||
https://github.com/your-repo/magic-orm
|
||||
`)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// 检查版本参数
|
||||
if *versionFlag || *vFlag {
|
||||
fmt.Printf("Magic-ORM Code Generator v%s\n", version)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查帮助参数
|
||||
if *helpFlag || *hFlag {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 -all 参数
|
||||
if *allFlag {
|
||||
generateAllTablesFromDB(*outputDir)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取参数
|
||||
args := flag.Args()
|
||||
if len(args) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "错误:请指定至少一个表名")
|
||||
fmt.Fprintln(os.Stderr, "使用 'gendb -h' 查看帮助")
|
||||
fmt.Fprintln(os.Stderr, "或者使用 'gendb -all' 生成所有预设表")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
tableNames := args
|
||||
|
||||
// 创建代码生成器
|
||||
cg := generator.NewCodeGenerator(*outputDir)
|
||||
|
||||
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
|
||||
fmt.Printf("[Output Directory: %s]\n", *outputDir)
|
||||
fmt.Println()
|
||||
|
||||
// 处理每个表
|
||||
for _, tableName := range tableNames {
|
||||
// 跳过看起来像列定义的参数
|
||||
if strings.Contains(tableName, ":") {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("[Generating table '%s'...]\n", tableName)
|
||||
|
||||
// 解析列定义(如果有提供)
|
||||
columns := parseColumns(tableNames, tableName)
|
||||
|
||||
// 如果没有自定义列定义,使用默认列
|
||||
if len(columns) == 0 {
|
||||
columns = getDefaultColumns(tableName)
|
||||
}
|
||||
|
||||
// 生成代码
|
||||
err := cg.GenerateAll(tableName, columns)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("[Complete] Code generation finished!")
|
||||
fmt.Printf("[Location] Files are in: %s directory\n", *outputDir)
|
||||
}
|
||||
|
||||
// parseColumns 解析列定义
|
||||
func parseColumns(args []string, currentTable string) []generator.ColumnInfo {
|
||||
// 查找当前表的列定义
|
||||
found := false
|
||||
columnDefs := []string{}
|
||||
|
||||
for i, arg := range args {
|
||||
if arg == currentTable && !found {
|
||||
found = true
|
||||
// 收集后续的列定义
|
||||
for j := i + 1; j < len(args); j++ {
|
||||
if strings.Contains(args[j], ":") {
|
||||
columnDefs = append(columnDefs, args[j])
|
||||
} else {
|
||||
break // 遇到下一个表名
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(columnDefs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
columns := []generator.ColumnInfo{}
|
||||
for _, def := range columnDefs {
|
||||
parts := strings.Split(def, ":")
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
colName := parts[0]
|
||||
fieldType := parts[1]
|
||||
isPrimary := false
|
||||
isNullable := false
|
||||
|
||||
// 检查修饰符
|
||||
for i := 2; i < len(parts); i++ {
|
||||
switch strings.ToLower(parts[i]) {
|
||||
case "primary":
|
||||
isPrimary = true
|
||||
case "nullable":
|
||||
isNullable = true
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为 Go 字段名(驼峰)
|
||||
fieldName := toCamelCase(colName)
|
||||
|
||||
// 映射类型
|
||||
goType := mapType(fieldType)
|
||||
|
||||
columns = append(columns, generator.ColumnInfo{
|
||||
ColumnName: colName,
|
||||
FieldName: fieldName,
|
||||
FieldType: goType,
|
||||
JSONName: colName,
|
||||
IsPrimary: isPrimary,
|
||||
IsNullable: isNullable,
|
||||
})
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// getDefaultColumns 获取默认的列定义(根据表名推断)
|
||||
func getDefaultColumns(tableName string) []generator.ColumnInfo {
|
||||
columns := []generator.ColumnInfo{
|
||||
{
|
||||
ColumnName: "id",
|
||||
FieldName: "ID",
|
||||
FieldType: "int64",
|
||||
JSONName: "id",
|
||||
IsPrimary: true,
|
||||
},
|
||||
}
|
||||
|
||||
// 根据表名添加常见字段
|
||||
switch tableName {
|
||||
case "user", "users":
|
||||
columns = append(columns,
|
||||
generator.ColumnInfo{ColumnName: "username", FieldName: "Username", FieldType: "string", JSONName: "username"},
|
||||
generator.ColumnInfo{ColumnName: "email", FieldName: "Email", FieldType: "string", JSONName: "email", IsNullable: true},
|
||||
generator.ColumnInfo{ColumnName: "password", FieldName: "Password", FieldType: "string", JSONName: "password"},
|
||||
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
|
||||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||||
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
|
||||
)
|
||||
case "product", "products":
|
||||
columns = append(columns,
|
||||
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
|
||||
generator.ColumnInfo{ColumnName: "price", FieldName: "Price", FieldType: "float64", JSONName: "price"},
|
||||
generator.ColumnInfo{ColumnName: "stock", FieldName: "Stock", FieldType: "int", JSONName: "stock"},
|
||||
generator.ColumnInfo{ColumnName: "description", FieldName: "Description", FieldType: "string", JSONName: "description", IsNullable: true},
|
||||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||||
)
|
||||
case "order", "orders":
|
||||
columns = append(columns,
|
||||
generator.ColumnInfo{ColumnName: "order_no", FieldName: "OrderNo", FieldType: "string", JSONName: "order_no"},
|
||||
generator.ColumnInfo{ColumnName: "user_id", FieldName: "UserID", FieldType: "int64", JSONName: "user_id"},
|
||||
generator.ColumnInfo{ColumnName: "total_amount", FieldName: "TotalAmount", FieldType: "float64", JSONName: "total_amount"},
|
||||
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
|
||||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||||
)
|
||||
default:
|
||||
// 默认添加通用字段
|
||||
columns = append(columns,
|
||||
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
|
||||
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
|
||||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||||
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
|
||||
)
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// mapType 将类型字符串映射到 Go 类型
|
||||
func mapType(typeStr string) string {
|
||||
typeMap := map[string]string{
|
||||
"int": "int64",
|
||||
"integer": "int64",
|
||||
"bigint": "int64",
|
||||
"string": "string",
|
||||
"text": "string",
|
||||
"varchar": "string",
|
||||
"time.Time": "time.Time",
|
||||
"time": "time.Time",
|
||||
"datetime": "time.Time",
|
||||
"bool": "bool",
|
||||
"boolean": "bool",
|
||||
"float": "float64",
|
||||
"float64": "float64",
|
||||
"double": "float64",
|
||||
"decimal": "string",
|
||||
}
|
||||
|
||||
if goType, ok := typeMap[strings.ToLower(typeStr)]; ok {
|
||||
return goType
|
||||
}
|
||||
return "string" // 默认返回 string
|
||||
}
|
||||
|
||||
// toCamelCase 转换为驼峰命名
|
||||
func toCamelCase(str string) string {
|
||||
parts := strings.Split(str, "_")
|
||||
result := ""
|
||||
|
||||
for _, part := range parts {
|
||||
if len(part) > 0 {
|
||||
result += strings.ToUpper(string(part[0])) + part[1:]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// generateAllTablesFromDB 从数据库读取所有表并生成代码
|
||||
func generateAllTablesFromDB(outputDir string) {
|
||||
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
|
||||
fmt.Println()
|
||||
|
||||
// 1. 加载配置文件
|
||||
fmt.Println("[Step 1] Loading configuration file...")
|
||||
cfg, err := loadDatabaseConfig()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[Error] Failed to load config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("[Info] Database type: %s\n", cfg.Type)
|
||||
fmt.Printf("[Info] Database name: %s\n", cfg.Name)
|
||||
fmt.Println()
|
||||
|
||||
// 2. 连接数据库并获取所有表
|
||||
fmt.Println("[Step 2] Connecting to database and fetching table structure...")
|
||||
intro, err := introspector.NewIntrospector(cfg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[Error] Failed to connect to database: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer intro.Close()
|
||||
|
||||
tableNames, err := intro.GetTableNames()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[Error] Failed to get table names: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("[Info] Found %d tables\n", len(tableNames))
|
||||
fmt.Println()
|
||||
|
||||
// 3. 创建代码生成器
|
||||
cg := generator.NewCodeGenerator(outputDir)
|
||||
|
||||
// 4. 为每个表生成代码
|
||||
for _, tableName := range tableNames {
|
||||
fmt.Printf("[Generating] Table '%s'...\n", tableName)
|
||||
|
||||
// 获取表详细信息
|
||||
tableInfo, err := intro.GetTableInfo(tableName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[Error] Failed to get table info: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换为 generator.ColumnInfo
|
||||
columns := make([]generator.ColumnInfo, len(tableInfo.Columns))
|
||||
for i, col := range tableInfo.Columns {
|
||||
columns[i] = generator.ColumnInfo{
|
||||
ColumnName: col.ColumnName,
|
||||
FieldName: col.FieldName,
|
||||
FieldType: col.GoType,
|
||||
JSONName: col.JSONName,
|
||||
IsPrimary: col.IsPrimary,
|
||||
IsNullable: col.IsNullable,
|
||||
}
|
||||
}
|
||||
|
||||
// 生成代码
|
||||
err = cg.GenerateAll(tableName, columns)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("[Complete] Code generation finished!")
|
||||
fmt.Printf("[Location] Files are in: %s directory\n", outputDir)
|
||||
}
|
||||
|
||||
// loadDatabaseConfig 加载数据库配置
|
||||
func loadDatabaseConfig() (*config.DatabaseConfig, error) {
|
||||
// 自动查找配置文件
|
||||
configPath, err := config.FindConfigFile("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找配置文件失败:%w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[Info] Using config file: %s\n", configPath)
|
||||
|
||||
// 从文件加载配置
|
||||
cfg, err := config.LoadFromFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载配置文件失败:%w", err)
|
||||
}
|
||||
|
||||
return &cfg.Database, nil
|
||||
}
|
||||
|
||||
// generateAllTables 生成所有预设的表
|
||||
func generateAllTables(outputDir string) {
|
||||
fmt.Printf("🚀 Magic-ORM 代码生成器 v%s\n", version)
|
||||
fmt.Printf("📁 输出目录:%s\n", outputDir)
|
||||
fmt.Println()
|
||||
|
||||
// 预设的所有表
|
||||
presetTables := []string{"user", "product", "order"}
|
||||
|
||||
// 创建代码生成器
|
||||
cg := generator.NewCodeGenerator(outputDir)
|
||||
|
||||
// 处理每个表
|
||||
for _, tableName := range presetTables {
|
||||
fmt.Printf("📝 生成表 '%s' 的代码...\n", tableName)
|
||||
|
||||
// 使用默认列定义
|
||||
columns := getDefaultColumns(tableName)
|
||||
|
||||
// 生成代码
|
||||
err := cg.GenerateAll(tableName, columns)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "❌ 生成失败:%v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("✅ 成功生成 %s.go 和 %s_dao.go\n", tableName, tableName)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("✨ 代码生成完成!")
|
||||
fmt.Printf("📂 生成的文件在:%s 目录下\n", outputDir)
|
||||
}
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Magic-ORM 数据库配置文件示例
|
||||
# 请根据实际情况修改以下配置
|
||||
|
||||
database:
|
||||
# 数据库地址(MySQL/PostgreSQL 为 IP 或域名,SQLite 可忽略)
|
||||
host: "127.0.0.1"
|
||||
|
||||
# 数据库端口
|
||||
port: "3306"
|
||||
|
||||
# 数据库用户名
|
||||
user: "root"
|
||||
|
||||
# 数据库密码
|
||||
pass: "your_password"
|
||||
|
||||
# 数据库名称
|
||||
name: "your_database"
|
||||
|
||||
# 数据库类型(支持:mysql, postgres, sqlite)
|
||||
type: "mysql"
|
||||
|
||||
# 其他配置可以继续添加...
|
||||
BIN
db/config.yaml
BIN
db/config.yaml
Binary file not shown.
|
|
@ -1,142 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAutoFindConfig 测试自动查找配置文件
|
||||
func TestAutoFindConfig(t *testing.T) {
|
||||
fmt.Println("\n=== 测试自动查找配置文件 ===")
|
||||
|
||||
// 创建临时目录结构
|
||||
tempDir, err := os.MkdirTemp("", "config_test")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时目录失败:%v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// 创建子目录
|
||||
subDir := filepath.Join(tempDir, "subdir")
|
||||
if err := os.MkdirAll(subDir, 0755); err != nil {
|
||||
t.Fatalf("创建子目录失败:%v", err)
|
||||
}
|
||||
|
||||
// 在根目录创建配置文件
|
||||
configContent := `database:
|
||||
host: "127.0.0.1"
|
||||
port: "3306"
|
||||
user: "root"
|
||||
pass: "test"
|
||||
name: "testdb"
|
||||
type: "mysql"
|
||||
`
|
||||
|
||||
configFile := filepath.Join(tempDir, "config.yaml")
|
||||
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
|
||||
t.Fatalf("创建配置文件失败:%v", err)
|
||||
}
|
||||
|
||||
// 测试 1:从子目录查找(应该能找到父目录的配置)
|
||||
foundPath, err := findConfigFile(subDir)
|
||||
if err != nil {
|
||||
t.Errorf("从子目录查找失败:%v", err)
|
||||
} else {
|
||||
fmt.Printf("✓ 从子目录找到配置文件:%s\n", foundPath)
|
||||
}
|
||||
|
||||
// 测试 2:从根目录查找
|
||||
foundPath, err = findConfigFile(tempDir)
|
||||
if err != nil {
|
||||
t.Errorf("从根目录查找失败:%v", err)
|
||||
} else {
|
||||
fmt.Printf("✓ 从根目录找到配置文件:%s\n", foundPath)
|
||||
}
|
||||
|
||||
// 测试 3:测试不同格式的配置文件
|
||||
formats := []string{"config.yaml", "config.yml", "config.toml", "config.json"}
|
||||
for _, format := range formats {
|
||||
testFile := filepath.Join(tempDir, format)
|
||||
if err := os.WriteFile(testFile, []byte(configContent), 0644); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
foundPath, err = findConfigFile(tempDir)
|
||||
if err != nil {
|
||||
t.Errorf("查找 %s 失败:%v", format, err)
|
||||
} else {
|
||||
fmt.Printf("✓ 支持格式 %s: %s\n", format, foundPath)
|
||||
}
|
||||
|
||||
os.Remove(testFile)
|
||||
}
|
||||
|
||||
fmt.Println("✓ 自动查找配置文件测试通过")
|
||||
}
|
||||
|
||||
// TestAutoConnect 测试自动连接功能
|
||||
func TestAutoConnect(t *testing.T) {
|
||||
fmt.Println("\n=== 测试 AutoConnect 接口 ===")
|
||||
|
||||
// 创建临时配置文件
|
||||
tempDir, err := os.MkdirTemp("", "autoconnect_test")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时目录失败:%v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
configContent := `database:
|
||||
host: "127.0.0.1"
|
||||
port: "3306"
|
||||
user: "root"
|
||||
pass: "test"
|
||||
name: ":memory:"
|
||||
type: "sqlite"
|
||||
`
|
||||
|
||||
configFile := filepath.Join(tempDir, "config.yaml")
|
||||
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
|
||||
t.Fatalf("创建配置文件失败:%v", err)
|
||||
}
|
||||
|
||||
// 切换到临时目录
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tempDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
// 测试 AutoConnect
|
||||
_, err = AutoConnect(false)
|
||||
if err != nil {
|
||||
t.Logf("自动连接失败(预期):%v", err)
|
||||
fmt.Println("✓ AutoConnect 接口正常(需要真实数据库才能连接成功)")
|
||||
} else {
|
||||
fmt.Println("✓ AutoConnect 自动连接成功")
|
||||
}
|
||||
|
||||
fmt.Println("✓ AutoConnect 测试完成")
|
||||
}
|
||||
|
||||
// TestAllAutoFind 完整自动查找测试
|
||||
func TestAllAutoFind(t *testing.T) {
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" 配置文件自动查找完整性测试")
|
||||
fmt.Println("========================================")
|
||||
|
||||
TestAutoFindConfig(t)
|
||||
TestAutoConnect(t)
|
||||
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" 所有自动查找测试完成!")
|
||||
fmt.Println("========================================")
|
||||
fmt.Println()
|
||||
fmt.Println("已实现的自动查找功能:")
|
||||
fmt.Println(" ✓ 自动在当前目录查找配置文件")
|
||||
fmt.Println(" ✓ 自动在上级目录查找(最多 3 层)")
|
||||
fmt.Println(" ✓ 支持 yaml, yml, toml, ini, json 格式")
|
||||
fmt.Println(" ✓ 支持 config.* 和 .config.* 命名")
|
||||
fmt.Println(" ✓ 提供 AutoConnect() 一键连接")
|
||||
fmt.Println(" ✓ 无需手动指定配置文件路径")
|
||||
fmt.Println()
|
||||
}
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/db/core"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// NewDatabaseFromConfig 从配置文件创建数据库连接(已废弃,请使用 AutoConnect)
|
||||
// Deprecated: 使用 AutoConnect 代替
|
||||
func NewDatabaseFromConfig(configPath string, debug bool) (*core.Database, error) {
|
||||
return autoConnectWithConfig(configPath, debug)
|
||||
}
|
||||
|
||||
// AutoConnect 自动查找配置文件并创建数据库连接
|
||||
// 会在当前目录及上级目录中查找 config.yaml, config.toml, config.ini, config.json 等文件
|
||||
func AutoConnect(debug bool) (*core.Database, error) {
|
||||
// 自动查找配置文件
|
||||
configPath, err := FindConfigFile("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找配置文件失败:%w", err)
|
||||
}
|
||||
|
||||
return autoConnectWithConfig(configPath, debug)
|
||||
}
|
||||
|
||||
// AutoConnectWithDir 在指定目录自动查找配置文件并创建数据库连接
|
||||
func AutoConnectWithDir(dir string, debug bool) (*core.Database, error) {
|
||||
configPath, err := FindConfigFile(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找配置文件失败:%w", err)
|
||||
}
|
||||
|
||||
return autoConnectWithConfig(configPath, debug)
|
||||
}
|
||||
|
||||
// autoConnectWithConfig 根据配置文件创建数据库连接(内部使用)
|
||||
func autoConnectWithConfig(configPath string, debug bool) (*core.Database, error) {
|
||||
// 从文件加载配置
|
||||
configFile, err := LoadFromFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载配置失败:%w", err)
|
||||
}
|
||||
|
||||
// 构建核心数据库配置
|
||||
dbConfig := &core.Config{
|
||||
DriverName: configFile.Database.GetDriverName(),
|
||||
DataSource: configFile.Database.BuildDSN(),
|
||||
Debug: debug,
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 3600000000000, // 1 小时
|
||||
TimeConfig: core.DefaultTimeConfig(), // 使用默认时间配置
|
||||
}
|
||||
|
||||
// 创建数据库连接
|
||||
db, err := core.NewDatabase(dbConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建数据库连接失败:%w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// NewDatabaseFromYAML 从 YAML 内容创建数据库连接
|
||||
func NewDatabaseFromYAML(yamlContent []byte, debug bool) (*core.Database, error) {
|
||||
var configFile Config
|
||||
if err := yaml.Unmarshal(yamlContent, &configFile); err != nil {
|
||||
return nil, fmt.Errorf("解析 YAML 失败:%w", err)
|
||||
}
|
||||
|
||||
if err := configFile.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("验证配置失败:%w", err)
|
||||
}
|
||||
|
||||
// 构建核心数据库配置
|
||||
dbConfig := &core.Config{
|
||||
DriverName: configFile.Database.GetDriverName(),
|
||||
DataSource: configFile.Database.BuildDSN(),
|
||||
Debug: debug,
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 3600000000000, // 1 小时
|
||||
TimeConfig: core.DefaultTimeConfig(),
|
||||
}
|
||||
|
||||
// 创建数据库连接
|
||||
db, err := core.NewDatabase(dbConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建数据库连接失败:%w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
|
@ -1,165 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// DatabaseConfig 数据库配置结构 - 对应配置文件中的 database 部分
|
||||
type DatabaseConfig struct {
|
||||
Host string `yaml:"host"` // 数据库地址
|
||||
Port string `yaml:"port"` // 数据库端口
|
||||
User string `yaml:"user"` // 用户名
|
||||
Password string `yaml:"pass"` // 密码
|
||||
Name string `yaml:"name"` // 数据库名称
|
||||
Type string `yaml:"type"` // 数据库类型(mysql, sqlite, postgres 等)
|
||||
}
|
||||
|
||||
// Config 完整配置文件结构
|
||||
type Config struct {
|
||||
Database DatabaseConfig `yaml:"database"` // 数据库配置
|
||||
}
|
||||
|
||||
// LoadFromFile 从 YAML 文件加载配置
|
||||
func LoadFromFile(filePath string) (*Config, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败:%w", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败:%w", err)
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// Validate 验证配置
|
||||
func (c *Config) Validate() error {
|
||||
if c.Database.Type == "" {
|
||||
return fmt.Errorf("数据库类型不能为空")
|
||||
}
|
||||
|
||||
if c.Database.Type == "sqlite" {
|
||||
// SQLite 只需要 Name(作为文件路径)
|
||||
if c.Database.Name == "" {
|
||||
return fmt.Errorf("SQLite 数据库名称不能为空")
|
||||
}
|
||||
} else {
|
||||
// 其他数据库需要所有字段
|
||||
if c.Database.Host == "" {
|
||||
return fmt.Errorf("数据库地址不能为空")
|
||||
}
|
||||
if c.Database.Port == "" {
|
||||
return fmt.Errorf("数据库端口不能为空")
|
||||
}
|
||||
if c.Database.User == "" {
|
||||
return fmt.Errorf("数据库用户名不能为空")
|
||||
}
|
||||
if c.Database.Password == "" {
|
||||
return fmt.Errorf("数据库密码不能为空")
|
||||
}
|
||||
if c.Database.Name == "" {
|
||||
return fmt.Errorf("数据库名称不能为空")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildDSN 根据配置构建数据源连接字符串(DSN)
|
||||
func (c *DatabaseConfig) BuildDSN() string {
|
||||
switch c.Type {
|
||||
case "mysql":
|
||||
return c.buildMySQLDSN()
|
||||
case "postgres":
|
||||
return c.buildPostgresDSN()
|
||||
case "sqlite":
|
||||
return c.buildSQLiteDSN()
|
||||
default:
|
||||
// 默认返回原始配置
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// buildMySQLDSN 构建 MySQL DSN
|
||||
func (c *DatabaseConfig) buildMySQLDSN() string {
|
||||
// 格式:user:pass@tcp(host:port)/dbname?charset=utf8mb4&parseTime=True&loc=Local
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
c.User,
|
||||
c.Password,
|
||||
c.Host,
|
||||
c.Port,
|
||||
c.Name,
|
||||
)
|
||||
return dsn
|
||||
}
|
||||
|
||||
// buildPostgresDSN 构建 PostgreSQL DSN
|
||||
func (c *DatabaseConfig) buildPostgresDSN() string {
|
||||
// 格式:host=localhost port=5432 user=user password=password dbname=db sslmode=disable
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
|
||||
c.Host,
|
||||
c.Port,
|
||||
c.User,
|
||||
c.Password,
|
||||
c.Name,
|
||||
)
|
||||
return dsn
|
||||
}
|
||||
|
||||
// buildSQLiteDSN 构建 SQLite DSN
|
||||
func (c *DatabaseConfig) buildSQLiteDSN() string {
|
||||
// SQLite 直接使用文件名作为 DSN
|
||||
return c.Name
|
||||
}
|
||||
|
||||
// GetDriverName 获取驱动名称
|
||||
func (c *DatabaseConfig) GetDriverName() string {
|
||||
return c.Type
|
||||
}
|
||||
|
||||
// FindConfigFile 在项目目录下自动查找配置文件
|
||||
// 支持 yaml, yml, toml, ini, json 等格式
|
||||
// 只在当前目录查找,不越级查找
|
||||
func FindConfigFile(searchDir string) (string, error) {
|
||||
// 配置文件名优先级列表
|
||||
configNames := []string{
|
||||
"config.yaml", "config.yml",
|
||||
"config.toml",
|
||||
"config.ini",
|
||||
"config.json",
|
||||
".config.yaml", ".config.yml",
|
||||
".config.toml",
|
||||
".config.ini",
|
||||
".config.json",
|
||||
}
|
||||
|
||||
// 如果未指定搜索目录,使用当前目录
|
||||
if searchDir == "" {
|
||||
var err error
|
||||
searchDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取当前目录失败:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 只在当前目录下查找,不向上查找
|
||||
for _, name := range configNames {
|
||||
filePath := filepath.Join(searchDir, name)
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
return filePath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("未找到配置文件(支持 yaml, yml, toml, ini, json 格式)")
|
||||
}
|
||||
|
|
@ -1,194 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLoadFromFile 测试从文件加载配置
|
||||
func TestLoadFromFile(t *testing.T) {
|
||||
fmt.Println("\n=== 测试从文件加载配置 ===")
|
||||
|
||||
// 创建临时配置文件
|
||||
tempConfig := `database:
|
||||
host: "127.0.0.1"
|
||||
port: "3306"
|
||||
user: "root"
|
||||
pass: "test_password"
|
||||
name: "test_db"
|
||||
type: "mysql"
|
||||
`
|
||||
|
||||
// 写入临时文件
|
||||
tempFile := "test_config.yaml"
|
||||
if err := os.WriteFile(tempFile, []byte(tempConfig), 0644); err != nil {
|
||||
t.Fatalf("创建临时文件失败:%v", err)
|
||||
}
|
||||
defer os.Remove(tempFile) // 测试完成后删除
|
||||
|
||||
// 加载配置
|
||||
config, err := LoadFromFile(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("加载配置失败:%v", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if config.Database.Host != "127.0.0.1" {
|
||||
t.Errorf("期望 Host 为 127.0.0.1,实际为 %s", config.Database.Host)
|
||||
}
|
||||
if config.Database.Port != "3306" {
|
||||
t.Errorf("期望 Port 为 3306,实际为 %s", config.Database.Port)
|
||||
}
|
||||
if config.Database.User != "root" {
|
||||
t.Errorf("期望 User 为 root,实际为 %s", config.Database.User)
|
||||
}
|
||||
if config.Database.Password != "test_password" {
|
||||
t.Errorf("期望 Password 为 test_password,实际为 %s", config.Database.Password)
|
||||
}
|
||||
if config.Database.Name != "test_db" {
|
||||
t.Errorf("期望 Name 为 test_db,实际为 %s", config.Database.Name)
|
||||
}
|
||||
if config.Database.Type != "mysql" {
|
||||
t.Errorf("期望 Type 为 mysql,实际为 %s", config.Database.Type)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ 配置加载成功\n")
|
||||
fmt.Printf(" Host: %s\n", config.Database.Host)
|
||||
fmt.Printf(" Port: %s\n", config.Database.Port)
|
||||
fmt.Printf(" User: %s\n", config.Database.User)
|
||||
fmt.Printf(" Pass: %s\n", config.Database.Password)
|
||||
fmt.Printf(" Name: %s\n", config.Database.Name)
|
||||
fmt.Printf(" Type: %s\n", config.Database.Type)
|
||||
}
|
||||
|
||||
// TestBuildDSN 测试 DSN 构建
|
||||
func TestBuildDSN(t *testing.T) {
|
||||
fmt.Println("\n=== 测试 DSN 构建 ===")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config DatabaseConfig
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "MySQL",
|
||||
config: DatabaseConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: "3306",
|
||||
User: "root",
|
||||
Password: "password",
|
||||
Name: "testdb",
|
||||
Type: "mysql",
|
||||
},
|
||||
expected: "root:password@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL",
|
||||
config: DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "postgres",
|
||||
Password: "secret",
|
||||
Name: "mydb",
|
||||
Type: "postgres",
|
||||
},
|
||||
expected: "host=localhost port=5432 user=postgres password=secret dbname=mydb sslmode=disable",
|
||||
},
|
||||
{
|
||||
name: "SQLite",
|
||||
config: DatabaseConfig{
|
||||
Name: "./test.db",
|
||||
Type: "sqlite",
|
||||
},
|
||||
expected: "./test.db",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dsn := tc.config.BuildDSN()
|
||||
if dsn != tc.expected {
|
||||
t.Errorf("期望 DSN 为 %s,实际为 %s", tc.expected, dsn)
|
||||
}
|
||||
fmt.Printf("%s DSN: %s\n", tc.name, dsn)
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Println("✓ DSN 构建测试通过")
|
||||
}
|
||||
|
||||
// TestValidate 测试配置验证
|
||||
func TestValidate(t *testing.T) {
|
||||
fmt.Println("\n=== 测试配置验证 ===")
|
||||
|
||||
// 测试有效配置
|
||||
validConfig := &Config{
|
||||
Database: DatabaseConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: "3306",
|
||||
User: "root",
|
||||
Password: "pass",
|
||||
Name: "db",
|
||||
Type: "mysql",
|
||||
},
|
||||
}
|
||||
|
||||
if err := validConfig.Validate(); err != nil {
|
||||
t.Errorf("有效配置验证失败:%v", err)
|
||||
}
|
||||
fmt.Println("✓ MySQL 配置验证通过")
|
||||
|
||||
// 测试 SQLite 配置
|
||||
sqliteConfig := &Config{
|
||||
Database: DatabaseConfig{
|
||||
Name: "./test.db",
|
||||
Type: "sqlite",
|
||||
},
|
||||
}
|
||||
|
||||
if err := sqliteConfig.Validate(); err != nil {
|
||||
t.Errorf("SQLite 配置验证失败:%v", err)
|
||||
}
|
||||
fmt.Println("✓ SQLite 配置验证通过")
|
||||
|
||||
// 测试无效配置(缺少必填字段)
|
||||
invalidConfig := &Config{
|
||||
Database: DatabaseConfig{
|
||||
Host: "127.0.0.1",
|
||||
Type: "mysql",
|
||||
// 缺少其他必填字段
|
||||
},
|
||||
}
|
||||
|
||||
if err := invalidConfig.Validate(); err == nil {
|
||||
t.Error("无效配置应该验证失败")
|
||||
} else {
|
||||
fmt.Printf("✓ 无效配置正确拒绝:%v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllConfigLoading 完整配置加载测试
|
||||
func TestAllConfigLoading(t *testing.T) {
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" 数据库配置加载完整性测试")
|
||||
fmt.Println("========================================")
|
||||
|
||||
TestLoadFromFile(t)
|
||||
TestBuildDSN(t)
|
||||
TestValidate(t)
|
||||
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" 所有配置加载测试完成!")
|
||||
fmt.Println("========================================")
|
||||
fmt.Println()
|
||||
fmt.Println("已实现的配置加载功能:")
|
||||
fmt.Println(" ✓ 从 YAML 文件加载数据库配置")
|
||||
fmt.Println(" ✓ 支持 host, port, user, pass, name, type 字段")
|
||||
fmt.Println(" ✓ 自动验证配置完整性")
|
||||
fmt.Println(" ✓ 自动构建 MySQL DSN")
|
||||
fmt.Println(" ✓ 自动构建 PostgreSQL DSN")
|
||||
fmt.Println(" ✓ 自动构建 SQLite DSN")
|
||||
fmt.Println(" ✓ 支持多种数据库类型")
|
||||
fmt.Println()
|
||||
}
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestFindConfigOnlyCurrentDir 测试只在当前目录查找配置文件
|
||||
func TestFindConfigOnlyCurrentDir(t *testing.T) {
|
||||
fmt.Println("\n=== 测试只在当前目录查找配置文件 ===")
|
||||
|
||||
// 创建临时目录结构
|
||||
tempDir, err := os.MkdirTemp("", "config_test")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时目录失败:%v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// 创建子目录
|
||||
subDir := filepath.Join(tempDir, "subdir")
|
||||
if err := os.MkdirAll(subDir, 0755); err != nil {
|
||||
t.Fatalf("创建子目录失败:%v", err)
|
||||
}
|
||||
|
||||
// 在根目录创建配置文件
|
||||
configContent := `database:
|
||||
host: "127.0.0.1"
|
||||
port: "3306"
|
||||
user: "root"
|
||||
pass: "test"
|
||||
name: "testdb"
|
||||
type: "mysql"
|
||||
`
|
||||
|
||||
configFile := filepath.Join(tempDir, "config.yaml")
|
||||
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
|
||||
t.Fatalf("创建配置文件失败:%v", err)
|
||||
}
|
||||
|
||||
// 测试 1:从根目录查找(应该找到)
|
||||
foundPath, err := findConfigFile(tempDir)
|
||||
if err != nil {
|
||||
t.Errorf("从根目录查找失败:%v", err)
|
||||
} else {
|
||||
fmt.Printf("✓ 从根目录找到配置文件:%s\n", foundPath)
|
||||
}
|
||||
|
||||
// 测试 2:从子目录查找(不应该找到父目录的配置)
|
||||
foundPath, err = findConfigFile(subDir)
|
||||
if err == nil {
|
||||
t.Errorf("从子目录查找应该失败(不越级查找),但找到了:%s", foundPath)
|
||||
} else {
|
||||
fmt.Printf("✓ 从子目录查找正确失败(不越级):%v\n", err)
|
||||
}
|
||||
|
||||
// 测试 3:在子目录创建配置文件(应该找到)
|
||||
subConfigFile := filepath.Join(subDir, "config.yaml")
|
||||
if err := os.WriteFile(subConfigFile, []byte(configContent), 0644); err != nil {
|
||||
t.Fatalf("创建子目录配置文件失败:%v", err)
|
||||
}
|
||||
|
||||
foundPath, err = findConfigFile(subDir)
|
||||
if err != nil {
|
||||
t.Errorf("从子目录查找失败:%v", err)
|
||||
} else {
|
||||
fmt.Printf("✓ 从子目录找到配置文件:%s\n", foundPath)
|
||||
}
|
||||
|
||||
fmt.Println("✓ 只在当前目录查找测试通过")
|
||||
}
|
||||
|
||||
// TestNoParentSearch 测试不向上查找
|
||||
func TestNoParentSearch(t *testing.T) {
|
||||
fmt.Println("\n=== 测试不向上层目录查找 ===")
|
||||
|
||||
// 创建临时目录结构
|
||||
tempDir, err := os.MkdirTemp("", "no_parent_test")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时目录失败:%v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// 创建多级子目录
|
||||
level1 := filepath.Join(tempDir, "level1")
|
||||
level2 := filepath.Join(level1, "level2")
|
||||
level3 := filepath.Join(level2, "level3")
|
||||
|
||||
if err := os.MkdirAll(level3, 0755); err != nil {
|
||||
t.Fatalf("创建目录失败:%v", err)
|
||||
}
|
||||
|
||||
// 只在根目录创建配置文件
|
||||
configContent := `database:
|
||||
host: "127.0.0.1"
|
||||
port: "3306"
|
||||
user: "root"
|
||||
pass: "test"
|
||||
name: "testdb"
|
||||
type: "mysql"
|
||||
`
|
||||
|
||||
configFile := filepath.Join(tempDir, "config.yaml")
|
||||
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
|
||||
t.Fatalf("创建配置文件失败:%v", err)
|
||||
}
|
||||
|
||||
// 从各级子目录查找(都应该失败,因为不越级查找)
|
||||
testDirs := []string{level1, level2, level3}
|
||||
for _, dir := range testDirs {
|
||||
_, err := findConfigFile(dir)
|
||||
if err == nil {
|
||||
t.Errorf("从 %s 查找应该失败(不越级查找)", dir)
|
||||
} else {
|
||||
fmt.Printf("✓ 从 %s 查找正确失败(不越级)\n", filepath.Base(dir))
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("✓ 不向上层目录查找测试通过")
|
||||
}
|
||||
|
||||
// TestAllNoParentSearch 完整的不越级查找测试
|
||||
func TestAllNoParentSearch(t *testing.T) {
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" 不越级查找完整性测试")
|
||||
fmt.Println("========================================")
|
||||
|
||||
TestFindConfigOnlyCurrentDir(t)
|
||||
TestNoParentSearch(t)
|
||||
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" 所有不越级查找测试完成!")
|
||||
fmt.Println("========================================")
|
||||
fmt.Println()
|
||||
fmt.Println("已实现的不越级查找功能:")
|
||||
fmt.Println(" ✓ 只在当前工作目录查找配置文件")
|
||||
fmt.Println(" ✓ 不会向上层目录查找")
|
||||
fmt.Println(" ✓ 支持 yaml, yml, toml, ini, json 格式")
|
||||
fmt.Println(" ✓ 支持 config.* 和 .config.* 命名")
|
||||
fmt.Println(" ✓ 提供 AutoConnect() 一键连接")
|
||||
fmt.Println(" ✓ 无需手动指定配置文件路径")
|
||||
fmt.Println()
|
||||
}
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.magicany.cc/black1552/gin-base/db/core"
|
||||
"git.magicany.cc/black1552/gin-base/db/driver"
|
||||
)
|
||||
|
||||
// 示例:在应用程序中使用纯自研驱动管理
|
||||
func ExampleUsage() {
|
||||
// 1. 首先导入你选择的第三方数据库驱动
|
||||
// 注意:这些驱动需要在 main 包或启动包中导入
|
||||
/*
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite 驱动
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL 驱动
|
||||
_ "github.com/lib/pq" // PostgreSQL 驱动
|
||||
_ "github.com/denisenkom/go-mssqldb" // SQL Server 驱动
|
||||
)
|
||||
*/
|
||||
|
||||
// 2. 获取驱动管理器并注册你选择的驱动
|
||||
manager := driver.GetDefaultManager()
|
||||
|
||||
// 注册 SQLite 驱动
|
||||
sqliteDriver := driver.NewGenericDriver("sqlite3")
|
||||
manager.Register("sqlite3", sqliteDriver)
|
||||
manager.Register("sqlite", sqliteDriver) // 别名
|
||||
|
||||
// 注册 MySQL 驱动
|
||||
mysqlDriver := driver.NewGenericDriver("mysql")
|
||||
manager.Register("mysql", mysqlDriver)
|
||||
|
||||
// 注册 PostgreSQL 驱动
|
||||
postgresDriver := driver.NewGenericDriver("postgres")
|
||||
manager.Register("postgres", postgresDriver)
|
||||
|
||||
// 3. 加载配置文件
|
||||
configFile, err := LoadFromFile("config.yaml")
|
||||
if err != nil {
|
||||
fmt.Printf("加载配置失败:%v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 验证驱动是否已注册
|
||||
err = manager.RegisterDriverByConfig(configFile.Database.Type)
|
||||
if err != nil {
|
||||
fmt.Printf("驱动未注册:%v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 使用配置创建数据库连接
|
||||
dbConfig := &core.Config{
|
||||
DriverName: configFile.Database.GetDriverName(),
|
||||
DataSource: configFile.Database.BuildDSN(),
|
||||
Debug: true,
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 3600000000000, // 1小时
|
||||
}
|
||||
|
||||
// 6. 创建数据库实例
|
||||
db, err := core.NewDatabase(dbConfig)
|
||||
if err != nil {
|
||||
fmt.Printf("创建数据库连接失败:%v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("成功连接到 %s 数据库\n", configFile.Database.Type)
|
||||
|
||||
// 现在可以使用 db 进行数据库操作
|
||||
_ = db
|
||||
}
|
||||
|
||||
// AdvancedExample 高级使用示例
|
||||
func AdvancedExample() {
|
||||
manager := driver.GetDefaultManager()
|
||||
|
||||
// 根据环境变量或配置动态注册驱动
|
||||
databaseType := "sqlite3" // 从配置获取
|
||||
|
||||
// 注册对应驱动
|
||||
genericDriver := driver.NewGenericDriver(databaseType)
|
||||
manager.Register(databaseType, genericDriver)
|
||||
|
||||
// 验证驱动
|
||||
err := manager.RegisterDriverByConfig(databaseType)
|
||||
if err != nil {
|
||||
fmt.Printf("驱动注册问题:%v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 现在可以安全地打开数据库连接
|
||||
db, err := manager.Open(databaseType, "./example.db")
|
||||
if err != nil {
|
||||
fmt.Printf("打开数据库失败:%v\n", err)
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
fmt.Println("数据库连接成功")
|
||||
}
|
||||
|
|
@ -1,216 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/db/core"
|
||||
"git.magicany.cc/black1552/gin-base/db/model"
|
||||
)
|
||||
|
||||
// TestTimeConfig 测试时间配置
|
||||
func TestTimeConfig(t *testing.T) {
|
||||
fmt.Println("\n=== 测试时间配置 ===")
|
||||
|
||||
// 测试默认配置
|
||||
defaultConfig := core.DefaultTimeConfig()
|
||||
fmt.Printf("默认创建时间字段:%s\n", defaultConfig.GetCreatedAt())
|
||||
fmt.Printf("默认更新时间字段:%s\n", defaultConfig.GetUpdatedAt())
|
||||
fmt.Printf("默认删除时间字段:%s\n", defaultConfig.GetDeletedAt())
|
||||
fmt.Printf("默认时间格式:%s\n", defaultConfig.GetFormat())
|
||||
|
||||
// 测试自定义配置
|
||||
customConfig := &core.TimeConfig{
|
||||
CreatedAt: "create_time",
|
||||
UpdatedAt: "update_time",
|
||||
DeletedAt: "delete_time",
|
||||
Format: "2006-01-02 15:04:05",
|
||||
}
|
||||
customConfig.Validate()
|
||||
|
||||
fmt.Printf("\n自定义创建时间字段:%s\n", customConfig.GetCreatedAt())
|
||||
fmt.Printf("自定义更新时间字段:%s\n", customConfig.GetUpdatedAt())
|
||||
fmt.Printf("自定义删除时间字段:%s\n", customConfig.GetDeletedAt())
|
||||
fmt.Printf("自定义时间格式:%s\n", customConfig.GetFormat())
|
||||
|
||||
// 测试格式化
|
||||
now := time.Now()
|
||||
formatted := customConfig.FormatTime(now)
|
||||
fmt.Printf("\n格式化时间:%s -> %s\n", now.Format("2006-01-02 15:04:05"), formatted)
|
||||
|
||||
// 测试解析
|
||||
parsed, err := customConfig.ParseTime(formatted)
|
||||
if err != nil {
|
||||
t.Errorf("解析时间失败:%v", err)
|
||||
}
|
||||
fmt.Printf("解析时间:%s -> %s\n", formatted, parsed.Format("2006-01-02 15:04:05"))
|
||||
|
||||
fmt.Println("✓ 时间配置测试通过")
|
||||
}
|
||||
|
||||
// TestCustomTimeFields 测试自定义时间字段
|
||||
func TestCustomTimeFields(t *testing.T) {
|
||||
fmt.Println("\n=== 测试自定义时间字段模型 ===")
|
||||
|
||||
// 使用自定义字段的模型
|
||||
type CustomModel struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Name string `json:"name" db:"name"`
|
||||
CreateTime model.Time `json:"create_time" db:"create_time"` // 自定义创建时间字段
|
||||
UpdateTime model.Time `json:"update_time" db:"update_time"` // 自定义更新时间字段
|
||||
DeleteTime *model.Time `json:"delete_time,omitempty" db:"delete_time"` // 自定义删除时间字段
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
custom := &CustomModel{
|
||||
ID: 1,
|
||||
Name: "test",
|
||||
CreateTime: model.Time{Time: now},
|
||||
UpdateTime: model.Time{Time: now},
|
||||
}
|
||||
|
||||
// 序列化为 JSON
|
||||
jsonData, err := json.Marshal(custom)
|
||||
if err != nil {
|
||||
t.Errorf("JSON 序列化失败:%v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("原始时间:%s\n", now.Format("2006-01-02 15:04:05"))
|
||||
fmt.Printf("JSON 输出:%s\n", string(jsonData))
|
||||
|
||||
// 验证时间格式
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(jsonData, &result); err != nil {
|
||||
t.Errorf("JSON 反序列化失败:%v", err)
|
||||
}
|
||||
|
||||
createTime, ok := result["create_time"].(string)
|
||||
if !ok {
|
||||
t.Error("create_time 应该是字符串")
|
||||
}
|
||||
|
||||
_, err = time.Parse("2006-01-02 15:04:05", createTime)
|
||||
if err != nil {
|
||||
t.Errorf("时间格式不正确:%v", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ 自定义时间字段测试通过")
|
||||
}
|
||||
|
||||
// TestDatabaseWithTimeConfig 测试数据库配置中的时间配置
|
||||
func TestDatabaseWithTimeConfig(t *testing.T) {
|
||||
fmt.Println("\n=== 测试数据库时间配置 ===")
|
||||
|
||||
// 创建带自定义时间配置的 Config
|
||||
config := &core.Config{
|
||||
DriverName: "sqlite",
|
||||
DataSource: ":memory:",
|
||||
Debug: true,
|
||||
TimeConfig: &core.TimeConfig{
|
||||
CreatedAt: "created_at",
|
||||
UpdatedAt: "updated_at",
|
||||
DeletedAt: "deleted_at",
|
||||
Format: "2006-01-02 15:04:05",
|
||||
},
|
||||
}
|
||||
|
||||
fmt.Printf("配置中的创建时间字段:%s\n", config.TimeConfig.GetCreatedAt())
|
||||
fmt.Printf("配置中的更新时间字段:%s\n", config.TimeConfig.GetUpdatedAt())
|
||||
fmt.Printf("配置中的删除时间字段:%s\n", config.TimeConfig.GetDeletedAt())
|
||||
fmt.Printf("配置中的时间格式:%s\n", config.TimeConfig.GetFormat())
|
||||
|
||||
// 注意:这里不实际创建数据库连接,仅测试配置
|
||||
fmt.Println("\n数据库会使用该配置自动处理时间字段:")
|
||||
fmt.Println(" - Insert: 自动设置 created_at/updated_at 为当前时间")
|
||||
fmt.Println(" - Update: 自动设置 updated_at 为当前时间")
|
||||
fmt.Println(" - Delete: 软删除时设置 deleted_at 为当前时间")
|
||||
fmt.Println(" - Read: 所有时间字段格式化为 YYYY-MM-DD HH:mm:ss")
|
||||
|
||||
fmt.Println("✓ 数据库时间配置测试通过")
|
||||
}
|
||||
|
||||
// TestAllTimeFormats 测试所有时间格式
|
||||
func TestAllTimeFormats(t *testing.T) {
|
||||
fmt.Println("\n=== 测试所有支持的时间格式 ===")
|
||||
|
||||
testCases := []struct {
|
||||
format string
|
||||
timeStr string
|
||||
}{
|
||||
{"2006-01-02 15:04:05", "2026-04-02 22:09:09"},
|
||||
{"2006/01/02 15:04:05", "2026/04/02 22:09:09"},
|
||||
{"2006-01-02T15:04:05", "2026-04-02T22:09:09"},
|
||||
{"2006-01-02", "2026-04-02"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.format, func(t *testing.T) {
|
||||
parsed, err := time.Parse(tc.format, tc.timeStr)
|
||||
if err != nil {
|
||||
t.Logf("格式 %s 解析失败:%v", tc.format, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 统一格式化为标准格式
|
||||
formatted := parsed.Format("2006-01-02 15:04:05")
|
||||
fmt.Printf("%s -> %s\n", tc.timeStr, formatted)
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Println("✓ 所有时间格式测试通过")
|
||||
}
|
||||
|
||||
// TestDateTimeType 测试 datetime 类型支持
|
||||
func TestDateTimeType(t *testing.T) {
|
||||
fmt.Println("\n=== 测试 DATETIME 类型支持 ===")
|
||||
|
||||
// Go 的 time.Time 会自动映射到数据库的 DATETIME 类型
|
||||
now := time.Now()
|
||||
|
||||
// 在 SQLite 中,DATETIME 存储为 TEXT(ISO8601 格式)
|
||||
// 在 MySQL 中,DATETIME 存储为 DATETIME 类型
|
||||
// Go 的 database/sql 会自动处理类型转换
|
||||
|
||||
fmt.Printf("Go time.Time: %s\n", now.Format("2006-01-02 15:04:05"))
|
||||
fmt.Printf("数据库 DATETIME: 自动映射(由驱动处理)\n")
|
||||
fmt.Println(" - SQLite: TEXT (ISO8601)")
|
||||
fmt.Println(" - MySQL: DATETIME")
|
||||
fmt.Println(" - PostgreSQL: TIMESTAMP")
|
||||
|
||||
// model.Time 包装后仍然保持 time.Time 的特性
|
||||
customTime := model.Time{Time: now}
|
||||
fmt.Printf("model.Time: %s\n", customTime.String())
|
||||
|
||||
fmt.Println("✓ DATETIME 类型测试通过")
|
||||
}
|
||||
|
||||
// TestCompleteTimeHandling 完整时间处理测试
|
||||
func TestCompleteTimeHandling(t *testing.T) {
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" CRUD 操作时间配置完整性测试")
|
||||
fmt.Println("========================================")
|
||||
|
||||
TestTimeConfig(t)
|
||||
TestCustomTimeFields(t)
|
||||
TestDatabaseWithTimeConfig(t)
|
||||
TestAllTimeFormats(t)
|
||||
TestDateTimeType(t)
|
||||
|
||||
fmt.Println("\n========================================")
|
||||
fmt.Println(" 所有时间配置测试完成!")
|
||||
fmt.Println("========================================")
|
||||
fmt.Println()
|
||||
fmt.Println("已实现的时间配置功能:")
|
||||
fmt.Println(" ✓ 配置文件定义创建时间字段名")
|
||||
fmt.Println(" ✓ 配置文件定义更新时间字段名")
|
||||
fmt.Println(" ✓ 配置文件定义删除时间字段名")
|
||||
fmt.Println(" ✓ 配置文件定义时间格式(默认年 - 月-日 时:分:秒)")
|
||||
fmt.Println(" ✓ Insert: 自动设置配置的时间字段")
|
||||
fmt.Println(" ✓ Update: 自动设置配置的更新时间字段")
|
||||
fmt.Println(" ✓ Delete: 软删除使用配置的删除时间字段")
|
||||
fmt.Println(" ✓ Read: 所有时间字段格式化为配置的格式")
|
||||
fmt.Println(" ✓ 支持 DATETIME 类型自动映射")
|
||||
fmt.Println()
|
||||
}
|
||||
148
db/core/cache.go
148
db/core/cache.go
|
|
@ -1,148 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheItem 缓存项
|
||||
type CacheItem struct {
|
||||
Data interface{} // 缓存的数据
|
||||
ExpiresAt time.Time // 过期时间
|
||||
}
|
||||
|
||||
// QueryCache 查询缓存 - 提高重复查询的性能
|
||||
type QueryCache struct {
|
||||
mu sync.RWMutex // 读写锁
|
||||
items map[string]*CacheItem // 缓存项
|
||||
duration time.Duration // 默认缓存时长
|
||||
}
|
||||
|
||||
// NewQueryCache 创建查询缓存实例
|
||||
func NewQueryCache(duration time.Duration) *QueryCache {
|
||||
cache := &QueryCache{
|
||||
items: make(map[string]*CacheItem),
|
||||
duration: duration,
|
||||
}
|
||||
|
||||
// 启动清理协程
|
||||
go cache.cleaner()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (qc *QueryCache) Set(key string, data interface{}) {
|
||||
qc.mu.Lock()
|
||||
defer qc.mu.Unlock()
|
||||
|
||||
qc.items[key] = &CacheItem{
|
||||
Data: data,
|
||||
ExpiresAt: time.Now().Add(qc.duration),
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取缓存
|
||||
func (qc *QueryCache) Get(key string) (interface{}, bool) {
|
||||
qc.mu.RLock()
|
||||
defer qc.mu.RUnlock()
|
||||
|
||||
item, exists := qc.items[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return item.Data, true
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (qc *QueryCache) Delete(key string) {
|
||||
qc.mu.Lock()
|
||||
defer qc.mu.Unlock()
|
||||
delete(qc.items, key)
|
||||
}
|
||||
|
||||
// Clear 清空所有缓存
|
||||
func (qc *QueryCache) Clear() {
|
||||
qc.mu.Lock()
|
||||
defer qc.mu.Unlock()
|
||||
qc.items = make(map[string]*CacheItem)
|
||||
}
|
||||
|
||||
// cleaner 定期清理过期缓存
|
||||
func (qc *QueryCache) cleaner() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
qc.cleanExpired()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanExpired 清理过期的缓存项
|
||||
func (qc *QueryCache) cleanExpired() {
|
||||
qc.mu.Lock()
|
||||
defer qc.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, item := range qc.items {
|
||||
if now.After(item.ExpiresAt) {
|
||||
delete(qc.items, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateCacheKey 生成缓存键
|
||||
func GenerateCacheKey(sql string, args ...interface{}) string {
|
||||
// 将 SQL 和参数组合成字符串
|
||||
keyData := sql
|
||||
for _, arg := range args {
|
||||
keyData += fmt.Sprintf("%v", arg)
|
||||
}
|
||||
|
||||
// 计算 MD5 哈希
|
||||
hash := md5.Sum([]byte(keyData))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// deepCopy 深拷贝数据(使用 JSON 序列化/反序列化)
|
||||
func deepCopy(src, dst interface{}) error {
|
||||
// 序列化为 JSON
|
||||
data, err := json.Marshal(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化失败:%w", err)
|
||||
}
|
||||
|
||||
// 反序列化到目标
|
||||
if err := json.Unmarshal(data, dst); err != nil {
|
||||
return fmt.Errorf("反序列化失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithCache 带缓存的查询装饰器
|
||||
func (q *QueryBuilder) WithCache(cache *QueryCache) IQuery {
|
||||
if cache == nil {
|
||||
return q
|
||||
}
|
||||
|
||||
// 设置缓存实例
|
||||
q.cache = cache
|
||||
q.useCache = true
|
||||
|
||||
// 生成缓存键(使用 SQL 和参数)
|
||||
sqlStr, args := q.BuildSelect()
|
||||
q.cacheKey = GenerateCacheKey(sqlStr, args...)
|
||||
|
||||
return q
|
||||
}
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestWithCache 测试带缓存的查询
|
||||
func TestWithCache(t *testing.T) {
|
||||
fmt.Println("\n=== 测试带缓存的查询 ===")
|
||||
|
||||
// 创建缓存实例(缓存 5 分钟)
|
||||
_ = NewQueryCache(5 * time.Minute)
|
||||
|
||||
// 注意:这个测试需要真实的数据库连接
|
||||
// 以下是使用示例:
|
||||
|
||||
// 示例 1: 基本缓存查询
|
||||
// var users []User
|
||||
// err := db.Model(&User{}).
|
||||
// Where("status = ?", "active").
|
||||
// WithCache(cache).
|
||||
// Find(&users)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
// 示例 2: 第二次查询会命中缓存
|
||||
// var users2 []User
|
||||
// err = db.Model(&User{}).
|
||||
// Where("status = ?", "active").
|
||||
// WithCache(cache).
|
||||
// Find(&users2)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
fmt.Println("✓ WithCache 已实现")
|
||||
fmt.Println("功能:")
|
||||
fmt.Println(" - 缓存命中时直接返回数据,不执行 SQL")
|
||||
fmt.Println(" - 缓存未命中时执行查询并自动缓存结果")
|
||||
fmt.Println(" - 支持深拷贝,避免引用问题")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
|
||||
// TestDeepCopy 测试深拷贝功能
|
||||
func TestDeepCopy(t *testing.T) {
|
||||
fmt.Println("\n=== 测试深拷贝功能 ===")
|
||||
|
||||
type TestData struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
src := &TestData{ID: 1, Name: "test"}
|
||||
dst := &TestData{}
|
||||
|
||||
err := deepCopy(src, dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if dst.ID != src.ID || dst.Name != src.Name {
|
||||
t.Errorf("深拷贝失败:期望 %+v, 得到 %+v", src, dst)
|
||||
}
|
||||
|
||||
// 修改源数据,目标不应该受影响
|
||||
src.Name = "modified"
|
||||
if dst.Name == "modified" {
|
||||
t.Error("深拷贝失败:目标受到了源数据修改的影响")
|
||||
}
|
||||
|
||||
fmt.Println("✓ 深拷贝功能正常")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
|
||||
// TestCacheKeyGeneration 测试缓存键生成
|
||||
func TestCacheKeyGeneration(t *testing.T) {
|
||||
fmt.Println("\n=== 测试缓存键生成 ===")
|
||||
|
||||
// 相同的 SQL 和参数应该生成相同的键
|
||||
key1 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1)
|
||||
key2 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 1)
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("缓存键不一致:%s vs %s", key1, key2)
|
||||
}
|
||||
|
||||
// 不同的参数应该生成不同的键
|
||||
key3 := GenerateCacheKey("SELECT * FROM user WHERE id = ?", 2)
|
||||
if key1 == key3 {
|
||||
t.Error("不同的参数应该生成不同的缓存键")
|
||||
}
|
||||
|
||||
fmt.Println("✓ 缓存键生成正常")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
|
||||
// ExampleWithCache 使用示例
|
||||
func exampleWithCacheUsage() {
|
||||
// 示例 1: 基本用法
|
||||
// cache := NewQueryCache(5 * time.Minute)
|
||||
// var results []map[string]interface{}
|
||||
// err := db.Table("users").
|
||||
// Where("status = ?", "active").
|
||||
// WithCache(cache).
|
||||
// Find(&results)
|
||||
|
||||
// 示例 2: 带条件的查询
|
||||
// err := db.Model(&User{}).
|
||||
// Select("id", "username", "email").
|
||||
// Where("age > ?", 18).
|
||||
// Order("created_at DESC").
|
||||
// Limit(10).
|
||||
// WithCache(cache).
|
||||
// Find(&results)
|
||||
|
||||
// 示例 3: 清除缓存
|
||||
// cache.Clear() // 清空所有缓存
|
||||
// cache.Delete(key) // 删除指定缓存
|
||||
|
||||
fmt.Println("使用示例请查看测试代码")
|
||||
}
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TimeConfig 时间配置 - 定义时间字段名称和格式
|
||||
type TimeConfig struct {
|
||||
CreatedAt string `json:"created_at" yaml:"created_at"` // 创建时间字段名
|
||||
UpdatedAt string `json:"updated_at" yaml:"updated_at"` // 更新时间字段名
|
||||
DeletedAt string `json:"deleted_at" yaml:"deleted_at"` // 删除时间字段名
|
||||
Format string `json:"format" yaml:"format"` // 时间格式,默认 "2006-01-02 15:04:05"
|
||||
}
|
||||
|
||||
// DefaultTimeConfig 获取默认时间配置
|
||||
func DefaultTimeConfig() *TimeConfig {
|
||||
return &TimeConfig{
|
||||
CreatedAt: "created_at",
|
||||
UpdatedAt: "updated_at",
|
||||
DeletedAt: "deleted_at",
|
||||
Format: "2006-01-02 15:04:05", // Go 的参考时间格式
|
||||
}
|
||||
}
|
||||
|
||||
// Validate 验证时间配置
|
||||
func (tc *TimeConfig) Validate() {
|
||||
if tc.CreatedAt == "" {
|
||||
tc.CreatedAt = "created_at"
|
||||
}
|
||||
if tc.UpdatedAt == "" {
|
||||
tc.UpdatedAt = "updated_at"
|
||||
}
|
||||
if tc.DeletedAt == "" {
|
||||
tc.DeletedAt = "deleted_at"
|
||||
}
|
||||
if tc.Format == "" {
|
||||
tc.Format = "2006-01-02 15:04:05"
|
||||
}
|
||||
}
|
||||
|
||||
// GetCreatedAt 获取创建时间字段名
|
||||
func (tc *TimeConfig) GetCreatedAt() string {
|
||||
tc.Validate()
|
||||
return tc.CreatedAt
|
||||
}
|
||||
|
||||
// GetUpdatedAt 获取更新时间字段名
|
||||
func (tc *TimeConfig) GetUpdatedAt() string {
|
||||
tc.Validate()
|
||||
return tc.UpdatedAt
|
||||
}
|
||||
|
||||
// GetDeletedAt 获取删除时间字段名
|
||||
func (tc *TimeConfig) GetDeletedAt() string {
|
||||
tc.Validate()
|
||||
return tc.DeletedAt
|
||||
}
|
||||
|
||||
// GetFormat 获取时间格式
|
||||
func (tc *TimeConfig) GetFormat() string {
|
||||
tc.Validate()
|
||||
return tc.Format
|
||||
}
|
||||
|
||||
// FormatTime 格式化时间为配置的格式
|
||||
func (tc *TimeConfig) FormatTime(t time.Time) string {
|
||||
tc.Validate()
|
||||
return t.Format(tc.Format)
|
||||
}
|
||||
|
||||
// ParseTime 解析时间字符串
|
||||
func (tc *TimeConfig) ParseTime(timeStr string) (time.Time, error) {
|
||||
tc.Validate()
|
||||
return time.Parse(tc.Format, timeStr)
|
||||
}
|
||||
314
db/core/dao.go
314
db/core/dao.go
|
|
@ -1,314 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// DAO 数据访问对象基类 - 所有 DAO 都继承此结构
|
||||
// 提供通用的 CRUD 操作方法,子类只需嵌入即可使用
|
||||
type DAO struct {
|
||||
db *Database // 数据库连接实例
|
||||
modelType interface{} // 模型类型信息,用于 Columns 等方法
|
||||
}
|
||||
|
||||
// NewDAO 创建 DAO 基类实例
|
||||
// 自动使用全局默认 Database 实例
|
||||
func NewDAO() *DAO {
|
||||
return &DAO{
|
||||
db: GetDefaultDatabase(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewDAOWithModel 创建带模型类型的 DAO 基类实例
|
||||
// 参数:
|
||||
// - model: 模型实例(指针类型),用于获取表结构信息
|
||||
//
|
||||
// 自动使用全局默认 Database 实例
|
||||
func NewDAOWithModel(model interface{}) *DAO {
|
||||
return &DAO{
|
||||
db: GetDefaultDatabase(),
|
||||
modelType: model,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建记录(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) Create(ctx context.Context, model interface{}) error {
|
||||
// 使用事务来插入数据
|
||||
tx, err := dao.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Insert(model)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 查询单条记录(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) GetByID(ctx context.Context, model interface{}, id int64) error {
|
||||
return dao.db.Model(model).Where("id = ?", id).First(model)
|
||||
}
|
||||
|
||||
// Update 更新记录(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) Update(ctx context.Context, model interface{}, data map[string]interface{}) error {
|
||||
pkValue := getFieldValue(model, "ID")
|
||||
|
||||
if pkValue == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return dao.db.Model(model).Where("id = ?", pkValue).Updates(data)
|
||||
}
|
||||
|
||||
// Delete 删除记录(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) Delete(ctx context.Context, model interface{}) error {
|
||||
pkValue := getFieldValue(model, "ID")
|
||||
|
||||
if pkValue == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return dao.db.Model(model).Where("id = ?", pkValue).Delete()
|
||||
}
|
||||
|
||||
// FindAll 查询所有记录(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) FindAll(ctx context.Context, model interface{}) error {
|
||||
return dao.db.Model(model).Find(model)
|
||||
}
|
||||
|
||||
// FindByPage 分页查询(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) FindByPage(ctx context.Context, model interface{}, page, pageSize int) error {
|
||||
return dao.db.Model(model).Limit(pageSize).Offset((page - 1) * pageSize).Find(model)
|
||||
}
|
||||
|
||||
// Count 统计记录数(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) Count(ctx context.Context, model interface{}, where ...string) (int64, error) {
|
||||
var count int64
|
||||
|
||||
query := dao.db.Model(model)
|
||||
if len(where) > 0 {
|
||||
query = query.Where(where[0])
|
||||
}
|
||||
|
||||
// Count 是链式调用,需要调用 Find 来执行
|
||||
err := query.Count(&count).Find(model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Exists 检查记录是否存在(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) Exists(ctx context.Context, model interface{}) (bool, error) {
|
||||
count, err := dao.Count(ctx, model)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// First 查询第一条记录(通用方法)
|
||||
// 自动使用 DAO 中已关联的 Database 实例
|
||||
func (dao *DAO) First(ctx context.Context, model interface{}) error {
|
||||
return dao.db.Model(model).First(model)
|
||||
}
|
||||
|
||||
// Columns 获取表的所有列名
|
||||
// 返回一个动态创建的结构体类型,所有字段都是 string 类型
|
||||
// 用途:用于构建 UPDATE、INSERT 等操作时的列名映射
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// type UserDAO struct {
|
||||
// *core.DAO
|
||||
// }
|
||||
//
|
||||
// func NewUserDAO() *UserDAO {
|
||||
// return &UserDAO{
|
||||
// DAO: core.NewDAOWithModel(&model.User{}),
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 使用
|
||||
// dao := NewUserDAO()
|
||||
// cols := dao.Columns() // 返回 *struct{ID string; Username string; ...}
|
||||
func (dao *DAO) Columns() interface{} {
|
||||
// 检查是否有模型类型信息
|
||||
if dao.modelType == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取模型类型
|
||||
modelType := reflect.TypeOf(dao.modelType)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// 创建字段列表
|
||||
fields := []reflect.StructField{}
|
||||
|
||||
// 遍历模型的所有字段
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// 跳过未导出的字段
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 db 标签,如果没有则跳过
|
||||
dbTag := field.Tag.Get("db")
|
||||
if dbTag == "" || dbTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 创建新的结构体字段,类型为 string
|
||||
newField := reflect.StructField{
|
||||
Name: field.Name,
|
||||
Type: reflect.TypeOf(""), // string 类型
|
||||
Tag: reflect.StructTag(`json:"` + field.Tag.Get("json") + `" db:"` + dbTag + `"`),
|
||||
}
|
||||
|
||||
fields = append(fields, newField)
|
||||
}
|
||||
|
||||
// 动态创建结构体类型
|
||||
columnsType := reflect.StructOf(fields)
|
||||
|
||||
// 创建该类型的指针并返回
|
||||
return reflect.New(columnsType).Interface()
|
||||
}
|
||||
|
||||
// getFieldValue 获取结构体字段值(辅助函数)
|
||||
// 用于获取主键或其他字段的值,支持多种数据类型
|
||||
// 参数:
|
||||
// - model: 模型实例(可以是指针或值)
|
||||
// - fieldName: 字段名(如 "ID", "UserId" 等)
|
||||
//
|
||||
// 返回:
|
||||
// - int64: 字段值(如果是数字类型)或 0(无法获取时)
|
||||
func getFieldValue(model interface{}, fieldName string) int64 {
|
||||
// 检查空值
|
||||
if model == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 获取反射对象
|
||||
val := reflect.ValueOf(model)
|
||||
|
||||
// 如果是指针,解引用
|
||||
if val.Kind() == reflect.Ptr {
|
||||
if val.IsNil() {
|
||||
return 0
|
||||
}
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
// 确保是结构体
|
||||
if val.Kind() != reflect.Struct {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 查找字段
|
||||
field := val.FieldByName(fieldName)
|
||||
if !field.IsValid() {
|
||||
// 尝试查找常见的主键字段名变体
|
||||
alternativeNames := []string{"Id", "id", "ID"}
|
||||
for _, name := range alternativeNames {
|
||||
if name != fieldName {
|
||||
field = val.FieldByName(name)
|
||||
if field.IsValid() {
|
||||
fieldName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !field.IsValid() {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// 检查字段是否可以访问
|
||||
if !field.CanInterface() {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 获取字段值并转换为 int64
|
||||
fieldValue := field.Interface()
|
||||
|
||||
// 根据字段类型进行转换
|
||||
switch v := fieldValue.(type) {
|
||||
case int:
|
||||
return int64(v)
|
||||
case int8:
|
||||
return int64(v)
|
||||
case int16:
|
||||
return int64(v)
|
||||
case int32:
|
||||
return int64(v)
|
||||
case int64:
|
||||
return v
|
||||
case uint:
|
||||
return int64(v)
|
||||
case uint8:
|
||||
return int64(v)
|
||||
case uint16:
|
||||
return int64(v)
|
||||
case uint32:
|
||||
return int64(v)
|
||||
case uint64:
|
||||
// 注意:uint64 转 int64 可能溢出,但这里假设 ID 不会超过 int64 范围
|
||||
return int64(v)
|
||||
case float32:
|
||||
return int64(v)
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
// 尝试将字符串解析为数字
|
||||
// 注意:这里不导入 strconv,简单处理返回 0
|
||||
return 0
|
||||
default:
|
||||
// 其他类型(如 sql.NullInt64 等),尝试使用反射
|
||||
return convertToInteger(field)
|
||||
}
|
||||
}
|
||||
|
||||
// convertToInteger 使用反射将字段值转换为 int64
|
||||
func convertToInteger(field reflect.Value) int64 {
|
||||
// 获取实际的值(如果是指针则解引用)
|
||||
if field.Kind() == reflect.Ptr {
|
||||
if field.IsNil() {
|
||||
return 0
|
||||
}
|
||||
field = field.Elem()
|
||||
}
|
||||
|
||||
// 根据 Kind 进行转换
|
||||
switch field.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return field.Int()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return int64(field.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return int64(field.Float())
|
||||
case reflect.String:
|
||||
// 字符串类型,尝试解析(简单实现,不处理错误)
|
||||
return 0
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
|
@ -1,179 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetFieldValue 测试获取字段值的基本功能
|
||||
func TestGetFieldValue(t *testing.T) {
|
||||
fmt.Println("\n=== 测试 getFieldValue 基本功能 ===")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model interface{}
|
||||
fieldName string
|
||||
expected int64
|
||||
}{
|
||||
{"int 类型", &TestModelInt{ID: 123}, "ID", 123},
|
||||
{"int64 类型", &TestModelInt64{ID: 456}, "ID", 456},
|
||||
{"uint 类型", &TestModelUint{ID: 789}, "ID", 789},
|
||||
{"float 类型", &TestModelFloat{ID: 999.5}, "ID", 999},
|
||||
{"指针为 nil", (*TestModelInt)(nil), "ID", 0},
|
||||
{"model 为 nil", nil, "ID", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getFieldValue(tt.model, tt.fieldName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("期望 %d, 得到 %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Println("✓ 基本功能测试通过")
|
||||
}
|
||||
|
||||
// TestGetFieldValueAlternativeNames 测试字段名变体查找
|
||||
func TestGetFieldValueAlternativeNames(t *testing.T) {
|
||||
fmt.Println("\n=== 测试字段名变体查找 ===")
|
||||
|
||||
// 测试 Id 字段(驼峰)
|
||||
model1 := &TestModelId{Id: 111}
|
||||
result1 := getFieldValue(model1, "ID")
|
||||
if result1 != 111 {
|
||||
t.Errorf("期望 111, 得到 %d", result1)
|
||||
}
|
||||
|
||||
fmt.Println("✓ 字段名变体查找测试通过")
|
||||
}
|
||||
|
||||
// TestGetFieldValueEdgeCases 测试边界情况
|
||||
func TestGetFieldValueEdgeCases(t *testing.T) {
|
||||
fmt.Println("\n=== 测试边界情况 ===")
|
||||
|
||||
// 测试非结构体类型
|
||||
nonStruct := 123
|
||||
result := getFieldValue(nonStruct, "ID")
|
||||
if result != 0 {
|
||||
t.Errorf("非结构体应该返回 0, 得到 %d", result)
|
||||
}
|
||||
|
||||
// 测试不存在的字段(没有 ID/Id/id 等变体)
|
||||
type ModelNoID struct {
|
||||
Name string `json:"name" db:"name"`
|
||||
}
|
||||
model := &ModelNoID{Name: "test"}
|
||||
result = getFieldValue(model, "NonExistentField")
|
||||
if result != 0 {
|
||||
t.Errorf("不存在的字段应该返回 0, 得到 %d", result)
|
||||
}
|
||||
|
||||
fmt.Println("✓ 边界情况测试通过")
|
||||
}
|
||||
|
||||
// TestGetFieldValueSpecialTypes 测试特殊类型
|
||||
func TestGetFieldValueSpecialTypes(t *testing.T) {
|
||||
fmt.Println("\n=== 测试特殊类型 ===")
|
||||
|
||||
// 注意:sql.NullInt64 等数据库特殊类型目前不支持
|
||||
// 如果需要支持,可以在 convertToInteger 中添加专门的处理逻辑
|
||||
|
||||
fmt.Println("✓ 特殊类型测试通过(当前版本不支持 sql.NullInt64)")
|
||||
}
|
||||
|
||||
// TestGetFieldValueInUpdate 测试在 Update 场景中的使用
|
||||
func TestGetFieldValueInUpdate(t *testing.T) {
|
||||
fmt.Println("\n=== 测试 Update 场景 ===")
|
||||
|
||||
user := &UserModel{
|
||||
ID: 1,
|
||||
Username: "test",
|
||||
}
|
||||
|
||||
pkValue := getFieldValue(user, "ID")
|
||||
if pkValue != 1 {
|
||||
t.Errorf("期望主键值为 1, 得到 %d", pkValue)
|
||||
}
|
||||
|
||||
// 测试主键为 0 的情况
|
||||
user2 := &UserModel{
|
||||
ID: 0,
|
||||
Username: "test2",
|
||||
}
|
||||
|
||||
pkValue2 := getFieldValue(user2, "ID")
|
||||
if pkValue2 != 0 {
|
||||
t.Errorf("期望主键值为 0, 得到 %d", pkValue2)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Update 场景测试通过")
|
||||
}
|
||||
|
||||
// TestGetFieldValueLargeNumbers 测试大数字
|
||||
func TestGetFieldValueLargeNumbers(t *testing.T) {
|
||||
fmt.Println("\n=== 测试大数字 ===")
|
||||
|
||||
// 测试最大 int64
|
||||
maxInt := int64(9223372036854775807)
|
||||
model1 := &TestModelInt64{ID: maxInt}
|
||||
result1 := getFieldValue(model1, "ID")
|
||||
if result1 != maxInt {
|
||||
t.Errorf("期望 %d, 得到 %d", maxInt, result1)
|
||||
}
|
||||
|
||||
// 测试 uint64 转 int64
|
||||
largeUint := uint64(18446744073709551615) // 这会导致溢出
|
||||
model2 := &TestModelUint64{ID: largeUint}
|
||||
result2 := getFieldValue(model2, "ID")
|
||||
// 注意:这里会发生溢出,但这是预期的行为
|
||||
if result2 == 0 {
|
||||
t.Error("uint64 转换不应该返回 0")
|
||||
}
|
||||
|
||||
fmt.Println("✓ 大数字测试通过")
|
||||
}
|
||||
|
||||
// 测试模型定义
|
||||
type TestModelInt struct {
|
||||
ID int `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type TestModelInt64 struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type TestModelUint struct {
|
||||
ID uint `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type TestModelUint64 struct {
|
||||
ID uint64 `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type TestModelFloat struct {
|
||||
ID float64 `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type TestModelId struct {
|
||||
Id int64 `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type TestModelid struct {
|
||||
id int64 `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type TestModelPrivate struct {
|
||||
privateField int64
|
||||
}
|
||||
|
||||
type TestModelNullInt struct {
|
||||
ID sql.NullInt64 `json:"id" db:"id"`
|
||||
}
|
||||
|
||||
type UserModel struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Username string `json:"username" db:"username"`
|
||||
}
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/db/driver"
|
||||
)
|
||||
|
||||
// defaultDatabase 全局默认数据库连接实例
|
||||
var defaultDatabase *Database
|
||||
|
||||
// NewDatabase 创建数据库连接 - 初始化数据库连接和相关组件
|
||||
func NewDatabase(config *Config) (*Database, error) {
|
||||
db := &Database{
|
||||
config: config,
|
||||
debug: config.Debug,
|
||||
driverName: config.DriverName,
|
||||
}
|
||||
|
||||
// 初始化时间配置
|
||||
if config.TimeConfig == nil {
|
||||
db.timeConfig = DefaultTimeConfig()
|
||||
} else {
|
||||
db.timeConfig = config.TimeConfig
|
||||
db.timeConfig.Validate()
|
||||
}
|
||||
|
||||
// 获取驱动管理器
|
||||
dm := driver.GetDefaultManager()
|
||||
|
||||
// 打开数据库连接
|
||||
sqlDB, err := dm.Open(config.DriverName, config.DataSource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库失败:%w", err)
|
||||
}
|
||||
|
||||
db.db = sqlDB
|
||||
|
||||
// 配置连接池参数
|
||||
if config.MaxIdleConns > 0 {
|
||||
db.db.SetMaxIdleConns(config.MaxIdleConns)
|
||||
}
|
||||
if config.MaxOpenConns > 0 {
|
||||
db.db.SetMaxOpenConns(config.MaxOpenConns)
|
||||
}
|
||||
if config.ConnMaxLifetime > 0 {
|
||||
db.db.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||||
}
|
||||
|
||||
// 测试数据库连接
|
||||
if err := db.db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败:%w", err)
|
||||
}
|
||||
|
||||
// 初始化组件
|
||||
db.mapper = NewFieldMapper()
|
||||
db.migrator = NewMigrator(db)
|
||||
|
||||
if config.Debug {
|
||||
fmt.Println("[Magic-ORM] 数据库连接成功")
|
||||
}
|
||||
|
||||
// 设置为全局默认实例
|
||||
defaultDatabase = db
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// AutoConnect 自动查找配置文件并创建数据库连接
|
||||
// 会在当前目录及上级目录中查找 config.yaml, config.toml, config.ini, config.json 等文件
|
||||
func AutoConnect(debug bool) (*Database, error) {
|
||||
// 自动查找配置文件
|
||||
configPath, err := findConfigFile("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找配置文件失败:%w", err)
|
||||
}
|
||||
|
||||
// 从文件加载配置(使用 config 包)
|
||||
return loadAndConnect(configPath, debug)
|
||||
}
|
||||
|
||||
// Connect 从配置文件创建数据库连接(向后兼容)
|
||||
// Deprecated: 使用 AutoConnect 代替
|
||||
func Connect(configPath string, debug bool) (*Database, error) {
|
||||
return loadAndConnect(configPath, debug)
|
||||
}
|
||||
|
||||
// findConfigFile 在项目目录下自动查找配置文件
|
||||
// 支持 yaml, yml, toml, ini, json 等格式
|
||||
// 只在当前目录查找,不越级查找
|
||||
func findConfigFile(searchDir string) (string, error) {
|
||||
// 配置文件名优先级列表
|
||||
configNames := []string{
|
||||
"config.yaml", "config.yml",
|
||||
"config.toml",
|
||||
"config.ini",
|
||||
"config.json",
|
||||
".config.yaml", ".config.yml",
|
||||
".config.toml",
|
||||
".config.ini",
|
||||
".config.json",
|
||||
}
|
||||
|
||||
// 如果未指定搜索目录,使用当前目录
|
||||
if searchDir == "" {
|
||||
var err error
|
||||
searchDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取当前目录失败:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 只在当前目录下查找,不向上查找
|
||||
for _, name := range configNames {
|
||||
filePath := filepath.Join(searchDir, name)
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
return filePath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("未找到配置文件(支持 yaml, yml, toml, ini, json 格式)")
|
||||
}
|
||||
|
||||
// loadAndConnect 从配置文件加载并创建数据库连接
|
||||
func loadAndConnect(configPath string, debug bool) (*Database, error) {
|
||||
// 这里需要调用 config 包的 LoadFromFile
|
||||
// 为了避免循环依赖,我们直接在 core 包中实现简单的 YAML 解析
|
||||
// 或者通过接口传递配置
|
||||
|
||||
// 简单方案:返回错误,提示使用 config 包
|
||||
return nil, fmt.Errorf("请使用 config.AutoConnect() 方法")
|
||||
}
|
||||
|
||||
// GetDefaultDatabase 获取全局默认数据库实例
|
||||
func GetDefaultDatabase() *Database {
|
||||
return defaultDatabase
|
||||
}
|
||||
|
||||
// SetDefaultDatabase 设置全局默认数据库实例
|
||||
func SetDefaultDatabase(db *Database) {
|
||||
defaultDatabase = db
|
||||
}
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ParamFilter 参数过滤器 - 智能过滤零值和空值字段
|
||||
type ParamFilter struct{}
|
||||
|
||||
// NewParamFilter 创建参数过滤器实例
|
||||
func NewParamFilter() *ParamFilter {
|
||||
return &ParamFilter{}
|
||||
}
|
||||
|
||||
// FilterZeroValues 过滤零值和空值字段
|
||||
func (pf *ParamFilter) FilterZeroValues(data map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
for key, value := range data {
|
||||
if !pf.isZeroValue(value) {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterEmptyStrings 过滤空字符串
|
||||
func (pf *ParamFilter) FilterEmptyStrings(data map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
for key, value := range data {
|
||||
if str, ok := value.(string); ok {
|
||||
if str != "" {
|
||||
result[key] = value
|
||||
}
|
||||
} else {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterNilValues 过滤 nil 值
|
||||
func (pf *ParamFilter) FilterNilValues(data map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
for key, value := range data {
|
||||
if value != nil {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// isZeroValue 检查是否是零值
|
||||
func (pf *ParamFilter) isZeroValue(v interface{}) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(v)
|
||||
|
||||
switch val.Kind() {
|
||||
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
|
||||
return val.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !val.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return val.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return val.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return val.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return val.IsNil()
|
||||
case reflect.Struct:
|
||||
// 特殊处理 time.Time
|
||||
if t, ok := v.(time.Time); ok {
|
||||
return t.IsZero()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsValidValue 检查值是否有效(非零值、非空值)
|
||||
func (pf *ParamFilter) IsValidValue(v interface{}) bool {
|
||||
return !pf.isZeroValue(v)
|
||||
}
|
||||
|
|
@ -1,254 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IDatabase 数据库连接接口 - 提供所有数据库操作的顶层接口
|
||||
type IDatabase interface {
|
||||
// 基础操作
|
||||
DB() *sql.DB // 返回底层的 sql.DB 对象
|
||||
Close() error // 关闭数据库连接
|
||||
Ping() error // 测试数据库连接是否正常
|
||||
|
||||
// 事务管理
|
||||
Begin() (ITx, error) // 开始一个新事务
|
||||
Transaction(fn func(ITx) error) error // 执行事务,自动提交或回滚
|
||||
|
||||
// 查询构建器
|
||||
Model(model interface{}) IQuery // 基于模型创建查询
|
||||
Table(name string) IQuery // 基于表名创建查询
|
||||
Query(result interface{}, query string, args ...interface{}) error // 执行原生 SQL 查询
|
||||
Exec(query string, args ...interface{}) (sql.Result, error) // 执行原生 SQL 并返回结果
|
||||
|
||||
// 迁移管理
|
||||
Migrate(models ...interface{}) error // 执行数据库迁移
|
||||
|
||||
// 配置
|
||||
SetDebug(bool) // 设置调试模式
|
||||
SetMaxIdleConns(int) // 设置最大空闲连接数
|
||||
SetMaxOpenConns(int) // 设置最大打开连接数
|
||||
SetConnMaxLifetime(time.Duration) // 设置连接最大生命周期
|
||||
}
|
||||
|
||||
// ITx 事务接口 - 提供事务操作的所有方法
|
||||
type ITx interface {
|
||||
// 基础操作
|
||||
Commit() error // 提交事务
|
||||
Rollback() error // 回滚事务
|
||||
|
||||
// 查询操作
|
||||
Model(model interface{}) IQuery // 在事务中基于模型创建查询
|
||||
Table(name string) IQuery // 在事务中基于表名创建查询
|
||||
Insert(model interface{}) (int64, error) // 插入数据,返回插入的 ID
|
||||
BatchInsert(models interface{}, batchSize int) error // 批量插入数据
|
||||
Update(model interface{}, data map[string]interface{}) error // 更新数据
|
||||
Delete(model interface{}) error // 删除数据
|
||||
|
||||
// 原生 SQL
|
||||
Query(result interface{}, query string, args ...interface{}) error // 执行原生 SQL 查询
|
||||
Exec(query string, args ...interface{}) (sql.Result, error) // 执行原生 SQL
|
||||
}
|
||||
|
||||
// IQuery 查询构建器接口 - 提供流畅的链式查询构建能力
|
||||
type IQuery interface {
|
||||
// 条件查询
|
||||
Where(query string, args ...interface{}) IQuery // 添加 WHERE 条件
|
||||
Or(query string, args ...interface{}) IQuery // 添加 OR 条件
|
||||
And(query string, args ...interface{}) IQuery // 添加 AND 条件
|
||||
|
||||
// 字段选择
|
||||
Select(fields ...string) IQuery // 选择要查询的字段
|
||||
Omit(fields ...string) IQuery // 排除指定的字段
|
||||
|
||||
// 排序
|
||||
Order(order string) IQuery // 设置排序规则
|
||||
OrderBy(field string, direction string) IQuery // 按指定字段和方向排序
|
||||
|
||||
// 分页
|
||||
Limit(limit int) IQuery // 限制返回数量
|
||||
Offset(offset int) IQuery // 设置偏移量
|
||||
Page(page, pageSize int) IQuery // 分页查询
|
||||
|
||||
// 分组
|
||||
Group(group string) IQuery // 设置分组字段
|
||||
Having(having string, args ...interface{}) IQuery // 添加 HAVING 条件
|
||||
|
||||
// 连接
|
||||
Join(join string, args ...interface{}) IQuery // 添加 JOIN 连接
|
||||
LeftJoin(table, on string) IQuery // 左连接
|
||||
RightJoin(table, on string) IQuery // 右连接
|
||||
InnerJoin(table, on string) IQuery // 内连接
|
||||
|
||||
// 预加载
|
||||
Preload(relation string, conditions ...interface{}) IQuery // 预加载关联数据
|
||||
|
||||
// 执行查询
|
||||
First(result interface{}) error // 查询第一条记录
|
||||
Find(result interface{}) error // 查询多条记录
|
||||
Count(count *int64) IQuery // 统计记录数量
|
||||
Exists() (bool, error) // 检查记录是否存在
|
||||
|
||||
// 更新和删除
|
||||
Updates(data interface{}) error // 更新数据
|
||||
UpdateColumn(column string, value interface{}) error // 更新单个字段
|
||||
Delete() error // 删除数据
|
||||
|
||||
// 特殊模式
|
||||
Unscoped() IQuery // 忽略软删除
|
||||
DryRun() IQuery // 干跑模式,不执行只生成 SQL
|
||||
Debug() IQuery // 调试模式,打印 SQL 日志
|
||||
|
||||
// 构建 SQL(不执行)
|
||||
Build() (string, []interface{}) // 构建 SELECT SQL 语句
|
||||
BuildUpdate(data interface{}) (string, []interface{}) // 构建 UPDATE SQL 语句
|
||||
BuildDelete() (string, []interface{}) // 构建 DELETE SQL 语句
|
||||
}
|
||||
|
||||
// IModel 模型接口 - 定义模型的基本行为和生命周期回调
|
||||
type IModel interface {
|
||||
// 表名映射
|
||||
TableName() string // 返回模型对应的表名
|
||||
|
||||
// 生命周期回调(可选实现)
|
||||
BeforeCreate(tx ITx) error // 创建前回调
|
||||
AfterCreate(tx ITx) error // 创建后回调
|
||||
BeforeUpdate(tx ITx) error // 更新前回调
|
||||
AfterUpdate(tx ITx) error // 更新后回调
|
||||
BeforeDelete(tx ITx) error // 删除前回调
|
||||
AfterDelete(tx ITx) error // 删除后回调
|
||||
BeforeSave(tx ITx) error // 保存前回调
|
||||
AfterSave(tx ITx) error // 保存后回调
|
||||
}
|
||||
|
||||
// IFieldMapper 字段映射器接口 - 处理 Go 结构体与数据库字段之间的映射
|
||||
type IFieldMapper interface {
|
||||
// 结构体字段转数据库列
|
||||
StructToColumns(model interface{}) (map[string]interface{}, error) // 将结构体转换为键值对
|
||||
|
||||
// 数据库列转结构体字段
|
||||
ColumnsToStruct(row *sql.Rows, model interface{}) error // 将查询结果映射到结构体
|
||||
|
||||
// 获取表名
|
||||
GetTableName(model interface{}) string // 获取模型对应的表名
|
||||
|
||||
// 获取主键字段
|
||||
GetPrimaryKey(model interface{}) string // 获取主键字段名
|
||||
|
||||
// 获取字段信息
|
||||
GetFields(model interface{}) []FieldInfo // 获取所有字段信息
|
||||
}
|
||||
|
||||
// FieldInfo 字段信息 - 描述数据库字段的详细信息
|
||||
type FieldInfo struct {
|
||||
Name string // 字段名(Go 结构体字段名)
|
||||
Column string // 列名(数据库中的实际列名)
|
||||
Type string // Go 类型(如 string, int, time.Time 等)
|
||||
DbType string // 数据库类型(如 VARCHAR, INT, DATETIME 等)
|
||||
Tag string // 标签(db 标签内容)
|
||||
IsPrimary bool // 是否主键
|
||||
IsAuto bool // 是否自增
|
||||
}
|
||||
|
||||
// IMigrator 迁移管理器接口 - 提供数据库架构迁移的所有操作
|
||||
type IMigrator interface {
|
||||
// 自动迁移
|
||||
AutoMigrate(models ...interface{}) error // 自动执行模型迁移
|
||||
|
||||
// 表操作
|
||||
CreateTable(model interface{}) error // 创建表
|
||||
DropTable(model interface{}) error // 删除表
|
||||
HasTable(model interface{}) (bool, error) // 检查表是否存在
|
||||
RenameTable(oldName, newName string) error // 重命名表
|
||||
|
||||
// 列操作
|
||||
AddColumn(model interface{}, field string) error // 添加列
|
||||
DropColumn(model interface{}, field string) error // 删除列
|
||||
HasColumn(model interface{}, field string) (bool, error) // 检查列是否存在
|
||||
RenameColumn(model interface{}, oldField, newField string) error // 重命名列
|
||||
|
||||
// 索引操作
|
||||
CreateIndex(model interface{}, field string) error // 创建索引
|
||||
DropIndex(model interface{}, field string) error // 删除索引
|
||||
HasIndex(model interface{}, field string) (bool, error) // 检查索引是否存在
|
||||
}
|
||||
|
||||
// ICodeGenerator 代码生成器接口 - 自动生成 Model 和 DAO 代码
|
||||
type ICodeGenerator interface {
|
||||
// 生成 Model 代码
|
||||
GenerateModel(table string, outputDir string) error // 根据表生成 Model 文件
|
||||
|
||||
// 生成 DAO 代码
|
||||
GenerateDAO(table string, outputDir string) error // 根据表生成 DAO 文件
|
||||
|
||||
// 生成完整代码
|
||||
GenerateAll(tables []string, outputDir string) error // 批量生成所有代码
|
||||
|
||||
// 从数据库读取表结构
|
||||
InspectTable(tableName string) (*TableSchema, error) // 检查表结构
|
||||
}
|
||||
|
||||
// TableSchema 表结构信息 - 描述数据库表的完整结构
|
||||
type TableSchema struct {
|
||||
Name string // 表名
|
||||
Columns []ColumnInfo // 列信息列表
|
||||
Indexes []IndexInfo // 索引信息列表
|
||||
}
|
||||
|
||||
// ColumnInfo 列信息 - 描述表中一个列的详细信息
|
||||
type ColumnInfo struct {
|
||||
Name string // 列名
|
||||
Type string // 数据类型
|
||||
Nullable bool // 是否允许为空
|
||||
Default interface{} // 默认值
|
||||
PrimaryKey bool // 是否主键
|
||||
}
|
||||
|
||||
// IndexInfo 索引信息 - 描述表中一个索引的详细信息
|
||||
type IndexInfo struct {
|
||||
Name string // 索引名
|
||||
Columns []string // 索引包含的列
|
||||
Unique bool // 是否唯一索引
|
||||
}
|
||||
|
||||
// ReadPolicy 读负载均衡策略 - 定义主从集群中读操作的分配策略
|
||||
type ReadPolicy int
|
||||
|
||||
const (
|
||||
Random ReadPolicy = iota // 随机选择一个从库
|
||||
RoundRobin // 轮询方式选择从库
|
||||
LeastConn // 选择连接数最少的从库
|
||||
)
|
||||
|
||||
// Config 数据库配置 - 包含数据库连接的所有配置项
|
||||
type Config struct {
|
||||
DriverName string // 驱动名称(如 mysql, sqlite, postgres 等)
|
||||
DataSource string // 数据源连接字符串(DNS)
|
||||
MaxIdleConns int // 最大空闲连接数
|
||||
MaxOpenConns int // 最大打开连接数
|
||||
ConnMaxLifetime time.Duration // 连接最大生命周期
|
||||
Debug bool // 调试模式(是否打印 SQL 日志)
|
||||
|
||||
// 主从配置
|
||||
Replicas []string // 从库列表(用于读写分离)
|
||||
ReadPolicy ReadPolicy // 读负载均衡策略
|
||||
|
||||
// OpenTelemetry 可观测性配置
|
||||
EnableTracing bool // 是否启用链路追踪
|
||||
ServiceName string // 服务名称(用于 Tracing)
|
||||
|
||||
// 时间配置
|
||||
TimeConfig *TimeConfig // 时间字段配置(字段名、格式等)
|
||||
}
|
||||
|
||||
// Database 数据库实现 - IDatabase 接口的具体实现
|
||||
type Database struct {
|
||||
db *sql.DB // 底层数据库连接
|
||||
config *Config // 数据库配置
|
||||
debug bool // 调试模式开关
|
||||
mapper IFieldMapper // 字段映射器实例
|
||||
migrator IMigrator // 迁移管理器实例
|
||||
driverName string // 驱动名称
|
||||
timeConfig *TimeConfig // 时间配置
|
||||
}
|
||||
|
|
@ -1,306 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FieldMapper 字段映射器实现 - 使用反射处理 Go 结构体与数据库字段之间的映射
|
||||
type FieldMapper struct{}
|
||||
|
||||
// NewFieldMapper 创建字段映射器实例
|
||||
func NewFieldMapper() IFieldMapper {
|
||||
return &FieldMapper{}
|
||||
}
|
||||
|
||||
// StructToColumns 将结构体转换为键值对 - 用于 INSERT/UPDATE 操作
|
||||
func (fm *FieldMapper) StructToColumns(model interface{}) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 获取反射对象
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, errors.New("模型必须是结构体")
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
|
||||
// 遍历所有字段
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
value := val.Field(i)
|
||||
|
||||
// 跳过未导出的字段
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 db 标签
|
||||
dbTag := field.Tag.Get("db")
|
||||
if dbTag == "" || dbTag == "-" {
|
||||
continue // 跳过没有 db 标签或标签为 - 的字段
|
||||
}
|
||||
|
||||
// 跳过零值(可选优化)
|
||||
if fm.isZeroValue(value) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加到结果 map
|
||||
result[dbTag] = value.Interface()
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ColumnsToStruct 将查询结果映射到结构体 - 用于 SELECT 操作
|
||||
func (fm *FieldMapper) ColumnsToStruct(rows *sql.Rows, model interface{}) error {
|
||||
// 获取列信息
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取列信息失败:%w", err)
|
||||
}
|
||||
|
||||
// 获取反射对象
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return errors.New("模型必须是指针类型")
|
||||
}
|
||||
|
||||
elem := val.Elem()
|
||||
if elem.Kind() != reflect.Struct {
|
||||
return errors.New("模型必须是指向结构体的指针")
|
||||
}
|
||||
|
||||
// 创建扫描目标
|
||||
scanTargets := make([]interface{}, len(columns))
|
||||
fieldMap := make(map[int]int) // column index -> field index
|
||||
|
||||
// 建立列名到结构体字段的映射
|
||||
for i, col := range columns {
|
||||
found := false
|
||||
for j := 0; j < elem.NumField(); j++ {
|
||||
field := elem.Type().Field(j)
|
||||
dbTag := field.Tag.Get("db")
|
||||
|
||||
// 匹配列名和字段
|
||||
if dbTag == col || strings.ToLower(dbTag) == strings.ToLower(col) ||
|
||||
strings.ToLower(field.Name) == strings.ToLower(col) {
|
||||
fieldMap[i] = j
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没找到匹配字段,使用 interface{} 占位
|
||||
if !found {
|
||||
var dummy interface{}
|
||||
scanTargets[i] = &dummy
|
||||
}
|
||||
}
|
||||
|
||||
// 为找到的字段创建扫描目标
|
||||
for i := range columns {
|
||||
if fieldIdx, ok := fieldMap[i]; ok {
|
||||
field := elem.Field(fieldIdx)
|
||||
if field.CanSet() {
|
||||
scanTargets[i] = field.Addr().Interface()
|
||||
} else {
|
||||
var dummy interface{}
|
||||
scanTargets[i] = &dummy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 执行扫描
|
||||
if err := rows.Scan(scanTargets...); err != nil {
|
||||
return fmt.Errorf("扫描数据失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTableName 获取模型对应的表名
|
||||
func (fm *FieldMapper) GetTableName(model interface{}) string {
|
||||
// 检查是否实现了 TableName() 方法
|
||||
type tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
if t, ok := model.(tabler); ok {
|
||||
return t.TableName()
|
||||
}
|
||||
|
||||
// 否则使用结构体名称
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
return fm.toSnakeCase(typ.Name())
|
||||
}
|
||||
|
||||
// GetPrimaryKey 获取主键字段名 - 默认为 "id"
|
||||
func (fm *FieldMapper) GetPrimaryKey(model interface{}) string {
|
||||
// 查找标记为主键的字段
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// 检查是否是 ID 字段
|
||||
fieldName := field.Name
|
||||
if fieldName == "ID" || fieldName == "Id" || fieldName == "id" {
|
||||
dbTag := field.Tag.Get("db")
|
||||
if dbTag != "" && dbTag != "-" {
|
||||
return dbTag
|
||||
}
|
||||
return "id"
|
||||
}
|
||||
|
||||
// 检查是否有 primary 标签
|
||||
if field.Tag.Get("primary") == "true" {
|
||||
dbTag := field.Tag.Get("db")
|
||||
if dbTag != "" {
|
||||
return dbTag
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "id" // 默认返回 id
|
||||
}
|
||||
|
||||
// GetFields 获取所有字段信息 - 用于生成 SQL 语句
|
||||
func (fm *FieldMapper) GetFields(model interface{}) []FieldInfo {
|
||||
var fields []FieldInfo
|
||||
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
|
||||
// 遍历所有字段
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// 跳过未导出的字段
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 db 标签
|
||||
dbTag := field.Tag.Get("db")
|
||||
if dbTag == "" || dbTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 创建字段信息
|
||||
info := FieldInfo{
|
||||
Name: field.Name,
|
||||
Column: dbTag,
|
||||
Type: fm.getTypeName(field.Type),
|
||||
DbType: fm.mapToDbType(field.Type),
|
||||
Tag: dbTag,
|
||||
}
|
||||
|
||||
// 检查是否是主键
|
||||
if field.Tag.Get("primary") == "true" ||
|
||||
field.Name == "ID" || field.Name == "Id" {
|
||||
info.IsPrimary = true
|
||||
}
|
||||
|
||||
// 检查是否是自增
|
||||
if field.Tag.Get("auto") == "true" {
|
||||
info.IsAuto = true
|
||||
}
|
||||
|
||||
fields = append(fields, info)
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// isZeroValue 检查是否是零值
|
||||
func (fm *FieldMapper) isZeroValue(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
|
||||
return v.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return v.IsNil()
|
||||
case reflect.Struct:
|
||||
// 特殊处理 time.Time
|
||||
if t, ok := v.Interface().(time.Time); ok {
|
||||
return t.IsZero()
|
||||
}
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getTypeName 获取类型的名称
|
||||
func (fm *FieldMapper) getTypeName(t reflect.Type) string {
|
||||
return t.String()
|
||||
}
|
||||
|
||||
// mapToDbType 将 Go 类型映射到数据库类型
|
||||
func (fm *FieldMapper) mapToDbType(t reflect.Type) string {
|
||||
switch t.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return "BIGINT"
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return "BIGINT UNSIGNED"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "DECIMAL"
|
||||
case reflect.Bool:
|
||||
return "TINYINT"
|
||||
case reflect.String:
|
||||
return "VARCHAR(255)"
|
||||
default:
|
||||
// 特殊类型
|
||||
if t.PkgPath() == "time" && t.Name() == "Time" {
|
||||
return "DATETIME"
|
||||
}
|
||||
return "TEXT"
|
||||
}
|
||||
}
|
||||
|
||||
// toSnakeCase 将驼峰命名转换为下划线命名
|
||||
func (fm *FieldMapper) toSnakeCase(str string) string {
|
||||
var result strings.Builder
|
||||
|
||||
for i, r := range str {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
if i > 0 {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(r + 32) // 转换为小写
|
||||
} else {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
|
@ -1,292 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Migrator 迁移管理器实现 - 处理数据库架构的自动迁移
|
||||
type Migrator struct {
|
||||
db *Database // 数据库连接实例
|
||||
}
|
||||
|
||||
// NewMigrator 创建迁移管理器实例
|
||||
func NewMigrator(db *Database) IMigrator {
|
||||
return &Migrator{db: db}
|
||||
}
|
||||
|
||||
// AutoMigrate 自动迁移 - 根据模型自动创建或更新数据库表结构
|
||||
func (m *Migrator) AutoMigrate(models ...interface{}) error {
|
||||
for _, model := range models {
|
||||
if err := m.CreateTable(model); err != nil {
|
||||
return fmt.Errorf("创建表失败:%w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTable 创建表 - 根据模型创建数据库表
|
||||
func (m *Migrator) CreateTable(model interface{}) error {
|
||||
mapper := NewFieldMapper()
|
||||
|
||||
// 获取表名
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
// 获取字段信息
|
||||
fields := mapper.GetFields(model)
|
||||
if len(fields) == 0 {
|
||||
return fmt.Errorf("模型没有有效的字段")
|
||||
}
|
||||
|
||||
// 生成 CREATE TABLE SQL
|
||||
var sqlBuilder strings.Builder
|
||||
sqlBuilder.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", tableName))
|
||||
|
||||
columnDefs := make([]string, 0)
|
||||
for _, field := range fields {
|
||||
colDef := fmt.Sprintf("%s %s", field.Column, field.DbType)
|
||||
|
||||
// 添加主键约束
|
||||
if field.IsPrimary {
|
||||
colDef += " PRIMARY KEY"
|
||||
if field.IsAuto {
|
||||
colDef += " AUTOINCREMENT"
|
||||
}
|
||||
}
|
||||
|
||||
// 添加 NOT NULL 约束(可选)
|
||||
// colDef += " NOT NULL"
|
||||
|
||||
columnDefs = append(columnDefs, colDef)
|
||||
}
|
||||
|
||||
sqlBuilder.WriteString(strings.Join(columnDefs, ", "))
|
||||
sqlBuilder.WriteString(")")
|
||||
|
||||
createSQL := sqlBuilder.String()
|
||||
|
||||
if m.db.debug {
|
||||
fmt.Printf("[Magic-ORM] CREATE TABLE SQL: %s\n", createSQL)
|
||||
}
|
||||
|
||||
// 执行 SQL
|
||||
_, err := m.db.db.Exec(createSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("执行 CREATE TABLE 失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropTable 删除表 - 删除指定的数据库表
|
||||
func (m *Migrator) DropTable(model interface{}) error {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
|
||||
|
||||
if m.db.debug {
|
||||
fmt.Printf("[Magic-ORM] DROP TABLE SQL: %s\n", dropSQL)
|
||||
}
|
||||
|
||||
_, err := m.db.db.Exec(dropSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("执行 DROP TABLE 失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasTable 检查表是否存在 - 验证数据库中是否已存在指定表
|
||||
func (m *Migrator) HasTable(model interface{}) (bool, error) {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
// SQLite 检查表是否存在的 SQL
|
||||
checkSQL := `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?`
|
||||
|
||||
var count int
|
||||
err := m.db.db.QueryRow(checkSQL, tableName).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查表是否存在失败:%w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// RenameTable 重命名表 - 修改数据库表的名称
|
||||
func (m *Migrator) RenameTable(oldName, newName string) error {
|
||||
renameSQL := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName)
|
||||
|
||||
if m.db.debug {
|
||||
fmt.Printf("[Magic-ORM] RENAME TABLE SQL: %s\n", renameSQL)
|
||||
}
|
||||
|
||||
_, err := m.db.db.Exec(renameSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("重命名表失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddColumn 添加列 - 向表中添加新的字段
|
||||
func (m *Migrator) AddColumn(model interface{}, field string) error {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
// 获取字段信息
|
||||
fields := mapper.GetFields(model)
|
||||
var targetField *FieldInfo
|
||||
|
||||
for _, f := range fields {
|
||||
if f.Name == field || f.Column == field {
|
||||
targetField = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetField == nil {
|
||||
return fmt.Errorf("字段不存在:%s", field)
|
||||
}
|
||||
|
||||
addSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s",
|
||||
tableName, targetField.Column, targetField.DbType)
|
||||
|
||||
if m.db.debug {
|
||||
fmt.Printf("[Magic-ORM] ADD COLUMN SQL: %s\n", addSQL)
|
||||
}
|
||||
|
||||
_, err := m.db.db.Exec(addSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加列失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropColumn 删除列 - 从表中删除指定的字段
|
||||
func (m *Migrator) DropColumn(model interface{}, field string) error {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
// SQLite 不直接支持 DROP COLUMN,需要重建表
|
||||
// 这里使用简化方案:创建新表 -> 复制数据 -> 删除旧表 -> 重命名
|
||||
|
||||
_ = tableName // 避免编译错误
|
||||
return fmt.Errorf("SQLite 不支持直接删除列,需要手动重建表")
|
||||
}
|
||||
|
||||
// HasColumn 检查列是否存在 - 验证表中是否已存在指定字段
|
||||
func (m *Migrator) HasColumn(model interface{}, field string) (bool, error) {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
// SQLite 检查列是否存在的 SQL
|
||||
checkSQL := `PRAGMA table_info(` + tableName + `)`
|
||||
|
||||
rows, err := m.db.db.Query(checkSQL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查列失败:%w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name string
|
||||
var typ string
|
||||
var notNull int
|
||||
var dfltValue interface{}
|
||||
var pk int
|
||||
|
||||
if err := rows.Scan(&cid, &name, &typ, ¬Null, &dfltValue, &pk); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if name == field {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// RenameColumn 重命名列 - 修改表中字段的名称
|
||||
func (m *Migrator) RenameColumn(model interface{}, oldField, newField string) error {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
// SQLite 3.25.0+ 支持 ALTER TABLE ... RENAME COLUMN
|
||||
renameSQL := fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s",
|
||||
tableName, oldField, newField)
|
||||
|
||||
if m.db.debug {
|
||||
fmt.Printf("[Magic-ORM] RENAME COLUMN SQL: %s\n", renameSQL)
|
||||
}
|
||||
|
||||
_, err := m.db.db.Exec(renameSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("重命名列失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateIndex 创建索引 - 为表中的字段创建索引
|
||||
func (m *Migrator) CreateIndex(model interface{}, field string) error {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
indexName := fmt.Sprintf("idx_%s_%s", tableName, field)
|
||||
createSQL := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
indexName, tableName, field)
|
||||
|
||||
if m.db.debug {
|
||||
fmt.Printf("[Magic-ORM] CREATE INDEX SQL: %s\n", createSQL)
|
||||
}
|
||||
|
||||
_, err := m.db.db.Exec(createSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建索引失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropIndex 删除索引 - 删除表中的指定索引
|
||||
func (m *Migrator) DropIndex(model interface{}, field string) error {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
indexName := fmt.Sprintf("idx_%s_%s", tableName, field)
|
||||
dropSQL := fmt.Sprintf("DROP INDEX IF EXISTS %s", indexName)
|
||||
|
||||
if m.db.debug {
|
||||
fmt.Printf("[Magic-ORM] DROP INDEX SQL: %s\n", dropSQL)
|
||||
}
|
||||
|
||||
_, err := m.db.db.Exec(dropSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除索引失败:%w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasIndex 检查索引是否存在 - 验证表中是否已存在指定索引
|
||||
func (m *Migrator) HasIndex(model interface{}, field string) (bool, error) {
|
||||
mapper := NewFieldMapper()
|
||||
tableName := mapper.GetTableName(model)
|
||||
|
||||
indexName := fmt.Sprintf("idx_%s_%s", tableName, field)
|
||||
|
||||
checkSQL := `SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND name=?`
|
||||
|
||||
var count int
|
||||
err := m.db.db.QueryRow(checkSQL, indexName).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查索引失败:%w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ExampleQueryBuilder_Omit 演示 Omit 方法的使用
|
||||
func ExampleQueryBuilder_Omit() {
|
||||
// 定义用户模型
|
||||
type User struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Name string `json:"name" db:"name"`
|
||||
Email string `json:"email" db:"email"`
|
||||
Password string `json:"password" db:"password"`
|
||||
Status int `json:"status" db:"status"`
|
||||
}
|
||||
|
||||
// 创建 Database 实例(示例中使用 nil,实际使用需要正确初始化)
|
||||
db := &Database{}
|
||||
|
||||
// 示例 1: 排除敏感字段(如密码)
|
||||
q1 := db.Model(&User{}).Omit("password")
|
||||
sql1, _ := q1.(*QueryBuilder).BuildSelect()
|
||||
fmt.Printf("排除密码:%s\n", sql1)
|
||||
|
||||
// 示例 2: 排除多个字段
|
||||
q2 := db.Model(&User{}).Omit("password", "status")
|
||||
sql2, _ := q2.(*QueryBuilder).BuildSelect()
|
||||
fmt.Printf("排除多个字段:%s\n", sql2)
|
||||
|
||||
// 示例 3: 链式调用 Omit
|
||||
q3 := db.Model(&User{}).Omit("password").Omit("status")
|
||||
sql3, _ := q3.(*QueryBuilder).BuildSelect()
|
||||
fmt.Printf("链式调用:%s\n", sql3)
|
||||
|
||||
// 示例 4: Select 优先于 Omit
|
||||
q4 := db.Model(&User{}).Select("id", "name").Omit("password")
|
||||
sql4, _ := q4.(*QueryBuilder).BuildSelect()
|
||||
fmt.Printf("Select 优先:%s\n", sql4)
|
||||
|
||||
// 输出:
|
||||
// 排除密码:SELECT id, name, email, status FROM user_model
|
||||
// 排除多个字段:SELECT id, name, email FROM user_model
|
||||
// 链式调用:SELECT id, name, email FROM user_model
|
||||
// Select 优先:SELECT id, name FROM user_model
|
||||
}
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestQueryBuilder_Omit 测试 Omit 方法
|
||||
func TestQueryBuilder_Omit(t *testing.T) {
|
||||
// 创建测试模型
|
||||
type UserModel struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Name string `json:"name" db:"name"`
|
||||
Email string `json:"email" db:"email"`
|
||||
Password string `json:"password" db:"password"`
|
||||
Status int `json:"status" db:"status"`
|
||||
}
|
||||
|
||||
// 创建 Database 实例(使用 nil 连接,只测试 SQL 生成)
|
||||
db := &Database{}
|
||||
|
||||
t.Run("排除单个字段", func(t *testing.T) {
|
||||
qb := db.Model(&UserModel{}).Omit("password").(*QueryBuilder)
|
||||
sql, args := qb.BuildSelect()
|
||||
|
||||
fmt.Printf("排除单个字段 SQL: %s\n", sql)
|
||||
fmt.Printf("参数:%v\n", args)
|
||||
|
||||
// 验证 SQL 不包含 password 字段
|
||||
if containsString(sql, "password") {
|
||||
t.Errorf("SQL 不应该包含 password 字段:%s", sql)
|
||||
}
|
||||
|
||||
// 验证 SQL 包含其他字段
|
||||
expectedFields := []string{"id", "name", "email", "status"}
|
||||
for _, field := range expectedFields {
|
||||
if !containsString(sql, field) {
|
||||
t.Errorf("SQL 应该包含字段 %s: %s", field, sql)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("排除多个字段", func(t *testing.T) {
|
||||
qb := db.Model(&UserModel{}).Omit("password", "status").(*QueryBuilder)
|
||||
sql, args := qb.BuildSelect()
|
||||
|
||||
fmt.Printf("排除多个字段 SQL: %s\n", sql)
|
||||
fmt.Printf("参数:%v\n", args)
|
||||
|
||||
// 验证 SQL 不包含 password 和 status 字段
|
||||
if containsString(sql, "password") {
|
||||
t.Errorf("SQL 不应该包含 password 字段:%s", sql)
|
||||
}
|
||||
if containsString(sql, "status") {
|
||||
t.Errorf("SQL 不应该包含 status 字段:%s", sql)
|
||||
}
|
||||
|
||||
// 验证 SQL 包含其他字段
|
||||
expectedFields := []string{"id", "name", "email"}
|
||||
for _, field := range expectedFields {
|
||||
if !containsString(sql, field) {
|
||||
t.Errorf("SQL 应该包含字段 %s: %s", field, sql)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Omit 与 Select 优先级 - Select 优先", func(t *testing.T) {
|
||||
qb := db.Model(&UserModel{}).Select("id", "name").Omit("password").(*QueryBuilder)
|
||||
sql, args := qb.BuildSelect()
|
||||
|
||||
fmt.Printf("Select 优先 SQL: %s\n", sql)
|
||||
fmt.Printf("参数:%v\n", args)
|
||||
|
||||
// 当同时使用 Select 和 Omit 时,Select 优先
|
||||
expectedFields := []string{"id", "name"}
|
||||
for _, field := range expectedFields {
|
||||
if !containsString(sql, field) {
|
||||
t.Errorf("SQL 应该包含字段 %s: %s", field, sql)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("链式调用 Omit", func(t *testing.T) {
|
||||
qb := db.Model(&UserModel{}).Omit("password").Omit("status").(*QueryBuilder)
|
||||
sql, args := qb.BuildSelect()
|
||||
|
||||
fmt.Printf("链式调用 Omit SQL: %s\n", sql)
|
||||
fmt.Printf("参数:%v\n", args)
|
||||
|
||||
// 验证 SQL 不包含 password 和 status 字段
|
||||
if containsString(sql, "password") {
|
||||
t.Errorf("SQL 不应该包含 password 字段:%s", sql)
|
||||
}
|
||||
if containsString(sql, "status") {
|
||||
t.Errorf("SQL 不应该包含 status 字段:%s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("不设置 Omit - 默认行为", func(t *testing.T) {
|
||||
qb := db.Model(&UserModel{}).(*QueryBuilder)
|
||||
sql, args := qb.BuildSelect()
|
||||
|
||||
fmt.Printf("默认行为 SQL: %s\n", sql)
|
||||
fmt.Printf("参数:%v\n", args)
|
||||
|
||||
// 默认应该查询所有字段(使用 *)
|
||||
if sql != "SELECT * FROM user_model" {
|
||||
t.Errorf("默认应该使用 SELECT *: %s", sql)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// containsString 检查字符串是否包含子串
|
||||
func containsString(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr ||
|
||||
s[:len(substr)] == substr ||
|
||||
s[len(s)-len(substr):] == substr ||
|
||||
findSubstring(s, substr))
|
||||
}
|
||||
|
||||
func findSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// User 用户模型 - 用于测试
|
||||
type User struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
Username string `json:"username" db:"username"`
|
||||
Email string `json:"email" db:"email"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
|
||||
// 关联字段
|
||||
Profile UserProfile `json:"profile" db:"-" gorm:"ForeignKey:UserID;References:ID"`
|
||||
Orders []Order `json:"orders" db:"-" gorm:"ForeignKey:UserID;References:ID"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (User) TableName() string {
|
||||
return "user"
|
||||
}
|
||||
|
||||
// UserProfile 用户资料模型 - 一对一关联
|
||||
type UserProfile struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
UserID int64 `json:"user_id" db:"user_id"`
|
||||
Bio string `json:"bio" db:"bio"`
|
||||
Avatar string `json:"avatar" db:"avatar"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (UserProfile) TableName() string {
|
||||
return "user_profile"
|
||||
}
|
||||
|
||||
// Order 订单模型 - 一对多关联
|
||||
type Order struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
UserID int64 `json:"user_id" db:"user_id"`
|
||||
OrderNo string `json:"order_no" db:"order_no"`
|
||||
Amount float64 `json:"amount" db:"amount"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (Order) TableName() string {
|
||||
return "order"
|
||||
}
|
||||
|
||||
// TestPreloadHasOne 测试一对一预加载
|
||||
func TestPreloadHasOne(t *testing.T) {
|
||||
fmt.Println("\n=== 测试一对一预加载 ===")
|
||||
|
||||
// 这里只是示例,实际使用需要数据库连接
|
||||
// db := AutoConnect(true)
|
||||
// var users []User
|
||||
// err := db.Model(&User{}).Preload("Profile").Find(&users)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
fmt.Println("一对一预加载结构已实现")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
|
||||
// TestPreloadHasMany 测试一对多预加载
|
||||
func TestPreloadHasMany(t *testing.T) {
|
||||
fmt.Println("\n=== 测试一对多预加载 ===")
|
||||
|
||||
// 示例用法
|
||||
// var users []User
|
||||
// err := db.Model(&User{}).Preload("Orders").Find(&users)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
fmt.Println("一对多预加载结构已实现")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
|
||||
// TestPreloadBelongsTo 测试多对一预加载
|
||||
func TestPreloadBelongsTo(t *testing.T) {
|
||||
fmt.Println("\n=== 测试多对一预加载 ===")
|
||||
|
||||
// 示例用法
|
||||
// var orders []Order
|
||||
// err := db.Model(&Order{}).Preload("User").Find(&orders)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
fmt.Println("多对一预加载结构已实现")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
|
||||
// TestPreloadMultiple 测试多个预加载
|
||||
func TestPreloadMultiple(t *testing.T) {
|
||||
fmt.Println("\n=== 测试多个预加载 ===")
|
||||
|
||||
// 示例用法
|
||||
// var users []User
|
||||
// err := db.Model(&User{}).
|
||||
// Preload("Profile").
|
||||
// Preload("Orders").
|
||||
// Find(&users)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
fmt.Println("多个预加载已实现")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
|
||||
// TestPreloadWithConditions 测试带条件的预加载
|
||||
func TestPreloadWithConditions(t *testing.T) {
|
||||
fmt.Println("\n=== 测试带条件的预加载 ===")
|
||||
|
||||
// 示例用法
|
||||
// var users []User
|
||||
// err := db.Model(&User{}).
|
||||
// Preload("Orders", "amount > ?", 100).
|
||||
// Find(&users)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
fmt.Println("带条件的预加载已实现")
|
||||
fmt.Println("✓ 测试通过")
|
||||
}
|
||||
740
db/core/query.go
740
db/core/query.go
|
|
@ -1,740 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// QueryBuilder 查询构建器实现 - 提供流畅的链式查询构建能力
|
||||
type QueryBuilder struct {
|
||||
db *Database // 数据库连接实例
|
||||
table string // 表名
|
||||
model interface{} // 模型对象
|
||||
whereSQL string // WHERE 条件 SQL
|
||||
whereArgs []interface{} // WHERE 条件参数
|
||||
selectCols []string // 选择的字段列表
|
||||
omitCols []string // 排除的字段列表
|
||||
orderSQL string // ORDER BY SQL
|
||||
limit int // LIMIT 限制数量
|
||||
offset int // OFFSET 偏移量
|
||||
groupSQL string // GROUP BY SQL
|
||||
havingSQL string // HAVING 条件 SQL
|
||||
havingArgs []interface{} // HAVING 条件参数
|
||||
joinSQL string // JOIN SQL
|
||||
joinArgs []interface{} // JOIN 参数
|
||||
debug bool // 调试模式开关
|
||||
dryRun bool // 干跑模式开关
|
||||
unscoped bool // 忽略软删除开关
|
||||
tx *sql.Tx // 事务对象(如果在事务中)
|
||||
// 预加载关联数据
|
||||
preloadRelations map[string][]interface{} // 预加载的关联关系及条件
|
||||
// 缓存相关
|
||||
cache *QueryCache // 缓存实例
|
||||
cacheKey string // 缓存键
|
||||
useCache bool // 是否使用缓存
|
||||
}
|
||||
|
||||
// 同步池优化 - 复用 slice 减少内存分配
|
||||
var whereArgsPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]interface{}, 0, 10)
|
||||
},
|
||||
}
|
||||
|
||||
var joinArgsPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]interface{}, 0, 5)
|
||||
},
|
||||
}
|
||||
|
||||
// Model 基于模型创建查询
|
||||
func (d *Database) Model(model interface{}) IQuery {
|
||||
return &QueryBuilder{
|
||||
db: d,
|
||||
model: model,
|
||||
preloadRelations: make(map[string][]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Table 基于表名创建查询
|
||||
func (d *Database) Table(name string) IQuery {
|
||||
return &QueryBuilder{
|
||||
db: d,
|
||||
table: name,
|
||||
preloadRelations: make(map[string][]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Where 添加 WHERE 条件 - 性能优化版本
|
||||
func (q *QueryBuilder) Where(query string, args ...interface{}) IQuery {
|
||||
if q.whereSQL == "" {
|
||||
q.whereSQL = query
|
||||
} else {
|
||||
// 使用 strings.Builder 优化字符串拼接
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(q.whereSQL) + 5 + len(query)) // 预分配内存
|
||||
builder.WriteString(q.whereSQL)
|
||||
builder.WriteString(" AND ")
|
||||
builder.WriteString(query)
|
||||
q.whereSQL = builder.String()
|
||||
}
|
||||
q.whereArgs = append(q.whereArgs, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
// Or 添加 OR 条件 - 性能优化版本
|
||||
func (q *QueryBuilder) Or(query string, args ...interface{}) IQuery {
|
||||
if q.whereSQL == "" {
|
||||
q.whereSQL = query
|
||||
} else {
|
||||
// 使用 strings.Builder 优化字符串拼接
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(q.whereSQL) + 10 + len(query)) // 预分配内存
|
||||
builder.WriteString(" (")
|
||||
builder.WriteString(q.whereSQL)
|
||||
builder.WriteString(") OR ")
|
||||
builder.WriteString(query)
|
||||
q.whereSQL = builder.String()
|
||||
}
|
||||
q.whereArgs = append(q.whereArgs, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
// And 添加 AND 条件(同 Where)
|
||||
func (q *QueryBuilder) And(query string, args ...interface{}) IQuery {
|
||||
return q.Where(query, args...)
|
||||
}
|
||||
|
||||
// Select 选择要查询的字段
|
||||
func (q *QueryBuilder) Select(fields ...string) IQuery {
|
||||
q.selectCols = fields
|
||||
return q
|
||||
}
|
||||
|
||||
// Omit 排除指定的字段
|
||||
func (q *QueryBuilder) Omit(fields ...string) IQuery {
|
||||
q.omitCols = append(q.omitCols, fields...)
|
||||
return q
|
||||
}
|
||||
|
||||
// Order 设置排序规则
|
||||
func (q *QueryBuilder) Order(order string) IQuery {
|
||||
q.orderSQL = order
|
||||
return q
|
||||
}
|
||||
|
||||
// OrderBy 按指定字段和方向排序
|
||||
func (q *QueryBuilder) OrderBy(field string, direction string) IQuery {
|
||||
q.orderSQL = field + " " + direction
|
||||
return q
|
||||
}
|
||||
|
||||
// Limit 限制返回数量
|
||||
func (q *QueryBuilder) Limit(limit int) IQuery {
|
||||
q.limit = limit
|
||||
return q
|
||||
}
|
||||
|
||||
// Offset 设置偏移量
|
||||
func (q *QueryBuilder) Offset(offset int) IQuery {
|
||||
q.offset = offset
|
||||
return q
|
||||
}
|
||||
|
||||
// Page 分页查询
|
||||
func (q *QueryBuilder) Page(page, pageSize int) IQuery {
|
||||
q.limit = pageSize
|
||||
q.offset = (page - 1) * pageSize
|
||||
return q
|
||||
}
|
||||
|
||||
// Group 设置分组字段
|
||||
func (q *QueryBuilder) Group(group string) IQuery {
|
||||
q.groupSQL = group
|
||||
return q
|
||||
}
|
||||
|
||||
// Having 添加 HAVING 条件
|
||||
func (q *QueryBuilder) Having(having string, args ...interface{}) IQuery {
|
||||
q.havingSQL = having
|
||||
q.havingArgs = args
|
||||
return q
|
||||
}
|
||||
|
||||
// Join 添加 JOIN 连接 - 性能优化版本
|
||||
func (q *QueryBuilder) Join(join string, args ...interface{}) IQuery {
|
||||
if q.joinSQL == "" {
|
||||
q.joinSQL = join
|
||||
} else {
|
||||
// 使用 strings.Builder 优化字符串拼接
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(q.joinSQL) + 1 + len(join)) // 预分配内存
|
||||
builder.WriteString(q.joinSQL)
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteString(join)
|
||||
q.joinSQL = builder.String()
|
||||
}
|
||||
q.joinArgs = append(q.joinArgs, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
// LeftJoin 左连接
|
||||
func (q *QueryBuilder) LeftJoin(table, on string) IQuery {
|
||||
return q.Join("LEFT JOIN " + table + " ON " + on)
|
||||
}
|
||||
|
||||
// RightJoin 右连接
|
||||
func (q *QueryBuilder) RightJoin(table, on string) IQuery {
|
||||
return q.Join("RIGHT JOIN " + table + " ON " + on)
|
||||
}
|
||||
|
||||
// InnerJoin 内连接
|
||||
func (q *QueryBuilder) InnerJoin(table, on string) IQuery {
|
||||
return q.Join("INNER JOIN " + table + " ON " + on)
|
||||
}
|
||||
|
||||
// Preload 预加载关联数据
|
||||
func (q *QueryBuilder) Preload(relation string, conditions ...interface{}) IQuery {
|
||||
if q.preloadRelations == nil {
|
||||
q.preloadRelations = make(map[string][]interface{})
|
||||
}
|
||||
// 将关联条件添加到预加载列表中
|
||||
q.preloadRelations[relation] = conditions
|
||||
return q
|
||||
}
|
||||
|
||||
// First 查询第一条记录
|
||||
func (q *QueryBuilder) First(result interface{}) error {
|
||||
q.limit = 1
|
||||
return q.Find(result)
|
||||
}
|
||||
|
||||
// Find 查询多条记录
|
||||
func (q *QueryBuilder) Find(result interface{}) error {
|
||||
// 如果使用缓存,先检查缓存
|
||||
if q.useCache && q.cache != nil && q.cacheKey != "" {
|
||||
if cachedData, exists := q.cache.Get(q.cacheKey); exists {
|
||||
// 缓存命中,将数据拷贝到结果对象
|
||||
if err := deepCopy(cachedData, result); err != nil {
|
||||
return fmt.Errorf("缓存数据拷贝失败:%w", err)
|
||||
}
|
||||
if q.debug || (q.db != nil && q.db.debug) {
|
||||
fmt.Printf("[Magic-ORM] 缓存命中:%s\n", q.cacheKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,执行实际查询
|
||||
sqlStr, args := q.BuildSelect()
|
||||
|
||||
// 调试模式打印 SQL
|
||||
if q.debug || (q.db != nil && q.db.debug) {
|
||||
fmt.Printf("[Magic-ORM] SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||||
}
|
||||
|
||||
// 干跑模式不执行 SQL
|
||||
if q.dryRun {
|
||||
return nil
|
||||
}
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
// 判断是否在事务中
|
||||
if q.tx != nil {
|
||||
rows, err = q.tx.Query(sqlStr, args...)
|
||||
} else if q.db != nil && q.db.db != nil {
|
||||
rows, err = q.db.db.Query(sqlStr, args...)
|
||||
} else {
|
||||
return fmt.Errorf("数据库连接未初始化")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询失败:%w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 使用 ResultSetMapper 将查询结果映射到 result
|
||||
mapper := NewResultSetMapper()
|
||||
if err := mapper.ScanAll(rows, result); err != nil {
|
||||
return fmt.Errorf("结果映射失败:%w", err)
|
||||
}
|
||||
|
||||
// 执行预加载关联数据
|
||||
if len(q.preloadRelations) > 0 {
|
||||
if err := q.executePreload(result); err != nil {
|
||||
return fmt.Errorf("预加载关联失败:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 将结果存入缓存(如果启用了缓存)
|
||||
if q.useCache && q.cache != nil && q.cacheKey != "" {
|
||||
q.cache.Set(q.cacheKey, result)
|
||||
if q.debug || (q.db != nil && q.db.debug) {
|
||||
fmt.Printf("[Magic-ORM] 缓存已设置:%s\n", q.cacheKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计记录数量
|
||||
func (q *QueryBuilder) Count(count *int64) IQuery {
|
||||
// 构建 COUNT 查询
|
||||
originalSelect := q.selectCols
|
||||
q.selectCols = []string{"COUNT(*)"}
|
||||
|
||||
sqlStr, args := q.BuildSelect()
|
||||
|
||||
// 调试模式
|
||||
if q.debug || (q.db != nil && q.db.debug) {
|
||||
fmt.Printf("[Magic-ORM] COUNT SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||||
}
|
||||
|
||||
// 干跑模式
|
||||
if q.dryRun {
|
||||
return q
|
||||
}
|
||||
|
||||
var err error
|
||||
if q.tx != nil {
|
||||
err = q.tx.QueryRow(sqlStr, args...).Scan(count)
|
||||
} else if q.db != nil && q.db.db != nil {
|
||||
err = q.db.db.QueryRow(sqlStr, args...).Scan(count)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("[Magic-ORM] Count 错误:%v\n", err)
|
||||
}
|
||||
|
||||
// 恢复原来的选择字段
|
||||
q.selectCols = originalSelect
|
||||
return q
|
||||
}
|
||||
|
||||
// Exists 检查记录是否存在
|
||||
func (q *QueryBuilder) Exists() (bool, error) {
|
||||
// 使用 LIMIT 1 优化查询
|
||||
originalLimit := q.limit
|
||||
q.limit = 1
|
||||
|
||||
sqlStr, args := q.BuildSelect()
|
||||
|
||||
// 调试模式
|
||||
if q.debug || (q.db != nil && q.db.debug) {
|
||||
fmt.Printf("[Magic-ORM] EXISTS SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||||
}
|
||||
|
||||
// 干跑模式
|
||||
if q.dryRun {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
if q.tx != nil {
|
||||
rows, err = q.tx.Query(sqlStr, args...)
|
||||
} else if q.db != nil && q.db.db != nil {
|
||||
rows, err = q.db.db.Query(sqlStr, args...)
|
||||
} else {
|
||||
return false, fmt.Errorf("数据库连接未初始化")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 检查是否有结果
|
||||
exists := rows.Next()
|
||||
|
||||
// 恢复原来的 limit
|
||||
q.limit = originalLimit
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// Updates 更新数据
|
||||
func (q *QueryBuilder) Updates(data interface{}) error {
|
||||
sqlStr, args := q.BuildUpdate(data)
|
||||
|
||||
// 调试模式打印 SQL
|
||||
if q.debug || (q.db != nil && q.db.debug) {
|
||||
fmt.Printf("[Magic-ORM] UPDATE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||||
}
|
||||
|
||||
// 干跑模式不执行 SQL
|
||||
if q.dryRun {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if q.tx != nil {
|
||||
_, err = q.tx.Exec(sqlStr, args...)
|
||||
} else if q.db != nil && q.db.db != nil {
|
||||
_, err = q.db.db.Exec(sqlStr, args...)
|
||||
} else {
|
||||
return fmt.Errorf("数据库连接未初始化")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateColumn 更新单个字段
|
||||
func (q *QueryBuilder) UpdateColumn(column string, value interface{}) error {
|
||||
return q.Updates(map[string]interface{}{column: value})
|
||||
}
|
||||
|
||||
// Delete 删除数据
|
||||
func (q *QueryBuilder) Delete() error {
|
||||
sqlStr, args := q.BuildDelete()
|
||||
|
||||
// 调试模式打印 SQL
|
||||
if q.debug || (q.db != nil && q.db.debug) {
|
||||
fmt.Printf("[Magic-ORM] DELETE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||||
}
|
||||
|
||||
// 干跑模式不执行 SQL
|
||||
if q.dryRun {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if q.tx != nil {
|
||||
_, err = q.tx.Exec(sqlStr, args...)
|
||||
} else if q.db != nil && q.db.db != nil {
|
||||
_, err = q.db.db.Exec(sqlStr, args...)
|
||||
} else {
|
||||
return fmt.Errorf("数据库连接未初始化")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unscoped 忽略软删除限制
|
||||
func (q *QueryBuilder) Unscoped() IQuery {
|
||||
q.unscoped = true
|
||||
return q
|
||||
}
|
||||
|
||||
// DryRun 设置干跑模式(只生成 SQL 不执行)
|
||||
func (q *QueryBuilder) DryRun() IQuery {
|
||||
q.dryRun = true
|
||||
return q
|
||||
}
|
||||
|
||||
// Debug 设置调试模式(打印 SQL 日志)
|
||||
func (q *QueryBuilder) Debug() IQuery {
|
||||
q.debug = true
|
||||
return q
|
||||
}
|
||||
|
||||
// Build 构建 SELECT SQL 语句
|
||||
func (q *QueryBuilder) Build() (string, []interface{}) {
|
||||
return q.BuildSelect()
|
||||
}
|
||||
|
||||
// BuildSelect 构建 SELECT SQL 语句
|
||||
func (q *QueryBuilder) BuildSelect() (string, []interface{}) {
|
||||
var builder strings.Builder
|
||||
|
||||
// SELECT 部分
|
||||
builder.WriteString("SELECT ")
|
||||
if len(q.selectCols) > 0 {
|
||||
// 如果指定了选择字段,直接使用
|
||||
builder.WriteString(strings.Join(q.selectCols, ", "))
|
||||
} else if len(q.omitCols) > 0 {
|
||||
// 如果没有指定 select 但设置了 omit,需要从模型获取所有字段并排除 omit 的字段
|
||||
fields := q.getAllFields()
|
||||
if len(fields) > 0 {
|
||||
builder.WriteString(strings.Join(fields, ", "))
|
||||
} else {
|
||||
// 无法获取字段信息,使用 *
|
||||
builder.WriteString("*")
|
||||
}
|
||||
} else {
|
||||
// 默认选择所有字段
|
||||
builder.WriteString("*")
|
||||
}
|
||||
|
||||
// FROM 部分
|
||||
builder.WriteString(" FROM ")
|
||||
if q.table != "" {
|
||||
builder.WriteString(q.table)
|
||||
} else if q.model != nil {
|
||||
// 从模型获取表名
|
||||
mapper := NewFieldMapper()
|
||||
builder.WriteString(mapper.GetTableName(q.model))
|
||||
} else {
|
||||
builder.WriteString("unknown_table")
|
||||
}
|
||||
|
||||
// JOIN 部分
|
||||
if q.joinSQL != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(q.joinSQL)
|
||||
}
|
||||
|
||||
// WHERE 部分
|
||||
if q.whereSQL != "" {
|
||||
builder.WriteString(" WHERE ")
|
||||
builder.WriteString(q.whereSQL)
|
||||
}
|
||||
|
||||
// GROUP BY 部分
|
||||
if q.groupSQL != "" {
|
||||
builder.WriteString(" GROUP BY ")
|
||||
builder.WriteString(q.groupSQL)
|
||||
}
|
||||
|
||||
// HAVING 部分
|
||||
if q.havingSQL != "" {
|
||||
builder.WriteString(" HAVING ")
|
||||
builder.WriteString(q.havingSQL)
|
||||
}
|
||||
|
||||
// ORDER BY 部分
|
||||
if q.orderSQL != "" {
|
||||
builder.WriteString(" ORDER BY ")
|
||||
builder.WriteString(q.orderSQL)
|
||||
}
|
||||
|
||||
// LIMIT 部分
|
||||
if q.limit > 0 {
|
||||
builder.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
|
||||
}
|
||||
|
||||
// OFFSET 部分
|
||||
if q.offset > 0 {
|
||||
builder.WriteString(fmt.Sprintf(" OFFSET %d", q.offset))
|
||||
}
|
||||
|
||||
// 合并参数
|
||||
allArgs := make([]interface{}, 0)
|
||||
allArgs = append(allArgs, q.joinArgs...)
|
||||
allArgs = append(allArgs, q.whereArgs...)
|
||||
allArgs = append(allArgs, q.havingArgs...)
|
||||
|
||||
return builder.String(), allArgs
|
||||
}
|
||||
|
||||
// getAllFields 获取模型的所有字段(排除 omit 的字段)
|
||||
func (q *QueryBuilder) getAllFields() []string {
|
||||
var fields []string
|
||||
|
||||
// 如果有模型,从模型获取字段
|
||||
if q.model != nil {
|
||||
mapper := NewFieldMapper()
|
||||
fieldInfos := mapper.GetFields(q.model)
|
||||
|
||||
// 创建 omit 字段的 map 用于快速查找
|
||||
omitMap := make(map[string]bool)
|
||||
for _, omitField := range q.omitCols {
|
||||
// 同时存储原始形式和小写形式,支持不区分大小写的匹配
|
||||
omitMap[omitField] = true
|
||||
omitMap[strings.ToLower(omitField)] = true
|
||||
}
|
||||
|
||||
// 遍历所有字段,排除 omit 的字段
|
||||
for _, fieldInfo := range fieldInfos {
|
||||
// 检查字段是否在 omit 列表中
|
||||
if !omitMap[fieldInfo.Column] && !omitMap[strings.ToLower(fieldInfo.Column)] {
|
||||
fields = append(fields, fieldInfo.Column)
|
||||
}
|
||||
}
|
||||
} else if q.table != "" {
|
||||
// 如果只有表名没有模型,从数据库元数据获取字段
|
||||
columns, err := q.getTableColumns(q.table)
|
||||
if err != nil {
|
||||
// 如果获取失败,返回 nil 使用 SELECT *
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建 omit 字段的 map 用于快速查找
|
||||
omitMap := make(map[string]bool)
|
||||
for _, omitField := range q.omitCols {
|
||||
omitMap[omitField] = true
|
||||
omitMap[strings.ToLower(omitField)] = true
|
||||
}
|
||||
|
||||
// 过滤掉 omit 的字段
|
||||
for _, col := range columns {
|
||||
if !omitMap[col] && !omitMap[strings.ToLower(col)] {
|
||||
fields = append(fields, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// getTableColumns 从数据库元数据获取表的列名
|
||||
func (q *QueryBuilder) getTableColumns(tableName string) ([]string, error) {
|
||||
if q.db == nil || q.db.db == nil {
|
||||
return nil, fmt.Errorf("数据库连接未初始化")
|
||||
}
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
// 根据不同数据库类型查询元数据
|
||||
switch q.db.driverName {
|
||||
case "mysql":
|
||||
query = `
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?
|
||||
ORDER BY ORDINAL_POSITION
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
case "postgres":
|
||||
query = `
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public' AND table_name = $1
|
||||
ORDER BY ordinal_position
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
case "sqlite", "sqlite3":
|
||||
query = `PRAGMA table_info(?)`
|
||||
args = []interface{}{tableName}
|
||||
default:
|
||||
// 未知数据库类型,返回空
|
||||
return nil, fmt.Errorf("不支持的数据库类型:%s", q.db.driverName)
|
||||
}
|
||||
|
||||
rows, err = q.db.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询表元数据失败:%w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var columns []string
|
||||
for rows.Next() {
|
||||
var columnName string
|
||||
if q.db.driverName == "sqlite" || q.db.driverName == "sqlite3" {
|
||||
// SQLite PRAGMA table_info 返回多列:cid, name, type, notnull, dflt_value, pk
|
||||
var cid int
|
||||
var typ string
|
||||
var notNull int
|
||||
var dfltValue sql.NullString
|
||||
var pk int
|
||||
if err := rows.Scan(&cid, &columnName, &typ, ¬Null, &dfltValue, &pk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Scan(&columnName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
columns = append(columns, columnName)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// executePreload 执行预加载关联数据
|
||||
func (q *QueryBuilder) executePreload(models interface{}) error {
|
||||
// 创建关联加载器
|
||||
loader := NewRelationLoader(q.db)
|
||||
|
||||
// 遍历所有预加载的关联关系
|
||||
for relation, conditions := range q.preloadRelations {
|
||||
if err := loader.Preload(models, relation, conditions...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildUpdate 构建 UPDATE SQL 语句
|
||||
func (q *QueryBuilder) BuildUpdate(data interface{}) (string, []interface{}) {
|
||||
var builder strings.Builder
|
||||
var args []interface{}
|
||||
|
||||
builder.WriteString("UPDATE ")
|
||||
if q.table != "" {
|
||||
builder.WriteString(q.table)
|
||||
} else if q.model != nil {
|
||||
mapper := NewFieldMapper()
|
||||
builder.WriteString(mapper.GetTableName(q.model))
|
||||
} else {
|
||||
builder.WriteString("unknown_table")
|
||||
}
|
||||
|
||||
builder.WriteString(" SET ")
|
||||
|
||||
// 根据 data 类型生成 SET 子句
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
// map 类型,生成 key=value 对
|
||||
setParts := make([]string, 0, len(v))
|
||||
for key, value := range v {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
|
||||
args = append(args, value)
|
||||
}
|
||||
builder.WriteString(strings.Join(setParts, ", "))
|
||||
case string:
|
||||
// string 类型,直接使用(注意:实际使用需要转义)
|
||||
builder.WriteString(v)
|
||||
default:
|
||||
// 结构体类型,使用字段映射器
|
||||
mapper := NewFieldMapper()
|
||||
columns, err := mapper.StructToColumns(data)
|
||||
if err == nil && len(columns) > 0 {
|
||||
setParts := make([]string, 0, len(columns))
|
||||
for key := range columns {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
|
||||
args = append(args, columns[key])
|
||||
}
|
||||
builder.WriteString(strings.Join(setParts, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// WHERE 部分
|
||||
if q.whereSQL != "" {
|
||||
builder.WriteString(" WHERE ")
|
||||
builder.WriteString(q.whereSQL)
|
||||
args = append(args, q.whereArgs...)
|
||||
}
|
||||
|
||||
return builder.String(), args
|
||||
}
|
||||
|
||||
// BuildDelete 构建 DELETE SQL 语句
|
||||
func (q *QueryBuilder) BuildDelete() (string, []interface{}) {
|
||||
var builder strings.Builder
|
||||
|
||||
builder.WriteString("DELETE FROM ")
|
||||
if q.table != "" {
|
||||
builder.WriteString(q.table)
|
||||
} else if q.model != nil {
|
||||
mapper := NewFieldMapper()
|
||||
builder.WriteString(mapper.GetTableName(q.model))
|
||||
} else {
|
||||
builder.WriteString("unknown_table")
|
||||
}
|
||||
|
||||
if q.whereSQL != "" {
|
||||
builder.WriteString(" WHERE ")
|
||||
builder.WriteString(q.whereSQL)
|
||||
}
|
||||
|
||||
return builder.String(), q.whereArgs
|
||||
}
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ReadWriteDB 读写分离数据库连接
|
||||
type ReadWriteDB struct {
|
||||
master *sql.DB // 主库(写)
|
||||
slaves []*sql.DB // 从库列表(读)
|
||||
policy ReadPolicy // 读负载均衡策略
|
||||
counter uint64 // 轮询计数器
|
||||
mu sync.RWMutex // 读写锁
|
||||
}
|
||||
|
||||
// NewReadWriteDB 创建读写分离数据库连接
|
||||
func NewReadWriteDB(master *sql.DB, slaves []*sql.DB, policy ReadPolicy) *ReadWriteDB {
|
||||
return &ReadWriteDB{
|
||||
master: master,
|
||||
slaves: slaves,
|
||||
policy: policy,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMaster 获取主库连接(用于写操作)
|
||||
func (rw *ReadWriteDB) GetMaster() *sql.DB {
|
||||
return rw.master
|
||||
}
|
||||
|
||||
// GetSlave 获取从库连接(用于读操作)
|
||||
func (rw *ReadWriteDB) GetSlave() *sql.DB {
|
||||
rw.mu.RLock()
|
||||
defer rw.mu.RUnlock()
|
||||
|
||||
if len(rw.slaves) == 0 {
|
||||
// 没有从库,使用主库
|
||||
return rw.master
|
||||
}
|
||||
|
||||
switch rw.policy {
|
||||
case Random:
|
||||
// 随机选择一个从库
|
||||
idx := int(atomic.LoadUint64(&rw.counter)) % len(rw.slaves)
|
||||
return rw.slaves[idx]
|
||||
|
||||
case RoundRobin:
|
||||
// 轮询选择从库
|
||||
idx := int(atomic.AddUint64(&rw.counter, 1)) % len(rw.slaves)
|
||||
return rw.slaves[idx]
|
||||
|
||||
case LeastConn:
|
||||
// 选择连接数最少的从库(简化实现)
|
||||
return rw.selectLeastConn()
|
||||
|
||||
default:
|
||||
return rw.slaves[0]
|
||||
}
|
||||
}
|
||||
|
||||
// selectLeastConn 选择连接数最少的从库
|
||||
func (rw *ReadWriteDB) selectLeastConn() *sql.DB {
|
||||
if len(rw.slaves) == 0 {
|
||||
return rw.master
|
||||
}
|
||||
|
||||
minConn := -1
|
||||
selected := rw.slaves[0]
|
||||
|
||||
for _, slave := range rw.slaves {
|
||||
stats := slave.Stats()
|
||||
openConnections := stats.OpenConnections
|
||||
|
||||
if minConn == -1 || openConnections < minConn {
|
||||
minConn = openConnections
|
||||
selected = slave
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// AddSlave 添加从库
|
||||
func (rw *ReadWriteDB) AddSlave(slave *sql.DB) {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
rw.slaves = append(rw.slaves, slave)
|
||||
}
|
||||
|
||||
// RemoveSlave 移除从库
|
||||
func (rw *ReadWriteDB) RemoveSlave(slave *sql.DB) {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
|
||||
for i, s := range rw.slaves {
|
||||
if s == slave {
|
||||
rw.slaves = append(rw.slaves[:i], rw.slaves[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭所有连接
|
||||
func (rw *ReadWriteDB) Close() error {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
|
||||
// 关闭主库
|
||||
if rw.master != nil {
|
||||
if err := rw.master.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭所有从库
|
||||
for _, slave := range rw.slaves {
|
||||
if err := slave.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue