diff --git a/const.go b/const.go index d501ca6..6613aab 100644 --- a/const.go +++ b/const.go @@ -33,4 +33,6 @@ const ( DefaultReconnectInterval = 5 * time.Second DefaultReconnectTryTime = 10 + + DefaultCompressionThreshold = 1024 ) diff --git a/socket/compression-handler.go b/socket/compression-handler.go new file mode 100644 index 0000000..3267d8b --- /dev/null +++ b/socket/compression-handler.go @@ -0,0 +1,58 @@ +package socket + +import ( + "errors" + "sync/atomic" + + "git.loafle.net/commons/server-go" +) + +type CompressionHandler interface { + IsEnableCompression() bool + GetCompressionLevel() int + GetCompressionThreshold() int +} + +type CompressionHandlers struct { + EnableCompression bool `json:"enableCompression,omitempty"` + CompressionLevel int `json:"compressionLevel,omitempty"` + CompressionThreshold int `json:"compressionThreshold,omitempty"` + + validated atomic.Value +} + +func (ch *CompressionHandlers) IsEnableCompression() bool { + return ch.EnableCompression +} +func (ch *CompressionHandlers) GetCompressionLevel() int { + return ch.CompressionLevel +} +func (ch *CompressionHandlers) GetCompressionThreshold() int { + return ch.CompressionThreshold +} + +func (ch *CompressionHandlers) Clone() *CompressionHandlers { + return &CompressionHandlers{ + EnableCompression: ch.EnableCompression, + CompressionLevel: ch.CompressionLevel, + CompressionThreshold: ch.CompressionThreshold, + validated: ch.validated, + } +} + +func (ch *CompressionHandlers) Validate() error { + if nil != ch.validated.Load() { + return nil + } + ch.validated.Store(true) + + if !IsValidCompressionLevel(ch.CompressionLevel) { + return errors.New("Socket: invalid compression level") + } + + if ch.CompressionThreshold <= 0 { + ch.CompressionThreshold = server.DefaultCompressionThreshold + } + + return nil +} diff --git a/socket/conn-compression.go b/socket/conn-compression.go index a915899..acc279f 100644 --- a/socket/conn-compression.go +++ b/socket/conn-compression.go @@ -9,9 +9,10 @@ import ( ) const ( - minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 - maxCompressionLevel = flate.BestCompression - defaultCompressionLevel = 1 + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 + defaultCompressionThreshold = 1024 ) var ( @@ -33,7 +34,7 @@ func DecompressNoContextTakeover(r io.Reader) io.ReadCloser { return &flateReadWrapper{fr} } -func isValidCompressionLevel(level int) bool { +func IsValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } @@ -46,6 +47,7 @@ func CompressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { } else { fw.Reset(tw) } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} } @@ -95,6 +97,7 @@ func (w *flateWriteWrapper) Write(p []byte) (int, error) { if w.fw == nil { return 0, errWriteClosed } + return w.fw.Write(p) } diff --git a/socket/conn-json.go b/socket/conn-json.go index 5683eda..e89edd1 100644 --- a/socket/conn-json.go +++ b/socket/conn-json.go @@ -19,7 +19,7 @@ func WriteJSON(c *SocketConn, v interface{}) error { // See the documentation for encoding/json Marshal for details about the // conversion of Go values to JSON. func (c *SocketConn) WriteJSON(v interface{}) error { - w, err := c.NextWriter(TextMessage) + w, err := c.NextWriter(TextMessage, true) if err != nil { return err } diff --git a/socket/conn-prepared.go b/socket/conn-prepared.go index 30d2dff..6cf1fab 100644 --- a/socket/conn-prepared.go +++ b/socket/conn-prepared.go @@ -22,9 +22,10 @@ type PreparedMessage struct { // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. type prepareKey struct { - isServer bool - compress bool - compressionLevel int + isServer bool + compress bool + compressionLevel int + compressionThreshold int } // preparedFrame contains data in wire representation. @@ -78,6 +79,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { mu: mu, isServer: key.isServer, compressionLevel: key.compressionLevel, + compressionThreshold: key.compressionThreshold, enableWriteCompression: true, writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), } diff --git a/socket/conn.go b/socket/conn.go index 41548e0..b60e535 100644 --- a/socket/conn.go +++ b/socket/conn.go @@ -126,7 +126,7 @@ type Conn interface { EnableWriteCompression(enable bool) LocalAddr() net.Addr NextReader() (messageType int, r io.Reader, err error) - NextWriter(messageType int) (io.WriteCloser, error) + NextWriter(messageType int, useCompress bool) (io.WriteCloser, error) PingHandler() func(appData string) error PongHandler() func(appData string) error ReadJSON(v interface{}) error @@ -134,6 +134,7 @@ type Conn interface { RemoteAddr() net.Addr SetCloseHandler(h func(code int, text string) error) SetCompressionLevel(level int) error + SetCompressionThreshold(threshold int) SetPingHandler(h func(appData string) error) SetPongHandler(h func(appData string) error) SetReadDeadline(t time.Time) error @@ -144,6 +145,7 @@ type Conn interface { WriteControl(messageType int, data []byte, deadline time.Time) error WriteJSON(v interface{}) error WriteMessage(messageType int, data []byte) error + WriteCompress(messageType int, data []byte) error WritePreparedMessage(pm *PreparedMessage) error } @@ -165,6 +167,7 @@ type SocketConn struct { enableWriteCompression bool compressionLevel int + compressionThreshold int newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields @@ -254,6 +257,7 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in writeBuf: writeBuf, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, + compressionThreshold: defaultCompressionThreshold, } c.SetCloseHandler(nil) c.SetPingHandler(nil) @@ -417,7 +421,31 @@ func (c *SocketConn) prepWrite(messageType int) error { // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. -func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) { +// func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) { +// if err := c.prepWrite(messageType); err != nil { +// return nil, err +// } + +// mw := &messageWriter{ +// c: c, +// frameType: messageType, +// pos: maxFrameHeaderSize, +// } +// c.writer = mw +// if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { +// w := c.newCompressionWriter(c.writer, c.compressionLevel) +// mw.compress = true +// c.writer = w +// } +// return c.writer, nil +// } + +// NextWriterWithUseCompress returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +func (c *SocketConn) NextWriter(messageType int, useCompress bool) (io.WriteCloser, error) { if err := c.prepWrite(messageType); err != nil { return nil, err } @@ -428,7 +456,7 @@ func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) { pos: maxFrameHeaderSize, } c.writer = mw - if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + if useCompress && c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true c.writer = w @@ -639,9 +667,10 @@ func (w *messageWriter) Close() error { // WritePreparedMessage writes prepared message into connection. func (c *SocketConn) WritePreparedMessage(pm *PreparedMessage) error { frameType, frameData, err := pm.frame(prepareKey{ - isServer: c.isServer, - compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), - compressionLevel: c.compressionLevel, + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + compressionThreshold: c.compressionThreshold, }) if err != nil { return err @@ -675,7 +704,43 @@ func (c *SocketConn) WriteMessage(messageType int, data []byte) error { return mw.flushFrame(true, data) } - w, err := c.NextWriter(messageType) + w, err := c.NextWriter(messageType, false) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *SocketConn) WriteCompress(messageType int, data []byte) error { + + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression || len(data) <= c.compressionThreshold) { + // Fast path with no allocations and single frame. + + if err := c.prepWrite(messageType); err != nil { + return err + } + mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize} + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + var w io.WriteCloser + var err error + length := len(data) + switch { + case length > c.compressionThreshold: + w, err = c.NextWriter(messageType, true) + default: + w, err = c.NextWriter(messageType, false) + } + if err != nil { return err } @@ -1066,13 +1131,17 @@ func (c *SocketConn) EnableWriteCompression(enable bool) { // with the peer. See the compress/flate package for a description of // compression levels. func (c *SocketConn) SetCompressionLevel(level int) error { - if !isValidCompressionLevel(level) { + if !IsValidCompressionLevel(level) { return errors.New("websocket: invalid compression level") } c.compressionLevel = level return nil } +func (c *SocketConn) SetCompressionThreshold(threshold int) { + c.compressionThreshold = threshold +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. func FormatCloseMessage(closeCode int, text string) []byte { buf := make([]byte, 2+len(text)) diff --git a/socket/readwrite-handler.go b/socket/readwrite-handler.go index aa20b22..64d1c2e 100644 --- a/socket/readwrite-handler.go +++ b/socket/readwrite-handler.go @@ -12,8 +12,6 @@ type ReadWriteHandler interface { GetPongTimeout() time.Duration GetPingTimeout() time.Duration GetPingPeriod() time.Duration - - IsEnableCompression() bool } type ReadWriteHandlers struct { @@ -23,8 +21,6 @@ type ReadWriteHandlers struct { PingTimeout time.Duration `json:"pingTimeout,omitempty"` PingPeriod time.Duration `json:"pingPeriod,omitempty"` - EnableCompression bool `json:"enableCompression,omitempty"` - validated atomic.Value } @@ -37,9 +33,6 @@ func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration { func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration { return rwh.PingPeriod } -func (rwh *ReadWriteHandlers) IsEnableCompression() bool { - return rwh.EnableCompression -} func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers { return &ReadWriteHandlers{ @@ -47,7 +40,6 @@ func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers { PongTimeout: rwh.PongTimeout, PingTimeout: rwh.PingTimeout, PingPeriod: rwh.PingPeriod, - EnableCompression: rwh.EnableCompression, validated: rwh.validated, } } diff --git a/socket/readwrite.go b/socket/readwrite.go index fc644de..05c5694 100644 --- a/socket/readwrite.go +++ b/socket/readwrite.go @@ -2,7 +2,6 @@ package socket import ( "fmt" - "io" "time" logging "git.loafle.net/commons/logging-go" @@ -64,7 +63,6 @@ func connReadHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-ch func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) { var ( - wc io.WriteCloser message []byte ok bool err error @@ -91,17 +89,12 @@ func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-c return } - wc, err = conn.NextWriter(TextMessage) + err = conn.WriteCompress(TextMessage, message) if err != nil { logging.Logger().Debug(err) return } - wc.Write(message) - if err = wc.Close(); nil != err { - logging.Logger().Debug(err) - return - } case <-ticker.C: if 0 < readWriteHandler.GetPingTimeout() { conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetPingTimeout())) diff --git a/socket/server-handler.go b/socket/server-handler.go index 64d2a74..7e3358f 100644 --- a/socket/server-handler.go +++ b/socket/server-handler.go @@ -9,11 +9,13 @@ import ( type ServerHandler interface { server.ServerHandler ReadWriteHandler + CompressionHandler } type ServerHandlers struct { server.ServerHandlers ReadWriteHandlers + CompressionHandlers validated atomic.Value } @@ -30,6 +32,9 @@ func (sh *ServerHandlers) Validate() error { if err := sh.ReadWriteHandlers.Validate(); nil != err { return err } + if err := sh.CompressionHandlers.Validate(); nil != err { + return err + } return nil } diff --git a/socket/web/server.go b/socket/web/server.go index d3b6cb5..bfa66ea 100644 --- a/socket/web/server.go +++ b/socket/web/server.go @@ -57,12 +57,14 @@ func (s *Server) ListenAndServe() error { } s.upgrader = &Upgrader{ - HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), - ReadBufferSize: s.ServerHandler.GetReadBufferSize(), - WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), - CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin, - Error: s.onError, - EnableCompression: s.ServerHandler.IsEnableCompression(), + HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), + ReadBufferSize: s.ServerHandler.GetReadBufferSize(), + WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), + CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin, + Error: s.onError, + EnableCompression: s.ServerHandler.IsEnableCompression(), + CompressionLevel: s.ServerHandler.GetCompressionLevel(), + CompressionThreshold: s.ServerHandler.GetCompressionThreshold(), } if err = s.ServerHandler.Init(s.ctx); nil != err { diff --git a/socket/web/upgrade.go b/socket/web/upgrade.go index 37a86a1..9ff6a81 100644 --- a/socket/web/upgrade.go +++ b/socket/web/upgrade.go @@ -58,6 +58,9 @@ type Upgrader struct { // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool + + CompressionLevel int + CompressionThreshold int } func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*socket.SocketConn, error) { @@ -195,6 +198,8 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re c := socket.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c.SetSubprotocol(subprotocol) if compress { + c.SetCompressionLevel(u.CompressionLevel) + c.SetCompressionThreshold(u.CompressionThreshold) c.SetNewCompressionWriter(socket.CompressNoContextTakeover) c.SetNewDecompressionReader(socket.DecompressNoContextTakeover) }