package pool import ( "context" "encoding/json" "fmt" "sync" "time" "github.com/gogf/gf/v2/os/gfile" "github.com/nutsdb/nutsdb" ) // ConnType 连接类型 type ConnType string const ( ConnTypeWebSocket ConnType = "websocket" ConnTypeTCP ConnType = "tcp" ) // ConnectionInfo 连接信息 type ConnectionInfo struct { ID string `json:"id"` Type ConnType `json:"type"` Address string `json:"address"` IsActive bool `json:"isActive"` LastUsed time.Time `json:"lastUsed"` CreatedAt time.Time `json:"createdAt"` // 额外的连接数据,根据不同类型存储不同的信息 Data map[string]interface{} `json:"data"` } // NutsPool NutsDB连接池 type NutsPool struct { db *nutsdb.DB bucket string mutex sync.RWMutex ctx context.Context cancel context.CancelFunc // 内存缓存,提高并发性能 cache map[string]*ConnectionInfo } // NewNutsPool 创建NutsDB连接池 func NewNutsPool() (*NutsPool, error) { ctx, cancel := context.WithCancel(context.Background()) // 使用当前运行项目的根目录的nuts文件夹作为存储路径,并添加时间戳以避免锁定问题 dir := gfile.Pwd() + "/nuts" // 打开NutsDB db, err := nutsdb.Open( nutsdb.DefaultOptions, nutsdb.WithDir(dir), ) if err != nil { cancel() return nil, fmt.Errorf("failed to open nutsdb: %w", err) } return &NutsPool{ db: db, bucket: "connections", ctx: ctx, cancel: cancel, cache: make(map[string]*ConnectionInfo), }, nil } // Close 关闭连接池 func (p *NutsPool) Close() error { p.cancel() return p.db.Close() } // Add 添加连接 func (p *NutsPool) Add(conn *ConnectionInfo) error { p.mutex.Lock() defer p.mutex.Unlock() // 序列化连接信息 data, err := json.Marshal(conn) if err != nil { return fmt.Errorf("failed to marshal connection info: %w", err) } // 存储到NutsDB err = p.db.Update(func(tx *nutsdb.Tx) error { return tx.Put(p.bucket, []byte(conn.ID), data, 0) }) if err != nil { return fmt.Errorf("failed to store connection: %w", err) } // 更新内存缓存 p.cache[conn.ID] = conn return nil } // Get 获取连接 func (p *NutsPool) Get(connID string) (*ConnectionInfo, error) { p.mutex.RLock() // 先从内存缓存获取 if conn, ok := p.cache[connID]; ok { p.mutex.RUnlock() return conn, nil } p.mutex.RUnlock() // 从NutsDB获取 var connInfo ConnectionInfo err := p.db.View(func(tx *nutsdb.Tx) error { data, err := tx.Get(p.bucket, []byte(connID)) if err != nil { return err } return json.Unmarshal(data, &connInfo) }) if err != nil { if err == nutsdb.ErrKeyNotFound { return nil, nil } return nil, fmt.Errorf("failed to get connection: %w", err) } // 更新内存缓存 p.mutex.Lock() p.cache[connID] = &connInfo p.mutex.Unlock() return &connInfo, nil } // Remove 移除连接 func (p *NutsPool) Remove(connID string) error { p.mutex.Lock() defer p.mutex.Unlock() // 从NutsDB删除 err := p.db.Update(func(tx *nutsdb.Tx) error { return tx.Delete(p.bucket, []byte(connID)) }) if err != nil { return fmt.Errorf("failed to remove connection: %w", err) } // 从内存缓存删除 delete(p.cache, connID) return nil } // Update 更新连接信息 func (p *NutsPool) Update(conn *ConnectionInfo) error { p.mutex.Lock() defer p.mutex.Unlock() // 序列化连接信息 data, err := json.Marshal(conn) if err != nil { return fmt.Errorf("failed to marshal connection info: %w", err) } // 存储到NutsDB err = p.db.Update(func(tx *nutsdb.Tx) error { return tx.Put(p.bucket, []byte(conn.ID), data, 0) }) if err != nil { return fmt.Errorf("failed to update connection: %w", err) } // 更新内存缓存 p.cache[conn.ID] = conn return nil } // GetAll 获取所有连接 func (p *NutsPool) GetAll() ([]*ConnectionInfo, error) { p.mutex.RLock() // 如果内存缓存不为空,直接返回缓存 if len(p.cache) > 0 { conns := make([]*ConnectionInfo, 0, len(p.cache)) for _, conn := range p.cache { conns = append(conns, conn) } p.mutex.RUnlock() return conns, nil } p.mutex.RUnlock() // 从NutsDB获取所有连接 var conns []*ConnectionInfo err := p.db.View(func(tx *nutsdb.Tx) error { keys, _, err := tx.GetAll(p.bucket) if err != nil { return err } for _, key := range keys { val, err := tx.Get(p.bucket, key) if err != nil { return err } var connInfo ConnectionInfo if err := json.Unmarshal(val, &connInfo); err != nil { return err } conns = append(conns, &connInfo) } return nil }) if err != nil { return nil, fmt.Errorf("failed to get all connections: %w", err) } // 更新内存缓存 p.mutex.Lock() for _, conn := range conns { p.cache[conn.ID] = conn } p.mutex.Unlock() return conns, nil } // GetByType 根据类型获取连接 func (p *NutsPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) { allConns, err := p.GetAll() if err != nil { return nil, err } var filtered []*ConnectionInfo for _, conn := range allConns { if conn.Type == connType { filtered = append(filtered, conn) } } return filtered, nil } // Count 获取连接数量 func (p *NutsPool) Count() (int, error) { p.mutex.RLock() // 如果内存缓存不为空,直接返回缓存大小 if len(p.cache) > 0 { count := len(p.cache) p.mutex.RUnlock() return count, nil } p.mutex.RUnlock() // 从NutsDB统计数量 var count int err := p.db.View(func(tx *nutsdb.Tx) error { entries, _, err := tx.GetAll(p.bucket) if err != nil { return err } count = len(entries) return nil }) if err != nil { return 0, fmt.Errorf("failed to count connections: %w", err) } return count, nil } // GetAllConnIDs 获取所有在线连接的ID列表 func (p *NutsPool) GetAllConnIDs() ([]string, error) { p.mutex.RLock() // 如果内存缓存不为空,从缓存中提取在线连接的ID if len(p.cache) > 0 { ids := make([]string, 0, len(p.cache)) for id, conn := range p.cache { if conn.IsActive { ids = append(ids, id) } } p.mutex.RUnlock() return ids, nil } p.mutex.RUnlock() // 从NutsDB获取所有在线连接的ID var ids []string err := p.db.View(func(tx *nutsdb.Tx) error { keys, _, err := tx.GetAll(p.bucket) if err != nil { return err } for _, key := range keys { val, err := tx.Get(p.bucket, key) if err != nil { return err } var connInfo ConnectionInfo if err := json.Unmarshal(val, &connInfo); err != nil { return err } if connInfo.IsActive { ids = append(ids, string(key)) } } return nil }) if err != nil { return nil, fmt.Errorf("failed to get all connection IDs: %w", err) } return ids, nil } // CleanupInactive 清理不活跃的连接 func (p *NutsPool) CleanupInactive(duration time.Duration) error { allConns, err := p.GetAll() if err != nil { return err } now := time.Now() for _, conn := range allConns { if !conn.IsActive || now.Sub(conn.LastUsed) > duration { if err := p.Remove(conn.ID); err != nil { return err } } } return nil }