gf-common/server/ws/websocket.go

489 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package ws
import (
"context"
"errors"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/os/gtimer"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gorilla/websocket"
)
// 常量定义:默认配置
const (
// 默认读写缓冲区大小(字节)
DefaultReadBufferSize = 1024
DefaultWriteBufferSize = 1024
// 默认心跳间隔每30秒发送一次心跳
DefaultHeartbeatInterval = 30 * time.Second
// 默认心跳超时60秒未收到客户端心跳响应则关闭连接
DefaultHeartbeatTimeout = 60 * time.Second
// 默认读写超时(秒)
DefaultReadTimeout = 60 * time.Second
DefaultWriteTimeout = 10 * time.Second
// 消息类型
MessageTypeText = websocket.TextMessage
MessageTypeBinary = websocket.BinaryMessage
// 心跳最大重试次数
HeartbeatMaxRetry = 3
)
// Config WebSocket服务端配置
type Config struct {
// 读写缓冲区大小
ReadBufferSize int
WriteBufferSize int
// 跨域配置是否允许所有跨域生产环境建议指定Origin
AllowAllOrigins bool
// 允许的跨域Origin列表AllowAllOrigins=false时生效
AllowedOrigins []string
// 心跳配置
HeartbeatInterval time.Duration // 心跳发送间隔
HeartbeatTimeout time.Duration // 心跳超时时间
// 读写超时
ReadTimeout time.Duration
WriteTimeout time.Duration
MsgType int // 发送消息的默认类型
HeartbeatValue string // 心跳消息的标识字段值(如"heartbeat"、"pong"
HeartbeatKey string // 心跳消息的标识字段名(如"type"
}
// 默认配置
func DefaultConfig() *Config {
return &Config{
ReadBufferSize: DefaultReadBufferSize,
WriteBufferSize: DefaultWriteBufferSize,
AllowAllOrigins: true,
AllowedOrigins: []string{},
HeartbeatInterval: DefaultHeartbeatInterval,
HeartbeatTimeout: DefaultHeartbeatTimeout,
ReadTimeout: DefaultReadTimeout,
WriteTimeout: DefaultWriteTimeout,
MsgType: MessageTypeText,
HeartbeatValue: "heartbeat",
HeartbeatKey: "type", // 心跳消息的标识字段名,默认"type"
}
}
// Connection WebSocket连接结构体
type Connection struct {
conn *websocket.Conn // 底层连接
connID string // 唯一连接ID
manager *Manager // 所属管理器
createTime time.Time // 连接创建时间
heartbeatChan time.Time // 心跳通道(用于检测客户端响应)
heartbeatTime *gtimer.Entry
ctx context.Context // 上下文
cancel context.CancelFunc // 上下文取消函数
writeMutex sync.Mutex // 写消息互斥锁(防止并发写)
heartbeatRetry int // 心跳发送重试次数
}
// Manager WebSocket连接管理器
type Manager struct {
config *Config // 配置
upgrader *websocket.Upgrader // HTTP升级器
connections map[string]*Connection // 所有在线连接connID -> Connection
mutex sync.RWMutex // 读写锁保护connections
// 业务回调:收到消息时触发(用户自定义处理逻辑)
OnMessage func(connID string, msgType int, data any)
// 业务回调:连接建立时触发
OnConnect func(connID string)
// 业务回调:连接关闭时触发
OnDisconnect func(connID string, err error)
}
// Merge 合并配置,用传入的配置覆盖非零值部分
func (c *Config) Merge(other *Config) *Config {
result := *c // 复制当前配置
if other == nil {
return &result
}
if other.ReadBufferSize > 0 {
result.ReadBufferSize = other.ReadBufferSize
}
if other.WriteBufferSize > 0 {
result.WriteBufferSize = other.WriteBufferSize
}
if other.HeartbeatInterval > 0 {
result.HeartbeatInterval = other.HeartbeatInterval
}
if other.HeartbeatTimeout > 0 {
result.HeartbeatTimeout = other.HeartbeatTimeout
}
if other.ReadTimeout > 0 {
result.ReadTimeout = other.ReadTimeout
}
if other.WriteTimeout > 0 {
result.WriteTimeout = other.WriteTimeout
}
if other.AllowAllOrigins {
result.AllowAllOrigins = other.AllowAllOrigins
}
if other.HeartbeatValue != "" {
result.HeartbeatValue = other.HeartbeatValue
}
if other.HeartbeatKey != "" {
result.HeartbeatKey = other.HeartbeatKey
}
if len(other.AllowedOrigins) > 0 {
result.AllowedOrigins = other.AllowedOrigins
}
if other.MsgType != 0 {
result.MsgType = other.MsgType
}
return &result
}
// NewManager 创建连接管理器
func NewManager(config *Config) *Manager {
defaultConfig := DefaultConfig()
finalConfig := defaultConfig.Merge(config)
// 初始化升级器
upgrader := &websocket.Upgrader{
ReadBufferSize: config.ReadBufferSize,
WriteBufferSize: config.WriteBufferSize,
CheckOrigin: func(r *http.Request) bool {
// 跨域检查
if config.AllowAllOrigins {
return true
}
origin := r.Header.Get("Origin")
for _, allowed := range finalConfig.AllowedOrigins {
if origin == allowed {
return true
}
}
return false
},
}
return &Manager{
config: finalConfig,
upgrader: upgrader,
connections: make(map[string]*Connection),
mutex: sync.RWMutex{},
// 默认回调(用户可覆盖)
OnMessage: func(connID string, msgType int, data any) {
log.Printf("[默认回调] 收到连接[%s]消息:%s", connID, gconv.String(data))
},
OnConnect: func(connID string) {
log.Printf("[默认回调] 连接[%s]已建立", connID)
},
OnDisconnect: func(connID string, err error) {
log.Printf("[默认回调] 连接[%s]已关闭:%v", connID, err)
},
}
}
// Upgrade HTTP升级为WebSocket连接
// connID自定义连接唯一ID如用户ID、设备ID
func (m *Manager) Upgrade(w http.ResponseWriter, r *http.Request, connID string) (*Connection, error) {
if connID == "" {
return nil, errors.New("连接ID不能为空")
}
// 检查连接ID是否已存在
m.mutex.RLock()
_, exists := m.connections[connID]
m.mutex.RUnlock()
if exists {
return nil, fmt.Errorf("连接ID[%s]已存在", connID)
}
// 升级HTTP连接
conn, err := m.upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, fmt.Errorf("升级WebSocket失败%w", err)
}
// 创建上下文(用于优雅关闭)
ctx, cancel := context.WithCancel(context.Background())
// 创建连接实例
wsConn := &Connection{
conn: conn,
connID: connID,
manager: m,
createTime: time.Now(),
heartbeatChan: time.Now(), // 缓冲1防止阻塞
ctx: ctx,
cancel: cancel,
writeMutex: sync.Mutex{},
heartbeatRetry: 0,
}
wsConn.heartbeatTime = gtimer.AddSingleton(gctx.New(), m.config.HeartbeatTimeout, func(ctx context.Context) {
log.Printf("[心跳检测] 连接[%s]已关闭:心跳超时", wsConn.connID)
wsConn.heartbeatTime.Close()
wsConn.heartbeatTime.Stop()
wsConn.heartbeatTime = nil
wsConn.ctx.Done()
wsConn.Close(fmt.Errorf("心跳超时"))
})
// 添加到管理器
m.mutex.Lock()
m.connections[connID] = wsConn
m.mutex.Unlock()
// 触发连接建立回调
m.OnConnect(connID)
// 启动读消息协程
go wsConn.ReadPump()
// 启动写消息协程(处理异步发送)
go wsConn.WritePump()
// 启动心跳检测协程
go wsConn.Heartbeat()
return wsConn, nil
}
// ReadPump 读取客户端消息(持续运行)
func (c *Connection) ReadPump() {
defer func() {
// 发生panic时关闭连接
if err := recover(); err != nil {
log.Printf("连接[%s]读消息协程panic%v", c.connID, err)
}
// 关闭连接并清理
c.Close(fmt.Errorf("读消息协程退出"))
}()
// 循环读取消息
for {
select {
case <-c.ctx.Done():
return // 上下文已取消,退出
default:
// 设置读超时(每次读取前重置,防止长时间无消息超时)
c.conn.SetReadDeadline(time.Now().Add(c.manager.config.ReadTimeout))
// 读取客户端消息
msgType, data, err := c.conn.ReadMessage()
if err != nil {
// 区分正常关闭和异常错误
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
c.Close(fmt.Errorf("客户端主动关闭:%s代码%d", closeErr.Text, closeErr.Code))
} else {
c.Close(fmt.Errorf("读取消息失败:%w", err))
}
return
}
// 尝试解析JSON格式的心跳消息精准判断替代包含判断
isHeartbeat := false
// 先尝试解析为JSON对象
var msgMap map[string]interface{}
if err := gjson.DecodeTo(data, &msgMap); err == nil {
// 获取心跳标识字段的值
heartbeatValue := gconv.String(msgMap[c.manager.config.HeartbeatKey])
if heartbeatValue == c.manager.config.HeartbeatValue {
isHeartbeat = true
}
} else {
// 非JSON格式降级为包含判断兼容纯文本心跳
str := gconv.String(data)
if gstr.Contains(str, c.manager.config.HeartbeatValue) {
isHeartbeat = true
}
}
if isHeartbeat {
log.Printf("[心跳] 收到连接[%s]心跳消息:%s", c.connID, string(data))
// 心跳消息:重置重试次数 + 发送心跳信号 + 重置读超时
js, err := gjson.Encode(&Msg[any]{c.manager.config.HeartbeatValue, nil, gtime.Timestamp()})
if err != nil {
log.Printf("[心跳] 客户端[%s]json编码失败", c.connID)
continue
}
err = c.Send(js)
if err != nil {
log.Printf("[心跳] 客户端[%s]发送心跳消息失败", c.connID)
continue
}
c.heartbeatTime.Reset()
continue // 跳过业务回调
}
// 非心跳消息:触发业务回调
c.manager.OnMessage(c.connID, msgType, data)
}
}
}
type Msg[T any] struct {
Type string `json:"type"`
Data T `json:"data"`
Timestamp int64 `json:"timestamp"`
}
// WritePump 处理异步写消息(持续运行)
// 扩展为监听写队列,防止消息丢失
func (c *Connection) WritePump() {
defer func() {
if err := recover(); err != nil {
log.Printf("连接[%s]写消息协程panic%v", c.connID, err)
}
}()
// 暂时保持简化,实际可扩展为带缓冲的写队列
<-c.ctx.Done()
}
// Heartbeat 心跳检测(持续运行)
func (c *Connection) Heartbeat() {
defer func() {
if err := recover(); err != nil {
log.Printf("连接[%s]心跳协程panic%v", c.connID, err)
}
}()
c.heartbeatTime.Start()
}
// Send 发送消息到客户端(线程安全)
func (c *Connection) Send(data []byte) error {
select {
case <-c.ctx.Done():
return errors.New("连接已关闭,无法发送消息")
default:
// 加锁防止并发写
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
// 设置写超时
c.conn.SetWriteDeadline(time.Now().Add(c.manager.config.WriteTimeout))
// 发送消息(使用连接的默认类型,支持动态调整)
err := c.conn.WriteMessage(c.manager.config.MsgType, data)
if err != nil {
return fmt.Errorf("发送消息失败:%w", err)
}
return nil
}
}
// Close 关闭连接(优雅清理)
func (c *Connection) Close(err error) {
// 防止重复关闭
select {
case <-c.ctx.Done():
return
default:
}
// 取消上下文(终止所有协程)
c.cancel()
// 关闭底层连接(友好关闭)
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
_ = c.conn.Close()
// 从管理器移除
c.manager.mutex.Lock()
delete(c.manager.connections, c.connID)
c.manager.mutex.Unlock()
// 触发断开回调
c.manager.OnDisconnect(c.connID, err)
log.Printf("连接[%s]已关闭,当前在线数:%d原因%v", c.connID, c.manager.GetOnlineCount(), err)
}
// GetOnlineCount 获取在线连接数
func (m *Manager) GetOnlineCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.connections)
}
// Broadcast 广播消息到所有在线连接
func (m *Manager) Broadcast(data []byte) error {
m.mutex.RLock()
defer m.mutex.RUnlock()
if len(m.connections) == 0 {
return errors.New("无在线连接")
}
// 并发发送(非阻塞)
var wg sync.WaitGroup
var errMsg string
for _, conn := range m.connections {
wg.Add(1)
go func(c *Connection) {
defer wg.Done()
if err := c.Send(data); err != nil {
errMsg += fmt.Sprintf("连接[%s]广播失败:%v", c.connID, err)
}
}(conn)
}
wg.Wait()
if errMsg != "" {
return errors.New(errMsg)
}
return nil
}
// SendToConn 定向发送消息到指定连接
func (m *Manager) SendToConn(connID string, data []byte) error {
m.mutex.RLock()
conn, exists := m.connections[connID]
m.mutex.RUnlock()
if !exists {
return fmt.Errorf("连接[%s]不存在", connID)
}
return conn.Send(data)
}
func (m *Manager) GetAllConn() map[string]*Connection {
m.mutex.RLock()
defer m.mutex.RUnlock()
// 返回副本,防止外部修改
connCopy := make(map[string]*Connection, len(m.connections))
for k, v := range m.connections {
connCopy[k] = v
}
return connCopy
}
func (m *Manager) GetConn(connID string) *Connection {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.connections[connID]
}
// CloseAll 关闭所有连接
func (m *Manager) CloseAll() {
m.mutex.RLock()
connIDs := make([]string, 0, len(m.connections))
for connID := range m.connections {
connIDs = append(connIDs, connID)
}
m.mutex.RUnlock()
for _, connID := range connIDs {
m.mutex.RLock()
conn := m.connections[connID]
m.mutex.RUnlock()
if conn != nil {
conn.Close(errors.New("服务端主动关闭所有连接"))
}
}
}