This commit is contained in:
crusader 2018-04-12 10:23:40 +09:00
parent 05c5f95a40
commit 89a2ffcef0
5 changed files with 55 additions and 50 deletions

View File

@ -1,6 +1,8 @@
package socket package client
type Client interface { type Connector interface {
Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error)
Disconnect() error Disconnect() error
Validate() error
} }

View File

@ -1,4 +1,4 @@
package net package client
import ( import (
"crypto/tls" "crypto/tls"
@ -8,11 +8,12 @@ import (
"time" "time"
"git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go/client"
"git.loafle.net/commons/server-go/socket" "git.loafle.net/commons/server-go/socket"
) )
type Client struct { type Connectors struct {
socket.Client client.Connector
socket.ClientConnHandlers socket.ClientConnHandlers
socket.ReadWriteHandlers socket.ReadWriteHandlers
@ -35,13 +36,13 @@ type Client struct {
crw socket.ClientReadWriter crw socket.ClientReadWriter
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { func (c *Connectors) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
var ( var (
conn socket.Conn conn socket.Conn
) )
if c.stopChan != nil { if c.stopChan != nil {
return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) return nil, nil, fmt.Errorf("%s already connected", c.logHeader())
} }
err = c.Validate() err = c.Validate()
@ -75,9 +76,9 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err
return c.readChan, c.writeChan, nil return c.readChan, c.writeChan, nil
} }
func (c *Client) Disconnect() error { func (c *Connectors) Disconnect() error {
if c.stopChan == nil { if c.stopChan == nil {
return fmt.Errorf(c.clientMessage("must be started before stopping it")) return fmt.Errorf("%s must be connected before disconnection it", c.logHeader())
} }
close(c.stopChan) close(c.stopChan)
c.stopWg.Wait() c.stopWg.Wait()
@ -87,11 +88,11 @@ func (c *Client) Disconnect() error {
return nil return nil
} }
func (c *Client) clientMessage(msg string) string { func (c *Connectors) logHeader() string {
return fmt.Sprintf("Client[%s]: %s", c.Name, msg) return fmt.Sprintf("Connector[%s]: ", c.Name)
} }
func (c *Client) handleReconnect() { func (c *Connectors) handleReconnect() {
defer func() { defer func() {
c.stopWg.Done() c.stopWg.Done()
}() }()
@ -109,10 +110,10 @@ RC_LOOP:
continue RC_LOOP continue RC_LOOP
} }
logging.Logger().Debugf("connection lost") logging.Logger().Debugf("%s connection lost", c.logHeader())
for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { for indexI := 0; indexI < c.ReconnectTryTime; indexI++ {
logging.Logger().Debugf("trying reconnect[%d]", indexI) logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI)
conn, err := c.connect() conn, err := c.connect()
if nil == err { if nil == err {
@ -122,11 +123,11 @@ RC_LOOP:
} }
time.Sleep(c.ReconnectInterval) time.Sleep(c.ReconnectInterval)
} }
logging.Logger().Debugf("reconnecting has been failed") logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader())
} }
} }
func (c *Client) connect() (socket.Conn, error) { func (c *Connectors) connect() (socket.Conn, error) {
netConn, err := c.dial() netConn, err := c.dial()
if nil != err { if nil != err {
return nil, err return nil, err
@ -134,13 +135,13 @@ func (c *Client) connect() (socket.Conn, error) {
conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize)
conn.SetCloseHandler(func(code int, text string) error { conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close") logging.Logger().Debugf("%s close", c.logHeader())
return nil return nil
}) })
return conn, nil return conn, nil
} }
func (c *Client) dial() (net.Conn, error) { func (c *Connectors) dial() (net.Conn, error) {
if err := c.Validate(); nil != err { if err := c.Validate(); nil != err {
return nil, err return nil, err
} }
@ -179,7 +180,7 @@ func (c *Client) dial() (net.Conn, error) {
return conn, nil return conn, nil
} }
func (c *Client) Validate() error { func (c *Connectors) Validate() error {
if err := c.ClientConnHandlers.Validate(); nil != err { if err := c.ClientConnHandlers.Validate(); nil != err {
return err return err
} }
@ -188,15 +189,15 @@ func (c *Client) Validate() error {
} }
if "" == c.Name { if "" == c.Name {
c.Name = "Client" c.Name = "Connector"
} }
if "" == c.Network { if "" == c.Network {
return fmt.Errorf("Client: Network is not valid") return fmt.Errorf("%s Network is not valid", c.logHeader())
} }
if "" == c.Address { if "" == c.Address {
return fmt.Errorf("Client: Address is not valid") return fmt.Errorf("%s Address is not valid", c.logHeader())
} }
return nil return nil

View File

@ -1,4 +1,4 @@
package web package client
import ( import (
"bufio" "bufio"
@ -17,13 +17,15 @@ import (
"time" "time"
"git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go/client"
"git.loafle.net/commons/server-go/socket" "git.loafle.net/commons/server-go/socket"
"git.loafle.net/commons/server-go/socket/web"
) )
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
type Client struct { type Connectors struct {
socket.Client client.Connector
socket.ClientConnHandlers socket.ClientConnHandlers
socket.ReadWriteHandlers socket.ReadWriteHandlers
@ -64,14 +66,14 @@ type Client struct {
crw socket.ClientReadWriter crw socket.ClientReadWriter
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { func (c *Connectors) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
var ( var (
conn socket.Conn conn socket.Conn
res *http.Response res *http.Response
) )
if c.stopChan != nil { if c.stopChan != nil {
return nil, nil, fmt.Errorf(c.clientMessage("already running. Stop it before starting it again")) return nil, nil, fmt.Errorf("%s already connected", c.logHeader())
} }
err = c.Validate() err = c.Validate()
@ -109,9 +111,9 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err
return c.readChan, c.writeChan, nil return c.readChan, c.writeChan, nil
} }
func (c *Client) Disconnect() error { func (c *Connectors) Disconnect() error {
if c.stopChan == nil { if c.stopChan == nil {
return fmt.Errorf(c.clientMessage("must be started before stopping it")) return fmt.Errorf("%s must be connected before disconnection it", c.logHeader())
} }
close(c.stopChan) close(c.stopChan)
c.stopWg.Wait() c.stopWg.Wait()
@ -121,11 +123,11 @@ func (c *Client) Disconnect() error {
return nil return nil
} }
func (c *Client) clientMessage(msg string) string { func (c *Connectors) logHeader() string {
return fmt.Sprintf("Client[%s]: %s", c.Name, msg) return fmt.Sprintf("Connector[%s]:", c.Name)
} }
func (c *Client) handleReconnect() { func (c *Connectors) handleReconnect() {
defer func() { defer func() {
c.stopWg.Done() c.stopWg.Done()
}() }()
@ -143,10 +145,10 @@ RC_LOOP:
continue RC_LOOP continue RC_LOOP
} }
logging.Logger().Debugf("connection lost") logging.Logger().Debugf("%s connection lost", c.logHeader())
for indexI := 0; indexI < c.ReconnectTryTime; indexI++ { for indexI := 0; indexI < c.ReconnectTryTime; indexI++ {
logging.Logger().Debugf("trying reconnect[%d]", indexI) logging.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI)
conn, res, err := c.connect() conn, res, err := c.connect()
if nil == err { if nil == err {
@ -155,30 +157,30 @@ RC_LOOP:
resH(res) resH(res)
} }
logging.Logger().Debugf("reconnected") logging.Logger().Debugf("%s reconnected", c.logHeader())
c.reconnectedChan <- conn c.reconnectedChan <- conn
continue RC_LOOP continue RC_LOOP
} }
time.Sleep(c.ReconnectInterval) time.Sleep(c.ReconnectInterval)
} }
logging.Logger().Debugf("reconnecting has been failed") logging.Logger().Debugf("%s reconnecting has been failed", c.logHeader())
} }
} }
func (c *Client) connect() (socket.Conn, *http.Response, error) { func (c *Connectors) connect() (socket.Conn, *http.Response, error) {
conn, res, err := c.dial() conn, res, err := c.dial()
if nil != err { if nil != err {
return nil, nil, err return nil, nil, err
} }
conn.SetCloseHandler(func(code int, text string) error { conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close") logging.Logger().Debugf("%s close", c.logHeader())
return nil return nil
}) })
return conn, res, nil return conn, res, nil
} }
func (c *Client) dial() (socket.Conn, *http.Response, error) { func (c *Connectors) dial() (socket.Conn, *http.Response, error) {
var ( var (
err error err error
challengeKey string challengeKey string
@ -189,7 +191,7 @@ func (c *Client) dial() (socket.Conn, *http.Response, error) {
return nil, nil, err return nil, nil, err
} }
challengeKey, err = generateChallengeKey() challengeKey, err = web.GenerateChallengeKey()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -234,7 +236,7 @@ func (c *Client) dial() (socket.Conn, *http.Response, error) {
k == "Sec-Websocket-Version" || k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" || k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0): (k == "Sec-Websocket-Protocol" && len(c.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) return nil, nil, fmt.Errorf("%s duplicate header not allowed: %s", c.logHeader(), k)
default: default:
req.Header[k] = vs req.Header[k] = vs
} }
@ -358,7 +360,7 @@ func (c *Client) dial() (socket.Conn, *http.Response, error) {
if resp.StatusCode != 101 || if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { resp.Header.Get("Sec-Websocket-Accept") != web.ComputeAcceptKey(challengeKey) {
// Before closing the network connection on return from this // Before closing the network connection on return from this
// function, slurp up some of the response to aid application // function, slurp up some of the response to aid application
// debugging. // debugging.
@ -368,7 +370,7 @@ func (c *Client) dial() (socket.Conn, *http.Response, error) {
return nil, resp, socket.ErrBadHandshake return nil, resp, socket.ErrBadHandshake
} }
for _, ext := range httpParseExtensions(resp.Header) { for _, ext := range web.HttpParseExtensions(resp.Header) {
if ext[""] != "permessage-deflate" { if ext[""] != "permessage-deflate" {
continue continue
} }
@ -391,7 +393,7 @@ func (c *Client) dial() (socket.Conn, *http.Response, error) {
return conn, resp, nil return conn, resp, nil
} }
func (c *Client) Validate() error { func (c *Connectors) Validate() error {
if err := c.ClientConnHandlers.Validate(); nil != err { if err := c.ClientConnHandlers.Validate(); nil != err {
return err return err
} }
@ -400,11 +402,11 @@ func (c *Client) Validate() error {
} }
if "" == c.Name { if "" == c.Name {
c.Name = "Client" c.Name = "Connector"
} }
if "" == c.URL { if "" == c.URL {
return fmt.Errorf("Client: URL is not valid") return fmt.Errorf("%s URL is not valid", c.logHeader())
} }
u, err := parseURL(c.URL) u, err := parseURL(c.URL)

View File

@ -168,7 +168,7 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols) ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols)
ctx.Response.Header.Set("Upgrade", "websocket") ctx.Response.Header.Set("Upgrade", "websocket")
ctx.Response.Header.Set("Connection", "Upgrade") ctx.Response.Header.Set("Connection", "Upgrade")
ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) ctx.Response.Header.Set("Sec-Websocket-Accept", ComputeAcceptKey(challengeKey))
if subprotocol != "" { if subprotocol != "" {
ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol) ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol)
} }

View File

@ -17,14 +17,14 @@ import (
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string { func ComputeAcceptKey(challengeKey string) string {
h := sha1.New() h := sha1.New()
h.Write([]byte(challengeKey)) h.Write([]byte(challengeKey))
h.Write(keyGUID) h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil)) return base64.StdEncoding.EncodeToString(h.Sum(nil))
} }
func generateChallengeKey() (string, error) { func GenerateChallengeKey() (string, error) {
p := make([]byte, 16) p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil { if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err return "", err
@ -213,7 +213,7 @@ headers:
} }
// parseExtensiosn parses WebSocket extensions from a header. // parseExtensiosn parses WebSocket extensions from a header.
func httpParseExtensions(header http.Header) []map[string]string { func HttpParseExtensions(header http.Header) []map[string]string {
// From RFC 6455: // From RFC 6455:
// //