feat(database): 添加ClickHouse数据库驱动支持

- 实现了ClickHouse数据库驱动程序,支持基本的数据库操作
- 添加了ClickHouse特定的迁移功能,包括表、列、索引的创建和管理
- 集成了ClickHouse的语法特性,如MergeTree引擎和Nullable类型
- 实现了数据库连接池管理和SQL执行接口
- 添加了对系统表查询的支持,用于检查表和列的存在性
main v1.0.2023
black 2026-04-13 16:07:54 +08:00
parent a083b74f9b
commit 284f9380ed
23 changed files with 3599 additions and 8 deletions

View File

@ -41,16 +41,59 @@ func main() {
user := getStringValue(defaultDbConfig, "user", "root")
pass := getStringValue(defaultDbConfig, "pass", "")
dbType := getStringValue(defaultDbConfig, "type", "mysql")
link := getStringValue(defaultDbConfig, "link", "")
fmt.Println("=== Gin-Base DAO 代码生成工具 ===")
fmt.Printf("📊 数据库: %s\n", name)
fmt.Printf("🔧 类型: %s\n", dbType)
fmt.Printf("🌐 主机: %s:%s\n\n", host, port)
// 构建数据库连接字符串
link := fmt.Sprintf("mysql:%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=true&loc=Local",
user, pass, host, port, name,
)
var connectionInfo string
if link == "" {
// 如果没有配置 link则根据数据库类型构建
switch dbType {
case "sqlite":
// SQLite 使用配置文件中的 link 或默认路径
connectionInfo = fmt.Sprintf("📁 数据库文件: %s", link)
case "mysql", "mariadb":
link = fmt.Sprintf("mysql:%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=true&loc=Local",
user, pass, host, port, name,
)
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
case "pgsql", "postgresql":
link = fmt.Sprintf("pgsql:%s:%s@tcp(%s:%s)/%s?sslmode=disable",
user, pass, host, port, name,
)
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
case "mssql":
link = fmt.Sprintf("mssql:%s:%s@tcp(%s:%s)/%s",
user, pass, host, port, name,
)
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
case "oracle":
link = fmt.Sprintf("oracle:%s:%s@%s:%s/%s",
user, pass, host, port, name,
)
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
case "clickhouse":
link = fmt.Sprintf("clickhouse:%s:%s@tcp(%s:%s)/%s",
user, pass, host, port, name,
)
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
default:
fmt.Printf("⚠️ 警告: 未知的数据库类型 %s尝试使用配置的 link\n", dbType)
if link == "" {
fmt.Println("❌ 错误: 未配置数据库连接信息")
os.Exit(1)
}
connectionInfo = "🔗 使用自定义连接"
}
} else {
// 使用配置文件中直接提供的 link
connectionInfo = fmt.Sprintf("🔗 连接: %s", link)
}
fmt.Println(connectionInfo)
fmt.Println()
// 准备表名参数
tablesArg := ""

View File

@ -61,3 +61,13 @@ func (d *Driver) injectNeedParsedSql(ctx context.Context) context.Context {
}
return context.WithValue(ctx, needParsedSqlInCtx, true)
}
// Migration returns a Migration instance for ClickHouse database operations.
func (d *Driver) Migration() *Migration {
return NewMigration(d)
}
// GetMigration returns a Migration instance implementing the Migration interface.
func (d *Driver) GetMigration() database.Migration {
return d.Migration()
}

View File

@ -0,0 +1,321 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package clickhouse
import (
"context"
"fmt"
"strings"
"git.magicany.cc/black1552/gin-base/database"
)
// Migration implements database migration operations for ClickHouse.
type Migration struct {
*database.MigrationCore
*database.AutoMigrateCore
}
// NewMigration creates a new ClickHouse Migration instance.
func NewMigration(db database.DB) *Migration {
return &Migration{
MigrationCore: database.NewMigrationCore(db),
AutoMigrateCore: database.NewAutoMigrateCore(db),
}
}
// CreateTable creates a new table with the given name and column definitions.
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
if len(columns) == 0 {
return fmt.Errorf("cannot create table without columns")
}
var opts database.TableOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE TABLE ")
if opts.IfNotExists {
sql.WriteString("IF NOT EXISTS ")
}
sql.WriteString(database.QuoteIdentifier(table))
sql.WriteString(" (\n")
// Add columns
var colDefs []string
for name, def := range columns {
colDef := m.buildColumnDefinition(name, def)
colDefs = append(colDefs, " "+colDef)
}
sql.WriteString(strings.Join(colDefs, ",\n"))
sql.WriteString("\n)")
// Add engine specification (required for ClickHouse)
sql.WriteString(" ENGINE = MergeTree()")
if opts.Comment != "" {
sql.WriteString(fmt.Sprintf(" COMMENT '%s'", escapeString(opts.Comment)))
}
return m.ExecuteSQL(ctx, sql.String())
}
// buildColumnDefinition builds column definition for ClickHouse.
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
var parts []string
parts = append(parts, database.QuoteIdentifier(name))
// Handle ClickHouse-specific types
dbType := def.Type
parts = append(parts, dbType)
if !def.Null {
// ClickHouse uses Nullable type wrapper
if !strings.HasPrefix(dbType, "Nullable(") {
// Type is already non-nullable by default in ClickHouse
}
} else {
// Make type nullable if needed
if !strings.HasPrefix(dbType, "Nullable(") {
parts[len(parts)-1] = fmt.Sprintf("Nullable(%s)", dbType)
}
}
if def.Default != nil {
defaultValue := formatDefaultValue(def.Default)
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
}
if def.Comment != "" {
parts = append(parts, fmt.Sprintf("COMMENT '%s'", escapeString(def.Comment)))
}
return strings.Join(parts, " ")
}
// DropTable drops an existing table from the database.
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
sql := "DROP TABLE "
if len(ifExists) > 0 && ifExists[0] {
sql += "IF EXISTS "
}
sql += database.QuoteIdentifier(table)
return m.ExecuteSQL(ctx, sql)
}
// HasTable checks if a table exists in the database.
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "currentDatabase()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM system.tables WHERE database = %s AND name = '%s'",
schema,
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// RenameTable renames an existing table from oldName to newName.
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
sql := fmt.Sprintf(
"RENAME TABLE %s TO %s",
database.QuoteIdentifier(oldName),
database.QuoteIdentifier(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// TruncateTable removes all records from a table but keeps the table structure.
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifier(table))
return m.ExecuteSQL(ctx, sql)
}
// AddColumn adds a new column to an existing table.
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifier(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// DropColumn removes a column from an existing table.
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP COLUMN %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(column),
)
return m.ExecuteSQL(ctx, sql)
}
// RenameColumn renames a column in an existing table.
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s RENAME COLUMN %s TO %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(oldName),
database.QuoteIdentifier(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// ModifyColumn modifies an existing column's definition.
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", database.QuoteIdentifier(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// HasColumn checks if a column exists in a table.
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
fields, err := m.GetDB().TableFields(ctx, table)
if err != nil {
return false, err
}
_, exists := fields[column]
return exists, nil
}
// CreateIndex creates a new index on the specified table and columns.
// Note: ClickHouse has limited index support compared to traditional databases.
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
var opts database.IndexOptions
for _, opt := range options {
opt(&opts)
}
// ClickHouse uses different indexing mechanisms
// For now, we'll use ALTER TABLE ADD INDEX syntax
colList := m.BuildIndexColumns(columns)
sql := fmt.Sprintf(
"ALTER TABLE %s ADD INDEX %s (%s) TYPE minmax GRANULARITY 1",
database.QuoteIdentifier(table),
database.QuoteIdentifier(index),
colList,
)
return m.ExecuteSQL(ctx, sql)
}
// DropIndex drops an existing index from a table.
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP INDEX %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(index),
)
return m.ExecuteSQL(ctx, sql)
}
// HasIndex checks if an index exists on a table.
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "currentDatabase()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM system.data_skipping_indices WHERE database = %s AND table = '%s' AND name = '%s'",
schema,
table,
index,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateForeignKey is not supported in ClickHouse.
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
// ClickHouse does not support foreign keys
return fmt.Errorf("ClickHouse does not support foreign key constraints")
}
// DropForeignKey is not supported in ClickHouse.
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
// ClickHouse does not support foreign keys
return fmt.Errorf("ClickHouse does not support foreign key constraints")
}
// HasForeignKey always returns false as ClickHouse doesn't support foreign keys.
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
return false, nil
}
// CreateSchema creates a new database schema.
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
sql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", database.QuoteIdentifier(schema))
return m.ExecuteSQL(ctx, sql)
}
// DropSchema drops an existing database schema.
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
sql := fmt.Sprintf("DROP DATABASE IF EXISTS %s", database.QuoteIdentifier(schema))
return m.ExecuteSQL(ctx, sql)
}
// HasSchema checks if a schema exists.
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM system.databases WHERE name = '%s'",
schema,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// formatDefaultValue formats the default value for SQL.
func formatDefaultValue(value any) string {
switch v := value.(type) {
case string:
return fmt.Sprintf("'%s'", escapeString(v))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
if v {
return "1"
}
return "0"
case nil:
return "NULL"
default:
return fmt.Sprintf("'%v'", v)
}
}
// escapeString escapes special characters in strings for SQL.
func escapeString(s string) string {
s = strings.ReplaceAll(s, "'", "\\'")
return s
}

View File

@ -46,3 +46,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
func (d *Driver) GetChars() (charLeft string, charRight string) {
return quoteChar, quoteChar
}
// Migration returns a Migration instance for MSSQL database operations.
func (d *Driver) Migration() *Migration {
return NewMigration(d)
}
// GetMigration returns a Migration instance implementing the Migration interface.
func (d *Driver) GetMigration() database.Migration {
return d.Migration()
}

View File

