diff --git a/servlet/auth-servlet.go b/servlet/auth-servlet.go index 6dc765b..4b31a9d 100644 --- a/servlet/auth-servlet.go +++ b/servlet/auth-servlet.go @@ -27,7 +27,7 @@ type AuthServlet interface { type AuthServlets struct { ogrs.RPCServlets - connections sync.Map + sessions sync.Map } func (s *AuthServlets) Init(serverCtx server.ServerCtx) error { @@ -132,7 +132,7 @@ func (s *AuthServlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn) sessionID := servletCtx.GetAttribute(og.SessionIDKey) targetID := servletCtx.GetAttribute(og.SessionTargetIDKey) if nil != sessionID && nil != targetID { - s.connections.Store(sessionID.(string), retainConnection(targetID.(string), servletCtx)) + s.sessions.Store(sessionID.(string), ogrs.RetainSession(targetID.(string), servletCtx)) } } @@ -141,7 +141,7 @@ func (s *AuthServlets) OnDisconnect(servletCtx server.ServletCtx) { sessionID := servletCtx.GetAttribute(og.SessionIDKey) if nil != sessionID { - s.connections.Delete(sessionID.(string)) + s.sessions.Delete(sessionID.(string)) } } @@ -156,13 +156,13 @@ func (s *AuthServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeChan switch msg.TargetType { case ogs.PROBE: for _, targetID := range msg.Targets { - _connections := s.getProbeConnections(targetID) - if nil == _connections || 0 == len(_connections) { + _sessions := s.getAuthSessions(targetID) + if nil == _sessions || 0 == len(_sessions) { break } - for _, _connection := range _connections { - _writeChan := _connection.servletCtx.GetAttribute(og.SessionWriteChanKey) + for _, _session := range _sessions { + _writeChan := _session.ServletCtx.GetAttribute(og.SessionWriteChanKey) if nil != _writeChan { writeChan := _writeChan.(chan<- []byte) writeChan <- msg.Message @@ -175,34 +175,16 @@ func (s *AuthServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeChan } } -func (s *AuthServlets) getProbeConnections(targetID string) []*connection { - var connections []*connection +func (s *AuthServlets) getAuthSessions(targetID string) []*ogrs.Session { + var sessions []*ogrs.Session - s.connections.Range(func(k, v interface{}) bool { - _connection := v.(*connection) - if _connection.targetID == targetID { - connections = append(connections, _connection) + s.sessions.Range(func(k, v interface{}) bool { + session := v.(*ogrs.Session) + if session.TargetID == targetID { + sessions = append(sessions, session) } return true }) - return connections -} - -type connection struct { - targetID string - servletCtx server.ServletCtx -} - -var connectionPool sync.Pool - -func retainConnection(targetID string, servletCtx server.ServletCtx) *connection { - return nil -} - -func releaseConnection(_connection *connection) { - _connection.targetID = "" - _connection.servletCtx = nil - - connectionPool.Put(_connection) + return sessions } diff --git a/servlet/probe-servlet.go b/servlet/probe-servlet.go index 16a82f4..9b57490 100644 --- a/servlet/probe-servlet.go +++ b/servlet/probe-servlet.go @@ -113,7 +113,7 @@ func (s *ProbeServlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn sessionID := servletCtx.GetAttribute(og.SessionIDKey) targetID := servletCtx.GetAttribute(og.SessionTargetIDKey) if nil != sessionID && nil != targetID { - s.connections.Store(sessionID.(string), retainConnection(targetID.(string), servletCtx)) + s.connections.Store(sessionID.(string), ogrs.RetainSession(targetID.(string), servletCtx)) } } @@ -137,13 +137,13 @@ func (s *ProbeServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeCha switch msg.TargetType { case ogs.PROBE: for _, targetID := range msg.Targets { - _connections := s.getProbeConnections(targetID) - if nil == _connections || 0 == len(_connections) { + _sessions := s.getProbeSessions(targetID) + if nil == _sessions || 0 == len(_sessions) { break } - for _, _connection := range _connections { - _writeChan := _connection.servletCtx.GetAttribute(og.SessionWriteChanKey) + for _, _session := range _sessions { + _writeChan := _session.ServletCtx.GetAttribute(og.SessionWriteChanKey) if nil != _writeChan { writeChan := _writeChan.(chan<- []byte) writeChan <- msg.Message @@ -155,34 +155,16 @@ func (s *ProbeServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeCha } } -func (s *ProbeServlets) getProbeConnections(targetID string) []*connection { - var connections []*connection +func (s *ProbeServlets) getProbeSessions(targetID string) []*ogrs.Session { + var sessions []*ogrs.Session s.connections.Range(func(k, v interface{}) bool { - _connection := v.(*connection) - if _connection.targetID == targetID { - connections = append(connections, _connection) + session := v.(*ogrs.Session) + if session.TargetID == targetID { + sessions = append(sessions, session) } return true }) - return connections + return sessions } - -//type connection struct { -// targetID string -// servletCtx server.ServletCtx -//} -// -//var connectionPool sync.Pool -// -//func retainConnection(targetID string, servletCtx server.ServletCtx) *connection { -// return nil -//} -// -//func releaseConnection(_connection *connection) { -// _connection.targetID = "" -// _connection.servletCtx = nil -// -// connectionPool.Put(_connection) -//}