This commit is contained in:
crusader 2017-11-08 19:15:09 +09:00
parent 760e6b8ae1
commit b5a6a7280c
5 changed files with 252 additions and 15 deletions

View File

@ -13,12 +13,18 @@ const (
DefaultReadBufferSize = 4096
// DefaultWriteBufferSize is default value of Write Buffer Size
DefaultWriteBufferSize = 4096
// DefaultReadTimeout is default value of read timeout
DefaultReadTimeout = 0
// DefaultWriteTimeout is default value of write timeout
DefaultWriteTimeout = 0
// DefaultEnableCompression is default value of support compression
DefaultEnableCompression = false
// DefaultMaxMessageSize is default size for a message read from the peer
DefaultMaxMessageSize = 4096
// DefaultPongTimeout is default value of websocket pong Timeout
DefaultPongTimeout = 0
// DefaultPingTimeout is default value of websocket ping Timeout
DefaultPingTimeout = 0
// DefaultPingPeriod is default value of send ping period
DefaultPingPeriod = 0
)

View File

@ -132,8 +132,10 @@ func (s *server) handleRequest(ctx *fasthttp.RequestCtx) {
return
}
soc := newSocket(conn, socketHandler)
s.stopWg.Add(1)
go handleConnection(s, conn, socketHandler)
handleConnection(s, soc)
})
}
@ -145,24 +147,24 @@ func (s *server) handleError(ctx *fasthttp.RequestCtx, status int, reason error)
s.sh.OnError(ctx, status, reason)
}
func handleConnection(s *server, conn *websocket.Conn, handler SocketHandler) {
func handleConnection(s *server, soc *Socket) {
defer s.stopWg.Done()
logging.Logger.Debug(fmt.Sprintf("Server: Client[%s] is connected.", conn.RemoteAddr()))
logging.Logger.Debug(fmt.Sprintf("Server: Client[%s] is connected.", soc.RemoteAddr()))
clientStopChan := make(chan struct{})
handleDoneChan := make(chan struct{})
go handler.Handle(conn, clientStopChan, handleDoneChan)
go handler.Handle(soc, clientStopChan, handleDoneChan)
select {
case <-s.stopChan:
close(clientStopChan)
conn.Close()
soc.Close()
<-handleDoneChan
case <-handleDoneChan:
close(clientStopChan)
logging.Logger.Debug(fmt.Sprintf("Server: Client[%s] is disconnected.", conn.RemoteAddr()))
conn.Close()
logging.Logger.Debug(fmt.Sprintf("Server: Client[%s] is disconnected.", soc.RemoteAddr()))
soc.Close()
}
}

164
socket.go Normal file
View File

@ -0,0 +1,164 @@
package websocket_fasthttp
import (
"io"
"net"
"sync"
"time"
"git.loafle.net/commons_go/websocket_fasthttp/websocket"
)
func newSocket(conn *websocket.Conn, sh SocketHandler) *Socket {
s := retainSocket()
s.Conn = conn
s.sh = sh
s.SetReadLimit(sh.GetMaxMessageSize())
if 0 < sh.GetReadTimeout() {
s.SetReadDeadline(time.Now().Add(sh.GetReadTimeout() * time.Second))
}
return s
}
type Socket struct {
*websocket.Conn
sh SocketHandler
sc *SocketConn
}
func (s *Socket) WaitRequest() (*SocketConn, error) {
if nil != s.sc {
releaseSocketConn(s.sc)
s.sc = nil
}
var mt int
var err error
var r io.Reader
if mt, r, err = s.NextReader(); nil != err {
return nil, err
}
sc := retainSocketConn()
sc.s = s
sc.MessageType = mt
sc.r = r
return sc, nil
}
func (s *Socket) NextWriter(messageType int) (io.WriteCloser, error) {
if 0 < s.sh.GetWriteTimeout() {
s.SetWriteDeadline(time.Now().Add(s.sh.GetWriteTimeout() * time.Second))
}
return s.Conn.NextWriter(messageType)
}
func (s *Socket) WriteMessage(messageType int, data []byte) error {
if 0 < s.sh.GetWriteTimeout() {
s.SetWriteDeadline(time.Now().Add(s.sh.GetWriteTimeout() * time.Second))
}
return s.Conn.WriteMessage(messageType, data)
}
func (s *Socket) Close() error {
err := s.Conn.Close()
releaseSocket(s)
return err
}
type SocketConn struct {
net.Conn
s *Socket
MessageType int
r io.Reader
wc io.WriteCloser
}
func (sc *SocketConn) Read(b []byte) (n int, err error) {
return sc.r.Read(b)
}
func (sc *SocketConn) Write(b []byte) (n int, err error) {
if nil == sc.wc {
var err error
if sc.wc, err = sc.s.NextWriter(sc.MessageType); nil != err {
return 0, err
}
}
return sc.wc.Write(b)
}
func (sc *SocketConn) Close() error {
var err error
if sc.wc != nil {
err = sc.wc.Close()
}
releaseSocketConn(sc)
sc.s.sc = nil
return err
}
func (sc *SocketConn) LocalAddr() net.Addr {
return sc.s.LocalAddr()
}
func (sc *SocketConn) RemoteAddr() net.Addr {
return sc.s.RemoteAddr()
}
func (sc *SocketConn) SetDeadline(t time.Time) error {
if err := sc.s.SetReadDeadline(t); nil != err {
return err
}
if err := sc.s.SetWriteDeadline(t); nil != err {
return err
}
return nil
}
func (sc *SocketConn) SetReadDeadline(t time.Time) error {
return sc.s.SetReadDeadline(t)
}
func (sc *SocketConn) SetWriteDeadline(t time.Time) error {
return sc.s.SetWriteDeadline(t)
}
var socketPool sync.Pool
func retainSocket() *Socket {
v := socketPool.Get()
if v == nil {
return &Socket{}
}
return v.(*Socket)
}
func releaseSocket(s *Socket) {
s.sh = nil
s.sc = nil
socketPool.Put(s)
}
var socketConnPool sync.Pool
func retainSocketConn() *SocketConn {
v := socketConnPool.Get()
if v == nil {
return &SocketConn{}
}
return v.(*SocketConn)
}
func releaseSocketConn(sc *SocketConn) {
sc.s = nil
sc.r = nil
sc.wc = nil
socketConnPool.Put(sc)
}

View File

@ -1,11 +1,21 @@
package websocket_fasthttp
import (
"git.loafle.net/commons_go/websocket_fasthttp/websocket"
"time"
"github.com/valyala/fasthttp"
)
type SocketHandler interface {
Handshake(ctx *fasthttp.RequestCtx) (connectable bool, extensionsHeader *fasthttp.ResponseHeader)
Handle(conn *websocket.Conn, stopChan <-chan struct{}, doneChan chan<- struct{})
Handle(soc *Socket, stopChan <-chan struct{}, doneChan chan<- struct{})
GetMaxMessageSize() int64
GetWriteTimeout() time.Duration
GetReadTimeout() time.Duration
GetPongTimeout() time.Duration
GetPingTimeout() time.Duration
GetPingPeriod() time.Duration
Validate()
}

View File

@ -1,21 +1,76 @@
package websocket_fasthttp
import (
"git.loafle.net/commons_go/websocket_fasthttp/websocket"
"time"
"github.com/valyala/fasthttp"
)
type SocketHandlers struct {
// MaxMessageSize is the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application.
MaxMessageSize int64
// WriteTimeout is the write deadline on the underlying network
// connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will
// not time out.
WriteTimeout time.Duration
// ReadTimeout is the read deadline on the underlying network connection.
// After a read has timed out, the websocket connection state is corrupt and
// all future reads will return an error. A zero value for t means reads will
// not time out.
ReadTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
}
func (sh *SocketHandlers) Handshake(ctx *fasthttp.RequestCtx) (bool, *fasthttp.ResponseHeader) {
return true, nil
}
func (sh *SocketHandlers) Handle(conn *websocket.Conn, stopChan <-chan struct{}, doneChan chan<- struct{}) {
func (sh *SocketHandlers) Handle(soc *Socket, stopChan <-chan struct{}, doneChan chan<- struct{}) {
// no op
}
func (sh *SocketHandlers) Validate() {
func (sh *SocketHandlers) GetMaxMessageSize() int64 {
return sh.MaxMessageSize
}
func (sh *SocketHandlers) GetWriteTimeout() time.Duration {
return sh.WriteTimeout
}
func (sh *SocketHandlers) GetReadTimeout() time.Duration {
return sh.ReadTimeout
}
func (sh *SocketHandlers) GetPongTimeout() time.Duration {
return sh.PongTimeout
}
func (sh *SocketHandlers) GetPingTimeout() time.Duration {
return sh.PingTimeout
}
func (sh *SocketHandlers) GetPingPeriod() time.Duration {
return sh.PingPeriod
}
func (sh *SocketHandlers) Validate() {
if sh.MaxMessageSize <= 0 {
sh.MaxMessageSize = DefaultMaxMessageSize
}
if sh.WriteTimeout <= 0 {
sh.WriteTimeout = DefaultWriteTimeout
}
if sh.ReadTimeout <= 0 {
sh.ReadTimeout = DefaultReadTimeout
}
if sh.PongTimeout <= 0 {
sh.PongTimeout = DefaultPongTimeout
}
if sh.PingTimeout <= 0 {
sh.PingTimeout = DefaultPingTimeout
}
if sh.PingPeriod <= 0 {
sh.PingPeriod = DefaultPingPeriod
}
}