diff --git a/server.go b/server.go index 7c1de37..e52e95b 100644 --- a/server.go +++ b/server.go @@ -154,11 +154,11 @@ func (s *server) handleError(ctx *fasthttp.RequestCtx, status int, reason error) s.sh.OnError(ctx, status, reason) } -func handleConnection(s *server, soc *Socket, socketHandler SocketHandler) { +func handleConnection(s *server, soc Socket, socketHandler SocketHandler) { defer s.stopWg.Done() logging.Logger().Debug(fmt.Sprintf("Server: Client[%s] is connected.", soc.RemoteAddr())) - socketHandler.OnConnect(soc) + soc = socketHandler.OnConnect(soc) clientStopChan := make(chan struct{}) handleDoneChan := make(chan struct{}) diff --git a/socket.go b/socket.go index 39280c6..9dd3fcb 100644 --- a/socket.go +++ b/socket.go @@ -7,9 +7,148 @@ import ( "time" "git.loafle.net/commons_go/websocket_fasthttp/websocket" + "github.com/valyala/fasthttp" ) -func newSocket(id string, conn *websocket.Conn, sh SocketHandler) *Socket { +type Socket interface { + // ID returns the identity of the client. + ID() string + // GetAttribute returns a attribute for the key. + GetAttribute(key interface{}) interface{} + // SetAttribute store a attribute for the key. + SetAttribute(key interface{}, value interface{}) + // WaitRequest wait request of client. + WaitRequest() (*SocketConn, error) + + // Subprotocol returns the negotiated protocol for the connection. + Subprotocol() string + + // Close closes the underlying network connection without sending or waiting for a close frame. + Close() error + + // LocalAddr returns the local network address. + LocalAddr() net.Addr + + // RemoteAddr returns the remote network address. + RemoteAddr() net.Addr + + // WriteControl writes a control message with the given deadline. The allowed + // message types are CloseMessage, PingMessage and PongMessage. + WriteControl(messageType int, data []byte, deadline time.Time) error + + // NextWriter returns a writer for the next message to send. The writer's Close + // method flushes the complete message to the network. + // + // There can be at most one open writer on a connection. NextWriter closes the + // previous writer if the application has not already done so. + NextWriter(messageType int) (io.WriteCloser, error) + + // WritePreparedMessage writes prepared message into connection. + WritePreparedMessage(pm *websocket.PreparedMessage) error + + // WriteMessage is a helper method for getting a writer using NextWriter, + // writing the message and closing the writer. + WriteMessage(messageType int, data []byte) error + + // SetWriteDeadline sets the write deadline on the underlying network + // connection. After a write has timed out, the websocket state is corrupt and + // all future writes will return an error. A zero value for t means writes will + // not time out. + SetWriteDeadline(t time.Time) error + + // NextReader returns the next data message received from the peer. The + // returned messageType is either TextMessage or BinaryMessage. + // + // There can be at most one open reader on a connection. NextReader discards + // the previous message if the application has not already consumed it. + // + // Applications must break out of the application's read loop when this method + // returns a non-nil error value. Errors returned from this method are + // permanent. Once this method returns a non-nil error, all subsequent calls to + // this method return the same error. + NextReader() (messageType int, r io.Reader, err error) + + // ReadMessage is a helper method for getting a reader using NextReader and + // reading from that reader to a buffer. + ReadMessage() (messageType int, p []byte, err error) + + // SetReadDeadline sets the read deadline on the underlying network connection. + // After a read has timed out, the websocket connection state is corrupt and + // all future reads will return an error. A zero value for t means reads will + // not time out. + SetReadDeadline(t time.Time) error + + // SetReadLimit sets the maximum size for a message read from the peer. If a + // message exceeds the limit, the connection sends a close frame to the peer + // and returns ErrReadLimit to the application. + SetReadLimit(limit int64) + + // CloseHandler returns the current close handler + CloseHandler() func(code int, text string) error + + // SetCloseHandler sets the handler for close messages received from the peer. + // The code argument to h is the received close code or CloseNoStatusReceived + // if the close message is empty. The default close handler sends a close frame + // back to the peer. + // + // The application must read the connection to process close messages as + // described in the section on Control Frames above. + // + // The connection read methods return a CloseError when a close frame is + // received. Most applications should handle close messages as part of their + // normal error handling. Applications should only set a close handler when the + // application must perform some action before sending a close frame back to + // the peer. + SetCloseHandler(h func(code int, text string) error) + + // PingHandler returns the current ping handler + PingHandler() func(appData string) error + + // SetPingHandler sets the handler for ping messages received from the peer. + // The appData argument to h is the PING frame application data. The default + // ping handler sends a pong to the peer. + // + // The application must read the connection to process ping messages as + // described in the section on Control Frames above. + SetPingHandler(h func(appData string) error) + + // PongHandler returns the current pong handler + PongHandler() func(appData string) error + + // SetPongHandler sets the handler for pong messages received from the peer. + // The appData argument to h is the PONG frame application data. The default + // pong handler does nothing. + // + // The application must read the connection to process ping messages as + // described in the section on Control Frames above. + SetPongHandler(h func(appData string) error) + + // UnderlyingConn returns the internal net.Conn. This can be used to further + // modifications to connection specific flags. + UnderlyingConn() net.Conn + + // EnableWriteCompression enables and disables write compression of + // subsequent text and binary messages. This function is a noop if + // compression was not negotiated with the peer. + EnableWriteCompression(enable bool) + + // SetCompressionLevel sets the flate compression level for subsequent text and + // binary messages. This function is a noop if compression was not negotiated + // with the peer. See the compress/flate package for a description of + // compression levels. + SetCompressionLevel(level int) error + + // SetHeaders sets request headers + SetHeaders(h *fasthttp.RequestHeader) + + // Header returns header by key + Header(key string) (value string) + + // Headers returns the RequestHeader struct + Headers() *fasthttp.RequestHeader +} + +func newSocket(id string, conn *websocket.Conn, sh SocketHandler) Socket { s := retainSocket() s.Conn = conn s.sh = sh @@ -22,7 +161,7 @@ func newSocket(id string, conn *websocket.Conn, sh SocketHandler) *Socket { return s } -type Socket struct { +type fasthttpSocket struct { *websocket.Conn sh SocketHandler @@ -32,25 +171,25 @@ type Socket struct { sc *SocketConn } -func (s *Socket) ID() string { +func (s *fasthttpSocket) ID() string { return s.id } -func (s *Socket) GetAttribute(key interface{}) interface{} { +func (s *fasthttpSocket) GetAttribute(key interface{}) interface{} { if nil == s.attributes { return nil } return s.attributes[key] } -func (s *Socket) SetAttribute(key interface{}, value interface{}) { +func (s *fasthttpSocket) SetAttribute(key interface{}, value interface{}) { if nil == s.attributes { s.attributes = make(map[interface{}]interface{}) } s.attributes[key] = value } -func (s *Socket) WaitRequest() (*SocketConn, error) { +func (s *fasthttpSocket) WaitRequest() (*SocketConn, error) { if nil != s.sc { releaseSocketConn(s.sc) s.sc = nil @@ -72,7 +211,7 @@ func (s *Socket) WaitRequest() (*SocketConn, error) { return s.sc, nil } -func (s *Socket) NextWriter(messageType int) (io.WriteCloser, error) { +func (s *fasthttpSocket) NextWriter(messageType int) (io.WriteCloser, error) { if 0 < s.sh.GetWriteTimeout() { s.SetWriteDeadline(time.Now().Add(s.sh.GetWriteTimeout() * time.Second)) } @@ -80,7 +219,7 @@ func (s *Socket) NextWriter(messageType int) (io.WriteCloser, error) { return s.Conn.NextWriter(messageType) } -func (s *Socket) WriteMessage(messageType int, data []byte) error { +func (s *fasthttpSocket) WriteMessage(messageType int, data []byte) error { if 0 < s.sh.GetWriteTimeout() { s.SetWriteDeadline(time.Now().Add(s.sh.GetWriteTimeout() * time.Second)) } @@ -88,7 +227,7 @@ func (s *Socket) WriteMessage(messageType int, data []byte) error { return s.Conn.WriteMessage(messageType, data) } -func (s *Socket) Close() error { +func (s *fasthttpSocket) Close() error { err := s.Conn.Close() releaseSocket(s) return err @@ -97,7 +236,7 @@ func (s *Socket) Close() error { type SocketConn struct { net.Conn - s *Socket + s *fasthttpSocket MessageType int r io.Reader @@ -152,15 +291,15 @@ func (sc *SocketConn) SetWriteDeadline(t time.Time) error { var socketPool sync.Pool -func retainSocket() *Socket { +func retainSocket() *fasthttpSocket { v := socketPool.Get() if v == nil { - return &Socket{} + return &fasthttpSocket{} } - return v.(*Socket) + return v.(*fasthttpSocket) } -func releaseSocket(s *Socket) { +func releaseSocket(s *fasthttpSocket) { s.sh = nil s.sc = nil s.id = "" diff --git a/socket_handler.go b/socket_handler.go index f61224d..f791596 100644 --- a/socket_handler.go +++ b/socket_handler.go @@ -24,20 +24,21 @@ type SocketHandler interface { // OnConnect invoked when client is connected // If you override ths method, must call // - // func (sh *SocketHandler) OnConnect(soc *cwf.Socket) { - // sh.SocketHandlers.OnConnect(soc) + // func (sh *SocketHandler) OnConnect(soc cwf.Socket) cwf.Socket { // ... + // newSoc := ... + // return sh.SocketHandlers.OnConnect(newSoc) // } - OnConnect(soc *Socket) - Handle(soc *Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) + OnConnect(soc Socket) Socket + Handle(soc Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) // OnDisconnect invoked when client is disconnected // If you override ths method, must call // - // func (sh *SocketHandler) OnDisconnect(soc *cwf.Socket) { + // func (sh *SocketHandler) OnDisconnect(soc cwf.Socket) { // ... // sh.SocketHandlers.OnDisconnect(soc) // } - OnDisconnect(soc *Socket) + OnDisconnect(soc Socket) // Destroy invoked when server is stopped // If you override ths method, must call // @@ -47,8 +48,8 @@ type SocketHandler interface { // } Destroy() - GetSocket(id string) *Socket - GetSockets() map[string]*Socket + GetSocket(id string) Socket + GetSockets() map[string]Socket GetMaxMessageSize() int64 GetWriteTimeout() time.Duration diff --git a/socket_handlers.go b/socket_handlers.go index 0cc662e..28dc67f 100644 --- a/socket_handlers.go +++ b/socket_handlers.go @@ -26,11 +26,11 @@ type SocketHandlers struct { PingTimeout time.Duration PingPeriod time.Duration - sockets map[string]*Socket + sockets map[string]Socket } func (sh *SocketHandlers) Init() error { - sh.sockets = make(map[string]*Socket) + sh.sockets = make(map[string]Socket) return nil } @@ -39,15 +39,16 @@ func (sh *SocketHandlers) Handshake(ctx *fasthttp.RequestCtx) (id string, extens return "", nil } -func (sh *SocketHandlers) OnConnect(soc *Socket) { +func (sh *SocketHandlers) OnConnect(soc Socket) Socket { sh.sockets[soc.ID()] = soc + return soc } -func (sh *SocketHandlers) Handle(soc *Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) { +func (sh *SocketHandlers) Handle(soc Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) { // no op } -func (sh *SocketHandlers) OnDisconnect(soc *Socket) { +func (sh *SocketHandlers) OnDisconnect(soc Socket) { delete(sh.sockets, soc.ID()) } @@ -55,10 +56,10 @@ func (sh *SocketHandlers) Destroy() { // no op } -func (sh *SocketHandlers) GetSocket(id string) *Socket { +func (sh *SocketHandlers) GetSocket(id string) Socket { return sh.sockets[id] } -func (sh *SocketHandlers) GetSockets() map[string]*Socket { +func (sh *SocketHandlers) GetSockets() map[string]Socket { return sh.sockets }