diff --git a/client/client.go b/client/client.go index 2f323a8..56fe72d 100644 --- a/client/client.go +++ b/client/client.go @@ -222,7 +222,9 @@ func (c *client) handleRPC() { <-writerDone } - c.rwcHandler.Disconnect(c.ctx, c.conn) + if nil != c.conn { + c.rwcHandler.Disconnect(c.ctx, c.conn) + } if err != nil { //c.LogError("%s", err) @@ -295,6 +297,11 @@ func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) { requestID = rs.ID } + if nil == c.conn { + err = io.EOF + return + } + err = c.rwcHandler.WriteRequest(c.ctx, c.ch.GetCodec(), c.conn, rs.Method, rs.Args, requestID) if !rs.hasResponse { rs.Error = err @@ -325,6 +332,10 @@ func (c *client) rpcReader(readerDone chan<- error) { }() for { + if nil == c.conn { + err = io.EOF + return + } resCodec, err := c.rwcHandler.ReadResponse(c.ctx, c.ch.GetCodec(), c.conn) if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { diff --git a/client/rwc/socket/client_rwc_handlers.go b/client/rwc/socket/client_rwc_handlers.go index f8d3dd1..5944db5 100644 --- a/client/rwc/socket/client_rwc_handlers.go +++ b/client/rwc/socket/client_rwc_handlers.go @@ -1,18 +1,13 @@ package socket import ( - "io" - - "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/rpc/client" "git.loafle.net/commons_go/rpc/protocol" "git.loafle.net/commons_go/server" ) -func New(address string) client.ClientReadWriteCloseHandler { - return &ClientReadWriteCloseHandlers{ - Address: address, - } +func New() client.ClientReadWriteCloseHandler { + return &ClientReadWriteCloseHandlers{} } type ClientReadWriteCloseHandlers struct { @@ -20,10 +15,6 @@ type ClientReadWriteCloseHandlers struct { } func (crwch *ClientReadWriteCloseHandlers) ReadResponse(clientCTX client.ClientContext, codec protocol.ClientCodec, conn interface{}) (protocol.ClientResponseCodec, error) { - if nil == conn { - return nil, io.EOF - } - soc := conn.(server.Socket) resCodec, err := codec.NewResponse(soc) @@ -31,13 +22,9 @@ func (crwch *ClientReadWriteCloseHandlers) ReadResponse(clientCTX client.ClientC } func (crwch *ClientReadWriteCloseHandlers) WriteRequest(clientCTX client.ClientContext, codec protocol.ClientCodec, conn interface{}, method string, params interface{}, id interface{}) error { - if nil == conn { - return io.EOF - } - soc := conn.(server.Socket) - if wErr := codec.WriteRequest(soc, method, params); nil != wErr { + if wErr := codec.WriteRequest(soc, method, params, id); nil != wErr { return wErr } @@ -45,16 +32,10 @@ func (crwch *ClientReadWriteCloseHandlers) WriteRequest(clientCTX client.ClientC } func (crwch *ClientReadWriteCloseHandlers) Disconnect(clientCTX client.ClientContext, conn interface{}) { - if nil == conn { - return - } - soc := conn.(server.Socket) soc.Close() } func (crwch *ClientReadWriteCloseHandlers) Validate() { - if "" == crwch.Address { - logging.Logger().Panic("RPC Client RWC Handler: Address must be specified") - } + crwch.ClientReadWriteCloseHandlers.Validate() } diff --git a/client/rwc/websocket/fasthttp/client_rwc_handlers.go b/client/rwc/websocket/fasthttp/client_rwc_handlers.go index dade377..8deae97 100644 --- a/client/rwc/websocket/fasthttp/client_rwc_handlers.go +++ b/client/rwc/websocket/fasthttp/client_rwc_handlers.go @@ -1,8 +1,6 @@ package fasthttp import ( - "io" - "github.com/gorilla/websocket" "git.loafle.net/commons_go/rpc/client" @@ -19,10 +17,6 @@ type ClientReadWriteCloseHandlers struct { } func (crwch *ClientReadWriteCloseHandlers) ReadResponse(clientCTX client.ClientContext, codec protocol.ClientCodec, conn interface{}) (protocol.ClientResponseCodec, error) { - if nil == conn { - return nil, io.EOF - } - soc := conn.(cwf.Socket) _, r, err := soc.NextReader() @@ -32,10 +26,6 @@ func (crwch *ClientReadWriteCloseHandlers) ReadResponse(clientCTX client.ClientC } func (crwch *ClientReadWriteCloseHandlers) WriteRequest(clientCTX client.ClientContext, codec protocol.ClientCodec, conn interface{}, method string, params interface{}, id interface{}) error { - if nil == conn { - return io.EOF - } - soc := conn.(cwf.Socket) wc, wErr := soc.NextWriter(websocket.TextMessage) @@ -46,7 +36,7 @@ func (crwch *ClientReadWriteCloseHandlers) WriteRequest(clientCTX client.ClientC wc.Close() }() - if wErr := codec.WriteRequest(wc, method, params); nil != wErr { + if wErr := codec.WriteRequest(wc, method, params, id); nil != wErr { return wErr } @@ -54,14 +44,10 @@ func (crwch *ClientReadWriteCloseHandlers) WriteRequest(clientCTX client.ClientC } func (crwch *ClientReadWriteCloseHandlers) Disconnect(clientCTX client.ClientContext, conn interface{}) { - if nil == conn { - return - } - soc := conn.(cwf.Socket) soc.Close() } func (crwch *ClientReadWriteCloseHandlers) Validate() { - + crwch.ClientReadWriteCloseHandlers.Validate() } diff --git a/server/rwc/socket/servlet_rwc_handlers.go b/server/rwc/socket/servlet_rwc_handlers.go index 14724e5..d37e3f3 100644 --- a/server/rwc/socket/servlet_rwc_handlers.go +++ b/server/rwc/socket/servlet_rwc_handlers.go @@ -1,8 +1,6 @@ package socket import ( - "io" - "git.loafle.net/commons_go/rpc" "git.loafle.net/commons_go/rpc/protocol" "git.loafle.net/commons_go/server" @@ -17,10 +15,6 @@ type ServletReadWriteCloseHandlers struct { } func (srwch *ServletReadWriteCloseHandlers) ReadRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) { - if nil == conn { - return nil, io.EOF - } - soc := conn.(server.Socket) reqCodec, err := codec.NewRequest(soc) @@ -28,10 +22,6 @@ func (srwch *ServletReadWriteCloseHandlers) ReadRequest(servletCTX rpc.ServletCo } func (srwch *ServletReadWriteCloseHandlers) WriteResponse(servletCTX rpc.ServletContext, conn interface{}, reqCodec protocol.ServerRequestCodec, result interface{}, err error) error { - if nil == conn { - return io.EOF - } - soc := conn.(server.Socket) if nil != err { @@ -48,10 +38,6 @@ func (srwch *ServletReadWriteCloseHandlers) WriteResponse(servletCTX rpc.Servlet } func (srwch *ServletReadWriteCloseHandlers) WriteNotification(servletCTX rpc.ServletContext, conn interface{}, codec protocol.ServerCodec, method string, params interface{}) error { - if nil == conn { - return io.EOF - } - soc := conn.(server.Socket) if wErr := codec.WriteNotification(soc, method, params); nil != wErr { @@ -60,3 +46,7 @@ func (srwch *ServletReadWriteCloseHandlers) WriteNotification(servletCTX rpc.Ser return nil } + +func (srwch *ServletReadWriteCloseHandlers) Validate() { + srwch.ServletReadWriteCloseHandlers.Validate() +} diff --git a/server/rwc/websocket/fasthttp/servlet_rwc_handlers.go b/server/rwc/websocket/fasthttp/servlet_rwc_handlers.go index cc60128..a44cc07 100644 --- a/server/rwc/websocket/fasthttp/servlet_rwc_handlers.go +++ b/server/rwc/websocket/fasthttp/servlet_rwc_handlers.go @@ -1,8 +1,6 @@ package fasthttp import ( - "io" - "git.loafle.net/commons_go/rpc" "git.loafle.net/commons_go/rpc/protocol" cwf "git.loafle.net/commons_go/websocket_fasthttp" @@ -18,10 +16,6 @@ type ServletReadWriteCloseHandlers struct { } func (srwch *ServletReadWriteCloseHandlers) ReadRequest(servletCTX rpc.ServletContext, codec protocol.ServerCodec, conn interface{}) (protocol.ServerRequestCodec, error) { - if nil == conn { - return nil, io.EOF - } - soc := conn.(cwf.Socket) _, r, err := soc.NextReader() @@ -31,10 +25,6 @@ func (srwch *ServletReadWriteCloseHandlers) ReadRequest(servletCTX rpc.ServletCo } func (srwch *ServletReadWriteCloseHandlers) WriteResponse(servletCTX rpc.ServletContext, conn interface{}, requestCodec protocol.ServerRequestCodec, result interface{}, err error) error { - if nil == conn { - return io.EOF - } - soc := conn.(cwf.Socket) wc, wErr := soc.NextWriter(websocket.TextMessage) @@ -60,10 +50,6 @@ func (srwch *ServletReadWriteCloseHandlers) WriteResponse(servletCTX rpc.Servlet } func (srwch *ServletReadWriteCloseHandlers) WriteNotification(servletCTX rpc.ServletContext, conn interface{}, codec protocol.ServerCodec, method string, params interface{}) error { - if nil == conn { - return io.EOF - } - soc := conn.(cwf.Socket) wc, wErr := soc.NextWriter(websocket.TextMessage) @@ -77,3 +63,7 @@ func (srwch *ServletReadWriteCloseHandlers) WriteNotification(servletCTX rpc.Ser return nil } + +func (srwch *ServletReadWriteCloseHandlers) Validate() { + srwch.ServletReadWriteCloseHandlers.Validate() +} diff --git a/servlet.go b/servlet.go index c4fa917..71fcebc 100644 --- a/servlet.go +++ b/servlet.go @@ -163,6 +163,11 @@ func handleReader(s *rpcServlet, stopChan chan struct{}, doneChan chan error) { }() for { + if nil == s.conn { + err = fmt.Errorf("RPC Server: disconnected from client") + return + } + requestCodec, err := s.rwcSH.ReadRequest(s.ctx, s.serverCodec, s.conn) if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { @@ -213,6 +218,11 @@ func handleWriter(s *rpcServlet, stopChan chan struct{}, doneChan chan error) { } } + if nil == s.conn { + err = fmt.Errorf("RPC Server: disconnected from client") + return + } + if nil != rs.requestCodec { if err := s.rwcSH.WriteResponse(s.ctx, s.conn, rs.requestCodec, rs.result, rs.err); nil != err { logging.Logger().Error(fmt.Sprintf("RPC Server: response error %v", err))