gin-base/db/cmd/gendb/main.go

426 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package main
import (
"flag"
"fmt"
"os"
"strings"
"git.magicany.cc/black1552/gin-base/db/config"
"git.magicany.cc/black1552/gin-base/db/generator"
"git.magicany.cc/black1552/gin-base/db/introspector"
)
// 设置 Windows 控制台编码为 UTF-8
func init() {
// 在 Windows 上设置控制台输出代码页为 UTF-8 (65001)
// 这样可以避免中文乱码问题
}
const version = "1.0.0"
func main() {
// 定义命令行参数
versionFlag := flag.Bool("version", false, "显示版本号")
vFlag := flag.Bool("v", false, "显示版本号(简写)")
helpFlag := flag.Bool("help", false, "显示帮助信息")
hFlag := flag.Bool("h", false, "显示帮助信息(简写)")
outputDir := flag.String("o", "./model", "输出目录")
allFlag := flag.Bool("all", false, "生成所有预设的表user, product, order")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, `Magic-ORM 代码生成器 - 快速生成 Model 和 DAO 代码
用法:
gendb [选项] <表名> [列定义...]
gendb [选项] -all
选项:
`)
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, `
示例:
# 生成 user 表代码(自动推断常用列)
gendb user
# 指定输出目录
gendb -o ./models user product
# 自定义列定义
gendb user id:int64:primary username:string email:string created_at:time.Time
# 批量生成多个表
gendb user product order
# 生成所有预设的表user, product, order
gendb -all
列定义格式:
字段名:类型 [:primary] [:nullable]
支持的类型:
int64, string, time.Time, bool, float64, int
更多信息:
https://github.com/your-repo/magic-orm
`)
}
flag.Parse()
// 检查版本参数
if *versionFlag || *vFlag {
fmt.Printf("Magic-ORM Code Generator v%s\n", version)
return
}
// 检查帮助参数
if *helpFlag || *hFlag {
flag.Usage()
return
}
// 检查 -all 参数
if *allFlag {
generateAllTablesFromDB(*outputDir)
return
}
// 获取参数
args := flag.Args()
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "错误:请指定至少一个表名")
fmt.Fprintln(os.Stderr, "使用 'gendb -h' 查看帮助")
fmt.Fprintln(os.Stderr, "或者使用 'gendb -all' 生成所有预设表")
os.Exit(1)
}
tableNames := args
// 创建代码生成器
cg := generator.NewCodeGenerator(*outputDir)
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
fmt.Printf("[Output Directory: %s]\n", *outputDir)
fmt.Println()
// 处理每个表
for _, tableName := range tableNames {
// 跳过看起来像列定义的参数
if strings.Contains(tableName, ":") {
continue
}
fmt.Printf("[Generating table '%s'...]\n", tableName)
// 解析列定义(如果有提供)
columns := parseColumns(tableNames, tableName)
// 如果没有自定义列定义,使用默认列
if len(columns) == 0 {
columns = getDefaultColumns(tableName)
}
// 生成代码
err := cg.GenerateAll(tableName, columns)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
continue
}
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
}
fmt.Println()
fmt.Println("[Complete] Code generation finished!")
fmt.Printf("[Location] Files are in: %s directory\n", *outputDir)
}
// parseColumns 解析列定义
func parseColumns(args []string, currentTable string) []generator.ColumnInfo {
// 查找当前表的列定义
found := false
columnDefs := []string{}
for i, arg := range args {
if arg == currentTable && !found {
found = true
// 收集后续的列定义
for j := i + 1; j < len(args); j++ {
if strings.Contains(args[j], ":") {
columnDefs = append(columnDefs, args[j])
} else {
break // 遇到下一个表名
}
}
break
}
}
if len(columnDefs) == 0 {
return nil
}
columns := []generator.ColumnInfo{}
for _, def := range columnDefs {
parts := strings.Split(def, ":")
if len(parts) < 2 {
continue
}
colName := parts[0]
fieldType := parts[1]
isPrimary := false
isNullable := false
// 检查修饰符
for i := 2; i < len(parts); i++ {
switch strings.ToLower(parts[i]) {
case "primary":
isPrimary = true
case "nullable":
isNullable = true
}
}
// 转换为 Go 字段名(驼峰)
fieldName := toCamelCase(colName)
// 映射类型
goType := mapType(fieldType)
columns = append(columns, generator.ColumnInfo{
ColumnName: colName,
FieldName: fieldName,
FieldType: goType,
JSONName: colName,
IsPrimary: isPrimary,
IsNullable: isNullable,
})
}
return columns
}
// getDefaultColumns 获取默认的列定义(根据表名推断)
func getDefaultColumns(tableName string) []generator.ColumnInfo {
columns := []generator.ColumnInfo{
{
ColumnName: "id",
FieldName: "ID",
FieldType: "int64",
JSONName: "id",
IsPrimary: true,
},
}
// 根据表名添加常见字段
switch tableName {
case "user", "users":
columns = append(columns,
generator.ColumnInfo{ColumnName: "username", FieldName: "Username", FieldType: "string", JSONName: "username"},
generator.ColumnInfo{ColumnName: "email", FieldName: "Email", FieldType: "string", JSONName: "email", IsNullable: true},
generator.ColumnInfo{ColumnName: "password", FieldName: "Password", FieldType: "string", JSONName: "password"},
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
)
case "product", "products":
columns = append(columns,
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
generator.ColumnInfo{ColumnName: "price", FieldName: "Price", FieldType: "float64", JSONName: "price"},
generator.ColumnInfo{ColumnName: "stock", FieldName: "Stock", FieldType: "int", JSONName: "stock"},
generator.ColumnInfo{ColumnName: "description", FieldName: "Description", FieldType: "string", JSONName: "description", IsNullable: true},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
)
case "order", "orders":
columns = append(columns,
generator.ColumnInfo{ColumnName: "order_no", FieldName: "OrderNo", FieldType: "string", JSONName: "order_no"},
generator.ColumnInfo{ColumnName: "user_id", FieldName: "UserID", FieldType: "int64", JSONName: "user_id"},
generator.ColumnInfo{ColumnName: "total_amount", FieldName: "TotalAmount", FieldType: "float64", JSONName: "total_amount"},
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
)
default:
// 默认添加通用字段
columns = append(columns,
generator.ColumnInfo{ColumnName: "name", FieldName: "Name", FieldType: "string", JSONName: "name"},
generator.ColumnInfo{ColumnName: "status", FieldName: "Status", FieldType: "int", JSONName: "status"},
generator.ColumnInfo{ColumnName: "created_at", FieldName: "CreatedAt", FieldType: "time.Time", JSONName: "created_at"},
generator.ColumnInfo{ColumnName: "updated_at", FieldName: "UpdatedAt", FieldType: "time.Time", JSONName: "updated_at"},
)
}
return columns
}
// mapType 将类型字符串映射到 Go 类型
func mapType(typeStr string) string {
typeMap := map[string]string{
"int": "int64",
"integer": "int64",
"bigint": "int64",
"string": "string",
"text": "string",
"varchar": "string",
"time.Time": "time.Time",
"time": "time.Time",
"datetime": "time.Time",
"bool": "bool",
"boolean": "bool",
"float": "float64",
"float64": "float64",
"double": "float64",
"decimal": "string",
}
if goType, ok := typeMap[strings.ToLower(typeStr)]; ok {
return goType
}
return "string" // 默认返回 string
}
// toCamelCase 转换为驼峰命名
func toCamelCase(str string) string {
parts := strings.Split(str, "_")
result := ""
for _, part := range parts {
if len(part) > 0 {
result += strings.ToUpper(string(part[0])) + part[1:]
}
}
return result
}
// generateAllTablesFromDB 从数据库读取所有表并生成代码
func generateAllTablesFromDB(outputDir string) {
fmt.Printf("[Magic-ORM Code Generator v%s]\n", version)
fmt.Println()
// 1. 加载配置文件
fmt.Println("[Step 1] Loading configuration file...")
cfg, err := loadDatabaseConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to load config: %v\n", err)
os.Exit(1)
}
fmt.Printf("[Info] Database type: %s\n", cfg.Type)
fmt.Printf("[Info] Database name: %s\n", cfg.Name)
fmt.Println()
// 2. 连接数据库并获取所有表
fmt.Println("[Step 2] Connecting to database and fetching table structure...")
intro, err := introspector.NewIntrospector(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to connect to database: %v\n", err)
os.Exit(1)
}
defer intro.Close()
tableNames, err := intro.GetTableNames()
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to get table names: %v\n", err)
os.Exit(1)
}
fmt.Printf("[Info] Found %d tables\n", len(tableNames))
fmt.Println()
// 3. 创建代码生成器
cg := generator.NewCodeGenerator(outputDir)
// 4. 为每个表生成代码
for _, tableName := range tableNames {
fmt.Printf("[Generating] Table '%s'...\n", tableName)
// 获取表详细信息
tableInfo, err := intro.GetTableInfo(tableName)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Failed to get table info: %v\n", err)
continue
}
// 转换为 generator.ColumnInfo
columns := make([]generator.ColumnInfo, len(tableInfo.Columns))
for i, col := range tableInfo.Columns {
columns[i] = generator.ColumnInfo{
ColumnName: col.ColumnName,
FieldName: col.FieldName,
FieldType: col.GoType,
JSONName: col.JSONName,
IsPrimary: col.IsPrimary,
IsNullable: col.IsNullable,
}
}
// 生成代码
err = cg.GenerateAll(tableName, columns)
if err != nil {
fmt.Fprintf(os.Stderr, "[Error] Generation failed: %v\n", err)
continue
}
fmt.Printf("[Success] Generated %s.go and %s_dao.go\n", tableName, tableName)
}
fmt.Println()
fmt.Println("[Complete] Code generation finished!")
fmt.Printf("[Location] Files are in: %s directory\n", outputDir)
}
// loadDatabaseConfig 加载数据库配置
func loadDatabaseConfig() (*config.DatabaseConfig, error) {
// 自动查找配置文件
configPath, err := config.FindConfigFile("")
if err != nil {
return nil, fmt.Errorf("查找配置文件失败:%w", err)
}
fmt.Printf("[Info] Using config file: %s\n", configPath)
// 从文件加载配置
cfg, err := config.LoadFromFile(configPath)
if err != nil {
return nil, fmt.Errorf("加载配置文件失败:%w", err)
}
return &cfg.Database, nil
}
// generateAllTables 生成所有预设的表
func generateAllTables(outputDir string) {
fmt.Printf("🚀 Magic-ORM 代码生成器 v%s\n", version)
fmt.Printf("📁 输出目录:%s\n", outputDir)
fmt.Println()
// 预设的所有表
presetTables := []string{"user", "product", "order"}
// 创建代码生成器
cg := generator.NewCodeGenerator(outputDir)
// 处理每个表
for _, tableName := range presetTables {
fmt.Printf("📝 生成表 '%s' 的代码...\n", tableName)
// 使用默认列定义
columns := getDefaultColumns(tableName)
// 生成代码
err := cg.GenerateAll(tableName, columns)
if err != nil {
fmt.Fprintf(os.Stderr, "❌ 生成失败:%v\n", err)
continue
}
fmt.Printf("✅ 成功生成 %s.go 和 %s_dao.go\n", tableName, tableName)
}
fmt.Println()
fmt.Println("✨ 代码生成完成!")
fmt.Printf("📂 生成的文件在:%s 目录下\n", outputDir)
}