From 6e4ca0dd5c6caac613497dea1921b29245d79abd Mon Sep 17 00:00:00 2001 From: crusader Date: Wed, 12 Jul 2017 21:36:12 +0900 Subject: [PATCH] ing --- websocket/channel.go | 1 - websocket/client.go | 209 ++++++++++++++++++++++++++++++---------- websocket/connection.go | 34 +++++++ websocket/options.go | 22 ++++- websocket/server.go | 72 +++++++++++--- 5 files changed, 264 insertions(+), 74 deletions(-) delete mode 100644 websocket/channel.go create mode 100644 websocket/connection.go diff --git a/websocket/channel.go b/websocket/channel.go deleted file mode 100644 index 708bc8c..0000000 --- a/websocket/channel.go +++ /dev/null @@ -1 +0,0 @@ -package websocket diff --git a/websocket/client.go b/websocket/client.go index 8524bc7..867ee28 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -1,47 +1,81 @@ package websocket import ( - "io" + "fmt" + "log" + "net" "net/http" "sync" "time" - "github.com/gorilla/websocket" + gWebsocket "github.com/gorilla/websocket" ) +type ClientStatus uint8 + +const ( + CONNECTED ClientStatus = iota + 1 + DISCONNECTED +) + +type ( + // OnDisconnectFunc is callback function that used when client is disconnected + OnDisconnectFunc func(Client) + // OnErrorFunc is callback function that used when error occurred + OnErrorFunc func(string) + // OnMessageFunc is callback function that receives messages from client + OnMessageFunc func([]byte) + // OnFunc is callback function that particular event which fires when a message to this event received + OnFunc interface{} +) + +// Client is interface type Client interface { ID() string - RemoteAddr() string - UserAgent() string - SetWriteDeadline(t time.Time) error - SetReadDeadline(t time.Time) error - SetReadLimit(limit int64) - SetPongHandler(h func(appData string) error) - SetPingHandler(h func(appData string) error) - WriteControl(messageType int, data []byte, deadline time.Time) error - WriteMessage(messageType int, data []byte) error - ReadMessage() (messageType int, p []byte, err error) - NextWriter(messageType int) (io.WriteCloser, error) - IsClosed() bool - Close() error + HTTPRequest() *http.Request + Conn() Connection + Disconnect() error + OnMessage(OnMessageFunc) + OnError(OnErrorFunc) + OnDisconnect(OnDisconnectFunc) + On(string, OnFunc) + initialize() error + destroy() error } type client struct { - id string - server *server - httpRequest *http.Request - conn *websocket.Conn - writeMTX sync.Mutex + id string + status ClientStatus + messageType int + server *server + httpRequest *http.Request + conn Connection + pingTicker *time.Ticker + writeMTX sync.Mutex + onMessageListeners []OnMessageFunc + onErrorListeners []OnErrorFunc + onDisconnectListeners []OnDisconnectFunc + onListeners map[string][]OnFunc } var _ Client = &client{} -func newClient(s *server, r *http.Request, conn *websocket.Conn, clientID string) Client { +func newClient(s *server, r *http.Request, conn Connection, clientID string) Client { c := &client{ - id: clientID, - server: s, - httpRequest: r, - conn: conn, + id: clientID, + status: CONNECTED, + messageType: gWebsocket.TextMessage, + server: s, + httpRequest: r, + conn: conn, + onMessageListeners: make([]OnMessageFunc, 0), + onErrorListeners: make([]OnErrorFunc, 0), + onDisconnectListeners: make([]OnDisconnectFunc, 0), + onListeners: make(map[string][]OnFunc), + } + + if s.options.BinaryMessage { + c.messageType = gWebsocket.BinaryMessage } return c @@ -51,58 +85,127 @@ func (c *client) ID() string { return c.id } -func (c *client) RemoteAddr() string { - return c.httpRequest.RemoteAddr +func (c *client) HTTPRequest() *http.Request { + return c.httpRequest } -func (c *client) UserAgent() string { - return c.httpRequest.UserAgent() +func (c *client) Conn() Connection { + return c.conn } -func (c *client) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) +func (c *client) Disconnect() error { + return c.server.Disconnect(c.ID()) } -func (c *client) ID() string { - return c.id +func (c *client) OnDisconnect(cb OnDisconnectFunc) { + c.onDisconnectListeners = append(c.onDisconnectListeners, cb) } -func (c *client) ID() string { - return c.id +func (c *client) OnError(cb OnErrorFunc) { + c.onErrorListeners = append(c.onErrorListeners, cb) } -func (c *client) ID() string { - return c.id +func (c *client) OnMessage(cb OnMessageFunc) { + c.onMessageListeners = append(c.onMessageListeners, cb) } -func (c *client) ID() string { - return c.id +func (c *client) On(event string, cb OnFunc) { + if c.onListeners[event] == nil { + c.onListeners[event] = make([]OnFunc, 0) + } + + c.onListeners[event] = append(c.onListeners[event], cb) } -func (c *client) ID() string { - return c.id +func (c *client) initialize() error { + c.startPingPong() } -func (c *client) ID() string { - return c.id +func (c *client) destroy() error { + c.stopPingPong() + c.status = DISCONNECTED + + for i := range c.onDisconnectListeners { + c.onDisconnectListeners[i](c) + } + + return c.conn.Close() } -func (c *client) ID() string { - return c.id +func (c *client) startPingPong() { + c.conn.SetPingHandler(func(message string) error { + err := c.conn.WriteControl(gWebsocket.PongMessage, []byte("pong"), time.Now().Add(c.server.options.PongTimeout)) + if err == gWebsocket.ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + }) + + c.pingTicker = time.NewTicker(c.server.options.PingPeriod) + go func() { + for { + <-c.pingTicker.C + err := c.conn.WriteControl(gWebsocket.PingMessage, []byte("ping"), time.Now().Add(c.server.options.PingTimeout)) + if err == gWebsocket.ErrCloseSent { + } else if e, ok := err.(net.Error); ok && e.Temporary() { + } + } + }() } -func (c *client) ID() string { - return c.id +func (c *client) stopPingPong() { + c.pingTicker.Stop() } -func (c *client) ID() string { - return c.id +func (c *client) startReading() { + hasReadTimeout := c.server.options.ReadTimeout > 0 + c.conn.SetReadLimit(c.server.options.MaxMessageSize) + c.conn.SetPongHandler(func(message string) error { + if hasReadTimeout { + c.conn.SetReadDeadline(time.Now().Add(c.server.options.ReadTimeout)) + } + + return nil + }) + + defer func() { + c.Disconnect() + }() + + for { + if hasReadTimeout { + c.conn.SetReadDeadline(time.Now().Add(c.server.options.ReadTimeout)) + } + messageType, data, err := c.conn.ReadMessage() + if err != nil { + if gWebsocket.IsUnexpectedCloseError(err, gWebsocket.CloseGoingAway) { + c.EmitError(err.Error()) + } + break + } else { + c.onMessageReceived(messageType, data) + } + + } } -func (c *client) ID() string { - return c.id +func (c *client) onMessageReceived(messageType int, data []byte) { } -func (c *client) ID() string { - return c.id +func (c *client) write(messageType int, data []byte) { + c.writeMTX.Lock() + if writeTimeout := c.server.options.WriteTimeout; writeTimeout > 0 { + // set the write deadline based on the configuration + err := c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + log.Println(fmt.Errorf("%v", err)) + } + + err := c.conn.WriteMessage(messageType, data) + c.writeMTX.Unlock() + + if nil != err { + c.Disconnect() + } } diff --git a/websocket/connection.go b/websocket/connection.go new file mode 100644 index 0000000..c514de7 --- /dev/null +++ b/websocket/connection.go @@ -0,0 +1,34 @@ +package websocket + +import ( + "io" + "net" + "time" +) + +// Connection is wrapper of the websocket.Conn +type Connection interface { + Close() error + CloseHandler() func(code int, text string) error + EnableWriteCompression(enable bool) + LocalAddr() net.Addr + NextReader() (messageType int, r io.Reader, err error) + NextWriter(messageType int) (io.WriteCloser, error) + PingHandler() func(appData string) error + PongHandler() func(appData string) error + ReadJSON(v interface{}) error + ReadMessage() (messageType int, p []byte, err error) + RemoteAddr() net.Addr + SetCloseHandler(h func(code int, text string) error) + SetCompressionLevel(level int) error + SetPingHandler(h func(appData string) error) + SetPongHandler(h func(appData string) error) + SetReadDeadline(t time.Time) error + SetReadLimit(limit int64) + SetWriteDeadline(t time.Time) error + Subprotocol() string + UnderlyingConn() net.Conn + WriteControl(messageType int, data []byte, deadline time.Time) error + WriteJSON(v interface{}) error + WriteMessage(messageType int, data []byte) error +} diff --git a/websocket/options.go b/websocket/options.go index afef69d..2086c12 100644 --- a/websocket/options.go +++ b/websocket/options.go @@ -14,6 +14,8 @@ const ( DefaultReadTimeout = 0 // DefaultPongTimeout is default value of Pong Timeout DefaultPongTimeout = 60 * time.Second + // DefaultPingTimeout is default value of Ping Timeout + DefaultPingTimeout = 10 * time.Second // DefaultPingPeriod is default value of Ping Period DefaultPingPeriod = (DefaultPongTimeout * 9) / 10 // DefaultMaxMessageSize is default value of Max Message Size @@ -52,9 +54,10 @@ type Options struct { WriteTimeout time.Duration ReadTimeout time.Duration PongTimeout time.Duration + PingTimeout time.Duration PingPeriod time.Duration MaxMessageSize int64 - BinaryMessages bool + BinaryMessage bool ReadBufferSize int WriteBufferSize int IDGenerator func(*http.Request) string @@ -67,9 +70,10 @@ func (o *Options) Set(main *Options) { main.WriteTimeout = o.WriteTimeout main.ReadTimeout = o.ReadTimeout main.PongTimeout = o.PongTimeout + main.PingTimeout = o.PingTimeout main.PingPeriod = o.PingPeriod main.MaxMessageSize = o.MaxMessageSize - main.BinaryMessages = o.BinaryMessages + main.BinaryMessage = o.BinaryMessage main.ReadBufferSize = o.ReadBufferSize main.WriteBufferSize = o.WriteBufferSize main.IDGenerator = o.IDGenerator @@ -114,6 +118,14 @@ func PongTimeout(val time.Duration) OptionSet { } } +// PingTimeout allowed to send the ping message to the connection +// Default value is 10 * time.Second +func PingTimeout(val time.Duration) OptionSet { + return func(o *Options) { + o.PingTimeout = val + } +} + // PingPeriod send ping messages to the connection with this period. Must be less than PongTimeout // Default value is (PongTimeout * 9) / 10 func PingPeriod(val time.Duration) OptionSet { @@ -130,13 +142,13 @@ func MaxMessageSize(val int64) OptionSet { } } -// BinaryMessages set it to true in order to denotes binary data messages instead of utf-8 text +// BinaryMessage set it to true in order to denotes binary data messages instead of utf-8 text // compatible if you wanna use the Connection's EmitMessage to send a custom binary data to the client, // like a native server-client communication. // defaults to false -func BinaryMessages(val bool) OptionSet { +func BinaryMessage(val bool) OptionSet { return func(o *Options) { - o.BinaryMessages = val + o.BinaryMessage = val } } diff --git a/websocket/server.go b/websocket/server.go index 209a22f..4f083be 100644 --- a/websocket/server.go +++ b/websocket/server.go @@ -10,7 +10,8 @@ import ( ) type ( - OnConnectionFunc func(websocket.Conn) + // OnConnectionFunc is callback function that used when client is connected + OnConnectionFunc func(Client) ) // Server is the websocket server, @@ -18,16 +19,16 @@ type ( type Server interface { Set(...OptionSetter) Handler() http.Handler - HandleConnection(*http.Request, *websocket.Conn) + HandleConnection(*http.Request, Connection) OnConnection(cb OnConnectionFunc) IsConnected(clientID string) bool - GetClient(clientID string) *Client + GetSocket(clientID string) Client Disconnect(clientID string) error } type server struct { options *Options - clients map[string]*client + clients map[string]Client clientMTX sync.Mutex onConnectionListeners []OnConnectionFunc } @@ -44,10 +45,10 @@ func New(setters ...OptionSetter) Server { } // newServer creates a websocket server and returns it -func newServer(setters ...OptionSetter) *server { +func newServer(setters ...OptionSetter) Server { s := &server{ - clients: make(map[string]*client, 100), + clients: make(map[string]Client, 100), onConnectionListeners: make([]OnConnectionFunc, 0), } @@ -55,6 +56,11 @@ func newServer(setters ...OptionSetter) *server { return s } +// Set is function that set option values +func Set(setters ...OptionSetter) { + defaultServer.Set(setters...) +} + func (s *server) Set(setters ...OptionSetter) { for _, setter := range setters { setter.Set(s.options) @@ -63,6 +69,11 @@ func (s *server) Set(setters ...OptionSetter) { s.options.Validate() } +// Handler is the function that used on http request +func Handler() http.Handler { + return defaultServer.Handler() +} + func (s *server) Handler() http.Handler { o := s.options @@ -70,7 +81,8 @@ func (s *server) Handler() http.Handler { ReadBufferSize: o.ReadBufferSize, WriteBufferSize: o.WriteBufferSize, Error: o.Error, - CheckOrigin: o.CheckOrigin} + CheckOrigin: o.CheckOrigin, + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, w.Header()) @@ -82,7 +94,11 @@ func (s *server) Handler() http.Handler { }) } -func (s *server) HandleConnection(r *http.Request, conn *websocket.Conn) { +func HandleConnection(r *http.Request, conn Connection) { + defaultServer.HandleConnection(r, conn) +} + +func (s *server) HandleConnection(r *http.Request, conn Connection) { clientID := s.options.IDGenerator(r) c := newClient(s, r, conn, clientID) err := s.addClient(clientID, c) @@ -90,21 +106,45 @@ func (s *server) HandleConnection(r *http.Request, conn *websocket.Conn) { log.Println(fmt.Errorf("%v", err)) return } + + for i := range s.onConnectionListeners { + s.onConnectionListeners[i](c) + } +} + +// OnConnection is function that add the callback when client is connected to default Server +func OnConnection(cb OnConnectionFunc) { + defaultServer.OnConnection(cb) } func (s *server) OnConnection(cb OnConnectionFunc) { s.onConnectionListeners = append(s.onConnectionListeners, cb) } -func (s *server) IsConnected(clientID string) bool { - c := s.clients[clientID] - return c != nil +// IsConnected is function that check client is connect +func IsConnected(clientID string) bool { + return defaultServer.IsConnected(clientID) } -func (s *server) GetClient(clientID string) *Client { +func (s *server) IsConnected(clientID string) bool { + soc := s.clients[clientID] + return soc != nil +} + +// GetSocket is function that return client instance +func GetSocket(clientID string) Client { + return defaultServer.GetSocket(clientID) +} + +func (s *server) GetSocket(clientID string) Client { return s.clients[clientID] } +// Disconnect is function that disconnect a client +func Disconnect(clientID string) error { + return defaultServer.Disconnect(clientID) +} + func (s *server) Disconnect(clientID string) error { c := s.clients[clientID] @@ -112,13 +152,15 @@ func (s *server) Disconnect(clientID string) error { return nil } - return nil + delete(s.clients, clientID) + + return c.destroy() } -func (s *server) addClient(clientID string, c *client) error { +func (s *server) addClient(clientID string, c Client) error { s.clientMTX.Lock() if s.clients[clientID] != nil { - return fmt.Errorf("Client[%s] is exist already", clientID) + return fmt.Errorf("ID of Client[%s] is exist already", clientID) } s.clients[clientID] = c s.clientMTX.Unlock()