package rpc import ( "fmt" "sync" "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/rpc/protocol" ) func NewServlet(sh ServletHandler) Servlet { return &servlet{ sh: sh, } } type Servlet interface { Start(contentType string, reader interface{}, writer interface{}) error Stop() Send(method string, args ...interface{}) (err error) } type servlet struct { sh ServletHandler contentType string reader interface{} writer interface{} serverCodec protocol.ServerCodec messageQueueChan chan *messageState stopChan chan struct{} stopWg sync.WaitGroup } func (s *servlet) Start(contentType string, reader interface{}, writer interface{}) error { if nil == s.sh { panic("Servlet: servlet handler must be specified.") } s.sh.Validate() if s.stopChan != nil { panic("Servlet: servlet is already running. Stop it before starting it again") } sc, err := s.sh.getCodec(contentType) if nil != err { return err } s.contentType = contentType s.reader = reader s.writer = writer s.serverCodec = sc if err := s.sh.Init(); nil != err { logging.Logger().Panic(fmt.Sprintf("Servlet: Initialization of servlet has been failed %v", err)) } s.stopChan = make(chan struct{}) s.messageQueueChan = make(chan *messageState, s.sh.GetPendingMessages()) s.stopWg.Add(1) go handleServlet(s) return nil } func (s *servlet) Stop() { if s.stopChan == nil { panic("Server: server must be started before stopping it") } close(s.stopChan) s.stopWg.Wait() s.stopChan = nil s.sh.Destroy() s.contentType = "" s.reader = nil s.writer = nil s.serverCodec = nil logging.Logger().Info(fmt.Sprintf("Servlet is stopped")) } func (s *servlet) Send(method string, args ...interface{}) (err error) { ms := retainMessageState(protocol.MessageTypeNotification) ms.noti.method = method ms.noti.args = args s.messageQueueChan <- ms return nil } func handleServlet(s *servlet) { defer s.stopWg.Done() s.stopWg.Add(1) go handleMessage(s) for { requestCodec, err := s.sh.GetRequest(s.serverCodec, s.reader) if nil != err { continue } s.stopWg.Add(1) go handleRequest(s, requestCodec) select { case <-s.stopChan: default: } } } func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) { defer func() { s.stopWg.Done() }() result, err := s.sh.Invoke(requestCodec) ms := retainMessageState(protocol.MessageTypeResponse) ms.res.requestCodec = requestCodec ms.res.result = result ms.res.err = err s.messageQueueChan <- ms } func handleMessage(s *servlet) { defer func() { s.stopWg.Done() }() for { select { case ms := <-s.messageQueueChan: switch ms.messageType { case protocol.MessageTypeResponse: if err := s.sh.SendResponse(ms.res.requestCodec, s.writer, ms.res.result, ms.res.err); nil != err { logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) } ms.res.requestCodec.Close() case protocol.MessageTypeNotification: if err := s.sh.SendNotification(s.serverCodec, s.writer, ms.noti.method, ms.noti.args...); nil != err { logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) } default: } releaseMessageState(ms) case <-s.stopChan: return } } } type messageState struct { messageType protocol.MessageType res messageResponse noti messageNotification } type messageResponse struct { requestCodec protocol.ServerRequestCodec result interface{} err error } type messageNotification struct { method string args []interface{} } var messageStatePool sync.Pool func retainMessageState(messageType protocol.MessageType) *messageState { var ms *messageState v := messageStatePool.Get() if v == nil { ms = &messageState{} } else { ms = v.(*messageState) } ms.messageType = messageType return ms } func releaseMessageState(ms *messageState) { ms.messageType = protocol.MessageTypeUnknown ms.res.requestCodec = nil ms.res.result = nil ms.res.err = nil ms.noti.method = "" ms.noti.args = nil messageStatePool.Put(ms) }