This commit is contained in:
crusader 2018-06-27 21:31:50 +09:00
parent 1cae23cf78
commit 1c6419e15b
11 changed files with 169 additions and 38 deletions

View File

@ -33,4 +33,6 @@ const (
DefaultReconnectInterval = 5 * time.Second DefaultReconnectInterval = 5 * time.Second
DefaultReconnectTryTime = 10 DefaultReconnectTryTime = 10
DefaultCompressionThreshold = 1024
) )

View File

@ -0,0 +1,58 @@
package socket
import (
"errors"
"sync/atomic"
"git.loafle.net/commons/server-go"
)
type CompressionHandler interface {
IsEnableCompression() bool
GetCompressionLevel() int
GetCompressionThreshold() int
}
type CompressionHandlers struct {
EnableCompression bool `json:"enableCompression,omitempty"`
CompressionLevel int `json:"compressionLevel,omitempty"`
CompressionThreshold int `json:"compressionThreshold,omitempty"`
validated atomic.Value
}
func (ch *CompressionHandlers) IsEnableCompression() bool {
return ch.EnableCompression
}
func (ch *CompressionHandlers) GetCompressionLevel() int {
return ch.CompressionLevel
}
func (ch *CompressionHandlers) GetCompressionThreshold() int {
return ch.CompressionThreshold
}
func (ch *CompressionHandlers) Clone() *CompressionHandlers {
return &CompressionHandlers{
EnableCompression: ch.EnableCompression,
CompressionLevel: ch.CompressionLevel,
CompressionThreshold: ch.CompressionThreshold,
validated: ch.validated,
}
}
func (ch *CompressionHandlers) Validate() error {
if nil != ch.validated.Load() {
return nil
}
ch.validated.Store(true)
if !IsValidCompressionLevel(ch.CompressionLevel) {
return errors.New("Socket: invalid compression level")
}
if ch.CompressionThreshold <= 0 {
ch.CompressionThreshold = server.DefaultCompressionThreshold
}
return nil
}

View File

@ -12,6 +12,7 @@ const (
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression maxCompressionLevel = flate.BestCompression
defaultCompressionLevel = 1 defaultCompressionLevel = 1
defaultCompressionThreshold = 1024
) )
var ( var (
@ -33,7 +34,7 @@ func DecompressNoContextTakeover(r io.Reader) io.ReadCloser {
return &flateReadWrapper{fr} return &flateReadWrapper{fr}
} }
func isValidCompressionLevel(level int) bool { func IsValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel return minCompressionLevel <= level && level <= maxCompressionLevel
} }
@ -46,6 +47,7 @@ func CompressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
} else { } else {
fw.Reset(tw) fw.Reset(tw)
} }
return &flateWriteWrapper{fw: fw, tw: tw, p: p} return &flateWriteWrapper{fw: fw, tw: tw, p: p}
} }
@ -95,6 +97,7 @@ func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil { if w.fw == nil {
return 0, errWriteClosed return 0, errWriteClosed
} }
return w.fw.Write(p) return w.fw.Write(p)
} }

View File

@ -19,7 +19,7 @@ func WriteJSON(c *SocketConn, v interface{}) error {
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
func (c *SocketConn) WriteJSON(v interface{}) error { func (c *SocketConn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage) w, err := c.NextWriter(TextMessage, true)
if err != nil { if err != nil {
return err return err
} }

View File

@ -25,6 +25,7 @@ type prepareKey struct {
isServer bool isServer bool
compress bool compress bool
compressionLevel int compressionLevel int
compressionThreshold int
} }
// preparedFrame contains data in wire representation. // preparedFrame contains data in wire representation.
@ -78,6 +79,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
mu: mu, mu: mu,
isServer: key.isServer, isServer: key.isServer,
compressionLevel: key.compressionLevel, compressionLevel: key.compressionLevel,
compressionThreshold: key.compressionThreshold,
enableWriteCompression: true, enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
} }

View File

