This commit is contained in:
crusader 2017-11-28 18:33:52 +09:00
parent 2aebcbab5e
commit 96b71dc5d5
5 changed files with 144 additions and 86 deletions

View File

@ -1,42 +1,43 @@
package socket package socket
import ( import (
"io" "net"
"git.loafle.net/commons_go/rpc"
"git.loafle.net/commons_go/rpc/protocol" "git.loafle.net/commons_go/rpc/protocol"
) )
type ServletHandlers struct { type ServletHandlers struct {
} }
func (sh *ServletHandlers) GetRequest(codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) { func (sh *ServletHandlers) GetRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) {
r := reader.(io.Reader) nConn := conn.(net.Conn)
requestCodec, err := codec.NewRequest(r) requestCodec, err := codec.NewRequest(nConn)
return requestCodec, err return requestCodec, err
} }
func (sh *ServletHandlers) SendResponse(requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error { func (sh *ServletHandlers) SendResponse(servletCTX rpc.ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error {
w := writer.(io.Writer) nConn := conn.(net.Conn)
if nil != err { if nil != err {
if lerr := requestCodec.WriteError(w, 500, err); nil != lerr { if wErr := requestCodec.WriteError(nConn, 500, err); nil != wErr {
return wErr
} }
} else { } else {
if err := requestCodec.WriteResponse(w, result); nil != err { if wErr := requestCodec.WriteResponse(nConn, result); nil != wErr {
return wErr
} }
} }
return nil return nil
} }
func (sh *ServletHandlers) SendNotification(codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error { func (sh *ServletHandlers) SendNotification(servletCTX rpc.ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error {
w := writer.(io.Writer) nConn := conn.(net.Conn)
if err := codec.WriteNotification(w, method, args); nil != err {
if wErr := codec.WriteNotification(nConn, method, args); nil != wErr {
return wErr
} }
return nil return nil

View File

@ -1,57 +1,57 @@
package fasthttp package fasthttp
import ( import (
"fmt" "github.com/gorilla/websocket"
"git.loafle.net/commons_go/rpc" "git.loafle.net/commons_go/rpc"
"git.loafle.net/commons_go/rpc/protocol" "git.loafle.net/commons_go/rpc/protocol"
"git.loafle.net/commons_go/websocket_fasthttp/websocket" cwf "git.loafle.net/commons_go/websocket_fasthttp"
) )
type ServletHandlers struct { type ServletHandlers struct {
} }
func (sh *ServletHandlers) GetRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) { func (sh *ServletHandlers) GetRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) {
conn := reader.(*websocket.Conn) soc := conn.(cwf.Socket)
_, r, err := conn.NextReader() _, r, err := soc.NextReader()
requestCodec, err := codec.NewRequest(r) requestCodec, err := codec.NewRequest(r)
return requestCodec, err return requestCodec, err
} }
func (sh *ServletHandlers) SendResponse(servletCTX rpc.ServletContext, requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error { func (sh *ServletHandlers) SendResponse(servletCTX rpc.ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error {
conn := writer.(*websocket.Conn) soc := conn.(cwf.Socket)
wc, lerr := conn.NextWriter(websocket.TextMessage)
if nil != lerr {
wc, wErr := soc.NextWriter(websocket.TextMessage)
if nil != wErr {
return wErr
} }
if nil != err { if nil != err {
if lerr := requestCodec.WriteError(wc, 500, err); nil != lerr { if wErr := requestCodec.WriteError(wc, 500, err); nil != wErr {
return wErr
} }
} else { } else {
if err := requestCodec.WriteResponse(wc, result); nil != err { if wErr := requestCodec.WriteResponse(wc, result); nil != wErr {
return wErr
} }
} }
return fmt.Errorf("Servlet Handler: SendResponse is not implemented") return nil
} }
func (sh *ServletHandlers) SendNotification(servletCTX rpc.ServletContext, codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error { func (sh *ServletHandlers) SendNotification(servletCTX rpc.ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error {
conn := writer.(*websocket.Conn) soc := conn.(cwf.Socket)
wc, lerr := conn.NextWriter(websocket.TextMessage)
if nil != lerr {
wc, wErr := soc.NextWriter(websocket.TextMessage)
if nil != wErr {
return wErr
} }
if err := codec.WriteNotification(wc, method, args); nil != err { if wErr := codec.WriteNotification(wc, method, args); nil != wErr {
return wErr
} }
return fmt.Errorf("Servlet Handler: SendNotification is not implemented") return nil
} }

View File

@ -2,6 +2,8 @@ package rpc
import ( import (
"fmt" "fmt"
"io"
"runtime"
"sync" "sync"
"git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/logging"
@ -16,7 +18,7 @@ func NewServlet(sh ServletHandler) Servlet {
} }
type Servlet interface { type Servlet interface {
Start(parentCTX cuc.Context, reader interface{}, writer interface{}) error Start(parentCTX cuc.Context, conn interface{}) error
Stop() Stop()
Send(method string, args ...interface{}) (err error) Send(method string, args ...interface{}) (err error)
@ -29,15 +31,14 @@ type servlet struct {
sh ServletHandler sh ServletHandler
messageQueueChan chan *messageState messageQueueChan chan *messageState
reader interface{} conn interface{}
writer interface{}
serverCodec protocol.ServerCodec serverCodec protocol.ServerCodec
stopChan chan struct{} stopChan chan struct{}
stopWg sync.WaitGroup stopWg sync.WaitGroup
} }
func (s *servlet) Start(parentCTX cuc.Context, reader interface{}, writer interface{}, doneChan chan<- struct{}) error { func (s *servlet) Start(parentCTX cuc.Context, conn interface{}) error {
if nil == s.sh { if nil == s.sh {
panic("Servlet: servlet handler must be specified.") panic("Servlet: servlet handler must be specified.")
} }
@ -53,8 +54,7 @@ func (s *servlet) Start(parentCTX cuc.Context, reader interface{}, writer interf
return err return err
} }
s.reader = reader s.conn = conn
s.writer = writer
s.serverCodec = sc s.serverCodec = sc
if err := s.sh.Init(s.ctx); nil != err { if err := s.sh.Init(s.ctx); nil != err {
@ -65,7 +65,7 @@ func (s *servlet) Start(parentCTX cuc.Context, reader interface{}, writer interf
s.messageQueueChan = make(chan *messageState, s.sh.GetPendingMessages()) s.messageQueueChan = make(chan *messageState, s.sh.GetPendingMessages())
s.stopWg.Add(1) s.stopWg.Add(1)
go handleServlet(s, doneChan) go handleServlet(s)
return nil return nil
} }
@ -81,8 +81,7 @@ func (s *servlet) Stop() {
s.messageQueueChan = nil s.messageQueueChan = nil
s.reader = nil s.conn = nil
s.writer = nil
s.serverCodec = nil s.serverCodec = nil
logging.Logger().Info(fmt.Sprintf("Servlet is stopped")) logging.Logger().Info(fmt.Sprintf("Servlet is stopped"))
@ -102,17 +101,56 @@ func (s *servlet) Context() ServletContext {
return s.ctx return s.ctx
} }
func handleServlet(s *servlet, doneChan chan<- struct{}) { func handleServlet(s *servlet) {
defer s.stopWg.Done() defer s.stopWg.Done()
messageStopChan := make(chan struct{}) subStopChan := make(chan struct{})
messageDoneChan := make(chan struct{})
go handleMessage(s, messageStopChan, messageDoneChan) readerDone := make(chan error, 1)
go handleReader(s, subStopChan, readerDone)
writerDone := make(chan error, 1)
go handleWriter(s, subStopChan, writerDone)
var err error
select {
case err = <-readerDone:
close(subStopChan)
<-writerDone
case err = <-writerDone:
close(subStopChan)
<-readerDone
case <-s.stopChan:
close(subStopChan)
<-readerDone
<-writerDone
}
if err != nil {
logging.Logger().Error(fmt.Sprintf("RPC Server: servlet error %v", err))
}
}
func handleReader(s *servlet, stopChan chan struct{}, doneChan chan error) {
var err error
defer func() {
if r := recover(); r != nil {
if err == nil {
err = fmt.Errorf("RPC Server: Panic when reading request from client: %v", r)
}
}
doneChan <- err
}()
for { for {
requestCodec, err := s.sh.GetRequest(s.ctx, s.serverCodec, s.reader) requestCodec, err := s.sh.GetRequest(s.ctx, s.serverCodec, s.conn)
if nil != err { if nil != err {
if err == io.ErrUnexpectedEOF || err == io.EOF {
err = fmt.Errorf("RPC Server: disconnected from client")
return
}
logging.Logger().Error(fmt.Sprintf("RPC Server: Cannot read request: [%s]", err))
continue continue
} }
@ -120,12 +158,59 @@ func handleServlet(s *servlet, doneChan chan<- struct{}) {
go handleRequest(s, requestCodec) go handleRequest(s, requestCodec)
select { select {
case <-s.stopChan: case <-stopChan:
err = fmt.Errorf("RPC Server: reading request stopped because get stop channel")
return
default: default:
} }
} }
} }
func handleWriter(s *servlet, stopChan chan struct{}, doneChan chan error) {
var err error
defer func() {
if r := recover(); r != nil {
if err == nil {
err = fmt.Errorf("RPC Server: Panic when writing message to client: %v", r)
}
}
doneChan <- err
}()
for {
var ms *messageState
select {
case ms = <-s.messageQueueChan:
default:
// Give the last chance for ready goroutines filling s.messageQueueChan :)
runtime.Gosched()
select {
case <-stopChan:
err = fmt.Errorf("RPC Server: writing message stopped because get stop channel")
return
case ms = <-s.messageQueueChan:
}
}
switch ms.messageType {
case protocol.MessageTypeResponse:
if err := s.sh.SendResponse(s.ctx, s.conn, ms.res.requestCodec, ms.res.result, ms.res.err); nil != err {
logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err))
}
ms.res.requestCodec.Close()
case protocol.MessageTypeNotification:
if err := s.sh.SendNotification(s.ctx, s.conn, s.serverCodec, ms.noti.method, ms.noti.args...); nil != err {
logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err))
}
default:
}
}
}
func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) { func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) {
defer func() { defer func() {
s.stopWg.Done() s.stopWg.Done()
@ -141,34 +226,6 @@ func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) {
s.messageQueueChan <- ms s.messageQueueChan <- ms
} }
func handleMessage(s *servlet) {
defer func() {
s.stopWg.Done()
}()
for {
select {
case ms := <-s.messageQueueChan:
switch ms.messageType {
case protocol.MessageTypeResponse:
if err := s.sh.SendResponse(s.ctx, ms.res.requestCodec, s.writer, ms.res.result, ms.res.err); nil != err {
logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err))
}
ms.res.requestCodec.Close()
case protocol.MessageTypeNotification:
if err := s.sh.SendNotification(s.ctx, s.serverCodec, s.writer, ms.noti.method, ms.noti.args...); nil != err {
logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err))
}
default:
}
releaseMessageState(ms)
case <-s.stopChan:
return
}
}
}
type messageState struct { type messageState struct {
messageType protocol.MessageType messageType protocol.MessageType
res messageResponse res messageResponse

View File

@ -10,10 +10,10 @@ type ServletHandler interface {
Init(servletCTX ServletContext) error Init(servletCTX ServletContext) error
GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error)
Invoke(servletCTX ServletContext, requestCodec protocol.RegistryCodec) (result interface{}, err error) Invoke(servletCTX ServletContext, requestCodec protocol.RegistryCodec) (result interface{}, err error)
SendResponse(servletCTX ServletContext, requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error SendResponse(servletCTX ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error
SendNotification(servletCTX ServletContext, codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error SendNotification(servletCTX ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error
Destroy(servletCTX ServletContext) Destroy(servletCTX ServletContext)

View File

@ -29,7 +29,7 @@ func (sh *ServletHandlers) Init(servletCTX ServletContext) error {
return nil return nil
} }
func (sh *ServletHandlers) GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) { func (sh *ServletHandlers) GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) {
return nil, fmt.Errorf("Servlet Handler: GetRequest is not implemented") return nil, fmt.Errorf("Servlet Handler: GetRequest is not implemented")
} }
@ -37,11 +37,11 @@ func (sh *ServletHandlers) Invoke(servletCTX ServletContext, requestCodec protoc
return nil, fmt.Errorf("Servlet Handler: Invoke is not implemented") return nil, fmt.Errorf("Servlet Handler: Invoke is not implemented")
} }
func (sh *ServletHandlers) SendResponse(servletCTX ServletContext, requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error { func (sh *ServletHandlers) SendResponse(servletCTX ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error {
return fmt.Errorf("Servlet Handler: SendResponse is not implemented") return fmt.Errorf("Servlet Handler: SendResponse is not implemented")
} }
func (sh *ServletHandlers) SendNotification(servletCTX ServletContext, codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error { func (sh *ServletHandlers) SendNotification(servletCTX ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error {
return fmt.Errorf("Servlet Handler: SendNotification is not implemented") return fmt.Errorf("Servlet Handler: SendNotification is not implemented")
} }