diff --git a/client-connection-handler.go b/client-connection-handler.go new file mode 100644 index 0000000..4b5505c --- /dev/null +++ b/client-connection-handler.go @@ -0,0 +1,42 @@ +package server + +import ( + "time" +) + +type ClientConnectionHandler interface { + ConnectionHandler + GetReconnectInterval() time.Duration + GetReconnectTryTime() int +} + +type ClientConnectionHandlers struct { + ConnectionHandlers + + ReconnectInterval time.Duration + ReconnectTryTime int +} + +func (cch *ClientConnectionHandlers) GetReconnectInterval() time.Duration { + return cch.ReconnectInterval +} + +func (cch *ClientConnectionHandlers) GetReconnectTryTime() int { + return cch.ReconnectTryTime +} + +func (cch *ClientConnectionHandlers) Validate() error { + if err := cch.ConnectionHandlers.Validate(); nil != err { + return err + } + + if cch.ReconnectInterval <= 0 { + cch.ReconnectInterval = DefaultReconnectInterval + } + + if cch.ReconnectTryTime <= 0 { + cch.ReconnectTryTime = DefaultReconnectTryTime + } + + return nil +} diff --git a/client-rwc-handler.go b/client-rwc-handler.go new file mode 100644 index 0000000..9451d29 --- /dev/null +++ b/client-rwc-handler.go @@ -0,0 +1,70 @@ +package server + +import ( + "io" + "sync" + + logging "git.loafle.net/commons/logging-go" +) + +type ClientRWCHandler struct { + ReadwriteHandler ReadWriteHandler + ReadChan chan<- []byte + WriteChan <-chan []byte + DisconnectedChan chan<- struct{} + ReconnectedChan <-chan *Conn + ClientStopChan <-chan struct{} + ClientStopWg *sync.WaitGroup +} + +func (crwch *ClientRWCHandler) HandleConnection(conn *Conn) { + + defer func() { + if nil != conn { + conn.Close() + } + logging.Logger().Infof("disconnected") + crwch.ClientStopWg.Done() + }() + + logging.Logger().Infof("connected") + + stopChan := make(chan struct{}) + + readerDoneChan := make(chan error) + writerDoneChan := make(chan error) + + var err error + + for { + if nil != err { + if io.EOF == err || io.ErrUnexpectedEOF == err { + crwch.DisconnectedChan <- struct{}{} + newConn := <-crwch.ReconnectedChan + if nil == newConn { + return + } + conn = newConn + } else { + return + } + } + + go connReadHandler(crwch.ReadwriteHandler, conn, stopChan, readerDoneChan, crwch.ReadChan) + go connWriteHandler(crwch.ReadwriteHandler, conn, stopChan, writerDoneChan, crwch.WriteChan) + + select { + case err = <-readerDoneChan: + close(stopChan) + <-writerDoneChan + case err = <-writerDoneChan: + close(stopChan) + <-readerDoneChan + case <-crwch.ClientStopChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + return + } + } +} diff --git a/connection-handler.go b/connection-handler.go new file mode 100644 index 0000000..20bca64 --- /dev/null +++ b/connection-handler.go @@ -0,0 +1,65 @@ +package server + +import ( + "crypto/tls" + "fmt" + "net" + "time" +) + +type ConnectionHandler interface { + GetConcurrency() int + GetKeepAlive() time.Duration + GetHandshakeTimeout() time.Duration + GetTLSConfig() *tls.Config + + Listener(serverCtx ServerCtx) (net.Listener, error) +} + +type ConnectionHandlers struct { + ConnectionHandler + + // The maximum number of concurrent connections the server may serve. + // + // DefaultConcurrency is used if not set. + Concurrency int + KeepAlive time.Duration + HandshakeTimeout time.Duration + TLSConfig *tls.Config +} + +func (ch *ConnectionHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) { + return nil, fmt.Errorf("Method[ConnectionHandler.Listener] is not implemented") +} + +func (ch *ConnectionHandlers) GetConcurrency() int { + return ch.Concurrency +} + +func (ch *ConnectionHandlers) GetKeepAlive() time.Duration { + return ch.KeepAlive +} + +func (ch *ConnectionHandlers) GetHandshakeTimeout() time.Duration { + return ch.HandshakeTimeout +} + +func (ch *ConnectionHandlers) GetTLSConfig() *tls.Config { + return ch.TLSConfig +} + +func (ch *ConnectionHandlers) Validate() error { + if ch.Concurrency <= 0 { + ch.Concurrency = DefaultConcurrency + } + + if ch.KeepAlive <= 0 { + ch.KeepAlive = DefaultKeepAlive + } + + if ch.HandshakeTimeout <= 0 { + ch.HandshakeTimeout = DefaultHandshakeTimeout + } + + return nil +} diff --git a/const.go b/const.go index a31a1c5..d7179f3 100644 --- a/const.go +++ b/const.go @@ -30,4 +30,7 @@ const ( DefaultPingTimeout = 10 * time.Second // DefaultPingPeriod is default value of send ping period DefaultPingPeriod = (DefaultPingTimeout * 9) / 10 + + DefaultReconnectInterval = 1 * time.Second + DefaultReconnectTryTime = 10 ) diff --git a/fasthttp/websocket/client.go b/fasthttp/websocket/client.go index 26e0308..58f0d6e 100644 --- a/fasthttp/websocket/client.go +++ b/fasthttp/websocket/client.go @@ -23,6 +23,9 @@ import ( var errMalformedURL = errors.New("malformed ws or wss URL") type Client struct { + server.ClientConnectionHandlers + server.ReadWriteHandlers + Name string URL string @@ -44,46 +47,18 @@ type Client struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*http.Request) (*url.URL, error) - TLSConfig *tls.Config - HandshakeTimeout time.Duration - // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer - // size is zero, then a useful default size is used. The I/O buffer sizes - // do not limit the size of the messages that can be sent or received. - ReadBufferSize, WriteBufferSize int - // Subprotocols specifies the client's requested subprotocols. - - // EnableCompression specifies if the client should attempt to negotiate - // per message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. Currently only "no context - // takeover" modes are supported. - EnableCompression bool - - // MaxMessageSize is the maximum size for a message read from the peer. If a - // message exceeds the limit, the connection sends a close frame to the peer - // and returns ErrReadLimit to the application. - MaxMessageSize int64 - // WriteTimeout is the write deadline on the underlying network - // connection. After a write has timed out, the websocket state is corrupt and - // all future writes will return an error. A zero value for t means writes will - // not time out. - WriteTimeout time.Duration - // ReadTimeout is the read deadline on the underlying network connection. - // After a read has timed out, the websocket connection state is corrupt and - // all future reads will return an error. A zero value for t means reads will - // not time out. - ReadTimeout time.Duration - - PongTimeout time.Duration - PingTimeout time.Duration - PingPeriod time.Duration - serverURL *url.URL - stopChan chan struct{} - stopWg sync.WaitGroup - conn *server.Conn + stopChan chan struct{} + stopWg sync.WaitGroup + readChan chan []byte writeChan chan []byte + + disconnectedChan chan struct{} + reconnectedChan chan *server.Conn + + crwch server.ClientRWCHandler } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { @@ -107,10 +82,20 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) - + c.disconnectedChan = make(chan struct{}) + c.reconnectedChan = make(chan *server.Conn) c.stopChan = make(chan struct{}) + + c.crwch.ReadwriteHandler = c + c.crwch.ReadChan = c.readChan + c.crwch.WriteChan = c.writeChan + c.crwch.ClientStopChan = c.stopChan + c.crwch.ClientStopWg = &c.stopWg + c.crwch.DisconnectedChan = c.disconnectedChan + c.crwch.ReconnectedChan = c.reconnectedChan + c.stopWg.Add(1) - go c.handleConnection(conn) + go c.crwch.HandleConnection(conn) return c.readChan, c.writeChan, res, nil } @@ -144,124 +129,6 @@ func (c *Client) connect() (*server.Conn, *http.Response, error) { return conn, res, nil } -func (c *Client) handleConnection(conn *server.Conn) { - defer func() { - if nil != conn { - conn.Close() - } - logging.Logger().Infof(c.clientMessage("disconnected")) - c.stopWg.Done() - }() - - logging.Logger().Infof(c.clientMessage("connected")) - - stopChan := make(chan struct{}) - - readerDoneChan := make(chan struct{}) - writerDoneChan := make(chan struct{}) - - go handleClientRead(c, conn, stopChan, readerDoneChan) - go handleClientWrite(c, conn, stopChan, writerDoneChan) - - select { - case <-readerDoneChan: - close(stopChan) - <-writerDoneChan - case <-writerDoneChan: - close(stopChan) - <-readerDoneChan - case <-c.stopChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - } -} - -func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { - defer func() { - doneChan <- struct{}{} - }() - - if 0 < c.MaxMessageSize { - conn.SetReadLimit(c.MaxMessageSize) - } - if 0 < c.ReadTimeout { - conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) - } - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(c.PongTimeout)) - return nil - }) - - var ( - message []byte - err error - ) - - for { - readMessageChan := make(chan struct{}) - - go func() { - _, message, err = conn.ReadMessage() - close(readMessageChan) - }() - - select { - case <-stopChan: - <-readMessageChan - return - case <-readMessageChan: - } - - if nil != err { - if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { - logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err))) - } - return - } - - c.readChan <- message - } -} - -func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { - defer func() { - doneChan <- struct{}{} - }() - - ticker := time.NewTicker(c.PingPeriod) - defer func() { - ticker.Stop() - }() - for { - select { - case message, ok := <-c.writeChan: - conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) - if !ok { - conn.WriteMessage(server.CloseMessage, []byte{}) - return - } - - w, err := conn.NextWriter(server.TextMessage) - if err != nil { - return - } - w.Write(message) - - if err := w.Close(); nil != err { - return - } - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(c.PingTimeout)) - if err := conn.WriteMessage(server.PingMessage, nil); nil != err { - return - } - case <-stopChan: - return - } - } -} - func (c *Client) Dial() (*server.Conn, *http.Response, error) { var ( err error @@ -476,6 +343,13 @@ func (c *Client) Dial() (*server.Conn, *http.Response, error) { } func (c *Client) Validate() error { + if err := c.ClientConnectionHandlers.Validate(); nil != err { + return err + } + if err := c.ReadWriteHandlers.Validate(); nil != err { + return err + } + if "" == c.Name { c.Name = "Client" } @@ -507,34 +381,6 @@ func (c *Client) Validate() error { c.Proxy = http.ProxyFromEnvironment } - if c.HandshakeTimeout <= 0 { - c.HandshakeTimeout = server.DefaultHandshakeTimeout - } - if c.MaxMessageSize <= 0 { - c.MaxMessageSize = server.DefaultMaxMessageSize - } - if c.ReadBufferSize <= 0 { - c.ReadBufferSize = server.DefaultReadBufferSize - } - if c.WriteBufferSize <= 0 { - c.WriteBufferSize = server.DefaultWriteBufferSize - } - if c.ReadTimeout <= 0 { - c.ReadTimeout = server.DefaultReadTimeout - } - if c.WriteTimeout <= 0 { - c.WriteTimeout = server.DefaultWriteTimeout - } - if c.PongTimeout <= 0 { - c.PongTimeout = server.DefaultPongTimeout - } - if c.PingTimeout <= 0 { - c.PingTimeout = server.DefaultPingTimeout - } - if c.PingPeriod <= 0 { - c.PingPeriod = server.DefaultPingPeriod - } - return nil } diff --git a/fasthttp/websocket/server.go b/fasthttp/websocket/server.go index 779e86d..0142c4e 100644 --- a/fasthttp/websocket/server.go +++ b/fasthttp/websocket/server.go @@ -6,7 +6,6 @@ import ( "net" "net/http" "sync" - "time" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" @@ -16,14 +15,14 @@ import ( type Server struct { ServerHandler ServerHandler - ctx server.ServerCtx + ctx server.ServerCtx + stopChan chan struct{} + stopWg sync.WaitGroup + + srwch server.ServerRWCHandler hs *fasthttp.Server upgrader *Upgrader - - connections sync.Map - stopChan chan struct{} - stopWg sync.WaitGroup } func (s *Server) ListenAndServe() error { @@ -59,7 +58,7 @@ func (s *Server) ListenAndServe() error { HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), ReadBufferSize: s.ServerHandler.GetReadBufferSize(), WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), - CheckOrigin: s.ServerHandler.CheckOrigin, + CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin, Error: s.onError, EnableCompression: s.ServerHandler.IsEnableCompression(), } @@ -73,13 +72,18 @@ func (s *Server) ListenAndServe() error { } s.stopChan = make(chan struct{}) + + s.srwch.ReadwriteHandler = s.ServerHandler + s.srwch.ServerStopChan = s.stopChan + s.srwch.ServerStopWg = &s.stopWg + s.stopWg.Add(1) return s.handleServer(listener) } func (s *Server) Shutdown(ctx context.Context) error { if s.stopChan == nil { - return fmt.Errorf(s.serverMessage("server must be started before stopping it")) + return fmt.Errorf("server must be started before stopping it") } close(s.stopChan) s.stopWg.Wait() @@ -91,15 +95,6 @@ func (s *Server) Shutdown(ctx context.Context) error { return nil } -func (s *Server) ConnectionSize() int { - var sz int - s.connections.Range(func(k, v interface{}) bool { - sz++ - return true - }) - return sz -} - func (s *Server) serverMessage(msg string) string { return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) } @@ -162,7 +157,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { ) if 0 < s.ServerHandler.GetConcurrency() { - sz := s.ConnectionSize() + sz := s.srwch.ConnectionSize() if sz >= s.ServerHandler.GetConcurrency() { logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz))) s.onError(ctx, fasthttp.StatusServiceUnavailable, err) @@ -170,7 +165,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { } } - if servlet = s.ServerHandler.Servlet(path); nil == servlet { + if servlet = s.ServerHandler.(ServerHandler).Servlet(path); nil == servlet { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } @@ -190,154 +185,10 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { } s.stopWg.Add(1) - go s.handleConnection(servlet, servletCtx, conn) + go s.srwch.HandleConnection(servlet, servletCtx, conn) }) } -func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *server.Conn) { - addr := conn.RemoteAddr() - - defer func() { - if nil != conn { - conn.Close() - } - servlet.OnDisconnect(servletCtx) - logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) - s.stopWg.Done() - }() - - logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr))) - - s.connections.Store(conn, true) - defer s.connections.Delete(conn) - - servlet.OnConnect(servletCtx, conn) - conn.SetCloseHandler(func(code int, text string) error { - logging.Logger().Debugf("close") - return nil - }) - - stopChan := make(chan struct{}) - servletDoneChan := make(chan struct{}) - - readChan := make(chan []byte) - writeChan := make(chan []byte) - - readerDoneChan := make(chan struct{}) - writerDoneChan := make(chan struct{}) - - go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) - go handleRead(s, conn, stopChan, readerDoneChan, readChan) - go handleWrite(s, conn, stopChan, writerDoneChan, writeChan) - - select { - case <-readerDoneChan: - close(stopChan) - <-writerDoneChan - <-servletDoneChan - case <-writerDoneChan: - close(stopChan) - <-readerDoneChan - <-servletDoneChan - case <-servletDoneChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - case <-s.stopChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - <-servletDoneChan - } -} - func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { - s.ServerHandler.OnError(s.ctx, ctx, status, reason) -} - -func handleRead(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) { - defer func() { - doneChan <- struct{}{} - }() - - if 0 < s.ServerHandler.GetMaxMessageSize() { - conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) - } - if 0 < s.ServerHandler.GetReadTimeout() { - conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) - } - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout())) - return nil - }) - - var ( - message []byte - err error - ) - - for { - readMessageChan := make(chan struct{}) - - go func() { - _, message, err = conn.ReadMessage() - close(readMessageChan) - }() - - select { - case <-s.stopChan: - <-readMessageChan - return - case <-readMessageChan: - } - - if nil != err { - if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { - logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) - } - return - } - - readChan <- message - } -} - -func handleWrite(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) { - defer func() { - doneChan <- struct{}{} - }() - - ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) - defer func() { - ticker.Stop() - }() - for { - select { - case message, ok := <-writeChan: - if 0 < s.ServerHandler.GetWriteTimeout() { - conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout())) - } - if !ok { - conn.WriteMessage(server.CloseMessage, []byte{}) - return - } - - w, err := conn.NextWriter(server.TextMessage) - if err != nil { - return - } - w.Write(message) - - if err := w.Close(); nil != err { - return - } - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout())) - if err := conn.WriteMessage(server.PingMessage, nil); nil != err { - return - } - case <-s.stopChan: - return - } - } + s.ServerHandler.(ServerHandler).OnError(s.ctx, ctx, status, reason) } diff --git a/net/client.go b/net/client.go index 8f8829b..eeec2e2 100644 --- a/net/client.go +++ b/net/client.go @@ -12,52 +12,25 @@ import ( ) type Client struct { + server.ClientConnectionHandlers + server.ReadWriteHandlers + Name string - Network string - Address string - TLSConfig *tls.Config - HandshakeTimeout time.Duration - KeepAlive time.Duration - LocalAddress net.Addr + Network string + Address string + LocalAddress net.Addr - MaxMessageSize int64 - // Per-connection buffer size for requests' reading. - // This also limits the maximum header size. - // - // Increase this buffer if your clients send multi-KB RequestURIs - // and/or multi-KB headers (for example, BIG cookies). - // - // Default buffer size is used if not set. - ReadBufferSize int - // Per-connection buffer size for responses' writing. - // - // Default buffer size is used if not set. - WriteBufferSize int - // Maximum duration for reading the full request (including body). - // - // This also limits the maximum duration for idle keep-alive - // connections. - // - // By default request read timeout is unlimited. - ReadTimeout time.Duration + stopChan chan struct{} + stopWg sync.WaitGroup - // Maximum duration for writing the full response (including body). - // - // By default response write timeout is unlimited. - WriteTimeout time.Duration - - PongTimeout time.Duration - PingTimeout time.Duration - PingPeriod time.Duration - - EnableCompression bool - - stopChan chan struct{} - stopWg sync.WaitGroup - conn *server.Conn readChan chan []byte writeChan chan []byte + + disconnectedChan chan struct{} + reconnectedChan chan *server.Conn + + crwch server.ClientRWCHandler } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { @@ -81,10 +54,20 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) - + c.disconnectedChan = make(chan struct{}) + c.reconnectedChan = make(chan *server.Conn) c.stopChan = make(chan struct{}) + + c.crwch.ReadwriteHandler = c + c.crwch.ReadChan = c.readChan + c.crwch.WriteChan = c.writeChan + c.crwch.ClientStopChan = c.stopChan + c.crwch.ClientStopWg = &c.stopWg + c.crwch.DisconnectedChan = c.disconnectedChan + c.crwch.ReconnectedChan = c.reconnectedChan + c.stopWg.Add(1) - go c.handleConnection(conn) + go c.crwch.HandleConnection(conn) return c.readChan, c.writeChan, nil } @@ -119,126 +102,6 @@ func (c *Client) connect() (*server.Conn, error) { return conn, nil } -func (c *Client) handleConnection(conn *server.Conn) { - defer func() { - if nil != conn { - conn.Close() - } - logging.Logger().Infof(c.clientMessage("disconnected")) - c.stopWg.Done() - }() - - logging.Logger().Infof(c.clientMessage("connected")) - - stopChan := make(chan struct{}) - - readerDoneChan := make(chan struct{}) - writerDoneChan := make(chan struct{}) - - go handleClientRead(c, conn, stopChan, readerDoneChan) - go handleClientWrite(c, conn, stopChan, writerDoneChan) - - select { - case <-readerDoneChan: - close(stopChan) - <-writerDoneChan - case <-writerDoneChan: - close(stopChan) - <-readerDoneChan - case <-c.stopChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - } -} - -func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { - defer func() { - doneChan <- struct{}{} - }() - - if 0 < c.MaxMessageSize { - conn.SetReadLimit(c.MaxMessageSize) - } - if 0 < c.ReadTimeout { - conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) - } - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(c.PongTimeout)) - return nil - }) - - var ( - message []byte - err error - ) - - for { - readMessageChan := make(chan struct{}) - - go func() { - _, message, err = conn.ReadMessage() - close(readMessageChan) - }() - - select { - case <-stopChan: - <-readMessageChan - return - case <-readMessageChan: - } - - if nil != err { - if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { - logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err))) - } - return - } - - c.readChan <- message - } -} - -func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { - defer func() { - doneChan <- struct{}{} - }() - - ticker := time.NewTicker(c.PingPeriod) - defer func() { - ticker.Stop() - }() - for { - select { - case message, ok := <-c.writeChan: - if 0 < c.WriteTimeout { - conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) - } - if !ok { - conn.WriteMessage(server.CloseMessage, []byte{}) - return - } - - w, err := conn.NextWriter(server.TextMessage) - if err != nil { - return - } - w.Write(message) - - if err := w.Close(); nil != err { - return - } - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(c.PingTimeout)) - if err := conn.WriteMessage(server.PingMessage, nil); nil != err { - return - } - case <-stopChan: - return - } - } -} - func (c *Client) Dial() (net.Conn, error) { if err := c.Validate(); nil != err { return nil, err @@ -279,6 +142,13 @@ func (c *Client) Dial() (net.Conn, error) { } func (c *Client) Validate() error { + if err := c.ClientConnectionHandlers.Validate(); nil != err { + return err + } + if err := c.ReadWriteHandlers.Validate(); nil != err { + return err + } + if "" == c.Name { c.Name = "Client" } @@ -291,33 +161,5 @@ func (c *Client) Validate() error { return fmt.Errorf("Client: Address is not valid") } - if c.HandshakeTimeout <= 0 { - c.HandshakeTimeout = server.DefaultHandshakeTimeout - } - if c.MaxMessageSize <= 0 { - c.MaxMessageSize = server.DefaultMaxMessageSize - } - if c.ReadBufferSize <= 0 { - c.ReadBufferSize = server.DefaultReadBufferSize - } - if c.WriteBufferSize <= 0 { - c.WriteBufferSize = server.DefaultWriteBufferSize - } - if c.ReadTimeout <= 0 { - c.ReadTimeout = server.DefaultReadTimeout - } - if c.WriteTimeout <= 0 { - c.WriteTimeout = server.DefaultWriteTimeout - } - if c.PongTimeout <= 0 { - c.PongTimeout = server.DefaultPongTimeout - } - if c.PingTimeout <= 0 { - c.PingTimeout = server.DefaultPingTimeout - } - if c.PingPeriod <= 0 { - c.PingPeriod = server.DefaultPingPeriod - } - return nil } diff --git a/net/server.go b/net/server.go index 806c99e..04868af 100644 --- a/net/server.go +++ b/net/server.go @@ -1,8 +1,10 @@ package net import ( + "context" "fmt" "net" + "sync" "sync/atomic" "time" @@ -10,24 +12,19 @@ import ( "git.loafle.net/commons/server-go" ) -type Server interface { - server.Server +type Server struct { + ServerHandler ServerHandler + + ctx server.ServerCtx + stopChan chan struct{} + stopWg sync.WaitGroup + + srwch server.ServerRWCHandler } -func NewServer(serverHandler ServerHandler) Server { - s := &netServer{} - s.ServerHandler = serverHandler - - return s -} - -type netServer struct { - server.Servers -} - -func (s *netServer) ListenAndServe() error { - if s.StopChan != nil { - return fmt.Errorf(s.ServerMessage("already running. Stop it before starting it again")) +func (s *Server) ListenAndServe() error { + if s.stopChan != nil { + return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again")) } var ( @@ -41,25 +38,48 @@ func (s *netServer) ListenAndServe() error { return err } - s.ServerCtx = s.ServerHandler.ServerCtx() - if nil == s.ServerCtx { - return fmt.Errorf(s.ServerMessage("ServerCtx is nil")) + s.ctx = s.ServerHandler.ServerCtx() + if nil == s.ctx { + return fmt.Errorf(s.serverMessage("ServerCtx is nil")) } - if err = s.ServerHandler.Init(s.ServerCtx); nil != err { + if err = s.ServerHandler.Init(s.ctx); nil != err { return err } - if listener, err = s.ServerHandler.Listener(s.ServerCtx); nil != err { + if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { return err } - s.StopChan = make(chan struct{}) - s.StopWg.Add(1) + s.stopChan = make(chan struct{}) + + s.srwch.ReadwriteHandler = s.ServerHandler + s.srwch.ServerStopChan = s.stopChan + s.srwch.ServerStopWg = &s.stopWg + + s.stopWg.Add(1) return s.handleServer(listener) } -func (s *netServer) handleServer(listener net.Listener) error { +func (s *Server) Shutdown(ctx context.Context) error { + if s.stopChan == nil { + return fmt.Errorf("server must be started before stopping it") + } + close(s.stopChan) + s.stopWg.Wait() + + s.ServerHandler.Destroy(s.ctx) + + s.stopChan = nil + + return nil +} + +func (s *Server) serverMessage(msg string) string { + return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) +} + +func (s *Server) handleServer(listener net.Listener) error { var ( stopping atomic.Value netConn net.Conn @@ -71,18 +91,18 @@ func (s *netServer) handleServer(listener net.Listener) error { listener.Close() } - s.ServerHandler.OnStop(s.ServerCtx) + s.ServerHandler.OnStop(s.ctx) - logging.Logger().Infof(s.ServerMessage("Stopped")) + logging.Logger().Infof(s.serverMessage("Stopped")) - s.StopWg.Done() + s.stopWg.Done() }() - if err = s.ServerHandler.OnStart(s.ServerCtx); nil != err { + if err = s.ServerHandler.OnStart(s.ctx); nil != err { return err } - logging.Logger().Infof(s.ServerMessage("Started")) + logging.Logger().Infof(s.serverMessage("Started")) for { acceptChan := make(chan struct{}) @@ -90,14 +110,14 @@ func (s *netServer) handleServer(listener net.Listener) error { go func() { if netConn, err = listener.Accept(); err != nil { if nil == stopping.Load() { - logging.Logger().Errorf(s.ServerMessage(fmt.Sprintf("%v", err))) + logging.Logger().Errorf(s.serverMessage(fmt.Sprintf("%v", err))) } } close(acceptChan) }() select { - case <-s.StopChan: + case <-s.stopChan: stopping.Store(true) listener.Close() <-acceptChan @@ -108,7 +128,7 @@ func (s *netServer) handleServer(listener net.Listener) error { if nil != err { select { - case <-s.StopChan: + case <-s.stopChan: return nil case <-time.After(time.Second): } @@ -116,9 +136,9 @@ func (s *netServer) handleServer(listener net.Listener) error { } if 0 < s.ServerHandler.GetConcurrency() { - sz := s.ConnectionSize() + sz := s.srwch.ConnectionSize() if sz >= s.ServerHandler.GetConcurrency() { - logging.Logger().Warnf(s.ServerMessage(fmt.Sprintf("max connections size %d, refuse", sz))) + logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz))) netConn.Close() continue } @@ -126,24 +146,24 @@ func (s *netServer) handleServer(listener net.Listener) error { servlet := s.ServerHandler.(ServerHandler).Servlet() if nil == servlet { - logging.Logger().Errorf(s.ServerMessage("Servlet is nil")) + logging.Logger().Errorf(s.serverMessage("Servlet is nil")) continue } - servletCtx := servlet.ServletCtx(s.ServerCtx) + servletCtx := servlet.ServletCtx(s.ctx) if nil == servletCtx { - logging.Logger().Errorf(s.ServerMessage("ServletCtx is nil")) + logging.Logger().Errorf(s.serverMessage("ServletCtx is nil")) continue } if err := servlet.Handshake(servletCtx, netConn); nil != err { - logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err))) + logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err))) continue } conn := server.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize()) - s.StopWg.Add(1) - go s.HandleConnection(servlet, servletCtx, conn) + s.stopWg.Add(1) + go s.srwch.HandleConnection(servlet, servletCtx, conn) } } diff --git a/readwrite-handler.go b/readwrite-handler.go new file mode 100644 index 0000000..68b94a9 --- /dev/null +++ b/readwrite-handler.go @@ -0,0 +1,115 @@ +package server + +import ( + "time" +) + +type ReadWriteHandler interface { + GetMaxMessageSize() int64 + GetReadBufferSize() int + GetWriteBufferSize() int + GetReadTimeout() time.Duration + GetWriteTimeout() time.Duration + GetPongTimeout() time.Duration + GetPingTimeout() time.Duration + GetPingPeriod() time.Duration + + IsEnableCompression() bool +} + +type ReadWriteHandlers struct { + ReadWriteHandler + + MaxMessageSize int64 + // Per-connection buffer size for requests' reading. + // This also limits the maximum header size. + // + // Increase this buffer if your clients send multi-KB RequestURIs + // and/or multi-KB headers (for example, BIG cookies). + // + // Default buffer size is used if not set. + ReadBufferSize int + // Per-connection buffer size for responses' writing. + // + // Default buffer size is used if not set. + WriteBufferSize int + // Maximum duration for reading the full request (including body). + // + // This also limits the maximum duration for idle keep-alive + // connections. + // + // By default request read timeout is unlimited. + ReadTimeout time.Duration + + // Maximum duration for writing the full response (including body). + // + // By default response write timeout is unlimited. + WriteTimeout time.Duration + + PongTimeout time.Duration + PingTimeout time.Duration + PingPeriod time.Duration + + EnableCompression bool +} + +func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 { + return rwh.MaxMessageSize +} + +func (rwh *ReadWriteHandlers) GetReadBufferSize() int { + return rwh.ReadBufferSize +} + +func (rwh *ReadWriteHandlers) GetWriteBufferSize() int { + return rwh.WriteBufferSize +} + +func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration { + return rwh.ReadTimeout +} +func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration { + return rwh.WriteTimeout +} +func (rwh *ReadWriteHandlers) GetPongTimeout() time.Duration { + return rwh.PongTimeout +} +func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration { + return rwh.PingTimeout +} +func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration { + return rwh.PingPeriod +} + +func (rwh *ReadWriteHandlers) IsEnableCompression() bool { + return rwh.EnableCompression +} + +func (rwh *ReadWriteHandlers) Validate() error { + if rwh.MaxMessageSize <= 0 { + rwh.MaxMessageSize = DefaultMaxMessageSize + } + if rwh.ReadBufferSize <= 0 { + rwh.ReadBufferSize = DefaultReadBufferSize + } + if rwh.WriteBufferSize <= 0 { + rwh.WriteBufferSize = DefaultWriteBufferSize + } + if rwh.ReadTimeout <= 0 { + rwh.ReadTimeout = DefaultReadTimeout + } + if rwh.WriteTimeout <= 0 { + rwh.WriteTimeout = DefaultWriteTimeout + } + if rwh.PongTimeout <= 0 { + rwh.PongTimeout = DefaultPongTimeout + } + if rwh.PingTimeout <= 0 { + rwh.PingTimeout = DefaultPingTimeout + } + if rwh.PingPeriod <= 0 { + rwh.PingPeriod = (rwh.PingTimeout * 9) / 10 + } + + return nil +} diff --git a/readwrite.go b/readwrite.go new file mode 100644 index 0000000..0c946b3 --- /dev/null +++ b/readwrite.go @@ -0,0 +1,101 @@ +package server + +import ( + "fmt" + "io" + "time" +) + +func connReadHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- []byte) { + var ( + message []byte + err error + ) + + defer func() { + doneChan <- err + }() + + if 0 < readWriteHandler.GetMaxMessageSize() { + conn.SetReadLimit(readWriteHandler.GetMaxMessageSize()) + } + if 0 < readWriteHandler.GetReadTimeout() { + conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetReadTimeout())) + } + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetPongTimeout())) + return nil + }) + + for { + readMessageChan := make(chan struct{}) + + go func() { + _, message, err = conn.ReadMessage() + close(readMessageChan) + }() + + select { + case <-stopChan: + <-readMessageChan + return + case <-readMessageChan: + } + + if nil != err { + if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) { + err = fmt.Errorf("Read error %v", err) + } + return + } + + readChan <- message + } +} + +func connWriteHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) { + var ( + wc io.WriteCloser + message []byte + ok bool + err error + ) + + defer func() { + doneChan <- err + }() + + ticker := time.NewTicker(readWriteHandler.GetPingPeriod()) + defer func() { + ticker.Stop() + }() + for { + select { + case message, ok = <-writeChan: + if 0 < readWriteHandler.GetWriteTimeout() { + conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetWriteTimeout())) + } + if !ok { + conn.WriteMessage(CloseMessage, []byte{}) + return + } + + wc, err = conn.NextWriter(TextMessage) + if err != nil { + return + } + wc.Write(message) + + if err = wc.Close(); nil != err { + return + } + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetPingTimeout())) + if err = conn.WriteMessage(PingMessage, nil); nil != err { + return + } + case <-stopChan: + return + } + } +} diff --git a/server-handler.go b/server-handler.go index 601768b..e83fd4b 100644 --- a/server-handler.go +++ b/server-handler.go @@ -1,26 +1,10 @@ package server -import ( - "fmt" - "net" - "time" -) - type ServerHandler interface { + ConnectionHandler + ReadWriteHandler + GetName() string - GetConcurrency() int - GetHandshakeTimeout() time.Duration - GetMaxMessageSize() int64 - GetReadBufferSize() int - GetWriteBufferSize() int - GetReadTimeout() time.Duration - GetWriteTimeout() time.Duration - GetPongTimeout() time.Duration - GetPingTimeout() time.Duration - GetPingPeriod() time.Duration - - IsEnableCompression() bool - ServerCtx() ServerCtx Init(serverCtx ServerCtx) error @@ -28,57 +12,18 @@ type ServerHandler interface { OnStop(serverCtx ServerCtx) Destroy(serverCtx ServerCtx) - Listener(serverCtx ServerCtx) (net.Listener, error) - Validate() error } type ServerHandlers struct { ServerHandler + ConnectionHandlers + ReadWriteHandlers // Server name for sending in response headers. // // Default server name is used if left blank. Name string - - // The maximum number of concurrent connections the server may serve. - // - // DefaultConcurrency is used if not set. - Concurrency int - - HandshakeTimeout time.Duration - - MaxMessageSize int64 - // Per-connection buffer size for requests' reading. - // This also limits the maximum header size. - // - // Increase this buffer if your clients send multi-KB RequestURIs - // and/or multi-KB headers (for example, BIG cookies). - // - // Default buffer size is used if not set. - ReadBufferSize int - // Per-connection buffer size for responses' writing. - // - // Default buffer size is used if not set. - WriteBufferSize int - // Maximum duration for reading the full request (including body). - // - // This also limits the maximum duration for idle keep-alive - // connections. - // - // By default request read timeout is unlimited. - ReadTimeout time.Duration - - // Maximum duration for writing the full response (including body). - // - // By default response write timeout is unlimited. - WriteTimeout time.Duration - - PongTimeout time.Duration - PingTimeout time.Duration - PingPeriod time.Duration - - EnableCompression bool } func (sh *ServerHandlers) ServerCtx() ServerCtx { @@ -101,90 +46,21 @@ func (sh *ServerHandlers) Destroy(serverCtx ServerCtx) { } -func (sh *ServerHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) { - return nil, fmt.Errorf("Server: Method[ServerHandler.Listener] is not implemented") -} - func (sh *ServerHandlers) GetName() string { return sh.Name } -func (sh *ServerHandlers) GetConcurrency() int { - return sh.Concurrency -} - -func (sh *ServerHandlers) GetHandshakeTimeout() time.Duration { - return sh.HandshakeTimeout -} - -func (sh *ServerHandlers) GetMaxMessageSize() int64 { - return sh.MaxMessageSize -} - -func (sh *ServerHandlers) GetReadBufferSize() int { - return sh.ReadBufferSize -} - -func (sh *ServerHandlers) GetWriteBufferSize() int { - return sh.WriteBufferSize -} - -func (sh *ServerHandlers) GetReadTimeout() time.Duration { - return sh.ReadTimeout -} -func (sh *ServerHandlers) GetWriteTimeout() time.Duration { - return sh.WriteTimeout -} -func (sh *ServerHandlers) GetPongTimeout() time.Duration { - return sh.PongTimeout -} -func (sh *ServerHandlers) GetPingTimeout() time.Duration { - return sh.PingTimeout -} -func (sh *ServerHandlers) GetPingPeriod() time.Duration { - return sh.PingPeriod -} - -func (sh *ServerHandlers) IsEnableCompression() bool { - return sh.EnableCompression -} - func (sh *ServerHandlers) Validate() error { + if err := sh.ConnectionHandlers.Validate(); nil != err { + return err + } + if err := sh.ReadWriteHandlers.Validate(); nil != err { + return err + } + if "" == sh.Name { sh.Name = "Server" } - if sh.Concurrency <= 0 { - sh.Concurrency = DefaultConcurrency - } - - if sh.HandshakeTimeout <= 0 { - sh.HandshakeTimeout = DefaultHandshakeTimeout - } - if sh.MaxMessageSize <= 0 { - sh.MaxMessageSize = DefaultMaxMessageSize - } - if sh.ReadBufferSize <= 0 { - sh.ReadBufferSize = DefaultReadBufferSize - } - if sh.WriteBufferSize <= 0 { - sh.WriteBufferSize = DefaultWriteBufferSize - } - if sh.ReadTimeout <= 0 { - sh.ReadTimeout = DefaultReadTimeout - } - if sh.WriteTimeout <= 0 { - sh.WriteTimeout = DefaultWriteTimeout - } - if sh.PongTimeout <= 0 { - sh.PongTimeout = DefaultPongTimeout - } - if sh.PingTimeout <= 0 { - sh.PingTimeout = DefaultPingTimeout - } - if sh.PingPeriod <= 0 { - sh.PingPeriod = (sh.PingTimeout * 9) / 10 - } - return nil } diff --git a/server-rwc-handler.go b/server-rwc-handler.go new file mode 100644 index 0000000..9ed7680 --- /dev/null +++ b/server-rwc-handler.go @@ -0,0 +1,81 @@ +package server + +import ( + "sync" + + logging "git.loafle.net/commons/logging-go" +) + +type ServerRWCHandler struct { + connections sync.Map + + ReadwriteHandler ReadWriteHandler + ServerStopChan <-chan struct{} + ServerStopWg *sync.WaitGroup +} + +func (srwch *ServerRWCHandler) ConnectionSize() int { + var sz int + srwch.connections.Range(func(k, v interface{}) bool { + sz++ + return true + }) + return sz +} + +func (srwch *ServerRWCHandler) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) { + addr := conn.RemoteAddr() + + defer func() { + if nil != conn { + conn.Close() + } + servlet.OnDisconnect(servletCtx) + logging.Logger().Infof("Client[%s] has been disconnected", addr) + srwch.ServerStopWg.Done() + }() + + logging.Logger().Infof("Client[%s] has been connected", addr) + + srwch.connections.Store(conn, true) + defer srwch.connections.Delete(conn) + + servlet.OnConnect(servletCtx, conn) + conn.SetCloseHandler(func(code int, text string) error { + logging.Logger().Debugf("close") + return nil + }) + + stopChan := make(chan struct{}) + servletDoneChan := make(chan struct{}) + + readChan := make(chan []byte) + writeChan := make(chan []byte) + + readerDoneChan := make(chan error) + writerDoneChan := make(chan error) + + go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) + go connReadHandler(srwch.ReadwriteHandler, conn, stopChan, readerDoneChan, readChan) + go connWriteHandler(srwch.ReadwriteHandler, conn, stopChan, writerDoneChan, writeChan) + + select { + case <-readerDoneChan: + close(stopChan) + <-writerDoneChan + <-servletDoneChan + case <-writerDoneChan: + close(stopChan) + <-readerDoneChan + <-servletDoneChan + case <-servletDoneChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + case <-srwch.ServerStopChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + <-servletDoneChan + } +} diff --git a/server.go b/server.go deleted file mode 100644 index 13efc35..0000000 --- a/server.go +++ /dev/null @@ -1,200 +0,0 @@ -package server - -import ( - "context" - "fmt" - "sync" - "time" - - logging "git.loafle.net/commons/logging-go" -) - -type Server interface { - ListenAndServe() error - Shutdown(ctx context.Context) error - ConnectionSize() int - - HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) - HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) - HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) -} - -type Servers struct { - ServerHandler ServerHandler - - ServerCtx ServerCtx - Connections sync.Map - StopChan chan struct{} - StopWg sync.WaitGroup -} - -func (s *Servers) Shutdown(ctx context.Context) error { - if s.StopChan == nil { - return fmt.Errorf("server must be started before stopping it") - } - close(s.StopChan) - s.StopWg.Wait() - - s.ServerHandler.Destroy(s.ServerCtx) - - s.StopChan = nil - - return nil -} - -func (s *Servers) ConnectionSize() int { - var sz int - s.Connections.Range(func(k, v interface{}) bool { - sz++ - return true - }) - return sz -} - -func (s *Servers) ServerMessage(msg string) string { - return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) -} - -func (s *Servers) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) { - addr := conn.RemoteAddr() - - defer func() { - if nil != conn { - conn.Close() - } - servlet.OnDisconnect(servletCtx) - logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) - s.StopWg.Done() - }() - - logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been connected", addr))) - - s.Connections.Store(conn, true) - defer s.Connections.Delete(conn) - - servlet.OnConnect(servletCtx, conn) - conn.SetCloseHandler(func(code int, text string) error { - logging.Logger().Debugf("close") - return nil - }) - - stopChan := make(chan struct{}) - servletDoneChan := make(chan struct{}) - - readChan := make(chan []byte) - writeChan := make(chan []byte) - - readerDoneChan := make(chan struct{}) - writerDoneChan := make(chan struct{}) - - go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) - go s.HandleRead(conn, stopChan, readerDoneChan, readChan) - go s.HandleWrite(conn, stopChan, writerDoneChan, writeChan) - - select { - case <-readerDoneChan: - close(stopChan) - <-writerDoneChan - <-servletDoneChan - case <-writerDoneChan: - close(stopChan) - <-readerDoneChan - <-servletDoneChan - case <-servletDoneChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - case <-s.StopChan: - close(stopChan) - <-readerDoneChan - <-writerDoneChan - <-servletDoneChan - } -} - -func (s *Servers) HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) { - defer func() { - doneChan <- struct{}{} - }() - - if 0 < s.ServerHandler.GetMaxMessageSize() { - conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) - } - if 0 < s.ServerHandler.GetReadTimeout() { - conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) - } - conn.SetPongHandler(func(string) error { - conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout())) - return nil - }) - - var ( - message []byte - err error - ) - - for { - readMessageChan := make(chan struct{}) - - go func() { - _, message, err = conn.ReadMessage() - close(readMessageChan) - }() - - select { - case <-s.StopChan: - <-readMessageChan - return - case <-readMessageChan: - } - - if nil != err { - if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) { - logging.Logger().Debugf(s.ServerMessage(fmt.Sprintf("Read error %v", err))) - } - return - } - - readChan <- message - } -} - -func (s *Servers) HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) { - defer func() { - doneChan <- struct{}{} - }() - - ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) - defer func() { - ticker.Stop() - }() - for { - select { - case message, ok := <-writeChan: - if 0 < s.ServerHandler.GetWriteTimeout() { - conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout())) - } - if !ok { - conn.WriteMessage(CloseMessage, []byte{}) - return - } - - w, err := conn.NextWriter(TextMessage) - if err != nil { - return - } - w.Write(message) - - if err := w.Close(); nil != err { - return - } - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout())) - if err := conn.WriteMessage(PingMessage, nil); nil != err { - return - } - case <-stopChan: - return - } - } -}