This commit is contained in:
crusader 2018-04-04 14:31:10 +09:00
parent f85b949e66
commit 67bc6d5d03
5 changed files with 553 additions and 41 deletions

View File

@ -5,6 +5,8 @@ const (
// the Server may serve by default (i.e. if Server.Concurrency isn't set). // the Server may serve by default (i.e. if Server.Concurrency isn't set).
DefaultConcurrency = 256 * 1024 DefaultConcurrency = 256 * 1024
DefaultKeepAlive = 0
// DefaultHandshakeTimeout is default value of websocket handshake Timeout // DefaultHandshakeTimeout is default value of websocket handshake Timeout
DefaultHandshakeTimeout = 0 DefaultHandshakeTimeout = 0

View File

@ -13,8 +13,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
logging "git.loafle.net/commons/logging-go"
server "git.loafle.net/commons/server-go" server "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/internal" "git.loafle.net/commons/server-go/internal"
) )
@ -27,6 +29,11 @@ type Client struct {
URL string URL string
RequestHeader http.Header RequestHeader http.Header
Subprotocols []string
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
CookieJar http.CookieJar
// NetDial specifies the dial function for creating TCP connections. If // NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
@ -45,7 +52,6 @@ type Client struct {
// do not limit the size of the messages that can be sent or received. // do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the client's requested subprotocols. // Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate // EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not // per message compression (RFC 7692). Setting this value to true does not
@ -53,11 +59,6 @@ type Client struct {
// takeover" modes are supported. // takeover" modes are supported.
EnableCompression bool EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
CookieJar http.CookieJar
// MaxMessageSize is the maximum size for a message read from the peer. If a // MaxMessageSize is the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer // message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application. // and returns ErrReadLimit to the application.
@ -78,6 +79,197 @@ type Client struct {
PingPeriod time.Duration PingPeriod time.Duration
serverURL *url.URL serverURL *url.URL
stopChan chan struct{}
stopWg sync.WaitGroup
conn *internal.Conn
readChan chan []byte
writeChan chan []byte
}
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, res *http.Response, err error) {
var (
conn *internal.Conn
)
if c.stopChan != nil {
return nil, nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again"))
}
err = c.Validate()
if nil != err {
return nil, nil, nil, err
}
conn, res, err = c.connect()
if nil != err {
return nil, nil, nil, err
}
c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256)
c.stopChan = make(chan struct{})
c.stopWg.Add(1)
go c.handleConnection(conn)
return c.readChan, c.writeChan, res, nil
}
func (c *Client) Disconnect() error {
if c.stopChan == nil {
return fmt.Errorf(c.clientMessage("must be started before stopping it"))
}
close(c.stopChan)
c.stopWg.Wait()
c.stopChan = nil
return nil
}
func (c *Client) clientMessage(msg string) string {
return fmt.Sprintf("Client[%s]: %s", c.Name, msg)
}
func (c *Client) connect() (*internal.Conn, *http.Response, error) {
conn, res, err := c.Dial()
if nil != err {
return nil, nil, err
}
conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close")
return nil
})
return conn, res, nil
}
func (c *Client) handleConnection(conn *internal.Conn) {
defer func() {
if nil != conn {
conn.Close()
}
logging.Logger().Infof(c.clientMessage("disconnected"))
c.stopWg.Done()
}()
logging.Logger().Infof(c.clientMessage("connected"))
stopChan := make(chan struct{})
readerDoneChan := make(chan struct{})
writerDoneChan := make(chan struct{})
go handleClientRead(c, conn, stopChan, readerDoneChan)
go handleClientWrite(c, conn, stopChan, writerDoneChan)
select {
case <-readerDoneChan:
close(stopChan)
conn.Close()
<-writerDoneChan
conn = nil
case <-writerDoneChan:
close(stopChan)
conn.Close()
<-readerDoneChan
conn = nil
case <-c.stopChan:
close(stopChan)
conn.Close()
<-readerDoneChan
<-writerDoneChan
conn = nil
}
}
func handleClientRead(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) {
defer func() {
close(doneChan)
}()
conn.SetReadLimit(c.MaxMessageSize)
conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(c.PongTimeout))
return nil
})
var (
message []byte
err error
)
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
if err != nil {
if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) {
logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err)))
}
}
close(readMessageChan)
}()
select {
case <-c.stopChan:
<-readMessageChan
break
case <-readMessageChan:
}
if nil != err {
select {
case <-c.stopChan:
break
case <-time.After(time.Second):
}
continue
}
c.readChan <- message
}
}
func handleClientWrite(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) {
defer func() {
close(doneChan)
}()
ticker := time.NewTicker(c.PingPeriod)
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-c.writeChan:
conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
if !ok {
conn.WriteMessage(internal.CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(internal.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(c.PingTimeout))
if err := conn.WriteMessage(internal.PingMessage, nil); nil != err {
return
}
case <-c.stopChan:
break
}
}
} }
func (c *Client) Dial() (*internal.Conn, *http.Response, error) { func (c *Client) Dial() (*internal.Conn, *http.Response, error) {
@ -325,15 +517,33 @@ func (c *Client) Validate() error {
c.Proxy = http.ProxyFromEnvironment c.Proxy = http.ProxyFromEnvironment
} }
if 0 > c.HandshakeTimeout { if c.HandshakeTimeout <= 0 {
c.HandshakeTimeout = server.DefaultHandshakeTimeout c.HandshakeTimeout = server.DefaultHandshakeTimeout
} }
if 0 > c.ReadBufferSize { if c.MaxMessageSize <= 0 {
c.MaxMessageSize = server.DefaultMaxMessageSize
}
if c.ReadBufferSize <= 0 {
c.ReadBufferSize = server.DefaultReadBufferSize c.ReadBufferSize = server.DefaultReadBufferSize
} }
if 0 > c.WriteBufferSize { if c.WriteBufferSize <= 0 {
c.WriteBufferSize = server.DefaultWriteBufferSize c.WriteBufferSize = server.DefaultWriteBufferSize
} }
if c.ReadTimeout <= 0 {
c.ReadTimeout = server.DefaultReadTimeout
}
if c.WriteTimeout <= 0 {
c.WriteTimeout = server.DefaultWriteTimeout
}
if c.PongTimeout <= 0 {
c.PongTimeout = server.DefaultPongTimeout
}
if c.PingTimeout <= 0 {
c.PingTimeout = server.DefaultPingTimeout
}
if c.PingPeriod <= 0 {
c.PingPeriod = server.DefaultPingPeriod
}
return nil return nil
} }

