gf-common/pool/badger.go

332 lines
7.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package pool
import (
"context"
"encoding/json"
"fmt"
"path/filepath"
"sync"
"time"
"github.com/dgraph-io/badger/v4"
"github.com/gogf/gf/v2/os/gfile"
)
// ConnType 连接类型
type ConnType string
const (
ConnTypeWebSocket ConnType = "websocket"
ConnTypeTCP ConnType = "tcp"
)
// ConnectionInfo 连接信息
type ConnectionInfo struct {
ID string `json:"id"`
Type ConnType `json:"type"`
Address string `json:"address"`
IsActive bool `json:"isActive"`
LastUsed time.Time `json:"lastUsed"`
CreatedAt time.Time `json:"createdAt"`
// 额外的连接数据,根据不同类型存储不同的信息
Data map[string]interface{} `json:"data"`
}
// BadgerPool BadgerDB连接池
type BadgerPool struct {
db *badger.DB
mutex sync.RWMutex
ctx context.Context
cancel context.CancelFunc
// 内存缓存,提高并发性能
cache map[string]*ConnectionInfo
}
// NewBadgerPool 创建BadgerDB连接池
func NewBadgerPool() (*BadgerPool, error) {
ctx, cancel := context.WithCancel(context.Background())
db, err := badger.Open(badger.DefaultOptions(filepath.Join(gfile.Pwd(), "badger")))
if err != nil {
cancel()
return nil, fmt.Errorf("failed to open badger db: %w", err)
}
return &BadgerPool{
db: db,
ctx: ctx,
cancel: cancel,
cache: make(map[string]*ConnectionInfo),
}, nil
}
// Close 关闭连接池
func (p *BadgerPool) Close() error {
p.cancel()
return p.db.Close()
}
// Add 添加连接
func (p *BadgerPool) Add(conn *ConnectionInfo) error {
p.mutex.Lock()
defer p.mutex.Unlock()
// 序列化连接信息
data, err := json.Marshal(conn)
if err != nil {
return fmt.Errorf("failed to marshal connection info: %w", err)
}
// 存储到BadgerDB
err = p.db.Update(func(txn *badger.Txn) error {
return txn.Set([]byte(conn.ID), data)
})
if err != nil {
return fmt.Errorf("failed to store connection: %w", err)
}
// 更新内存缓存
p.cache[conn.ID] = conn
return nil
}
// Get 获取连接
func (p *BadgerPool) Get(connID string) (*ConnectionInfo, error) {
p.mutex.RLock()
// 先从内存缓存获取
if conn, ok := p.cache[connID]; ok {
p.mutex.RUnlock()
return conn, nil
}
p.mutex.RUnlock()
// 从BadgerDB获取
var connInfo ConnectionInfo
err := p.db.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(connID))
if err != nil {
return err
}
return item.Value(func(val []byte) error {
return json.Unmarshal(val, &connInfo)
})
})
if err != nil {
if err == badger.ErrKeyNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get connection: %w", err)
}
// 更新内存缓存
p.mutex.Lock()
p.cache[connID] = &connInfo
p.mutex.Unlock()
return &connInfo, nil
}
// Remove 移除连接
func (p *BadgerPool) Remove(connID string) error {
p.mutex.Lock()
defer p.mutex.Unlock()
// 从BadgerDB删除
err := p.db.Update(func(txn *badger.Txn) error {
return txn.Delete([]byte(connID))
})
if err != nil {
return fmt.Errorf("failed to remove connection: %w", err)
}
// 从内存缓存删除
delete(p.cache, connID)
return nil
}
// Update 更新连接信息
func (p *BadgerPool) Update(conn *ConnectionInfo) error {
p.mutex.Lock()
defer p.mutex.Unlock()
// 序列化连接信息
data, err := json.Marshal(conn)
if err != nil {
return fmt.Errorf("failed to marshal connection info: %w", err)
}
// 存储到BadgerDB
err = p.db.Update(func(txn *badger.Txn) error {
return txn.Set([]byte(conn.ID), data)
})
if err != nil {
return fmt.Errorf("failed to update connection: %w", err)
}
// 更新内存缓存
p.cache[conn.ID] = conn
return nil
}
// GetAll 获取所有连接
func (p *BadgerPool) GetAll() ([]*ConnectionInfo, error) {
p.mutex.RLock()
// 如果内存缓存不为空,直接返回缓存
if len(p.cache) > 0 {
conns := make([]*ConnectionInfo, 0, len(p.cache))
for _, conn := range p.cache {
conns = append(conns, conn)
}
p.mutex.RUnlock()
return conns, nil
}
p.mutex.RUnlock()
// 从BadgerDB获取所有连接
var conns []*ConnectionInfo
err := p.db.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchSize = 10
it := txn.NewIterator(opts)
defer it.Close()
for it.Rewind(); it.Valid(); it.Next() {
item := it.Item()
var connInfo ConnectionInfo
err := item.Value(func(val []byte) error {
return json.Unmarshal(val, &connInfo)
})
if err != nil {
return err
}
conns = append(conns, &connInfo)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to get all connections: %w", err)
}
// 更新内存缓存
p.mutex.Lock()
for _, conn := range conns {
p.cache[conn.ID] = conn
}
p.mutex.Unlock()
return conns, nil
}
// GetByType 根据类型获取连接
func (p *BadgerPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) {
allConns, err := p.GetAll()
if err != nil {
return nil, err
}
var filtered []*ConnectionInfo
for _, conn := range allConns {
if conn.Type == connType {
filtered = append(filtered, conn)
}
}
return filtered, nil
}
// Count 获取连接数量
func (p *BadgerPool) Count() (int, error) {
p.mutex.RLock()
// 如果内存缓存不为空,直接返回缓存大小
if len(p.cache) > 0 {
count := len(p.cache)
p.mutex.RUnlock()
return count, nil
}
p.mutex.RUnlock()
// 从BadgerDB统计数量
var count int
err := p.db.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchSize = 10
it := txn.NewIterator(opts)
defer it.Close()
for it.Rewind(); it.Valid(); it.Next() {
count++
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to count connections: %w", err)
}
return count, nil
}
// GetAllConnIDs 获取所有在线连接的ID列表
func (p *BadgerPool) GetAllConnIDs() ([]string, error) {
p.mutex.RLock()
// 如果内存缓存不为空从缓存中提取在线连接的ID
if len(p.cache) > 0 {
ids := make([]string, 0, len(p.cache))
for id, conn := range p.cache {
if conn.IsActive {
ids = append(ids, id)
}
}
p.mutex.RUnlock()
return ids, nil
}
p.mutex.RUnlock()
// 从BadgerDB获取所有在线连接的ID
var ids []string
err := p.db.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchSize = 10
it := txn.NewIterator(opts)
defer it.Close()
for it.Rewind(); it.Valid(); it.Next() {
item := it.Item()
var connInfo ConnectionInfo
err := item.Value(func(val []byte) error {
return json.Unmarshal(val, &connInfo)
})
if err != nil {
return err
}
if connInfo.IsActive {
ids = append(ids, string(item.Key()))
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to get all connection IDs: %w", err)
}
return ids, nil
}
// CleanupInactive 清理不活跃的连接
func (p *BadgerPool) CleanupInactive(duration time.Duration) error {
allConns, err := p.GetAll()
if err != nil {
return err
}
now := time.Now()
for _, conn := range allConns {
if !conn.IsActive || now.Sub(conn.LastUsed) > duration {
if err := p.Remove(conn.ID); err != nil {
return err
}
}
}
return nil
}