This commit is contained in:
crusader 2018-04-04 22:28:35 +09:00
parent b48ffe1374
commit aaed2f9fe3
13 changed files with 627 additions and 915 deletions

View File

@ -0,0 +1,42 @@
package server
import (
"time"
)
type ClientConnectionHandler interface {
ConnectionHandler
GetReconnectInterval() time.Duration
GetReconnectTryTime() int
}
type ClientConnectionHandlers struct {
ConnectionHandlers
ReconnectInterval time.Duration
ReconnectTryTime int
}
func (cch *ClientConnectionHandlers) GetReconnectInterval() time.Duration {
return cch.ReconnectInterval
}
func (cch *ClientConnectionHandlers) GetReconnectTryTime() int {
return cch.ReconnectTryTime
}
func (cch *ClientConnectionHandlers) Validate() error {
if err := cch.ConnectionHandlers.Validate(); nil != err {
return err
}
if cch.ReconnectInterval <= 0 {
cch.ReconnectInterval = DefaultReconnectInterval
}
if cch.ReconnectTryTime <= 0 {
cch.ReconnectTryTime = DefaultReconnectTryTime
}
return nil
}

70
client-rwc-handler.go Normal file
View File

@ -0,0 +1,70 @@
package server
import (
"io"
"sync"
logging "git.loafle.net/commons/logging-go"
)
type ClientRWCHandler struct {
ReadwriteHandler ReadWriteHandler
ReadChan chan<- []byte
WriteChan <-chan []byte
DisconnectedChan chan<- struct{}
ReconnectedChan <-chan *Conn
ClientStopChan <-chan struct{}
ClientStopWg *sync.WaitGroup
}
func (crwch *ClientRWCHandler) HandleConnection(conn *Conn) {
defer func() {
if nil != conn {
conn.Close()
}
logging.Logger().Infof("disconnected")
crwch.ClientStopWg.Done()
}()
logging.Logger().Infof("connected")
stopChan := make(chan struct{})
readerDoneChan := make(chan error)
writerDoneChan := make(chan error)
var err error
for {
if nil != err {
if io.EOF == err || io.ErrUnexpectedEOF == err {
crwch.DisconnectedChan <- struct{}{}
newConn := <-crwch.ReconnectedChan
if nil == newConn {
return
}
conn = newConn
} else {
return
}
}
go connReadHandler(crwch.ReadwriteHandler, conn, stopChan, readerDoneChan, crwch.ReadChan)
go connWriteHandler(crwch.ReadwriteHandler, conn, stopChan, writerDoneChan, crwch.WriteChan)
select {
case err = <-readerDoneChan:
close(stopChan)
<-writerDoneChan
case err = <-writerDoneChan:
close(stopChan)
<-readerDoneChan
case <-crwch.ClientStopChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
return
}
}
}

65
connection-handler.go Normal file
View File

@ -0,0 +1,65 @@
package server
import (
"crypto/tls"
"fmt"
"net"
"time"
)
type ConnectionHandler interface {
GetConcurrency() int
GetKeepAlive() time.Duration
GetHandshakeTimeout() time.Duration
GetTLSConfig() *tls.Config
Listener(serverCtx ServerCtx) (net.Listener, error)
}
type ConnectionHandlers struct {
ConnectionHandler
// The maximum number of concurrent connections the server may serve.
//
// DefaultConcurrency is used if not set.
Concurrency int
KeepAlive time.Duration
HandshakeTimeout time.Duration
TLSConfig *tls.Config
}
func (ch *ConnectionHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) {
return nil, fmt.Errorf("Method[ConnectionHandler.Listener] is not implemented")
}
func (ch *ConnectionHandlers) GetConcurrency() int {
return ch.Concurrency
}
func (ch *ConnectionHandlers) GetKeepAlive() time.Duration {
return ch.KeepAlive
}
func (ch *ConnectionHandlers) GetHandshakeTimeout() time.Duration {
return ch.HandshakeTimeout
}
func (ch *ConnectionHandlers) GetTLSConfig() *tls.Config {
return ch.TLSConfig
}
func (ch *ConnectionHandlers) Validate() error {
if ch.Concurrency <= 0 {
ch.Concurrency = DefaultConcurrency
}
if ch.KeepAlive <= 0 {
ch.KeepAlive = DefaultKeepAlive
}
if ch.HandshakeTimeout <= 0 {
ch.HandshakeTimeout = DefaultHandshakeTimeout
}
return nil
}

View File

@ -30,4 +30,7 @@ const (
DefaultPingTimeout = 10 * time.Second DefaultPingTimeout = 10 * time.Second
// DefaultPingPeriod is default value of send ping period // DefaultPingPeriod is default value of send ping period
DefaultPingPeriod = (DefaultPingTimeout * 9) / 10 DefaultPingPeriod = (DefaultPingTimeout * 9) / 10
DefaultReconnectInterval = 1 * time.Second
DefaultReconnectTryTime = 10
) )

View File

