This commit is contained in:
crusader 2018-06-29 17:53:39 +09:00
parent 1c6419e15b
commit e34e7030c2
8 changed files with 41 additions and 132 deletions

View File

@ -1,58 +0,0 @@
package socket
import (
"errors"
"sync/atomic"
"git.loafle.net/commons/server-go"
)
type CompressionHandler interface {
IsEnableCompression() bool
GetCompressionLevel() int
GetCompressionThreshold() int
}
type CompressionHandlers struct {
EnableCompression bool `json:"enableCompression,omitempty"`
CompressionLevel int `json:"compressionLevel,omitempty"`
CompressionThreshold int `json:"compressionThreshold,omitempty"`
validated atomic.Value
}
func (ch *CompressionHandlers) IsEnableCompression() bool {
return ch.EnableCompression
}
func (ch *CompressionHandlers) GetCompressionLevel() int {
return ch.CompressionLevel
}
func (ch *CompressionHandlers) GetCompressionThreshold() int {
return ch.CompressionThreshold
}
func (ch *CompressionHandlers) Clone() *CompressionHandlers {
return &CompressionHandlers{
EnableCompression: ch.EnableCompression,
CompressionLevel: ch.CompressionLevel,
CompressionThreshold: ch.CompressionThreshold,
validated: ch.validated,
}
}
func (ch *CompressionHandlers) Validate() error {
if nil != ch.validated.Load() {
return nil
}
ch.validated.Store(true)
if !IsValidCompressionLevel(ch.CompressionLevel) {
return errors.New("Socket: invalid compression level")
}
if ch.CompressionThreshold <= 0 {
ch.CompressionThreshold = server.DefaultCompressionThreshold
}
return nil
}

View File

