rpc/servlet.go
crusader 91fb8f2093 ing
2018-03-23 17:06:36 +09:00

276 lines
5.8 KiB
Go

package rpc
import (
"fmt"
"io"
"runtime"
"sync"
"git.loafle.net/commons_go/logging"
"git.loafle.net/commons_go/rpc/protocol"
cuc "git.loafle.net/commons_go/util/context"
)
func NewServlet(sh ServletHandler, rwcSH ServletReadWriteCloseHandler) Servlet {
return &rpcServlet{
sh: sh,
rwcSH: rwcSH,
}
}
type Servlet interface {
Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error
Stop()
Send(method string, args ...interface{}) (err error)
Context() ServletContext
}
type rpcServlet struct {
ctx ServletContext
sh ServletHandler
rwcSH ServletReadWriteCloseHandler
responseQueueChan chan *responseState
conn interface{}
serverCodec protocol.ServerCodec
decoder interface{}
stopChan chan struct{}
stopWg sync.WaitGroup
requestMtx sync.Mutex
responseMTX sync.Mutex
}
func (s *rpcServlet) Context() ServletContext {
return s.ctx
}
func (s *rpcServlet) Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error {
if nil == s.sh {
return fmt.Errorf("RPC Servlet: servlet handler must be specified")
}
s.sh.Validate()
if nil == s.rwcSH {
return fmt.Errorf("RPC Servlet: servlet RWC handler must be specified")
}
s.rwcSH.Validate()
if s.stopChan != nil {
return fmt.Errorf("RPC Servlet: servlet is already running. Stop it before starting it again")
}
s.ctx = s.sh.ServletContext(parentCTX)
sc, err := s.sh.getCodec(s.ctx.GetAttribute(ContentTypeKey).(string))
if nil != err {
return err
}
s.conn = conn
s.serverCodec = sc
s.decoder, err = s.rwcSH.NewDecoder(s.ctx, sc, conn)
if nil != err {
return fmt.Errorf("RPC Servlet: Cannot build rpc decoder")
}
if err := s.sh.Init(s.ctx); nil != err {
return fmt.Errorf("RPC Servlet: Initialization of servlet has been failed %v", err)
}
s.responseQueueChan = make(chan *responseState, s.sh.GetPendingResponses())
s.stopChan = make(chan struct{})
s.stopWg.Add(1)
go handleServlet(s, doneChan)
return nil
}
func (s *rpcServlet) Send(method string, args ...interface{}) (err error) {
noti := &notification{
method: method,
args: args,
}
rs := &responseState{
noti: noti,
}
s.responseQueueChan <- rs
return nil
}
func (s *rpcServlet) Stop() {
if s.stopChan == nil {
logging.Logger().Warnf("RPC Servlet: RPC Servlet must be started before stopping it")
return
}
close(s.stopChan)
s.stopWg.Wait()
s.sh.Destroy(s.ctx)
s.stopChan = nil
s.responseQueueChan = nil
s.conn = nil
s.serverCodec = nil
logging.Logger().Infof("RPC Servlet: RPC Servlet is stopped")
}
func handleServlet(s *rpcServlet, doneChan chan<- error) {
var err error
logging.Logger().Infof("RPC Servlet: RPC Servlet is started")
defer func() {
s.stopWg.Done()
doneChan <- err
}()
subStopChan := make(chan struct{})
readerDone := make(chan error, 1)
go handleReader(s, subStopChan, readerDone)
writerDone := make(chan error, 1)
go handleWriter(s, subStopChan, writerDone)
select {
case err = <-readerDone:
close(subStopChan)
<-writerDone
case err = <-writerDone:
close(subStopChan)
<-readerDone
case <-s.stopChan:
close(subStopChan)
<-readerDone
<-writerDone
}
if err != nil {
logging.Logger().Errorf("RPC Servlet: servlet error %v", err)
}
}
func handleReader(s *rpcServlet, stopChan chan struct{}, doneChan chan error) {
logging.Logger().Debugf("RPC Servlet: Reader of Servlet is started")
var err error
defer func() {
logging.Logger().Debugf("RPC Servlet: Reader of Servlet is stopped")
doneChan <- err
}()
for {
if nil == s.conn {
err = fmt.Errorf("RPC Servlet: Disconnected from client")
return
}
s.requestMtx.Lock()
requestCodec, err := s.rwcSH.ReadRequest(s.ctx, s.serverCodec, s.decoder)
s.requestMtx.Unlock()
if nil != err {
if err == io.ErrUnexpectedEOF || err == io.EOF {
err = fmt.Errorf("RPC Server: Disconnected from client")
return
}
logging.Logger().Errorf("RPC Servlet: Cannot read request: [%s]", err)
continue
}
s.stopWg.Add(1)
go handleRequest(s, requestCodec)
select {
case <-stopChan:
err = fmt.Errorf("RPC Servlet: Reading request stopped because get stop channel")
return
default:
}
}
}
func handleWriter(s *rpcServlet, stopChan chan struct{}, doneChan chan error) {
logging.Logger().Debugf("RPC Servlet: Writer of Servlet is started")
var err error
defer func() {
logging.Logger().Debugf("RPC Servlet: Writer of Servlet is stopped")
doneChan <- err
}()
for {
var rs *responseState
select {
case rs = <-s.responseQueueChan:
default:
// Give the last chance for ready goroutines filling s.responseQueueChan :)
runtime.Gosched()
select {
case <-stopChan:
err = fmt.Errorf("RPC Servlet: Writing message stopped because get stop channel")
return
case rs = <-s.responseQueueChan:
}
}
if nil == s.conn {
err = fmt.Errorf("RPC Servlet: Disconnected from client")
return
}
s.responseMTX.Lock()
if nil != rs.requestCodec {
if err := s.rwcSH.WriteResponse(s.ctx, s.conn, rs.requestCodec, rs.result, rs.err); nil != err {
logging.Logger().Errorf("RPC Servlet: response error %v", err)
}
} else {
if err := s.rwcSH.WriteNotification(s.ctx, s.conn, s.serverCodec, rs.noti.method, rs.noti.args); nil != err {
logging.Logger().Errorf("RPC Servlet: notification error %v", err)
}
}
s.responseMTX.Unlock()
}
}
func handleRequest(s *rpcServlet, requestCodec protocol.ServerRequestCodec) {
defer func() {
s.stopWg.Done()
}()
result, err := s.sh.Invoke(s.ctx, requestCodec)
if !requestCodec.HasResponse() {
return
}
rs := &responseState{
requestCodec: requestCodec,
result: result,
err: err,
}
s.responseQueueChan <- rs
}
type responseState struct {
requestCodec protocol.ServerRequestCodec
result interface{}
noti *notification
err error
}
type notification struct {
method string
args []interface{}
}