This commit is contained in:
crusader 2018-04-06 12:17:55 +09:00
parent 851891b8b0
commit 8422254849

View File

@ -1,6 +1,10 @@
package fasthttp package fasthttp
import ( import (
"fmt"
"strings"
logging "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/web" "git.loafle.net/commons/server-go/web"
@ -21,6 +25,10 @@ type ServerHandler interface {
type ServerHandlers struct { type ServerHandlers struct {
web.ServerHandlers web.ServerHandlers
NotFoundServelt Servlet
// path = context only.
// ex) /auth => /auth, /auth/member => /auth
servlets map[string]Servlet servlets map[string]Servlet
} }
@ -53,23 +61,27 @@ func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) {
func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) { func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) {
} }
func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) { func (sh *ServerHandlers) RegisterServlet(contextPath string, servlet Servlet) {
if nil == sh.servlets { if nil == sh.servlets {
sh.servlets = make(map[string]Servlet) sh.servlets = make(map[string]Servlet)
} }
sh.servlets[path] = servlet sh.servlets[contextPath] = servlet
} }
func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet { func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet {
path := string(ctx.Path()) contextPath, err := getContextPath(string(ctx.Path()))
if nil != err {
logging.Logger().Warnf("%v", err)
return sh.NotFoundServelt
}
var servlet Servlet var servlet Servlet
if path == "" && len(sh.servlets) == 1 { if contextPath == "" && len(sh.servlets) == 1 {
for _, s := range sh.servlets { for _, s := range sh.servlets {
servlet = s servlet = s
} }
} else if servlet = sh.servlets[path]; nil == servlet { } else if servlet = sh.servlets[contextPath]; nil == servlet {
return nil return sh.NotFoundServelt
} }
return servlet return servlet
@ -84,5 +96,29 @@ func (sh *ServerHandlers) Validate() error {
return err return err
} }
if nil == sh.NotFoundServelt {
return fmt.Errorf("NotFoundServelt must to set")
}
return nil return nil
} }
func getContextPath(path string) (string, error) {
p := strings.TrimSpace(path)
if !strings.HasPrefix(p, "/") {
return "", fmt.Errorf("The path[%s] must started /", path)
}
if strings.HasSuffix(p, "/") {
cpl := len(p) - 1
p = p[:cpl]
}
components := strings.Split(p, "/")
if 0 == len(components) {
return "", fmt.Errorf("The path[%s] is not invalid", path)
}
return fmt.Sprintf("/%s", components[0]), nil
}