209 lines
4.0 KiB
Go
209 lines
4.0 KiB
Go
package rpc
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
|
|
"git.loafle.net/commons_go/logging"
|
|
"git.loafle.net/commons_go/rpc/protocol"
|
|
)
|
|
|
|
func NewServlet(sh ServletHandler) Servlet {
|
|
return &servlet{
|
|
sh: sh,
|
|
}
|
|
}
|
|
|
|
type Servlet interface {
|
|
Start(contentType string, reader interface{}, writer interface{}) error
|
|
Stop()
|
|
|
|
Send(method string, args ...interface{}) (err error)
|
|
}
|
|
|
|
type servlet struct {
|
|
sh ServletHandler
|
|
|
|
contentType string
|
|
reader interface{}
|
|
writer interface{}
|
|
serverCodec protocol.ServerCodec
|
|
|
|
messageQueueChan chan *messageState
|
|
|
|
stopChan chan struct{}
|
|
stopWg sync.WaitGroup
|
|
}
|
|
|
|
func (s *servlet) Start(contentType string, reader interface{}, writer interface{}) error {
|
|
if nil == s.sh {
|
|
panic("Servlet: servlet handler must be specified.")
|
|
}
|
|
s.sh.Validate()
|
|
|
|
if s.stopChan != nil {
|
|
panic("Servlet: servlet is already running. Stop it before starting it again")
|
|
}
|
|
|
|
sc, err := s.sh.getCodec(contentType)
|
|
if nil != err {
|
|
return err
|
|
}
|
|
|
|
s.contentType = contentType
|
|
s.reader = reader
|
|
s.writer = writer
|
|
s.serverCodec = sc
|
|
|
|
if err := s.sh.Init(); nil != err {
|
|
logging.Logger().Panic(fmt.Sprintf("Servlet: Initialization of servlet has been failed %v", err))
|
|
}
|
|
|
|
s.stopChan = make(chan struct{})
|
|
s.messageQueueChan = make(chan *messageState, s.sh.GetPendingMessages())
|
|
|
|
s.stopWg.Add(1)
|
|
go handleServlet(s)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *servlet) Stop() {
|
|
if s.stopChan == nil {
|
|
panic("Server: server must be started before stopping it")
|
|
}
|
|
close(s.stopChan)
|
|
s.stopWg.Wait()
|
|
s.stopChan = nil
|
|
|
|
s.sh.Destroy()
|
|
|
|
s.contentType = ""
|
|
s.reader = nil
|
|
s.writer = nil
|
|
s.serverCodec = nil
|
|
|
|
logging.Logger().Info(fmt.Sprintf("Servlet is stopped"))
|
|
}
|
|
|
|
func (s *servlet) Send(method string, args ...interface{}) (err error) {
|
|
ms := retainMessageState(protocol.MessageTypeNotification)
|
|
ms.noti.method = method
|
|
ms.noti.args = args
|
|
|
|
s.messageQueueChan <- ms
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleServlet(s *servlet) {
|
|
defer s.stopWg.Done()
|
|
|
|
s.stopWg.Add(1)
|
|
go handleMessage(s)
|
|
|
|
for {
|
|
requestCodec, err := s.sh.GetRequest(s.serverCodec, s.reader)
|
|
if nil != err {
|
|
continue
|
|
}
|
|
|
|
s.stopWg.Add(1)
|
|
go handleRequest(s, requestCodec)
|
|
|
|
select {
|
|
case <-s.stopChan:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) {
|
|
defer func() {
|
|
s.stopWg.Done()
|
|
}()
|
|
|
|
result, err := s.sh.Invoke(requestCodec)
|
|
|
|
ms := retainMessageState(protocol.MessageTypeResponse)
|
|
ms.res.requestCodec = requestCodec
|
|
ms.res.result = result
|
|
ms.res.err = err
|
|
|
|
s.messageQueueChan <- ms
|
|
}
|
|
|
|
func handleMessage(s *servlet) {
|
|
defer func() {
|
|
s.stopWg.Done()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case ms := <-s.messageQueueChan:
|
|
switch ms.messageType {
|
|
case protocol.MessageTypeResponse:
|
|
if err := s.sh.SendResponse(ms.res.requestCodec, s.writer, ms.res.result, ms.res.err); nil != err {
|
|
logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err))
|
|
}
|
|
ms.res.requestCodec.Close()
|
|
case protocol.MessageTypeNotification:
|
|
if err := s.sh.SendNotification(s.serverCodec, s.writer, ms.noti.method, ms.noti.args...); nil != err {
|
|
logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err))
|
|
}
|
|
default:
|
|
}
|
|
|
|
releaseMessageState(ms)
|
|
case <-s.stopChan:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type messageState struct {
|
|
messageType protocol.MessageType
|
|
res messageResponse
|
|
noti messageNotification
|
|
}
|
|
|
|
type messageResponse struct {
|
|
requestCodec protocol.ServerRequestCodec
|
|
result interface{}
|
|
err error
|
|
}
|
|
|
|
type messageNotification struct {
|
|
method string
|
|
args []interface{}
|
|
}
|
|
|
|
var messageStatePool sync.Pool
|
|
|
|
func retainMessageState(messageType protocol.MessageType) *messageState {
|
|
var ms *messageState
|
|
v := messageStatePool.Get()
|
|
if v == nil {
|
|
ms = &messageState{}
|
|
} else {
|
|
ms = v.(*messageState)
|
|
}
|
|
|
|
ms.messageType = messageType
|
|
|
|
return ms
|
|
}
|
|
|
|
func releaseMessageState(ms *messageState) {
|
|
ms.messageType = protocol.MessageTypeUnknown
|
|
|
|
ms.res.requestCodec = nil
|
|
ms.res.result = nil
|
|
ms.res.err = nil
|
|
|
|
ms.noti.method = ""
|
|
ms.noti.args = nil
|
|
|
|
messageStatePool.Put(ms)
|
|
}
|