package server import ( "context" "encoding/json" "fmt" "io" "log" "net/http" "strings" "sync" "time" serverGrpc "git.loafle.net/overflow/overflow_api_server/golang" "github.com/gorilla/websocket" ) type ClientStatus uint8 const ( CONNECTED ClientStatus = iota + 1 DISCONNECTED ) type ( // OnDisconnectFunc is callback function that used when client is disconnected OnDisconnectFunc func(Client) // OnErrorFunc is callback function that used when error occurred OnErrorFunc func(error) ) // Client is interface type Client interface { ID() string HTTPRequest() *http.Request Conn() Connection Disconnect() error OnDisconnect(OnDisconnectFunc) OnError(OnErrorFunc) initialize() error destroy() error } type client struct { id string status ClientStatus messageType int server Server httpRequest *http.Request conn Connection pingTicker *time.Ticker writeMTX sync.Mutex onDisconnectListeners []OnDisconnectFunc onErrorListeners []OnErrorFunc } var _ Client = &client{} func newClient(s Server, r *http.Request, conn Connection, clientID string) Client { c := &client{ id: clientID, status: CONNECTED, messageType: websocket.TextMessage, server: s, httpRequest: r, conn: conn, onDisconnectListeners: make([]OnDisconnectFunc, 0), onErrorListeners: make([]OnErrorFunc, 0), } if s.GetOptions().BinaryMessage { c.messageType = websocket.BinaryMessage } return c } func (c *client) ID() string { return c.id } func (c *client) HTTPRequest() *http.Request { return c.httpRequest } func (c *client) Conn() Connection { return c.conn } func (c *client) Disconnect() error { return c.server.Disconnect(c.ID()) } func (c *client) OnDisconnect(cb OnDisconnectFunc) { c.onDisconnectListeners = append(c.onDisconnectListeners, cb) } func (c *client) OnError(cb OnErrorFunc) { c.onErrorListeners = append(c.onErrorListeners, cb) } func (c *client) initialize() error { c.status = CONNECTED c.startPing() c.startReading() return nil } func (c *client) destroy() error { c.status = DISCONNECTED c.pingTicker.Stop() for _, cb := range c.onDisconnectListeners { cb(c) } return c.conn.Close() } func (c *client) startPing() { c.pingTicker = time.NewTicker(c.server.GetOptions().PingPeriod) go func() { for { <-c.pingTicker.C if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.server.GetOptions().PingTimeout)); err != nil { log.Println("ping:", err) } } }() } func (c *client) startReading() { hasReadTimeout := c.server.GetOptions().ReadTimeout > 0 c.conn.SetReadLimit(c.server.GetOptions().MaxMessageSize) c.conn.SetPongHandler(func(message string) error { if hasReadTimeout { c.conn.SetReadDeadline(time.Now().Add(c.server.GetOptions().PongTimeout)) } return nil }) defer func() { c.Disconnect() }() for { if hasReadTimeout { c.conn.SetReadDeadline(time.Now().Add(c.server.GetOptions().ReadTimeout)) } // messageType, data, err := c.conn.ReadMessage() messageType, r, err := c.conn.NextReader() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { c.fireError(err) } break } else { c.onMessageReceived(messageType, r) } } } func (c *client) onMessageReceived(messageType int, r io.Reader) { req := new(Request) err := json.NewDecoder(r).Decode(req) if err != nil { log.Println(err) } parts := strings.Split(req.Method, ".") si := &serverGrpc.ServerInput{ Target: parts[0], Method: parts[1], Params: req.Params, } grpcPool, err := c.server.GetGRPCPool().Get() if nil != err { c.writeError(req, NewError(E_INTERNAL, err, nil)) } defer c.server.GetGRPCPool().Put(grpcPool) grpcClient := grpcPool.(serverGrpc.OverflowApiServerClient) out, err := grpcClient.Exec(context.Background(), si) if err != nil { c.writeError(req, NewError(E_SERVER, err, err)) } c.writeResult(req, out.Result) } func (c *client) fireError(err error) { for _, cb := range c.onErrorListeners { cb(err) } } func (c *client) writeResult(r *Request, result string) { c.writeMTX.Lock() if writeTimeout := c.server.GetOptions().WriteTimeout; writeTimeout > 0 { // set the write deadline based on the configuration err := c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) log.Println(fmt.Errorf("%v", err)) } res := &Response{ Protocol: r.Protocol, ID: r.ID, Result: &result, Error: nil, } jRes, err := json.Marshal(res) if nil != err { log.Println(fmt.Errorf("%v", err)) } err = c.conn.WriteMessage(c.messageType, jRes) c.writeMTX.Unlock() if nil != err { _ = c.Disconnect() } } func (c *client) writeError(r *Request, perr *Error) { c.writeMTX.Lock() if writeTimeout := c.server.GetOptions().WriteTimeout; writeTimeout > 0 { // set the write deadline based on the configuration err := c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) log.Println(fmt.Errorf("%v", err)) } res := &Response{ Protocol: r.Protocol, ID: r.ID, Result: nil, Error: perr, } jRes, err := json.Marshal(res) if nil != err { log.Println(fmt.Errorf("%v", err)) } err = c.conn.WriteMessage(c.messageType, jRes) c.writeMTX.Unlock() if nil != err { _ = c.Disconnect() } }