gin-base/database/gdb_migration_auto.go

453 lines
13 KiB
Go

// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package database
import (
"context"
"fmt"
"reflect"
"strings"
"time"
)
// AutoMigrateCore provides automatic migration functionality based on entity structs.
type AutoMigrateCore struct {
db DB
}
// NewAutoMigrateCore creates a new AutoMigrateCore instance.
func NewAutoMigrateCore(db DB) *AutoMigrateCore {
return &AutoMigrateCore{db: db}
}
// AutoMigrate automatically creates or updates tables based on entity structs.
func (am *AutoMigrateCore) AutoMigrate(ctx context.Context, entities ...any) error {
for _, entity := range entities {
if err := am.migrateEntity(ctx, entity); err != nil {
return fmt.Errorf("failed to migrate entity %T: %w", entity, err)
}
}
return nil
}
// migrateEntity migrates a single entity to database.
func (am *AutoMigrateCore) migrateEntity(ctx context.Context, entity any) error {
// Get table name and columns from entity
tableName, columns, err := am.parseEntity(entity)
if err != nil {
return err
}
if len(columns) == 0 {
return fmt.Errorf("no columns found for table %s", tableName)
}
// Check if table exists
hasTable, err := am.db.HasTable(ctx, tableName)
if err != nil {
return fmt.Errorf("failed to check table existence: %w", err)
}
if !hasTable {
// Create table
return am.createTableFromColumns(ctx, tableName, columns)
}
// Update table structure
return am.updateTableStructure(ctx, tableName, columns)
}
// parseEntity parses an entity struct and returns table name and column definitions.
func (am *AutoMigrateCore) parseEntity(entity any) (string, map[string]*ColumnDefinition, error) {
val := reflect.ValueOf(entity)
// Handle pointer
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return "", nil, fmt.Errorf("entity must be a struct or pointer to struct")
}
typ := val.Type()
// Get table name from orm tag or struct name
tableName := am.getTableName(typ)
// Parse columns
columns := make(map[string]*ColumnDefinition)
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Parse field to column definition
colName, colDef, err := am.parseField(field, fieldValue)
if err != nil {
return "", nil, fmt.Errorf("failed to parse field %s: %w", field.Name, err)
}
if colName != "" && colDef != nil {
columns[colName] = colDef
}
}
return tableName, columns, nil
}
// getTableName extracts table name from struct tags or generates from struct name.
func (am *AutoMigrateCore) getTableName(typ reflect.Type) string {
// Check for orm tag
if tag, ok := typ.FieldByName("Meta"); ok {
ormTag := tag.Tag.Get("orm")
if ormTag != "" {
// Parse table:name from orm tag
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "table:") {
return strings.TrimPrefix(part, "table:")
}
}
}
}
// Check for table method (commented out as it requires instance)
// method, hasMethod := typ.MethodByName("TableName")
// if hasMethod {
// // Try to call TableName method if it exists
// }
// Convert struct name to snake_case table name
return camelToSnake(typ.Name())
}
// parseField parses a struct field into a column definition.
func (am *AutoMigrateCore) parseField(field reflect.StructField, fieldValue reflect.Value) (string, *ColumnDefinition, error) {
// Check orm tag
ormTag := field.Tag.Get("orm")
if ormTag == "-" {
// Skip field
return "", nil, nil
}
// Get column name
colName := am.getColumnName(field, ormTag)
// Build column definition
colDef := &ColumnDefinition{
Type: am.getFieldType(field, fieldValue),
Null: true, // Default to nullable
}
// Parse orm tag options
if ormTag != "" {
am.parseOrmTag(colDef, ormTag)
}
// Check for gorm/tag compatibility
if gormTag := field.Tag.Get("gorm"); gormTag != "" {
am.parseGormTag(colDef, gormTag)
}
// Check json tag for field presence
jsonTag := field.Tag.Get("json")
if jsonTag == "-" {
return "", nil, nil
}
return colName, colDef, nil
}
// getColumnName extracts column name from field name or tags.
func (am *AutoMigrateCore) getColumnName(field reflect.StructField, ormTag string) string {
// Check orm tag for explicit column name
if ormTag != "" {
parts := strings.Split(ormTag, ",")
namePart := strings.TrimSpace(parts[0])
if namePart != "" && !strings.Contains(namePart, ":") {
return namePart
}
}
// Use field name converted to snake_case
return camelToSnake(field.Name)
}
// getFieldType determines the database type for a Go field type.
func (am *AutoMigrateCore) getFieldType(field reflect.StructField, fieldValue reflect.Value) string {
// Check for explicit type in orm tag
ormTag := field.Tag.Get("orm")
if ormTag != "" {
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "type:") {
return strings.TrimPrefix(part, "type:")
}
}
}
// Infer type from Go type
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return "BIGINT"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "BIGINT UNSIGNED"
case reflect.Float32:
return "FLOAT"
case reflect.Float64:
return "DOUBLE"
case reflect.Bool:
return "BOOLEAN"
case reflect.String:
// Check for length specification
length := am.getStringLength(field)
if length > 0 {
return fmt.Sprintf("VARCHAR(%d)", length)
}
return "TEXT"
case reflect.Struct:
// Handle special types
typeName := fieldValue.Type().String()
switch {
case strings.Contains(typeName, "time.Time"):
return "TIMESTAMP"
default:
return "TEXT"
}
case reflect.Slice:
elemKind := fieldValue.Type().Elem().Kind()
if elemKind == reflect.Uint8 {
return "BLOB"
}
return "JSON"
default:
return "TEXT"
}
}
// getStringLength gets the specified string length from tags.
func (am *AutoMigrateCore) getStringLength(field reflect.StructField) int {
// Check orm tag
ormTag := field.Tag.Get("orm")
if ormTag != "" {
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "length:") {
var length int
fmt.Sscanf(strings.TrimPrefix(part, "length:"), "%d", &length)
return length
}
}
}
// Check gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "size:") {
var size int
fmt.Sscanf(strings.TrimPrefix(part, "size:"), "%d", &size)
return size
}
}
}
return 0
}
// parseOrmTag parses orm tag options.
func (am *AutoMigrateCore) parseOrmTag(colDef *ColumnDefinition, ormTag string) {
parts := strings.Split(ormTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
switch {
case part == "pk" || part == "primary_key":
colDef.PrimaryKey = true
colDef.Null = false
case part == "auto_increment":
colDef.AutoIncrement = true
case part == "not_null":
colDef.Null = false
case part == "unique":
colDef.Unique = true
case strings.HasPrefix(part, "default:"):
defaultVal := strings.TrimPrefix(part, "default:")
colDef.Default = defaultVal
case strings.HasPrefix(part, "comment:"):
colDef.Comment = strings.TrimPrefix(part, "comment:")
}
}
}
// parseGormTag parses gorm tag options for compatibility.
func (am *AutoMigrateCore) parseGormTag(colDef *ColumnDefinition, gormTag string) {
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
lowerPart := strings.ToLower(part)
switch {
case lowerPart == "primarykey" || lowerPart == "primaryKey":
colDef.PrimaryKey = true
colDef.Null = false
case lowerPart == "autoincrement":
colDef.AutoIncrement = true
case lowerPart == "not null":
colDef.Null = false
case lowerPart == "unique":
colDef.Unique = true
case strings.HasPrefix(lowerPart, "default:"):
defaultVal := strings.TrimPrefix(part, "default:")
colDef.Default = defaultVal
case strings.HasPrefix(lowerPart, "comment:"):
colDef.Comment = strings.TrimPrefix(part, "comment:")
}
}
}
// createTableFromColumns creates a table from column definitions.
func (am *AutoMigrateCore) createTableFromColumns(ctx context.Context, table string, columns map[string]*ColumnDefinition) error {
// Get the migration instance based on database type
migration := am.getMigrationInstance()
if migration == nil {
return fmt.Errorf("failed to get migration instance for database type %s", am.db.GetConfig().Type)
}
return migration.CreateTable(ctx, table, columns)
}
// updateTableStructure updates table structure by comparing with existing columns.
func (am *AutoMigrateCore) updateTableStructure(ctx context.Context, table string, newColumns map[string]*ColumnDefinition) error {
// Get existing columns
existingFields, err := am.db.TableFields(ctx, table)
if err != nil {
return fmt.Errorf("failed to get table fields: %w", err)
}
// Add missing columns
for colName, colDef := range newColumns {
if _, exists := existingFields[colName]; !exists {
// Column doesn't exist, add it
if err := am.addColumn(ctx, table, colName, colDef); err != nil {
return fmt.Errorf("failed to add column %s: %w", colName, err)
}
} else {
// Column exists, check if modification is needed
if err := am.modifyColumnIfNeeded(ctx, table, colName, colDef, existingFields[colName]); err != nil {
return fmt.Errorf("failed to modify column %s: %w", colName, err)
}
}
}
return nil
}
// addColumn adds a new column to table.
func (am *AutoMigrateCore) addColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
migration := am.getMigrationInstance()
if migration == nil {
return fmt.Errorf("failed to get migration instance")
}
return migration.AddColumn(ctx, table, column, definition)
}
// modifyColumnIfNeeded checks if column needs modification and applies it.
func (am *AutoMigrateCore) modifyColumnIfNeeded(ctx context.Context, table, column string, newDef *ColumnDefinition, existingField *TableField) error {
// Compare and modify if needed
needsModification := false
// Check type
if !strings.EqualFold(existingField.Type, newDef.Type) {
needsModification = true
}
// Check nullability
if existingField.Null != newDef.Null {
needsModification = true
}
if needsModification {
return am.modifyColumn(ctx, table, column, newDef)
}
return nil
}
// modifyColumn modifies an existing column.
func (am *AutoMigrateCore) modifyColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
migration := am.getMigrationInstance()
if migration == nil {
return fmt.Errorf("failed to get migration instance")
}
return migration.ModifyColumn(ctx, table, column, definition)
}
// getMigrationInstance returns the appropriate Migration instance based on database type.
func (am *AutoMigrateCore) getMigrationInstance() Migration {
return am.db.GetMigration()
}
// camelToSnake converts CamelCase to snake_case.
func camelToSnake(s string) string {
var result strings.Builder
for i, r := range s {
if r >= 'A' && r <= 'Z' {
// Add underscore before uppercase letter if:
// 1. It's not the first character
// 2. The previous character is lowercase or the next character is lowercase
if i > 0 {
prevRune := rune(s[i-1])
if (prevRune >= 'a' && prevRune <= 'z') ||
(i+1 < len(s) && rune(s[i+1]) >= 'a' && rune(s[i+1]) <= 'z') {
result.WriteRune('_')
}
}
result.WriteRune(r + 32)
} else {
result.WriteRune(r)
}
}
return result.String()
}
// FormatTime formats time for default values.
func FormatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
// IsZeroValue checks if a value is zero value.
func IsZeroValue(v any) bool {
if v == nil {
return true
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return rv.Len() == 0
case reflect.Bool:
return !rv.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return rv.Uint() == 0
case reflect.Float32, reflect.Float64:
return rv.Float() == 0
case reflect.Interface, reflect.Ptr:
return rv.IsNil()
}
return false
}