package server import ( "fmt" "log" "net/http" "sync" serverGrpc "git.loafle.net/overflow/overflow_api_server/golang" "git.loafle.net/overflow/overflow_service_websocket/config" "git.loafle.net/overflow/overflow_service_websocket/pool" grpcPool "git.loafle.net/overflow/overflow_service_websocket/pool/grpc" "github.com/gorilla/websocket" "google.golang.org/grpc" ) type ( // OnConnectionFunc is callback function that used when client is connected OnConnectionFunc func(Client) ) // Server is the websocket server, // listens on the config's port, the critical part is the event OnConnection type Server interface { GetOptions() *Options GetGRPCPool() pool.Pool HTTPHandler() http.Handler HandleConnection(*http.Request, Connection) OnConnection(cb OnConnectionFunc) IsConnected(clientID string) bool GetClient(clientID string) Client Disconnect(clientID string) error } type server struct { options *Options clients map[string]Client clientMTX sync.Mutex onConnectionListeners []OnConnectionFunc grpcPool pool.Pool } // server implementation // New creates a websocket server and returns it func New(o *Options) Server { return newServer(o) } // newServer creates a websocket server and returns it func newServer(o *Options) Server { s := &server{ clients: make(map[string]Client, 100), onConnectionListeners: make([]OnConnectionFunc, 0), } if nil == o { o = &Options{} } o.Validate() s.options = o pool, err := grpcPool.New(1, 5, func(conn *grpc.ClientConn) (interface{}, error) { return serverGrpc.NewOverflowApiServerClient(conn), nil }, func() (*grpc.ClientConn, error) { return grpc.Dial(config.GetConfig().GRpc.Addr, grpc.WithInsecure()) }, ) if nil != err { log.Fatal(err) return nil } s.grpcPool = pool return s } func (s *server) GetOptions() *Options { return s.options } func (s *server) GetGRPCPool() pool.Pool { return s.grpcPool } func (s *server) HTTPHandler() http.Handler { o := s.options upgrader := websocket.Upgrader{ ReadBufferSize: o.ReadBufferSize, WriteBufferSize: o.WriteBufferSize, Error: o.OnError, CheckOrigin: o.OnCheckOrigin, } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, w.Header()) if err != nil { http.Error(w, "Websocket Error: "+err.Error(), http.StatusServiceUnavailable) return } s.HandleConnection(r, conn) }) } func (s *server) HandleConnection(r *http.Request, conn Connection) { clientID := s.options.IDGenerator(r) c := newClient(s, r, conn, clientID) err := s.addClient(clientID, c) if nil != err { log.Println(fmt.Errorf("%v", err)) return } for i := range s.onConnectionListeners { s.onConnectionListeners[i](c) } c.initialize() } func (s *server) OnConnection(cb OnConnectionFunc) { s.onConnectionListeners = append(s.onConnectionListeners, cb) } func (s *server) IsConnected(clientID string) bool { soc := s.clients[clientID] return soc != nil } func (s *server) GetClient(clientID string) Client { return s.clients[clientID] } func (s *server) Disconnect(clientID string) error { c := s.clients[clientID] if nil == c { return nil } delete(s.clients, clientID) return c.destroy() } func (s *server) addClient(clientID string, c Client) error { s.clientMTX.Lock() if s.clients[clientID] != nil { return fmt.Errorf("ID of Client[%s] is exist already", clientID) } s.clients[clientID] = c s.clientMTX.Unlock() return nil }