This commit is contained in:
crusader 2018-04-12 14:55:01 +09:00
parent 17ef0e1833
commit 9b450e0195
9 changed files with 183 additions and 103 deletions

View File

@ -4,5 +4,30 @@ type Connector interface {
Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error)
Disconnect() error
GetName() string
Clone() Connector
Validate() error
}
type Connectors struct {
Name string `json:"name"`
}
func (c *Connectors) GetName() string {
return c.Name
}
func (c *Connectors) Clone() *Connectors {
return &Connectors{
Name: c.Name,
}
}
func (c *Connectors) Validate() error {
if "" == c.Name {
c.Name = "Connector"
}
return nil
}

View File

@ -52,6 +52,17 @@ func (ch *ConnectionHandlers) GetTLSConfig() *tls.Config {
return ch.TLSConfig
}
func (ch *ConnectionHandlers) Clone() *ConnectionHandlers {
return &ConnectionHandlers{
Network: ch.Network,
Address: ch.Address,
Concurrency: ch.Concurrency,
KeepAlive: ch.KeepAlive,
HandshakeTimeout: ch.HandshakeTimeout,
TLSConfig: ch.TLSConfig,
}
}
func (ch *ConnectionHandlers) Validate() error {
if ch.Concurrency <= 0 {
ch.Concurrency = DefaultConcurrency

View File

@ -13,7 +13,7 @@ type ReadWriteHandler interface {
}
type ReadWriteHandlers struct {
MaxMessageSize int64 `json:"maxMessageSize"`
MaxMessageSize int64 `json:"maxMessageSize"`
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
@ -21,23 +21,23 @@ type ReadWriteHandlers struct {
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int `json:"readBufferSize"`
ReadBufferSize int `json:"readBufferSize"`
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int `json:"writeBufferSize"`
WriteBufferSize int `json:"writeBufferSize"`
// Maximum duration for reading the full request (including body).
//
// This also limits the maximum duration for idle keep-alive
// connections.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration `json:"readTimeout"`
ReadTimeout time.Duration `json:"readTimeout"`
// Maximum duration for writing the full response (including body).
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration `json:"writeTimeout"`
WriteTimeout time.Duration `json:"writeTimeout"`
}
func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 {
@ -56,6 +56,16 @@ func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration {
return rwh.WriteTimeout
}
func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers {
return &ReadWriteHandlers{
MaxMessageSize: rwh.MaxMessageSize,
ReadBufferSize: rwh.ReadBufferSize,
WriteBufferSize: rwh.WriteBufferSize,
ReadTimeout: rwh.ReadTimeout,
WriteTimeout: rwh.WriteTimeout,
}
}
func (rwh *ReadWriteHandlers) Validate() error {
if rwh.MaxMessageSize <= 0 {
rwh.MaxMessageSize = DefaultMaxMessageSize

View File

@ -20,7 +20,7 @@ type ServerHandlers struct {
// Server name for sending in response headers.
//
// Default server name is used if left blank.
Name string `json:"name"`
Name string `json:"name"`
}
func (sh *ServerHandlers) ServerCtx() ServerCtx {

View File

@ -16,8 +16,8 @@ type ClientConnHandler interface {
type ClientConnHandlers struct {
server.ConnectionHandlers
ReconnectInterval time.Duration `json:"reconnectInterval"`
ReconnectTryTime int `json:"reconnectTryTime"`
ReconnectInterval time.Duration `json:"reconnectInterval"`
ReconnectTryTime int `json:"reconnectTryTime"`
}
func (cch *ClientConnHandlers) GetReconnectInterval() time.Duration {
@ -28,6 +28,14 @@ func (cch *ClientConnHandlers) GetReconnectTryTime() int {
return cch.ReconnectTryTime
}
func (cch *ClientConnHandlers) Clone() *ClientConnHandlers {
return &ClientConnHandlers{
ConnectionHandlers: *cch.ConnectionHandlers.Clone(),
ReconnectInterval: cch.ReconnectInterval,
ReconnectTryTime: cch.ReconnectTryTime,
}
}
func (cch *ClientConnHandlers) Validate() error {
if err := cch.ConnectionHandlers.Validate(); nil != err {
return err

View File

@ -13,13 +13,10 @@ import (
)
type Connectors struct {
client.Connector
client.Connectors
socket.ClientConnHandlers
socket.ReadWriteHandlers
Name string `json:"name"`
Network string `json:"network"`
Address string `json:"address"`
LocalAddress net.Addr
@ -41,12 +38,11 @@ func (c *Connectors) Connect() (readChan <-chan []byte, writeChan chan<- []byte,
conn socket.Conn
)
if c.stopChan != nil {
if nil != c.stopChan {
return nil, nil, fmt.Errorf("%s already connected", c.logHeader())
}
err = c.Validate()
if nil != err {
if err := c.Validate(); nil != err {
return nil, nil, err
}
@ -105,14 +101,14 @@ RC_LOOP:
return
}
if 0 >= c.ReconnectTryTime {
if 0 >= c.GetReconnectTryTime() {
c.reconnectedChan <- nil
continue RC_LOOP
}
logging.Logger().Debugf("%s connection lost", c.logHeader())
for indexI := 0; indexI < c.ReconnectTryTime; indexI++ {
for indexI := 0; indexI < c.GetReconnectTryTime(); indexI++ {
logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI)
conn, err := c.connect()
@ -121,7 +117,7 @@ RC_LOOP:
c.reconnectedChan <- conn
continue RC_LOOP
}
time.Sleep(c.ReconnectInterval)
time.Sleep(c.GetReconnectInterval())
}
logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader())
}
@ -133,7 +129,7 @@ func (c *Connectors) connect() (socket.Conn, error) {
return nil, err
}
conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize)
conn := socket.NewConn(netConn, false, c.GetReadBufferSize(), c.GetWriteBufferSize())
conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("%s close", c.logHeader())
return nil
@ -142,17 +138,13 @@ func (c *Connectors) connect() (socket.Conn, error) {
}
func (c *Connectors) dial() (net.Conn, error) {
if err := c.Validate(); nil != err {
return nil, err
}
var deadline time.Time
if 0 != c.HandshakeTimeout {
deadline = time.Now().Add(c.HandshakeTimeout)
if 0 != c.GetHandshakeTimeout() {
deadline = time.Now().Add(c.GetHandshakeTimeout())
}
d := &net.Dialer{
KeepAlive: c.KeepAlive,
KeepAlive: c.GetKeepAlive(),
Deadline: deadline,
LocalAddr: c.LocalAddress,
}
@ -162,8 +154,8 @@ func (c *Connectors) dial() (net.Conn, error) {
return nil, err
}
if nil != c.TLSConfig {
cfg := c.TLSConfig.Clone()
if nil != c.GetTLSConfig() {
cfg := c.GetTLSConfig().Clone()
tlsConn := tls.Client(conn, cfg)
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close()
@ -180,7 +172,21 @@ func (c *Connectors) dial() (net.Conn, error) {
return conn, nil
}
func (c *Connectors) Clone() *Connectors {
return &Connectors{
Connectors: *c.Connectors.Clone(),
ClientConnHandlers: *c.ClientConnHandlers.Clone(),
ReadWriteHandlers: *c.ReadWriteHandlers.Clone(),
Network: c.Network,
Address: c.Address,
LocalAddress: c.LocalAddress,
}
}
func (c *Connectors) Validate() error {
if err := c.Connectors.Validate(); nil != err {
return err
}
if err := c.ClientConnHandlers.Validate(); nil != err {
return err
}
@ -188,10 +194,6 @@ func (c *Connectors) Validate() error {
return err
}
if "" == c.Name {
c.Name = "Connector"
}
if "" == c.Network {
return fmt.Errorf("%s Network is not valid", c.logHeader())
}

View File

@ -37,6 +37,17 @@ func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression
}
func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers {
return &ReadWriteHandlers{
ReadWriteHandlers: *rwh.ReadWriteHandlers.Clone(),
PongTimeout: rwh.PongTimeout,
PingTimeout: rwh.PingTimeout,
PingPeriod: rwh.PingPeriod,
EnableCompression: rwh.EnableCompression,
}
}
func (rwh *ReadWriteHandlers) Validate() error {
if err := rwh.ReadWriteHandlers.Validate(); nil != err {
return err

View File

@ -25,12 +25,10 @@ import (
var errMalformedURL = errors.New("malformed ws or wss URL")
type Connectors struct {
client.Connector
client.Connectors
socket.ClientConnHandlers
socket.ReadWriteHandlers
Name string `json:"name"`
URL string `json:"url"`
RequestHeader http.Header
@ -75,9 +73,7 @@ func (c *Connectors) Connect() (readChan <-chan []byte, writeChan chan<- []byte,
if c.stopChan != nil {
return nil, nil, fmt.Errorf("%s already connected", c.logHeader())
}
err = c.Validate()
if nil != err {
if err := c.Validate(); nil != err {
return nil, nil, err
}
@ -140,14 +136,14 @@ RC_LOOP:
return
}
if 0 >= c.ReconnectTryTime {
if 0 >= c.GetReconnectTryTime() {
c.reconnectedChan <- nil
continue RC_LOOP
}
logging.Logger().Debugf("%s connection lost", c.logHeader())
for indexI := 0; indexI < c.ReconnectTryTime; indexI++ {
for indexI := 0; indexI < c.GetReconnectTryTime(); indexI++ {
logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI)
conn, res, err := c.connect()
@ -161,7 +157,7 @@ RC_LOOP:
c.reconnectedChan <- conn
continue RC_LOOP
}
time.Sleep(c.ReconnectInterval)
time.Sleep(c.GetReconnectInterval())
}
logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader())
}
@ -187,10 +183,6 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
netConn net.Conn
)
if err = c.Validate(); nil != err {
return nil, nil, err
}
challengeKey, err = web.GenerateChallengeKey()
if err != nil {
return nil, nil, err
@ -206,9 +198,10 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
Host: c.serverURL.Host,
}
cookieJar := c.CookieJar
// Set the cookies present in the cookie jar of the dialer
if nil != c.CookieJar {
for _, cookie := range c.CookieJar.Cookies(c.serverURL) {
if nil != cookieJar {
for _, cookie := range cookieJar.Cookies(c.serverURL) {
req.AddCookie(cookie)
}
}
@ -221,8 +214,11 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(c.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(c.Subprotocols, ", ")}
subprotocols := c.Subprotocols
if len(subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(subprotocols, ", ")}
}
for k, vs := range c.RequestHeader {
switch {
@ -235,14 +231,14 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0):
(k == "Sec-Websocket-Protocol" && len(subprotocols) > 0):
return nil, nil, fmt.Errorf("%s duplicate header not allowed: %s", c.logHeader(), k)
default:
req.Header[k] = vs
}
}
if c.EnableCompression {
if c.IsEnableCompression() {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
@ -250,8 +246,9 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
var proxyURL *url.URL
// Check wether the proxy method has been configured
if nil != c.Proxy {
proxyURL, err = c.Proxy(req)
proxy := c.Proxy
if nil != proxy {
proxyURL, err = proxy(req)
if err != nil {
return nil, nil, err
}
@ -265,8 +262,9 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
}
var deadline time.Time
if 0 != c.HandshakeTimeout {
deadline = time.Now().Add(c.HandshakeTimeout)
handshakeTimeout := c.GetHandshakeTimeout()
if 0 != handshakeTimeout {
deadline = time.Now().Add(handshakeTimeout)
}
netDial := c.NetDial
@ -324,7 +322,7 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
}
if "https" == c.serverURL.Scheme {
cfg := cloneTLSConfig(c.TLSConfig)
cfg := cloneTLSConfig(c.GetTLSConfig())
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
@ -340,7 +338,7 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
}
}
conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize)
conn := socket.NewConn(netConn, false, c.GetReadBufferSize(), c.GetWriteBufferSize())
if err := req.Write(netConn); err != nil {
return nil, nil, err
@ -351,9 +349,9 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
return nil, nil, err
}
if nil != c.CookieJar {
if nil != cookieJar {
if rc := resp.Cookies(); len(rc) > 0 {
c.CookieJar.SetCookies(c.serverURL, rc)
cookieJar.SetCookies(c.serverURL, rc)
}
}
@ -393,48 +391,6 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
return conn, resp, nil
}
func (c *Connectors) Validate() error {
if err := c.ClientConnHandlers.Validate(); nil != err {
return err
}
if err := c.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == c.Name {
c.Name = "Connector"
}
if "" == c.URL {
return fmt.Errorf("%s URL is not valid", c.logHeader())
}
u, err := parseURL(c.URL)
if nil != err {
return err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return errMalformedURL
}
if nil != u.User {
// User name and password are not allowed in websocket URIs.
return errMalformedURL
}
c.serverURL = u
if nil == c.Proxy {
c.Proxy = http.ProxyFromEnvironment
}
return nil
}
// parseURL parses the URL.
//
// This function is a replacement for the standard library url.Parse function.
@ -503,3 +459,60 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
}
return cfg.Clone()
}
func (c *Connectors) Clone() *Connectors {
return &Connectors{
Connectors: *c.Connectors.Clone(),
ClientConnHandlers: *c.ClientConnHandlers.Clone(),
ReadWriteHandlers: *c.ReadWriteHandlers.Clone(),
URL: c.URL,
RequestHeader: c.RequestHeader,
Subprotocols: c.Subprotocols,
CookieJar: c.CookieJar,
ResponseHandler: c.ResponseHandler,
NetDial: c.NetDial,
Proxy: c.Proxy,
serverURL: c.serverURL,
}
}
func (c *Connectors) Validate() error {
if err := c.Connectors.Validate(); nil != err {
return err
}
if err := c.ClientConnHandlers.Validate(); nil != err {
return err
}
if err := c.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == c.URL {
return fmt.Errorf("URL is not valid")
}
u, err := parseURL(c.URL)
if nil != err {
return err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return errMalformedURL
}
if nil != u.User {
// User name and password are not allowed in websocket URIs.
return errMalformedURL
}
c.serverURL = u
if nil == c.Proxy {
c.Proxy = http.ProxyFromEnvironment
}
return nil
}

View File

@ -136,7 +136,7 @@ func getContextPath(path string) (string, error) {
p := strings.TrimSpace(path)
if !strings.HasPrefix(p, "/") {
return "", fmt.Errorf("The path[%s] must started /", path)
return "", fmt.Errorf("path[%s] must started /", path)
}
p = p[1:]
@ -147,7 +147,7 @@ func getContextPath(path string) (string, error) {
components := strings.Split(p, "/")
if 0 == len(components) {
return "", fmt.Errorf("The path[%s] is not invalid", path)
return "", fmt.Errorf("path[%s] is not invalid", path)
}
return fmt.Sprintf("/%s", components[0]), nil