package net import ( "crypto/tls" "fmt" "net" "sync" "time" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" ) type Client struct { server.ClientConnHandlers server.ReadWriteHandlers Name string Network string Address string LocalAddress net.Addr stopChan chan struct{} stopWg sync.WaitGroup readChan chan []byte writeChan chan []byte disconnectedChan chan struct{} reconnectedChan chan *server.Conn crw server.ClientReadWriter } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { var ( conn *server.Conn ) if c.stopChan != nil { return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) } err = c.Validate() if nil != err { return nil, nil, err } conn, err = c.connect() if nil != err { return nil, nil, err } c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) c.disconnectedChan = make(chan struct{}) c.reconnectedChan = make(chan *server.Conn) c.stopChan = make(chan struct{}) c.crw.ReadwriteHandler = c c.crw.ReadChan = c.readChan c.crw.WriteChan = c.writeChan c.crw.ClientStopChan = c.stopChan c.crw.ClientStopWg = &c.stopWg c.crw.DisconnectedChan = c.disconnectedChan c.crw.ReconnectedChan = c.reconnectedChan c.stopWg.Add(2) go c.handleReconnect() go c.crw.HandleConnection(conn) return c.readChan, c.writeChan, nil } func (c *Client) Disconnect() error { if c.stopChan == nil { return fmt.Errorf(c.clientMessage("must be started before stopping it")) } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil return nil } func (c *Client) clientMessage(msg string) string { return fmt.Sprintf("Client[%s]: %s", c.Name, msg) } func (c *Client) handleReconnect() { defer func() { c.stopWg.Done() }() RC_LOOP: for { select { case <-c.disconnectedChan: case <-c.stopChan: return } if 0 >= c.ReconnectTryTime { c.reconnectedChan <- nil continue RC_LOOP } logging.Logger().Debugf("connection lost") for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { logging.Logger().Debugf("trying reconnect[%d]", indexI) conn, err := c.connect() if nil == err { logging.Logger().Debugf("reconnected") c.reconnectedChan <- conn continue RC_LOOP } time.Sleep(c.ReconnectInterval) } logging.Logger().Debugf("reconnecting has been failed") } } func (c *Client) connect() (*server.Conn, error) { netConn, err := c.Dial() if nil != err { return nil, err } conn := server.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) conn.SetCloseHandler(func(code int, text string) error { logging.Logger().Debugf("close") return nil }) return conn, nil } func (c *Client) Dial() (net.Conn, error) { if err := c.Validate(); nil != err { return nil, err } var deadline time.Time if 0 != c.HandshakeTimeout { deadline = time.Now().Add(c.HandshakeTimeout) } d := &net.Dialer{ KeepAlive: c.KeepAlive, Deadline: deadline, LocalAddr: c.LocalAddress, } conn, err := d.Dial(c.Network, c.Address) if nil != err { return nil, err } if nil != c.TLSConfig { cfg := c.TLSConfig.Clone() tlsConn := tls.Client(conn, cfg) if err := tlsConn.Handshake(); err != nil { tlsConn.Close() return nil, err } if !cfg.InsecureSkipVerify { if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { return nil, err } } conn = tlsConn } return conn, nil } func (c *Client) Validate() error { if err := c.ClientConnHandlers.Validate(); nil != err { return err } if err := c.ReadWriteHandlers.Validate(); nil != err { return err } if "" == c.Name { c.Name = "Client" } if "" == c.Network { return fmt.Errorf("Client: Network is not valid") } if "" == c.Address { return fmt.Errorf("Client: Address is not valid") } return nil }