package ws import ( "context" "errors" "fmt" "log" "net/http" "sync" "time" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/os/gtimer" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" "github.com/gorilla/websocket" ) // 常量定义:默认配置 const ( // 默认读写缓冲区大小(字节) DefaultReadBufferSize = 1024 DefaultWriteBufferSize = 1024 // 默认心跳间隔(秒):每30秒发送一次心跳 DefaultHeartbeatInterval = 30 * time.Second // 默认心跳超时(秒):60秒未收到客户端心跳响应则关闭连接 DefaultHeartbeatTimeout = 60 * time.Second // 默认读写超时(秒) DefaultReadTimeout = 60 * time.Second DefaultWriteTimeout = 10 * time.Second // 消息类型 MessageTypeText = websocket.TextMessage MessageTypeBinary = websocket.BinaryMessage // 心跳最大重试次数 HeartbeatMaxRetry = 3 ) // Config WebSocket服务端配置 type Config struct { // 读写缓冲区大小 ReadBufferSize int WriteBufferSize int // 跨域配置:是否允许所有跨域(生产环境建议指定Origin) AllowAllOrigins bool // 允许的跨域Origin列表(AllowAllOrigins=false时生效) AllowedOrigins []string // 心跳配置 HeartbeatInterval time.Duration // 心跳发送间隔 HeartbeatTimeout time.Duration // 心跳超时时间 // 读写超时 ReadTimeout time.Duration WriteTimeout time.Duration MsgType int // 发送消息的默认类型 HeartbeatValue string // 心跳消息的标识字段值(如"heartbeat"、"pong") HeartbeatKey string // 心跳消息的标识字段名(如"type") } // 默认配置 func DefaultConfig() *Config { return &Config{ ReadBufferSize: DefaultReadBufferSize, WriteBufferSize: DefaultWriteBufferSize, AllowAllOrigins: true, AllowedOrigins: []string{}, HeartbeatInterval: DefaultHeartbeatInterval, HeartbeatTimeout: DefaultHeartbeatTimeout, ReadTimeout: DefaultReadTimeout, WriteTimeout: DefaultWriteTimeout, MsgType: MessageTypeText, HeartbeatValue: "heartbeat", HeartbeatKey: "type", // 心跳消息的标识字段名,默认"type" } } // Connection WebSocket连接结构体 type Connection struct { conn *websocket.Conn // 底层连接 connID string // 唯一连接ID manager *Manager // 所属管理器 createTime time.Time // 连接创建时间 heartbeatChan time.Time // 心跳通道(用于检测客户端响应) heartbeatTime *gtimer.Entry ctx context.Context // 上下文 cancel context.CancelFunc // 上下文取消函数 writeMutex sync.Mutex // 写消息互斥锁(防止并发写) heartbeatRetry int // 心跳发送重试次数 } // Manager WebSocket连接管理器 type Manager struct { config *Config // 配置 upgrader *websocket.Upgrader // HTTP升级器 connections map[string]*Connection // 所有在线连接(connID -> Connection) mutex sync.RWMutex // 读写锁(保护connections) // 业务回调:收到消息时触发(用户自定义处理逻辑) OnMessage func(connID string, msgType int, data any) // 业务回调:连接建立时触发 OnConnect func(connID string) // 业务回调:连接关闭时触发 OnDisconnect func(connID string, err error) } // Merge 合并配置,用传入的配置覆盖非零值部分 func (c *Config) Merge(other *Config) *Config { result := *c // 复制当前配置 if other == nil { return &result } if other.ReadBufferSize > 0 { result.ReadBufferSize = other.ReadBufferSize } if other.WriteBufferSize > 0 { result.WriteBufferSize = other.WriteBufferSize } if other.HeartbeatInterval > 0 { result.HeartbeatInterval = other.HeartbeatInterval } if other.HeartbeatTimeout > 0 { result.HeartbeatTimeout = other.HeartbeatTimeout } if other.ReadTimeout > 0 { result.ReadTimeout = other.ReadTimeout } if other.WriteTimeout > 0 { result.WriteTimeout = other.WriteTimeout } if other.AllowAllOrigins { result.AllowAllOrigins = other.AllowAllOrigins } if other.HeartbeatValue != "" { result.HeartbeatValue = other.HeartbeatValue } if other.HeartbeatKey != "" { result.HeartbeatKey = other.HeartbeatKey } if len(other.AllowedOrigins) > 0 { result.AllowedOrigins = other.AllowedOrigins } if other.MsgType != 0 { result.MsgType = other.MsgType } return &result } // NewManager 创建连接管理器 func NewManager(config *Config) *Manager { defaultConfig := DefaultConfig() finalConfig := defaultConfig.Merge(config) // 初始化升级器 upgrader := &websocket.Upgrader{ ReadBufferSize: config.ReadBufferSize, WriteBufferSize: config.WriteBufferSize, CheckOrigin: func(r *http.Request) bool { // 跨域检查 if config.AllowAllOrigins { return true } origin := r.Header.Get("Origin") for _, allowed := range finalConfig.AllowedOrigins { if origin == allowed { return true } } return false }, } return &Manager{ config: finalConfig, upgrader: upgrader, connections: make(map[string]*Connection), mutex: sync.RWMutex{}, // 默认回调(用户可覆盖) OnMessage: func(connID string, msgType int, data any) { log.Printf("[默认回调] 收到连接[%s]消息:%s", connID, gconv.String(data)) }, OnConnect: func(connID string) { log.Printf("[默认回调] 连接[%s]已建立", connID) }, OnDisconnect: func(connID string, err error) { log.Printf("[默认回调] 连接[%s]已关闭:%v", connID, err) }, } } // Upgrade HTTP升级为WebSocket连接 // connID:自定义连接唯一ID(如用户ID、设备ID) func (m *Manager) Upgrade(w http.ResponseWriter, r *http.Request, connID string) (*Connection, error) { if connID == "" { return nil, errors.New("连接ID不能为空") } // 检查连接ID是否已存在 m.mutex.RLock() _, exists := m.connections[connID] m.mutex.RUnlock() if exists { return nil, fmt.Errorf("连接ID[%s]已存在", connID) } // 升级HTTP连接 conn, err := m.upgrader.Upgrade(w, r, nil) if err != nil { return nil, fmt.Errorf("升级WebSocket失败:%w", err) } // 创建上下文(用于优雅关闭) ctx, cancel := context.WithCancel(context.Background()) // 创建连接实例 wsConn := &Connection{ conn: conn, connID: connID, manager: m, createTime: time.Now(), heartbeatChan: time.Now(), // 缓冲1,防止阻塞 ctx: ctx, cancel: cancel, writeMutex: sync.Mutex{}, heartbeatRetry: 0, } wsConn.heartbeatTime = gtimer.AddSingleton(gctx.New(), m.config.HeartbeatTimeout, func(ctx context.Context) { log.Printf("[心跳检测] 连接[%s]已关闭:心跳超时", wsConn.connID) wsConn.heartbeatTime.Close() wsConn.heartbeatTime.Stop() wsConn.heartbeatTime = nil wsConn.ctx.Done() wsConn.Close(fmt.Errorf("心跳超时")) }) // 添加到管理器 m.mutex.Lock() m.connections[connID] = wsConn m.mutex.Unlock() // 触发连接建立回调 m.OnConnect(connID) // 启动读消息协程 go wsConn.ReadPump() // 启动写消息协程(处理异步发送) go wsConn.WritePump() // 启动心跳检测协程 go wsConn.Heartbeat() return wsConn, nil } // ReadPump 读取客户端消息(持续运行) func (c *Connection) ReadPump() { defer func() { // 发生panic时关闭连接 if err := recover(); err != nil { log.Printf("连接[%s]读消息协程panic:%v", c.connID, err) } // 关闭连接并清理 c.Close(fmt.Errorf("读消息协程退出")) }() // 循环读取消息 for { select { case <-c.ctx.Done(): return // 上下文已取消,退出 default: // 设置读超时(每次读取前重置,防止长时间无消息超时) c.conn.SetReadDeadline(time.Now().Add(c.manager.config.ReadTimeout)) // 读取客户端消息 msgType, data, err := c.conn.ReadMessage() if err != nil { // 区分正常关闭和异常错误 var closeErr *websocket.CloseError if errors.As(err, &closeErr) { c.Close(fmt.Errorf("客户端主动关闭:%s(代码:%d)", closeErr.Text, closeErr.Code)) } else { c.Close(fmt.Errorf("读取消息失败:%w", err)) } return } // 尝试解析JSON格式的心跳消息(精准判断,替代包含判断) isHeartbeat := false // 先尝试解析为JSON对象 var msgMap map[string]interface{} if err := gjson.DecodeTo(data, &msgMap); err == nil { // 获取心跳标识字段的值 heartbeatValue := gconv.String(msgMap[c.manager.config.HeartbeatKey]) if heartbeatValue == c.manager.config.HeartbeatValue { isHeartbeat = true } } else { // 非JSON格式,降级为包含判断(兼容纯文本心跳) str := gconv.String(data) if gstr.Contains(str, c.manager.config.HeartbeatValue) { isHeartbeat = true } } if isHeartbeat { log.Printf("[心跳] 收到连接[%s]心跳消息:%s", c.connID, string(data)) // 心跳消息:重置重试次数 + 发送心跳信号 + 重置读超时 js, err := gjson.Encode(&Msg[any]{c.manager.config.HeartbeatValue, nil, gtime.Timestamp()}) if err != nil { log.Printf("[心跳] 客户端[%s]json编码失败", c.connID) continue } err = c.Send(js) if err != nil { log.Printf("[心跳] 客户端[%s]发送心跳消息失败", c.connID) continue } c.heartbeatTime.Reset() continue // 跳过业务回调 } // 非心跳消息:触发业务回调 c.manager.OnMessage(c.connID, msgType, data) } } } type Msg[T any] struct { Type string `json:"type"` Data T `json:"data"` Timestamp int64 `json:"timestamp"` } // WritePump 处理异步写消息(持续运行) // 扩展为监听写队列,防止消息丢失 func (c *Connection) WritePump() { defer func() { if err := recover(); err != nil { log.Printf("连接[%s]写消息协程panic:%v", c.connID, err) } }() // 暂时保持简化,实际可扩展为带缓冲的写队列 <-c.ctx.Done() } // Heartbeat 心跳检测(持续运行) func (c *Connection) Heartbeat() { defer func() { if err := recover(); err != nil { log.Printf("连接[%s]心跳协程panic:%v", c.connID, err) } }() c.heartbeatTime.Start() } // Send 发送消息到客户端(线程安全) func (c *Connection) Send(data []byte) error { select { case <-c.ctx.Done(): return errors.New("连接已关闭,无法发送消息") default: // 加锁防止并发写 c.writeMutex.Lock() defer c.writeMutex.Unlock() // 设置写超时 c.conn.SetWriteDeadline(time.Now().Add(c.manager.config.WriteTimeout)) // 发送消息(使用连接的默认类型,支持动态调整) err := c.conn.WriteMessage(c.manager.config.MsgType, data) if err != nil { return fmt.Errorf("发送消息失败:%w", err) } return nil } } // Close 关闭连接(优雅清理) func (c *Connection) Close(err error) { // 防止重复关闭 select { case <-c.ctx.Done(): return default: } // 取消上下文(终止所有协程) c.cancel() // 关闭底层连接(友好关闭) _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error())) _ = c.conn.Close() // 从管理器移除 c.manager.mutex.Lock() delete(c.manager.connections, c.connID) c.manager.mutex.Unlock() // 触发断开回调 c.manager.OnDisconnect(c.connID, err) log.Printf("连接[%s]已关闭,当前在线数:%d,原因:%v", c.connID, c.manager.GetOnlineCount(), err) } // GetOnlineCount 获取在线连接数 func (m *Manager) GetOnlineCount() int { m.mutex.RLock() defer m.mutex.RUnlock() return len(m.connections) } // Broadcast 广播消息到所有在线连接 func (m *Manager) Broadcast(data []byte) error { m.mutex.RLock() defer m.mutex.RUnlock() if len(m.connections) == 0 { return errors.New("无在线连接") } // 并发发送(非阻塞) var wg sync.WaitGroup var errMsg string for _, conn := range m.connections { wg.Add(1) go func(c *Connection) { defer wg.Done() if err := c.Send(data); err != nil { errMsg += fmt.Sprintf("连接[%s]广播失败:%v;", c.connID, err) } }(conn) } wg.Wait() if errMsg != "" { return errors.New(errMsg) } return nil } // SendToConn 定向发送消息到指定连接 func (m *Manager) SendToConn(connID string, data []byte) error { m.mutex.RLock() conn, exists := m.connections[connID] m.mutex.RUnlock() if !exists { return fmt.Errorf("连接[%s]不存在", connID) } return conn.Send(data) } func (m *Manager) GetAllConn() map[string]*Connection { m.mutex.RLock() defer m.mutex.RUnlock() // 返回副本,防止外部修改 connCopy := make(map[string]*Connection, len(m.connections)) for k, v := range m.connections { connCopy[k] = v } return connCopy } func (m *Manager) GetConn(connID string) *Connection { m.mutex.RLock() defer m.mutex.RUnlock() return m.connections[connID] } // CloseAll 关闭所有连接 func (m *Manager) CloseAll() { m.mutex.RLock() connIDs := make([]string, 0, len(m.connections)) for connID := range m.connections { connIDs = append(connIDs, connID) } m.mutex.RUnlock() for _, connID := range connIDs { m.mutex.RLock() conn := m.connections[connID] m.mutex.RUnlock() if conn != nil { conn.Close(errors.New("服务端主动关闭所有连接")) } } }