gin-base/db/core/query.go

549 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package core
import (
"database/sql"
"fmt"
"strings"
"sync"
)
// QueryBuilder 查询构建器实现 - 提供流畅的链式查询构建能力
type QueryBuilder struct {
db *Database // 数据库连接实例
table string // 表名
model interface{} // 模型对象
whereSQL string // WHERE 条件 SQL
whereArgs []interface{} // WHERE 条件参数
selectCols []string // 选择的字段列表
orderSQL string // ORDER BY SQL
limit int // LIMIT 限制数量
offset int // OFFSET 偏移量
groupSQL string // GROUP BY SQL
havingSQL string // HAVING 条件 SQL
havingArgs []interface{} // HAVING 条件参数
joinSQL string // JOIN SQL
joinArgs []interface{} // JOIN 参数
debug bool // 调试模式开关
dryRun bool // 干跑模式开关
unscoped bool // 忽略软删除开关
tx *sql.Tx // 事务对象(如果在事务中)
}
// 同步池优化 - 复用 slice 减少内存分配
var whereArgsPool = sync.Pool{
New: func() interface{} {
return make([]interface{}, 0, 10)
},
}
var joinArgsPool = sync.Pool{
New: func() interface{} {
return make([]interface{}, 0, 5)
},
}
// Model 基于模型创建查询
func (d *Database) Model(model interface{}) IQuery {
return &QueryBuilder{
db: d,
model: model,
}
}
// Table 基于表名创建查询
func (d *Database) Table(name string) IQuery {
return &QueryBuilder{
db: d,
table: name,
}
}
// Where 添加 WHERE 条件 - 性能优化版本
func (q *QueryBuilder) Where(query string, args ...interface{}) IQuery {
if q.whereSQL == "" {
q.whereSQL = query
} else {
// 使用 strings.Builder 优化字符串拼接
var builder strings.Builder
builder.Grow(len(q.whereSQL) + 5 + len(query)) // 预分配内存
builder.WriteString(q.whereSQL)
builder.WriteString(" AND ")
builder.WriteString(query)
q.whereSQL = builder.String()
}
q.whereArgs = append(q.whereArgs, args...)
return q
}
// Or 添加 OR 条件 - 性能优化版本
func (q *QueryBuilder) Or(query string, args ...interface{}) IQuery {
if q.whereSQL == "" {
q.whereSQL = query
} else {
// 使用 strings.Builder 优化字符串拼接
var builder strings.Builder
builder.Grow(len(q.whereSQL) + 10 + len(query)) // 预分配内存
builder.WriteString(" (")
builder.WriteString(q.whereSQL)
builder.WriteString(") OR ")
builder.WriteString(query)
q.whereSQL = builder.String()
}
q.whereArgs = append(q.whereArgs, args...)
return q
}
// And 添加 AND 条件(同 Where
func (q *QueryBuilder) And(query string, args ...interface{}) IQuery {
return q.Where(query, args...)
}
// Select 选择要查询的字段
func (q *QueryBuilder) Select(fields ...string) IQuery {
q.selectCols = fields
return q
}
// Omit 排除指定的字段(暂未实现)
func (q *QueryBuilder) Omit(fields ...string) IQuery {
// TODO: 实现字段排除逻辑,生成 SELECT 时排除这些字段
return q
}
// Order 设置排序规则
func (q *QueryBuilder) Order(order string) IQuery {
q.orderSQL = order
return q
}
// OrderBy 按指定字段和方向排序
func (q *QueryBuilder) OrderBy(field string, direction string) IQuery {
q.orderSQL = field + " " + direction
return q
}
// Limit 限制返回数量
func (q *QueryBuilder) Limit(limit int) IQuery {
q.limit = limit
return q
}
// Offset 设置偏移量
func (q *QueryBuilder) Offset(offset int) IQuery {
q.offset = offset
return q
}
// Page 分页查询
func (q *QueryBuilder) Page(page, pageSize int) IQuery {
q.limit = pageSize
q.offset = (page - 1) * pageSize
return q
}
// Group 设置分组字段
func (q *QueryBuilder) Group(group string) IQuery {
q.groupSQL = group
return q
}
// Having 添加 HAVING 条件
func (q *QueryBuilder) Having(having string, args ...interface{}) IQuery {
q.havingSQL = having
q.havingArgs = args
return q
}
// Join 添加 JOIN 连接 - 性能优化版本
func (q *QueryBuilder) Join(join string, args ...interface{}) IQuery {
if q.joinSQL == "" {
q.joinSQL = join
} else {
// 使用 strings.Builder 优化字符串拼接
var builder strings.Builder
builder.Grow(len(q.joinSQL) + 1 + len(join)) // 预分配内存
builder.WriteString(q.joinSQL)
builder.WriteByte(' ')
builder.WriteString(join)
q.joinSQL = builder.String()
}
q.joinArgs = append(q.joinArgs, args...)
return q
}
// LeftJoin 左连接
func (q *QueryBuilder) LeftJoin(table, on string) IQuery {
return q.Join("LEFT JOIN " + table + " ON " + on)
}
// RightJoin 右连接
func (q *QueryBuilder) RightJoin(table, on string) IQuery {
return q.Join("RIGHT JOIN " + table + " ON " + on)
}
// InnerJoin 内连接
func (q *QueryBuilder) InnerJoin(table, on string) IQuery {
return q.Join("INNER JOIN " + table + " ON " + on)
}
// Preload 预加载关联数据(暂未实现)
func (q *QueryBuilder) Preload(relation string, conditions ...interface{}) IQuery {
// TODO: 实现预加载逻辑
return q
}
// First 查询第一条记录
func (q *QueryBuilder) First(result interface{}) error {
q.limit = 1
return q.Find(result)
}
// Find 查询多条记录
func (q *QueryBuilder) Find(result interface{}) error {
sqlStr, args := q.BuildSelect()
// 调试模式打印 SQL
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式不执行 SQL
if q.dryRun {
return nil
}
var rows *sql.Rows
var err error
// 判断是否在事务中
if q.tx != nil {
rows, err = q.tx.Query(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
rows, err = q.db.db.Query(sqlStr, args...)
} else {
return fmt.Errorf("数据库连接未初始化")
}
if err != nil {
return fmt.Errorf("查询失败:%w", err)
}
defer rows.Close()
// TODO: 实现结果映射逻辑
// 使用 FieldMapper 将查询结果映射到 result
return nil
}
// Count 统计记录数量
func (q *QueryBuilder) Count(count *int64) IQuery {
// 构建 COUNT 查询
originalSelect := q.selectCols
q.selectCols = []string{"COUNT(*)"}
sqlStr, args := q.BuildSelect()
// 调试模式
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] COUNT SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式
if q.dryRun {
return q
}
var err error
if q.tx != nil {
err = q.tx.QueryRow(sqlStr, args...).Scan(count)
} else if q.db != nil && q.db.db != nil {
err = q.db.db.QueryRow(sqlStr, args...).Scan(count)
}
if err != nil {
fmt.Printf("[Magic-ORM] Count 错误:%v\n", err)
}
// 恢复原来的选择字段
q.selectCols = originalSelect
return q
}
// Exists 检查记录是否存在
func (q *QueryBuilder) Exists() (bool, error) {
// 使用 LIMIT 1 优化查询
originalLimit := q.limit
q.limit = 1
sqlStr, args := q.BuildSelect()
// 调试模式
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] EXISTS SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式
if q.dryRun {
return false, nil
}
var rows *sql.Rows
var err error
if q.tx != nil {
rows, err = q.tx.Query(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
rows, err = q.db.db.Query(sqlStr, args...)
} else {
return false, fmt.Errorf("数据库连接未初始化")
}
defer rows.Close()
if err != nil {
return false, err
}
// 检查是否有结果
exists := rows.Next()
// 恢复原来的 limit
q.limit = originalLimit
return exists, nil
}
// Updates 更新数据
func (q *QueryBuilder) Updates(data interface{}) error {
sqlStr, args := q.BuildUpdate(data)
// 调试模式打印 SQL
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] UPDATE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式不执行 SQL
if q.dryRun {
return nil
}
var err error
if q.tx != nil {
_, err = q.tx.Exec(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
_, err = q.db.db.Exec(sqlStr, args...)
} else {
return fmt.Errorf("数据库连接未初始化")
}
if err != nil {
return fmt.Errorf("更新失败:%w", err)
}
return nil
}
// UpdateColumn 更新单个字段
func (q *QueryBuilder) UpdateColumn(column string, value interface{}) error {
return q.Updates(map[string]interface{}{column: value})
}
// Delete 删除数据
func (q *QueryBuilder) Delete() error {
sqlStr, args := q.BuildDelete()
// 调试模式打印 SQL
if q.debug || (q.db != nil && q.db.debug) {
fmt.Printf("[Magic-ORM] DELETE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
}
// 干跑模式不执行 SQL
if q.dryRun {
return nil
}
var err error
if q.tx != nil {
_, err = q.tx.Exec(sqlStr, args...)
} else if q.db != nil && q.db.db != nil {
_, err = q.db.db.Exec(sqlStr, args...)
} else {
return fmt.Errorf("数据库连接未初始化")
}
if err != nil {
return fmt.Errorf("删除失败:%w", err)
}
return nil
}
// Unscoped 忽略软删除限制
func (q *QueryBuilder) Unscoped() IQuery {
q.unscoped = true
return q
}
// DryRun 设置干跑模式(只生成 SQL 不执行)
func (q *QueryBuilder) DryRun() IQuery {
q.dryRun = true
return q
}
// Debug 设置调试模式(打印 SQL 日志)
func (q *QueryBuilder) Debug() IQuery {
q.debug = true
return q
}
// Build 构建 SELECT SQL 语句
func (q *QueryBuilder) Build() (string, []interface{}) {
return q.BuildSelect()
}
// BuildSelect 构建 SELECT SQL 语句
func (q *QueryBuilder) BuildSelect() (string, []interface{}) {
var builder strings.Builder
// SELECT 部分
builder.WriteString("SELECT ")
if len(q.selectCols) > 0 {
builder.WriteString(strings.Join(q.selectCols, ", "))
} else {
builder.WriteString("*")
}
// FROM 部分
builder.WriteString(" FROM ")
if q.table != "" {
builder.WriteString(q.table)
} else if q.model != nil {
// 从模型获取表名
mapper := NewFieldMapper()
builder.WriteString(mapper.GetTableName(q.model))
} else {
builder.WriteString("unknown_table")
}
// JOIN 部分
if q.joinSQL != "" {
builder.WriteString(" ")
builder.WriteString(q.joinSQL)
}
// WHERE 部分
if q.whereSQL != "" {
builder.WriteString(" WHERE ")
builder.WriteString(q.whereSQL)
}
// GROUP BY 部分
if q.groupSQL != "" {
builder.WriteString(" GROUP BY ")
builder.WriteString(q.groupSQL)
}
// HAVING 部分
if q.havingSQL != "" {
builder.WriteString(" HAVING ")
builder.WriteString(q.havingSQL)
}
// ORDER BY 部分
if q.orderSQL != "" {
builder.WriteString(" ORDER BY ")
builder.WriteString(q.orderSQL)
}
// LIMIT 部分
if q.limit > 0 {
builder.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
}
// OFFSET 部分
if q.offset > 0 {
builder.WriteString(fmt.Sprintf(" OFFSET %d", q.offset))
}
// 合并参数
allArgs := make([]interface{}, 0)
allArgs = append(allArgs, q.joinArgs...)
allArgs = append(allArgs, q.whereArgs...)
allArgs = append(allArgs, q.havingArgs...)
return builder.String(), allArgs
}
// BuildUpdate 构建 UPDATE SQL 语句
func (q *QueryBuilder) BuildUpdate(data interface{}) (string, []interface{}) {
var builder strings.Builder
var args []interface{}
builder.WriteString("UPDATE ")
if q.table != "" {
builder.WriteString(q.table)
} else if q.model != nil {
mapper := NewFieldMapper()
builder.WriteString(mapper.GetTableName(q.model))
} else {
builder.WriteString("unknown_table")
}
builder.WriteString(" SET ")
// 根据 data 类型生成 SET 子句
switch v := data.(type) {
case map[string]interface{}:
// map 类型,生成 key=value 对
setParts := make([]string, 0, len(v))
for key, value := range v {
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
args = append(args, value)
}
builder.WriteString(strings.Join(setParts, ", "))
case string:
// string 类型,直接使用(注意:实际使用需要转义)
builder.WriteString(v)
default:
// 结构体类型,使用字段映射器
mapper := NewFieldMapper()
columns, err := mapper.StructToColumns(data)
if err == nil && len(columns) > 0 {
setParts := make([]string, 0, len(columns))
for key := range columns {
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
args = append(args, columns[key])
}
builder.WriteString(strings.Join(setParts, ", "))
}
}
// WHERE 部分
if q.whereSQL != "" {
builder.WriteString(" WHERE ")
builder.WriteString(q.whereSQL)
args = append(args, q.whereArgs...)
}
return builder.String(), args
}
// BuildDelete 构建 DELETE SQL 语句
func (q *QueryBuilder) BuildDelete() (string, []interface{}) {
var builder strings.Builder
builder.WriteString("DELETE FROM ")
if q.table != "" {
builder.WriteString(q.table)
} else if q.model != nil {
mapper := NewFieldMapper()
builder.WriteString(mapper.GetTableName(q.model))
} else {
builder.WriteString("unknown_table")
}
if q.whereSQL != "" {
builder.WriteString(" WHERE ")
builder.WriteString(q.whereSQL)
}
return builder.String(), q.whereArgs
}