@ -0,0 +1,340 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package mssql
import (
"context"
"fmt"
"strings"
"git.magicany.cc/black1552/gin-base/database"
)
// Migration implements database migration operations for SQL Server.
type Migration struct {
*database.MigrationCore
*database.AutoMigrateCore
}
// NewMigration creates a new SQL Server Migration instance.
func NewMigration(db database.DB) *Migration {
return &Migration{
MigrationCore: database.NewMigrationCore(db),
AutoMigrateCore: database.NewAutoMigrateCore(db),
}
}
// CreateTable creates a new table with the given name and column definitions.
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
if len(columns) == 0 {
return fmt.Errorf("cannot create table without columns")
}
var opts database.TableOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE TABLE ")
sql.WriteString(database.QuoteIdentifier(table))
sql.WriteString(" (\n")
// Add columns
var colDefs []string
var primaryKeys []string
for name, def := range columns {
colDef := m.buildColumnDefinition(name, def)
if def.PrimaryKey {
primaryKeys = append(primaryKeys, database.QuoteIdentifier(name))
}
colDefs = append(colDefs, " "+colDef)
}
// Add primary key constraint if needed
if len(primaryKeys) > 0 {
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
}
sql.WriteString(strings.Join(colDefs, ",\n"))
sql.WriteString("\n)")
return m.ExecuteSQL(ctx, sql.String())
}
// buildColumnDefinition builds column definition for SQL Server.
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
var parts []string
parts = append(parts, database.QuoteIdentifier(name))
// Handle SQL Server-specific types
dbType := def.Type
if def.AutoIncrement {
if dbType == "INT" || dbType == "INTEGER" {
dbType = "INT IDENTITY(1,1)"
} else if dbType == "BIGINT" {
dbType = "BIGINT IDENTITY(1,1)"
}
}
parts = append(parts, dbType)
if !def.Null {
parts = append(parts, "NOT NULL")
} else {
parts = append(parts, "NULL")
}
if def.Default != nil {
defaultValue := formatDefaultValue(def.Default)
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
}
return strings.Join(parts, " ")
}
// DropTable drops an existing table from the database.
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
if len(ifExists) > 0 && ifExists[0] {
sql := fmt.Sprintf(
"IF OBJECT_ID('%s', 'U') IS NOT NULL DROP TABLE %s",
table,
database.QuoteIdentifier(table),
)
return m.ExecuteSQL(ctx, sql)
}
sql := fmt.Sprintf("DROP TABLE %s", database.QuoteIdentifier(table))
return m.ExecuteSQL(ctx, sql)
}
// HasTable checks if a table exists in the database.
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '%s'",
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// RenameTable renames an existing table from oldName to newName.
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
sql := fmt.Sprintf(
"EXEC sp_rename '%s', '%s'",
oldName,
newName,
)
return m.ExecuteSQL(ctx, sql)
}
// TruncateTable removes all records from a table but keeps the table structure.
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifier(table))
return m.ExecuteSQL(ctx, sql)
}
// AddColumn adds a new column to an existing table.
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s ADD %s", database.QuoteIdentifier(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// DropColumn removes a column from an existing table.
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP COLUMN %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(column),
)
return m.ExecuteSQL(ctx, sql)
}
// RenameColumn renames a column in an existing table.
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
sql := fmt.Sprintf(
"EXEC sp_rename '%s.%s', '%s', 'COLUMN'",
table,
oldName,
newName,
)
return m.ExecuteSQL(ctx, sql)
}
// ModifyColumn modifies an existing column's definition.
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", database.QuoteIdentifier(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// HasColumn checks if a column exists in a table.
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
fields, err := m.GetDB().TableFields(ctx, table)
if err != nil {
return false, err
}
_, exists := fields[column]
return exists, nil
}
// CreateIndex creates a new index on the specified table and columns.
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
var opts database.IndexOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE ")
if opts.Unique {
sql.WriteString("UNIQUE ")
}
sql.WriteString("INDEX ")
sql.WriteString(database.QuoteIdentifier(index))
sql.WriteString(" ON ")
sql.WriteString(database.QuoteIdentifier(table))
colList := m.BuildIndexColumns(columns)
sql.WriteString(fmt.Sprintf(" (%s)", colList))
return m.ExecuteSQL(ctx, sql.String())
}
// DropIndex drops an existing index from a table.
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
sql := fmt.Sprintf("DROP INDEX %s ON %s", database.QuoteIdentifier(index), database.QuoteIdentifier(table))
return m.ExecuteSQL(ctx, sql)
}
// HasIndex checks if an index exists on a table.
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM sys.indexes WHERE name = '%s' AND object_id = OBJECT_ID('%s')",
index,
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateForeignKey creates a foreign key constraint.
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
var opts database.ForeignKeyOptions
for _, opt := range options {
opt(&opts)
}
sql := fmt.Sprintf(
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(constraint),
m.BuildForeignKeyColumns(columns),
database.QuoteIdentifier(refTable),
m.BuildForeignKeyColumns(refColumns),
)
if opts.OnDelete != "" {
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
}
if opts.OnUpdate != "" {
sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate)
}
return m.ExecuteSQL(ctx, sql)
}
// DropForeignKey drops a foreign key constraint.
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP CONSTRAINT %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(constraint),
)
return m.ExecuteSQL(ctx, sql)
}
// HasForeignKey checks if a foreign key constraint exists.
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_NAME = '%s' AND TABLE_NAME = '%s' AND CONSTRAINT_TYPE = 'FOREIGN KEY'",
constraint,
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateSchema creates a new database schema.
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
sql := fmt.Sprintf("CREATE SCHEMA %s", database.QuoteIdentifier(schema))
return m.ExecuteSQL(ctx, sql)
}
// DropSchema drops an existing database schema.
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
sql := fmt.Sprintf("DROP SCHEMA %s", database.QuoteIdentifier(schema))
return m.ExecuteSQL(ctx, sql)
}
// HasSchema checks if a schema exists.
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'",
schema,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// formatDefaultValue formats the default value for SQL.
func formatDefaultValue(value any) string {
switch v := value.(type) {
case string:
return fmt.Sprintf("'%s'", escapeString(v))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
if v {
return "1"
}
return "0"
case nil:
return "NULL"
default:
return fmt.Sprintf("'%v'", v)
}
}
// escapeString escapes special characters in strings for SQL.
func escapeString(s string) string {
s = strings.ReplaceAll(s, "'", "''")
return s
}

View File

@ -52,3 +52,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
func (d *Driver) GetChars() (charLeft string, charRight string) {
return quoteChar, quoteChar
}
// Migration returns a Migration instance for MySQL database operations.
func (d *Driver) Migration() *Migration {
return NewMigration(d)
}
// GetMigration returns a Migration instance implementing the Migration interface.
func (d *Driver) GetMigration() database.Migration {
return d.Migration()
}

View File

@ -0,0 +1,365 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package mysql
import (
"context"
"fmt"
"strings"
"git.magicany.cc/black1552/gin-base/database"
)
// Migration implements database migration operations for MySQL.
type Migration struct {
*database.MigrationCore
*database.AutoMigrateCore
}
// NewMigration creates a new MySQL Migration instance.
func NewMigration(db database.DB) *Migration {
return &Migration{
MigrationCore: database.NewMigrationCore(db),
AutoMigrateCore: database.NewAutoMigrateCore(db),
}
}
// CreateTable creates a new table with the given name and column definitions.
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
if len(columns) == 0 {
return fmt.Errorf("cannot create table without columns")
}
var opts database.TableOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE TABLE ")
if opts.IfNotExists {
sql.WriteString("IF NOT EXISTS ")
}
sql.WriteString(database.QuoteIdentifier(table))
sql.WriteString(" (\n")
// Add columns
var colDefs []string
var primaryKeys []string
for name, def := range columns {
colDef := m.BuildColumnDefinition(name, def)
if def.PrimaryKey {
primaryKeys = append(primaryKeys, database.QuoteIdentifier(name))
}
colDefs = append(colDefs, " "+colDef)
}
// Add primary key constraint if needed
if len(primaryKeys) > 0 {
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
}
sql.WriteString(strings.Join(colDefs, ",\n"))
sql.WriteString("\n)")
// Add table options
if opts.Engine != "" {
sql.WriteString(fmt.Sprintf(" ENGINE=%s", opts.Engine))
}
if opts.Charset != "" {
sql.WriteString(fmt.Sprintf(" DEFAULT CHARSET=%s", opts.Charset))
}
if opts.Collation != "" {
sql.WriteString(fmt.Sprintf(" COLLATE=%s", opts.Collation))
}
if opts.Comment != "" {
sql.WriteString(fmt.Sprintf(" COMMENT='%s'", escapeString(opts.Comment)))
}
return m.ExecuteSQL(ctx, sql.String())
}
// DropTable drops an existing table from the database.
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
sql := "DROP TABLE "
if len(ifExists) > 0 && ifExists[0] {
sql += "IF EXISTS "
}
sql += database.QuoteIdentifier(table)
return m.ExecuteSQL(ctx, sql)
}
// HasTable checks if a table exists in the database.
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "DATABASE()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = %s AND table_name = '%s'",
schema,
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// RenameTable renames an existing table from oldName to newName.
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
sql := fmt.Sprintf(
"RENAME TABLE %s TO %s",
database.QuoteIdentifier(oldName),
database.QuoteIdentifier(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// TruncateTable removes all records from a table but keeps the table structure.
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifier(table))
return m.ExecuteSQL(ctx, sql)
}
// AddColumn adds a new column to an existing table.
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.BuildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifier(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// DropColumn removes a column from an existing table.
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP COLUMN %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(column),
)
return m.ExecuteSQL(ctx, sql)
}
// RenameColumn renames a column in an existing table.
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
// MySQL requires the full column definition when renaming
fields, err := m.GetDB().TableFields(ctx, table)
if err != nil {
return err
}
field, ok := fields[oldName]
if !ok {
return fmt.Errorf("column %s does not exist in table %s", oldName, table)
}
def := &database.ColumnDefinition{
Type: field.Type,
Null: field.Null,
Default: field.Default,
Comment: field.Comment,
AutoIncrement: strings.Contains(field.Extra, "auto_increment"),
}
colDef := m.BuildColumnDefinition(newName, def)
sql := fmt.Sprintf(
"ALTER TABLE %s CHANGE COLUMN %s %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(oldName),
colDef,
)
return m.ExecuteSQL(ctx, sql)
}
// ModifyColumn modifies an existing column's definition.
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.BuildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", database.QuoteIdentifier(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// HasColumn checks if a column exists in a table.
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
fields, err := m.GetDB().TableFields(ctx, table)
if err != nil {
return false, err
}
_, exists := fields[column]
return exists, nil
}
// CreateIndex creates a new index on the specified table and columns.
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
var opts database.IndexOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE ")
if opts.Unique {
sql.WriteString("UNIQUE ")
}
if opts.FullText {
sql.WriteString("FULLTEXT ")
}
if opts.Spatial {
sql.WriteString("SPATIAL ")
}
sql.WriteString("INDEX ")
sql.WriteString(database.QuoteIdentifier(index))
sql.WriteString(" ON ")
sql.WriteString(database.QuoteIdentifier(table))
colList := m.BuildIndexColumns(columns)
sql.WriteString(fmt.Sprintf(" (%s)", colList))
if opts.Using != "" {
sql.WriteString(fmt.Sprintf(" USING %s", opts.Using))
}
return m.ExecuteSQL(ctx, sql.String())
}
// DropIndex drops an existing index from a table.
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
sql := fmt.Sprintf(
"DROP INDEX %s ON %s",
database.QuoteIdentifier(index),
database.QuoteIdentifier(table),
)
return m.ExecuteSQL(ctx, sql)
}
// HasIndex checks if an index exists on a table.
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "DATABASE()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM information_schema.statistics WHERE table_schema = %s AND table_name = '%s' AND index_name = '%s'",
schema,
table,
index,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateForeignKey creates a foreign key constraint.
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
var opts database.ForeignKeyOptions
for _, opt := range options {
opt(&opts)
}
sql := fmt.Sprintf(
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(constraint),
m.BuildForeignKeyColumns(columns),
database.QuoteIdentifier(refTable),
m.BuildForeignKeyColumns(refColumns),
)
if opts.OnDelete != "" {
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
}
if opts.OnUpdate != "" {
sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate)
}
return m.ExecuteSQL(ctx, sql)
}
// DropForeignKey drops a foreign key constraint.
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP FOREIGN KEY %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(constraint),
)
return m.ExecuteSQL(ctx, sql)
}
// HasForeignKey checks if a foreign key constraint exists.
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "DATABASE()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM information_schema.table_constraints WHERE constraint_schema = %s AND table_name = '%s' AND constraint_name = '%s' AND constraint_type = 'FOREIGN KEY'",
schema,
table,
constraint,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateSchema creates a new database schema.
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
sql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", database.QuoteIdentifier(schema))
return m.ExecuteSQL(ctx, sql)
}
// DropSchema drops an existing database schema.
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
sql := "DROP DATABASE "
if len(cascade) > 0 && cascade[0] {
sql += "IF EXISTS "
}
sql += database.QuoteIdentifier(schema)
return m.ExecuteSQL(ctx, sql)
}
// HasSchema checks if a schema exists.
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = '%s'",
schema,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// escapeString escapes special characters in strings for SQL.
func escapeString(s string) string {
s = strings.ReplaceAll(s, "'", "''")
s = strings.ReplaceAll(s, "\\", "\\\\")
return s
}