@ -23,6 +23,9 @@ import (
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
type Client struct { type Client struct {
server.ClientConnectionHandlers
server.ReadWriteHandlers
Name string Name string
URL string URL string
@ -44,46 +47,18 @@ type Client struct {
// If Proxy is nil or returns a nil *URL, no proxy is used. // If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error) Proxy func(*http.Request) (*url.URL, error)
TLSConfig *tls.Config
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the client's requested subprotocols.
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
// MaxMessageSize is the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application.
MaxMessageSize int64
// WriteTimeout is the write deadline on the underlying network
// connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will
// not time out.
WriteTimeout time.Duration
// ReadTimeout is the read deadline on the underlying network connection.
// After a read has timed out, the websocket connection state is corrupt and
// all future reads will return an error. A zero value for t means reads will
// not time out.
ReadTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
serverURL *url.URL serverURL *url.URL
stopChan chan struct{} stopChan chan struct{}
stopWg sync.WaitGroup stopWg sync.WaitGroup
conn *server.Conn
readChan chan []byte readChan chan []byte
writeChan chan []byte writeChan chan []byte
disconnectedChan chan struct{}
reconnectedChan chan *server.Conn
crwch server.ClientRWCHandler
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) { func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) {
@ -107,10 +82,20 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res
c.readChan = make(chan []byte, 256) c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256)
c.disconnectedChan = make(chan struct{})
c.reconnectedChan = make(chan *server.Conn)
c.stopChan = make(chan struct{}) c.stopChan = make(chan struct{})
c.crwch.ReadwriteHandler = c
c.crwch.ReadChan = c.readChan
c.crwch.WriteChan = c.writeChan
c.crwch.ClientStopChan = c.stopChan
c.crwch.ClientStopWg = &c.stopWg
c.crwch.DisconnectedChan = c.disconnectedChan
c.crwch.ReconnectedChan = c.reconnectedChan
c.stopWg.Add(1) c.stopWg.Add(1)
go c.handleConnection(conn) go c.crwch.HandleConnection(conn)
return c.readChan, c.writeChan, res, nil return c.readChan, c.writeChan, res, nil
} }
@ -144,124 +129,6 @@ func (c *Client) connect() (*server.Conn, *http.Response, error) {
return conn, res, nil return conn, res, nil
} }
func (c *Client) handleConnection(conn *server.Conn) {
defer func() {
if nil != conn {
conn.Close()
}
logging.Logger().Infof(c.clientMessage("disconnected"))
c.stopWg.Done()
}()
logging.Logger().Infof(c.clientMessage("connected"))
stopChan := make(chan struct{})
readerDoneChan := make(chan struct{})
writerDoneChan := make(chan struct{})
go handleClientRead(c, conn, stopChan, readerDoneChan)
go handleClientWrite(c, conn, stopChan, writerDoneChan)
select {
case <-readerDoneChan:
close(stopChan)
<-writerDoneChan
case <-writerDoneChan:
close(stopChan)
<-readerDoneChan
case <-c.stopChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
}
}
func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
defer func() {
doneChan <- struct{}{}
}()
if 0 < c.MaxMessageSize {
conn.SetReadLimit(c.MaxMessageSize)
}
if 0 < c.ReadTimeout {
conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
}
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(c.PongTimeout))
return nil
})
var (
message []byte
err error
)
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
close(readMessageChan)
}()
select {
case <-stopChan:
<-readMessageChan
return
case <-readMessageChan:
}
if nil != err {
if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) {
logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err)))
}
return
}
c.readChan <- message
}
}
func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
defer func() {
doneChan <- struct{}{}
}()
ticker := time.NewTicker(c.PingPeriod)
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-c.writeChan:
conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
if !ok {
conn.WriteMessage(server.CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(server.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(c.PingTimeout))
if err := conn.WriteMessage(server.PingMessage, nil); nil != err {
return
}
case <-stopChan:
return
}
}
}
func (c *Client) Dial() (*server.Conn, *http.Response, error) { func (c *Client) Dial() (*server.Conn, *http.Response, error) {
var ( var (
err error err error
@ -476,6 +343,13 @@ func (c *Client) Dial() (*server.Conn, *http.Response, error) {
} }
func (c *Client) Validate() error { func (c *Client) Validate() error {
if err := c.ClientConnectionHandlers.Validate(); nil != err {
return err
}
if err := c.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == c.Name { if "" == c.Name {
c.Name = "Client" c.Name = "Client"
} }
@ -507,34 +381,6 @@ func (c *Client) Validate() error {
c.Proxy = http.ProxyFromEnvironment c.Proxy = http.ProxyFromEnvironment
} }
if c.HandshakeTimeout <= 0 {
c.HandshakeTimeout = server.DefaultHandshakeTimeout
}
if c.MaxMessageSize <= 0 {
c.MaxMessageSize = server.DefaultMaxMessageSize
}
if c.ReadBufferSize <= 0 {
c.ReadBufferSize = server.DefaultReadBufferSize
}
if c.WriteBufferSize <= 0 {
c.WriteBufferSize = server.DefaultWriteBufferSize
}
if c.ReadTimeout <= 0 {
c.ReadTimeout = server.DefaultReadTimeout
}
if c.WriteTimeout <= 0 {
c.WriteTimeout = server.DefaultWriteTimeout
}
if c.PongTimeout <= 0 {
c.PongTimeout = server.DefaultPongTimeout
}
if c.PingTimeout <= 0 {
c.PingTimeout = server.DefaultPingTimeout
}
if c.PingPeriod <= 0 {
c.PingPeriod = server.DefaultPingPeriod
}
return nil return nil
} }

View File

@ -6,7 +6,6 @@ import (
"net" "net"
"net/http" "net/http"
"sync" "sync"
"time"
"git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
@ -17,13 +16,13 @@ type Server struct {
ServerHandler ServerHandler ServerHandler ServerHandler
ctx server.ServerCtx ctx server.ServerCtx
stopChan chan struct{}
stopWg sync.WaitGroup
srwch server.ServerRWCHandler
hs *fasthttp.Server hs *fasthttp.Server
upgrader *Upgrader upgrader *Upgrader
connections sync.Map
stopChan chan struct{}
stopWg sync.WaitGroup
} }
func (s *Server) ListenAndServe() error { func (s *Server) ListenAndServe() error {
@ -59,7 +58,7 @@ func (s *Server) ListenAndServe() error {
HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(),
ReadBufferSize: s.ServerHandler.GetReadBufferSize(), ReadBufferSize: s.ServerHandler.GetReadBufferSize(),
WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), WriteBufferSize: s.ServerHandler.GetWriteBufferSize(),
CheckOrigin: s.ServerHandler.CheckOrigin, CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin,
Error: s.onError, Error: s.onError,
EnableCompression: s.ServerHandler.IsEnableCompression(), EnableCompression: s.ServerHandler.IsEnableCompression(),
} }
@ -73,13 +72,18 @@ func (s *Server) ListenAndServe() error {
} }
s.stopChan = make(chan struct{}) s.stopChan = make(chan struct{})
s.srwch.ReadwriteHandler = s.ServerHandler
s.srwch.ServerStopChan = s.stopChan
s.srwch.ServerStopWg = &s.stopWg
s.stopWg.Add(1) s.stopWg.Add(1)
return s.handleServer(listener) return s.handleServer(listener)
} }
func (s *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
if s.stopChan == nil { if s.stopChan == nil {
return fmt.Errorf(s.serverMessage("server must be started before stopping it")) return fmt.Errorf("server must be started before stopping it")
} }
close(s.stopChan) close(s.stopChan)
s.stopWg.Wait() s.stopWg.Wait()
@ -91,15 +95,6 @@ func (s *Server) Shutdown(ctx context.Context) error {
return nil return nil
} }
func (s *Server) ConnectionSize() int {
var sz int
s.connections.Range(func(k, v interface{}) bool {
sz++
return true
})
return sz
}
func (s *Server) serverMessage(msg string) string { func (s *Server) serverMessage(msg string) string {
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg) return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
} }
@ -162,7 +157,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
) )
if 0 < s.ServerHandler.GetConcurrency() { if 0 < s.ServerHandler.GetConcurrency() {
sz := s.ConnectionSize() sz := s.srwch.ConnectionSize()
if sz >= s.ServerHandler.GetConcurrency() { if sz >= s.ServerHandler.GetConcurrency() {
logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz))) logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz)))
s.onError(ctx, fasthttp.StatusServiceUnavailable, err) s.onError(ctx, fasthttp.StatusServiceUnavailable, err)
@ -170,7 +165,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
} }
} }
if servlet = s.ServerHandler.Servlet(path); nil == servlet { if servlet = s.ServerHandler.(ServerHandler).Servlet(path); nil == servlet {
s.onError(ctx, fasthttp.StatusInternalServerError, err) s.onError(ctx, fasthttp.StatusInternalServerError, err)
return return
} }
@ -190,154 +185,10 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
} }
s.stopWg.Add(1) s.stopWg.Add(1)
go s.handleConnection(servlet, servletCtx, conn) go s.srwch.HandleConnection(servlet, servletCtx, conn)
}) })
} }
func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *server.Conn) {
addr := conn.RemoteAddr()
defer func() {
if nil != conn {
conn.Close()
}
servlet.OnDisconnect(servletCtx)
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr)))
s.stopWg.Done()
}()
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr)))
s.connections.Store(conn, true)
defer s.connections.Delete(conn)
servlet.OnConnect(servletCtx, conn)
conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close")
return nil
})
stopChan := make(chan struct{})
servletDoneChan := make(chan struct{})
readChan := make(chan []byte)
writeChan := make(chan []byte)
readerDoneChan := make(chan struct{})
writerDoneChan := make(chan struct{})
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
go handleRead(s, conn, stopChan, readerDoneChan, readChan)
go handleWrite(s, conn, stopChan, writerDoneChan, writeChan)
select {
case <-readerDoneChan:
close(stopChan)
<-writerDoneChan
<-servletDoneChan
case <-writerDoneChan:
close(stopChan)
<-readerDoneChan
<-servletDoneChan
case <-servletDoneChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
case <-s.stopChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
<-servletDoneChan
}
}
func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) {
s.ServerHandler.OnError(s.ctx, ctx, status, reason) s.ServerHandler.(ServerHandler).OnError(s.ctx, ctx, status, reason)
}
func handleRead(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) {
defer func() {
doneChan <- struct{}{}
}()
if 0 < s.ServerHandler.GetMaxMessageSize() {
conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize())
}
if 0 < s.ServerHandler.GetReadTimeout() {
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout()))
}
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout()))
return nil
})
var (
message []byte
err error
)
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
close(readMessageChan)
}()
select {
case <-s.stopChan:
<-readMessageChan
return
case <-readMessageChan:
}
if nil != err {
if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) {
logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err)))
}
return
}
readChan <- message
}
}
func handleWrite(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) {
defer func() {
doneChan <- struct{}{}
}()
ticker := time.NewTicker(s.ServerHandler.GetPingPeriod())
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-writeChan:
if 0 < s.ServerHandler.GetWriteTimeout() {
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout()))
}
if !ok {
conn.WriteMessage(server.CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(server.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout()))
if err := conn.WriteMessage(server.PingMessage, nil); nil != err {
return
}
case <-s.stopChan:
return
}
}
} }

