This commit is contained in:
crusader 2017-08-24 23:50:14 +09:00
parent 2ed09895ff
commit 67dbc5cd28
3 changed files with 37 additions and 48 deletions

View File

@ -58,9 +58,9 @@ func (c *client) RequestCtx() *fasthttp.RequestCtx {
func (c *client) run() {
hasReadTimeout := c.o.ReadTimeout > 0
c.conn.SetReadLimit(c.o.MaxMessageSize)
defer func() {
c.o.OnDisconnected(c)
}()
// defer func() {
// c.o.OnDisconnected(c)
// }()
for {
if hasReadTimeout {

View File

@ -33,6 +33,7 @@ func NewServer(o *Options) Server {
ReadBufferSize: s._option.ReadBufferSize,
WriteBufferSize: s._option.WriteBufferSize,
HandshakeTimeout: s._option.HandshakeTimeout,
CheckOrigin: s._option.OnCheckOrigin,
}
return s
@ -49,12 +50,13 @@ func (s *server) onDisconnected(c clients.Client) {
}
func (s *server) onConnection(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
s._upgrader.Upgrade(ctx, nil, func(conn *websocket.Conn, err error) {
if err != nil {
log.Print("upgrade:", err)
return
}
path := string(ctx.Path())
co, ok := s._handlers[path]
if !ok {
log.Printf("Path[%s] is not exist.", path)

View File

@ -113,9 +113,11 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
return
}
if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil == v {
cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported"))
return
if nil != responseHeader {
if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil == v {
cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported"))
return
}
}
if !tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") {
@ -162,50 +164,39 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
}
}
ctx.Hijack(func(netConn net.Conn) {
var err error
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
p := c.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols)
ctx.Response.Header.Set("Upgrade", "websocket")
ctx.Response.Header.Set("Connection", "Upgrade")
ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
if subprotocol != "" {
ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol)
}
if compress {
ctx.Response.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
if nil != responseHeader {
responseHeader.VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
if k == "Sec-Websocket-Protocol" {
return
}
p = append(p, k...)
p = append(p, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// prevent response splitting.
b = ' '
}
p = append(p, b)
}
p = append(p, "\r\n"...)
ctx.Response.Header.Set(k, v)
})
}
p = append(p, "\r\n"...)
h := &fasthttp.RequestHeader{}
//copy request headers in order to have access inside the Conn after
ctx.Request.Header.CopyTo(h)
ctx.Hijack(func(netConn net.Conn) {
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.SetHeaders(h)
c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
@ -213,11 +204,7 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
netConn.Close()
cb(nil, err)
return
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}