diff --git a/pool/badger.go b/pool/badger.go index 9438a0b..16ac8c7 100644 --- a/pool/badger.go +++ b/pool/badger.go @@ -215,6 +215,51 @@ func (p *BadgerPool) GetAll() ([]*ConnectionInfo, error) { return conns, nil } +// GetAllConnIDs 获取所有在线连接的ID列表 +func (p *BadgerPool) GetAllConnIDs() ([]string, error) { + p.mutex.RLock() + // 如果内存缓存不为空,从缓存中提取在线连接的ID + if len(p.cache) > 0 { + ids := make([]string, 0, len(p.cache)) + for id, conn := range p.cache { + if conn.IsActive { + ids = append(ids, id) + } + } + p.mutex.RUnlock() + return ids, nil + } + p.mutex.RUnlock() + + // 从BadgerDB获取所有在线连接的ID + var ids []string + err := p.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchSize = 10 + it := txn.NewIterator(opts) + defer it.Close() + + for it.Rewind(); it.Valid(); it.Next() { + item := it.Item() + var connInfo ConnectionInfo + err := item.Value(func(val []byte) error { + return json.Unmarshal(val, &connInfo) + }) + if err != nil { + return err + } + if connInfo.IsActive { + ids = append(ids, string(item.Key())) + } + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to get all connections: %w", err) + } + return ids, nil +} + // GetByType 根据类型获取连接 func (p *BadgerPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) { allConns, err := p.GetAll()