This commit is contained in:
crusader 2018-04-10 22:17:46 +09:00
parent 8f775f2bad
commit ab5e366a2c

View File

@ -2,12 +2,14 @@ package servlet
import ( import (
"fmt" "fmt"
"reflect"
crp "git.loafle.net/commons/rpc-go/protocol" cdr "git.loafle.net/commons/di-go/registry"
crr "git.loafle.net/commons/rpc-go/registry" crr "git.loafle.net/commons/rest-go/registry"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
csw "git.loafle.net/commons/server-go/web" csw "git.loafle.net/commons/server-go/web"
cswf "git.loafle.net/commons/server-go/web/fasthttp" cswf "git.loafle.net/commons/server-go/web/fasthttp"
oca "git.loafle.net/overflow/commons-go/annotation"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -18,43 +20,55 @@ type MethodMapping struct {
type RESTServlet interface { type RESTServlet interface {
cswf.Servlet cswf.Servlet
RegisterRESTServices(services []interface{}) error
} }
type RESTServlets struct { type RESTServlets struct {
cswf.Servlets cswf.Servlets
ServerCodec crp.ServerCodec RESTRegistry crr.RESTRegistry
RPCInvoker crr.RPCInvoker
MethodMapping map[string]MethodMapping MethodMapping map[string]map[string]*MethodMapping
} }
func (s *RESTServlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { func (s *RESTServlets) Init(serverCtx server.ServerCtx) error {
method := string(ctx.Method()) if err := s.Servlets.Init(serverCtx); nil != err {
return err
switch method { }
case "GET": if nil == s.MethodMapping {
return s.HandleGet(servletCtx, ctx) s.MethodMapping = make(map[string]map[string]*MethodMapping)
case "POST":
return s.HandlePost(servletCtx, ctx)
} }
return nil return nil
} }
func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { func (s *RESTServlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error {
method := string(ctx.Method())
path := string(ctx.Path()) path := string(ctx.Path())
if nil == s.MethodMapping {
return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for %s", path)) es, ok := s.MethodMapping[method]
}
mapping, ok := s.MethodMapping[path]
if !ok { if !ok {
return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for %s", path)) return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path))
}
mapping, ok := es[path]
if !ok {
return csw.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found for [%s]%s", method, path))
} }
method := mapping.Method switch method {
case "GET":
return s.HandleGet(servletCtx, ctx, mapping)
case "POST":
return s.HandlePost(servletCtx, ctx, mapping)
}
return nil
}
func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error {
params := make([]string, 0) params := make([]string, 0)
if nil != mapping.ParamKeys && 0 < len(mapping.ParamKeys) { if nil != mapping.ParamKeys {
qargs := ctx.QueryArgs() qargs := ctx.QueryArgs()
if nil == qargs { if nil == qargs {
return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied"))
@ -69,50 +83,81 @@ func (s *RESTServlets) HandleGet(servletCtx server.ServletCtx, ctx *fasthttp.Req
} }
} }
reqCodec, err := s.ServerCodec.NewRequestWithString(method, params, nil) _, err := s.RESTRegistry.Invoke(mapping.Method, params, servletCtx, ctx)
if nil != err {
return csw.NewError(fasthttp.StatusBadRequest, err)
}
reply, err := s.RPCInvoker.Invoke(reqCodec, servletCtx, ctx)
if nil == reply && nil == err {
return nil
}
buf, err := reqCodec.NewResponseWithString(reply.(string), err)
if nil != err { if nil != err {
return csw.NewError(fasthttp.StatusInternalServerError, err) return csw.NewError(fasthttp.StatusInternalServerError, err)
} }
ctx.SetBody(buf)
return nil return nil
} }
func (s *RESTServlets) HandlePost(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *csw.Error { func (s *RESTServlets) HandlePost(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx, mapping *MethodMapping) *csw.Error {
buf := ctx.PostBody() params := make([]string, 0)
if nil == buf { if nil != mapping.ParamKeys {
return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied")) pargs := ctx.PostArgs()
if nil == pargs {
return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter is not valied"))
}
for _, k := range mapping.ParamKeys {
buf := pargs.Peek(k)
if nil == buf {
return csw.NewError(fasthttp.StatusBadRequest, fmt.Errorf("Parameter for %s is not valied", k))
}
params = append(params, string(buf))
}
} }
reqCodec, err := s.ServerCodec.NewRequest(buf) _, err := s.RESTRegistry.Invoke(mapping.Method, params, servletCtx, ctx)
if nil != err {
return csw.NewError(fasthttp.StatusBadRequest, err)
}
reply, err := s.RPCInvoker.Invoke(reqCodec, servletCtx, ctx)
if nil == reply && nil == err {
return nil
}
buf, err = reqCodec.NewResponseWithString(reply.(string), err)
if nil != err { if nil != err {
return csw.NewError(fasthttp.StatusInternalServerError, err) return csw.NewError(fasthttp.StatusInternalServerError, err)
} }
ctx.SetBody(buf) return nil
}
func (s *RESTServlets) RegisterRESTServices(services []interface{}) error {
if nil == services || 0 == len(services) {
return nil
}
s.MethodMapping = make(map[string]map[string]*MethodMapping)
s.RESTRegistry = crr.NewRESTRegistry()
LOOP:
for _, service := range services {
t := reflect.TypeOf(service)
ta := cdr.GetTypeAnnotation(t, oca.RESTServiceAnnotationType)
if nil == ta {
return fmt.Errorf("Service[%s] is not RESTService, use @RESTService", t.Elem().Name())
}
s.RESTRegistry.RegisterService(service, "")
mas := cdr.GetMethodAnnotations(t, oca.RequestMappingAnnotationType)
if nil == mas || 0 == len(mas) {
continue LOOP
}
for methodName, v := range mas {
ma := v.(*oca.RequestMappingAnnotation)
mm, ok := s.MethodMapping[ma.Method]
if !ok {
mm = make(map[string]*MethodMapping)
s.MethodMapping[ma.Method] = mm
}
_, ok = mm[ma.Entry]
if ok {
return fmt.Errorf("Mapping of method[%s], entry[%s] is exist already", ma.Method, ma.Entry)
}
mm[ma.Entry] = &MethodMapping{
Method: fmt.Sprintf("%s.%s", t.Elem().Name(), methodName),
ParamKeys: ma.Params,
}
}
}
return nil return nil
} }