View File

@ -12,52 +12,25 @@ import (
) )
type Client struct { type Client struct {
server.ClientConnectionHandlers
server.ReadWriteHandlers
Name string Name string
Network string Network string
Address string Address string
TLSConfig *tls.Config
HandshakeTimeout time.Duration
KeepAlive time.Duration
LocalAddress net.Addr LocalAddress net.Addr
MaxMessageSize int64
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// Maximum duration for reading the full request (including body).
//
// This also limits the maximum duration for idle keep-alive
// connections.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for writing the full response (including body).
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
EnableCompression bool
stopChan chan struct{} stopChan chan struct{}
stopWg sync.WaitGroup stopWg sync.WaitGroup
conn *server.Conn
readChan chan []byte readChan chan []byte
writeChan chan []byte writeChan chan []byte
disconnectedChan chan struct{}
reconnectedChan chan *server.Conn
crwch server.ClientRWCHandler
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
@ -81,10 +54,20 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err
c.readChan = make(chan []byte, 256) c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256)
c.disconnectedChan = make(chan struct{})
c.reconnectedChan = make(chan *server.Conn)
c.stopChan = make(chan struct{}) c.stopChan = make(chan struct{})
c.crwch.ReadwriteHandler = c
c.crwch.ReadChan = c.readChan
c.crwch.WriteChan = c.writeChan
c.crwch.ClientStopChan = c.stopChan
c.crwch.ClientStopWg = &c.stopWg
c.crwch.DisconnectedChan = c.disconnectedChan
c.crwch.ReconnectedChan = c.reconnectedChan
c.stopWg.Add(1) c.stopWg.Add(1)
go c.handleConnection(conn) go c.crwch.HandleConnection(conn)
return c.readChan, c.writeChan, nil return c.readChan, c.writeChan, nil
} }
@ -119,126 +102,6 @@ func (c *Client) connect() (*server.Conn, error) {
return conn, nil return conn, nil
} }
func (c *Client) handleConnection(conn *server.Conn) {
defer func() {
if nil != conn {
conn.Close()
}
logging.Logger().Infof(c.clientMessage("disconnected"))
c.stopWg.Done()
}()
logging.Logger().Infof(c.clientMessage("connected"))
stopChan := make(chan struct{})
readerDoneChan := make(chan struct{})
writerDoneChan := make(chan struct{})
go handleClientRead(c, conn, stopChan, readerDoneChan)
go handleClientWrite(c, conn, stopChan, writerDoneChan)
select {
case <-readerDoneChan:
close(stopChan)
<-writerDoneChan
case <-writerDoneChan:
close(stopChan)
<-readerDoneChan
case <-c.stopChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
}
}
func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
defer func() {
doneChan <- struct{}{}
}()
if 0 < c.MaxMessageSize {
conn.SetReadLimit(c.MaxMessageSize)
}
if 0 < c.ReadTimeout {
conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
}
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(c.PongTimeout))
return nil
})
var (
message []byte
err error
)
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
close(readMessageChan)
}()
select {
case <-stopChan:
<-readMessageChan
return
case <-readMessageChan:
}
if nil != err {
if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) {
logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err)))
}
return
}
c.readChan <- message
}
}
func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
defer func() {
doneChan <- struct{}{}
}()
ticker := time.NewTicker(c.PingPeriod)
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-c.writeChan:
if 0 < c.WriteTimeout {
conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
}
if !ok {
conn.WriteMessage(server.CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(server.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(c.PingTimeout))
if err := conn.WriteMessage(server.PingMessage, nil); nil != err {
return
}
case <-stopChan:
return
}
}
}
func (c *Client) Dial() (net.Conn, error) { func (c *Client) Dial() (net.Conn, error) {
if err := c.Validate(); nil != err { if err := c.Validate(); nil != err {
return nil, err return nil, err
@ -279,6 +142,13 @@ func (c *Client) Dial() (net.Conn, error) {
} }
func (c *Client) Validate() error { func (c *Client) Validate() error {
if err := c.ClientConnectionHandlers.Validate(); nil != err {
return err
}
if err := c.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == c.Name { if "" == c.Name {
c.Name = "Client" c.Name = "Client"
} }
@ -291,33 +161,5 @@ func (c *Client) Validate() error {
return fmt.Errorf("Client: Address is not valid") return fmt.Errorf("Client: Address is not valid")
} }
if c.HandshakeTimeout <= 0 {
c.HandshakeTimeout = server.DefaultHandshakeTimeout
}
if c.MaxMessageSize <= 0 {
c.MaxMessageSize = server.DefaultMaxMessageSize
}
if c.ReadBufferSize <= 0 {
c.ReadBufferSize = server.DefaultReadBufferSize
}
if c.WriteBufferSize <= 0 {
c.WriteBufferSize = server.DefaultWriteBufferSize
}
if c.ReadTimeout <= 0 {
c.ReadTimeout = server.DefaultReadTimeout
}
if c.WriteTimeout <= 0 {
c.WriteTimeout = server.DefaultWriteTimeout
}
if c.PongTimeout <= 0 {
c.PongTimeout = server.DefaultPongTimeout
}
if c.PingTimeout <= 0 {
c.PingTimeout = server.DefaultPingTimeout
}
if c.PingPeriod <= 0 {
c.PingPeriod = server.DefaultPingPeriod
}
return nil return nil
} }

View File

@ -1,8 +1,10 @@
package net package net
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -10,24 +12,19 @@ import (
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
) )
type Server interface { type Server struct {
server.Server ServerHandler ServerHandler
ctx server.ServerCtx
stopChan chan struct{}
stopWg sync.WaitGroup
srwch server.ServerRWCHandler
} }
func NewServer(serverHandler ServerHandler) Server { func (s *Server) ListenAndServe() error {
s := &netServer{} if s.stopChan != nil {
s.ServerHandler = serverHandler return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again"))
return s
}
type netServer struct {
server.Servers
}
func (s *netServer) ListenAndServe() error {
if s.StopChan != nil {
return fmt.Errorf(s.ServerMessage("already running. Stop it before starting it again"))
} }
var ( var (
@ -41,25 +38,48 @@ func (s *netServer) ListenAndServe() error {
return err return err
} }
s.ServerCtx = s.ServerHandler.ServerCtx() s.ctx = s.ServerHandler.ServerCtx()
if nil == s.ServerCtx { if nil == s.ctx {
return fmt.Errorf(s.ServerMessage("ServerCtx is nil")) return fmt.Errorf(s.serverMessage("ServerCtx is nil"))
} }
if err = s.ServerHandler.Init(s.ServerCtx); nil != err { if err = s.ServerHandler.Init(s.ctx); nil != err {
return err return err
} }
if listener, err = s.ServerHandler.Listener(s.ServerCtx); nil != err { if listener, err = s.ServerHandler.Listener(s.ctx); nil != err {
return err return err
} }
s.StopChan = make(chan struct{}) s.stopChan = make(chan struct{})
s.StopWg.Add(1)
s.srwch.ReadwriteHandler = s.ServerHandler
s.srwch.ServerStopChan = s.stopChan
s.srwch.ServerStopWg = &s.stopWg
s.stopWg.Add(1)
return s.handleServer(listener) return s.handleServer(listener)
} }
func (s *netServer) handleServer(listener net.Listener) error { func (s *Server) Shutdown(ctx context.Context) error {
if s.stopChan == nil {
return fmt.Errorf("server must be started before stopping it")
}
close(s.stopChan)
s.stopWg.Wait()
s.ServerHandler.Destroy(s.ctx)
s.stopChan = nil
return nil
}
func (s *Server) serverMessage(msg string) string {
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
}
func (s *Server) handleServer(listener net.Listener) error {
var ( var (
stopping atomic.Value stopping atomic.Value
netConn net.Conn netConn net.Conn
@ -71,18 +91,18 @@ func (s *netServer) handleServer(listener net.Listener) error {
listener.Close() listener.Close()
} }
s.ServerHandler.OnStop(s.ServerCtx) s.ServerHandler.OnStop(s.ctx)
logging.Logger().Infof(s.ServerMessage("Stopped")) logging.Logger().Infof(s.serverMessage("Stopped"))
s.StopWg.Done() s.stopWg.Done()
}() }()
if err = s.ServerHandler.OnStart(s.ServerCtx); nil != err { if err = s.ServerHandler.OnStart(s.ctx); nil != err {
return err return err
} }
logging.Logger().Infof(s.ServerMessage("Started")) logging.Logger().Infof(s.serverMessage("Started"))
for { for {
acceptChan := make(chan struct{}) acceptChan := make(chan struct{})
@ -90,14 +110,14 @@ func (s *netServer) handleServer(listener net.Listener) error {
go func() { go func() {
if netConn, err = listener.Accept(); err != nil { if netConn, err = listener.Accept(); err != nil {
if nil == stopping.Load() { if nil == stopping.Load() {
logging.Logger().Errorf(s.ServerMessage(fmt.Sprintf("%v", err))) logging.Logger().Errorf(s.serverMessage(fmt.Sprintf("%v", err)))
} }
} }
close(acceptChan) close(acceptChan)
}() }()
select { select {
case <-s.StopChan: case <-s.stopChan:
stopping.Store(true) stopping.Store(true)
listener.Close() listener.Close()
<-acceptChan <-acceptChan
@ -108,7 +128,7 @@ func (s *netServer) handleServer(listener net.Listener) error {
if nil != err { if nil != err {
select { select {
case <-s.StopChan: case <-s.stopChan:
return nil return nil
case <-time.After(time.Second): case <-time.After(time.Second):
} }
@ -116,9 +136,9 @@ func (s *netServer) handleServer(listener net.Listener) error {
} }
if 0 < s.ServerHandler.GetConcurrency() { if 0 < s.ServerHandler.GetConcurrency() {
sz := s.ConnectionSize() sz := s.srwch.ConnectionSize()
if sz >= s.ServerHandler.GetConcurrency() { if sz >= s.ServerHandler.GetConcurrency() {
logging.Logger().Warnf(s.ServerMessage(fmt.Sprintf("max connections size %d, refuse", sz))) logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz)))
netConn.Close() netConn.Close()
continue continue
} }
@ -126,24 +146,24 @@ func (s *netServer) handleServer(listener net.Listener) error {
servlet := s.ServerHandler.(ServerHandler).Servlet() servlet := s.ServerHandler.(ServerHandler).Servlet()
if nil == servlet { if nil == servlet {
logging.Logger().Errorf(s.ServerMessage("Servlet is nil")) logging.Logger().Errorf(s.serverMessage("Servlet is nil"))
continue continue
} }
servletCtx := servlet.ServletCtx(s.ServerCtx) servletCtx := servlet.ServletCtx(s.ctx)
if nil == servletCtx { if nil == servletCtx {
logging.Logger().Errorf(s.ServerMessage("ServletCtx is nil")) logging.Logger().Errorf(s.serverMessage("ServletCtx is nil"))
continue continue
} }
if err := servlet.Handshake(servletCtx, netConn); nil != err { if err := servlet.Handshake(servletCtx, netConn); nil != err {
logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err))) logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err)))
continue continue
} }
conn := server.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize()) conn := server.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize())
s.StopWg.Add(1) s.stopWg.Add(1)
go s.HandleConnection(servlet, servletCtx, conn) go s.srwch.HandleConnection(servlet, servletCtx, conn)
} }
} }

