549 lines
13 KiB
Go
549 lines
13 KiB
Go
package core
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"strings"
|
||
"sync"
|
||
)
|
||
|
||
// QueryBuilder 查询构建器实现 - 提供流畅的链式查询构建能力
|
||
type QueryBuilder struct {
|
||
db *Database // 数据库连接实例
|
||
table string // 表名
|
||
model interface{} // 模型对象
|
||
whereSQL string // WHERE 条件 SQL
|
||
whereArgs []interface{} // WHERE 条件参数
|
||
selectCols []string // 选择的字段列表
|
||
orderSQL string // ORDER BY SQL
|
||
limit int // LIMIT 限制数量
|
||
offset int // OFFSET 偏移量
|
||
groupSQL string // GROUP BY SQL
|
||
havingSQL string // HAVING 条件 SQL
|
||
havingArgs []interface{} // HAVING 条件参数
|
||
joinSQL string // JOIN SQL
|
||
joinArgs []interface{} // JOIN 参数
|
||
debug bool // 调试模式开关
|
||
dryRun bool // 干跑模式开关
|
||
unscoped bool // 忽略软删除开关
|
||
tx *sql.Tx // 事务对象(如果在事务中)
|
||
}
|
||
|
||
// 同步池优化 - 复用 slice 减少内存分配
|
||
var whereArgsPool = sync.Pool{
|
||
New: func() interface{} {
|
||
return make([]interface{}, 0, 10)
|
||
},
|
||
}
|
||
|
||
var joinArgsPool = sync.Pool{
|
||
New: func() interface{} {
|
||
return make([]interface{}, 0, 5)
|
||
},
|
||
}
|
||
|
||
// Model 基于模型创建查询
|
||
func (d *Database) Model(model interface{}) IQuery {
|
||
return &QueryBuilder{
|
||
db: d,
|
||
model: model,
|
||
}
|
||
}
|
||
|
||
// Table 基于表名创建查询
|
||
func (d *Database) Table(name string) IQuery {
|
||
return &QueryBuilder{
|
||
db: d,
|
||
table: name,
|
||
}
|
||
}
|
||
|
||
// Where 添加 WHERE 条件 - 性能优化版本
|
||
func (q *QueryBuilder) Where(query string, args ...interface{}) IQuery {
|
||
if q.whereSQL == "" {
|
||
q.whereSQL = query
|
||
} else {
|
||
// 使用 strings.Builder 优化字符串拼接
|
||
var builder strings.Builder
|
||
builder.Grow(len(q.whereSQL) + 5 + len(query)) // 预分配内存
|
||
builder.WriteString(q.whereSQL)
|
||
builder.WriteString(" AND ")
|
||
builder.WriteString(query)
|
||
q.whereSQL = builder.String()
|
||
}
|
||
q.whereArgs = append(q.whereArgs, args...)
|
||
return q
|
||
}
|
||
|
||
// Or 添加 OR 条件 - 性能优化版本
|
||
func (q *QueryBuilder) Or(query string, args ...interface{}) IQuery {
|
||
if q.whereSQL == "" {
|
||
q.whereSQL = query
|
||
} else {
|
||
// 使用 strings.Builder 优化字符串拼接
|
||
var builder strings.Builder
|
||
builder.Grow(len(q.whereSQL) + 10 + len(query)) // 预分配内存
|
||
builder.WriteString(" (")
|
||
builder.WriteString(q.whereSQL)
|
||
builder.WriteString(") OR ")
|
||
builder.WriteString(query)
|
||
q.whereSQL = builder.String()
|
||
}
|
||
q.whereArgs = append(q.whereArgs, args...)
|
||
return q
|
||
}
|
||
|
||
// And 添加 AND 条件(同 Where)
|
||
func (q *QueryBuilder) And(query string, args ...interface{}) IQuery {
|
||
return q.Where(query, args...)
|
||
}
|
||
|
||
// Select 选择要查询的字段
|
||
func (q *QueryBuilder) Select(fields ...string) IQuery {
|
||
q.selectCols = fields
|
||
return q
|
||
}
|
||
|
||
// Omit 排除指定的字段(暂未实现)
|
||
func (q *QueryBuilder) Omit(fields ...string) IQuery {
|
||
// TODO: 实现字段排除逻辑,生成 SELECT 时排除这些字段
|
||
return q
|
||
}
|
||
|
||
// Order 设置排序规则
|
||
func (q *QueryBuilder) Order(order string) IQuery {
|
||
q.orderSQL = order
|
||
return q
|
||
}
|
||
|
||
// OrderBy 按指定字段和方向排序
|
||
func (q *QueryBuilder) OrderBy(field string, direction string) IQuery {
|
||
q.orderSQL = field + " " + direction
|
||
return q
|
||
}
|
||
|
||
// Limit 限制返回数量
|
||
func (q *QueryBuilder) Limit(limit int) IQuery {
|
||
q.limit = limit
|
||
return q
|
||
}
|
||
|
||
// Offset 设置偏移量
|
||
func (q *QueryBuilder) Offset(offset int) IQuery {
|
||
q.offset = offset
|
||
return q
|
||
}
|
||
|
||
// Page 分页查询
|
||
func (q *QueryBuilder) Page(page, pageSize int) IQuery {
|
||
q.limit = pageSize
|
||
q.offset = (page - 1) * pageSize
|
||
return q
|
||
}
|
||
|
||
// Group 设置分组字段
|
||
func (q *QueryBuilder) Group(group string) IQuery {
|
||
q.groupSQL = group
|
||
return q
|
||
}
|
||
|
||
// Having 添加 HAVING 条件
|
||
func (q *QueryBuilder) Having(having string, args ...interface{}) IQuery {
|
||
q.havingSQL = having
|
||
q.havingArgs = args
|
||
return q
|
||
}
|
||
|
||
// Join 添加 JOIN 连接 - 性能优化版本
|
||
func (q *QueryBuilder) Join(join string, args ...interface{}) IQuery {
|
||
if q.joinSQL == "" {
|
||
q.joinSQL = join
|
||
} else {
|
||
// 使用 strings.Builder 优化字符串拼接
|
||
var builder strings.Builder
|
||
builder.Grow(len(q.joinSQL) + 1 + len(join)) // 预分配内存
|
||
builder.WriteString(q.joinSQL)
|
||
builder.WriteByte(' ')
|
||
builder.WriteString(join)
|
||
q.joinSQL = builder.String()
|
||
}
|
||
q.joinArgs = append(q.joinArgs, args...)
|
||
return q
|
||
}
|
||
|
||
// LeftJoin 左连接
|
||
func (q *QueryBuilder) LeftJoin(table, on string) IQuery {
|
||
return q.Join("LEFT JOIN " + table + " ON " + on)
|
||
}
|
||
|
||
// RightJoin 右连接
|
||
func (q *QueryBuilder) RightJoin(table, on string) IQuery {
|
||
return q.Join("RIGHT JOIN " + table + " ON " + on)
|
||
}
|
||
|
||
// InnerJoin 内连接
|
||
func (q *QueryBuilder) InnerJoin(table, on string) IQuery {
|
||
return q.Join("INNER JOIN " + table + " ON " + on)
|
||
}
|
||
|
||
// Preload 预加载关联数据(暂未实现)
|
||
func (q *QueryBuilder) Preload(relation string, conditions ...interface{}) IQuery {
|
||
// TODO: 实现预加载逻辑
|
||
return q
|
||
}
|
||
|
||
// First 查询第一条记录
|
||
func (q *QueryBuilder) First(result interface{}) error {
|
||
q.limit = 1
|
||
return q.Find(result)
|
||
}
|
||
|
||
// Find 查询多条记录
|
||
func (q *QueryBuilder) Find(result interface{}) error {
|
||
sqlStr, args := q.BuildSelect()
|
||
|
||
// 调试模式打印 SQL
|
||
if q.debug || (q.db != nil && q.db.debug) {
|
||
fmt.Printf("[Magic-ORM] SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||
}
|
||
|
||
// 干跑模式不执行 SQL
|
||
if q.dryRun {
|
||
return nil
|
||
}
|
||
|
||
var rows *sql.Rows
|
||
var err error
|
||
|
||
// 判断是否在事务中
|
||
if q.tx != nil {
|
||
rows, err = q.tx.Query(sqlStr, args...)
|
||
} else if q.db != nil && q.db.db != nil {
|
||
rows, err = q.db.db.Query(sqlStr, args...)
|
||
} else {
|
||
return fmt.Errorf("数据库连接未初始化")
|
||
}
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("查询失败:%w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
// TODO: 实现结果映射逻辑
|
||
// 使用 FieldMapper 将查询结果映射到 result
|
||
|
||
return nil
|
||
}
|
||
|
||
// Count 统计记录数量
|
||
func (q *QueryBuilder) Count(count *int64) IQuery {
|
||
// 构建 COUNT 查询
|
||
originalSelect := q.selectCols
|
||
q.selectCols = []string{"COUNT(*)"}
|
||
|
||
sqlStr, args := q.BuildSelect()
|
||
|
||
// 调试模式
|
||
if q.debug || (q.db != nil && q.db.debug) {
|
||
fmt.Printf("[Magic-ORM] COUNT SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||
}
|
||
|
||
// 干跑模式
|
||
if q.dryRun {
|
||
return q
|
||
}
|
||
|
||
var err error
|
||
if q.tx != nil {
|
||
err = q.tx.QueryRow(sqlStr, args...).Scan(count)
|
||
} else if q.db != nil && q.db.db != nil {
|
||
err = q.db.db.QueryRow(sqlStr, args...).Scan(count)
|
||
}
|
||
|
||
if err != nil {
|
||
fmt.Printf("[Magic-ORM] Count 错误:%v\n", err)
|
||
}
|
||
|
||
// 恢复原来的选择字段
|
||
q.selectCols = originalSelect
|
||
return q
|
||
}
|
||
|
||
// Exists 检查记录是否存在
|
||
func (q *QueryBuilder) Exists() (bool, error) {
|
||
// 使用 LIMIT 1 优化查询
|
||
originalLimit := q.limit
|
||
q.limit = 1
|
||
|
||
sqlStr, args := q.BuildSelect()
|
||
|
||
// 调试模式
|
||
if q.debug || (q.db != nil && q.db.debug) {
|
||
fmt.Printf("[Magic-ORM] EXISTS SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||
}
|
||
|
||
// 干跑模式
|
||
if q.dryRun {
|
||
return false, nil
|
||
}
|
||
|
||
var rows *sql.Rows
|
||
var err error
|
||
|
||
if q.tx != nil {
|
||
rows, err = q.tx.Query(sqlStr, args...)
|
||
} else if q.db != nil && q.db.db != nil {
|
||
rows, err = q.db.db.Query(sqlStr, args...)
|
||
} else {
|
||
return false, fmt.Errorf("数据库连接未初始化")
|
||
}
|
||
defer rows.Close()
|
||
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
// 检查是否有结果
|
||
exists := rows.Next()
|
||
|
||
// 恢复原来的 limit
|
||
q.limit = originalLimit
|
||
|
||
return exists, nil
|
||
}
|
||
|
||
// Updates 更新数据
|
||
func (q *QueryBuilder) Updates(data interface{}) error {
|
||
sqlStr, args := q.BuildUpdate(data)
|
||
|
||
// 调试模式打印 SQL
|
||
if q.debug || (q.db != nil && q.db.debug) {
|
||
fmt.Printf("[Magic-ORM] UPDATE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||
}
|
||
|
||
// 干跑模式不执行 SQL
|
||
if q.dryRun {
|
||
return nil
|
||
}
|
||
|
||
var err error
|
||
if q.tx != nil {
|
||
_, err = q.tx.Exec(sqlStr, args...)
|
||
} else if q.db != nil && q.db.db != nil {
|
||
_, err = q.db.db.Exec(sqlStr, args...)
|
||
} else {
|
||
return fmt.Errorf("数据库连接未初始化")
|
||
}
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("更新失败:%w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// UpdateColumn 更新单个字段
|
||
func (q *QueryBuilder) UpdateColumn(column string, value interface{}) error {
|
||
return q.Updates(map[string]interface{}{column: value})
|
||
}
|
||
|
||
// Delete 删除数据
|
||
func (q *QueryBuilder) Delete() error {
|
||
sqlStr, args := q.BuildDelete()
|
||
|
||
// 调试模式打印 SQL
|
||
if q.debug || (q.db != nil && q.db.debug) {
|
||
fmt.Printf("[Magic-ORM] DELETE SQL: %s\n[Magic-ORM] Args: %v\n", sqlStr, args)
|
||
}
|
||
|
||
// 干跑模式不执行 SQL
|
||
if q.dryRun {
|
||
return nil
|
||
}
|
||
|
||
var err error
|
||
if q.tx != nil {
|
||
_, err = q.tx.Exec(sqlStr, args...)
|
||
} else if q.db != nil && q.db.db != nil {
|
||
_, err = q.db.db.Exec(sqlStr, args...)
|
||
} else {
|
||
return fmt.Errorf("数据库连接未初始化")
|
||
}
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("删除失败:%w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Unscoped 忽略软删除限制
|
||
func (q *QueryBuilder) Unscoped() IQuery {
|
||
q.unscoped = true
|
||
return q
|
||
}
|
||
|
||
// DryRun 设置干跑模式(只生成 SQL 不执行)
|
||
func (q *QueryBuilder) DryRun() IQuery {
|
||
q.dryRun = true
|
||
return q
|
||
}
|
||
|
||
// Debug 设置调试模式(打印 SQL 日志)
|
||
func (q *QueryBuilder) Debug() IQuery {
|
||
q.debug = true
|
||
return q
|
||
}
|
||
|
||
// Build 构建 SELECT SQL 语句
|
||
func (q *QueryBuilder) Build() (string, []interface{}) {
|
||
return q.BuildSelect()
|
||
}
|
||
|
||
// BuildSelect 构建 SELECT SQL 语句
|
||
func (q *QueryBuilder) BuildSelect() (string, []interface{}) {
|
||
var builder strings.Builder
|
||
|
||
// SELECT 部分
|
||
builder.WriteString("SELECT ")
|
||
if len(q.selectCols) > 0 {
|
||
builder.WriteString(strings.Join(q.selectCols, ", "))
|
||
} else {
|
||
builder.WriteString("*")
|
||
}
|
||
|
||
// FROM 部分
|
||
builder.WriteString(" FROM ")
|
||
if q.table != "" {
|
||
builder.WriteString(q.table)
|
||
} else if q.model != nil {
|
||
// 从模型获取表名
|
||
mapper := NewFieldMapper()
|
||
builder.WriteString(mapper.GetTableName(q.model))
|
||
} else {
|
||
builder.WriteString("unknown_table")
|
||
}
|
||
|
||
// JOIN 部分
|
||
if q.joinSQL != "" {
|
||
builder.WriteString(" ")
|
||
builder.WriteString(q.joinSQL)
|
||
}
|
||
|
||
// WHERE 部分
|
||
if q.whereSQL != "" {
|
||
builder.WriteString(" WHERE ")
|
||
builder.WriteString(q.whereSQL)
|
||
}
|
||
|
||
// GROUP BY 部分
|
||
if q.groupSQL != "" {
|
||
builder.WriteString(" GROUP BY ")
|
||
builder.WriteString(q.groupSQL)
|
||
}
|
||
|
||
// HAVING 部分
|
||
if q.havingSQL != "" {
|
||
builder.WriteString(" HAVING ")
|
||
builder.WriteString(q.havingSQL)
|
||
}
|
||
|
||
// ORDER BY 部分
|
||
if q.orderSQL != "" {
|
||
builder.WriteString(" ORDER BY ")
|
||
builder.WriteString(q.orderSQL)
|
||
}
|
||
|
||
// LIMIT 部分
|
||
if q.limit > 0 {
|
||
builder.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
|
||
}
|
||
|
||
// OFFSET 部分
|
||
if q.offset > 0 {
|
||
builder.WriteString(fmt.Sprintf(" OFFSET %d", q.offset))
|
||
}
|
||
|
||
// 合并参数
|
||
allArgs := make([]interface{}, 0)
|
||
allArgs = append(allArgs, q.joinArgs...)
|
||
allArgs = append(allArgs, q.whereArgs...)
|
||
allArgs = append(allArgs, q.havingArgs...)
|
||
|
||
return builder.String(), allArgs
|
||
}
|
||
|
||
// BuildUpdate 构建 UPDATE SQL 语句
|
||
func (q *QueryBuilder) BuildUpdate(data interface{}) (string, []interface{}) {
|
||
var builder strings.Builder
|
||
var args []interface{}
|
||
|
||
builder.WriteString("UPDATE ")
|
||
if q.table != "" {
|
||
builder.WriteString(q.table)
|
||
} else if q.model != nil {
|
||
mapper := NewFieldMapper()
|
||
builder.WriteString(mapper.GetTableName(q.model))
|
||
} else {
|
||
builder.WriteString("unknown_table")
|
||
}
|
||
|
||
builder.WriteString(" SET ")
|
||
|
||
// 根据 data 类型生成 SET 子句
|
||
switch v := data.(type) {
|
||
case map[string]interface{}:
|
||
// map 类型,生成 key=value 对
|
||
setParts := make([]string, 0, len(v))
|
||
for key, value := range v {
|
||
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
|
||
args = append(args, value)
|
||
}
|
||
builder.WriteString(strings.Join(setParts, ", "))
|
||
case string:
|
||
// string 类型,直接使用(注意:实际使用需要转义)
|
||
builder.WriteString(v)
|
||
default:
|
||
// 结构体类型,使用字段映射器
|
||
mapper := NewFieldMapper()
|
||
columns, err := mapper.StructToColumns(data)
|
||
if err == nil && len(columns) > 0 {
|
||
setParts := make([]string, 0, len(columns))
|
||
for key := range columns {
|
||
setParts = append(setParts, fmt.Sprintf("%s = ?", key))
|
||
args = append(args, columns[key])
|
||
}
|
||
builder.WriteString(strings.Join(setParts, ", "))
|
||
}
|
||
}
|
||
|
||
// WHERE 部分
|
||
if q.whereSQL != "" {
|
||
builder.WriteString(" WHERE ")
|
||
builder.WriteString(q.whereSQL)
|
||
args = append(args, q.whereArgs...)
|
||
}
|
||
|
||
return builder.String(), args
|
||
}
|
||
|
||
// BuildDelete 构建 DELETE SQL 语句
|
||
func (q *QueryBuilder) BuildDelete() (string, []interface{}) {
|
||
var builder strings.Builder
|
||
|
||
builder.WriteString("DELETE FROM ")
|
||
if q.table != "" {
|
||
builder.WriteString(q.table)
|
||
} else if q.model != nil {
|
||
mapper := NewFieldMapper()
|
||
builder.WriteString(mapper.GetTableName(q.model))
|
||
} else {
|
||
builder.WriteString("unknown_table")
|
||
}
|
||
|
||
if q.whereSQL != "" {
|
||
builder.WriteString(" WHERE ")
|
||
builder.WriteString(q.whereSQL)
|
||
}
|
||
|
||
return builder.String(), q.whereArgs
|
||
}
|