This commit is contained in:
crusader 2017-07-11 17:05:06 +09:00
parent a9e720e7ab
commit 2f96f3bf23
5 changed files with 284 additions and 133 deletions

1
websocket/channel.go Normal file
View File

@ -0,0 +1 @@
package websocket

108
websocket/client.go Normal file
View File

@ -0,0 +1,108 @@
package websocket
import (
"io"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
type Client interface {
ID() string
RemoteAddr() string
UserAgent() string
SetWriteDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetReadLimit(limit int64)
SetPongHandler(h func(appData string) error)
SetPingHandler(h func(appData string) error)
WriteControl(messageType int, data []byte, deadline time.Time) error
WriteMessage(messageType int, data []byte) error
ReadMessage() (messageType int, p []byte, err error)
NextWriter(messageType int) (io.WriteCloser, error)
IsClosed() bool
Close() error
}
type client struct {
id string
server *server
httpRequest *http.Request
conn *websocket.Conn
writeMTX sync.Mutex
}
var _ Client = &client{}
func newClient(s *server, r *http.Request, conn *websocket.Conn, clientID string) Client {
c := &client{
id: clientID,
server: s,
httpRequest: r,
conn: conn,
}
return c
}
func (c *client) ID() string {
return c.id
}
func (c *client) RemoteAddr() string {
return c.httpRequest.RemoteAddr
}
func (c *client) UserAgent() string {
return c.httpRequest.UserAgent()
}
func (c *client) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}
func (c *client) ID() string {
return c.id
}

View File

@ -1,12 +0,0 @@
package websocket
import (
"net/http"
"sync"
)
type connection struct {
id string
httpRequest http.Request
writeMTX sync.Mutex
}

View File

