package servlet import ( "fmt" "reflect" cdr "git.loafle.net/commons/di-go/registry" crr "git.loafle.net/commons/rest-go/registry" "git.loafle.net/commons/server-go" csw "git.loafle.net/commons/server-go/web" cswf "git.loafle.net/commons/server-go/web/fasthttp" oca "git.loafle.net/overflow/commons-go/annotation" "github.com/valyala/fasthttp" ) type MethodMapping struct { Method string ParamKeys []string } type RESTServlet interface { cswf.Servlet RegisterRESTServices(services []interface{}) error } type RESTServlets struct { cswf.Servlets RESTRegistry crr.RESTRegistry MethodMapping map[string]map[string]*MethodMapping } func (s *RESTServlets) Init(serverCtx server.ServerCtx) error { if err := s.Servlets.Init(serverCtx); nil != err { return err } if nil == s.MethodMapping { s.MethodMapping = make(map[string]map[string]*MethodMapping) } return nil } func (s *RESTServlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { method := string(ctx.Method()) path := string(ctx.Path()) es, ok := s.MethodMapping[method] if !ok { return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path)) } mapping, ok := es[path] if !ok { return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path)) } switch method { case "GET": return s.HandleGet(servletCtx, ctx, mapping) case "POST": return s.HandlePost(servletCtx, ctx, mapping) } return nil } func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error { params := make([]string, 0) if nil != mapping.ParamKeys { qargs := ctx.QueryArgs() if nil == qargs { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) } for _, k := range mapping.ParamKeys { buf := qargs.Peek(k) if nil == buf { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter for %s is not valied", k)) } params = append(params, string(buf)) } } _, err := s.RESTRegistry.Invoke(mapping.Method, params, servletCtx, ctx) if nil != err { return csw.NewError(fasthttp.StatusInternalServerError, err) } return nil } func (s *RESTServlets) HandlePost(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error { params := make([]string, 0) if nil != mapping.ParamKeys { pargs := ctx.PostArgs() if nil == pargs { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) } for _, k := range mapping.ParamKeys { buf := pargs.Peek(k) if nil == buf { return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter for %s is not valied", k)) } params = append(params, string(buf)) } } _, err := s.RESTRegistry.Invoke(mapping.Method, params, servletCtx, ctx) if nil != err { return csw.NewError(fasthttp.StatusInternalServerError, err) } return nil } func (s *RESTServlets) RegisterRESTServices(services []interface{}) error { if nil == services || 0 == len(services) { return nil } s.MethodMapping = make(map[string]map[string]*MethodMapping) s.RESTRegistry = crr.NewRESTRegistry() LOOP: for _, service := range services { t := reflect.TypeOf(service) ta := cdr.GetTypeAnnotation(t, oca.RESTServiceAnnotationType) if nil == ta { return fmt.Errorf("Service[%s] is not RESTService, use @RESTService", t.Elem().Name()) } s.RESTRegistry.RegisterService(service, "") mas := cdr.GetMethodAnnotations(t, oca.RequestMappingAnnotationType) if nil == mas || 0 == len(mas) { continue LOOP } for methodName, v := range mas { ma := v.(*oca.RequestMappingAnnotation) mm, ok := s.MethodMapping[ma.Method] if !ok { mm = make(map[string]*MethodMapping) s.MethodMapping[ma.Method] = mm } _, ok = mm[ma.Entry] if ok { return fmt.Errorf("Mapping of method[%s], entry[%s] is exist already", ma.Method, ma.Entry) } mm[ma.Entry] = &MethodMapping{ Method: fmt.Sprintf("%s.%s", t.Elem().Name(), methodName), ParamKeys: ma.Params, } } } return nil }