115
readwrite-handler.go Normal file
View File

@ -0,0 +1,115 @@
package server
import (
"time"
)
type ReadWriteHandler interface {
GetMaxMessageSize() int64
GetReadBufferSize() int
GetWriteBufferSize() int
GetReadTimeout() time.Duration
GetWriteTimeout() time.Duration
GetPongTimeout() time.Duration
GetPingTimeout() time.Duration
GetPingPeriod() time.Duration
IsEnableCompression() bool
}
type ReadWriteHandlers struct {
ReadWriteHandler
MaxMessageSize int64
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// Maximum duration for reading the full request (including body).
//
// This also limits the maximum duration for idle keep-alive
// connections.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for writing the full response (including body).
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
EnableCompression bool
}
func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 {
return rwh.MaxMessageSize
}
func (rwh *ReadWriteHandlers) GetReadBufferSize() int {
return rwh.ReadBufferSize
}
func (rwh *ReadWriteHandlers) GetWriteBufferSize() int {
return rwh.WriteBufferSize
}
func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration {
return rwh.ReadTimeout
}
func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration {
return rwh.WriteTimeout
}
func (rwh *ReadWriteHandlers) GetPongTimeout() time.Duration {
return rwh.PongTimeout
}
func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration {
return rwh.PingTimeout
}
func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
return rwh.PingPeriod
}
func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression
}
func (rwh *ReadWriteHandlers) Validate() error {
if rwh.MaxMessageSize <= 0 {
rwh.MaxMessageSize = DefaultMaxMessageSize
}
if rwh.ReadBufferSize <= 0 {
rwh.ReadBufferSize = DefaultReadBufferSize
}
if rwh.WriteBufferSize <= 0 {
rwh.WriteBufferSize = DefaultWriteBufferSize
}
if rwh.ReadTimeout <= 0 {
rwh.ReadTimeout = DefaultReadTimeout
}
if rwh.WriteTimeout <= 0 {
rwh.WriteTimeout = DefaultWriteTimeout
}
if rwh.PongTimeout <= 0 {
rwh.PongTimeout = DefaultPongTimeout
}
if rwh.PingTimeout <= 0 {
rwh.PingTimeout = DefaultPingTimeout
}
if rwh.PingPeriod <= 0 {
rwh.PingPeriod = (rwh.PingTimeout * 9) / 10
}
return nil
}

