package websocket import ( "bufio" "bytes" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "net" "net/http" "net/url" "strings" "sync" "time" "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" ) var errMalformedURL = errors.New("malformed ws or wss URL") type Client struct { Name string URL string RequestHeader http.Header Subprotocols []string // Jar specifies the cookie jar. // If Jar is nil, cookies are not sent in requests and ignored // in responses. CookieJar http.CookieJar // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*http.Request) (*url.URL, error) TLSConfig *tls.Config HandshakeTimeout time.Duration // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer // size is zero, then a useful default size is used. The I/O buffer sizes // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the client's requested subprotocols. // EnableCompression specifies if the client should attempt to negotiate // per message compression (RFC 7692). Setting this value to true does not // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool // MaxMessageSize is the maximum size for a message read from the peer. If a // message exceeds the limit, the connection sends a close frame to the peer // and returns ErrReadLimit to the application. MaxMessageSize int64 // WriteTimeout is the write deadline on the underlying network // connection. After a write has timed out, the websocket state is corrupt and // all future writes will return an error. A zero value for t means writes will // not time out. WriteTimeout time.Duration // ReadTimeout is the read deadline on the underlying network connection. // After a read has timed out, the websocket connection state is corrupt and // all future reads will return an error. A zero value for t means reads will // not time out. ReadTimeout time.Duration PongTimeout time.Duration PingTimeout time.Duration PingPeriod time.Duration serverURL *url.URL stopChan chan struct{} stopWg sync.WaitGroup conn *server.Conn readChan chan []byte writeChan chan []byte } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { var ( conn *server.Conn ) if c.stopChan != nil { return nil, nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) } err = c.Validate() if nil != err { return nil, nil, nil, err } conn, res, err = c.connect() if nil != err { return nil, nil, nil, err } c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) c.stopChan = make(chan struct{}) c.stopWg.Add(1) go c.handleConnection(conn) return c.readChan, c.writeChan, res, nil } func (c *Client) Disconnect() error { if c.stopChan == nil { return fmt.Errorf(c.clientMessage("must be started before stopping it")) } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil return nil } func (c *Client) clientMessage(msg string) string { return fmt.Sprintf("Client[%s]: %s", c.Name, msg) } func (c *Client) connect() (*server.Conn, *http.Response, error) { conn, res, err := c.Dial() if nil != err { return nil, nil, err } conn.SetCloseHandler(func(code int, text string) error { logging.Logger().Debugf("close") return nil }) return conn, res, nil } func (c *Client) handleConnection(conn *server.Conn) { defer func() { if nil != conn { conn.Close() } logging.Logger().Infof(c.clientMessage("disconnected")) c.stopWg.Done() }() logging.Logger().Infof(c.clientMessage("connected")) stopChan := make(chan struct{}) readerDoneChan := make(chan struct{}) writerDoneChan := make(chan struct{}) go handleClientRead(c, conn, stopChan, readerDoneChan) go handleClientWrite(c, conn, stopChan, writerDoneChan) select { case <-readerDoneChan: close(stopChan) <-writerDoneChan case <-writerDoneChan: close(stopChan) <-readerDoneChan case <-c.stopChan: close(stopChan) <-readerDoneChan <-writerDoneChan } } func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { defer func() { doneChan <- struct{}{} }() if 0 < c.MaxMessageSize { conn.SetReadLimit(c.MaxMessageSize) } if 0 < c.ReadTimeout { conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) } conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(c.PongTimeout)) return nil }) var ( message []byte err error ) for { readMessageChan := make(chan struct{}) go func() { _, message, err = conn.ReadMessage() close(readMessageChan) }() select { case <-stopChan: <-readMessageChan return case <-readMessageChan: } if nil != err { if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) { logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err))) } return } c.readChan <- message } } func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) { defer func() { doneChan <- struct{}{} }() ticker := time.NewTicker(c.PingPeriod) defer func() { ticker.Stop() }() for { select { case message, ok := <-c.writeChan: conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) if !ok { conn.WriteMessage(server.CloseMessage, []byte{}) return } w, err := conn.NextWriter(server.TextMessage) if err != nil { return } w.Write(message) if err := w.Close(); nil != err { return } case <-ticker.C: conn.SetWriteDeadline(time.Now().Add(c.PingTimeout)) if err := conn.WriteMessage(server.PingMessage, nil); nil != err { return } case <-stopChan: return } } } func (c *Client) Dial() (*server.Conn, *http.Response, error) { var ( err error challengeKey string netConn net.Conn ) if err = c.Validate(); nil != err { return nil, nil, err } challengeKey, err = generateChallengeKey() if err != nil { return nil, nil, err } req := &http.Request{ Method: "GET", URL: c.serverURL, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Host: c.serverURL.Host, } // Set the cookies present in the cookie jar of the dialer if nil != c.CookieJar { for _, cookie := range c.CookieJar.Cookies(c.serverURL) { req.AddCookie(cookie) } } // Set the request headers using the capitalization for names and values in // RFC examples. Although the capitalization shouldn't matter, there are // servers that depend on it. The Header.Set method is not used because the // method canonicalizes the header names. req.Header["Upgrade"] = []string{"websocket"} req.Header["Connection"] = []string{"Upgrade"} req.Header["Sec-WebSocket-Key"] = []string{challengeKey} req.Header["Sec-WebSocket-Version"] = []string{"13"} if len(c.Subprotocols) > 0 { req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(c.Subprotocols, ", ")} } for k, vs := range c.RequestHeader { switch { case k == "Host": if len(vs) > 0 { req.Host = vs[0] } case k == "Upgrade" || k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) default: req.Header[k] = vs } } if c.EnableCompression { req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") } hostPort, hostNoPort := hostPortNoPort(c.serverURL) var proxyURL *url.URL // Check wether the proxy method has been configured if nil != c.Proxy { proxyURL, err = c.Proxy(req) if err != nil { return nil, nil, err } } var targetHostPort string if proxyURL != nil { targetHostPort, _ = hostPortNoPort(proxyURL) } else { targetHostPort = hostPort } var deadline time.Time if 0 != c.HandshakeTimeout { deadline = time.Now().Add(c.HandshakeTimeout) } netDial := c.NetDial if netDial == nil { netDialer := &net.Dialer{Deadline: deadline} netDial = netDialer.Dial } netConn, err = netDial("tcp", targetHostPort) if err != nil { return nil, nil, err } defer func() { if nil != netConn { netConn.Close() } }() err = netConn.SetDeadline(deadline) if nil != err { return nil, nil, err } if nil != proxyURL { connectHeader := make(http.Header) if user := proxyURL.User; nil != user { proxyUser := user.Username() if proxyPassword, passwordSet := user.Password(); passwordSet { credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) connectHeader.Set("Proxy-Authorization", "Basic "+credential) } } connectReq := &http.Request{ Method: "CONNECT", URL: &url.URL{Opaque: hostPort}, Host: hostPort, Header: connectHeader, } connectReq.Write(netConn) // Read response. // Okay to use and discard buffered reader here, because // TLS server will not speak until spoken to. br := bufio.NewReader(netConn) resp, err := http.ReadResponse(br, connectReq) if err != nil { return nil, nil, err } if resp.StatusCode != 200 { f := strings.SplitN(resp.Status, " ", 2) return nil, nil, errors.New(f[1]) } } if "https" == c.serverURL.Scheme { cfg := cloneTLSConfig(c.TLSConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn if err := tlsConn.Handshake(); err != nil { return nil, nil, err } if !cfg.InsecureSkipVerify { if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { return nil, nil, err } } } conn := server.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) if err := req.Write(netConn); err != nil { return nil, nil, err } resp, err := http.ReadResponse(conn.BuffReader, req) if err != nil { return nil, nil, err } if nil != c.CookieJar { if rc := resp.Cookies(); len(rc) > 0 { c.CookieJar.SetCookies(c.serverURL, rc) } } if resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application // debugging. buf := make([]byte, 1024) n, _ := io.ReadFull(resp.Body, buf) resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) return nil, resp, server.ErrBadHandshake } for _, ext := range httpParseExtensions(resp.Header) { if ext[""] != "permessage-deflate" { continue } _, snct := ext["server_no_context_takeover"] _, cnct := ext["client_no_context_takeover"] if !snct || !cnct { return nil, resp, server.ErrInvalidCompression } conn.NewCompressionWriter = server.CompressNoContextTakeover conn.NewDecompressionReader = server.DecompressNoContextTakeover break } resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.Subprotocol = resp.Header.Get("Sec-Websocket-Protocol") netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. return conn, resp, nil } func (c *Client) Validate() error { if "" == c.Name { c.Name = "Client" } if "" == c.URL { return fmt.Errorf("Client: URL is not valid") } u, err := parseURL(c.URL) if nil != err { return err } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" default: return errMalformedURL } if nil != u.User { // User name and password are not allowed in websocket URIs. return errMalformedURL } c.serverURL = u if nil == c.Proxy { c.Proxy = http.ProxyFromEnvironment } if c.HandshakeTimeout <= 0 { c.HandshakeTimeout = server.DefaultHandshakeTimeout } if c.MaxMessageSize <= 0 { c.MaxMessageSize = server.DefaultMaxMessageSize } if c.ReadBufferSize <= 0 { c.ReadBufferSize = server.DefaultReadBufferSize } if c.WriteBufferSize <= 0 { c.WriteBufferSize = server.DefaultWriteBufferSize } if c.ReadTimeout <= 0 { c.ReadTimeout = server.DefaultReadTimeout } if c.WriteTimeout <= 0 { c.WriteTimeout = server.DefaultWriteTimeout } if c.PongTimeout <= 0 { c.PongTimeout = server.DefaultPongTimeout } if c.PingTimeout <= 0 { c.PingTimeout = server.DefaultPingTimeout } if c.PingPeriod <= 0 { c.PingPeriod = server.DefaultPingPeriod } return nil } // parseURL parses the URL. // // This function is a replacement for the standard library url.Parse function. // In Go 1.4 and earlier, url.Parse loses information from the path. func parseURL(s string) (*url.URL, error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] var u url.URL switch { case strings.HasPrefix(s, "ws://"): u.Scheme = "ws" s = s[len("ws://"):] case strings.HasPrefix(s, "wss://"): u.Scheme = "wss" s = s[len("wss://"):] default: return nil, errMalformedURL } if i := strings.Index(s, "?"); i >= 0 { u.RawQuery = s[i+1:] s = s[:i] } if i := strings.Index(s, "/"); i >= 0 { u.Opaque = s[i:] s = s[:i] } else { u.Opaque = "/" } u.Host = s if strings.Contains(u.Host, "@") { // Don't bother parsing user information because user information is // not allowed in websocket URIs. return nil, errMalformedURL } return &u, nil } func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { hostNoPort = hostNoPort[:i] } else { switch u.Scheme { case "wss": hostPort += ":443" case "https": hostPort += ":443" default: hostPort += ":80" } } return hostPort, hostNoPort } func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} } return cfg.Clone() }