This commit is contained in:
crusader 2017-11-07 19:03:40 +09:00
commit 072d6f11a8
22 changed files with 3128 additions and 0 deletions

70
.gitignore vendored Normal file
View File

@ -0,0 +1,70 @@
# Created by .ignore support plugin (hsz.mobi)
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff:
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/dictionaries
# Sensitive or high-churn files:
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.xml
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
# Gradle:
.idea/**/gradle.xml
.idea/**/libraries
# Mongo Explorer plugin:
.idea/**/mongoSettings.xml
## File-based project format:
*.iws
## Plugin-specific files:
# IntelliJ
/out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
### Go template
# Binaries for programs and plugins
*.exe
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/
.idea/
*.iml
vendor/
glide.lock
.DS_Store
dist/
debug

32
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,32 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Debug",
"type": "go",
"request": "launch",
"mode": "debug",
"remotePath": "",
"port": 2345,
"host": "127.0.0.1",
"program": "${workspaceRoot}/main.go",
"env": {},
"args": [],
"showLog": true
},
{
"name": "File Debug",
"type": "go",
"request": "launch",
"mode": "debug",
"remotePath": "",
"port": 2345,
"host": "127.0.0.1",
"program": "${fileDirname}",
"env": {},
"args": [],
"showLog": true
}
]
}

3
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,3 @@
// Place your settings in this file to overwrite default and user settings.
{
}

24
constants.go Normal file
View File

@ -0,0 +1,24 @@
package websocket_fasthttp
const (
// DefaultConcurrency is the maximum number of concurrent connections
// the Server may serve by default (i.e. if Server.Concurrency isn't set).
DefaultConcurrency = 256 * 1024
// DefaultMaxStopWaitTime is default value of server stop Timeout
DefaultMaxStopWaitTime = 0
// DefaultHandshakeTimeout is default value of websocket handshake Timeout
DefaultHandshakeTimeout = 0
// DefaultReadBufferSize is default value of Read Buffer Size
DefaultReadBufferSize = 4096
// DefaultWriteBufferSize is default value of Write Buffer Size
DefaultWriteBufferSize = 4096
// DefaultReadTimeout is default value of read timeout
DefaultReadTimeout = 0
// DefaultWriteTimeout is default value of write timeout
DefaultWriteTimeout = 0
// DefaultEnableCompression is default value of support compression
DefaultEnableCompression = false
)

4
glide.yaml Normal file
View File

@ -0,0 +1,4 @@
package: git.loafle.net/commons_go/websocket_fasthttp
import:
- package: github.com/valyala/fasthttp
version: v20160617

108
listener.go Normal file
View File

@ -0,0 +1,108 @@
package websocket_fasthttp
import (
"fmt"
"net"
"sync/atomic"
"time"
)
type GracefulListener struct {
// inner listener
ln net.Listener
// maximum wait time for graceful shutdown
maxWaitTime time.Duration
// this channel is closed during graceful shutdown on zero open connections.
done chan struct{}
// the number of open connections
connsCount uint64
// becomes non-zero when graceful shutdown starts
shutdown uint64
}
// NewGracefulListener wraps the given listener into 'graceful shutdown' listener.
func newGracefulListener(ln net.Listener, maxWaitTime time.Duration) net.Listener {
return &GracefulListener{
ln: ln,
maxWaitTime: maxWaitTime,
done: make(chan struct{}),
}
}
func (ln *GracefulListener) Accept() (net.Conn, error) {
c, err := ln.ln.Accept()
if err != nil {
return nil, err
}
atomic.AddUint64(&ln.connsCount, 1)
return &gracefulConn{
Conn: c,
ln: ln,
}, nil
}
func (ln *GracefulListener) Addr() net.Addr {
return ln.ln.Addr()
}
// Close closes the inner listener and waits until all the pending open connections
// are closed before returning.
func (ln *GracefulListener) Close() error {
err := ln.ln.Close()
if err != nil {
return nil
}
return ln.waitForZeroConns()
}
func (ln *GracefulListener) waitForZeroConns() error {
atomic.AddUint64(&ln.shutdown, 1)
if atomic.LoadUint64(&ln.connsCount) == 0 {
close(ln.done)
return nil
}
select {
case <-ln.done:
return nil
case <-time.After(ln.maxWaitTime):
return fmt.Errorf("cannot complete graceful shutdown in %s", ln.maxWaitTime)
}
return nil
}
func (ln *GracefulListener) closeConn() {
connsCount := atomic.AddUint64(&ln.connsCount, ^uint64(0))
if atomic.LoadUint64(&ln.shutdown) != 0 && connsCount == 0 {
close(ln.done)
}
}
type gracefulConn struct {
net.Conn
ln *GracefulListener
}
func (c *gracefulConn) Close() error {
err := c.Conn.Close()
if err != nil {
return err
}
c.ln.closeConn()
return nil
}