@ -126,7 +126,7 @@ type Conn interface {
EnableWriteCompression(enable bool) EnableWriteCompression(enable bool)
LocalAddr() net.Addr LocalAddr() net.Addr
NextReader() (messageType int, r io.Reader, err error) NextReader() (messageType int, r io.Reader, err error)
NextWriter(messageType int) (io.WriteCloser, error) NextWriter(messageType int, useCompress bool) (io.WriteCloser, error)
PingHandler() func(appData string) error PingHandler() func(appData string) error
PongHandler() func(appData string) error PongHandler() func(appData string) error
ReadJSON(v interface{}) error ReadJSON(v interface{}) error
@ -134,6 +134,7 @@ type Conn interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
SetCloseHandler(h func(code int, text string) error) SetCloseHandler(h func(code int, text string) error)
SetCompressionLevel(level int) error SetCompressionLevel(level int) error
SetCompressionThreshold(threshold int)
SetPingHandler(h func(appData string) error) SetPingHandler(h func(appData string) error)
SetPongHandler(h func(appData string) error) SetPongHandler(h func(appData string) error)
SetReadDeadline(t time.Time) error SetReadDeadline(t time.Time) error
@ -144,6 +145,7 @@ type Conn interface {
WriteControl(messageType int, data []byte, deadline time.Time) error WriteControl(messageType int, data []byte, deadline time.Time) error
WriteJSON(v interface{}) error WriteJSON(v interface{}) error
WriteMessage(messageType int, data []byte) error WriteMessage(messageType int, data []byte) error
WriteCompress(messageType int, data []byte) error
WritePreparedMessage(pm *PreparedMessage) error WritePreparedMessage(pm *PreparedMessage) error
} }
@ -165,6 +167,7 @@ type SocketConn struct {
enableWriteCompression bool enableWriteCompression bool
compressionLevel int compressionLevel int
compressionThreshold int
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
// Read fields // Read fields
@ -254,6 +257,7 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
writeBuf: writeBuf, writeBuf: writeBuf,
enableWriteCompression: true, enableWriteCompression: true,
compressionLevel: defaultCompressionLevel, compressionLevel: defaultCompressionLevel,
compressionThreshold: defaultCompressionThreshold,
} }
c.SetCloseHandler(nil) c.SetCloseHandler(nil)
c.SetPingHandler(nil) c.SetPingHandler(nil)
@ -417,7 +421,31 @@ func (c *SocketConn) prepWrite(messageType int) error {
// //
// There can be at most one open writer on a connection. NextWriter closes the // There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so. // previous writer if the application has not already done so.
func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) { // func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) {
// if err := c.prepWrite(messageType); err != nil {
// return nil, err
// }
// mw := &messageWriter{
// c: c,
// frameType: messageType,
// pos: maxFrameHeaderSize,
// }
// c.writer = mw
// if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
// w := c.newCompressionWriter(c.writer, c.compressionLevel)
// mw.compress = true
// c.writer = w
// }
// return c.writer, nil
// }
// NextWriterWithUseCompress returns a writer for the next message to send. The writer's Close
// method flushes the complete message to the network.
//
// There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so.
func (c *SocketConn) NextWriter(messageType int, useCompress bool) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil { if err := c.prepWrite(messageType); err != nil {
return nil, err return nil, err
} }
@ -428,7 +456,7 @@ func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) {
pos: maxFrameHeaderSize, pos: maxFrameHeaderSize,
} }
c.writer = mw c.writer = mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { if useCompress && c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel) w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true mw.compress = true
c.writer = w c.writer = w
@ -642,6 +670,7 @@ func (c *SocketConn) WritePreparedMessage(pm *PreparedMessage) error {
isServer: c.isServer, isServer: c.isServer,
compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
compressionLevel: c.compressionLevel, compressionLevel: c.compressionLevel,
compressionThreshold: c.compressionThreshold,
}) })
if err != nil { if err != nil {
return err return err
@ -675,7 +704,43 @@ func (c *SocketConn) WriteMessage(messageType int, data []byte) error {
return mw.flushFrame(true, data) return mw.flushFrame(true, data)
} }
w, err := c.NextWriter(messageType) w, err := c.NextWriter(messageType, false)
if err != nil {
return err
}
if _, err = w.Write(data); err != nil {
return err
}
return w.Close()
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *SocketConn) WriteCompress(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression || len(data) <= c.compressionThreshold) {
// Fast path with no allocations and single frame.
if err := c.prepWrite(messageType); err != nil {
return err
}
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n
data = data[n:]
return mw.flushFrame(true, data)
}
var w io.WriteCloser
var err error
length := len(data)
switch {
case length > c.compressionThreshold:
w, err = c.NextWriter(messageType, true)
default:
w, err = c.NextWriter(messageType, false)
}
if err != nil { if err != nil {
return err return err
} }
@ -1066,13 +1131,17 @@ func (c *SocketConn) EnableWriteCompression(enable bool) {
// with the peer. See the compress/flate package for a description of // with the peer. See the compress/flate package for a description of
// compression levels. // compression levels.
func (c *SocketConn) SetCompressionLevel(level int) error { func (c *SocketConn) SetCompressionLevel(level int) error {
if !isValidCompressionLevel(level) { if !IsValidCompressionLevel(level) {
return errors.New("websocket: invalid compression level") return errors.New("websocket: invalid compression level")
} }
c.compressionLevel = level c.compressionLevel = level
return nil return nil
} }
func (c *SocketConn) SetCompressionThreshold(threshold int) {
c.compressionThreshold = threshold
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
func FormatCloseMessage(closeCode int, text string) []byte { func FormatCloseMessage(closeCode int, text string) []byte {
buf := make([]byte, 2+len(text)) buf := make([]byte, 2+len(text))

View File

@ -12,8 +12,6 @@ type ReadWriteHandler interface {
GetPongTimeout() time.Duration GetPongTimeout() time.Duration
GetPingTimeout() time.Duration GetPingTimeout() time.Duration
GetPingPeriod() time.Duration GetPingPeriod() time.Duration
IsEnableCompression() bool
} }
type ReadWriteHandlers struct { type ReadWriteHandlers struct {
@ -23,8 +21,6 @@ type ReadWriteHandlers struct {
PingTimeout time.Duration `json:"pingTimeout,omitempty"` PingTimeout time.Duration `json:"pingTimeout,omitempty"`
PingPeriod time.Duration `json:"pingPeriod,omitempty"` PingPeriod time.Duration `json:"pingPeriod,omitempty"`
EnableCompression bool `json:"enableCompression,omitempty"`
validated atomic.Value validated atomic.Value
} }
@ -37,9 +33,6 @@ func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration {
func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration { func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
return rwh.PingPeriod return rwh.PingPeriod
} }
func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression
}
func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers { func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers {
return &ReadWriteHandlers{ return &ReadWriteHandlers{
@ -47,7 +40,6 @@ func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers {
PongTimeout: rwh.PongTimeout, PongTimeout: rwh.PongTimeout,
PingTimeout: rwh.PingTimeout, PingTimeout: rwh.PingTimeout,
PingPeriod: rwh.PingPeriod, PingPeriod: rwh.PingPeriod,
EnableCompression: rwh.EnableCompression,
validated: rwh.validated, validated: rwh.validated,
} }
} }

View File

@ -2,7 +2,6 @@ package socket
import ( import (
"fmt" "fmt"
"io"
"time" "time"
logging "git.loafle.net/commons/logging-go" logging "git.loafle.net/commons/logging-go"
@ -64,7 +63,6 @@ func connReadHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-ch
func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) { func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) {
var ( var (
wc io.WriteCloser
message []byte message []byte
ok bool ok bool
err error err error
@ -91,17 +89,12 @@ func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-c
return return
} }
wc, err = conn.NextWriter(TextMessage) err = conn.WriteCompress(TextMessage, message)
if err != nil { if err != nil {
logging.Logger().Debug(err) logging.Logger().Debug(err)
return return
} }
wc.Write(message)
if err = wc.Close(); nil != err {
logging.Logger().Debug(err)
return
}
case <-ticker.C: case <-ticker.C:
if 0 < readWriteHandler.GetPingTimeout() { if 0 < readWriteHandler.GetPingTimeout() {
conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetPingTimeout())) conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetPingTimeout()))

View File

@ -9,11 +9,13 @@ import (
type ServerHandler interface { type ServerHandler interface {
server.ServerHandler server.ServerHandler
ReadWriteHandler ReadWriteHandler
CompressionHandler
} }
type ServerHandlers struct { type ServerHandlers struct {
server.ServerHandlers server.ServerHandlers
ReadWriteHandlers ReadWriteHandlers
CompressionHandlers
validated atomic.Value validated atomic.Value
} }
@ -30,6 +32,9 @@ func (sh *ServerHandlers) Validate() error {
if err := sh.ReadWriteHandlers.Validate(); nil != err { if err := sh.ReadWriteHandlers.Validate(); nil != err {
return err return err
} }
if err := sh.CompressionHandlers.Validate(); nil != err {
return err
}
return nil return nil
} }

View File

@ -63,6 +63,8 @@ func (s *Server) ListenAndServe() error {
CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin, CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin,
Error: s.onError, Error: s.onError,
EnableCompression: s.ServerHandler.IsEnableCompression(), EnableCompression: s.ServerHandler.IsEnableCompression(),
CompressionLevel: s.ServerHandler.GetCompressionLevel(),
CompressionThreshold: s.ServerHandler.GetCompressionThreshold(),
} }
if err = s.ServerHandler.Init(s.ctx); nil != err { if err = s.ServerHandler.Init(s.ctx); nil != err {

View File

@ -58,6 +58,9 @@ type Upgrader struct {
// guarantee that compression will be supported. Currently only "no context // guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported. // takeover" modes are supported.
EnableCompression bool EnableCompression bool
CompressionLevel int
CompressionThreshold int
} }
func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*socket.SocketConn, error) { func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*socket.SocketConn, error) {
@ -195,6 +198,8 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
c := socket.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c := socket.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.SetSubprotocol(subprotocol) c.SetSubprotocol(subprotocol)
if compress { if compress {
c.SetCompressionLevel(u.CompressionLevel)
c.SetCompressionThreshold(u.CompressionThreshold)
c.SetNewCompressionWriter(socket.CompressNoContextTakeover) c.SetNewCompressionWriter(socket.CompressNoContextTakeover)
c.SetNewDecompressionReader(socket.DecompressNoContextTakeover) c.SetNewDecompressionReader(socket.DecompressNoContextTakeover)
} }