This commit is contained in:
crusader 2018-04-04 13:01:26 +09:00
parent 16c6290f63
commit f85b949e66
24 changed files with 3495 additions and 203 deletions

View File

@ -1,7 +1,29 @@
package server
const (
// DefaultConcurrency is the maximum number of concurrent connections
// the Server may serve by default (i.e. if Server.Concurrency isn't set).
DefaultConcurrency = 256 * 1024
// DefaultHandshakeTimeout is default value of websocket handshake Timeout
DefaultHandshakeTimeout = 0
DefaultKeepAlive = 0
// DefaultReadBufferSize is default value of Read Buffer Size
DefaultReadBufferSize = 0
// DefaultWriteBufferSize is default value of Write Buffer Size
DefaultWriteBufferSize = 0
// DefaultReadTimeout is default value of read timeout
DefaultReadTimeout = 0
// DefaultWriteTimeout is default value of write timeout
DefaultWriteTimeout = 0
// DefaultEnableCompression is default value of support compression
DefaultEnableCompression = false
// DefaultMaxMessageSize is default size for a message read from the peer
DefaultMaxMessageSize = 4096
// DefaultPongTimeout is default value of websocket pong Timeout
DefaultPongTimeout = 0
// DefaultPingTimeout is default value of websocket ping Timeout
DefaultPingTimeout = 0
// DefaultPingPeriod is default value of send ping period
DefaultPingPeriod = 0
)

View File

@ -0,0 +1,408 @@
package websocket
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
server "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/internal"
)
var errMalformedURL = errors.New("malformed ws or wss URL")
type Client struct {
Name string
URL string
RequestHeader http.Header
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)
TLSConfig *tls.Config
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
CookieJar http.CookieJar
// MaxMessageSize is the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application.
MaxMessageSize int64
// WriteTimeout is the write deadline on the underlying network
// connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will
// not time out.
WriteTimeout time.Duration
// ReadTimeout is the read deadline on the underlying network connection.
// After a read has timed out, the websocket connection state is corrupt and
// all future reads will return an error. A zero value for t means reads will
// not time out.
ReadTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
serverURL *url.URL
}
func (c *Client) Dial() (*internal.Conn, *http.Response, error) {
var (
err error
challengeKey string
netConn net.Conn
)
if err = c.Validate(); nil != err {
return nil, nil, err
}
challengeKey, err = generateChallengeKey()
if err != nil {
return nil, nil, err
}
req := &http.Request{
Method: "GET",
URL: c.serverURL,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: c.serverURL.Host,
}
// Set the cookies present in the cookie jar of the dialer
if nil != c.CookieJar {
for _, cookie := range c.CookieJar.Cookies(c.serverURL) {
req.AddCookie(cookie)
}
}
// Set the request headers using the capitalization for names and values in
// RFC examples. Although the capitalization shouldn't matter, there are
// servers that depend on it. The Header.Set method is not used because the
// method canonicalizes the header names.
req.Header["Upgrade"] = []string{"websocket"}
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(c.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(c.Subprotocols, ", ")}
}
for k, vs := range c.RequestHeader {
switch {
case k == "Host":
if len(vs) > 0 {
req.Host = vs[0]
}
case k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
default:
req.Header[k] = vs
}
}
if c.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
hostPort, hostNoPort := hostPortNoPort(c.serverURL)
var proxyURL *url.URL
// Check wether the proxy method has been configured
if nil != c.Proxy {
proxyURL, err = c.Proxy(req)
if err != nil {
return nil, nil, err
}
}
var targetHostPort string
if proxyURL != nil {
targetHostPort, _ = hostPortNoPort(proxyURL)
} else {
targetHostPort = hostPort
}
var deadline time.Time
if 0 != c.HandshakeTimeout {
deadline = time.Now().Add(c.HandshakeTimeout)
}
netDial := c.NetDial
if netDial == nil {
netDialer := &net.Dialer{Deadline: deadline}
netDial = netDialer.Dial
}
netConn, err = netDial("tcp", targetHostPort)
if err != nil {
return nil, nil, err
}
defer func() {
if nil != netConn {
netConn.Close()
}
}()
err = netConn.SetDeadline(deadline)
if nil != err {
return nil, nil, err
}
if nil != proxyURL {
connectHeader := make(http.Header)
if user := proxyURL.User; nil != user {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: connectHeader,
}
connectReq.Write(netConn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}
if "https" == c.serverURL.Scheme {
cfg := cloneTLSConfig(c.TLSConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
return nil, nil, err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return nil, nil, err
}
}
}
conn := internal.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize)
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
resp, err := http.ReadResponse(conn.BuffReader, req)
if err != nil {
return nil, nil, err
}
if nil != c.CookieJar {
if rc := resp.Cookies(); len(rc) > 0 {
c.CookieJar.SetCookies(c.serverURL, rc)
}
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, internal.ErrBadHandshake
}
for _, ext := range httpParseExtensions(resp.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
return nil, resp, internal.ErrInvalidCompression
}
conn.NewCompressionWriter = internal.CompressNoContextTakeover
conn.NewDecompressionReader = internal.DecompressNoContextTakeover
break
}
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.Subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil
}
func (c *Client) Validate() error {
if "" == c.Name {
c.Name = "Client"
}
if "" == c.URL {
return fmt.Errorf("Client: URL is not valid")
}
u, err := parseURL(c.URL)
if nil != err {
return err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return errMalformedURL
}
if nil != u.User {
// User name and password are not allowed in websocket URIs.
return errMalformedURL
}
c.serverURL = u
if nil == c.Proxy {
c.Proxy = http.ProxyFromEnvironment
}
if 0 > c.HandshakeTimeout {
c.HandshakeTimeout = server.DefaultHandshakeTimeout
}
if 0 > c.ReadBufferSize {
c.ReadBufferSize = server.DefaultReadBufferSize
}
if 0 > c.WriteBufferSize {
c.WriteBufferSize = server.DefaultWriteBufferSize
}
return nil
}
// parseURL parses the URL.
//
// This function is a replacement for the standard library url.Parse function.
// In Go 1.4 and earlier, url.Parse loses information from the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC:
//
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
var u url.URL
switch {
case strings.HasPrefix(s, "ws://"):
u.Scheme = "ws"
s = s[len("ws://"):]
case strings.HasPrefix(s, "wss://"):
u.Scheme = "wss"
s = s[len("wss://"):]
default:
return nil, errMalformedURL
}
if i := strings.Index(s, "?"); i >= 0 {
u.RawQuery = s[i+1:]
s = s[:i]
}
if i := strings.Index(s, "/"); i >= 0 {
u.Opaque = s[i:]
s = s[:i]
} else {
u.Opaque = "/"
}
u.Host = s
if strings.Contains(u.Host, "@") {
// Don't bother parsing user information because user information is
// not allowed in websocket URIs.
return nil, errMalformedURL
}
return &u, nil
}
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
switch u.Scheme {
case "wss":
hostPort += ":443"
case "https":
hostPort += ":443"
default:
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View File

@ -0,0 +1,101 @@
package websocket
import (
"net/http"
"git.loafle.net/commons/server-go"
"github.com/valyala/fasthttp"
)
type ServerHandler interface {
server.ServerHandler
OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error)
RegisterServlet(path string, servlet Servlet)
Servlet(path string) Servlet
CheckOrigin(ctx *fasthttp.RequestCtx) bool
}
type ServerHandlers struct {
server.ServerHandlers
servlets map[string]Servlet
}
func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error {
if err := sh.ServerHandlers.Init(serverCtx); nil != err {
return err
}
if nil != sh.servlets {
for _, servlet := range sh.servlets {
if err := servlet.Init(serverCtx); nil != err {
return err
}
}
}
return nil
}
func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) {
if nil != sh.servlets {
for _, servlet := range sh.servlets {
servlet.Destroy(serverCtx)
}
}
sh.ServerHandlers.Destroy(serverCtx)
}
func (sh *ServerHandlers) OnPing(msg string) error {
return nil
}
func (sh *ServerHandlers) OnPong(msg string) error {
return nil
}
func (sh *ServerHandlers) OnClose(code int, text string) error {
return nil
}
func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) {
ctx.Response.Header.Set("Sec-Websocket-Version", "13")
ctx.Error(http.StatusText(status), status)
}
func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) {
if nil == sh.servlets {
sh.servlets = make(map[string]Servlet)
}
sh.servlets[path] = servlet
}
func (sh *ServerHandlers) Servlet(path string) Servlet {
var servlet Servlet
if path == "" && len(sh.servlets) == 1 {
for _, s := range sh.servlets {
servlet = s
}
} else if servlet = sh.servlets[path]; nil == servlet {
return nil
}
return servlet
}
func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool {
return true
}
func (sh *ServerHandlers) Validate() error {
if err := sh.ServerHandlers.Validate(); nil != err {
return err
}
return nil
}

