This commit is contained in:
crusader 2018-04-05 01:51:34 +09:00
parent 2b4c818883
commit 899fc53ded
19 changed files with 142 additions and 119 deletions

View File

@ -12,12 +12,12 @@ type ClientReadWriter struct {
ReadChan chan<- []byte ReadChan chan<- []byte
WriteChan <-chan []byte WriteChan <-chan []byte
DisconnectedChan chan<- struct{} DisconnectedChan chan<- struct{}
ReconnectedChan <-chan *Conn ReconnectedChan <-chan Conn
ClientStopChan <-chan struct{} ClientStopChan <-chan struct{}
ClientStopWg *sync.WaitGroup ClientStopWg *sync.WaitGroup
} }
func (crw *ClientReadWriter) HandleConnection(conn *Conn) { func (crw *ClientReadWriter) HandleConnection(conn Conn) {
defer func() { defer func() {
if nil != conn { if nil != conn {

View File

@ -10,7 +10,7 @@ import (
) )
// WriteJSON is deprecated, use c.WriteJSON instead. // WriteJSON is deprecated, use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error { func WriteJSON(c *SocketConn, v interface{}) error {
return c.WriteJSON(v) return c.WriteJSON(v)
} }
@ -18,7 +18,7 @@ func WriteJSON(c *Conn, v interface{}) error {
// //
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
func (c *Conn) WriteJSON(v interface{}) error { func (c *SocketConn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage) w, err := c.NextWriter(TextMessage)
if err != nil { if err != nil {
return err return err
@ -32,7 +32,7 @@ func (c *Conn) WriteJSON(v interface{}) error {
} }
// ReadJSON is deprecated, use c.ReadJSON instead. // ReadJSON is deprecated, use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error { func ReadJSON(c *SocketConn, v interface{}) error {
return c.ReadJSON(v) return c.ReadJSON(v)
} }
@ -41,7 +41,7 @@ func ReadJSON(c *Conn, v interface{}) error {
// //
// See the documentation for the encoding/json Unmarshal function for details // See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value. // about the conversion of JSON to a Go value.
func (c *Conn) ReadJSON(v interface{}) error { func (c *SocketConn) ReadJSON(v interface{}) error {
_, r, err := c.NextReader() _, r, err := c.NextReader()
if err != nil { if err != nil {
return err return err

View File

@ -73,7 +73,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
mu := make(chan bool, 1) mu := make(chan bool, 1)
mu <- true mu <- true
var nc prepareConn var nc prepareConn
c := &Conn{ c := &SocketConn{
conn: &nc, conn: &nc,
mu: mu, mu: mu,
isServer: key.isServer, isServer: key.isServer,
@ -83,7 +83,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
} }
if key.compress { if key.compress {
c.NewCompressionWriter = CompressNoContextTakeover c.newCompressionWriter = CompressNoContextTakeover
} }
err = c.WriteMessage(pm.messageType, pm.data) err = c.WriteMessage(pm.messageType, pm.data)
frame.data = nc.buf.Bytes() frame.data = nc.buf.Bytes()

View File

@ -8,7 +8,7 @@ package server
import "io" import "io"
func (c *Conn) read(n int) ([]byte, error) { func (c *SocketConn) read(n int) ([]byte, error) {
p, err := c.BuffReader.Peek(n) p, err := c.BuffReader.Peek(n)
if err == io.EOF { if err == io.EOF {
err = errUnexpectedEOF err = errUnexpectedEOF

126
conn.go
View File

@ -120,11 +120,38 @@ func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
} }
// The Conn type represents a WebSocket connection. type Conn interface {
type Conn struct { Close() error
CloseHandler() func(code int, text string) error
EnableWriteCompression(enable bool)
LocalAddr() net.Addr
NextReader() (messageType int, r io.Reader, err error)
NextWriter(messageType int) (io.WriteCloser, error)
PingHandler() func(appData string) error
PongHandler() func(appData string) error
ReadJSON(v interface{}) error
ReadMessage() (messageType int, p []byte, err error)
RemoteAddr() net.Addr
SetCloseHandler(h func(code int, text string) error)
SetCompressionLevel(level int) error
SetPingHandler(h func(appData string) error)
SetPongHandler(h func(appData string) error)
SetReadDeadline(t time.Time) error
SetReadLimit(limit int64)
SetWriteDeadline(t time.Time) error
Subprotocol() string
UnderlyingConn() net.Conn
WriteControl(messageType int, data []byte, deadline time.Time) error
WriteJSON(v interface{}) error
WriteMessage(messageType int, data []byte) error
WritePreparedMessage(pm *PreparedMessage) error
}
// The SocketConn type represents a WebSocket connection.
type SocketConn struct {
conn net.Conn conn net.Conn
isServer bool isServer bool
Subprotocol string subprotocol string
// Write fields // Write fields
mu chan bool // used as mutex to protect write to conn mu chan bool // used as mutex to protect write to conn
@ -138,7 +165,7 @@ type Conn struct {
enableWriteCompression bool enableWriteCompression bool
compressionLevel int compressionLevel int
NewCompressionWriter func(io.WriteCloser, int) io.WriteCloser newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
// Read fields // Read fields
reader io.ReadCloser // the current reader returned to the application reader io.ReadCloser // the current reader returned to the application
@ -157,10 +184,10 @@ type Conn struct {
messageReader *messageReader // the current low-level reader messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set readDecompress bool // whether last read frame had RSV1 set
NewDecompressionReader func(io.Reader) io.ReadCloser newDecompressionReader func(io.Reader) io.ReadCloser
} }
func NewConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func NewConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *SocketConn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
} }
@ -173,7 +200,7 @@ func (wh *writeHook) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *SocketConn {
mu := make(chan bool, 1) mu := make(chan bool, 1)
mu <- true mu <- true
@ -218,7 +245,7 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
} }
c := &Conn{ c := &SocketConn{
isServer: isServer, isServer: isServer,
BuffReader: br, BuffReader: br,
conn: conn, conn: conn,
@ -235,23 +262,38 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
} }
// Close closes the underlying network connection without sending or waiting for a close frame. // Close closes the underlying network connection without sending or waiting for a close frame.
func (c *Conn) Close() error { func (c *SocketConn) Close() error {
return c.conn.Close() return c.conn.Close()
} }
func (c *SocketConn) Subprotocol() string {
return c.subprotocol
}
func (c *SocketConn) SetSubprotocol(subprotocol string) {
c.subprotocol = subprotocol
}
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr { func (c *SocketConn) LocalAddr() net.Addr {
return c.conn.LocalAddr() return c.conn.LocalAddr()
} }
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr { func (c *SocketConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.conn.RemoteAddr()
} }
func (c *SocketConn) SetNewCompressionWriter(w func(io.WriteCloser, int) io.WriteCloser) {
c.newCompressionWriter = w
}
func (c *SocketConn) SetNewDecompressionReader(r func(io.Reader) io.ReadCloser) {
c.newDecompressionReader = r
}
// Write methods // Write methods
func (c *Conn) writeFatal(err error) error { func (c *SocketConn) writeFatal(err error) error {
err = hideTempErr(err) err = hideTempErr(err)
c.writeErrMu.Lock() c.writeErrMu.Lock()
if c.writeErr == nil { if c.writeErr == nil {
@ -261,7 +303,7 @@ func (c *Conn) writeFatal(err error) error {
return err return err
} }
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { func (c *SocketConn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
<-c.mu <-c.mu
defer func() { c.mu <- true }() defer func() { c.mu <- true }()
@ -290,7 +332,7 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
// WriteControl writes a control message with the given deadline. The allowed // WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage. // message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { func (c *SocketConn) WriteControl(messageType int, data []byte, deadline time.Time) error {
if !isControl(messageType) { if !isControl(messageType) {
return errBadWriteOpCode return errBadWriteOpCode
} }
@ -351,7 +393,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err return err
} }
func (c *Conn) prepWrite(messageType int) error { func (c *SocketConn) prepWrite(messageType int) error {
// Close previous writer if not already closed by the application. It's // Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot // probably better to return an error in this situation, but we cannot
// change this without breaking existing applications. // change this without breaking existing applications.
@ -375,7 +417,7 @@ func (c *Conn) prepWrite(messageType int) error {
// //
// There can be at most one open writer on a connection. NextWriter closes the // There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so. // previous writer if the application has not already done so.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil { if err := c.prepWrite(messageType); err != nil {
return nil, err return nil, err
} }
@ -386,8 +428,8 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
pos: maxFrameHeaderSize, pos: maxFrameHeaderSize,
} }
c.writer = mw c.writer = mw
if c.NewCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.NewCompressionWriter(c.writer, c.compressionLevel) w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true mw.compress = true
c.writer = w c.writer = w
} }
@ -395,7 +437,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
} }
type messageWriter struct { type messageWriter struct {
c *Conn c *SocketConn
compress bool // whether next call to flushFrame should set RSV1 compress bool // whether next call to flushFrame should set RSV1
pos int // end of data in writeBuf. pos int // end of data in writeBuf.
frameType int // type of the current frame. frameType int // type of the current frame.
@ -595,10 +637,10 @@ func (w *messageWriter) Close() error {
} }
// WritePreparedMessage writes prepared message into connection. // WritePreparedMessage writes prepared message into connection.
func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { func (c *SocketConn) WritePreparedMessage(pm *PreparedMessage) error {
frameType, frameData, err := pm.frame(prepareKey{ frameType, frameData, err := pm.frame(prepareKey{
isServer: c.isServer, isServer: c.isServer,
compress: c.NewCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
compressionLevel: c.compressionLevel, compressionLevel: c.compressionLevel,
}) })
if err != nil { if err != nil {
@ -618,9 +660,9 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
// WriteMessage is a helper method for getting a writer using NextWriter, // WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer. // writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error { func (c *SocketConn) WriteMessage(messageType int, data []byte) error {
if c.isServer && (c.NewCompressionWriter == nil || !c.enableWriteCompression) { if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame. // Fast path with no allocations and single frame.
if err := c.prepWrite(messageType); err != nil { if err := c.prepWrite(messageType); err != nil {
@ -647,14 +689,14 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
// connection. After a write has timed out, the websocket state is corrupt and // connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will // all future writes will return an error. A zero value for t means writes will
// not time out. // not time out.
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *SocketConn) SetWriteDeadline(t time.Time) error {
c.writeDeadline = t c.writeDeadline = t
return nil return nil
} }
// Read methods // Read methods
func (c *Conn) advanceFrame() (int, error) { func (c *SocketConn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
@ -677,7 +719,7 @@ func (c *Conn) advanceFrame() (int, error) {
c.readRemaining = int64(p[1] & 0x7f) c.readRemaining = int64(p[1] & 0x7f)
c.readDecompress = false c.readDecompress = false
if c.NewDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true c.readDecompress = true
p[0] &^= rsv1Bit p[0] &^= rsv1Bit
} }
@ -800,7 +842,7 @@ func (c *Conn) advanceFrame() (int, error) {
return frameType, nil return frameType, nil
} }
func (c *Conn) handleProtocolError(message string) error { func (c *SocketConn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
return errors.New("websocket: " + message) return errors.New("websocket: " + message)
} }
@ -815,7 +857,7 @@ func (c *Conn) handleProtocolError(message string) error {
// returns a non-nil error value. Errors returned from this method are // returns a non-nil error value. Errors returned from this method are
// permanent. Once this method returns a non-nil error, all subsequent calls to // permanent. Once this method returns a non-nil error, all subsequent calls to
// this method return the same error. // this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { func (c *SocketConn) NextReader() (messageType int, r io.Reader, err error) {
// Close previous reader, only relevant for decompression. // Close previous reader, only relevant for decompression.
if c.reader != nil { if c.reader != nil {
c.reader.Close() c.reader.Close()
@ -835,7 +877,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.messageReader = &messageReader{c} c.messageReader = &messageReader{c}
c.reader = c.messageReader c.reader = c.messageReader
if c.readDecompress { if c.readDecompress {
c.reader = c.NewDecompressionReader(c.reader) c.reader = c.newDecompressionReader(c.reader)
} }
return frameType, c.reader, nil return frameType, c.reader, nil
} }
@ -852,7 +894,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
return noFrame, nil, c.readErr return noFrame, nil, c.readErr
} }
type messageReader struct{ c *Conn } type messageReader struct{ c *SocketConn }
func (r *messageReader) Read(b []byte) (int, error) { func (r *messageReader) Read(b []byte) (int, error) {
c := r.c c := r.c
@ -905,7 +947,7 @@ func (r *messageReader) Close() error {
// ReadMessage is a helper method for getting a reader using NextReader and // ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer. // reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { func (c *SocketConn) ReadMessage() (messageType int, p []byte, err error) {
var r io.Reader var r io.Reader
messageType, r, err = c.NextReader() messageType, r, err = c.NextReader()
if err != nil { if err != nil {
@ -919,19 +961,19 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
// After a read has timed out, the websocket connection state is corrupt and // After a read has timed out, the websocket connection state is corrupt and
// all future reads will return an error. A zero value for t means reads will // all future reads will return an error. A zero value for t means reads will
// not time out. // not time out.
func (c *Conn) SetReadDeadline(t time.Time) error { func (c *SocketConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t) return c.conn.SetReadDeadline(t)
} }
// SetReadLimit sets the maximum size for a message read from the peer. If a // SetReadLimit sets the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer // message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application. // and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) { func (c *SocketConn) SetReadLimit(limit int64) {
c.readLimit = limit c.readLimit = limit
} }
// CloseHandler returns the current close handler // CloseHandler returns the current close handler
func (c *Conn) CloseHandler() func(code int, text string) error { func (c *SocketConn) CloseHandler() func(code int, text string) error {
return c.handleClose return c.handleClose
} }
@ -948,7 +990,7 @@ func (c *Conn) CloseHandler() func(code int, text string) error {
// normal error handling. Applications should only set a close handler when the // normal error handling. Applications should only set a close handler when the
// application must perform some action before sending a close frame back to // application must perform some action before sending a close frame back to
// the peer. // the peer.
func (c *Conn) SetCloseHandler(h func(code int, text string) error) { func (c *SocketConn) SetCloseHandler(h func(code int, text string) error) {
if h == nil { if h == nil {
h = func(code int, text string) error { h = func(code int, text string) error {
message := []byte{} message := []byte{}
@ -963,7 +1005,7 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
} }
// PingHandler returns the current ping handler // PingHandler returns the current ping handler
func (c *Conn) PingHandler() func(appData string) error { func (c *SocketConn) PingHandler() func(appData string) error {
return c.handlePing return c.handlePing
} }
@ -973,7 +1015,7 @@ func (c *Conn) PingHandler() func(appData string) error {
// //
// The application must read the connection to process ping messages as // The application must read the connection to process ping messages as
// described in the section on Control Frames above. // described in the section on Control Frames above.
func (c *Conn) SetPingHandler(h func(appData string) error) { func (c *SocketConn) SetPingHandler(h func(appData string) error) {
if h == nil { if h == nil {
h = func(message string) error { h = func(message string) error {
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
@ -989,7 +1031,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
} }
// PongHandler returns the current pong handler // PongHandler returns the current pong handler
func (c *Conn) PongHandler() func(appData string) error { func (c *SocketConn) PongHandler() func(appData string) error {
return c.handlePong return c.handlePong
} }
@ -999,7 +1041,7 @@ func (c *Conn) PongHandler() func(appData string) error {
// //
// The application must read the connection to process ping messages as // The application must read the connection to process ping messages as
// described in the section on Control Frames above. // described in the section on Control Frames above.
func (c *Conn) SetPongHandler(h func(appData string) error) { func (c *SocketConn) SetPongHandler(h func(appData string) error) {
if h == nil { if h == nil {
h = func(string) error { return nil } h = func(string) error { return nil }
} }
@ -1008,14 +1050,14 @@ func (c *Conn) SetPongHandler(h func(appData string) error) {
// UnderlyingConn returns the internal net.Conn. This can be used to further // UnderlyingConn returns the internal net.Conn. This can be used to further
// modifications to connection specific flags. // modifications to connection specific flags.
func (c *Conn) UnderlyingConn() net.Conn { func (c *SocketConn) UnderlyingConn() net.Conn {
return c.conn return c.conn
} }
// EnableWriteCompression enables and disables write compression of // EnableWriteCompression enables and disables write compression of
// subsequent text and binary messages. This function is a noop if // subsequent text and binary messages. This function is a noop if
// compression was not negotiated with the peer. // compression was not negotiated with the peer.
func (c *Conn) EnableWriteCompression(enable bool) { func (c *SocketConn) EnableWriteCompression(enable bool) {
c.enableWriteCompression = enable c.enableWriteCompression = enable
} }
@ -1023,7 +1065,7 @@ func (c *Conn) EnableWriteCompression(enable bool) {
// binary messages. This function is a noop if compression was not negotiated // binary messages. This function is a noop if compression was not negotiated
// with the peer. See the compress/flate package for a description of // with the peer. See the compress/flate package for a description of
// compression levels. // compression levels.
func (c *Conn) SetCompressionLevel(level int) error { func (c *SocketConn) SetCompressionLevel(level int) error {
if !isValidCompressionLevel(level) { if !isValidCompressionLevel(level) {
return errors.New("websocket: invalid compression level") return errors.New("websocket: invalid compression level")
} }

View File

@ -31,6 +31,6 @@ const (
// DefaultPingPeriod is default value of send ping period // DefaultPingPeriod is default value of send ping period
DefaultPingPeriod = (DefaultPingTimeout * 9) / 10 DefaultPingPeriod = (DefaultPingTimeout * 9) / 10
DefaultReconnectInterval = 1 * time.Second DefaultReconnectInterval = 5 * time.Second
DefaultReconnectTryTime = 10 DefaultReconnectTryTime = 10
) )

View File

@ -56,14 +56,14 @@ type Client struct {
writeChan chan []byte writeChan chan []byte
disconnectedChan chan struct{} disconnectedChan chan struct{}
reconnectedChan chan *server.Conn reconnectedChan chan server.Conn
crw server.ClientReadWriter crw server.ClientReadWriter
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) {
var ( var (
conn *server.Conn conn server.Conn
) )
if c.stopChan != nil { if c.stopChan != nil {
@ -83,7 +83,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res
c.readChan = make(chan []byte, 256) c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256)
c.disconnectedChan = make(chan struct{}) c.disconnectedChan = make(chan struct{})
c.reconnectedChan = make(chan *server.Conn) c.reconnectedChan = make(chan server.Conn)
c.stopChan = make(chan struct{}) c.stopChan = make(chan struct{})
c.crw.ReadwriteHandler = c c.crw.ReadwriteHandler = c
@ -152,7 +152,7 @@ RC_LOOP:
} }
} }
func (c *Client) connect() (*server.Conn, *http.Response, error) { func (c *Client) connect() (server.Conn, *http.Response, error) {
conn, res, err := c.Dial() conn, res, err := c.Dial()
if nil != err { if nil != err {
return nil, nil, err return nil, nil, err
@ -165,7 +165,7 @@ func (c *Client) connect() (*server.Conn, *http.Response, error) {
return conn, res, nil return conn, res, nil
} }
func (c *Client) Dial() (*server.Conn, *http.Response, error) { func (c *Client) Dial() (server.Conn, *http.Response, error) {
var ( var (
err error err error
challengeKey string challengeKey string
@ -364,13 +364,13 @@ func (c *Client) Dial() (*server.Conn, *http.Response, error) {
if !snct || !cnct { if !snct || !cnct {
return nil, resp, server.ErrInvalidCompression return nil, resp, server.ErrInvalidCompression
} }
conn.NewCompressionWriter = server.CompressNoContextTakeover conn.SetNewCompressionWriter(server.CompressNoContextTakeover)
conn.NewDecompressionReader = server.DecompressNoContextTakeover conn.SetNewDecompressionReader(server.DecompressNoContextTakeover)
break break
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.Subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.SetSubprotocol(resp.Header.Get("Sec-Websocket-Protocol"))
netConn.SetDeadline(time.Time{}) netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer. netConn = nil // to avoid close in defer.

View File

@ -14,7 +14,7 @@ type ServerHandler interface {
OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error)
RegisterServlet(path string, servlet Servlet) RegisterServlet(path string, servlet Servlet)
Servlet(path string) Servlet Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet
CheckOrigin(ctx *fasthttp.RequestCtx) bool CheckOrigin(ctx *fasthttp.RequestCtx) bool
} }
@ -51,18 +51,6 @@ func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) {
sh.ServerHandlers.Destroy(serverCtx) sh.ServerHandlers.Destroy(serverCtx)
} }
func (sh *ServerHandlers) OnPing(msg string) error {
return nil
}
func (sh *ServerHandlers) OnPong(msg string) error {
return nil
}
func (sh *ServerHandlers) OnClose(code int, text string) error {
return nil
}
func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) { func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) {
ctx.Response.Header.Set("Sec-Websocket-Version", "13") ctx.Response.Header.Set("Sec-Websocket-Version", "13")
ctx.Error(http.StatusText(status), status) ctx.Error(http.StatusText(status), status)
@ -75,7 +63,9 @@ func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) {
sh.servlets[path] = servlet sh.servlets[path] = servlet
} }
func (sh *ServerHandlers) Servlet(path string) Servlet { func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet {
path := string(ctx.Path())
var servlet Servlet var servlet Servlet
if path == "" && len(sh.servlets) == 1 { if path == "" && len(sh.servlets) == 1 {
for _, s := range sh.servlets { for _, s := range sh.servlets {

View File

@ -150,7 +150,6 @@ func (s *Server) handleServer(listener net.Listener) error {
} }
func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
var ( var (
servlet Servlet servlet Servlet
err error err error
@ -165,7 +164,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
} }
} }
if servlet = s.ServerHandler.(ServerHandler).Servlet(path); nil == servlet { if servlet = s.ServerHandler.(ServerHandler).Servlet(s.ctx, ctx); nil == servlet {
s.onError(ctx, fasthttp.StatusInternalServerError, err) s.onError(ctx, fasthttp.StatusInternalServerError, err)
return return
} }
@ -178,7 +177,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
return return
} }
s.upgrader.Upgrade(ctx, responseHeader, func(conn *server.Conn, err error) { s.upgrader.Upgrade(ctx, responseHeader, func(conn *server.SocketConn, err error) {
if err != nil { if err != nil {
s.onError(ctx, fasthttp.StatusInternalServerError, err) s.onError(ctx, fasthttp.StatusInternalServerError, err)
return return

View File

@ -27,18 +27,18 @@ func (s *Servlets) Destroy(serverCtx server.ServerCtx) {
// //
} }
func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn *server.Conn) {
//
}
func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) {
//
}
func (s *Servlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) { func (s *Servlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) {
return nil, nil return nil, nil
} }
func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn server.Conn) {
//
}
func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) { func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) {
} }
func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) {
//
}

View File

@ -16,7 +16,7 @@ import (
) )
type ( type (
OnUpgradeFunc func(*server.Conn, error) OnUpgradeFunc func(*server.SocketConn, error)
) )
// HandshakeError describes an error with the handshake from the peer. // HandshakeError describes an error with the handshake from the peer.
@ -60,7 +60,7 @@ type Upgrader struct {
EnableCompression bool EnableCompression bool
} }
func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*server.Conn, error) { func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*server.SocketConn, error) {
err := HandshakeError{reason} err := HandshakeError{reason}
if u.Error != nil { if u.Error != nil {
u.Error(ctx, status, err) u.Error(ctx, status, err)
@ -193,10 +193,10 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
ctx.Hijack(func(netConn net.Conn) { ctx.Hijack(func(netConn net.Conn) {
c := server.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c := server.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.Subprotocol = subprotocol c.SetSubprotocol(subprotocol)
if compress { if compress {
c.NewCompressionWriter = server.CompressNoContextTakeover c.SetNewCompressionWriter(server.CompressNoContextTakeover)
c.NewDecompressionReader = server.DecompressNoContextTakeover c.SetNewDecompressionReader(server.DecompressNoContextTakeover)
} }
// Clear deadlines set by HTTP server. // Clear deadlines set by HTTP server.

View File

@ -28,14 +28,14 @@ type Client struct {
writeChan chan []byte writeChan chan []byte
disconnectedChan chan struct{} disconnectedChan chan struct{}
reconnectedChan chan *server.Conn reconnectedChan chan server.Conn
crw server.ClientReadWriter crw server.ClientReadWriter
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
var ( var (
conn *server.Conn conn server.Conn
) )
if c.stopChan != nil { if c.stopChan != nil {
@ -55,7 +55,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err
c.readChan = make(chan []byte, 256) c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256)
c.disconnectedChan = make(chan struct{}) c.disconnectedChan = make(chan struct{})
c.reconnectedChan = make(chan *server.Conn) c.reconnectedChan = make(chan server.Conn)
c.stopChan = make(chan struct{}) c.stopChan = make(chan struct{})
c.crw.ReadwriteHandler = c c.crw.ReadwriteHandler = c
@ -124,7 +124,7 @@ RC_LOOP:
} }
} }
func (c *Client) connect() (*server.Conn, error) { func (c *Client) connect() (server.Conn, error) {
netConn, err := c.Dial() netConn, err := c.Dial()
if nil != err { if nil != err {
return nil, err return nil, err

View File

@ -12,7 +12,7 @@ type ServerHandler interface {
OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error)
RegisterServlet(servlet Servlet) RegisterServlet(servlet Servlet)
Servlet() Servlet Servlet(serverCtx server.ServerCtx, conn net.Conn) Servlet
} }
type ServerHandlers struct { type ServerHandlers struct {
@ -50,7 +50,7 @@ func (sh *ServerHandlers) RegisterServlet(servlet Servlet) {
sh.servlet = servlet sh.servlet = servlet
} }
func (sh *ServerHandlers) Servlet() Servlet { func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, conn net.Conn) Servlet {
return sh.servlet return sh.servlet
} }

View File

@ -32,7 +32,7 @@ func (s *Server) ListenAndServe() error {
listener net.Listener listener net.Listener
) )
if nil == s.ServerHandler { if nil == s.ServerHandler {
return fmt.Errorf("Server: server handler must be specified") return fmt.Errorf(s.serverMessage("server handler must be specified"))
} }
if err = s.ServerHandler.Validate(); nil != err { if err = s.ServerHandler.Validate(); nil != err {
return err return err
@ -90,11 +90,8 @@ func (s *Server) handleServer(listener net.Listener) error {
if nil != listener { if nil != listener {
listener.Close() listener.Close()
} }
s.ServerHandler.OnStop(s.ctx) s.ServerHandler.OnStop(s.ctx)
logging.Logger().Infof(s.serverMessage("Stopped")) logging.Logger().Infof(s.serverMessage("Stopped"))
s.stopWg.Done() s.stopWg.Done()
}() }()
@ -144,7 +141,7 @@ func (s *Server) handleServer(listener net.Listener) error {
} }
} }
servlet := s.ServerHandler.(ServerHandler).Servlet() servlet := s.ServerHandler.(ServerHandler).Servlet(s.ctx, netConn)
if nil == servlet { if nil == servlet {
logging.Logger().Errorf(s.serverMessage("Servlet is nil")) logging.Logger().Errorf(s.serverMessage("Servlet is nil"))
continue continue

View File

@ -28,18 +28,18 @@ func (s *Servlets) Destroy(serverCtx server.ServerCtx) {
// //
} }
func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn *server.Conn) {
//
}
func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) {
//
}
func (s *Servlets) Handshake(servletCtx server.ServletCtx, conn net.Conn) error { func (s *Servlets) Handshake(servletCtx server.ServletCtx, conn net.Conn) error {
return nil return nil
} }
func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn server.Conn) {
//
}
func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) { func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) {
} }
func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) {
//
}

View File

@ -56,15 +56,12 @@ type ReadWriteHandlers struct {
func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 { func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 {
return rwh.MaxMessageSize return rwh.MaxMessageSize
} }
func (rwh *ReadWriteHandlers) GetReadBufferSize() int { func (rwh *ReadWriteHandlers) GetReadBufferSize() int {
return rwh.ReadBufferSize return rwh.ReadBufferSize
} }
func (rwh *ReadWriteHandlers) GetWriteBufferSize() int { func (rwh *ReadWriteHandlers) GetWriteBufferSize() int {
return rwh.WriteBufferSize return rwh.WriteBufferSize
} }
func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration { func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration {
return rwh.ReadTimeout return rwh.ReadTimeout
} }
@ -80,11 +77,9 @@ func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration {
func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration { func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
return rwh.PingPeriod return rwh.PingPeriod
} }
func (rwh *ReadWriteHandlers) IsEnableCompression() bool { func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression return rwh.EnableCompression
} }
func (rwh *ReadWriteHandlers) Validate() error { func (rwh *ReadWriteHandlers) Validate() error {
if rwh.MaxMessageSize <= 0 { if rwh.MaxMessageSize <= 0 {
rwh.MaxMessageSize = DefaultMaxMessageSize rwh.MaxMessageSize = DefaultMaxMessageSize

View File

@ -6,7 +6,7 @@ import (
"time" "time"
) )
func connReadHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- []byte) { func connReadHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- []byte) {
var ( var (
message []byte message []byte
err error err error
@ -53,7 +53,7 @@ func connReadHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-c
} }
} }
func connWriteHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) { func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) {
var ( var (
wc io.WriteCloser wc io.WriteCloser
message []byte message []byte

View File

@ -23,7 +23,7 @@ func (srw *ServerReadWriter) ConnectionSize() int {
return sz return sz
} }
func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) { func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn Conn) {
addr := conn.RemoteAddr() addr := conn.RemoteAddr()
defer func() { defer func() {
@ -55,9 +55,9 @@ func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx Servle
readerDoneChan := make(chan error) readerDoneChan := make(chan error)
writerDoneChan := make(chan error) writerDoneChan := make(chan error)
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
go connReadHandler(srw.ReadwriteHandler, conn, stopChan, readerDoneChan, readChan) go connReadHandler(srw.ReadwriteHandler, conn, stopChan, readerDoneChan, readChan)
go connWriteHandler(srw.ReadwriteHandler, conn, stopChan, writerDoneChan, writeChan) go connWriteHandler(srw.ReadwriteHandler, conn, stopChan, writerDoneChan, writeChan)
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
select { select {
case <-readerDoneChan: case <-readerDoneChan:

View File

@ -6,7 +6,7 @@ type Servlet interface {
Init(serverCtx ServerCtx) error Init(serverCtx ServerCtx) error
Destroy(serverCtx ServerCtx) Destroy(serverCtx ServerCtx)
OnConnect(servletCtx ServletCtx, conn *Conn) OnConnect(servletCtx ServletCtx, conn Conn)
Handle(servletCtx ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) Handle(servletCtx ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte)
OnDisconnect(servletCtx ServletCtx) OnDisconnect(servletCtx ServletCtx)
} }