diff --git a/servlet/auth-servlet.go b/servlet/auth-servlet.go index 4b31a9d..620241e 100644 --- a/servlet/auth-servlet.go +++ b/servlet/auth-servlet.go @@ -100,6 +100,10 @@ func (s *AuthServlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.Req extHeader := &fasthttp.ResponseHeader{} extHeader.Add(ocnc.HTTPResponseHeaderKey_NoAuthProbe_SetTempProbeKey, nap.TempProbeKey) + servletCtx.SetAttribute(og.SessionIDKey, nap.TempProbeKey) + servletCtx.SetAttribute(og.SessionClientTypeKey, og.PROBE) + servletCtx.SetAttribute(og.SessionTargetIDKey, nap.TempProbeKey) + return extHeader, nil case ocnc.HTTPRequestHeaderValue_NoAuthProbe_Method_Connect: bTempProbeKey := ctx.Request.Header.Peek(ocnc.HTTPRequestHeaderKey_NoAuthProbe_TempProbeKey) @@ -146,6 +150,7 @@ func (s *AuthServlets) OnDisconnect(servletCtx server.ServletCtx) { } func (s *AuthServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeChan <-chan *ogs.Message) { +LOOP: for { select { case msg, ok := <-subscribeChan: @@ -155,36 +160,36 @@ func (s *AuthServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeChan switch msg.TargetType { case ogs.PROBE: - for _, targetID := range msg.Targets { - _sessions := s.getAuthSessions(targetID) - if nil == _sessions || 0 == len(_sessions) { - break - } - - for _, _session := range _sessions { - _writeChan := _session.ServletCtx.GetAttribute(og.SessionWriteChanKey) - if nil != _writeChan { - writeChan := _writeChan.(chan<- []byte) - writeChan <- msg.Message - } - } + sessions := s.getAuthSessions(msg.Targets) + if nil == sessions || 0 == len(sessions) { + continue LOOP } + for _, session := range sessions { + _writeChan := session.ServletCtx.GetAttribute(og.SessionWriteChanKey) + if nil != _writeChan { + writeChan := _writeChan.(chan<- []byte) + writeChan <- msg.Message + } + } } } } } -func (s *AuthServlets) getAuthSessions(targetID string) []*ogrs.Session { +func (s *AuthServlets) getAuthSessions(targetIDs []string) []*ogrs.Session { var sessions []*ogrs.Session - s.sessions.Range(func(k, v interface{}) bool { - session := v.(*ogrs.Session) - if session.TargetID == targetID { - sessions = append(sessions, session) + if nil == targetIDs || 0 == len(targetIDs) { + return sessions + } + + for _, targetID := range targetIDs { + session, ok := s.sessions.Load(targetID) + if ok { + sessions = append(sessions, session.(*ogrs.Session)) } - return true - }) + } return sessions } diff --git a/servlet/probe-servlet.go b/servlet/probe-servlet.go index 9b57490..bf4d076 100644 --- a/servlet/probe-servlet.go +++ b/servlet/probe-servlet.go @@ -26,7 +26,7 @@ type ProbeServlet interface { type ProbeServlets struct { ogrs.RPCServlets - connections sync.Map + sessions sync.Map } func (s *ProbeServlets) Init(serverCtx server.ServerCtx) error { @@ -113,7 +113,7 @@ func (s *ProbeServlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn sessionID := servletCtx.GetAttribute(og.SessionIDKey) targetID := servletCtx.GetAttribute(og.SessionTargetIDKey) if nil != sessionID && nil != targetID { - s.connections.Store(sessionID.(string), ogrs.RetainSession(targetID.(string), servletCtx)) + s.sessions.Store(sessionID.(string), ogrs.RetainSession(targetID.(string), servletCtx)) } } @@ -122,11 +122,13 @@ func (s *ProbeServlets) OnDisconnect(servletCtx server.ServletCtx) { sessionID := servletCtx.GetAttribute(og.SessionIDKey) if nil != sessionID { - s.connections.Delete(sessionID.(string)) + s.sessions.Delete(sessionID.(string)) } } func (s *ProbeServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeChan <-chan *ogs.Message) { + +LOOP: for { select { case msg, ok := <-subscribeChan: @@ -136,18 +138,16 @@ func (s *ProbeServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeCha switch msg.TargetType { case ogs.PROBE: - for _, targetID := range msg.Targets { - _sessions := s.getProbeSessions(targetID) - if nil == _sessions || 0 == len(_sessions) { - break - } + sessions := s.getProbeSessions(msg.Targets) + if nil == sessions || 0 == len(sessions) { + continue LOOP + } - for _, _session := range _sessions { - _writeChan := _session.ServletCtx.GetAttribute(og.SessionWriteChanKey) - if nil != _writeChan { - writeChan := _writeChan.(chan<- []byte) - writeChan <- msg.Message - } + for _, session := range sessions { + _writeChan := session.ServletCtx.GetAttribute(og.SessionWriteChanKey) + if nil != _writeChan { + writeChan := _writeChan.(chan<- []byte) + writeChan <- msg.Message } } } @@ -155,16 +155,19 @@ func (s *ProbeServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeCha } } -func (s *ProbeServlets) getProbeSessions(targetID string) []*ogrs.Session { +func (s *ProbeServlets) getProbeSessions(targetIDs []string) []*ogrs.Session { var sessions []*ogrs.Session - s.connections.Range(func(k, v interface{}) bool { - session := v.(*ogrs.Session) - if session.TargetID == targetID { - sessions = append(sessions, session) + if nil == targetIDs || 0 == len(targetIDs) { + return sessions + } + + for _, targetID := range targetIDs { + session, ok := s.sessions.Load(targetID) + if ok { + sessions = append(sessions, session.(*ogrs.Session)) } - return true - }) + } return sessions }