diff --git a/.idea/GOHCache.xml b/.idea/GOHCache.xml
index 2bdc17b..bbc7165 100644
--- a/.idea/GOHCache.xml
+++ b/.idea/GOHCache.xml
@@ -17,6 +17,20 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -34,6 +48,8 @@
+
+
@@ -41,17 +57,49 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -83,28 +131,30 @@
+
+
+
+
+
+
+
+
+
-
+
-
-
-
-
-
-
-
-
+
@@ -125,6 +175,8 @@
+
+
@@ -132,6 +184,7 @@
+
@@ -139,6 +192,7 @@
+
@@ -146,24 +200,11 @@
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
@@ -178,10 +219,140 @@
+
+
+
+
+
+
+
-
+
\ No newline at end of file
diff --git a/pool/common.go b/pool/common.go
new file mode 100644
index 0000000..fc08dae
--- /dev/null
+++ b/pool/common.go
@@ -0,0 +1,25 @@
+package pool
+
+import (
+ "time"
+)
+
+// ConnType 连接类型
+type ConnType string
+
+const (
+ ConnTypeWebSocket ConnType = "websocket"
+ ConnTypeTCP ConnType = "tcp"
+)
+
+// ConnectionInfo 连接信息
+type ConnectionInfo struct {
+ ID string `json:"id" gorm:"primaryKey"`
+ Type ConnType `json:"type" gorm:"index"`
+ Address string `json:"address"`
+ IsActive bool `json:"isActive" gorm:"index"`
+ LastUsed time.Time `json:"lastUsed"`
+ CreatedAt time.Time `json:"createdAt"`
+ // 额外的连接数据,根据不同类型存储不同的信息
+ Data map[string]interface{} `json:"data" gorm:"-"`
+}
diff --git a/pool/sqlite.go b/pool/sqlite.go
new file mode 100644
index 0000000..f22aacc
--- /dev/null
+++ b/pool/sqlite.go
@@ -0,0 +1,247 @@
+package pool
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "time"
+
+ "git.magicany.cc/black1552/gf-common/db"
+ "gorm.io/gorm"
+)
+
+// SQLitePool SQLite连接池
+type SQLitePool struct {
+ db *gorm.DB
+ mutex sync.RWMutex
+ ctx context.Context
+ cancel context.CancelFunc
+ // 内存缓存,提高并发性能
+ cache map[string]*ConnectionInfo
+}
+
+// NewSQLitePool 创建SQLite连接池
+func NewSQLitePool() (*SQLitePool, error) {
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // 检查数据库连接是否正常
+ if db.Db == nil {
+ return nil, fmt.Errorf("database connection is not initialized")
+ }
+
+ // 自动迁移ConnectionInfo模型
+ err := db.Db.AutoMigrate(&ConnectionInfo{})
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("failed to migrate connection info model: %w", err)
+ }
+
+ return &SQLitePool{
+ db: db.Db,
+ ctx: ctx,
+ cancel: cancel,
+ cache: make(map[string]*ConnectionInfo),
+ }, nil
+}
+
+// Close 关闭连接池
+func (p *SQLitePool) Close() error {
+ p.cancel()
+ // SQLite连接由db包管理,不需要在这里关闭
+ return nil
+}
+
+// Add 添加连接
+func (p *SQLitePool) Add(conn *ConnectionInfo) error {
+ p.mutex.Lock()
+ defer p.mutex.Unlock()
+
+ // 存储到SQLite
+ result := p.db.Create(conn)
+ if result.Error != nil {
+ return fmt.Errorf("failed to store connection: %w", result.Error)
+ }
+
+ // 更新内存缓存
+ p.cache[conn.ID] = conn
+ return nil
+}
+
+// Get 获取连接
+func (p *SQLitePool) Get(connID string) (*ConnectionInfo, error) {
+ p.mutex.RLock()
+ // 先从内存缓存获取
+ if conn, ok := p.cache[connID]; ok {
+ p.mutex.RUnlock()
+ return conn, nil
+ }
+ p.mutex.RUnlock()
+
+ // 从SQLite获取
+ var connInfo ConnectionInfo
+ result := p.db.First(&connInfo, "id = ?", connID)
+ if result.Error != nil {
+ if result.Error == gorm.ErrRecordNotFound {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to get connection: %w", result.Error)
+ }
+
+ // 更新内存缓存
+ p.mutex.Lock()
+ p.cache[connID] = &connInfo
+ p.mutex.Unlock()
+
+ return &connInfo, nil
+}
+
+// Remove 移除连接
+func (p *SQLitePool) Remove(connID string) error {
+ p.mutex.Lock()
+ defer p.mutex.Unlock()
+
+ // 从SQLite删除
+ result := p.db.Delete(&ConnectionInfo{}, "id = ?", connID)
+ if result.Error != nil {
+ return fmt.Errorf("failed to remove connection: %w", result.Error)
+ }
+
+ // 从内存缓存删除
+ delete(p.cache, connID)
+ return nil
+}
+
+// Update 更新连接信息
+func (p *SQLitePool) Update(conn *ConnectionInfo) error {
+ p.mutex.Lock()
+ defer p.mutex.Unlock()
+
+ // 存储到SQLite
+ result := p.db.Save(conn)
+ if result.Error != nil {
+ return fmt.Errorf("failed to update connection: %w", result.Error)
+ }
+
+ // 更新内存缓存
+ p.cache[conn.ID] = conn
+ return nil
+}
+
+// GetAll 获取所有连接
+func (p *SQLitePool) GetAll() ([]*ConnectionInfo, error) {
+ p.mutex.RLock()
+ // 如果内存缓存不为空,直接返回缓存
+ if len(p.cache) > 0 {
+ conns := make([]*ConnectionInfo, 0, len(p.cache))
+ for _, conn := range p.cache {
+ conns = append(conns, conn)
+ }
+ p.mutex.RUnlock()
+ return conns, nil
+ }
+ p.mutex.RUnlock()
+
+ // 从SQLite获取所有连接
+ var conns []*ConnectionInfo
+ result := p.db.Find(&conns)
+ if result.Error != nil {
+ return nil, fmt.Errorf("failed to get all connections: %w", result.Error)
+ }
+
+ // 更新内存缓存
+ p.mutex.Lock()
+ for _, conn := range conns {
+ p.cache[conn.ID] = conn
+ }
+ p.mutex.Unlock()
+
+ return conns, nil
+}
+
+// GetByType 根据类型获取连接
+func (p *SQLitePool) GetByType(connType ConnType) ([]*ConnectionInfo, error) {
+ allConns, err := p.GetAll()
+ if err != nil {
+ return nil, err
+ }
+
+ var filtered []*ConnectionInfo
+ for _, conn := range allConns {
+ if conn.Type == connType {
+ filtered = append(filtered, conn)
+ }
+ }
+
+ return filtered, nil
+}
+
+// Count 获取连接数量
+func (p *SQLitePool) Count() (int, error) {
+ p.mutex.RLock()
+ // 如果内存缓存不为空,直接返回缓存大小
+ if len(p.cache) > 0 {
+ count := len(p.cache)
+ p.mutex.RUnlock()
+ return count, nil
+ }
+ p.mutex.RUnlock()
+
+ // 从SQLite统计数量
+ var count int64
+ result := p.db.Model(&ConnectionInfo{}).Count(&count)
+ if result.Error != nil {
+ return 0, fmt.Errorf("failed to count connections: %w", result.Error)
+ }
+
+ return int(count), nil
+}
+
+// GetAllConnIDs 获取所有在线连接的ID列表
+func (p *SQLitePool) GetAllConnIDs() ([]string, error) {
+ p.mutex.RLock()
+ // 如果内存缓存不为空,从缓存中提取在线连接的ID
+ if len(p.cache) > 0 {
+ ids := make([]string, 0, len(p.cache))
+ for id, conn := range p.cache {
+ if conn.IsActive {
+ ids = append(ids, id)
+ }
+ }
+ p.mutex.RUnlock()
+ return ids, nil
+ }
+ p.mutex.RUnlock()
+
+ // 从SQLite获取所有在线连接的ID
+ var conns []*ConnectionInfo
+ result := p.db.Where("is_active = ?", true).Find(&conns)
+ if result.Error != nil {
+ return nil, fmt.Errorf("failed to get all connection IDs: %w", result.Error)
+ }
+
+ ids := make([]string, 0, len(conns))
+ for _, conn := range conns {
+ ids = append(ids, conn.ID)
+ }
+
+ return ids, nil
+}
+
+// CleanupInactive 清理不活跃的连接
+func (p *SQLitePool) CleanupInactive(duration time.Duration) error {
+ allConns, err := p.GetAll()
+ if err != nil {
+ return err
+ }
+
+ now := time.Now()
+ for _, conn := range allConns {
+ if !conn.IsActive || now.Sub(conn.LastUsed) > duration {
+ if err := p.Remove(conn.ID); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/tcp/example.go b/tcp/example.go
index cd54038..7f0b8d0 100644
--- a/tcp/example.go
+++ b/tcp/example.go
@@ -18,7 +18,11 @@ func Example() {
}
// 创建TCP服务器
- server := NewTCPServer("0.0.0.0:8888", config)
+ server, err := NewTCPServer("0.0.0.0:8888", config)
+ if err != nil {
+ fmt.Printf("Failed to create server: %v\n", err)
+ return
+ }
// 设置消息处理函数
server.SetMessageHandler(func(conn *TcpConnection, msg *TcpMessage) error {
@@ -45,3 +49,48 @@ func Example() {
fmt.Println("TCP server stopped.")
}
+
+// TestTCP 测试TCP连接
+func TestTCP() {
+ fmt.Println("=== 测试TCP连接 ===")
+ fmt.Println("1. 创建TCP服务器配置")
+ config := &TcpPoolConfig{
+ BufferSize: 2048,
+ MaxConnections: 100000,
+ ConnectTimeout: time.Second * 5,
+ ReadTimeout: time.Second * 30,
+ WriteTimeout: time.Second * 10,
+ MaxIdleTime: time.Minute * 5,
+ }
+ fmt.Println("2. 创建TCP服务器")
+ server, err := NewTCPServer("0.0.0.0:8888", config)
+ if err != nil {
+ fmt.Printf("创建服务器失败:%v\n", err)
+ return
+ }
+ fmt.Println("3. 服务器创建成功")
+ fmt.Println("4. 获取在线连接数")
+ count := server.Connection.Count()
+ fmt.Printf("当前在线连接数:%d\n", count)
+ fmt.Println("5. 获取所有在线连接ID")
+ connIDs, err := server.GetAllConnIDs()
+ if err != nil {
+ fmt.Printf("获取在线连接ID失败:%v\n", err)
+ } else {
+ fmt.Printf("在线连接ID:%v\n", connIDs)
+ }
+ fmt.Println("6. 启动服务器")
+ if err := server.Start(); err != nil {
+ fmt.Printf("启动服务器失败:%v\n", err)
+ return
+ }
+ fmt.Println("7. 服务器启动成功,运行2秒后停止")
+ time.Sleep(time.Second * 2)
+ fmt.Println("8. 停止服务器")
+ if err := server.Stop(); err != nil {
+ fmt.Printf("停止服务器失败:%v\n", err)
+ } else {
+ fmt.Println("服务器停止成功")
+ }
+ fmt.Println("=== TCP测试完成 ===")
+}
diff --git a/tcp/tcp.go b/tcp/tcp.go
index 73b5bd3..a2c1354 100644
--- a/tcp/tcp.go
+++ b/tcp/tcp.go
@@ -6,6 +6,7 @@ import (
"sync"
"time"
+ "git.magicany.cc/black1552/gf-common/pool"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/glog"
@@ -32,18 +33,26 @@ type TCPServer struct {
// ConnectionPool 连接池结构
type ConnectionPool struct {
connections map[string]*TcpConnection
+ sqlitePool *pool.SQLitePool
mutex sync.RWMutex
config *TcpPoolConfig
logger *glog.Logger
}
// NewTCPServer 创建一个新的TCP服务器
-func NewTCPServer(address string, config *TcpPoolConfig) *TCPServer {
+func NewTCPServer(address string, config *TcpPoolConfig) (*TCPServer, error) {
logger := g.Log(address)
ctx, cancel := context.WithCancel(context.Background())
+ // 初始化SQLite连接池
+ sqlitePool, err := pool.NewSQLitePool()
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sqlite pool: %w", err)
+ }
+
pool := &ConnectionPool{
connections: make(map[string]*TcpConnection),
+ sqlitePool: sqlitePool,
config: config,
logger: logger,
}
@@ -58,7 +67,7 @@ func NewTCPServer(address string, config *TcpPoolConfig) *TCPServer {
}
server.Listener = gtcp.NewServer(address, server.handleConnection)
- return server
+ return server, nil
}
// SetMessageHandler 设置消息处理函数
@@ -86,6 +95,11 @@ func (s *TCPServer) Stop() error {
s.Listener.Close()
s.wg.Wait()
s.Connection.Clear()
+ // 关闭SQLite连接池
+ if err := s.Connection.sqlitePool.Close(); err != nil {
+ s.Logger.Error(s.ctx, fmt.Sprintf("Failed to close SQLite pool: %v", err))
+ // 不影响服务器停止,仅记录错误
+ }
s.Logger.Info(s.ctx, "TCP server stopped")
return nil
}
@@ -109,6 +123,23 @@ func (s *TCPServer) handleConnection(conn *gtcp.Conn) {
s.Connection.Add(tcpConn)
s.Logger.Info(s.ctx, fmt.Sprintf("New connection established: %s", connID))
+ // 存储到SQLite
+ connInfo := &pool.ConnectionInfo{
+ ID: connID,
+ Type: pool.ConnTypeTCP,
+ Address: conn.RemoteAddr().String(),
+ IsActive: true,
+ LastUsed: time.Now(),
+ CreatedAt: time.Now(),
+ Data: map[string]interface{}{
+ "localAddress": conn.LocalAddr().String(),
+ },
+ }
+ if err := s.Connection.sqlitePool.Add(connInfo); err != nil {
+ s.Logger.Error(s.ctx, fmt.Sprintf("Failed to store connection to SQLite: %v", err))
+ // 不影响连接建立,仅记录错误
+ }
+
// 启动消息接收协程
go s.receiveMessages(tcpConn)
}
@@ -121,6 +152,11 @@ func (s *TCPServer) receiveMessages(conn *TcpConnection) {
}
s.Connection.Remove(conn.Id)
conn.Server.Close()
+ // 从SQLite移除
+ if err := s.Connection.sqlitePool.Remove(conn.Id); err != nil {
+ s.Logger.Error(s.ctx, fmt.Sprintf("Failed to remove connection from SQLite: %v", err))
+ // 不影响连接关闭,仅记录错误
+ }
s.Logger.Info(s.ctx, fmt.Sprintf("Connection closed: %s", conn.Id))
}()
@@ -142,10 +178,21 @@ func (s *TCPServer) receiveMessages(conn *TcpConnection) {
if n > 0 {
// 更新最后使用时间
+ now := time.Now()
conn.Mutex.Lock()
- conn.LastUsed = time.Now()
+ conn.LastUsed = now
conn.Mutex.Unlock()
+ // 更新SQLite中的连接信息
+ connInfo, err := s.Connection.sqlitePool.Get(conn.Id)
+ if err == nil && connInfo != nil {
+ connInfo.LastUsed = now
+ if err := s.Connection.sqlitePool.Update(connInfo); err != nil {
+ s.Logger.Error(s.ctx, fmt.Sprintf("Failed to update connection in SQLite: %v", err))
+ // 不影响消息处理,仅记录错误
+ }
+ }
+
// 处理消息
data := make([]byte, n)
copy(data, buffer[:n])
@@ -209,7 +256,19 @@ func (s *TCPServer) sendMessage(conn *TcpConnection, data []byte) error {
}
// 更新最后使用时间
- conn.LastUsed = time.Now()
+ now := time.Now()
+ conn.LastUsed = now
+
+ // 更新SQLite中的连接信息
+ connInfo, err := s.Connection.sqlitePool.Get(conn.Id)
+ if err == nil && connInfo != nil {
+ connInfo.LastUsed = now
+ if err := s.Connection.sqlitePool.Update(connInfo); err != nil {
+ s.Logger.Error(s.ctx, fmt.Sprintf("Failed to update connection in SQLite: %v", err))
+ // 不影响消息发送,仅记录错误
+ }
+ }
+
return nil
}
@@ -224,11 +283,21 @@ func (s *TCPServer) Kick(connID string) error {
conn.Server.Close()
// 从连接池移除
s.Connection.Remove(connID)
+ // 从SQLite移除
+ if err := s.Connection.sqlitePool.Remove(connID); err != nil {
+ s.Logger.Error(s.ctx, fmt.Sprintf("Failed to remove connection from SQLite: %v", err))
+ // 不影响连接关闭,仅记录错误
+ }
s.Logger.Info(s.ctx, fmt.Sprintf("Kicked connection: %s", connID))
return nil
}
+// GetAllConnIDs 获取所有在线连接的ID列表
+func (s *TCPServer) GetAllConnIDs() ([]string, error) {
+ return s.Connection.GetAllConnIDs()
+}
+
// Add 添加连接到连接池
func (p *ConnectionPool) Add(conn *TcpConnection) {
p.mutex.Lock()
@@ -278,3 +347,8 @@ func (p *ConnectionPool) Count() int {
defer p.mutex.RUnlock()
return len(p.connections)
}
+
+// GetAllConnIDs 获取所有在线连接的ID列表
+func (p *ConnectionPool) GetAllConnIDs() ([]string, error) {
+ return p.sqlitePool.GetAllConnIDs()
+}
diff --git a/ws/example.go b/ws/example.go
index 1420de9..06fd21f 100644
--- a/ws/example.go
+++ b/ws/example.go
@@ -19,7 +19,10 @@ func NewWs() *Manager {
}
// 2. 创建管理器
- m := NewManager(customConfig)
+ m, err := NewManager(customConfig)
+ if err != nil {
+ log.Fatalf("Failed to create manager: %v", err)
+ }
// 3. 覆盖业务回调(核心:自定义消息处理逻辑)
// 连接建立回调
@@ -71,3 +74,35 @@ func main() {
log.Println("WebSocket服务启动:http://localhost:8080/ws")
log.Fatal(http.ListenAndServe(":8080", nil))
}
+
+// TestWebSocket 测试WebSocket连接
+func TestWebSocket() {
+ log.Println("=== 测试WebSocket连接 ===")
+ log.Println("1. 创建WebSocket管理器")
+ m, err := NewManager(DefaultConfig())
+ if err != nil {
+ log.Fatalf("创建管理器失败:%v", err)
+ }
+ log.Println("2. 管理器创建成功")
+ log.Println("3. 获取在线连接数")
+ count, err := m.sqlitePool.Count()
+ if err != nil {
+ log.Printf("获取在线连接数失败:%v", err)
+ } else {
+ log.Printf("当前在线连接数:%d", count)
+ }
+ log.Println("4. 获取所有在线连接ID")
+ connIDs, err := m.GetAllConnIDs()
+ if err != nil {
+ log.Printf("获取在线连接ID失败:%v", err)
+ } else {
+ log.Printf("在线连接ID:%v", connIDs)
+ }
+ log.Println("5. 关闭管理器")
+ if err := m.Close(); err != nil {
+ log.Printf("关闭管理器失败:%v", err)
+ } else {
+ log.Println("管理器关闭成功")
+ }
+ log.Println("=== WebSocket测试完成 ===")
+}
diff --git a/ws/websocket.go b/ws/websocket.go
index e02c0be..b7c45a2 100644
--- a/ws/websocket.go
+++ b/ws/websocket.go
@@ -9,6 +9,7 @@ import (
"sync"
"time"
+ "git.magicany.cc/black1552/gin-base/pool"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/gtime"
@@ -20,20 +21,20 @@ import (
// 常量定义:默认配置
const (
- // DefaultReadBufferSize 默认读写缓冲区大小(字节)
+ // 默认读写缓冲区大小(字节)
DefaultReadBufferSize = 1024
DefaultWriteBufferSize = 1024
- // DefaultHeartbeatInterval 默认心跳间隔(秒):每30秒发送一次心跳
+ // 默认心跳间隔(秒):每30秒发送一次心跳
DefaultHeartbeatInterval = 30 * time.Second
- // DefaultHeartbeatTimeout 默认心跳超时(秒):60秒未收到客户端心跳响应则关闭连接
+ // 默认心跳超时(秒):60秒未收到客户端心跳响应则关闭连接
DefaultHeartbeatTimeout = 60 * time.Second
- // DefaultReadTimeout 默认读写超时(秒)
+ // 默认读写超时(秒)
DefaultReadTimeout = 60 * time.Second
DefaultWriteTimeout = 10 * time.Second
- // MessageTypeText 消息类型
+ // 消息类型
MessageTypeText = websocket.TextMessage
MessageTypeBinary = websocket.BinaryMessage
- // HeartbeatMaxRetry 心跳最大重试次数
+ // 心跳最大重试次数
HeartbeatMaxRetry = 3
)
@@ -92,7 +93,8 @@ type Connection struct {
type Manager struct {
config *Config // 配置
upgrader *websocket.Upgrader // HTTP升级器
- connections map[string]*Connection // 所有在线连接(connID -> Connection)
+ connections map[string]*Connection // 内存中的连接(connID -> Connection)
+ sqlitePool *pool.SQLitePool // SQLite连接池
mutex sync.RWMutex // 读写锁(保护connections)
// 业务回调:收到消息时触发(用户自定义处理逻辑)
OnMessage func(connID string, msgType int, data any)
@@ -148,16 +150,16 @@ func (c *Config) Merge(other *Config) *Config {
}
// NewManager 创建连接管理器
-func NewManager(config *Config) *Manager {
+func NewManager(config *Config) (*Manager, error) {
defaultConfig := DefaultConfig()
finalConfig := defaultConfig.Merge(config)
// 初始化升级器
upgrader := &websocket.Upgrader{
- ReadBufferSize: config.ReadBufferSize,
- WriteBufferSize: config.WriteBufferSize,
+ ReadBufferSize: finalConfig.ReadBufferSize,
+ WriteBufferSize: finalConfig.WriteBufferSize,
CheckOrigin: func(r *http.Request) bool {
// 跨域检查
- if config.AllowAllOrigins {
+ if finalConfig.AllowAllOrigins {
return true
}
origin := r.Header.Get("Origin")
@@ -170,10 +172,17 @@ func NewManager(config *Config) *Manager {
},
}
+ // 初始化SQLite连接池
+ sqlitePool, err := pool.NewSQLitePool()
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sqlite pool: %w", err)
+ }
+
return &Manager{
config: finalConfig,
upgrader: upgrader,
connections: make(map[string]*Connection),
+ sqlitePool: sqlitePool,
mutex: sync.RWMutex{},
// 默认回调(用户可覆盖)
OnMessage: func(connID string, msgType int, data any) {
@@ -185,7 +194,7 @@ func NewManager(config *Config) *Manager {
OnDisconnect: func(connID string, err error) {
log.Printf("[默认回调] 连接[%s]已关闭:%v", connID, err)
},
- }
+ }, nil
}
// Upgrade HTTP升级为WebSocket连接
@@ -237,6 +246,24 @@ func (m *Manager) Upgrade(w http.ResponseWriter, r *http.Request, connID string)
m.connections[connID] = wsConn
m.mutex.Unlock()
+ // 存储到SQLite
+ connInfo := &pool.ConnectionInfo{
+ ID: connID,
+ Type: pool.ConnTypeWebSocket,
+ Address: r.RemoteAddr,
+ IsActive: true,
+ LastUsed: time.Now(),
+ CreatedAt: time.Now(),
+ Data: map[string]interface{}{
+ "origin": r.Header.Get("Origin"),
+ "userAgent": r.Header.Get("User-Agent"),
+ },
+ }
+ if err := m.sqlitePool.Add(connInfo); err != nil {
+ log.Printf("[错误] 存储连接到SQLite失败:%v", err)
+ // 不影响连接建立,仅记录错误
+ }
+
// 触发连接建立回调
m.OnConnect(connID)
@@ -282,6 +309,18 @@ func (c *Connection) ReadPump() {
return
}
+ // 更新最后使用时间
+ now := time.Now()
+ // 从SQLite获取连接信息并更新
+ connInfo, err := c.manager.sqlitePool.Get(c.connID)
+ if err == nil && connInfo != nil {
+ connInfo.LastUsed = now
+ if err := c.manager.sqlitePool.Update(connInfo); err != nil {
+ log.Printf("[错误] 更新SQLite连接信息失败:%v", err)
+ // 不影响消息处理,仅记录错误
+ }
+ }
+
// 尝试解析JSON格式的心跳消息(精准判断,替代包含判断)
isHeartbeat := false
// 先尝试解析为JSON对象
@@ -369,6 +408,19 @@ func (c *Connection) Send(data []byte) error {
if err != nil {
return fmt.Errorf("发送消息失败:%w", err)
}
+
+ // 更新最后使用时间
+ now := time.Now()
+ // 从SQLite获取连接信息并更新
+ connInfo, err := c.manager.sqlitePool.Get(c.connID)
+ if err == nil && connInfo != nil {
+ connInfo.LastUsed = now
+ if err := c.manager.sqlitePool.Update(connInfo); err != nil {
+ log.Printf("[错误] 更新SQLite连接信息失败:%v", err)
+ // 不影响消息发送,仅记录错误
+ }
+ }
+
return nil
}
}
@@ -394,6 +446,12 @@ func (c *Connection) Close(err error) {
delete(c.manager.connections, c.connID)
c.manager.mutex.Unlock()
+ // 从SQLite移除
+ if err := c.manager.sqlitePool.Remove(c.connID); err != nil {
+ log.Printf("[错误] 从SQLite移除连接失败:%v", err)
+ // 不影响连接关闭,仅记录错误
+ }
+
// 触发断开回调
c.manager.OnDisconnect(c.connID, err)
@@ -462,12 +520,18 @@ func (m *Manager) GetAllConn() map[string]*Connection {
return connCopy
}
+// GetConn 获取指定连接
func (m *Manager) GetConn(connID string) *Connection {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.connections[connID]
}
+// GetAllConnIDs 获取所有在线连接的ID列表
+func (m *Manager) GetAllConnIDs() ([]string, error) {
+ return m.sqlitePool.GetAllConnIDs()
+}
+
// CloseAll 关闭所有连接
func (m *Manager) CloseAll() {
m.mutex.RLock()
@@ -486,3 +550,14 @@ func (m *Manager) CloseAll() {
}
}
}
+
+// Close 关闭管理器,清理资源
+func (m *Manager) Close() error {
+ // 关闭所有连接
+ m.CloseAll()
+ // 关闭SQLite连接池
+ if m.sqlitePool != nil {
+ return m.sqlitePool.Close()
+ }
+ return nil
+}