diff --git a/constants.go b/constants.go index a42c25e..1b90101 100644 --- a/constants.go +++ b/constants.go @@ -9,9 +9,11 @@ import ( var ( ServletSocketKey = cuc.ContextKey("ServletSocket") - ClientTypeKey = cuc.ContextKey("ClientType") - SocketIDKey = cuc.ContextKey("SocketID") - TargetIDKey = cuc.ContextKey("TargetID") + + GRPCMetadataKey = cuc.ContextKey("GRPCMetadata") + ClientTypeKey = cuc.ContextKey("ClientType") + SocketIDKey = cuc.ContextKey("SocketID") + TargetIDKey = cuc.ContextKey("TargetID") ) const ( diff --git a/internal/server/rpc/gateway_rpc_servlet_handlers.go b/internal/server/rpc/gateway_rpc_servlet_handlers.go index 9f42e19..fb86dae 100644 --- a/internal/server/rpc/gateway_rpc_servlet_handlers.go +++ b/internal/server/rpc/gateway_rpc_servlet_handlers.go @@ -24,11 +24,17 @@ type GatewayRPCServletHandlers struct { } func (sh *GatewayRPCServletHandlers) Invoke(servletCTX rpc.ServletContext, requestCodec protocol.RegistryCodec) (result interface{}, err error) { - md := metadata.Pairs( - oogw.GRPCClientTypeKey, servletCTX.GetAttribute(oogw.SocketIDKey).(string), - oogw.GRPCSessionIDKey, servletCTX.GetAttribute(oogw.SocketIDKey).(string), - oogw.GRPCTargetIDKey, servletCTX.GetAttribute(oogw.SocketIDKey).(string)) - grpcCTX := metadata.NewOutgoingContext(context.Background(), md) + md := servletCTX.GetAttribute(oogw.GRPCMetadataKey) + if nil == md { + md = metadata.Pairs( + oogw.GRPCClientTypeKey, servletCTX.GetAttribute(oogw.SocketIDKey).(string), + oogw.GRPCSessionIDKey, servletCTX.GetAttribute(oogw.SocketIDKey).(string), + oogw.GRPCTargetIDKey, servletCTX.GetAttribute(oogw.SocketIDKey).(string)) + + servletCTX.SetAttribute(oogw.GRPCMetadataKey, md) + } + + grpcCTX := metadata.NewOutgoingContext(context.Background(), md.(metadata.MD)) params, err := requestCodec.Params() if nil != err {