diff --git a/server.go b/server.go index b3463aa..15047cf 100644 --- a/server.go +++ b/server.go @@ -120,8 +120,8 @@ func (s *server) handleRequest(ctx *fasthttp.RequestCtx) { return } var responseHeader *fasthttp.ResponseHeader - var allowHandshake bool - if allowHandshake, responseHeader = socketHandler.Handshake(ctx); !allowHandshake { + var socketID interface{} + if socketID, responseHeader = socketHandler.Handshake(ctx); nil == socketID { s.handleError(ctx, http.StatusNotAcceptable, fmt.Errorf("Server: Handshake err")) return } @@ -132,7 +132,7 @@ func (s *server) handleRequest(ctx *fasthttp.RequestCtx) { return } - soc := newSocket(conn, socketHandler) + soc := newSocket(socketID, conn, socketHandler) s.stopWg.Add(1) handleConnection(s, soc, socketHandler) diff --git a/socket.go b/socket.go index f7b11d5..06ea3ca 100644 --- a/socket.go +++ b/socket.go @@ -9,10 +9,11 @@ import ( "git.loafle.net/commons_go/websocket_fasthttp/websocket" ) -func newSocket(conn *websocket.Conn, sh SocketHandler) *Socket { +func newSocket(id interface{}, conn *websocket.Conn, sh SocketHandler) *Socket { s := retainSocket() s.Conn = conn s.sh = sh + s.id = id s.SetReadLimit(sh.GetMaxMessageSize()) if 0 < sh.GetReadTimeout() { s.SetReadDeadline(time.Now().Add(sh.GetReadTimeout() * time.Second)) @@ -25,9 +26,30 @@ type Socket struct { *websocket.Conn sh SocketHandler + id interface{} + attributes map[interface{}]interface{} + sc *SocketConn } +func (s *Socket) ID() interface{} { + return s.id +} + +func (s *Socket) GetAttribute(key interface{}) interface{} { + if nil == s.attributes { + return nil + } + return s.attributes[key] +} + +func (s *Socket) SetAttribute(key interface{}, value interface{}) { + if nil == s.attributes { + s.attributes = make(map[interface{}]interface{}) + } + s.attributes[key] = value +} + func (s *Socket) WaitRequest() (*SocketConn, error) { if nil != s.sc { releaseSocketConn(s.sc) @@ -141,6 +163,7 @@ func retainSocket() *Socket { func releaseSocket(s *Socket) { s.sh = nil s.sc = nil + s.id = nil socketPool.Put(s) } diff --git a/socket_handler.go b/socket_handler.go index aa485e7..46ebe57 100644 --- a/socket_handler.go +++ b/socket_handler.go @@ -7,7 +7,9 @@ import ( ) type SocketHandler interface { - Handshake(ctx *fasthttp.RequestCtx) (connectable bool, extensionsHeader *fasthttp.ResponseHeader) + // Handshake do handshake client and server + // id is identity of client socket. if id is nil, disallow connection + Handshake(ctx *fasthttp.RequestCtx) (id interface{}, extensionsHeader *fasthttp.ResponseHeader) Handle(soc *Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) GetMaxMessageSize() int64 diff --git a/socket_handlers.go b/socket_handlers.go index 20aa04a..b8ce65b 100644 --- a/socket_handlers.go +++ b/socket_handlers.go @@ -27,8 +27,8 @@ type SocketHandlers struct { PingPeriod time.Duration } -func (sh *SocketHandlers) Handshake(ctx *fasthttp.RequestCtx) (bool, *fasthttp.ResponseHeader) { - return true, nil +func (sh *SocketHandlers) Handshake(ctx *fasthttp.RequestCtx) (id interface{}, extensionsHeader *fasthttp.ResponseHeader) { + return nil, nil } func (sh *SocketHandlers) Handle(soc *Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) {