This commit is contained in:
crusader 2017-07-12 21:36:12 +09:00
parent 2f96f3bf23
commit 6e4ca0dd5c
5 changed files with 264 additions and 74 deletions

View File

@ -1 +0,0 @@
package websocket

View File

@ -1,47 +1,81 @@
package websocket package websocket
import ( import (
"io" "fmt"
"log"
"net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/gorilla/websocket" gWebsocket "github.com/gorilla/websocket"
) )
type ClientStatus uint8
const (
CONNECTED ClientStatus = iota + 1
DISCONNECTED
)
type (
// OnDisconnectFunc is callback function that used when client is disconnected
OnDisconnectFunc func(Client)
// OnErrorFunc is callback function that used when error occurred
OnErrorFunc func(string)
// OnMessageFunc is callback function that receives messages from client
OnMessageFunc func([]byte)
// OnFunc is callback function that particular event which fires when a message to this event received
OnFunc interface{}
)
// Client is interface
type Client interface { type Client interface {
ID() string ID() string
RemoteAddr() string HTTPRequest() *http.Request
UserAgent() string Conn() Connection
SetWriteDeadline(t time.Time) error Disconnect() error
SetReadDeadline(t time.Time) error OnMessage(OnMessageFunc)
SetReadLimit(limit int64) OnError(OnErrorFunc)
SetPongHandler(h func(appData string) error) OnDisconnect(OnDisconnectFunc)
SetPingHandler(h func(appData string) error) On(string, OnFunc)
WriteControl(messageType int, data []byte, deadline time.Time) error initialize() error
WriteMessage(messageType int, data []byte) error destroy() error
ReadMessage() (messageType int, p []byte, err error)
NextWriter(messageType int) (io.WriteCloser, error)
IsClosed() bool
Close() error
} }
type client struct { type client struct {
id string id string
server *server status ClientStatus
httpRequest *http.Request messageType int
conn *websocket.Conn server *server
writeMTX sync.Mutex httpRequest *http.Request
conn Connection
pingTicker *time.Ticker
writeMTX sync.Mutex
onMessageListeners []OnMessageFunc
onErrorListeners []OnErrorFunc
onDisconnectListeners []OnDisconnectFunc
onListeners map[string][]OnFunc
} }
var _ Client = &client{} var _ Client = &client{}
func newClient(s *server, r *http.Request, conn *websocket.Conn, clientID string) Client { func newClient(s *server, r *http.Request, conn Connection, clientID string) Client {
c := &client{ c := &client{
id: clientID, id: clientID,
server: s, status: CONNECTED,
httpRequest: r, messageType: gWebsocket.TextMessage,
conn: conn, server: s,
httpRequest: r,
conn: conn,
onMessageListeners: make([]OnMessageFunc, 0),
onErrorListeners: make([]OnErrorFunc, 0),
onDisconnectListeners: make([]OnDisconnectFunc, 0),
onListeners: make(map[string][]OnFunc),
}
if s.options.BinaryMessage {
c.messageType = gWebsocket.BinaryMessage
} }
return c return c
@ -51,58 +85,127 @@ func (c *client) ID() string {
return c.id return c.id
} }
func (c *client) RemoteAddr() string { func (c *client) HTTPRequest() *http.Request {
return c.httpRequest.RemoteAddr return c.httpRequest
} }
func (c *client) UserAgent() string { func (c *client) Conn() Connection {
return c.httpRequest.UserAgent() return c.conn
} }
func (c *client) SetWriteDeadline(t time.Time) error { func (c *client) Disconnect() error {
return c.conn.SetWriteDeadline(t) return c.server.Disconnect(c.ID())
} }
func (c *client) ID() string { func (c *client) OnDisconnect(cb OnDisconnectFunc) {
return c.id c.onDisconnectListeners = append(c.onDisconnectListeners, cb)
} }
func (c *client) ID() string { func (c *client) OnError(cb OnErrorFunc) {
return c.id c.onErrorListeners = append(c.onErrorListeners, cb)
} }
func (c *client) ID() string { func (c *client) OnMessage(cb OnMessageFunc) {
return c.id c.onMessageListeners = append(c.onMessageListeners, cb)
} }
func (c *client) ID() string { func (c *client) On(event string, cb OnFunc) {
return c.id if c.onListeners[event] == nil {
c.onListeners[event] = make([]OnFunc, 0)
}
c.onListeners[event] = append(c.onListeners[event], cb)
} }
func (c *client) ID() string { func (c *client) initialize() error {
return c.id c.startPingPong()
} }
func (c *client) ID() string { func (c *client) destroy() error {
return c.id c.stopPingPong()
c.status = DISCONNECTED
for i := range c.onDisconnectListeners {
c.onDisconnectListeners[i](c)
}
return c.conn.Close()
} }
func (c *client) ID() string { func (c *client) startPingPong() {
return c.id c.conn.SetPingHandler(func(message string) error {
err := c.conn.WriteControl(gWebsocket.PongMessage, []byte("pong"), time.Now().Add(c.server.options.PongTimeout))
if err == gWebsocket.ErrCloseSent {
return nil
} else if e, ok := err.(net.Error); ok && e.Temporary() {
return nil
}
return err
})
c.pingTicker = time.NewTicker(c.server.options.PingPeriod)
go func() {
for {
<-c.pingTicker.C
err := c.conn.WriteControl(gWebsocket.PingMessage, []byte("ping"), time.Now().Add(c.server.options.PingTimeout))
if err == gWebsocket.ErrCloseSent {
} else if e, ok := err.(net.Error); ok && e.Temporary() {
}
}
}()
} }
func (c *client) ID() string { func (c *client) stopPingPong() {
return c.id c.pingTicker.Stop()
} }
func (c *client) ID() string { func (c *client) startReading() {
return c.id hasReadTimeout := c.server.options.ReadTimeout > 0
c.conn.SetReadLimit(c.server.options.MaxMessageSize)
c.conn.SetPongHandler(func(message string) error {
if hasReadTimeout {
c.conn.SetReadDeadline(time.Now().Add(c.server.options.ReadTimeout))
}
return nil
})
defer func() {
c.Disconnect()
}()
for {
if hasReadTimeout {
c.conn.SetReadDeadline(time.Now().Add(c.server.options.ReadTimeout))
}
messageType, data, err := c.conn.ReadMessage()
if err != nil {
if gWebsocket.IsUnexpectedCloseError(err, gWebsocket.CloseGoingAway) {
c.EmitError(err.Error())
}
break
} else {
c.onMessageReceived(messageType, data)
}
}
} }
func (c *client) ID() string { func (c *client) onMessageReceived(messageType int, data []byte) {
return c.id
} }
func (c *client) ID() string { func (c *client) write(messageType int, data []byte) {
return c.id c.writeMTX.Lock()
if writeTimeout := c.server.options.WriteTimeout; writeTimeout > 0 {
// set the write deadline based on the configuration
err := c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
log.Println(fmt.Errorf("%v", err))
}
err := c.conn.WriteMessage(messageType, data)
c.writeMTX.Unlock()
if nil != err {
c.Disconnect()
}
} }

34
websocket/connection.go Normal file
View File

@ -0,0 +1,34 @@
package websocket
import (
"io"
"net"
"time"
)
// Connection is wrapper of the websocket.Conn
type Connection interface {
Close() error
CloseHandler() func(code int, text string) error
EnableWriteCompression(enable bool)
LocalAddr() net.Addr
NextReader() (messageType int, r io.Reader, err error)
NextWriter(messageType int) (io.WriteCloser, error)
PingHandler() func(appData string) error
PongHandler() func(appData string) error
ReadJSON(v interface{}) error
ReadMessage() (messageType int, p []byte, err error)
RemoteAddr() net.Addr
SetCloseHandler(h func(code int, text string) error)
SetCompressionLevel(level int) error
SetPingHandler(h func(appData string) error)
SetPongHandler(h func(appData string) error)
SetReadDeadline(t time.Time) error
SetReadLimit(limit int64)
SetWriteDeadline(t time.Time) error
Subprotocol() string
UnderlyingConn() net.Conn
WriteControl(messageType int, data []byte, deadline time.Time) error
WriteJSON(v interface{}) error
WriteMessage(messageType int, data []byte) error
}

View File

@ -14,6 +14,8 @@ const (
DefaultReadTimeout = 0 DefaultReadTimeout = 0
// DefaultPongTimeout is default value of Pong Timeout // DefaultPongTimeout is default value of Pong Timeout
DefaultPongTimeout = 60 * time.Second DefaultPongTimeout = 60 * time.Second
// DefaultPingTimeout is default value of Ping Timeout
DefaultPingTimeout = 10 * time.Second
// DefaultPingPeriod is default value of Ping Period // DefaultPingPeriod is default value of Ping Period
DefaultPingPeriod = (DefaultPongTimeout * 9) / 10 DefaultPingPeriod = (DefaultPongTimeout * 9) / 10
// DefaultMaxMessageSize is default value of Max Message Size // DefaultMaxMessageSize is default value of Max Message Size
@ -52,9 +54,10 @@ type Options struct {
WriteTimeout time.Duration WriteTimeout time.Duration
ReadTimeout time.Duration ReadTimeout time.Duration
PongTimeout time.Duration PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration PingPeriod time.Duration
MaxMessageSize int64 MaxMessageSize int64
BinaryMessages bool BinaryMessage bool
ReadBufferSize int ReadBufferSize int
WriteBufferSize int WriteBufferSize int
IDGenerator func(*http.Request) string IDGenerator func(*http.Request) string
@ -67,9 +70,10 @@ func (o *Options) Set(main *Options) {
main.WriteTimeout = o.WriteTimeout main.WriteTimeout = o.WriteTimeout
main.ReadTimeout = o.ReadTimeout main.ReadTimeout = o.ReadTimeout
main.PongTimeout = o.PongTimeout main.PongTimeout = o.PongTimeout
main.PingTimeout = o.PingTimeout
main.PingPeriod = o.PingPeriod main.PingPeriod = o.PingPeriod
main.MaxMessageSize = o.MaxMessageSize main.MaxMessageSize = o.MaxMessageSize
main.BinaryMessages = o.BinaryMessages main.BinaryMessage = o.BinaryMessage
main.ReadBufferSize = o.ReadBufferSize main.ReadBufferSize = o.ReadBufferSize
main.WriteBufferSize = o.WriteBufferSize main.WriteBufferSize = o.WriteBufferSize
main.IDGenerator = o.IDGenerator main.IDGenerator = o.IDGenerator
@ -114,6 +118,14 @@ func PongTimeout(val time.Duration) OptionSet {
} }
} }
// PingTimeout allowed to send the ping message to the connection
// Default value is 10 * time.Second
func PingTimeout(val time.Duration) OptionSet {
return func(o *Options) {
o.PingTimeout = val
}
}
// PingPeriod send ping messages to the connection with this period. Must be less than PongTimeout // PingPeriod send ping messages to the connection with this period. Must be less than PongTimeout
// Default value is (PongTimeout * 9) / 10 // Default value is (PongTimeout * 9) / 10
func PingPeriod(val time.Duration) OptionSet { func PingPeriod(val time.Duration) OptionSet {
@ -130,13 +142,13 @@ func MaxMessageSize(val int64) OptionSet {
} }
} }
// BinaryMessages set it to true in order to denotes binary data messages instead of utf-8 text // BinaryMessage set it to true in order to denotes binary data messages instead of utf-8 text
// compatible if you wanna use the Connection's EmitMessage to send a custom binary data to the client, // compatible if you wanna use the Connection's EmitMessage to send a custom binary data to the client,
// like a native server-client communication. // like a native server-client communication.
// defaults to false // defaults to false
func BinaryMessages(val bool) OptionSet { func BinaryMessage(val bool) OptionSet {
return func(o *Options) { return func(o *Options) {
o.BinaryMessages = val o.BinaryMessage = val
} }
} }