View File

@ -0,0 +1,286 @@
package websocket
import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"
"git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/internal"
"github.com/valyala/fasthttp"
)
type Server struct {
ServerHandler ServerHandler
ctx server.ServerCtx
hs *fasthttp.Server
upgrader *Upgrader
connections sync.Map
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (s *Server) ListenAndServe() error {
var (
err error
listener net.Listener
)
if nil == s.ServerHandler {
return fmt.Errorf("Server: server handler must be specified")
}
s.ServerHandler.Validate()
if s.stopChan != nil {
return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again"))
}
s.ctx = s.ServerHandler.ServerCtx()
if nil == s.ctx {
return fmt.Errorf(s.serverMessage("ServerCtx is nil"))
}
s.hs = &fasthttp.Server{
Handler: s.httpHandler,
Name: s.ServerHandler.GetName(),
Concurrency: s.ServerHandler.GetConcurrency(),
ReadBufferSize: s.ServerHandler.GetReadBufferSize(),
WriteBufferSize: s.ServerHandler.GetWriteBufferSize(),
ReadTimeout: s.ServerHandler.GetReadTimeout(),
WriteTimeout: s.ServerHandler.GetWriteTimeout(),
}
s.upgrader = &Upgrader{
HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(),
ReadBufferSize: s.ServerHandler.GetReadBufferSize(),
WriteBufferSize: s.ServerHandler.GetWriteBufferSize(),
CheckOrigin: s.ServerHandler.CheckOrigin,
Error: s.onError,
EnableCompression: s.ServerHandler.IsEnableCompression(),
}
if err = s.ServerHandler.Init(s.ctx); nil != err {
return err
}
if listener, err = s.ServerHandler.Listener(s.ctx); nil != err {
return err
}
s.stopChan = make(chan struct{})
s.stopWg.Add(1)
return s.handleServer(listener)
}
func (s *Server) Shutdown(ctx context.Context) error {
if s.stopChan == nil {
return fmt.Errorf(s.serverMessage("server must be started before stopping it"))
}
close(s.stopChan)
s.stopWg.Wait()
s.ServerHandler.Destroy(s.ctx)
s.stopChan = nil
return nil
}
func (s *Server) ConnectionSize() int {
var sz int
s.connections.Range(func(k, v interface{}) bool {
sz++
return true
})
return sz
}
func (s *Server) serverMessage(msg string) string {
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
}
func (s *Server) handleServer(listener net.Listener) error {
var (
err error
)
errChan := make(chan error)
defer func() {
if nil != listener {
listener.Close()
}
s.stopWg.Done()
}()
go func() {
if err := s.hs.Serve(listener); nil != err {
errChan <- err
return
}
close(errChan)
}()
select {
case err, _ := <-errChan:
if nil != err {
return err
}
}
defer func() {
s.ServerHandler.OnStop(s.ctx)
logging.Logger().Infof(s.serverMessage("Stopped"))
}()
if err = s.ServerHandler.OnStart(s.ctx); nil != err {
return err
}
logging.Logger().Infof(s.serverMessage("Started"))
select {
case <-s.stopChan:
listener.Close()
listener = nil
}
return nil
}
func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
var (
servlet Servlet
err error
)
if 0 < s.ServerHandler.GetConcurrency() {
sz := s.ConnectionSize()
if sz >= s.ServerHandler.GetConcurrency() {
logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz)))
s.onError(ctx, fasthttp.StatusServiceUnavailable, err)
return
}
}
if servlet = s.ServerHandler.Servlet(path); nil == servlet {
s.onError(ctx, fasthttp.StatusInternalServerError, err)
return
}
var responseHeader *fasthttp.ResponseHeader
servletCtx := servlet.ServletCtx(s.ctx)
if responseHeader, err = servlet.Handshake(servletCtx, ctx); nil != err {
s.onError(ctx, http.StatusNotAcceptable, fmt.Errorf("Handshake err: %v", err))
return
}
s.upgrader.Upgrade(ctx, responseHeader, func(conn *internal.Conn, err error) {
if err != nil {
s.onError(ctx, fasthttp.StatusInternalServerError, err)
return
}
s.stopWg.Add(1)
go s.handleConnection(servlet, servletCtx, conn)
})
}
func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *internal.Conn) {
addr := conn.RemoteAddr()
defer func() {
s.connections.Delete(conn)
conn.Close()
servlet.OnDisconnect(servletCtx)
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr)))
s.stopWg.Done()
}()
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr)))
s.connections.Store(conn, true)
servlet.OnConnect(servletCtx, conn)
servletStopChan := make(chan struct{})
doneChan := make(chan struct{})
readChan := make(chan []byte)
writeChan := make(chan []byte)
go servlet.Handle(servletCtx, doneChan, servletStopChan, readChan, writeChan)
go handleRead(s, conn, readChan)
go handleWrite(s, conn, writeChan)
select {
case <-doneChan:
close(servletStopChan)
case <-s.stopChan:
close(servletStopChan)
<-doneChan
}
}
func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) {
s.ServerHandler.OnError(s.ctx, ctx, status, reason)
}
func handleRead(s *Server, conn *internal.Conn, readChan chan []byte) {
conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize())
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout()))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout()))
return nil
})
for {
_, message, err := conn.ReadMessage()
if err != nil {
if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) {
logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err)))
}
break
}
readChan <- message
}
}
func handleWrite(s *Server, conn *internal.Conn, writeChan chan []byte) {
ticker := time.NewTicker(s.ServerHandler.GetPingPeriod())
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-writeChan:
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout()))
if !ok {
conn.WriteMessage(internal.CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(internal.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout()))
if err := conn.WriteMessage(internal.PingMessage, nil); nil != err {
return
}
}
}
}

