package websocket_fasthttp import ( "fmt" "net" "net/http" "sync" "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/websocket_fasthttp/websocket" "github.com/valyala/fasthttp" ) type Server interface { Start() error Stop() Context() ServerContext } func New(sh ServerHandler) Server { s := &server{ sh: sh, } s.ctx = newServerContext() return s } type server struct { ctx ServerContext sh ServerHandler httpServer *fasthttp.Server upgrader *websocket.Upgrader listener net.Listener stopChan chan struct{} stopWg sync.WaitGroup } func (s *server) Start() error { if nil == s.sh { logging.Logger().Panic("Server: server handler must be specified.") } s.sh.Validate() if s.stopChan != nil { logging.Logger().Panic("Server: server is already running. Stop it before starting it again") } s.httpServer = &fasthttp.Server{ Handler: s.handleRequest, Name: s.sh.GetName(), Concurrency: s.sh.GetConcurrency(), ReadBufferSize: s.sh.GetReadBufferSize(), WriteBufferSize: s.sh.GetWriteBufferSize(), ReadTimeout: s.sh.GetReadTimeout(), WriteTimeout: s.sh.GetWriteTimeout(), } s.upgrader = &websocket.Upgrader{ HandshakeTimeout: s.sh.GetHandshakeTimeout(), ReadBufferSize: s.sh.GetReadBufferSize(), WriteBufferSize: s.sh.GetWriteBufferSize(), CheckOrigin: s.sh.CheckOrigin, Error: s.handleError, EnableCompression: s.sh.IsEnableCompression(), } var err error if err = s.sh.Init(s.ctx); nil != err { logging.Logger().Panic(fmt.Sprintf("Server: Initialization of server has been failed %v", err)) } var listener net.Listener if listener, err = s.sh.Listen(s.ctx); nil != err { return err } s.listener = newGracefulListener(listener, s.sh.GetMaxStopWaitTime()) s.stopChan = make(chan struct{}) s.stopWg.Add(1) go handleServer(s) return nil } func (s *server) Stop() { if s.stopChan == nil { logging.Logger().Panic("Server: server must be started before stopping it") } close(s.stopChan) s.stopWg.Wait() s.stopChan = nil s.sh.OnStop(s.ctx) logging.Logger().Info(fmt.Sprintf("Server[%s] is stopped", s.sh.GetName())) } func (s *server) Context() ServerContext { return s.ctx } func handleServer(s *server) { go func() { defer s.stopWg.Done() if err := s.httpServer.Serve(s.listener); nil != err { logging.Logger().Error(fmt.Sprintf("Server: Server err - %v", err)) } }() logging.Logger().Info(fmt.Sprintf("Server[%s] is started", s.sh.GetName())) s.sh.OnStart(s.ctx) select { case <-s.stopChan: s.listener.Close() return } } func (s *server) handleRequest(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) var socketHandler SocketHandler var err error if socketHandler, err = s.sh.GetSocketHandler(path); nil != err { s.handleError(ctx, fasthttp.StatusNotFound, err) return } var responseHeader *fasthttp.ResponseHeader var socketID string if socketID, responseHeader = socketHandler.Handshake(s.ctx, ctx); "" == socketID { s.handleError(ctx, http.StatusNotAcceptable, fmt.Errorf("Server: Handshake err")) return } s.upgrader.Upgrade(ctx, responseHeader, func(conn *websocket.Conn, err error) { if err != nil { s.handleError(ctx, fasthttp.StatusInternalServerError, err) return } soc := newSocket(s.ctx, socketID, conn, socketHandler) s.stopWg.Add(1) handleConnection(s, soc, socketHandler) }) } func (s *server) handleError(ctx *fasthttp.RequestCtx, status int, reason error) { ctx.Response.Header.Set("Sec-Websocket-Version", "13") ctx.Error(http.StatusText(status), status) s.sh.OnError(s.ctx, ctx, status, reason) } func handleConnection(s *server, soc Socket, socketHandler SocketHandler) { defer s.stopWg.Done() logging.Logger().Debug(fmt.Sprintf("Server: Client[%s] is connected.", soc.RemoteAddr())) soc = socketHandler.OnConnect(soc) socketHandler.addSocket(soc) clientStopChan := make(chan struct{}) handleDoneChan := make(chan struct{}) go socketHandler.Handle(soc, clientStopChan, handleDoneChan) select { case <-s.stopChan: close(clientStopChan) <-handleDoneChan case <-handleDoneChan: close(clientStopChan) socketHandler.OnDisconnect(soc) logging.Logger().Debug(fmt.Sprintf("Server: Client[%s] is disconnected.", soc.RemoteAddr())) socketHandler.removeSocket(soc) soc.Close() } }