gf-common/pool/nutsdb.go

333 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package pool
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/gogf/gf/v2/os/gfile"
"github.com/nutsdb/nutsdb"
)
// ConnType 连接类型
type ConnType string
const (
ConnTypeWebSocket ConnType = "websocket"
ConnTypeTCP ConnType = "tcp"
)
// ConnectionInfo 连接信息
type ConnectionInfo struct {
ID string `json:"id"`
Type ConnType `json:"type"`
Address string `json:"address"`
IsActive bool `json:"isActive"`
LastUsed time.Time `json:"lastUsed"`
CreatedAt time.Time `json:"createdAt"`
// 额外的连接数据,根据不同类型存储不同的信息
Data map[string]interface{} `json:"data"`
}
// NutsPool NutsDB连接池
type NutsPool struct {
db *nutsdb.DB
bucket string
mutex sync.RWMutex
ctx context.Context
cancel context.CancelFunc
// 内存缓存,提高并发性能
cache map[string]*ConnectionInfo
}
// NewNutsPool 创建NutsDB连接池
func NewNutsPool() (*NutsPool, error) {
ctx, cancel := context.WithCancel(context.Background())
// 使用当前运行项目的根目录的nuts文件夹作为存储路径并添加时间戳以避免锁定问题
dir := gfile.Pwd() + "/nuts"
// 打开NutsDB
db, err := nutsdb.Open(
nutsdb.DefaultOptions,
nutsdb.WithDir(dir),
)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to open nutsdb: %w", err)
}
return &NutsPool{
db: db,
bucket: "connections",
ctx: ctx,
cancel: cancel,
cache: make(map[string]*ConnectionInfo),
}, nil
}
// Close 关闭连接池
func (p *NutsPool) Close() error {
p.cancel()
return p.db.Close()
}
// Add 添加连接
func (p *NutsPool) Add(conn *ConnectionInfo) error {
p.mutex.Lock()
defer p.mutex.Unlock()
// 序列化连接信息
data, err := json.Marshal(conn)
if err != nil {
return fmt.Errorf("failed to marshal connection info: %w", err)
}
// 存储到NutsDB
err = p.db.Update(func(tx *nutsdb.Tx) error {
return tx.Put(p.bucket, []byte(conn.ID), data, 0)
})
if err != nil {
return fmt.Errorf("failed to store connection: %w", err)
}
// 更新内存缓存
p.cache[conn.ID] = conn
return nil
}
// Get 获取连接
func (p *NutsPool) Get(connID string) (*ConnectionInfo, error) {
p.mutex.RLock()
// 先从内存缓存获取
if conn, ok := p.cache[connID]; ok {
p.mutex.RUnlock()
return conn, nil
}
p.mutex.RUnlock()
// 从NutsDB获取
var connInfo ConnectionInfo
err := p.db.View(func(tx *nutsdb.Tx) error {
data, err := tx.Get(p.bucket, []byte(connID))
if err != nil {
return err
}
return json.Unmarshal(data, &connInfo)
})
if err != nil {
if err == nutsdb.ErrKeyNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get connection: %w", err)
}
// 更新内存缓存
p.mutex.Lock()
p.cache[connID] = &connInfo
p.mutex.Unlock()
return &connInfo, nil
}
// Remove 移除连接
func (p *NutsPool) Remove(connID string) error {
p.mutex.Lock()
defer p.mutex.Unlock()
// 从NutsDB删除
err := p.db.Update(func(tx *nutsdb.Tx) error {
return tx.Delete(p.bucket, []byte(connID))
})
if err != nil {
return fmt.Errorf("failed to remove connection: %w", err)
}
// 从内存缓存删除
delete(p.cache, connID)
return nil
}
// Update 更新连接信息
func (p *NutsPool) Update(conn *ConnectionInfo) error {
p.mutex.Lock()
defer p.mutex.Unlock()
// 序列化连接信息
data, err := json.Marshal(conn)
if err != nil {
return fmt.Errorf("failed to marshal connection info: %w", err)
}
// 存储到NutsDB
err = p.db.Update(func(tx *nutsdb.Tx) error {
return tx.Put(p.bucket, []byte(conn.ID), data, 0)
})
if err != nil {
return fmt.Errorf("failed to update connection: %w", err)
}
// 更新内存缓存
p.cache[conn.ID] = conn
return nil
}
// GetAll 获取所有连接
func (p *NutsPool) GetAll() ([]*ConnectionInfo, error) {
p.mutex.RLock()
// 如果内存缓存不为空,直接返回缓存
if len(p.cache) > 0 {
conns := make([]*ConnectionInfo, 0, len(p.cache))
for _, conn := range p.cache {
conns = append(conns, conn)
}
p.mutex.RUnlock()
return conns, nil
}
p.mutex.RUnlock()
// 从NutsDB获取所有连接
var conns []*ConnectionInfo
err := p.db.View(func(tx *nutsdb.Tx) error {
keys, _, err := tx.GetAll(p.bucket)
if err != nil {
return err
}
for _, key := range keys {
val, err := tx.Get(p.bucket, key)
if err != nil {
return err
}
var connInfo ConnectionInfo
if err := json.Unmarshal(val, &connInfo); err != nil {
return err
}
conns = append(conns, &connInfo)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to get all connections: %w", err)
}
// 更新内存缓存
p.mutex.Lock()
for _, conn := range conns {
p.cache[conn.ID] = conn
}
p.mutex.Unlock()
return conns, nil
}
// GetByType 根据类型获取连接
func (p *NutsPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) {
allConns, err := p.GetAll()
if err != nil {
return nil, err
}
var filtered []*ConnectionInfo
for _, conn := range allConns {
if conn.Type == connType {
filtered = append(filtered, conn)
}
}
return filtered, nil
}
// Count 获取连接数量
func (p *NutsPool) Count() (int, error) {
p.mutex.RLock()
// 如果内存缓存不为空,直接返回缓存大小
if len(p.cache) > 0 {
count := len(p.cache)
p.mutex.RUnlock()
return count, nil
}
p.mutex.RUnlock()
// 从NutsDB统计数量
var count int
err := p.db.View(func(tx *nutsdb.Tx) error {
entries, _, err := tx.GetAll(p.bucket)
if err != nil {
return err
}
count = len(entries)
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to count connections: %w", err)
}
return count, nil
}
// GetAllConnIDs 获取所有在线连接的ID列表
func (p *NutsPool) GetAllConnIDs() ([]string, error) {
p.mutex.RLock()
// 如果内存缓存不为空从缓存中提取在线连接的ID
if len(p.cache) > 0 {
ids := make([]string, 0, len(p.cache))
for id, conn := range p.cache {
if conn.IsActive {
ids = append(ids, id)
}
}
p.mutex.RUnlock()
return ids, nil
}
p.mutex.RUnlock()
// 从NutsDB获取所有在线连接的ID
var ids []string
err := p.db.View(func(tx *nutsdb.Tx) error {
keys, _, err := tx.GetAll(p.bucket)
if err != nil {
return err
}
for _, key := range keys {
val, err := tx.Get(p.bucket, key)
if err != nil {
return err
}
var connInfo ConnectionInfo
if err := json.Unmarshal(val, &connInfo); err != nil {
return err
}
if connInfo.IsActive {
ids = append(ids, string(key))
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to get all connection IDs: %w", err)
}
return ids, nil
}
// CleanupInactive 清理不活跃的连接
func (p *NutsPool) CleanupInactive(duration time.Duration) error {
allConns, err := p.GetAll()
if err != nil {
return err
}
now := time.Now()
for _, conn := range allConns {
if !conn.IsActive || now.Sub(conn.LastUsed) > duration {
if err := p.Remove(conn.ID); err != nil {
return err
}
}
}
return nil
}