163 lines
3.4 KiB
Go
163 lines
3.4 KiB
Go
package fasthttp
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"sync/atomic"
|
|
|
|
olog "git.loafle.net/overflow/log-go"
|
|
"git.loafle.net/overflow/server-go"
|
|
"git.loafle.net/overflow/server-go/web"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
type ServerHandler interface {
|
|
web.ServerHandler
|
|
|
|
OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, err *web.Error)
|
|
|
|
RegisterServlet(path string, servlet Servlet)
|
|
Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet
|
|
|
|
CheckOrigin(ctx *fasthttp.RequestCtx) bool
|
|
}
|
|
|
|
type ServerHandlers struct {
|
|
web.ServerHandlers
|
|
|
|
ErrorServelt Servlet `json:"-"`
|
|
|
|
// path = context only.
|
|
// ex) /auth => /auth, /auth/member => /auth
|
|
servlets map[string]Servlet
|
|
|
|
validated atomic.Value
|
|
}
|
|
|
|
func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error {
|
|
if err := sh.ServerHandlers.Init(serverCtx); nil != err {
|
|
return err
|
|
}
|
|
|
|
if nil != sh.servlets {
|
|
for _, servlet := range sh.servlets {
|
|
if err := servlet.Init(serverCtx); nil != err {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sh *ServerHandlers) OnStart(serverCtx server.ServerCtx) error {
|
|
if err := sh.ServerHandlers.OnStart(serverCtx); nil != err {
|
|
return err
|
|
}
|
|
|
|
if nil != sh.servlets {
|
|
for _, servlet := range sh.servlets {
|
|
if err := servlet.OnStart(serverCtx); nil != err {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sh *ServerHandlers) OnStop(serverCtx server.ServerCtx) {
|
|
if nil != sh.servlets {
|
|
for _, servlet := range sh.servlets {
|
|
servlet.OnStop(serverCtx)
|
|
}
|
|
}
|
|
|
|
sh.ServerHandlers.OnStop(serverCtx)
|
|
}
|
|
|
|
func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) {
|
|
if nil != sh.servlets {
|
|
for _, servlet := range sh.servlets {
|
|
servlet.Destroy(serverCtx)
|
|
}
|
|
}
|
|
|
|
sh.ServerHandlers.Destroy(serverCtx)
|
|
}
|
|
|
|
func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, err *web.Error) {
|
|
if nil != sh.ErrorServelt {
|
|
servletCtx := sh.ErrorServelt.ServletCtx(serverCtx)
|
|
servletCtx.SetAttribute(web.ErrorKey, err)
|
|
sh.ErrorServelt.Handle(servletCtx, ctx)
|
|
return
|
|
}
|
|
|
|
ctx.Error(err.Cause.Error(), err.Code)
|
|
}
|
|
|
|
func (sh *ServerHandlers) RegisterServlet(contextPath string, servlet Servlet) {
|
|
if nil == sh.servlets {
|
|
sh.servlets = make(map[string]Servlet)
|
|
}
|
|
servlet.setContextPath(contextPath)
|
|
sh.servlets[contextPath] = servlet
|
|
}
|
|
|
|
func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet {
|
|
path := string(ctx.Path())
|
|
contextPath, err := getContextPath(path)
|
|
if nil != err {
|
|
olog.Logger().Warnf("Bad Request %v", err)
|
|
return nil
|
|
}
|
|
|
|
var servlet Servlet
|
|
if servlet = sh.servlets[contextPath]; nil == servlet {
|
|
olog.Logger().Warnf("Servlet is not exist for url[%s]", path)
|
|
return nil
|
|
}
|
|
|
|
return servlet
|
|
}
|
|
|
|
func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool {
|
|
return true
|
|
}
|
|
|
|
func (sh *ServerHandlers) Validate() error {
|
|
if nil != sh.validated.Load() {
|
|
return nil
|
|
}
|
|
sh.validated.Store(true)
|
|
|
|
if err := sh.ServerHandlers.Validate(); nil != err {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getContextPath(path string) (string, error) {
|
|
p := strings.TrimSpace(path)
|
|
|
|
if !strings.HasPrefix(p, "/") {
|
|
return "", fmt.Errorf("path[%s] must started /", path)
|
|
}
|
|
p = p[1:]
|
|
|
|
if strings.HasSuffix(p, "/") {
|
|
cpl := len(p) - 1
|
|
p = p[:cpl]
|
|
}
|
|
|
|
components := strings.Split(p, "/")
|
|
if 0 == len(components) {
|
|
return "", fmt.Errorf("path[%s] is not invalid", path)
|
|
}
|
|
|
|
return fmt.Sprintf("/%s", components[0]), nil
|
|
}
|