View File

@ -44,3 +44,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
func (d *Driver) GetChars() (charLeft string, charRight string) {
return quoteChar, quoteChar
}
// Migration returns a Migration instance for Oracle database operations.
func (d *Driver) Migration() *Migration {
return NewMigration(d)
}
// GetMigration returns a Migration instance implementing the Migration interface.
func (d *Driver) GetMigration() database.Migration {
return d.Migration()
}

View File

@ -0,0 +1,351 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package oracle
import (
"context"
"fmt"
"strings"
"git.magicany.cc/black1552/gin-base/database"
)
// Migration implements database migration operations for Oracle.
type Migration struct {
*database.MigrationCore
*database.AutoMigrateCore
}
// NewMigration creates a new Oracle Migration instance.
func NewMigration(db database.DB) *Migration {
return &Migration{
MigrationCore: database.NewMigrationCore(db),
AutoMigrateCore: database.NewAutoMigrateCore(db),
}
}
// CreateTable creates a new table with the given name and column definitions.
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
if len(columns) == 0 {
return fmt.Errorf("cannot create table without columns")
}
var opts database.TableOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE TABLE ")
sql.WriteString(database.QuoteIdentifierDouble(table))
sql.WriteString(" (\n")
// Add columns
var colDefs []string
var primaryKeys []string
for name, def := range columns {
colDef := m.buildColumnDefinition(name, def)
if def.PrimaryKey {
primaryKeys = append(primaryKeys, database.QuoteIdentifierDouble(name))
}
colDefs = append(colDefs, " "+colDef)
}
// Add primary key constraint if needed
if len(primaryKeys) > 0 {
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
}
sql.WriteString(strings.Join(colDefs, ",\n"))
sql.WriteString("\n)")
return m.ExecuteSQL(ctx, sql.String())
}
// buildColumnDefinition builds column definition for Oracle.
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
var parts []string
parts = append(parts, database.QuoteIdentifierDouble(name))
// Handle Oracle-specific types
dbType := def.Type
parts = append(parts, dbType)
if !def.Null {
parts = append(parts, "NOT NULL")
}
if def.Default != nil {
defaultValue := formatDefaultValue(def.Default)
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
}
return strings.Join(parts, " ")
}
// DropTable drops an existing table from the database.
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
if len(ifExists) > 0 && ifExists[0] {
sql := fmt.Sprintf(
"BEGIN EXECUTE IMMEDIATE 'DROP TABLE %s'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;",
database.QuoteIdentifierDouble(table),
)
return m.ExecuteSQL(ctx, sql)
}
sql := fmt.Sprintf("DROP TABLE %s", database.QuoteIdentifierDouble(table))
return m.ExecuteSQL(ctx, sql)
}
// HasTable checks if a table exists in the database.
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = '%s'",
strings.ToUpper(table),
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// RenameTable renames an existing table from oldName to newName.
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s",
database.QuoteIdentifierDouble(oldName),
database.QuoteIdentifierDouble(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// TruncateTable removes all records from a table but keeps the table structure.
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifierDouble(table))
return m.ExecuteSQL(ctx, sql)
}
// AddColumn adds a new column to an existing table.
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s ADD %s", database.QuoteIdentifierDouble(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// DropColumn removes a column from an existing table.
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP COLUMN %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(column),
)
return m.ExecuteSQL(ctx, sql)
}
// RenameColumn renames a column in an existing table.
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s RENAME COLUMN %s TO %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(oldName),
database.QuoteIdentifierDouble(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// ModifyColumn modifies an existing column's definition.
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s MODIFY %s", database.QuoteIdentifierDouble(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// HasColumn checks if a column exists in a table.
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
fields, err := m.GetDB().TableFields(ctx, table)
if err != nil {
return false, err
}
_, exists := fields[column]
return exists, nil
}
// CreateIndex creates a new index on the specified table and columns.
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
var opts database.IndexOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE ")
if opts.Unique {
sql.WriteString("UNIQUE ")
}
sql.WriteString("INDEX ")
sql.WriteString(database.QuoteIdentifierDouble(index))
sql.WriteString(" ON ")
sql.WriteString(database.QuoteIdentifierDouble(table))
colList := m.BuildIndexColumnsWithDouble(columns)
sql.WriteString(fmt.Sprintf(" (%s)", colList))
return m.ExecuteSQL(ctx, sql.String())
}
// BuildIndexColumnsWithDouble builds the column list for index creation using double quotes.
func (m *Migration) BuildIndexColumnsWithDouble(columns []string) string {
quoted := make([]string, len(columns))
for i, col := range columns {
quoted[i] = database.QuoteIdentifierDouble(col)
}
return strings.Join(quoted, ", ")
}
// DropIndex drops an existing index from a table.
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
sql := fmt.Sprintf("DROP INDEX %s", database.QuoteIdentifierDouble(index))
return m.ExecuteSQL(ctx, sql)
}
// HasIndex checks if an index exists on a table.
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM USER_INDEXES WHERE INDEX_NAME = '%s' AND TABLE_NAME = '%s'",
strings.ToUpper(index),
strings.ToUpper(table),
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateForeignKey creates a foreign key constraint.
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
var opts database.ForeignKeyOptions
for _, opt := range options {
opt(&opts)
}
sql := fmt.Sprintf(
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(constraint),
m.BuildForeignKeyColumnsWithDouble(columns),
database.QuoteIdentifierDouble(refTable),
m.BuildForeignKeyColumnsWithDouble(refColumns),
)
if opts.OnDelete != "" {
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
}
return m.ExecuteSQL(ctx, sql)
}
// BuildForeignKeyColumnsWithDouble builds the column list for foreign key using double quotes.
func (m *Migration) BuildForeignKeyColumnsWithDouble(columns []string) string {
quoted := make([]string, len(columns))
for i, col := range columns {
quoted[i] = database.QuoteIdentifierDouble(col)
}
return "(" + strings.Join(quoted, ", ") + ")"
}
// DropForeignKey drops a foreign key constraint.
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP CONSTRAINT %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(constraint),
)
return m.ExecuteSQL(ctx, sql)
}
// HasForeignKey checks if a foreign key constraint exists.
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE CONSTRAINT_NAME = '%s' AND TABLE_NAME = '%s' AND CONSTRAINT_TYPE = 'R'",
strings.ToUpper(constraint),
strings.ToUpper(table),
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateSchema creates a new database schema (user in Oracle).
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
// In Oracle, schema is equivalent to user
sql := fmt.Sprintf("CREATE USER %s IDENTIFIED BY password", database.QuoteIdentifierDouble(schema))
return m.ExecuteSQL(ctx, sql)
}
// DropSchema drops an existing database schema (user in Oracle).
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
sql := "DROP USER "
if len(cascade) > 0 && cascade[0] {
sql += database.QuoteIdentifierDouble(schema) + " CASCADE"
} else {
sql += database.QuoteIdentifierDouble(schema)
}
return m.ExecuteSQL(ctx, sql)
}
// HasSchema checks if a schema exists.
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM ALL_USERS WHERE USERNAME = '%s'",
strings.ToUpper(schema),
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// formatDefaultValue formats the default value for SQL.
func formatDefaultValue(value any) string {
switch v := value.(type) {
case string:
return fmt.Sprintf("'%s'", escapeString(v))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
if v {
return "1"
}
return "0"
case nil:
return "NULL"
default:
return fmt.Sprintf("'%v'", v)
}
}
// escapeString escapes special characters in strings for SQL.
func escapeString(s string) string {
s = strings.ReplaceAll(s, "'", "''")
return s
}

View File

@ -48,3 +48,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
func (d *Driver) GetChars() (charLeft string, charRight string) {
return quoteChar, quoteChar
}
// Migration returns a Migration instance for PostgreSQL database operations.
func (d *Driver) Migration() *Migration {
return NewMigration(d)
}
// GetMigration returns a Migration instance implementing the Migration interface.
func (d *Driver) GetMigration() database.Migration {
return d.Migration()
}

View File

