package websocket import ( "bufio" "bytes" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "net" "net/http" "net/url" "strings" "sync" "time" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" ) var errMalformedURL = errors.New("malformed ws or wss URL") type Client struct { server.ClientConnHandlers server.ReadWriteHandlers Name string URL string RequestHeader http.Header Subprotocols []string // Jar specifies the cookie jar. // If Jar is nil, cookies are not sent in requests and ignored // in responses. CookieJar http.CookieJar // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*http.Request) (*url.URL, error) serverURL *url.URL stopChan chan struct{} stopWg sync.WaitGroup readChan chan []byte writeChan chan []byte disconnectedChan chan struct{} reconnectedChan chan *server.Conn crw server.ClientReadWriter } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { var ( conn *server.Conn ) if c.stopChan != nil { return nil, nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) } err = c.Validate() if nil != err { return nil, nil, nil, err } conn, res, err = c.connect() if nil != err { return nil, nil, nil, err } c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) c.disconnectedChan = make(chan struct{}) c.reconnectedChan = make(chan *server.Conn) c.stopChan = make(chan struct{}) c.crw.ReadwriteHandler = c c.crw.ReadChan = c.readChan c.crw.WriteChan = c.writeChan c.crw.ClientStopChan = c.stopChan c.crw.ClientStopWg = &c.stopWg c.crw.DisconnectedChan = c.disconnectedChan c.crw.ReconnectedChan = c.reconnectedChan c.stopWg.Add(2) go c.handleReconnect() go c.crw.HandleConnection(conn) return c.readChan, c.writeChan, res, nil } func (c *Client) Disconnect() error { if c.stopChan == nil { return fmt.Errorf(c.clientMessage("must be started before stopping it")) } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil return nil } func (c *Client) clientMessage(msg string) string { return fmt.Sprintf("Client[%s]: %s", c.Name, msg) } func (c *Client) handleReconnect() { defer func() { c.stopWg.Done() }() RC_LOOP: for { select { case <-c.disconnectedChan: case <-c.stopChan: return } if 0 >= c.ReconnectTryTime { c.reconnectedChan <- nil continue RC_LOOP } for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { conn, _, err := c.connect() if nil == err { c.reconnectedChan <- conn continue RC_LOOP } time.Sleep(c.ReconnectInterval) } } } func (c *Client) connect() (*server.Conn, *http.Response, error) { conn, res, err := c.Dial() if nil != err { return nil, nil, err } conn.SetCloseHandler(func(code int, text string) error { logging.Logger().Debugf("close") return nil }) return conn, res, nil } func (c *Client) Dial() (*server.Conn, *http.Response, error) { var ( err error challengeKey string netConn net.Conn ) if err = c.Validate(); nil != err { return nil, nil, err } challengeKey, err = generateChallengeKey() if err != nil { return nil, nil, err } req := &http.Request{ Method: "GET", URL: c.serverURL, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Host: c.serverURL.Host, } // Set the cookies present in the cookie jar of the dialer if nil != c.CookieJar { for _, cookie := range c.CookieJar.Cookies(c.serverURL) { req.AddCookie(cookie) } } // Set the request headers using the capitalization for names and values in // RFC examples. Although the capitalization shouldn't matter, there are // servers that depend on it. The Header.Set method is not used because the // method canonicalizes the header names. req.Header["Upgrade"] = []string{"websocket"} req.Header["Connection"] = []string{"Upgrade"} req.Header["Sec-WebSocket-Key"] = []string{challengeKey} req.Header["Sec-WebSocket-Version"] = []string{"13"} if len(c.Subprotocols) > 0 { req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(c.Subprotocols, ", ")} } for k, vs := range c.RequestHeader { switch { case k == "Host": if len(vs) > 0 { req.Host = vs[0] } case k == "Upgrade" || k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) default: req.Header[k] = vs } } if c.EnableCompression { req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") } hostPort, hostNoPort := hostPortNoPort(c.serverURL) var proxyURL *url.URL // Check wether the proxy method has been configured if nil != c.Proxy { proxyURL, err = c.Proxy(req) if err != nil { return nil, nil, err } } var targetHostPort string if proxyURL != nil { targetHostPort, _ = hostPortNoPort(proxyURL) } else { targetHostPort = hostPort } var deadline time.Time if 0 != c.HandshakeTimeout { deadline = time.Now().Add(c.HandshakeTimeout) } netDial := c.NetDial if netDial == nil { netDialer := &net.Dialer{Deadline: deadline} netDial = netDialer.Dial } netConn, err = netDial("tcp", targetHostPort) if err != nil { return nil, nil, err } defer func() { if nil != netConn { netConn.Close() } }() err = netConn.SetDeadline(deadline) if nil != err { return nil, nil, err } if nil != proxyURL { connectHeader := make(http.Header) if user := proxyURL.User; nil != user { proxyUser := user.Username() if proxyPassword, passwordSet := user.Password(); passwordSet { credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) connectHeader.Set("Proxy-Authorization", "Basic "+credential) } } connectReq := &http.Request{ Method: "CONNECT", URL: &url.URL{Opaque: hostPort}, Host: hostPort, Header: connectHeader, } connectReq.Write(netConn) // Read response. // Okay to use and discard buffered reader here, because // TLS server will not speak until spoken to. br := bufio.NewReader(netConn) resp, err := http.ReadResponse(br, connectReq) if err != nil { return nil, nil, err } if resp.StatusCode != 200 { f := strings.SplitN(resp.Status, " ", 2) return nil, nil, errors.New(f[1]) } } if "https" == c.serverURL.Scheme { cfg := cloneTLSConfig(c.TLSConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn if err := tlsConn.Handshake(); err != nil { return nil, nil, err } if !cfg.InsecureSkipVerify { if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { return nil, nil, err } } } conn := server.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) if err := req.Write(netConn); err != nil { return nil, nil, err } resp, err := http.ReadResponse(conn.BuffReader, req) if err != nil { return nil, nil, err } if nil != c.CookieJar { if rc := resp.Cookies(); len(rc) > 0 { c.CookieJar.SetCookies(c.serverURL, rc) } } if resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application // debugging. buf := make([]byte, 1024) n, _ := io.ReadFull(resp.Body, buf) resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) return nil, resp, server.ErrBadHandshake } for _, ext := range httpParseExtensions(resp.Header) { if ext[""] != "permessage-deflate" { continue } _, snct := ext["server_no_context_takeover"] _, cnct := ext["client_no_context_takeover"] if !snct || !cnct { return nil, resp, server.ErrInvalidCompression } conn.NewCompressionWriter = server.CompressNoContextTakeover conn.NewDecompressionReader = server.DecompressNoContextTakeover break } resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.Subprotocol = resp.Header.Get("Sec-Websocket-Protocol") netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. return conn, resp, nil } func (c *Client) Validate() error { if err := c.ClientConnHandlers.Validate(); nil != err { return err } if err := c.ReadWriteHandlers.Validate(); nil != err { return err } if "" == c.Name { c.Name = "Client" } if "" == c.URL { return fmt.Errorf("Client: URL is not valid") } u, err := parseURL(c.URL) if nil != err { return err } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" default: return errMalformedURL } if nil != u.User { // User name and password are not allowed in websocket URIs. return errMalformedURL } c.serverURL = u if nil == c.Proxy { c.Proxy = http.ProxyFromEnvironment } return nil } // parseURL parses the URL. // // This function is a replacement for the standard library url.Parse function. // In Go 1.4 and earlier, url.Parse loses information from the path. func parseURL(s string) (*url.URL, error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] var u url.URL switch { case strings.HasPrefix(s, "ws://"): u.Scheme = "ws" s = s[len("ws://"):] case strings.HasPrefix(s, "wss://"): u.Scheme = "wss" s = s[len("wss://"):] default: return nil, errMalformedURL } if i := strings.Index(s, "?"); i >= 0 { u.RawQuery = s[i+1:] s = s[:i] } if i := strings.Index(s, "/"); i >= 0 { u.Opaque = s[i:] s = s[:i] } else { u.Opaque = "/" } u.Host = s if strings.Contains(u.Host, "@") { // Don't bother parsing user information because user information is // not allowed in websocket URIs. return nil, errMalformedURL } return &u, nil } func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { hostNoPort = hostNoPort[:i] } else { switch u.Scheme { case "wss": hostPort += ":443" case "https": hostPort += ":443" default: hostPort += ":80" } } return hostPort, hostNoPort } func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} } return cfg.Clone() }