@ -19,7 +19,7 @@ func WriteJSON(c *SocketConn, v interface{}) error {
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
func (c *SocketConn) WriteJSON(v interface{}) error { func (c *SocketConn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage, true) w, err := c.NextWriter(TextMessage)
if err != nil { if err != nil {
return err return err
} }

View File

@ -22,10 +22,9 @@ type PreparedMessage struct {
// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
type prepareKey struct { type prepareKey struct {
isServer bool isServer bool
compress bool compress bool
compressionLevel int compressionLevel int
compressionThreshold int
} }
// preparedFrame contains data in wire representation. // preparedFrame contains data in wire representation.
@ -79,7 +78,6 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
mu: mu, mu: mu,
isServer: key.isServer, isServer: key.isServer,
compressionLevel: key.compressionLevel, compressionLevel: key.compressionLevel,
compressionThreshold: key.compressionThreshold,
enableWriteCompression: true, enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
} }

View File

@ -126,7 +126,7 @@ type Conn interface {
EnableWriteCompression(enable bool) EnableWriteCompression(enable bool)
LocalAddr() net.Addr LocalAddr() net.Addr
NextReader() (messageType int, r io.Reader, err error) NextReader() (messageType int, r io.Reader, err error)
NextWriter(messageType int, useCompress bool) (io.WriteCloser, error) NextWriter(messageType int) (io.WriteCloser, error)
PingHandler() func(appData string) error PingHandler() func(appData string) error
PongHandler() func(appData string) error PongHandler() func(appData string) error
ReadJSON(v interface{}) error ReadJSON(v interface{}) error
@ -134,7 +134,6 @@ type Conn interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
SetCloseHandler(h func(code int, text string) error) SetCloseHandler(h func(code int, text string) error)
SetCompressionLevel(level int) error SetCompressionLevel(level int) error
SetCompressionThreshold(threshold int)
SetPingHandler(h func(appData string) error) SetPingHandler(h func(appData string) error)
SetPongHandler(h func(appData string) error) SetPongHandler(h func(appData string) error)
SetReadDeadline(t time.Time) error SetReadDeadline(t time.Time) error
@ -167,7 +166,6 @@ type SocketConn struct {
enableWriteCompression bool enableWriteCompression bool
compressionLevel int compressionLevel int
compressionThreshold int
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
// Read fields // Read fields
@ -257,7 +255,6 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
writeBuf: writeBuf, writeBuf: writeBuf,
enableWriteCompression: true, enableWriteCompression: true,
compressionLevel: defaultCompressionLevel, compressionLevel: defaultCompressionLevel,
compressionThreshold: defaultCompressionThreshold,
} }
c.SetCloseHandler(nil) c.SetCloseHandler(nil)
c.SetPingHandler(nil) c.SetPingHandler(nil)
@ -445,7 +442,7 @@ func (c *SocketConn) prepWrite(messageType int) error {
// //
// There can be at most one open writer on a connection. NextWriter closes the // There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so. // previous writer if the application has not already done so.
func (c *SocketConn) NextWriter(messageType int, useCompress bool) (io.WriteCloser, error) { func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil { if err := c.prepWrite(messageType); err != nil {
return nil, err return nil, err
} }
@ -456,7 +453,7 @@ func (c *SocketConn) NextWriter(messageType int, useCompress bool) (io.WriteClos
pos: maxFrameHeaderSize, pos: maxFrameHeaderSize,
} }
c.writer = mw c.writer = mw
if useCompress && c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel) w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true mw.compress = true
c.writer = w c.writer = w
@ -667,10 +664,9 @@ func (w *messageWriter) Close() error {
// WritePreparedMessage writes prepared message into connection. // WritePreparedMessage writes prepared message into connection.
func (c *SocketConn) WritePreparedMessage(pm *PreparedMessage) error { func (c *SocketConn) WritePreparedMessage(pm *PreparedMessage) error {
frameType, frameData, err := pm.frame(prepareKey{ frameType, frameData, err := pm.frame(prepareKey{
isServer: c.isServer, isServer: c.isServer,
compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
compressionLevel: c.compressionLevel, compressionLevel: c.compressionLevel,
compressionThreshold: c.compressionThreshold,
}) })
if err != nil { if err != nil {
return err return err
@ -704,43 +700,7 @@ func (c *SocketConn) WriteMessage(messageType int, data []byte) error {
return mw.flushFrame(true, data) return mw.flushFrame(true, data)
} }
w, err := c.NextWriter(messageType, false) w, err := c.NextWriter(messageType)
if err != nil {
return err
}
if _, err = w.Write(data); err != nil {
return err
}
return w.Close()
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *SocketConn) WriteCompress(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression || len(data) <= c.compressionThreshold) {
// Fast path with no allocations and single frame.
if err := c.prepWrite(messageType); err != nil {
return err
}
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n
data = data[n:]
return mw.flushFrame(true, data)
}
var w io.WriteCloser
var err error
length := len(data)
switch {
case length > c.compressionThreshold:
w, err = c.NextWriter(messageType, true)
default:
w, err = c.NextWriter(messageType, false)
}
if err != nil { if err != nil {
return err return err
} }
@ -1138,10 +1098,6 @@ func (c *SocketConn) SetCompressionLevel(level int) error {
return nil return nil
} }
func (c *SocketConn) SetCompressionThreshold(threshold int) {
c.compressionThreshold = threshold
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
func FormatCloseMessage(closeCode int, text string) []byte { func FormatCloseMessage(closeCode int, text string) []byte {
buf := make([]byte, 2+len(text)) buf := make([]byte, 2+len(text))

View File

@ -1,6 +1,7 @@
package socket package socket
import ( import (
"errors"
"sync/atomic" "sync/atomic"
"time" "time"
@ -12,6 +13,9 @@ type ReadWriteHandler interface {
GetPongTimeout() time.Duration GetPongTimeout() time.Duration
GetPingTimeout() time.Duration GetPingTimeout() time.Duration
GetPingPeriod() time.Duration GetPingPeriod() time.Duration
IsEnableCompression() bool
GetCompressionLevel() int
} }
type ReadWriteHandlers struct { type ReadWriteHandlers struct {
@ -21,6 +25,9 @@ type ReadWriteHandlers struct {
PingTimeout time.Duration `json:"pingTimeout,omitempty"` PingTimeout time.Duration `json:"pingTimeout,omitempty"`
PingPeriod time.Duration `json:"pingPeriod,omitempty"` PingPeriod time.Duration `json:"pingPeriod,omitempty"`
EnableCompression bool `json:"enableCompression,omitempty"`
CompressionLevel int `json:"compressionLevel,omitempty"`
validated atomic.Value validated atomic.Value
} }
@ -34,12 +41,21 @@ func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
return rwh.PingPeriod return rwh.PingPeriod
} }
func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression
}
func (rwh *ReadWriteHandlers) GetCompressionLevel() int {
return rwh.CompressionLevel
}
func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers { func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers {
return &ReadWriteHandlers{ return &ReadWriteHandlers{
ReadWriteHandlers: *rwh.ReadWriteHandlers.Clone(), ReadWriteHandlers: *rwh.ReadWriteHandlers.Clone(),
PongTimeout: rwh.PongTimeout, PongTimeout: rwh.PongTimeout,
PingTimeout: rwh.PingTimeout, PingTimeout: rwh.PingTimeout,
PingPeriod: rwh.PingPeriod, PingPeriod: rwh.PingPeriod,
EnableCompression: rwh.EnableCompression,
CompressionLevel: rwh.CompressionLevel,
validated: rwh.validated, validated: rwh.validated,
} }
} }
@ -71,5 +87,11 @@ func (rwh *ReadWriteHandlers) Validate() error {
rwh.PingPeriod = rwh.PingPeriod * time.Second rwh.PingPeriod = rwh.PingPeriod * time.Second
} }
if rwh.EnableCompression {
if !IsValidCompressionLevel(rwh.CompressionLevel) {
return errors.New("Socket: invalid compression level")
}
}
return nil return nil
} }

View File

@ -9,13 +9,11 @@ import (
type ServerHandler interface { type ServerHandler interface {
server.ServerHandler server.ServerHandler
ReadWriteHandler ReadWriteHandler
CompressionHandler
} }
type ServerHandlers struct { type ServerHandlers struct {
server.ServerHandlers server.ServerHandlers
ReadWriteHandlers ReadWriteHandlers
CompressionHandlers
validated atomic.Value validated atomic.Value
} }
@ -32,9 +30,6 @@ func (sh *ServerHandlers) Validate() error {
if err := sh.ReadWriteHandlers.Validate(); nil != err { if err := sh.ReadWriteHandlers.Validate(); nil != err {
return err return err
} }
if err := sh.CompressionHandlers.Validate(); nil != err {
return err
}
return nil return nil
} }

View File

@ -57,14 +57,13 @@ func (s *Server) ListenAndServe() error {
} }
s.upgrader = &Upgrader{ s.upgrader = &Upgrader{
HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(),
ReadBufferSize: s.ServerHandler.GetReadBufferSize(), ReadBufferSize: s.ServerHandler.GetReadBufferSize(),
WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), WriteBufferSize: s.ServerHandler.GetWriteBufferSize(),
CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin, CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin,
Error: s.onError, Error: s.onError,
EnableCompression: s.ServerHandler.IsEnableCompression(), EnableCompression: s.ServerHandler.IsEnableCompression(),
CompressionLevel: s.ServerHandler.GetCompressionLevel(), CompressionLevel: s.ServerHandler.GetCompressionLevel(),
CompressionThreshold: s.ServerHandler.GetCompressionThreshold(),
} }
if err = s.ServerHandler.Init(s.ctx); nil != err { if err = s.ServerHandler.Init(s.ctx); nil != err {

View File

@ -58,9 +58,7 @@ type Upgrader struct {
// guarantee that compression will be supported. Currently only "no context // guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported. // takeover" modes are supported.
EnableCompression bool EnableCompression bool
CompressionLevel int
CompressionLevel int
CompressionThreshold int
} }
func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*socket.SocketConn, error) { func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*socket.SocketConn, error) {
@ -199,7 +197,6 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
c.SetSubprotocol(subprotocol) c.SetSubprotocol(subprotocol)
if compress { if compress {
c.SetCompressionLevel(u.CompressionLevel) c.SetCompressionLevel(u.CompressionLevel)
c.SetCompressionThreshold(u.CompressionThreshold)
c.SetNewCompressionWriter(socket.CompressNoContextTakeover) c.SetNewCompressionWriter(socket.CompressNoContextTakeover)
c.SetNewDecompressionReader(socket.DecompressNoContextTakeover) c.SetNewDecompressionReader(socket.DecompressNoContextTakeover)
} }