101
readwrite.go Normal file
View File

@ -0,0 +1,101 @@
package server
import (
"fmt"
"io"
"time"
)
func connReadHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- []byte) {
var (
message []byte
err error
)
defer func() {
doneChan <- err
}()
if 0 < readWriteHandler.GetMaxMessageSize() {
conn.SetReadLimit(readWriteHandler.GetMaxMessageSize())
}
if 0 < readWriteHandler.GetReadTimeout() {
conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetReadTimeout()))
}
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetPongTimeout()))
return nil
})
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
close(readMessageChan)
}()
select {
case <-stopChan:
<-readMessageChan
return
case <-readMessageChan:
}
if nil != err {
if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) {
err = fmt.Errorf("Read error %v", err)
}
return
}
readChan <- message
}
}
func connWriteHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) {
var (
wc io.WriteCloser
message []byte
ok bool
err error
)
defer func() {
doneChan <- err
}()
ticker := time.NewTicker(readWriteHandler.GetPingPeriod())
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok = <-writeChan:
if 0 < readWriteHandler.GetWriteTimeout() {
conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetWriteTimeout()))
}
if !ok {
conn.WriteMessage(CloseMessage, []byte{})
return
}
wc, err = conn.NextWriter(TextMessage)
if err != nil {
return
}
wc.Write(message)
if err = wc.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetPingTimeout()))
if err = conn.WriteMessage(PingMessage, nil); nil != err {
return
}
case <-stopChan:
return
}
}
}