@ -0,0 +1,481 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package pgsql
import (
"context"
"fmt"
"strings"
"git.magicany.cc/black1552/gin-base/database"
)
// Migration implements database migration operations for PostgreSQL.
type Migration struct {
*database.MigrationCore
*database.AutoMigrateCore
}
// NewMigration creates a new PostgreSQL Migration instance.
func NewMigration(db database.DB) *Migration {
return &Migration{
MigrationCore: database.NewMigrationCore(db),
AutoMigrateCore: database.NewAutoMigrateCore(db),
}
}
// CreateTable creates a new table with the given name and column definitions.
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
if len(columns) == 0 {
return fmt.Errorf("cannot create table without columns")
}
var opts database.TableOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE TABLE ")
if opts.IfNotExists {
sql.WriteString("IF NOT EXISTS ")
}
sql.WriteString(database.QuoteIdentifierDouble(table))
sql.WriteString(" (\n")
// Add columns
var colDefs []string
var primaryKeys []string
for name, def := range columns {
colDef := m.buildColumnDefinition(name, def)
if def.PrimaryKey {
primaryKeys = append(primaryKeys, database.QuoteIdentifierDouble(name))
}
colDefs = append(colDefs, " "+colDef)
}
// Add primary key constraint if needed
if len(primaryKeys) > 0 {
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
}
sql.WriteString(strings.Join(colDefs, ",\n"))
sql.WriteString("\n)")
// Add table comment if provided
if opts.Comment != "" {
commentSQL := fmt.Sprintf(
"COMMENT ON TABLE %s IS '%s'",
database.QuoteIdentifierDouble(table),
escapeString(opts.Comment),
)
if err := m.ExecuteSQL(ctx, commentSQL); err != nil {
return err
}
}
return m.ExecuteSQL(ctx, sql.String())
}
// buildColumnDefinition builds column definition for PostgreSQL.
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
var parts []string
parts = append(parts, database.QuoteIdentifierDouble(name))
// Handle PostgreSQL-specific types
dbType := def.Type
if def.AutoIncrement {
if dbType == "INT" || dbType == "INTEGER" {
dbType = "SERIAL"
} else if dbType == "BIGINT" {
dbType = "BIGSERIAL"
}
}
parts = append(parts, dbType)
if def.PrimaryKey {
// Primary key is handled separately
} else {
if !def.Null {
parts = append(parts, "NOT NULL")
}
if def.Unique {
parts = append(parts, "UNIQUE")
}
if def.Default != nil {
defaultValue := formatDefaultValue(def.Default)
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
}
}
return strings.Join(parts, " ")
}
// DropTable drops an existing table from the database.
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
sql := "DROP TABLE "
if len(ifExists) > 0 && ifExists[0] {
sql += "IF EXISTS "
}
sql += database.QuoteIdentifierDouble(table)
return m.ExecuteSQL(ctx, sql)
}
// HasTable checks if a table exists in the database.
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "current_schema()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = %s AND table_name = '%s'",
schema,
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// RenameTable renames an existing table from oldName to newName.
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s",
database.QuoteIdentifierDouble(oldName),
database.QuoteIdentifierDouble(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// TruncateTable removes all records from a table but keeps the table structure.
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifierDouble(table))
return m.ExecuteSQL(ctx, sql)
}
// AddColumn adds a new column to an existing table.
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifierDouble(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// DropColumn removes a column from an existing table.
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP COLUMN %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(column),
)
return m.ExecuteSQL(ctx, sql)
}
// RenameColumn renames a column in an existing table.
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s RENAME COLUMN %s TO %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(oldName),
database.QuoteIdentifierDouble(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// ModifyColumn modifies an existing column's definition.
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
// PostgreSQL requires multiple ALTER statements for different modifications
var statements []string
if definition.Type != "" {
statements = append(statements, fmt.Sprintf(
"ALTER TABLE %s ALTER COLUMN %s TYPE %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(column),
definition.Type,
))
}
if definition.Default != nil {
defaultValue := formatDefaultValue(definition.Default)
statements = append(statements, fmt.Sprintf(
"ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(column),
defaultValue,
))
} else {
statements = append(statements, fmt.Sprintf(
"ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(column),
))
}
if !definition.Null {
statements = append(statements, fmt.Sprintf(
"ALTER TABLE %s ALTER COLUMN %s SET NOT NULL",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(column),
))
} else {
statements = append(statements, fmt.Sprintf(
"ALTER TABLE %s ALTER COLUMN %s DROP NOT NULL",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(column),
))
}
// Execute all statements
for _, stmt := range statements {
if err := m.ExecuteSQL(ctx, stmt); err != nil {
return err
}
}
return nil
}
// HasColumn checks if a column exists in a table.
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
fields, err := m.GetDB().TableFields(ctx, table)
if err != nil {
return false, err
}
_, exists := fields[column]
return exists, nil
}
// CreateIndex creates a new index on the specified table and columns.
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
var opts database.IndexOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE ")
if opts.Unique {
sql.WriteString("UNIQUE ")
}
sql.WriteString("INDEX ")
sql.WriteString(database.QuoteIdentifierDouble(index))
sql.WriteString(" ON ")
sql.WriteString(database.QuoteIdentifierDouble(table))
colList := m.BuildIndexColumnsWithDouble(columns)
sql.WriteString(fmt.Sprintf(" (%s)", colList))
if opts.Using != "" {
sql.WriteString(fmt.Sprintf(" USING %s", opts.Using))
}
if opts.Comment != "" {
// Comment will be added after index creation
}
err := m.ExecuteSQL(ctx, sql.String())
if err != nil {
return err
}
// Add comment if provided
if opts.Comment != "" {
commentSQL := fmt.Sprintf(
"COMMENT ON INDEX %s IS '%s'",
database.QuoteIdentifierDouble(index),
escapeString(opts.Comment),
)
return m.ExecuteSQL(ctx, commentSQL)
}
return nil
}
// BuildIndexColumnsWithDouble builds the column list for index creation using double quotes.
func (m *Migration) BuildIndexColumnsWithDouble(columns []string) string {
quoted := make([]string, len(columns))
for i, col := range columns {
quoted[i] = database.QuoteIdentifierDouble(col)
}
return strings.Join(quoted, ", ")
}
// DropIndex drops an existing index from a table.
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
sql := fmt.Sprintf("DROP INDEX %s", database.QuoteIdentifierDouble(index))
return m.ExecuteSQL(ctx, sql)
}
// HasIndex checks if an index exists on a table.
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "current_schema()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM pg_indexes WHERE schemaname = %s AND tablename = '%s' AND indexname = '%s'",
schema,
table,
index,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateForeignKey creates a foreign key constraint.
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
var opts database.ForeignKeyOptions
for _, opt := range options {
opt(&opts)
}
sql := fmt.Sprintf(
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(constraint),
m.BuildForeignKeyColumnsWithDouble(columns),
database.QuoteIdentifierDouble(refTable),
m.BuildForeignKeyColumnsWithDouble(refColumns),
)
if opts.OnDelete != "" {
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
}
if opts.OnUpdate != "" {
sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate)
}
if opts.Deferrable {
sql += " DEFERRABLE"
if opts.InitiallyDeferred {
sql += " INITIALLY DEFERRED"
}
}
return m.ExecuteSQL(ctx, sql)
}
// BuildForeignKeyColumnsWithDouble builds the column list for foreign key using double quotes.
func (m *Migration) BuildForeignKeyColumnsWithDouble(columns []string) string {
quoted := make([]string, len(columns))
for i, col := range columns {
quoted[i] = database.QuoteIdentifierDouble(col)
}
return "(" + strings.Join(quoted, ", ") + ")"
}
// DropForeignKey drops a foreign key constraint.
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s DROP CONSTRAINT %s",
database.QuoteIdentifierDouble(table),
database.QuoteIdentifierDouble(constraint),
)
return m.ExecuteSQL(ctx, sql)
}
// HasForeignKey checks if a foreign key constraint exists.
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
schema := m.GetDB().GetSchema()
if schema == "" {
schema = "current_schema()"
} else {
schema = fmt.Sprintf("'%s'", schema)
}
query := fmt.Sprintf(
"SELECT COUNT(*) FROM information_schema.table_constraints WHERE constraint_schema = %s AND table_name = '%s' AND constraint_name = '%s' AND constraint_type = 'FOREIGN KEY'",
schema,
table,
constraint,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateSchema creates a new database schema.
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
sql := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", database.QuoteIdentifierDouble(schema))
return m.ExecuteSQL(ctx, sql)
}
// DropSchema drops an existing database schema.
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
sql := "DROP SCHEMA "
if len(cascade) > 0 && cascade[0] {
sql += "IF EXISTS "
sql += database.QuoteIdentifierDouble(schema) + " CASCADE"
} else {
sql += "IF EXISTS " + database.QuoteIdentifierDouble(schema)
}
return m.ExecuteSQL(ctx, sql)
}
// HasSchema checks if a schema exists.
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = '%s'",
schema,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// formatDefaultValue formats the default value for SQL.
func formatDefaultValue(value any) string {
switch v := value.(type) {
case string:
return fmt.Sprintf("'%s'", escapeString(v))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
if v {
return "TRUE"
}
return "FALSE"
case nil:
return "NULL"
default:
return fmt.Sprintf("'%v'", v)
}
}
// escapeString escapes special characters in strings for SQL.
func escapeString(s string) string {
s = strings.ReplaceAll(s, "'", "''")
return s
}

View File

@ -45,3 +45,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
func (d *Driver) GetChars() (charLeft string, charRight string) {
return quoteChar, quoteChar
}
// Migration returns a Migration instance for SQLite database operations.
func (d *Driver) Migration() *Migration {
return NewMigration(d)
}
// GetMigration returns a Migration instance implementing the Migration interface.
func (d *Driver) GetMigration() database.Migration {
return d.Migration()
}

View File

