diff --git a/client/socket.go b/client/socket.go index 9bb1595..d949ac9 100644 --- a/client/socket.go +++ b/client/socket.go @@ -1,6 +1,7 @@ package client import ( + "crypto/tls" "net" "sync" @@ -31,20 +32,30 @@ func NewSocket(sb SocketBuilder, parentContext cuc.Context) (Socket, error) { } sh.Validate() - d := &net.Dialer{} - d.Timeout = sb.GetTimeout() - d.KeepAlive = sb.GetKeepAlive() - d.LocalAddr = sb.GetLocalAddress() - network := sb.GetNetwork() address := sb.GetAddress() - conn, err := sb.Dial(d, network, address) - + conn, err := sb.Dial(network, address) if nil != err { return nil, err } + tlsConfig := sb.GetTLSConfig() + if nil != tlsConfig { + cfg := tlsConfig.Clone() + tlsConn := tls.Client(conn, cfg) + if err := tlsConn.Handshake(); err != nil { + tlsConn.Close() + return nil, err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return nil, err + } + } + conn = tlsConn + } + sh.OnConnect(sc, conn) s := retainSocket() diff --git a/client/socket_builder.go b/client/socket_builder.go index 148db21..593b1c4 100644 --- a/client/socket_builder.go +++ b/client/socket_builder.go @@ -12,14 +12,14 @@ type SocketBuilder interface { SocketContext(parent cuc.Context) SocketContext SocketHandler() SocketHandler - Dial(dialer *net.Dialer, network, address string) (net.Conn, error) + Dial(network, address string) (net.Conn, error) GetNetwork() string GetAddress() string GetTLSConfig() *tls.Config + GetHandshakeTimeout() time.Duration GetKeepAlive() time.Duration - GetTimeout() time.Duration GetLocalAddress() net.Addr // Validate is check handler value diff --git a/client/socket_builders.go b/client/socket_builders.go index 12c53f1..e497b33 100644 --- a/client/socket_builders.go +++ b/client/socket_builders.go @@ -15,8 +15,8 @@ type SocketBuilders struct { Address string TLSConfig *tls.Config - KeepAlive time.Duration - Timeout time.Duration + HandshakeTimeout time.Duration + KeepAlive time.Duration LocalAddress net.Addr } @@ -41,20 +41,27 @@ func (sb *SocketBuilders) GetTLSConfig() *tls.Config { return sb.TLSConfig } -func (sb *SocketBuilders) Dial(dialer *net.Dialer, network, address string) (net.Conn, error) { - if nil == sb.TLSConfig { - return dialer.Dial(network, address) +func (sb *SocketBuilders) Dial(network, address string) (net.Conn, error) { + var deadline time.Time + if 0 != sb.HandshakeTimeout { + deadline = time.Now().Add(sb.HandshakeTimeout) } - return tls.DialWithDialer(dialer, network, address, sb.TLSConfig) + d := &net.Dialer{ + KeepAlive: sb.KeepAlive, + Deadline: deadline, + LocalAddr: sb.LocalAddress, + } + + return d.Dial(network, address) } func (sb *SocketBuilders) GetKeepAlive() time.Duration { return sb.KeepAlive } -func (sb *SocketBuilders) GetTimeout() time.Duration { - return sb.Timeout +func (sb *SocketBuilders) GetHandshakeTimeout() time.Duration { + return sb.HandshakeTimeout } func (sb *SocketBuilders) GetLocalAddress() net.Addr { @@ -72,8 +79,8 @@ func (sb *SocketBuilders) Validate() { if 0 >= sb.KeepAlive { sb.KeepAlive = cs.DefaultKeepAlive } - if 0 >= sb.Timeout { - sb.Timeout = cs.DefaultConnectTimeout + if 0 >= sb.HandshakeTimeout { + sb.HandshakeTimeout = cs.DefaultHandshakeTimeout } } diff --git a/constants.go b/constants.go index 10422a6..8d93bb8 100644 --- a/constants.go +++ b/constants.go @@ -31,7 +31,7 @@ const ( // DefaultWriteTimeout is default value of write timeout DefaultWriteTimeout = 0 - DefaultConnectTimeout = 0 + DefaultHandshakeTimeout = 0 DefaultKeepAlive = 0 )