diff --git a/fasthttp/websocket/client.go b/fasthttp/websocket/client.go index 2e782e5..e6e2fe4 100644 --- a/fasthttp/websocket/client.go +++ b/fasthttp/websocket/client.go @@ -94,7 +94,8 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res c.crw.DisconnectedChan = c.disconnectedChan c.crw.ReconnectedChan = c.reconnectedChan - c.stopWg.Add(1) + c.stopWg.Add(2) + go c.handleReconnect() go c.crw.HandleConnection(conn) return c.readChan, c.writeChan, res, nil @@ -116,6 +117,35 @@ func (c *Client) clientMessage(msg string) string { return fmt.Sprintf("Client[%s]: %s", c.Name, msg) } +func (c *Client) handleReconnect() { + defer func() { + c.stopWg.Done() + }() + +RC_LOOP: + for { + select { + case <-c.disconnectedChan: + case <-c.stopChan: + return + } + + if 0 >= c.ReconnectTryTime { + c.reconnectedChan <- nil + continue RC_LOOP + } + + for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { + conn, _, err := c.connect() + if nil == err { + c.reconnectedChan <- conn + continue RC_LOOP + } + time.Sleep(c.ReconnectInterval) + } + } +} + func (c *Client) connect() (*server.Conn, *http.Response, error) { conn, res, err := c.Dial() if nil != err { diff --git a/net/client.go b/net/client.go index 99b41dd..a75e287 100644 --- a/net/client.go +++ b/net/client.go @@ -66,7 +66,8 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err c.crw.DisconnectedChan = c.disconnectedChan c.crw.ReconnectedChan = c.reconnectedChan - c.stopWg.Add(1) + c.stopWg.Add(2) + go c.handleReconnect() go c.crw.HandleConnection(conn) return c.readChan, c.writeChan, nil @@ -88,6 +89,35 @@ func (c *Client) clientMessage(msg string) string { return fmt.Sprintf("Client[%s]: %s", c.Name, msg) } +func (c *Client) handleReconnect() { + defer func() { + c.stopWg.Done() + }() + +RC_LOOP: + for { + select { + case <-c.disconnectedChan: + case <-c.stopChan: + return + } + + if 0 >= c.ReconnectTryTime { + c.reconnectedChan <- nil + continue RC_LOOP + } + + for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { + conn, err := c.connect() + if nil == err { + c.reconnectedChan <- conn + continue RC_LOOP + } + time.Sleep(c.ReconnectInterval) + } + } +} + func (c *Client) connect() (*server.Conn, error) { netConn, err := c.Dial() if nil != err {