View File

@ -10,7 +10,8 @@ import (
) )
type ( type (
OnConnectionFunc func(websocket.Conn) // OnConnectionFunc is callback function that used when client is connected
OnConnectionFunc func(Client)
) )
// Server is the websocket server, // Server is the websocket server,
@ -18,16 +19,16 @@ type (
type Server interface { type Server interface {
Set(...OptionSetter) Set(...OptionSetter)
Handler() http.Handler Handler() http.Handler
HandleConnection(*http.Request, *websocket.Conn) HandleConnection(*http.Request, Connection)
OnConnection(cb OnConnectionFunc) OnConnection(cb OnConnectionFunc)
IsConnected(clientID string) bool IsConnected(clientID string) bool
GetClient(clientID string) *Client GetSocket(clientID string) Client
Disconnect(clientID string) error Disconnect(clientID string) error
} }
type server struct { type server struct {
options *Options options *Options
clients map[string]*client clients map[string]Client
clientMTX sync.Mutex clientMTX sync.Mutex
onConnectionListeners []OnConnectionFunc onConnectionListeners []OnConnectionFunc
} }
@ -44,10 +45,10 @@ func New(setters ...OptionSetter) Server {
} }
// newServer creates a websocket server and returns it // newServer creates a websocket server and returns it
func newServer(setters ...OptionSetter) *server { func newServer(setters ...OptionSetter) Server {
s := &server{ s := &server{
clients: make(map[string]*client, 100), clients: make(map[string]Client, 100),
onConnectionListeners: make([]OnConnectionFunc, 0), onConnectionListeners: make([]OnConnectionFunc, 0),
} }
@ -55,6 +56,11 @@ func newServer(setters ...OptionSetter) *server {
return s return s
} }
// Set is function that set option values
func Set(setters ...OptionSetter) {
defaultServer.Set(setters...)
}
func (s *server) Set(setters ...OptionSetter) { func (s *server) Set(setters ...OptionSetter) {
for _, setter := range setters { for _, setter := range setters {
setter.Set(s.options) setter.Set(s.options)
@ -63,6 +69,11 @@ func (s *server) Set(setters ...OptionSetter) {
s.options.Validate() s.options.Validate()
} }
// Handler is the function that used on http request
func Handler() http.Handler {
return defaultServer.Handler()
}
func (s *server) Handler() http.Handler { func (s *server) Handler() http.Handler {
o := s.options o := s.options
@ -70,7 +81,8 @@ func (s *server) Handler() http.Handler {
ReadBufferSize: o.ReadBufferSize, ReadBufferSize: o.ReadBufferSize,
WriteBufferSize: o.WriteBufferSize, WriteBufferSize: o.WriteBufferSize,
Error: o.Error, Error: o.Error,
CheckOrigin: o.CheckOrigin} CheckOrigin: o.CheckOrigin,
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, w.Header()) conn, err := upgrader.Upgrade(w, r, w.Header())
@ -82,7 +94,11 @@ func (s *server) Handler() http.Handler {
}) })
} }
func (s *server) HandleConnection(r *http.Request, conn *websocket.Conn) { func HandleConnection(r *http.Request, conn Connection) {
defaultServer.HandleConnection(r, conn)
}
func (s *server) HandleConnection(r *http.Request, conn Connection) {
clientID := s.options.IDGenerator(r) clientID := s.options.IDGenerator(r)
c := newClient(s, r, conn, clientID) c := newClient(s, r, conn, clientID)
err := s.addClient(clientID, c) err := s.addClient(clientID, c)
@ -90,21 +106,45 @@ func (s *server) HandleConnection(r *http.Request, conn *websocket.Conn) {
log.Println(fmt.Errorf("%v", err)) log.Println(fmt.Errorf("%v", err))
return return
} }
for i := range s.onConnectionListeners {
s.onConnectionListeners[i](c)
}
}
// OnConnection is function that add the callback when client is connected to default Server
func OnConnection(cb OnConnectionFunc) {
defaultServer.OnConnection(cb)
} }
func (s *server) OnConnection(cb OnConnectionFunc) { func (s *server) OnConnection(cb OnConnectionFunc) {
s.onConnectionListeners = append(s.onConnectionListeners, cb) s.onConnectionListeners = append(s.onConnectionListeners, cb)
} }
func (s *server) IsConnected(clientID string) bool { // IsConnected is function that check client is connect
c := s.clients[clientID] func IsConnected(clientID string) bool {
return c != nil return defaultServer.IsConnected(clientID)
} }
func (s *server) GetClient(clientID string) *Client { func (s *server) IsConnected(clientID string) bool {
soc := s.clients[clientID]
return soc != nil
}
// GetSocket is function that return client instance
func GetSocket(clientID string) Client {
return defaultServer.GetSocket(clientID)
}
func (s *server) GetSocket(clientID string) Client {
return s.clients[clientID] return s.clients[clientID]
} }
// Disconnect is function that disconnect a client
func Disconnect(clientID string) error {
return defaultServer.Disconnect(clientID)
}
func (s *server) Disconnect(clientID string) error { func (s *server) Disconnect(clientID string) error {
c := s.clients[clientID] c := s.clients[clientID]
@ -112,13 +152,15 @@ func (s *server) Disconnect(clientID string) error {
return nil return nil
} }
return nil delete(s.clients, clientID)
return c.destroy()
} }
func (s *server) addClient(clientID string, c *client) error { func (s *server) addClient(clientID string, c Client) error {
s.clientMTX.Lock() s.clientMTX.Lock()
if s.clients[clientID] != nil { if s.clients[clientID] != nil {
return fmt.Errorf("Client[%s] is exist already", clientID) return fmt.Errorf("ID of Client[%s] is exist already", clientID)
} }
s.clients[clientID] = c s.clients[clientID] = c
s.clientMTX.Unlock() s.clientMTX.Unlock()