package overflow_gateway_websocket import ( "net/http" "sync" "git.loafle.net/overflow/overflow_gateway_websocket/websocket" "github.com/valyala/fasthttp" ) type () type Server interface { ListenAndServe(addr string) error HandleSocket(pattern string, o *SocketOptions) } type server struct { _option *ServerOptions _upgrader *websocket.Upgrader _handlers map[string]*SocketOptions _sockets map[string]Socket _socketsMtx sync.Mutex _addSocketCh chan Socket _removeSocketCh chan Socket } func NewServer(o *ServerOptions) Server { s := &server{ _option: o.Validate(), _handlers: make(map[string]*SocketOptions, 1), _sockets: make(map[string]Socket, 100), _addSocketCh: make(chan Socket), _removeSocketCh: make(chan Socket), } s._upgrader = &websocket.Upgrader{ HandshakeTimeout: s._option.HandshakeTimeout, ReadBufferSize: s._option.ReadBufferSize, WriteBufferSize: s._option.WriteBufferSize, CheckOrigin: s._option.OnCheckOrigin, Error: s.onError, EnableCompression: s._option.EnableCompression, } return s } func (s *server) addSocket(soc Socket) { s._addSocketCh <- soc } func (s *server) removeSocket(soc Socket) { s._removeSocketCh <- soc } func (s *server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { ctx.Response.Header.Set("Sec-Websocket-Version", "13") ctx.Error(http.StatusText(status), status) s._option.OnError(ctx, status, reason) } func (s *server) onDisconnected(soc Socket) { s.removeSocket(soc) s._option.OnDisconnected(soc) } func (s *server) onConnection(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) co, ok := s._handlers[path] if !ok { s.onError(ctx, fasthttp.StatusNotFound, nil) return } s._upgrader.Upgrade(ctx, nil, func(conn *websocket.Conn, err error) { if err != nil { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } id := s._option.IDGenerator(ctx) soc := NewSocket(id, path, co, conn) s.addSocket(soc) s._option.OnConnection(soc) soc.run() }) } func (s *server) listenHandler() { for { select { // Add new a socket case soc := <-s._addSocketCh: s._socketsMtx.Lock() s._sockets[soc.ID()] = soc s._socketsMtx.Unlock() // remove a socket case soc := <-s._removeSocketCh: s._socketsMtx.Lock() delete(s._sockets, soc.ID()) s._socketsMtx.Unlock() } } } func (s *server) Sockets() map[string]Socket { return s._sockets } func (s *server) HandleSocket(pattern string, o *SocketOptions) { o.onDisconnected = s.onDisconnected s._handlers[pattern] = o.Validate() } func (s *server) ListenAndServe(addr string) error { go s.listenHandler() return fasthttp.ListenAndServe(addr, s.onConnection) }