167
server.go Normal file
View File

@ -0,0 +1,167 @@
package websocket_fasthttp
import (
"fmt"
"net"
"net/http"
"sync"
"git.loafle.net/commons_go/logging"
"git.loafle.net/commons_go/websocket_fasthttp/websocket"
"github.com/valyala/fasthttp"
)
type Server interface {
Start() error
Stop()
}
func New(sh ServerHandler) Server {
s := &server{
sh: sh,
}
return s
}
type server struct {
sh ServerHandler
httpServer *fasthttp.Server
upgrader *websocket.Upgrader
listener net.Listener
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (s *server) Start() error {
if nil == s.sh {
panic("Server: server handler must be specified.")
}
s.sh.Validate()
if s.stopChan != nil {
panic("Server: server is already running. Stop it before starting it again")
}
s.httpServer = &fasthttp.Server{
Handler: s.handleRequest,
Name: s.sh.GetName(),
Concurrency: s.sh.GetConcurrency(),
ReadBufferSize: s.sh.GetReadBufferSize(),
WriteBufferSize: s.sh.GetWriteBufferSize(),
ReadTimeout: s.sh.GetReadTimeout(),
WriteTimeout: s.sh.GetWriteTimeout(),
}
s.upgrader = &websocket.Upgrader{
HandshakeTimeout: s.sh.GetHandshakeTimeout(),
ReadBufferSize: s.sh.GetReadBufferSize(),
WriteBufferSize: s.sh.GetWriteBufferSize(),
CheckOrigin: s.sh.CheckOrigin,
Error: s.handleError,
EnableCompression: s.sh.IsEnableCompression(),
}
var err error
var listener net.Listener
if listener, err = s.sh.Listen(); nil != err {
return err
}
s.listener = newGracefulListener(listener, s.sh.GetMaxStopWaitTime())
s.stopChan = make(chan struct{})
s.sh.OnStart()
s.stopWg.Add(1)
go handleServer(s)
return nil
}
func (s *server) Stop() {
if s.stopChan == nil {
panic("Server: server must be started before stopping it")
}
close(s.stopChan)
s.stopWg.Wait()
s.stopChan = nil
s.sh.OnStop()
}
func handleServer(s *server) {
go func() {
defer s.stopWg.Done()
if err := s.httpServer.Serve(s.listener); nil != err {
logging.Logger.Error(fmt.Sprintf("Server: Server err - %v", err))
}
}()
select {
case <-s.stopChan:
s.listener.Close()
return
}
}
func (s *server) handleRequest(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
var socketHandler SocketHandler
var err error
if socketHandler, err = s.sh.GetSocketHandler(path); nil != err {
s.handleError(ctx, fasthttp.StatusNotFound, err)
return
}
var responseHeader *fasthttp.ResponseHeader
var allowHandshake bool
if allowHandshake, responseHeader = socketHandler.Handshake(ctx); !allowHandshake {
s.handleError(ctx, http.StatusNotAcceptable, fmt.Errorf("Server: Handshake err"))
return
}
s.upgrader.Upgrade(ctx, responseHeader, func(conn *websocket.Conn, err error) {
if err != nil {
s.handleError(ctx, fasthttp.StatusInternalServerError, err)
return
}
s.stopWg.Add(1)
go handleConnection(s, ctx, conn, socketHandler)
})
}
func (s *server) handleError(ctx *fasthttp.RequestCtx, status int, reason error) {
ctx.Response.Header.Set("Sec-Websocket-Version", "13")
ctx.Error(http.StatusText(status), status)
s.sh.OnError(ctx, status, reason)
}
func handleConnection(s *server, ctx *fasthttp.RequestCtx, conn *websocket.Conn, handler SocketHandler) {
defer s.stopWg.Done()
logging.Logger.Debug(fmt.Sprintf("Server: Client[%s] is connected.", conn.RemoteAddr()))
clientStopChan := make(chan struct{})
handleDoneChan := make(chan struct{})
go handler.Handle(conn, clientStopChan, handleDoneChan)
select {
case <-s.stopChan:
close(clientStopChan)
conn.Close()
<-handleDoneChan
case <-handleDoneChan:
close(clientStopChan)
conn.Close()
logging.Logger.Debug(fmt.Sprintf("Server: Client[%s] is disconnected.", conn.RemoteAddr()))
}
}

32
server_handler.go Normal file
View File

@ -0,0 +1,32 @@
package websocket_fasthttp
import (
"net"
"time"
"github.com/valyala/fasthttp"
)
type ServerHandler interface {
Listen() (net.Listener, error)
CheckOrigin(ctx *fasthttp.RequestCtx) bool
OnError(ctx *fasthttp.RequestCtx, status int, reason error)
OnStart()
OnStop()
RegisterSocketHandler(path string, handler SocketHandler)
GetSocketHandler(path string) (SocketHandler, error)
GetName() string
GetConcurrency() int
GetMaxStopWaitTime() time.Duration
GetHandshakeTimeout() time.Duration
GetReadBufferSize() int
GetWriteBufferSize() int
GetReadTimeout() time.Duration
GetWriteTimeout() time.Duration
IsEnableCompression() bool
Validate()
}

151
server_handlers.go Normal file
View File

@ -0,0 +1,151 @@
package websocket_fasthttp
import (
"errors"
"fmt"
"net"
"time"
"github.com/valyala/fasthttp"
)
type ServerHandlers struct {
// Server name for sending in response headers.
//
// Default server name is used if left blank.
Name string
// The maximum number of concurrent connections the server may serve.
//
// DefaultConcurrency is used if not set.
Concurrency int
MaxStopWaitTime time.Duration
HandshakeTimeout time.Duration
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// Maximum duration for reading the full request (including body).
//
// This also limits the maximum duration for idle keep-alive
// connections.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for writing the full response (including body).
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
EnableCompression bool
socketHandlers map[string]SocketHandler
}
func (sh *ServerHandlers) Listen() (net.Listener, error) {
return nil, errors.New("Server: Handler method[Listen] of Server is not implement")
}
func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool {
return true
}
func (sh *ServerHandlers) OnError(ctx *fasthttp.RequestCtx, status int, reason error) {
// no op
}
func (sh *ServerHandlers) OnStart() {
// no op
}
func (sh *ServerHandlers) OnStop() {
// no op
}
func (sh *ServerHandlers) RegisterSocketHandler(path string, handler SocketHandler) {
if nil == sh.socketHandlers {
sh.socketHandlers = make(map[string]SocketHandler)
}
sh.socketHandlers[path] = handler
}
func (sh *ServerHandlers) GetSocketHandler(path string) (SocketHandler, error) {
var handler SocketHandler
if path == "" && len(sh.socketHandlers) == 1 {
for _, h := range sh.socketHandlers {
handler = h
}
} else if handler = sh.socketHandlers[path]; nil == handler {
return nil, fmt.Errorf("Unrecognized path: %s", path)
}
return handler, nil
}
func (sh *ServerHandlers) GetName() string {
return sh.Name
}
func (sh *ServerHandlers) GetConcurrency() int {
return sh.Concurrency
}
func (sh *ServerHandlers) GetMaxStopWaitTime() time.Duration {
return sh.MaxStopWaitTime
}
func (sh *ServerHandlers) GetHandshakeTimeout() time.Duration {
return sh.HandshakeTimeout
}
func (sh *ServerHandlers) GetReadBufferSize() int {
return sh.ReadBufferSize
}
func (sh *ServerHandlers) GetWriteBufferSize() int {
return sh.WriteBufferSize
}
func (sh *ServerHandlers) GetReadTimeout() time.Duration {
return sh.ReadTimeout
}
func (sh *ServerHandlers) GetWriteTimeout() time.Duration {
return sh.WriteTimeout
}
func (sh *ServerHandlers) IsEnableCompression() bool {
return sh.EnableCompression
}
func (sh *ServerHandlers) Validate() {
if sh.Concurrency <= 0 {
sh.Concurrency = DefaultConcurrency
}
if sh.MaxStopWaitTime <= 0 {
sh.MaxStopWaitTime = DefaultMaxStopWaitTime
}
if sh.HandshakeTimeout <= 0 {
sh.HandshakeTimeout = DefaultHandshakeTimeout
}
if sh.ReadBufferSize <= 0 {
sh.ReadBufferSize = DefaultReadBufferSize
}
if sh.WriteBufferSize <= 0 {
sh.WriteBufferSize = DefaultWriteBufferSize
}
if sh.ReadTimeout <= 0 {
sh.ReadTimeout = DefaultReadTimeout
}
if sh.WriteTimeout <= 0 {
sh.WriteTimeout = DefaultWriteTimeout
}
}

2
socket.go Normal file
View File

@ -0,0 +1,2 @@
package websocket_fasthttp

11
socket_handler.go Normal file
View File

@ -0,0 +1,11 @@
package websocket_fasthttp
import (
"git.loafle.net/commons_go/websocket_fasthttp/websocket"
"github.com/valyala/fasthttp"
)
type SocketHandler interface {
Handshake(ctx *fasthttp.RequestCtx) (bool, *fasthttp.ResponseHeader)
Handle(conn *websocket.Conn, stopChan <-chan struct{}, doneChan chan<- struct{})
}

20
socket_handlers.go Normal file
View File

@ -0,0 +1,20 @@
package websocket_fasthttp
import (
"git.loafle.net/commons_go/websocket_fasthttp/websocket"
"github.com/valyala/fasthttp"
)
type SocketHandlers struct {
}
func (sh *SocketHandlers) Handshake(ctx *fasthttp.RequestCtx) (bool, *fasthttp.ResponseHeader) {
return true, nil
}
func (sh *SocketHandlers) Handle(conn *websocket.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
// no op
}
func (sh *SocketHandlers) Validate() {
}

392
websocket/client.go Normal file
View File

@ -0,0 +1,392 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")
var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
// NewClient creates a new client connection using the given net connection.
// The URL u specifies the host and request URI. Use requestHeader to specify
// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
// (Cookie). Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etc.
//
// Deprecated: Use Dialer instead.
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
d := Dialer{
ReadBufferSize: readBufSize,
WriteBufferSize: writeBufSize,
NetDial: func(net, addr string) (net.Conn, error) {
return netConn, nil
},
}
return d.Dial(u.String(), requestHeader)
}
// A Dialer contains options for connecting to WebSocket server.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar
}
var errMalformedURL = errors.New("malformed ws or wss URL")
// parseURL parses the URL.
//
// This function is a replacement for the standard library url.Parse function.
// In Go 1.4 and earlier, url.Parse loses information from the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC:
//
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
var u url.URL
switch {
case strings.HasPrefix(s, "ws://"):
u.Scheme = "ws"
s = s[len("ws://"):]
case strings.HasPrefix(s, "wss://"):
u.Scheme = "wss"
s = s[len("wss://"):]
default:
return nil, errMalformedURL
}
if i := strings.Index(s, "?"); i >= 0 {
u.RawQuery = s[i+1:]
s = s[:i]
}
if i := strings.Index(s, "/"); i >= 0 {
u.Opaque = s[i:]
s = s[:i]
} else {
u.Opaque = "/"
}
u.Host = s
if strings.Contains(u.Host, "@") {
// Don't bother parsing user information because user information is
// not allowed in websocket URIs.
return nil, errMalformedURL
}
return &u, nil
}
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
switch u.Scheme {
case "wss":
hostPort += ":443"
case "https":
hostPort += ":443"
default:
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
// DefaultDialer is a dialer with all fields set to the default zero values.
var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
// Dial creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil {
d = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
challengeKey, err := generateChallengeKey()
if err != nil {
return nil, nil, err
}
u, err := parseURL(urlStr)
if err != nil {
return nil, nil, err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return nil, nil, errMalformedURL
}
if u.User != nil {
// User name and password are not allowed in websocket URIs.
return nil, nil, errMalformedURL
}
req := &http.Request{
Method: "GET",
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
// Set the cookies present in the cookie jar of the dialer
if d.Jar != nil {
for _, cookie := range d.Jar.Cookies(u) {
req.AddCookie(cookie)
}
}
// Set the request headers using the capitalization for names and values in
// RFC examples. Although the capitalization shouldn't matter, there are
// servers that depend on it. The Header.Set method is not used because the
// method canonicalizes the header names.
req.Header["Upgrade"] = []string{"websocket"}
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(d.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
}
for k, vs := range requestHeader {
switch {
case k == "Host":
if len(vs) > 0 {
req.Host = vs[0]
}
case k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
default:
req.Header[k] = vs
}
}
if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
hostPort, hostNoPort := hostPortNoPort(u)
var proxyURL *url.URL
// Check wether the proxy method has been configured
if d.Proxy != nil {
proxyURL, err = d.Proxy(req)
}
if err != nil {
return nil, nil, err
}
var targetHostPort string
if proxyURL != nil {
targetHostPort, _ = hostPortNoPort(proxyURL)
} else {
targetHostPort = hostPort
}
var deadline time.Time
if d.HandshakeTimeout != 0 {
deadline = time.Now().Add(d.HandshakeTimeout)
}
netDial := d.NetDial
if netDial == nil {
netDialer := &net.Dialer{Deadline: deadline}
netDial = netDialer.Dial
}
netConn, err := netDial("tcp", targetHostPort)
if err != nil {
return nil, nil, err
}
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if err := netConn.SetDeadline(deadline); err != nil {
return nil, nil, err
}
if proxyURL != nil {
connectHeader := make(http.Header)
if user := proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: connectHeader,
}
connectReq.Write(netConn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}
if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
return nil, nil, err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return nil, nil, err
}
}
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
return nil, nil, err
}
if d.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
d.Jar.SetCookies(u, rc)
}
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake
}
for _, ext := range httpParseExtensions(resp.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
return nil, resp, errInvalidCompression
}
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
break
}
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil
}

