From 96b71dc5d5885bbf1f09fea6871256b515d10c3e Mon Sep 17 00:00:00 2001 From: crusader Date: Tue, 28 Nov 2017 18:33:52 +0900 Subject: [PATCH] ing --- connection/socket/servlet_handlers.go | 29 ++-- .../websocket/fasthttp/servlet_handlers.go | 46 +++--- servlet.go | 143 ++++++++++++------ servlet_handler.go | 6 +- servlet_handlers.go | 6 +- 5 files changed, 144 insertions(+), 86 deletions(-) diff --git a/connection/socket/servlet_handlers.go b/connection/socket/servlet_handlers.go index 1c5beaf..09b8525 100644 --- a/connection/socket/servlet_handlers.go +++ b/connection/socket/servlet_handlers.go @@ -1,42 +1,43 @@ package socket import ( - "io" + "net" + "git.loafle.net/commons_go/rpc" "git.loafle.net/commons_go/rpc/protocol" ) type ServletHandlers struct { } -func (sh *ServletHandlers) GetRequest(codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) { - r := reader.(io.Reader) - requestCodec, err := codec.NewRequest(r) +func (sh *ServletHandlers) GetRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) { + nConn := conn.(net.Conn) + requestCodec, err := codec.NewRequest(nConn) return requestCodec, err } -func (sh *ServletHandlers) SendResponse(requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error { - w := writer.(io.Writer) +func (sh *ServletHandlers) SendResponse(servletCTX rpc.ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error { + nConn := conn.(net.Conn) if nil != err { - if lerr := requestCodec.WriteError(w, 500, err); nil != lerr { - + if wErr := requestCodec.WriteError(nConn, 500, err); nil != wErr { + return wErr } } else { - if err := requestCodec.WriteResponse(w, result); nil != err { - + if wErr := requestCodec.WriteResponse(nConn, result); nil != wErr { + return wErr } } return nil } -func (sh *ServletHandlers) SendNotification(codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error { - w := writer.(io.Writer) - - if err := codec.WriteNotification(w, method, args); nil != err { +func (sh *ServletHandlers) SendNotification(servletCTX rpc.ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error { + nConn := conn.(net.Conn) + if wErr := codec.WriteNotification(nConn, method, args); nil != wErr { + return wErr } return nil diff --git a/connection/websocket/fasthttp/servlet_handlers.go b/connection/websocket/fasthttp/servlet_handlers.go index 9e74b86..33958c6 100644 --- a/connection/websocket/fasthttp/servlet_handlers.go +++ b/connection/websocket/fasthttp/servlet_handlers.go @@ -1,57 +1,57 @@ package fasthttp import ( - "fmt" + "github.com/gorilla/websocket" "git.loafle.net/commons_go/rpc" "git.loafle.net/commons_go/rpc/protocol" - "git.loafle.net/commons_go/websocket_fasthttp/websocket" + cwf "git.loafle.net/commons_go/websocket_fasthttp" ) type ServletHandlers struct { } -func (sh *ServletHandlers) GetRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) { - conn := reader.(*websocket.Conn) - _, r, err := conn.NextReader() +func (sh *ServletHandlers) GetRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) { + soc := conn.(cwf.Socket) + _, r, err := soc.NextReader() requestCodec, err := codec.NewRequest(r) return requestCodec, err } -func (sh *ServletHandlers) SendResponse(servletCTX rpc.ServletContext, requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error { - conn := writer.(*websocket.Conn) - - wc, lerr := conn.NextWriter(websocket.TextMessage) - if nil != lerr { +func (sh *ServletHandlers) SendResponse(servletCTX rpc.ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error { + soc := conn.(cwf.Socket) + wc, wErr := soc.NextWriter(websocket.TextMessage) + if nil != wErr { + return wErr } if nil != err { - if lerr := requestCodec.WriteError(wc, 500, err); nil != lerr { - + if wErr := requestCodec.WriteError(wc, 500, err); nil != wErr { + return wErr } } else { - if err := requestCodec.WriteResponse(wc, result); nil != err { - + if wErr := requestCodec.WriteResponse(wc, result); nil != wErr { + return wErr } } - return fmt.Errorf("Servlet Handler: SendResponse is not implemented") + return nil } -func (sh *ServletHandlers) SendNotification(servletCTX rpc.ServletContext, codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error { - conn := writer.(*websocket.Conn) - - wc, lerr := conn.NextWriter(websocket.TextMessage) - if nil != lerr { +func (sh *ServletHandlers) SendNotification(servletCTX rpc.ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error { + soc := conn.(cwf.Socket) + wc, wErr := soc.NextWriter(websocket.TextMessage) + if nil != wErr { + return wErr } - if err := codec.WriteNotification(wc, method, args); nil != err { - + if wErr := codec.WriteNotification(wc, method, args); nil != wErr { + return wErr } - return fmt.Errorf("Servlet Handler: SendNotification is not implemented") + return nil } diff --git a/servlet.go b/servlet.go index 0a5b61c..c113e66 100644 --- a/servlet.go +++ b/servlet.go @@ -2,6 +2,8 @@ package rpc import ( "fmt" + "io" + "runtime" "sync" "git.loafle.net/commons_go/logging" @@ -16,7 +18,7 @@ func NewServlet(sh ServletHandler) Servlet { } type Servlet interface { - Start(parentCTX cuc.Context, reader interface{}, writer interface{}) error + Start(parentCTX cuc.Context, conn interface{}) error Stop() Send(method string, args ...interface{}) (err error) @@ -29,15 +31,14 @@ type servlet struct { sh ServletHandler messageQueueChan chan *messageState - reader interface{} - writer interface{} + conn interface{} serverCodec protocol.ServerCodec stopChan chan struct{} stopWg sync.WaitGroup } -func (s *servlet) Start(parentCTX cuc.Context, reader interface{}, writer interface{}, doneChan chan<- struct{}) error { +func (s *servlet) Start(parentCTX cuc.Context, conn interface{}) error { if nil == s.sh { panic("Servlet: servlet handler must be specified.") } @@ -53,8 +54,7 @@ func (s *servlet) Start(parentCTX cuc.Context, reader interface{}, writer interf return err } - s.reader = reader - s.writer = writer + s.conn = conn s.serverCodec = sc if err := s.sh.Init(s.ctx); nil != err { @@ -65,7 +65,7 @@ func (s *servlet) Start(parentCTX cuc.Context, reader interface{}, writer interf s.messageQueueChan = make(chan *messageState, s.sh.GetPendingMessages()) s.stopWg.Add(1) - go handleServlet(s, doneChan) + go handleServlet(s) return nil } @@ -81,8 +81,7 @@ func (s *servlet) Stop() { s.messageQueueChan = nil - s.reader = nil - s.writer = nil + s.conn = nil s.serverCodec = nil logging.Logger().Info(fmt.Sprintf("Servlet is stopped")) @@ -102,17 +101,56 @@ func (s *servlet) Context() ServletContext { return s.ctx } -func handleServlet(s *servlet, doneChan chan<- struct{}) { +func handleServlet(s *servlet) { defer s.stopWg.Done() - messageStopChan := make(chan struct{}) - messageDoneChan := make(chan struct{}) + subStopChan := make(chan struct{}) - go handleMessage(s, messageStopChan, messageDoneChan) + readerDone := make(chan error, 1) + go handleReader(s, subStopChan, readerDone) + + writerDone := make(chan error, 1) + go handleWriter(s, subStopChan, writerDone) + + var err error + + select { + case err = <-readerDone: + close(subStopChan) + <-writerDone + case err = <-writerDone: + close(subStopChan) + <-readerDone + case <-s.stopChan: + close(subStopChan) + <-readerDone + <-writerDone + } + + if err != nil { + logging.Logger().Error(fmt.Sprintf("RPC Server: servlet error %v", err)) + } +} + +func handleReader(s *servlet, stopChan chan struct{}, doneChan chan error) { + var err error + defer func() { + if r := recover(); r != nil { + if err == nil { + err = fmt.Errorf("RPC Server: Panic when reading request from client: %v", r) + } + } + doneChan <- err + }() for { - requestCodec, err := s.sh.GetRequest(s.ctx, s.serverCodec, s.reader) + requestCodec, err := s.sh.GetRequest(s.ctx, s.serverCodec, s.conn) if nil != err { + if err == io.ErrUnexpectedEOF || err == io.EOF { + err = fmt.Errorf("RPC Server: disconnected from client") + return + } + logging.Logger().Error(fmt.Sprintf("RPC Server: Cannot read request: [%s]", err)) continue } @@ -120,12 +158,59 @@ func handleServlet(s *servlet, doneChan chan<- struct{}) { go handleRequest(s, requestCodec) select { - case <-s.stopChan: + case <-stopChan: + err = fmt.Errorf("RPC Server: reading request stopped because get stop channel") + return default: } } } +func handleWriter(s *servlet, stopChan chan struct{}, doneChan chan error) { + var err error + defer func() { + if r := recover(); r != nil { + if err == nil { + err = fmt.Errorf("RPC Server: Panic when writing message to client: %v", r) + } + } + doneChan <- err + }() + + for { + var ms *messageState + + select { + case ms = <-s.messageQueueChan: + default: + // Give the last chance for ready goroutines filling s.messageQueueChan :) + runtime.Gosched() + + select { + case <-stopChan: + err = fmt.Errorf("RPC Server: writing message stopped because get stop channel") + return + case ms = <-s.messageQueueChan: + } + } + + switch ms.messageType { + case protocol.MessageTypeResponse: + if err := s.sh.SendResponse(s.ctx, s.conn, ms.res.requestCodec, ms.res.result, ms.res.err); nil != err { + logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) + } + ms.res.requestCodec.Close() + case protocol.MessageTypeNotification: + if err := s.sh.SendNotification(s.ctx, s.conn, s.serverCodec, ms.noti.method, ms.noti.args...); nil != err { + logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) + } + default: + } + + } + +} + func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) { defer func() { s.stopWg.Done() @@ -141,34 +226,6 @@ func handleRequest(s *servlet, requestCodec protocol.ServerRequestCodec) { s.messageQueueChan <- ms } -func handleMessage(s *servlet) { - defer func() { - s.stopWg.Done() - }() - - for { - select { - case ms := <-s.messageQueueChan: - switch ms.messageType { - case protocol.MessageTypeResponse: - if err := s.sh.SendResponse(s.ctx, ms.res.requestCodec, s.writer, ms.res.result, ms.res.err); nil != err { - logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) - } - ms.res.requestCodec.Close() - case protocol.MessageTypeNotification: - if err := s.sh.SendNotification(s.ctx, s.serverCodec, s.writer, ms.noti.method, ms.noti.args...); nil != err { - logging.Logger().Error(fmt.Sprintf("RPC Server: response message error %v", err)) - } - default: - } - - releaseMessageState(ms) - case <-s.stopChan: - return - } - } -} - type messageState struct { messageType protocol.MessageType res messageResponse diff --git a/servlet_handler.go b/servlet_handler.go index d8f1d4a..c46c98b 100644 --- a/servlet_handler.go +++ b/servlet_handler.go @@ -10,10 +10,10 @@ type ServletHandler interface { Init(servletCTX ServletContext) error - GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) + GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) Invoke(servletCTX ServletContext, requestCodec protocol.RegistryCodec) (result interface{}, err error) - SendResponse(servletCTX ServletContext, requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error - SendNotification(servletCTX ServletContext, codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error + SendResponse(servletCTX ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error + SendNotification(servletCTX ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error Destroy(servletCTX ServletContext) diff --git a/servlet_handlers.go b/servlet_handlers.go index 57a8063..093ef9e 100644 --- a/servlet_handlers.go +++ b/servlet_handlers.go @@ -29,7 +29,7 @@ func (sh *ServletHandlers) Init(servletCTX ServletContext) error { return nil } -func (sh *ServletHandlers) GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, reader interface{}) (protocol.ServerRequestCodec, error) { +func (sh *ServletHandlers) GetRequest(servletCTX ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) { return nil, fmt.Errorf("Servlet Handler: GetRequest is not implemented") } @@ -37,11 +37,11 @@ func (sh *ServletHandlers) Invoke(servletCTX ServletContext, requestCodec protoc return nil, fmt.Errorf("Servlet Handler: Invoke is not implemented") } -func (sh *ServletHandlers) SendResponse(servletCTX ServletContext, requestCodec protocol.ServerRequestCodec, writer interface{}, result interface{}, err error) error { +func (sh *ServletHandlers) SendResponse(servletCTX ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error { return fmt.Errorf("Servlet Handler: SendResponse is not implemented") } -func (sh *ServletHandlers) SendNotification(servletCTX ServletContext, codec protocol.ServerCodec, writer interface{}, method string, args ...interface{}) error { +func (sh *ServletHandlers) SendNotification(servletCTX ServletContext, conn interface{}, codec protocol.ServerCodec, method string, args ...interface{}) error { return fmt.Errorf("Servlet Handler: SendNotification is not implemented") }