package websocket import ( "fmt" "log" "net/http" "sync" "github.com/gorilla/websocket" ) type ( OnConnectionFunc func(websocket.Conn) ) // Server is the websocket server, // listens on the config's port, the critical part is the event OnConnection type Server interface { Set(...OptionSetter) Handler() http.Handler HandleConnection(*http.Request, *websocket.Conn) OnConnection(cb OnConnectionFunc) IsConnected(clientID string) bool GetClient(clientID string) *Client Disconnect(clientID string) error } type server struct { options *Options clients map[string]*client clientMTX sync.Mutex onConnectionListeners []OnConnectionFunc } var _ Server = &server{} var defaultServer = newServer() // server implementation // New creates a websocket server and returns it func New(setters ...OptionSetter) Server { return newServer(setters...) } // newServer creates a websocket server and returns it func newServer(setters ...OptionSetter) *server { s := &server{ clients: make(map[string]*client, 100), onConnectionListeners: make([]OnConnectionFunc, 0), } s.Set(setters...) return s } func (s *server) Set(setters ...OptionSetter) { for _, setter := range setters { setter.Set(s.options) } s.options.Validate() } func (s *server) Handler() http.Handler { o := s.options upgrader := websocket.Upgrader{ ReadBufferSize: o.ReadBufferSize, WriteBufferSize: o.WriteBufferSize, Error: o.Error, CheckOrigin: o.CheckOrigin} return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, w.Header()) if err != nil { http.Error(w, "Websocket Error: "+err.Error(), http.StatusServiceUnavailable) return } s.HandleConnection(r, conn) }) } func (s *server) HandleConnection(r *http.Request, conn *websocket.Conn) { clientID := s.options.IDGenerator(r) c := newClient(s, r, conn, clientID) err := s.addClient(clientID, c) if nil != err { log.Println(fmt.Errorf("%v", err)) return } } func (s *server) OnConnection(cb OnConnectionFunc) { s.onConnectionListeners = append(s.onConnectionListeners, cb) } func (s *server) IsConnected(clientID string) bool { c := s.clients[clientID] return c != nil } func (s *server) GetClient(clientID string) *Client { return s.clients[clientID] } func (s *server) Disconnect(clientID string) error { c := s.clients[clientID] if nil == c { return nil } return nil } func (s *server) addClient(clientID string, c *client) error { s.clientMTX.Lock() if s.clients[clientID] != nil { return fmt.Errorf("Client[%s] is exist already", clientID) } s.clients[clientID] = c s.clientMTX.Unlock() return nil }