286 lines
6.0 KiB
Go
286 lines
6.0 KiB
Go
package pool
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/dgraph-io/badger/v4"
|
|
"github.com/gogf/gf/v2/os/gfile"
|
|
)
|
|
|
|
// ConnType 连接类型
|
|
type ConnType string
|
|
|
|
const (
|
|
ConnTypeWebSocket ConnType = "websocket"
|
|
ConnTypeTCP ConnType = "tcp"
|
|
)
|
|
|
|
// ConnectionInfo 连接信息
|
|
type ConnectionInfo struct {
|
|
ID string `json:"id"`
|
|
Type ConnType `json:"type"`
|
|
Address string `json:"address"`
|
|
IsActive bool `json:"isActive"`
|
|
LastUsed time.Time `json:"lastUsed"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
// 额外的连接数据,根据不同类型存储不同的信息
|
|
Data map[string]interface{} `json:"data"`
|
|
}
|
|
|
|
// BadgerPool BadgerDB连接池
|
|
type BadgerPool struct {
|
|
db *badger.DB
|
|
mutex sync.RWMutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
// 内存缓存,提高并发性能
|
|
cache map[string]*ConnectionInfo
|
|
}
|
|
|
|
// NewBadgerPool 创建BadgerDB连接池
|
|
func NewBadgerPool(badgerDir string) (*BadgerPool, error) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
db, err := badger.Open(badger.DefaultOptions(badgerDir))
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("failed to open badger db: %w", err)
|
|
}
|
|
|
|
return &BadgerPool{
|
|
db: db,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
cache: make(map[string]*ConnectionInfo),
|
|
}, nil
|
|
}
|
|
|
|
// Close 关闭连接池
|
|
func (p *BadgerPool) Close() error {
|
|
p.cancel()
|
|
return p.db.Close()
|
|
}
|
|
|
|
// Add 添加连接
|
|
func (p *BadgerPool) Add(conn *ConnectionInfo) error {
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
|
|
// 序列化连接信息
|
|
data, err := json.Marshal(conn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal connection info: %w", err)
|
|
}
|
|
|
|
// 存储到BadgerDB
|
|
err = p.db.Update(func(txn *badger.Txn) error {
|
|
return txn.Set([]byte(conn.ID), data)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to store connection: %w", err)
|
|
}
|
|
|
|
// 更新内存缓存
|
|
p.cache[conn.ID] = conn
|
|
return nil
|
|
}
|
|
|
|
// Get 获取连接
|
|
func (p *BadgerPool) Get(connID string) (*ConnectionInfo, error) {
|
|
p.mutex.RLock()
|
|
// 先从内存缓存获取
|
|
if conn, ok := p.cache[connID]; ok {
|
|
p.mutex.RUnlock()
|
|
return conn, nil
|
|
}
|
|
p.mutex.RUnlock()
|
|
|
|
// 从BadgerDB获取
|
|
var connInfo ConnectionInfo
|
|
err := p.db.View(func(txn *badger.Txn) error {
|
|
item, err := txn.Get([]byte(connID))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return item.Value(func(val []byte) error {
|
|
return json.Unmarshal(val, &connInfo)
|
|
})
|
|
})
|
|
if err != nil {
|
|
if err == badger.ErrKeyNotFound {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get connection: %w", err)
|
|
}
|
|
|
|
// 更新内存缓存
|
|
p.mutex.Lock()
|
|
p.cache[connID] = &connInfo
|
|
p.mutex.Unlock()
|
|
|
|
return &connInfo, nil
|
|
}
|
|
|
|
// Remove 移除连接
|
|
func (p *BadgerPool) Remove(connID string) error {
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
|
|
// 从BadgerDB删除
|
|
err := p.db.Update(func(txn *badger.Txn) error {
|
|
return txn.Delete([]byte(connID))
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to remove connection: %w", err)
|
|
}
|
|
|
|
// 从内存缓存删除
|
|
delete(p.cache, connID)
|
|
return nil
|
|
}
|
|
|
|
// Update 更新连接信息
|
|
func (p *BadgerPool) Update(conn *ConnectionInfo) error {
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
|
|
// 序列化连接信息
|
|
data, err := json.Marshal(conn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal connection info: %w", err)
|
|
}
|
|
|
|
// 存储到BadgerDB
|
|
err = p.db.Update(func(txn *badger.Txn) error {
|
|
return txn.Set([]byte(conn.ID), data)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update connection: %w", err)
|
|
}
|
|
|
|
// 更新内存缓存
|
|
p.cache[conn.ID] = conn
|
|
return nil
|
|
}
|
|
|
|
// GetAll 获取所有连接
|
|
func (p *BadgerPool) GetAll() ([]*ConnectionInfo, error) {
|
|
p.mutex.RLock()
|
|
// 如果内存缓存不为空,直接返回缓存
|
|
if len(p.cache) > 0 {
|
|
conns := make([]*ConnectionInfo, 0, len(p.cache))
|
|
for _, conn := range p.cache {
|
|
conns = append(conns, conn)
|
|
}
|
|
p.mutex.RUnlock()
|
|
return conns, nil
|
|
}
|
|
p.mutex.RUnlock()
|
|
|
|
// 从BadgerDB获取所有连接
|
|
var conns []*ConnectionInfo
|
|
err := p.db.View(func(txn *badger.Txn) error {
|
|
opts := badger.DefaultIteratorOptions
|
|
opts.PrefetchSize = 10
|
|
it := txn.NewIterator(opts)
|
|
defer it.Close()
|
|
|
|
for it.Rewind(); it.Valid(); it.Next() {
|
|
item := it.Item()
|
|
var connInfo ConnectionInfo
|
|
err := item.Value(func(val []byte) error {
|
|
return json.Unmarshal(val, &connInfo)
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conns = append(conns, &connInfo)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get all connections: %w", err)
|
|
}
|
|
|
|
// 更新内存缓存
|
|
p.mutex.Lock()
|
|
for _, conn := range conns {
|
|
p.cache[conn.ID] = conn
|
|
}
|
|
p.mutex.Unlock()
|
|
|
|
return conns, nil
|
|
}
|
|
|
|
// GetByType 根据类型获取连接
|
|
func (p *BadgerPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) {
|
|
allConns, err := p.GetAll()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var filtered []*ConnectionInfo
|
|
for _, conn := range allConns {
|
|
if conn.Type == connType {
|
|
filtered = append(filtered, conn)
|
|
}
|
|
}
|
|
|
|
return filtered, nil
|
|
}
|
|
|
|
// Count 获取连接数量
|
|
func (p *BadgerPool) Count() (int, error) {
|
|
p.mutex.RLock()
|
|
// 如果内存缓存不为空,直接返回缓存大小
|
|
if len(p.cache) > 0 {
|
|
count := len(p.cache)
|
|
p.mutex.RUnlock()
|
|
return count, nil
|
|
}
|
|
p.mutex.RUnlock()
|
|
|
|
// 从BadgerDB统计数量
|
|
var count int
|
|
err := p.db.View(func(txn *badger.Txn) error {
|
|
opts := badger.DefaultIteratorOptions
|
|
opts.PrefetchSize = 10
|
|
it := txn.NewIterator(opts)
|
|
defer it.Close()
|
|
|
|
for it.Rewind(); it.Valid(); it.Next() {
|
|
count++
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to count connections: %w", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// CleanupInactive 清理不活跃的连接
|
|
func (p *BadgerPool) CleanupInactive(duration time.Duration) error {
|
|
allConns, err := p.GetAll()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now()
|
|
for _, conn := range allConns {
|
|
if !conn.IsActive || now.Sub(conn.LastUsed) > duration {
|
|
if err := p.Remove(conn.ID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|