263 lines
5.9 KiB
Go
263 lines
5.9 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"sync"
|
|
|
|
"git.loafle.net/overflow_scanner/probe/internal/pubsub"
|
|
"git.loafle.net/overflow_scanner/probe/internal/rpc"
|
|
|
|
olog "git.loafle.net/overflow/log-go"
|
|
orp "git.loafle.net/overflow/rpc-go/protocol"
|
|
orr "git.loafle.net/overflow/rpc-go/registry"
|
|
"git.loafle.net/overflow/server-go"
|
|
oss "git.loafle.net/overflow/server-go/socket"
|
|
ossw "git.loafle.net/overflow/server-go/socket/web"
|
|
"git.loafle.net/overflow_scanner/probe"
|
|
uuid "github.com/satori/go.uuid"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
type ScannerServlet interface {
|
|
ossw.Servlet
|
|
}
|
|
|
|
type ScannerServlets struct {
|
|
ossw.Servlets
|
|
|
|
RPCInvoker orr.RPCInvoker
|
|
RPCServerCodec orp.ServerCodec
|
|
PubSub *pubsub.PubSub
|
|
|
|
sessions sync.Map
|
|
|
|
subscribeChan chan interface{}
|
|
}
|
|
|
|
func (s *ScannerServlets) Init(serverCtx server.ServerCtx) error {
|
|
if err := s.Servlets.Init(serverCtx); nil != err {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ScannerServlets) OnStart(serverCtx server.ServerCtx) error {
|
|
if err := s.Servlets.OnStart(serverCtx); nil != err {
|
|
return err
|
|
}
|
|
|
|
s.subscribeChan = s.PubSub.Sub("/scanner")
|
|
go s.handleSubscribe(serverCtx)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ScannerServlets) OnStop(serverCtx server.ServerCtx) {
|
|
|
|
s.PubSub.Unsub(s.subscribeChan, "/scanner")
|
|
|
|
s.Servlets.OnStop(serverCtx)
|
|
}
|
|
|
|
func (s *ScannerServlets) Destroy(serverCtx server.ServerCtx) {
|
|
|
|
s.Servlets.Destroy(serverCtx)
|
|
}
|
|
|
|
func (s *ScannerServlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) {
|
|
requesterID := string(ctx.QueryArgs().Peek("requesterID"))
|
|
if "" == requesterID {
|
|
return nil, fmt.Errorf("requesterID is not valid")
|
|
}
|
|
|
|
sessionID := uuid.NewV4().String()
|
|
|
|
servletCtx.SetAttribute(probe.SessionTargetIDKey, requesterID)
|
|
servletCtx.SetAttribute(probe.SessionIDKey, sessionID)
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *ScannerServlets) OnConnect(servletCtx server.ServletCtx, conn oss.Conn) {
|
|
s.Servlets.OnConnect(servletCtx, conn)
|
|
|
|
sessionID := servletCtx.GetAttribute(probe.SessionIDKey)
|
|
targetID := servletCtx.GetAttribute(probe.SessionTargetIDKey)
|
|
|
|
if nil != sessionID && nil != targetID {
|
|
s.sessions.Store(sessionID.(string), RetainSession(targetID.(string), servletCtx))
|
|
}
|
|
}
|
|
|
|
func (s *ScannerServlets) OnDisconnect(servletCtx server.ServletCtx) {
|
|
s.Servlets.OnDisconnect(servletCtx)
|
|
|
|
sessionID := servletCtx.GetAttribute(probe.SessionIDKey)
|
|
if nil != sessionID {
|
|
session, ok := s.sessions.Load(sessionID)
|
|
if ok {
|
|
ReleaseSession(session.(*Session))
|
|
}
|
|
s.sessions.Delete(sessionID.(string))
|
|
}
|
|
}
|
|
|
|
func (s *ScannerServlets) Handle(servletCtx server.ServletCtx,
|
|
stopChan <-chan struct{}, doneChan chan<- struct{},
|
|
readChan <-chan oss.SocketMessage, writeChan chan<- oss.SocketMessage) {
|
|
|
|
var (
|
|
src orp.ServerRequestCodec
|
|
messageType int
|
|
message []byte
|
|
result interface{}
|
|
resMessageType int
|
|
resMessage []byte
|
|
err error
|
|
)
|
|
|
|
servletCtx.SetAttribute(probe.SessionWriteChanKey, writeChan)
|
|
|
|
for {
|
|
select {
|
|
case socketMessage, ok := <-readChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
messageType, message = socketMessage()
|
|
// grpc exec method call
|
|
src, err = s.RPCServerCodec.NewRequest(messageType, message)
|
|
if nil != err {
|
|
olog.Logger().Error(err.Error())
|
|
break
|
|
}
|
|
|
|
if !s.RPCInvoker.HasMethod(src.Method()) {
|
|
s.writeError(src, writeChan, orp.E_NO_METHOD, "", fmt.Errorf("%s is not exist", src.Method()))
|
|
break
|
|
}
|
|
|
|
result, err = s.RPCInvoker.Invoke(src)
|
|
if nil != err {
|
|
olog.Logger().Error(err.Error())
|
|
}
|
|
|
|
if !src.HasResponse() {
|
|
break
|
|
}
|
|
resMessageType, resMessage, err = src.NewResponse(result, err)
|
|
if nil != err {
|
|
olog.Logger().Error(err.Error())
|
|
s.writeError(src, writeChan, orp.E_INTERNAL, "", err)
|
|
break
|
|
}
|
|
|
|
writeChan <- oss.MakeSocketMessage(resMessageType, resMessage)
|
|
case <-stopChan:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *ScannerServlets) handleSubscribe(serverCtx server.ServerCtx) {
|
|
var sessions []*Session
|
|
|
|
LOOP:
|
|
for {
|
|
select {
|
|
case msg, ok := <-s.subscribeChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
_msg, ok := msg.(rpc.RPCMessage)
|
|
if !ok {
|
|
log.Print("RPCMessage is not valid")
|
|
continue LOOP
|
|
}
|
|
|
|
targets, method, params := _msg()
|
|
|
|
sessions = s.GetSessionsByTargetIDs(targets)
|
|
|
|
if nil == sessions || 0 == len(sessions) {
|
|
continue LOOP
|
|
}
|
|
|
|
messageType, message, err := s.RPCServerCodec.NewNotification(method, params)
|
|
if nil != err {
|
|
log.Print("RPCMessage is not valid ", _msg, err)
|
|
continue LOOP
|
|
}
|
|
|
|
for _, session := range sessions {
|
|
_writeChan := session.ServletCtx.GetAttribute(probe.SessionWriteChanKey)
|
|
if nil != _writeChan {
|
|
writeChan := _writeChan.(chan<- oss.SocketMessage)
|
|
writeChan <- oss.MakeSocketMessage(messageType, message)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *ScannerServlets) writeError(src orp.ServerRequestCodec, writeChan chan<- oss.SocketMessage, code orp.ErrorCode, message string, data interface{}) {
|
|
if !src.HasResponse() {
|
|
return
|
|
}
|
|
|
|
pErr := &orp.Error{
|
|
Code: code,
|
|
Message: message,
|
|
Data: data,
|
|
}
|
|
|
|
resMessageType, resMessage, err := src.NewResponse(nil, pErr)
|
|
if nil != err {
|
|
olog.Logger().Error(err.Error())
|
|
return
|
|
}
|
|
writeChan <- oss.MakeSocketMessage(resMessageType, resMessage)
|
|
}
|
|
|
|
func (s *ScannerServlets) GetSessions(sessionIDs []string) []*Session {
|
|
var sessions []*Session
|
|
|
|
if nil == sessionIDs || 0 == len(sessionIDs) {
|
|
return sessions
|
|
}
|
|
|
|
for _, sessionID := range sessionIDs {
|
|
session, ok := s.sessions.Load(sessionID)
|
|
if ok {
|
|
sessions = append(sessions, session.(*Session))
|
|
}
|
|
}
|
|
|
|
return sessions
|
|
}
|
|
|
|
func (s *ScannerServlets) GetSessionsByTargetIDs(targetIDs []string) []*Session {
|
|
var sessions []*Session
|
|
if nil == targetIDs || 0 == len(targetIDs) {
|
|
return sessions
|
|
}
|
|
|
|
s.sessions.Range(func(k, v interface{}) bool {
|
|
session := v.(*Session)
|
|
|
|
for _, targetID := range targetIDs {
|
|
if session.TargetID == targetID {
|
|
sessions = append(sessions, session)
|
|
break
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
return sessions
|
|
}
|