This commit is contained in:
crusader 2018-04-05 02:08:07 +09:00
parent 899fc53ded
commit ad158313e5
2 changed files with 22 additions and 10 deletions

View File

@ -37,6 +37,8 @@ type Client struct {
// in responses. // in responses.
CookieJar http.CookieJar CookieJar http.CookieJar
ResponseHandler func(*http.Response)
// NetDial specifies the dial function for creating TCP connections. If // NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
@ -61,23 +63,28 @@ type Client struct {
crw server.ClientReadWriter crw server.ClientReadWriter
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
var ( var (
conn server.Conn conn server.Conn
res *http.Response
) )
if c.stopChan != nil { if c.stopChan != nil {
return nil, nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again"))
} }
err = c.Validate() err = c.Validate()
if nil != err { if nil != err {
return nil, nil, nil, err return nil, nil, err
} }
conn, res, err = c.connect() conn, res, err = c.connect()
if nil != err { if nil != err {
return nil, nil, nil, err return nil, nil, err
}
resH := c.ResponseHandler
if nil != resH {
resH(res)
} }
c.readChan = make(chan []byte, 256) c.readChan = make(chan []byte, 256)
@ -98,7 +105,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res
go c.handleReconnect() go c.handleReconnect()
go c.crw.HandleConnection(conn) go c.crw.HandleConnection(conn)
return c.readChan, c.writeChan, res, nil return c.readChan, c.writeChan, nil
} }
func (c *Client) Disconnect() error { func (c *Client) Disconnect() error {
@ -140,8 +147,13 @@ RC_LOOP:
for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { for indexI := 0; indexI < c.ReconnectTryTime; indexI++ {
logging.Logger().Debugf("trying reconnect[%d]", indexI) logging.Logger().Debugf("trying reconnect[%d]", indexI)
conn, _, err := c.connect() conn, res, err := c.connect()
if nil == err { if nil == err {
resH := c.ResponseHandler
if nil != resH {
resH(res)
}
logging.Logger().Debugf("reconnected") logging.Logger().Debugf("reconnected")
c.reconnectedChan <- conn c.reconnectedChan <- conn
continue RC_LOOP continue RC_LOOP
@ -153,7 +165,7 @@ RC_LOOP:
} }
func (c *Client) connect() (server.Conn, *http.Response, error) { func (c *Client) connect() (server.Conn, *http.Response, error) {
conn, res, err := c.Dial() conn, res, err := c.dial()
if nil != err { if nil != err {
return nil, nil, err return nil, nil, err
} }
@ -165,7 +177,7 @@ func (c *Client) connect() (server.Conn, *http.Response, error) {
return conn, res, nil return conn, res, nil
} }
func (c *Client) Dial() (server.Conn, *http.Response, error) { func (c *Client) dial() (server.Conn, *http.Response, error) {
var ( var (
err error err error
challengeKey string challengeKey string

View File

@ -125,7 +125,7 @@ RC_LOOP:
} }
func (c *Client) connect() (server.Conn, error) { func (c *Client) connect() (server.Conn, error) {
netConn, err := c.Dial() netConn, err := c.dial()
if nil != err { if nil != err {
return nil, err return nil, err
} }
@ -138,7 +138,7 @@ func (c *Client) connect() (server.Conn, error) {
return conn, nil return conn, nil
} }
func (c *Client) Dial() (net.Conn, error) { func (c *Client) dial() (net.Conn, error) {
if err := c.Validate(); nil != err { if err := c.Validate(); nil != err {
return nil, err return nil, err
} }