package rpc import ( "fmt" "io" "runtime" "sync" "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/rpc/protocol" cuc "git.loafle.net/commons_go/util/context" ) func NewServlet(sh ServletHandler, rwcSH ServletReadWriteCloseHandler) Servlet { return &rpcServlet{ sh: sh, rwcSH: rwcSH, } } type Servlet interface { Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error Stop() Send(method string, args ...interface{}) (err error) Context() ServletContext } type rpcServlet struct { ctx ServletContext sh ServletHandler rwcSH ServletReadWriteCloseHandler responseQueueChan chan *responseState conn interface{} serverCodec protocol.ServerCodec stopChan chan struct{} stopWg sync.WaitGroup } func (s *rpcServlet) Context() ServletContext { return s.ctx } func (s *rpcServlet) Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error { if nil == s.sh { return fmt.Errorf("RPC Servlet: servlet handler must be specified") } s.sh.Validate() if nil == s.rwcSH { return fmt.Errorf("RPC Servlet: servlet RWC handler must be specified") } s.rwcSH.Validate() if s.stopChan != nil { return fmt.Errorf("RPC Servlet: servlet is already running. Stop it before starting it again") } s.ctx = s.sh.ServletContext(parentCTX) sc, err := s.sh.getCodec(s.ctx.GetAttribute(ContentTypeKey).(string)) if nil != err { return err } s.conn = conn s.serverCodec = sc if err := s.sh.Init(s.ctx); nil != err { return fmt.Errorf("RPC Servlet: Initialization of servlet has been failed %v", err) } s.responseQueueChan = make(chan *responseState, s.sh.GetPendingResponses()) s.stopChan = make(chan struct{}) s.stopWg.Add(1) go handleServlet(s, doneChan) return nil } func (s *rpcServlet) Send(method string, args ...interface{}) (err error) { noti := ¬ification{ method: method, args: args, } rs := &responseState{ noti: noti, } s.responseQueueChan <- rs return nil } func (s *rpcServlet) Stop() { if s.stopChan == nil { logging.Logger().Warnf("RPC Servlet: RPC Servlet must be started before stopping it") return } close(s.stopChan) s.stopWg.Wait() s.sh.Destroy(s.ctx) s.stopChan = nil s.responseQueueChan = nil s.conn = nil s.serverCodec = nil logging.Logger().Infof("RPC Servlet: RPC Servlet is stopped") } func handleServlet(s *rpcServlet, doneChan chan<- error) { var err error logging.Logger().Infof("RPC Servlet: RPC Servlet is started") defer func() { s.stopWg.Done() doneChan <- err }() subStopChan := make(chan struct{}) readerDone := make(chan error, 1) go handleReader(s, subStopChan, readerDone) writerDone := make(chan error, 1) go handleWriter(s, subStopChan, writerDone) select { case err = <-readerDone: close(subStopChan) <-writerDone case err = <-writerDone: close(subStopChan) <-readerDone case <-s.stopChan: close(subStopChan) <-readerDone <-writerDone } if err != nil { logging.Logger().Errorf("RPC Servlet: servlet error %v", err) } } func handleReader(s *rpcServlet, stopChan chan struct{}, doneChan chan error) { logging.Logger().Debugf("RPC Servlet: Reader of Servlet is started") var err error defer func() { logging.Logger().Debugf("RPC Servlet: Reader of Servlet is stopped") doneChan <- err }() for { if nil == s.conn { err = fmt.Errorf("RPC Servlet: Disconnected from client") return } requestCodec, err := s.rwcSH.ReadRequest(s.ctx, s.serverCodec, s.conn) if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { err = fmt.Errorf("RPC Server: Disconnected from client") return } logging.Logger().Errorf("RPC Servlet: Cannot read request: [%s]", err) continue } s.stopWg.Add(1) go handleRequest(s, requestCodec) select { case <-stopChan: err = fmt.Errorf("RPC Servlet: Reading request stopped because get stop channel") return default: } } } func handleWriter(s *rpcServlet, stopChan chan struct{}, doneChan chan error) { logging.Logger().Debugf("RPC Servlet: Writer of Servlet is started") var err error defer func() { logging.Logger().Debugf("RPC Servlet: Writer of Servlet is stopped") doneChan <- err }() for { var rs *responseState select { case rs = <-s.responseQueueChan: default: // Give the last chance for ready goroutines filling s.responseQueueChan :) runtime.Gosched() select { case <-stopChan: err = fmt.Errorf("RPC Servlet: Writing message stopped because get stop channel") return case rs = <-s.responseQueueChan: } } if nil == s.conn { err = fmt.Errorf("RPC Servlet: Disconnected from client") return } if nil != rs.requestCodec { if err := s.rwcSH.WriteResponse(s.ctx, s.conn, rs.requestCodec, rs.result, rs.err); nil != err { logging.Logger().Errorf("RPC Servlet: response error %v", err) } } else { if err := s.rwcSH.WriteNotification(s.ctx, s.conn, s.serverCodec, rs.noti.method, rs.noti.args); nil != err { logging.Logger().Errorf("RPC Servlet: notification error %v", err) } } } } func handleRequest(s *rpcServlet, requestCodec protocol.ServerRequestCodec) { defer func() { s.stopWg.Done() }() result, err := s.sh.Invoke(s.ctx, requestCodec) if !requestCodec.HasResponse() { return } rs := &responseState{ requestCodec: requestCodec, result: result, err: err, } s.responseQueueChan <- rs } type responseState struct { requestCodec protocol.ServerRequestCodec result interface{} noti *notification err error } type notification struct { method string args []interface{} }