diff --git a/clients/client.go b/clients/client.go index 8eca96d..365f652 100644 --- a/clients/client.go +++ b/clients/client.go @@ -58,9 +58,9 @@ func (c *client) RequestCtx() *fasthttp.RequestCtx { func (c *client) run() { hasReadTimeout := c.o.ReadTimeout > 0 c.conn.SetReadLimit(c.o.MaxMessageSize) - defer func() { - c.o.OnDisconnected(c) - }() + // defer func() { + // c.o.OnDisconnected(c) + // }() for { if hasReadTimeout { diff --git a/server.go b/server.go index 46f1164..ce845e5 100644 --- a/server.go +++ b/server.go @@ -33,6 +33,7 @@ func NewServer(o *Options) Server { ReadBufferSize: s._option.ReadBufferSize, WriteBufferSize: s._option.WriteBufferSize, HandshakeTimeout: s._option.HandshakeTimeout, + CheckOrigin: s._option.OnCheckOrigin, } return s @@ -49,12 +50,13 @@ func (s *server) onDisconnected(c clients.Client) { } func (s *server) onConnection(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + s._upgrader.Upgrade(ctx, nil, func(conn *websocket.Conn, err error) { if err != nil { log.Print("upgrade:", err) return } - path := string(ctx.Path()) co, ok := s._handlers[path] if !ok { log.Printf("Path[%s] is not exist.", path) diff --git a/websocket/server.go b/websocket/server.go index 5ec3fa5..00e0fd6 100644 --- a/websocket/server.go +++ b/websocket/server.go @@ -113,9 +113,11 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re return } - if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil == v { - cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")) - return + if nil != responseHeader { + if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil == v { + cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")) + return + } } if !tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") { @@ -162,50 +164,39 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re } } - ctx.Hijack(func(netConn net.Conn) { - var err error - c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) - c.subprotocol = subprotocol - if compress { - c.newCompressionWriter = compressNoContextTakeover - c.newDecompressionReader = decompressNoContextTakeover - } - - p := c.writeBuf[:0] - p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) - p = append(p, computeAcceptKey(challengeKey)...) - p = append(p, "\r\n"...) - if c.subprotocol != "" { - p = append(p, "Sec-Websocket-Protocol: "...) - p = append(p, c.subprotocol...) - p = append(p, "\r\n"...) - } - if compress { - p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) - } - + ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols) + ctx.Response.Header.Set("Upgrade", "websocket") + ctx.Response.Header.Set("Connection", "Upgrade") + ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + if subprotocol != "" { + ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol) + } + if compress { + ctx.Response.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + if nil != responseHeader { responseHeader.VisitAll(func(key, value []byte) { k := string(key) v := string(value) if k == "Sec-Websocket-Protocol" { return } - - p = append(p, k...) - p = append(p, ": "...) - for i := 0; i < len(v); i++ { - b := v[i] - if b <= 31 { - // prevent response splitting. - b = ' ' - } - p = append(p, b) - } - p = append(p, "\r\n"...) - + ctx.Response.Header.Set(k, v) }) + } - p = append(p, "\r\n"...) + h := &fasthttp.RequestHeader{} + //copy request headers in order to have access inside the Conn after + ctx.Request.Header.CopyTo(h) + + ctx.Hijack(func(netConn net.Conn) { + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) + c.SetHeaders(h) + c.subprotocol = subprotocol + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } // Clear deadlines set by HTTP server. netConn.SetDeadline(time.Time{}) @@ -213,11 +204,7 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) } - if _, err = netConn.Write(p); err != nil { - netConn.Close() - cb(nil, err) - return - } + if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Time{}) }