diff --git a/servlet/rest-servlet.go b/servlet/rest-servlet.go index 7ff34e4..63796c9 100644 --- a/servlet/rest-servlet.go +++ b/servlet/rest-servlet.go @@ -2,12 +2,14 @@ package servlet import ( "fmt" + "reflect" - crp "git.loafle.net/commons/rpc-go/protocol" - crr "git.loafle.net/commons/rpc-go/registry" + cdr "git.loafle.net/commons/di-go/registry" + crr "git.loafle.net/commons/rest-go/registry" "git.loafle.net/commons/server-go" csw "git.loafle.net/commons/server-go/web" cswf "git.loafle.net/commons/server-go/web/fasthttp" + oca "git.loafle.net/overflow/commons-go/annotation" "github.com/valyala/fasthttp" ) @@ -18,43 +20,55 @@ type MethodMapping struct { type RESTServlet interface { cswf.Servlet + + RegisterRESTServices(services []interface{}) error } type RESTServlets struct { cswf.Servlets - ServerCodec crp.ServerCodec - RPCInvoker crr.RPCInvoker + RESTRegistry crr.RESTRegistry - MethodMapping map[string]MethodMapping + MethodMapping map[string]map[string]*MethodMapping } -func (s *RESTServlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { - method := string(ctx.Method()) - - switch method { - case "GET": - return s.HandleGet(servletCtx, ctx) - case "POST": - return s.HandlePost(servletCtx, ctx) +func (s *RESTServlets) Init(serverCtx server.ServerCtx) error { + if err := s.Servlets.Init(serverCtx); nil != err { + return err + } + if nil == s.MethodMapping { + s.MethodMapping = make(map[string]map[string]*MethodMapping) } return nil } -func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { +func (s *RESTServlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { + method := string(ctx.Method()) path := string(ctx.Path()) - if nil == s.MethodMapping { - return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for %s", path)) - } - mapping, ok := s.MethodMapping[path] + + es, ok := s.MethodMapping[method] if !ok { - return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for %s", path)) + return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path)) + } + mapping, ok := es[path] + if !ok { + return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path)) } - method := mapping.Method + switch method { + case "GET": + return s.HandleGet(servletCtx, ctx, mapping) + case "POST": + return s.HandlePost(servletCtx, ctx, mapping) + } + + return nil +} + +func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error { params := make([]string, 0) - if nil != mapping.ParamKeys && 0 < len(mapping.ParamKeys) { + if nil != mapping.ParamKeys { qargs := ctx.QueryArgs() if nil == qargs { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) @@ -69,50 +83,81 @@ func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.Req } } - reqCodec, err := s.ServerCodec.NewRequestWithString(method, params, nil) - if nil != err { - return csw.NewError(fasthttp.StatusBadRequest, err) - } + _, err := s.RESTRegistry.Invoke(mapping.Method, params, servletCtx, ctx) - reply, err := s.RPCInvoker.Invoke(reqCodec, servletCtx, ctx) - - if nil == reply && nil == err { - return nil - } - - buf, err := reqCodec.NewResponseWithString(reply.(string), err) if nil != err { return csw.NewError(fasthttp.StatusInternalServerError, err) } - ctx.SetBody(buf) - return nil } -func (s *RESTServlets) HandlePost(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { - buf := ctx.PostBody() - if nil == buf { - return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) +func (s *RESTServlets) HandlePost(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error { + params := make([]string, 0) + if nil != mapping.ParamKeys { + pargs := ctx.PostArgs() + if nil == pargs { + return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) + } + + for _, k := range mapping.ParamKeys { + buf := pargs.Peek(k) + if nil == buf { + return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter for %s is not valied", k)) + } + params = append(params, string(buf)) + } } - reqCodec, err := s.ServerCodec.NewRequest(buf) - if nil != err { - return csw.NewError(fasthttp.StatusBadRequest, err) - } + _, err := s.RESTRegistry.Invoke(mapping.Method, params, servletCtx, ctx) - reply, err := s.RPCInvoker.Invoke(reqCodec, servletCtx, ctx) - - if nil == reply && nil == err { - return nil - } - - buf, err = reqCodec.NewResponseWithString(reply.(string), err) if nil != err { return csw.NewError(fasthttp.StatusInternalServerError, err) } - ctx.SetBody(buf) - + return nil +} + +func (s *RESTServlets) RegisterRESTServices(services []interface{}) error { + if nil == services || 0 == len(services) { + return nil + } + + s.MethodMapping = make(map[string]map[string]*MethodMapping) + s.RESTRegistry = crr.NewRESTRegistry() + +LOOP: + for _, service := range services { + t := reflect.TypeOf(service) + ta := cdr.GetTypeAnnotation(t, oca.RESTServiceAnnotationType) + if nil == ta { + return fmt.Errorf("Service[%s] is not RESTService, use @RESTService", t.Elem().Name()) + } + s.RESTRegistry.RegisterService(service, "") + + mas := cdr.GetMethodAnnotations(t, oca.RequestMappingAnnotationType) + if nil == mas || 0 == len(mas) { + continue LOOP + } + + for methodName, v := range mas { + ma := v.(*oca.RequestMappingAnnotation) + mm, ok := s.MethodMapping[ma.Method] + if !ok { + mm = make(map[string]*MethodMapping) + s.MethodMapping[ma.Method] = mm + } + + _, ok = mm[ma.Entry] + if ok { + return fmt.Errorf("Mapping of method[%s], entry[%s] is exist already", ma.Method, ma.Entry) + } + + mm[ma.Entry] = &MethodMapping{ + Method: fmt.Sprintf("%s.%s", t.Elem().Name(), methodName), + ParamKeys: ma.Params, + } + } + } return nil }