@ -0,0 +1,322 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package sqlite
import (
"context"
"fmt"
"strings"
"git.magicany.cc/black1552/gin-base/database"
)
// Migration implements database migration operations for SQLite.
type Migration struct {
*database.MigrationCore
*database.AutoMigrateCore
}
// NewMigration creates a new SQLite Migration instance.
func NewMigration(db database.DB) *Migration {
return &Migration{
MigrationCore: database.NewMigrationCore(db),
AutoMigrateCore: database.NewAutoMigrateCore(db),
}
}
// CreateTable creates a new table with the given name and column definitions.
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
if len(columns) == 0 {
return fmt.Errorf("cannot create table without columns")
}
var opts database.TableOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE TABLE ")
if opts.IfNotExists {
sql.WriteString("IF NOT EXISTS ")
}
sql.WriteString(database.QuoteIdentifier(table))
sql.WriteString(" (\n")
// Add columns
var colDefs []string
var primaryKeys []string
for name, def := range columns {
colDef := m.buildColumnDefinition(name, def)
if def.PrimaryKey {
primaryKeys = append(primaryKeys, database.QuoteIdentifier(name))
}
colDefs = append(colDefs, " "+colDef)
}
// Add primary key constraint if needed
if len(primaryKeys) > 0 {
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
}
sql.WriteString(strings.Join(colDefs, ",\n"))
sql.WriteString("\n)")
return m.ExecuteSQL(ctx, sql.String())
}
// buildColumnDefinition builds column definition for SQLite.
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
var parts []string
parts = append(parts, database.QuoteIdentifier(name))
// Handle SQLite-specific types
dbType := def.Type
if def.AutoIncrement && def.PrimaryKey {
if dbType == "INT" || dbType == "INTEGER" {
dbType = "INTEGER"
}
}
parts = append(parts, dbType)
if def.PrimaryKey && def.AutoIncrement {
parts = append(parts, "PRIMARY KEY AUTOINCREMENT")
} else {
if !def.Null {
parts = append(parts, "NOT NULL")
}
if def.Unique && !def.PrimaryKey {
parts = append(parts, "UNIQUE")
}
if def.Default != nil {
defaultValue := formatDefaultValue(def.Default)
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
}
}
return strings.Join(parts, " ")
}
// DropTable drops an existing table from the database.
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
sql := "DROP TABLE "
if len(ifExists) > 0 && ifExists[0] {
sql += "IF EXISTS "
}
sql += database.QuoteIdentifier(table)
return m.ExecuteSQL(ctx, sql)
}
// HasTable checks if a table exists in the database.
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='%s'",
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// RenameTable renames an existing table from oldName to newName.
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s",
database.QuoteIdentifier(oldName),
database.QuoteIdentifier(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// TruncateTable removes all records from a table but keeps the table structure.
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
sql := fmt.Sprintf("DELETE FROM %s", database.QuoteIdentifier(table))
return m.ExecuteSQL(ctx, sql)
}
// AddColumn adds a new column to an existing table.
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
colDef := m.buildColumnDefinition(column, definition)
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifier(table), colDef)
return m.ExecuteSQL(ctx, sql)
}
// DropColumn removes a column from an existing table.
// Note: SQLite has limited ALTER TABLE support. This may require table recreation.
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
// SQLite 3.35.0+ supports DROP COLUMN
sql := fmt.Sprintf(
"ALTER TABLE %s DROP COLUMN %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(column),
)
return m.ExecuteSQL(ctx, sql)
}
// RenameColumn renames a column in an existing table.
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
sql := fmt.Sprintf(
"ALTER TABLE %s RENAME COLUMN %s TO %s",
database.QuoteIdentifier(table),
database.QuoteIdentifier(oldName),
database.QuoteIdentifierDouble(newName),
)
return m.ExecuteSQL(ctx, sql)
}
// ModifyColumn modifies an existing column's definition.
// Note: SQLite requires table recreation for most column modifications.
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
// SQLite has very limited ALTER TABLE support
// This would typically require recreating the table
return fmt.Errorf("SQLite does not support MODIFY COLUMN directly, table recreation required")
}
// HasColumn checks if a column exists in a table.
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
fields, err := m.GetDB().TableFields(ctx, table)
if err != nil {
return false, err
}
_, exists := fields[column]
return exists, nil
}
// CreateIndex creates a new index on the specified table and columns.
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
var opts database.IndexOptions
for _, opt := range options {
opt(&opts)
}
var sql strings.Builder
sql.WriteString("CREATE ")
if opts.Unique {
sql.WriteString("UNIQUE ")
}
sql.WriteString("INDEX ")
if opts.IfNotExists {
sql.WriteString("IF NOT EXISTS ")
}
sql.WriteString(database.QuoteIdentifier(index))
sql.WriteString(" ON ")
sql.WriteString(database.QuoteIdentifier(table))
colList := m.BuildIndexColumns(columns)
sql.WriteString(fmt.Sprintf(" (%s)", colList))
return m.ExecuteSQL(ctx, sql.String())
}
// DropIndex drops an existing index from a table.
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
sql := fmt.Sprintf("DROP INDEX IF EXISTS %s", database.QuoteIdentifier(index))
return m.ExecuteSQL(ctx, sql)
}
// HasIndex checks if an index exists on a table.
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND name='%s' AND tbl_name='%s'",
index,
table,
)
value, err := m.GetDB().GetValue(ctx, query)
if err != nil {
return false, err
}
return value.Int() > 0, nil
}
// CreateForeignKey creates a foreign key constraint.
// Note: SQLite requires foreign keys to be enabled with PRAGMA foreign_keys = ON
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
// SQLite doesn't support adding foreign keys to existing tables
// Foreign keys must be defined during table creation
return fmt.Errorf("SQLite does not support adding foreign keys to existing tables")
}
// DropForeignKey drops a foreign key constraint.
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
// SQLite doesn't support dropping foreign keys from existing tables
return fmt.Errorf("SQLite does not support dropping foreign keys from existing tables")
}
// HasForeignKey checks if a foreign key constraint exists.
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
// Query pragma to check foreign keys
query := fmt.Sprintf("PRAGMA foreign_key_list(%s)", database.QuoteIdentifier(table))
result, err := m.GetDB().GetAll(ctx, query)
if err != nil {
return false, err
}
// Check if constraint exists in the result
for _, row := range result {
if id, ok := row["id"]; ok && id.String() == constraint {
return true, nil
}
}
return false, nil
}
// CreateSchema is not applicable for SQLite (single database file).
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
// SQLite doesn't support schemas
return fmt.Errorf("SQLite does not support schemas")
}
// DropSchema is not applicable for SQLite (single database file).
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
// SQLite doesn't support schemas
return fmt.Errorf("SQLite does not support schemas")
}
// HasSchema checks if a schema exists (always returns false for SQLite).
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
// SQLite doesn't support schemas
return false, nil
}
// formatDefaultValue formats the default value for SQL.
func formatDefaultValue(value any) string {
switch v := value.(type) {
case string:
return fmt.Sprintf("'%s'", escapeString(v))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
if v {
return "1"
}
return "0"
case nil:
return "NULL"
default:
return fmt.Sprintf("'%v'", v)
}
}
// escapeString escapes special characters in strings for SQL.
func escapeString(s string) string {
s = strings.ReplaceAll(s, "'", "''")
return s
}

View File

@ -319,6 +319,10 @@ type DB interface {
// If no schema is specified, it uses the default schema.
Tables(ctx context.Context, schema ...string) (tables []string, err error)
// HasTable checks if a table exists in the database.
// It returns true if the table exists, false otherwise.
HasTable(ctx context.Context, table string) (bool, error)
// TableFields returns detailed information about all fields in the specified table.
// The returned map keys are field names and values contain field metadata.
TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error)
@ -345,6 +349,14 @@ type DB interface {
// OrderRandomFunction returns the SQL function for random ordering.
// The implementation is database-specific (e.g., RAND() for MySQL).
OrderRandomFunction() string
// ===========================================================================
// Migration support.
// ===========================================================================
// GetMigration returns a Migration instance for database schema operations.
// The returned Migration can be used to create, alter, and drop tables, columns, indexes, etc.
GetMigration() Migration
}
// TX defines the interfaces for ORM transaction operations.

View File

