From f85b949e66ab073ba709f2eeb20499873ee0fcf6 Mon Sep 17 00:00:00 2001 From: crusader Date: Wed, 4 Apr 2018 13:01:26 +0900 Subject: [PATCH] ing --- const.go | 24 +- fasthttp/websocket/client.go | 408 ++++++++++ fasthttp/websocket/server-handler.go | 101 +++ fasthttp/websocket/server.go | 286 +++++++ fasthttp/websocket/servlet.go | 12 + fasthttp/websocket/upgrade.go | 277 +++++++ fasthttp/websocket/util.go | 272 +++++++ internal/compression.go | 144 ++++ internal/conn-json.go | 55 ++ internal/conn-read.go | 18 + internal/conn.go | 1040 ++++++++++++++++++++++++++ internal/error.go | 111 +++ internal/mask.go | 49 ++ internal/prepared.go | 100 +++ client.go => net/client.go | 2 +- net/const.go | 7 + net/server-handler.go | 67 ++ net/server.go | 321 ++++++++ net/servlet.go | 13 + server-context.go | 12 +- server-handler.go | 172 ++++- server.go | 166 ---- servlet-context.go | 28 + servlet.go | 13 +- 24 files changed, 3495 insertions(+), 203 deletions(-) create mode 100644 fasthttp/websocket/client.go create mode 100644 fasthttp/websocket/server-handler.go create mode 100644 fasthttp/websocket/server.go create mode 100644 fasthttp/websocket/servlet.go create mode 100644 fasthttp/websocket/upgrade.go create mode 100644 fasthttp/websocket/util.go create mode 100644 internal/compression.go create mode 100644 internal/conn-json.go create mode 100644 internal/conn-read.go create mode 100644 internal/conn.go create mode 100644 internal/error.go create mode 100644 internal/mask.go create mode 100644 internal/prepared.go rename client.go => net/client.go (98%) create mode 100644 net/const.go create mode 100644 net/server-handler.go create mode 100644 net/server.go create mode 100644 net/servlet.go delete mode 100644 server.go create mode 100644 servlet-context.go diff --git a/const.go b/const.go index bf9ead9..40cd029 100644 --- a/const.go +++ b/const.go @@ -1,7 +1,29 @@ package server const ( + // DefaultConcurrency is the maximum number of concurrent connections + // the Server may serve by default (i.e. if Server.Concurrency isn't set). + DefaultConcurrency = 256 * 1024 + + // DefaultHandshakeTimeout is default value of websocket handshake Timeout DefaultHandshakeTimeout = 0 - DefaultKeepAlive = 0 + // DefaultReadBufferSize is default value of Read Buffer Size + DefaultReadBufferSize = 0 + // DefaultWriteBufferSize is default value of Write Buffer Size + DefaultWriteBufferSize = 0 + // DefaultReadTimeout is default value of read timeout + DefaultReadTimeout = 0 + // DefaultWriteTimeout is default value of write timeout + DefaultWriteTimeout = 0 + // DefaultEnableCompression is default value of support compression + DefaultEnableCompression = false + // DefaultMaxMessageSize is default size for a message read from the peer + DefaultMaxMessageSize = 4096 + // DefaultPongTimeout is default value of websocket pong Timeout + DefaultPongTimeout = 0 + // DefaultPingTimeout is default value of websocket ping Timeout + DefaultPingTimeout = 0 + // DefaultPingPeriod is default value of send ping period + DefaultPingPeriod = 0 ) diff --git a/fasthttp/websocket/client.go b/fasthttp/websocket/client.go new file mode 100644 index 0000000..4643352 --- /dev/null +++ b/fasthttp/websocket/client.go @@ -0,0 +1,408 @@ +package websocket + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" + + server "git.loafle.net/commons/server-go" + "git.loafle.net/commons/server-go/internal" +) + +var errMalformedURL = errors.New("malformed ws or wss URL") + +type Client struct { + Name string + + URL string + + RequestHeader http.Header + + // NetDial specifies the dial function for creating TCP connections. If + // NetDial is nil, net.Dial is used. + NetDial func(network, addr string) (net.Conn, error) + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + TLSConfig *tls.Config + HandshakeTimeout time.Duration + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer + // size is zero, then a useful default size is used. The I/O buffer sizes + // do not limit the size of the messages that can be sent or received. + ReadBufferSize, WriteBufferSize int + // Subprotocols specifies the client's requested subprotocols. + Subprotocols []string + + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + CookieJar http.CookieJar + + // MaxMessageSize is the maximum size for a message read from the peer. If a + // message exceeds the limit, the connection sends a close frame to the peer + // and returns ErrReadLimit to the application. + MaxMessageSize int64 + // WriteTimeout is the write deadline on the underlying network + // connection. After a write has timed out, the websocket state is corrupt and + // all future writes will return an error. A zero value for t means writes will + // not time out. + WriteTimeout time.Duration + // ReadTimeout is the read deadline on the underlying network connection. + // After a read has timed out, the websocket connection state is corrupt and + // all future reads will return an error. A zero value for t means reads will + // not time out. + ReadTimeout time.Duration + + PongTimeout time.Duration + PingTimeout time.Duration + PingPeriod time.Duration + + serverURL *url.URL +} + +func (c *Client) Dial() (*internal.Conn, *http.Response, error) { + var ( + err error + challengeKey string + netConn net.Conn + ) + + if err = c.Validate(); nil != err { + return nil, nil, err + } + + challengeKey, err = generateChallengeKey() + if err != nil { + return nil, nil, err + } + + req := &http.Request{ + Method: "GET", + URL: c.serverURL, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: c.serverURL.Host, + } + + // Set the cookies present in the cookie jar of the dialer + if nil != c.CookieJar { + for _, cookie := range c.CookieJar.Cookies(c.serverURL) { + req.AddCookie(cookie) + } + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(c.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(c.Subprotocols, ", ")} + } + for k, vs := range c.RequestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || + (k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + default: + req.Header[k] = vs + } + } + + if c.EnableCompression { + req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + + hostPort, hostNoPort := hostPortNoPort(c.serverURL) + + var proxyURL *url.URL + // Check wether the proxy method has been configured + if nil != c.Proxy { + proxyURL, err = c.Proxy(req) + if err != nil { + return nil, nil, err + } + } + + var targetHostPort string + if proxyURL != nil { + targetHostPort, _ = hostPortNoPort(proxyURL) + } else { + targetHostPort = hostPort + } + + var deadline time.Time + if 0 != c.HandshakeTimeout { + deadline = time.Now().Add(c.HandshakeTimeout) + } + + netDial := c.NetDial + if netDial == nil { + netDialer := &net.Dialer{Deadline: deadline} + netDial = netDialer.Dial + } + + netConn, err = netDial("tcp", targetHostPort) + if err != nil { + return nil, nil, err + } + + defer func() { + if nil != netConn { + netConn.Close() + } + }() + + err = netConn.SetDeadline(deadline) + if nil != err { + return nil, nil, err + } + + if nil != proxyURL { + connectHeader := make(http.Header) + if user := proxyURL.User; nil != user { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: hostPort}, + Host: hostPort, + Header: connectHeader, + } + + connectReq.Write(netConn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(netConn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + return nil, nil, errors.New(f[1]) + } + } + + if "https" == c.serverURL.Scheme { + cfg := cloneTLSConfig(c.TLSConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + if err := tlsConn.Handshake(); err != nil { + return nil, nil, err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return nil, nil, err + } + } + } + + conn := internal.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + resp, err := http.ReadResponse(conn.BuffReader, req) + if err != nil { + return nil, nil, err + } + + if nil != c.CookieJar { + if rc := resp.Cookies(); len(rc) > 0 { + c.CookieJar.SetCookies(c.serverURL, rc) + } + } + + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, internal.ErrBadHandshake + } + + for _, ext := range httpParseExtensions(resp.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, internal.ErrInvalidCompression + } + conn.NewCompressionWriter = internal.CompressNoContextTakeover + conn.NewDecompressionReader = internal.DecompressNoContextTakeover + break + } + + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + conn.Subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + + netConn.SetDeadline(time.Time{}) + netConn = nil // to avoid close in defer. + + return conn, resp, nil +} + +func (c *Client) Validate() error { + if "" == c.Name { + c.Name = "Client" + } + + if "" == c.URL { + return fmt.Errorf("Client: URL is not valid") + } + + u, err := parseURL(c.URL) + if nil != err { + return err + } + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return errMalformedURL + } + if nil != u.User { + // User name and password are not allowed in websocket URIs. + return errMalformedURL + } + + c.serverURL = u + + if nil == c.Proxy { + c.Proxy = http.ProxyFromEnvironment + } + + if 0 > c.HandshakeTimeout { + c.HandshakeTimeout = server.DefaultHandshakeTimeout + } + if 0 > c.ReadBufferSize { + c.ReadBufferSize = server.DefaultReadBufferSize + } + if 0 > c.WriteBufferSize { + c.WriteBufferSize = server.DefaultWriteBufferSize + } + + return nil +} + +// parseURL parses the URL. +// +// This function is a replacement for the standard library url.Parse function. +// In Go 1.4 and earlier, url.Parse loses information from the path. +func parseURL(s string) (*url.URL, error) { + // From the RFC: + // + // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] + // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] + var u url.URL + switch { + case strings.HasPrefix(s, "ws://"): + u.Scheme = "ws" + s = s[len("ws://"):] + case strings.HasPrefix(s, "wss://"): + u.Scheme = "wss" + s = s[len("wss://"):] + default: + return nil, errMalformedURL + } + + if i := strings.Index(s, "?"); i >= 0 { + u.RawQuery = s[i+1:] + s = s[:i] + } + + if i := strings.Index(s, "/"); i >= 0 { + u.Opaque = s[i:] + s = s[:i] + } else { + u.Opaque = "/" + } + + u.Host = s + + if strings.Contains(u.Host, "@") { + // Don't bother parsing user information because user information is + // not allowed in websocket URIs. + return nil, errMalformedURL + } + + return &u, nil +} + +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": + hostPort += ":443" + default: + hostPort += ":80" + } + } + return hostPort, hostNoPort +} + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} diff --git a/fasthttp/websocket/server-handler.go b/fasthttp/websocket/server-handler.go new file mode 100644 index 0000000..77f49ec --- /dev/null +++ b/fasthttp/websocket/server-handler.go @@ -0,0 +1,101 @@ +package websocket + +import ( + "net/http" + + "git.loafle.net/commons/server-go" + + "github.com/valyala/fasthttp" +) + +type ServerHandler interface { + server.ServerHandler + + OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) + + RegisterServlet(path string, servlet Servlet) + Servlet(path string) Servlet + + CheckOrigin(ctx *fasthttp.RequestCtx) bool +} + +type ServerHandlers struct { + server.ServerHandlers + + servlets map[string]Servlet +} + +func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.Init(serverCtx); nil != err { + return err + } + + if nil != sh.servlets { + for _, servlet := range sh.servlets { + if err := servlet.Init(serverCtx); nil != err { + return err + } + } + } + + return nil +} + +func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { + if nil != sh.servlets { + for _, servlet := range sh.servlets { + servlet.Destroy(serverCtx) + } + } + + sh.ServerHandlers.Destroy(serverCtx) +} + +func (sh *ServerHandlers) OnPing(msg string) error { + return nil +} + +func (sh *ServerHandlers) OnPong(msg string) error { + return nil +} + +func (sh *ServerHandlers) OnClose(code int, text string) error { + return nil +} + +func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) { + ctx.Response.Header.Set("Sec-Websocket-Version", "13") + ctx.Error(http.StatusText(status), status) +} + +func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) { + if nil == sh.servlets { + sh.servlets = make(map[string]Servlet) + } + sh.servlets[path] = servlet +} + +func (sh *ServerHandlers) Servlet(path string) Servlet { + var servlet Servlet + if path == "" && len(sh.servlets) == 1 { + for _, s := range sh.servlets { + servlet = s + } + } else if servlet = sh.servlets[path]; nil == servlet { + return nil + } + + return servlet +} + +func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool { + return true +} + +func (sh *ServerHandlers) Validate() error { + if err := sh.ServerHandlers.Validate(); nil != err { + return err + } + + return nil +} diff --git a/fasthttp/websocket/server.go b/fasthttp/websocket/server.go new file mode 100644 index 0000000..4e8dad1 --- /dev/null +++ b/fasthttp/websocket/server.go @@ -0,0 +1,286 @@ +package websocket + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" + "time" + + "git.loafle.net/commons/logging-go" + "git.loafle.net/commons/server-go" + "git.loafle.net/commons/server-go/internal" + "github.com/valyala/fasthttp" +) + +type Server struct { + ServerHandler ServerHandler + + ctx server.ServerCtx + + hs *fasthttp.Server + upgrader *Upgrader + + connections sync.Map + stopChan chan struct{} + stopWg sync.WaitGroup +} + +func (s *Server) ListenAndServe() error { + var ( + err error + listener net.Listener + ) + if nil == s.ServerHandler { + return fmt.Errorf("Server: server handler must be specified") + } + s.ServerHandler.Validate() + + if s.stopChan != nil { + return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again")) + } + + s.ctx = s.ServerHandler.ServerCtx() + if nil == s.ctx { + return fmt.Errorf(s.serverMessage("ServerCtx is nil")) + } + + s.hs = &fasthttp.Server{ + Handler: s.httpHandler, + Name: s.ServerHandler.GetName(), + Concurrency: s.ServerHandler.GetConcurrency(), + ReadBufferSize: s.ServerHandler.GetReadBufferSize(), + WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), + ReadTimeout: s.ServerHandler.GetReadTimeout(), + WriteTimeout: s.ServerHandler.GetWriteTimeout(), + } + + s.upgrader = &Upgrader{ + HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), + ReadBufferSize: s.ServerHandler.GetReadBufferSize(), + WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), + CheckOrigin: s.ServerHandler.CheckOrigin, + Error: s.onError, + EnableCompression: s.ServerHandler.IsEnableCompression(), + } + + if err = s.ServerHandler.Init(s.ctx); nil != err { + return err + } + + if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { + return err + } + + s.stopChan = make(chan struct{}) + s.stopWg.Add(1) + return s.handleServer(listener) +} + +func (s *Server) Shutdown(ctx context.Context) error { + if s.stopChan == nil { + return fmt.Errorf(s.serverMessage("server must be started before stopping it")) + } + close(s.stopChan) + s.stopWg.Wait() + + s.ServerHandler.Destroy(s.ctx) + + s.stopChan = nil + + return nil +} + +func (s *Server) ConnectionSize() int { + var sz int + s.connections.Range(func(k, v interface{}) bool { + sz++ + return true + }) + return sz +} + +func (s *Server) serverMessage(msg string) string { + return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) +} + +func (s *Server) handleServer(listener net.Listener) error { + var ( + err error + ) + + errChan := make(chan error) + + defer func() { + if nil != listener { + listener.Close() + } + s.stopWg.Done() + }() + + go func() { + if err := s.hs.Serve(listener); nil != err { + errChan <- err + return + } + close(errChan) + }() + + select { + case err, _ := <-errChan: + if nil != err { + return err + } + } + + defer func() { + s.ServerHandler.OnStop(s.ctx) + + logging.Logger().Infof(s.serverMessage("Stopped")) + }() + + if err = s.ServerHandler.OnStart(s.ctx); nil != err { + return err + } + + logging.Logger().Infof(s.serverMessage("Started")) + + select { + case <-s.stopChan: + listener.Close() + listener = nil + } + + return nil +} + +func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + var ( + servlet Servlet + err error + ) + + if 0 < s.ServerHandler.GetConcurrency() { + sz := s.ConnectionSize() + if sz >= s.ServerHandler.GetConcurrency() { + logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz))) + s.onError(ctx, fasthttp.StatusServiceUnavailable, err) + return + } + } + + if servlet = s.ServerHandler.Servlet(path); nil == servlet { + s.onError(ctx, fasthttp.StatusInternalServerError, err) + return + } + + var responseHeader *fasthttp.ResponseHeader + servletCtx := servlet.ServletCtx(s.ctx) + + if responseHeader, err = servlet.Handshake(servletCtx, ctx); nil != err { + s.onError(ctx, http.StatusNotAcceptable, fmt.Errorf("Handshake err: %v", err)) + return + } + + s.upgrader.Upgrade(ctx, responseHeader, func(conn *internal.Conn, err error) { + if err != nil { + s.onError(ctx, fasthttp.StatusInternalServerError, err) + return + } + + s.stopWg.Add(1) + go s.handleConnection(servlet, servletCtx, conn) + }) +} + +func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *internal.Conn) { + addr := conn.RemoteAddr() + + defer func() { + s.connections.Delete(conn) + conn.Close() + servlet.OnDisconnect(servletCtx) + logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) + s.stopWg.Done() + }() + + logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr))) + s.connections.Store(conn, true) + servlet.OnConnect(servletCtx, conn) + + servletStopChan := make(chan struct{}) + doneChan := make(chan struct{}) + + readChan := make(chan []byte) + writeChan := make(chan []byte) + + go servlet.Handle(servletCtx, doneChan, servletStopChan, readChan, writeChan) + go handleRead(s, conn, readChan) + go handleWrite(s, conn, writeChan) + + select { + case <-doneChan: + close(servletStopChan) + case <-s.stopChan: + close(servletStopChan) + <-doneChan + } +} + +func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { + s.ServerHandler.OnError(s.ctx, ctx, status, reason) +} + +func handleRead(s *Server, conn *internal.Conn, readChan chan []byte) { + conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) + conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout())) + return nil + }) + + for { + _, message, err := conn.ReadMessage() + if err != nil { + if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) { + logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) + } + break + } + readChan <- message + } +} + +func handleWrite(s *Server, conn *internal.Conn, writeChan chan []byte) { + ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) + defer func() { + ticker.Stop() + }() + for { + select { + case message, ok := <-writeChan: + conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout())) + if !ok { + conn.WriteMessage(internal.CloseMessage, []byte{}) + return + } + + w, err := conn.NextWriter(internal.TextMessage) + if err != nil { + return + } + w.Write(message) + + if err := w.Close(); nil != err { + return + } + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout())) + if err := conn.WriteMessage(internal.PingMessage, nil); nil != err { + return + } + } + } +} diff --git a/fasthttp/websocket/servlet.go b/fasthttp/websocket/servlet.go new file mode 100644 index 0000000..7ce2deb --- /dev/null +++ b/fasthttp/websocket/servlet.go @@ -0,0 +1,12 @@ +package websocket + +import ( + "git.loafle.net/commons/server-go" + "github.com/valyala/fasthttp" +) + +type Servlet interface { + server.Servlet + + Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) +} diff --git a/fasthttp/websocket/upgrade.go b/fasthttp/websocket/upgrade.go new file mode 100644 index 0000000..e68b7d4 --- /dev/null +++ b/fasthttp/websocket/upgrade.go @@ -0,0 +1,277 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "net" + "net/http" + "net/url" + "strings" + "time" + + "git.loafle.net/commons/server-go/internal" + "github.com/valyala/fasthttp" +) + +type ( + OnUpgradeFunc func(*internal.Conn, error) +) + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +// Upgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +type Upgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is set, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(ctx *fasthttp.RequestCtx, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, the host in the Origin header must not be set or + // must match the host of the request. + CheckOrigin func(ctx *fasthttp.RequestCtx) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool +} + +func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*internal.Conn, error) { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(ctx, status, err) + } else { + ctx.Response.Header.Set("Sec-Websocket-Version", "13") + ctx.Error(http.StatusText(status), status) + } + return nil, err +} + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +func checkSameOrigin(ctx *fasthttp.RequestCtx) bool { + origin := string(ctx.Request.Header.Peek("Origin")) + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin) + if err != nil { + return false + } + return u.Host == string(ctx.Host()) +} + +func (u *Upgrader) selectSubprotocol(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader) string { + if u.Subprotocols != nil { + clientProtocols := Subprotocols(ctx) + for _, serverProtocol := range u.Subprotocols { + for _, clientProtocol := range clientProtocols { + if clientProtocol == serverProtocol { + return clientProtocol + } + } + } + } else if responseHeader != nil { + return string(responseHeader.Peek("Sec-Websocket-Protocol")) + } + return "" +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// application negotiated subprotocol (Sec-Websocket-Protocol). +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, cb OnUpgradeFunc) { + if !ctx.IsGet() { + cb(u.returnError(ctx, fasthttp.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")) + return + } + + if nil != responseHeader { + if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil != v { + cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")) + return + } + } + + if !tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header")) + return + } + + if !tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header")) + return + } + + if !tokenListContainsValue(&ctx.Request.Header, "Sec-Websocket-Version", "13") { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")) + return + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(ctx) { + cb(u.returnError(ctx, fasthttp.StatusForbidden, "websocket: 'Origin' header value not allowed")) + return + } + + challengeKey := string(ctx.Request.Header.Peek("Sec-Websocket-Key")) + if challengeKey == "" { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank")) + return + } + + subprotocol := u.selectSubprotocol(ctx, responseHeader) + + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(&ctx.Request.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + + ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols) + ctx.Response.Header.Set("Upgrade", "websocket") + ctx.Response.Header.Set("Connection", "Upgrade") + ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + if subprotocol != "" { + ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol) + } + if compress { + ctx.Response.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + if nil != responseHeader { + responseHeader.VisitAll(func(key, value []byte) { + k := string(key) + v := string(value) + if k == "Sec-Websocket-Protocol" { + return + } + ctx.Response.Header.Set(k, v) + }) + } + + h := &fasthttp.RequestHeader{} + + //copy request headers in order to have access inside the Conn after + ctx.Request.Header.CopyTo(h) + + ctx.Hijack(func(netConn net.Conn) { + c := internal.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) + c.Subprotocol = subprotocol + if compress { + c.NewCompressionWriter = internal.CompressNoContextTakeover + c.NewDecompressionReader = internal.DecompressNoContextTakeover + } + + // Clear deadlines set by HTTP server. + netConn.SetDeadline(time.Time{}) + + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + } + + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Time{}) + } + + cb(c, nil) + }) +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// This function is deprecated, use websocket.Upgrader instead. +// +// The application is responsible for checking the request origin before +// calling Upgrade. An example implementation of the same origin policy is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", 403) +// return +// } +// +// If the endpoint supports subprotocols, then the application is responsible +// for negotiating the protocol used on the connection. Use the Subprotocols() +// function to get the subprotocols requested by the client. Use the +// Sec-Websocket-Protocol response header to specify the subprotocol selected +// by the application. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// negotiated subprotocol (Sec-Websocket-Protocol). +// +// The connection buffers IO to the underlying network connection. The +// readBufSize and writeBufSize parameters specify the size of the buffers to +// use. Messages can be larger than the buffers. +// +// If the request is not a valid WebSocket handshake, then Upgrade returns an +// error of type HandshakeError. Applications should handle this error by +// replying to the client with an HTTP error response. +func Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, readBufSize, writeBufSize int, cb OnUpgradeFunc) { + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} + u.Error = func(ctx *fasthttp.RequestCtx, status int, reason error) { + // don't return errors to maintain backwards compatibility + } + u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool { + // allow all connections by default + return true + } + u.Upgrade(ctx, responseHeader, cb) +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-Websocket-Protocol header. +func Subprotocols(ctx *fasthttp.RequestCtx) []string { + h := strings.TrimSpace(string(ctx.Request.Header.Peek("Sec-Websocket-Protocol"))) + if h == "" { + return nil + } + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(ctx *fasthttp.RequestCtx) bool { + return tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") && + tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") +} diff --git a/fasthttp/websocket/util.go b/fasthttp/websocket/util.go new file mode 100644 index 0000000..29e113b --- /dev/null +++ b/fasthttp/websocket/util.go @@ -0,0 +1,272 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "strings" + + "github.com/valyala/fasthttp" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// Octet types from RFC 2616. +var octetTypes [256]byte + +const ( + isTokenOctet = 1 << iota + isSpaceOctet +) + +func init() { + // From RFC 2616 + // + // OCTET = + // CHAR = + // CTL = + // CR = + // LF = + // SP = + // HT = + // <"> = + // CRLF = CR LF + // LWS = [CRLF] 1*( SP | HT ) + // TEXT = + // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT + // token = 1* + // qdtext = > + + for c := 0; c < 256; c++ { + var t byte + isCtl := c <= 31 || c == 127 + isChar := 0 <= c && c <= 127 + isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0 + if strings.IndexRune(" \t\r\n", rune(c)) >= 0 { + t |= isSpaceOctet + } + if isChar && !isCtl && !isSeparator { + t |= isTokenOctet + } + octetTypes[c] = t + } +} + +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isSpaceOctet == 0 { + break + } + } + return s[i:] +} + +func nextToken(s string) (token string, rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isTokenOctet == 0 { + break + } + } + return s[:i], s[i:] +} + +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains token. +func tokenListContainsValue(header *fasthttp.RequestHeader, name string, value string) bool { + s := string(header.Peek(name)) + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + break + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + break + } + if strings.EqualFold(t, value) { + return true + } + if s == "" { + break + } + s = s[1:] + } + return false +} + +// parseExtensiosn parses WebSocket extensions from a header. +func parseExtensions(header *fasthttp.RequestHeader) []map[string]string { + + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + s := string(header.Peek("Sec-Websocket-Extensions")) + var result []map[string]string +headers: + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + break headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + break headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + break headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + break headers + } + result = append(result, ext) + if s == "" { + break headers + } + s = s[1:] + } + + return result +} + +// parseExtensiosn parses WebSocket extensions from a header. +func httpParseExtensions(header http.Header) []map[string]string { + + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} diff --git a/internal/compression.go b/internal/compression.go new file mode 100644 index 0000000..7b5586d --- /dev/null +++ b/internal/compression.go @@ -0,0 +1,144 @@ +package internal + +import ( + "compress/flate" + "errors" + "io" + "strings" + "sync" +) + +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func DecompressNoContextTakeover(r io.Reader) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{fr} +} + +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func CompressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWriteWrapper struct { + fw *flate.Writer + tw *truncWriter + p *sync.Pool +} + +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } + return w.fw.Write(p) +} + +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } + err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/internal/conn-json.go b/internal/conn-json.go new file mode 100644 index 0000000..d003ac4 --- /dev/null +++ b/internal/conn-json.go @@ -0,0 +1,55 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "encoding/json" + "io" +) + +// WriteJSON is deprecated, use c.WriteJSON instead. +func WriteJSON(c *Conn, v interface{}) error { + return c.WriteJSON(v) +} + +// WriteJSON writes the JSON encoding of v to the connection. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *Conn) WriteJSON(v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON is deprecated, use c.ReadJSON instead. +func ReadJSON(c *Conn, v interface{}) error { + return c.ReadJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *Conn) ReadJSON(v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} diff --git a/internal/conn-read.go b/internal/conn-read.go new file mode 100644 index 0000000..ebd3ed2 --- /dev/null +++ b/internal/conn-read.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.5 + +package internal + +import "io" + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.BuffReader.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.BuffReader.Discard(len(p)) + return p, err +} diff --git a/internal/conn.go b/internal/conn.go new file mode 100644 index 0000000..c62fe39 --- /dev/null +++ b/internal/conn.go @@ -0,0 +1,1040 @@ +package internal + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strconv" + "sync" + "time" + "unicode/utf8" +) + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + +func hideTempErr(err error) error { + if e, ok := err.(net.Error); ok && e.Temporary() { + err = &netError{msg: e.Error(), timeout: e.Timeout()} + } + return err +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +// The Conn type represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + Subprotocol string + + // Write fields + mu chan bool // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + compressionLevel int + NewCompressionWriter func(io.WriteCloser, int) io.WriteCloser + + // Read fields + reader io.ReadCloser // the current reader returned to the application + readErr error + BuffReader *bufio.Reader + readRemaining int64 // bytes remaining in current frame. + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + handleClose func(int, string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + NewDecompressionReader func(io.Reader) io.ReadCloser +} + +func NewConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { + return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) +} + +type writeHook struct { + p []byte +} + +func (wh *writeHook) Write(p []byte) (int, error) { + wh.p = p + return len(p), nil +} + +func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { + mu := make(chan bool, 1) + mu <- true + + var br *bufio.Reader + if readBufferSize == 0 && brw != nil && brw.Reader != nil { + // Reuse the supplied bufio.Reader if the buffer has a useful size. + // This code assumes that peek on a reader returns + // bufio.Reader.buf[:0]. + brw.Reader.Reset(conn) + if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { + br = brw.Reader + } + } + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } + if readBufferSize < maxControlFramePayloadSize { + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + var writeBuf []byte + if writeBufferSize == 0 && brw != nil && brw.Writer != nil { + // Use the bufio.Writer's buffer if the buffer has a useful size. This + // code assumes that bufio.Writer.buf[:1] is passed to the + // bufio.Writer's underlying writer. + var wh writeHook + brw.Writer.Reset(&wh) + brw.Writer.WriteByte(0) + brw.Flush() + if cap(wh.p) >= maxFrameHeaderSize+256 { + writeBuf = wh.p[:cap(wh.p)] + } + } + + if writeBuf == nil { + if writeBufferSize == 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) + } + + c := &Conn{ + isServer: isServer, + BuffReader: br, + conn: conn, + mu: mu, + readFinal: true, + writeBuf: writeBuf, + enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, + } + c.SetCloseHandler(nil) + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// Close closes the underlying network connection without sending or waiting for a close frame. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { + <-c.mu + defer func() { c.mu <- true }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + for _, buf := range bufs { + if len(buf) > 0 { + _, err := c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + } + } + + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return nil +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + d := time.Hour * 1000 + if !deadline.IsZero() { + d = deadline.Sub(time.Now()) + if d < 0 { + return errWriteTimeout + } + } + + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + defer func() { c.mu <- true }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return err +} + +func (c *Conn) prepWrite(messageType int) error { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil + } + + if !isControl(messageType) && !isData(messageType) { + return errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + return err +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + if err := c.prepWrite(messageType); err != nil { + return nil, err + } + + mw := &messageWriter{ + c: c, + frameType: messageType, + pos: maxFrameHeaderSize, + } + c.writer = mw + if c.NewCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.NewCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) fatal(err error) error { + if w.err != nil { + w.err = err + w.c.writer = nil + } + return err +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(w.frameType) && + (!final || length > maxControlFramePayloadSize) { + return w.fatal(errInvalidControlFrame) + } + + b0 := byte(w.frameType) + if final { + b0 |= finalBit + } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) + if len(extra) > 0 { + return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + if err != nil { + return w.fatal(err) + } + + if final { + c.writer = nil + return nil + } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos + if n <= 0 { + if err := w.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.pos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err + } + for { + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + if err := w.flushFrame(true, nil); err != nil { + return err + } + w.err = errWriteClosed + return nil +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.NewCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + + if c.isServer && (c.NewCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + if err := c.prepWrite(messageType); err != nil { + return err + } + mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize} + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(ioutil.Discard, c.BuffReader, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + final := p[0]&finalBit != 0 + frameType := int(p[0] & 0xf) + mask := p[1]&maskBit != 0 + c.readRemaining = int64(p[1] & 0x7f) + + c.readDecompress = false + if c.NewDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + c.readDecompress = true + p[0] &^= rsv1Bit + } + + if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + return noFrame, c.handleProtocolError("control frame length > 125") + } + if !final { + return noFrame, c.handleProtocolError("control frame not final") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + return noFrame, c.handleProtocolError("message start before final message frame") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + return noFrame, c.handleProtocolError("continuation after final message frame") + } + c.readFinal = final + default: + return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + } + + // 3. Read and parse frame length. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint16(p)) + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint64(p)) + } + + // 4. Handle frame masking. + + if mask != c.isServer { + return noFrame, c.handleProtocolError("incorrect mask flag") + } + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + if c.readLimit > 0 && c.readLength > c.readLimit { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + c.readRemaining = 0 + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("invalid close code") + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = hideTempErr(err) + break + } + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + c.reader = c.messageReader + if c.readDecompress { + c.reader = c.NewDecompressionReader(c.reader) + } + return frameType, c.reader, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *Conn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.BuffReader.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + c.readRemaining -= int64(n) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = hideTempErr(err) + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +func (r *messageReader) Close() error { + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = ioutil.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size for a message read from the peer. If a +// message exceeds the limit, the connection sends a close frame to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close frame +// back to the peer. +// +// The application must read the connection to process close messages as +// described in the section on Control Frames above. +// +// The connection read methods return a CloseError when a close frame is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close frame back to +// the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := []byte{} + if code != CloseNoStatusReceived { + message = FormatCloseMessage(code, "") + } + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING frame application data. The default +// ping handler sends a pong to the peer. +// +// The application must read the connection to process ping messages as +// described in the section on Control Frames above. +func (c *Conn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + if err == ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG frame application data. The default +// pong handler does nothing. +// +// The application must read the connection to process ping messages as +// described in the section on Control Frames above. +func (c *Conn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +func (c *Conn) UnderlyingConn() net.Conn { + return c.conn +} + +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +func FormatCloseMessage(closeCode int, text string) []byte { + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/internal/error.go b/internal/error.go new file mode 100644 index 0000000..17890c8 --- /dev/null +++ b/internal/error.go @@ -0,0 +1,111 @@ +package internal + +import ( + "errors" + "io" + "strconv" +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("socket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("socket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents close frame. +type CloseError struct { + + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("socket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "socket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("socket: bad write message type") + errWriteClosed = errors.New("socket: write closed") + errInvalidControlFrame = errors.New("socket: invalid control frame") + // ErrBadHandshake is returned when the server response to opening handshake is + // invalid. + ErrBadHandshake = errors.New("socket: bad handshake") + ErrInvalidCompression = errors.New("socket: invalid compression negotiation") +) diff --git a/internal/mask.go b/internal/mask.go new file mode 100644 index 0000000..2f8acc6 --- /dev/null +++ b/internal/mask.go @@ -0,0 +1,49 @@ +package internal + +import "unsafe" + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytes(key [4]byte, pos int, b []byte) int { + + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/internal/prepared.go b/internal/prepared.go new file mode 100644 index 0000000..f5ccef6 --- /dev/null +++ b/internal/prepared.go @@ -0,0 +1,100 @@ +package internal + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + err error + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan bool, 1) + mu <- true + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + + if key.compress { + c.NewCompressionWriter = CompressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/client.go b/net/client.go similarity index 98% rename from client.go rename to net/client.go index d21faa5..20dd31d 100644 --- a/client.go +++ b/net/client.go @@ -1,4 +1,4 @@ -package server +package net import ( "crypto/tls" diff --git a/net/const.go b/net/const.go new file mode 100644 index 0000000..4f3d4ad --- /dev/null +++ b/net/const.go @@ -0,0 +1,7 @@ +package net + +const ( + DefaultHandshakeTimeout = 0 + + DefaultKeepAlive = 0 +) diff --git a/net/server-handler.go b/net/server-handler.go new file mode 100644 index 0000000..37e1557 --- /dev/null +++ b/net/server-handler.go @@ -0,0 +1,67 @@ +package net + +import ( + "net" + + "git.loafle.net/commons/server-go" +) + +type ServerHandler interface { + server.ServerHandler + + OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) + + RegisterServlet(servlet Servlet) + Servlet() Servlet +} + +type ServerHandlers struct { + server.ServerHandlers + + servlet Servlet +} + +func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.Init(serverCtx); nil != err { + return err + } + + if nil != sh.servlet { + if err := sh.servlet.Init(serverCtx); nil != err { + return err + } + } + + return nil +} + +func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { + if nil != sh.servlet { + sh.servlet.Destroy(serverCtx) + } + + sh.ServerHandlers.Destroy(serverCtx) +} + +func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) { +} + +func (sh *ServerHandlers) RegisterServlet(servlet Servlet) { + sh.servlet = servlet +} + +func (sh *ServerHandlers) Servlet() Servlet { + return sh.servlet +} + +func (sh *ServerHandlers) Validate() error { + if err := sh.ServerHandlers.Validate(); nil != err { + return err + } + + if 0 >= sh.Concurrency { + sh.Concurrency = 0 + } + + return nil +} diff --git a/net/server.go b/net/server.go new file mode 100644 index 0000000..01c08f5 --- /dev/null +++ b/net/server.go @@ -0,0 +1,321 @@ +package net + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "git.loafle.net/commons/logging-go" + "git.loafle.net/commons/server-go" + "git.loafle.net/commons/server-go/internal" +) + +type Server struct { + ServerHandler ServerHandler + + ctx server.ServerCtx + connections sync.Map + stopChan chan struct{} + stopWg sync.WaitGroup +} + +func (s *Server) ListenAndServe() error { + if s.stopChan != nil { + return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again")) + } + + var ( + err error + listener net.Listener + ) + if nil == s.ServerHandler { + return fmt.Errorf("Server: server handler must be specified") + } + if err = s.ServerHandler.Validate(); nil != err { + return err + } + + s.ctx = s.ServerHandler.ServerCtx() + if nil == s.ctx { + return fmt.Errorf(s.serverMessage("ServerCtx is nil")) + } + + if err = s.ServerHandler.Init(s.ctx); nil != err { + return err + } + + if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { + return err + } + + s.stopChan = make(chan struct{}) + s.stopWg.Add(1) + return s.handleServer(listener) +} + +func (s *Server) Shutdown(ctx context.Context) error { + if s.stopChan == nil { + return fmt.Errorf(s.serverMessage("server must be started before stopping it")) + } + close(s.stopChan) + s.stopWg.Wait() + + s.ServerHandler.Destroy(s.ctx) + + s.stopChan = nil + + return nil +} + +func (s *Server) ConnectionSize() int { + var sz int + s.connections.Range(func(k, v interface{}) bool { + sz++ + return true + }) + return sz +} + +func (s *Server) serverMessage(msg string) string { + return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) +} + +func (s *Server) handleServer(listener net.Listener) error { + var ( + stopping atomic.Value + netConn net.Conn + err error + ) + + defer func() { + if nil != listener { + listener.Close() + } + + s.ServerHandler.OnStop(s.ctx) + + logging.Logger().Infof(s.serverMessage("Stopped")) + + s.stopWg.Done() + }() + + if err = s.ServerHandler.OnStart(s.ctx); nil != err { + return err + } + + logging.Logger().Infof(s.serverMessage("Started")) + + for { + acceptChan := make(chan struct{}) + + go func() { + if netConn, err = listener.Accept(); err != nil { + if nil == stopping.Load() { + logging.Logger().Errorf(s.serverMessage(fmt.Sprintf("%v", err))) + } + } + close(acceptChan) + }() + + select { + case <-s.stopChan: + stopping.Store(true) + listener.Close() + <-acceptChan + listener = nil + return nil + case <-acceptChan: + } + + if nil != err { + select { + case <-s.stopChan: + return nil + case <-time.After(time.Second): + } + continue + } + + if 0 < s.ServerHandler.GetConcurrency() { + sz := s.ConnectionSize() + if sz >= s.ServerHandler.GetConcurrency() { + logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz))) + netConn.Close() + continue + } + } + + servlet := s.ServerHandler.Servlet() + if nil == servlet { + logging.Logger().Errorf(s.serverMessage("Servlet is nil")) + continue + } + + servletCtx := servlet.ServletCtx(s.ctx) + if nil == servletCtx { + logging.Logger().Errorf(s.serverMessage("ServletCtx is nil")) + continue + } + + if err := servlet.Handshake(servletCtx, netConn); nil != err { + logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err))) + continue + } + + conn := internal.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize()) + + s.stopWg.Add(1) + go s.handleConnection(servlet, servletCtx, conn) + } +} + +func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *internal.Conn) { + addr := conn.RemoteAddr() + + defer func() { + if nil != conn { + conn.Close() + } + logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) + s.stopWg.Done() + }() + + logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr))) + + s.connections.Store(conn, true) + defer s.connections.Delete(conn) + + servlet.OnConnect(servletCtx, conn) + + stopChan := make(chan struct{}) + servletDoneChan := make(chan struct{}) + + readChan := make(chan []byte, 256) + writeChan := make(chan []byte, 256) + + readerDoneChan := make(chan struct{}) + writerDoneChan := make(chan struct{}) + + go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) + go handleRead(s, conn, stopChan, readerDoneChan, readChan) + go handleWrite(s, conn, stopChan, writerDoneChan, writeChan) + + select { + case <-readerDoneChan: + close(stopChan) + conn.Close() + <-writerDoneChan + <-servletDoneChan + conn = nil + case <-writerDoneChan: + close(stopChan) + conn.Close() + <-readerDoneChan + <-servletDoneChan + conn = nil + case <-servletDoneChan: + close(stopChan) + conn.Close() + <-readerDoneChan + <-writerDoneChan + conn = nil + case <-s.stopChan: + close(stopChan) + conn.Close() + <-readerDoneChan + <-writerDoneChan + <-servletDoneChan + conn = nil + } +} + +func handleRead(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, readChan chan []byte) { + defer func() { + close(doneChan) + }() + + conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) + conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout())) + return nil + }) + + var ( + message []byte + err error + ) + + for { + readMessageChan := make(chan struct{}) + + go func() { + _, message, err = conn.ReadMessage() + if err != nil { + if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) { + logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) + } + } + close(readMessageChan) + }() + + select { + case <-s.stopChan: + <-readMessageChan + break + case <-readMessageChan: + } + + if nil != err { + select { + case <-s.stopChan: + break + case <-time.After(time.Second): + } + continue + } + + readChan <- message + } +} + +func handleWrite(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, writeChan chan []byte) { + defer func() { + close(doneChan) + }() + + ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) + defer func() { + ticker.Stop() + }() + for { + select { + case message, ok := <-writeChan: + conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout())) + if !ok { + conn.WriteMessage(internal.CloseMessage, []byte{}) + return + } + + w, err := conn.NextWriter(internal.TextMessage) + if err != nil { + return + } + w.Write(message) + + if err := w.Close(); nil != err { + return + } + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout())) + if err := conn.WriteMessage(internal.PingMessage, nil); nil != err { + return + } + case <-s.stopChan: + break + } + } +} diff --git a/net/servlet.go b/net/servlet.go new file mode 100644 index 0000000..19db5f4 --- /dev/null +++ b/net/servlet.go @@ -0,0 +1,13 @@ +package net + +import ( + "net" + + "git.loafle.net/commons/server-go" +) + +type Servlet interface { + server.Servlet + + Handshake(servletCtx server.ServletCtx, conn net.Conn) error +} diff --git a/server-context.go b/server-context.go index a405214..4551a8c 100644 --- a/server-context.go +++ b/server-context.go @@ -4,16 +4,16 @@ import ( cuc "git.loafle.net/commons/util-go/context" ) -type ServerContext interface { +type ServerCtx interface { cuc.Context } -func NewServerContext(parent cuc.Context) ServerContext { - return &serverContext{ - ServerContext: cuc.NewContext(parent), +func NewServerCtx(parent cuc.Context) ServerCtx { + return &serverCtx{ + Context: cuc.NewContext(parent), } } -type serverContext struct { - ServerContext +type serverCtx struct { + cuc.Context } diff --git a/server-handler.go b/server-handler.go index 87c762f..7bcd61e 100644 --- a/server-handler.go +++ b/server-handler.go @@ -3,70 +3,188 @@ package server import ( "fmt" "net" + "time" ) type ServerHandler interface { GetName() string - GetMaxConnections() int + GetConcurrency() int + GetHandshakeTimeout() time.Duration + GetMaxMessageSize() int64 + GetReadBufferSize() int + GetWriteBufferSize() int + GetReadTimeout() time.Duration + GetWriteTimeout() time.Duration + GetPongTimeout() time.Duration + GetPingTimeout() time.Duration + GetPingPeriod() time.Duration - ServerContext() ServerContext + IsEnableCompression() bool - Init(serverCTX ServerContext) error - OnStart(serverCTX ServerContext) error - OnError(serverCTX ServerContext, conn net.Conn, status int, reason error) - OnStop(serverCTX ServerContext) - Destroy(serverCTX ServerContext) + ServerCtx() ServerCtx - Listen(serverCTX ServerContext) (net.Listener, error) - Servlet() Servlet + Init(serverCtx ServerCtx) error + OnStart(serverCtx ServerCtx) error + OnStop(serverCtx ServerCtx) + Destroy(serverCtx ServerCtx) - Validate() + Listener(serverCtx ServerCtx) (net.Listener, error) + + Validate() error } type ServerHandlers struct { ServerHandler - Name string - MaxConnections int + // Server name for sending in response headers. + // + // Default server name is used if left blank. + Name string + + // The maximum number of concurrent connections the server may serve. + // + // DefaultConcurrency is used if not set. + Concurrency int + + HandshakeTimeout time.Duration + + MaxMessageSize int64 + // Per-connection buffer size for requests' reading. + // This also limits the maximum header size. + // + // Increase this buffer if your clients send multi-KB RequestURIs + // and/or multi-KB headers (for example, BIG cookies). + // + // Default buffer size is used if not set. + ReadBufferSize int + // Per-connection buffer size for responses' writing. + // + // Default buffer size is used if not set. + WriteBufferSize int + // Maximum duration for reading the full request (including body). + // + // This also limits the maximum duration for idle keep-alive + // connections. + // + // By default request read timeout is unlimited. + ReadTimeout time.Duration + + // Maximum duration for writing the full response (including body). + // + // By default response write timeout is unlimited. + WriteTimeout time.Duration + + PongTimeout time.Duration + PingTimeout time.Duration + PingPeriod time.Duration + + EnableCompression bool } -func (sh *ServerHandlers) ServerContext() ServerContext { +func (sh *ServerHandlers) ServerCtx() ServerCtx { return nil } -func (sh *ServerHandlers) Init(serverCTX ServerContext) error { +func (sh *ServerHandlers) Init(serverCtx ServerCtx) error { return nil } -func (sh *ServerHandlers) OnStart(serverCTX ServerContext) error { +func (sh *ServerHandlers) OnStart(serverCtx ServerCtx) error { return nil } -func (sh *ServerHandlers) OnError(serverCTX ServerContext, conn net.Conn, status int, reason error) { -} - -func (sh *ServerHandlers) OnStop(serverCTX ServerContext) { +func (sh *ServerHandlers) OnStop(serverCtx ServerCtx) { } -func (sh *ServerHandlers) Destroy(serverCTX ServerContext) { +func (sh *ServerHandlers) Destroy(serverCtx ServerCtx) { } -func (sh *ServerHandlers) Listen(serverCTX ServerContext) (net.Listener, error) { - return nil, fmt.Errorf("Server: Method[ServerHandler.Listen] is not implemented") +func (sh *ServerHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) { + return nil, fmt.Errorf("Server: Method[ServerHandler.Listener] is not implemented") } -func (sh *ServerHandlers) Servlet() Servlet { - return nil +func (sh *ServerHandlers) GetName() string { + return sh.Name } -func (sh *ServerHandlers) Validate() { +func (sh *ServerHandlers) GetConcurrency() int { + return sh.Concurrency +} + +func (sh *ServerHandlers) GetHandshakeTimeout() time.Duration { + return sh.HandshakeTimeout +} + +func (sh *ServerHandlers) GetMaxMessageSize() int64 { + return sh.MaxMessageSize +} + +func (sh *ServerHandlers) GetReadBufferSize() int { + return sh.ReadBufferSize +} + +func (sh *ServerHandlers) GetWriteBufferSize() int { + return sh.WriteBufferSize +} + +func (sh *ServerHandlers) GetReadTimeout() time.Duration { + return sh.ReadTimeout +} +func (sh *ServerHandlers) GetWriteTimeout() time.Duration { + return sh.WriteTimeout +} +func (sh *ServerHandlers) GetPongTimeout() time.Duration { + return sh.PongTimeout +} +func (sh *ServerHandlers) GetPingTimeout() time.Duration { + return sh.PingTimeout +} +func (sh *ServerHandlers) GetPingPeriod() time.Duration { + return sh.PingPeriod +} + +func (sh *ServerHandlers) IsEnableCompression() bool { + return sh.EnableCompression +} + +func (sh *ServerHandlers) Validate() error { if "" == sh.Name { sh.Name = "Server" } - if 0 >= sh.MaxConnections { - sh.MaxConnections = 0 + if sh.Concurrency <= 0 { + sh.Concurrency = DefaultConcurrency } + + if sh.HandshakeTimeout <= 0 { + sh.HandshakeTimeout = DefaultHandshakeTimeout + } + if sh.MaxMessageSize <= 0 { + sh.MaxMessageSize = DefaultMaxMessageSize + } + if sh.ReadBufferSize <= 0 { + sh.ReadBufferSize = DefaultReadBufferSize + } + if sh.WriteBufferSize <= 0 { + sh.WriteBufferSize = DefaultWriteBufferSize + } + if sh.ReadTimeout <= 0 { + sh.ReadTimeout = DefaultReadTimeout + } + if sh.WriteTimeout <= 0 { + sh.WriteTimeout = DefaultWriteTimeout + } + if sh.PongTimeout <= 0 { + sh.PongTimeout = DefaultPongTimeout + } + if sh.PingTimeout <= 0 { + sh.PingTimeout = DefaultPingTimeout + } + if sh.PingPeriod <= 0 { + sh.PingPeriod = DefaultPingPeriod + } + + return nil } diff --git a/server.go b/server.go deleted file mode 100644 index fae6f5e..0000000 --- a/server.go +++ /dev/null @@ -1,166 +0,0 @@ -package server - -import ( - "context" - "fmt" - "net" - "sync" - "sync/atomic" - "time" - - "git.loafle.net/commons/logging-go" -) - -type Server struct { - ServerHandler ServerHandler - - ctx ServerContext - servlets sync.Map - stopChan chan struct{} - stopWg sync.WaitGroup -} - -func (s *Server) ListenAndServe() error { - if s.stopChan != nil { - return fmt.Errorf("Server: server is already running. Stop it before starting it again") - } - - var ( - err error - listener net.Listener - ) - if nil == s.ServerHandler { - panic("Server: server handler must be specified.") - } - s.ServerHandler.Validate() - - s.ctx = s.ServerHandler.ServerContext() - if nil == s.ctx { - return fmt.Errorf("Server: ServerContext is nil") - } - - if err = s.ServerHandler.Init(s.ctx); nil != err { - return fmt.Errorf("Server: Initialization of server has been failed %v", err) - } - - if listener, err = s.ServerHandler.Listen(s.ctx); nil != err { - return err - } - - s.stopChan = make(chan struct{}) - s.stopWg.Add(1) - return s.handleLoop(listener) -} - -func (s *Server) Shutdown(ctx context.Context) error { - if s.stopChan == nil { - return fmt.Errorf("Server: server must be started before stopping it") - } - close(s.stopChan) - s.stopWg.Wait() - - s.ServerHandler.Destroy(s.ctx) - - s.stopChan = nil - - return nil -} - -func (s *Server) ConnectionSize() int { - var sz int - s.servlets.Range(func(k, v interface{}) bool { - sz++ - return true - }) - return sz -} - -func (s *Server) handleLoop(listener net.Listener) error { - var ( - stopping atomic.Value - conn net.Conn - err error - ) - - defer func() { - s.ServerHandler.OnStop(s.ctx) - - s.stopWg.Done() - }() - - if err = s.ServerHandler.OnStart(s.ctx); nil != err { - return err - } - - for { - acceptChan := make(chan struct{}) - - go func() { - if conn, err = listener.Accept(); err != nil { - if nil == stopping.Load() { - logging.Logger().Errorf("Server: Cannot accept new connection: [%s]", err) - } - } - close(acceptChan) - }() - - select { - case <-s.stopChan: - stopping.Store(true) - listener.Close() - <-acceptChan - return nil - case <-acceptChan: - } - - if nil != err { - select { - case <-s.stopChan: - return nil - case <-time.After(time.Second): - } - continue - } - - if 0 < s.ServerHandler.GetMaxConnections() { - sz := s.ConnectionSize() - if sz >= s.ServerHandler.GetMaxConnections() { - logging.Logger().Warnf("max connections size %d, refuse\n", sz) - conn.Close() - continue - } - } - - s.stopWg.Add(1) - go s.handleConnection(conn) - } -} - -func (s *Server) handleConnection(conn net.Conn) { - servlet := s.ServerHandler.Servlet() - - defer func() { - s.servlets.Delete(servlet) - s.stopWg.Done() - }() - - if nil == servlet { - logging.Logger().Errorf("Server: Servlet is nil") - } - s.servlets.Store(servlet, true) - - servletStopChan := make(chan struct{}) - doneChan := make(chan struct{}) - - go servlet.Handle(s.ctx, conn, doneChan, servletStopChan) - - select { - case <-doneChan: - close(servletStopChan) - conn.Close() - case <-s.stopChan: - close(servletStopChan) - conn.Close() - <-doneChan - } -} diff --git a/servlet-context.go b/servlet-context.go new file mode 100644 index 0000000..ce7939d --- /dev/null +++ b/servlet-context.go @@ -0,0 +1,28 @@ +package server + +import ( + cuc "git.loafle.net/commons/util-go/context" +) + +type ServletCtx interface { + cuc.Context + + ServerCtx() ServerCtx +} + +func NewServletContext(parent cuc.Context, serverCtx ServerCtx) ServletCtx { + return &servletCtx{ + Context: cuc.NewContext(parent), + serverCtx: serverCtx, + } +} + +type servletCtx struct { + cuc.Context + + serverCtx ServerCtx +} + +func (sc *servletCtx) ServerCtx() ServerCtx { + return sc.serverCtx +} diff --git a/servlet.go b/servlet.go index 7853f72..794445e 100644 --- a/servlet.go +++ b/servlet.go @@ -1,7 +1,16 @@ package server -import "net" +import ( + "git.loafle.net/commons/server-go/internal" +) type Servlet interface { - Handle(serverCTX ServerContext, conn net.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) + ServletCtx(serverCtx ServerCtx) ServletCtx + + Init(serverCtx ServerCtx) error + Destroy(serverCtx ServerCtx) + + OnConnect(servletCtx ServletCtx, conn *internal.Conn) + Handle(servletCtx ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) + OnDisconnect(servletCtx ServletCtx) }