package servlet import ( "context" "sync" logging "git.loafle.net/commons/logging-go" crp "git.loafle.net/commons/rpc-go/protocol" "git.loafle.net/commons/server-go" css "git.loafle.net/commons/server-go/socket" cssw "git.loafle.net/commons/server-go/socket/web" oe "git.loafle.net/overflow/external-go" oeg "git.loafle.net/overflow/external-go/grpc" og "git.loafle.net/overflow/gateway" "github.com/valyala/fasthttp" "google.golang.org/grpc/metadata" ) type RPCServlet interface { cssw.Servlet } type RPCServlets struct { cssw.Servlets RPCServerCodec crp.ServerCodec UseSession bool sessions sync.Map } func (s *RPCServlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) { return nil, nil } func (s *RPCServlets) OnConnect(servletCtx server.ServletCtx, conn css.Conn) { s.Servlets.OnConnect(servletCtx, conn) if s.UseSession { sessionID := servletCtx.GetAttribute(og.SessionIDKey) targetID := servletCtx.GetAttribute(og.SessionTargetIDKey) if nil != sessionID && nil != targetID { s.sessions.Store(sessionID.(string), RetainSession(targetID.(string), servletCtx)) } } } func (s *RPCServlets) OnDisconnect(servletCtx server.ServletCtx) { s.Servlets.OnDisconnect(servletCtx) if s.UseSession { sessionID := servletCtx.GetAttribute(og.SessionIDKey) if nil != sessionID { s.sessions.Delete(sessionID.(string)) } } } func (s *RPCServlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan css.SocketMessage, writeChan chan<- css.SocketMessage) { defer func() { doneChan <- struct{}{} }() var ( md metadata.MD src crp.ServerRequestCodec messageType int message []byte method string params []string grpcCtx context.Context grpcReply string replyBuff []byte err error ) _clientType := servletCtx.GetAttribute(og.SessionClientTypeKey) _sessionID := servletCtx.GetAttribute(og.SessionIDKey) _targetID := servletCtx.GetAttribute(og.SessionTargetIDKey) if nil != _clientType && nil != _sessionID && nil != _targetID { md = metadata.Pairs( oe.GRPCClientTypeKey.String(), _clientType.(oe.ClientType).String(), oe.GRPCSessionIDKey.String(), _sessionID.(string), oe.GRPCTargetIDKey.String(), _targetID.(string), ) } servletCtx.SetAttribute(og.SessionWriteChanKey, writeChan) for { select { case socketMessage, ok := <-readChan: if !ok { return } messageType, message = socketMessage() // grpc exec method call src, err = s.RPCServerCodec.NewRequest(messageType, message) if nil != err { logging.Logger().Error(err) break } method = src.Method() params, err = src.Params() if nil != err { logging.Logger().Error(err) s.writeError(src, writeChan, crp.E_BAD_PARAMS, "", err) break } grpcCtx = metadata.NewOutgoingContext(context.Background(), md) grpcReply, err = oeg.Exec(grpcCtx, method, params...) if nil != err { logging.Logger().Error(err) } if !src.HasResponse() { break } messageType, replyBuff, err = src.NewResponseWithString(grpcReply, err) if nil != err { logging.Logger().Error(err) s.writeError(src, writeChan, crp.E_INTERNAL, "", err) break } writeChan <- css.MakeSocketMessage(messageType, replyBuff) case <-stopChan: return } } } func (s *RPCServlets) writeError(src crp.ServerRequestCodec, writeChan chan<- css.SocketMessage, code crp.ErrorCode, message string, data interface{}) { if !src.HasResponse() { return } pErr := &crp.Error{ Code: code, Message: message, Data: data, } messageType, buf, err := src.NewResponse(nil, pErr) if nil != err { logging.Logger().Error(err) return } writeChan <- css.MakeSocketMessage(messageType, buf) } func (s *RPCServlets) GetSessions(sessionIDs []string) []*Session { var sessions []*Session if nil == sessionIDs || 0 == len(sessionIDs) { return sessions } for _, sessionID := range sessionIDs { session, ok := s.sessions.Load(sessionID) if ok { sessions = append(sessions, session.(*Session)) } } return sessions } func (s *RPCServlets) GetSessionsByTargetIDs(targetIDs []string) []*Session { var sessions []*Session if nil == targetIDs || 0 == len(targetIDs) { return sessions } s.sessions.Range(func(k, v interface{}) bool { session := v.(*Session) for _, targetID := range targetIDs { if session.TargetID == targetID { sessions = append(sessions, session) break } } return true }) return sessions }