@ -33,20 +33,20 @@ type (
// OptionSetter sets a configuration field to the websocket config // OptionSetter sets a configuration field to the websocket config
// used to help developers to write less and configure only what they really want and nothing else // used to help developers to write less and configure only what they really want and nothing else
OptionSetter interface { OptionSetter interface {
Set(c *Config) Set(o *Options)
} }
// OptionSet implements the OptionSetter // OptionSet implements the OptionSetter
OptionSet func(c *Config) OptionSet func(o *Options)
) )
// Set is the func which makes the OptionSet an OptionSetter, this is used mostly // Set is the func which makes the OptionSet an OptionSetter, this is used mostly
func (o OptionSet) Set(c *Config) { func (os OptionSet) Set(o *Options) {
o(c) os(o)
} }
// Config is configuration of the websocket server // Options is configuration of the websocket server
type Config struct { type Options struct {
Error func(res http.ResponseWriter, req *http.Request, status int, reason error) Error func(res http.ResponseWriter, req *http.Request, status int, reason error)
CheckOrigin func(req *http.Request) bool CheckOrigin func(req *http.Request) bool
WriteTimeout time.Duration WriteTimeout time.Duration
@ -61,72 +61,72 @@ type Config struct {
} }
// Set is the func which makes the OptionSet an OptionSetter, this is used mostly // Set is the func which makes the OptionSet an OptionSetter, this is used mostly
func (c Config) Set(main *Config) { func (o *Options) Set(main *Options) {
main.Error = c.Error main.Error = o.Error
main.CheckOrigin = c.CheckOrigin main.CheckOrigin = o.CheckOrigin
main.WriteTimeout = c.WriteTimeout main.WriteTimeout = o.WriteTimeout
main.ReadTimeout = c.ReadTimeout main.ReadTimeout = o.ReadTimeout
main.PongTimeout = c.PongTimeout main.PongTimeout = o.PongTimeout
main.PingPeriod = c.PingPeriod main.PingPeriod = o.PingPeriod
main.MaxMessageSize = c.MaxMessageSize main.MaxMessageSize = o.MaxMessageSize
main.BinaryMessages = c.BinaryMessages main.BinaryMessages = o.BinaryMessages
main.ReadBufferSize = c.ReadBufferSize main.ReadBufferSize = o.ReadBufferSize
main.WriteBufferSize = c.WriteBufferSize main.WriteBufferSize = o.WriteBufferSize
main.IDGenerator = c.IDGenerator main.IDGenerator = o.IDGenerator
} }
// Error sets the error handler // Error sets the error handler
func Error(val func(res http.ResponseWriter, req *http.Request, status int, reason error)) OptionSet { func Error(val func(res http.ResponseWriter, req *http.Request, status int, reason error)) OptionSet {
return func(c *Config) { return func(o *Options) {
c.Error = val o.Error = val
} }
} }
// CheckOrigin sets a handler which will check if different origin(domains) are allowed to contact with // CheckOrigin sets a handler which will check if different origin(domains) are allowed to contact with
// the websocket server // the websocket server
func CheckOrigin(val func(req *http.Request) bool) OptionSet { func CheckOrigin(val func(req *http.Request) bool) OptionSet {
return func(c *Config) { return func(o *Options) {
c.CheckOrigin = val o.CheckOrigin = val
} }
} }
// WriteTimeout time allowed to write a message to the connection. // WriteTimeout time allowed to write a message to the connection.
// Default value is 15 * time.Second // Default value is 15 * time.Second
func WriteTimeout(val time.Duration) OptionSet { func WriteTimeout(val time.Duration) OptionSet {
return func(c *Config) { return func(o *Options) {
c.WriteTimeout = val o.WriteTimeout = val
} }
} }
// ReadTimeout time allowed to read a message from the connection. // ReadTimeout time allowed to read a message from the connection.
// Default value is 15 * time.Second // Default value is 15 * time.Second
func ReadTimeout(val time.Duration) OptionSet { func ReadTimeout(val time.Duration) OptionSet {
return func(c *Config) { return func(o *Options) {
c.ReadTimeout = val o.ReadTimeout = val
} }
} }
// PongTimeout allowed to read the next pong message from the connection // PongTimeout allowed to read the next pong message from the connection
// Default value is 60 * time.Second // Default value is 60 * time.Second
func PongTimeout(val time.Duration) OptionSet { func PongTimeout(val time.Duration) OptionSet {
return func(c *Config) { return func(o *Options) {
c.PongTimeout = val o.PongTimeout = val
} }
} }
// PingPeriod send ping messages to the connection with this period. Must be less than PongTimeout // PingPeriod send ping messages to the connection with this period. Must be less than PongTimeout
// Default value is (PongTimeout * 9) / 10 // Default value is (PongTimeout * 9) / 10
func PingPeriod(val time.Duration) OptionSet { func PingPeriod(val time.Duration) OptionSet {
return func(c *Config) { return func(o *Options) {
c.PingPeriod = val o.PingPeriod = val
} }
} }
// MaxMessageSize max message size allowed from connection // MaxMessageSize max message size allowed from connection
// Default value is 1024 // Default value is 1024
func MaxMessageSize(val int64) OptionSet { func MaxMessageSize(val int64) OptionSet {
return func(c *Config) { return func(o *Options) {
c.MaxMessageSize = val o.MaxMessageSize = val
} }
} }
@ -135,22 +135,22 @@ func MaxMessageSize(val int64) OptionSet {
// like a native server-client communication. // like a native server-client communication.
// defaults to false // defaults to false
func BinaryMessages(val bool) OptionSet { func BinaryMessages(val bool) OptionSet {
return func(c *Config) { return func(o *Options) {
c.BinaryMessages = val o.BinaryMessages = val
} }
} }
// ReadBufferSize is the buffer size for the underline reader // ReadBufferSize is the buffer size for the underline reader
func ReadBufferSize(val int) OptionSet { func ReadBufferSize(val int) OptionSet {
return func(c *Config) { return func(o *Options) {
c.ReadBufferSize = val o.ReadBufferSize = val
} }
} }
// WriteBufferSize is the buffer size for the underline writer // WriteBufferSize is the buffer size for the underline writer
func WriteBufferSize(val int) OptionSet { func WriteBufferSize(val int) OptionSet {
return func(c *Config) { return func(o *Options) {
c.WriteBufferSize = val o.WriteBufferSize = val
} }
} }
@ -159,56 +159,54 @@ func WriteBufferSize(val int) OptionSet {
// The request is an argument which you can use to generate the ID (from headers for example). // The request is an argument which you can use to generate the ID (from headers for example).
// If empty then the ID is generated by func: uuid.NewV4().String() // If empty then the ID is generated by func: uuid.NewV4().String()
func IDGenerator(val func(*http.Request) string) OptionSet { func IDGenerator(val func(*http.Request) string) OptionSet {
return func(c *Config) { return func(o *Options) {
c.IDGenerator = val o.IDGenerator = val
} }
} }
// Validate validates the configuration // Validate validates the configuration
func (c Config) Validate() Config { func (o *Options) Validate() {
if c.WriteTimeout < 0 { if o.WriteTimeout < 0 {
c.WriteTimeout = DefaultWriteTimeout o.WriteTimeout = DefaultWriteTimeout
} }
if c.ReadTimeout < 0 { if o.ReadTimeout < 0 {
c.ReadTimeout = DefaultReadTimeout o.ReadTimeout = DefaultReadTimeout
} }
if c.PongTimeout < 0 { if o.PongTimeout < 0 {
c.PongTimeout = DefaultPongTimeout o.PongTimeout = DefaultPongTimeout
} }
if c.PingPeriod <= 0 { if o.PingPeriod <= 0 {
c.PingPeriod = DefaultPingPeriod o.PingPeriod = DefaultPingPeriod
} }
if c.MaxMessageSize <= 0 { if o.MaxMessageSize <= 0 {
c.MaxMessageSize = DefaultMaxMessageSize o.MaxMessageSize = DefaultMaxMessageSize
} }
if c.ReadBufferSize <= 0 { if o.ReadBufferSize <= 0 {
c.ReadBufferSize = DefaultReadBufferSize o.ReadBufferSize = DefaultReadBufferSize
} }
if c.WriteBufferSize <= 0 { if o.WriteBufferSize <= 0 {
c.WriteBufferSize = DefaultWriteBufferSize o.WriteBufferSize = DefaultWriteBufferSize
} }
if c.Error == nil { if o.Error == nil {
c.Error = func(res http.ResponseWriter, req *http.Request, status int, reason error) { o.Error = func(res http.ResponseWriter, req *http.Request, status int, reason error) {
} }
} }
if c.CheckOrigin == nil { if o.CheckOrigin == nil {
c.CheckOrigin = func(req *http.Request) bool { o.CheckOrigin = func(req *http.Request) bool {
return true return true
} }
} }
if c.IDGenerator == nil { if o.IDGenerator == nil {
c.IDGenerator = DefaultIDGenerator o.IDGenerator = DefaultIDGenerator
} }
return c
} }

View File

@ -1,70 +1,126 @@
package websocket package websocket
import "net/http" import (
"fmt"
"log"
"net/http"
"sync"
"github.com/gorilla/websocket"
)
type (
OnConnectionFunc func(websocket.Conn)
)
// Server is the websocket server, // Server is the websocket server,
// listens on the config's port, the critical part is the event OnConnection // listens on the config's port, the critical part is the event OnConnection
type Server interface { type Server interface {
// Set sets an option aka configuration field to the websocket server
Set(...OptionSetter) Set(...OptionSetter)
// Handler returns the http.Handler which is setted to the 'Websocket Endpoint path',
// the client should target to this handler's developer's custom path
// ex: http.Handle("/myendpoint", mywebsocket.Handler())
// Handler calls the HandleConnection, so
// Use Handler or HandleConnection manually, DO NOT USE both.
// Note: you can always create your own upgrader which returns an UnderlineConnection and call only the HandleConnection manually (as Iris web framework does)
Handler() http.Handler Handler() http.Handler
// HandleConnection creates & starts to listening to a new connection HandleConnection(*http.Request, *websocket.Conn)
// DO NOT USE Handler() and HandleConnection at the sametime, see Handler for more OnConnection(cb OnConnectionFunc)
// NOTE: You don't need this, this is needed only when we want to 'hijack' the upgrader IsConnected(clientID string) bool
// (used for Iris and fasthttp before Iris v6) GetClient(clientID string) *Client
HandleConnection(*http.Request, UnderlineConnection) Disconnect(clientID string) error
// OnConnection this is the main event you, as developer, will work with each of the websocket connections
OnConnection(cb ConnectionFunc)
/*
connection actions, same as the connection's method,
but these methods accept the connection ID,
which is useful when the developer maps
this id with a database field (using config.IDGenerator).
*/
// IsConnected returns true if the connection with that ID is connected to the server
// useful when you have defined a custom connection id generator (based on a database)
// and you want to check if that connection is already connected (on multiple tabs)
IsConnected(connID string) bool
// Join joins a websocket client to a room,
// first parameter is the room name and the second the connection.ID()
//
// You can use connection.Join("room name") instead.
Join(roomName string, connID string)
// LeaveAll kicks out a connection from ALL of its joined rooms
LeaveAll(connID string)
// Leave leaves a websocket client from a room,
// first parameter is the room name and the second the connection.ID()
//
// You can use connection.Leave("room name") instead.
// Returns true if the connection has actually left from the particular room.
Leave(roomName string, connID string) bool
// GetConnectionsByRoom returns a list of Connection
// are joined to this room.
GetConnectionsByRoom(roomName string) []Connection
// Disconnect force-disconnects a websocket connection
// based on its connection.ID()
// What it does?
// 1. remove the connection from the list
// 2. leave from all joined rooms
// 3. fire the disconnect callbacks, if any
// 4. close the underline connection and return its error, if any.
//
// You can use the connection.Disconnect() instead.
Disconnect(connID string) error
} }
type server struct { type server struct {
options *Options
clients map[string]*client
clientMTX sync.Mutex
onConnectionListeners []OnConnectionFunc
}
var _ Server = &server{}
var defaultServer = newServer()
// server implementation
// New creates a websocket server and returns it
func New(setters ...OptionSetter) Server {
return newServer(setters...)
}
// newServer creates a websocket server and returns it
func newServer(setters ...OptionSetter) *server {
s := &server{
clients: make(map[string]*client, 100),
onConnectionListeners: make([]OnConnectionFunc, 0),
}
s.Set(setters...)
return s
}
func (s *server) Set(setters ...OptionSetter) {
for _, setter := range setters {
setter.Set(s.options)
}
s.options.Validate()
}
func (s *server) Handler() http.Handler {
o := s.options
upgrader := websocket.Upgrader{
ReadBufferSize: o.ReadBufferSize,
WriteBufferSize: o.WriteBufferSize,
Error: o.Error,
CheckOrigin: o.CheckOrigin}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, w.Header())
if err != nil {
http.Error(w, "Websocket Error: "+err.Error(), http.StatusServiceUnavailable)
return
}
s.HandleConnection(r, conn)
})
}
func (s *server) HandleConnection(r *http.Request, conn *websocket.Conn) {
clientID := s.options.IDGenerator(r)
c := newClient(s, r, conn, clientID)
err := s.addClient(clientID, c)
if nil != err {
log.Println(fmt.Errorf("%v", err))
return
}
}
func (s *server) OnConnection(cb OnConnectionFunc) {
s.onConnectionListeners = append(s.onConnectionListeners, cb)
}
func (s *server) IsConnected(clientID string) bool {
c := s.clients[clientID]
return c != nil
}
func (s *server) GetClient(clientID string) *Client {
return s.clients[clientID]
}
func (s *server) Disconnect(clientID string) error {
c := s.clients[clientID]
if nil == c {
return nil
}
return nil
}
func (s *server) addClient(clientID string, c *client) error {
s.clientMTX.Lock()
if s.clients[clientID] != nil {
return fmt.Errorf("Client[%s] is exist already", clientID)
}
s.clients[clientID] = c
s.clientMTX.Unlock()
return nil
} }