diff --git a/net/server.go b/net/server.go index 0bc7276..9bd9a0b 100644 --- a/net/server.go +++ b/net/server.go @@ -188,6 +188,10 @@ func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, defer s.connections.Delete(conn) servlet.OnConnect(servletCtx, conn) + conn.SetCloseHandler(func(code int, text string) error { + logging.Logger().Debugf("close") + return nil + }) stopChan := make(chan struct{}) servletDoneChan := make(chan struct{}) @@ -205,35 +209,27 @@ func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, select { case <-readerDoneChan: close(stopChan) - conn.Close() <-writerDoneChan <-servletDoneChan - conn = nil case <-writerDoneChan: close(stopChan) - conn.Close() <-readerDoneChan <-servletDoneChan - conn = nil case <-servletDoneChan: close(stopChan) - conn.Close() <-readerDoneChan <-writerDoneChan - conn = nil case <-s.stopChan: close(stopChan) - conn.Close() <-readerDoneChan <-writerDoneChan <-servletDoneChan - conn = nil } } -func handleRead(s *Server, conn *server.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, readChan chan []byte) { +func handleRead(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) { defer func() { - close(doneChan) + doneChan <- struct{}{} }() if 0 < s.ServerHandler.GetMaxMessageSize() { @@ -257,37 +253,31 @@ func handleRead(s *Server, conn *server.Conn, doneChan chan<- struct{}, stopChan go func() { _, message, err = conn.ReadMessage() - if err != nil { - if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { - logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) - } - } close(readMessageChan) }() select { - case <-s.stopChan: + case <-stopChan: + conn.Close() <-readMessageChan break case <-readMessageChan: } if nil != err { - select { - case <-s.stopChan: - break - case <-time.After(time.Second): + if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { + logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) } - continue + break } readChan <- message } } -func handleWrite(s *Server, conn *server.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, writeChan chan []byte) { +func handleWrite(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) { defer func() { - close(doneChan) + doneChan <- struct{}{} }() ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) @@ -319,8 +309,8 @@ func handleWrite(s *Server, conn *server.Conn, doneChan chan<- struct{}, stopCha if err := conn.WriteMessage(server.PingMessage, nil); nil != err { return } - case <-s.stopChan: - break + case <-stopChan: + return } } }