diff --git a/server.go b/server.go index 888b9ae..74b4f94 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package overflow_gateway_websocket import ( "log" "net/http" + "sync" "git.loafle.net/overflow/overflow_gateway_websocket/websocket" "github.com/valyala/fasthttp" @@ -20,6 +21,7 @@ type server struct { _upgrader *websocket.Upgrader _handlers map[string]*ClientOptions _clients map[string]Client + _cMtx sync.Mutex } func NewServer(o *ServerOptions) Server { @@ -53,25 +55,26 @@ func (s *server) onDisconnected(c Client) { func (s *server) onConnection(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) + co, ok := s._handlers[path] + if !ok { + ctx.Response.Header.Set("Sec-Websocket-Version", "13") + ctx.Error(http.StatusText(fasthttp.StatusNotFound), fasthttp.StatusNotFound) + + log.Printf("Path[%s] is not exist.", path) + return + } s._upgrader.Upgrade(ctx, nil, func(conn *websocket.Conn, err error) { if err != nil { log.Print("upgrade:", err) return } - co, ok := s._handlers[path] - if !ok { - ctx.Response.Header.Set("Sec-Websocket-Version", "13") - ctx.Error(http.StatusText(fasthttp.StatusNotFound), fasthttp.StatusNotFound) - - log.Printf("Path[%s] is not exist.", path) - conn.Close() - return - } - + s._cMtx.Lock() cid := s._option.IDGenerator(ctx) c := NewClient(cid, path, co, conn) s._clients[cid] = c + s._cMtx.Unlock() + s._option.OnConnection(path, c) }) }