@ -735,7 +735,7 @@ func (c *Core) writeSqlToLogger(sql *Sql) {
}
// HasTable determine whether the table name exists in the database.
func (c *Core) HasTable(name string) (bool, error) {
func (c *Core) HasTable(ctx context.Context, name string) (bool, error) {
tables, err := c.GetTablesWithCache()
if err != nil {
return false, err

View File

@ -44,3 +44,8 @@ func (d *DriverDefault) PingMaster() error {
func (d *DriverDefault) PingSlave() error {
return nil
}
// GetMigration returns a Migration instance. For default driver, it returns nil.
func (d *DriverDefault) GetMigration() Migration {
return nil
}

View File

@ -527,7 +527,7 @@ func formatWhereHolder(ctx context.Context, db DB, in formatWhereHolderInput) (n
)
// If `Prefix` is given, it checks and retrieves the table name.
if in.Prefix != "" {
hasTable, _ := db.GetCore().HasTable(in.Prefix)
hasTable, _ := db.GetCore().HasTable(ctx, in.Prefix)
if hasTable {
in.Table = in.Prefix
} else {

304
database/gdb_migration.go Normal file
View File

@ -0,0 +1,304 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
)
// Migration defines the interface for database migration operations.
// It provides methods for creating, altering, and dropping database schema objects.
type Migration interface {
// ===========================================================================
// Auto Migration based on Entity Structs
// ===========================================================================
// AutoMigrate automatically creates or updates tables based on entity structs.
// It analyzes the struct fields and their tags to determine column definitions.
// The entities parameter accepts struct instances or pointers to structs.
AutoMigrate(ctx context.Context, entities ...any) error
// ===========================================================================
// Table Operations
// ===========================================================================
// CreateTable creates a new table with the given name and column definitions.
// The columns parameter is a map where keys are column names and values are column definitions.
CreateTable(ctx context.Context, table string, columns map[string]*ColumnDefinition, options ...TableOption) error
// DropTable drops an existing table from the database.
// If ifExists is true, it won't return an error if the table doesn't exist.
DropTable(ctx context.Context, table string, ifExists ...bool) error
// HasTable checks if a table exists in the database.
HasTable(ctx context.Context, table string) (bool, error)
// RenameTable renames an existing table from oldName to newName.
RenameTable(ctx context.Context, oldName, newName string) error
// TruncateTable removes all records from a table but keeps the table structure.
TruncateTable(ctx context.Context, table string) error
// ===========================================================================
// Column Operations
// ===========================================================================
// AddColumn adds a new column to an existing table.
AddColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error
// DropColumn removes a column from an existing table.
DropColumn(ctx context.Context, table, column string) error
// RenameColumn renames a column in an existing table.
RenameColumn(ctx context.Context, table, oldName, newName string) error
// ModifyColumn modifies an existing column's definition.
ModifyColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error
// HasColumn checks if a column exists in a table.
HasColumn(ctx context.Context, table, column string) (bool, error)
// ===========================================================================
// Index Operations
// ===========================================================================
// CreateIndex creates a new index on the specified table and columns.
CreateIndex(ctx context.Context, table, index string, columns []string, options ...IndexOption) error
// DropIndex drops an existing index from a table.
DropIndex(ctx context.Context, table, index string) error
// HasIndex checks if an index exists on a table.
HasIndex(ctx context.Context, table, index string) (bool, error)
// ===========================================================================
// Foreign Key Operations
// ===========================================================================
// CreateForeignKey creates a foreign key constraint.
CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...ForeignKeyOption) error
// DropForeignKey drops a foreign key constraint.
DropForeignKey(ctx context.Context, table, constraint string) error
// HasForeignKey checks if a foreign key constraint exists.
HasForeignKey(ctx context.Context, table, constraint string) (bool, error)
// ===========================================================================
// Schema Operations
// ===========================================================================
// CreateSchema creates a new database schema (namespace).
CreateSchema(ctx context.Context, schema string) error
// DropSchema drops an existing database schema.
DropSchema(ctx context.Context, schema string, cascade ...bool) error
// HasSchema checks if a schema exists.
HasSchema(ctx context.Context, schema string) (bool, error)
}
// ColumnDefinition defines the structure and properties of a database column.
type ColumnDefinition struct {
// Type specifies the database column type (e.g., "VARCHAR(255)", "INT", "TEXT").
Type string
// Null indicates whether the column can contain NULL values.
Null bool
// Default sets the default value for the column.
Default any
// Comment adds a comment to the column.
Comment string
// AutoIncrement enables auto-increment for numeric columns.
AutoIncrement bool
// PrimaryKey marks this column as part of the primary key.
PrimaryKey bool
// Unique marks this column as having unique values.
Unique bool
// Length specifies the length for string types.
Length int
// Precision specifies the precision for decimal/numeric types.
Precision int
// Scale specifies the scale for decimal/numeric types.
Scale int
}
// TableOption defines additional options for table creation.
type TableOption func(*TableOptions)
// TableOptions defines additional options for table creation.
type TableOptions struct {
// Engine specifies the storage engine (MySQL specific).
Engine string
// Charset specifies the character set.
Charset string
// Collation specifies the collation.
Collation string
// Comment adds a comment to the table.
Comment string
// IfNotExists prevents error if table already exists.
IfNotExists bool
}
// WithEngine sets the storage engine for the table (MySQL specific).
func WithEngine(engine string) TableOption {
return func(opts *TableOptions) {
opts.Engine = engine
}
}
// WithCharset sets the character set for the table.
func WithCharset(charset string) TableOption {
return func(opts *TableOptions) {
opts.Charset = charset
}
}
// WithCollation sets the collation for the table.
func WithCollation(collation string) TableOption {
return func(opts *TableOptions) {
opts.Collation = collation
}
}
// WithTableComment adds a comment to the table.
func WithTableComment(comment string) TableOption {
return func(opts *TableOptions) {
opts.Comment = comment
}
}
// WithIfNotExists prevents error if table already exists.
func WithIfNotExists() TableOption {
return func(opts *TableOptions) {
opts.IfNotExists = true
}
}
// IndexOption defines additional options for index creation.
type IndexOption func(*IndexOptions)
// IndexOptions defines additional options for index creation.
type IndexOptions struct {
// Unique marks the index as unique.
Unique bool
// FullText creates a full-text index (MySQL specific).
FullText bool
// Spatial creates a spatial index (MySQL specific).
Spatial bool
// Using specifies the index method (e.g., BTREE, HASH).
Using string
// Comment adds a comment to the index.
Comment string
// IfNotExists prevents error if index already exists.
IfNotExists bool
}
// WithUniqueIndex creates a unique index.
func WithUniqueIndex() IndexOption {
return func(opts *IndexOptions) {
opts.Unique = true
}
}
// WithFullTextIndex creates a full-text index (MySQL specific).
func WithFullTextIndex() IndexOption {
return func(opts *IndexOptions) {
opts.FullText = true
}
}
// WithSpatialIndex creates a spatial index (MySQL specific).
func WithSpatialIndex() IndexOption {
return func(opts *IndexOptions) {
opts.Spatial = true
}
}
// WithIndexUsing specifies the index method.
func WithIndexUsing(method string) IndexOption {
return func(opts *IndexOptions) {
opts.Using = method
}
}
// WithIndexComment adds a comment to the index.
func WithIndexComment(comment string) IndexOption {
return func(opts *IndexOptions) {
opts.Comment = comment
}
}
// WithIndexIfNotExists prevents error if index already exists.
func WithIndexIfNotExists() IndexOption {
return func(opts *IndexOptions) {
opts.IfNotExists = true
}
}
// ForeignKeyOption defines additional options for foreign key creation.
type ForeignKeyOption func(*ForeignKeyOptions)
// ForeignKeyOptions defines additional options for foreign key creation.
type ForeignKeyOptions struct {
// OnDelete specifies the action when referenced row is deleted.
OnDelete string
// OnUpdate specifies the action when referenced row is updated.
OnUpdate string
// Deferrable makes the constraint deferrable (PostgreSQL specific).
Deferrable bool
// InitiallyDeferred sets the constraint to be initially deferred (PostgreSQL specific).
InitiallyDeferred bool
}
// WithOnDelete sets the ON DELETE action for foreign key.
func WithOnDelete(action string) ForeignKeyOption {
return func(opts *ForeignKeyOptions) {
opts.OnDelete = action
}
}
// WithOnUpdate sets the ON UPDATE action for foreign key.
func WithOnUpdate(action string) ForeignKeyOption {
return func(opts *ForeignKeyOptions) {
opts.OnUpdate = action
}
}
// WithDeferrable makes the foreign key constraint deferrable (PostgreSQL specific).
func WithDeferrable() ForeignKeyOption {
return func(opts *ForeignKeyOptions) {
opts.Deferrable = true
}
}
// WithInitiallyDeferred sets the constraint to be initially deferred (PostgreSQL specific).
func WithInitiallyDeferred() ForeignKeyOption {
return func(opts *ForeignKeyOptions) {
opts.InitiallyDeferred = true
}
}

View File

@ -0,0 +1,452 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"fmt"
"reflect"
"strings"
"time"
)
// AutoMigrateCore provides automatic migration functionality based on entity structs.
type AutoMigrateCore struct {
db DB
}
// NewAutoMigrateCore creates a new AutoMigrateCore instance.
func NewAutoMigrateCore(db DB) *AutoMigrateCore {
return &AutoMigrateCore{db: db}
}
// AutoMigrate automatically creates or updates tables based on entity structs.
func (am *AutoMigrateCore) AutoMigrate(ctx context.Context, entities ...any) error {
for _, entity := range entities {
if err := am.migrateEntity(ctx, entity); err != nil {
return fmt.Errorf("failed to migrate entity %T: %w", entity, err)
}
}
return nil
}
// migrateEntity migrates a single entity to database.
func (am *AutoMigrateCore) migrateEntity(ctx context.Context, entity any) error {
// Get table name and columns from entity
tableName, columns, err := am.parseEntity(entity)
if err != nil {
return err
}
if len(columns) == 0 {
return fmt.Errorf("no columns found for table %s", tableName)
}
// Check if table exists
hasTable, err := am.db.HasTable(ctx, tableName)
if err != nil {
return fmt.Errorf("failed to check table existence: %w", err)
}
if !hasTable {
// Create table
return am.createTableFromColumns(ctx, tableName, columns)
}
// Update table structure
return am.updateTableStructure(ctx, tableName, columns)
}
// parseEntity parses an entity struct and returns table name and column definitions.
func (am *AutoMigrateCore) parseEntity(entity any) (string, map[string]*ColumnDefinition, error) {
val := reflect.ValueOf(entity)
// Handle pointer
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return "", nil, fmt.Errorf("entity must be a struct or pointer to struct")
}
typ := val.Type()
// Get table name from orm tag or struct name
tableName := am.getTableName(typ)
// Parse columns
columns := make(map[string]*ColumnDefinition)
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Parse field to column definition
colName, colDef, err := am.parseField(field, fieldValue)
if err != nil {
return "", nil, fmt.Errorf("failed to parse field %s: %w", field.Name, err)
}
if colName != "" && colDef != nil {
columns[colName] = colDef
}
}
return tableName, columns, nil
}
// getTableName extracts table name from struct tags or generates from struct name.
func (am *AutoMigrateCore) getTableName(typ reflect.Type) string {
// Check for orm tag
if tag, ok := typ.FieldByName("Meta"); ok {
ormTag := tag.Tag.Get("orm")
if ormTag != "" {
// Parse table:name from orm tag
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "table:") {
return strings.TrimPrefix(part, "table:")
}
}
}
}
// Check for table method (commented out as it requires instance)
// method, hasMethod := typ.MethodByName("TableName")
// if hasMethod {
// // Try to call TableName method if it exists
// }
// Convert struct name to snake_case table name
return camelToSnake(typ.Name())
}
// parseField parses a struct field into a column definition.
func (am *AutoMigrateCore) parseField(field reflect.StructField, fieldValue reflect.Value) (string, *ColumnDefinition, error) {
// Check orm tag
ormTag := field.Tag.Get("orm")
if ormTag == "-" {
// Skip field
return "", nil, nil
}
// Get column name
colName := am.getColumnName(field, ormTag)
// Build column definition
colDef := &ColumnDefinition{
Type: am.getFieldType(field, fieldValue),
Null: true, // Default to nullable
}
// Parse orm tag options
if ormTag != "" {
am.parseOrmTag(colDef, ormTag)
}
// Check for gorm/tag compatibility
if gormTag := field.Tag.Get("gorm"); gormTag != "" {
am.parseGormTag(colDef, gormTag)
}
// Check json tag for field presence
jsonTag := field.Tag.Get("json")
if jsonTag == "-" {
return "", nil, nil
}
return colName, colDef, nil
}
// getColumnName extracts column name from field name or tags.
func (am *AutoMigrateCore) getColumnName(field reflect.StructField, ormTag string) string {
// Check orm tag for explicit column name
if ormTag != "" {
parts := strings.Split(ormTag, ",")
namePart := strings.TrimSpace(parts[0])
if namePart != "" && !strings.Contains(namePart, ":") {
return namePart
}
}
// Use field name converted to snake_case
return camelToSnake(field.Name)
}
// getFieldType determines the database type for a Go field type.
func (am *AutoMigrateCore) getFieldType(field reflect.StructField, fieldValue reflect.Value) string {
// Check for explicit type in orm tag
ormTag := field.Tag.Get("orm")
if ormTag != "" {
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "type:") {
return strings.TrimPrefix(part, "type:")
}
}
}
// Infer type from Go type
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return "BIGINT"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "BIGINT UNSIGNED"
case reflect.Float32:
return "FLOAT"
case reflect.Float64:
return "DOUBLE"
case reflect.Bool:
return "BOOLEAN"
case reflect.String:
// Check for length specification
length := am.getStringLength(field)
if length > 0 {
return fmt.Sprintf("VARCHAR(%d)", length)
}
return "TEXT"
case reflect.Struct:
// Handle special types
typeName := fieldValue.Type().String()
switch {
case strings.Contains(typeName, "time.Time"):
return "TIMESTAMP"
default:
return "TEXT"
}
case reflect.Slice:
elemKind := fieldValue.Type().Elem().Kind()
if elemKind == reflect.Uint8 {
return "BLOB"
}
return "JSON"
default:
return "TEXT"
}
}
// getStringLength gets the specified string length from tags.
func (am *AutoMigrateCore) getStringLength(field reflect.StructField) int {
// Check orm tag
ormTag := field.Tag.Get("orm")
if ormTag != "" {
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "length:") {
var length int
fmt.Sscanf(strings.TrimPrefix(part, "length:"), "%d", &length)
return length
}
}
}
// Check gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "size:") {
var size int
fmt.Sscanf(strings.TrimPrefix(part, "size:"), "%d", &size)
return size
}
}
}
return 0
}
// parseOrmTag parses orm tag options.
func (am *AutoMigrateCore) parseOrmTag(colDef *ColumnDefinition, ormTag string) {
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
switch {
case part == "pk" || part == "primary_key":
colDef.PrimaryKey = true
colDef.Null = false
case part == "auto_increment":
colDef.AutoIncrement = true
case part == "not_null":
colDef.Null = false
case part == "unique":
colDef.Unique = true
case strings.HasPrefix(part, "default:"):
defaultVal := strings.TrimPrefix(part, "default:")
colDef.Default = defaultVal
case strings.HasPrefix(part, "comment:"):
colDef.Comment = strings.TrimPrefix(part, "comment:")
}
}
}
// parseGormTag parses gorm tag options for compatibility.
func (am *AutoMigrateCore) parseGormTag(colDef *ColumnDefinition, gormTag string) {
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
lowerPart := strings.ToLower(part)
switch {
case lowerPart == "primarykey" || lowerPart == "primaryKey":
colDef.PrimaryKey = true
colDef.Null = false
case lowerPart == "autoincrement":
colDef.AutoIncrement = true
case lowerPart == "not null":
colDef.Null = false
case lowerPart == "unique":
colDef.Unique = true
case strings.HasPrefix(lowerPart, "default:"):
defaultVal := strings.TrimPrefix(part, "default:")
colDef.Default = defaultVal
case strings.HasPrefix(lowerPart, "comment:"):
colDef.Comment = strings.TrimPrefix(part, "comment:")
}
}
}
// createTableFromColumns creates a table from column definitions.
func (am *AutoMigrateCore) createTableFromColumns(ctx context.Context, table string, columns map[string]*ColumnDefinition) error {
// Get the migration instance based on database type
migration := am.getMigrationInstance()
if migration == nil {
return fmt.Errorf("failed to get migration instance for database type %s", am.db.GetConfig().Type)
}
return migration.CreateTable(ctx, table, columns)
}
// updateTableStructure updates table structure by comparing with existing columns.
func (am *AutoMigrateCore) updateTableStructure(ctx context.Context, table string, newColumns map[string]*ColumnDefinition) error {
// Get existing columns
existingFields, err := am.db.TableFields(ctx, table)
if err != nil {
return fmt.Errorf("failed to get table fields: %w", err)
}
// Add missing columns
for colName, colDef := range newColumns {
if _, exists := existingFields[colName]; !exists {
// Column doesn't exist, add it
if err := am.addColumn(ctx, table, colName, colDef); err != nil {
return fmt.Errorf("failed to add column %s: %w", colName, err)
}
} else {
// Column exists, check if modification is needed
if err := am.modifyColumnIfNeeded(ctx, table, colName, colDef, existingFields[colName]); err != nil {
return fmt.Errorf("failed to modify column %s: %w", colName, err)
}
}
}
return nil
}
// addColumn adds a new column to table.
func (am *AutoMigrateCore) addColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
migration := am.getMigrationInstance()
if migration == nil {
return fmt.Errorf("failed to get migration instance")
}
return migration.AddColumn(ctx, table, column, definition)
}
// modifyColumnIfNeeded checks if column needs modification and applies it.
func (am *AutoMigrateCore) modifyColumnIfNeeded(ctx context.Context, table, column string, newDef *ColumnDefinition, existingField *TableField) error {
// Compare and modify if needed
needsModification := false
// Check type
if !strings.EqualFold(existingField.Type, newDef.Type) {
needsModification = true
}
// Check nullability
if existingField.Null != newDef.Null {
needsModification = true
}
if needsModification {
return am.modifyColumn(ctx, table, column, newDef)
}
return nil
}
// modifyColumn modifies an existing column.
func (am *AutoMigrateCore) modifyColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
migration := am.getMigrationInstance()
if migration == nil {
return fmt.Errorf("failed to get migration instance")
}
return migration.ModifyColumn(ctx, table, column, definition)
}
// getMigrationInstance returns the appropriate Migration instance based on database type.
func (am *AutoMigrateCore) getMigrationInstance() Migration {
return am.db.GetMigration()
}
// camelToSnake converts CamelCase to snake_case.
func camelToSnake(s string) string {
var result strings.Builder
for i, r := range s {
if r >= 'A' && r <= 'Z' {
// Add underscore before uppercase letter if:
// 1. It's not the first character
// 2. The previous character is lowercase or the next character is lowercase
if i > 0 {
prevRune := rune(s[i-1])
if (prevRune >= 'a' && prevRune <= 'z') ||
(i+1 < len(s) && rune(s[i+1]) >= 'a' && rune(s[i+1]) <= 'z') {
result.WriteRune('_')
}
}
result.WriteRune(r + 32)
} else {
result.WriteRune(r)
}
}
return result.String()
}
// FormatTime formats time for default values.
func FormatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
// IsZeroValue checks if a value is zero value.
func IsZeroValue(v any) bool {
if v == nil {
return true
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return rv.Len() == 0
case reflect.Bool:
return !rv.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return rv.Uint() == 0
case reflect.Float32, reflect.Float64:
return rv.Float() == 0
case reflect.Interface, reflect.Ptr:
return rv.IsNil()
}
return false
}