View File

@ -1,26 +1,10 @@
package server package server
import (
"fmt"
"net"
"time"
)
type ServerHandler interface { type ServerHandler interface {
ConnectionHandler
ReadWriteHandler
GetName() string GetName() string
GetConcurrency() int
GetHandshakeTimeout() time.Duration
GetMaxMessageSize() int64
GetReadBufferSize() int
GetWriteBufferSize() int
GetReadTimeout() time.Duration
GetWriteTimeout() time.Duration
GetPongTimeout() time.Duration
GetPingTimeout() time.Duration
GetPingPeriod() time.Duration
IsEnableCompression() bool
ServerCtx() ServerCtx ServerCtx() ServerCtx
Init(serverCtx ServerCtx) error Init(serverCtx ServerCtx) error
@ -28,57 +12,18 @@ type ServerHandler interface {
OnStop(serverCtx ServerCtx) OnStop(serverCtx ServerCtx)
Destroy(serverCtx ServerCtx) Destroy(serverCtx ServerCtx)
Listener(serverCtx ServerCtx) (net.Listener, error)
Validate() error Validate() error
} }
type ServerHandlers struct { type ServerHandlers struct {
ServerHandler ServerHandler
ConnectionHandlers
ReadWriteHandlers
// Server name for sending in response headers. // Server name for sending in response headers.
// //
// Default server name is used if left blank. // Default server name is used if left blank.
Name string Name string
// The maximum number of concurrent connections the server may serve.
//
// DefaultConcurrency is used if not set.
Concurrency int
HandshakeTimeout time.Duration
MaxMessageSize int64
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// Maximum duration for reading the full request (including body).
//
// This also limits the maximum duration for idle keep-alive
// connections.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for writing the full response (including body).
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
EnableCompression bool
} }
func (sh *ServerHandlers) ServerCtx() ServerCtx { func (sh *ServerHandlers) ServerCtx() ServerCtx {
@ -101,90 +46,21 @@ func (sh *ServerHandlers) Destroy(serverCtx ServerCtx) {
} }
func (sh *ServerHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) {
return nil, fmt.Errorf("Server: Method[ServerHandler.Listener] is not implemented")
}
func (sh *ServerHandlers) GetName() string { func (sh *ServerHandlers) GetName() string {
return sh.Name return sh.Name
} }
func (sh *ServerHandlers) GetConcurrency() int {
return sh.Concurrency
}
func (sh *ServerHandlers) GetHandshakeTimeout() time.Duration {
return sh.HandshakeTimeout
}
func (sh *ServerHandlers) GetMaxMessageSize() int64 {
return sh.MaxMessageSize
}
func (sh *ServerHandlers) GetReadBufferSize() int {
return sh.ReadBufferSize
}
func (sh *ServerHandlers) GetWriteBufferSize() int {
return sh.WriteBufferSize
}
func (sh *ServerHandlers) GetReadTimeout() time.Duration {
return sh.ReadTimeout
}
func (sh *ServerHandlers) GetWriteTimeout() time.Duration {
return sh.WriteTimeout
}
func (sh *ServerHandlers) GetPongTimeout() time.Duration {
return sh.PongTimeout
}
func (sh *ServerHandlers) GetPingTimeout() time.Duration {
return sh.PingTimeout
}
func (sh *ServerHandlers) GetPingPeriod() time.Duration {
return sh.PingPeriod
}
func (sh *ServerHandlers) IsEnableCompression() bool {
return sh.EnableCompression
}
func (sh *ServerHandlers) Validate() error { func (sh *ServerHandlers) Validate() error {
if err := sh.ConnectionHandlers.Validate(); nil != err {
return err
}
if err := sh.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == sh.Name { if "" == sh.Name {
sh.Name = "Server" sh.Name = "Server"
} }
if sh.Concurrency <= 0 {
sh.Concurrency = DefaultConcurrency
}
if sh.HandshakeTimeout <= 0 {
sh.HandshakeTimeout = DefaultHandshakeTimeout
}
if sh.MaxMessageSize <= 0 {
sh.MaxMessageSize = DefaultMaxMessageSize
}
if sh.ReadBufferSize <= 0 {
sh.ReadBufferSize = DefaultReadBufferSize
}
if sh.WriteBufferSize <= 0 {
sh.WriteBufferSize = DefaultWriteBufferSize
}
if sh.ReadTimeout <= 0 {
sh.ReadTimeout = DefaultReadTimeout
}
if sh.WriteTimeout <= 0 {
sh.WriteTimeout = DefaultWriteTimeout
}
if sh.PongTimeout <= 0 {
sh.PongTimeout = DefaultPongTimeout
}
if sh.PingTimeout <= 0 {
sh.PingTimeout = DefaultPingTimeout
}
if sh.PingPeriod <= 0 {
sh.PingPeriod = (sh.PingTimeout * 9) / 10
}
return nil return nil
} }

81
server-rwc-handler.go Normal file
View File

@ -0,0 +1,81 @@
package server
import (
"sync"
logging "git.loafle.net/commons/logging-go"
)
type ServerRWCHandler struct {
connections sync.Map
ReadwriteHandler ReadWriteHandler
ServerStopChan <-chan struct{}
ServerStopWg *sync.WaitGroup
}
func (srwch *ServerRWCHandler) ConnectionSize() int {
var sz int
srwch.connections.Range(func(k, v interface{}) bool {
sz++
return true
})
return sz
}
func (srwch *ServerRWCHandler) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) {
addr := conn.RemoteAddr()
defer func() {
if nil != conn {
conn.Close()
}
servlet.OnDisconnect(servletCtx)
logging.Logger().Infof("Client[%s] has been disconnected", addr)
srwch.ServerStopWg.Done()
}()
logging.Logger().Infof("Client[%s] has been connected", addr)
srwch.connections.Store(conn, true)
defer srwch.connections.Delete(conn)
servlet.OnConnect(servletCtx, conn)
conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close")
return nil
})
stopChan := make(chan struct{})
servletDoneChan := make(chan struct{})
readChan := make(chan []byte)
writeChan := make(chan []byte)
readerDoneChan := make(chan error)
writerDoneChan := make(chan error)
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
go connReadHandler(srwch.ReadwriteHandler, conn, stopChan, readerDoneChan, readChan)
go connWriteHandler(srwch.ReadwriteHandler, conn, stopChan, writerDoneChan, writeChan)
select {
case <-readerDoneChan:
close(stopChan)
<-writerDoneChan
<-servletDoneChan
case <-writerDoneChan:
close(stopChan)
<-readerDoneChan
<-servletDoneChan
case <-servletDoneChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
case <-srwch.ServerStopChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
<-servletDoneChan
}
}