View File

@ -0,0 +1,12 @@
package websocket
import (
"git.loafle.net/commons/server-go"
"github.com/valyala/fasthttp"
)
type Servlet interface {
server.Servlet
Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error)
}

View File

@ -0,0 +1,277 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"net"
"net/http"
"net/url"
"strings"
"time"
"git.loafle.net/commons/server-go/internal"
"github.com/valyala/fasthttp"
)
type (
OnUpgradeFunc func(*internal.Conn, error)
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
}
func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client.
Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(ctx *fasthttp.RequestCtx, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or
// must match the host of the request.
CheckOrigin func(ctx *fasthttp.RequestCtx) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}
func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*internal.Conn, error) {
err := HandshakeError{reason}
if u.Error != nil {
u.Error(ctx, status, err)
} else {
ctx.Response.Header.Set("Sec-Websocket-Version", "13")
ctx.Error(http.StatusText(status), status)
}
return nil, err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(ctx *fasthttp.RequestCtx) bool {
origin := string(ctx.Request.Header.Peek("Origin"))
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin)
if err != nil {
return false
}
return u.Host == string(ctx.Host())
}
func (u *Upgrader) selectSubprotocol(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(ctx)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return string(responseHeader.Peek("Sec-Websocket-Protocol"))
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol).
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, cb OnUpgradeFunc) {
if !ctx.IsGet() {
cb(u.returnError(ctx, fasthttp.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET"))
return
}
if nil != responseHeader {
if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil != v {
cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported"))
return
}
}
if !tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header"))
return
}
if !tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header"))
return
}
if !tokenListContainsValue(&ctx.Request.Header, "Sec-Websocket-Version", "13") {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header"))
return
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(ctx) {
cb(u.returnError(ctx, fasthttp.StatusForbidden, "websocket: 'Origin' header value not allowed"))
return
}
challengeKey := string(ctx.Request.Header.Peek("Sec-Websocket-Key"))
if challengeKey == "" {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank"))
return
}
subprotocol := u.selectSubprotocol(ctx, responseHeader)
// Negotiate PMCE
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(&ctx.Request.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols)
ctx.Response.Header.Set("Upgrade", "websocket")
ctx.Response.Header.Set("Connection", "Upgrade")
ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
if subprotocol != "" {
ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol)
}
if compress {
ctx.Response.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
if nil != responseHeader {
responseHeader.VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
if k == "Sec-Websocket-Protocol" {
return
}
ctx.Response.Header.Set(k, v)
})
}
h := &fasthttp.RequestHeader{}
//copy request headers in order to have access inside the Conn after
ctx.Request.Header.CopyTo(h)
ctx.Hijack(func(netConn net.Conn) {
c := internal.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.Subprotocol = subprotocol
if compress {
c.NewCompressionWriter = internal.CompressNoContextTakeover
c.NewDecompressionReader = internal.DecompressNoContextTakeover
}
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
cb(c, nil)
})
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// This function is deprecated, use websocket.Upgrader instead.
//
// The application is responsible for checking the request origin before
// calling Upgrade. An example implementation of the same origin policy is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, readBufSize, writeBufSize int, cb OnUpgradeFunc) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(ctx *fasthttp.RequestCtx, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool {
// allow all connections by default
return true
}
u.Upgrade(ctx, responseHeader, cb)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(ctx *fasthttp.RequestCtx) []string {
h := strings.TrimSpace(string(ctx.Request.Header.Peek("Sec-Websocket-Protocol")))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(ctx *fasthttp.RequestCtx) bool {
return tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") &&
tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket")
}

272
fasthttp/websocket/util.go Normal file
View File

@ -0,0 +1,272 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
"strings"
"github.com/valyala/fasthttp"
)
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func generateChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Octet types from RFC 2616.
var octetTypes [256]byte
const (
isTokenOctet = 1 << iota
isSpaceOctet
)
func init() {
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
break
}
}
return s[i:]
}
func nextToken(s string) (token string, rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 {
break
}
}
return s[:i], s[i:]
}
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j++
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j++
}
}
return "", ""
}
}
return "", ""
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header *fasthttp.RequestHeader, name string, value string) bool {
s := string(header.Peek(name))
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
break
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
break
}
if strings.EqualFold(t, value) {
return true
}
if s == "" {
break
}
s = s[1:]
}
return false
}
// parseExtensiosn parses WebSocket extensions from a header.
func parseExtensions(header *fasthttp.RequestHeader) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
s := string(header.Peek("Sec-Websocket-Extensions"))
var result []map[string]string
headers:
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
break headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
break headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
break headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
break headers
}
result = append(result, ext)
if s == "" {
break headers
}
s = s[1:]
}
return result
}
// parseExtensiosn parses WebSocket extensions from a header.
func httpParseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}

