package rpc import ( "fmt" "io" "runtime" "sync" "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/rpc/protocol" cuc "git.loafle.net/commons_go/util/context" ) func NewServlet(sh ServletHandler) Servlet { return &servlet{ sh: sh, } } type Servlet interface { Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error Stop() Send(method string, args ...interface{}) (err error) Context() ServletContext } type servlet struct { ctx ServletContext sh ServletHandler messageQueueChan chan *messageState doneChan chan<- error conn interface{} serverCodec protocol.ServerCodec stopChan chan struct{} stopWg sync.WaitGroup } func (s *servlet) Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error { if nil == s.sh { panic("Servlet: servlet handler must be specified.") } s.sh.Validate() if s.stopChan != nil { return fmt.Errorf("Servlet: servlet is already running. Stop it before starting it again") } servletCTX := s.sh.ServletContext(parentCTX) sc, err := s.sh.getCodec(servletCTX.GetAttribute(ContentTypeKey).(string)) if nil != err { return err } s.doneChan = doneChan s.conn = conn s.serverCodec = sc if err := s.sh.Init(s.ctx); nil != err { logging.Logger().Panic(fmt.Sprintf("Servlet: Initialization of servlet has been failed %v", err)) } s.stopChan = make(chan struct{}) s.messageQueueChan = make(chan *messageState, s.sh.GetPendingMessages()) s.stopWg.Add(1) go handleServlet(s) return nil } func (s *servlet) Stop() { if s.stopChan == nil { panic("Server: server must be started before stopping it") } close(s.stopChan) s.stopWg.Wait() s.stopChan = nil s.sh.Destroy(s.ctx) s.messageQueueChan = nil s.conn = nil s.serverCodec = nil logging.Logger().Info(fmt.Sprintf("Servlet is stopped")) } func (s *servlet) Send(method string, args ...interface{}) (err error) { ms := retainMessageState(protocol.MessageTypeNotification) ms.noti.method = method ms.noti.args = args s.messageQueueChan <- ms return nil } func (s *servlet) Context() ServletContext { return s.ctx } func handleServlet(s *servlet) { var err error defer func() { s.doneChan <- err s.stopWg.Done() }() subStopChan := make(chan struct{}) readerDone := make(chan error, 1) go handleReader(s, subStopChan, readerDone) writerDone := make(chan error, 1) go handleWriter(s, subStopChan, writerDone) select { case err = <-readerDone: close(subStopChan) <-writerDone case err = <-writerDone: close(subStopChan) <-readerDone case <-s.stopChan: close(subStopChan) <-readerDone <-writerDone } if err != nil { logging.Logger().Error(fmt.Sprintf("RPC Server: servlet error %v", err)) } } func handleReader(s *servlet, stopChan chan struct{}, doneChan chan error) { var err error defer func() { if r := recover(); r != nil { if err == nil { err = fmt.Errorf("RPC Server: Panic when reading request from client: %v", r) } } doneChan <- err }() for { requestCodec, err := s.sh.GetRequest(s.ctx, s.serverCodec, s.conn) if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { err = fmt.Errorf("RPC Server: disconnected from client") return } logging.Logger().Error(fmt.Sprintf("RPC Server: Cannot read request: [%s]", err)) continue } s.stopWg.Add(1) go handleRequest(s, requestCodec) select { case <-stopChan: err = fmt.Errorf("RPC Server: reading request stopped because get stop channel") return default: } } } func handleWriter(s *servlet, stopChan chan struct{}, doneChan chan error) { var err error defer func() { if r := recover(); r != nil { if err == nil { err = fmt.Errorf("RPC Server: Panic when writing message to client: %v", r) } } doneChan <- err }() for { var ms *messageState select { case ms = <-s.messageQueueChan: default: // Give the last chance for ready goroutines filling s.messageQueueChan :) runtime.Gosched() select { case <-stopChan: err = fmt.Errorf("RPC Server: writing message stopped because get stop channel") return case ms = <-s.messageQueueChan: } } switch ms.messageType { case protocol.MessageTypeResponse: if err := s.sh.SendResponse(s.ctx, s.conn, ms.res.requestCodec, ms.res.result, ms.res.err); nil != err { logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) } ms.res.requestCodec.Close() case protocol.MessageTypeNotification: if err := s.sh.SendNotification(s.ctx, s.conn, s.serverCodec, ms.noti.method, ms.noti.args...); nil != err { logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) } default: } } } func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) { defer func() { s.stopWg.Done() }() result, err := s.sh.Invoke(s.ctx, requestCodec) ms := retainMessageState(protocol.MessageTypeResponse) ms.res.requestCodec = requestCodec ms.res.result = result ms.res.err = err s.messageQueueChan <- ms } type messageState struct { messageType protocol.MessageType res messageResponse noti messageNotification } type messageResponse struct { requestCodec protocol.ServerRequestCodec result interface{} err error } type messageNotification struct { method string args []interface{} } var messageStatePool sync.Pool func retainMessageState(messageType protocol.MessageType) *messageState { var ms *messageState v := messageStatePool.Get() if v == nil { ms = &messageState{} } else { ms = v.(*messageState) } ms.messageType = messageType return ms } func releaseMessageState(ms *messageState) { ms.messageType = protocol.MessageTypeUnknown ms.res.requestCodec = nil ms.res.result = nil ms.res.err = nil ms.noti.method = "" ms.noti.args = nil messageStatePool.Put(ms) }