feat(database): 添加ClickHouse数据库驱动支持
- 实现了ClickHouse数据库驱动程序,支持基本的数据库操作 - 添加了ClickHouse特定的迁移功能,包括表、列、索引的创建和管理 - 集成了ClickHouse的语法特性,如MergeTree引擎和Nullable类型 - 实现了数据库连接池管理和SQL执行接口 - 添加了对系统表查询的支持,用于检查表和列的存在性main v1.0.2023
parent
a083b74f9b
commit
284f9380ed
49
cmd/main.go
49
cmd/main.go
|
|
@ -41,16 +41,59 @@ func main() {
|
|||
user := getStringValue(defaultDbConfig, "user", "root")
|
||||
pass := getStringValue(defaultDbConfig, "pass", "")
|
||||
dbType := getStringValue(defaultDbConfig, "type", "mysql")
|
||||
link := getStringValue(defaultDbConfig, "link", "")
|
||||
|
||||
fmt.Println("=== Gin-Base DAO 代码生成工具 ===")
|
||||
fmt.Printf("📊 数据库: %s\n", name)
|
||||
fmt.Printf("🔧 类型: %s\n", dbType)
|
||||
fmt.Printf("🌐 主机: %s:%s\n\n", host, port)
|
||||
|
||||
// 构建数据库连接字符串
|
||||
link := fmt.Sprintf("mysql:%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=true&loc=Local",
|
||||
var connectionInfo string
|
||||
if link == "" {
|
||||
// 如果没有配置 link,则根据数据库类型构建
|
||||
switch dbType {
|
||||
case "sqlite":
|
||||
// SQLite 使用配置文件中的 link 或默认路径
|
||||
connectionInfo = fmt.Sprintf("📁 数据库文件: %s", link)
|
||||
case "mysql", "mariadb":
|
||||
link = fmt.Sprintf("mysql:%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=true&loc=Local",
|
||||
user, pass, host, port, name,
|
||||
)
|
||||
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
|
||||
case "pgsql", "postgresql":
|
||||
link = fmt.Sprintf("pgsql:%s:%s@tcp(%s:%s)/%s?sslmode=disable",
|
||||
user, pass, host, port, name,
|
||||
)
|
||||
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
|
||||
case "mssql":
|
||||
link = fmt.Sprintf("mssql:%s:%s@tcp(%s:%s)/%s",
|
||||
user, pass, host, port, name,
|
||||
)
|
||||
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
|
||||
case "oracle":
|
||||
link = fmt.Sprintf("oracle:%s:%s@%s:%s/%s",
|
||||
user, pass, host, port, name,
|
||||
)
|
||||
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
|
||||
case "clickhouse":
|
||||
link = fmt.Sprintf("clickhouse:%s:%s@tcp(%s:%s)/%s",
|
||||
user, pass, host, port, name,
|
||||
)
|
||||
connectionInfo = fmt.Sprintf("📊 数据库: %s\n🌐 主机: %s:%s", name, host, port)
|
||||
default:
|
||||
fmt.Printf("⚠️ 警告: 未知的数据库类型 %s,尝试使用配置的 link\n", dbType)
|
||||
if link == "" {
|
||||
fmt.Println("❌ 错误: 未配置数据库连接信息")
|
||||
os.Exit(1)
|
||||
}
|
||||
connectionInfo = "🔗 使用自定义连接"
|
||||
}
|
||||
} else {
|
||||
// 使用配置文件中直接提供的 link
|
||||
connectionInfo = fmt.Sprintf("🔗 连接: %s", link)
|
||||
}
|
||||
|
||||
fmt.Println(connectionInfo)
|
||||
fmt.Println()
|
||||
|
||||
// 准备表名参数
|
||||
tablesArg := ""
|
||||
|
|
|
|||
|
|
@ -61,3 +61,13 @@ func (d *Driver) injectNeedParsedSql(ctx context.Context) context.Context {
|
|||
}
|
||||
return context.WithValue(ctx, needParsedSqlInCtx, true)
|
||||
}
|
||||
|
||||
// Migration returns a Migration instance for ClickHouse database operations.
|
||||
func (d *Driver) Migration() *Migration {
|
||||
return NewMigration(d)
|
||||
}
|
||||
|
||||
// GetMigration returns a Migration instance implementing the Migration interface.
|
||||
func (d *Driver) GetMigration() database.Migration {
|
||||
return d.Migration()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,321 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database"
|
||||
)
|
||||
|
||||
// Migration implements database migration operations for ClickHouse.
|
||||
type Migration struct {
|
||||
*database.MigrationCore
|
||||
*database.AutoMigrateCore
|
||||
}
|
||||
|
||||
// NewMigration creates a new ClickHouse Migration instance.
|
||||
func NewMigration(db database.DB) *Migration {
|
||||
return &Migration{
|
||||
MigrationCore: database.NewMigrationCore(db),
|
||||
AutoMigrateCore: database.NewAutoMigrateCore(db),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates a new table with the given name and column definitions.
|
||||
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
|
||||
if len(columns) == 0 {
|
||||
return fmt.Errorf("cannot create table without columns")
|
||||
}
|
||||
|
||||
var opts database.TableOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE TABLE ")
|
||||
if opts.IfNotExists {
|
||||
sql.WriteString("IF NOT EXISTS ")
|
||||
}
|
||||
sql.WriteString(database.QuoteIdentifier(table))
|
||||
sql.WriteString(" (\n")
|
||||
|
||||
// Add columns
|
||||
var colDefs []string
|
||||
for name, def := range columns {
|
||||
colDef := m.buildColumnDefinition(name, def)
|
||||
colDefs = append(colDefs, " "+colDef)
|
||||
}
|
||||
|
||||
sql.WriteString(strings.Join(colDefs, ",\n"))
|
||||
sql.WriteString("\n)")
|
||||
|
||||
// Add engine specification (required for ClickHouse)
|
||||
sql.WriteString(" ENGINE = MergeTree()")
|
||||
|
||||
if opts.Comment != "" {
|
||||
sql.WriteString(fmt.Sprintf(" COMMENT '%s'", escapeString(opts.Comment)))
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// buildColumnDefinition builds column definition for ClickHouse.
|
||||
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
|
||||
var parts []string
|
||||
parts = append(parts, database.QuoteIdentifier(name))
|
||||
|
||||
// Handle ClickHouse-specific types
|
||||
dbType := def.Type
|
||||
parts = append(parts, dbType)
|
||||
|
||||
if !def.Null {
|
||||
// ClickHouse uses Nullable type wrapper
|
||||
if !strings.HasPrefix(dbType, "Nullable(") {
|
||||
// Type is already non-nullable by default in ClickHouse
|
||||
}
|
||||
} else {
|
||||
// Make type nullable if needed
|
||||
if !strings.HasPrefix(dbType, "Nullable(") {
|
||||
parts[len(parts)-1] = fmt.Sprintf("Nullable(%s)", dbType)
|
||||
}
|
||||
}
|
||||
|
||||
if def.Default != nil {
|
||||
defaultValue := formatDefaultValue(def.Default)
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
|
||||
}
|
||||
|
||||
if def.Comment != "" {
|
||||
parts = append(parts, fmt.Sprintf("COMMENT '%s'", escapeString(def.Comment)))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// DropTable drops an existing table from the database.
|
||||
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
|
||||
sql := "DROP TABLE "
|
||||
if len(ifExists) > 0 && ifExists[0] {
|
||||
sql += "IF EXISTS "
|
||||
}
|
||||
sql += database.QuoteIdentifier(table)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "currentDatabase()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM system.tables WHERE database = %s AND name = '%s'",
|
||||
schema,
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// RenameTable renames an existing table from oldName to newName.
|
||||
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"RENAME TABLE %s TO %s",
|
||||
database.QuoteIdentifier(oldName),
|
||||
database.QuoteIdentifier(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// TruncateTable removes all records from a table but keeps the table structure.
|
||||
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
|
||||
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifier(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// AddColumn adds a new column to an existing table.
|
||||
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifier(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropColumn removes a column from an existing table.
|
||||
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP COLUMN %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(column),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// RenameColumn renames a column in an existing table.
|
||||
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME COLUMN %s TO %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(oldName),
|
||||
database.QuoteIdentifier(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// ModifyColumn modifies an existing column's definition.
|
||||
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", database.QuoteIdentifier(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasColumn checks if a column exists in a table.
|
||||
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
|
||||
fields, err := m.GetDB().TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, exists := fields[column]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CreateIndex creates a new index on the specified table and columns.
|
||||
// Note: ClickHouse has limited index support compared to traditional databases.
|
||||
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
|
||||
var opts database.IndexOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
// ClickHouse uses different indexing mechanisms
|
||||
// For now, we'll use ALTER TABLE ADD INDEX syntax
|
||||
colList := m.BuildIndexColumns(columns)
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD INDEX %s (%s) TYPE minmax GRANULARITY 1",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(index),
|
||||
colList,
|
||||
)
|
||||
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropIndex drops an existing index from a table.
|
||||
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP INDEX %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(index),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasIndex checks if an index exists on a table.
|
||||
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "currentDatabase()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM system.data_skipping_indices WHERE database = %s AND table = '%s' AND name = '%s'",
|
||||
schema,
|
||||
table,
|
||||
index,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateForeignKey is not supported in ClickHouse.
|
||||
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
|
||||
// ClickHouse does not support foreign keys
|
||||
return fmt.Errorf("ClickHouse does not support foreign key constraints")
|
||||
}
|
||||
|
||||
// DropForeignKey is not supported in ClickHouse.
|
||||
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
|
||||
// ClickHouse does not support foreign keys
|
||||
return fmt.Errorf("ClickHouse does not support foreign key constraints")
|
||||
}
|
||||
|
||||
// HasForeignKey always returns false as ClickHouse doesn't support foreign keys.
|
||||
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// CreateSchema creates a new database schema.
|
||||
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
|
||||
sql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", database.QuoteIdentifier(schema))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropSchema drops an existing database schema.
|
||||
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
|
||||
sql := fmt.Sprintf("DROP DATABASE IF EXISTS %s", database.QuoteIdentifier(schema))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasSchema checks if a schema exists.
|
||||
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM system.databases WHERE name = '%s'",
|
||||
schema,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// formatDefaultValue formats the default value for SQL.
|
||||
func formatDefaultValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", escapeString(v))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case nil:
|
||||
return "NULL"
|
||||
default:
|
||||
return fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeString escapes special characters in strings for SQL.
|
||||
func escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, "'", "\\'")
|
||||
return s
|
||||
}
|
||||
|
|
@ -46,3 +46,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
|
|||
func (d *Driver) GetChars() (charLeft string, charRight string) {
|
||||
return quoteChar, quoteChar
|
||||
}
|
||||
|
||||
// Migration returns a Migration instance for MSSQL database operations.
|
||||
func (d *Driver) Migration() *Migration {
|
||||
return NewMigration(d)
|
||||
}
|
||||
|
||||
// GetMigration returns a Migration instance implementing the Migration interface.
|
||||
func (d *Driver) GetMigration() database.Migration {
|
||||
return d.Migration()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,340 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database"
|
||||
)
|
||||
|
||||
// Migration implements database migration operations for SQL Server.
|
||||
type Migration struct {
|
||||
*database.MigrationCore
|
||||
*database.AutoMigrateCore
|
||||
}
|
||||
|
||||
// NewMigration creates a new SQL Server Migration instance.
|
||||
func NewMigration(db database.DB) *Migration {
|
||||
return &Migration{
|
||||
MigrationCore: database.NewMigrationCore(db),
|
||||
AutoMigrateCore: database.NewAutoMigrateCore(db),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates a new table with the given name and column definitions.
|
||||
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
|
||||
if len(columns) == 0 {
|
||||
return fmt.Errorf("cannot create table without columns")
|
||||
}
|
||||
|
||||
var opts database.TableOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE TABLE ")
|
||||
sql.WriteString(database.QuoteIdentifier(table))
|
||||
sql.WriteString(" (\n")
|
||||
|
||||
// Add columns
|
||||
var colDefs []string
|
||||
var primaryKeys []string
|
||||
for name, def := range columns {
|
||||
colDef := m.buildColumnDefinition(name, def)
|
||||
if def.PrimaryKey {
|
||||
primaryKeys = append(primaryKeys, database.QuoteIdentifier(name))
|
||||
}
|
||||
colDefs = append(colDefs, " "+colDef)
|
||||
}
|
||||
|
||||
// Add primary key constraint if needed
|
||||
if len(primaryKeys) > 0 {
|
||||
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
|
||||
}
|
||||
|
||||
sql.WriteString(strings.Join(colDefs, ",\n"))
|
||||
sql.WriteString("\n)")
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// buildColumnDefinition builds column definition for SQL Server.
|
||||
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
|
||||
var parts []string
|
||||
parts = append(parts, database.QuoteIdentifier(name))
|
||||
|
||||
// Handle SQL Server-specific types
|
||||
dbType := def.Type
|
||||
if def.AutoIncrement {
|
||||
if dbType == "INT" || dbType == "INTEGER" {
|
||||
dbType = "INT IDENTITY(1,1)"
|
||||
} else if dbType == "BIGINT" {
|
||||
dbType = "BIGINT IDENTITY(1,1)"
|
||||
}
|
||||
}
|
||||
parts = append(parts, dbType)
|
||||
|
||||
if !def.Null {
|
||||
parts = append(parts, "NOT NULL")
|
||||
} else {
|
||||
parts = append(parts, "NULL")
|
||||
}
|
||||
|
||||
if def.Default != nil {
|
||||
defaultValue := formatDefaultValue(def.Default)
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// DropTable drops an existing table from the database.
|
||||
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
|
||||
if len(ifExists) > 0 && ifExists[0] {
|
||||
sql := fmt.Sprintf(
|
||||
"IF OBJECT_ID('%s', 'U') IS NOT NULL DROP TABLE %s",
|
||||
table,
|
||||
database.QuoteIdentifier(table),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
sql := fmt.Sprintf("DROP TABLE %s", database.QuoteIdentifier(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '%s'",
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// RenameTable renames an existing table from oldName to newName.
|
||||
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"EXEC sp_rename '%s', '%s'",
|
||||
oldName,
|
||||
newName,
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// TruncateTable removes all records from a table but keeps the table structure.
|
||||
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
|
||||
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifier(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// AddColumn adds a new column to an existing table.
|
||||
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s ADD %s", database.QuoteIdentifier(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropColumn removes a column from an existing table.
|
||||
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP COLUMN %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(column),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// RenameColumn renames a column in an existing table.
|
||||
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"EXEC sp_rename '%s.%s', '%s', 'COLUMN'",
|
||||
table,
|
||||
oldName,
|
||||
newName,
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// ModifyColumn modifies an existing column's definition.
|
||||
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", database.QuoteIdentifier(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasColumn checks if a column exists in a table.
|
||||
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
|
||||
fields, err := m.GetDB().TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, exists := fields[column]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CreateIndex creates a new index on the specified table and columns.
|
||||
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
|
||||
var opts database.IndexOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE ")
|
||||
|
||||
if opts.Unique {
|
||||
sql.WriteString("UNIQUE ")
|
||||
}
|
||||
|
||||
sql.WriteString("INDEX ")
|
||||
sql.WriteString(database.QuoteIdentifier(index))
|
||||
sql.WriteString(" ON ")
|
||||
sql.WriteString(database.QuoteIdentifier(table))
|
||||
|
||||
colList := m.BuildIndexColumns(columns)
|
||||
sql.WriteString(fmt.Sprintf(" (%s)", colList))
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// DropIndex drops an existing index from a table.
|
||||
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
|
||||
sql := fmt.Sprintf("DROP INDEX %s ON %s", database.QuoteIdentifier(index), database.QuoteIdentifier(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasIndex checks if an index exists on a table.
|
||||
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM sys.indexes WHERE name = '%s' AND object_id = OBJECT_ID('%s')",
|
||||
index,
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateForeignKey creates a foreign key constraint.
|
||||
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
|
||||
var opts database.ForeignKeyOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(constraint),
|
||||
m.BuildForeignKeyColumns(columns),
|
||||
database.QuoteIdentifier(refTable),
|
||||
m.BuildForeignKeyColumns(refColumns),
|
||||
)
|
||||
|
||||
if opts.OnDelete != "" {
|
||||
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
|
||||
}
|
||||
if opts.OnUpdate != "" {
|
||||
sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate)
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropForeignKey drops a foreign key constraint.
|
||||
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP CONSTRAINT %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(constraint),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasForeignKey checks if a foreign key constraint exists.
|
||||
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_NAME = '%s' AND TABLE_NAME = '%s' AND CONSTRAINT_TYPE = 'FOREIGN KEY'",
|
||||
constraint,
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateSchema creates a new database schema.
|
||||
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
|
||||
sql := fmt.Sprintf("CREATE SCHEMA %s", database.QuoteIdentifier(schema))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropSchema drops an existing database schema.
|
||||
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
|
||||
sql := fmt.Sprintf("DROP SCHEMA %s", database.QuoteIdentifier(schema))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasSchema checks if a schema exists.
|
||||
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'",
|
||||
schema,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// formatDefaultValue formats the default value for SQL.
|
||||
func formatDefaultValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", escapeString(v))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case nil:
|
||||
return "NULL"
|
||||
default:
|
||||
return fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeString escapes special characters in strings for SQL.
|
||||
func escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, "'", "''")
|
||||
return s
|
||||
}
|
||||
|
|
@ -52,3 +52,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
|
|||
func (d *Driver) GetChars() (charLeft string, charRight string) {
|
||||
return quoteChar, quoteChar
|
||||
}
|
||||
|
||||
// Migration returns a Migration instance for MySQL database operations.
|
||||
func (d *Driver) Migration() *Migration {
|
||||
return NewMigration(d)
|
||||
}
|
||||
|
||||
// GetMigration returns a Migration instance implementing the Migration interface.
|
||||
func (d *Driver) GetMigration() database.Migration {
|
||||
return d.Migration()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,365 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database"
|
||||
)
|
||||
|
||||
// Migration implements database migration operations for MySQL.
|
||||
type Migration struct {
|
||||
*database.MigrationCore
|
||||
*database.AutoMigrateCore
|
||||
}
|
||||
|
||||
// NewMigration creates a new MySQL Migration instance.
|
||||
func NewMigration(db database.DB) *Migration {
|
||||
return &Migration{
|
||||
MigrationCore: database.NewMigrationCore(db),
|
||||
AutoMigrateCore: database.NewAutoMigrateCore(db),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates a new table with the given name and column definitions.
|
||||
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
|
||||
if len(columns) == 0 {
|
||||
return fmt.Errorf("cannot create table without columns")
|
||||
}
|
||||
|
||||
var opts database.TableOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE TABLE ")
|
||||
if opts.IfNotExists {
|
||||
sql.WriteString("IF NOT EXISTS ")
|
||||
}
|
||||
sql.WriteString(database.QuoteIdentifier(table))
|
||||
sql.WriteString(" (\n")
|
||||
|
||||
// Add columns
|
||||
var colDefs []string
|
||||
var primaryKeys []string
|
||||
for name, def := range columns {
|
||||
colDef := m.BuildColumnDefinition(name, def)
|
||||
if def.PrimaryKey {
|
||||
primaryKeys = append(primaryKeys, database.QuoteIdentifier(name))
|
||||
}
|
||||
colDefs = append(colDefs, " "+colDef)
|
||||
}
|
||||
|
||||
// Add primary key constraint if needed
|
||||
if len(primaryKeys) > 0 {
|
||||
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
|
||||
}
|
||||
|
||||
sql.WriteString(strings.Join(colDefs, ",\n"))
|
||||
sql.WriteString("\n)")
|
||||
|
||||
// Add table options
|
||||
if opts.Engine != "" {
|
||||
sql.WriteString(fmt.Sprintf(" ENGINE=%s", opts.Engine))
|
||||
}
|
||||
if opts.Charset != "" {
|
||||
sql.WriteString(fmt.Sprintf(" DEFAULT CHARSET=%s", opts.Charset))
|
||||
}
|
||||
if opts.Collation != "" {
|
||||
sql.WriteString(fmt.Sprintf(" COLLATE=%s", opts.Collation))
|
||||
}
|
||||
if opts.Comment != "" {
|
||||
sql.WriteString(fmt.Sprintf(" COMMENT='%s'", escapeString(opts.Comment)))
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// DropTable drops an existing table from the database.
|
||||
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
|
||||
sql := "DROP TABLE "
|
||||
if len(ifExists) > 0 && ifExists[0] {
|
||||
sql += "IF EXISTS "
|
||||
}
|
||||
sql += database.QuoteIdentifier(table)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "DATABASE()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = %s AND table_name = '%s'",
|
||||
schema,
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// RenameTable renames an existing table from oldName to newName.
|
||||
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"RENAME TABLE %s TO %s",
|
||||
database.QuoteIdentifier(oldName),
|
||||
database.QuoteIdentifier(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// TruncateTable removes all records from a table but keeps the table structure.
|
||||
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
|
||||
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifier(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// AddColumn adds a new column to an existing table.
|
||||
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.BuildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifier(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropColumn removes a column from an existing table.
|
||||
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP COLUMN %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(column),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// RenameColumn renames a column in an existing table.
|
||||
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
|
||||
// MySQL requires the full column definition when renaming
|
||||
fields, err := m.GetDB().TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
field, ok := fields[oldName]
|
||||
if !ok {
|
||||
return fmt.Errorf("column %s does not exist in table %s", oldName, table)
|
||||
}
|
||||
|
||||
def := &database.ColumnDefinition{
|
||||
Type: field.Type,
|
||||
Null: field.Null,
|
||||
Default: field.Default,
|
||||
Comment: field.Comment,
|
||||
AutoIncrement: strings.Contains(field.Extra, "auto_increment"),
|
||||
}
|
||||
|
||||
colDef := m.BuildColumnDefinition(newName, def)
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s CHANGE COLUMN %s %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(oldName),
|
||||
colDef,
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// ModifyColumn modifies an existing column's definition.
|
||||
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.BuildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", database.QuoteIdentifier(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasColumn checks if a column exists in a table.
|
||||
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
|
||||
fields, err := m.GetDB().TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, exists := fields[column]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CreateIndex creates a new index on the specified table and columns.
|
||||
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
|
||||
var opts database.IndexOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE ")
|
||||
|
||||
if opts.Unique {
|
||||
sql.WriteString("UNIQUE ")
|
||||
}
|
||||
if opts.FullText {
|
||||
sql.WriteString("FULLTEXT ")
|
||||
}
|
||||
if opts.Spatial {
|
||||
sql.WriteString("SPATIAL ")
|
||||
}
|
||||
|
||||
sql.WriteString("INDEX ")
|
||||
sql.WriteString(database.QuoteIdentifier(index))
|
||||
sql.WriteString(" ON ")
|
||||
sql.WriteString(database.QuoteIdentifier(table))
|
||||
|
||||
colList := m.BuildIndexColumns(columns)
|
||||
sql.WriteString(fmt.Sprintf(" (%s)", colList))
|
||||
|
||||
if opts.Using != "" {
|
||||
sql.WriteString(fmt.Sprintf(" USING %s", opts.Using))
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// DropIndex drops an existing index from a table.
|
||||
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"DROP INDEX %s ON %s",
|
||||
database.QuoteIdentifier(index),
|
||||
database.QuoteIdentifier(table),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasIndex checks if an index exists on a table.
|
||||
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "DATABASE()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM information_schema.statistics WHERE table_schema = %s AND table_name = '%s' AND index_name = '%s'",
|
||||
schema,
|
||||
table,
|
||||
index,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateForeignKey creates a foreign key constraint.
|
||||
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
|
||||
var opts database.ForeignKeyOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(constraint),
|
||||
m.BuildForeignKeyColumns(columns),
|
||||
database.QuoteIdentifier(refTable),
|
||||
m.BuildForeignKeyColumns(refColumns),
|
||||
)
|
||||
|
||||
if opts.OnDelete != "" {
|
||||
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
|
||||
}
|
||||
if opts.OnUpdate != "" {
|
||||
sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate)
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropForeignKey drops a foreign key constraint.
|
||||
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP FOREIGN KEY %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(constraint),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasForeignKey checks if a foreign key constraint exists.
|
||||
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "DATABASE()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM information_schema.table_constraints WHERE constraint_schema = %s AND table_name = '%s' AND constraint_name = '%s' AND constraint_type = 'FOREIGN KEY'",
|
||||
schema,
|
||||
table,
|
||||
constraint,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateSchema creates a new database schema.
|
||||
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
|
||||
sql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", database.QuoteIdentifier(schema))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropSchema drops an existing database schema.
|
||||
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
|
||||
sql := "DROP DATABASE "
|
||||
if len(cascade) > 0 && cascade[0] {
|
||||
sql += "IF EXISTS "
|
||||
}
|
||||
sql += database.QuoteIdentifier(schema)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasSchema checks if a schema exists.
|
||||
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = '%s'",
|
||||
schema,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// escapeString escapes special characters in strings for SQL.
|
||||
func escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, "'", "''")
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
return s
|
||||
}
|
||||
|
|
@ -44,3 +44,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
|
|||
func (d *Driver) GetChars() (charLeft string, charRight string) {
|
||||
return quoteChar, quoteChar
|
||||
}
|
||||
|
||||
// Migration returns a Migration instance for Oracle database operations.
|
||||
func (d *Driver) Migration() *Migration {
|
||||
return NewMigration(d)
|
||||
}
|
||||
|
||||
// GetMigration returns a Migration instance implementing the Migration interface.
|
||||
func (d *Driver) GetMigration() database.Migration {
|
||||
return d.Migration()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,351 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database"
|
||||
)
|
||||
|
||||
// Migration implements database migration operations for Oracle.
|
||||
type Migration struct {
|
||||
*database.MigrationCore
|
||||
*database.AutoMigrateCore
|
||||
}
|
||||
|
||||
// NewMigration creates a new Oracle Migration instance.
|
||||
func NewMigration(db database.DB) *Migration {
|
||||
return &Migration{
|
||||
MigrationCore: database.NewMigrationCore(db),
|
||||
AutoMigrateCore: database.NewAutoMigrateCore(db),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates a new table with the given name and column definitions.
|
||||
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
|
||||
if len(columns) == 0 {
|
||||
return fmt.Errorf("cannot create table without columns")
|
||||
}
|
||||
|
||||
var opts database.TableOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE TABLE ")
|
||||
sql.WriteString(database.QuoteIdentifierDouble(table))
|
||||
sql.WriteString(" (\n")
|
||||
|
||||
// Add columns
|
||||
var colDefs []string
|
||||
var primaryKeys []string
|
||||
for name, def := range columns {
|
||||
colDef := m.buildColumnDefinition(name, def)
|
||||
if def.PrimaryKey {
|
||||
primaryKeys = append(primaryKeys, database.QuoteIdentifierDouble(name))
|
||||
}
|
||||
colDefs = append(colDefs, " "+colDef)
|
||||
}
|
||||
|
||||
// Add primary key constraint if needed
|
||||
if len(primaryKeys) > 0 {
|
||||
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
|
||||
}
|
||||
|
||||
sql.WriteString(strings.Join(colDefs, ",\n"))
|
||||
sql.WriteString("\n)")
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// buildColumnDefinition builds column definition for Oracle.
|
||||
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
|
||||
var parts []string
|
||||
parts = append(parts, database.QuoteIdentifierDouble(name))
|
||||
|
||||
// Handle Oracle-specific types
|
||||
dbType := def.Type
|
||||
parts = append(parts, dbType)
|
||||
|
||||
if !def.Null {
|
||||
parts = append(parts, "NOT NULL")
|
||||
}
|
||||
|
||||
if def.Default != nil {
|
||||
defaultValue := formatDefaultValue(def.Default)
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// DropTable drops an existing table from the database.
|
||||
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
|
||||
if len(ifExists) > 0 && ifExists[0] {
|
||||
sql := fmt.Sprintf(
|
||||
"BEGIN EXECUTE IMMEDIATE 'DROP TABLE %s'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
sql := fmt.Sprintf("DROP TABLE %s", database.QuoteIdentifierDouble(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = '%s'",
|
||||
strings.ToUpper(table),
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// RenameTable renames an existing table from oldName to newName.
|
||||
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME TO %s",
|
||||
database.QuoteIdentifierDouble(oldName),
|
||||
database.QuoteIdentifierDouble(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// TruncateTable removes all records from a table but keeps the table structure.
|
||||
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
|
||||
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifierDouble(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// AddColumn adds a new column to an existing table.
|
||||
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s ADD %s", database.QuoteIdentifierDouble(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropColumn removes a column from an existing table.
|
||||
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP COLUMN %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(column),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// RenameColumn renames a column in an existing table.
|
||||
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME COLUMN %s TO %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(oldName),
|
||||
database.QuoteIdentifierDouble(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// ModifyColumn modifies an existing column's definition.
|
||||
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s MODIFY %s", database.QuoteIdentifierDouble(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasColumn checks if a column exists in a table.
|
||||
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
|
||||
fields, err := m.GetDB().TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, exists := fields[column]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CreateIndex creates a new index on the specified table and columns.
|
||||
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
|
||||
var opts database.IndexOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE ")
|
||||
|
||||
if opts.Unique {
|
||||
sql.WriteString("UNIQUE ")
|
||||
}
|
||||
|
||||
sql.WriteString("INDEX ")
|
||||
sql.WriteString(database.QuoteIdentifierDouble(index))
|
||||
sql.WriteString(" ON ")
|
||||
sql.WriteString(database.QuoteIdentifierDouble(table))
|
||||
|
||||
colList := m.BuildIndexColumnsWithDouble(columns)
|
||||
sql.WriteString(fmt.Sprintf(" (%s)", colList))
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// BuildIndexColumnsWithDouble builds the column list for index creation using double quotes.
|
||||
func (m *Migration) BuildIndexColumnsWithDouble(columns []string) string {
|
||||
quoted := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
quoted[i] = database.QuoteIdentifierDouble(col)
|
||||
}
|
||||
return strings.Join(quoted, ", ")
|
||||
}
|
||||
|
||||
// DropIndex drops an existing index from a table.
|
||||
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
|
||||
sql := fmt.Sprintf("DROP INDEX %s", database.QuoteIdentifierDouble(index))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasIndex checks if an index exists on a table.
|
||||
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM USER_INDEXES WHERE INDEX_NAME = '%s' AND TABLE_NAME = '%s'",
|
||||
strings.ToUpper(index),
|
||||
strings.ToUpper(table),
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateForeignKey creates a foreign key constraint.
|
||||
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
|
||||
var opts database.ForeignKeyOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(constraint),
|
||||
m.BuildForeignKeyColumnsWithDouble(columns),
|
||||
database.QuoteIdentifierDouble(refTable),
|
||||
m.BuildForeignKeyColumnsWithDouble(refColumns),
|
||||
)
|
||||
|
||||
if opts.OnDelete != "" {
|
||||
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// BuildForeignKeyColumnsWithDouble builds the column list for foreign key using double quotes.
|
||||
func (m *Migration) BuildForeignKeyColumnsWithDouble(columns []string) string {
|
||||
quoted := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
quoted[i] = database.QuoteIdentifierDouble(col)
|
||||
}
|
||||
return "(" + strings.Join(quoted, ", ") + ")"
|
||||
}
|
||||
|
||||
// DropForeignKey drops a foreign key constraint.
|
||||
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP CONSTRAINT %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(constraint),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasForeignKey checks if a foreign key constraint exists.
|
||||
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE CONSTRAINT_NAME = '%s' AND TABLE_NAME = '%s' AND CONSTRAINT_TYPE = 'R'",
|
||||
strings.ToUpper(constraint),
|
||||
strings.ToUpper(table),
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateSchema creates a new database schema (user in Oracle).
|
||||
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
|
||||
// In Oracle, schema is equivalent to user
|
||||
sql := fmt.Sprintf("CREATE USER %s IDENTIFIED BY password", database.QuoteIdentifierDouble(schema))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropSchema drops an existing database schema (user in Oracle).
|
||||
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
|
||||
sql := "DROP USER "
|
||||
if len(cascade) > 0 && cascade[0] {
|
||||
sql += database.QuoteIdentifierDouble(schema) + " CASCADE"
|
||||
} else {
|
||||
sql += database.QuoteIdentifierDouble(schema)
|
||||
}
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasSchema checks if a schema exists.
|
||||
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM ALL_USERS WHERE USERNAME = '%s'",
|
||||
strings.ToUpper(schema),
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// formatDefaultValue formats the default value for SQL.
|
||||
func formatDefaultValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", escapeString(v))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case nil:
|
||||
return "NULL"
|
||||
default:
|
||||
return fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeString escapes special characters in strings for SQL.
|
||||
func escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, "'", "''")
|
||||
return s
|
||||
}
|
||||
|
|
@ -48,3 +48,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
|
|||
func (d *Driver) GetChars() (charLeft string, charRight string) {
|
||||
return quoteChar, quoteChar
|
||||
}
|
||||
|
||||
// Migration returns a Migration instance for PostgreSQL database operations.
|
||||
func (d *Driver) Migration() *Migration {
|
||||
return NewMigration(d)
|
||||
}
|
||||
|
||||
// GetMigration returns a Migration instance implementing the Migration interface.
|
||||
func (d *Driver) GetMigration() database.Migration {
|
||||
return d.Migration()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,481 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database"
|
||||
)
|
||||
|
||||
// Migration implements database migration operations for PostgreSQL.
|
||||
type Migration struct {
|
||||
*database.MigrationCore
|
||||
*database.AutoMigrateCore
|
||||
}
|
||||
|
||||
// NewMigration creates a new PostgreSQL Migration instance.
|
||||
func NewMigration(db database.DB) *Migration {
|
||||
return &Migration{
|
||||
MigrationCore: database.NewMigrationCore(db),
|
||||
AutoMigrateCore: database.NewAutoMigrateCore(db),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates a new table with the given name and column definitions.
|
||||
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
|
||||
if len(columns) == 0 {
|
||||
return fmt.Errorf("cannot create table without columns")
|
||||
}
|
||||
|
||||
var opts database.TableOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE TABLE ")
|
||||
if opts.IfNotExists {
|
||||
sql.WriteString("IF NOT EXISTS ")
|
||||
}
|
||||
sql.WriteString(database.QuoteIdentifierDouble(table))
|
||||
sql.WriteString(" (\n")
|
||||
|
||||
// Add columns
|
||||
var colDefs []string
|
||||
var primaryKeys []string
|
||||
for name, def := range columns {
|
||||
colDef := m.buildColumnDefinition(name, def)
|
||||
if def.PrimaryKey {
|
||||
primaryKeys = append(primaryKeys, database.QuoteIdentifierDouble(name))
|
||||
}
|
||||
colDefs = append(colDefs, " "+colDef)
|
||||
}
|
||||
|
||||
// Add primary key constraint if needed
|
||||
if len(primaryKeys) > 0 {
|
||||
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
|
||||
}
|
||||
|
||||
sql.WriteString(strings.Join(colDefs, ",\n"))
|
||||
sql.WriteString("\n)")
|
||||
|
||||
// Add table comment if provided
|
||||
if opts.Comment != "" {
|
||||
commentSQL := fmt.Sprintf(
|
||||
"COMMENT ON TABLE %s IS '%s'",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
escapeString(opts.Comment),
|
||||
)
|
||||
if err := m.ExecuteSQL(ctx, commentSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// buildColumnDefinition builds column definition for PostgreSQL.
|
||||
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
|
||||
var parts []string
|
||||
parts = append(parts, database.QuoteIdentifierDouble(name))
|
||||
|
||||
// Handle PostgreSQL-specific types
|
||||
dbType := def.Type
|
||||
if def.AutoIncrement {
|
||||
if dbType == "INT" || dbType == "INTEGER" {
|
||||
dbType = "SERIAL"
|
||||
} else if dbType == "BIGINT" {
|
||||
dbType = "BIGSERIAL"
|
||||
}
|
||||
}
|
||||
parts = append(parts, dbType)
|
||||
|
||||
if def.PrimaryKey {
|
||||
// Primary key is handled separately
|
||||
} else {
|
||||
if !def.Null {
|
||||
parts = append(parts, "NOT NULL")
|
||||
}
|
||||
|
||||
if def.Unique {
|
||||
parts = append(parts, "UNIQUE")
|
||||
}
|
||||
|
||||
if def.Default != nil {
|
||||
defaultValue := formatDefaultValue(def.Default)
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// DropTable drops an existing table from the database.
|
||||
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
|
||||
sql := "DROP TABLE "
|
||||
if len(ifExists) > 0 && ifExists[0] {
|
||||
sql += "IF EXISTS "
|
||||
}
|
||||
sql += database.QuoteIdentifierDouble(table)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "current_schema()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = %s AND table_name = '%s'",
|
||||
schema,
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// RenameTable renames an existing table from oldName to newName.
|
||||
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME TO %s",
|
||||
database.QuoteIdentifierDouble(oldName),
|
||||
database.QuoteIdentifierDouble(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// TruncateTable removes all records from a table but keeps the table structure.
|
||||
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
|
||||
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifierDouble(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// AddColumn adds a new column to an existing table.
|
||||
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifierDouble(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropColumn removes a column from an existing table.
|
||||
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP COLUMN %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(column),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// RenameColumn renames a column in an existing table.
|
||||
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME COLUMN %s TO %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(oldName),
|
||||
database.QuoteIdentifierDouble(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// ModifyColumn modifies an existing column's definition.
|
||||
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
// PostgreSQL requires multiple ALTER statements for different modifications
|
||||
var statements []string
|
||||
|
||||
if definition.Type != "" {
|
||||
statements = append(statements, fmt.Sprintf(
|
||||
"ALTER TABLE %s ALTER COLUMN %s TYPE %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(column),
|
||||
definition.Type,
|
||||
))
|
||||
}
|
||||
|
||||
if definition.Default != nil {
|
||||
defaultValue := formatDefaultValue(definition.Default)
|
||||
statements = append(statements, fmt.Sprintf(
|
||||
"ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(column),
|
||||
defaultValue,
|
||||
))
|
||||
} else {
|
||||
statements = append(statements, fmt.Sprintf(
|
||||
"ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(column),
|
||||
))
|
||||
}
|
||||
|
||||
if !definition.Null {
|
||||
statements = append(statements, fmt.Sprintf(
|
||||
"ALTER TABLE %s ALTER COLUMN %s SET NOT NULL",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(column),
|
||||
))
|
||||
} else {
|
||||
statements = append(statements, fmt.Sprintf(
|
||||
"ALTER TABLE %s ALTER COLUMN %s DROP NOT NULL",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(column),
|
||||
))
|
||||
}
|
||||
|
||||
// Execute all statements
|
||||
for _, stmt := range statements {
|
||||
if err := m.ExecuteSQL(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasColumn checks if a column exists in a table.
|
||||
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
|
||||
fields, err := m.GetDB().TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, exists := fields[column]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CreateIndex creates a new index on the specified table and columns.
|
||||
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
|
||||
var opts database.IndexOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE ")
|
||||
|
||||
if opts.Unique {
|
||||
sql.WriteString("UNIQUE ")
|
||||
}
|
||||
|
||||
sql.WriteString("INDEX ")
|
||||
sql.WriteString(database.QuoteIdentifierDouble(index))
|
||||
sql.WriteString(" ON ")
|
||||
sql.WriteString(database.QuoteIdentifierDouble(table))
|
||||
|
||||
colList := m.BuildIndexColumnsWithDouble(columns)
|
||||
sql.WriteString(fmt.Sprintf(" (%s)", colList))
|
||||
|
||||
if opts.Using != "" {
|
||||
sql.WriteString(fmt.Sprintf(" USING %s", opts.Using))
|
||||
}
|
||||
|
||||
if opts.Comment != "" {
|
||||
// Comment will be added after index creation
|
||||
}
|
||||
|
||||
err := m.ExecuteSQL(ctx, sql.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add comment if provided
|
||||
if opts.Comment != "" {
|
||||
commentSQL := fmt.Sprintf(
|
||||
"COMMENT ON INDEX %s IS '%s'",
|
||||
database.QuoteIdentifierDouble(index),
|
||||
escapeString(opts.Comment),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, commentSQL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildIndexColumnsWithDouble builds the column list for index creation using double quotes.
|
||||
func (m *Migration) BuildIndexColumnsWithDouble(columns []string) string {
|
||||
quoted := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
quoted[i] = database.QuoteIdentifierDouble(col)
|
||||
}
|
||||
return strings.Join(quoted, ", ")
|
||||
}
|
||||
|
||||
// DropIndex drops an existing index from a table.
|
||||
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
|
||||
sql := fmt.Sprintf("DROP INDEX %s", database.QuoteIdentifierDouble(index))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasIndex checks if an index exists on a table.
|
||||
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "current_schema()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM pg_indexes WHERE schemaname = %s AND tablename = '%s' AND indexname = '%s'",
|
||||
schema,
|
||||
table,
|
||||
index,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateForeignKey creates a foreign key constraint.
|
||||
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
|
||||
var opts database.ForeignKeyOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(constraint),
|
||||
m.BuildForeignKeyColumnsWithDouble(columns),
|
||||
database.QuoteIdentifierDouble(refTable),
|
||||
m.BuildForeignKeyColumnsWithDouble(refColumns),
|
||||
)
|
||||
|
||||
if opts.OnDelete != "" {
|
||||
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
|
||||
}
|
||||
if opts.OnUpdate != "" {
|
||||
sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate)
|
||||
}
|
||||
if opts.Deferrable {
|
||||
sql += " DEFERRABLE"
|
||||
if opts.InitiallyDeferred {
|
||||
sql += " INITIALLY DEFERRED"
|
||||
}
|
||||
}
|
||||
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// BuildForeignKeyColumnsWithDouble builds the column list for foreign key using double quotes.
|
||||
func (m *Migration) BuildForeignKeyColumnsWithDouble(columns []string) string {
|
||||
quoted := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
quoted[i] = database.QuoteIdentifierDouble(col)
|
||||
}
|
||||
return "(" + strings.Join(quoted, ", ") + ")"
|
||||
}
|
||||
|
||||
// DropForeignKey drops a foreign key constraint.
|
||||
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP CONSTRAINT %s",
|
||||
database.QuoteIdentifierDouble(table),
|
||||
database.QuoteIdentifierDouble(constraint),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasForeignKey checks if a foreign key constraint exists.
|
||||
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
|
||||
schema := m.GetDB().GetSchema()
|
||||
if schema == "" {
|
||||
schema = "current_schema()"
|
||||
} else {
|
||||
schema = fmt.Sprintf("'%s'", schema)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM information_schema.table_constraints WHERE constraint_schema = %s AND table_name = '%s' AND constraint_name = '%s' AND constraint_type = 'FOREIGN KEY'",
|
||||
schema,
|
||||
table,
|
||||
constraint,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateSchema creates a new database schema.
|
||||
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
|
||||
sql := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", database.QuoteIdentifierDouble(schema))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropSchema drops an existing database schema.
|
||||
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
|
||||
sql := "DROP SCHEMA "
|
||||
if len(cascade) > 0 && cascade[0] {
|
||||
sql += "IF EXISTS "
|
||||
sql += database.QuoteIdentifierDouble(schema) + " CASCADE"
|
||||
} else {
|
||||
sql += "IF EXISTS " + database.QuoteIdentifierDouble(schema)
|
||||
}
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasSchema checks if a schema exists.
|
||||
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = '%s'",
|
||||
schema,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// formatDefaultValue formats the default value for SQL.
|
||||
func formatDefaultValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", escapeString(v))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
if v {
|
||||
return "TRUE"
|
||||
}
|
||||
return "FALSE"
|
||||
case nil:
|
||||
return "NULL"
|
||||
default:
|
||||
return fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeString escapes special characters in strings for SQL.
|
||||
func escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, "'", "''")
|
||||
return s
|
||||
}
|
||||
|
|
@ -45,3 +45,13 @@ func (d *Driver) New(core *database.Core, node *database.ConfigNode) (database.D
|
|||
func (d *Driver) GetChars() (charLeft string, charRight string) {
|
||||
return quoteChar, quoteChar
|
||||
}
|
||||
|
||||
// Migration returns a Migration instance for SQLite database operations.
|
||||
func (d *Driver) Migration() *Migration {
|
||||
return NewMigration(d)
|
||||
}
|
||||
|
||||
// GetMigration returns a Migration instance implementing the Migration interface.
|
||||
func (d *Driver) GetMigration() database.Migration {
|
||||
return d.Migration()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,322 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.magicany.cc/black1552/gin-base/database"
|
||||
)
|
||||
|
||||
// Migration implements database migration operations for SQLite.
|
||||
type Migration struct {
|
||||
*database.MigrationCore
|
||||
*database.AutoMigrateCore
|
||||
}
|
||||
|
||||
// NewMigration creates a new SQLite Migration instance.
|
||||
func NewMigration(db database.DB) *Migration {
|
||||
return &Migration{
|
||||
MigrationCore: database.NewMigrationCore(db),
|
||||
AutoMigrateCore: database.NewAutoMigrateCore(db),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates a new table with the given name and column definitions.
|
||||
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
|
||||
if len(columns) == 0 {
|
||||
return fmt.Errorf("cannot create table without columns")
|
||||
}
|
||||
|
||||
var opts database.TableOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE TABLE ")
|
||||
if opts.IfNotExists {
|
||||
sql.WriteString("IF NOT EXISTS ")
|
||||
}
|
||||
sql.WriteString(database.QuoteIdentifier(table))
|
||||
sql.WriteString(" (\n")
|
||||
|
||||
// Add columns
|
||||
var colDefs []string
|
||||
var primaryKeys []string
|
||||
for name, def := range columns {
|
||||
colDef := m.buildColumnDefinition(name, def)
|
||||
if def.PrimaryKey {
|
||||
primaryKeys = append(primaryKeys, database.QuoteIdentifier(name))
|
||||
}
|
||||
colDefs = append(colDefs, " "+colDef)
|
||||
}
|
||||
|
||||
// Add primary key constraint if needed
|
||||
if len(primaryKeys) > 0 {
|
||||
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
|
||||
}
|
||||
|
||||
sql.WriteString(strings.Join(colDefs, ",\n"))
|
||||
sql.WriteString("\n)")
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// buildColumnDefinition builds column definition for SQLite.
|
||||
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
|
||||
var parts []string
|
||||
parts = append(parts, database.QuoteIdentifier(name))
|
||||
|
||||
// Handle SQLite-specific types
|
||||
dbType := def.Type
|
||||
if def.AutoIncrement && def.PrimaryKey {
|
||||
if dbType == "INT" || dbType == "INTEGER" {
|
||||
dbType = "INTEGER"
|
||||
}
|
||||
}
|
||||
parts = append(parts, dbType)
|
||||
|
||||
if def.PrimaryKey && def.AutoIncrement {
|
||||
parts = append(parts, "PRIMARY KEY AUTOINCREMENT")
|
||||
} else {
|
||||
if !def.Null {
|
||||
parts = append(parts, "NOT NULL")
|
||||
}
|
||||
|
||||
if def.Unique && !def.PrimaryKey {
|
||||
parts = append(parts, "UNIQUE")
|
||||
}
|
||||
|
||||
if def.Default != nil {
|
||||
defaultValue := formatDefaultValue(def.Default)
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// DropTable drops an existing table from the database.
|
||||
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
|
||||
sql := "DROP TABLE "
|
||||
if len(ifExists) > 0 && ifExists[0] {
|
||||
sql += "IF EXISTS "
|
||||
}
|
||||
sql += database.QuoteIdentifier(table)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='%s'",
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// RenameTable renames an existing table from oldName to newName.
|
||||
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME TO %s",
|
||||
database.QuoteIdentifier(oldName),
|
||||
database.QuoteIdentifier(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// TruncateTable removes all records from a table but keeps the table structure.
|
||||
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
|
||||
sql := fmt.Sprintf("DELETE FROM %s", database.QuoteIdentifier(table))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// AddColumn adds a new column to an existing table.
|
||||
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
colDef := m.buildColumnDefinition(column, definition)
|
||||
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifier(table), colDef)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// DropColumn removes a column from an existing table.
|
||||
// Note: SQLite has limited ALTER TABLE support. This may require table recreation.
|
||||
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
|
||||
// SQLite 3.35.0+ supports DROP COLUMN
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s DROP COLUMN %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(column),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// RenameColumn renames a column in an existing table.
|
||||
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME COLUMN %s TO %s",
|
||||
database.QuoteIdentifier(table),
|
||||
database.QuoteIdentifier(oldName),
|
||||
database.QuoteIdentifierDouble(newName),
|
||||
)
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// ModifyColumn modifies an existing column's definition.
|
||||
// Note: SQLite requires table recreation for most column modifications.
|
||||
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
||||
// SQLite has very limited ALTER TABLE support
|
||||
// This would typically require recreating the table
|
||||
return fmt.Errorf("SQLite does not support MODIFY COLUMN directly, table recreation required")
|
||||
}
|
||||
|
||||
// HasColumn checks if a column exists in a table.
|
||||
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
|
||||
fields, err := m.GetDB().TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, exists := fields[column]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CreateIndex creates a new index on the specified table and columns.
|
||||
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
|
||||
var opts database.IndexOptions
|
||||
for _, opt := range options {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
var sql strings.Builder
|
||||
sql.WriteString("CREATE ")
|
||||
|
||||
if opts.Unique {
|
||||
sql.WriteString("UNIQUE ")
|
||||
}
|
||||
|
||||
sql.WriteString("INDEX ")
|
||||
if opts.IfNotExists {
|
||||
sql.WriteString("IF NOT EXISTS ")
|
||||
}
|
||||
sql.WriteString(database.QuoteIdentifier(index))
|
||||
sql.WriteString(" ON ")
|
||||
sql.WriteString(database.QuoteIdentifier(table))
|
||||
|
||||
colList := m.BuildIndexColumns(columns)
|
||||
sql.WriteString(fmt.Sprintf(" (%s)", colList))
|
||||
|
||||
return m.ExecuteSQL(ctx, sql.String())
|
||||
}
|
||||
|
||||
// DropIndex drops an existing index from a table.
|
||||
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
|
||||
sql := fmt.Sprintf("DROP INDEX IF EXISTS %s", database.QuoteIdentifier(index))
|
||||
return m.ExecuteSQL(ctx, sql)
|
||||
}
|
||||
|
||||
// HasIndex checks if an index exists on a table.
|
||||
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
|
||||
query := fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND name='%s' AND tbl_name='%s'",
|
||||
index,
|
||||
table,
|
||||
)
|
||||
|
||||
value, err := m.GetDB().GetValue(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value.Int() > 0, nil
|
||||
}
|
||||
|
||||
// CreateForeignKey creates a foreign key constraint.
|
||||
// Note: SQLite requires foreign keys to be enabled with PRAGMA foreign_keys = ON
|
||||
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
|
||||
// SQLite doesn't support adding foreign keys to existing tables
|
||||
// Foreign keys must be defined during table creation
|
||||
return fmt.Errorf("SQLite does not support adding foreign keys to existing tables")
|
||||
}
|
||||
|
||||
// DropForeignKey drops a foreign key constraint.
|
||||
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
|
||||
// SQLite doesn't support dropping foreign keys from existing tables
|
||||
return fmt.Errorf("SQLite does not support dropping foreign keys from existing tables")
|
||||
}
|
||||
|
||||
// HasForeignKey checks if a foreign key constraint exists.
|
||||
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
|
||||
// Query pragma to check foreign keys
|
||||
query := fmt.Sprintf("PRAGMA foreign_key_list(%s)", database.QuoteIdentifier(table))
|
||||
|
||||
result, err := m.GetDB().GetAll(ctx, query)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Check if constraint exists in the result
|
||||
for _, row := range result {
|
||||
if id, ok := row["id"]; ok && id.String() == constraint {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// CreateSchema is not applicable for SQLite (single database file).
|
||||
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
|
||||
// SQLite doesn't support schemas
|
||||
return fmt.Errorf("SQLite does not support schemas")
|
||||
}
|
||||
|
||||
// DropSchema is not applicable for SQLite (single database file).
|
||||
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
|
||||
// SQLite doesn't support schemas
|
||||
return fmt.Errorf("SQLite does not support schemas")
|
||||
}
|
||||
|
||||
// HasSchema checks if a schema exists (always returns false for SQLite).
|
||||
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
|
||||
// SQLite doesn't support schemas
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// formatDefaultValue formats the default value for SQL.
|
||||
func formatDefaultValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", escapeString(v))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case nil:
|
||||
return "NULL"
|
||||
default:
|
||||
return fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeString escapes special characters in strings for SQL.
|
||||
func escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, "'", "''")
|
||||
return s
|
||||
}
|
||||
|
|
@ -319,6 +319,10 @@ type DB interface {
|
|||
// If no schema is specified, it uses the default schema.
|
||||
Tables(ctx context.Context, schema ...string) (tables []string, err error)
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
// It returns true if the table exists, false otherwise.
|
||||
HasTable(ctx context.Context, table string) (bool, error)
|
||||
|
||||
// TableFields returns detailed information about all fields in the specified table.
|
||||
// The returned map keys are field names and values contain field metadata.
|
||||
TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error)
|
||||
|
|
@ -345,6 +349,14 @@ type DB interface {
|
|||
// OrderRandomFunction returns the SQL function for random ordering.
|
||||
// The implementation is database-specific (e.g., RAND() for MySQL).
|
||||
OrderRandomFunction() string
|
||||
|
||||
// ===========================================================================
|
||||
// Migration support.
|
||||
// ===========================================================================
|
||||
|
||||
// GetMigration returns a Migration instance for database schema operations.
|
||||
// The returned Migration can be used to create, alter, and drop tables, columns, indexes, etc.
|
||||
GetMigration() Migration
|
||||
}
|
||||
|
||||
// TX defines the interfaces for ORM transaction operations.
|
||||
|
|
|
|||
|
|
@ -735,7 +735,7 @@ func (c *Core) writeSqlToLogger(sql *Sql) {
|
|||
}
|
||||
|
||||
// HasTable determine whether the table name exists in the database.
|
||||
func (c *Core) HasTable(name string) (bool, error) {
|
||||
func (c *Core) HasTable(ctx context.Context, name string) (bool, error) {
|
||||
tables, err := c.GetTablesWithCache()
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
|
|||
|
|
@ -44,3 +44,8 @@ func (d *DriverDefault) PingMaster() error {
|
|||
func (d *DriverDefault) PingSlave() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMigration returns a Migration instance. For default driver, it returns nil.
|
||||
func (d *DriverDefault) GetMigration() Migration {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -527,7 +527,7 @@ func formatWhereHolder(ctx context.Context, db DB, in formatWhereHolderInput) (n
|
|||
)
|
||||
// If `Prefix` is given, it checks and retrieves the table name.
|
||||
if in.Prefix != "" {
|
||||
hasTable, _ := db.GetCore().HasTable(in.Prefix)
|
||||
hasTable, _ := db.GetCore().HasTable(ctx, in.Prefix)
|
||||
if hasTable {
|
||||
in.Table = in.Prefix
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,304 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Migration defines the interface for database migration operations.
|
||||
// It provides methods for creating, altering, and dropping database schema objects.
|
||||
type Migration interface {
|
||||
// ===========================================================================
|
||||
// Auto Migration based on Entity Structs
|
||||
// ===========================================================================
|
||||
|
||||
// AutoMigrate automatically creates or updates tables based on entity structs.
|
||||
// It analyzes the struct fields and their tags to determine column definitions.
|
||||
// The entities parameter accepts struct instances or pointers to structs.
|
||||
AutoMigrate(ctx context.Context, entities ...any) error
|
||||
|
||||
// ===========================================================================
|
||||
// Table Operations
|
||||
// ===========================================================================
|
||||
|
||||
// CreateTable creates a new table with the given name and column definitions.
|
||||
// The columns parameter is a map where keys are column names and values are column definitions.
|
||||
CreateTable(ctx context.Context, table string, columns map[string]*ColumnDefinition, options ...TableOption) error
|
||||
|
||||
// DropTable drops an existing table from the database.
|
||||
// If ifExists is true, it won't return an error if the table doesn't exist.
|
||||
DropTable(ctx context.Context, table string, ifExists ...bool) error
|
||||
|
||||
// HasTable checks if a table exists in the database.
|
||||
HasTable(ctx context.Context, table string) (bool, error)
|
||||
|
||||
// RenameTable renames an existing table from oldName to newName.
|
||||
RenameTable(ctx context.Context, oldName, newName string) error
|
||||
|
||||
// TruncateTable removes all records from a table but keeps the table structure.
|
||||
TruncateTable(ctx context.Context, table string) error
|
||||
|
||||
// ===========================================================================
|
||||
// Column Operations
|
||||
// ===========================================================================
|
||||
|
||||
// AddColumn adds a new column to an existing table.
|
||||
AddColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error
|
||||
|
||||
// DropColumn removes a column from an existing table.
|
||||
DropColumn(ctx context.Context, table, column string) error
|
||||
|
||||
// RenameColumn renames a column in an existing table.
|
||||
RenameColumn(ctx context.Context, table, oldName, newName string) error
|
||||
|
||||
// ModifyColumn modifies an existing column's definition.
|
||||
ModifyColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error
|
||||
|
||||
// HasColumn checks if a column exists in a table.
|
||||
HasColumn(ctx context.Context, table, column string) (bool, error)
|
||||
|
||||
// ===========================================================================
|
||||
// Index Operations
|
||||
// ===========================================================================
|
||||
|
||||
// CreateIndex creates a new index on the specified table and columns.
|
||||
CreateIndex(ctx context.Context, table, index string, columns []string, options ...IndexOption) error
|
||||
|
||||
// DropIndex drops an existing index from a table.
|
||||
DropIndex(ctx context.Context, table, index string) error
|
||||
|
||||
// HasIndex checks if an index exists on a table.
|
||||
HasIndex(ctx context.Context, table, index string) (bool, error)
|
||||
|
||||
// ===========================================================================
|
||||
// Foreign Key Operations
|
||||
// ===========================================================================
|
||||
|
||||
// CreateForeignKey creates a foreign key constraint.
|
||||
CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...ForeignKeyOption) error
|
||||
|
||||
// DropForeignKey drops a foreign key constraint.
|
||||
DropForeignKey(ctx context.Context, table, constraint string) error
|
||||
|
||||
// HasForeignKey checks if a foreign key constraint exists.
|
||||
HasForeignKey(ctx context.Context, table, constraint string) (bool, error)
|
||||
|
||||
// ===========================================================================
|
||||
// Schema Operations
|
||||
// ===========================================================================
|
||||
|
||||
// CreateSchema creates a new database schema (namespace).
|
||||
CreateSchema(ctx context.Context, schema string) error
|
||||
|
||||
// DropSchema drops an existing database schema.
|
||||
DropSchema(ctx context.Context, schema string, cascade ...bool) error
|
||||
|
||||
// HasSchema checks if a schema exists.
|
||||
HasSchema(ctx context.Context, schema string) (bool, error)
|
||||
}
|
||||
|
||||
// ColumnDefinition defines the structure and properties of a database column.
|
||||
type ColumnDefinition struct {
|
||||
// Type specifies the database column type (e.g., "VARCHAR(255)", "INT", "TEXT").
|
||||
Type string
|
||||
|
||||
// Null indicates whether the column can contain NULL values.
|
||||
Null bool
|
||||
|
||||
// Default sets the default value for the column.
|
||||
Default any
|
||||
|
||||
// Comment adds a comment to the column.
|
||||
Comment string
|
||||
|
||||
// AutoIncrement enables auto-increment for numeric columns.
|
||||
AutoIncrement bool
|
||||
|
||||
// PrimaryKey marks this column as part of the primary key.
|
||||
PrimaryKey bool
|
||||
|
||||
// Unique marks this column as having unique values.
|
||||
Unique bool
|
||||
|
||||
// Length specifies the length for string types.
|
||||
Length int
|
||||
|
||||
// Precision specifies the precision for decimal/numeric types.
|
||||
Precision int
|
||||
|
||||
// Scale specifies the scale for decimal/numeric types.
|
||||
Scale int
|
||||
}
|
||||
|
||||
// TableOption defines additional options for table creation.
|
||||
type TableOption func(*TableOptions)
|
||||
|
||||
// TableOptions defines additional options for table creation.
|
||||
type TableOptions struct {
|
||||
// Engine specifies the storage engine (MySQL specific).
|
||||
Engine string
|
||||
|
||||
// Charset specifies the character set.
|
||||
Charset string
|
||||
|
||||
// Collation specifies the collation.
|
||||
Collation string
|
||||
|
||||
// Comment adds a comment to the table.
|
||||
Comment string
|
||||
|
||||
// IfNotExists prevents error if table already exists.
|
||||
IfNotExists bool
|
||||
}
|
||||
|
||||
// WithEngine sets the storage engine for the table (MySQL specific).
|
||||
func WithEngine(engine string) TableOption {
|
||||
return func(opts *TableOptions) {
|
||||
opts.Engine = engine
|
||||
}
|
||||
}
|
||||
|
||||
// WithCharset sets the character set for the table.
|
||||
func WithCharset(charset string) TableOption {
|
||||
return func(opts *TableOptions) {
|
||||
opts.Charset = charset
|
||||
}
|
||||
}
|
||||
|
||||
// WithCollation sets the collation for the table.
|
||||
func WithCollation(collation string) TableOption {
|
||||
return func(opts *TableOptions) {
|
||||
opts.Collation = collation
|
||||
}
|
||||
}
|
||||
|
||||
// WithTableComment adds a comment to the table.
|
||||
func WithTableComment(comment string) TableOption {
|
||||
return func(opts *TableOptions) {
|
||||
opts.Comment = comment
|
||||
}
|
||||
}
|
||||
|
||||
// WithIfNotExists prevents error if table already exists.
|
||||
func WithIfNotExists() TableOption {
|
||||
return func(opts *TableOptions) {
|
||||
opts.IfNotExists = true
|
||||
}
|
||||
}
|
||||
|
||||
// IndexOption defines additional options for index creation.
|
||||
type IndexOption func(*IndexOptions)
|
||||
|
||||
// IndexOptions defines additional options for index creation.
|
||||
type IndexOptions struct {
|
||||
// Unique marks the index as unique.
|
||||
Unique bool
|
||||
|
||||
// FullText creates a full-text index (MySQL specific).
|
||||
FullText bool
|
||||
|
||||
// Spatial creates a spatial index (MySQL specific).
|
||||
Spatial bool
|
||||
|
||||
// Using specifies the index method (e.g., BTREE, HASH).
|
||||
Using string
|
||||
|
||||
// Comment adds a comment to the index.
|
||||
Comment string
|
||||
|
||||
// IfNotExists prevents error if index already exists.
|
||||
IfNotExists bool
|
||||
}
|
||||
|
||||
// WithUniqueIndex creates a unique index.
|
||||
func WithUniqueIndex() IndexOption {
|
||||
return func(opts *IndexOptions) {
|
||||
opts.Unique = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithFullTextIndex creates a full-text index (MySQL specific).
|
||||
func WithFullTextIndex() IndexOption {
|
||||
return func(opts *IndexOptions) {
|
||||
opts.FullText = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithSpatialIndex creates a spatial index (MySQL specific).
|
||||
func WithSpatialIndex() IndexOption {
|
||||
return func(opts *IndexOptions) {
|
||||
opts.Spatial = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithIndexUsing specifies the index method.
|
||||
func WithIndexUsing(method string) IndexOption {
|
||||
return func(opts *IndexOptions) {
|
||||
opts.Using = method
|
||||
}
|
||||
}
|
||||
|
||||
// WithIndexComment adds a comment to the index.
|
||||
func WithIndexComment(comment string) IndexOption {
|
||||
return func(opts *IndexOptions) {
|
||||
opts.Comment = comment
|
||||
}
|
||||
}
|
||||
|
||||
// WithIndexIfNotExists prevents error if index already exists.
|
||||
func WithIndexIfNotExists() IndexOption {
|
||||
return func(opts *IndexOptions) {
|
||||
opts.IfNotExists = true
|
||||
}
|
||||
}
|
||||
|
||||
// ForeignKeyOption defines additional options for foreign key creation.
|
||||
type ForeignKeyOption func(*ForeignKeyOptions)
|
||||
|
||||
// ForeignKeyOptions defines additional options for foreign key creation.
|
||||
type ForeignKeyOptions struct {
|
||||
// OnDelete specifies the action when referenced row is deleted.
|
||||
OnDelete string
|
||||
|
||||
// OnUpdate specifies the action when referenced row is updated.
|
||||
OnUpdate string
|
||||
|
||||
// Deferrable makes the constraint deferrable (PostgreSQL specific).
|
||||
Deferrable bool
|
||||
|
||||
// InitiallyDeferred sets the constraint to be initially deferred (PostgreSQL specific).
|
||||
InitiallyDeferred bool
|
||||
}
|
||||
|
||||
// WithOnDelete sets the ON DELETE action for foreign key.
|
||||
func WithOnDelete(action string) ForeignKeyOption {
|
||||
return func(opts *ForeignKeyOptions) {
|
||||
opts.OnDelete = action
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnUpdate sets the ON UPDATE action for foreign key.
|
||||
func WithOnUpdate(action string) ForeignKeyOption {
|
||||
return func(opts *ForeignKeyOptions) {
|
||||
opts.OnUpdate = action
|
||||
}
|
||||
}
|
||||
|
||||
// WithDeferrable makes the foreign key constraint deferrable (PostgreSQL specific).
|
||||
func WithDeferrable() ForeignKeyOption {
|
||||
return func(opts *ForeignKeyOptions) {
|
||||
opts.Deferrable = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithInitiallyDeferred sets the constraint to be initially deferred (PostgreSQL specific).
|
||||
func WithInitiallyDeferred() ForeignKeyOption {
|
||||
return func(opts *ForeignKeyOptions) {
|
||||
opts.InitiallyDeferred = true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,452 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AutoMigrateCore provides automatic migration functionality based on entity structs.
|
||||
type AutoMigrateCore struct {
|
||||
db DB
|
||||
}
|
||||
|
||||
// NewAutoMigrateCore creates a new AutoMigrateCore instance.
|
||||
func NewAutoMigrateCore(db DB) *AutoMigrateCore {
|
||||
return &AutoMigrateCore{db: db}
|
||||
}
|
||||
|
||||
// AutoMigrate automatically creates or updates tables based on entity structs.
|
||||
func (am *AutoMigrateCore) AutoMigrate(ctx context.Context, entities ...any) error {
|
||||
for _, entity := range entities {
|
||||
if err := am.migrateEntity(ctx, entity); err != nil {
|
||||
return fmt.Errorf("failed to migrate entity %T: %w", entity, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateEntity migrates a single entity to database.
|
||||
func (am *AutoMigrateCore) migrateEntity(ctx context.Context, entity any) error {
|
||||
// Get table name and columns from entity
|
||||
tableName, columns, err := am.parseEntity(entity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(columns) == 0 {
|
||||
return fmt.Errorf("no columns found for table %s", tableName)
|
||||
}
|
||||
|
||||
// Check if table exists
|
||||
hasTable, err := am.db.HasTable(ctx, tableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check table existence: %w", err)
|
||||
}
|
||||
|
||||
if !hasTable {
|
||||
// Create table
|
||||
return am.createTableFromColumns(ctx, tableName, columns)
|
||||
}
|
||||
|
||||
// Update table structure
|
||||
return am.updateTableStructure(ctx, tableName, columns)
|
||||
}
|
||||
|
||||
// parseEntity parses an entity struct and returns table name and column definitions.
|
||||
func (am *AutoMigrateCore) parseEntity(entity any) (string, map[string]*ColumnDefinition, error) {
|
||||
val := reflect.ValueOf(entity)
|
||||
|
||||
// Handle pointer
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return "", nil, fmt.Errorf("entity must be a struct or pointer to struct")
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
|
||||
// Get table name from orm tag or struct name
|
||||
tableName := am.getTableName(typ)
|
||||
|
||||
// Parse columns
|
||||
columns := make(map[string]*ColumnDefinition)
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse field to column definition
|
||||
colName, colDef, err := am.parseField(field, fieldValue)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to parse field %s: %w", field.Name, err)
|
||||
}
|
||||
|
||||
if colName != "" && colDef != nil {
|
||||
columns[colName] = colDef
|
||||
}
|
||||
}
|
||||
|
||||
return tableName, columns, nil
|
||||
}
|
||||
|
||||
// getTableName extracts table name from struct tags or generates from struct name.
|
||||
func (am *AutoMigrateCore) getTableName(typ reflect.Type) string {
|
||||
// Check for orm tag
|
||||
if tag, ok := typ.FieldByName("Meta"); ok {
|
||||
ormTag := tag.Tag.Get("orm")
|
||||
if ormTag != "" {
|
||||
// Parse table:name from orm tag
|
||||
parts := strings.Split(ormTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "table:") {
|
||||
return strings.TrimPrefix(part, "table:")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for table method (commented out as it requires instance)
|
||||
// method, hasMethod := typ.MethodByName("TableName")
|
||||
// if hasMethod {
|
||||
// // Try to call TableName method if it exists
|
||||
// }
|
||||
|
||||
// Convert struct name to snake_case table name
|
||||
return camelToSnake(typ.Name())
|
||||
}
|
||||
|
||||
// parseField parses a struct field into a column definition.
|
||||
func (am *AutoMigrateCore) parseField(field reflect.StructField, fieldValue reflect.Value) (string, *ColumnDefinition, error) {
|
||||
// Check orm tag
|
||||
ormTag := field.Tag.Get("orm")
|
||||
if ormTag == "-" {
|
||||
// Skip field
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Get column name
|
||||
colName := am.getColumnName(field, ormTag)
|
||||
|
||||
// Build column definition
|
||||
colDef := &ColumnDefinition{
|
||||
Type: am.getFieldType(field, fieldValue),
|
||||
Null: true, // Default to nullable
|
||||
}
|
||||
|
||||
// Parse orm tag options
|
||||
if ormTag != "" {
|
||||
am.parseOrmTag(colDef, ormTag)
|
||||
}
|
||||
|
||||
// Check for gorm/tag compatibility
|
||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" {
|
||||
am.parseGormTag(colDef, gormTag)
|
||||
}
|
||||
|
||||
// Check json tag for field presence
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag == "-" {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
return colName, colDef, nil
|
||||
}
|
||||
|
||||
// getColumnName extracts column name from field name or tags.
|
||||
func (am *AutoMigrateCore) getColumnName(field reflect.StructField, ormTag string) string {
|
||||
// Check orm tag for explicit column name
|
||||
if ormTag != "" {
|
||||
parts := strings.Split(ormTag, ",")
|
||||
namePart := strings.TrimSpace(parts[0])
|
||||
if namePart != "" && !strings.Contains(namePart, ":") {
|
||||
return namePart
|
||||
}
|
||||
}
|
||||
|
||||
// Use field name converted to snake_case
|
||||
return camelToSnake(field.Name)
|
||||
}
|
||||
|
||||
// getFieldType determines the database type for a Go field type.
|
||||
func (am *AutoMigrateCore) getFieldType(field reflect.StructField, fieldValue reflect.Value) string {
|
||||
// Check for explicit type in orm tag
|
||||
ormTag := field.Tag.Get("orm")
|
||||
if ormTag != "" {
|
||||
parts := strings.Split(ormTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "type:") {
|
||||
return strings.TrimPrefix(part, "type:")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Infer type from Go type
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return "BIGINT"
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return "BIGINT UNSIGNED"
|
||||
case reflect.Float32:
|
||||
return "FLOAT"
|
||||
case reflect.Float64:
|
||||
return "DOUBLE"
|
||||
case reflect.Bool:
|
||||
return "BOOLEAN"
|
||||
case reflect.String:
|
||||
// Check for length specification
|
||||
length := am.getStringLength(field)
|
||||
if length > 0 {
|
||||
return fmt.Sprintf("VARCHAR(%d)", length)
|
||||
}
|
||||
return "TEXT"
|
||||
case reflect.Struct:
|
||||
// Handle special types
|
||||
typeName := fieldValue.Type().String()
|
||||
switch {
|
||||
case strings.Contains(typeName, "time.Time"):
|
||||
return "TIMESTAMP"
|
||||
default:
|
||||
return "TEXT"
|
||||
}
|
||||
case reflect.Slice:
|
||||
elemKind := fieldValue.Type().Elem().Kind()
|
||||
if elemKind == reflect.Uint8 {
|
||||
return "BLOB"
|
||||
}
|
||||
return "JSON"
|
||||
default:
|
||||
return "TEXT"
|
||||
}
|
||||
}
|
||||
|
||||
// getStringLength gets the specified string length from tags.
|
||||
func (am *AutoMigrateCore) getStringLength(field reflect.StructField) int {
|
||||
// Check orm tag
|
||||
ormTag := field.Tag.Get("orm")
|
||||
if ormTag != "" {
|
||||
parts := strings.Split(ormTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "length:") {
|
||||
var length int
|
||||
fmt.Sscanf(strings.TrimPrefix(part, "length:"), "%d", &length)
|
||||
return length
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
parts := strings.Split(gormTag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "size:") {
|
||||
var size int
|
||||
fmt.Sscanf(strings.TrimPrefix(part, "size:"), "%d", &size)
|
||||
return size
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseOrmTag parses orm tag options.
|
||||
func (am *AutoMigrateCore) parseOrmTag(colDef *ColumnDefinition, ormTag string) {
|
||||
parts := strings.Split(ormTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
switch {
|
||||
case part == "pk" || part == "primary_key":
|
||||
colDef.PrimaryKey = true
|
||||
colDef.Null = false
|
||||
case part == "auto_increment":
|
||||
colDef.AutoIncrement = true
|
||||
case part == "not_null":
|
||||
colDef.Null = false
|
||||
case part == "unique":
|
||||
colDef.Unique = true
|
||||
case strings.HasPrefix(part, "default:"):
|
||||
defaultVal := strings.TrimPrefix(part, "default:")
|
||||
colDef.Default = defaultVal
|
||||
case strings.HasPrefix(part, "comment:"):
|
||||
colDef.Comment = strings.TrimPrefix(part, "comment:")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseGormTag parses gorm tag options for compatibility.
|
||||
func (am *AutoMigrateCore) parseGormTag(colDef *ColumnDefinition, gormTag string) {
|
||||
parts := strings.Split(gormTag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
lowerPart := strings.ToLower(part)
|
||||
switch {
|
||||
case lowerPart == "primarykey" || lowerPart == "primaryKey":
|
||||
colDef.PrimaryKey = true
|
||||
colDef.Null = false
|
||||
case lowerPart == "autoincrement":
|
||||
colDef.AutoIncrement = true
|
||||
case lowerPart == "not null":
|
||||
colDef.Null = false
|
||||
case lowerPart == "unique":
|
||||
colDef.Unique = true
|
||||
case strings.HasPrefix(lowerPart, "default:"):
|
||||
defaultVal := strings.TrimPrefix(part, "default:")
|
||||
colDef.Default = defaultVal
|
||||
case strings.HasPrefix(lowerPart, "comment:"):
|
||||
colDef.Comment = strings.TrimPrefix(part, "comment:")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createTableFromColumns creates a table from column definitions.
|
||||
func (am *AutoMigrateCore) createTableFromColumns(ctx context.Context, table string, columns map[string]*ColumnDefinition) error {
|
||||
// Get the migration instance based on database type
|
||||
migration := am.getMigrationInstance()
|
||||
if migration == nil {
|
||||
return fmt.Errorf("failed to get migration instance for database type %s", am.db.GetConfig().Type)
|
||||
}
|
||||
|
||||
return migration.CreateTable(ctx, table, columns)
|
||||
}
|
||||
|
||||
// updateTableStructure updates table structure by comparing with existing columns.
|
||||
func (am *AutoMigrateCore) updateTableStructure(ctx context.Context, table string, newColumns map[string]*ColumnDefinition) error {
|
||||
// Get existing columns
|
||||
existingFields, err := am.db.TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table fields: %w", err)
|
||||
}
|
||||
|
||||
// Add missing columns
|
||||
for colName, colDef := range newColumns {
|
||||
if _, exists := existingFields[colName]; !exists {
|
||||
// Column doesn't exist, add it
|
||||
if err := am.addColumn(ctx, table, colName, colDef); err != nil {
|
||||
return fmt.Errorf("failed to add column %s: %w", colName, err)
|
||||
}
|
||||
} else {
|
||||
// Column exists, check if modification is needed
|
||||
if err := am.modifyColumnIfNeeded(ctx, table, colName, colDef, existingFields[colName]); err != nil {
|
||||
return fmt.Errorf("failed to modify column %s: %w", colName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addColumn adds a new column to table.
|
||||
func (am *AutoMigrateCore) addColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
|
||||
migration := am.getMigrationInstance()
|
||||
if migration == nil {
|
||||
return fmt.Errorf("failed to get migration instance")
|
||||
}
|
||||
return migration.AddColumn(ctx, table, column, definition)
|
||||
}
|
||||
|
||||
// modifyColumnIfNeeded checks if column needs modification and applies it.
|
||||
func (am *AutoMigrateCore) modifyColumnIfNeeded(ctx context.Context, table, column string, newDef *ColumnDefinition, existingField *TableField) error {
|
||||
// Compare and modify if needed
|
||||
needsModification := false
|
||||
|
||||
// Check type
|
||||
if !strings.EqualFold(existingField.Type, newDef.Type) {
|
||||
needsModification = true
|
||||
}
|
||||
|
||||
// Check nullability
|
||||
if existingField.Null != newDef.Null {
|
||||
needsModification = true
|
||||
}
|
||||
|
||||
if needsModification {
|
||||
return am.modifyColumn(ctx, table, column, newDef)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// modifyColumn modifies an existing column.
|
||||
func (am *AutoMigrateCore) modifyColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
|
||||
migration := am.getMigrationInstance()
|
||||
if migration == nil {
|
||||
return fmt.Errorf("failed to get migration instance")
|
||||
}
|
||||
return migration.ModifyColumn(ctx, table, column, definition)
|
||||
}
|
||||
|
||||
// getMigrationInstance returns the appropriate Migration instance based on database type.
|
||||
func (am *AutoMigrateCore) getMigrationInstance() Migration {
|
||||
return am.db.GetMigration()
|
||||
}
|
||||
|
||||
// camelToSnake converts CamelCase to snake_case.
|
||||
func camelToSnake(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
// Add underscore before uppercase letter if:
|
||||
// 1. It's not the first character
|
||||
// 2. The previous character is lowercase or the next character is lowercase
|
||||
if i > 0 {
|
||||
prevRune := rune(s[i-1])
|
||||
if (prevRune >= 'a' && prevRune <= 'z') ||
|
||||
(i+1 < len(s) && rune(s[i+1]) >= 'a' && rune(s[i+1]) <= 'z') {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
result.WriteRune(r + 32)
|
||||
} else {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// FormatTime formats time for default values.
|
||||
func FormatTime(t time.Time) string {
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// IsZeroValue checks if a value is zero value.
|
||||
func IsZeroValue(v any) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
|
||||
return rv.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !rv.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return rv.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return rv.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return rv.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return rv.IsNil()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MigrationCore provides the base implementation for migration operations.
|
||||
// It serves as a foundation that database-specific drivers can extend.
|
||||
type MigrationCore struct {
|
||||
db DB // Database interface for executing SQL statements
|
||||
}
|
||||
|
||||
// NewMigrationCore creates a new MigrationCore instance.
|
||||
func NewMigrationCore(db DB) *MigrationCore {
|
||||
return &MigrationCore{db: db}
|
||||
}
|
||||
|
||||
// GetDB returns the underlying database interface.
|
||||
func (m *MigrationCore) GetDB() DB {
|
||||
return m.db
|
||||
}
|
||||
|
||||
// ExecuteSQL executes a raw SQL statement.
|
||||
func (m *MigrationCore) ExecuteSQL(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := m.db.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// BuildColumnDefinition builds the column definition SQL string from ColumnDefinition.
|
||||
func (m *MigrationCore) BuildColumnDefinition(name string, def *ColumnDefinition) string {
|
||||
var parts []string
|
||||
parts = append(parts, QuoteIdentifier(name))
|
||||
parts = append(parts, def.Type)
|
||||
|
||||
if !def.Null {
|
||||
parts = append(parts, "NOT NULL")
|
||||
}
|
||||
|
||||
if def.AutoIncrement {
|
||||
parts = append(parts, "AUTO_INCREMENT")
|
||||
}
|
||||
|
||||
if def.PrimaryKey {
|
||||
parts = append(parts, "PRIMARY KEY")
|
||||
}
|
||||
|
||||
if def.Unique && !def.PrimaryKey {
|
||||
parts = append(parts, "UNIQUE")
|
||||
}
|
||||
|
||||
if def.Default != nil {
|
||||
defaultValue := formatDefaultValue(def.Default)
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// formatDefaultValue formats the default value for SQL.
|
||||
func formatDefaultValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", escapeString(v))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
if v {
|
||||
return "TRUE"
|
||||
}
|
||||
return "FALSE"
|
||||
case nil:
|
||||
return "NULL"
|
||||
default:
|
||||
return fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeString escapes special characters in strings for SQL.
|
||||
func escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, "'", "''")
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
return s
|
||||
}
|
||||
|
||||
// QuoteIdentifier quotes an identifier (table name, column name, etc.) with backticks.
|
||||
func QuoteIdentifier(name string) string {
|
||||
if strings.Contains(name, ".") {
|
||||
parts := strings.Split(name, ".")
|
||||
for i, part := range parts {
|
||||
parts[i] = fmt.Sprintf("`%s`", part)
|
||||
}
|
||||
return strings.Join(parts, ".")
|
||||
}
|
||||
return fmt.Sprintf("`%s`", name)
|
||||
}
|
||||
|
||||
// QuoteIdentifierDouble quotes an identifier with double quotes (for PostgreSQL, Oracle, etc.).
|
||||
func QuoteIdentifierDouble(name string) string {
|
||||
if strings.Contains(name, ".") {
|
||||
parts := strings.Split(name, ".")
|
||||
for i, part := range parts {
|
||||
parts[i] = fmt.Sprintf(`"%s"`, part)
|
||||
}
|
||||
return strings.Join(parts, ".")
|
||||
}
|
||||
return fmt.Sprintf(`"%s"`, name)
|
||||
}
|
||||
|
||||
// BuildIndexColumns builds the column list for index creation.
|
||||
func (m *MigrationCore) BuildIndexColumns(columns []string) string {
|
||||
quoted := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
quoted[i] = QuoteIdentifier(col)
|
||||
}
|
||||
return strings.Join(quoted, ", ")
|
||||
}
|
||||
|
||||
// BuildForeignKeyColumns builds the column list for foreign key.
|
||||
func (m *MigrationCore) BuildForeignKeyColumns(columns []string) string {
|
||||
quoted := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
quoted[i] = QuoteIdentifier(col)
|
||||
}
|
||||
return "(" + strings.Join(quoted, ", ") + ")"
|
||||
}
|
||||
|
|
@ -66,7 +66,7 @@ func (m *Model) getModel() *Model {
|
|||
func (m *Model) mappingAndFilterToTableFields(table string, fields []any, filter bool) []any {
|
||||
var fieldsTable = table
|
||||
if fieldsTable != "" {
|
||||
hasTable, _ := m.db.GetCore().HasTable(fieldsTable)
|
||||
hasTable, _ := m.db.GetCore().HasTable(m.GetCtx(), fieldsTable)
|
||||
if !hasTable {
|
||||
if fieldsTable != m.tablesInit {
|
||||
// Table/alias unknown (e.g., FieldsPrefix called before LeftJoin), skip filtering.
|
||||
|
|
|
|||
3
go.mod
3
go.mod
|
|
@ -22,6 +22,7 @@ require (
|
|||
github.com/shopspring/decimal v1.3.1
|
||||
github.com/sijms/go-ora/v2 v2.7.10
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.opentelemetry.io/otel v1.42.0
|
||||
go.opentelemetry.io/otel/trace v1.42.0
|
||||
golang.org/x/crypto v0.49.0
|
||||
|
|
@ -42,6 +43,7 @@ require (
|
|||
github.com/clipperhouse/displaywidth v0.10.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.6.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect
|
||||
github.com/fatih/color v1.19.0 // indirect
|
||||
|
|
@ -78,6 +80,7 @@ require (
|
|||
github.com/paulmach/orb v0.7.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.14 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
|
|
|
|||
|
|
@ -0,0 +1,397 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SessionManager Session管理器,负责Session的创建、获取和销毁
|
||||
type SessionManager struct {
|
||||
sessions map[string]*Session // 存储所有活跃的Session,key为SessionID
|
||||
mutex sync.RWMutex // 读写锁,保证并发安全
|
||||
maxAge time.Duration // Session最大存活时间
|
||||
}
|
||||
|
||||
// Session 单个Session对象,用于存储用户数据
|
||||
type Session struct {
|
||||
ID string // Session唯一标识符
|
||||
Data map[string]interface{} // 存储的数据,键值对形式
|
||||
CreateTime time.Time // 创建时间
|
||||
LastAccess time.Time // 最后访问时间
|
||||
mutex sync.RWMutex // 读写锁,保证并发安全
|
||||
}
|
||||
|
||||
var (
|
||||
// defaultManager 默认的Session管理器实例
|
||||
defaultManager *SessionManager
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// init 包初始化时自动创建默认管理器
|
||||
func init() {
|
||||
InitDefaultManager(30 * time.Minute) // 默认30分钟过期
|
||||
}
|
||||
|
||||
// InitDefaultManager 初始化默认Session管理器
|
||||
// 参数 duration: Session的最大存活时间
|
||||
func InitDefaultManager(duration time.Duration) {
|
||||
once.Do(func() {
|
||||
defaultManager = NewSessionManager(duration)
|
||||
})
|
||||
}
|
||||
|
||||
// NewSessionManager 创建一个新的Session管理器
|
||||
// 参数 duration: Session的最大存活时间
|
||||
// 返回值: Session管理器指针
|
||||
func NewSessionManager(duration time.Duration) *SessionManager {
|
||||
sm := &SessionManager{
|
||||
sessions: make(map[string]*Session),
|
||||
maxAge: duration,
|
||||
}
|
||||
|
||||
// 启动定时清理任务,每5分钟清理一次过期的Session
|
||||
go sm.startCleanupTicker(5 * time.Minute)
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// startCleanupTicker 启动定时清理过期Session的任务
|
||||
// 参数 interval: 清理间隔时间
|
||||
func (sm *SessionManager) startCleanupTicker(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
sm.CleanupExpiredSessions()
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession 创建新的Session并返回Session ID
|
||||
// 参数 c: Gin上下文对象
|
||||
// 返回值: Session ID字符串
|
||||
func (sm *SessionManager) CreateSession(c *gin.Context) string {
|
||||
sessionID := uuid.New().String()
|
||||
|
||||
session := &Session{
|
||||
ID: sessionID,
|
||||
Data: make(map[string]interface{}),
|
||||
CreateTime: time.Now(),
|
||||
LastAccess: time.Now(),
|
||||
}
|
||||
|
||||
sm.mutex.Lock()
|
||||
sm.sessions[sessionID] = session
|
||||
sm.mutex.Unlock()
|
||||
|
||||
// 设置Cookie,保存Session ID
|
||||
c.SetCookie(
|
||||
"session_id", // Cookie名称
|
||||
sessionID, // Cookie值
|
||||
int(sm.maxAge.Seconds()), // 过期时间(秒)
|
||||
"/", // 路径
|
||||
"", // 域名(空表示当前域名)
|
||||
false, // 是否仅HTTPS
|
||||
true, // 是否HttpOnly(防止XSS攻击)
|
||||
)
|
||||
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GetSession 根据Gin上下文获取Session对象
|
||||
// 参数 c: Gin上下文对象
|
||||
// 返回值: Session对象指针,如果不存在则返回nil
|
||||
func (sm *SessionManager) GetSession(c *gin.Context) *Session {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.mutex.RLock()
|
||||
session, exists := sm.sessions[sessionID]
|
||||
sm.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查Session是否过期
|
||||
if time.Since(session.LastAccess) > sm.maxAge {
|
||||
sm.DestroySession(sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新最后访问时间
|
||||
session.mutex.Lock()
|
||||
session.LastAccess = time.Now()
|
||||
session.mutex.Unlock()
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// GetSessionByID 根据Session ID获取Session对象
|
||||
// 参数 sessionID: Session的唯一标识符
|
||||
// 返回值: Session对象指针,如果不存在则返回nil
|
||||
func (sm *SessionManager) GetSessionByID(sessionID string) *Session {
|
||||
sm.mutex.RLock()
|
||||
session, exists := sm.sessions[sessionID]
|
||||
sm.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查Session是否过期
|
||||
if time.Since(session.LastAccess) > sm.maxAge {
|
||||
sm.DestroySession(sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新最后访问时间
|
||||
session.mutex.Lock()
|
||||
session.LastAccess = time.Now()
|
||||
session.mutex.Unlock()
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// DestroySession 销毁指定的Session
|
||||
// 参数 sessionID: Session的唯一标识符
|
||||
func (sm *SessionManager) DestroySession(sessionID string) {
|
||||
sm.mutex.Lock()
|
||||
delete(sm.sessions, sessionID)
|
||||
sm.mutex.Unlock()
|
||||
}
|
||||
|
||||
// DestroySessionByContext 根据Gin上下文销毁Session
|
||||
// 参数 c: Gin上下文对象
|
||||
func (sm *SessionManager) DestroySessionByContext(c *gin.Context) {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sm.DestroySession(sessionID)
|
||||
|
||||
// 清除Cookie
|
||||
c.SetCookie(
|
||||
"session_id",
|
||||
"",
|
||||
-1, // 立即过期
|
||||
"/",
|
||||
"",
|
||||
false,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions 清理所有过期的Session
|
||||
func (sm *SessionManager) CleanupExpiredSessions() {
|
||||
now := time.Now()
|
||||
expiredIDs := make([]string, 0)
|
||||
|
||||
sm.mutex.RLock()
|
||||
for id, session := range sm.sessions {
|
||||
session.mutex.RLock()
|
||||
if now.Sub(session.LastAccess) > sm.maxAge {
|
||||
expiredIDs = append(expiredIDs, id)
|
||||
}
|
||||
session.mutex.RUnlock()
|
||||
}
|
||||
sm.mutex.RUnlock()
|
||||
|
||||
// 删除过期的Session
|
||||
sm.mutex.Lock()
|
||||
for _, id := range expiredIDs {
|
||||
delete(sm.sessions, id)
|
||||
}
|
||||
sm.mutex.Unlock()
|
||||
}
|
||||
|
||||
// GetSessionCount 获取当前活跃的Session数量
|
||||
// 返回值: Session数量
|
||||
func (sm *SessionManager) GetSessionCount() int {
|
||||
sm.mutex.RLock()
|
||||
defer sm.mutex.RUnlock()
|
||||
return len(sm.sessions)
|
||||
}
|
||||
|
||||
// Set 在Session中设置键值对
|
||||
// 参数 key: 键名
|
||||
// 参数 value: 值
|
||||
func (s *Session) Set(key string, value interface{}) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.Data[key] = value
|
||||
}
|
||||
|
||||
// Get 从Session中获取值
|
||||
// 参数 key: 键名
|
||||
// 返回值: 对应的值,如果不存在则返回nil
|
||||
func (s *Session) Get(key string) interface{} {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
return s.Data[key]
|
||||
}
|
||||
|
||||
// GetString 从Session中获取字符串类型的值
|
||||
// 参数 key: 键名
|
||||
// 返回值: 字符串值,如果不存在或类型不匹配则返回空字符串
|
||||
func (s *Session) GetString(key string) string {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
if val, ok := s.Data[key]; ok {
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetInt 从Session中获取整数类型的值
|
||||
// 参数 key: 键名
|
||||
// 返回值: 整数值,如果不存在或类型不匹配则返回0
|
||||
func (s *Session) GetInt(key string) int {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
if val, ok := s.Data[key]; ok {
|
||||
switch v := val.(type) {
|
||||
case int:
|
||||
return v
|
||||
case float64: // JSON解码时数字会变成float64
|
||||
return int(v)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetFloat64 从Session中获取浮点数类型的值
|
||||
// 参数 key: 键名
|
||||
// 返回值: 浮点数值,如果不存在或类型不匹配则返回0
|
||||
func (s *Session) GetFloat64(key string) float64 {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
if val, ok := s.Data[key]; ok {
|
||||
switch v := val.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case int:
|
||||
return float64(v)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetBool 从Session中获取布尔类型的值
|
||||
// 参数 key: 键名
|
||||
// 返回值: 布尔值,如果不存在或类型不匹配则返回false
|
||||
func (s *Session) GetBool(key string) bool {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
if val, ok := s.Data[key]; ok {
|
||||
if b, ok := val.(bool); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Delete 从Session中删除指定的键
|
||||
// 参数 key: 要删除的键名
|
||||
func (s *Session) Delete(key string) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
delete(s.Data, key)
|
||||
}
|
||||
|
||||
// Clear 清空Session中的所有数据
|
||||
func (s *Session) Clear() {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.Data = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Has 检查Session中是否存在指定的键
|
||||
// 参数 key: 键名
|
||||
// 返回值: 如果存在返回true,否则返回false
|
||||
func (s *Session) Has(key string) bool {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
_, exists := s.Data[key]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetAll 获取Session中的所有数据
|
||||
// 返回值: 包含所有数据的map副本
|
||||
func (s *Session) GetAll() map[string]interface{} {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
// 返回副本,避免外部修改
|
||||
result := make(map[string]interface{}, len(s.Data))
|
||||
for k, v := range s.Data {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToJSON 将Session数据转换为JSON字符串
|
||||
// 返回值: JSON字符串,如果转换失败则返回错误
|
||||
func (s *Session) ToJSON() (string, error) {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
data, err := json.Marshal(s.Data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化Session数据失败: %w", err)
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// FromJSON 从JSON字符串恢复Session数据
|
||||
// 参数 jsonStr: JSON字符串
|
||||
// 返回值: 如果解析失败则返回错误
|
||||
func (s *Session) FromJSON(jsonStr string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
return fmt.Errorf("反序列化Session数据失败: %w", err)
|
||||
}
|
||||
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultManager 获取默认的Session管理器
|
||||
// 返回值: 默认Session管理器指针
|
||||
func GetDefaultManager() *SessionManager {
|
||||
return defaultManager
|
||||
}
|
||||
|
||||
// CreateSession 使用默认管理器创建Session
|
||||
// 参数 c: Gin上下文对象
|
||||
// 返回值: Session ID字符串
|
||||
func CreateSession(c *gin.Context) string {
|
||||
return defaultManager.CreateSession(c)
|
||||
}
|
||||
|
||||
// GetSession 使用默认管理器获取Session
|
||||
// 参数 c: Gin上下文对象
|
||||
// 返回值: Session对象指针
|
||||
func GetSession(c *gin.Context) *Session {
|
||||
return defaultManager.GetSession(c)
|
||||
}
|
||||
|
||||
// DestroySession 使用默认管理器销毁Session
|
||||
// 参数 c: Gin上下文对象
|
||||
func DestroySession(c *gin.Context) {
|
||||
defaultManager.DestroySessionByContext(c)
|
||||
}
|
||||
Loading…
Reference in New Issue