16
websocket/client_clone.go Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

148
websocket/compression.go Normal file
View File

@ -0,0 +1,148 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"compress/flate"
"errors"
"io"
"strings"
"sync"
)
const (
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression
defaultCompressionLevel = 1
)
var (
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
flateReaderPool = sync.Pool{New: func() interface{} {
return flate.NewReader(nil)
}}
)
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr}
}
func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}
func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
p *sync.Pool
}
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil {
return 0, errWriteClosed
}
return w.fw.Write(p)
}
func (w *flateWriteWrapper) Close() error {
if w.fw == nil {
return errWriteClosed
}
err1 := w.fw.Flush()
w.p.Put(w.fw)
w.fw = nil
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
type flateReadWrapper struct {
fr io.ReadCloser
}
func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}
func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}

1168
websocket/conn.go Normal file

File diff suppressed because it is too large Load Diff

18
websocket/conn_read.go Normal file
View File

@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}

55
websocket/json.go Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"encoding/json"
"io"
)
// WriteJSON is deprecated, use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v)
}
// WriteJSON writes the JSON encoding of v to the connection.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func (c *Conn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON is deprecated, use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v)
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *Conn) ReadJSON(v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}

55
websocket/mask.go Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// +build !appengine
package websocket
import "unsafe"
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}
// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

