diff --git a/client-readwriter.go b/client-readwriter.go index 60c7328..30980fd 100644 --- a/client-readwriter.go +++ b/client-readwriter.go @@ -12,12 +12,12 @@ type ClientReadWriter struct { ReadChan chan<- []byte WriteChan <-chan []byte DisconnectedChan chan<- struct{} - ReconnectedChan <-chan *Conn + ReconnectedChan <-chan Conn ClientStopChan <-chan struct{} ClientStopWg *sync.WaitGroup } -func (crw *ClientReadWriter) HandleConnection(conn *Conn) { +func (crw *ClientReadWriter) HandleConnection(conn Conn) { defer func() { if nil != conn { diff --git a/conn-json.go b/conn-json.go index 5c1a539..8803cdb 100644 --- a/conn-json.go +++ b/conn-json.go @@ -10,7 +10,7 @@ import ( ) // WriteJSON is deprecated, use c.WriteJSON instead. -func WriteJSON(c *Conn, v interface{}) error { +func WriteJSON(c *SocketConn, v interface{}) error { return c.WriteJSON(v) } @@ -18,7 +18,7 @@ func WriteJSON(c *Conn, v interface{}) error { // // See the documentation for encoding/json Marshal for details about the // conversion of Go values to JSON. -func (c *Conn) WriteJSON(v interface{}) error { +func (c *SocketConn) WriteJSON(v interface{}) error { w, err := c.NextWriter(TextMessage) if err != nil { return err @@ -32,7 +32,7 @@ func (c *Conn) WriteJSON(v interface{}) error { } // ReadJSON is deprecated, use c.ReadJSON instead. -func ReadJSON(c *Conn, v interface{}) error { +func ReadJSON(c *SocketConn, v interface{}) error { return c.ReadJSON(v) } @@ -41,7 +41,7 @@ func ReadJSON(c *Conn, v interface{}) error { // // See the documentation for the encoding/json Unmarshal function for details // about the conversion of JSON to a Go value. -func (c *Conn) ReadJSON(v interface{}) error { +func (c *SocketConn) ReadJSON(v interface{}) error { _, r, err := c.NextReader() if err != nil { return err diff --git a/conn-prepared.go b/conn-prepared.go index a88ad60..c48e6ce 100644 --- a/conn-prepared.go +++ b/conn-prepared.go @@ -73,7 +73,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { mu := make(chan bool, 1) mu <- true var nc prepareConn - c := &Conn{ + c := &SocketConn{ conn: &nc, mu: mu, isServer: key.isServer, @@ -83,7 +83,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { } if key.compress { - c.NewCompressionWriter = CompressNoContextTakeover + c.newCompressionWriter = CompressNoContextTakeover } err = c.WriteMessage(pm.messageType, pm.data) frame.data = nc.buf.Bytes() diff --git a/conn-read.go b/conn-read.go index 71d80a3..87fa254 100644 --- a/conn-read.go +++ b/conn-read.go @@ -8,7 +8,7 @@ package server import "io" -func (c *Conn) read(n int) ([]byte, error) { +func (c *SocketConn) read(n int) ([]byte, error) { p, err := c.BuffReader.Peek(n) if err == io.EOF { err = errUnexpectedEOF diff --git a/conn.go b/conn.go index d500c1b..0047e73 100644 --- a/conn.go +++ b/conn.go @@ -120,11 +120,38 @@ func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } -// The Conn type represents a WebSocket connection. -type Conn struct { +type Conn interface { + Close() error + CloseHandler() func(code int, text string) error + EnableWriteCompression(enable bool) + LocalAddr() net.Addr + NextReader() (messageType int, r io.Reader, err error) + NextWriter(messageType int) (io.WriteCloser, error) + PingHandler() func(appData string) error + PongHandler() func(appData string) error + ReadJSON(v interface{}) error + ReadMessage() (messageType int, p []byte, err error) + RemoteAddr() net.Addr + SetCloseHandler(h func(code int, text string) error) + SetCompressionLevel(level int) error + SetPingHandler(h func(appData string) error) + SetPongHandler(h func(appData string) error) + SetReadDeadline(t time.Time) error + SetReadLimit(limit int64) + SetWriteDeadline(t time.Time) error + Subprotocol() string + UnderlyingConn() net.Conn + WriteControl(messageType int, data []byte, deadline time.Time) error + WriteJSON(v interface{}) error + WriteMessage(messageType int, data []byte) error + WritePreparedMessage(pm *PreparedMessage) error +} + +// The SocketConn type represents a WebSocket connection. +type SocketConn struct { conn net.Conn isServer bool - Subprotocol string + subprotocol string // Write fields mu chan bool // used as mutex to protect write to conn @@ -138,7 +165,7 @@ type Conn struct { enableWriteCompression bool compressionLevel int - NewCompressionWriter func(io.WriteCloser, int) io.WriteCloser + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application @@ -157,10 +184,10 @@ type Conn struct { messageReader *messageReader // the current low-level reader readDecompress bool // whether last read frame had RSV1 set - NewDecompressionReader func(io.Reader) io.ReadCloser + newDecompressionReader func(io.Reader) io.ReadCloser } -func NewConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { +func NewConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *SocketConn { return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) } @@ -173,7 +200,7 @@ func (wh *writeHook) Write(p []byte) (int, error) { return len(p), nil } -func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { +func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *SocketConn { mu := make(chan bool, 1) mu <- true @@ -218,7 +245,7 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) } - c := &Conn{ + c := &SocketConn{ isServer: isServer, BuffReader: br, conn: conn, @@ -235,23 +262,38 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in } // Close closes the underlying network connection without sending or waiting for a close frame. -func (c *Conn) Close() error { +func (c *SocketConn) Close() error { return c.conn.Close() } +func (c *SocketConn) Subprotocol() string { + return c.subprotocol +} + +func (c *SocketConn) SetSubprotocol(subprotocol string) { + c.subprotocol = subprotocol +} + // LocalAddr returns the local network address. -func (c *Conn) LocalAddr() net.Addr { +func (c *SocketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } // RemoteAddr returns the remote network address. -func (c *Conn) RemoteAddr() net.Addr { +func (c *SocketConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c *SocketConn) SetNewCompressionWriter(w func(io.WriteCloser, int) io.WriteCloser) { + c.newCompressionWriter = w +} +func (c *SocketConn) SetNewDecompressionReader(r func(io.Reader) io.ReadCloser) { + c.newDecompressionReader = r +} + // Write methods -func (c *Conn) writeFatal(err error) error { +func (c *SocketConn) writeFatal(err error) error { err = hideTempErr(err) c.writeErrMu.Lock() if c.writeErr == nil { @@ -261,7 +303,7 @@ func (c *Conn) writeFatal(err error) error { return err } -func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { +func (c *SocketConn) write(frameType int, deadline time.Time, bufs ...[]byte) error { <-c.mu defer func() { c.mu <- true }() @@ -290,7 +332,7 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { // WriteControl writes a control message with the given deadline. The allowed // message types are CloseMessage, PingMessage and PongMessage. -func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { +func (c *SocketConn) WriteControl(messageType int, data []byte, deadline time.Time) error { if !isControl(messageType) { return errBadWriteOpCode } @@ -351,7 +393,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } -func (c *Conn) prepWrite(messageType int) error { +func (c *SocketConn) prepWrite(messageType int) error { // Close previous writer if not already closed by the application. It's // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. @@ -375,7 +417,7 @@ func (c *Conn) prepWrite(messageType int) error { // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. -func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { +func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) { if err := c.prepWrite(messageType); err != nil { return nil, err } @@ -386,8 +428,8 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { pos: maxFrameHeaderSize, } c.writer = mw - if c.NewCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { - w := c.NewCompressionWriter(c.writer, c.compressionLevel) + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true c.writer = w } @@ -395,7 +437,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } type messageWriter struct { - c *Conn + c *SocketConn compress bool // whether next call to flushFrame should set RSV1 pos int // end of data in writeBuf. frameType int // type of the current frame. @@ -595,10 +637,10 @@ func (w *messageWriter) Close() error { } // WritePreparedMessage writes prepared message into connection. -func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { +func (c *SocketConn) WritePreparedMessage(pm *PreparedMessage) error { frameType, frameData, err := pm.frame(prepareKey{ isServer: c.isServer, - compress: c.NewCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), compressionLevel: c.compressionLevel, }) if err != nil { @@ -618,9 +660,9 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. -func (c *Conn) WriteMessage(messageType int, data []byte) error { +func (c *SocketConn) WriteMessage(messageType int, data []byte) error { - if c.isServer && (c.NewCompressionWriter == nil || !c.enableWriteCompression) { + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { // Fast path with no allocations and single frame. if err := c.prepWrite(messageType); err != nil { @@ -647,14 +689,14 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { // connection. After a write has timed out, the websocket state is corrupt and // all future writes will return an error. A zero value for t means writes will // not time out. -func (c *Conn) SetWriteDeadline(t time.Time) error { +func (c *SocketConn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t return nil } // Read methods -func (c *Conn) advanceFrame() (int, error) { +func (c *SocketConn) advanceFrame() (int, error) { // 1. Skip remainder of previous frame. @@ -677,7 +719,7 @@ func (c *Conn) advanceFrame() (int, error) { c.readRemaining = int64(p[1] & 0x7f) c.readDecompress = false - if c.NewDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { c.readDecompress = true p[0] &^= rsv1Bit } @@ -800,7 +842,7 @@ func (c *Conn) advanceFrame() (int, error) { return frameType, nil } -func (c *Conn) handleProtocolError(message string) error { +func (c *SocketConn) handleProtocolError(message string) error { c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) return errors.New("websocket: " + message) } @@ -815,7 +857,7 @@ func (c *Conn) handleProtocolError(message string) error { // returns a non-nil error value. Errors returned from this method are // permanent. Once this method returns a non-nil error, all subsequent calls to // this method return the same error. -func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { +func (c *SocketConn) NextReader() (messageType int, r io.Reader, err error) { // Close previous reader, only relevant for decompression. if c.reader != nil { c.reader.Close() @@ -835,7 +877,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.messageReader = &messageReader{c} c.reader = c.messageReader if c.readDecompress { - c.reader = c.NewDecompressionReader(c.reader) + c.reader = c.newDecompressionReader(c.reader) } return frameType, c.reader, nil } @@ -852,7 +894,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { return noFrame, nil, c.readErr } -type messageReader struct{ c *Conn } +type messageReader struct{ c *SocketConn } func (r *messageReader) Read(b []byte) (int, error) { c := r.c @@ -905,7 +947,7 @@ func (r *messageReader) Close() error { // ReadMessage is a helper method for getting a reader using NextReader and // reading from that reader to a buffer. -func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { +func (c *SocketConn) ReadMessage() (messageType int, p []byte, err error) { var r io.Reader messageType, r, err = c.NextReader() if err != nil { @@ -919,19 +961,19 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { // After a read has timed out, the websocket connection state is corrupt and // all future reads will return an error. A zero value for t means reads will // not time out. -func (c *Conn) SetReadDeadline(t time.Time) error { +func (c *SocketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } // SetReadLimit sets the maximum size for a message read from the peer. If a // message exceeds the limit, the connection sends a close frame to the peer // and returns ErrReadLimit to the application. -func (c *Conn) SetReadLimit(limit int64) { +func (c *SocketConn) SetReadLimit(limit int64) { c.readLimit = limit } // CloseHandler returns the current close handler -func (c *Conn) CloseHandler() func(code int, text string) error { +func (c *SocketConn) CloseHandler() func(code int, text string) error { return c.handleClose } @@ -948,7 +990,7 @@ func (c *Conn) CloseHandler() func(code int, text string) error { // normal error handling. Applications should only set a close handler when the // application must perform some action before sending a close frame back to // the peer. -func (c *Conn) SetCloseHandler(h func(code int, text string) error) { +func (c *SocketConn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := []byte{} @@ -963,7 +1005,7 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { } // PingHandler returns the current ping handler -func (c *Conn) PingHandler() func(appData string) error { +func (c *SocketConn) PingHandler() func(appData string) error { return c.handlePing } @@ -973,7 +1015,7 @@ func (c *Conn) PingHandler() func(appData string) error { // // The application must read the connection to process ping messages as // described in the section on Control Frames above. -func (c *Conn) SetPingHandler(h func(appData string) error) { +func (c *SocketConn) SetPingHandler(h func(appData string) error) { if h == nil { h = func(message string) error { err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) @@ -989,7 +1031,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) { } // PongHandler returns the current pong handler -func (c *Conn) PongHandler() func(appData string) error { +func (c *SocketConn) PongHandler() func(appData string) error { return c.handlePong } @@ -999,7 +1041,7 @@ func (c *Conn) PongHandler() func(appData string) error { // // The application must read the connection to process ping messages as // described in the section on Control Frames above. -func (c *Conn) SetPongHandler(h func(appData string) error) { +func (c *SocketConn) SetPongHandler(h func(appData string) error) { if h == nil { h = func(string) error { return nil } } @@ -1008,14 +1050,14 @@ func (c *Conn) SetPongHandler(h func(appData string) error) { // UnderlyingConn returns the internal net.Conn. This can be used to further // modifications to connection specific flags. -func (c *Conn) UnderlyingConn() net.Conn { +func (c *SocketConn) UnderlyingConn() net.Conn { return c.conn } // EnableWriteCompression enables and disables write compression of // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. -func (c *Conn) EnableWriteCompression(enable bool) { +func (c *SocketConn) EnableWriteCompression(enable bool) { c.enableWriteCompression = enable } @@ -1023,7 +1065,7 @@ func (c *Conn) EnableWriteCompression(enable bool) { // binary messages. This function is a noop if compression was not negotiated // with the peer. See the compress/flate package for a description of // compression levels. -func (c *Conn) SetCompressionLevel(level int) error { +func (c *SocketConn) SetCompressionLevel(level int) error { if !isValidCompressionLevel(level) { return errors.New("websocket: invalid compression level") } diff --git a/const.go b/const.go index d7179f3..d501ca6 100644 --- a/const.go +++ b/const.go @@ -31,6 +31,6 @@ const ( // DefaultPingPeriod is default value of send ping period DefaultPingPeriod = (DefaultPingTimeout * 9) / 10 - DefaultReconnectInterval = 1 * time.Second + DefaultReconnectInterval = 5 * time.Second DefaultReconnectTryTime = 10 ) diff --git a/fasthttp/websocket/client.go b/fasthttp/websocket/client.go index 6895734..711c09e 100644 --- a/fasthttp/websocket/client.go +++ b/fasthttp/websocket/client.go @@ -56,14 +56,14 @@ type Client struct { writeChan chan []byte disconnectedChan chan struct{} - reconnectedChan chan *server.Conn + reconnectedChan chan server.Conn crw server.ClientReadWriter } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { var ( - conn *server.Conn + conn server.Conn ) if c.stopChan != nil { @@ -83,7 +83,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) c.disconnectedChan = make(chan struct{}) - c.reconnectedChan = make(chan *server.Conn) + c.reconnectedChan = make(chan server.Conn) c.stopChan = make(chan struct{}) c.crw.ReadwriteHandler = c @@ -152,7 +152,7 @@ RC_LOOP: } } -func (c *Client) connect() (*server.Conn, *http.Response, error) { +func (c *Client) connect() (server.Conn, *http.Response, error) { conn, res, err := c.Dial() if nil != err { return nil, nil, err @@ -165,7 +165,7 @@ func (c *Client) connect() (*server.Conn, *http.Response, error) { return conn, res, nil } -func (c *Client) Dial() (*server.Conn, *http.Response, error) { +func (c *Client) Dial() (server.Conn, *http.Response, error) { var ( err error challengeKey string @@ -364,13 +364,13 @@ func (c *Client) Dial() (*server.Conn, *http.Response, error) { if !snct || !cnct { return nil, resp, server.ErrInvalidCompression } - conn.NewCompressionWriter = server.CompressNoContextTakeover - conn.NewDecompressionReader = server.DecompressNoContextTakeover + conn.SetNewCompressionWriter(server.CompressNoContextTakeover) + conn.SetNewDecompressionReader(server.DecompressNoContextTakeover) break } resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) - conn.Subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + conn.SetSubprotocol(resp.Header.Get("Sec-Websocket-Protocol")) netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. diff --git a/fasthttp/websocket/server-handler.go b/fasthttp/websocket/server-handler.go index 77f49ec..71ca0b8 100644 --- a/fasthttp/websocket/server-handler.go +++ b/fasthttp/websocket/server-handler.go @@ -14,7 +14,7 @@ type ServerHandler interface { OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) RegisterServlet(path string, servlet Servlet) - Servlet(path string) Servlet + Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet CheckOrigin(ctx *fasthttp.RequestCtx) bool } @@ -51,18 +51,6 @@ func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { sh.ServerHandlers.Destroy(serverCtx) } -func (sh *ServerHandlers) OnPing(msg string) error { - return nil -} - -func (sh *ServerHandlers) OnPong(msg string) error { - return nil -} - -func (sh *ServerHandlers) OnClose(code int, text string) error { - return nil -} - func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) { ctx.Response.Header.Set("Sec-Websocket-Version", "13") ctx.Error(http.StatusText(status), status) @@ -75,7 +63,9 @@ func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) { sh.servlets[path] = servlet } -func (sh *ServerHandlers) Servlet(path string) Servlet { +func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet { + path := string(ctx.Path()) + var servlet Servlet if path == "" && len(sh.servlets) == 1 { for _, s := range sh.servlets { diff --git a/fasthttp/websocket/server.go b/fasthttp/websocket/server.go index 15176bb..0b6f2b9 100644 --- a/fasthttp/websocket/server.go +++ b/fasthttp/websocket/server.go @@ -150,7 +150,6 @@ func (s *Server) handleServer(listener net.Listener) error { } func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { - path := string(ctx.Path()) var ( servlet Servlet err error @@ -165,7 +164,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { } } - if servlet = s.ServerHandler.(ServerHandler).Servlet(path); nil == servlet { + if servlet = s.ServerHandler.(ServerHandler).Servlet(s.ctx, ctx); nil == servlet { s.onError(ctx, fasthttp.StatusInternalServerError, err) return } @@ -178,7 +177,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { return } - s.upgrader.Upgrade(ctx, responseHeader, func(conn *server.Conn, err error) { + s.upgrader.Upgrade(ctx, responseHeader, func(conn *server.SocketConn, err error) { if err != nil { s.onError(ctx, fasthttp.StatusInternalServerError, err) return diff --git a/fasthttp/websocket/servlet.go b/fasthttp/websocket/servlet.go index c694413..e17c9ac 100644 --- a/fasthttp/websocket/servlet.go +++ b/fasthttp/websocket/servlet.go @@ -27,18 +27,18 @@ func (s *Servlets) Destroy(serverCtx server.ServerCtx) { // } -func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn *server.Conn) { - // -} - -func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) { - // -} - func (s *Servlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) { return nil, nil } +func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn server.Conn) { + // +} + func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) { } + +func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) { + // +} diff --git a/fasthttp/websocket/upgrade.go b/fasthttp/websocket/upgrade.go index c9725cc..dc7c8e7 100644 --- a/fasthttp/websocket/upgrade.go +++ b/fasthttp/websocket/upgrade.go @@ -16,7 +16,7 @@ import ( ) type ( - OnUpgradeFunc func(*server.Conn, error) + OnUpgradeFunc func(*server.SocketConn, error) ) // HandshakeError describes an error with the handshake from the peer. @@ -60,7 +60,7 @@ type Upgrader struct { EnableCompression bool } -func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*server.Conn, error) { +func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*server.SocketConn, error) { err := HandshakeError{reason} if u.Error != nil { u.Error(ctx, status, err) @@ -193,10 +193,10 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re ctx.Hijack(func(netConn net.Conn) { c := server.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) - c.Subprotocol = subprotocol + c.SetSubprotocol(subprotocol) if compress { - c.NewCompressionWriter = server.CompressNoContextTakeover - c.NewDecompressionReader = server.DecompressNoContextTakeover + c.SetNewCompressionWriter(server.CompressNoContextTakeover) + c.SetNewDecompressionReader(server.DecompressNoContextTakeover) } // Clear deadlines set by HTTP server. diff --git a/net/client.go b/net/client.go index 12e3c8c..c1cf94b 100644 --- a/net/client.go +++ b/net/client.go @@ -28,14 +28,14 @@ type Client struct { writeChan chan []byte disconnectedChan chan struct{} - reconnectedChan chan *server.Conn + reconnectedChan chan server.Conn crw server.ClientReadWriter } func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { var ( - conn *server.Conn + conn server.Conn ) if c.stopChan != nil { @@ -55,7 +55,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err c.readChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256) c.disconnectedChan = make(chan struct{}) - c.reconnectedChan = make(chan *server.Conn) + c.reconnectedChan = make(chan server.Conn) c.stopChan = make(chan struct{}) c.crw.ReadwriteHandler = c @@ -124,7 +124,7 @@ RC_LOOP: } } -func (c *Client) connect() (*server.Conn, error) { +func (c *Client) connect() (server.Conn, error) { netConn, err := c.Dial() if nil != err { return nil, err diff --git a/net/server-handler.go b/net/server-handler.go index 37e1557..e946437 100644 --- a/net/server-handler.go +++ b/net/server-handler.go @@ -12,7 +12,7 @@ type ServerHandler interface { OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) RegisterServlet(servlet Servlet) - Servlet() Servlet + Servlet(serverCtx server.ServerCtx, conn net.Conn) Servlet } type ServerHandlers struct { @@ -50,7 +50,7 @@ func (sh *ServerHandlers) RegisterServlet(servlet Servlet) { sh.servlet = servlet } -func (sh *ServerHandlers) Servlet() Servlet { +func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, conn net.Conn) Servlet { return sh.servlet } diff --git a/net/server.go b/net/server.go index 2693f34..a4356cf 100644 --- a/net/server.go +++ b/net/server.go @@ -32,7 +32,7 @@ func (s *Server) ListenAndServe() error { listener net.Listener ) if nil == s.ServerHandler { - return fmt.Errorf("Server: server handler must be specified") + return fmt.Errorf(s.serverMessage("server handler must be specified")) } if err = s.ServerHandler.Validate(); nil != err { return err @@ -90,11 +90,8 @@ func (s *Server) handleServer(listener net.Listener) error { if nil != listener { listener.Close() } - s.ServerHandler.OnStop(s.ctx) - logging.Logger().Infof(s.serverMessage("Stopped")) - s.stopWg.Done() }() @@ -144,7 +141,7 @@ func (s *Server) handleServer(listener net.Listener) error { } } - servlet := s.ServerHandler.(ServerHandler).Servlet() + servlet := s.ServerHandler.(ServerHandler).Servlet(s.ctx, netConn) if nil == servlet { logging.Logger().Errorf(s.serverMessage("Servlet is nil")) continue diff --git a/net/servlet.go b/net/servlet.go index 248be6a..fe0d54c 100644 --- a/net/servlet.go +++ b/net/servlet.go @@ -28,18 +28,18 @@ func (s *Servlets) Destroy(serverCtx server.ServerCtx) { // } -func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn *server.Conn) { - // -} - -func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) { - // -} - func (s *Servlets) Handshake(servletCtx server.ServletCtx, conn net.Conn) error { return nil } +func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn server.Conn) { + // +} + func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) { } + +func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) { + // +} diff --git a/readwrite-handler.go b/readwrite-handler.go index 68b94a9..fdff5f8 100644 --- a/readwrite-handler.go +++ b/readwrite-handler.go @@ -56,15 +56,12 @@ type ReadWriteHandlers struct { func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 { return rwh.MaxMessageSize } - func (rwh *ReadWriteHandlers) GetReadBufferSize() int { return rwh.ReadBufferSize } - func (rwh *ReadWriteHandlers) GetWriteBufferSize() int { return rwh.WriteBufferSize } - func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration { return rwh.ReadTimeout } @@ -80,11 +77,9 @@ func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration { func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration { return rwh.PingPeriod } - func (rwh *ReadWriteHandlers) IsEnableCompression() bool { return rwh.EnableCompression } - func (rwh *ReadWriteHandlers) Validate() error { if rwh.MaxMessageSize <= 0 { rwh.MaxMessageSize = DefaultMaxMessageSize diff --git a/readwrite.go b/readwrite.go index 0c946b3..f9ef71f 100644 --- a/readwrite.go +++ b/readwrite.go @@ -6,7 +6,7 @@ import ( "time" ) -func connReadHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- []byte) { +func connReadHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- []byte) { var ( message []byte err error @@ -53,7 +53,7 @@ func connReadHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-c } } -func connWriteHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) { +func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) { var ( wc io.WriteCloser message []byte diff --git a/server-readwriter.go b/server-readwriter.go index 33d3918..7388d5c 100644 --- a/server-readwriter.go +++ b/server-readwriter.go @@ -23,7 +23,7 @@ func (srw *ServerReadWriter) ConnectionSize() int { return sz } -func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) { +func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn Conn) { addr := conn.RemoteAddr() defer func() { @@ -55,9 +55,9 @@ func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx Servle readerDoneChan := make(chan error) writerDoneChan := make(chan error) - go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) go connReadHandler(srw.ReadwriteHandler, conn, stopChan, readerDoneChan, readChan) go connWriteHandler(srw.ReadwriteHandler, conn, stopChan, writerDoneChan, writeChan) + go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) select { case <-readerDoneChan: diff --git a/servlet.go b/servlet.go index c1b6d16..5bfd00f 100644 --- a/servlet.go +++ b/servlet.go @@ -6,7 +6,7 @@ type Servlet interface { Init(serverCtx ServerCtx) error Destroy(serverCtx ServerCtx) - OnConnect(servletCtx ServletCtx, conn *Conn) + OnConnect(servletCtx ServletCtx, conn Conn) Handle(servletCtx ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte) OnDisconnect(servletCtx ServletCtx) }