diff --git a/config/fun.go b/config/fun.go index 856abf5..7ce8604 100644 --- a/config/fun.go +++ b/config/fun.go @@ -7,6 +7,7 @@ import ( "git.magicany.cc/black1552/gin-base/log" "git.magicany.cc/black1552/gin-base/utils" + "github.com/fsnotify/fsnotify" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/os/gfile" "github.com/spf13/viper" @@ -35,22 +36,29 @@ func init() { return } log.Info("配置文件是否为空", utils.EmptyFile(configPath)) - setDefault() + SetDefault() err = viper.WriteConfig() if err != nil { log.Error("保存配置文件失败: ", err) return } + } else { + err = viper.ReadInConfig() + if err != nil { + log.Error("读取配置文件失败: ", err) + return + } } + viper.OnConfigChange(func(in fsnotify.Event) { + log.Info("配置文件已修改") + }) } -func setDefault() { +func SetDefault() { viper.Set("SERVER.addr", "127.0.0.1:8080") - viper.Set("DATABASE.host", "127.0.0.1") - viper.Set("DATABASE.port", 3306) - viper.Set("DATABASE.username", "root") - viper.Set("DATABASE.password", "") - viper.Set("DATABASE.name", "") + viper.Set("SERVER.mode", "release") + viper.Set("DATABASE.type", "mysql") + viper.Set("DATABASE.dns", "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local") viper.Set("JWT.secret", "SET-YOUR-SECRET") viper.Set("JWT.expire", 86400) } @@ -88,18 +96,16 @@ func SetConfigMap(value map[string]any) error { } func GetConfigValue(key string, def ...any) *gvar.Var { - va := gvar.New(viper.Get(key)) - if va.IsEmpty() && len(def) > 0 { + value := gvar.New(viper.Get(key)) + if value.IsEmpty() && len(def) > 0 { return gvar.New(def[0]) } - return va + return value } -func Unmarshal(s any) (any, error) { - err := viper.Unmarshal(s) - if err != nil { - return nil, err - } - return s, nil +func Unmarshal[T any]() (*T, error) { + var s T + err := viper.Unmarshal(&s) + return &s, err } func GetAllConfig() map[string]any { diff --git a/config/structs.go b/config/structs.go index be8ffdf..8e61de0 100644 --- a/config/structs.go +++ b/config/structs.go @@ -1,29 +1,21 @@ package config type BaseConfig struct { - Server ServerConfig `toml:"SERVER"` - Database DataBaseConfig `toml:"DATABASE"` - Jwt JwtConfig `toml:"JWT"` - Logger Logger `toml:"LOGGER"` + Server ServerConfig `mapstructure:"SERVER"` + Database DataBaseConfig `mapstructure:"DATABASE"` + Jwt JwtConfig `mapstructure:"JWT"` } type ServerConfig struct { - Addr string `toml:"addr"` + Addr string `mapstructure:"addr"` + Mode string `mapstructure:"mode"` } type DataBaseConfig struct { - Host string `toml:"host"` - Port string `toml:"port"` - User string `toml:"user"` - Pwd string `toml:"pwd"` - Name string `toml:"name"` + Dns string `mapstructure:"dns"` + Type string `mapstructure:"type"` } type JwtConfig struct { - Secret string `toml:"secret"` - Expire int64 `toml:"expire"` -} - -type Logger struct { - Level string `toml:"level"` - Path string `toml:"path"` + Secret string `mapstructure:"secret"` + Expire int64 `mapstructure:"expire"` } diff --git a/database/database.go b/database/database.go index 542508b..195c22e 100644 --- a/database/database.go +++ b/database/database.go @@ -1,5 +1,63 @@ package database -func init() { +import ( + "database/sql" + "git.magicany.cc/black1552/gin-base/config" + "git.magicany.cc/black1552/gin-base/log" + "github.com/gogf/gf/v2/frame/g" + "gorm.io/driver/mysql" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +var ( + Type gorm.Dialector + Db *gorm.DB + err error + sqlDb *sql.DB + dns = config.GetConfigValue("database.dns", "") +) + +func init() { + if g.IsEmpty(dns) { + log.Error("gormDns未配置", "请检查配置文件") + return + } + switch config.GetConfigValue("database.type", "mysql").String() { + case "mysql": + log.Info("使用mysql数据库") + mysqlInit() + case "sqlite": + log.Info("使用sqlite数据库") + sqliteInit() + } + Db, err = gorm.Open(Type, &gorm.Config{}) + if err != nil { + log.Error("数据库连接失败: ", err) + return + } + sqlDb, err = Db.DB() + if err != nil { + log.Error("获取sqlDb失败", err) + return + } + if err = sqlDb.Ping(); err != nil { + log.Error("数据库未正常连接", err) + return + } +} + +func mysqlInit() { + Type = mysql.New(mysql.Config{ + DSN: config.GetConfigValue("database.dns", "").String(), + DefaultStringSize: 255, // string 类型字段的默认长度 + DisableDatetimePrecision: true, // 禁用 datetime 精度,MySQL 5.6 之前的数据库不支持 + DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式,MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引 + SkipInitializeWithVersion: false, // 根据当前 MySQL 版本自动配置 + }) +} + +func sqliteInit() { + Type = sqlite.Open(config.GetConfigValue("database.dns", "").String()) } diff --git a/database/index.go b/database/index.go new file mode 100644 index 0000000..636bab8 --- /dev/null +++ b/database/index.go @@ -0,0 +1 @@ +package database diff --git a/database/migrate.go b/database/migrate.go new file mode 100644 index 0000000..43e6c85 --- /dev/null +++ b/database/migrate.go @@ -0,0 +1,44 @@ +package database + +import ( + "git.magicany.cc/black1552/gin-base/log" + "github.com/gogf/gf/v2/frame/g" +) + +func SetAutoMigrate(models ...interface{}) { + if g.IsNil(Db) { + log.Error("数据库连接失败") + return + } + Db = Db.Set("gorm:table_options", "ENGINE=InnoDB") + err := Db.AutoMigrate(models...) + if err != nil { + log.Error("数据库迁移失败", err) + } +} +func RenameColumn(dst interface{}, name, newName string) { + if Db.Migrator().HasColumn(dst, name) { + err := Db.Migrator().RenameColumn(dst, name, newName) + if err != nil { + log.Error("数据库修改字段失败", err) + return + } + } else { + log.Info("数据库字段不存在", name) + } +} + +// DropColumn +// 删除字段 +// 例:DropColumn(&User{}, "Sex") +func DropColumn(dst interface{}, name string) { + if Db.Migrator().HasColumn(dst, name) { + err := Db.Migrator().DropColumn(dst, name) + if err != nil { + log.Error("数据库删除字段失败", err) + return + } + } else { + log.Info("数据库字段不存在", name) + } +} diff --git a/go.mod b/go.mod index 57c65a7..62b3214 100644 --- a/go.mod +++ b/go.mod @@ -3,31 +3,51 @@ module git.magicany.cc/black1552/gin-base go 1.25 require ( + github.com/eclipse/paho.mqtt.golang v1.5.1 + github.com/fsnotify/fsnotify v1.9.0 github.com/gin-gonic/gin v1.7.7 github.com/gogf/gf/v2 v2.10.0 + github.com/gorilla/websocket v1.5.3 github.com/spf13/viper v1.21.0 - github.com/stretchr/testify v1.11.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1 + gorm.io/driver/mysql v1.6.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.1 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/BurntSushi/toml v1.5.0 // indirect + github.com/clbanning/mxj/v2 v2.7.0 // indirect github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect - github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.13.0 // indirect github.com/go-playground/universal-translator v0.17.0 // indirect github.com/go-playground/validator/v10 v10.4.1 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang/protobuf v1.3.3 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/grokify/html-strip-tags-go v0.1.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.9 // indirect github.com/leodido/go-urn v1.2.0 // indirect + github.com/magiconair/properties v1.8.10 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect + github.com/olekukonko/errors v1.1.0 // indirect + github.com/olekukonko/ll v0.0.9 // indirect + github.com/olekukonko/tablewriter v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect @@ -35,13 +55,17 @@ require ( github.com/spf13/pflag v1.0.10 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/ugorji/go/codec v1.1.7 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect + go.opentelemetry.io/otel/sdk v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 // indirect + golang.org/x/crypto v0.42.0 // indirect + golang.org/x/net v0.44.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v2 v2.2.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 177a705..655ed83 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME= @@ -5,6 +7,8 @@ github.com/clbanning/mxj/v2 v2.7.0/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE= +github.com/eclipse/paho.mqtt.golang v1.5.1/go.mod h1:1/yJCneuyOoCOzKSsOTUc0AJfpsItBGWvYpBLimhArU= github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU= github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= @@ -17,6 +21,7 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -29,6 +34,8 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gogf/gf/v2 v2.10.0 h1:rzDROlyqGMe/eM6dCalSR8dZOuMIdLhmxKSH1DGhbFs= @@ -44,13 +51,14 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grokify/html-strip-tags-go v0.1.0 h1:03UrQLjAny8xci+R+qjCce/MYnpNXCtgzltlQbOBae4= github.com/grokify/html-strip-tags-go v0.1.0/go.mod h1:ZdzgfHEzAfz9X6Xe5eBLVblWIxXfYSQ40S/VKrAOGpc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= @@ -60,10 +68,13 @@ github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3 github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= @@ -112,19 +123,27 @@ go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgf go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= +golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= +golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= @@ -143,3 +162,9 @@ gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= +gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/main.go b/main.go index 339e0a0..ff43765 100644 --- a/main.go +++ b/main.go @@ -3,25 +3,17 @@ package main import ( "git.magicany.cc/black1552/gin-base/config" "git.magicany.cc/black1552/gin-base/log" - "git.magicany.cc/black1552/gin-base/response" - "github.com/gin-gonic/gin" + "git.magicany.cc/black1552/gin-base/server" ) // TIP

To run your response, right-click the response and select Run.

Alternatively, click // the icon in the gutter and select the Run menu item from here.

func main() { - err := config.LoadConfigFromFile() + g := server.New() + cf, err := config.Unmarshal[config.BaseConfig]() if err != nil { - log.Info("err: ", err.Error()) + log.Error("转换配置失败", err) } - err = config.SetConfigValue("SERVICE.user-service", "127.0.0.1:3001") - if err != nil { - log.Info("err: ", err.Error()) - } - err = config.SetConfigMap(map[string]any{ - "SERVICE.product-service": "127.0.0.1:3002", - "SERVICE.order-service": "127.0.0.1:3003", - }) - response.Success(&gin.Context{}).SetMsg("text").End() - log.Info("ceshi", config.GetAllConfig()) + log.Info("启动服务:", cf.Server.Addr) + server.Run(g) } diff --git a/mqtt/client/mqtt.go b/mqtt/client/mqtt.go new file mode 100644 index 0000000..dbb23d4 --- /dev/null +++ b/mqtt/client/mqtt.go @@ -0,0 +1,220 @@ +package client + +import ( + "context" + "fmt" + "sync" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/gogf/gf/v2/os/glog" +) + +// Client MQTT客户端结构 +type Client struct { + client mqtt.Client + opts *mqtt.ClientOptions + ctx context.Context + subscribed map[string]byte + subMutex sync.RWMutex + callback mqtt.MessageHandler + // 错误处理相关 + onConnectionLost func(error) + onReconnect func() + onConnect func() + onSubscriptionError func(error) + onPublishError func(error) +} + +// NewClientWithAuth 创建带用户名密码认证的MQTT客户端 +func NewClientWithAuth(ctx context.Context, broker, clientId, username, password string) *Client { + c := &Client{ + ctx: ctx, + subscribed: make(map[string]byte), + } + + opts := mqtt.NewClientOptions() + opts.AddBroker(broker) + opts.SetClientID(clientId) + opts.SetUsername(username) + opts.SetPassword(password) + // 设置连接丢失处理函数 + opts.SetConnectionLostHandler(func(client mqtt.Client, err error) { + glog.Error(ctx, "MQTT连接断开:", err) + if c.onConnectionLost != nil { + c.onConnectionLost(err) + } + }) + + // 设置重连处理函数 + opts.SetReconnectingHandler(func(client mqtt.Client, opts *mqtt.ClientOptions) { + glog.Info(ctx, "MQTT客户端正在尝试重连...") + if c.onReconnect != nil { + c.onReconnect() + } + }) + + // 设置连接成功处理函数,重连成功后重新订阅主题 + opts.SetOnConnectHandler(func(client mqtt.Client) { + glog.Info(ctx, "MQTT客户端重新连接成功") + if c.onConnect != nil { + c.onConnect() + } + // 重连成功后重新订阅主题 + go c.resubscribe() + }) + + // 设置其他选项... + + mqttClient := mqtt.NewClient(opts) + c.client = mqttClient + c.opts = opts + return c +} + +// Connect 连接到MQTT服务器 +func (c *Client) Connect() error { + glog.Info(c.ctx, "开始连接到MQTT服务器...") + token := c.client.Connect() + glog.Info(c.ctx, "等待连接完成...") + + // 使用更长的超时时间,避免网络延迟导致连接失败 + if token.WaitTimeout(30 * time.Second) { + glog.Info(c.ctx, "连接操作完成") + if token.Error() != nil { + err := fmt.Errorf("连接到MQTT代理时发生错误: %w", token.Error()) + glog.Error(c.ctx, "连接MQTT服务器时发生错误:", token.Error()) + return err + } + } else { + // 连接超时 + err := fmt.Errorf("连接到MQTT服务器超时") + glog.Error(c.ctx, "连接MQTT服务器超时") + return err + } + glog.Info(c.ctx, "成功连接到MQTT服务器") + return nil +} + +// Disconnect 断开MQTT连接 +func (c *Client) Disconnect() { + if c.client.IsConnected() { + c.client.Disconnect(250) + glog.Info(c.ctx, "已断开MQTT连接") + } +} + +// SubscribeMultiple 同时订阅多个主题 +func (c *Client) SubscribeMultiple(topics map[string]byte, callback mqtt.MessageHandler) error { + // 保存订阅信息 + c.subMutex.Lock() + for topic, qos := range topics { + c.subscribed[topic] = qos + } + c.callback = callback + c.subMutex.Unlock() + + token := c.client.SubscribeMultiple(topics, callback) + // 增加订阅超时时间 + if token.WaitTimeout(30*time.Second) && token.Error() != nil { + err := fmt.Errorf("同时订阅多个主题出现错误: %w", token.Error()) + glog.Error(c.ctx, "订阅主题时发生错误:", token.Error()) + if c.onSubscriptionError != nil { + c.onSubscriptionError(err) + } + return err + } + + // 检查订阅是否成功 + if token.WaitTimeout(30 * time.Second) { + glog.Info(c.ctx, "成功订阅主题:", topics) + } else { + err := fmt.Errorf("订阅主题超时: %v", topics) + glog.Error(c.ctx, "订阅主题超时:", topics) + if c.onSubscriptionError != nil { + c.onSubscriptionError(err) + } + return err + } + + return nil +} + +// Publish 发布消息 +func (c *Client) Publish(topic string, qos byte, retained bool, payload interface{}) error { + token := c.client.Publish(topic, qos, retained, payload) + // 增加发布超时时间 + if token.WaitTimeout(30*time.Second) && token.Error() != nil { + err := fmt.Errorf("发送消息到主题%s出现错误: %w", topic, token.Error()) + glog.Error(c.ctx, "发布消息到主题", topic, "时发生错误:", token.Error()) + if c.onPublishError != nil { + c.onPublishError(err) + } + return err + } + + // 检查发布是否成功 + if token.WaitTimeout(30 * time.Second) { + glog.Info(c.ctx, "成功发布消息到主题:", topic) + } else { + err := fmt.Errorf("发布消息到主题%s超时", topic) + glog.Error(c.ctx, "发布消息到主题超时:", topic) + if c.onPublishError != nil { + c.onPublishError(err) + } + return err + } + + return nil +} + +// resubscribe 重新订阅主题 +func (c *Client) resubscribe() { + c.subMutex.RLock() + defer c.subMutex.RUnlock() + + if len(c.subscribed) == 0 { + glog.Info(c.ctx, "没有需要重新订阅的主题") + return + } + + // 复制订阅信息避免并发问题 + topics := make(map[string]byte) + for topic, qos := range c.subscribed { + topics[topic] = qos + } + + glog.Info(c.ctx, "开始重新订阅主题:", topics) + // 增加重新订阅的超时时间 + token := c.client.SubscribeMultiple(topics, c.callback) + if token.WaitTimeout(30 * time.Second) { + if token.Error() != nil { + err := fmt.Errorf("重新订阅主题时发生错误: %w", token.Error()) + glog.Error(c.ctx, "重新订阅主题时发生错误:", token.Error()) + if c.onSubscriptionError != nil { + c.onSubscriptionError(err) + } + } else { + glog.Info(c.ctx, "重新订阅主题成功") + } + } else { + err := fmt.Errorf("重新订阅主题超时") + glog.Error(c.ctx, "重新订阅主题超时") + if c.onSubscriptionError != nil { + c.onSubscriptionError(err) + } + } +} + +// IsConnected 检查是否连接 +func (c *Client) IsConnected() bool { + isConnected := c.client.IsConnected() + glog.Debug(c.ctx, "检查连接状态:", isConnected) + return isConnected +} + +// Subscribe 单独订阅一个主题 +func (c *Client) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) error { + topics := map[string]byte{topic: qos} + return c.SubscribeMultiple(topics, callback) +} diff --git a/server/server.go b/server/server.go index 61b3e86..24a256a 100644 --- a/server/server.go +++ b/server/server.go @@ -1,15 +1,71 @@ package server import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + "git.magicany.cc/black1552/gin-base/config" + "git.magicany.cc/black1552/gin-base/log" "git.magicany.cc/black1552/gin-base/middleware" "github.com/gin-gonic/gin" ) +// New 创建一个gin实例 +// 默认使用全局异常处理中间件 +// 默认使用跨域中间件 +// @Return *gin.Engine 可以使用创建路由 func New() *gin.Engine { + gin.SetMode(config.GetConfigValue("server.mode", "debug").String()) g := gin.New() g.Use(middleware.ErrorHandler()) g.Use(middleware.CORSMiddleware()) - g.Run(config.GetConfigValue("server.addr", "127.0.0.1:8080").String()) return g } + +// Run 启动服务 +// @Param *gin.Engine 路由实例 +// 设置监听挂壁 +func Run(g *gin.Engine) { + s := &http.Server{ + Addr: config.GetConfigValue("server.addr", ":8080").String(), + Handler: g, + ReadTimeout: 60 * time.Second, + ReadHeaderTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 60 * time.Second, + MaxHeaderBytes: 20 * 1024 * 1024, + } + go func() { + if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Error("服务器启动失败:", err) + } + }() + log.Info("服务器启动成功....") + if strings.Contains(s.Addr, "127.0.0.1") || strings.Contains(s.Addr, "0.0.0.0") || strings.Contains(s.Addr, "locahost") { + log.Info("请使用打开:", fmt.Sprintf("http://%s\n", s.Addr)) + } else { + log.Info("请使用打开:", fmt.Sprintf("http://localhost%s\n", s.Addr)) + } + // 等待中断信号以优雅地关闭服务器 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Info("正在关闭服务器...") + + // 设置5秒的超时时间用于关闭服务器 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := s.Shutdown(ctx); err != nil { + log.Error("服务器强制关闭") + } + log.Info("服务器已退出") +} diff --git a/tcp/example.go b/tcp/example.go new file mode 100644 index 0000000..cd54038 --- /dev/null +++ b/tcp/example.go @@ -0,0 +1,47 @@ +package tcp + +import ( + "fmt" + "time" +) + +// Example 展示如何使用TCP服务 +func Example() { + // 创建配置 + config := &TcpPoolConfig{ + BufferSize: 2048, + MaxConnections: 100000, + ConnectTimeout: time.Second * 5, + ReadTimeout: time.Second * 30, + WriteTimeout: time.Second * 10, + MaxIdleTime: time.Minute * 5, + } + + // 创建TCP服务器 + server := NewTCPServer("0.0.0.0:8888", config) + + // 设置消息处理函数 + server.SetMessageHandler(func(conn *TcpConnection, msg *TcpMessage) error { + fmt.Printf("Received message from %s: %s\n", conn.Id, string(msg.Data)) + + // 回显消息 + return server.SendTo(conn.Id, []byte(fmt.Sprintf("Echo: %s", msg.Data))) + }) + + // 启动服务器 + if err := server.Start(); err != nil { + fmt.Printf("Failed to start server: %v\n", err) + return + } + + // 运行10秒后停止 + fmt.Println("TCP server started. Running for 10 seconds...") + time.Sleep(time.Second * 10) + + // 停止服务器 + if err := server.Stop(); err != nil { + fmt.Printf("Failed to stop server: %v\n", err) + } + + fmt.Println("TCP server stopped.") +} diff --git a/tcp/tcp.go b/tcp/tcp.go new file mode 100644 index 0000000..73b5bd3 --- /dev/null +++ b/tcp/tcp.go @@ -0,0 +1,280 @@ +package tcp + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/gtcp" + "github.com/gogf/gf/v2/os/glog" + "github.com/gogf/gf/v2/os/grpool" + "github.com/gogf/gf/v2/os/gtime" +) + +// MessageHandler 消息处理函数类型 +type MessageHandler func(conn *TcpConnection, msg *TcpMessage) error + +// TCPServer TCP服务器结构 +type TCPServer struct { + Address string + Config *TcpPoolConfig + Listener *gtcp.Server + Connection *ConnectionPool + Logger *glog.Logger + MessageHandler MessageHandler + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// ConnectionPool 连接池结构 +type ConnectionPool struct { + connections map[string]*TcpConnection + mutex sync.RWMutex + config *TcpPoolConfig + logger *glog.Logger +} + +// NewTCPServer 创建一个新的TCP服务器 +func NewTCPServer(address string, config *TcpPoolConfig) *TCPServer { + logger := g.Log(address) + ctx, cancel := context.WithCancel(context.Background()) + + pool := &ConnectionPool{ + connections: make(map[string]*TcpConnection), + config: config, + logger: logger, + } + + server := &TCPServer{ + Address: address, + Config: config, + Connection: pool, + Logger: logger, + ctx: ctx, + cancel: cancel, + } + + server.Listener = gtcp.NewServer(address, server.handleConnection) + return server +} + +// SetMessageHandler 设置消息处理函数 +func (s *TCPServer) SetMessageHandler(handler MessageHandler) { + s.MessageHandler = handler +} + +// Start 启动TCP服务器 +func (s *TCPServer) Start() error { + s.Logger.Info(s.ctx, fmt.Sprintf("TCP server starting on %s", s.Address)) + go func() { + s.wg.Add(1) + defer s.wg.Done() + if err := s.Listener.Run(); err != nil { + s.Logger.Error(s.ctx, fmt.Sprintf("TCP server stopped with error: %v", err)) + } + }() + return nil +} + +// Stop 停止TCP服务器 +func (s *TCPServer) Stop() error { + s.Logger.Info(s.ctx, "TCP server stopping...") + s.cancel() + s.Listener.Close() + s.wg.Wait() + s.Connection.Clear() + s.Logger.Info(s.ctx, "TCP server stopped") + return nil +} + +// handleConnection 处理新连接 +func (s *TCPServer) handleConnection(conn *gtcp.Conn) { + // 生成连接ID + connID := fmt.Sprintf("%s_%d", conn.RemoteAddr().String(), gtime.TimestampNano()) + + // 创建连接对象 + tcpConn := &TcpConnection{ + Id: connID, + Address: conn.RemoteAddr().String(), + Server: *conn, + IsActive: true, + LastUsed: time.Now(), + CreatedAt: time.Now(), + } + + // 添加到连接池 + s.Connection.Add(tcpConn) + s.Logger.Info(s.ctx, fmt.Sprintf("New connection established: %s", connID)) + + // 启动消息接收协程 + go s.receiveMessages(tcpConn) +} + +// receiveMessages 接收消息 +func (s *TCPServer) receiveMessages(conn *TcpConnection) { + defer func() { + if err := recover(); err != nil { + s.Logger.Error(s.ctx, fmt.Sprintf("Panic in receiveMessages: %v", err)) + } + s.Connection.Remove(conn.Id) + conn.Server.Close() + s.Logger.Info(s.ctx, fmt.Sprintf("Connection closed: %s", conn.Id)) + }() + + buffer := make([]byte, s.Config.BufferSize) + for { + select { + case <-s.ctx.Done(): + return + default: + // 设置读取超时 + conn.Server.SetReadDeadline(time.Now().Add(s.Config.ReadTimeout)) + + // 读取数据 + n, err := conn.Server.Read(buffer) + if err != nil { + s.Logger.Error(s.ctx, fmt.Sprintf("Read error from %s: %v", conn.Id, err)) + return + } + + if n > 0 { + // 更新最后使用时间 + conn.Mutex.Lock() + conn.LastUsed = time.Now() + conn.Mutex.Unlock() + + // 处理消息 + data := make([]byte, n) + copy(data, buffer[:n]) + + msg := &TcpMessage{ + Id: fmt.Sprintf("msg_%d", gtime.TimestampNano()), + ConnId: conn.Id, + Data: data, + Timestamp: time.Now(), + IsSend: false, + } + + // 使用协程池处理消息,避免阻塞 + grpool.AddWithRecover(s.ctx, func(ctx context.Context) { + if s.MessageHandler != nil { + if err := s.MessageHandler(conn, msg); err != nil { + s.Logger.Error(s.ctx, fmt.Sprintf("Message handling error: %v", err)) + } + } + }, func(ctx context.Context, err error) { + s.Logger.Error(ctx, fmt.Sprintf("Message handling error: %v", err)) + }) + } + } + } +} + +// SendTo 发送消息到指定连接 +func (s *TCPServer) SendTo(connID string, data []byte) error { + conn := s.Connection.Get(connID) + if conn == nil { + return fmt.Errorf("connection not found: %s", connID) + } + return s.sendMessage(conn, data) +} + +// SendToAll 发送消息到所有连接 +func (s *TCPServer) SendToAll(data []byte) error { + conns := s.Connection.GetAll() + for _, conn := range conns { + if err := s.sendMessage(conn, data); err != nil { + s.Logger.Error(s.ctx, fmt.Sprintf("Send to %s failed: %v", conn.Id, err)) + // 继续发送给其他连接 + } + } + return nil +} + +// sendMessage 发送消息 +func (s *TCPServer) sendMessage(conn *TcpConnection, data []byte) error { + conn.Mutex.Lock() + defer conn.Mutex.Unlock() + + // 设置写入超时 + conn.Server.SetWriteDeadline(time.Now().Add(s.Config.WriteTimeout)) + + // 发送数据 + _, err := conn.Server.Write(data) + if err != nil { + return err + } + + // 更新最后使用时间 + conn.LastUsed = time.Now() + return nil +} + +// Kick 强制退出客户端 +func (s *TCPServer) Kick(connID string) error { + conn := s.Connection.Get(connID) + if conn == nil { + return fmt.Errorf("connection not found: %s", connID) + } + + // 关闭连接 + conn.Server.Close() + // 从连接池移除 + s.Connection.Remove(connID) + + s.Logger.Info(s.ctx, fmt.Sprintf("Kicked connection: %s", connID)) + return nil +} + +// Add 添加连接到连接池 +func (p *ConnectionPool) Add(conn *TcpConnection) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.connections[conn.Id] = conn +} + +// Get 获取连接 +func (p *ConnectionPool) Get(connID string) *TcpConnection { + p.mutex.RLock() + defer p.mutex.RUnlock() + return p.connections[connID] +} + +// GetAll 获取所有连接 +func (p *ConnectionPool) GetAll() []*TcpConnection { + p.mutex.RLock() + defer p.mutex.RUnlock() + + conns := make([]*TcpConnection, 0, len(p.connections)) + for _, conn := range p.connections { + conns = append(conns, conn) + } + return conns +} + +// Remove 从连接池移除连接 +func (p *ConnectionPool) Remove(connID string) { + p.mutex.Lock() + defer p.mutex.Unlock() + delete(p.connections, connID) +} + +// Clear 清空连接池 +func (p *ConnectionPool) Clear() { + p.mutex.Lock() + defer p.mutex.Unlock() + for connID, conn := range p.connections { + conn.Server.Close() + delete(p.connections, connID) + } +} + +// Count 获取连接数量 +func (p *ConnectionPool) Count() int { + p.mutex.RLock() + defer p.mutex.RUnlock() + return len(p.connections) +} diff --git a/tcp/tcpConfig.go b/tcp/tcpConfig.go new file mode 100644 index 0000000..91f425c --- /dev/null +++ b/tcp/tcpConfig.go @@ -0,0 +1,38 @@ +package tcp + +import ( + "sync" + "time" + + "github.com/gogf/gf/v2/net/gtcp" +) + +// TcpPoolConfig TCP连接池配置 +type TcpPoolConfig struct { + BufferSize int `json:"bufferSize"` // 缓冲区大小 + MaxConnections int `json:"maxConnections"` // 最大连接数 + ConnectTimeout time.Duration `json:"connectTimeout"` // 连接超时时间 + ReadTimeout time.Duration `json:"readTimeout"` // 读取超时时间 + WriteTimeout time.Duration `json:"writeTimeout"` // 写入超时时间 + MaxIdleTime time.Duration `json:"maxIdleTime"` // 最大空闲时间 +} + +// TcpConnection TCP连接结构 +type TcpConnection struct { + Id string `json:"id"` // 连接ID + Address string `json:"address"` // 连接地址 + Server gtcp.Conn `json:"server"` // 实际连接 + IsActive bool `json:"isActive"` // 是否活跃 + LastUsed time.Time `json:"lastUsed"` // 最后使用时间 + CreatedAt time.Time `json:"createdAt"` // 创建时间 + Mutex sync.RWMutex `json:"-"` // 读写锁 +} + +// TcpMessage TCP消息结构 +type TcpMessage struct { + Id string `json:"id"` // 消息ID + ConnId string `json:"connId"` // 连接ID + Data []byte `json:"data"` // 消息数据 + Timestamp time.Time `json:"timestamp"` // 时间戳 + IsSend bool `json:"isSend"` // 是否是发送的消息 +} diff --git a/ws/example.go b/ws/example.go new file mode 100644 index 0000000..1420de9 --- /dev/null +++ b/ws/example.go @@ -0,0 +1,73 @@ +package ws + +import ( + "log" + "net/http" + "time" + + "github.com/gogf/gf/v2/util/gconv" +) + +var manager = NewWs() + +func NewWs() *Manager { + // 1. 自定义配置(可选,也可使用默认配置) + customConfig := &Config{ + AllowAllOrigins: true, + HeartbeatInterval: 20 * time.Second, // 20秒发一次心跳 + HeartbeatTimeout: 40 * time.Second, // 40秒超时 + } + + // 2. 创建管理器 + m := NewManager(customConfig) + + // 3. 覆盖业务回调(核心:自定义消息处理逻辑) + // 连接建立回调 + m.OnConnect = func(connID string) { + log.Printf("业务回调:连接[%s]上线,当前在线数:%d", connID, m.GetOnlineCount()) + // 欢迎消息 + _ = m.SendToConn(connID, []byte("欢迎连接WebSocket服务!")) + } + + // 收到消息回调 + m.OnMessage = func(connID string, msgType int, data any) { + log.Printf("业务回调:收到连接[%s]消息:%s", connID, gconv.String(data)) + // 示例:echo回复 + reply := []byte("服务端回复:" + gconv.String(data)) + _ = m.SendToConn(connID, reply) + + // 示例:广播消息给所有连接 + _ = m.Broadcast([]byte("广播:" + connID + "说:" + gconv.String(data))) + } + + // 连接断开回调 + m.OnDisconnect = func(connID string, err error) { + log.Printf("业务回调:连接[%s]下线,原因:%v,当前在线数:%d", connID, err, m.GetOnlineCount()) + } + return m +} +func Upgrade(w http.ResponseWriter, r *http.Request, connID string) { + _, err := manager.Upgrade(w, r, connID) + if err != nil { + log.Printf("升级连接失败:%v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} +func main() { + // 4. 注册WebSocket路由 + http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + // 自定义连接ID(示例:使用请求参数中的user_id) + connID := r.URL.Query().Get("user_id") + if connID == "" { + http.Error(w, "user_id不能为空", http.StatusBadRequest) + return + } + // 升级连接 + Upgrade(w, r, connID) + }) + + // 5. 启动服务 + log.Println("WebSocket服务启动:http://localhost:8080/ws") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/ws/websocket.go b/ws/websocket.go new file mode 100644 index 0000000..e02c0be --- /dev/null +++ b/ws/websocket.go @@ -0,0 +1,488 @@ +package ws + +import ( + "context" + "errors" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/os/gctx" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/os/gtimer" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gorilla/websocket" +) + +// 常量定义:默认配置 +const ( + // DefaultReadBufferSize 默认读写缓冲区大小(字节) + DefaultReadBufferSize = 1024 + DefaultWriteBufferSize = 1024 + // DefaultHeartbeatInterval 默认心跳间隔(秒):每30秒发送一次心跳 + DefaultHeartbeatInterval = 30 * time.Second + // DefaultHeartbeatTimeout 默认心跳超时(秒):60秒未收到客户端心跳响应则关闭连接 + DefaultHeartbeatTimeout = 60 * time.Second + // DefaultReadTimeout 默认读写超时(秒) + DefaultReadTimeout = 60 * time.Second + DefaultWriteTimeout = 10 * time.Second + // MessageTypeText 消息类型 + MessageTypeText = websocket.TextMessage + MessageTypeBinary = websocket.BinaryMessage + // HeartbeatMaxRetry 心跳最大重试次数 + HeartbeatMaxRetry = 3 +) + +// Config WebSocket服务端配置 +type Config struct { + // 读写缓冲区大小 + ReadBufferSize int + WriteBufferSize int + // 跨域配置:是否允许所有跨域(生产环境建议指定Origin) + AllowAllOrigins bool + // 允许的跨域Origin列表(AllowAllOrigins=false时生效) + AllowedOrigins []string + // 心跳配置 + HeartbeatInterval time.Duration // 心跳发送间隔 + HeartbeatTimeout time.Duration // 心跳超时时间 + // 读写超时 + ReadTimeout time.Duration + WriteTimeout time.Duration + MsgType int // 发送消息的默认类型 + HeartbeatValue string // 心跳消息的标识字段值(如"heartbeat"、"pong") + HeartbeatKey string // 心跳消息的标识字段名(如"type") +} + +// 默认配置 +func DefaultConfig() *Config { + return &Config{ + ReadBufferSize: DefaultReadBufferSize, + WriteBufferSize: DefaultWriteBufferSize, + AllowAllOrigins: true, + AllowedOrigins: []string{}, + HeartbeatInterval: DefaultHeartbeatInterval, + HeartbeatTimeout: DefaultHeartbeatTimeout, + ReadTimeout: DefaultReadTimeout, + WriteTimeout: DefaultWriteTimeout, + MsgType: MessageTypeText, + HeartbeatValue: "heartbeat", + HeartbeatKey: "type", // 心跳消息的标识字段名,默认"type" + } +} + +// Connection WebSocket连接结构体 +type Connection struct { + conn *websocket.Conn // 底层连接 + connID string // 唯一连接ID + manager *Manager // 所属管理器 + createTime time.Time // 连接创建时间 + heartbeatChan time.Time // 心跳通道(用于检测客户端响应) + heartbeatTime *gtimer.Entry + ctx context.Context // 上下文 + cancel context.CancelFunc // 上下文取消函数 + writeMutex sync.Mutex // 写消息互斥锁(防止并发写) + heartbeatRetry int // 心跳发送重试次数 +} + +// Manager WebSocket连接管理器 +type Manager struct { + config *Config // 配置 + upgrader *websocket.Upgrader // HTTP升级器 + connections map[string]*Connection // 所有在线连接(connID -> Connection) + mutex sync.RWMutex // 读写锁(保护connections) + // 业务回调:收到消息时触发(用户自定义处理逻辑) + OnMessage func(connID string, msgType int, data any) + // 业务回调:连接建立时触发 + OnConnect func(connID string) + // 业务回调:连接关闭时触发 + OnDisconnect func(connID string, err error) +} + +// Merge 合并配置,用传入的配置覆盖非零值部分 +func (c *Config) Merge(other *Config) *Config { + result := *c // 复制当前配置 + + if other == nil { + return &result + } + + if other.ReadBufferSize > 0 { + result.ReadBufferSize = other.ReadBufferSize + } + if other.WriteBufferSize > 0 { + result.WriteBufferSize = other.WriteBufferSize + } + if other.HeartbeatInterval > 0 { + result.HeartbeatInterval = other.HeartbeatInterval + } + if other.HeartbeatTimeout > 0 { + result.HeartbeatTimeout = other.HeartbeatTimeout + } + if other.ReadTimeout > 0 { + result.ReadTimeout = other.ReadTimeout + } + if other.WriteTimeout > 0 { + result.WriteTimeout = other.WriteTimeout + } + if other.AllowAllOrigins { + result.AllowAllOrigins = other.AllowAllOrigins + } + if other.HeartbeatValue != "" { + result.HeartbeatValue = other.HeartbeatValue + } + if other.HeartbeatKey != "" { + result.HeartbeatKey = other.HeartbeatKey + } + if len(other.AllowedOrigins) > 0 { + result.AllowedOrigins = other.AllowedOrigins + } + if other.MsgType != 0 { + result.MsgType = other.MsgType + } + + return &result +} + +// NewManager 创建连接管理器 +func NewManager(config *Config) *Manager { + defaultConfig := DefaultConfig() + finalConfig := defaultConfig.Merge(config) + // 初始化升级器 + upgrader := &websocket.Upgrader{ + ReadBufferSize: config.ReadBufferSize, + WriteBufferSize: config.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { + // 跨域检查 + if config.AllowAllOrigins { + return true + } + origin := r.Header.Get("Origin") + for _, allowed := range finalConfig.AllowedOrigins { + if origin == allowed { + return true + } + } + return false + }, + } + + return &Manager{ + config: finalConfig, + upgrader: upgrader, + connections: make(map[string]*Connection), + mutex: sync.RWMutex{}, + // 默认回调(用户可覆盖) + OnMessage: func(connID string, msgType int, data any) { + log.Printf("[默认回调] 收到连接[%s]消息:%s", connID, gconv.String(data)) + }, + OnConnect: func(connID string) { + log.Printf("[默认回调] 连接[%s]已建立", connID) + }, + OnDisconnect: func(connID string, err error) { + log.Printf("[默认回调] 连接[%s]已关闭:%v", connID, err) + }, + } +} + +// Upgrade HTTP升级为WebSocket连接 +// connID:自定义连接唯一ID(如用户ID、设备ID) +func (m *Manager) Upgrade(w http.ResponseWriter, r *http.Request, connID string) (*Connection, error) { + if connID == "" { + return nil, errors.New("连接ID不能为空") + } + + // 检查连接ID是否已存在 + m.mutex.RLock() + _, exists := m.connections[connID] + m.mutex.RUnlock() + if exists { + return nil, fmt.Errorf("连接ID[%s]已存在", connID) + } + + // 升级HTTP连接 + conn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, fmt.Errorf("升级WebSocket失败:%w", err) + } + + // 创建上下文(用于优雅关闭) + ctx, cancel := context.WithCancel(context.Background()) + + // 创建连接实例 + wsConn := &Connection{ + conn: conn, + connID: connID, + manager: m, + createTime: time.Now(), + heartbeatChan: time.Now(), // 缓冲1,防止阻塞 + ctx: ctx, + cancel: cancel, + writeMutex: sync.Mutex{}, + heartbeatRetry: 0, + } + wsConn.heartbeatTime = gtimer.AddSingleton(gctx.New(), m.config.HeartbeatTimeout, func(ctx context.Context) { + log.Printf("[心跳检测] 连接[%s]已关闭:心跳超时", wsConn.connID) + wsConn.heartbeatTime.Close() + wsConn.heartbeatTime.Stop() + wsConn.heartbeatTime = nil + wsConn.ctx.Done() + wsConn.Close(fmt.Errorf("心跳超时")) + }) + // 添加到管理器 + m.mutex.Lock() + m.connections[connID] = wsConn + m.mutex.Unlock() + + // 触发连接建立回调 + m.OnConnect(connID) + + // 启动读消息协程 + go wsConn.ReadPump() + // 启动写消息协程(处理异步发送) + go wsConn.WritePump() + // 启动心跳检测协程 + go wsConn.Heartbeat() + + return wsConn, nil +} + +// ReadPump 读取客户端消息(持续运行) +func (c *Connection) ReadPump() { + defer func() { + // 发生panic时关闭连接 + if err := recover(); err != nil { + log.Printf("连接[%s]读消息协程panic:%v", c.connID, err) + } + // 关闭连接并清理 + c.Close(fmt.Errorf("读消息协程退出")) + }() + + // 循环读取消息 + for { + select { + case <-c.ctx.Done(): + return // 上下文已取消,退出 + default: + // 设置读超时(每次读取前重置,防止长时间无消息超时) + c.conn.SetReadDeadline(time.Now().Add(c.manager.config.ReadTimeout)) + // 读取客户端消息 + msgType, data, err := c.conn.ReadMessage() + if err != nil { + // 区分正常关闭和异常错误 + var closeErr *websocket.CloseError + if errors.As(err, &closeErr) { + c.Close(fmt.Errorf("客户端主动关闭:%s(代码:%d)", closeErr.Text, closeErr.Code)) + } else { + c.Close(fmt.Errorf("读取消息失败:%w", err)) + } + return + } + + // 尝试解析JSON格式的心跳消息(精准判断,替代包含判断) + isHeartbeat := false + // 先尝试解析为JSON对象 + var msgMap map[string]interface{} + if err := gjson.DecodeTo(data, &msgMap); err == nil { + // 获取心跳标识字段的值 + heartbeatValue := gconv.String(msgMap[c.manager.config.HeartbeatKey]) + if heartbeatValue == c.manager.config.HeartbeatValue { + isHeartbeat = true + } + } else { + // 非JSON格式,降级为包含判断(兼容纯文本心跳) + str := gconv.String(data) + if gstr.Contains(str, c.manager.config.HeartbeatValue) { + isHeartbeat = true + } + } + if isHeartbeat { + log.Printf("[心跳] 收到连接[%s]心跳消息:%s", c.connID, string(data)) + // 心跳消息:重置重试次数 + 发送心跳信号 + 重置读超时 + js, err := gjson.Encode(&Msg[any]{c.manager.config.HeartbeatValue, nil, gtime.Timestamp()}) + if err != nil { + log.Printf("[心跳] 客户端[%s]json编码失败", c.connID) + continue + } + err = c.Send(js) + if err != nil { + log.Printf("[心跳] 客户端[%s]发送心跳消息失败", c.connID) + continue + } + c.heartbeatTime.Reset() + continue // 跳过业务回调 + } + + // 非心跳消息:触发业务回调 + c.manager.OnMessage(c.connID, msgType, data) + } + } +} + +type Msg[T any] struct { + Type string `json:"type"` + Data T `json:"data"` + Timestamp int64 `json:"timestamp"` +} + +// WritePump 处理异步写消息(持续运行) +// 扩展为监听写队列,防止消息丢失 +func (c *Connection) WritePump() { + defer func() { + if err := recover(); err != nil { + log.Printf("连接[%s]写消息协程panic:%v", c.connID, err) + } + }() + + // 暂时保持简化,实际可扩展为带缓冲的写队列 + <-c.ctx.Done() +} + +// Heartbeat 心跳检测(持续运行) +func (c *Connection) Heartbeat() { + defer func() { + if err := recover(); err != nil { + log.Printf("连接[%s]心跳协程panic:%v", c.connID, err) + } + }() + c.heartbeatTime.Start() +} + +// Send 发送消息到客户端(线程安全) +func (c *Connection) Send(data []byte) error { + select { + case <-c.ctx.Done(): + return errors.New("连接已关闭,无法发送消息") + default: + // 加锁防止并发写 + c.writeMutex.Lock() + defer c.writeMutex.Unlock() + + // 设置写超时 + c.conn.SetWriteDeadline(time.Now().Add(c.manager.config.WriteTimeout)) + + // 发送消息(使用连接的默认类型,支持动态调整) + err := c.conn.WriteMessage(c.manager.config.MsgType, data) + if err != nil { + return fmt.Errorf("发送消息失败:%w", err) + } + return nil + } +} + +// Close 关闭连接(优雅清理) +func (c *Connection) Close(err error) { + // 防止重复关闭 + select { + case <-c.ctx.Done(): + return + default: + } + + // 取消上下文(终止所有协程) + c.cancel() + + // 关闭底层连接(友好关闭) + _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error())) + _ = c.conn.Close() + + // 从管理器移除 + c.manager.mutex.Lock() + delete(c.manager.connections, c.connID) + c.manager.mutex.Unlock() + + // 触发断开回调 + c.manager.OnDisconnect(c.connID, err) + + log.Printf("连接[%s]已关闭,当前在线数:%d,原因:%v", c.connID, c.manager.GetOnlineCount(), err) +} + +// GetOnlineCount 获取在线连接数 +func (m *Manager) GetOnlineCount() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + return len(m.connections) +} + +// Broadcast 广播消息到所有在线连接 +func (m *Manager) Broadcast(data []byte) error { + m.mutex.RLock() + defer m.mutex.RUnlock() + + if len(m.connections) == 0 { + return errors.New("无在线连接") + } + + // 并发发送(非阻塞) + var wg sync.WaitGroup + var errMsg string + + for _, conn := range m.connections { + wg.Add(1) + go func(c *Connection) { + defer wg.Done() + if err := c.Send(data); err != nil { + errMsg += fmt.Sprintf("连接[%s]广播失败:%v;", c.connID, err) + } + }(conn) + } + + wg.Wait() + + if errMsg != "" { + return errors.New(errMsg) + } + return nil +} + +// SendToConn 定向发送消息到指定连接 +func (m *Manager) SendToConn(connID string, data []byte) error { + m.mutex.RLock() + conn, exists := m.connections[connID] + m.mutex.RUnlock() + + if !exists { + return fmt.Errorf("连接[%s]不存在", connID) + } + + return conn.Send(data) +} + +func (m *Manager) GetAllConn() map[string]*Connection { + m.mutex.RLock() + defer m.mutex.RUnlock() + // 返回副本,防止外部修改 + connCopy := make(map[string]*Connection, len(m.connections)) + for k, v := range m.connections { + connCopy[k] = v + } + return connCopy +} + +func (m *Manager) GetConn(connID string) *Connection { + m.mutex.RLock() + defer m.mutex.RUnlock() + return m.connections[connID] +} + +// CloseAll 关闭所有连接 +func (m *Manager) CloseAll() { + m.mutex.RLock() + connIDs := make([]string, 0, len(m.connections)) + for connID := range m.connections { + connIDs = append(connIDs, connID) + } + m.mutex.RUnlock() + + for _, connID := range connIDs { + m.mutex.RLock() + conn := m.connections[connID] + m.mutex.RUnlock() + if conn != nil { + conn.Close(errors.New("服务端主动关闭所有连接")) + } + } +}