500 lines
14 KiB
Go
500 lines
14 KiB
Go
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the MIT License.
|
|
// If a copy of the MIT was not distributed with this file,
|
|
// You can obtain one at https://github.com/gogf/gf.
|
|
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// AutoMigrateCore provides automatic migration functionality based on entity structs.
|
|
type AutoMigrateCore struct {
|
|
db DB
|
|
}
|
|
|
|
// NewAutoMigrateCore creates a new AutoMigrateCore instance.
|
|
func NewAutoMigrateCore(db DB) *AutoMigrateCore {
|
|
return &AutoMigrateCore{db: db}
|
|
}
|
|
|
|
// AutoMigrate automatically creates or updates tables based on entity structs.
|
|
func (am *AutoMigrateCore) AutoMigrate(ctx context.Context, entities ...any) error {
|
|
for _, entity := range entities {
|
|
if err := am.migrateEntity(ctx, entity); err != nil {
|
|
return fmt.Errorf("failed to migrate entity %T: %w", entity, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// migrateEntity migrates a single entity to database.
|
|
func (am *AutoMigrateCore) migrateEntity(ctx context.Context, entity any) error {
|
|
// Get table name and columns from entity
|
|
tableName, columns, err := am.parseEntity(entity)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(columns) == 0 {
|
|
return fmt.Errorf("no columns found for table %s", tableName)
|
|
}
|
|
|
|
// For SQLite, ensure the database file directory exists
|
|
if am.db.GetConfig().Type == "sqlite" {
|
|
if err := am.ensureSQLiteDirectory(); err != nil {
|
|
return fmt.Errorf("failed to prepare SQLite database: %w", err)
|
|
}
|
|
}
|
|
|
|
// Check if table exists
|
|
hasTable, err := am.db.HasTable(ctx, tableName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check table existence: %w", err)
|
|
}
|
|
|
|
if !hasTable {
|
|
// Create table
|
|
return am.createTableFromColumns(ctx, tableName, columns)
|
|
}
|
|
|
|
// Update table structure
|
|
return am.updateTableStructure(ctx, tableName, columns)
|
|
}
|
|
|
|
// parseEntity parses an entity struct and returns table name and column definitions.
|
|
func (am *AutoMigrateCore) parseEntity(entity any) (string, map[string]*ColumnDefinition, error) {
|
|
val := reflect.ValueOf(entity)
|
|
|
|
// Handle pointer
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
|
|
if val.Kind() != reflect.Struct {
|
|
return "", nil, fmt.Errorf("entity must be a struct or pointer to struct")
|
|
}
|
|
|
|
typ := val.Type()
|
|
|
|
// Get table name from orm tag or struct name
|
|
tableName := am.getTableName(typ)
|
|
|
|
// Parse columns
|
|
columns := make(map[string]*ColumnDefinition)
|
|
for i := 0; i < val.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
fieldValue := val.Field(i)
|
|
|
|
// Skip unexported fields
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
|
|
// Parse field to column definition
|
|
colName, colDef, err := am.parseField(field, fieldValue)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to parse field %s: %w", field.Name, err)
|
|
}
|
|
|
|
if colName != "" && colDef != nil {
|
|
columns[colName] = colDef
|
|
}
|
|
}
|
|
|
|
return tableName, columns, nil
|
|
}
|
|
|
|
// getTableName extracts table name from struct tags or generates from struct name.
|
|
func (am *AutoMigrateCore) getTableName(typ reflect.Type) string {
|
|
// Check for orm tag
|
|
if tag, ok := typ.FieldByName("Meta"); ok {
|
|
ormTag := tag.Tag.Get("orm")
|
|
if ormTag != "" {
|
|
// Parse table:name from orm tag
|
|
parts := strings.Split(ormTag, ",")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if strings.HasPrefix(part, "table:") {
|
|
return strings.TrimPrefix(part, "table:")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for table method (commented out as it requires instance)
|
|
// method, hasMethod := typ.MethodByName("TableName")
|
|
// if hasMethod {
|
|
// // Try to call TableName method if it exists
|
|
// }
|
|
|
|
// Convert struct name to snake_case table name
|
|
return camelToSnake(typ.Name())
|
|
}
|
|
|
|
// parseField parses a struct field into a column definition.
|
|
func (am *AutoMigrateCore) parseField(field reflect.StructField, fieldValue reflect.Value) (string, *ColumnDefinition, error) {
|
|
// Check orm tag
|
|
ormTag := field.Tag.Get("orm")
|
|
if ormTag == "-" {
|
|
// Skip field
|
|
return "", nil, nil
|
|
}
|
|
|
|
// Get column name
|
|
colName := am.getColumnName(field, ormTag)
|
|
|
|
// Build column definition
|
|
colDef := &ColumnDefinition{
|
|
Type: am.getFieldType(field, fieldValue),
|
|
Null: true, // Default to nullable
|
|
}
|
|
|
|
// Parse orm tag options
|
|
if ormTag != "" {
|
|
am.parseOrmTag(colDef, ormTag)
|
|
}
|
|
|
|
// Check for gorm/tag compatibility
|
|
if gormTag := field.Tag.Get("gorm"); gormTag != "" {
|
|
am.parseGormTag(colDef, gormTag)
|
|
}
|
|
|
|
// Check json tag for field presence
|
|
jsonTag := field.Tag.Get("json")
|
|
if jsonTag == "-" {
|
|
return "", nil, nil
|
|
}
|
|
|
|
return colName, colDef, nil
|
|
}
|
|
|
|
// getColumnName extracts column name from field name or tags.
|
|
func (am *AutoMigrateCore) getColumnName(field reflect.StructField, ormTag string) string {
|
|
// Check orm tag for explicit column name
|
|
if ormTag != "" {
|
|
parts := strings.Split(ormTag, ",")
|
|
namePart := strings.TrimSpace(parts[0])
|
|
if namePart != "" && !strings.Contains(namePart, ":") {
|
|
return namePart
|
|
}
|
|
}
|
|
|
|
// Use field name converted to snake_case
|
|
return camelToSnake(field.Name)
|
|
}
|
|
|
|
// getFieldType determines the database type for a Go field type.
|
|
func (am *AutoMigrateCore) getFieldType(field reflect.StructField, fieldValue reflect.Value) string {
|
|
// Check for explicit type in orm tag
|
|
ormTag := field.Tag.Get("orm")
|
|
if ormTag != "" {
|
|
parts := strings.Split(ormTag, ",")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if strings.HasPrefix(part, "type:") {
|
|
return strings.TrimPrefix(part, "type:")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Infer type from Go type
|
|
switch fieldValue.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return "BIGINT"
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
return "BIGINT UNSIGNED"
|
|
case reflect.Float32:
|
|
return "FLOAT"
|
|
case reflect.Float64:
|
|
return "DOUBLE"
|
|
case reflect.Bool:
|
|
return "BOOLEAN"
|
|
case reflect.String:
|
|
// Check for length specification
|
|
length := am.getStringLength(field)
|
|
if length > 0 {
|
|
return fmt.Sprintf("VARCHAR(%d)", length)
|
|
}
|
|
return "TEXT"
|
|
case reflect.Struct:
|
|
// Handle special types
|
|
typeName := fieldValue.Type().String()
|
|
switch {
|
|
case strings.Contains(typeName, "time.Time"):
|
|
return "TIMESTAMP"
|
|
default:
|
|
return "TEXT"
|
|
}
|
|
case reflect.Slice:
|
|
elemKind := fieldValue.Type().Elem().Kind()
|
|
if elemKind == reflect.Uint8 {
|
|
return "BLOB"
|
|
}
|
|
return "JSON"
|
|
default:
|
|
return "TEXT"
|
|
}
|
|
}
|
|
|
|
// getStringLength gets the specified string length from tags.
|
|
func (am *AutoMigrateCore) getStringLength(field reflect.StructField) int {
|
|
// Check orm tag
|
|
ormTag := field.Tag.Get("orm")
|
|
if ormTag != "" {
|
|
parts := strings.Split(ormTag, ",")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if strings.HasPrefix(part, "length:") {
|
|
var length int
|
|
fmt.Sscanf(strings.TrimPrefix(part, "length:"), "%d", &length)
|
|
return length
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check gorm tag
|
|
gormTag := field.Tag.Get("gorm")
|
|
if gormTag != "" {
|
|
parts := strings.Split(gormTag, ";")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if strings.HasPrefix(part, "size:") {
|
|
var size int
|
|
fmt.Sscanf(strings.TrimPrefix(part, "size:"), "%d", &size)
|
|
return size
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
// parseOrmTag parses orm tag options.
|
|
func (am *AutoMigrateCore) parseOrmTag(colDef *ColumnDefinition, ormTag string) {
|
|
parts := strings.Split(ormTag, ",")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
switch {
|
|
case part == "pk" || part == "primary_key":
|
|
colDef.PrimaryKey = true
|
|
colDef.Null = false
|
|
case part == "auto_increment":
|
|
colDef.AutoIncrement = true
|
|
case part == "not_null":
|
|
colDef.Null = false
|
|
case part == "unique":
|
|
colDef.Unique = true
|
|
case strings.HasPrefix(part, "default:"):
|
|
defaultVal := strings.TrimPrefix(part, "default:")
|
|
colDef.Default = defaultVal
|
|
case strings.HasPrefix(part, "comment:"):
|
|
colDef.Comment = strings.TrimPrefix(part, "comment:")
|
|
}
|
|
}
|
|
}
|
|
|
|
// parseGormTag parses gorm tag options for compatibility.
|
|
func (am *AutoMigrateCore) parseGormTag(colDef *ColumnDefinition, gormTag string) {
|
|
parts := strings.Split(gormTag, ";")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
lowerPart := strings.ToLower(part)
|
|
switch {
|
|
case lowerPart == "primarykey" || lowerPart == "primaryKey":
|
|
colDef.PrimaryKey = true
|
|
colDef.Null = false
|
|
case lowerPart == "autoincrement":
|
|
colDef.AutoIncrement = true
|
|
case lowerPart == "not null":
|
|
colDef.Null = false
|
|
case lowerPart == "unique":
|
|
colDef.Unique = true
|
|
case strings.HasPrefix(lowerPart, "default:"):
|
|
defaultVal := strings.TrimPrefix(part, "default:")
|
|
colDef.Default = defaultVal
|
|
case strings.HasPrefix(lowerPart, "comment:"):
|
|
colDef.Comment = strings.TrimPrefix(part, "comment:")
|
|
}
|
|
}
|
|
}
|
|
|
|
// createTableFromColumns creates a table from column definitions.
|
|
func (am *AutoMigrateCore) createTableFromColumns(ctx context.Context, table string, columns map[string]*ColumnDefinition) error {
|
|
// Get the migration instance based on database type
|
|
migration := am.getMigrationInstance()
|
|
if migration == nil {
|
|
return fmt.Errorf("failed to get migration instance for database type %s", am.db.GetConfig().Type)
|
|
}
|
|
|
|
return migration.CreateTable(ctx, table, columns)
|
|
}
|
|
|
|
// updateTableStructure updates table structure by comparing with existing columns.
|
|
func (am *AutoMigrateCore) updateTableStructure(ctx context.Context, table string, newColumns map[string]*ColumnDefinition) error {
|
|
// Get existing columns
|
|
existingFields, err := am.db.TableFields(ctx, table)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get table fields: %w", err)
|
|
}
|
|
|
|
// Add missing columns
|
|
for colName, colDef := range newColumns {
|
|
if _, exists := existingFields[colName]; !exists {
|
|
// Column doesn't exist, add it
|
|
if err := am.addColumn(ctx, table, colName, colDef); err != nil {
|
|
return fmt.Errorf("failed to add column %s: %w", colName, err)
|
|
}
|
|
} else {
|
|
// Column exists, check if modification is needed
|
|
if err := am.modifyColumnIfNeeded(ctx, table, colName, colDef, existingFields[colName]); err != nil {
|
|
return fmt.Errorf("failed to modify column %s: %w", colName, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addColumn adds a new column to table.
|
|
func (am *AutoMigrateCore) addColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
|
|
migration := am.getMigrationInstance()
|
|
if migration == nil {
|
|
return fmt.Errorf("failed to get migration instance")
|
|
}
|
|
return migration.AddColumn(ctx, table, column, definition)
|
|
}
|
|
|
|
// modifyColumnIfNeeded checks if column needs modification and applies it.
|
|
func (am *AutoMigrateCore) modifyColumnIfNeeded(ctx context.Context, table, column string, newDef *ColumnDefinition, existingField *TableField) error {
|
|
// Compare and modify if needed
|
|
needsModification := false
|
|
|
|
// Check type
|
|
if !strings.EqualFold(existingField.Type, newDef.Type) {
|
|
needsModification = true
|
|
}
|
|
|
|
// Check nullability
|
|
if existingField.Null != newDef.Null {
|
|
needsModification = true
|
|
}
|
|
|
|
if needsModification {
|
|
return am.modifyColumn(ctx, table, column, newDef)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// modifyColumn modifies an existing column.
|
|
func (am *AutoMigrateCore) modifyColumn(ctx context.Context, table, column string, definition *ColumnDefinition) error {
|
|
migration := am.getMigrationInstance()
|
|
if migration == nil {
|
|
return fmt.Errorf("failed to get migration instance")
|
|
}
|
|
return migration.ModifyColumn(ctx, table, column, definition)
|
|
}
|
|
|
|
// getMigrationInstance returns the appropriate Migration instance based on database type.
|
|
func (am *AutoMigrateCore) getMigrationInstance() Migration {
|
|
return am.db.GetMigration()
|
|
}
|
|
|
|
// camelToSnake converts CamelCase to snake_case.
|
|
func camelToSnake(s string) string {
|
|
var result strings.Builder
|
|
for i, r := range s {
|
|
if r >= 'A' && r <= 'Z' {
|
|
// Add underscore before uppercase letter if:
|
|
// 1. It's not the first character
|
|
// 2. The previous character is lowercase or the next character is lowercase
|
|
if i > 0 {
|
|
prevRune := rune(s[i-1])
|
|
if (prevRune >= 'a' && prevRune <= 'z') ||
|
|
(i+1 < len(s) && rune(s[i+1]) >= 'a' && rune(s[i+1]) <= 'z') {
|
|
result.WriteRune('_')
|
|
}
|
|
}
|
|
result.WriteRune(r + 32)
|
|
} else {
|
|
result.WriteRune(r)
|
|
}
|
|
}
|
|
return result.String()
|
|
}
|
|
|
|
// FormatTime formats time for default values.
|
|
func FormatTime(t time.Time) string {
|
|
return t.Format("2006-01-02 15:04:05")
|
|
}
|
|
|
|
// IsZeroValue checks if a value is zero value.
|
|
func IsZeroValue(v any) bool {
|
|
if v == nil {
|
|
return true
|
|
}
|
|
rv := reflect.ValueOf(v)
|
|
switch rv.Kind() {
|
|
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
|
|
return rv.Len() == 0
|
|
case reflect.Bool:
|
|
return !rv.Bool()
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return rv.Int() == 0
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
return rv.Uint() == 0
|
|
case reflect.Float32, reflect.Float64:
|
|
return rv.Float() == 0
|
|
case reflect.Interface, reflect.Ptr:
|
|
return rv.IsNil()
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ensureSQLiteDirectory ensures the SQLite database file directory exists.
|
|
func (am *AutoMigrateCore) ensureSQLiteDirectory() error {
|
|
config := am.db.GetConfig()
|
|
if config.Type != "sqlite" {
|
|
return nil
|
|
}
|
|
|
|
// Get the database file path from config.Name or config.Link
|
|
dbPath := config.Name
|
|
if dbPath == "" && config.Link != "" {
|
|
// Parse link format: sqlite::@file(./data/company.db)
|
|
if strings.Contains(config.Link, "@file(") {
|
|
start := strings.Index(config.Link, "@file(") + 6
|
|
end := strings.Index(config.Link[start:], ")")
|
|
if end > 0 {
|
|
dbPath = config.Link[start : start+end]
|
|
}
|
|
}
|
|
}
|
|
|
|
if dbPath == "" {
|
|
return fmt.Errorf("SQLite database path is empty")
|
|
}
|
|
|
|
// Get directory path
|
|
dir := filepath.Dir(dbPath)
|
|
|
|
// Check if directory exists
|
|
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
|
// Create directory with all parents
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|