package fasthttp import ( "fmt" "strings" "sync/atomic" logging "git.loafle.net/commons/logging-go" "git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go/web" "github.com/valyala/fasthttp" ) type ServerHandler interface { web.ServerHandler OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, err *web.Error) RegisterServlet(path string, servlet Servlet) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet CheckOrigin(ctx *fasthttp.RequestCtx) bool } type ServerHandlers struct { web.ServerHandlers ErrorServelt Servlet `json:"-"` // path = context only. // ex) /auth => /auth, /auth/member => /auth servlets map[string]Servlet validated atomic.Value } func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error { if err := sh.ServerHandlers.Init(serverCtx); nil != err { return err } if nil != sh.servlets { for _, servlet := range sh.servlets { if err := servlet.Init(serverCtx); nil != err { return err } } } return nil } func (sh *ServerHandlers) OnStart(serverCtx server.ServerCtx) error { if err := sh.ServerHandlers.OnStart(serverCtx); nil != err { return err } if nil != sh.servlets { for _, servlet := range sh.servlets { if err := servlet.OnStart(serverCtx); nil != err { return err } } } return nil } func (sh *ServerHandlers) OnStop(serverCtx server.ServerCtx) { if nil != sh.servlets { for _, servlet := range sh.servlets { servlet.OnStop(serverCtx) } } sh.ServerHandlers.OnStop(serverCtx) } func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { if nil != sh.servlets { for _, servlet := range sh.servlets { servlet.Destroy(serverCtx) } } sh.ServerHandlers.Destroy(serverCtx) } func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, err *web.Error) { if nil != sh.ErrorServelt { servletCtx := sh.ErrorServelt.ServletCtx(serverCtx) servletCtx.SetAttribute(web.ErrorKey, err) sh.ErrorServelt.Handle(servletCtx, ctx) return } ctx.Error(err.Cause.Error(), err.Code) } func (sh *ServerHandlers) RegisterServlet(contextPath string, servlet Servlet) { if nil == sh.servlets { sh.servlets = make(map[string]Servlet) } servlet.setContextPath(contextPath) sh.servlets[contextPath] = servlet } func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet { path := string(ctx.Path()) contextPath, err := getContextPath(path) if nil != err { logging.Logger().Warnf("Bad Request %v", err) return nil } var servlet Servlet if servlet = sh.servlets[contextPath]; nil == servlet { logging.Logger().Warnf("Servlet is not exist for url[%s]", path) return nil } return servlet } func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool { return true } func (sh *ServerHandlers) Validate() error { if nil != sh.validated.Load() { return nil } sh.validated.Store(true) if err := sh.ServerHandlers.Validate(); nil != err { return err } return nil } func getContextPath(path string) (string, error) { p := strings.TrimSpace(path) if !strings.HasPrefix(p, "/") { return "", fmt.Errorf("path[%s] must started /", path) } p = p[1:] if strings.HasSuffix(p, "/") { cpl := len(p) - 1 p = p[:cpl] } components := strings.Split(p, "/") if 0 == len(components) { return "", fmt.Errorf("path[%s] is not invalid", path) } return fmt.Sprintf("/%s", components[0]), nil }