// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "net" "net/http" "net/url" "strings" "time" "github.com/valyala/fasthttp" ) type ( OnUpgradeFunc func(*Conn, error) ) // HandshakeError describes an error with the handshake from the peer. type HandshakeError struct { message string } func (e HandshakeError) Error() string { return e.message } // Upgrader specifies parameters for upgrading an HTTP connection to a // WebSocket connection. type Upgrader struct { // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer // size is zero, then buffers allocated by the HTTP server are used. The // I/O buffer sizes do not limit the size of the messages that can be sent // or received. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the server's supported protocols in order of // preference. If this field is set, then the Upgrade method negotiates a // subprotocol by selecting the first match in this list with a protocol // requested by the client. Subprotocols []string // Error specifies the function for generating HTTP error responses. If Error // is nil, then http.Error is used to generate the HTTP response. Error func(ctx *fasthttp.RequestCtx, status int, reason error) // CheckOrigin returns true if the request Origin header is acceptable. If // CheckOrigin is nil, the host in the Origin header must not be set or // must match the host of the request. CheckOrigin func(ctx *fasthttp.RequestCtx) bool // EnableCompression specify if the server should attempt to negotiate per // message compression (RFC 7692). Setting this value to true does not // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool } func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*Conn, error) { err := HandshakeError{reason} if u.Error != nil { u.Error(ctx, status, err) } else { ctx.Response.Header.Set("Sec-Websocket-Version", "13") ctx.Error(http.StatusText(status), status) } return nil, err } // checkSameOrigin returns true if the origin is not set or is equal to the request host. func checkSameOrigin(ctx *fasthttp.RequestCtx) bool { origin := string(ctx.Request.Header.Peek("Origin")) if len(origin) == 0 { return true } u, err := url.Parse(origin) if err != nil { return false } return u.Host == string(ctx.Host()) } func (u *Upgrader) selectSubprotocol(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader) string { if u.Subprotocols != nil { clientProtocols := Subprotocols(ctx) for _, serverProtocol := range u.Subprotocols { for _, clientProtocol := range clientProtocols { if clientProtocol == serverProtocol { return clientProtocol } } } } else if responseHeader != nil { return string(responseHeader.Peek("Sec-Websocket-Protocol")) } return "" } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the // application negotiated subprotocol (Sec-Websocket-Protocol). // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, cb OnUpgradeFunc) { if !ctx.IsGet() { cb(u.returnError(ctx, fasthttp.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")) return } if nil != responseHeader { if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil == v { cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")) return } } if !tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") { cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header")) return } if !tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") { cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header")) return } if !tokenListContainsValue(&ctx.Request.Header, "Sec-Websocket-Version", "13") { cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")) return } checkOrigin := u.CheckOrigin if checkOrigin == nil { checkOrigin = checkSameOrigin } if !checkOrigin(ctx) { cb(u.returnError(ctx, fasthttp.StatusForbidden, "websocket: 'Origin' header value not allowed")) return } challengeKey := string(ctx.Request.Header.Peek("Sec-Websocket-Key")) if challengeKey == "" { cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank")) return } subprotocol := u.selectSubprotocol(ctx, responseHeader) // Negotiate PMCE var compress bool if u.EnableCompression { for _, ext := range parseExtensions(&ctx.Request.Header) { if ext[""] != "permessage-deflate" { continue } compress = true break } } ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols) ctx.Response.Header.Set("Upgrade", "websocket") ctx.Response.Header.Set("Connection", "Upgrade") ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) if subprotocol != "" { ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol) } if compress { ctx.Response.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") } if nil != responseHeader { responseHeader.VisitAll(func(key, value []byte) { k := string(key) v := string(value) if k == "Sec-Websocket-Protocol" { return } ctx.Response.Header.Set(k, v) }) } h := &fasthttp.RequestHeader{} //copy request headers in order to have access inside the Conn after ctx.Request.Header.CopyTo(h) ctx.Hijack(func(netConn net.Conn) { c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c.SetHeaders(h) c.subprotocol = subprotocol if compress { c.newCompressionWriter = compressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover } // Clear deadlines set by HTTP server. netConn.SetDeadline(time.Time{}) if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) } if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Time{}) } cb(c, nil) }) } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // This function is deprecated, use websocket.Upgrader instead. // // The application is responsible for checking the request origin before // calling Upgrade. An example implementation of the same origin policy is: // // if req.Header.Get("Origin") != "http://"+req.Host { // http.Error(w, "Origin not allowed", 403) // return // } // // If the endpoint supports subprotocols, then the application is responsible // for negotiating the protocol used on the connection. Use the Subprotocols() // function to get the subprotocols requested by the client. Use the // Sec-Websocket-Protocol response header to specify the subprotocol selected // by the application. // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the // negotiated subprotocol (Sec-Websocket-Protocol). // // The connection buffers IO to the underlying network connection. The // readBufSize and writeBufSize parameters specify the size of the buffers to // use. Messages can be larger than the buffers. // // If the request is not a valid WebSocket handshake, then Upgrade returns an // error of type HandshakeError. Applications should handle this error by // replying to the client with an HTTP error response. func Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, readBufSize, writeBufSize int, cb OnUpgradeFunc) { u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} u.Error = func(ctx *fasthttp.RequestCtx, status int, reason error) { // don't return errors to maintain backwards compatibility } u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool { // allow all connections by default return true } u.Upgrade(ctx, responseHeader, cb) } // Subprotocols returns the subprotocols requested by the client in the // Sec-Websocket-Protocol header. func Subprotocols(ctx *fasthttp.RequestCtx) []string { h := strings.TrimSpace(string(ctx.Request.Header.Peek("Sec-Websocket-Protocol"))) if h == "" { return nil } protocols := strings.Split(h, ",") for i := range protocols { protocols[i] = strings.TrimSpace(protocols[i]) } return protocols } // IsWebSocketUpgrade returns true if the client requested upgrade to the // WebSocket protocol. func IsWebSocketUpgrade(ctx *fasthttp.RequestCtx) bool { return tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") && tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") }