diff --git a/pool/nutsdb.go b/pool/nutsdb.go new file mode 100644 index 0000000..8ce4657 --- /dev/null +++ b/pool/nutsdb.go @@ -0,0 +1,332 @@ +package pool + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/gogf/gf/v2/os/gfile" + "github.com/nutsdb/nutsdb" +) + +// ConnType 连接类型 +type ConnType string + +const ( + ConnTypeWebSocket ConnType = "websocket" + ConnTypeTCP ConnType = "tcp" +) + +// ConnectionInfo 连接信息 +type ConnectionInfo struct { + ID string `json:"id"` + Type ConnType `json:"type"` + Address string `json:"address"` + IsActive bool `json:"isActive"` + LastUsed time.Time `json:"lastUsed"` + CreatedAt time.Time `json:"createdAt"` + // 额外的连接数据,根据不同类型存储不同的信息 + Data map[string]interface{} `json:"data"` +} + +// NutsPool NutsDB连接池 +type NutsPool struct { + db *nutsdb.DB + bucket string + mutex sync.RWMutex + ctx context.Context + cancel context.CancelFunc + // 内存缓存,提高并发性能 + cache map[string]*ConnectionInfo +} + +// NewNutsPool 创建NutsDB连接池 +func NewNutsPool() (*NutsPool, error) { + ctx, cancel := context.WithCancel(context.Background()) + + // 使用当前运行项目的根目录的nuts文件夹作为存储路径,并添加时间戳以避免锁定问题 + dir := gfile.Pwd() + "/nuts" + + // 打开NutsDB + db, err := nutsdb.Open( + nutsdb.DefaultOptions, + nutsdb.WithDir(dir), + ) + if err != nil { + cancel() + return nil, fmt.Errorf("failed to open nutsdb: %w", err) + } + + return &NutsPool{ + db: db, + bucket: "connections", + ctx: ctx, + cancel: cancel, + cache: make(map[string]*ConnectionInfo), + }, nil +} + +// Close 关闭连接池 +func (p *NutsPool) Close() error { + p.cancel() + return p.db.Close() +} + +// Add 添加连接 +func (p *NutsPool) Add(conn *ConnectionInfo) error { + p.mutex.Lock() + defer p.mutex.Unlock() + + // 序列化连接信息 + data, err := json.Marshal(conn) + if err != nil { + return fmt.Errorf("failed to marshal connection info: %w", err) + } + + // 存储到NutsDB + err = p.db.Update(func(tx *nutsdb.Tx) error { + return tx.Put(p.bucket, []byte(conn.ID), data, 0) + }) + if err != nil { + return fmt.Errorf("failed to store connection: %w", err) + } + + // 更新内存缓存 + p.cache[conn.ID] = conn + return nil +} + +// Get 获取连接 +func (p *NutsPool) Get(connID string) (*ConnectionInfo, error) { + p.mutex.RLock() + // 先从内存缓存获取 + if conn, ok := p.cache[connID]; ok { + p.mutex.RUnlock() + return conn, nil + } + p.mutex.RUnlock() + + // 从NutsDB获取 + var connInfo ConnectionInfo + err := p.db.View(func(tx *nutsdb.Tx) error { + data, err := tx.Get(p.bucket, []byte(connID)) + if err != nil { + return err + } + return json.Unmarshal(data, &connInfo) + }) + if err != nil { + if err == nutsdb.ErrKeyNotFound { + return nil, nil + } + return nil, fmt.Errorf("failed to get connection: %w", err) + } + + // 更新内存缓存 + p.mutex.Lock() + p.cache[connID] = &connInfo + p.mutex.Unlock() + + return &connInfo, nil +} + +// Remove 移除连接 +func (p *NutsPool) Remove(connID string) error { + p.mutex.Lock() + defer p.mutex.Unlock() + + // 从NutsDB删除 + err := p.db.Update(func(tx *nutsdb.Tx) error { + return tx.Delete(p.bucket, []byte(connID)) + }) + if err != nil { + return fmt.Errorf("failed to remove connection: %w", err) + } + + // 从内存缓存删除 + delete(p.cache, connID) + return nil +} + +// Update 更新连接信息 +func (p *NutsPool) Update(conn *ConnectionInfo) error { + p.mutex.Lock() + defer p.mutex.Unlock() + + // 序列化连接信息 + data, err := json.Marshal(conn) + if err != nil { + return fmt.Errorf("failed to marshal connection info: %w", err) + } + + // 存储到NutsDB + err = p.db.Update(func(tx *nutsdb.Tx) error { + return tx.Put(p.bucket, []byte(conn.ID), data, 0) + }) + if err != nil { + return fmt.Errorf("failed to update connection: %w", err) + } + + // 更新内存缓存 + p.cache[conn.ID] = conn + return nil +} + +// GetAll 获取所有连接 +func (p *NutsPool) GetAll() ([]*ConnectionInfo, error) { + p.mutex.RLock() + // 如果内存缓存不为空,直接返回缓存 + if len(p.cache) > 0 { + conns := make([]*ConnectionInfo, 0, len(p.cache)) + for _, conn := range p.cache { + conns = append(conns, conn) + } + p.mutex.RUnlock() + return conns, nil + } + p.mutex.RUnlock() + + // 从NutsDB获取所有连接 + var conns []*ConnectionInfo + err := p.db.View(func(tx *nutsdb.Tx) error { + keys, _, err := tx.GetAll(p.bucket) + if err != nil { + return err + } + for _, key := range keys { + val, err := tx.Get(p.bucket, key) + if err != nil { + return err + } + var connInfo ConnectionInfo + if err := json.Unmarshal(val, &connInfo); err != nil { + return err + } + conns = append(conns, &connInfo) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to get all connections: %w", err) + } + + // 更新内存缓存 + p.mutex.Lock() + for _, conn := range conns { + p.cache[conn.ID] = conn + } + p.mutex.Unlock() + + return conns, nil +} + +// GetByType 根据类型获取连接 +func (p *NutsPool) GetByType(connType ConnType) ([]*ConnectionInfo, error) { + allConns, err := p.GetAll() + if err != nil { + return nil, err + } + + var filtered []*ConnectionInfo + for _, conn := range allConns { + if conn.Type == connType { + filtered = append(filtered, conn) + } + } + + return filtered, nil +} + +// Count 获取连接数量 +func (p *NutsPool) Count() (int, error) { + p.mutex.RLock() + // 如果内存缓存不为空,直接返回缓存大小 + if len(p.cache) > 0 { + count := len(p.cache) + p.mutex.RUnlock() + return count, nil + } + p.mutex.RUnlock() + + // 从NutsDB统计数量 + var count int + err := p.db.View(func(tx *nutsdb.Tx) error { + entries, _, err := tx.GetAll(p.bucket) + if err != nil { + return err + } + count = len(entries) + return nil + }) + if err != nil { + return 0, fmt.Errorf("failed to count connections: %w", err) + } + + return count, nil +} + +// GetAllConnIDs 获取所有在线连接的ID列表 +func (p *NutsPool) GetAllConnIDs() ([]string, error) { + p.mutex.RLock() + // 如果内存缓存不为空,从缓存中提取在线连接的ID + if len(p.cache) > 0 { + ids := make([]string, 0, len(p.cache)) + for id, conn := range p.cache { + if conn.IsActive { + ids = append(ids, id) + } + } + p.mutex.RUnlock() + return ids, nil + } + p.mutex.RUnlock() + + // 从NutsDB获取所有在线连接的ID + var ids []string + err := p.db.View(func(tx *nutsdb.Tx) error { + keys, _, err := tx.GetAll(p.bucket) + if err != nil { + return err + } + for _, key := range keys { + val, err := tx.Get(p.bucket, key) + if err != nil { + return err + } + var connInfo ConnectionInfo + if err := json.Unmarshal(val, &connInfo); err != nil { + return err + } + if connInfo.IsActive { + ids = append(ids, string(key)) + } + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to get all connection IDs: %w", err) + } + + return ids, nil +} + +// CleanupInactive 清理不活跃的连接 +func (p *NutsPool) CleanupInactive(duration time.Duration) error { + allConns, err := p.GetAll() + if err != nil { + return err + } + + now := time.Now() + for _, conn := range allConns { + if !conn.IsActive || now.Sub(conn.LastUsed) > duration { + if err := p.Remove(conn.ID); err != nil { + return err + } + } + } + + return nil +} diff --git a/test.go b/test.go new file mode 100644 index 0000000..55ec5e9 --- /dev/null +++ b/test.go @@ -0,0 +1,22 @@ +package main + +import ( + "log" + + "git.magicany.cc/black1552/gf-common/server/ws" + "git.magicany.cc/black1552/gf-common/tcp" +) + +func main() { + // 测试WebSocket + log.Println("开始测试WebSocket...") + ws.TestWebSocket() + log.Println("WebSocket测试完成") + + // 测试TCP + log.Println("开始测试TCP...") + tcp.TestTCP() + log.Println("TCP测试完成") + + log.Println("所有测试完成") +}