This commit is contained in:
crusader 2018-04-06 00:15:29 +09:00
parent ad158313e5
commit 7e7d6258c0
31 changed files with 473 additions and 108 deletions

View File

@ -10,16 +10,9 @@ type ReadWriteHandler interface {
GetWriteBufferSize() int GetWriteBufferSize() int
GetReadTimeout() time.Duration GetReadTimeout() time.Duration
GetWriteTimeout() time.Duration GetWriteTimeout() time.Duration
GetPongTimeout() time.Duration
GetPingTimeout() time.Duration
GetPingPeriod() time.Duration
IsEnableCompression() bool
} }
type ReadWriteHandlers struct { type ReadWriteHandlers struct {
ReadWriteHandler
MaxMessageSize int64 MaxMessageSize int64
// Per-connection buffer size for requests' reading. // Per-connection buffer size for requests' reading.
// This also limits the maximum header size. // This also limits the maximum header size.
@ -45,12 +38,6 @@ type ReadWriteHandlers struct {
// //
// By default response write timeout is unlimited. // By default response write timeout is unlimited.
WriteTimeout time.Duration WriteTimeout time.Duration
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
EnableCompression bool
} }
func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 { func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 {
@ -68,18 +55,7 @@ func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration {
func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration { func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration {
return rwh.WriteTimeout return rwh.WriteTimeout
} }
func (rwh *ReadWriteHandlers) GetPongTimeout() time.Duration {
return rwh.PongTimeout
}
func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration {
return rwh.PingTimeout
}
func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
return rwh.PingPeriod
}
func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression
}
func (rwh *ReadWriteHandlers) Validate() error { func (rwh *ReadWriteHandlers) Validate() error {
if rwh.MaxMessageSize <= 0 { if rwh.MaxMessageSize <= 0 {
rwh.MaxMessageSize = DefaultMaxMessageSize rwh.MaxMessageSize = DefaultMaxMessageSize
@ -96,15 +72,6 @@ func (rwh *ReadWriteHandlers) Validate() error {
if rwh.WriteTimeout <= 0 { if rwh.WriteTimeout <= 0 {
rwh.WriteTimeout = DefaultWriteTimeout rwh.WriteTimeout = DefaultWriteTimeout
} }
if rwh.PongTimeout <= 0 {
rwh.PongTimeout = DefaultPongTimeout
}
if rwh.PingTimeout <= 0 {
rwh.PingTimeout = DefaultPingTimeout
}
if rwh.PingPeriod <= 0 {
rwh.PingPeriod = (rwh.PingTimeout * 9) / 10
}
return nil return nil
} }

View File

@ -1,9 +1,6 @@
package server package server
type ServerHandler interface { type ServerHandler interface {
ConnectionHandler
ReadWriteHandler
GetName() string GetName() string
ServerCtx() ServerCtx ServerCtx() ServerCtx
@ -16,9 +13,6 @@ type ServerHandler interface {
} }
type ServerHandlers struct { type ServerHandlers struct {
ConnectionHandlers
ReadWriteHandlers
// Server name for sending in response headers. // Server name for sending in response headers.
// //
// Default server name is used if left blank. // Default server name is used if left blank.
@ -50,13 +44,6 @@ func (sh *ServerHandlers) GetName() string {
} }
func (sh *ServerHandlers) Validate() error { func (sh *ServerHandlers) Validate() error {
if err := sh.ConnectionHandlers.Validate(); nil != err {
return err
}
if err := sh.ReadWriteHandlers.Validate(); nil != err {
return err
}
if "" == sh.Name { if "" == sh.Name {
sh.Name = "Server" sh.Name = "Server"
} }

View File

@ -5,8 +5,4 @@ type Servlet interface {
Init(serverCtx ServerCtx) error Init(serverCtx ServerCtx) error
Destroy(serverCtx ServerCtx) Destroy(serverCtx ServerCtx)
OnConnect(servletCtx ServletCtx, conn Conn)
Handle(servletCtx ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte)
OnDisconnect(servletCtx ServletCtx)
} }

View File

@ -1,4 +1,4 @@
package server package socket
import ( import (
"io" "io"

View File

@ -1,4 +1,4 @@
package server package socket
import ( import (
"compress/flate" "compress/flate"

View File

@ -1,4 +1,4 @@
package server package socket
import ( import (
"errors" "errors"

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package server package socket
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package server package socket
import "unsafe" import "unsafe"

View File

@ -1,4 +1,4 @@
package server package socket
import ( import (
"bytes" "bytes"

View File

@ -4,7 +4,7 @@
// +build go1.5 // +build go1.5
package server package socket
import "io" import "io"

View File

@ -1,4 +1,4 @@
package server package socket
import ( import (
"bufio" "bufio"

View File

@ -9,11 +9,12 @@ import (
"git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
) )
type Client struct { type Client struct {
server.ClientConnHandlers server.ClientConnHandlers
server.ReadWriteHandlers socket.ReadWriteHandlers
Name string Name string
@ -28,14 +29,14 @@ type Client struct {
writeChan chan []byte writeChan chan []byte
disconnectedChan chan struct{} disconnectedChan chan struct{}
reconnectedChan chan server.Conn reconnectedChan chan socket.Conn
crw server.ClientReadWriter crw socket.ClientReadWriter
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
var ( var (
conn server.Conn conn socket.Conn
) )
if c.stopChan != nil { if c.stopChan != nil {
@ -55,7 +56,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err
c.readChan = make(chan []byte, 256) c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256)
c.disconnectedChan = make(chan struct{}) c.disconnectedChan = make(chan struct{})
c.reconnectedChan = make(chan server.Conn) c.reconnectedChan = make(chan socket.Conn)
c.stopChan = make(chan struct{}) c.stopChan = make(chan struct{})
c.crw.ReadwriteHandler = c c.crw.ReadwriteHandler = c
@ -124,13 +125,13 @@ RC_LOOP:
} }
} }
func (c *Client) connect() (server.Conn, error) { func (c *Client) connect() (socket.Conn, error) {
netConn, err := c.dial() netConn, err := c.dial()
if nil != err { if nil != err {
return nil, err return nil, err
} }
conn := server.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize)
conn.SetCloseHandler(func(code int, text string) error { conn.SetCloseHandler(func(code int, text string) error {
logging.Logger().Debugf("close") logging.Logger().Debugf("close")
return nil return nil

View File

@ -4,10 +4,11 @@ import (
"net" "net"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
) )
type ServerHandler interface { type ServerHandler interface {
server.ServerHandler socket.ServerHandler
OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error)
@ -59,9 +60,5 @@ func (sh *ServerHandlers) Validate() error {
return err return err
} }
if 0 >= sh.Concurrency {
sh.Concurrency = 0
}
return nil return nil
} }

View File

@ -10,6 +10,7 @@ import (
"git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
) )
type Server struct { type Server struct {
@ -19,7 +20,7 @@ type Server struct {
stopChan chan struct{} stopChan chan struct{}
stopWg sync.WaitGroup stopWg sync.WaitGroup
srw server.ServerReadWriter srw socket.ServerReadWriter
} }
func (s *Server) ListenAndServe() error { func (s *Server) ListenAndServe() error {
@ -158,7 +159,7 @@ func (s *Server) handleServer(listener net.Listener) error {
continue continue
} }
conn := server.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize()) conn := socket.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize())
s.stopWg.Add(1) s.stopWg.Add(1)
go s.srw.HandleConnection(servlet, servletCtx, conn) go s.srw.HandleConnection(servlet, servletCtx, conn)

View File

@ -4,10 +4,11 @@ import (
"net" "net"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
) )
type Servlet interface { type Servlet interface {
server.Servlet socket.Servlet
Handshake(servletCtx server.ServletCtx, conn net.Conn) error Handshake(servletCtx server.ServletCtx, conn net.Conn) error
} }
@ -32,7 +33,7 @@ func (s *Servlets) Handshake(servletCtx server.ServletCtx, conn net.Conn) error
return nil return nil
} }
func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn server.Conn) { func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn) {
// //
} }

View File

@ -0,0 +1,56 @@
package socket
import (
"time"
"git.loafle.net/commons/server-go"
)
type ReadWriteHandler interface {
server.ReadWriteHandler
GetPongTimeout() time.Duration
GetPingTimeout() time.Duration
GetPingPeriod() time.Duration
IsEnableCompression() bool
}
type ReadWriteHandlers struct {
server.ReadWriteHandlers
PongTimeout time.Duration
PingTimeout time.Duration
PingPeriod time.Duration
EnableCompression bool
}
func (rwh *ReadWriteHandlers) GetPongTimeout() time.Duration {
return rwh.PongTimeout
}
func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration {
return rwh.PingTimeout
}
func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration {
return rwh.PingPeriod
}
func (rwh *ReadWriteHandlers) IsEnableCompression() bool {
return rwh.EnableCompression
}
func (rwh *ReadWriteHandlers) Validate() error {
if err := rwh.ReadWriteHandlers.Validate(); nil != err {
return err
}
if rwh.PongTimeout <= 0 {
rwh.PongTimeout = server.DefaultPongTimeout
}
if rwh.PingTimeout <= 0 {
rwh.PingTimeout = server.DefaultPingTimeout
}
if rwh.PingPeriod <= 0 {
rwh.PingPeriod = (rwh.PingTimeout * 9) / 10
}
return nil
}

View File

@ -1,4 +1,4 @@
package server package socket
import ( import (
"fmt" "fmt"

31
socket/server-handler.go Normal file
View File

@ -0,0 +1,31 @@
package socket
import (
"git.loafle.net/commons/server-go"
)
type ServerHandler interface {
server.ServerHandler
server.ConnectionHandler
ReadWriteHandler
}
type ServerHandlers struct {
server.ServerHandlers
server.ConnectionHandlers
ReadWriteHandlers
}
func (sh *ServerHandlers) Validate() error {
if err := sh.ServerHandlers.Validate(); nil != err {
return err
}
if err := sh.ConnectionHandlers.Validate(); nil != err {
return err
}
if err := sh.ReadWriteHandlers.Validate(); nil != err {
return err
}
return nil
}

View File

@ -1,9 +1,10 @@
package server package socket
import ( import (
"sync" "sync"
logging "git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go"
) )
type ServerReadWriter struct { type ServerReadWriter struct {
@ -23,7 +24,7 @@ func (srw *ServerReadWriter) ConnectionSize() int {
return sz return sz
} }
func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx ServletCtx, conn Conn) { func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx server.ServletCtx, conn Conn) {
addr := conn.RemoteAddr() addr := conn.RemoteAddr()
defer func() { defer func() {

13
socket/servlet.go Normal file
View File

@ -0,0 +1,13 @@
package socket
import (
"git.loafle.net/commons/server-go"
)
type Servlet interface {
server.Servlet
OnConnect(servletCtx server.ServletCtx, conn Conn)
Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan []byte, writeChan chan<- []byte)
OnDisconnect(servletCtx server.ServletCtx)
}

View File

@ -1,4 +1,4 @@
package websocket package web
import ( import (
"bufio" "bufio"
@ -18,13 +18,14 @@ import (
"git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
) )
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
type Client struct { type Client struct {
server.ClientConnHandlers server.ClientConnHandlers
server.ReadWriteHandlers socket.ReadWriteHandlers
Name string Name string
@ -58,14 +59,14 @@ type Client struct {
writeChan chan []byte writeChan chan []byte
disconnectedChan chan struct{} disconnectedChan chan struct{}
reconnectedChan chan server.Conn reconnectedChan chan socket.Conn
crw server.ClientReadWriter crw socket.ClientReadWriter
} }
func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) { func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err error) {
var ( var (
conn server.Conn conn socket.Conn
res *http.Response res *http.Response
) )
@ -90,7 +91,7 @@ func (c *Client) Connect() (readChan <-chan []byte, writeChan chan<- []byte, err
c.readChan = make(chan []byte, 256) c.readChan = make(chan []byte, 256)
c.writeChan = make(chan []byte, 256) c.writeChan = make(chan []byte, 256)
c.disconnectedChan = make(chan struct{}) c.disconnectedChan = make(chan struct{})
c.reconnectedChan = make(chan server.Conn) c.reconnectedChan = make(chan socket.Conn)
c.stopChan = make(chan struct{}) c.stopChan = make(chan struct{})
c.crw.ReadwriteHandler = c c.crw.ReadwriteHandler = c
@ -164,7 +165,7 @@ RC_LOOP:
} }
} }
func (c *Client) connect() (server.Conn, *http.Response, error) { func (c *Client) connect() (socket.Conn, *http.Response, error) {
conn, res, err := c.dial() conn, res, err := c.dial()
if nil != err { if nil != err {
return nil, nil, err return nil, nil, err
@ -177,7 +178,7 @@ func (c *Client) connect() (server.Conn, *http.Response, error) {
return conn, res, nil return conn, res, nil
} }
func (c *Client) dial() (server.Conn, *http.Response, error) { func (c *Client) dial() (socket.Conn, *http.Response, error) {
var ( var (
err error err error
challengeKey string challengeKey string
@ -337,7 +338,7 @@ func (c *Client) dial() (server.Conn, *http.Response, error) {
} }
} }
conn := server.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize) conn := socket.NewConn(netConn, false, c.ReadBufferSize, c.WriteBufferSize)
if err := req.Write(netConn); err != nil { if err := req.Write(netConn); err != nil {
return nil, nil, err return nil, nil, err
@ -364,7 +365,7 @@ func (c *Client) dial() (server.Conn, *http.Response, error) {
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, server.ErrBadHandshake return nil, resp, socket.ErrBadHandshake
} }
for _, ext := range httpParseExtensions(resp.Header) { for _, ext := range httpParseExtensions(resp.Header) {
@ -374,10 +375,10 @@ func (c *Client) dial() (server.Conn, *http.Response, error) {
_, snct := ext["server_no_context_takeover"] _, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"] _, cnct := ext["client_no_context_takeover"]
if !snct || !cnct { if !snct || !cnct {
return nil, resp, server.ErrInvalidCompression return nil, resp, socket.ErrInvalidCompression
} }
conn.SetNewCompressionWriter(server.CompressNoContextTakeover) conn.SetNewCompressionWriter(socket.CompressNoContextTakeover)
conn.SetNewDecompressionReader(server.DecompressNoContextTakeover) conn.SetNewDecompressionReader(socket.DecompressNoContextTakeover)
break break
} }

View File

@ -1,15 +1,16 @@
package websocket package web
import ( import (
"net/http" "net/http"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
type ServerHandler interface { type ServerHandler interface {
server.ServerHandler socket.ServerHandler
OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error)

View File

@ -1,4 +1,4 @@
package websocket package web
import ( import (
"context" "context"
@ -9,6 +9,7 @@ import (
"git.loafle.net/commons/logging-go" "git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -19,7 +20,7 @@ type Server struct {
stopChan chan struct{} stopChan chan struct{}
stopWg sync.WaitGroup stopWg sync.WaitGroup
srw server.ServerReadWriter srw socket.ServerReadWriter
hs *fasthttp.Server hs *fasthttp.Server
upgrader *Upgrader upgrader *Upgrader
@ -177,7 +178,7 @@ func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
return return
} }
s.upgrader.Upgrade(ctx, responseHeader, func(conn *server.SocketConn, err error) { s.upgrader.Upgrade(ctx, responseHeader, func(conn *socket.SocketConn, err error) {
if err != nil { if err != nil {
s.onError(ctx, fasthttp.StatusInternalServerError, err) s.onError(ctx, fasthttp.StatusInternalServerError, err)
return return

View File

@ -1,12 +1,13 @@
package websocket package web
import ( import (
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/socket"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
type Servlet interface { type Servlet interface {
server.Servlet socket.Servlet
Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error)
} }
@ -31,7 +32,7 @@ func (s *Servlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.Request
return nil, nil return nil, nil
} }
func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn server.Conn) { func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn) {
// //
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package websocket package web
import ( import (
"net" "net"
@ -11,12 +11,12 @@ import (
"strings" "strings"
"time" "time"
"git.loafle.net/commons/server-go" "git.loafle.net/commons/server-go/socket"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
type ( type (
OnUpgradeFunc func(*server.SocketConn, error) OnUpgradeFunc func(*socket.SocketConn, error)
) )
// HandshakeError describes an error with the handshake from the peer. // HandshakeError describes an error with the handshake from the peer.
@ -60,7 +60,7 @@ type Upgrader struct {
EnableCompression bool EnableCompression bool
} }
func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*server.SocketConn, error) { func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*socket.SocketConn, error) {
err := HandshakeError{reason} err := HandshakeError{reason}
if u.Error != nil { if u.Error != nil {
u.Error(ctx, status, err) u.Error(ctx, status, err)
@ -192,11 +192,11 @@ func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.Re
ctx.Request.Header.CopyTo(h) ctx.Request.Header.CopyTo(h)
ctx.Hijack(func(netConn net.Conn) { ctx.Hijack(func(netConn net.Conn) {
c := server.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c := socket.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.SetSubprotocol(subprotocol) c.SetSubprotocol(subprotocol)
if compress { if compress {
c.SetNewCompressionWriter(server.CompressNoContextTakeover) c.SetNewCompressionWriter(socket.CompressNoContextTakeover)
c.SetNewDecompressionReader(server.DecompressNoContextTakeover) c.SetNewDecompressionReader(socket.DecompressNoContextTakeover)
} }
// Clear deadlines set by HTTP server. // Clear deadlines set by HTTP server.

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package websocket package web
import ( import (
"crypto/rand" "crypto/rand"

View File

@ -0,0 +1,92 @@
package fasthttp
import (
"net/http"
"git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/web"
"github.com/valyala/fasthttp"
)
type ServerHandler interface {
web.ServerHandler
OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error)
RegisterServlet(path string, servlet Servlet)
Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet
CheckOrigin(ctx *fasthttp.RequestCtx) bool
}
type ServerHandlers struct {
web.ServerHandlers
servlets map[string]Servlet
}
func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error {
if err := sh.ServerHandlers.Init(serverCtx); nil != err {
return err
}
if nil != sh.servlets {
for _, servlet := range sh.servlets {
if err := servlet.Init(serverCtx); nil != err {
return err
}
}
}
return nil
}
func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) {
if nil != sh.servlets {
for _, servlet := range sh.servlets {
servlet.Destroy(serverCtx)
}
}
sh.ServerHandlers.Destroy(serverCtx)
}
func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) {
ctx.Response.Header.Set("Sec-Websocket-Version", "13")
ctx.Error(http.StatusText(status), status)
}
func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) {
if nil == sh.servlets {
sh.servlets = make(map[string]Servlet)
}
sh.servlets[path] = servlet
}
func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet {
path := string(ctx.Path())
var servlet Servlet
if path == "" && len(sh.servlets) == 1 {
for _, s := range sh.servlets {
servlet = s
}
} else if servlet = sh.servlets[path]; nil == servlet {
return nil
}
return servlet
}
func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool {
return true
}
func (sh *ServerHandlers) Validate() error {
if err := sh.ServerHandlers.Validate(); nil != err {
return err
}
return nil
}

150
web/fasthttp/server.go Normal file
View File

@ -0,0 +1,150 @@
package fasthttp
import (
"context"
"fmt"
"net"
"sync"
"git.loafle.net/commons/logging-go"
"git.loafle.net/commons/server-go"
"github.com/valyala/fasthttp"
)
type Server struct {
ServerHandler ServerHandler
ctx server.ServerCtx
stopChan chan struct{}
stopWg sync.WaitGroup
hs *fasthttp.Server
}
func (s *Server) ListenAndServe() error {
var (
err error
listener net.Listener
)
if nil == s.ServerHandler {
return fmt.Errorf("Server: server handler must be specified")
}
s.ServerHandler.Validate()
if s.stopChan != nil {
return fmt.Errorf(s.serverMessage("already running. Stop it before starting it again"))
}
s.ctx = s.ServerHandler.ServerCtx()
if nil == s.ctx {
return fmt.Errorf(s.serverMessage("ServerCtx is nil"))
}
s.hs = &fasthttp.Server{
Handler: s.httpHandler,
Name: s.ServerHandler.GetName(),
Concurrency: s.ServerHandler.GetConcurrency(),
ReadBufferSize: s.ServerHandler.GetReadBufferSize(),
WriteBufferSize: s.ServerHandler.GetWriteBufferSize(),
ReadTimeout: s.ServerHandler.GetReadTimeout(),
WriteTimeout: s.ServerHandler.GetWriteTimeout(),
}
if err = s.ServerHandler.Init(s.ctx); nil != err {
return err
}
if listener, err = s.ServerHandler.Listener(s.ctx); nil != err {
return err
}
s.stopChan = make(chan struct{})
s.stopWg.Add(1)
return s.handleServer(listener)
}
func (s *Server) Shutdown(ctx context.Context) error {
if s.stopChan == nil {
return fmt.Errorf("server must be started before stopping it")
}
close(s.stopChan)
s.stopWg.Wait()
s.ServerHandler.Destroy(s.ctx)
s.stopChan = nil
return nil
}
func (s *Server) serverMessage(msg string) string {
return fmt.Sprintf("Server[%s]: %s", s.ServerHandler.GetName(), msg)
}
func (s *Server) handleServer(listener net.Listener) error {
var (
err error
)
errChan := make(chan error)
defer func() {
if nil != listener {
listener.Close()
}
s.stopWg.Done()
}()
go func() {
if err := s.hs.Serve(listener); nil != err {
errChan <- err
return
}
close(errChan)
}()
select {
case err, _ := <-errChan:
if nil != err {
return err
}
}
defer func() {
s.ServerHandler.OnStop(s.ctx)
logging.Logger().Infof(s.serverMessage("Stopped"))
}()
if err = s.ServerHandler.OnStart(s.ctx); nil != err {
return err
}
logging.Logger().Infof(s.serverMessage("Started"))
select {
case <-s.stopChan:
listener.Close()
listener = nil
}
return nil
}
func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) {
var (
servlet Servlet
err error
)
if servlet = s.ServerHandler.(ServerHandler).Servlet(s.ctx, ctx); nil == servlet {
s.onError(ctx, fasthttp.StatusInternalServerError, err)
return
}
}
func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) {
s.ServerHandler.(ServerHandler).OnError(s.ctx, ctx, status, reason)
}

33
web/fasthttp/servlet.go Normal file
View File

@ -0,0 +1,33 @@
package fasthttp
import (
"git.loafle.net/commons/server-go"
"git.loafle.net/commons/server-go/web"
"github.com/valyala/fasthttp"
)
type Servlet interface {
web.Servlet
Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx)
}
type Servlets struct {
Servlet
}
func (s *Servlets) ServletCtx(serverCtx server.ServerCtx) server.ServletCtx {
return server.NewServletContext(nil, serverCtx)
}
func (s *Servlets) Init(serverCtx server.ServerCtx) error {
return nil
}
func (s *Servlets) Destroy(serverCtx server.ServerCtx) {
//
}
func (s *Servlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) {
}

26
web/server-handler.go Normal file
View File

@ -0,0 +1,26 @@
package web
import (
"git.loafle.net/commons/server-go"
)
type ServerHandler interface {
server.ServerHandler
server.ConnectionHandler
}
type ServerHandlers struct {
server.ServerHandlers
server.ConnectionHandlers
}
func (sh *ServerHandlers) Validate() error {
if err := sh.ServerHandlers.Validate(); nil != err {
return err
}
if err := sh.ConnectionHandlers.Validate(); nil != err {
return err
}
return nil
}

9
web/servlet.go Normal file
View File

@ -0,0 +1,9 @@
package web
import (
"git.loafle.net/commons/server-go"
)
type Servlet interface {
server.Servlet
}