package servlet import ( "encoding/json" "fmt" "reflect" cdr "git.loafle.net/commons/di-go/registry" crr "git.loafle.net/commons/rest-go/registry" "git.loafle.net/commons/server-go" csw "git.loafle.net/commons/server-go/web" cswf "git.loafle.net/commons/server-go/web/fasthttp" occa "git.loafle.net/overflow/commons-go/core/annotation" "github.com/valyala/fasthttp" ) type MethodMapping struct { Method string ParamKeys []string } type RESTServlet interface { cswf.Servlet RegisterRESTServices(services []interface{}) error } type RESTServlets struct { cswf.Servlets restRegistry crr.RESTRegistry methodMapping map[string]map[string]*MethodMapping } func (s *RESTServlets) Init(serverCtx server.ServerCtx) error { if err := s.Servlets.Init(serverCtx); nil != err { return err } if nil == s.methodMapping { s.methodMapping = make(map[string]map[string]*MethodMapping) } return nil } func (s *RESTServlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { method := string(ctx.Method()) path := string(ctx.Path()) requestPath := s.RequestPath(ctx) es, ok := s.methodMapping[method] if !ok { return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path)) } mapping, ok := es[requestPath] if !ok { return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path)) } switch method { case "GET": return s.HandleGet(servletCtx, ctx, mapping) case "POST": return s.HandlePost(servletCtx, ctx, mapping) } return nil } func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error { params := make([]string, 0) if nil != mapping.ParamKeys { qargs := ctx.QueryArgs() if nil == qargs { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) } for _, k := range mapping.ParamKeys { buf := qargs.Peek(k) if nil == buf { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter for %s is not valied", k)) } params = append(params, string(buf)) } } _, err := s.restRegistry.Invoke(mapping.Method, params, servletCtx, ctx) if nil != err { return csw.NewError(fasthttp.StatusInternalServerError, err) } return nil } func (s *RESTServlets) HandlePost(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error { params := make([][]byte, 0) if nil != mapping.ParamKeys { buf := ctx.PostBody() if nil == buf || 0 == len(buf) { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) } var jsonMap map[string]json.RawMessage if err := json.Unmarshal(buf, &jsonMap); nil != err { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied %v", err)) } for _, k := range mapping.ParamKeys { v, ok := jsonMap[k] if !ok { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter for %s is not valied", k)) } params = append(params, v) } } _, err := s.restRegistry.InvokeWithBytes(mapping.Method, params, servletCtx, ctx) if nil != err { return csw.NewError(fasthttp.StatusInternalServerError, err) } return nil } func (s *RESTServlets) RegisterRESTServices(services []interface{}) error { if nil == services || 0 == len(services) { return nil } s.methodMapping = make(map[string]map[string]*MethodMapping) s.restRegistry = crr.NewRESTRegistry() LOOP: for _, service := range services { t := reflect.TypeOf(service) ta := cdr.GetTypeAnnotation(t, occa.RESTServiceAnnotationType) if nil == ta { return fmt.Errorf("Service[%s] is not RESTService, use @RESTService", t.Elem().Name()) } s.restRegistry.RegisterService(service, "") mas := cdr.GetMethodAnnotations(t, occa.RequestMappingAnnotationType) if nil == mas || 0 == len(mas) { continue LOOP } for methodName, v := range mas { ma := v.(*occa.RequestMappingAnnotation) mm, ok := s.methodMapping[ma.Method] if !ok { mm = make(map[string]*MethodMapping) s.methodMapping[ma.Method] = mm } _, ok = mm[ma.Entry] if ok { return fmt.Errorf("Mapping of method[%s], entry[%s] is exist already", ma.Method, ma.Entry) } mm[ma.Entry] = &MethodMapping{ Method: fmt.Sprintf("%s.%s", t.Elem().Name(), methodName), ParamKeys: ma.Params, } } } return nil }