package web import ( "net/http" "sync/atomic" "git.loafle.net/overflow/server-go" "git.loafle.net/overflow/server-go/socket" "github.com/valyala/fasthttp" ) type ServerHandler interface { socket.ServerHandler OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) RegisterServlet(path string, servlet Servlet) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet CheckOrigin(ctx *fasthttp.RequestCtx) bool } type ServerHandlers struct { socket.ServerHandlers servlets map[string]Servlet validated atomic.Value } func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error { if err := sh.ServerHandlers.Init(serverCtx); nil != err { return err } if nil != sh.servlets { for _, servlet := range sh.servlets { if err := servlet.Init(serverCtx); nil != err { return err } } } return nil } func (sh *ServerHandlers) OnStart(serverCtx server.ServerCtx) error { if err := sh.ServerHandlers.OnStart(serverCtx); nil != err { return err } if nil != sh.servlets { for _, servlet := range sh.servlets { if err := servlet.OnStart(serverCtx); nil != err { return err } } } return nil } func (sh *ServerHandlers) OnStop(serverCtx server.ServerCtx) { if nil != sh.servlets { for _, servlet := range sh.servlets { servlet.OnStop(serverCtx) } } sh.ServerHandlers.OnStop(serverCtx) } func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { if nil != sh.servlets { for _, servlet := range sh.servlets { servlet.Destroy(serverCtx) } } sh.ServerHandlers.Destroy(serverCtx) } func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) { ctx.Response.Header.Set("Sec-Websocket-Version", "13") ctx.Error(http.StatusText(status), status) } func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) { if nil == sh.servlets { sh.servlets = make(map[string]Servlet) } sh.servlets[path] = servlet } func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet { path := string(ctx.Path()) var servlet Servlet if path == "" && len(sh.servlets) == 1 { for _, s := range sh.servlets { servlet = s } } else if servlet = sh.servlets[path]; nil == servlet { return nil } return servlet } func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool { return true } func (sh *ServerHandlers) Validate() error { if nil != sh.validated.Load() { return nil } sh.validated.Store(true) if err := sh.ServerHandlers.Validate(); nil != err { return err } return nil }