From b4b379c8918640c77553fcea067c63963a779c73 Mon Sep 17 00:00:00 2001 From: crusader Date: Wed, 1 Nov 2017 12:10:39 +0900 Subject: [PATCH] ing --- client/client.go | 84 +++++++++++++++++++-------------------- client/client_handler.go | 7 ++-- client/client_handlers.go | 20 +++------- 3 files changed, 49 insertions(+), 62 deletions(-) diff --git a/client/client.go b/client/client.go index 4d4d38c..e793514 100644 --- a/client/client.go +++ b/client/client.go @@ -3,6 +3,7 @@ package client import ( "fmt" "io" + "log" "runtime" "sync" "sync/atomic" @@ -19,8 +20,9 @@ func New(ch ClientHandler) Client { } type Client interface { - Start(rwc io.ReadWriteCloser) - Stop() + Connect() error + Close() + Notify(method string, args interface{}) error Call(method string, args interface{}, result interface{}) error CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) error @@ -32,6 +34,8 @@ type client struct { rwc io.ReadWriteCloser pendingRequestsCount uint32 + pendingRequests map[interface{}]*CallState + pendingRequestsLock sync.Mutex requestQueueChan chan *CallState @@ -39,28 +43,27 @@ type client struct { stopWg sync.WaitGroup } -func (c *client) Start(rwc io.ReadWriteCloser) { +func (c *client) Connect() error { + var err error c.ch.Validate() - if nil == rwc { - panic("RWC(io.ReadWriteCloser) must be specified.") - } - if c.stopChan != nil { - panic("Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!") + panic("RPC Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!") } - c.rwc = rwc + if c.rwc, err = c.ch.Connect(); nil != err { + return err + } c.stopChan = make(chan struct{}) c.requestQueueChan = make(chan *CallState, c.ch.GetPendingRequests()) + c.pendingRequests = make(map[interface{}]*CallState) - c.ch.OnStart() + go c.handleRPC() - c.stopWg.Add(1) - go handleRPC(c) + return nil } -func (c *client) Stop() { +func (c *client) Close() { if c.stopChan == nil { panic("Client: the client must be started before stopping it") } @@ -160,47 +163,42 @@ func (c *client) send(method string, args interface{}, result interface{}, hasRe } } -func handleRPC(c *client) { - defer c.stopWg.Done() - - stopChan := make(chan struct{}) - - pendingRequests := make(map[interface{}]*CallState) - var pendingRequestsLock sync.Mutex +func (c *client) handleRPC() { + subStopChan := make(chan struct{}) writerDone := make(chan error, 1) - go rpcWriter(c, pendingRequests, &pendingRequestsLock, stopChan, writerDone) + go c.rpcWriter(subStopChan, writerDone) readerDone := make(chan error, 1) - go rpcReader(c, pendingRequests, &pendingRequestsLock, readerDone) + go c.rpcReader(readerDone) var err error select { case err = <-writerDone: - close(stopChan) - c.rwc.Close() + close(subStopChan) <-readerDone case err = <-readerDone: - close(stopChan) - c.rwc.Close() + close(subStopChan) <-writerDone case <-c.stopChan: - close(stopChan) - c.rwc.Close() + close(subStopChan) <-readerDone <-writerDone } + c.rwc.Close() + if err != nil { //c.LogError("%s", err) + log.Printf("handleRPC: %v", err) err = &ClientError{ Connection: true, err: err, } } - for _, cs := range pendingRequests { + for _, cs := range c.pendingRequests { atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) cs.Error = err if cs.DoneChan != nil { @@ -210,7 +208,7 @@ func handleRPC(c *client) { } -func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, stopChan <-chan struct{}, writerDone chan<- error) { +func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) { var err error defer func() { writerDone <- err @@ -244,10 +242,10 @@ func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingReq } if nil != cs.DoneChan { - pendingRequestsLock.Lock() - n := len(pendingRequests) - pendingRequests[cs.ID] = cs - pendingRequestsLock.Unlock() + c.pendingRequestsLock.Lock() + n := len(c.pendingRequests) + c.pendingRequests[cs.ID] = cs + c.pendingRequestsLock.Unlock() atomic.AddUint32(&c.pendingRequestsCount, 1) if n > 10*c.ch.GetPendingRequests() { @@ -267,7 +265,7 @@ func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingReq } } -func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, readerDone chan<- error) { +func (c *client) rpcReader(readerDone chan<- error) { var err error defer func() { if r := recover(); r != nil { @@ -286,9 +284,9 @@ func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingReq } if crn.IsResponse() { - err = responseHandle(c, crn.GetResponse(), pendingRequests, pendingRequestsLock) + err = c.responseHandle(crn.GetResponse()) } else { - err = notifyHandle(c, crn.GetNotify()) + err = c.notifyHandle(crn.GetNotify()) } if nil != err { return @@ -297,13 +295,13 @@ func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingReq } -func responseHandle(c *client, codecResponse protocol.ClientCodecResponse, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex) error { - pendingRequestsLock.Lock() - cs, ok := pendingRequests[codecResponse.ID()] +func (c *client) responseHandle(codecResponse protocol.ClientCodecResponse) error { + c.pendingRequestsLock.Lock() + cs, ok := c.pendingRequests[codecResponse.ID()] if ok { - delete(pendingRequests, codecResponse.ID()) + delete(c.pendingRequests, codecResponse.ID()) } - pendingRequestsLock.Unlock() + c.pendingRequestsLock.Unlock() if !ok { return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", codecResponse.ID()) @@ -324,7 +322,7 @@ func responseHandle(c *client, codecResponse protocol.ClientCodecResponse, pendi return nil } -func notifyHandle(c *client, codecNotify protocol.ClientCodecNotify) error { +func (c *client) notifyHandle(codecNotify protocol.ClientCodecNotify) error { _, err := c.ch.GetRPCRegistry().Invoke(codecNotify) return err diff --git a/client/client_handler.go b/client/client_handler.go index 1ebc2da..d510812 100644 --- a/client/client_handler.go +++ b/client/client_handler.go @@ -1,6 +1,7 @@ package client import ( + "io" "time" "git.loafle.net/commons_go/rpc" @@ -8,11 +9,9 @@ import ( ) type ClientHandler interface { - OnStart() - OnStop() - - GetContentType() string + Connect() (io.ReadWriteCloser, error) GetCodec() protocol.ClientCodec + GetRPCRegistry() rpc.Registry GetRequestTimeout() time.Duration GetPendingRequests() int diff --git a/client/client_handlers.go b/client/client_handlers.go index e6dea91..cf46657 100644 --- a/client/client_handlers.go +++ b/client/client_handlers.go @@ -1,6 +1,8 @@ package client import ( + "errors" + "io" "sync" "time" @@ -9,8 +11,7 @@ import ( ) type ClientHandlers struct { - ContentType string - Codec protocol.ClientCodec + Codec protocol.ClientCodec // Maximum request time. // Default value is DefaultRequestTimeout. RequestTimeout time.Duration @@ -29,16 +30,8 @@ type ClientHandlers struct { requestIDMtx sync.Mutex } -func (ch *ClientHandlers) OnStart() { - // no op -} - -func (ch *ClientHandlers) OnStop() { - // no op -} - -func (ch *ClientHandlers) GetContentType() string { - return ch.ContentType +func (ch *ClientHandlers) Connect() (io.ReadWriteCloser, error) { + return nil, errors.New("RPC Client: ClientHandlers method[Connect] is not implement") } func (ch *ClientHandlers) GetCodec() protocol.ClientCodec { @@ -68,9 +61,6 @@ func (ch *ClientHandlers) GetRequestID() interface{} { } func (ch *ClientHandlers) Validate() { - if "" == ch.ContentType { - panic("ContentType must be specified.") - } if ch.RequestTimeout <= 0 { ch.RequestTimeout = DefaultRequestTimeout }