View File

@ -0,0 +1,135 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"fmt"
"strings"
)
// MigrationCore provides the base implementation for migration operations.
// It serves as a foundation that database-specific drivers can extend.
type MigrationCore struct {
db DB // Database interface for executing SQL statements
}
// NewMigrationCore creates a new MigrationCore instance.
func NewMigrationCore(db DB) *MigrationCore {
return &MigrationCore{db: db}
}
// GetDB returns the underlying database interface.
func (m *MigrationCore) GetDB() DB {
return m.db
}
// ExecuteSQL executes a raw SQL statement.
func (m *MigrationCore) ExecuteSQL(ctx context.Context, sql string, args ...any) error {
_, err := m.db.Exec(ctx, sql, args...)
return err
}
// BuildColumnDefinition builds the column definition SQL string from ColumnDefinition.
func (m *MigrationCore) BuildColumnDefinition(name string, def *ColumnDefinition) string {
var parts []string
parts = append(parts, QuoteIdentifier(name))
parts = append(parts, def.Type)
if !def.Null {
parts = append(parts, "NOT NULL")
}
if def.AutoIncrement {
parts = append(parts, "AUTO_INCREMENT")
}
if def.PrimaryKey {
parts = append(parts, "PRIMARY KEY")
}
if def.Unique && !def.PrimaryKey {
parts = append(parts, "UNIQUE")
}
if def.Default != nil {
defaultValue := formatDefaultValue(def.Default)
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
}
return strings.Join(parts, " ")
}
// formatDefaultValue formats the default value for SQL.
func formatDefaultValue(value any) string {
switch v := value.(type) {
case string:
return fmt.Sprintf("'%s'", escapeString(v))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
if v {
return "TRUE"
}
return "FALSE"
case nil:
return "NULL"
default:
return fmt.Sprintf("'%v'", v)
}
}
// escapeString escapes special characters in strings for SQL.
func escapeString(s string) string {
s = strings.ReplaceAll(s, "'", "''")
s = strings.ReplaceAll(s, "\\", "\\\\")
return s
}
// QuoteIdentifier quotes an identifier (table name, column name, etc.) with backticks.
func QuoteIdentifier(name string) string {
if strings.Contains(name, ".") {
parts := strings.Split(name, ".")
for i, part := range parts {
parts[i] = fmt.Sprintf("`%s`", part)
}
return strings.Join(parts, ".")
}
return fmt.Sprintf("`%s`", name)
}
// QuoteIdentifierDouble quotes an identifier with double quotes (for PostgreSQL, Oracle, etc.).
func QuoteIdentifierDouble(name string) string {
if strings.Contains(name, ".") {
parts := strings.Split(name, ".")
for i, part := range parts {
parts[i] = fmt.Sprintf(`"%s"`, part)
}
return strings.Join(parts, ".")
}
return fmt.Sprintf(`"%s"`, name)
}
// BuildIndexColumns builds the column list for index creation.
func (m *MigrationCore) BuildIndexColumns(columns []string) string {
quoted := make([]string, len(columns))
for i, col := range columns {
quoted[i] = QuoteIdentifier(col)
}
return strings.Join(quoted, ", ")
}
// BuildForeignKeyColumns builds the column list for foreign key.
func (m *MigrationCore) BuildForeignKeyColumns(columns []string) string {
quoted := make([]string, len(columns))
for i, col := range columns {
quoted[i] = QuoteIdentifier(col)
}
return "(" + strings.Join(quoted, ", ") + ")"
}

View File

