diff --git a/fasthttp/websocket/client.go b/fasthttp/websocket/client.go index 711c09e..8f8ec7a 100644 --- a/fasthttp/websocket/client.go +++ b/fasthttp/websocket/client.go @@ -37,6 +37,8 @@ type Client struct { // in responses. CookieJar http.CookieJar + ResponseHandler func(*http.Response) + // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) @@ -61,23 +63,28 @@ type Client struct { crw server.ClientReadWriter } -func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { +func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { var ( conn server.Conn + res *http.Response ) if c.stopChan != nil { - return nil, nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) + return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) } err = c.Validate() if nil != err { - return nil, nil, nil, err + return nil, nil, err } conn, res, err = c.connect() if nil != err { - return nil, nil, nil, err + return nil, nil, err + } + resH := c.ResponseHandler + if nil != resH { + resH(res) } c.readChan = make(chan []byte, 256) @@ -98,7 +105,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res go c.handleReconnect() go c.crw.HandleConnection(conn) - return c.readChan, c.writeChan, res, nil + return c.readChan, c.writeChan, nil } func (c *Client) Disconnect() error { @@ -140,8 +147,13 @@ RC_LOOP: for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { logging.Logger().Debugf("trying reconnect[%d]", indexI) - conn, _, err := c.connect() + conn, res, err := c.connect() if nil == err { + resH := c.ResponseHandler + if nil != resH { + resH(res) + } + logging.Logger().Debugf("reconnected") c.reconnectedChan <- conn continue RC_LOOP @@ -153,7 +165,7 @@ RC_LOOP: } func (c *Client) connect() (server.Conn, *http.Response, error) { - conn, res, err := c.Dial() + conn, res, err := c.dial() if nil != err { return nil, nil, err } @@ -165,7 +177,7 @@ func (c *Client) connect() (server.Conn, *http.Response, error) { return conn, res, nil } -func (c *Client) Dial() (server.Conn, *http.Response, error) { +func (c *Client) dial() (server.Conn, *http.Response, error) { var ( err error challengeKey string diff --git a/net/client.go b/net/client.go index c1cf94b..2fed7cc 100644 --- a/net/client.go +++ b/net/client.go @@ -125,7 +125,7 @@ RC_LOOP: } func (c *Client) connect() (server.Conn, error) { - netConn, err := c.Dial() + netConn, err := c.dial() if nil != err { return nil, err } @@ -138,7 +138,7 @@ func (c *Client) connect() (server.Conn, error) { return conn, nil } -func (c *Client) Dial() (net.Conn, error) { +func (c *Client) dial() (net.Conn, error) { if err := c.Validate(); nil != err { return nil, err }