194 lines
4.3 KiB
Go
194 lines
4.3 KiB
Go
package servlet
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
|
|
logging "git.loafle.net/commons/logging-go"
|
|
crp "git.loafle.net/commons/rpc-go/protocol"
|
|
"git.loafle.net/commons/server-go"
|
|
css "git.loafle.net/commons/server-go/socket"
|
|
cssw "git.loafle.net/commons/server-go/socket/web"
|
|
oe "git.loafle.net/overflow/external-go"
|
|
oeg "git.loafle.net/overflow/external-go/grpc"
|
|
og "git.loafle.net/overflow/gateway"
|
|
"github.com/valyala/fasthttp"
|
|
"google.golang.org/grpc/metadata"
|
|
)
|
|
|
|
type RPCServlet interface {
|
|
cssw.Servlet
|
|
}
|
|
|
|
type RPCServlets struct {
|
|
cssw.Servlets
|
|
|
|
RPCServerCodec crp.ServerCodec
|
|
UseSession bool
|
|
|
|
sessions sync.Map
|
|
}
|
|
|
|
func (s *RPCServlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *RPCServlets) OnConnect(servletCtx server.ServletCtx, conn css.Conn) {
|
|
s.Servlets.OnConnect(servletCtx, conn)
|
|
|
|
if s.UseSession {
|
|
sessionID := servletCtx.GetAttribute(og.SessionIDKey)
|
|
targetID := servletCtx.GetAttribute(og.SessionTargetIDKey)
|
|
if nil != sessionID && nil != targetID {
|
|
s.sessions.Store(sessionID.(string), RetainSession(targetID.(string), servletCtx))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *RPCServlets) OnDisconnect(servletCtx server.ServletCtx) {
|
|
s.Servlets.OnDisconnect(servletCtx)
|
|
|
|
if s.UseSession {
|
|
sessionID := servletCtx.GetAttribute(og.SessionIDKey)
|
|
if nil != sessionID {
|
|
s.sessions.Delete(sessionID.(string))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *RPCServlets) Handle(servletCtx server.ServletCtx,
|
|
stopChan <-chan struct{}, doneChan chan<- struct{},
|
|
readChan <-chan []byte, writeChan chan<- []byte) {
|
|
defer func() {
|
|
doneChan <- struct{}{}
|
|
}()
|
|
|
|
var (
|
|
md metadata.MD
|
|
src crp.ServerRequestCodec
|
|
method string
|
|
params []string
|
|
grpcCtx context.Context
|
|
grpcReply string
|
|
replyBuff []byte
|
|
err error
|
|
)
|
|
|
|
_clientType := servletCtx.GetAttribute(og.SessionClientTypeKey)
|
|
_sessionID := servletCtx.GetAttribute(og.SessionIDKey)
|
|
_targetID := servletCtx.GetAttribute(og.SessionTargetIDKey)
|
|
|
|
if nil != _clientType && nil != _sessionID && nil != _targetID {
|
|
md = metadata.Pairs(
|
|
oe.GRPCClientTypeKey.String(), _clientType.(oe.ClientType).String(),
|
|
oe.GRPCSessionIDKey.String(), _sessionID.(string),
|
|
oe.GRPCTargetIDKey.String(), _targetID.(string),
|
|
)
|
|
}
|
|
|
|
servletCtx.SetAttribute(og.SessionWriteChanKey, writeChan)
|
|
|
|
for {
|
|
select {
|
|
case msg, ok := <-readChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
// grpc exec method call
|
|
src, err = s.RPCServerCodec.NewRequest(msg)
|
|
if nil != err {
|
|
logging.Logger().Error(err)
|
|
break
|
|
}
|
|
|
|
method = src.Method()
|
|
params, err = src.Params()
|
|
if nil != err {
|
|
logging.Logger().Error(err)
|
|
s.writeError(src, writeChan, crp.E_BAD_PARAMS, "", err)
|
|
break
|
|
}
|
|
|
|
grpcCtx = metadata.NewOutgoingContext(context.Background(), md)
|
|
grpcReply, err = oeg.Exec(grpcCtx, method, params...)
|
|
if nil != err {
|
|
logging.Logger().Error(err)
|
|
}
|
|
|
|
if !src.HasResponse() {
|
|
break
|
|
}
|
|
replyBuff, err = src.NewResponseWithString(grpcReply, err)
|
|
if nil != err {
|
|
logging.Logger().Error(err)
|
|
s.writeError(src, writeChan, crp.E_INTERNAL, "", err)
|
|
break
|
|
}
|
|
|
|
writeChan <- replyBuff
|
|
case <-stopChan:
|
|
return
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
func (s *RPCServlets) writeError(src crp.ServerRequestCodec, writeChan chan<- []byte, code crp.ErrorCode, message string, data interface{}) {
|
|
if !src.HasResponse() {
|
|
return
|
|
}
|
|
|
|
pErr := &crp.Error{
|
|
Code: code,
|
|
Message: message,
|
|
Data: data,
|
|
}
|
|
|
|
buf, err := src.NewResponse(nil, pErr)
|
|
if nil != err {
|
|
logging.Logger().Error(err)
|
|
return
|
|
}
|
|
writeChan <- buf
|
|
}
|
|
|
|
func (s *RPCServlets) GetSessions(sessionIDs []string) []*Session {
|
|
var sessions []*Session
|
|
|
|
if nil == sessionIDs || 0 == len(sessionIDs) {
|
|
return sessions
|
|
}
|
|
|
|
for _, sessionID := range sessionIDs {
|
|
session, ok := s.sessions.Load(sessionID)
|
|
if ok {
|
|
sessions = append(sessions, session.(*Session))
|
|
}
|
|
}
|
|
|
|
return sessions
|
|
}
|
|
|
|
func (s *RPCServlets) GetSessionsByTargetIDs(targetIDs []string) []*Session {
|
|
var sessions []*Session
|
|
if nil == targetIDs || 0 == len(targetIDs) {
|
|
return sessions
|
|
}
|
|
|
|
s.sessions.Range(func(k, v interface{}) bool {
|
|
session := v.(*Session)
|
|
|
|
for _, targetID := range targetIDs {
|
|
if session.TargetID == targetID {
|
|
sessions = append(sessions, session)
|
|
break
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
return sessions
|
|
}
|