package web import ( "context" "fmt" "net" "net/http" "sync" "sync/atomic" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go/socket" "github.com/valyala/fasthttp" ) type Server struct { ServerHandler ServerHandler ctx server.ServerCtx stopChan chan struct{} stopWg sync.WaitGroup srw socket.ServerReadWriter hs *fasthttp.Server upgrader *Upgrader } func (s *Server) ListenAndServe() error { var ( err error listener net.Listener ) if nil == s.ServerHandler { return fmt.Errorf("%s server handler must be specified", s.logHeader()) } s.ServerHandler.Validate() if s.stopChan != nil { return fmt.Errorf("%s already running. Stop it before starting it again", s.logHeader()) } s.ctx = s.ServerHandler.ServerCtx() if nil == s.ctx { return fmt.Errorf("%s ServerCtx is nil", s.logHeader()) } s.hs = &fasthttp.Server{ Handler: s.httpHandler, Name: s.ServerHandler.GetName(), Concurrency: s.ServerHandler.GetConcurrency(), ReadBufferSize: s.ServerHandler.GetReadBufferSize(), WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), ReadTimeout: s.ServerHandler.GetReadTimeout(), WriteTimeout: s.ServerHandler.GetWriteTimeout(), } s.upgrader = &Upgrader{ HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), ReadBufferSize: s.ServerHandler.GetReadBufferSize(), WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin, Error: s.onError, EnableCompression: s.ServerHandler.IsEnableCompression(), } if err = s.ServerHandler.Init(s.ctx); nil != err { logging.Logger().Errorf("%s Init has been failed %v", s.logHeader(), err) return err } if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { return err } s.stopChan = make(chan struct{}) s.srw.ReadwriteHandler = s.ServerHandler s.srw.ServerStopChan = s.stopChan s.srw.ServerStopWg = &s.stopWg s.stopWg.Add(1) return s.handleServer(listener) } func (s *Server) Shutdown(ctx context.Context) error { if s.stopChan == nil { return fmt.Errorf("%s must be started before stopping it", s.logHeader()) } close(s.stopChan) s.stopWg.Wait() s.ServerHandler.Destroy(s.ctx) s.stopChan = nil return nil } func (s *Server) logHeader() string { return fmt.Sprintf("Server[%s]:", s.ServerHandler.GetName()) } func (s *Server) handleServer(listener net.Listener) error { var ( err error stopping atomic.Value ) defer func() { if nil != listener { listener.Close() } s.ServerHandler.OnStop(s.ctx) logging.Logger().Infof("%s Stopped", s.logHeader()) s.stopWg.Done() }() if err = s.ServerHandler.OnStart(s.ctx); nil != err { logging.Logger().Errorf("%s OnStart has been failed %v", s.logHeader(), err) return err } hsCloseChan := make(chan error) go func() { if err := s.hs.Serve(listener); nil != err { if nil == stopping.Load() { hsCloseChan <- err return } } hsCloseChan <- nil }() logging.Logger().Infof("%s Started", s.logHeader()) select { case err, _ := <-hsCloseChan: if nil != err { return err } case <-s.stopChan: stopping.Store(true) listener.Close() <-hsCloseChan listener = nil } return nil } func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { var ( servlet Servlet err error ) if 0 < s.ServerHandler.GetConcurrency() { sz := s.srw.ConnectionSize() if sz >= s.ServerHandler.GetConcurrency() { logging.Logger().Warnf("%s max connections size %d, refuse", s.logHeader(), sz) s.onError(ctx, fasthttp.StatusServiceUnavailable, err) return } } if servlet = s.ServerHandler.Servlet(s.ctx, ctx); nil == servlet { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } var responseHeader *fasthttp.ResponseHeader servletCtx := servlet.ServletCtx(s.ctx) if responseHeader, err = servlet.Handshake(servletCtx, ctx); nil != err { s.onError(ctx, http.StatusNotAcceptable, fmt.Errorf("Handshake err: %v", err)) return } s.upgrader.Upgrade(ctx, responseHeader, func(conn *socket.SocketConn, err error) { if err != nil { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } s.stopWg.Add(1) s.srw.HandleConnection(servlet, servletCtx, conn) }) } func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { s.ServerHandler.(ServerHandler).OnError(s.ctx, ctx, status, reason) }