package websocket import ( "fmt" "log" "net/http" "sync" "github.com/gorilla/websocket" ) type ( // OnConnectionFunc is callback function that used when client is connected OnConnectionFunc func(Client) ) // Server is the websocket server, // listens on the config's port, the critical part is the event OnConnection type Server interface { Set(...OptionSetter) Handler() http.Handler HandleConnection(*http.Request, Connection) OnConnection(cb OnConnectionFunc) IsConnected(clientID string) bool GetSocket(clientID string) Client Disconnect(clientID string) error } type server struct { options *Options clients map[string]Client clientMTX sync.Mutex onConnectionListeners []OnConnectionFunc } var _ Server = &server{} var defaultServer = newServer() // server implementation // New creates a websocket server and returns it func New(setters ...OptionSetter) Server { return newServer(setters...) } // newServer creates a websocket server and returns it func newServer(setters ...OptionSetter) Server { s := &server{ clients: make(map[string]Client, 100), onConnectionListeners: make([]OnConnectionFunc, 0), } s.Set(setters...) return s } // Set is function that set option values func Set(setters ...OptionSetter) { defaultServer.Set(setters...) } func (s *server) Set(setters ...OptionSetter) { for _, setter := range setters { setter.Set(s.options) } s.options.Validate() } // Handler is the function that used on http request func Handler() http.Handler { return defaultServer.Handler() } func (s *server) Handler() http.Handler { o := s.options upgrader := websocket.Upgrader{ ReadBufferSize: o.ReadBufferSize, WriteBufferSize: o.WriteBufferSize, Error: o.Error, CheckOrigin: o.CheckOrigin, } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, w.Header()) if err != nil { http.Error(w, "Websocket Error: "+err.Error(), http.StatusServiceUnavailable) return } s.HandleConnection(r, conn) }) } func HandleConnection(r *http.Request, conn Connection) { defaultServer.HandleConnection(r, conn) } func (s *server) HandleConnection(r *http.Request, conn Connection) { clientID := s.options.IDGenerator(r) c := newClient(s, r, conn, clientID) err := s.addClient(clientID, c) if nil != err { log.Println(fmt.Errorf("%v", err)) return } for i := range s.onConnectionListeners { s.onConnectionListeners[i](c) } } // OnConnection is function that add the callback when client is connected to default Server func OnConnection(cb OnConnectionFunc) { defaultServer.OnConnection(cb) } func (s *server) OnConnection(cb OnConnectionFunc) { s.onConnectionListeners = append(s.onConnectionListeners, cb) } // IsConnected is function that check client is connect func IsConnected(clientID string) bool { return defaultServer.IsConnected(clientID) } func (s *server) IsConnected(clientID string) bool { soc := s.clients[clientID] return soc != nil } // GetSocket is function that return client instance func GetSocket(clientID string) Client { return defaultServer.GetSocket(clientID) } func (s *server) GetSocket(clientID string) Client { return s.clients[clientID] } // Disconnect is function that disconnect a client func Disconnect(clientID string) error { return defaultServer.Disconnect(clientID) } func (s *server) Disconnect(clientID string) error { c := s.clients[clientID] if nil == c { return nil } delete(s.clients, clientID) return c.destroy() } func (s *server) addClient(clientID string, c Client) error { s.clientMTX.Lock() if s.clients[clientID] != nil { return fmt.Errorf("ID of Client[%s] is exist already", clientID) } s.clients[clientID] = c s.clientMTX.Unlock() return nil }