103
websocket/prepared.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"net"
"sync"
"time"
)
// PreparedMessage caches on the wire representations of a message payload.
// Use PreparedMessage to efficiently send a message payload to multiple
// connections. PreparedMessage is especially useful when compression is used
// because the CPU and memory expensive compression operation can be executed
// once for a given set of compression options.
type PreparedMessage struct {
messageType int
data []byte
err error
mu sync.Mutex
frames map[prepareKey]*preparedFrame
}
// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
type prepareKey struct {
isServer bool
compress bool
compressionLevel int
}
// preparedFrame contains data in wire representation.
type preparedFrame struct {
once sync.Once
data []byte
}
// NewPreparedMessage returns an initialized PreparedMessage. You can then send
// it to connection using WritePreparedMessage method. Valid wire
// representation will be calculated lazily only once for a set of current
// connection options.
func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
pm := &PreparedMessage{
messageType: messageType,
frames: make(map[prepareKey]*preparedFrame),
data: data,
}
// Prepare a plain server frame.
_, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
if err != nil {
return nil, err
}
// To protect against caller modifying the data argument, remember the data
// copied to the plain server frame.
pm.data = frameData[len(frameData)-len(data):]
return pm, nil
}
func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
pm.mu.Lock()
frame, ok := pm.frames[key]
if !ok {
frame = &preparedFrame{}
pm.frames[key] = frame
}
pm.mu.Unlock()
var err error
frame.once.Do(func() {
// Prepare a frame using a 'fake' connection.
// TODO: Refactor code in conn.go to allow more direct construction of
// the frame.
mu := make(chan bool, 1)
mu <- true
var nc prepareConn
c := &Conn{
conn: &nc,
mu: mu,
isServer: key.isServer,
compressionLevel: key.compressionLevel,
enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
}
if key.compress {
c.newCompressionWriter = compressNoContextTakeover
}
err = c.WriteMessage(pm.messageType, pm.data)
frame.data = nc.buf.Bytes()
})
return pm.messageType, frame.data, err
}
type prepareConn struct {
buf bytes.Buffer
net.Conn
}
func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) }
func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }

