482 lines
13 KiB
Go
482 lines
13 KiB
Go
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the MIT License.
|
|
// If a copy of the MIT was not distributed with this file,
|
|
// You can obtain one at https://github.com/gogf/gf.
|
|
|
|
package pgsql
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"git.magicany.cc/black1552/gin-base/database"
|
|
)
|
|
|
|
// Migration implements database migration operations for PostgreSQL.
|
|
type Migration struct {
|
|
*database.MigrationCore
|
|
*database.AutoMigrateCore
|
|
}
|
|
|
|
// NewMigration creates a new PostgreSQL Migration instance.
|
|
func NewMigration(db database.DB) *Migration {
|
|
return &Migration{
|
|
MigrationCore: database.NewMigrationCore(db),
|
|
AutoMigrateCore: database.NewAutoMigrateCore(db),
|
|
}
|
|
}
|
|
|
|
// CreateTable creates a new table with the given name and column definitions.
|
|
func (m *Migration) CreateTable(ctx context.Context, table string, columns map[string]*database.ColumnDefinition, options ...database.TableOption) error {
|
|
if len(columns) == 0 {
|
|
return fmt.Errorf("cannot create table without columns")
|
|
}
|
|
|
|
var opts database.TableOptions
|
|
for _, opt := range options {
|
|
opt(&opts)
|
|
}
|
|
|
|
var sql strings.Builder
|
|
sql.WriteString("CREATE TABLE ")
|
|
if opts.IfNotExists {
|
|
sql.WriteString("IF NOT EXISTS ")
|
|
}
|
|
sql.WriteString(database.QuoteIdentifierDouble(table))
|
|
sql.WriteString(" (\n")
|
|
|
|
// Add columns
|
|
var colDefs []string
|
|
var primaryKeys []string
|
|
for name, def := range columns {
|
|
colDef := m.buildColumnDefinition(name, def)
|
|
if def.PrimaryKey {
|
|
primaryKeys = append(primaryKeys, database.QuoteIdentifierDouble(name))
|
|
}
|
|
colDefs = append(colDefs, " "+colDef)
|
|
}
|
|
|
|
// Add primary key constraint if needed
|
|
if len(primaryKeys) > 0 {
|
|
colDefs = append(colDefs, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
|
|
}
|
|
|
|
sql.WriteString(strings.Join(colDefs, ",\n"))
|
|
sql.WriteString("\n)")
|
|
|
|
// Add table comment if provided
|
|
if opts.Comment != "" {
|
|
commentSQL := fmt.Sprintf(
|
|
"COMMENT ON TABLE %s IS '%s'",
|
|
database.QuoteIdentifierDouble(table),
|
|
escapeString(opts.Comment),
|
|
)
|
|
if err := m.ExecuteSQL(ctx, commentSQL); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return m.ExecuteSQL(ctx, sql.String())
|
|
}
|
|
|
|
// buildColumnDefinition builds column definition for PostgreSQL.
|
|
func (m *Migration) buildColumnDefinition(name string, def *database.ColumnDefinition) string {
|
|
var parts []string
|
|
parts = append(parts, database.QuoteIdentifierDouble(name))
|
|
|
|
// Handle PostgreSQL-specific types
|
|
dbType := def.Type
|
|
if def.AutoIncrement {
|
|
if dbType == "INT" || dbType == "INTEGER" {
|
|
dbType = "SERIAL"
|
|
} else if dbType == "BIGINT" {
|
|
dbType = "BIGSERIAL"
|
|
}
|
|
}
|
|
parts = append(parts, dbType)
|
|
|
|
if def.PrimaryKey {
|
|
// Primary key is handled separately
|
|
} else {
|
|
if !def.Null {
|
|
parts = append(parts, "NOT NULL")
|
|
}
|
|
|
|
if def.Unique {
|
|
parts = append(parts, "UNIQUE")
|
|
}
|
|
|
|
if def.Default != nil {
|
|
defaultValue := formatDefaultValue(def.Default)
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
|
|
}
|
|
}
|
|
|
|
return strings.Join(parts, " ")
|
|
}
|
|
|
|
// DropTable drops an existing table from the database.
|
|
func (m *Migration) DropTable(ctx context.Context, table string, ifExists ...bool) error {
|
|
sql := "DROP TABLE "
|
|
if len(ifExists) > 0 && ifExists[0] {
|
|
sql += "IF EXISTS "
|
|
}
|
|
sql += database.QuoteIdentifierDouble(table)
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// HasTable checks if a table exists in the database.
|
|
func (m *Migration) HasTable(ctx context.Context, table string) (bool, error) {
|
|
schema := m.GetDB().GetSchema()
|
|
if schema == "" {
|
|
schema = "current_schema()"
|
|
} else {
|
|
schema = fmt.Sprintf("'%s'", schema)
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = %s AND table_name = '%s'",
|
|
schema,
|
|
table,
|
|
)
|
|
|
|
value, err := m.GetDB().GetValue(ctx, query)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return value.Int() > 0, nil
|
|
}
|
|
|
|
// RenameTable renames an existing table from oldName to newName.
|
|
func (m *Migration) RenameTable(ctx context.Context, oldName, newName string) error {
|
|
sql := fmt.Sprintf(
|
|
"ALTER TABLE %s RENAME TO %s",
|
|
database.QuoteIdentifierDouble(oldName),
|
|
database.QuoteIdentifierDouble(newName),
|
|
)
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// TruncateTable removes all records from a table but keeps the table structure.
|
|
func (m *Migration) TruncateTable(ctx context.Context, table string) error {
|
|
sql := fmt.Sprintf("TRUNCATE TABLE %s", database.QuoteIdentifierDouble(table))
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// AddColumn adds a new column to an existing table.
|
|
func (m *Migration) AddColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
|
colDef := m.buildColumnDefinition(column, definition)
|
|
sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", database.QuoteIdentifierDouble(table), colDef)
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// DropColumn removes a column from an existing table.
|
|
func (m *Migration) DropColumn(ctx context.Context, table, column string) error {
|
|
sql := fmt.Sprintf(
|
|
"ALTER TABLE %s DROP COLUMN %s",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(column),
|
|
)
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// RenameColumn renames a column in an existing table.
|
|
func (m *Migration) RenameColumn(ctx context.Context, table, oldName, newName string) error {
|
|
sql := fmt.Sprintf(
|
|
"ALTER TABLE %s RENAME COLUMN %s TO %s",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(oldName),
|
|
database.QuoteIdentifierDouble(newName),
|
|
)
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// ModifyColumn modifies an existing column's definition.
|
|
func (m *Migration) ModifyColumn(ctx context.Context, table, column string, definition *database.ColumnDefinition) error {
|
|
// PostgreSQL requires multiple ALTER statements for different modifications
|
|
var statements []string
|
|
|
|
if definition.Type != "" {
|
|
statements = append(statements, fmt.Sprintf(
|
|
"ALTER TABLE %s ALTER COLUMN %s TYPE %s",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(column),
|
|
definition.Type,
|
|
))
|
|
}
|
|
|
|
if definition.Default != nil {
|
|
defaultValue := formatDefaultValue(definition.Default)
|
|
statements = append(statements, fmt.Sprintf(
|
|
"ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(column),
|
|
defaultValue,
|
|
))
|
|
} else {
|
|
statements = append(statements, fmt.Sprintf(
|
|
"ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(column),
|
|
))
|
|
}
|
|
|
|
if !definition.Null {
|
|
statements = append(statements, fmt.Sprintf(
|
|
"ALTER TABLE %s ALTER COLUMN %s SET NOT NULL",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(column),
|
|
))
|
|
} else {
|
|
statements = append(statements, fmt.Sprintf(
|
|
"ALTER TABLE %s ALTER COLUMN %s DROP NOT NULL",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(column),
|
|
))
|
|
}
|
|
|
|
// Execute all statements
|
|
for _, stmt := range statements {
|
|
if err := m.ExecuteSQL(ctx, stmt); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HasColumn checks if a column exists in a table.
|
|
func (m *Migration) HasColumn(ctx context.Context, table, column string) (bool, error) {
|
|
fields, err := m.GetDB().TableFields(ctx, table)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
_, exists := fields[column]
|
|
return exists, nil
|
|
}
|
|
|
|
// CreateIndex creates a new index on the specified table and columns.
|
|
func (m *Migration) CreateIndex(ctx context.Context, table, index string, columns []string, options ...database.IndexOption) error {
|
|
var opts database.IndexOptions
|
|
for _, opt := range options {
|
|
opt(&opts)
|
|
}
|
|
|
|
var sql strings.Builder
|
|
sql.WriteString("CREATE ")
|
|
|
|
if opts.Unique {
|
|
sql.WriteString("UNIQUE ")
|
|
}
|
|
|
|
sql.WriteString("INDEX ")
|
|
sql.WriteString(database.QuoteIdentifierDouble(index))
|
|
sql.WriteString(" ON ")
|
|
sql.WriteString(database.QuoteIdentifierDouble(table))
|
|
|
|
colList := m.BuildIndexColumnsWithDouble(columns)
|
|
sql.WriteString(fmt.Sprintf(" (%s)", colList))
|
|
|
|
if opts.Using != "" {
|
|
sql.WriteString(fmt.Sprintf(" USING %s", opts.Using))
|
|
}
|
|
|
|
if opts.Comment != "" {
|
|
// Comment will be added after index creation
|
|
}
|
|
|
|
err := m.ExecuteSQL(ctx, sql.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Add comment if provided
|
|
if opts.Comment != "" {
|
|
commentSQL := fmt.Sprintf(
|
|
"COMMENT ON INDEX %s IS '%s'",
|
|
database.QuoteIdentifierDouble(index),
|
|
escapeString(opts.Comment),
|
|
)
|
|
return m.ExecuteSQL(ctx, commentSQL)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// BuildIndexColumnsWithDouble builds the column list for index creation using double quotes.
|
|
func (m *Migration) BuildIndexColumnsWithDouble(columns []string) string {
|
|
quoted := make([]string, len(columns))
|
|
for i, col := range columns {
|
|
quoted[i] = database.QuoteIdentifierDouble(col)
|
|
}
|
|
return strings.Join(quoted, ", ")
|
|
}
|
|
|
|
// DropIndex drops an existing index from a table.
|
|
func (m *Migration) DropIndex(ctx context.Context, table, index string) error {
|
|
sql := fmt.Sprintf("DROP INDEX %s", database.QuoteIdentifierDouble(index))
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// HasIndex checks if an index exists on a table.
|
|
func (m *Migration) HasIndex(ctx context.Context, table, index string) (bool, error) {
|
|
schema := m.GetDB().GetSchema()
|
|
if schema == "" {
|
|
schema = "current_schema()"
|
|
} else {
|
|
schema = fmt.Sprintf("'%s'", schema)
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
"SELECT COUNT(*) FROM pg_indexes WHERE schemaname = %s AND tablename = '%s' AND indexname = '%s'",
|
|
schema,
|
|
table,
|
|
index,
|
|
)
|
|
|
|
value, err := m.GetDB().GetValue(ctx, query)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return value.Int() > 0, nil
|
|
}
|
|
|
|
// CreateForeignKey creates a foreign key constraint.
|
|
func (m *Migration) CreateForeignKey(ctx context.Context, table, constraint string, columns []string, refTable string, refColumns []string, options ...database.ForeignKeyOption) error {
|
|
var opts database.ForeignKeyOptions
|
|
for _, opt := range options {
|
|
opt(&opts)
|
|
}
|
|
|
|
sql := fmt.Sprintf(
|
|
"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY %s REFERENCES %s %s",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(constraint),
|
|
m.BuildForeignKeyColumnsWithDouble(columns),
|
|
database.QuoteIdentifierDouble(refTable),
|
|
m.BuildForeignKeyColumnsWithDouble(refColumns),
|
|
)
|
|
|
|
if opts.OnDelete != "" {
|
|
sql += fmt.Sprintf(" ON DELETE %s", opts.OnDelete)
|
|
}
|
|
if opts.OnUpdate != "" {
|
|
sql += fmt.Sprintf(" ON UPDATE %s", opts.OnUpdate)
|
|
}
|
|
if opts.Deferrable {
|
|
sql += " DEFERRABLE"
|
|
if opts.InitiallyDeferred {
|
|
sql += " INITIALLY DEFERRED"
|
|
}
|
|
}
|
|
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// BuildForeignKeyColumnsWithDouble builds the column list for foreign key using double quotes.
|
|
func (m *Migration) BuildForeignKeyColumnsWithDouble(columns []string) string {
|
|
quoted := make([]string, len(columns))
|
|
for i, col := range columns {
|
|
quoted[i] = database.QuoteIdentifierDouble(col)
|
|
}
|
|
return "(" + strings.Join(quoted, ", ") + ")"
|
|
}
|
|
|
|
// DropForeignKey drops a foreign key constraint.
|
|
func (m *Migration) DropForeignKey(ctx context.Context, table, constraint string) error {
|
|
sql := fmt.Sprintf(
|
|
"ALTER TABLE %s DROP CONSTRAINT %s",
|
|
database.QuoteIdentifierDouble(table),
|
|
database.QuoteIdentifierDouble(constraint),
|
|
)
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// HasForeignKey checks if a foreign key constraint exists.
|
|
func (m *Migration) HasForeignKey(ctx context.Context, table, constraint string) (bool, error) {
|
|
schema := m.GetDB().GetSchema()
|
|
if schema == "" {
|
|
schema = "current_schema()"
|
|
} else {
|
|
schema = fmt.Sprintf("'%s'", schema)
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
"SELECT COUNT(*) FROM information_schema.table_constraints WHERE constraint_schema = %s AND table_name = '%s' AND constraint_name = '%s' AND constraint_type = 'FOREIGN KEY'",
|
|
schema,
|
|
table,
|
|
constraint,
|
|
)
|
|
|
|
value, err := m.GetDB().GetValue(ctx, query)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return value.Int() > 0, nil
|
|
}
|
|
|
|
// CreateSchema creates a new database schema.
|
|
func (m *Migration) CreateSchema(ctx context.Context, schema string) error {
|
|
sql := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", database.QuoteIdentifierDouble(schema))
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// DropSchema drops an existing database schema.
|
|
func (m *Migration) DropSchema(ctx context.Context, schema string, cascade ...bool) error {
|
|
sql := "DROP SCHEMA "
|
|
if len(cascade) > 0 && cascade[0] {
|
|
sql += "IF EXISTS "
|
|
sql += database.QuoteIdentifierDouble(schema) + " CASCADE"
|
|
} else {
|
|
sql += "IF EXISTS " + database.QuoteIdentifierDouble(schema)
|
|
}
|
|
return m.ExecuteSQL(ctx, sql)
|
|
}
|
|
|
|
// HasSchema checks if a schema exists.
|
|
func (m *Migration) HasSchema(ctx context.Context, schema string) (bool, error) {
|
|
query := fmt.Sprintf(
|
|
"SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = '%s'",
|
|
schema,
|
|
)
|
|
|
|
value, err := m.GetDB().GetValue(ctx, query)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return value.Int() > 0, nil
|
|
}
|
|
|
|
// formatDefaultValue formats the default value for SQL.
|
|
func formatDefaultValue(value any) string {
|
|
switch v := value.(type) {
|
|
case string:
|
|
return fmt.Sprintf("'%s'", escapeString(v))
|
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
|
return fmt.Sprintf("%d", v)
|
|
case float32, float64:
|
|
return fmt.Sprintf("%f", v)
|
|
case bool:
|
|
if v {
|
|
return "TRUE"
|
|
}
|
|
return "FALSE"
|
|
case nil:
|
|
return "NULL"
|
|
default:
|
|
return fmt.Sprintf("'%v'", v)
|
|
}
|
|
}
|
|
|
|
// escapeString escapes special characters in strings for SQL.
|
|
func escapeString(s string) string {
|
|
s = strings.ReplaceAll(s, "'", "''")
|
|
return s
|
|
}
|