package overflow_gateway_websocket import ( "context" "net/http" "go.uber.org/zap" "git.loafle.net/commons_go/logging" channelUtil "git.loafle.net/commons_go/util/channel" "git.loafle.net/overflow/overflow_gateway_websocket/websocket" "github.com/valyala/fasthttp" ) type () type socketsChannelAction struct { channelUtil.Action s Socket } type Server interface { ListenAndServe(addr string) error HandleSocket(pattern string, o SocketHandler) } type server struct { _ctx context.Context _logger *zap.Logger _sh ServerHandler _upgrader *websocket.Upgrader _handlers map[string]SocketHandler _sockets map[string]Socket _socketsCh chan socketsChannelAction } func NewServer(ctx context.Context, sh ServerHandler) Server { sh.Validate() s := &server{ _ctx: ctx, _logger: logging.WithContext(ctx), _sh: sh, _handlers: make(map[string]SocketHandler, 1), _sockets: make(map[string]Socket, 100), _socketsCh: make(chan socketsChannelAction), } s._upgrader = &websocket.Upgrader{ HandshakeTimeout: s._sh.GetHandshakeTimeout(), ReadBufferSize: s._sh.GetReadBufferSize(), WriteBufferSize: s._sh.GetWriteBufferSize(), CheckOrigin: s._sh.OnCheckOrigin, Error: s.onError, EnableCompression: s._sh.GetEnableCompression(), } return s } func (s *server) addSocket(soc Socket) { ca := socketsChannelAction{ s: soc, } ca.Type = channelUtil.ActionTypeCreate s._socketsCh <- ca } func (s *server) removeSocket(soc Socket) { ca := socketsChannelAction{ s: soc, } ca.Type = channelUtil.ActionTypeDelete s._socketsCh <- ca } func (s *server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { ctx.Response.Header.Set("Sec-Websocket-Version", "13") ctx.Error(http.StatusText(status), status) s._sh.OnError(ctx, status, reason) } func (s *server) onDisconnected(soc Socket) { s.removeSocket(soc) s._sh.OnDisconnected(soc) } func (s *server) onConnection(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) co, ok := s._handlers[path] if !ok { s.onError(ctx, fasthttp.StatusNotFound, nil) return } s._upgrader.Upgrade(ctx, nil, func(conn *websocket.Conn, err error) { if err != nil { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } id := s._sh.OnIDGenerate(ctx) soc := NewSocket(s._ctx, id, path, co, conn) s.addSocket(soc) s._sh.OnConnection(soc) soc.run() }) } func (s *server) listenHandler() { for { select { case <-s._ctx.Done(): return case ca := <-s._socketsCh: switch ca.Type { case channelUtil.ActionTypeCreate: s._sockets[ca.s.ID()] = ca.s break case channelUtil.ActionTypeDelete: delete(s._sockets, ca.s.ID()) break } } } } func (s *server) Sockets() map[string]Socket { return s._sockets } func (s *server) HandleSocket(pattern string, soch SocketHandler) { soch.(*SocketHandlers).onDisconnected = s.onDisconnected soch.Validate() s._handlers[pattern] = soch } func (s *server) ListenAndServe(addr string) error { go s.listenHandler() return fasthttp.ListenAndServe(addr, s.onConnection) }