277
websocket/server.go Normal file
View File

@ -0,0 +1,277 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/valyala/fasthttp"
)
type (
OnUpgradeFunc func(*Conn, error)
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
}
func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client.
Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(ctx *fasthttp.RequestCtx, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or
// must match the host of the request.
CheckOrigin func(ctx *fasthttp.RequestCtx) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}
func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*Conn, error) {
err := HandshakeError{reason}
if u.Error != nil {
u.Error(ctx, status, err)
} else {
ctx.Response.Header.Set("Sec-Websocket-Version", "13")
ctx.Error(http.StatusText(status), status)
}
return nil, err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(ctx *fasthttp.RequestCtx) bool {
origin := string(ctx.Request.Header.Peek("Origin"))
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin)
if err != nil {
return false
}
return u.Host == string(ctx.Host())
}
func (u *Upgrader) selectSubprotocol(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(ctx)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return string(responseHeader.Peek("Sec-Websocket-Protocol"))
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol).
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, cb OnUpgradeFunc) {
if !ctx.IsGet() {
cb(u.returnError(ctx, fasthttp.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET"))
return
}
if nil != responseHeader {
if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil == v {
cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported"))
return
}
}
if !tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header"))
return
}
if !tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header"))
return
}
if !tokenListContainsValue(&ctx.Request.Header, "Sec-Websocket-Version", "13") {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header"))
return
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(ctx) {
cb(u.returnError(ctx, fasthttp.StatusForbidden, "websocket: 'Origin' header value not allowed"))
return
}
challengeKey := string(ctx.Request.Header.Peek("Sec-Websocket-Key"))
if challengeKey == "" {
cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank"))
return
}
subprotocol := u.selectSubprotocol(ctx, responseHeader)
// Negotiate PMCE
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(&ctx.Request.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols)
ctx.Response.Header.Set("Upgrade", "websocket")
ctx.Response.Header.Set("Connection", "Upgrade")
ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
if subprotocol != "" {
ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol)
}
if compress {
ctx.Response.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
if nil != responseHeader {
responseHeader.VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
if k == "Sec-Websocket-Protocol" {
return
}
ctx.Response.Header.Set(k, v)
})
}
h := &fasthttp.RequestHeader{}
//copy request headers in order to have access inside the Conn after
ctx.Request.Header.CopyTo(h)
ctx.Hijack(func(netConn net.Conn) {
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.SetHeaders(h)
c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
cb(c, nil)
})
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// This function is deprecated, use websocket.Upgrader instead.
//
// The application is responsible for checking the request origin before
// calling Upgrade. An example implementation of the same origin policy is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, readBufSize, writeBufSize int, cb OnUpgradeFunc) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(ctx *fasthttp.RequestCtx, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool {
// allow all connections by default
return true
}
u.Upgrade(ctx, responseHeader, cb)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(ctx *fasthttp.RequestCtx) []string {
h := strings.TrimSpace(string(ctx.Request.Header.Peek("Sec-Websocket-Protocol")))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(ctx *fasthttp.RequestCtx) bool {
return tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") &&
tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket")
}

272
websocket/util.go Normal file
View File

@ -0,0 +1,272 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
"strings"
"github.com/valyala/fasthttp"
)
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func generateChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Octet types from RFC 2616.
var octetTypes [256]byte
const (
isTokenOctet = 1 << iota
isSpaceOctet
)
func init() {
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
break
}
}
return s[i:]
}
func nextToken(s string) (token string, rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 {
break
}
}
return s[:i], s[i:]
}
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j += 1
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j += 1
}
}
return "", ""
}
}
return "", ""
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header *fasthttp.RequestHeader, name string, value string) bool {
s := string(header.Peek(name))
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
break
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
break
}
if strings.EqualFold(t, value) {
return true
}
if s == "" {
break
}
s = s[1:]
}
return false
}
// parseExtensiosn parses WebSocket extensions from a header.
func parseExtensions(header *fasthttp.RequestHeader) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
s := string(header.Peek("Sec-Websocket-Extensions"))
var result []map[string]string
headers:
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
break headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
break headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
break headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
break headers
}
result = append(result, ext)
if s == "" {
break headers
}
s = s[1:]
}
return result
}
// parseExtensiosn parses WebSocket extensions from a header.
func httpParseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}