407 lines
9.6 KiB
Go
407 lines
9.6 KiB
Go
package introspector
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"git.magicany.cc/black1552/gin-base/db/config"
|
||
_ "github.com/go-sql-driver/mysql"
|
||
)
|
||
|
||
// TableInfo 表信息
|
||
type TableInfo struct {
|
||
TableName string // 表名
|
||
Columns []ColumnInfo // 列信息
|
||
}
|
||
|
||
// ColumnInfo 列信息
|
||
type ColumnInfo struct {
|
||
ColumnName string // 列名
|
||
DataType string // 数据类型
|
||
IsNullable bool // 是否可为空
|
||
ColumnKey string // 键类型(PRI, MUL 等)
|
||
ColumnDefault string // 默认值
|
||
Extra string // 额外信息(auto_increment 等)
|
||
GoType string // Go 类型
|
||
FieldName string // Go 字段名(驼峰)
|
||
JSONName string // JSON 标签名
|
||
IsPrimary bool // 是否主键
|
||
}
|
||
|
||
// Introspector 数据库结构检查器
|
||
type Introspector struct {
|
||
db *sql.DB
|
||
config *config.DatabaseConfig
|
||
}
|
||
|
||
// NewIntrospector 创建数据库结构检查器
|
||
func NewIntrospector(cfg *config.DatabaseConfig) (*Introspector, error) {
|
||
dsn := cfg.BuildDSN()
|
||
db, err := sql.Open(cfg.GetDriverName(), dsn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("打开数据库连接失败:%w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := db.Ping(); err != nil {
|
||
return nil, fmt.Errorf("连接数据库失败:%w", err)
|
||
}
|
||
|
||
return &Introspector{
|
||
db: db,
|
||
config: cfg,
|
||
}, nil
|
||
}
|
||
|
||
// Close 关闭数据库连接
|
||
func (i *Introspector) Close() error {
|
||
return i.db.Close()
|
||
}
|
||
|
||
// GetTableNames 获取所有表名
|
||
func (i *Introspector) GetTableNames() ([]string, error) {
|
||
switch i.config.Type {
|
||
case "mysql":
|
||
return i.getMySQLTableNames()
|
||
case "postgres":
|
||
return i.getPostgresTableNames()
|
||
case "sqlite":
|
||
return i.getSQLiteTableNames()
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型:%s", i.config.Type)
|
||
}
|
||
}
|
||
|
||
// getMySQLTableNames 获取 MySQL 所有表名
|
||
func (i *Introspector) getMySQLTableNames() ([]string, error) {
|
||
query := `
|
||
SELECT TABLE_NAME
|
||
FROM INFORMATION_SCHEMA.TABLES
|
||
WHERE TABLE_SCHEMA = ?
|
||
ORDER BY TABLE_NAME
|
||
`
|
||
|
||
rows, err := i.db.Query(query, i.config.Name)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询表名失败:%w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
tableNames := []string{}
|
||
for rows.Next() {
|
||
var tableName string
|
||
if err := rows.Scan(&tableName); err != nil {
|
||
return nil, fmt.Errorf("扫描表名失败:%w", err)
|
||
}
|
||
tableNames = append(tableNames, tableName)
|
||
}
|
||
|
||
return tableNames, nil
|
||
}
|
||
|
||
// getPostgresTableNames 获取 PostgreSQL 所有表名
|
||
func (i *Introspector) getPostgresTableNames() ([]string, error) {
|
||
query := `
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_schema = 'public'
|
||
ORDER BY table_name
|
||
`
|
||
|
||
rows, err := i.db.Query(query)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询表名失败:%w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
tableNames := []string{}
|
||
for rows.Next() {
|
||
var tableName string
|
||
if err := rows.Scan(&tableName); err != nil {
|
||
return nil, fmt.Errorf("扫描表名失败:%w", err)
|
||
}
|
||
tableNames = append(tableNames, tableName)
|
||
}
|
||
|
||
return tableNames, nil
|
||
}
|
||
|
||
// getSQLiteTableNames 获取 SQLite 所有表名
|
||
func (i *Introspector) getSQLiteTableNames() ([]string, error) {
|
||
query := `SELECT name FROM sqlite_master WHERE type='table' ORDER BY name`
|
||
|
||
rows, err := i.db.Query(query)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询表名失败:%w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
tableNames := []string{}
|
||
for rows.Next() {
|
||
var tableName string
|
||
if err := rows.Scan(&tableName); err != nil {
|
||
return nil, fmt.Errorf("扫描表名失败:%w", err)
|
||
}
|
||
// 跳过 SQLite 系统表
|
||
if tableName != "sqlite_sequence" {
|
||
tableNames = append(tableNames, tableName)
|
||
}
|
||
}
|
||
|
||
return tableNames, nil
|
||
}
|
||
|
||
// GetTableInfo 获取表的详细信息
|
||
func (i *Introspector) GetTableInfo(tableName string) (*TableInfo, error) {
|
||
switch i.config.Type {
|
||
case "mysql":
|
||
return i.getMySQLTableInfo(tableName)
|
||
case "postgres":
|
||
return i.getPostgresTableInfo(tableName)
|
||
case "sqlite":
|
||
return i.getSQLiteTableInfo(tableName)
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型:%s", i.config.Type)
|
||
}
|
||
}
|
||
|
||
// getMySQLTableInfo 获取 MySQL 表信息
|
||
func (i *Introspector) getMySQLTableInfo(tableName string) (*TableInfo, error) {
|
||
query := `
|
||
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_DEFAULT, EXTRA
|
||
FROM INFORMATION_SCHEMA.COLUMNS
|
||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||
ORDER BY ORDINAL_POSITION
|
||
`
|
||
|
||
rows, err := i.db.Query(query, i.config.Name, tableName)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询列信息失败:%w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
columns := []ColumnInfo{}
|
||
for rows.Next() {
|
||
var col ColumnInfo
|
||
var isNullableStr string // MySQL 返回的是字符串 "YES"/"NO"
|
||
var columnDefault sql.NullString
|
||
|
||
err := rows.Scan(&col.ColumnName, &col.DataType, &isNullableStr, &col.ColumnKey, &columnDefault, &col.Extra)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("扫描列信息失败:%w", err)
|
||
}
|
||
|
||
// 将字符串转换为布尔值
|
||
col.IsNullable = isNullableStr == "YES"
|
||
|
||
// 转换为 Go 类型
|
||
col.GoType = mapMySQLTypeToGoType(col.DataType)
|
||
col.FieldName = toCamelCase(col.ColumnName)
|
||
col.JSONName = col.ColumnName
|
||
col.IsPrimary = col.ColumnKey == "PRI"
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
|
||
return &TableInfo{
|
||
TableName: tableName,
|
||
Columns: columns,
|
||
}, nil
|
||
}
|
||
|
||
// getPostgresTableInfo 获取 PostgreSQL 表信息
|
||
func (i *Introspector) getPostgresTableInfo(tableName string) (*TableInfo, error) {
|
||
query := `
|
||
SELECT column_name, data_type, is_nullable, column_default
|
||
FROM information_schema.columns
|
||
WHERE table_name = $1
|
||
ORDER BY ordinal_position
|
||
`
|
||
|
||
rows, err := i.db.Query(query, tableName)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询列信息失败:%w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
columns := []ColumnInfo{}
|
||
for rows.Next() {
|
||
var col ColumnInfo
|
||
var columnDefault sql.NullString
|
||
err := rows.Scan(&col.ColumnName, &col.DataType, &col.IsNullable, &columnDefault)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("扫描列信息失败:%w", err)
|
||
}
|
||
|
||
// 转换为 Go 类型
|
||
col.GoType = mapPostgresTypeToGoType(col.DataType)
|
||
col.FieldName = toCamelCase(col.ColumnName)
|
||
col.JSONName = col.ColumnName
|
||
col.IsPrimary = col.ColumnName == "id"
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
|
||
return &TableInfo{
|
||
TableName: tableName,
|
||
Columns: columns,
|
||
}, nil
|
||
}
|
||
|
||
// getSQLiteTableInfo 获取 SQLite 表信息
|
||
func (i *Introspector) getSQLiteTableInfo(tableName string) (*TableInfo, error) {
|
||
query := fmt.Sprintf("PRAGMA table_info(%s)", tableName)
|
||
|
||
rows, err := i.db.Query(query)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询列信息失败:%w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
columns := []ColumnInfo{}
|
||
for rows.Next() {
|
||
var col ColumnInfo
|
||
var notNull int
|
||
var pk int
|
||
var defaultValue sql.NullString
|
||
|
||
err := rows.Scan(&col.ColumnName, &col.DataType, ¬Null, &defaultValue, &pk, &col.Extra)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("扫描列信息失败:%w", err)
|
||
}
|
||
|
||
col.IsNullable = notNull == 0
|
||
col.IsPrimary = pk > 0
|
||
|
||
// 转换为 Go 类型
|
||
col.GoType = mapSQLiteTypeToGoType(col.DataType)
|
||
col.FieldName = toCamelCase(col.ColumnName)
|
||
col.JSONName = col.ColumnName
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
|
||
return &TableInfo{
|
||
TableName: tableName,
|
||
Columns: columns,
|
||
}, nil
|
||
}
|
||
|
||
// mapMySQLTypeToGoType 映射 MySQL 类型到 Go 类型
|
||
func mapMySQLTypeToGoType(dbType string) string {
|
||
typeMap := map[string]string{
|
||
"tinyint": "int64",
|
||
"smallint": "int64",
|
||
"mediumint": "int64",
|
||
"int": "int64",
|
||
"bigint": "int64",
|
||
"float": "float64",
|
||
"double": "float64",
|
||
"decimal": "string",
|
||
"date": "time.Time",
|
||
"datetime": "time.Time",
|
||
"timestamp": "time.Time",
|
||
"time": "string",
|
||
"char": "string",
|
||
"varchar": "string",
|
||
"text": "string",
|
||
"tinytext": "string",
|
||
"mediumtext": "string",
|
||
"longtext": "string",
|
||
"blob": "[]byte",
|
||
"tinyblob": "[]byte",
|
||
"mediumblob": "[]byte",
|
||
"longblob": "[]byte",
|
||
"boolean": "bool",
|
||
"json": "string",
|
||
}
|
||
|
||
if goType, ok := typeMap[dbType]; ok {
|
||
return goType
|
||
}
|
||
return "string"
|
||
}
|
||
|
||
// mapPostgresTypeToGoType 映射 PostgreSQL 类型到 Go 类型
|
||
func mapPostgresTypeToGoType(dbType string) string {
|
||
typeMap := map[string]string{
|
||
"smallint": "int64",
|
||
"integer": "int64",
|
||
"bigint": "int64",
|
||
"real": "float64",
|
||
"double": "float64",
|
||
"numeric": "string",
|
||
"decimal": "string",
|
||
"date": "time.Time",
|
||
"timestamp": "time.Time",
|
||
"timestamptz": "time.Time",
|
||
"time": "string",
|
||
"char": "string",
|
||
"varchar": "string",
|
||
"text": "string",
|
||
"bytea": "[]byte",
|
||
"boolean": "bool",
|
||
"json": "string",
|
||
"jsonb": "string",
|
||
}
|
||
|
||
if goType, ok := typeMap[dbType]; ok {
|
||
return goType
|
||
}
|
||
return "string"
|
||
}
|
||
|
||
// mapSQLiteTypeToGoType 映射 SQLite 类型到 Go 类型
|
||
func mapSQLiteTypeToGoType(dbType string) string {
|
||
typeMap := map[string]string{
|
||
"INTEGER": "int64",
|
||
"REAL": "float64",
|
||
"TEXT": "string",
|
||
"BLOB": "[]byte",
|
||
"NUMERIC": "string",
|
||
}
|
||
|
||
if goType, ok := typeMap[dbType]; ok {
|
||
return goType
|
||
}
|
||
return "string"
|
||
}
|
||
|
||
// toCamelCase 转换为驼峰命名
|
||
func toCamelCase(str string) string {
|
||
parts := splitByUnderscore(str)
|
||
result := ""
|
||
|
||
for _, part := range parts {
|
||
if len(part) > 0 {
|
||
result += strings.ToUpper(string(part[0])) + part[1:]
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// splitByUnderscore 按下划线分割字符串
|
||
func splitByUnderscore(str string) []string {
|
||
result := []string{}
|
||
current := ""
|
||
|
||
for _, ch := range str {
|
||
if ch == '_' {
|
||
if current != "" {
|
||
result = append(result, current)
|
||
current = ""
|
||
}
|
||
} else {
|
||
current += string(ch)
|
||
}
|
||
}
|
||
|
||
if current != "" {
|
||
result = append(result, current)
|
||
}
|
||
|
||
return result
|
||
}
|