package server import ( "fmt" "log" "sync" "git.loafle.net/overflow_scanner/probe/internal/pubsub" "git.loafle.net/overflow_scanner/probe/internal/rpc" olog "git.loafle.net/overflow/log-go" orp "git.loafle.net/overflow/rpc-go/protocol" orr "git.loafle.net/overflow/rpc-go/registry" "git.loafle.net/overflow/server-go" oss "git.loafle.net/overflow/server-go/socket" ossw "git.loafle.net/overflow/server-go/socket/web" "git.loafle.net/overflow_scanner/probe" uuid "github.com/satori/go.uuid" "github.com/valyala/fasthttp" ) type ScannerServlet interface { ossw.Servlet } type ScannerServlets struct { ossw.Servlets RPCInvoker orr.RPCInvoker RPCServerCodec orp.ServerCodec PubSub *pubsub.PubSub sessions sync.Map subscribeChan chan interface{} } func (s *ScannerServlets) Init(serverCtx server.ServerCtx) error { if err := s.Servlets.Init(serverCtx); nil != err { return err } return nil } func (s *ScannerServlets) OnStart(serverCtx server.ServerCtx) error { if err := s.Servlets.OnStart(serverCtx); nil != err { return err } s.subscribeChan = s.PubSub.Sub("/scanner") go s.handleSubscribe(serverCtx) return nil } func (s *ScannerServlets) OnStop(serverCtx server.ServerCtx) { s.PubSub.Unsub(s.subscribeChan, "/scanner") s.Servlets.OnStop(serverCtx) } func (s *ScannerServlets) Destroy(serverCtx server.ServerCtx) { s.Servlets.Destroy(serverCtx) } func (s *ScannerServlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) { requesterID := string(ctx.QueryArgs().Peek("requesterID")) if "" == requesterID { return nil, fmt.Errorf("requesterID is not valid") } sessionID := uuid.NewV4().String() servletCtx.SetAttribute(probe.SessionTargetIDKey, requesterID) servletCtx.SetAttribute(probe.SessionIDKey, sessionID) return nil, nil } func (s *ScannerServlets) OnConnect(servletCtx server.ServletCtx, conn oss.Conn) { s.Servlets.OnConnect(servletCtx, conn) sessionID := servletCtx.GetAttribute(probe.SessionIDKey) targetID := servletCtx.GetAttribute(probe.SessionTargetIDKey) if nil != sessionID && nil != targetID { s.sessions.Store(sessionID.(string), RetainSession(targetID.(string), servletCtx)) } } func (s *ScannerServlets) OnDisconnect(servletCtx server.ServletCtx) { s.Servlets.OnDisconnect(servletCtx) sessionID := servletCtx.GetAttribute(probe.SessionIDKey) if nil != sessionID { session, ok := s.sessions.Load(sessionID) if ok { ReleaseSession(session.(*Session)) } s.sessions.Delete(sessionID.(string)) } } func (s *ScannerServlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan oss.SocketMessage, writeChan chan<- oss.SocketMessage) { var ( src orp.ServerRequestCodec messageType int message []byte result interface{} resMessageType int resMessage []byte err error ) servletCtx.SetAttribute(probe.SessionWriteChanKey, writeChan) for { select { case socketMessage, ok := <-readChan: if !ok { return } messageType, message = socketMessage() // grpc exec method call src, err = s.RPCServerCodec.NewRequest(messageType, message) if nil != err { olog.Logger().Error(err.Error()) break } if !s.RPCInvoker.HasMethod(src.Method()) { olog.Logger().Error(err.Error()) s.writeError(src, writeChan, orp.E_NO_METHOD, "", err) break } result, err = s.RPCInvoker.Invoke(src) if nil != err { olog.Logger().Error(err.Error()) } if !src.HasResponse() { break } resMessageType, resMessage, err = src.NewResponse(result, err) if nil != err { olog.Logger().Error(err.Error()) s.writeError(src, writeChan, orp.E_INTERNAL, "", err) break } writeChan <- oss.MakeSocketMessage(resMessageType, resMessage) case <-stopChan: return } } } func (s *ScannerServlets) handleSubscribe(serverCtx server.ServerCtx) { var sessions []*Session LOOP: for { select { case msg, ok := <-s.subscribeChan: if !ok { return } _msg, ok := msg.(rpc.RPCMessage) if !ok { log.Print("RPCMessage is not valid") continue LOOP } targets, method, params := _msg() sessions = s.GetSessionsByTargetIDs(targets) if nil == sessions || 0 == len(sessions) { continue LOOP } messageType, message, err := s.RPCServerCodec.NewNotification(method, params) if nil != err { log.Print("RPCMessage is not valid ", _msg, err) continue LOOP } for _, session := range sessions { _writeChan := session.ServletCtx.GetAttribute(probe.SessionWriteChanKey) if nil != _writeChan { writeChan := _writeChan.(chan<- oss.SocketMessage) writeChan <- oss.MakeSocketMessage(messageType, message) } } } } } func (s *ScannerServlets) writeError(src orp.ServerRequestCodec, writeChan chan<- oss.SocketMessage, code orp.ErrorCode, message string, data interface{}) { if !src.HasResponse() { return } pErr := &orp.Error{ Code: code, Message: message, Data: data, } resMessageType, resMessage, err := src.NewResponse(nil, pErr) if nil != err { olog.Logger().Error(err.Error()) return } writeChan <- oss.MakeSocketMessage(resMessageType, resMessage) } func (s *ScannerServlets) GetSessions(sessionIDs []string) []*Session { var sessions []*Session if nil == sessionIDs || 0 == len(sessionIDs) { return sessions } for _, sessionID := range sessionIDs { session, ok := s.sessions.Load(sessionID) if ok { sessions = append(sessions, session.(*Session)) } } return sessions } func (s *ScannerServlets) GetSessionsByTargetIDs(targetIDs []string) []*Session { var sessions []*Session if nil == targetIDs || 0 == len(targetIDs) { return sessions } s.sessions.Range(func(k, v interface{}) bool { session := v.(*Session) for _, targetID := range targetIDs { if session.TargetID == targetID { sessions = append(sessions, session) break } } return true }) return sessions }