diff --git a/server.go b/server.go index e52e95b..d45204c 100644 --- a/server.go +++ b/server.go @@ -159,6 +159,7 @@ func handleConnection(s *server, soc Socket, socketHandler SocketHandler) { logging.Logger().Debug(fmt.Sprintf("Server: Client[%s] is connected.", soc.RemoteAddr())) soc = socketHandler.OnConnect(soc) + socketHandler.addSocket(soc) clientStopChan := make(chan struct{}) handleDoneChan := make(chan struct{}) @@ -171,8 +172,9 @@ func handleConnection(s *server, soc Socket, socketHandler SocketHandler) { <-handleDoneChan case <-handleDoneChan: close(clientStopChan) - logging.Logger().Debug(fmt.Sprintf("Server: Client[%s] is disconnected.", soc.RemoteAddr())) socketHandler.OnDisconnect(soc) + logging.Logger().Debug(fmt.Sprintf("Server: Client[%s] is disconnected.", soc.RemoteAddr())) + socketHandler.removeSocket(soc) soc.Close() } } diff --git a/socket_handler.go b/socket_handler.go index 7e43931..381499f 100644 --- a/socket_handler.go +++ b/socket_handler.go @@ -26,9 +26,9 @@ type SocketHandler interface { // If you override ths method, must call // // func (sh *SocketHandler) OnConnect(soc cwf.Socket) cwf.Socket { - // ... + // soc = sh.SocketHandlers.OnConnect(newSoc) // newSoc := ... - // return sh.SocketHandlers.OnConnect(newSoc) + // return newSoc // } OnConnect(soc Socket) Socket Handle(soc Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) @@ -67,4 +67,7 @@ type SocketHandler interface { // ... // } Validate() + + addSocket(soc Socket) + removeSocket(soc Socket) } diff --git a/socket_handlers.go b/socket_handlers.go index 28dc67f..7b9b48c 100644 --- a/socket_handlers.go +++ b/socket_handlers.go @@ -40,7 +40,7 @@ func (sh *SocketHandlers) Handshake(ctx *fasthttp.RequestCtx) (id string, extens } func (sh *SocketHandlers) OnConnect(soc Socket) Socket { - sh.sockets[soc.ID()] = soc + return soc } @@ -49,7 +49,7 @@ func (sh *SocketHandlers) Handle(soc Socket, stopChan <-chan struct{}, doneChan } func (sh *SocketHandlers) OnDisconnect(soc Socket) { - delete(sh.sockets, soc.ID()) + } func (sh *SocketHandlers) Destroy() { @@ -102,3 +102,11 @@ func (sh *SocketHandlers) Validate() { sh.PingPeriod = DefaultPingPeriod } } + +func (sh *SocketHandlers) addSocket(soc Socket) { + sh.sockets[soc.ID()] = soc +} + +func (sh *SocketHandlers) removeSocket(soc Socket) { + delete(sh.sockets, soc.ID()) +}