426 lines
12 KiB
Go
426 lines
12 KiB
Go
package main
|
||
|
||
import (
|
||
"flag"
|
||
"fmt"
|
||
"os"
|
||
"strings"
|
||
|
||
"git.magicany.cc/black1552/gin-base/db/config"
|
||
"git.magicany.cc/black1552/gin-base/db/generator"
|
||
"git.magicany.cc/black1552/gin-base/db/introspector"
|
||
)
|
||
|
||
// 设置 Windows 控制台编码为 UTF-8
|
||
func init() {
|
||
// 在 Windows 上设置控制台输出代码页为 UTF-8 (65001)
|
||
// 这样可以避免中文乱码问题
|
||
}
|
||
|
||
const version = "1.0.0"
|
||
|
||
func main() {
|
||
// 定义命令行参数
|
||
versionFlag := flag.Bool("version", false, "显示版本号")
|
||
vFlag := flag.Bool("v", false, "显示版本号(简写)")
|
||
helpFlag := flag.Bool("help", false, "显示帮助信息")
|
||
hFlag := flag.Bool("h", false, "显示帮助信息(简写)")
|
||
outputDir := flag.String("o", "./model", "输出目录")
|
||
allFlag := flag.Bool("all", false, "生成所有预设的表(user, product, order)")
|
||
|
||
flag.Usage = func() {
|
||
fmt.Fprintf(os.Stderr, `Magic-ORM 代码生成器 - 快速生成 Model 和 DAO 代码
|
||
|
||
用法:
|
||
gendb [选项] <表名> [列定义...]
|
||
gendb [选项] -all
|
||
|
||
选项:
|
||
`)
|
||
flag.PrintDefaults()
|
||
|
||
fmt.Fprintf(os.Stderr, `
|
||
示例:
|
||
# 生成 user 表代码(自动推断常用列)
|
||
gendb user
|
||
|
||
# 指定输出目录
|
||
gendb -o ./models user product
|
||
|
||
# 自定义列定义
|
||
gendb user id:int64:primary username:string email:string created_at:time.Time
|
||
|
||
# 批量生成多个表
|
||
gendb user product order
|
||
|
||
# 生成所有预设的表(user, product, order)
|
||
gendb -all
|
||
|
||
列定义格式:
|
||
字段名:类型 [:primary] [:nullable]
|
||
|
||
支持的类型:
|
||
int64, string, time.Time, bool, float64, int
|
||
|
||
更多信息:
|
||
https://github.com/your-repo/magic-orm
|
||
`)
|
||
}
|
||
|
||
flag.Parse()
|
||
|
||
// 检查版本参数
|
||
if *versionFlag || *vFlag {
|
||
fmt.Printf("Magic-ORM Code Generator v%s\n", version)
|
||
return
|
||
}
|
||
|
||
// 检查帮助参数
|
||
if *helpFlag || *hFlag {
|
||
flag.Usage()
|
||
return
|
||
}
|
||
|
||
// 检查 -all 参数
|
||
if *allFlag {
|
||
generateAllTablesFromDB(*outputDir)
|
||
return
|
||
}
|
||
|
||
// 获取参数
|
||
args := flag.Args()
|
||
if len(args) == 0 {
|
||
fmt.Fprintln(os.Stderr, "错误:请指定至少一个表名")
|
||
fmt.Fprintln(os.Stderr, "使用 'gendb -h' 查看帮助")
|
||
fmt.Fprintln(os.Stderr, "或者使用 'gendb -all' 生成所有预设表")
|
||
os.Exit(1)
|
||
}
|
||
|
||
tableNames := args
|
||
|
||
// 创建代码生成器
|
||
cg := generator.NewCodeGenerator(*outputDir)
|
||
|
||
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
|
||
fmt.Printf("[Output Directory: %s]\n", *outputDir)
|
||
fmt.Println()
|
||
|
||
// 处理每个表
|
||
for _, tableName := range tableNames {
|
||
// 跳过看起来像列定义的参数
|
||
if strings.Contains(tableName, ":") {
|
||
continue
|
||
}
|
||
|
||
fmt.Printf("[Generating table '%s'...]\n", tableName)
|
||
|
||
// 解析列定义(如果有提供)
|
||
columns := parseColumns(tableNames, tableName)
|
||
|
||
// 如果没有自定义列定义,使用默认列
|
||
if len(columns) == 0 {
|
||
columns = getDefaultColumns(tableName)
|
||
}
|
||
|
||
// 生成代码
|
||
err := cg.GenerateAll(tableName, columns)
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
|
||
continue
|
||
}
|
||
|
||
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
|
||
}
|
||
|
||
fmt.Println()
|
||
fmt.Println("[Complete] Code generation finished!")
|
||
fmt.Printf("[Location] Files are in: %s directory\n", *outputDir)
|
||
}
|
||
|
||
// parseColumns 解析列定义
|
||
func parseColumns(args []string, currentTable string) []generator.ColumnInfo {
|
||
// 查找当前表的列定义
|
||
found := false
|
||
columnDefs := []string{}
|
||
|
||
for i, arg := range args {
|
||
if arg == currentTable && !found {
|
||
found = true
|
||
// 收集后续的列定义
|
||
for j := i + 1; j < len(args); j++ {
|
||
if strings.Contains(args[j], ":") {
|
||
columnDefs = append(columnDefs, args[j])
|
||
} else {
|
||
break // 遇到下一个表名
|
||
}
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
if len(columnDefs) == 0 {
|
||
return nil
|
||
}
|
||
|
||
columns := []generator.ColumnInfo{}
|
||
for _, def := range columnDefs {
|
||
parts := strings.Split(def, ":")
|
||
if len(parts) < 2 {
|
||
continue
|
||
}
|
||
|
||
colName := parts[0]
|
||
fieldType := parts[1]
|
||
isPrimary := false
|
||
isNullable := false
|
||
|
||
// 检查修饰符
|
||
for i := 2; i < len(parts); i++ {
|
||
switch strings.ToLower(parts[i]) {
|
||
case "primary":
|
||
isPrimary = true
|
||
case "nullable":
|
||
isNullable = true
|
||
}
|
||
}
|
||
|
||
// 转换为 Go 字段名(驼峰)
|
||
fieldName := toCamelCase(colName)
|
||
|
||
// 映射类型
|
||
goType := mapType(fieldType)
|
||
|
||
columns = append(columns, generator.ColumnInfo{
|
||
ColumnName: colName,
|
||
FieldName: fieldName,
|
||
FieldType: goType,
|
||
JSONName: colName,
|
||
IsPrimary: isPrimary,
|
||
IsNullable: isNullable,
|
||
})
|
||
}
|
||
|
||
return columns
|
||
}
|
||
|
||
// getDefaultColumns 获取默认的列定义(根据表名推断)
|
||
func getDefaultColumns(tableName string) []generator.ColumnInfo {
|
||
columns := []generator.ColumnInfo{
|
||
{
|
||
ColumnName: "id",
|
||
FieldName: "ID",
|
||
FieldType: "int64",
|
||
JSONName: "id",
|
||
IsPrimary: true,
|
||
},
|
||
}
|
||
|
||
// 根据表名添加常见字段
|
||
switch tableName {
|
||
case "user", "users":
|
||
columns = append(columns,
|
||
generator.ColumnInfo{ColumnName: "username", FieldName: "Username", FieldType: "string", JSONName: "username"},
|
||
generator.ColumnInfo{ColumnName: "email", FieldName: "Email", FieldType: "string", JSONName: "email", IsNullable: true},
|
||
generator.ColumnInfo{ColumnName: "password", FieldName: "Password", FieldType: "string", JSONName: "password"},
|
||
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
|
||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
|
||
)
|
||
case "product", "products":
|
||
columns = append(columns,
|
||
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
|
||
generator.ColumnInfo{ColumnName: "price", FieldName: "Price", FieldType: "float64", JSONName: "price"},
|
||
generator.ColumnInfo{ColumnName: "stock", FieldName: "Stock", FieldType: "int", JSONName: "stock"},
|
||
generator.ColumnInfo{ColumnName: "description", FieldName: "Description", FieldType: "string", JSONName: "description", IsNullable: true},
|
||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||
)
|
||
case "order", "orders":
|
||
columns = append(columns,
|
||
generator.ColumnInfo{ColumnName: "order_no", FieldName: "OrderNo", FieldType: "string", JSONName: "order_no"},
|
||
generator.ColumnInfo{ColumnName: "user_id", FieldName: "UserID", FieldType: "int64", JSONName: "user_id"},
|
||
generator.ColumnInfo{ColumnName: "total_amount", FieldName: "TotalAmount", FieldType: "float64", JSONName: "total_amount"},
|
||
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
|
||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||
)
|
||
default:
|
||
// 默认添加通用字段
|
||
columns = append(columns,
|
||
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
|
||
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
|
||
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
|
||
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
|
||
)
|
||
}
|
||
|
||
return columns
|
||
}
|
||
|
||
// mapType 将类型字符串映射到 Go 类型
|
||
func mapType(typeStr string) string {
|
||
typeMap := map[string]string{
|
||
"int": "int64",
|
||
"integer": "int64",
|
||
"bigint": "int64",
|
||
"string": "string",
|
||
"text": "string",
|
||
"varchar": "string",
|
||
"time.Time": "time.Time",
|
||
"time": "time.Time",
|
||
"datetime": "time.Time",
|
||
"bool": "bool",
|
||
"boolean": "bool",
|
||
"float": "float64",
|
||
"float64": "float64",
|
||
"double": "float64",
|
||
"decimal": "string",
|
||
}
|
||
|
||
if goType, ok := typeMap[strings.ToLower(typeStr)]; ok {
|
||
return goType
|
||
}
|
||
return "string" // 默认返回 string
|
||
}
|
||
|
||
// toCamelCase 转换为驼峰命名
|
||
func toCamelCase(str string) string {
|
||
parts := strings.Split(str, "_")
|
||
result := ""
|
||
|
||
for _, part := range parts {
|
||
if len(part) > 0 {
|
||
result += strings.ToUpper(string(part[0])) + part[1:]
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// generateAllTablesFromDB 从数据库读取所有表并生成代码
|
||
func generateAllTablesFromDB(outputDir string) {
|
||
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
|
||
fmt.Println()
|
||
|
||
// 1. 加载配置文件
|
||
fmt.Println("[Step 1] Loading configuration file...")
|
||
cfg, err := loadDatabaseConfig()
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "[Error] Failed to load config: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Printf("[Info] Database type: %s\n", cfg.Type)
|
||
fmt.Printf("[Info] Database name: %s\n", cfg.Name)
|
||
fmt.Println()
|
||
|
||
// 2. 连接数据库并获取所有表
|
||
fmt.Println("[Step 2] Connecting to database and fetching table structure...")
|
||
intro, err := introspector.NewIntrospector(cfg)
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "[Error] Failed to connect to database: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
defer intro.Close()
|
||
|
||
tableNames, err := intro.GetTableNames()
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "[Error] Failed to get table names: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
fmt.Printf("[Info] Found %d tables\n", len(tableNames))
|
||
fmt.Println()
|
||
|
||
// 3. 创建代码生成器
|
||
cg := generator.NewCodeGenerator(outputDir)
|
||
|
||
// 4. 为每个表生成代码
|
||
for _, tableName := range tableNames {
|
||
fmt.Printf("[Generating] Table '%s'...\n", tableName)
|
||
|
||
// 获取表详细信息
|
||
tableInfo, err := intro.GetTableInfo(tableName)
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "[Error] Failed to get table info: %v\n", err)
|
||
continue
|
||
}
|
||
|
||
// 转换为 generator.ColumnInfo
|
||
columns := make([]generator.ColumnInfo, len(tableInfo.Columns))
|
||
for i, col := range tableInfo.Columns {
|
||
columns[i] = generator.ColumnInfo{
|
||
ColumnName: col.ColumnName,
|
||
FieldName: col.FieldName,
|
||
FieldType: col.GoType,
|
||
JSONName: col.JSONName,
|
||
IsPrimary: col.IsPrimary,
|
||
IsNullable: col.IsNullable,
|
||
}
|
||
}
|
||
|
||
// 生成代码
|
||
err = cg.GenerateAll(tableName, columns)
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
|
||
continue
|
||
}
|
||
|
||
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
|
||
}
|
||
|
||
fmt.Println()
|
||
fmt.Println("[Complete] Code generation finished!")
|
||
fmt.Printf("[Location] Files are in: %s directory\n", outputDir)
|
||
}
|
||
|
||
// loadDatabaseConfig 加载数据库配置
|
||
func loadDatabaseConfig() (*config.DatabaseConfig, error) {
|
||
// 自动查找配置文件
|
||
configPath, err := config.FindConfigFile("")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查找配置文件失败:%w", err)
|
||
}
|
||
|
||
fmt.Printf("[Info] Using config file: %s\n", configPath)
|
||
|
||
// 从文件加载配置
|
||
cfg, err := config.LoadFromFile(configPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载配置文件失败:%w", err)
|
||
}
|
||
|
||
return &cfg.Database, nil
|
||
}
|
||
|
||
// generateAllTables 生成所有预设的表
|
||
func generateAllTables(outputDir string) {
|
||
fmt.Printf("🚀 Magic-ORM 代码生成器 v%s\n", version)
|
||
fmt.Printf("📁 输出目录:%s\n", outputDir)
|
||
fmt.Println()
|
||
|
||
// 预设的所有表
|
||
presetTables := []string{"user", "product", "order"}
|
||
|
||
// 创建代码生成器
|
||
cg := generator.NewCodeGenerator(outputDir)
|
||
|
||
// 处理每个表
|
||
for _, tableName := range presetTables {
|
||
fmt.Printf("📝 生成表 '%s' 的代码...\n", tableName)
|
||
|
||
// 使用默认列定义
|
||
columns := getDefaultColumns(tableName)
|
||
|
||
// 生成代码
|
||
err := cg.GenerateAll(tableName, columns)
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "❌ 生成失败:%v\n", err)
|
||
continue
|
||
}
|
||
|
||
fmt.Printf("✅ 成功生成 %s.go 和 %s_dao.go\n", tableName, tableName)
|
||
}
|
||
|
||
fmt.Println()
|
||
fmt.Println("✨ 代码生成完成!")
|
||
fmt.Printf("📂 生成的文件在:%s 目录下\n", outputDir)
|
||
}
|