This commit is contained in:
crusader 2017-12-01 12:42:03 +09:00
parent 0e01920fe9
commit 53a7e28c86
4 changed files with 38 additions and 20 deletions

View File

@ -1,6 +1,7 @@
package client
import (
"crypto/tls"
"net"
"sync"
@ -31,20 +32,30 @@ func NewSocket(sb SocketBuilder, parentContext cuc.Context) (Socket, error) {
}
sh.Validate()
d := &net.Dialer{}
d.Timeout = sb.GetTimeout()
d.KeepAlive = sb.GetKeepAlive()
d.LocalAddr = sb.GetLocalAddress()
network := sb.GetNetwork()
address := sb.GetAddress()
conn, err := sb.Dial(d, network, address)
conn, err := sb.Dial(network, address)
if nil != err {
return nil, err
}
tlsConfig := sb.GetTLSConfig()
if nil != tlsConfig {
cfg := tlsConfig.Clone()
tlsConn := tls.Client(conn, cfg)
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close()
return nil, err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
conn = tlsConn
}
sh.OnConnect(sc, conn)
s := retainSocket()

View File

@ -12,14 +12,14 @@ type SocketBuilder interface {
SocketContext(parent cuc.Context) SocketContext
SocketHandler() SocketHandler
Dial(dialer *net.Dialer, network, address string) (net.Conn, error)
Dial(network, address string) (net.Conn, error)
GetNetwork() string
GetAddress() string
GetTLSConfig() *tls.Config
GetHandshakeTimeout() time.Duration
GetKeepAlive() time.Duration
GetTimeout() time.Duration
GetLocalAddress() net.Addr
// Validate is check handler value

View File

@ -15,8 +15,8 @@ type SocketBuilders struct {
Address string
TLSConfig *tls.Config
HandshakeTimeout time.Duration
KeepAlive time.Duration
Timeout time.Duration
LocalAddress net.Addr
}
@ -41,20 +41,27 @@ func (sb *SocketBuilders) GetTLSConfig() *tls.Config {
return sb.TLSConfig
}
func (sb *SocketBuilders) Dial(dialer *net.Dialer, network, address string) (net.Conn, error) {
if nil == sb.TLSConfig {
return dialer.Dial(network, address)
func (sb *SocketBuilders) Dial(network, address string) (net.Conn, error) {
var deadline time.Time
if 0 != sb.HandshakeTimeout {
deadline = time.Now().Add(sb.HandshakeTimeout)
}
return tls.DialWithDialer(dialer, network, address, sb.TLSConfig)
d := &net.Dialer{
KeepAlive: sb.KeepAlive,
Deadline: deadline,
LocalAddr: sb.LocalAddress,
}
return d.Dial(network, address)
}
func (sb *SocketBuilders) GetKeepAlive() time.Duration {
return sb.KeepAlive
}
func (sb *SocketBuilders) GetTimeout() time.Duration {
return sb.Timeout
func (sb *SocketBuilders) GetHandshakeTimeout() time.Duration {
return sb.HandshakeTimeout
}
func (sb *SocketBuilders) GetLocalAddress() net.Addr {
@ -72,8 +79,8 @@ func (sb *SocketBuilders) Validate() {
if 0 >= sb.KeepAlive {
sb.KeepAlive = cs.DefaultKeepAlive
}
if 0 >= sb.Timeout {
sb.Timeout = cs.DefaultConnectTimeout
if 0 >= sb.HandshakeTimeout {
sb.HandshakeTimeout = cs.DefaultHandshakeTimeout
}
}

View File

@ -31,7 +31,7 @@ const (
// DefaultWriteTimeout is default value of write timeout
DefaultWriteTimeout = 0
DefaultConnectTimeout = 0
DefaultHandshakeTimeout = 0
DefaultKeepAlive = 0
)