diff --git a/client/socket.go b/client/socket.go index 71ec183..d6865fb 100644 --- a/client/socket.go +++ b/client/socket.go @@ -3,11 +3,13 @@ package client import ( "io" "net" + "net/http" "sync" "time" + "git.loafle.net/commons_go/logging" + "git.loafle.net/commons_go/websocket_fasthttp/websocket" - "github.com/valyala/fasthttp" ) type Socket interface { @@ -129,28 +131,58 @@ type Socket interface { // compression levels. SetCompressionLevel(level int) error - // SetHeaders sets response headers - SetHeaders(h *fasthttp.ResponseHeader) - // Header returns header by key Header(key string) (value string) // Headers returns the RequestHeader struct - Headers() *fasthttp.ResponseHeader + Headers() http.Header } -func newSocket(socketHandler SocketHandler, socketCTX SocketContext, conn *websocket.Conn, id string) Socket { - s := retainSocket() - s.Conn = conn - s.sh = socketHandler - s.ctx = socketCTX +func NewSocket(sb SocketBuilder, parentContext cuc.Context) (Socket, error) { + if nil == sb { + logging.Logger().Panic("Client Socket: SocketBuilder must be specified") + } + sb.Validate() - s.SetReadLimit(socketHandler.GetMaxMessageSize()) - if 0 < socketHandler.GetReadTimeout() { - s.SetReadDeadline(time.Now().Add(socketHandler.GetReadTimeout() * time.Second)) + sc := sb.SocketContext(parentContext) + if nil == sc { + logging.Logger().Panic("Client Socket: SocketContext must be specified") } - return s + sh := sb.SocketHandler() + if nil == sh { + logging.Logger().Panic("Client Socket: SocketHandler must be specified") + } + sh.Validate() + + d := &websocket.Dialer{} + d.NetDial = sb.Dial + d.Proxy = sb.UseProxy + d.TLSClientConfig = sb.GetTLSConfig() + d.HandshakeTimeout = sb.GetHandshakeTimeout() + d.ReadBufferSize = sb.GetReadBufferSize() + d.WriteBufferSize = sb.GetWriteBufferSize() + d.Subprotocols = sb.GetSubProtocols() + d.EnableCompression = sb.IsEnableCompression() + d.Jar = sb.GetRequestCookie() + + url := sb.GetURL() + reqHeader := sb.GetRequestHeader() + + conn, res, err := d.Dial(url.String(), reqHeader) + if nil != err { + return nil, err + } + + sh.OnConnect(sc, res) + + s := retainSocket() + s.Conn = conn + s.ctx = sc + s.sh = sh + s.resHeader = res.Header + + return s, nil } type fasthttpWebSocket struct { @@ -159,7 +191,7 @@ type fasthttpWebSocket struct { ctx SocketContext sh SocketHandler - resHeaders *fasthttp.ResponseHeader + resHeader http.Header } func (s *fasthttpWebSocket) Context() SocketContext { @@ -182,23 +214,20 @@ func (s *fasthttpWebSocket) WriteMessage(messageType int, data []byte) error { return s.Conn.WriteMessage(messageType, data) } -func (s *fasthttpWebSocket) SetHeaders(h *fasthttp.ResponseHeader) { - s.resHeaders = h -} - func (s *fasthttpWebSocket) Header(key string) (value string) { - if nil == s.resHeaders { + if nil == s.resHeader { return "" } - return string(s.resHeaders.Peek(key)) + return s.resHeader.Get(key) } -func (s *fasthttpWebSocket) Headers() *fasthttp.ResponseHeader { - return s.resHeaders +func (s *fasthttpWebSocket) Headers() http.Header { + return s.resHeader } func (s *fasthttpWebSocket) Close() error { err := s.Conn.Close() + s.sh.OnDisconnect(s) releaseSocket(s) return err } diff --git a/client/socket_builder.go b/client/socket_builder.go new file mode 100644 index 0000000..357267f --- /dev/null +++ b/client/socket_builder.go @@ -0,0 +1,37 @@ +package client + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + cuc "git.loafle.net/commons_go/util/context" +) + +type SocketBuilder interface { + SocketContext(parent cuc.Context) SocketContext + SocketHandler() SocketHandler + + GetURL() *url.URL + GetRequestCookie() http.CookieJar + GetRequestHeader() http.Header + GetSubProtocols() []string + IsEnableCompression() bool + UseProxy(req *http.Request) (*url.URL, error) + GetHandshakeTimeout() time.Duration + Dial(network, addr string) (net.Conn, error) + GetTLSConfig() *tls.Config + GetReadBufferSize() int + GetWriteBufferSize() int + + // Validate is check handler value + // If you override ths method, must call + // + // func (sh *SocketHandlers) Validate() { + // sh.SocketHandlers.Validate() + // ... + // } + Validate() +} diff --git a/client/socket_builders.go b/client/socket_builders.go new file mode 100644 index 0000000..1908dd4 --- /dev/null +++ b/client/socket_builders.go @@ -0,0 +1,101 @@ +package client + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + "git.loafle.net/commons_go/logging" + cuc "git.loafle.net/commons_go/util/context" + cwf "git.loafle.net/commons_go/websocket_fasthttp" +) + +type SocketBuilders struct { + URL *url.URL + RequestCookie http.CookieJar + RequestHeader http.Header + SubProtocols []string + EnableCompression bool + HandshakeTimeout time.Duration + TLSConfig *tls.Config + ReadBufferSize int + WriteBufferSize int +} + +func (sb *SocketBuilders) SocketContext(parent cuc.Context) SocketContext { + return newSocketContext(parent) +} + +func (sb *SocketBuilders) SocketHandler() SocketHandler { + return NewSocketHandler() +} + +func (sb *SocketBuilders) UseProxy(req *http.Request) (*url.URL, error) { + return http.ProxyFromEnvironment(req) +} + +func (sb *SocketBuilders) Dial(network, addr string) (net.Conn, error) { + var deadline time.Time + if 0 != sb.HandshakeTimeout { + deadline = time.Now().Add(sb.HandshakeTimeout) + } + + netDialer := &net.Dialer{Deadline: deadline} + + return netDialer.Dial(network, addr) +} + +func (sb *SocketBuilders) GetHandshakeTimeout() time.Duration { + return sb.HandshakeTimeout +} + +func (sb *SocketBuilders) GetURL() *url.URL { + return sb.URL +} + +func (sb *SocketBuilders) GetRequestCookie() http.CookieJar { + return sb.RequestCookie +} + +func (sb *SocketBuilders) GetRequestHeader() http.Header { + return sb.RequestHeader +} + +func (sb *SocketBuilders) GetSubProtocols() []string { + return sb.SubProtocols +} + +func (sb *SocketBuilders) IsEnableCompression() bool { + return sb.EnableCompression +} + +func (sb *SocketBuilders) GetTLSConfig() *tls.Config { + return sb.TLSConfig +} + +func (sb *SocketBuilders) GetReadBufferSize() int { + return sb.ReadBufferSize +} + +func (sb *SocketBuilders) GetWriteBufferSize() int { + return sb.WriteBufferSize +} + +func (sb *SocketBuilders) Validate() { + if nil == sb.URL { + logging.Logger().Panic("Client Socket: URL must be specified") + } + + if 0 >= sb.HandshakeTimeout { + sb.HandshakeTimeout = cwf.DefaultHandshakeTimeout + } + if 0 >= sb.ReadBufferSize { + sb.ReadBufferSize = cwf.DefaultReadBufferSize + } + if 0 >= sb.WriteBufferSize { + sb.WriteBufferSize = cwf.DefaultWriteBufferSize + } + +} diff --git a/client/socket_handler.go b/client/socket_handler.go index 7195c63..c48ab97 100644 --- a/client/socket_handler.go +++ b/client/socket_handler.go @@ -1,29 +1,20 @@ package client import ( - "crypto/tls" - "net" "net/http" - "net/url" "time" - - cuc "git.loafle.net/commons_go/util/context" ) type SocketHandler interface { - SocketContext(parent cuc.Context) SocketContext + OnConnect(socketContext SocketContext, res *http.Response) + OnDisconnect(soc Socket) - GetURL() *url.URL - GetRequestCookie() http.CookieJar - GetRequestHeader() http.Header - GetSubProtocols() []string - EnableCompression() bool - UseProxy(req *http.Request) (*url.URL, error) - GetHandshakeTimeout() time.Duration - Dial(network, addr string) (net.Conn, error) - GetTLSConfig() *tls.Config - GetReadBufferSize() int - GetWriteBufferSize() int + GetMaxMessageSize() int64 + GetWriteTimeout() time.Duration + GetReadTimeout() time.Duration + GetPongTimeout() time.Duration + GetPingTimeout() time.Duration + GetPingPeriod() time.Duration // Validate is check handler value // If you override ths method, must call diff --git a/client/socket_handlers.go b/client/socket_handlers.go index 64fcc86..5d0dc81 100644 --- a/client/socket_handlers.go +++ b/client/socket_handlers.go @@ -1,75 +1,81 @@ package client import ( - "crypto/tls" - "net" "net/http" - "net/url" "time" - cuc "git.loafle.net/commons_go/util/context" + cwf "git.loafle.net/commons_go/websocket_fasthttp" ) type SocketHandlers struct { - URL *url.URL - RequestCookie http.CookieJar - RequestHeader http.Header - SubProtocols []string - EnableCompression bool - HandshakeTimeout time.Duration - TLSConfig *tls.Config - ReadBufferSize int - WriteBufferSize int + // MaxMessageSize is the maximum size for a message read from the peer. If a + // message exceeds the limit, the connection sends a close frame to the peer + // and returns ErrReadLimit to the application. + MaxMessageSize int64 + // WriteTimeout is the write deadline on the underlying network + // connection. After a write has timed out, the websocket state is corrupt and + // all future writes will return an error. A zero value for t means writes will + // not time out. + WriteTimeout time.Duration + // ReadTimeout is the read deadline on the underlying network connection. + // After a read has timed out, the websocket connection state is corrupt and + // all future reads will return an error. A zero value for t means reads will + // not time out. + ReadTimeout time.Duration + + PongTimeout time.Duration + PingTimeout time.Duration + PingPeriod time.Duration } -func (sh *SocketHandlers) SocketContext(parent cuc.Context) SocketContext { - return newSocketContext(parent) +func (sh *SocketHandlers) OnConnect(socketContext SocketContext, res *http.Response) { + // no op } -func (sh *SocketHandlers) GetURL() *url.URL { - return sh.URL +func (sh *SocketHandlers) OnDisconnect(soc Socket) { + // no op } -func (sh *SocketHandlers) GetRequestCookie() http.CookieJar { - return sh.RequestCookie +func (sh *SocketHandlers) GetMaxMessageSize() int64 { + return sh.MaxMessageSize } - -func (sh *SocketHandlers) GetRequestHeader() http.Header { - return sh.RequestHeader +func (sh *SocketHandlers) GetWriteTimeout() time.Duration { + return sh.WriteTimeout } - -func (sh *SocketHandlers) GetSubProtocols() []string { - return sh.SubProtocols +func (sh *SocketHandlers) GetReadTimeout() time.Duration { + return sh.ReadTimeout } - -func (sh *SocketHandlers) EnableCompression() bool { - return sh.EnableCompression +func (sh *SocketHandlers) GetPongTimeout() time.Duration { + return sh.PongTimeout } - -func (sh *SocketHandlers) UseProxy(req *http.Request) (*url.URL, error) { - return nil, nil +func (sh *SocketHandlers) GetPingTimeout() time.Duration { + return sh.PingTimeout } - -func (sh *SocketHandlers) GetHandshakeTimeout() time.Duration { - return sh.HandshakeTimeout -} - -func (sh *SocketHandlers) Dial(network, addr string) (net.Conn, error) { - return nil, nil -} - -func (sh *SocketHandlers) GetTLSConfig() *tls.Config { - return sh.TLSConfig -} - -func (sh *SocketHandlers) GetReadBufferSize() int { - return sh.ReadBufferSize -} - -func (sh *SocketHandlers) GetWriteBufferSize() int { - return sh.WriteBufferSize +func (sh *SocketHandlers) GetPingPeriod() time.Duration { + return sh.PingPeriod } func (sh *SocketHandlers) Validate() { - + if sh.MaxMessageSize <= 0 { + sh.MaxMessageSize = cwf.DefaultMaxMessageSize + } + if sh.WriteTimeout <= 0 { + sh.WriteTimeout = cwf.DefaultWriteTimeout + } + if sh.ReadTimeout <= 0 { + sh.ReadTimeout = cwf.DefaultReadTimeout + } + if sh.PongTimeout <= 0 { + sh.PongTimeout = cwf.DefaultPongTimeout + } + if sh.PingTimeout <= 0 { + sh.PingTimeout = cwf.DefaultPingTimeout + } + if sh.PingPeriod <= 0 { + sh.PingPeriod = cwf.DefaultPingPeriod + } +} + +func NewSocketHandler() SocketHandler { + return &SocketHandlers{} } diff --git a/constants.go b/constants.go index fef84e7..996e9f9 100644 --- a/constants.go +++ b/constants.go @@ -10,9 +10,9 @@ const ( // DefaultHandshakeTimeout is default value of websocket handshake Timeout DefaultHandshakeTimeout = 0 // DefaultReadBufferSize is default value of Read Buffer Size - DefaultReadBufferSize = 4096 + DefaultReadBufferSize = 0 // DefaultWriteBufferSize is default value of Write Buffer Size - DefaultWriteBufferSize = 4096 + DefaultWriteBufferSize = 0 // DefaultReadTimeout is default value of read timeout DefaultReadTimeout = 0 // DefaultWriteTimeout is default value of write timeout