package websocket import ( "context" "fmt" "net" "net/http" "sync" "time" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" "github.com/valyala/fasthttp" ) type Server struct { ServerHandler ServerHandler ctx server.ServerCtx hs *fasthttp.Server upgrader *Upgrader connections sync.Map stopChan chan struct{} stopWg sync.WaitGroup } func (s *Server) ListenAndServe() error { var ( err error listener net.Listener ) if nil == s.ServerHandler { return fmt.Errorf("Server: server handler must be specified") } s.ServerHandler.Validate() if s.stopChan != nil { return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again")) } s.ctx = s.ServerHandler.ServerCtx() if nil == s.ctx { return fmt.Errorf(s.serverMessage("ServerCtx is nil")) } s.hs = &fasthttp.Server{ Handler: s.httpHandler, Name: s.ServerHandler.GetName(), Concurrency: s.ServerHandler.GetConcurrency(), ReadBufferSize: s.ServerHandler.GetReadBufferSize(), WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), ReadTimeout: s.ServerHandler.GetReadTimeout(), WriteTimeout: s.ServerHandler.GetWriteTimeout(), } s.upgrader = &Upgrader{ HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), ReadBufferSize: s.ServerHandler.GetReadBufferSize(), WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), CheckOrigin: s.ServerHandler.CheckOrigin, Error: s.onError, EnableCompression: s.ServerHandler.IsEnableCompression(), } if err = s.ServerHandler.Init(s.ctx); nil != err { return err } if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { return err } s.stopChan = make(chan struct{}) s.stopWg.Add(1) return s.handleServer(listener) } func (s *Server) Shutdown(ctx context.Context) error { if s.stopChan == nil { return fmt.Errorf(s.serverMessage("server must be started before stopping it")) } close(s.stopChan) s.stopWg.Wait() s.ServerHandler.Destroy(s.ctx) s.stopChan = nil return nil } func (s *Server) ConnectionSize() int { var sz int s.connections.Range(func(k, v interface{}) bool { sz++ return true }) return sz } func (s *Server) serverMessage(msg string) string { return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) } func (s *Server) handleServer(listener net.Listener) error { var ( err error ) errChan := make(chan error) defer func() { if nil != listener { listener.Close() } s.stopWg.Done() }() go func() { if err := s.hs.Serve(listener); nil != err { errChan <- err return } close(errChan) }() select { case err, _ := <-errChan: if nil != err { return err } } defer func() { s.ServerHandler.OnStop(s.ctx) logging.Logger().Infof(s.serverMessage("Stopped")) }() if err = s.ServerHandler.OnStart(s.ctx); nil != err { return err } logging.Logger().Infof(s.serverMessage("Started")) select { case <-s.stopChan: listener.Close() listener = nil } return nil } func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) var ( servlet Servlet err error ) if 0 < s.ServerHandler.GetConcurrency() { sz := s.ConnectionSize() if sz >= s.ServerHandler.GetConcurrency() { logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz))) s.onError(ctx, fasthttp.StatusServiceUnavailable, err) return } } if servlet = s.ServerHandler.Servlet(path); nil == servlet { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } var responseHeader *fasthttp.ResponseHeader servletCtx := servlet.ServletCtx(s.ctx) if responseHeader, err = servlet.Handshake(servletCtx, ctx); nil != err { s.onError(ctx, http.StatusNotAcceptable, fmt.Errorf("Handshake err: %v", err)) return } s.upgrader.Upgrade(ctx, responseHeader, func(conn *server.Conn, err error) { if err != nil { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } s.stopWg.Add(1) go s.handleConnection(servlet, servletCtx, conn) }) } func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *server.Conn) { addr := conn.RemoteAddr() defer func() { if nil != conn { conn.Close() } servlet.OnDisconnect(servletCtx) logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) s.stopWg.Done() }() logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr))) s.connections.Store(conn, true) defer s.connections.Delete(conn) servlet.OnConnect(servletCtx, conn) conn.SetCloseHandler(func(code int, text string) error { logging.Logger().Debugf("close") return nil }) stopChan := make(chan struct{}) servletDoneChan := make(chan struct{}) readChan := make(chan []byte) writeChan := make(chan []byte) readerDoneChan := make(chan struct{}) writerDoneChan := make(chan struct{}) go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) go handleRead(s, conn, stopChan, readerDoneChan, readChan) go handleWrite(s, conn, stopChan, writerDoneChan, writeChan) select { case <-readerDoneChan: close(stopChan) <-writerDoneChan <-servletDoneChan case <-writerDoneChan: close(stopChan) <-readerDoneChan <-servletDoneChan case <-servletDoneChan: close(stopChan) <-readerDoneChan <-writerDoneChan case <-s.stopChan: close(stopChan) <-readerDoneChan <-writerDoneChan <-servletDoneChan } } func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { s.ServerHandler.OnError(s.ctx, ctx, status, reason) } func handleRead(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) { defer func() { doneChan <- struct{}{} }() if 0 < s.ServerHandler.GetMaxMessageSize() { conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) } if 0 < s.ServerHandler.GetReadTimeout() { conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) } conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout())) return nil }) var ( message []byte err error ) for { readMessageChan := make(chan struct{}) go func() { _, message, err = conn.ReadMessage() close(readMessageChan) }() select { case <-s.stopChan: <-readMessageChan return case <-readMessageChan: } if nil != err { if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) } return } readChan <- message } } func handleWrite(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) { defer func() { doneChan <- struct{}{} }() ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) defer func() { ticker.Stop() }() for { select { case message, ok := <-writeChan: if 0 < s.ServerHandler.GetWriteTimeout() { conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout())) } if !ok { conn.WriteMessage(server.CloseMessage, []byte{}) return } w, err := conn.NextWriter(server.TextMessage) if err != nil { return } w.Write(message) if err := w.Close(); nil != err { return } case <-ticker.C: conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout())) if err := conn.WriteMessage(server.PingMessage, nil); nil != err { return } case <-s.stopChan: return } } }