200
server.go
View File

@ -1,200 +0,0 @@
package server
import (
"context"
"fmt"
"sync"
"time"
logging "git.loafle.net/commons/logging-go"
)
type Server interface {
ListenAndServe() error
Shutdown(ctx context.Context) error
ConnectionSize() int
HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn)
HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte)
HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte)
}
type Servers struct {
ServerHandler ServerHandler
ServerCtx ServerCtx
Connections sync.Map
StopChan chan struct{}
StopWg sync.WaitGroup
}
func (s *Servers) Shutdown(ctx context.Context) error {
if s.StopChan == nil {
return fmt.Errorf("server must be started before stopping it")
}
close(s.StopChan)
s.StopWg.Wait()
s.ServerHandler.Destroy(s.ServerCtx)
s.StopChan = nil
return nil
}
func (s *Servers) ConnectionSize() int {
var sz int
s.Connections.Range(func(k, v interface{}) bool {
sz++
return true
})
return sz
}
func (s *Servers) ServerMessage(msg string) string {
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
}
func (s *Servers) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) {
addr := conn.RemoteAddr()
defer func() {
if nil != conn {
conn.Close()
}
servlet.OnDisconnect(servletCtx)
logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been disconnected", addr)))
s.StopWg.Done()
}()
logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been connected", addr)))
s.Connections.Store(conn, true)
defer s.Connections.Delete(conn)
servlet.OnConnect(servletCtx, conn)
conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close")
return nil
})
stopChan := make(chan struct{})
servletDoneChan := make(chan struct{})
readChan := make(chan []byte)
writeChan := make(chan []byte)
readerDoneChan := make(chan struct{})
writerDoneChan := make(chan struct{})
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
go s.HandleRead(conn, stopChan, readerDoneChan, readChan)
go s.HandleWrite(conn, stopChan, writerDoneChan, writeChan)
select {
case <-readerDoneChan:
close(stopChan)
<-writerDoneChan
<-servletDoneChan
case <-writerDoneChan:
close(stopChan)
<-readerDoneChan
<-servletDoneChan
case <-servletDoneChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
case <-s.StopChan:
close(stopChan)
<-readerDoneChan
<-writerDoneChan
<-servletDoneChan
}
}
func (s *Servers) HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) {
defer func() {
doneChan <- struct{}{}
}()
if 0 < s.ServerHandler.GetMaxMessageSize() {
conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize())
}
if 0 < s.ServerHandler.GetReadTimeout() {
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout()))
}
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout()))
return nil
})
var (
message []byte
err error
)
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
close(readMessageChan)
}()
select {
case <-s.StopChan:
<-readMessageChan
return
case <-readMessageChan:
}
if nil != err {
if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) {
logging.Logger().Debugf(s.ServerMessage(fmt.Sprintf("Read error %v", err)))
}
return
}
readChan <- message
}
}
func (s *Servers) HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) {
defer func() {
doneChan <- struct{}{}
}()
ticker := time.NewTicker(s.ServerHandler.GetPingPeriod())
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-writeChan:
if 0 < s.ServerHandler.GetWriteTimeout() {
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout()))
}
if !ok {
conn.WriteMessage(CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout()))
if err := conn.WriteMessage(PingMessage, nil); nil != err {
return
}
case <-stopChan:
return
}
}
}