ing
This commit is contained in:
parent
b48ffe1374
commit
aaed2f9fe3
42
client-connection-handler.go
Normal file
42
client-connection-handler.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ClientConnectionHandler interface {
|
||||
ConnectionHandler
|
||||
GetReconnectInterval() time.Duration
|
||||
GetReconnectTryTime() int
|
||||
}
|
||||
|
||||
type ClientConnectionHandlers struct {
|
||||
ConnectionHandlers
|
||||
|
||||
ReconnectInterval time.Duration
|
||||
ReconnectTryTime int
|
||||
}
|
||||
|
||||
func (cch *ClientConnectionHandlers) GetReconnectInterval() time.Duration {
|
||||
return cch.ReconnectInterval
|
||||
}
|
||||
|
||||
func (cch *ClientConnectionHandlers) GetReconnectTryTime() int {
|
||||
return cch.ReconnectTryTime
|
||||
}
|
||||
|
||||
func (cch *ClientConnectionHandlers) Validate() error {
|
||||
if err := cch.ConnectionHandlers.Validate(); nil != err {
|
||||
return err
|
||||
}
|
||||
|
||||
if cch.ReconnectInterval <= 0 {
|
||||
cch.ReconnectInterval = DefaultReconnectInterval
|
||||
}
|
||||
|
||||
if cch.ReconnectTryTime <= 0 {
|
||||
cch.ReconnectTryTime = DefaultReconnectTryTime
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
70
client-rwc-handler.go
Normal file
70
client-rwc-handler.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
logging "git.loafle.net/commons/logging-go"
|
||||
)
|
||||
|
||||
type ClientRWCHandler struct {
|
||||
ReadwriteHandler ReadWriteHandler
|
||||
ReadChan chan<- []byte
|
||||
WriteChan <-chan []byte
|
||||
DisconnectedChan chan<- struct{}
|
||||
ReconnectedChan <-chan *Conn
|
||||
ClientStopChan <-chan struct{}
|
||||
ClientStopWg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (crwch *ClientRWCHandler) HandleConnection(conn *Conn) {
|
||||
|
||||
defer func() {
|
||||
if nil != conn {
|
||||
conn.Close()
|
||||
}
|
||||
logging.Logger().Infof("disconnected")
|
||||
crwch.ClientStopWg.Done()
|
||||
}()
|
||||
|
||||
logging.Logger().Infof("connected")
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
readerDoneChan := make(chan error)
|
||||
writerDoneChan := make(chan error)
|
||||
|
||||
var err error
|
||||
|
||||
for {
|
||||
if nil != err {
|
||||
if io.EOF == err || io.ErrUnexpectedEOF == err {
|
||||
crwch.DisconnectedChan <- struct{}{}
|
||||
newConn := <-crwch.ReconnectedChan
|
||||
if nil == newConn {
|
||||
return
|
||||
}
|
||||
conn = newConn
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go connReadHandler(crwch.ReadwriteHandler, conn, stopChan, readerDoneChan, crwch.ReadChan)
|
||||
go connWriteHandler(crwch.ReadwriteHandler, conn, stopChan, writerDoneChan, crwch.WriteChan)
|
||||
|
||||
select {
|
||||
case err = <-readerDoneChan:
|
||||
close(stopChan)
|
||||
<-writerDoneChan
|
||||
case err = <-writerDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
case <-crwch.ClientStopChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
65
connection-handler.go
Normal file
65
connection-handler.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ConnectionHandler interface {
|
||||
GetConcurrency() int
|
||||
GetKeepAlive() time.Duration
|
||||
GetHandshakeTimeout() time.Duration
|
||||
GetTLSConfig() *tls.Config
|
||||
|
||||
Listener(serverCtx ServerCtx) (net.Listener, error)
|
||||
}
|
||||
|
||||
type ConnectionHandlers struct {
|
||||
ConnectionHandler
|
||||
|
||||
// The maximum number of concurrent connections the server may serve.
|
||||
//
|
||||
// DefaultConcurrency is used if not set.
|
||||
Concurrency int
|
||||
KeepAlive time.Duration
|
||||
HandshakeTimeout time.Duration
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
func (ch *ConnectionHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) {
|
||||
return nil, fmt.Errorf("Method[ConnectionHandler.Listener] is not implemented")
|
||||
}
|
||||
|
||||
func (ch *ConnectionHandlers) GetConcurrency() int {
|
||||
return ch.Concurrency
|
||||
}
|
||||
|
||||
func (ch *ConnectionHandlers) GetKeepAlive() time.Duration {
|
||||
return ch.KeepAlive
|
||||
}
|
||||
|
||||
func (ch *ConnectionHandlers) GetHandshakeTimeout() time.Duration {
|
||||
return ch.HandshakeTimeout
|
||||
}
|
||||
|
||||
func (ch *ConnectionHandlers) GetTLSConfig() *tls.Config {
|
||||
return ch.TLSConfig
|
||||
}
|
||||
|
||||
func (ch *ConnectionHandlers) Validate() error {
|
||||
if ch.Concurrency <= 0 {
|
||||
ch.Concurrency = DefaultConcurrency
|
||||
}
|
||||
|
||||
if ch.KeepAlive <= 0 {
|
||||
ch.KeepAlive = DefaultKeepAlive
|
||||
}
|
||||
|
||||
if ch.HandshakeTimeout <= 0 {
|
||||
ch.HandshakeTimeout = DefaultHandshakeTimeout
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
3
const.go
3
const.go
|
@ -30,4 +30,7 @@ const (
|
|||
DefaultPingTimeout = 10 * time.Second
|
||||
// DefaultPingPeriod is default value of send ping period
|
||||
DefaultPingPeriod = (DefaultPingTimeout * 9) / 10
|
||||
|
||||
DefaultReconnectInterval = 1 * time.Second
|
||||
DefaultReconnectTryTime = 10
|
||||
)
|
||||
|
|
|
@ -23,6 +23,9 @@ import (
|
|||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||
|
||||
type Client struct {
|
||||
server.ClientConnectionHandlers
|
||||
server.ReadWriteHandlers
|
||||
|
||||
Name string
|
||||
|
||||
URL string
|
||||
|
@ -44,46 +47,18 @@ type Client struct {
|
|||
// If Proxy is nil or returns a nil *URL, no proxy is used.
|
||||
Proxy func(*http.Request) (*url.URL, error)
|
||||
|
||||
TLSConfig *tls.Config
|
||||
HandshakeTimeout time.Duration
|
||||
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
|
||||
// size is zero, then a useful default size is used. The I/O buffer sizes
|
||||
// do not limit the size of the messages that can be sent or received.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
// Subprotocols specifies the client's requested subprotocols.
|
||||
|
||||
// EnableCompression specifies if the client should attempt to negotiate
|
||||
// per message compression (RFC 7692). Setting this value to true does not
|
||||
// guarantee that compression will be supported. Currently only "no context
|
||||
// takeover" modes are supported.
|
||||
EnableCompression bool
|
||||
|
||||
// MaxMessageSize is the maximum size for a message read from the peer. If a
|
||||
// message exceeds the limit, the connection sends a close frame to the peer
|
||||
// and returns ErrReadLimit to the application.
|
||||
MaxMessageSize int64
|
||||
// WriteTimeout is the write deadline on the underlying network
|
||||
// connection. After a write has timed out, the websocket state is corrupt and
|
||||
// all future writes will return an error. A zero value for t means writes will
|
||||
// not time out.
|
||||
WriteTimeout time.Duration
|
||||
// ReadTimeout is the read deadline on the underlying network connection.
|
||||
// After a read has timed out, the websocket connection state is corrupt and
|
||||
// all future reads will return an error. A zero value for t means reads will
|
||||
// not time out.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
PongTimeout time.Duration
|
||||
PingTimeout time.Duration
|
||||
PingPeriod time.Duration
|
||||
|
||||
serverURL *url.URL
|
||||
|
||||
stopChan chan struct{}
|
||||
stopWg sync.WaitGroup
|
||||
conn *server.Conn
|
||||
stopChan chan struct{}
|
||||
stopWg sync.WaitGroup
|
||||
|
||||
readChan chan []byte
|
||||
writeChan chan []byte
|
||||
|
||||
disconnectedChan chan struct{}
|
||||
reconnectedChan chan *server.Conn
|
||||
|
||||
crwch server.ClientRWCHandler
|
||||
}
|
||||
|
||||
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) {
|
||||
|
@ -107,10 +82,20 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res
|
|||
|
||||
c.readChan = make(chan []byte, 256)
|
||||
c.writeChan = make(chan []byte, 256)
|
||||
|
||||
c.disconnectedChan = make(chan struct{})
|
||||
c.reconnectedChan = make(chan *server.Conn)
|
||||
c.stopChan = make(chan struct{})
|
||||
|
||||
c.crwch.ReadwriteHandler = c
|
||||
c.crwch.ReadChan = c.readChan
|
||||
c.crwch.WriteChan = c.writeChan
|
||||
c.crwch.ClientStopChan = c.stopChan
|
||||
c.crwch.ClientStopWg = &c.stopWg
|
||||
c.crwch.DisconnectedChan = c.disconnectedChan
|
||||
c.crwch.ReconnectedChan = c.reconnectedChan
|
||||
|
||||
c.stopWg.Add(1)
|
||||
go c.handleConnection(conn)
|
||||
go c.crwch.HandleConnection(conn)
|
||||
|
||||
return c.readChan, c.writeChan, res, nil
|
||||
}
|
||||
|
@ -144,124 +129,6 @@ func (c *Client) connect() (*server.Conn, *http.Response, error) {
|
|||
return conn, res, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleConnection(conn *server.Conn) {
|
||||
defer func() {
|
||||
if nil != conn {
|
||||
conn.Close()
|
||||
}
|
||||
logging.Logger().Infof(c.clientMessage("disconnected"))
|
||||
c.stopWg.Done()
|
||||
}()
|
||||
|
||||
logging.Logger().Infof(c.clientMessage("connected"))
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
readerDoneChan := make(chan struct{})
|
||||
writerDoneChan := make(chan struct{})
|
||||
|
||||
go handleClientRead(c, conn, stopChan, readerDoneChan)
|
||||
go handleClientWrite(c, conn, stopChan, writerDoneChan)
|
||||
|
||||
select {
|
||||
case <-readerDoneChan:
|
||||
close(stopChan)
|
||||
<-writerDoneChan
|
||||
case <-writerDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
case <-c.stopChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
}
|
||||
}
|
||||
|
||||
func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
if 0 < c.MaxMessageSize {
|
||||
conn.SetReadLimit(c.MaxMessageSize)
|
||||
}
|
||||
if 0 < c.ReadTimeout {
|
||||
conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(c.PongTimeout))
|
||||
return nil
|
||||
})
|
||||
|
||||
var (
|
||||
message []byte
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
readMessageChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, message, err = conn.ReadMessage()
|
||||
close(readMessageChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopChan:
|
||||
<-readMessageChan
|
||||
return
|
||||
case <-readMessageChan:
|
||||
}
|
||||
|
||||
if nil != err {
|
||||
if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) {
|
||||
logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err)))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.readChan <- message
|
||||
}
|
||||
}
|
||||
|
||||
func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(c.PingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.writeChan:
|
||||
conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
|
||||
if !ok {
|
||||
conn.WriteMessage(server.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := conn.NextWriter(server.TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Write(message)
|
||||
|
||||
if err := w.Close(); nil != err {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
conn.SetWriteDeadline(time.Now().Add(c.PingTimeout))
|
||||
if err := conn.WriteMessage(server.PingMessage, nil); nil != err {
|
||||
return
|
||||
}
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Dial() (*server.Conn, *http.Response, error) {
|
||||
var (
|
||||
err error
|
||||
|
@ -476,6 +343,13 @@ func (c *Client) Dial() (*server.Conn, *http.Response, error) {
|
|||
}
|
||||
|
||||
func (c *Client) Validate() error {
|
||||
if err := c.ClientConnectionHandlers.Validate(); nil != err {
|
||||
return err
|
||||
}
|
||||
if err := c.ReadWriteHandlers.Validate(); nil != err {
|
||||
return err
|
||||
}
|
||||
|
||||
if "" == c.Name {
|
||||
c.Name = "Client"
|
||||
}
|
||||
|
@ -507,34 +381,6 @@ func (c *Client) Validate() error {
|
|||
c.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
if c.HandshakeTimeout <= 0 {
|
||||
c.HandshakeTimeout = server.DefaultHandshakeTimeout
|
||||
}
|
||||
if c.MaxMessageSize <= 0 {
|
||||
c.MaxMessageSize = server.DefaultMaxMessageSize
|
||||
}
|
||||
if c.ReadBufferSize <= 0 {
|
||||
c.ReadBufferSize = server.DefaultReadBufferSize
|
||||
}
|
||||
if c.WriteBufferSize <= 0 {
|
||||
c.WriteBufferSize = server.DefaultWriteBufferSize
|
||||
}
|
||||
if c.ReadTimeout <= 0 {
|
||||
c.ReadTimeout = server.DefaultReadTimeout
|
||||
}
|
||||
if c.WriteTimeout <= 0 {
|
||||
c.WriteTimeout = server.DefaultWriteTimeout
|
||||
}
|
||||
if c.PongTimeout <= 0 {
|
||||
c.PongTimeout = server.DefaultPongTimeout
|
||||
}
|
||||
if c.PingTimeout <= 0 {
|
||||
c.PingTimeout = server.DefaultPingTimeout
|
||||
}
|
||||
if c.PingPeriod <= 0 {
|
||||
c.PingPeriod = server.DefaultPingPeriod
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.loafle.net/commons/logging-go"
|
||||
"git.loafle.net/commons/server-go"
|
||||
|
@ -16,14 +15,14 @@ import (
|
|||
type Server struct {
|
||||
ServerHandler ServerHandler
|
||||
|
||||
ctx server.ServerCtx
|
||||
ctx server.ServerCtx
|
||||
stopChan chan struct{}
|
||||
stopWg sync.WaitGroup
|
||||
|
||||
srwch server.ServerRWCHandler
|
||||
|
||||
hs *fasthttp.Server
|
||||
upgrader *Upgrader
|
||||
|
||||
connections sync.Map
|
||||
stopChan chan struct{}
|
||||
stopWg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe() error {
|
||||
|
@ -59,7 +58,7 @@ func (s *Server) ListenAndServe() error {
|
|||
HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(),
|
||||
ReadBufferSize: s.ServerHandler.GetReadBufferSize(),
|
||||
WriteBufferSize: s.ServerHandler.GetWriteBufferSize(),
|
||||
CheckOrigin: s.ServerHandler.CheckOrigin,
|
||||
CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin,
|
||||
Error: s.onError,
|
||||
EnableCompression: s.ServerHandler.IsEnableCompression(),
|
||||
}
|
||||
|
@ -73,13 +72,18 @@ func (s *Server) ListenAndServe() error {
|
|||
}
|
||||
|
||||
s.stopChan = make(chan struct{})
|
||||
|
||||
s.srwch.ReadwriteHandler = s.ServerHandler
|
||||
s.srwch.ServerStopChan = s.stopChan
|
||||
s.srwch.ServerStopWg = &s.stopWg
|
||||
|
||||
s.stopWg.Add(1)
|
||||
return s.handleServer(listener)
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if s.stopChan == nil {
|
||||
return fmt.Errorf(s.serverMessage("server must be started before stopping it"))
|
||||
return fmt.Errorf("server must be started before stopping it")
|
||||
}
|
||||
close(s.stopChan)
|
||||
s.stopWg.Wait()
|
||||
|
@ -91,15 +95,6 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ConnectionSize() int {
|
||||
var sz int
|
||||
s.connections.Range(func(k, v interface{}) bool {
|
||||
sz++
|
||||
return true
|
||||
})
|
||||
return sz
|
||||
}
|
||||
|
||||
func (s *Server) serverMessage(msg string) string {
|
||||
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
|
||||
}
|
||||
|
@ -162,7 +157,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
|
|||
)
|
||||
|
||||
if 0 < s.ServerHandler.GetConcurrency() {
|
||||
sz := s.ConnectionSize()
|
||||
sz := s.srwch.ConnectionSize()
|
||||
if sz >= s.ServerHandler.GetConcurrency() {
|
||||
logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz)))
|
||||
s.onError(ctx, fasthttp.StatusServiceUnavailable, err)
|
||||
|
@ -170,7 +165,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
|
|||
}
|
||||
}
|
||||
|
||||
if servlet = s.ServerHandler.Servlet(path); nil == servlet {
|
||||
if servlet = s.ServerHandler.(ServerHandler).Servlet(path); nil == servlet {
|
||||
s.onError(ctx, fasthttp.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
@ -190,154 +185,10 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
|
|||
}
|
||||
|
||||
s.stopWg.Add(1)
|
||||
go s.handleConnection(servlet, servletCtx, conn)
|
||||
go s.srwch.HandleConnection(servlet, servletCtx, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx, conn *server.Conn) {
|
||||
addr := conn.RemoteAddr()
|
||||
|
||||
defer func() {
|
||||
if nil != conn {
|
||||
conn.Close()
|
||||
}
|
||||
servlet.OnDisconnect(servletCtx)
|
||||
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr)))
|
||||
s.stopWg.Done()
|
||||
}()
|
||||
|
||||
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr)))
|
||||
|
||||
s.connections.Store(conn, true)
|
||||
defer s.connections.Delete(conn)
|
||||
|
||||
servlet.OnConnect(servletCtx, conn)
|
||||
conn.SetCloseHandler(func(code int, text string) error {
|
||||
logging.Logger().Debugf("close")
|
||||
return nil
|
||||
})
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
servletDoneChan := make(chan struct{})
|
||||
|
||||
readChan := make(chan []byte)
|
||||
writeChan := make(chan []byte)
|
||||
|
||||
readerDoneChan := make(chan struct{})
|
||||
writerDoneChan := make(chan struct{})
|
||||
|
||||
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
|
||||
go handleRead(s, conn, stopChan, readerDoneChan, readChan)
|
||||
go handleWrite(s, conn, stopChan, writerDoneChan, writeChan)
|
||||
|
||||
select {
|
||||
case <-readerDoneChan:
|
||||
close(stopChan)
|
||||
<-writerDoneChan
|
||||
<-servletDoneChan
|
||||
case <-writerDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-servletDoneChan
|
||||
case <-servletDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
case <-s.stopChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
<-servletDoneChan
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) {
|
||||
s.ServerHandler.OnError(s.ctx, ctx, status, reason)
|
||||
}
|
||||
|
||||
func handleRead(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
if 0 < s.ServerHandler.GetMaxMessageSize() {
|
||||
conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize())
|
||||
}
|
||||
if 0 < s.ServerHandler.GetReadTimeout() {
|
||||
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout()))
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout()))
|
||||
return nil
|
||||
})
|
||||
|
||||
var (
|
||||
message []byte
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
readMessageChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, message, err = conn.ReadMessage()
|
||||
close(readMessageChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
<-readMessageChan
|
||||
return
|
||||
case <-readMessageChan:
|
||||
}
|
||||
|
||||
if nil != err {
|
||||
if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) {
|
||||
logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err)))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
readChan <- message
|
||||
}
|
||||
}
|
||||
|
||||
func handleWrite(s *Server, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(s.ServerHandler.GetPingPeriod())
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-writeChan:
|
||||
if 0 < s.ServerHandler.GetWriteTimeout() {
|
||||
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout()))
|
||||
}
|
||||
if !ok {
|
||||
conn.WriteMessage(server.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := conn.NextWriter(server.TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Write(message)
|
||||
|
||||
if err := w.Close(); nil != err {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout()))
|
||||
if err := conn.WriteMessage(server.PingMessage, nil); nil != err {
|
||||
return
|
||||
}
|
||||
case <-s.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
s.ServerHandler.(ServerHandler).OnError(s.ctx, ctx, status, reason)
|
||||
}
|
||||
|
|
222
net/client.go
222
net/client.go
|
@ -12,52 +12,25 @@ import (
|
|||
)
|
||||
|
||||
type Client struct {
|
||||
server.ClientConnectionHandlers
|
||||
server.ReadWriteHandlers
|
||||
|
||||
Name string
|
||||
|
||||
Network string
|
||||
Address string
|
||||
TLSConfig *tls.Config
|
||||
HandshakeTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
LocalAddress net.Addr
|
||||
Network string
|
||||
Address string
|
||||
LocalAddress net.Addr
|
||||
|
||||
MaxMessageSize int64
|
||||
// Per-connection buffer size for requests' reading.
|
||||
// This also limits the maximum header size.
|
||||
//
|
||||
// Increase this buffer if your clients send multi-KB RequestURIs
|
||||
// and/or multi-KB headers (for example, BIG cookies).
|
||||
//
|
||||
// Default buffer size is used if not set.
|
||||
ReadBufferSize int
|
||||
// Per-connection buffer size for responses' writing.
|
||||
//
|
||||
// Default buffer size is used if not set.
|
||||
WriteBufferSize int
|
||||
// Maximum duration for reading the full request (including body).
|
||||
//
|
||||
// This also limits the maximum duration for idle keep-alive
|
||||
// connections.
|
||||
//
|
||||
// By default request read timeout is unlimited.
|
||||
ReadTimeout time.Duration
|
||||
stopChan chan struct{}
|
||||
stopWg sync.WaitGroup
|
||||
|
||||
// Maximum duration for writing the full response (including body).
|
||||
//
|
||||
// By default response write timeout is unlimited.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
PongTimeout time.Duration
|
||||
PingTimeout time.Duration
|
||||
PingPeriod time.Duration
|
||||
|
||||
EnableCompression bool
|
||||
|
||||
stopChan chan struct{}
|
||||
stopWg sync.WaitGroup
|
||||
conn *server.Conn
|
||||
readChan chan []byte
|
||||
writeChan chan []byte
|
||||
|
||||
disconnectedChan chan struct{}
|
||||
reconnectedChan chan *server.Conn
|
||||
|
||||
crwch server.ClientRWCHandler
|
||||
}
|
||||
|
||||
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
|
||||
|
@ -81,10 +54,20 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err
|
|||
|
||||
c.readChan = make(chan []byte, 256)
|
||||
c.writeChan = make(chan []byte, 256)
|
||||
|
||||
c.disconnectedChan = make(chan struct{})
|
||||
c.reconnectedChan = make(chan *server.Conn)
|
||||
c.stopChan = make(chan struct{})
|
||||
|
||||
c.crwch.ReadwriteHandler = c
|
||||
c.crwch.ReadChan = c.readChan
|
||||
c.crwch.WriteChan = c.writeChan
|
||||
c.crwch.ClientStopChan = c.stopChan
|
||||
c.crwch.ClientStopWg = &c.stopWg
|
||||
c.crwch.DisconnectedChan = c.disconnectedChan
|
||||
c.crwch.ReconnectedChan = c.reconnectedChan
|
||||
|
||||
c.stopWg.Add(1)
|
||||
go c.handleConnection(conn)
|
||||
go c.crwch.HandleConnection(conn)
|
||||
|
||||
return c.readChan, c.writeChan, nil
|
||||
}
|
||||
|
@ -119,126 +102,6 @@ func (c *Client) connect() (*server.Conn, error) {
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleConnection(conn *server.Conn) {
|
||||
defer func() {
|
||||
if nil != conn {
|
||||
conn.Close()
|
||||
}
|
||||
logging.Logger().Infof(c.clientMessage("disconnected"))
|
||||
c.stopWg.Done()
|
||||
}()
|
||||
|
||||
logging.Logger().Infof(c.clientMessage("connected"))
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
readerDoneChan := make(chan struct{})
|
||||
writerDoneChan := make(chan struct{})
|
||||
|
||||
go handleClientRead(c, conn, stopChan, readerDoneChan)
|
||||
go handleClientWrite(c, conn, stopChan, writerDoneChan)
|
||||
|
||||
select {
|
||||
case <-readerDoneChan:
|
||||
close(stopChan)
|
||||
<-writerDoneChan
|
||||
case <-writerDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
case <-c.stopChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
}
|
||||
}
|
||||
|
||||
func handleClientRead(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
if 0 < c.MaxMessageSize {
|
||||
conn.SetReadLimit(c.MaxMessageSize)
|
||||
}
|
||||
if 0 < c.ReadTimeout {
|
||||
conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(c.PongTimeout))
|
||||
return nil
|
||||
})
|
||||
|
||||
var (
|
||||
message []byte
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
readMessageChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, message, err = conn.ReadMessage()
|
||||
close(readMessageChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopChan:
|
||||
<-readMessageChan
|
||||
return
|
||||
case <-readMessageChan:
|
||||
}
|
||||
|
||||
if nil != err {
|
||||
if server.IsUnexpectedCloseError(err, server.CloseGoingAway, server.CloseAbnormalClosure) {
|
||||
logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err)))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.readChan <- message
|
||||
}
|
||||
}
|
||||
|
||||
func handleClientWrite(c *Client, conn *server.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(c.PingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.writeChan:
|
||||
if 0 < c.WriteTimeout {
|
||||
conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
|
||||
}
|
||||
if !ok {
|
||||
conn.WriteMessage(server.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := conn.NextWriter(server.TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Write(message)
|
||||
|
||||
if err := w.Close(); nil != err {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
conn.SetWriteDeadline(time.Now().Add(c.PingTimeout))
|
||||
if err := conn.WriteMessage(server.PingMessage, nil); nil != err {
|
||||
return
|
||||
}
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Dial() (net.Conn, error) {
|
||||
if err := c.Validate(); nil != err {
|
||||
return nil, err
|
||||
|
@ -279,6 +142,13 @@ func (c *Client) Dial() (net.Conn, error) {
|
|||
}
|
||||
|
||||
func (c *Client) Validate() error {
|
||||
if err := c.ClientConnectionHandlers.Validate(); nil != err {
|
||||
return err
|
||||
}
|
||||
if err := c.ReadWriteHandlers.Validate(); nil != err {
|
||||
return err
|
||||
}
|
||||
|
||||
if "" == c.Name {
|
||||
c.Name = "Client"
|
||||
}
|
||||
|
@ -291,33 +161,5 @@ func (c *Client) Validate() error {
|
|||
return fmt.Errorf("Client: Address is not valid")
|
||||
}
|
||||
|
||||
if c.HandshakeTimeout <= 0 {
|
||||
c.HandshakeTimeout = server.DefaultHandshakeTimeout
|
||||
}
|
||||
if c.MaxMessageSize <= 0 {
|
||||
c.MaxMessageSize = server.DefaultMaxMessageSize
|
||||
}
|
||||
if c.ReadBufferSize <= 0 {
|
||||
c.ReadBufferSize = server.DefaultReadBufferSize
|
||||
}
|
||||
if c.WriteBufferSize <= 0 {
|
||||
c.WriteBufferSize = server.DefaultWriteBufferSize
|
||||
}
|
||||
if c.ReadTimeout <= 0 {
|
||||
c.ReadTimeout = server.DefaultReadTimeout
|
||||
}
|
||||
if c.WriteTimeout <= 0 {
|
||||
c.WriteTimeout = server.DefaultWriteTimeout
|
||||
}
|
||||
if c.PongTimeout <= 0 {
|
||||
c.PongTimeout = server.DefaultPongTimeout
|
||||
}
|
||||
if c.PingTimeout <= 0 {
|
||||
c.PingTimeout = server.DefaultPingTimeout
|
||||
}
|
||||
if c.PingPeriod <= 0 {
|
||||
c.PingPeriod = server.DefaultPingPeriod
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
100
net/server.go
100
net/server.go
|
@ -1,8 +1,10 @@
|
|||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -10,24 +12,19 @@ import (
|
|||
"git.loafle.net/commons/server-go"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
server.Server
|
||||
type Server struct {
|
||||
ServerHandler ServerHandler
|
||||
|
||||
ctx server.ServerCtx
|
||||
stopChan chan struct{}
|
||||
stopWg sync.WaitGroup
|
||||
|
||||
srwch server.ServerRWCHandler
|
||||
}
|
||||
|
||||
func NewServer(serverHandler ServerHandler) Server {
|
||||
s := &netServer{}
|
||||
s.ServerHandler = serverHandler
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
type netServer struct {
|
||||
server.Servers
|
||||
}
|
||||
|
||||
func (s *netServer) ListenAndServe() error {
|
||||
if s.StopChan != nil {
|
||||
return fmt.Errorf(s.ServerMessage("already running. Stop it before starting it again"))
|
||||
func (s *Server) ListenAndServe() error {
|
||||
if s.stopChan != nil {
|
||||
return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again"))
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -41,25 +38,48 @@ func (s *netServer) ListenAndServe() error {
|
|||
return err
|
||||
}
|
||||
|
||||
s.ServerCtx = s.ServerHandler.ServerCtx()
|
||||
if nil == s.ServerCtx {
|
||||
return fmt.Errorf(s.ServerMessage("ServerCtx is nil"))
|
||||
s.ctx = s.ServerHandler.ServerCtx()
|
||||
if nil == s.ctx {
|
||||
return fmt.Errorf(s.serverMessage("ServerCtx is nil"))
|
||||
}
|
||||
|
||||
if err = s.ServerHandler.Init(s.ServerCtx); nil != err {
|
||||
if err = s.ServerHandler.Init(s.ctx); nil != err {
|
||||
return err
|
||||
}
|
||||
|
||||
if listener, err = s.ServerHandler.Listener(s.ServerCtx); nil != err {
|
||||
if listener, err = s.ServerHandler.Listener(s.ctx); nil != err {
|
||||
return err
|
||||
}
|
||||
|
||||
s.StopChan = make(chan struct{})
|
||||
s.StopWg.Add(1)
|
||||
s.stopChan = make(chan struct{})
|
||||
|
||||
s.srwch.ReadwriteHandler = s.ServerHandler
|
||||
s.srwch.ServerStopChan = s.stopChan
|
||||
s.srwch.ServerStopWg = &s.stopWg
|
||||
|
||||
s.stopWg.Add(1)
|
||||
return s.handleServer(listener)
|
||||
}
|
||||
|
||||
func (s *netServer) handleServer(listener net.Listener) error {
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if s.stopChan == nil {
|
||||
return fmt.Errorf("server must be started before stopping it")
|
||||
}
|
||||
close(s.stopChan)
|
||||
s.stopWg.Wait()
|
||||
|
||||
s.ServerHandler.Destroy(s.ctx)
|
||||
|
||||
s.stopChan = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) serverMessage(msg string) string {
|
||||
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
|
||||
}
|
||||
|
||||
func (s *Server) handleServer(listener net.Listener) error {
|
||||
var (
|
||||
stopping atomic.Value
|
||||
netConn net.Conn
|
||||
|
@ -71,18 +91,18 @@ func (s *netServer) handleServer(listener net.Listener) error {
|
|||
listener.Close()
|
||||
}
|
||||
|
||||
s.ServerHandler.OnStop(s.ServerCtx)
|
||||
s.ServerHandler.OnStop(s.ctx)
|
||||
|
||||
logging.Logger().Infof(s.ServerMessage("Stopped"))
|
||||
logging.Logger().Infof(s.serverMessage("Stopped"))
|
||||
|
||||
s.StopWg.Done()
|
||||
s.stopWg.Done()
|
||||
}()
|
||||
|
||||
if err = s.ServerHandler.OnStart(s.ServerCtx); nil != err {
|
||||
if err = s.ServerHandler.OnStart(s.ctx); nil != err {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.Logger().Infof(s.ServerMessage("Started"))
|
||||
logging.Logger().Infof(s.serverMessage("Started"))
|
||||
|
||||
for {
|
||||
acceptChan := make(chan struct{})
|
||||
|
@ -90,14 +110,14 @@ func (s *netServer) handleServer(listener net.Listener) error {
|
|||
go func() {
|
||||
if netConn, err = listener.Accept(); err != nil {
|
||||
if nil == stopping.Load() {
|
||||
logging.Logger().Errorf(s.ServerMessage(fmt.Sprintf("%v", err)))
|
||||
logging.Logger().Errorf(s.serverMessage(fmt.Sprintf("%v", err)))
|
||||
}
|
||||
}
|
||||
close(acceptChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.StopChan:
|
||||
case <-s.stopChan:
|
||||
stopping.Store(true)
|
||||
listener.Close()
|
||||
<-acceptChan
|
||||
|
@ -108,7 +128,7 @@ func (s *netServer) handleServer(listener net.Listener) error {
|
|||
|
||||
if nil != err {
|
||||
select {
|
||||
case <-s.StopChan:
|
||||
case <-s.stopChan:
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
@ -116,9 +136,9 @@ func (s *netServer) handleServer(listener net.Listener) error {
|
|||
}
|
||||
|
||||
if 0 < s.ServerHandler.GetConcurrency() {
|
||||
sz := s.ConnectionSize()
|
||||
sz := s.srwch.ConnectionSize()
|
||||
if sz >= s.ServerHandler.GetConcurrency() {
|
||||
logging.Logger().Warnf(s.ServerMessage(fmt.Sprintf("max connections size %d, refuse", sz)))
|
||||
logging.Logger().Warnf(s.serverMessage(fmt.Sprintf("max connections size %d, refuse", sz)))
|
||||
netConn.Close()
|
||||
continue
|
||||
}
|
||||
|
@ -126,24 +146,24 @@ func (s *netServer) handleServer(listener net.Listener) error {
|
|||
|
||||
servlet := s.ServerHandler.(ServerHandler).Servlet()
|
||||
if nil == servlet {
|
||||
logging.Logger().Errorf(s.ServerMessage("Servlet is nil"))
|
||||
logging.Logger().Errorf(s.serverMessage("Servlet is nil"))
|
||||
continue
|
||||
}
|
||||
|
||||
servletCtx := servlet.ServletCtx(s.ServerCtx)
|
||||
servletCtx := servlet.ServletCtx(s.ctx)
|
||||
if nil == servletCtx {
|
||||
logging.Logger().Errorf(s.ServerMessage("ServletCtx is nil"))
|
||||
logging.Logger().Errorf(s.serverMessage("ServletCtx is nil"))
|
||||
continue
|
||||
}
|
||||
|
||||
if err := servlet.Handshake(servletCtx, netConn); nil != err {
|
||||
logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err)))
|
||||
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Handshaking of Client[%s] has been failed %v", netConn.RemoteAddr(), err)))
|
||||
continue
|
||||
}
|
||||
|
||||
conn := server.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize())
|
||||
|
||||
s.StopWg.Add(1)
|
||||
go s.HandleConnection(servlet, servletCtx, conn)
|
||||
s.stopWg.Add(1)
|
||||
go s.srwch.HandleConnection(servlet, servletCtx, conn)
|
||||
}
|
||||
}
|
||||
|
|
115
readwrite-handler.go
Normal file
115
readwrite-handler.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ReadWriteHandler interface {
|
||||
GetMaxMessageSize() int64
|
||||
GetReadBufferSize() int
|
||||
GetWriteBufferSize() int
|
||||
GetReadTimeout() time.Duration
|
||||
GetWriteTimeout() time.Duration
|
||||
GetPongTimeout() time.Duration
|
||||
GetPingTimeout() time.Duration
|
||||
GetPingPeriod() time.Duration
|
||||
|
||||
IsEnableCompression() bool
|
||||
}
|
||||
|
||||
type ReadWriteHandlers struct {
|
||||
ReadWriteHandler
|
||||
|
||||
MaxMessageSize int64
|
||||
// Per-connection buffer size for requests' reading.
|
||||
// This also limits the maximum header size.
|
||||
//
|
||||
// Increase this buffer if your clients send multi-KB RequestURIs
|
||||
// and/or multi-KB headers (for example, BIG cookies).
|
||||
//
|
||||
// Default buffer size is used if not set.
|
||||
ReadBufferSize int
|
||||
// Per-connection buffer size for responses' writing.
|
||||
//
|
||||
// Default buffer size is used if not set.
|
||||
WriteBufferSize int
|
||||
// Maximum duration for reading the full request (including body).
|
||||
//
|
||||
// This also limits the maximum duration for idle keep-alive
|
||||
// connections.
|
||||
//
|
||||
// By default request read timeout is unlimited.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// Maximum duration for writing the full response (including body).
|
||||
//
|
||||
// By default response write timeout is unlimited.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
PongTimeout time.Duration
|
||||
PingTimeout time.Duration
|
||||
PingPeriod time.Duration
|
||||
|
||||
EnableCompression bool
|
||||
}
|
||||
|
||||
func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 {
|
||||
return rwh.MaxMessageSize
|
||||
}
|
||||
|
||||
func (rwh *ReadWriteHandlers) GetReadBufferSize() int {
|
||||
return rwh.ReadBufferSize
|
||||
}
|
||||
|
||||
func (rwh *ReadWriteHandlers) GetWriteBufferSize() int {
|
||||
return rwh.WriteBufferSize
|
||||
}
|
||||
|
||||
func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration {
|
||||
return rwh.ReadTimeout
|
||||
}
|
||||
func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration {
|
||||
return rwh.WriteTimeout
|
||||
}
|
||||
func (rwh *ReadWriteHandlers) GetPongTimeout() time.Duration {
|
||||
return rwh.PongTimeout
|
||||
}
|
||||
func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration {
|
||||
return rwh.PingTimeout
|
||||
}
|
||||
func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
|
||||
return rwh.PingPeriod
|
||||
}
|
||||
|
||||
func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
|
||||
return rwh.EnableCompression
|
||||
}
|
||||
|
||||
func (rwh *ReadWriteHandlers) Validate() error {
|
||||
if rwh.MaxMessageSize <= 0 {
|
||||
rwh.MaxMessageSize = DefaultMaxMessageSize
|
||||
}
|
||||
if rwh.ReadBufferSize <= 0 {
|
||||
rwh.ReadBufferSize = DefaultReadBufferSize
|
||||
}
|
||||
if rwh.WriteBufferSize <= 0 {
|
||||
rwh.WriteBufferSize = DefaultWriteBufferSize
|
||||
}
|
||||
if rwh.ReadTimeout <= 0 {
|
||||
rwh.ReadTimeout = DefaultReadTimeout
|
||||
}
|
||||
if rwh.WriteTimeout <= 0 {
|
||||
rwh.WriteTimeout = DefaultWriteTimeout
|
||||
}
|
||||
if rwh.PongTimeout <= 0 {
|
||||
rwh.PongTimeout = DefaultPongTimeout
|
||||
}
|
||||
if rwh.PingTimeout <= 0 {
|
||||
rwh.PingTimeout = DefaultPingTimeout
|
||||
}
|
||||
if rwh.PingPeriod <= 0 {
|
||||
rwh.PingPeriod = (rwh.PingTimeout * 9) / 10
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
101
readwrite.go
Normal file
101
readwrite.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
func connReadHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- []byte) {
|
||||
var (
|
||||
message []byte
|
||||
err error
|
||||
)
|
||||
|
||||
defer func() {
|
||||
doneChan <- err
|
||||
}()
|
||||
|
||||
if 0 < readWriteHandler.GetMaxMessageSize() {
|
||||
conn.SetReadLimit(readWriteHandler.GetMaxMessageSize())
|
||||
}
|
||||
if 0 < readWriteHandler.GetReadTimeout() {
|
||||
conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetReadTimeout()))
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetPongTimeout()))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
readMessageChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, message, err = conn.ReadMessage()
|
||||
close(readMessageChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopChan:
|
||||
<-readMessageChan
|
||||
return
|
||||
case <-readMessageChan:
|
||||
}
|
||||
|
||||
if nil != err {
|
||||
if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) {
|
||||
err = fmt.Errorf("Read error %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
readChan <- message
|
||||
}
|
||||
}
|
||||
|
||||
func connWriteHandler(readWriteHandler ReadWriteHandler, conn *Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan []byte) {
|
||||
var (
|
||||
wc io.WriteCloser
|
||||
message []byte
|
||||
ok bool
|
||||
err error
|
||||
)
|
||||
|
||||
defer func() {
|
||||
doneChan <- err
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(readWriteHandler.GetPingPeriod())
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case message, ok = <-writeChan:
|
||||
if 0 < readWriteHandler.GetWriteTimeout() {
|
||||
conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetWriteTimeout()))
|
||||
}
|
||||
if !ok {
|
||||
conn.WriteMessage(CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
wc, err = conn.NextWriter(TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
wc.Write(message)
|
||||
|
||||
if err = wc.Close(); nil != err {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetPingTimeout()))
|
||||
if err = conn.WriteMessage(PingMessage, nil); nil != err {
|
||||
return
|
||||
}
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,26 +1,10 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ServerHandler interface {
|
||||
ConnectionHandler
|
||||
ReadWriteHandler
|
||||
|
||||
GetName() string
|
||||
GetConcurrency() int
|
||||
GetHandshakeTimeout() time.Duration
|
||||
GetMaxMessageSize() int64
|
||||
GetReadBufferSize() int
|
||||
GetWriteBufferSize() int
|
||||
GetReadTimeout() time.Duration
|
||||
GetWriteTimeout() time.Duration
|
||||
GetPongTimeout() time.Duration
|
||||
GetPingTimeout() time.Duration
|
||||
GetPingPeriod() time.Duration
|
||||
|
||||
IsEnableCompression() bool
|
||||
|
||||
ServerCtx() ServerCtx
|
||||
|
||||
Init(serverCtx ServerCtx) error
|
||||
|
@ -28,57 +12,18 @@ type ServerHandler interface {
|
|||
OnStop(serverCtx ServerCtx)
|
||||
Destroy(serverCtx ServerCtx)
|
||||
|
||||
Listener(serverCtx ServerCtx) (net.Listener, error)
|
||||
|
||||
Validate() error
|
||||
}
|
||||
|
||||
type ServerHandlers struct {
|
||||
ServerHandler
|
||||
ConnectionHandlers
|
||||
ReadWriteHandlers
|
||||
|
||||
// Server name for sending in response headers.
|
||||
//
|
||||
// Default server name is used if left blank.
|
||||
Name string
|
||||
|
||||
// The maximum number of concurrent connections the server may serve.
|
||||
//
|
||||
// DefaultConcurrency is used if not set.
|
||||
Concurrency int
|
||||
|
||||
HandshakeTimeout time.Duration
|
||||
|
||||
MaxMessageSize int64
|
||||
// Per-connection buffer size for requests' reading.
|
||||
// This also limits the maximum header size.
|
||||
//
|
||||
// Increase this buffer if your clients send multi-KB RequestURIs
|
||||
// and/or multi-KB headers (for example, BIG cookies).
|
||||
//
|
||||
// Default buffer size is used if not set.
|
||||
ReadBufferSize int
|
||||
// Per-connection buffer size for responses' writing.
|
||||
//
|
||||
// Default buffer size is used if not set.
|
||||
WriteBufferSize int
|
||||
// Maximum duration for reading the full request (including body).
|
||||
//
|
||||
// This also limits the maximum duration for idle keep-alive
|
||||
// connections.
|
||||
//
|
||||
// By default request read timeout is unlimited.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// Maximum duration for writing the full response (including body).
|
||||
//
|
||||
// By default response write timeout is unlimited.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
PongTimeout time.Duration
|
||||
PingTimeout time.Duration
|
||||
PingPeriod time.Duration
|
||||
|
||||
EnableCompression bool
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) ServerCtx() ServerCtx {
|
||||
|
@ -101,90 +46,21 @@ func (sh *ServerHandlers) Destroy(serverCtx ServerCtx) {
|
|||
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) {
|
||||
return nil, fmt.Errorf("Server: Method[ServerHandler.Listener] is not implemented")
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) GetName() string {
|
||||
return sh.Name
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) GetConcurrency() int {
|
||||
return sh.Concurrency
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) GetHandshakeTimeout() time.Duration {
|
||||
return sh.HandshakeTimeout
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) GetMaxMessageSize() int64 {
|
||||
return sh.MaxMessageSize
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) GetReadBufferSize() int {
|
||||
return sh.ReadBufferSize
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) GetWriteBufferSize() int {
|
||||
return sh.WriteBufferSize
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) GetReadTimeout() time.Duration {
|
||||
return sh.ReadTimeout
|
||||
}
|
||||
func (sh *ServerHandlers) GetWriteTimeout() time.Duration {
|
||||
return sh.WriteTimeout
|
||||
}
|
||||
func (sh *ServerHandlers) GetPongTimeout() time.Duration {
|
||||
return sh.PongTimeout
|
||||
}
|
||||
func (sh *ServerHandlers) GetPingTimeout() time.Duration {
|
||||
return sh.PingTimeout
|
||||
}
|
||||
func (sh *ServerHandlers) GetPingPeriod() time.Duration {
|
||||
return sh.PingPeriod
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) IsEnableCompression() bool {
|
||||
return sh.EnableCompression
|
||||
}
|
||||
|
||||
func (sh *ServerHandlers) Validate() error {
|
||||
if err := sh.ConnectionHandlers.Validate(); nil != err {
|
||||
return err
|
||||
}
|
||||
if err := sh.ReadWriteHandlers.Validate(); nil != err {
|
||||
return err
|
||||
}
|
||||
|
||||
if "" == sh.Name {
|
||||
sh.Name = "Server"
|
||||
}
|
||||
|
||||
if sh.Concurrency <= 0 {
|
||||
sh.Concurrency = DefaultConcurrency
|
||||
}
|
||||
|
||||
if sh.HandshakeTimeout <= 0 {
|
||||
sh.HandshakeTimeout = DefaultHandshakeTimeout
|
||||
}
|
||||
if sh.MaxMessageSize <= 0 {
|
||||
sh.MaxMessageSize = DefaultMaxMessageSize
|
||||
}
|
||||
if sh.ReadBufferSize <= 0 {
|
||||
sh.ReadBufferSize = DefaultReadBufferSize
|
||||
}
|
||||
if sh.WriteBufferSize <= 0 {
|
||||
sh.WriteBufferSize = DefaultWriteBufferSize
|
||||
}
|
||||
if sh.ReadTimeout <= 0 {
|
||||
sh.ReadTimeout = DefaultReadTimeout
|
||||
}
|
||||
if sh.WriteTimeout <= 0 {
|
||||
sh.WriteTimeout = DefaultWriteTimeout
|
||||
}
|
||||
if sh.PongTimeout <= 0 {
|
||||
sh.PongTimeout = DefaultPongTimeout
|
||||
}
|
||||
if sh.PingTimeout <= 0 {
|
||||
sh.PingTimeout = DefaultPingTimeout
|
||||
}
|
||||
if sh.PingPeriod <= 0 {
|
||||
sh.PingPeriod = (sh.PingTimeout * 9) / 10
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
81
server-rwc-handler.go
Normal file
81
server-rwc-handler.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
logging "git.loafle.net/commons/logging-go"
|
||||
)
|
||||
|
||||
type ServerRWCHandler struct {
|
||||
connections sync.Map
|
||||
|
||||
ReadwriteHandler ReadWriteHandler
|
||||
ServerStopChan <-chan struct{}
|
||||
ServerStopWg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (srwch *ServerRWCHandler) ConnectionSize() int {
|
||||
var sz int
|
||||
srwch.connections.Range(func(k, v interface{}) bool {
|
||||
sz++
|
||||
return true
|
||||
})
|
||||
return sz
|
||||
}
|
||||
|
||||
func (srwch *ServerRWCHandler) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) {
|
||||
addr := conn.RemoteAddr()
|
||||
|
||||
defer func() {
|
||||
if nil != conn {
|
||||
conn.Close()
|
||||
}
|
||||
servlet.OnDisconnect(servletCtx)
|
||||
logging.Logger().Infof("Client[%s] has been disconnected", addr)
|
||||
srwch.ServerStopWg.Done()
|
||||
}()
|
||||
|
||||
logging.Logger().Infof("Client[%s] has been connected", addr)
|
||||
|
||||
srwch.connections.Store(conn, true)
|
||||
defer srwch.connections.Delete(conn)
|
||||
|
||||
servlet.OnConnect(servletCtx, conn)
|
||||
conn.SetCloseHandler(func(code int, text string) error {
|
||||
logging.Logger().Debugf("close")
|
||||
return nil
|
||||
})
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
servletDoneChan := make(chan struct{})
|
||||
|
||||
readChan := make(chan []byte)
|
||||
writeChan := make(chan []byte)
|
||||
|
||||
readerDoneChan := make(chan error)
|
||||
writerDoneChan := make(chan error)
|
||||
|
||||
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
|
||||
go connReadHandler(srwch.ReadwriteHandler, conn, stopChan, readerDoneChan, readChan)
|
||||
go connWriteHandler(srwch.ReadwriteHandler, conn, stopChan, writerDoneChan, writeChan)
|
||||
|
||||
select {
|
||||
case <-readerDoneChan:
|
||||
close(stopChan)
|
||||
<-writerDoneChan
|
||||
<-servletDoneChan
|
||||
case <-writerDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-servletDoneChan
|
||||
case <-servletDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
case <-srwch.ServerStopChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
<-servletDoneChan
|
||||
}
|
||||
}
|
200
server.go
200
server.go
|
@ -1,200 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
logging "git.loafle.net/commons/logging-go"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
ListenAndServe() error
|
||||
Shutdown(ctx context.Context) error
|
||||
ConnectionSize() int
|
||||
|
||||
HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn)
|
||||
HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte)
|
||||
HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte)
|
||||
}
|
||||
|
||||
type Servers struct {
|
||||
ServerHandler ServerHandler
|
||||
|
||||
ServerCtx ServerCtx
|
||||
Connections sync.Map
|
||||
StopChan chan struct{}
|
||||
StopWg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (s *Servers) Shutdown(ctx context.Context) error {
|
||||
if s.StopChan == nil {
|
||||
return fmt.Errorf("server must be started before stopping it")
|
||||
}
|
||||
close(s.StopChan)
|
||||
s.StopWg.Wait()
|
||||
|
||||
s.ServerHandler.Destroy(s.ServerCtx)
|
||||
|
||||
s.StopChan = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Servers) ConnectionSize() int {
|
||||
var sz int
|
||||
s.Connections.Range(func(k, v interface{}) bool {
|
||||
sz++
|
||||
return true
|
||||
})
|
||||
return sz
|
||||
}
|
||||
|
||||
func (s *Servers) ServerMessage(msg string) string {
|
||||
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
|
||||
}
|
||||
|
||||
func (s *Servers) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn *Conn) {
|
||||
addr := conn.RemoteAddr()
|
||||
|
||||
defer func() {
|
||||
if nil != conn {
|
||||
conn.Close()
|
||||
}
|
||||
servlet.OnDisconnect(servletCtx)
|
||||
logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been disconnected", addr)))
|
||||
s.StopWg.Done()
|
||||
}()
|
||||
|
||||
logging.Logger().Infof(s.ServerMessage(fmt.Sprintf("Client[%s] has been connected", addr)))
|
||||
|
||||
s.Connections.Store(conn, true)
|
||||
defer s.Connections.Delete(conn)
|
||||
|
||||
servlet.OnConnect(servletCtx, conn)
|
||||
conn.SetCloseHandler(func(code int, text string) error {
|
||||
logging.Logger().Debugf("close")
|
||||
return nil
|
||||
})
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
servletDoneChan := make(chan struct{})
|
||||
|
||||
readChan := make(chan []byte)
|
||||
writeChan := make(chan []byte)
|
||||
|
||||
readerDoneChan := make(chan struct{})
|
||||
writerDoneChan := make(chan struct{})
|
||||
|
||||
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
|
||||
go s.HandleRead(conn, stopChan, readerDoneChan, readChan)
|
||||
go s.HandleWrite(conn, stopChan, writerDoneChan, writeChan)
|
||||
|
||||
select {
|
||||
case <-readerDoneChan:
|
||||
close(stopChan)
|
||||
<-writerDoneChan
|
||||
<-servletDoneChan
|
||||
case <-writerDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-servletDoneChan
|
||||
case <-servletDoneChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
case <-s.StopChan:
|
||||
close(stopChan)
|
||||
<-readerDoneChan
|
||||
<-writerDoneChan
|
||||
<-servletDoneChan
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Servers) HandleRead(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan chan []byte) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
if 0 < s.ServerHandler.GetMaxMessageSize() {
|
||||
conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize())
|
||||
}
|
||||
if 0 < s.ServerHandler.GetReadTimeout() {
|
||||
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout()))
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetPongTimeout()))
|
||||
return nil
|
||||
})
|
||||
|
||||
var (
|
||||
message []byte
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
readMessageChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, message, err = conn.ReadMessage()
|
||||
close(readMessageChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.StopChan:
|
||||
<-readMessageChan
|
||||
return
|
||||
case <-readMessageChan:
|
||||
}
|
||||
|
||||
if nil != err {
|
||||
if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) {
|
||||
logging.Logger().Debugf(s.ServerMessage(fmt.Sprintf("Read error %v", err)))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
readChan <- message
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Servers) HandleWrite(conn *Conn, stopChan <-chan struct{}, doneChan chan<- struct{}, writeChan chan []byte) {
|
||||
defer func() {
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(s.ServerHandler.GetPingPeriod())
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-writeChan:
|
||||
if 0 < s.ServerHandler.GetWriteTimeout() {
|
||||
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetWriteTimeout()))
|
||||
}
|
||||
if !ok {
|
||||
conn.WriteMessage(CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := conn.NextWriter(TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Write(message)
|
||||
|
||||
if err := w.Close(); nil != err {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
conn.SetWriteDeadline(time.Now().Add(s.ServerHandler.GetPingTimeout()))
|
||||
if err := conn.WriteMessage(PingMessage, nil); nil != err {
|
||||
return
|
||||
}
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user