// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. // // This Source Code Form is subject to the terms of the MIT License. // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. package mssql import ( "context" "fmt" "strings" "git.magicany.cc/black1552/gin-base/database" ) // Migration implements database migration operations for SQL Server. type Migration struct { *database.MigrationCore *database.AutoMigrateCore } // NewMigration creates a new SQL Server Migration instance. func NewMigration(db database.DB) *Migration { return &Migration{ MigrationCore: database.NewMigrationCore(db), AutoMigrateCore: database.NewAutoMigrateCore(db), } } // CreateTable creates a new table with the given name and column definitions. func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error { if len(columns) == 0 { return fmt.Errorf("cannot create table without columns") } var opts database.TableOptions for _, opt := range options { opt(&opts) } var sql strings.Builder sql.WriteString("CREATE TABLE ") sql.WriteString(database.QuoteIdentifier(table)) sql.WriteString(" (\n") // Add columns var colDefs []string var primaryKeys []string for name, def := range columns { colDef := m.buildColumnDefinition(name, def) if def.PrimaryKey { primaryKeys = append(primaryKeys, database.QuoteIdentifier(name)) } colDefs = append(colDefs, " "+colDef) } // Add primary key constraint if needed if len(primaryKeys) > 0 { colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", "))) } sql.WriteString(strings.Join(colDefs, ",\n")) sql.WriteString("\n)") return m.ExecuteSQL(ctx, sql.String()) } // buildColumnDefinition builds column definition for SQL Server. func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string { var parts []string parts = append(parts, database.QuoteIdentifier(name)) // Handle SQL Server-specific types dbType := def.Type if def.AutoIncrement { if dbType == "INT" || dbType == "INTEGER" { dbType = "INT IDENTITY(1,1)" } else if dbType == "BIGINT" { dbType = "BIGINT IDENTITY(1,1)" } } parts = append(parts, dbType) if !def.Null { parts = append(parts, "NOT NULL") } else { parts = append(parts, "NULL") } if def.Default != nil { defaultValue := formatDefaultValue(def.Default) parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue)) } return strings.Join(parts, " ") } // DropTable drops an existing table from the database. func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error { if len(ifExists) > 0 && ifExists[0] { sql := fmt.Sprintf( "IF OBJECT_ID('%s', 'U') IS NOT NULL DROP TABLE %s", table, database.QuoteIdentifier(table), ) return m.ExecuteSQL(ctx, sql) } sql := fmt.Sprintf("DROP TABLE %s", database.QuoteIdentifier(table)) return m.ExecuteSQL(ctx, sql) } // HasTable checks if a table exists in the database. func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) { query := fmt.Sprintf( "SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '%s'", table, ) value, err := m.GetDB().GetValue(ctx, query) if err != nil { return false, err } return value.Int() > 0, nil } // RenameTable renames an existing table from oldName to newName. func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error { sql := fmt.Sprintf( "EXEC sp_rename '%s', '%s'", oldName, newName, ) return m.ExecuteSQL(ctx, sql) } // TruncateTable removes all records from a table but keeps the table structure. func (m *Migration) TruncateTable(ctx context.Context, table string) error { sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifier(table)) return m.ExecuteSQL(ctx, sql) } // AddColumn adds a new column to an existing table. func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error { colDef := m.buildColumnDefinition(column, definition) sql := fmt.Sprintf("ALTER TABLE %s ADD %s", database.QuoteIdentifier(table), colDef) return m.ExecuteSQL(ctx, sql) } // DropColumn removes a column from an existing table. func (m *Migration) DropColumn(ctx context.Context, table, column string) error { sql := fmt.Sprintf( "ALTER TABLE %s DROP COLUMN %s", database.QuoteIdentifier(table), database.QuoteIdentifier(column), ) return m.ExecuteSQL(ctx, sql) } // RenameColumn renames a column in an existing table. func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error { sql := fmt.Sprintf( "EXEC sp_rename '%s.%s', '%s', 'COLUMN'", table, oldName, newName, ) return m.ExecuteSQL(ctx, sql) } // ModifyColumn modifies an existing column's definition. func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error { colDef := m.buildColumnDefinition(column, definition) sql := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", database.QuoteIdentifier(table), colDef) return m.ExecuteSQL(ctx, sql) } // HasColumn checks if a column exists in a table. func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) { fields, err := m.GetDB().TableFields(ctx, table) if err != nil { return false, err } _, exists := fields[column] return exists, nil } // CreateIndex creates a new index on the specified table and columns. func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error { var opts database.IndexOptions for _, opt := range options { opt(&opts) } var sql strings.Builder sql.WriteString("CREATE ") if opts.Unique { sql.WriteString("UNIQUE ") } sql.WriteString("INDEX ") sql.WriteString(database.QuoteIdentifier(index)) sql.WriteString(" ON ") sql.WriteString(database.QuoteIdentifier(table)) colList := m.BuildIndexColumns(columns) sql.WriteString(fmt.Sprintf(" (%s)", colList)) return m.ExecuteSQL(ctx, sql.String()) } // DropIndex drops an existing index from a table. func (m *Migration) DropIndex(ctx context.Context, table, index string) error { sql := fmt.Sprintf("DROP INDEX %s ON %s", database.QuoteIdentifier(index), database.QuoteIdentifier(table)) return m.ExecuteSQL(ctx, sql) } // HasIndex checks if an index exists on a table. func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) { query := fmt.Sprintf( "SELECT COUNT(*) FROM sys.indexes WHERE name = '%s' AND object_id = OBJECT_ID('%s')", index, table, ) value, err := m.GetDB().GetValue(ctx, query) if err != nil { return false, err } return value.Int() > 0, nil } // CreateForeignKey creates a foreign key constraint. func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error { var opts database.ForeignKeyOptions for _, opt := range options { opt(&opts) } sql := fmt.Sprintf( "ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s", database.QuoteIdentifier(table), database.QuoteIdentifier(constraint), m.BuildForeignKeyColumns(columns), database.QuoteIdentifier(refTable), m.BuildForeignKeyColumns(refColumns), ) if opts.OnDelete != "" { sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete) } if opts.OnUpdate != "" { sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate) } return m.ExecuteSQL(ctx, sql) } // DropForeignKey drops a foreign key constraint. func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error { sql := fmt.Sprintf( "ALTER TABLE %s DROP CONSTRAINT %s", database.QuoteIdentifier(table), database.QuoteIdentifier(constraint), ) return m.ExecuteSQL(ctx, sql) } // HasForeignKey checks if a foreign key constraint exists. func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) { query := fmt.Sprintf( "SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_NAME = '%s' AND TABLE_NAME = '%s' AND CONSTRAINT_TYPE = 'FOREIGN KEY'", constraint, table, ) value, err := m.GetDB().GetValue(ctx, query) if err != nil { return false, err } return value.Int() > 0, nil } // CreateSchema creates a new database schema. func (m *Migration) CreateSchema(ctx context.Context, schema string) error { sql := fmt.Sprintf("CREATE SCHEMA %s", database.QuoteIdentifier(schema)) return m.ExecuteSQL(ctx, sql) } // DropSchema drops an existing database schema. func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error { sql := fmt.Sprintf("DROP SCHEMA %s", database.QuoteIdentifier(schema)) return m.ExecuteSQL(ctx, sql) } // HasSchema checks if a schema exists. func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) { query := fmt.Sprintf( "SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'", schema, ) value, err := m.GetDB().GetValue(ctx, query) if err != nil { return false, err } return value.Int() > 0, nil } // formatDefaultValue formats the default value for SQL. func formatDefaultValue(value any) string { switch v := value.(type) { case string: return fmt.Sprintf("'%s'", escapeString(v)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return fmt.Sprintf("%d", v) case float32, float64: return fmt.Sprintf("%f", v) case bool: if v { return "1" } return "0" case nil: return "NULL" default: return fmt.Sprintf("'%v'", v) } } // escapeString escapes special characters in strings for SQL. func escapeString(s string) string { s = strings.ReplaceAll(s, "'", "''") return s }