package client import ( "bufio" "bytes" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "net" "net/http" "net/url" "strings" "sync" "sync/atomic" "time" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go/socket" "git.loafle.net/commons/server-go/socket/client" "git.loafle.net/commons/server-go/socket/web" ) var errMalformedURL = errors.New("malformed ws or wss URL") type Connectors struct { client.Connectors socket.ClientConnHandlers socket.ReadWriteHandlers URL string `json:"url,omitempty"` RequestHeader func() http.Header `json:"-"` Subprotocols []string `json:"subprotocols,omitempty"` // Jar specifies the cookie jar. // If Jar is nil, cookies are not sent in requests and ignored // in responses. CookieJar http.CookieJar `json:"-"` ResponseHandler func(*http.Response) `json:"-"` // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) `json:"-"` // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*http.Request) (*url.URL, error) `json:"-"` serverURL *url.URL stopChan chan struct{} stopWg sync.WaitGroup readChan chan socket.SocketMessage writeChan chan socket.SocketMessage disconnectedChan chan struct{} reconnectedChan chan socket.Conn crw socket.ClientReadWriter validated atomic.Value } func (c *Connectors) Connect() (readChan <-chan socket.SocketMessage, writeChan chan<- socket.SocketMessage, err error) { var ( conn socket.Conn res *http.Response ) if c.stopChan != nil { return nil, nil, fmt.Errorf("%s already connected", c.logHeader()) } conn, res, err = c.connect() if nil != err { return nil, nil, err } resH := c.ResponseHandler if nil != resH { resH(res) } c.readChan = make(chan socket.SocketMessage, 256) c.writeChan = make(chan socket.SocketMessage, 256) c.disconnectedChan = make(chan struct{}) c.reconnectedChan = make(chan socket.Conn) c.stopChan = make(chan struct{}) c.crw.ReadwriteHandler = c c.crw.ReadChan = c.readChan c.crw.WriteChan = c.writeChan c.crw.ClientStopChan = c.stopChan c.crw.ClientStopWg = &c.stopWg c.crw.DisconnectedChan = c.disconnectedChan c.crw.ReconnectedChan = c.reconnectedChan c.stopWg.Add(1) go c.handleReconnect() c.stopWg.Add(1) go c.crw.HandleConnection(conn) return c.readChan, c.writeChan, nil } func (c *Connectors) Disconnect() error { if c.stopChan == nil { return fmt.Errorf("%s must be connected before disconnection it", c.logHeader()) } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil return nil } func (c *Connectors) logHeader() string { return fmt.Sprintf("Connector[%s]:", c.Name) } func (c *Connectors) onDisconnected() { close(c.readChan) close(c.writeChan) c.reconnectedChan <- nil onDisconnected := c.OnDisconnected if nil != onDisconnected { go func() { onDisconnected(c) }() } } func (c *Connectors) handleReconnect() { defer func() { c.stopWg.Done() }() RC_LOOP: for { select { case <-c.disconnectedChan: case <-c.stopChan: return } if 0 >= c.GetReconnectTryTime() { c.onDisconnected() return } logging.Logger().Debugf("%s connection lost", c.logHeader()) for indexI := 0; indexI < c.GetReconnectTryTime(); indexI++ { logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI) conn, res, err := c.connect() if nil == err { resH := c.ResponseHandler if nil != resH { resH(res) } logging.Logger().Debugf("%s reconnected", c.logHeader()) c.reconnectedChan <- conn continue RC_LOOP } time.Sleep(c.GetReconnectInterval()) } logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader()) c.onDisconnected() return } } func (c *Connectors) connect() (socket.Conn, *http.Response, error) { conn, res, err := c.dial() if nil != err { return nil, nil, err } conn.SetCloseHandler(func(code int, text string) error { logging.Logger().Debugf("%s close", c.logHeader()) return nil }) return conn, res, nil } func (c *Connectors) dial() (socket.Conn, *http.Response, error) { var ( err error challengeKey string netConn net.Conn ) challengeKey, err = web.GenerateChallengeKey() if err != nil { return nil, nil, err } req := &http.Request{ Method: "GET", URL: c.serverURL, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Host: c.serverURL.Host, } cookieJar := c.CookieJar // Set the cookies present in the cookie jar of the dialer if nil != cookieJar { for _, cookie := range cookieJar.Cookies(c.serverURL) { req.AddCookie(cookie) } } // Set the request headers using the capitalization for names and values in // RFC examples. Although the capitalization shouldn't matter, there are // servers that depend on it. The Header.Set method is not used because the // method canonicalizes the header names. req.Header["Upgrade"] = []string{"websocket"} req.Header["Connection"] = []string{"Upgrade"} req.Header["Sec-WebSocket-Key"] = []string{challengeKey} req.Header["Sec-WebSocket-Version"] = []string{"13"} subprotocols := c.Subprotocols if len(subprotocols) > 0 { req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(subprotocols, ", ")} } for k, vs := range c.RequestHeader() { switch { case k == "Host": if len(vs) > 0 { req.Host = vs[0] } case k == "Upgrade" || k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(subprotocols) > 0): return nil, nil, fmt.Errorf("%s duplicate header not allowed: %s", c.logHeader(), k) default: req.Header[k] = vs } } if c.IsEnableCompression() { req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") } hostPort, hostNoPort := hostPortNoPort(c.serverURL) var proxyURL *url.URL // Check wether the proxy method has been configured proxy := c.Proxy if nil != proxy { proxyURL, err = proxy(req) if err != nil { return nil, nil, err } } var targetHostPort string if proxyURL != nil { targetHostPort, _ = hostPortNoPort(proxyURL) } else { targetHostPort = hostPort } var deadline time.Time handshakeTimeout := c.GetHandshakeTimeout() if 0 != handshakeTimeout { deadline = time.Now().Add(handshakeTimeout) } netDial := c.NetDial if netDial == nil { netDialer := &net.Dialer{Deadline: deadline} netDial = netDialer.Dial } netConn, err = netDial("tcp", targetHostPort) if err != nil { return nil, nil, err } defer func() { if nil != netConn { netConn.Close() } }() err = netConn.SetDeadline(deadline) if nil != err { return nil, nil, err } if nil != proxyURL { connectHeader := make(http.Header) if user := proxyURL.User; nil != user { proxyUser := user.Username() if proxyPassword, passwordSet := user.Password(); passwordSet { credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) connectHeader.Set("Proxy-Authorization", "Basic "+credential) } } connectReq := &http.Request{ Method: "CONNECT", URL: &url.URL{Opaque: hostPort}, Host: hostPort, Header: connectHeader, } connectReq.Write(netConn) // Read response. // Okay to use and discard buffered reader here, because // TLS server will not speak until spoken to. br := bufio.NewReader(netConn) resp, err := http.ReadResponse(br, connectReq) if err != nil { return nil, nil, err } if resp.StatusCode != 200 { f := strings.SplitN(resp.Status, " ", 2) return nil, nil, errors.New(f[1]) } } if "https" == c.serverURL.Scheme { cfg := cloneTLSConfig(c.GetTLSConfig()) if cfg.ServerName == "" { cfg.ServerName = hostNoPort } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn if err := tlsConn.Handshake(); err != nil { return nil, nil, err } if !cfg.InsecureSkipVerify { if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { return nil, nil, err } } } conn := socket.NewConn(netConn, false, c.GetReadBufferSize(), c.GetWriteBufferSize()) if err := req.Write(netConn); err != nil { return nil, nil, err } resp, err := http.ReadResponse(conn.BuffReader, req) if err != nil { return nil, nil, err } if nil != cookieJar { if rc := resp.Cookies(); len(rc) > 0 { cookieJar.SetCookies(c.serverURL, rc) } } if resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != web.ComputeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application // debugging. buf := make([]byte, 1024) n, _ := io.ReadFull(resp.Body, buf) resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) return nil, resp, socket.ErrBadHandshake } for _, ext := range web.HttpParseExtensions(resp.Header) { if ext[""] != "permessage-deflate" { continue } _, snct := ext["server_no_context_takeover"] _, cnct := ext["client_no_context_takeover"] if !snct || !cnct { return nil, resp, socket.ErrInvalidCompression } conn.SetNewCompressionWriter(socket.CompressNoContextTakeover) conn.SetNewDecompressionReader(socket.DecompressNoContextTakeover) break } resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.SetSubprotocol(resp.Header.Get("Sec-Websocket-Protocol")) netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. return conn, resp, nil } // parseURL parses the URL. // // This function is a replacement for the standard library url.Parse function. // In Go 1.4 and earlier, url.Parse loses information from the path. func parseURL(s string) (*url.URL, error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] var u url.URL switch { case strings.HasPrefix(s, "ws://"): u.Scheme = "ws" s = s[len("ws://"):] case strings.HasPrefix(s, "wss://"): u.Scheme = "wss" s = s[len("wss://"):] default: return nil, errMalformedURL } if i := strings.Index(s, "?"); i >= 0 { u.RawQuery = s[i+1:] s = s[:i] } if i := strings.Index(s, "/"); i >= 0 { u.Opaque = s[i:] s = s[:i] } else { u.Opaque = "/" } u.Host = s if strings.Contains(u.Host, "@") { // Don't bother parsing user information because user information is // not allowed in websocket URIs. return nil, errMalformedURL } return &u, nil } func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { hostNoPort = hostNoPort[:i] } else { switch u.Scheme { case "wss": hostPort += ":443" case "https": hostPort += ":443" default: hostPort += ":80" } } return hostPort, hostNoPort } func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} } return cfg.Clone() } func (c *Connectors) Clone() client.Connector { return &Connectors{ Connectors: *c.Connectors.Clone(), ClientConnHandlers: *c.ClientConnHandlers.Clone(), ReadWriteHandlers: *c.ReadWriteHandlers.Clone(), URL: c.URL, RequestHeader: c.RequestHeader, Subprotocols: c.Subprotocols, CookieJar: c.CookieJar, ResponseHandler: c.ResponseHandler, NetDial: c.NetDial, Proxy: c.Proxy, serverURL: c.serverURL, validated: c.validated, } } func (c *Connectors) Validate() error { if nil != c.validated.Load() { return nil } c.validated.Store(true) if err := c.Connectors.Validate(); nil != err { return err } if err := c.ClientConnHandlers.Validate(); nil != err { return err } if err := c.ReadWriteHandlers.Validate(); nil != err { return err } if "" == c.URL { return fmt.Errorf("URL is not valid") } u, err := parseURL(c.URL) if nil != err { return err } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" default: return errMalformedURL } if nil != u.User { // User name and password are not allowed in websocket URIs. return errMalformedURL } c.serverURL = u if nil == c.Proxy { c.Proxy = http.ProxyFromEnvironment } return nil }