feat(pool): 添加NutsDB连接池实现
- 实现了基于NutsDB的连接池管理功能 - 定义了ConnectionInfo结构体用于存储连接信息 - 提供了连接的增删改查操作接口 - 实现了内存缓存机制提升并发性能 - 添加了按类型查询和统计连接数量的功能 - 实现了清理不活跃连接的定时任务功能main v1.0.1008
parent
c50714e8a0
commit
bb0f3de20a
|
|
@ -0,0 +1,332 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/os/gfile"
|
||||
"github.com/nutsdb/nutsdb"
|
||||
)
|
||||
|
||||
// ConnType 连接类型
|
||||
type ConnType string
|
||||
|
||||
const (
|
||||
ConnTypeWebSocket ConnType = "websocket"
|
||||
ConnTypeTCP ConnType = "tcp"
|
||||
)
|
||||
|
||||
// ConnectionInfo 连接信息
|
||||
type ConnectionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Type ConnType `json:"type"`
|
||||
Address string `json:"address"`
|
||||
IsActive bool `json:"isActive"`
|
||||
LastUsed time.Time `json:"lastUsed"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
// 额外的连接数据,根据不同类型存储不同的信息
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// NutsPool NutsDB连接池
|
||||
type NutsPool struct {
|
||||
db *nutsdb.DB
|
||||
bucket string
|
||||
mutex sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// 内存缓存,提高并发性能
|
||||
cache map[string]*ConnectionInfo
|
||||
}
|
||||
|
||||
// NewNutsPool 创建NutsDB连接池
|
||||
func NewNutsPool() (*NutsPool, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// 使用当前运行项目的根目录的nuts文件夹作为存储路径,并添加时间戳以避免锁定问题
|
||||
dir := gfile.Pwd() + "/nuts"
|
||||
|
||||
// 打开NutsDB
|
||||
db, err := nutsdb.Open(
|
||||
nutsdb.DefaultOptions,
|
||||
nutsdb.WithDir(dir),
|
||||
)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to open nutsdb: %w", err)
|
||||
}
|
||||
|
||||
return &NutsPool{
|
||||
db: db,
|
||||
bucket: "connections",
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
cache: make(map[string]*ConnectionInfo),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close 关闭连接池
|
||||
func (p *NutsPool) Close() error {
|
||||
p.cancel()
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
// Add 添加连接
|
||||
func (p *NutsPool) Add(conn *ConnectionInfo) error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
|
||||
// 序列化连接信息
|
||||
data, err := json.Marshal(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal connection info: %w", err)
|
||||
}
|
||||
|
||||
// 存储到NutsDB
|
||||
err = p.db.Update(func(tx *nutsdb.Tx) error {
|
||||
return tx.Put(p.bucket, []byte(conn.ID), data, 0)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store connection: %w", err)
|
||||
}
|
||||
|
||||
// 更新内存缓存
|
||||
p.cache[conn.ID] = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取连接
|
||||
func (p *NutsPool) Get(connID string) (*ConnectionInfo, error) {
|
||||
p.mutex.RLock()
|
||||
// 先从内存缓存获取
|
||||
if conn, ok := p.cache[connID]; ok {
|
||||
p.mutex.RUnlock()
|
||||
return conn, nil
|
||||
}
|
||||
p.mutex.RUnlock()
|
||||
|
||||
// 从NutsDB获取
|
||||
var connInfo ConnectionInfo
|
||||
err := p.db.View(func(tx *nutsdb.Tx) error {
|
||||
data, err := tx.Get(p.bucket, []byte(connID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, &connInfo)
|
||||
})
|
||||
if err != nil {
|
||||
if err == nutsdb.ErrKeyNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
|
||||
// 更新内存缓存
|
||||
p.mutex.Lock()
|
||||
p.cache[connID] = &connInfo
|
||||
p.mutex.Unlock()
|
||||
|
||||
return &connInfo, nil
|
||||
}
|
||||
|
||||
// Remove 移除连接
|
||||
func (p *NutsPool) Remove(connID string) error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
|
||||
// 从NutsDB删除
|
||||
err := p.db.Update(func(tx *nutsdb.Tx) error {
|
||||
return tx.Delete(p.bucket, []byte(connID))
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove connection: %w", err)
|
||||
}
|
||||
|
||||
// 从内存缓存删除
|
||||
delete(p.cache, connID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update 更新连接信息
|
||||
func (p *NutsPool) Update(conn *ConnectionInfo) error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
|
||||
// 序列化连接信息
|
||||
data, err := json.Marshal(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal connection info: %w", err)
|
||||
}
|
||||
|
||||
// 存储到NutsDB
|
||||
err = p.db.Update(func(tx *nutsdb.Tx) error {
|
||||
return tx.Put(p.bucket, []byte(conn.ID), data, 0)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update connection: %w", err)
|
||||
}
|
||||
|
||||
// 更新内存缓存
|
||||
p.cache[conn.ID] = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll 获取所有连接
|
||||
func (p *NutsPool) GetAll() ([]*ConnectionInfo, error) {
|
||||
p.mutex.RLock()
|
||||
// 如果内存缓存不为空,直接返回缓存
|
||||
if len(p.cache) > 0 {
|
||||
conns := make([]*ConnectionInfo, 0, len(p.cache))
|
||||
for _, conn := range p.cache {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
p.mutex.RUnlock()
|
||||
return conns, nil
|
||||
}
|
||||
p.mutex.RUnlock()
|
||||
|
||||
// 从NutsDB获取所有连接
|
||||
var conns []*ConnectionInfo
|
||||
err := p.db.View(func(tx *nutsdb.Tx) error {
|
||||
keys, _, err := tx.GetAll(p.bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
val, err := tx.Get(p.bucket, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var connInfo ConnectionInfo
|
||||
if err := json.Unmarshal(val, &connInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
conns = append(conns, &connInfo)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get all connections: %w", err)
|
||||
}
|
||||
|
||||
// 更新内存缓存
|
||||
p.mutex.Lock()
|
||||
for _, conn := range conns {
|
||||
p.cache[conn.ID] = conn
|
||||
}
|
||||
p.mutex.Unlock()
|
||||
|
||||
return conns, nil
|
||||
}
|
||||
|
||||
// GetByType 根据类型获取连接
|
||||
func (p *NutsPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) {
|
||||
allConns, err := p.GetAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var filtered []*ConnectionInfo
|
||||
for _, conn := range allConns {
|
||||
if conn.Type == connType {
|
||||
filtered = append(filtered, conn)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// Count 获取连接数量
|
||||
func (p *NutsPool) Count() (int, error) {
|
||||
p.mutex.RLock()
|
||||
// 如果内存缓存不为空,直接返回缓存大小
|
||||
if len(p.cache) > 0 {
|
||||
count := len(p.cache)
|
||||
p.mutex.RUnlock()
|
||||
return count, nil
|
||||
}
|
||||
p.mutex.RUnlock()
|
||||
|
||||
// 从NutsDB统计数量
|
||||
var count int
|
||||
err := p.db.View(func(tx *nutsdb.Tx) error {
|
||||
entries, _, err := tx.GetAll(p.bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count = len(entries)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count connections: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetAllConnIDs 获取所有在线连接的ID列表
|
||||
func (p *NutsPool) GetAllConnIDs() ([]string, error) {
|
||||
p.mutex.RLock()
|
||||
// 如果内存缓存不为空,从缓存中提取在线连接的ID
|
||||
if len(p.cache) > 0 {
|
||||
ids := make([]string, 0, len(p.cache))
|
||||
for id, conn := range p.cache {
|
||||
if conn.IsActive {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
p.mutex.RUnlock()
|
||||
return ids, nil
|
||||
}
|
||||
p.mutex.RUnlock()
|
||||
|
||||
// 从NutsDB获取所有在线连接的ID
|
||||
var ids []string
|
||||
err := p.db.View(func(tx *nutsdb.Tx) error {
|
||||
keys, _, err := tx.GetAll(p.bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
val, err := tx.Get(p.bucket, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var connInfo ConnectionInfo
|
||||
if err := json.Unmarshal(val, &connInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
if connInfo.IsActive {
|
||||
ids = append(ids, string(key))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get all connection IDs: %w", err)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// CleanupInactive 清理不活跃的连接
|
||||
func (p *NutsPool) CleanupInactive(duration time.Duration) error {
|
||||
allConns, err := p.GetAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, conn := range allConns {
|
||||
if !conn.IsActive || now.Sub(conn.LastUsed) > duration {
|
||||
if err := p.Remove(conn.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"git.magicany.cc/black1552/gf-common/server/ws"
|
||||
"git.magicany.cc/black1552/gf-common/tcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 测试WebSocket
|
||||
log.Println("开始测试WebSocket...")
|
||||
ws.TestWebSocket()
|
||||
log.Println("WebSocket测试完成")
|
||||
|
||||
// 测试TCP
|
||||
log.Println("开始测试TCP...")
|
||||
tcp.TestTCP()
|
||||
log.Println("TCP测试完成")
|
||||
|
||||
log.Println("所有测试完成")
|
||||
}
|
||||
Loading…
Reference in New Issue