package net import ( "crypto/tls" "fmt" "net" "sync" "time" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" ) type Client struct { Name string Network string Address string TLSConfig *tls.Config HandshakeTimeout time.Duration KeepAlive time.Duration LocalAddress net.Addr MaxMessageSize int64 // Per-connection buffer size for requests' reading. // This also limits the maximum header size. // // Increase this buffer if your clients send multi-KB RequestURIs // and/or multi-KB headers (for example, BIG cookies). // // Default buffer size is used if not set. ReadBufferSize int // Per-connection buffer size for responses' writing. // // Default buffer size is used if not set. WriteBufferSize int // Maximum duration for reading the full request (including body). // // This also limits the maximum duration for idle keep-alive // connections. // // By default request read timeout is unlimited. ReadTimeout time.Duration // Maximum duration for writing the full response (including body). // // By default response write timeout is unlimited. WriteTimeout time.Duration PongTimeout time.Duration PingTimeout time.Duration PingPeriod time.Duration EnableCompression bool stopChan chan struct{} stopWg sync.WaitGroup conn *server.Conn readChan chan []byte writeChan chan []byte } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { var ( conn *server.Conn ) if c.stopChan != nil { return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) } err = c.Validate() if nil != err { return nil, nil, err } conn, err = c.connect() if nil != err { return nil, nil, err } c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) c.stopChan = make(chan struct{}) c.stopWg.Add(1) go c.handleConnection(conn) return c.readChan, c.writeChan, nil } func (c *Client) Disconnect() error { if c.stopChan == nil { return fmt.Errorf(c.clientMessage("must be started before stopping it")) } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil return nil } func (c *Client) clientMessage(msg string) string { return fmt.Sprintf("Client[%s]: %s", c.Name, msg) } func (c *Client) connect() (*server.Conn, error) { netConn, err := c.Dial() if nil != err { return nil, err } conn := server.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) conn.SetCloseHandler(func(code int, text string) error { logging.Logger().Debugf("close") return nil }) return conn, nil } func (c *Client) handleConnection(conn *server.Conn) { defer func() { if nil != conn { conn.Close() } logging.Logger().Infof(c.clientMessage("disconnected")) c.stopWg.Done() }() logging.Logger().Infof(c.clientMessage("connected")) stopChan := make(chan struct{}) readerDoneChan := make(chan struct{}) writerDoneChan := make(chan struct{}) go handleClientRead(c, conn, stopChan, readerDoneChan) go handleClientWrite(c, conn, stopChan, writerDoneChan) select { case <-readerDoneChan: close(stopChan) <-writerDoneChan case <-writerDoneChan: close(stopChan) <-readerDoneChan case <-c.stopChan: close(stopChan) <-readerDoneChan <-writerDoneChan } } func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { defer func() { doneChan <- struct{}{} }() if 0 < c.MaxMessageSize { conn.SetReadLimit(c.MaxMessageSize) } if 0 < c.ReadTimeout { conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) } conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(c.PongTimeout)) return nil }) var ( message []byte err error ) for { readMessageChan := make(chan struct{}) go func() { _, message, err = conn.ReadMessage() close(readMessageChan) }() select { case <-stopChan: <-readMessageChan return case <-readMessageChan: } if nil != err { if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err))) } return } c.readChan <- message } } func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { defer func() { doneChan <- struct{}{} }() ticker := time.NewTicker(c.PingPeriod) defer func() { ticker.Stop() }() for { select { case message, ok := <-c.writeChan: if 0 < c.WriteTimeout { conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) } if !ok { conn.WriteMessage(server.CloseMessage, []byte{}) return } w, err := conn.NextWriter(server.TextMessage) if err != nil { return } w.Write(message) if err := w.Close(); nil != err { return } case <-ticker.C: conn.SetWriteDeadline(time.Now().Add(c.PingTimeout)) if err := conn.WriteMessage(server.PingMessage, nil); nil != err { return } case <-stopChan: return } } } func (c *Client) Dial() (net.Conn, error) { if err := c.Validate(); nil != err { return nil, err } var deadline time.Time if 0 != c.HandshakeTimeout { deadline = time.Now().Add(c.HandshakeTimeout) } d := &net.Dialer{ KeepAlive: c.KeepAlive, Deadline: deadline, LocalAddr: c.LocalAddress, } conn, err := d.Dial(c.Network, c.Address) if nil != err { return nil, err } if nil != c.TLSConfig { cfg := c.TLSConfig.Clone() tlsConn := tls.Client(conn, cfg) if err := tlsConn.Handshake(); err != nil { tlsConn.Close() return nil, err } if !cfg.InsecureSkipVerify { if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { return nil, err } } conn = tlsConn } return conn, nil } func (c *Client) Validate() error { if "" == c.Name { c.Name = "Client" } if "" == c.Network { return fmt.Errorf("Client: Network is not valid") } if "" == c.Address { return fmt.Errorf("Client: Address is not valid") } if c.HandshakeTimeout <= 0 { c.HandshakeTimeout = server.DefaultHandshakeTimeout } if c.MaxMessageSize <= 0 { c.MaxMessageSize = server.DefaultMaxMessageSize } if c.ReadBufferSize <= 0 { c.ReadBufferSize = server.DefaultReadBufferSize } if c.WriteBufferSize <= 0 { c.WriteBufferSize = server.DefaultWriteBufferSize } if c.ReadTimeout <= 0 { c.ReadTimeout = server.DefaultReadTimeout } if c.WriteTimeout <= 0 { c.WriteTimeout = server.DefaultWriteTimeout } if c.PongTimeout <= 0 { c.PongTimeout = server.DefaultPongTimeout } if c.PingTimeout <= 0 { c.PingTimeout = server.DefaultPingTimeout } if c.PingPeriod <= 0 { c.PingPeriod = server.DefaultPingPeriod } return nil }