View File

@ -199,33 +199,60 @@ func (s *Server) handleConnection(servlet Servlet, servletCtx server.ServletCtx,
addr := conn.RemoteAddr() addr := conn.RemoteAddr()
defer func() { defer func() {
s.connections.Delete(conn) if nil != conn {
conn.Close() conn.Close()
}
servlet.OnDisconnect(servletCtx) servlet.OnDisconnect(servletCtx)
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr))) logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been disconnected", addr)))
s.stopWg.Done() s.stopWg.Done()
}() }()
logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr))) logging.Logger().Infof(s.serverMessage(fmt.Sprintf("Client[%s] has been connected", addr)))
s.connections.Store(conn, true) s.connections.Store(conn, true)
defer s.connections.Delete(conn)
servlet.OnConnect(servletCtx, conn) servlet.OnConnect(servletCtx, conn)
servletStopChan := make(chan struct{}) stopChan := make(chan struct{})
doneChan := make(chan struct{}) servletDoneChan := make(chan struct{})
readChan := make(chan []byte) readChan := make(chan []byte)
writeChan := make(chan []byte) writeChan := make(chan []byte)
go servlet.Handle(servletCtx, doneChan, servletStopChan, readChan, writeChan) readerDoneChan := make(chan struct{})
go handleRead(s, conn, readChan) writerDoneChan := make(chan struct{})
go handleWrite(s, conn, writeChan)
go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan)
go handleRead(s, conn, stopChan, readerDoneChan, readChan)
go handleWrite(s, conn, stopChan, writerDoneChan, writeChan)
select { select {
case <-doneChan: case <-readerDoneChan:
close(servletStopChan) close(stopChan)
conn.Close()
<-writerDoneChan
<-servletDoneChan
conn = nil
case <-writerDoneChan:
close(stopChan)
conn.Close()
<-readerDoneChan
<-servletDoneChan
conn = nil
case <-servletDoneChan:
close(stopChan)
conn.Close()
<-readerDoneChan
<-writerDoneChan
conn = nil
case <-s.stopChan: case <-s.stopChan:
close(servletStopChan) close(stopChan)
<-doneChan conn.Close()
<-readerDoneChan
<-writerDoneChan
<-servletDoneChan
conn = nil
} }
} }
@ -233,7 +260,11 @@ func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) {
s.ServerHandler.OnError(s.ctx, ctx, status, reason) s.ServerHandler.OnError(s.ctx, ctx, status, reason)
} }
func handleRead(s *Server, conn *internal.Conn, readChan chan []byte) { func handleRead(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, readChan chan []byte) {
defer func() {
close(doneChan)
}()
conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize()) conn.SetReadLimit(s.ServerHandler.GetMaxMessageSize())
conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout())) conn.SetReadDeadline(time.Now().Add(s.ServerHandler.GetReadTimeout()))
conn.SetPongHandler(func(string) error { conn.SetPongHandler(func(string) error {
@ -241,19 +272,49 @@ func handleRead(s *Server, conn *internal.Conn, readChan chan []byte) {
return nil return nil
}) })
var (
message []byte
err error
)
for { for {
_, message, err := conn.ReadMessage() readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
if err != nil { if err != nil {
if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) { if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) {
logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err))) logging.Logger().Debugf(s.serverMessage(fmt.Sprintf("Read error %v", err)))
} }
break
} }
close(readMessageChan)
}()
select {
case <-s.stopChan:
<-readMessageChan
break
case <-readMessageChan:
}
if nil != err {
select {
case <-s.stopChan:
break
case <-time.After(time.Second):
}
continue
}
readChan <- message readChan <- message
} }
} }
func handleWrite(s *Server, conn *internal.Conn, writeChan chan []byte) { func handleWrite(s *Server, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}, writeChan chan []byte) {
defer func() {
close(doneChan)
}()
ticker := time.NewTicker(s.ServerHandler.GetPingPeriod()) ticker := time.NewTicker(s.ServerHandler.GetPingPeriod())
defer func() { defer func() {
ticker.Stop() ticker.Stop()
@ -281,6 +342,8 @@ func handleWrite(s *Server, conn *internal.Conn, writeChan chan []byte) {
if err := conn.WriteMessage(internal.PingMessage, nil); nil != err { if err := conn.WriteMessage(internal.PingMessage, nil); nil != err {
return return
} }
case <-s.stopChan:
break
} }
} }
} }