144
internal/compression.go Normal file
View File

@ -0,0 +1,144 @@
package internal
import (
"compress/flate"
"errors"
"io"
"strings"
"sync"
)
const (
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression
defaultCompressionLevel = 1
)
var (
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
flateReaderPool = sync.Pool{New: func() interface{} {
return flate.NewReader(nil)
}}
)
func DecompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr}
}
func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}
func CompressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
p *sync.Pool
}
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil {
return 0, errWriteClosed
}
return w.fw.Write(p)
}
func (w *flateWriteWrapper) Close() error {
if w.fw == nil {
return errWriteClosed
}
err1 := w.fw.Flush()
w.p.Put(w.fw)
w.fw = nil
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
type flateReadWrapper struct {
fr io.ReadCloser
}
func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}
func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}

55
internal/conn-json.go Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package internal
import (
"encoding/json"
"io"
)
// WriteJSON is deprecated, use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v)
}
// WriteJSON writes the JSON encoding of v to the connection.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func (c *Conn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON is deprecated, use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v)
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *Conn) ReadJSON(v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}

18
internal/conn-read.go Normal file
View File

@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.5
package internal
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.BuffReader.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.BuffReader.Discard(len(p))
return p, err
}

1040
internal/conn.go Normal file

File diff suppressed because it is too large Load Diff

111
internal/error.go Normal file
View File