@ -66,7 +66,7 @@ func (m *Model) getModel() *Model {
func (m *Model) mappingAndFilterToTableFields(table string, fields []any, filter bool) []any {
var fieldsTable = table
if fieldsTable != "" {
hasTable, _ := m.db.GetCore().HasTable(fieldsTable)
hasTable, _ := m.db.GetCore().HasTable(m.GetCtx(), fieldsTable)
if !hasTable {
if fieldsTable != m.tablesInit {
// Table/alias unknown (e.g., FieldsPrefix called before LeftJoin), skip filtering.

3
go.mod
View File

@ -22,6 +22,7 @@ require (
github.com/shopspring/decimal v1.3.1
github.com/sijms/go-ora/v2 v2.7.10
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
go.opentelemetry.io/otel v1.42.0
go.opentelemetry.io/otel/trace v1.42.0
golang.org/x/crypto v0.49.0
@ -42,6 +43,7 @@ require (
github.com/clipperhouse/displaywidth v0.10.0 // indirect
github.com/clipperhouse/uax29/v2 v2.6.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect
github.com/fatih/color v1.19.0 // indirect
@ -78,6 +80,7 @@ require (
github.com/paulmach/orb v0.7.1 // indirect
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect

397
session/session.go Normal file
View File

@ -0,0 +1,397 @@
package session
import (
"encoding/json"
"fmt"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// SessionManager Session管理器负责Session的创建、获取和销毁
type SessionManager struct {
sessions map[string]*Session // 存储所有活跃的Sessionkey为SessionID
mutex sync.RWMutex // 读写锁,保证并发安全
maxAge time.Duration // Session最大存活时间
}
// Session 单个Session对象用于存储用户数据
type Session struct {
ID string // Session唯一标识符
Data map[string]interface{} // 存储的数据,键值对形式
CreateTime time.Time // 创建时间
LastAccess time.Time // 最后访问时间
mutex sync.RWMutex // 读写锁,保证并发安全
}
var (
// defaultManager 默认的Session管理器实例
defaultManager *SessionManager
// once 确保只初始化一次
once sync.Once
)
// init 包初始化时自动创建默认管理器
func init() {
InitDefaultManager(30 * time.Minute) // 默认30分钟过期
}
// InitDefaultManager 初始化默认Session管理器
// 参数 duration: Session的最大存活时间
func InitDefaultManager(duration time.Duration) {
once.Do(func() {
defaultManager = NewSessionManager(duration)
})
}
// NewSessionManager 创建一个新的Session管理器
// 参数 duration: Session的最大存活时间
// 返回值: Session管理器指针
func NewSessionManager(duration time.Duration) *SessionManager {
sm := &SessionManager{
sessions: make(map[string]*Session),
maxAge: duration,
}
// 启动定时清理任务每5分钟清理一次过期的Session
go sm.startCleanupTicker(5 * time.Minute)
return sm
}
// startCleanupTicker 启动定时清理过期Session的任务
// 参数 interval: 清理间隔时间
func (sm *SessionManager) startCleanupTicker(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
sm.CleanupExpiredSessions()
}
}
// CreateSession 创建新的Session并返回Session ID
// 参数 c: Gin上下文对象
// 返回值: Session ID字符串
func (sm *SessionManager) CreateSession(c *gin.Context) string {
sessionID := uuid.New().String()
session := &Session{
ID: sessionID,
Data: make(map[string]interface{}),
CreateTime: time.Now(),
LastAccess: time.Now(),
}
sm.mutex.Lock()
sm.sessions[sessionID] = session
sm.mutex.Unlock()
// 设置Cookie保存Session ID
c.SetCookie(
"session_id", // Cookie名称
sessionID, // Cookie值
int(sm.maxAge.Seconds()), // 过期时间(秒)
"/", // 路径
"", // 域名(空表示当前域名)
false, // 是否仅HTTPS
true, // 是否HttpOnly防止XSS攻击
)
return sessionID
}
// GetSession 根据Gin上下文获取Session对象
// 参数 c: Gin上下文对象
// 返回值: Session对象指针如果不存在则返回nil
func (sm *SessionManager) GetSession(c *gin.Context) *Session {
sessionID, err := c.Cookie("session_id")
if err != nil {
return nil
}
sm.mutex.RLock()
session, exists := sm.sessions[sessionID]
sm.mutex.RUnlock()
if !exists {
return nil
}
// 检查Session是否过期
if time.Since(session.LastAccess) > sm.maxAge {
sm.DestroySession(sessionID)
return nil
}
// 更新最后访问时间
session.mutex.Lock()
session.LastAccess = time.Now()
session.mutex.Unlock()
return session
}
// GetSessionByID 根据Session ID获取Session对象
// 参数 sessionID: Session的唯一标识符
// 返回值: Session对象指针如果不存在则返回nil
func (sm *SessionManager) GetSessionByID(sessionID string) *Session {
sm.mutex.RLock()
session, exists := sm.sessions[sessionID]
sm.mutex.RUnlock()
if !exists {
return nil
}
// 检查Session是否过期
if time.Since(session.LastAccess) > sm.maxAge {
sm.DestroySession(sessionID)
return nil
}
// 更新最后访问时间
session.mutex.Lock()
session.LastAccess = time.Now()
session.mutex.Unlock()
return session
}
// DestroySession 销毁指定的Session
// 参数 sessionID: Session的唯一标识符
func (sm *SessionManager) DestroySession(sessionID string) {
sm.mutex.Lock()
delete(sm.sessions, sessionID)
sm.mutex.Unlock()
}
// DestroySessionByContext 根据Gin上下文销毁Session
// 参数 c: Gin上下文对象
func (sm *SessionManager) DestroySessionByContext(c *gin.Context) {
sessionID, err := c.Cookie("session_id")
if err != nil {
return
}
sm.DestroySession(sessionID)
// 清除Cookie
c.SetCookie(
"session_id",
"",
-1, // 立即过期
"/",
"",
false,
true,
)
}
// CleanupExpiredSessions 清理所有过期的Session
func (sm *SessionManager) CleanupExpiredSessions() {
now := time.Now()
expiredIDs := make([]string, 0)
sm.mutex.RLock()
for id, session := range sm.sessions {
session.mutex.RLock()
if now.Sub(session.LastAccess) > sm.maxAge {
expiredIDs = append(expiredIDs, id)
}
session.mutex.RUnlock()
}
sm.mutex.RUnlock()
// 删除过期的Session
sm.mutex.Lock()
for _, id := range expiredIDs {
delete(sm.sessions, id)
}
sm.mutex.Unlock()
}
// GetSessionCount 获取当前活跃的Session数量
// 返回值: Session数量
func (sm *SessionManager) GetSessionCount() int {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
return len(sm.sessions)
}
// Set 在Session中设置键值对
// 参数 key: 键名
// 参数 value: 值
func (s *Session) Set(key string, value interface{}) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.Data[key] = value
}
// Get 从Session中获取值
// 参数 key: 键名
// 返回值: 对应的值如果不存在则返回nil
func (s *Session) Get(key string) interface{} {
s.mutex.RLock()
defer s.mutex.RUnlock()
return s.Data[key]
}
// GetString 从Session中获取字符串类型的值
// 参数 key: 键名
// 返回值: 字符串值,如果不存在或类型不匹配则返回空字符串
func (s *Session) GetString(key string) string {
s.mutex.RLock()
defer s.mutex.RUnlock()
if val, ok := s.Data[key]; ok {
if str, ok := val.(string); ok {
return str
}
}
return ""
}
// GetInt 从Session中获取整数类型的值
// 参数 key: 键名
// 返回值: 整数值如果不存在或类型不匹配则返回0
func (s *Session) GetInt(key string) int {
s.mutex.RLock()
defer s.mutex.RUnlock()
if val, ok := s.Data[key]; ok {
switch v := val.(type) {
case int:
return v
case float64: // JSON解码时数字会变成float64
return int(v)
}
}
return 0
}
// GetFloat64 从Session中获取浮点数类型的值
// 参数 key: 键名
// 返回值: 浮点数值如果不存在或类型不匹配则返回0
func (s *Session) GetFloat64(key string) float64 {
s.mutex.RLock()
defer s.mutex.RUnlock()
if val, ok := s.Data[key]; ok {
switch v := val.(type) {
case float64:
return v
case int:
return float64(v)
}
}
return 0
}
// GetBool 从Session中获取布尔类型的值
// 参数 key: 键名
// 返回值: 布尔值如果不存在或类型不匹配则返回false
func (s *Session) GetBool(key string) bool {
s.mutex.RLock()
defer s.mutex.RUnlock()
if val, ok := s.Data[key]; ok {
if b, ok := val.(bool); ok {
return b
}
}
return false
}
// Delete 从Session中删除指定的键
// 参数 key: 要删除的键名
func (s *Session) Delete(key string) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.Data, key)
}
// Clear 清空Session中的所有数据
func (s *Session) Clear() {
s.mutex.Lock()
defer s.mutex.Unlock()
s.Data = make(map[string]interface{})
}
// Has 检查Session中是否存在指定的键
// 参数 key: 键名
// 返回值: 如果存在返回true否则返回false
func (s *Session) Has(key string) bool {
s.mutex.RLock()
defer s.mutex.RUnlock()
_, exists := s.Data[key]
return exists
}
// GetAll 获取Session中的所有数据
// 返回值: 包含所有数据的map副本
func (s *Session) GetAll() map[string]interface{} {
s.mutex.RLock()
defer s.mutex.RUnlock()
// 返回副本,避免外部修改
result := make(map[string]interface{}, len(s.Data))
for k, v := range s.Data {
result[k] = v
}
return result
}
// ToJSON 将Session数据转换为JSON字符串
// 返回值: JSON字符串如果转换失败则返回错误
func (s *Session) ToJSON() (string, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
data, err := json.Marshal(s.Data)
if err != nil {
return "", fmt.Errorf("序列化Session数据失败: %w", err)
}
return string(data), nil
}
// FromJSON 从JSON字符串恢复Session数据
// 参数 jsonStr: JSON字符串
// 返回值: 如果解析失败则返回错误
func (s *Session) FromJSON(jsonStr string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
var data map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
return fmt.Errorf("反序列化Session数据失败: %w", err)
}
s.Data = data
return nil
}
// GetDefaultManager 获取默认的Session管理器
// 返回值: 默认Session管理器指针
func GetDefaultManager() *SessionManager {
return defaultManager
}
// CreateSession 使用默认管理器创建Session
// 参数 c: Gin上下文对象
// 返回值: Session ID字符串
func CreateSession(c *gin.Context) string {
return defaultManager.CreateSession(c)
}
// GetSession 使用默认管理器获取Session
// 参数 c: Gin上下文对象
// 返回值: Session对象指针
func GetSession(c *gin.Context) *Session {
return defaultManager.GetSession(c)
}
// DestroySession 使用默认管理器销毁Session
// 参数 c: Gin上下文对象
func DestroySession(c *gin.Context) {
defaultManager.DestroySessionByContext(c)
}