package pool import ( "context" "encoding/json" "fmt" "sync" "time" "github.com/dgraph-io/badger/v4" ) // ConnType 连接类型 type ConnType string const ( ConnTypeWebSocket ConnType = "websocket" ConnTypeTCP ConnType = "tcp" ) // ConnectionInfo 连接信息 type ConnectionInfo struct { ID string `json:"id"` Type ConnType `json:"type"` Address string `json:"address"` IsActive bool `json:"isActive"` LastUsed time.Time `json:"lastUsed"` CreatedAt time.Time `json:"createdAt"` // 额外的连接数据,根据不同类型存储不同的信息 Data map[string]interface{} `json:"data"` } // BadgerPool BadgerDB连接池 type BadgerPool struct { db *badger.DB mutex sync.RWMutex ctx context.Context cancel context.CancelFunc // 内存缓存,提高并发性能 cache map[string]*ConnectionInfo } // NewBadgerPool 创建BadgerDB连接池 func NewBadgerPool(badgerDir string) (*BadgerPool, error) { ctx, cancel := context.WithCancel(context.Background()) db, err := badger.Open(badger.DefaultOptions(badgerDir)) if err != nil { cancel() return nil, fmt.Errorf("failed to open badger db: %w", err) } return &BadgerPool{ db: db, ctx: ctx, cancel: cancel, cache: make(map[string]*ConnectionInfo), }, nil } // Close 关闭连接池 func (p *BadgerPool) Close() error { p.cancel() return p.db.Close() } // Add 添加连接 func (p *BadgerPool) Add(conn *ConnectionInfo) error { p.mutex.Lock() defer p.mutex.Unlock() // 序列化连接信息 data, err := json.Marshal(conn) if err != nil { return fmt.Errorf("failed to marshal connection info: %w", err) } // 存储到BadgerDB err = p.db.Update(func(txn *badger.Txn) error { return txn.Set([]byte(conn.ID), data) }) if err != nil { return fmt.Errorf("failed to store connection: %w", err) } // 更新内存缓存 p.cache[conn.ID] = conn return nil } // Get 获取连接 func (p *BadgerPool) Get(connID string) (*ConnectionInfo, error) { p.mutex.RLock() // 先从内存缓存获取 if conn, ok := p.cache[connID]; ok { p.mutex.RUnlock() return conn, nil } p.mutex.RUnlock() // 从BadgerDB获取 var connInfo ConnectionInfo err := p.db.View(func(txn *badger.Txn) error { item, err := txn.Get([]byte(connID)) if err != nil { return err } return item.Value(func(val []byte) error { return json.Unmarshal(val, &connInfo) }) }) if err != nil { if err == badger.ErrKeyNotFound { return nil, nil } return nil, fmt.Errorf("failed to get connection: %w", err) } // 更新内存缓存 p.mutex.Lock() p.cache[connID] = &connInfo p.mutex.Unlock() return &connInfo, nil } // Remove 移除连接 func (p *BadgerPool) Remove(connID string) error { p.mutex.Lock() defer p.mutex.Unlock() // 从BadgerDB删除 err := p.db.Update(func(txn *badger.Txn) error { return txn.Delete([]byte(connID)) }) if err != nil { return fmt.Errorf("failed to remove connection: %w", err) } // 从内存缓存删除 delete(p.cache, connID) return nil } // Update 更新连接信息 func (p *BadgerPool) Update(conn *ConnectionInfo) error { p.mutex.Lock() defer p.mutex.Unlock() // 序列化连接信息 data, err := json.Marshal(conn) if err != nil { return fmt.Errorf("failed to marshal connection info: %w", err) } // 存储到BadgerDB err = p.db.Update(func(txn *badger.Txn) error { return txn.Set([]byte(conn.ID), data) }) if err != nil { return fmt.Errorf("failed to update connection: %w", err) } // 更新内存缓存 p.cache[conn.ID] = conn return nil } // GetAll 获取所有连接 func (p *BadgerPool) GetAll() ([]*ConnectionInfo, error) { p.mutex.RLock() // 如果内存缓存不为空,直接返回缓存 if len(p.cache) > 0 { conns := make([]*ConnectionInfo, 0, len(p.cache)) for _, conn := range p.cache { conns = append(conns, conn) } p.mutex.RUnlock() return conns, nil } p.mutex.RUnlock() // 从BadgerDB获取所有连接 var conns []*ConnectionInfo err := p.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchSize = 10 it := txn.NewIterator(opts) defer it.Close() for it.Rewind(); it.Valid(); it.Next() { item := it.Item() var connInfo ConnectionInfo err := item.Value(func(val []byte) error { return json.Unmarshal(val, &connInfo) }) if err != nil { return err } conns = append(conns, &connInfo) } return nil }) if err != nil { return nil, fmt.Errorf("failed to get all connections: %w", err) } // 更新内存缓存 p.mutex.Lock() for _, conn := range conns { p.cache[conn.ID] = conn } p.mutex.Unlock() return conns, nil } // GetAllConnIDs 获取所有在线连接的ID列表 func (p *BadgerPool) GetAllConnIDs() ([]string, error) { p.mutex.RLock() // 如果内存缓存不为空,从缓存中提取在线连接的ID if len(p.cache) > 0 { ids := make([]string, 0, len(p.cache)) for id, conn := range p.cache { if conn.IsActive { ids = append(ids, id) } } p.mutex.RUnlock() return ids, nil } p.mutex.RUnlock() // 从BadgerDB获取所有在线连接的ID var ids []string err := p.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchSize = 10 it := txn.NewIterator(opts) defer it.Close() for it.Rewind(); it.Valid(); it.Next() { item := it.Item() var connInfo ConnectionInfo err := item.Value(func(val []byte) error { return json.Unmarshal(val, &connInfo) }) if err != nil { return err } if connInfo.IsActive { ids = append(ids, string(item.Key())) } } return nil }) if err != nil { return nil, fmt.Errorf("failed to get all connections: %w", err) } return ids, nil } // GetByType 根据类型获取连接 func (p *BadgerPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) { allConns, err := p.GetAll() if err != nil { return nil, err } var filtered []*ConnectionInfo for _, conn := range allConns { if conn.Type == connType { filtered = append(filtered, conn) } } return filtered, nil } // Count 获取连接数量 func (p *BadgerPool) Count() (int, error) { p.mutex.RLock() // 如果内存缓存不为空,直接返回缓存大小 if len(p.cache) > 0 { count := len(p.cache) p.mutex.RUnlock() return count, nil } p.mutex.RUnlock() // 从BadgerDB统计数量 var count int err := p.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchSize = 10 it := txn.NewIterator(opts) defer it.Close() for it.Rewind(); it.Valid(); it.Next() { count++ } return nil }) if err != nil { return 0, fmt.Errorf("failed to count connections: %w", err) } return count, nil } // CleanupInactive 清理不活跃的连接 func (p *BadgerPool) CleanupInactive(duration time.Duration) error { allConns, err := p.GetAll() if err != nil { return err } now := time.Now() for _, conn := range allConns { if !conn.IsActive || now.Sub(conn.LastUsed) > duration { if err := p.Remove(conn.ID); err != nil { return err } } } return nil }