281 lines
6.6 KiB
Go
281 lines
6.6 KiB
Go
package tcp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
"github.com/gogf/gf/v2/net/gtcp"
|
|
"github.com/gogf/gf/v2/os/glog"
|
|
"github.com/gogf/gf/v2/os/grpool"
|
|
"github.com/gogf/gf/v2/os/gtime"
|
|
)
|
|
|
|
// MessageHandler 消息处理函数类型
|
|
type MessageHandler func(conn *TcpConnection, msg *TcpMessage) error
|
|
|
|
// TCPServer TCP服务器结构
|
|
type TCPServer struct {
|
|
Address string
|
|
Config *TcpPoolConfig
|
|
Listener *gtcp.Server
|
|
Connection *ConnectionPool
|
|
Logger *glog.Logger
|
|
MessageHandler MessageHandler
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// ConnectionPool 连接池结构
|
|
type ConnectionPool struct {
|
|
connections map[string]*TcpConnection
|
|
mutex sync.RWMutex
|
|
config *TcpPoolConfig
|
|
logger *glog.Logger
|
|
}
|
|
|
|
// NewTCPServer 创建一个新的TCP服务器
|
|
func NewTCPServer(address string, config *TcpPoolConfig) *TCPServer {
|
|
logger := g.Log(address)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
pool := &ConnectionPool{
|
|
connections: make(map[string]*TcpConnection),
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
|
|
server := &TCPServer{
|
|
Address: address,
|
|
Config: config,
|
|
Connection: pool,
|
|
Logger: logger,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
server.Listener = gtcp.NewServer(address, server.handleConnection)
|
|
return server
|
|
}
|
|
|
|
// SetMessageHandler 设置消息处理函数
|
|
func (s *TCPServer) SetMessageHandler(handler MessageHandler) {
|
|
s.MessageHandler = handler
|
|
}
|
|
|
|
// Start 启动TCP服务器
|
|
func (s *TCPServer) Start() error {
|
|
s.Logger.Info(s.ctx, fmt.Sprintf("TCP server starting on %s", s.Address))
|
|
go func() {
|
|
s.wg.Add(1)
|
|
defer s.wg.Done()
|
|
if err := s.Listener.Run(); err != nil {
|
|
s.Logger.Error(s.ctx, fmt.Sprintf("TCP server stopped with error: %v", err))
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
// Stop 停止TCP服务器
|
|
func (s *TCPServer) Stop() error {
|
|
s.Logger.Info(s.ctx, "TCP server stopping...")
|
|
s.cancel()
|
|
s.Listener.Close()
|
|
s.wg.Wait()
|
|
s.Connection.Clear()
|
|
s.Logger.Info(s.ctx, "TCP server stopped")
|
|
return nil
|
|
}
|
|
|
|
// handleConnection 处理新连接
|
|
func (s *TCPServer) handleConnection(conn *gtcp.Conn) {
|
|
// 生成连接ID
|
|
connID := fmt.Sprintf("%s_%d", conn.RemoteAddr().String(), gtime.TimestampNano())
|
|
|
|
// 创建连接对象
|
|
tcpConn := &TcpConnection{
|
|
Id: connID,
|
|
Address: conn.RemoteAddr().String(),
|
|
Server: *conn,
|
|
IsActive: true,
|
|
LastUsed: time.Now(),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
|
|
// 添加到连接池
|
|
s.Connection.Add(tcpConn)
|
|
s.Logger.Info(s.ctx, fmt.Sprintf("New connection established: %s", connID))
|
|
|
|
// 启动消息接收协程
|
|
go s.receiveMessages(tcpConn)
|
|
}
|
|
|
|
// receiveMessages 接收消息
|
|
func (s *TCPServer) receiveMessages(conn *TcpConnection) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
s.Logger.Error(s.ctx, fmt.Sprintf("Panic in receiveMessages: %v", err))
|
|
}
|
|
s.Connection.Remove(conn.Id)
|
|
conn.Server.Close()
|
|
s.Logger.Info(s.ctx, fmt.Sprintf("Connection closed: %s", conn.Id))
|
|
}()
|
|
|
|
buffer := make([]byte, s.Config.BufferSize)
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return
|
|
default:
|
|
// 设置读取超时
|
|
conn.Server.SetReadDeadline(time.Now().Add(s.Config.ReadTimeout))
|
|
|
|
// 读取数据
|
|
n, err := conn.Server.Read(buffer)
|
|
if err != nil {
|
|
s.Logger.Error(s.ctx, fmt.Sprintf("Read error from %s: %v", conn.Id, err))
|
|
return
|
|
}
|
|
|
|
if n > 0 {
|
|
// 更新最后使用时间
|
|
conn.Mutex.Lock()
|
|
conn.LastUsed = time.Now()
|
|
conn.Mutex.Unlock()
|
|
|
|
// 处理消息
|
|
data := make([]byte, n)
|
|
copy(data, buffer[:n])
|
|
|
|
msg := &TcpMessage{
|
|
Id: fmt.Sprintf("msg_%d", gtime.TimestampNano()),
|
|
ConnId: conn.Id,
|
|
Data: data,
|
|
Timestamp: time.Now(),
|
|
IsSend: false,
|
|
}
|
|
|
|
// 使用协程池处理消息,避免阻塞
|
|
grpool.AddWithRecover(s.ctx, func(ctx context.Context) {
|
|
if s.MessageHandler != nil {
|
|
if err := s.MessageHandler(conn, msg); err != nil {
|
|
s.Logger.Error(s.ctx, fmt.Sprintf("Message handling error: %v", err))
|
|
}
|
|
}
|
|
}, func(ctx context.Context, err error) {
|
|
s.Logger.Error(ctx, fmt.Sprintf("Message handling error: %v", err))
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SendTo 发送消息到指定连接
|
|
func (s *TCPServer) SendTo(connID string, data []byte) error {
|
|
conn := s.Connection.Get(connID)
|
|
if conn == nil {
|
|
return fmt.Errorf("connection not found: %s", connID)
|
|
}
|
|
return s.sendMessage(conn, data)
|
|
}
|
|
|
|
// SendToAll 发送消息到所有连接
|
|
func (s *TCPServer) SendToAll(data []byte) error {
|
|
conns := s.Connection.GetAll()
|
|
for _, conn := range conns {
|
|
if err := s.sendMessage(conn, data); err != nil {
|
|
s.Logger.Error(s.ctx, fmt.Sprintf("Send to %s failed: %v", conn.Id, err))
|
|
// 继续发送给其他连接
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// sendMessage 发送消息
|
|
func (s *TCPServer) sendMessage(conn *TcpConnection, data []byte) error {
|
|
conn.Mutex.Lock()
|
|
defer conn.Mutex.Unlock()
|
|
|
|
// 设置写入超时
|
|
conn.Server.SetWriteDeadline(time.Now().Add(s.Config.WriteTimeout))
|
|
|
|
// 发送数据
|
|
_, err := conn.Server.Write(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 更新最后使用时间
|
|
conn.LastUsed = time.Now()
|
|
return nil
|
|
}
|
|
|
|
// Kick 强制退出客户端
|
|
func (s *TCPServer) Kick(connID string) error {
|
|
conn := s.Connection.Get(connID)
|
|
if conn == nil {
|
|
return fmt.Errorf("connection not found: %s", connID)
|
|
}
|
|
|
|
// 关闭连接
|
|
conn.Server.Close()
|
|
// 从连接池移除
|
|
s.Connection.Remove(connID)
|
|
|
|
s.Logger.Info(s.ctx, fmt.Sprintf("Kicked connection: %s", connID))
|
|
return nil
|
|
}
|
|
|
|
// Add 添加连接到连接池
|
|
func (p *ConnectionPool) Add(conn *TcpConnection) {
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
p.connections[conn.Id] = conn
|
|
}
|
|
|
|
// Get 获取连接
|
|
func (p *ConnectionPool) Get(connID string) *TcpConnection {
|
|
p.mutex.RLock()
|
|
defer p.mutex.RUnlock()
|
|
return p.connections[connID]
|
|
}
|
|
|
|
// GetAll 获取所有连接
|
|
func (p *ConnectionPool) GetAll() []*TcpConnection {
|
|
p.mutex.RLock()
|
|
defer p.mutex.RUnlock()
|
|
|
|
conns := make([]*TcpConnection, 0, len(p.connections))
|
|
for _, conn := range p.connections {
|
|
conns = append(conns, conn)
|
|
}
|
|
return conns
|
|
}
|
|
|
|
// Remove 从连接池移除连接
|
|
func (p *ConnectionPool) Remove(connID string) {
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
delete(p.connections, connID)
|
|
}
|
|
|
|
// Clear 清空连接池
|
|
func (p *ConnectionPool) Clear() {
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
for connID, conn := range p.connections {
|
|
conn.Server.Close()
|
|
delete(p.connections, connID)
|
|
}
|
|
}
|
|
|
|
// Count 获取连接数量
|
|
func (p *ConnectionPool) Count() int {
|
|
p.mutex.RLock()
|
|
defer p.mutex.RUnlock()
|
|
return len(p.connections)
|
|
}
|