This commit is contained in:
crusader 2018-04-12 14:55:01 +09:00
parent 17ef0e1833
commit 9b450e0195
9 changed files with 183 additions and 103 deletions

View File

@ -4,5 +4,30 @@ type Connector interface {
Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error)
Disconnect() error Disconnect() error
GetName() string
Clone() Connector
Validate() error Validate() error
} }
type Connectors struct {
Name string `json:"name"`
}
func (c *Connectors) GetName() string {
return c.Name
}
func (c *Connectors) Clone() *Connectors {
return &Connectors{
Name: c.Name,
}
}
func (c *Connectors) Validate() error {
if "" == c.Name {
c.Name = "Connector"
}
return nil
}

View File

@ -52,6 +52,17 @@ func (ch *ConnectionHandlers) GetTLSConfig() *tls.Config {
return ch.TLSConfig return ch.TLSConfig
} }
func (ch *ConnectionHandlers) Clone() *ConnectionHandlers {
return &ConnectionHandlers{
Network: ch.Network,
Address: ch.Address,
Concurrency: ch.Concurrency,
KeepAlive: ch.KeepAlive,
HandshakeTimeout: ch.HandshakeTimeout,
TLSConfig: ch.TLSConfig,
}
}
func (ch *ConnectionHandlers) Validate() error { func (ch *ConnectionHandlers) Validate() error {
if ch.Concurrency <= 0 { if ch.Concurrency <= 0 {
ch.Concurrency = DefaultConcurrency ch.Concurrency = DefaultConcurrency

View File

@ -56,6 +56,16 @@ func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration {
return rwh.WriteTimeout return rwh.WriteTimeout
} }
func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers {
return &ReadWriteHandlers{
MaxMessageSize: rwh.MaxMessageSize,
ReadBufferSize: rwh.ReadBufferSize,
WriteBufferSize: rwh.WriteBufferSize,
ReadTimeout: rwh.ReadTimeout,
WriteTimeout: rwh.WriteTimeout,
}
}
func (rwh *ReadWriteHandlers) Validate() error { func (rwh *ReadWriteHandlers) Validate() error {
if rwh.MaxMessageSize <= 0 { if rwh.MaxMessageSize <= 0 {
rwh.MaxMessageSize = DefaultMaxMessageSize rwh.MaxMessageSize = DefaultMaxMessageSize

View File

@ -28,6 +28,14 @@ func (cch *ClientConnHandlers) GetReconnectTryTime() int {
return cch.ReconnectTryTime return cch.ReconnectTryTime
} }
func (cch *ClientConnHandlers) Clone() *ClientConnHandlers {
return &ClientConnHandlers{
ConnectionHandlers: *cch.ConnectionHandlers.Clone(),
ReconnectInterval: cch.ReconnectInterval,
ReconnectTryTime: cch.ReconnectTryTime,
}
}
func (cch *ClientConnHandlers) Validate() error { func (cch *ClientConnHandlers) Validate() error {
if err := cch.ConnectionHandlers.Validate(); nil != err { if err := cch.ConnectionHandlers.Validate(); nil != err {
return err return err

View File

@ -13,13 +13,10 @@ import (
) )
type Connectors struct { type Connectors struct {
client.Connector client.Connectors
socket.ClientConnHandlers socket.ClientConnHandlers
socket.ReadWriteHandlers socket.ReadWriteHandlers
Name string `json:"name"`
Network string `json:"network"` Network string `json:"network"`
Address string `json:"address"` Address string `json:"address"`
LocalAddress net.Addr LocalAddress net.Addr
@ -41,12 +38,11 @@ func (c *Connectors) Connect() (readChan <-chan []byte, writeChan chan<- []byte,
conn socket.Conn conn socket.Conn
) )
if c.stopChan != nil { if nil != c.stopChan {
return nil, nil, fmt.Errorf("%s already connected", c.logHeader()) return nil, nil, fmt.Errorf("%s already connected", c.logHeader())
} }
err = c.Validate() if err := c.Validate(); nil != err {
if nil != err {
return nil, nil, err return nil, nil, err
} }
@ -105,14 +101,14 @@ RC_LOOP:
return return
} }
if 0 >= c.ReconnectTryTime { if 0 >= c.GetReconnectTryTime() {
c.reconnectedChan <- nil c.reconnectedChan <- nil
continue RC_LOOP continue RC_LOOP
} }
logging.Logger().Debugf("%s connection lost", c.logHeader()) logging.Logger().Debugf("%s connection lost", c.logHeader())
for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { for indexI := 0; indexI < c.GetReconnectTryTime(); indexI++ {
logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI) logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI)
conn, err := c.connect() conn, err := c.connect()
@ -121,7 +117,7 @@ RC_LOOP:
c.reconnectedChan <- conn c.reconnectedChan <- conn
continue RC_LOOP continue RC_LOOP
} }
time.Sleep(c.ReconnectInterval) time.Sleep(c.GetReconnectInterval())
} }
logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader()) logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader())
} }
@ -133,7 +129,7 @@ func (c *Connectors) connect() (socket.Conn, error) {
return nil, err return nil, err
} }
conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) conn := socket.NewConn(netConn, false, c.GetReadBufferSize(), c.GetWriteBufferSize())
conn.SetCloseHandler(func(code int, text string) error { conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("%s close", c.logHeader()) logging.Logger().Debugf("%s close", c.logHeader())
return nil return nil
@ -142,17 +138,13 @@ func (c *Connectors) connect() (socket.Conn, error) {
} }
func (c *Connectors) dial() (net.Conn, error) { func (c *Connectors) dial() (net.Conn, error) {
if err := c.Validate(); nil != err {
return nil, err
}
var deadline time.Time var deadline time.Time
if 0 != c.HandshakeTimeout { if 0 != c.GetHandshakeTimeout() {
deadline = time.Now().Add(c.HandshakeTimeout) deadline = time.Now().Add(c.GetHandshakeTimeout())
} }
d := &net.Dialer{ d := &net.Dialer{
KeepAlive: c.KeepAlive, KeepAlive: c.GetKeepAlive(),
Deadline: deadline, Deadline: deadline,
LocalAddr: c.LocalAddress, LocalAddr: c.LocalAddress,
} }
@ -162,8 +154,8 @@ func (c *Connectors) dial() (net.Conn, error) {
return nil, err return nil, err
} }
if nil != c.TLSConfig { if nil != c.GetTLSConfig() {
cfg := c.TLSConfig.Clone() cfg := c.GetTLSConfig().Clone()
tlsConn := tls.Client(conn, cfg) tlsConn := tls.Client(conn, cfg)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
tlsConn.Close() tlsConn.Close()
@ -180,7 +172,21 @@ func (c *Connectors) dial() (net.Conn, error) {
return conn, nil return conn, nil
} }
func (c *Connectors) Clone() *Connectors {
return &Connectors{
Connectors: *c.Connectors.Clone(),
ClientConnHandlers: *c.ClientConnHandlers.Clone(),
ReadWriteHandlers: *c.ReadWriteHandlers.Clone(),
Network: c.Network,
Address: c.Address,
LocalAddress: c.LocalAddress,
}
}
func (c *Connectors) Validate() error { func (c *Connectors) Validate() error {
if err := c.Connectors.Validate(); nil != err {
return err
}
if err := c.ClientConnHandlers.Validate(); nil != err { if err := c.ClientConnHandlers.Validate(); nil != err {
return err return err
} }
@ -188,10 +194,6 @@ func (c *Connectors) Validate() error {
return err return err
} }
if "" == c.Name {
c.Name = "Connector"
}
if "" == c.Network { if "" == c.Network {
return fmt.Errorf("%s Network is not valid", c.logHeader()) return fmt.Errorf("%s Network is not valid", c.logHeader())
} }

View File

@ -37,6 +37,17 @@ func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
func (rwh *ReadWriteHandlers) IsEnableCompression() bool { func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression return rwh.EnableCompression
} }
func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers {
return &ReadWriteHandlers{
ReadWriteHandlers: *rwh.ReadWriteHandlers.Clone(),
PongTimeout: rwh.PongTimeout,
PingTimeout: rwh.PingTimeout,
PingPeriod: rwh.PingPeriod,
EnableCompression: rwh.EnableCompression,
}
}
func (rwh *ReadWriteHandlers) Validate() error { func (rwh *ReadWriteHandlers) Validate() error {
if err := rwh.ReadWriteHandlers.Validate(); nil != err { if err := rwh.ReadWriteHandlers.Validate(); nil != err {
return err return err

View File

@ -25,12 +25,10 @@ import (
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
type Connectors struct { type Connectors struct {
client.Connector client.Connectors
socket.ClientConnHandlers socket.ClientConnHandlers
socket.ReadWriteHandlers socket.ReadWriteHandlers
Name string `json:"name"`
URL string `json:"url"` URL string `json:"url"`
RequestHeader http.Header RequestHeader http.Header
@ -75,9 +73,7 @@ func (c *Connectors) Connect() (readChan <-chan []byte, writeChan chan<- []byte,
if c.stopChan != nil { if c.stopChan != nil {
return nil, nil, fmt.Errorf("%s already connected", c.logHeader()) return nil, nil, fmt.Errorf("%s already connected", c.logHeader())
} }
if err := c.Validate(); nil != err {
err = c.Validate()
if nil != err {
return nil, nil, err return nil, nil, err
} }
@ -140,14 +136,14 @@ RC_LOOP:
return return
} }
if 0 >= c.ReconnectTryTime { if 0 >= c.GetReconnectTryTime() {
c.reconnectedChan <- nil c.reconnectedChan <- nil
continue RC_LOOP continue RC_LOOP
} }
logging.Logger().Debugf("%s connection lost", c.logHeader()) logging.Logger().Debugf("%s connection lost", c.logHeader())
for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { for indexI := 0; indexI < c.GetReconnectTryTime(); indexI++ {
logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI) logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI)
conn, res, err := c.connect() conn, res, err := c.connect()
@ -161,7 +157,7 @@ RC_LOOP:
c.reconnectedChan <- conn c.reconnectedChan <- conn
continue RC_LOOP continue RC_LOOP
} }
time.Sleep(c.ReconnectInterval) time.Sleep(c.GetReconnectInterval())
} }
logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader()) logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader())
} }
@ -187,10 +183,6 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
netConn net.Conn netConn net.Conn
) )
if err = c.Validate(); nil != err {
return nil, nil, err
}
challengeKey, err = web.GenerateChallengeKey() challengeKey, err = web.GenerateChallengeKey()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -206,9 +198,10 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
Host: c.serverURL.Host, Host: c.serverURL.Host,
} }
cookieJar := c.CookieJar
// Set the cookies present in the cookie jar of the dialer // Set the cookies present in the cookie jar of the dialer
if nil != c.CookieJar { if nil != cookieJar {
for _, cookie := range c.CookieJar.Cookies(c.serverURL) { for _, cookie := range cookieJar.Cookies(c.serverURL) {
req.AddCookie(cookie) req.AddCookie(cookie)
} }
} }
@ -221,8 +214,11 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
req.Header["Connection"] = []string{"Upgrade"} req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey} req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"} req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(c.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(c.Subprotocols, ", ")} subprotocols := c.Subprotocols
if len(subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(subprotocols, ", ")}
} }
for k, vs := range c.RequestHeader { for k, vs := range c.RequestHeader {
switch { switch {
@ -235,14 +231,14 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
k == "Sec-Websocket-Key" || k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" || k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" || k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0): (k == "Sec-Websocket-Protocol" && len(subprotocols) > 0):
return nil, nil, fmt.Errorf("%s duplicate header not allowed: %s", c.logHeader(), k) return nil, nil, fmt.Errorf("%s duplicate header not allowed: %s", c.logHeader(), k)
default: default:
req.Header[k] = vs req.Header[k] = vs
} }
} }
if c.EnableCompression { if c.IsEnableCompression() {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
} }
@ -250,8 +246,9 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
var proxyURL *url.URL var proxyURL *url.URL
// Check wether the proxy method has been configured // Check wether the proxy method has been configured
if nil != c.Proxy { proxy := c.Proxy
proxyURL, err = c.Proxy(req) if nil != proxy {
proxyURL, err = proxy(req)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -265,8 +262,9 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
} }
var deadline time.Time var deadline time.Time
if 0 != c.HandshakeTimeout { handshakeTimeout := c.GetHandshakeTimeout()
deadline = time.Now().Add(c.HandshakeTimeout) if 0 != handshakeTimeout {
deadline = time.Now().Add(handshakeTimeout)
} }
netDial := c.NetDial netDial := c.NetDial
@ -324,7 +322,7 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
} }
if "https" == c.serverURL.Scheme { if "https" == c.serverURL.Scheme {
cfg := cloneTLSConfig(c.TLSConfig) cfg := cloneTLSConfig(c.GetTLSConfig())
if cfg.ServerName == "" { if cfg.ServerName == "" {
cfg.ServerName = hostNoPort cfg.ServerName = hostNoPort
} }
@ -340,7 +338,7 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
} }
} }
conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) conn := socket.NewConn(netConn, false, c.GetReadBufferSize(), c.GetWriteBufferSize())
if err := req.Write(netConn); err != nil { if err := req.Write(netConn); err != nil {
return nil, nil, err return nil, nil, err
@ -351,9 +349,9 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
return nil, nil, err return nil, nil, err
} }
if nil != c.CookieJar { if nil != cookieJar {
if rc := resp.Cookies(); len(rc) > 0 { if rc := resp.Cookies(); len(rc) > 0 {
c.CookieJar.SetCookies(c.serverURL, rc) cookieJar.SetCookies(c.serverURL, rc)
} }
} }
@ -393,48 +391,6 @@ func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
return conn, resp, nil return conn, resp, nil
} }
func (c *Connectors) Validate() error {
if err := c.ClientConnHandlers.Validate(); nil != err {
return err
}
if err := c.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == c.Name {
c.Name = "Connector"
}
if "" == c.URL {
return fmt.Errorf("%s URL is not valid", c.logHeader())
}
u, err := parseURL(c.URL)
if nil != err {
return err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return errMalformedURL
}
if nil != u.User {
// User name and password are not allowed in websocket URIs.
return errMalformedURL
}
c.serverURL = u
if nil == c.Proxy {
c.Proxy = http.ProxyFromEnvironment
}
return nil
}
// parseURL parses the URL. // parseURL parses the URL.
// //
// This function is a replacement for the standard library url.Parse function. // This function is a replacement for the standard library url.Parse function.
@ -503,3 +459,60 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
} }
return cfg.Clone() return cfg.Clone()
} }
func (c *Connectors) Clone() *Connectors {
return &Connectors{
Connectors: *c.Connectors.Clone(),
ClientConnHandlers: *c.ClientConnHandlers.Clone(),
ReadWriteHandlers: *c.ReadWriteHandlers.Clone(),
URL: c.URL,
RequestHeader: c.RequestHeader,
Subprotocols: c.Subprotocols,
CookieJar: c.CookieJar,
ResponseHandler: c.ResponseHandler,
NetDial: c.NetDial,
Proxy: c.Proxy,
serverURL: c.serverURL,
}
}
func (c *Connectors) Validate() error {
if err := c.Connectors.Validate(); nil != err {
return err
}
if err := c.ClientConnHandlers.Validate(); nil != err {
return err
}
if err := c.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == c.URL {
return fmt.Errorf("URL is not valid")
}
u, err := parseURL(c.URL)
if nil != err {
return err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return errMalformedURL
}
if nil != u.User {
// User name and password are not allowed in websocket URIs.
return errMalformedURL
}
c.serverURL = u
if nil == c.Proxy {
c.Proxy = http.ProxyFromEnvironment
}
return nil
}

View File

@ -136,7 +136,7 @@ func getContextPath(path string) (string, error) {
p := strings.TrimSpace(path) p := strings.TrimSpace(path)
if !strings.HasPrefix(p, "/") { if !strings.HasPrefix(p, "/") {
return "", fmt.Errorf("The path[%s] must started /", path) return "", fmt.Errorf("path[%s] must started /", path)
} }
p = p[1:] p = p[1:]
@ -147,7 +147,7 @@ func getContextPath(path string) (string, error) {
components := strings.Split(p, "/") components := strings.Split(p, "/")
if 0 == len(components) { if 0 == len(components) {
return "", fmt.Errorf("The path[%s] is not invalid", path) return "", fmt.Errorf("path[%s] is not invalid", path)
} }
return fmt.Sprintf("/%s", components[0]), nil return fmt.Sprintf("/%s", components[0]), nil