@ -0,0 +1,111 @@
package internal
import (
"errors"
"io"
"strconv"
)
// ErrCloseSent is returned when the application writes a message to the
// connection after sending a close message.
var ErrCloseSent = errors.New("socket: close sent")
// ErrReadLimit is returned when reading a message that is larger than the
// read limit set for the connection.
var ErrReadLimit = errors.New("socket: read limit exceeded")
// netError satisfies the net Error interface.
type netError struct {
msg string
temporary bool
timeout bool
}
func (e *netError) Error() string { return e.msg }
func (e *netError) Temporary() bool { return e.temporary }
func (e *netError) Timeout() bool { return e.timeout }
// CloseError represents close frame.
type CloseError struct {
// Code is defined in RFC 6455, section 11.7.
Code int
// Text is the optional text payload.
Text string
}
func (e *CloseError) Error() string {
s := []byte("socket: close ")
s = strconv.AppendInt(s, int64(e.Code), 10)
switch e.Code {
case CloseNormalClosure:
s = append(s, " (normal)"...)
case CloseGoingAway:
s = append(s, " (going away)"...)
case CloseProtocolError:
s = append(s, " (protocol error)"...)
case CloseUnsupportedData:
s = append(s, " (unsupported data)"...)
case CloseNoStatusReceived:
s = append(s, " (no status)"...)
case CloseAbnormalClosure:
s = append(s, " (abnormal closure)"...)
case CloseInvalidFramePayloadData:
s = append(s, " (invalid payload data)"...)
case ClosePolicyViolation:
s = append(s, " (policy violation)"...)
case CloseMessageTooBig:
s = append(s, " (message too big)"...)
case CloseMandatoryExtension:
s = append(s, " (mandatory extension missing)"...)
case CloseInternalServerErr:
s = append(s, " (internal server error)"...)
case CloseTLSHandshake:
s = append(s, " (TLS handshake error)"...)
}
if e.Text != "" {
s = append(s, ": "...)
s = append(s, e.Text...)
}
return string(s)
}
// IsCloseError returns boolean indicating whether the error is a *CloseError
// with one of the specified codes.
func IsCloseError(err error, codes ...int) bool {
if e, ok := err.(*CloseError); ok {
for _, code := range codes {
if e.Code == code {
return true
}
}
}
return false
}
// IsUnexpectedCloseError returns boolean indicating whether the error is a
// *CloseError with a code not in the list of expected codes.
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
if e, ok := err.(*CloseError); ok {
for _, code := range expectedCodes {
if e.Code == code {
return false
}
}
return true
}
return false
}
var (
errWriteTimeout = &netError{msg: "socket: write timeout", timeout: true, temporary: true}
errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
errBadWriteOpCode = errors.New("socket: bad write message type")
errWriteClosed = errors.New("socket: write closed")
errInvalidControlFrame = errors.New("socket: invalid control frame")
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
ErrBadHandshake = errors.New("socket: bad handshake")
ErrInvalidCompression = errors.New("socket: invalid compression negotiation")
)

49
internal/mask.go Normal file
View File

@ -0,0 +1,49 @@
package internal
import "unsafe"
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}
// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

100
internal/prepared.go Normal file
View File

@ -0,0 +1,100 @@
package internal
import (
"bytes"
"net"
"sync"
"time"
)
// PreparedMessage caches on the wire representations of a message payload.
// Use PreparedMessage to efficiently send a message payload to multiple
// connections. PreparedMessage is especially useful when compression is used
// because the CPU and memory expensive compression operation can be executed
// once for a given set of compression options.
type PreparedMessage struct {
messageType int
data []byte
err error
mu sync.Mutex
frames map[prepareKey]*preparedFrame
}
// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
type prepareKey struct {
isServer bool
compress bool
compressionLevel int
}
// preparedFrame contains data in wire representation.
type preparedFrame struct {
once sync.Once
data []byte
}
// NewPreparedMessage returns an initialized PreparedMessage. You can then send
// it to connection using WritePreparedMessage method. Valid wire
// representation will be calculated lazily only once for a set of current
// connection options.
func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
pm := &PreparedMessage{
messageType: messageType,
frames: make(map[prepareKey]*preparedFrame),
data: data,
}
// Prepare a plain server frame.
_, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
if err != nil {
return nil, err
}
// To protect against caller modifying the data argument, remember the data
// copied to the plain server frame.
pm.data = frameData[len(frameData)-len(data):]
return pm, nil
}
func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
pm.mu.Lock()
frame, ok := pm.frames[key]
if !ok {
frame = &preparedFrame{}
pm.frames[key] = frame
}
pm.mu.Unlock()
var err error
frame.once.Do(func() {
// Prepare a frame using a 'fake' connection.
// TODO: Refactor code in conn.go to allow more direct construction of
// the frame.
mu := make(chan bool, 1)
mu <- true
var nc prepareConn
c := &Conn{
conn: &nc,
mu: mu,
isServer: key.isServer,
compressionLevel: key.compressionLevel,
enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
}
if key.compress {
c.NewCompressionWriter = CompressNoContextTakeover
}
err = c.WriteMessage(pm.messageType, pm.data)
frame.data = nc.buf.Bytes()
})
return pm.messageType, frame.data, err
}
type prepareConn struct {
buf bytes.Buffer
net.Conn
}
func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) }
func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@ -1,4 +1,4 @@
package server
package net
import (
"crypto/tls"

7
net/const.go Normal file
View File

@ -0,0 +1,7 @@
package net
const (
DefaultHandshakeTimeout = 0
DefaultKeepAlive = 0
)

67
net/server-handler.go Normal file
View File

@ -0,0 +1,67 @@
package net
import (
"net"
"git.loafle.net/commons/server-go"
)
type ServerHandler interface {
server.ServerHandler
OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error)
RegisterServlet(servlet Servlet)
Servlet() Servlet
}
type ServerHandlers struct {
server.ServerHandlers
servlet Servlet
}
func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error {
if err := sh.ServerHandlers.Init(serverCtx); nil != err {
return err
}
if nil != sh.servlet {
if err := sh.servlet.Init(serverCtx); nil != err {
return err
}
}
return nil
}
func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) {
if nil != sh.servlet {
sh.servlet.Destroy(serverCtx)
}
sh.ServerHandlers.Destroy(serverCtx)
}
func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) {
}
func (sh *ServerHandlers) RegisterServlet(servlet Servlet) {
sh.servlet = servlet
}
func (sh *ServerHandlers) Servlet() Servlet {
return sh.servlet
}
func (sh *ServerHandlers) Validate() error {
if err := sh.ServerHandlers.Validate(); nil != err {
return err
}
if 0 >= sh.Concurrency {
sh.Concurrency = 0
}
return nil
}

