diff --git a/server.go b/server.go index 40d7be1..7b07cae 100644 --- a/server.go +++ b/server.go @@ -19,11 +19,10 @@ type Server interface { Context() ServerContext } -func New(serverCTX ServerContext, sh ServerHandler) Server { +func New(sh ServerHandler) Server { s := &server{ sh: sh, } - s.ctx = serverCTX return s } @@ -74,6 +73,7 @@ func (s *server) Start() error { } var err error + s.ctx = s.sh.ServerContext() if err = s.sh.Init(s.ctx); nil != err { logging.Logger().Panic(fmt.Sprintf("Server: Initialization of server has been failed %v", err)) @@ -153,7 +153,8 @@ func (s *server) handleRequest(ctx *fasthttp.RequestCtx) { return } - soc := newSocket(s.ctx, socketID, conn, socketHandler) + socketCTX := socketHandler.SocketContext(s.ctx) + soc := newSocket(socketHandler, socketCTX, conn, socketID) s.stopWg.Add(1) handleConnection(s, soc, socketHandler) diff --git a/server_context.go b/server_context.go index 0305dd1..8c66fcc 100644 --- a/server_context.go +++ b/server_context.go @@ -12,9 +12,9 @@ type serverContext struct { cuc.Context } -func NewServerContext(parent cuc.Context) ServerContext { +func newServerContext() ServerContext { sCTX := &serverContext{} - sCTX.Context = cuc.NewContext(parent) + sCTX.Context = cuc.NewContext(nil) return sCTX } diff --git a/server_handler.go b/server_handler.go index 0a2c62a..928629c 100644 --- a/server_handler.go +++ b/server_handler.go @@ -8,6 +8,7 @@ import ( ) type ServerHandler interface { + ServerContext() ServerContext // Init invoked before the server is started // If you override ths method, must call // diff --git a/server_handlers.go b/server_handlers.go index bbab29a..9aed952 100644 --- a/server_handlers.go +++ b/server_handlers.go @@ -55,6 +55,10 @@ type ServerHandlers struct { socketHandlers map[string]SocketHandler } +func (sh *ServerHandlers) ServerContext() ServerContext { + return newServerContext() +} + func (sh *ServerHandlers) Init(serverCTX ServerContext) error { if nil != sh.socketHandlers { for _, socketHandler := range sh.socketHandlers { diff --git a/socket.go b/socket.go index eedfcc4..e74b452 100644 --- a/socket.go +++ b/socket.go @@ -147,16 +147,16 @@ type Socket interface { Context() SocketContext } -func newSocket(serverCTX ServerContext, id string, conn *websocket.Conn, sh SocketHandler) Socket { +func newSocket(socketHandler SocketHandler, socketCTX SocketContext, conn *websocket.Conn, id string) Socket { s := retainSocket() s.Conn = conn - s.sh = sh + s.sh = socketHandler s.id = id - s.SetReadLimit(sh.GetMaxMessageSize()) - if 0 < sh.GetReadTimeout() { - s.SetReadDeadline(time.Now().Add(sh.GetReadTimeout() * time.Second)) + s.SetReadLimit(socketHandler.GetMaxMessageSize()) + if 0 < socketHandler.GetReadTimeout() { + s.SetReadDeadline(time.Now().Add(socketHandler.GetReadTimeout() * time.Second)) } - s.ctx = newSocketContext(serverCTX) + s.ctx = socketCTX return s } diff --git a/socket_handler.go b/socket_handler.go index a487fcb..48a3201 100644 --- a/socket_handler.go +++ b/socket_handler.go @@ -22,6 +22,7 @@ type SocketHandler interface { // Handshake do handshake client and server // id is identity of client socket. if id is "", disallow connection Handshake(serverCTX ServerContext, ctx *fasthttp.RequestCtx) (id string, extensionsHeader *fasthttp.ResponseHeader) + SocketContext(serverCTX ServerContext) SocketContext // OnConnect invoked when client is connected // If you override ths method, must call // diff --git a/socket_handlers.go b/socket_handlers.go index 74b2aa9..ca3fd39 100644 --- a/socket_handlers.go +++ b/socket_handlers.go @@ -35,10 +35,14 @@ func (sh *SocketHandlers) Init(serverCTX ServerContext) error { return nil } -func (sh *SocketHandlers) Handshake(ctx *fasthttp.RequestCtx) (id string, extensionsHeader *fasthttp.ResponseHeader) { +func (sh *SocketHandlers) Handshake(serverCTX ServerContext, ctx *fasthttp.RequestCtx) (id string, extensionsHeader *fasthttp.ResponseHeader) { return "", nil } +func (sh *SocketHandlers) SocketContext(serverCTX ServerContext) SocketContext { + return newSocketContext(serverCTX) +} + func (sh *SocketHandlers) OnConnect(soc Socket) Socket { return soc