diff --git a/const.go b/const.go index 40cd029..baa29c5 100644 --- a/const.go +++ b/const.go @@ -5,6 +5,8 @@ const ( // the Server may serve by default (i.e. if Server.Concurrency isn't set). DefaultConcurrency = 256 * 1024 + DefaultKeepAlive = 0 + // DefaultHandshakeTimeout is default value of websocket handshake Timeout DefaultHandshakeTimeout = 0 diff --git a/fasthttp/websocket/client.go b/fasthttp/websocket/client.go index 4643352..386a868 100644 --- a/fasthttp/websocket/client.go +++ b/fasthttp/websocket/client.go @@ -13,8 +13,10 @@ import ( "net/http" "net/url" "strings" + "sync" "time" + logging "git.loafle.net/commons/logging-go" server "git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go/internal" ) @@ -27,6 +29,11 @@ type Client struct { URL string RequestHeader http.Header + Subprotocols []string + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + CookieJar http.CookieJar // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. @@ -45,7 +52,6 @@ type Client struct { // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the client's requested subprotocols. - Subprotocols []string // EnableCompression specifies if the client should attempt to negotiate // per message compression (RFC 7692). Setting this value to true does not @@ -53,11 +59,6 @@ type Client struct { // takeover" modes are supported. EnableCompression bool - // Jar specifies the cookie jar. - // If Jar is nil, cookies are not sent in requests and ignored - // in responses. - CookieJar http.CookieJar - // MaxMessageSize is the maximum size for a message read from the peer. If a // message exceeds the limit, the connection sends a close frame to the peer // and returns ErrReadLimit to the application. @@ -78,6 +79,197 @@ type Client struct { PingPeriod time.Duration serverURL *url.URL + + stopChan chan struct{} + stopWg sync.WaitGroup + conn *internal.Conn + readChan chan []byte + writeChan chan []byte +} + +func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { + var ( + conn *internal.Conn + ) + + if c.stopChan != nil { + return nil, nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) + } + + err = c.Validate() + if nil != err { + return nil, nil, nil, err + } + + conn, res, err = c.connect() + if nil != err { + return nil, nil, nil, err + } + + c.readChan = make(chan []byte, 256) + c.writeChan = make(chan []byte, 256) + + c.stopChan = make(chan struct{}) + c.stopWg.Add(1) + go c.handleConnection(conn) + + return c.readChan, c.writeChan, res, nil +} + +func (c *Client) Disconnect() error { + if c.stopChan == nil { + return fmt.Errorf(c.clientMessage("must be started before stopping it")) + } + close(c.stopChan) + c.stopWg.Wait() + + c.stopChan = nil + + return nil +} + +func (c *Client) clientMessage(msg string) string { + return fmt.Sprintf("Client[%s]: %s", c.Name, msg) +} + +func (c *Client) connect() (*internal.Conn, *http.Response, error) { + conn, res, err := c.Dial() + if nil != err { + return nil, nil, err + } + + conn.SetCloseHandler(func(code int, text string) error { + logging.Logger().Debugf("close") + return nil + }) + return conn, res, nil +} + +func (c *Client) handleConnection(conn *internal.Conn) { + defer func() { + if nil != conn { + conn.Close() + } + logging.Logger().Infof(c.clientMessage("disconnected")) + c.stopWg.Done() + }() + + logging.Logger().Infof(c.clientMessage("connected")) + + stopChan := make(chan struct{}) + + readerDoneChan := make(chan struct{}) + writerDoneChan := make(chan struct{}) + + go handleClientRead(c, conn, stopChan, readerDoneChan) + go handleClientWrite(c, conn, stopChan, writerDoneChan) + + select { + case <-readerDoneChan: + close(stopChan) + conn.Close() + <-writerDoneChan + conn = nil + case <-writerDoneChan: + close(stopChan) + conn.Close() + <-readerDoneChan + conn = nil + case <-c.stopChan: + close(stopChan) + conn.Close() + <-readerDoneChan + <-writerDoneChan + conn = nil + } +} + +func handleClientRead(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) { + defer func() { + close(doneChan) + }() + + conn.SetReadLimit(c.MaxMessageSize) + conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(c.PongTimeout)) + return nil + }) + + var ( + message []byte + err error + ) + + for { + readMessageChan := make(chan struct{}) + + go func() { + _, message, err = conn.ReadMessage() + if err != nil { + if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) { + logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err))) + } + } + close(readMessageChan) + }() + + select { + case <-c.stopChan: + <-readMessageChan + break + case <-readMessageChan: + } + + if nil != err { + select { + case <-c.stopChan: + break + case <-time.After(time.Second): + } + continue + } + + c.readChan <- message + } +} + +func handleClientWrite(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) { + defer func() { + close(doneChan) + }() + + ticker := time.NewTicker(c.PingPeriod) + defer func() { + ticker.Stop() + }() + for { + select { + case message, ok := <-c.writeChan: + conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + if !ok { + conn.WriteMessage(internal.CloseMessage, []byte{}) + return + } + + w, err := conn.NextWriter(internal.TextMessage) + if err != nil { + return + } + w.Write(message) + + if err := w.Close(); nil != err { + return + } + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(c.PingTimeout)) + if err := conn.WriteMessage(internal.PingMessage, nil); nil != err { + return + } + case <-c.stopChan: + break + } + } } func (c *Client) Dial() (*internal.Conn, *http.Response, error) { @@ -325,15 +517,33 @@ func (c *Client) Validate() error { c.Proxy = http.ProxyFromEnvironment } - if 0 > c.HandshakeTimeout { + if c.HandshakeTimeout <= 0 { c.HandshakeTimeout = server.DefaultHandshakeTimeout } - if 0 > c.ReadBufferSize { + if c.MaxMessageSize <= 0 { + c.MaxMessageSize = server.DefaultMaxMessageSize + } + if c.ReadBufferSize <= 0 { c.ReadBufferSize = server.DefaultReadBufferSize } - if 0 > c.WriteBufferSize { + if c.WriteBufferSize <= 0 { c.WriteBufferSize = server.DefaultWriteBufferSize } + if c.ReadTimeout <= 0 { + c.ReadTimeout = server.DefaultReadTimeout + } + if c.WriteTimeout <= 0 { + c.WriteTimeout = server.DefaultWriteTimeout + } + if c.PongTimeout <= 0 { + c.PongTimeout = server.DefaultPongTimeout + } + if c.PingTimeout <= 0 { + c.PingTimeout = server.DefaultPingTimeout + } + if c.PingPeriod <= 0 { + c.PingPeriod = server.DefaultPingPeriod + } return nil } diff --git a/fasthttp/websocket/server.go b/fasthttp/websocket/server.go index 4e8dad1..44a6a36 100644 --- a/fasthttp/websocket/server.go +++ b/fasthttp/websocket/server.go @@ -199,33 +199,60 @@ func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, addr := conn.RemoteAddr() defer func() { - s.connections.Delete(conn) - conn.Close() + if nil != conn { + conn.Close() + } servlet.OnDisconnect(servletCtx) logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) s.stopWg.Done() }() logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr))) + s.connections.Store(conn, true) + defer s.connections.Delete(conn) + servlet.OnConnect(servletCtx, conn) - servletStopChan := make(chan struct{}) - doneChan := make(chan struct{}) + stopChan := make(chan struct{}) + servletDoneChan := make(chan struct{}) readChan := make(chan []byte) writeChan := make(chan []byte) - go servlet.Handle(servletCtx, doneChan, servletStopChan, readChan, writeChan) - go handleRead(s, conn, readChan) - go handleWrite(s, conn, writeChan) + readerDoneChan := make(chan struct{}) + writerDoneChan := make(chan struct{}) + + go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) + go handleRead(s, conn, stopChan, readerDoneChan, readChan) + go handleWrite(s, conn, stopChan, writerDoneChan, writeChan) select { - case <-doneChan: - close(servletStopChan) + case <-readerDoneChan: + close(stopChan) + conn.Close() + <-writerDoneChan + <-servletDoneChan + conn = nil + case <-writerDoneChan: + close(stopChan) + conn.Close() + <-readerDoneChan + <-servletDoneChan + conn = nil + case <-servletDoneChan: + close(stopChan) + conn.Close() + <-readerDoneChan + <-writerDoneChan + conn = nil case <-s.stopChan: - close(servletStopChan) - <-doneChan + close(stopChan) + conn.Close() + <-readerDoneChan + <-writerDoneChan + <-servletDoneChan + conn = nil } } @@ -233,7 +260,11 @@ func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { s.ServerHandler.OnError(s.ctx, ctx, status, reason) } -func handleRead(s *Server, conn *internal.Conn, readChan chan []byte) { +func handleRead(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, readChan chan []byte) { + defer func() { + close(doneChan) + }() + conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) conn.SetPongHandler(func(string) error { @@ -241,19 +272,49 @@ func handleRead(s *Server, conn *internal.Conn, readChan chan []byte) { return nil }) + var ( + message []byte + err error + ) + for { - _, message, err := conn.ReadMessage() - if err != nil { - if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) { - logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) + readMessageChan := make(chan struct{}) + + go func() { + _, message, err = conn.ReadMessage() + if err != nil { + if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) { + logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) + } } + close(readMessageChan) + }() + + select { + case <-s.stopChan: + <-readMessageChan break + case <-readMessageChan: } + + if nil != err { + select { + case <-s.stopChan: + break + case <-time.After(time.Second): + } + continue + } + readChan <- message } } -func handleWrite(s *Server, conn *internal.Conn, writeChan chan []byte) { +func handleWrite(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, writeChan chan []byte) { + defer func() { + close(doneChan) + }() + ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) defer func() { ticker.Stop() @@ -281,6 +342,8 @@ func handleWrite(s *Server, conn *internal.Conn, writeChan chan []byte) { if err := conn.WriteMessage(internal.PingMessage, nil); nil != err { return } + case <-s.stopChan: + break } } } diff --git a/net/client.go b/net/client.go index 20dd31d..b7c3438 100644 --- a/net/client.go +++ b/net/client.go @@ -4,7 +4,12 @@ import ( "crypto/tls" "fmt" "net" + "sync" "time" + + logging "git.loafle.net/commons/logging-go" + server "git.loafle.net/commons/server-go" + "git.loafle.net/commons/server-go/internal" ) type Client struct { @@ -17,7 +22,229 @@ type Client struct { KeepAlive time.Duration LocalAddress net.Addr - MaxConnections int + MaxMessageSize int64 + // Per-connection buffer size for requests' reading. + // This also limits the maximum header size. + // + // Increase this buffer if your clients send multi-KB RequestURIs + // and/or multi-KB headers (for example, BIG cookies). + // + // Default buffer size is used if not set. + ReadBufferSize int + // Per-connection buffer size for responses' writing. + // + // Default buffer size is used if not set. + WriteBufferSize int + // Maximum duration for reading the full request (including body). + // + // This also limits the maximum duration for idle keep-alive + // connections. + // + // By default request read timeout is unlimited. + ReadTimeout time.Duration + + // Maximum duration for writing the full response (including body). + // + // By default response write timeout is unlimited. + WriteTimeout time.Duration + + PongTimeout time.Duration + PingTimeout time.Duration + PingPeriod time.Duration + + EnableCompression bool + + stopChan chan struct{} + stopWg sync.WaitGroup + conn *internal.Conn + readChan chan []byte + writeChan chan []byte +} + +func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { + var ( + conn *internal.Conn + ) + + if c.stopChan != nil { + return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) + } + + err = c.Validate() + if nil != err { + return nil, nil, err + } + + conn, err = c.connect() + if nil != err { + return nil, nil, err + } + + c.readChan = make(chan []byte, 256) + c.writeChan = make(chan []byte, 256) + + c.stopChan = make(chan struct{}) + c.stopWg.Add(1) + go c.handleConnection(conn) + + return c.readChan, c.writeChan, nil +} + +func (c *Client) Disconnect() error { + if c.stopChan == nil { + return fmt.Errorf(c.clientMessage("must be started before stopping it")) + } + close(c.stopChan) + c.stopWg.Wait() + + c.stopChan = nil + + return nil +} + +func (c *Client) clientMessage(msg string) string { + return fmt.Sprintf("Client[%s]: %s", c.Name, msg) +} + +func (c *Client) connect() (*internal.Conn, error) { + netConn, err := c.Dial() + if nil != err { + return nil, err + } + + conn := internal.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) + conn.SetCloseHandler(func(code int, text string) error { + logging.Logger().Debugf("close") + return nil + }) + return conn, nil +} + +func (c *Client) handleConnection(conn *internal.Conn) { + defer func() { + if nil != conn { + conn.Close() + } + logging.Logger().Infof(c.clientMessage("disconnected")) + c.stopWg.Done() + }() + + logging.Logger().Infof(c.clientMessage("connected")) + + stopChan := make(chan struct{}) + + readerDoneChan := make(chan struct{}) + writerDoneChan := make(chan struct{}) + + go handleClientRead(c, conn, stopChan, readerDoneChan) + go handleClientWrite(c, conn, stopChan, writerDoneChan) + + select { + case <-readerDoneChan: + close(stopChan) + conn.Close() + <-writerDoneChan + conn = nil + case <-writerDoneChan: + close(stopChan) + conn.Close() + <-readerDoneChan + conn = nil + case <-c.stopChan: + close(stopChan) + conn.Close() + <-readerDoneChan + <-writerDoneChan + conn = nil + } +} + +func handleClientRead(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) { + defer func() { + close(doneChan) + }() + + conn.SetReadLimit(c.MaxMessageSize) + conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(c.PongTimeout)) + return nil + }) + + var ( + message []byte + err error + ) + + for { + readMessageChan := make(chan struct{}) + + go func() { + _, message, err = conn.ReadMessage() + if err != nil { + if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) { + logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err))) + } + } + close(readMessageChan) + }() + + select { + case <-c.stopChan: + <-readMessageChan + break + case <-readMessageChan: + } + + if nil != err { + select { + case <-c.stopChan: + break + case <-time.After(time.Second): + } + continue + } + + c.readChan <- message + } +} + +func handleClientWrite(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) { + defer func() { + close(doneChan) + }() + + ticker := time.NewTicker(c.PingPeriod) + defer func() { + ticker.Stop() + }() + for { + select { + case message, ok := <-c.writeChan: + conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + if !ok { + conn.WriteMessage(internal.CloseMessage, []byte{}) + return + } + + w, err := conn.NextWriter(internal.TextMessage) + if err != nil { + return + } + w.Write(message) + + if err := w.Close(); nil != err { + return + } + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(c.PingTimeout)) + if err := conn.WriteMessage(internal.PingMessage, nil); nil != err { + return + } + case <-c.stopChan: + break + } + } } func (c *Client) Dial() (net.Conn, error) { @@ -72,15 +299,32 @@ func (c *Client) Validate() error { return fmt.Errorf("Client: Address is not valid") } - if 0 >= c.MaxConnections { - c.MaxConnections = 1 + if c.HandshakeTimeout <= 0 { + c.HandshakeTimeout = server.DefaultHandshakeTimeout } - - if 0 >= c.KeepAlive { - c.KeepAlive = DefaultKeepAlive + if c.MaxMessageSize <= 0 { + c.MaxMessageSize = server.DefaultMaxMessageSize } - if 0 >= c.HandshakeTimeout { - c.HandshakeTimeout = DefaultHandshakeTimeout + if c.ReadBufferSize <= 0 { + c.ReadBufferSize = server.DefaultReadBufferSize + } + if c.WriteBufferSize <= 0 { + c.WriteBufferSize = server.DefaultWriteBufferSize + } + if c.ReadTimeout <= 0 { + c.ReadTimeout = server.DefaultReadTimeout + } + if c.WriteTimeout <= 0 { + c.WriteTimeout = server.DefaultWriteTimeout + } + if c.PongTimeout <= 0 { + c.PongTimeout = server.DefaultPongTimeout + } + if c.PingTimeout <= 0 { + c.PingTimeout = server.DefaultPingTimeout + } + if c.PingPeriod <= 0 { + c.PingPeriod = server.DefaultPingPeriod } return nil diff --git a/net/const.go b/net/const.go deleted file mode 100644 index 4f3d4ad..0000000 --- a/net/const.go +++ /dev/null @@ -1,7 +0,0 @@ -package net - -const ( - DefaultHandshakeTimeout = 0 - - DefaultKeepAlive = 0 -)