321
net/server.go Normal file
View File

@ -0,0 +1,321 @@
package net
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/internal"
)
type Server struct {
ServerHandler ServerHandler
ctx server.ServerCtx
connections sync.Map
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (s *Server) ListenAndServe() error {
if s.stopChan != nil {
return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again"))
}
var (
err error
listener net.Listener
)
if nil == s.ServerHandler {
return fmt.Errorf("Server: server handler must be specified")
}
if err = s.ServerHandler.Validate(); nil != err {
return err
}
s.ctx = s.ServerHandler.ServerCtx()
if nil == s.ctx {
return fmt.Errorf(s.serverMessage("ServerCtx is nil"))
}
if err = s.ServerHandler.Init(s.ctx); nil != err {
return err
}
if listener, err = s.ServerHandler.Listener(s.ctx); nil != err {
return err
}
s.stopChan = make(chan struct{})
s.stopWg.Add(1)
return s.handleServer(listener)
}
func (s *Server) Shutdown(ctx context.Context) error {
if s.stopChan == nil {
return fmt.Errorf(s.serverMessage("server must be started before stopping it"))
}
close(s.stopChan)
s.stopWg.Wait()
s.ServerHandler.Destroy(s.ctx)
s.stopChan = nil
return nil
}
func (s *Server) ConnectionSize() int {
var sz int
s.connections.Range(func(k, v interface{}) bool {
sz++
return true
})
return sz
}
func (s *Server) serverMessage(msg string) string {
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
}
func (s *Server) handleServer(listener net.Listener) error {
var (
stopping atomic.Value
netConn net.Conn
err error
)
defer func() {
if nil != listener {
listener.Close()
}
s.ServerHandler.OnStop(s.ctx)
logging.Logger().Infof(s.serverMessage("Stopped"))
s.stopWg.Done()
}()
if err = s.ServerHandler.OnStart(s.ctx); nil != err {
return err
}
logging.Logger().Infof(s.serverMessage("Started"))
for {
acceptChan := make(chan struct{})
go func() {
if netConn, err = listener.Accept(); err != nil {
if nil == stopping.Load() {
logging.Logger().Errorf(s.serverMessage(fmt.Sprintf("%v", err)))
}
}
close(acceptChan)
}()
select {
case <-s.stopChan:
stopping.Store(true)
listener.Close()
<-acceptChan
listener = nil
return nil
case <-acceptChan:
}
if nil != err {
select {
case <-s.stopChan:
return nil
case <-time.After(time.Second):
}
continue
}
if 0 < s.ServerHandler.GetConcurrency() {
sz := s.ConnectionSize()
if sz >= s.ServerHandler.GetConcurrency() {
logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz)))
netConn.Close()
continue
}
}
servlet := s.ServerHandler.Servlet()
if nil == servlet {
logging.Logger().Errorf(s.serverMessage("Servlet is nil"))
continue
}
servletCtx := servlet.ServletCtx(s.ctx)
if nil == servletCtx {
logging.Logger().Errorf(s.serverMessage("ServletCtx is nil"))
continue
}
if err := servlet.Handshake(servletCtx, netConn); nil != err {
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err)))
continue
}
conn := internal.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize())
s.stopWg.Add(1)
go s.handleConnection(servlet, servletCtx, conn)
}
}
func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *internal.Conn) {
addr := conn.RemoteAddr()
defer func() {
if nil != conn {
conn.Close()
}
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr)))
s.stopWg.Done()
}()
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr)))
s.connections.Store(conn, true)
defer s.connections.Delete(conn)
servlet.OnConnect(servletCtx, conn)
stopChan := make(chan struct{})
servletDoneChan := make(chan struct{})
readChan := make(chan []byte, 256)
writeChan := make(chan []byte, 256)
readerDoneChan := make(chan struct{})
writerDoneChan := make(chan struct{})
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
go handleRead(s, conn, stopChan, readerDoneChan, readChan)
go handleWrite(s, conn, stopChan, writerDoneChan, writeChan)
select {
case <-readerDoneChan:
close(stopChan)
conn.Close()
<-writerDoneChan
<-servletDoneChan
conn = nil
case <-writerDoneChan:
close(stopChan)
conn.Close()
<-readerDoneChan
<-servletDoneChan
conn = nil
case <-servletDoneChan:
close(stopChan)
conn.Close()
<-readerDoneChan
<-writerDoneChan
conn = nil
case <-s.stopChan:
close(stopChan)
conn.Close()
<-readerDoneChan
<-writerDoneChan
<-servletDoneChan
conn = nil
}
}
func handleRead(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, readChan chan []byte) {
defer func() {
close(doneChan)
}()
conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize())
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout()))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout()))
return nil
})
var (
message []byte
err error
)
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
if err != nil {
if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) {
logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err)))
}
}
close(readMessageChan)
}()
select {
case <-s.stopChan:
<-readMessageChan
break
case <-readMessageChan:
}
if nil != err {
select {
case <-s.stopChan:
break
case <-time.After(time.Second):
}
continue
}
readChan <- message
}
}
func handleWrite(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, writeChan chan []byte) {
defer func() {
close(doneChan)
}()
ticker := time.NewTicker(s.ServerHandler.GetPingPeriod())
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-writeChan:
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout()))
if !ok {
conn.WriteMessage(internal.CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(internal.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout()))
if err := conn.WriteMessage(internal.PingMessage, nil); nil != err {
return
}
case <-s.stopChan:
break
}
}
}

