package rpc import ( "fmt" "io" "runtime" "sync" "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/rpc/protocol" cuc "git.loafle.net/commons_go/util/context" ) func NewServlet(sh ServletHandler, rwcSH ServletReadWriteCloseHandler) Servlet { return &rpcServlet{ sh: sh, rwcSH: rwcSH, } } type Servlet interface { Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error Stop() Send(method string, args ...interface{}) (err error) Context() ServletContext } type rpcServlet struct { ctx ServletContext sh ServletHandler rwcSH ServletReadWriteCloseHandler responseQueueChan chan *responseState doneChan chan<- error conn interface{} serverCodec protocol.ServerCodec stopChan chan struct{} stopWg sync.WaitGroup } func (s *rpcServlet) Start(parentCTX cuc.Context, conn interface{}, doneChan chan<- error) error { if nil == s.sh { panic("Servlet: servlet handler must be specified.") } s.sh.Validate() if nil == s.rwcSH { panic("Servlet: servlet RWC handler must be specified.") } s.rwcSH.Validate() if s.stopChan != nil { return fmt.Errorf("Servlet: servlet is already running. Stop it before starting it again") } s.ctx = s.sh.ServletContext(parentCTX) sc, err := s.sh.getCodec(s.ctx.GetAttribute(ContentTypeKey).(string)) if nil != err { return err } s.doneChan = doneChan s.conn = conn s.serverCodec = sc if err := s.sh.Init(s.ctx); nil != err { logging.Logger().Panic(fmt.Sprintf("Servlet: Initialization of servlet has been failed %v", err)) } s.stopChan = make(chan struct{}) s.responseQueueChan = make(chan *responseState, s.sh.GetPendingResponses()) s.stopWg.Add(1) go handleServlet(s) return nil } func (s *rpcServlet) Stop() { if s.stopChan == nil { panic("Server: server must be started before stopping it") } close(s.stopChan) s.stopWg.Wait() s.sh.Destroy(s.ctx) s.stopChan = nil s.responseQueueChan = nil s.conn = nil s.serverCodec = nil logging.Logger().Info(fmt.Sprintf("Servlet is stopped")) } func (s *rpcServlet) Send(method string, args ...interface{}) (err error) { noti := ¬ification{ method: method, args: args, } rs := &responseState{ noti: noti, } s.responseQueueChan <- rs return nil } func (s *rpcServlet) Context() ServletContext { return s.ctx } func handleServlet(s *rpcServlet) { var err error logging.Logger().Info(fmt.Sprintf("Servlet is started")) defer func() { s.stopWg.Done() s.Stop() s.doneChan <- err }() subStopChan := make(chan struct{}) readerDone := make(chan error, 1) go handleReader(s, subStopChan, readerDone) writerDone := make(chan error, 1) go handleWriter(s, subStopChan, writerDone) select { case err = <-readerDone: close(subStopChan) <-writerDone case err = <-writerDone: close(subStopChan) <-readerDone case <-s.stopChan: close(subStopChan) <-readerDone <-writerDone } if err != nil { logging.Logger().Error(fmt.Sprintf("RPC Server: servlet error %v", err)) } } func handleReader(s *rpcServlet, stopChan chan struct{}, doneChan chan error) { logging.Logger().Debug(fmt.Sprintf("reader of Servlet is started")) var err error defer func() { logging.Logger().Debug(fmt.Sprintf("reader of Servlet is stopped")) if r := recover(); r != nil { if err == nil { err = fmt.Errorf("RPC Server: Panic when reading request from client: %v", r) } } doneChan <- err }() for { if nil == s.conn { err = fmt.Errorf("RPC Server: disconnected from client") return } requestCodec, err := s.rwcSH.ReadRequest(s.ctx, s.serverCodec, s.conn) if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { err = fmt.Errorf("RPC Server: disconnected from client") return } logging.Logger().Error(fmt.Sprintf("RPC Server: Cannot read request: [%s]", err)) continue } s.stopWg.Add(1) go handleRequest(s, requestCodec) select { case <-stopChan: err = fmt.Errorf("RPC Server: reading request stopped because get stop channel") return default: } } } func handleWriter(s *rpcServlet, stopChan chan struct{}, doneChan chan error) { logging.Logger().Debug(fmt.Sprintf("writer of Servlet is started")) var err error defer func() { logging.Logger().Debug(fmt.Sprintf("writer of Servlet is stopped")) if r := recover(); r != nil { if err == nil { err = fmt.Errorf("RPC Server: Panic when writing response to client: %v", r) } } doneChan <- err }() for { var rs *responseState select { case rs = <-s.responseQueueChan: default: // Give the last chance for ready goroutines filling s.responseQueueChan :) runtime.Gosched() select { case <-stopChan: err = fmt.Errorf("RPC Server: writing message stopped because get stop channel") return case rs = <-s.responseQueueChan: } } if nil == s.conn { err = fmt.Errorf("RPC Server: disconnected from client") return } if nil != rs.requestCodec { if err := s.rwcSH.WriteResponse(s.ctx, s.conn, rs.requestCodec, rs.result, rs.err); nil != err { logging.Logger().Error(fmt.Sprintf("RPC Server: response error %v", err)) } } else { if err := s.rwcSH.WriteNotification(s.ctx, s.conn, s.serverCodec, rs.noti.method, rs.noti.args); nil != err { logging.Logger().Error(fmt.Sprintf("RPC Server: notification error %v", err)) } } } } func handleRequest(s *rpcServlet, requestCodec protocol.ServerRequestCodec) { defer func() { s.stopWg.Done() }() result, err := s.sh.Invoke(s.ctx, requestCodec) rs := &responseState{ requestCodec: requestCodec, result: result, err: err, } s.responseQueueChan <- rs } type responseState struct { requestCodec protocol.ServerRequestCodec result interface{} noti *notification err error } type notification struct { method string args interface{} }