diff --git a/net/server.go b/net/server.go index 355fe69..806c99e 100644 --- a/net/server.go +++ b/net/server.go @@ -1,10 +1,8 @@ package net import ( - "context" "fmt" "net" - "sync" "sync/atomic" "time" @@ -12,18 +10,24 @@ import ( "git.loafle.net/commons/server-go" ) -type Server struct { - ServerHandler ServerHandler - - ctx server.ServerCtx - connections sync.Map - stopChan chan struct{} - stopWg sync.WaitGroup +type Server interface { + server.Server } -func (s *Server) ListenAndServe() error { - if s.stopChan != nil { - return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again")) +func NewServer(serverHandler ServerHandler) Server { + s := &netServer{} + s.ServerHandler = serverHandler + + return s +} + +type netServer struct { + server.Servers +} + +func (s *netServer) ListenAndServe() error { + if s.StopChan != nil { + return fmt.Errorf(s.ServerMessage("already running. Stop it before starting it again")) } var ( @@ -37,52 +41,25 @@ func (s *Server) ListenAndServe() error { return err } - s.ctx = s.ServerHandler.ServerCtx() - if nil == s.ctx { - return fmt.Errorf(s.serverMessage("ServerCtx is nil")) + s.ServerCtx = s.ServerHandler.ServerCtx() + if nil == s.ServerCtx { + return fmt.Errorf(s.ServerMessage("ServerCtx is nil")) } - if err = s.ServerHandler.Init(s.ctx); nil != err { + if err = s.ServerHandler.Init(s.ServerCtx); nil != err { return err } - if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { + if listener, err = s.ServerHandler.Listener(s.ServerCtx); nil != err { return err } - s.stopChan = make(chan struct{}) - s.stopWg.Add(1) + s.StopChan = make(chan struct{}) + s.StopWg.Add(1) return s.handleServer(listener) } -func (s *Server) Shutdown(ctx context.Context) error { - if s.stopChan == nil { - return fmt.Errorf(s.serverMessage("server must be started before stopping it")) - } - close(s.stopChan) - s.stopWg.Wait() - - s.ServerHandler.Destroy(s.ctx) - - s.stopChan = nil - - return nil -} - -func (s *Server) ConnectionSize() int { - var sz int - s.connections.Range(func(k, v interface{}) bool { - sz++ - return true - }) - return sz -} - -func (s *Server) serverMessage(msg string) string { - return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) -} - -func (s *Server) handleServer(listener net.Listener) error { +func (s *netServer) handleServer(listener net.Listener) error { var ( stopping atomic.Value netConn net.Conn @@ -94,18 +71,18 @@ func (s *Server) handleServer(listener net.Listener) error { listener.Close() } - s.ServerHandler.OnStop(s.ctx) + s.ServerHandler.OnStop(s.ServerCtx) - logging.Logger().Infof(s.serverMessage("Stopped")) + logging.Logger().Infof(s.ServerMessage("Stopped")) - s.stopWg.Done() + s.StopWg.Done() }() - if err = s.ServerHandler.OnStart(s.ctx); nil != err { + if err = s.ServerHandler.OnStart(s.ServerCtx); nil != err { return err } - logging.Logger().Infof(s.serverMessage("Started")) + logging.Logger().Infof(s.ServerMessage("Started")) for { acceptChan := make(chan struct{}) @@ -113,14 +90,14 @@ func (s *Server) handleServer(listener net.Listener) error { go func() { if netConn, err = listener.Accept(); err != nil { if nil == stopping.Load() { - logging.Logger().Errorf(s.serverMessage(fmt.Sprintf("%v", err))) + logging.Logger().Errorf(s.ServerMessage(fmt.Sprintf("%v", err))) } } close(acceptChan) }() select { - case <-s.stopChan: + case <-s.StopChan: stopping.Store(true) listener.Close() <-acceptChan @@ -131,7 +108,7 @@ func (s *Server) handleServer(listener net.Listener) error { if nil != err { select { - case <-s.stopChan: + case <-s.StopChan: return nil case <-time.After(time.Second): } @@ -141,178 +118,32 @@ func (s *Server) handleServer(listener net.Listener) error { if 0 < s.ServerHandler.GetConcurrency() { sz := s.ConnectionSize() if sz >= s.ServerHandler.GetConcurrency() { - logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz))) + logging.Logger().Warnf(s.ServerMessage(fmt.Sprintf("max connections size %d, refuse", sz))) netConn.Close() continue } } - servlet := s.ServerHandler.Servlet() + servlet := s.ServerHandler.(ServerHandler).Servlet() if nil == servlet { - logging.Logger().Errorf(s.serverMessage("Servlet is nil")) + logging.Logger().Errorf(s.ServerMessage("Servlet is nil")) continue } - servletCtx := servlet.ServletCtx(s.ctx) + servletCtx := servlet.ServletCtx(s.ServerCtx) if nil == servletCtx { - logging.Logger().Errorf(s.serverMessage("ServletCtx is nil")) + logging.Logger().Errorf(s.ServerMessage("ServletCtx is nil")) continue } if err := servlet.Handshake(servletCtx, netConn); nil != err { - logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err))) + logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err))) continue } conn := server.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize()) - s.stopWg.Add(1) - go s.handleConnection(servlet, servletCtx, conn) - } -} - -func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *server.Conn) { - addr := conn.RemoteAddr() - - defer func() { - if nil != conn { - conn.Close() - } - servlet.OnDisconnect(servletCtx) - logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) - s.stopWg.Done() - }() - - logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr))) - - s.connections.Store(conn, true) - defer s.connections.Delete(conn) - - servlet.OnConnect(servletCtx, conn) - conn.SetCloseHandler(func(code int, text string) error { - logging.Logger().Debugf("close") - return nil - }) - - stopChan := make(chan struct{}) - servletDoneChan := make(chan struct{}) - - readChan := make(chan []byte, 256) - writeChan := make(chan []byte, 256) - - readerDoneChan := make(chan struct{}) - writerDoneChan := make(chan struct{}) - - go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) - go handleRead(s, conn, stopChan, readerDoneChan, readChan) - go handleWrite(s, conn, stopChan, writerDoneChan, writeChan) - - select { - case <-readerDoneChan: - close(stopChan) - <-writerDoneChan - <-servletDoneChan - case <-writerDoneChan: - close(stopChan) - <-readerDoneChan - <-servletDoneChan - case <-servletDoneChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - case <-s.stopChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - <-servletDoneChan - } -} - -func handleRead(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) { - defer func() { - doneChan <- struct{}{} - }() - - if 0 < s.ServerHandler.GetMaxMessageSize() { - conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) - } - if 0 < s.ServerHandler.GetReadTimeout() { - conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) - } - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout())) - return nil - }) - - var ( - message []byte - err error - ) - - for { - readMessageChan := make(chan struct{}) - - go func() { - _, message, err = conn.ReadMessage() - close(readMessageChan) - }() - - select { - case <-stopChan: - conn.Close() - <-readMessageChan - return - case <-readMessageChan: - } - - if nil != err { - if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { - logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) - } - return - } - - readChan <- message - } -} - -func handleWrite(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) { - defer func() { - doneChan <- struct{}{} - }() - - ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) - defer func() { - ticker.Stop() - }() - for { - select { - case message, ok := <-writeChan: - if 0 < s.ServerHandler.GetWriteTimeout() { - conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout())) - } - - if !ok { - conn.WriteMessage(server.CloseMessage, []byte{}) - return - } - - w, err := conn.NextWriter(server.TextMessage) - if err != nil { - return - } - w.Write(message) - - if err := w.Close(); nil != err { - return - } - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout())) - if err := conn.WriteMessage(server.PingMessage, nil); nil != err { - return - } - case <-stopChan: - return - } + s.StopWg.Add(1) + go s.HandleConnection(servlet, servletCtx, conn) } } diff --git a/server.go b/server.go new file mode 100644 index 0000000..13efc35 --- /dev/null +++ b/server.go @@ -0,0 +1,200 @@ +package server + +import ( + "context" + "fmt" + "sync" + "time" + + logging "git.loafle.net/commons/logging-go" +) + +type Server interface { + ListenAndServe() error + Shutdown(ctx context.Context) error + ConnectionSize() int + + HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) + HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) + HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) +} + +type Servers struct { + ServerHandler ServerHandler + + ServerCtx ServerCtx + Connections sync.Map + StopChan chan struct{} + StopWg sync.WaitGroup +} + +func (s *Servers) Shutdown(ctx context.Context) error { + if s.StopChan == nil { + return fmt.Errorf("server must be started before stopping it") + } + close(s.StopChan) + s.StopWg.Wait() + + s.ServerHandler.Destroy(s.ServerCtx) + + s.StopChan = nil + + return nil +} + +func (s *Servers) ConnectionSize() int { + var sz int + s.Connections.Range(func(k, v interface{}) bool { + sz++ + return true + }) + return sz +} + +func (s *Servers) ServerMessage(msg string) string { + return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) +} + +func (s *Servers) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) { + addr := conn.RemoteAddr() + + defer func() { + if nil != conn { + conn.Close() + } + servlet.OnDisconnect(servletCtx) + logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) + s.StopWg.Done() + }() + + logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been connected", addr))) + + s.Connections.Store(conn, true) + defer s.Connections.Delete(conn) + + servlet.OnConnect(servletCtx, conn) + conn.SetCloseHandler(func(code int, text string) error { + logging.Logger().Debugf("close") + return nil + }) + + stopChan := make(chan struct{}) + servletDoneChan := make(chan struct{}) + + readChan := make(chan []byte) + writeChan := make(chan []byte) + + readerDoneChan := make(chan struct{}) + writerDoneChan := make(chan struct{}) + + go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) + go s.HandleRead(conn, stopChan, readerDoneChan, readChan) + go s.HandleWrite(conn, stopChan, writerDoneChan, writeChan) + + select { + case <-readerDoneChan: + close(stopChan) + <-writerDoneChan + <-servletDoneChan + case <-writerDoneChan: + close(stopChan) + <-readerDoneChan + <-servletDoneChan + case <-servletDoneChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + case <-s.StopChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + <-servletDoneChan + } +} + +func (s *Servers) HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) { + defer func() { + doneChan <- struct{}{} + }() + + if 0 < s.ServerHandler.GetMaxMessageSize() { + conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) + } + if 0 < s.ServerHandler.GetReadTimeout() { + conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) + } + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout())) + return nil + }) + + var ( + message []byte + err error + ) + + for { + readMessageChan := make(chan struct{}) + + go func() { + _, message, err = conn.ReadMessage() + close(readMessageChan) + }() + + select { + case <-s.StopChan: + <-readMessageChan + return + case <-readMessageChan: + } + + if nil != err { + if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) { + logging.Logger().Debugf(s.ServerMessage(fmt.Sprintf("Read error %v", err))) + } + return + } + + readChan <- message + } +} + +func (s *Servers) HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) { + defer func() { + doneChan <- struct{}{} + }() + + ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) + defer func() { + ticker.Stop() + }() + for { + select { + case message, ok := <-writeChan: + if 0 < s.ServerHandler.GetWriteTimeout() { + conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout())) + } + if !ok { + conn.WriteMessage(CloseMessage, []byte{}) + return + } + + w, err := conn.NextWriter(TextMessage) + if err != nil { + return + } + w.Write(message) + + if err := w.Close(); nil != err { + return + } + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout())) + if err := conn.WriteMessage(PingMessage, nil); nil != err { + return + } + case <-stopChan: + return + } + } +}