13
net/servlet.go Normal file
View File

@ -0,0 +1,13 @@
package net
import (
"net"
"git.loafle.net/commons/server-go"
)
type Servlet interface {
server.Servlet
Handshake(servletCtx server.ServletCtx, conn net.Conn) error
}

View File

@ -4,16 +4,16 @@ import (
cuc "git.loafle.net/commons/util-go/context"
)
type ServerContext interface {
type ServerCtx interface {
cuc.Context
}
func NewServerContext(parent cuc.Context) ServerContext {
return &serverContext{
ServerContext: cuc.NewContext(parent),
func NewServerCtx(parent cuc.Context) ServerCtx {
return &serverCtx{
Context: cuc.NewContext(parent),
}
}
type serverContext struct {
ServerContext
type serverCtx struct {
cuc.Context
}

View File

@ -3,70 +3,188 @@ package server
import (
"fmt"
"net"
"time"
)
type ServerHandler interface {
GetName() string
GetMaxConnections() int
GetConcurrency() int
GetHandshakeTimeout() time.Duration
GetMaxMessageSize() int64
GetReadBufferSize() int
GetWriteBufferSize() int
GetReadTimeout() time.Duration
GetWriteTimeout() time.Duration
GetPongTimeout() time.Duration
GetPingTimeout() time.Duration
GetPingPeriod() time.Duration
ServerContext() ServerContext
IsEnableCompression() bool
Init(serverCTX ServerContext) error
OnStart(serverCTX ServerContext) error
OnError(serverCTX ServerContext, conn net.Conn, status int, reason error)
OnStop(serverCTX ServerContext)
Destroy(serverCTX ServerContext)
ServerCtx() ServerCtx
Listen(serverCTX ServerContext) (net.Listener, error)
Servlet() Servlet
Init(serverCtx ServerCtx) error
OnStart(serverCtx ServerCtx) error
OnStop(serverCtx ServerCtx)
Destroy(serverCtx ServerCtx)
Validate()
Listener(serverCtx ServerCtx) (net.Listener, error)
Validate() error
}
type ServerHandlers struct {
ServerHandler
Name string
MaxConnections int
// Server name for sending in response headers.
//
// Default server name is used if left blank.
Name string
// The maximum number of concurrent connections the server may serve.
//
// DefaultConcurrency is used if not set.
Concurrency int
HandshakeTimeout time.Duration
MaxMessageSize int64
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// Maximum duration for reading the full request (including body).
//
// This also limits the maximum duration for idle keep-alive
// connections.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for writing the full response (including body).
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
EnableCompression bool
}
func (sh *ServerHandlers) ServerContext() ServerContext {
func (sh *ServerHandlers) ServerCtx() ServerCtx {
return nil
}
func (sh *ServerHandlers) Init(serverCTX ServerContext) error {
func (sh *ServerHandlers) Init(serverCtx ServerCtx) error {
return nil
}
func (sh *ServerHandlers) OnStart(serverCTX ServerContext) error {
func (sh *ServerHandlers) OnStart(serverCtx ServerCtx) error {
return nil
}
func (sh *ServerHandlers) OnError(serverCTX ServerContext, conn net.Conn, status int, reason error) {
}
func (sh *ServerHandlers) OnStop(serverCTX ServerContext) {
func (sh *ServerHandlers) OnStop(serverCtx ServerCtx) {
}
func (sh *ServerHandlers) Destroy(serverCTX ServerContext) {
func (sh *ServerHandlers) Destroy(serverCtx ServerCtx) {
}
func (sh *ServerHandlers) Listen(serverCTX ServerContext) (net.Listener, error) {
return nil, fmt.Errorf("Server: Method[ServerHandler.Listen] is not implemented")
func (sh *ServerHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) {
return nil, fmt.Errorf("Server: Method[ServerHandler.Listener] is not implemented")
}
func (sh *ServerHandlers) Servlet() Servlet {
return nil
func (sh *ServerHandlers) GetName() string {
return sh.Name
}
func (sh *ServerHandlers) Validate() {
func (sh *ServerHandlers) GetConcurrency() int {
return sh.Concurrency
}
func (sh *ServerHandlers) GetHandshakeTimeout() time.Duration {
return sh.HandshakeTimeout
}
func (sh *ServerHandlers) GetMaxMessageSize() int64 {
return sh.MaxMessageSize
}
func (sh *ServerHandlers) GetReadBufferSize() int {
return sh.ReadBufferSize
}
func (sh *ServerHandlers) GetWriteBufferSize() int {
return sh.WriteBufferSize
}
func (sh *ServerHandlers) GetReadTimeout() time.Duration {
return sh.ReadTimeout
}
func (sh *ServerHandlers) GetWriteTimeout() time.Duration {
return sh.WriteTimeout
}
func (sh *ServerHandlers) GetPongTimeout() time.Duration {
return sh.PongTimeout
}
func (sh *ServerHandlers) GetPingTimeout() time.Duration {
return sh.PingTimeout
}
func (sh *ServerHandlers) GetPingPeriod() time.Duration {
return sh.PingPeriod
}
func (sh *ServerHandlers) IsEnableCompression() bool {
return sh.EnableCompression
}
func (sh *ServerHandlers) Validate() error {
if "" == sh.Name {
sh.Name = "Server"
}
if 0 >= sh.MaxConnections {
sh.MaxConnections = 0
if sh.Concurrency <= 0 {
sh.Concurrency = DefaultConcurrency
}
if sh.HandshakeTimeout <= 0 {
sh.HandshakeTimeout = DefaultHandshakeTimeout
}
if sh.MaxMessageSize <= 0 {
sh.MaxMessageSize = DefaultMaxMessageSize
}
if sh.ReadBufferSize <= 0 {
sh.ReadBufferSize = DefaultReadBufferSize
}
if sh.WriteBufferSize <= 0 {
sh.WriteBufferSize = DefaultWriteBufferSize
}
if sh.ReadTimeout <= 0 {
sh.ReadTimeout = DefaultReadTimeout
}
if sh.WriteTimeout <= 0 {
sh.WriteTimeout = DefaultWriteTimeout
}
if sh.PongTimeout <= 0 {
sh.PongTimeout = DefaultPongTimeout
}
if sh.PingTimeout <= 0 {
sh.PingTimeout = DefaultPingTimeout
}
if sh.PingPeriod <= 0 {
sh.PingPeriod = DefaultPingPeriod
}
return nil
}

166
server.go
View File

@ -1,166 +0,0 @@
package server
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"git.loafle.net/commons/logging-go"
)
type Server struct {
ServerHandler ServerHandler
ctx ServerContext
servlets sync.Map
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (s *Server) ListenAndServe() error {
if s.stopChan != nil {
return fmt.Errorf("Server: server is already running. Stop it before starting it again")
}
var (
err error
listener net.Listener
)
if nil == s.ServerHandler {
panic("Server: server handler must be specified.")
}
s.ServerHandler.Validate()
s.ctx = s.ServerHandler.ServerContext()
if nil == s.ctx {
return fmt.Errorf("Server: ServerContext is nil")
}
if err = s.ServerHandler.Init(s.ctx); nil != err {
return fmt.Errorf("Server: Initialization of server has been failed %v", err)
}
if listener, err = s.ServerHandler.Listen(s.ctx); nil != err {
return err
}
s.stopChan = make(chan struct{})
s.stopWg.Add(1)
return s.handleLoop(listener)
}
func (s *Server) Shutdown(ctx context.Context) error {
if s.stopChan == nil {
return fmt.Errorf("Server: server must be started before stopping it")
}
close(s.stopChan)
s.stopWg.Wait()
s.ServerHandler.Destroy(s.ctx)
s.stopChan = nil
return nil
}
func (s *Server) ConnectionSize() int {
var sz int
s.servlets.Range(func(k, v interface{}) bool {
sz++
return true
})
return sz
}
func (s *Server) handleLoop(listener net.Listener) error {
var (
stopping atomic.Value
conn net.Conn
err error
)
defer func() {
s.ServerHandler.OnStop(s.ctx)
s.stopWg.Done()
}()
if err = s.ServerHandler.OnStart(s.ctx); nil != err {
return err
}
for {
acceptChan := make(chan struct{})
go func() {
if conn, err = listener.Accept(); err != nil {
if nil == stopping.Load() {
logging.Logger().Errorf("Server: Cannot accept new connection: [%s]", err)
}
}
close(acceptChan)
}()
select {
case <-s.stopChan:
stopping.Store(true)
listener.Close()
<-acceptChan
return nil
case <-acceptChan:
}
if nil != err {
select {
case <-s.stopChan:
return nil
case <-time.After(time.Second):
}
continue
}
if 0 < s.ServerHandler.GetMaxConnections() {
sz := s.ConnectionSize()
if sz >= s.ServerHandler.GetMaxConnections() {
logging.Logger().Warnf("max connections size %d, refuse\n", sz)
conn.Close()
continue
}
}
s.stopWg.Add(1)
go s.handleConnection(conn)
}
}
func (s *Server) handleConnection(conn net.Conn) {
servlet := s.ServerHandler.Servlet()
defer func() {
s.servlets.Delete(servlet)
s.stopWg.Done()
}()
if nil == servlet {
logging.Logger().Errorf("Server: Servlet is nil")
}
s.servlets.Store(servlet, true)
servletStopChan := make(chan struct{})
doneChan := make(chan struct{})
go servlet.Handle(s.ctx, conn, doneChan, servletStopChan)
select {
case <-doneChan:
close(servletStopChan)
conn.Close()
case <-s.stopChan:
close(servletStopChan)
conn.Close()
<-doneChan
}
}

28
servlet-context.go Normal file
View File

@ -0,0 +1,28 @@
package server
import (
cuc "git.loafle.net/commons/util-go/context"
)
type ServletCtx interface {
cuc.Context
ServerCtx() ServerCtx
}
func NewServletContext(parent cuc.Context, serverCtx ServerCtx) ServletCtx {
return &servletCtx{
Context: cuc.NewContext(parent),
serverCtx: serverCtx,
}
}
type servletCtx struct {
cuc.Context
serverCtx ServerCtx
}
func (sc *servletCtx) ServerCtx() ServerCtx {
return sc.serverCtx
}

View File

@ -1,7 +1,16 @@
package server
import "net"
import (
"git.loafle.net/commons/server-go/internal"
)
type Servlet interface {
Handle(serverCTX ServerContext, conn net.Conn, doneChan chan<- struct{}, stopChan <-chan struct{})
ServletCtx(serverCtx ServerCtx) ServletCtx
Init(serverCtx ServerCtx) error
Destroy(serverCtx ServerCtx)
OnConnect(servletCtx ServletCtx, conn *internal.Conn)
Handle(servletCtx ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte)
OnDisconnect(servletCtx ServletCtx)
}