diff --git a/web/fasthttp/server-handler.go b/web/fasthttp/server-handler.go index d1abaed..cd3bb5d 100644 --- a/web/fasthttp/server-handler.go +++ b/web/fasthttp/server-handler.go @@ -1,6 +1,10 @@ package fasthttp import ( + "fmt" + "strings" + + logging "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go/web" @@ -21,6 +25,10 @@ type ServerHandler interface { type ServerHandlers struct { web.ServerHandlers + NotFoundServelt Servlet + + // path = context only. + // ex) /auth => /auth, /auth/member => /auth servlets map[string]Servlet } @@ -53,23 +61,27 @@ func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) { } -func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) { +func (sh *ServerHandlers) RegisterServlet(contextPath string, servlet Servlet) { if nil == sh.servlets { sh.servlets = make(map[string]Servlet) } - sh.servlets[path] = servlet + sh.servlets[contextPath] = servlet } func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet { - path := string(ctx.Path()) + contextPath, err := getContextPath(string(ctx.Path())) + if nil != err { + logging.Logger().Warnf("%v", err) + return sh.NotFoundServelt + } var servlet Servlet - if path == "" && len(sh.servlets) == 1 { + if contextPath == "" && len(sh.servlets) == 1 { for _, s := range sh.servlets { servlet = s } - } else if servlet = sh.servlets[path]; nil == servlet { - return nil + } else if servlet = sh.servlets[contextPath]; nil == servlet { + return sh.NotFoundServelt } return servlet @@ -84,5 +96,29 @@ func (sh *ServerHandlers) Validate() error { return err } + if nil == sh.NotFoundServelt { + return fmt.Errorf("NotFoundServelt must to set") + } + return nil } + +func getContextPath(path string) (string, error) { + p := strings.TrimSpace(path) + + if !strings.HasPrefix(p, "/") { + return "", fmt.Errorf("The path[%s] must started /", path) + } + + if strings.HasSuffix(p, "/") { + cpl := len(p) - 1 + p = p[:cpl] + } + + components := strings.Split(p, "/") + if 0 == len(components) { + return "", fmt.Errorf("The path[%s] is not invalid", path) + } + + return fmt.Sprintf("/%s", components[0]), nil +}