diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..c826db6 --- /dev/null +++ b/client/client.go @@ -0,0 +1,305 @@ +package client + +import ( + "context" + "fmt" + "reflect" + "sync" + "time" + + logging "git.loafle.net/commons/logging-go" + "git.loafle.net/commons/rpc-go/protocol" + "git.loafle.net/commons/rpc-go/registry" +) + +var uint64Type = reflect.TypeOf(uint64(0)) + +type Client struct { + Codec protocol.ClientCodec + RPCInvoker registry.RPCInvoker + + stopChan chan struct{} + stopWg sync.WaitGroup + + requestID uint64 + requestQueueChan chan *requestState + + pendingRequests sync.Map +} + +func (c *Client) Start(readChan <-chan []byte, writeChan chan<- []byte) error { + if c.stopChan != nil { + return fmt.Errorf("Client: already running. Stop it before starting it again") + } + + if nil == c.Codec { + return fmt.Errorf("Client: Codec is not valid") + } + + c.stopChan = make(chan struct{}) + + c.stopWg.Add(1) + go c.handleClient(readChan, writeChan) + + return nil +} + +func (c *Client) Stop(ctx context.Context) error { + if c.stopChan == nil { + return fmt.Errorf("Client: must be started before stopping it") + } + close(c.stopChan) + c.stopWg.Wait() + c.stopChan = nil + + return nil +} + +func (c *Client) Send(method string, params ...interface{}) error { + rs, err := c.internalSend(false, nil, method, params...) + if nil != err { + return err + } + + defer releaseRequestState(rs) + select { + case <-rs.doneChan: + if nil != rs.clientError { + return rs.clientError + } + } + return nil +} + +func (c *Client) Call(result interface{}, method string, params ...interface{}) error { + return c.CallTimeout(10, result, method, params...) +} + +func (c *Client) CallTimeout(timeout time.Duration, result interface{}, method string, params ...interface{}) error { + rs, err := c.internalSend(true, result, method, params...) + if nil != err { + return err + } + + t := retainTimer(timeout) + defer func() { + releaseRequestState(rs) + releaseTimer(t) + }() + + select { + case <-rs.doneChan: + result = rs.result + return rs.clientError + case <-t.C: + rs.cancel() + return newError(method, params, fmt.Errorf("Timeout")) + } +} + +func (c *Client) getRequestID() uint64 { + c.requestID++ + return c.requestID +} + +func (c *Client) internalSend(hasResponse bool, result interface{}, method string, params ...interface{}) (*requestState, error) { + rs := retainRequestState() + + rs.method = method + rs.params = params + rs.doneChan = make(chan *requestState, 1) + if hasResponse { + rs.id = c.getRequestID() + rs.result = result + } + + select { + case c.requestQueueChan <- rs: + return rs, nil + default: + if !hasResponse { + releaseRequestState(rs) + return nil, newError(method, params, fmt.Errorf("Request Queue overflow")) + } + select { + case oldRS := <-c.requestQueueChan: + if nil != oldRS.doneChan { + oldRS.setError(fmt.Errorf("Request Queue overflow")) + oldRS.done() + } else { + releaseRequestState(oldRS) + } + default: + } + select { + case c.requestQueueChan <- rs: + return rs, nil + default: + releaseRequestState(rs) + return nil, newError(method, params, fmt.Errorf("Request Queue overflow")) + } + } +} + +func (c *Client) handleClient(readChan <-chan []byte, writeChan chan<- []byte) { + defer func() { + c.stopWg.Done() + }() + + stopChan := make(chan struct{}) + sendDoneChan := make(chan error) + receiveDoneChan := make(chan error) + + go c.handleSend(stopChan, sendDoneChan, writeChan) + go c.handleReceive(stopChan, receiveDoneChan, readChan) + + select { + case <-sendDoneChan: + close(stopChan) + <-receiveDoneChan + case <-receiveDoneChan: + close(stopChan) + <-sendDoneChan + case <-c.stopChan: + close(stopChan) + <-sendDoneChan + <-receiveDoneChan + } +} + +func (c *Client) handleSend(stopChan <-chan struct{}, doneChan chan<- error, writeChan chan<- []byte) { + var ( + rs *requestState + id interface{} + message []byte + err error + ok bool + ) + + defer func() { + doneChan <- err + }() + +LOOP: + for { + select { + case rs, ok = <-c.requestQueueChan: + if !ok { + return + } + if rs.isCanceled() { + if nil != rs.doneChan { + rs.done() + } else { + releaseRequestState(rs) + } + continue LOOP + } + + id = nil + if 0 < rs.id { + id = rs.id + } + message, err = c.Codec.NewRequest(rs.method, rs.params, id) + if nil != err { + rs.setError(err) + rs.done() + continue LOOP + } + + select { + case writeChan <- message: + default: + rs.setError(fmt.Errorf("Client: cannot send request")) + rs.done() + continue LOOP + } + + if 0 < rs.id { + c.pendingRequests.Store(rs.id, rs) + } + case <-c.stopChan: + return + } + } +} + +func (c *Client) handleReceive(stopChan <-chan struct{}, doneChan chan<- error, readChan <-chan []byte) { + var ( + message []byte + err error + ok bool + ) + + defer func() { + doneChan <- err + }() + +LOOP: + for { + select { + case message, ok = <-readChan: + if !ok { + return + } + resCodec, err := c.Codec.NewResponse(message) + if nil != err { + continue LOOP + } + if nil == resCodec.ID() { + // notification + notiCodec, err := resCodec.Notification() + if nil != err { + logging.Logger().Warnf("Client: notification error %v", err) + continue LOOP + } + + c.stopWg.Add(1) + go c.handleNotification(notiCodec) + } else { + // response + c.stopWg.Add(1) + go c.handleResponse(resCodec) + } + case <-stopChan: + return + } + } + +} + +func (c *Client) handleResponse(resCodec protocol.ClientResponseCodec) { + defer func() { + c.stopWg.Done() + }() + id := reflect.ValueOf(resCodec.ID()).Convert(uint64Type).Uint() + _rs, ok := c.pendingRequests.Load(id) + if !ok { + logging.Logger().Warnf("Client: unexpected ID=[%d] obtained from server", id) + return + } + rs := _rs.(*requestState) + rs.setError(resCodec.Error()) + err := resCodec.Result(rs.result) + if nil != err { + rs.setError(err) + } + + rs.done() +} + +func (c *Client) handleNotification(notiCodec protocol.ClientNotificationCodec) { + defer func() { + c.stopWg.Done() + }() + + if nil == c.RPCInvoker { + logging.Logger().Warnf("Client: received notification method[%s] but RPC Invoker is not exist", notiCodec.Method()) + return + } + + _, err := c.RPCInvoker.Invoke(notiCodec) + if nil != err { + logging.Logger().Errorf("Client: invoking of notification method[%s] has been failed %v", notiCodec.Method(), err) + } +} diff --git a/client/error.go b/client/error.go new file mode 100644 index 0000000..ec9ad8e --- /dev/null +++ b/client/error.go @@ -0,0 +1,20 @@ +package client + +func newError(method string, params []interface{}, err error) *Error { + return &Error{ + Method: method, + Params: params, + Err: err, + } +} + +type Error struct { + Method string + Params []interface{} + + Err error +} + +func (e *Error) Error() string { + return e.Err.Error() +} diff --git a/client/request-state.go b/client/request-state.go new file mode 100644 index 0000000..fe4b034 --- /dev/null +++ b/client/request-state.go @@ -0,0 +1,93 @@ +package client + +import ( + "sync" + "sync/atomic" + "time" +) + +type requestState struct { + id uint64 + method string + params []interface{} + result interface{} + clientError *Error + doneChan chan *requestState + + canceled atomic.Value +} + +func (rs *requestState) done() { + select { + case rs.doneChan <- rs: + default: + } +} + +func (rs *requestState) cancel() { + rs.canceled.Store(true) +} + +func (rs *requestState) isCanceled() bool { + v := rs.canceled.Load() + if nil == v { + return false + } + vv := v.(bool) + + return vv +} + +func (rs *requestState) setError(err error) { + rs.clientError = newError(rs.method, rs.params, err) +} + +var requestStatePool sync.Pool + +func retainRequestState() *requestState { + v := requestStatePool.Get() + if v == nil { + return &requestState{} + } + return v.(*requestState) +} + +func releaseRequestState(rs *requestState) { + rs.id = 0 + rs.method = "" + rs.params = nil + rs.result = nil + rs.clientError = nil + rs.doneChan = nil + rs.canceled.Store(false) + + requestStatePool.Put(rs) +} + +var timerPool sync.Pool + +func retainTimer(timeout time.Duration) *time.Timer { + tv := timerPool.Get() + if tv == nil { + return time.NewTimer(timeout) + } + + t := tv.(*time.Timer) + if t.Reset(timeout) { + panic("Client: Active timer trapped into retainTimer()") + } + return t +} + +func releaseTimer(t *time.Timer) { + if !t.Stop() { + // Collect possibly added time from the channel + // if timer has been stopped and nobody collected its' value. + select { + case <-t.C: + default: + } + } + + timerPool.Put(t) +}