View File

@ -4,7 +4,12 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
logging "git.loafle.net/commons/logging-go"
server "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/internal"
) )
type Client struct { type Client struct {
@ -17,7 +22,229 @@ type Client struct {
KeepAlive time.Duration KeepAlive time.Duration
LocalAddress net.Addr LocalAddress net.Addr
MaxConnections int MaxMessageSize int64
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// Maximum duration for reading the full request (including body).
//
// This also limits the maximum duration for idle keep-alive
// connections.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for writing the full response (including body).
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
EnableCompression bool
stopChan chan struct{}
stopWg sync.WaitGroup
conn *internal.Conn
readChan chan []byte
writeChan chan []byte
}
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
var (
conn *internal.Conn
)
if c.stopChan != nil {
return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again"))
}
err = c.Validate()
if nil != err {
return nil, nil, err
}
conn, err = c.connect()
if nil != err {
return nil, nil, err
}
c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256)
c.stopChan = make(chan struct{})
c.stopWg.Add(1)
go c.handleConnection(conn)
return c.readChan, c.writeChan, nil
}
func (c *Client) Disconnect() error {
if c.stopChan == nil {
return fmt.Errorf(c.clientMessage("must be started before stopping it"))
}
close(c.stopChan)
c.stopWg.Wait()
c.stopChan = nil
return nil
}
func (c *Client) clientMessage(msg string) string {
return fmt.Sprintf("Client[%s]: %s", c.Name, msg)
}
func (c *Client) connect() (*internal.Conn, error) {
netConn, err := c.Dial()
if nil != err {
return nil, err
}
conn := internal.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize)
conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close")
return nil
})
return conn, nil
}
func (c *Client) handleConnection(conn *internal.Conn) {
defer func() {
if nil != conn {
conn.Close()
}
logging.Logger().Infof(c.clientMessage("disconnected"))
c.stopWg.Done()
}()
logging.Logger().Infof(c.clientMessage("connected"))
stopChan := make(chan struct{})
readerDoneChan := make(chan struct{})
writerDoneChan := make(chan struct{})
go handleClientRead(c, conn, stopChan, readerDoneChan)
go handleClientWrite(c, conn, stopChan, writerDoneChan)
select {
case <-readerDoneChan:
close(stopChan)
conn.Close()
<-writerDoneChan
conn = nil
case <-writerDoneChan:
close(stopChan)
conn.Close()
<-readerDoneChan
conn = nil
case <-c.stopChan:
close(stopChan)
conn.Close()
<-readerDoneChan
<-writerDoneChan
conn = nil
}
}
func handleClientRead(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) {
defer func() {
close(doneChan)
}()
conn.SetReadLimit(c.MaxMessageSize)
conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(c.PongTimeout))
return nil
})
var (
message []byte
err error
)
for {
readMessageChan := make(chan struct{})
go func() {
_, message, err = conn.ReadMessage()
if err != nil {
if internal.IsUnexpectedCloseError(err, internal.CloseGoingAway, internal.CloseAbnormalClosure) {
logging.Logger().Debugf(c.clientMessage(fmt.Sprintf("Read error %v", err)))
}
}
close(readMessageChan)
}()
select {
case <-c.stopChan:
<-readMessageChan
break
case <-readMessageChan:
}
if nil != err {
select {
case <-c.stopChan:
break
case <-time.After(time.Second):
}
continue
}
c.readChan <- message
}
}
func handleClientWrite(c *Client, conn *internal.Conn, doneChan chan<- struct{}, stopChan <-chan struct{}) {
defer func() {
close(doneChan)
}()
ticker := time.NewTicker(c.PingPeriod)
defer func() {
ticker.Stop()
}()
for {
select {
case message, ok := <-c.writeChan:
conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
if !ok {
conn.WriteMessage(internal.CloseMessage, []byte{})
return
}
w, err := conn.NextWriter(internal.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); nil != err {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(c.PingTimeout))
if err := conn.WriteMessage(internal.PingMessage, nil); nil != err {
return
}
case <-c.stopChan:
break
}
}
} }
func (c *Client) Dial() (net.Conn, error) { func (c *Client) Dial() (net.Conn, error) {
@ -72,15 +299,32 @@ func (c *Client) Validate() error {
return fmt.Errorf("Client: Address is not valid") return fmt.Errorf("Client: Address is not valid")
} }
if 0 >= c.MaxConnections { if c.HandshakeTimeout <= 0 {
c.MaxConnections = 1 c.HandshakeTimeout = server.DefaultHandshakeTimeout
} }
if c.MaxMessageSize <= 0 {
if 0 >= c.KeepAlive { c.MaxMessageSize = server.DefaultMaxMessageSize
c.KeepAlive = DefaultKeepAlive
} }
if 0 >= c.HandshakeTimeout { if c.ReadBufferSize <= 0 {
c.HandshakeTimeout = DefaultHandshakeTimeout c.ReadBufferSize = server.DefaultReadBufferSize
}
if c.WriteBufferSize <= 0 {
c.WriteBufferSize = server.DefaultWriteBufferSize
}
if c.ReadTimeout <= 0 {
c.ReadTimeout = server.DefaultReadTimeout
}
if c.WriteTimeout <= 0 {
c.WriteTimeout = server.DefaultWriteTimeout
}
if c.PongTimeout <= 0 {
c.PongTimeout = server.DefaultPongTimeout
}
if c.PingTimeout <= 0 {
c.PingTimeout = server.DefaultPingTimeout
}
if c.PingPeriod <= 0 {
c.PingPeriod = server.DefaultPingPeriod
} }
return nil return nil

View File

@ -1,7 +0,0 @@
package net
const (
DefaultHandshakeTimeout = 0
DefaultKeepAlive = 0
)