probe/server/scanner-servlet.go
crusader 7db69d4551 ing
2018-09-14 02:14:57 +09:00

263 lines
5.9 KiB
Go

package server
import (
"fmt"
"log"
"sync"
"git.loafle.net/overflow_scanner/probe/internal/pubsub"
"git.loafle.net/overflow_scanner/probe/internal/rpc"
olog "git.loafle.net/overflow/log-go"
orp "git.loafle.net/overflow/rpc-go/protocol"
orr "git.loafle.net/overflow/rpc-go/registry"
"git.loafle.net/overflow/server-go"
oss "git.loafle.net/overflow/server-go/socket"
ossw "git.loafle.net/overflow/server-go/socket/web"
"git.loafle.net/overflow_scanner/probe"
uuid "github.com/satori/go.uuid"
"github.com/valyala/fasthttp"
)
type ScannerServlet interface {
ossw.Servlet
}
type ScannerServlets struct {
ossw.Servlets
RPCInvoker orr.RPCInvoker
RPCServerCodec orp.ServerCodec
PubSub *pubsub.PubSub
sessions sync.Map
subscribeChan chan interface{}
}
func (s *ScannerServlets) Init(serverCtx server.ServerCtx) error {
if err := s.Servlets.Init(serverCtx); nil != err {
return err
}
return nil
}
func (s *ScannerServlets) OnStart(serverCtx server.ServerCtx) error {
if err := s.Servlets.OnStart(serverCtx); nil != err {
return err
}
s.subscribeChan = s.PubSub.Sub("/scanner")
go s.handleSubscribe(serverCtx)
return nil
}
func (s *ScannerServlets) OnStop(serverCtx server.ServerCtx) {
s.PubSub.Unsub(s.subscribeChan, "/scanner")
s.Servlets.OnStop(serverCtx)
}
func (s *ScannerServlets) Destroy(serverCtx server.ServerCtx) {
s.Servlets.Destroy(serverCtx)
}
func (s *ScannerServlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) {
requesterID := string(ctx.QueryArgs().Peek("requesterID"))
if "" == requesterID {
return nil, fmt.Errorf("requesterID is not valid")
}
sessionID := uuid.NewV4().String()
servletCtx.SetAttribute(probe.SessionTargetIDKey, requesterID)
servletCtx.SetAttribute(probe.SessionIDKey, sessionID)
return nil, nil
}
func (s *ScannerServlets) OnConnect(servletCtx server.ServletCtx, conn oss.Conn) {
s.Servlets.OnConnect(servletCtx, conn)
sessionID := servletCtx.GetAttribute(probe.SessionIDKey)
targetID := servletCtx.GetAttribute(probe.SessionTargetIDKey)
if nil != sessionID && nil != targetID {
s.sessions.Store(sessionID.(string), RetainSession(targetID.(string), servletCtx))
}
}
func (s *ScannerServlets) OnDisconnect(servletCtx server.ServletCtx) {
s.Servlets.OnDisconnect(servletCtx)
sessionID := servletCtx.GetAttribute(probe.SessionIDKey)
if nil != sessionID {
session, ok := s.sessions.Load(sessionID)
if ok {
ReleaseSession(session.(*Session))
}
s.sessions.Delete(sessionID.(string))
}
}
func (s *ScannerServlets) Handle(servletCtx server.ServletCtx,
stopChan <-chan struct{}, doneChan chan<- struct{},
readChan <-chan oss.SocketMessage, writeChan chan<- oss.SocketMessage) {
var (
src orp.ServerRequestCodec
messageType int
message []byte
result interface{}
resMessageType int
resMessage []byte
err error
)
servletCtx.SetAttribute(probe.SessionWriteChanKey, writeChan)
for {
select {
case socketMessage, ok := <-readChan:
if !ok {
return
}
messageType, message = socketMessage()
// grpc exec method call
src, err = s.RPCServerCodec.NewRequest(messageType, message)
if nil != err {
olog.Logger().Error(err.Error())
break
}
if !s.RPCInvoker.HasMethod(src.Method()) {
s.writeError(src, writeChan, orp.E_NO_METHOD, "", fmt.Errorf("%s is not exist", src.Method()))
break
}
result, err = s.RPCInvoker.Invoke(src)
if nil != err {
olog.Logger().Error(err.Error())
}
if !src.HasResponse() {
break
}
resMessageType, resMessage, err = src.NewResponse(result, err)
if nil != err {
olog.Logger().Error(err.Error())
s.writeError(src, writeChan, orp.E_INTERNAL, "", err)
break
}
writeChan <- oss.MakeSocketMessage(resMessageType, resMessage)
case <-stopChan:
return
}
}
}
func (s *ScannerServlets) handleSubscribe(serverCtx server.ServerCtx) {
var sessions []*Session
LOOP:
for {
select {
case msg, ok := <-s.subscribeChan:
if !ok {
return
}
_msg, ok := msg.(rpc.RPCMessage)
if !ok {
log.Print("RPCMessage is not valid")
continue LOOP
}
targets, method, params := _msg()
sessions = s.GetSessionsByTargetIDs(targets)
if nil == sessions || 0 == len(sessions) {
continue LOOP
}
messageType, message, err := s.RPCServerCodec.NewNotification(method, params)
if nil != err {
log.Print("RPCMessage is not valid ", _msg, err)
continue LOOP
}
for _, session := range sessions {
_writeChan := session.ServletCtx.GetAttribute(probe.SessionWriteChanKey)
if nil != _writeChan {
writeChan := _writeChan.(chan<- oss.SocketMessage)
writeChan <- oss.MakeSocketMessage(messageType, message)
}
}
}
}
}
func (s *ScannerServlets) writeError(src orp.ServerRequestCodec, writeChan chan<- oss.SocketMessage, code orp.ErrorCode, message string, data interface{}) {
if !src.HasResponse() {
return
}
pErr := &orp.Error{
Code: code,
Message: message,
Data: data,
}
resMessageType, resMessage, err := src.NewResponse(nil, pErr)
if nil != err {
olog.Logger().Error(err.Error())
return
}
writeChan <- oss.MakeSocketMessage(resMessageType, resMessage)
}
func (s *ScannerServlets) GetSessions(sessionIDs []string) []*Session {
var sessions []*Session
if nil == sessionIDs || 0 == len(sessionIDs) {
return sessions
}
for _, sessionID := range sessionIDs {
session, ok := s.sessions.Load(sessionID)
if ok {
sessions = append(sessions, session.(*Session))
}
}
return sessions
}
func (s *ScannerServlets) GetSessionsByTargetIDs(targetIDs []string) []*Session {
var sessions []*Session
if nil == targetIDs || 0 == len(targetIDs) {
return sessions
}
s.sessions.Range(func(k, v interface{}) bool {
session := v.(*Session)
for _, targetID := range targetIDs {
if session.TargetID == targetID {
sessions = append(sessions, session)
break
}
}
return true
})
return sessions
}