340 lines
7.3 KiB
Go
340 lines
7.3 KiB
Go
|
package client
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"runtime"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"git.loafle.net/commons_go/rpc/protocol"
|
||
|
)
|
||
|
|
||
|
func New(ch ClientHandler) Client {
|
||
|
c := &client{
|
||
|
ch: ch,
|
||
|
}
|
||
|
return c
|
||
|
}
|
||
|
|
||
|
type Client interface {
|
||
|
Start(rwc io.ReadWriteCloser)
|
||
|
Stop()
|
||
|
Notify(method string, args interface{}) error
|
||
|
Call(method string, args interface{}, result interface{}) error
|
||
|
CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) error
|
||
|
}
|
||
|
|
||
|
type client struct {
|
||
|
ch ClientHandler
|
||
|
|
||
|
rwc io.ReadWriteCloser
|
||
|
|
||
|
pendingRequestsCount uint32
|
||
|
|
||
|
requestQueueChan chan *CallState
|
||
|
|
||
|
stopChan chan struct{}
|
||
|
stopWg sync.WaitGroup
|
||
|
}
|
||
|
|
||
|
func (c *client) Start(rwc io.ReadWriteCloser) {
|
||
|
c.ch.Validate()
|
||
|
|
||
|
if nil == rwc {
|
||
|
panic("RWC(io.ReadWriteCloser) must be specified.")
|
||
|
}
|
||
|
|
||
|
if c.stopChan != nil {
|
||
|
panic("Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!")
|
||
|
}
|
||
|
|
||
|
c.rwc = rwc
|
||
|
c.stopChan = make(chan struct{})
|
||
|
c.requestQueueChan = make(chan *CallState, c.ch.GetPendingRequests())
|
||
|
|
||
|
c.stopWg.Add(1)
|
||
|
go handleRPC(c)
|
||
|
}
|
||
|
|
||
|
func (c *client) Stop() {
|
||
|
if c.stopChan == nil {
|
||
|
panic("Client: the client must be started before stopping it")
|
||
|
}
|
||
|
close(c.stopChan)
|
||
|
c.stopWg.Wait()
|
||
|
c.stopChan = nil
|
||
|
}
|
||
|
|
||
|
func (c *client) Notify(method string, args interface{}) error {
|
||
|
_, err := c.send(method, args, nil, false, true)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (c *client) Call(method string, args interface{}, result interface{}) error {
|
||
|
return c.CallTimeout(method, args, result, c.ch.GetRequestTimeout())
|
||
|
}
|
||
|
|
||
|
func (c *client) CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) (err error) {
|
||
|
var cs *CallState
|
||
|
if cs, err = c.send(method, args, result, true, true); nil != err {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
t := retainTimer(timeout)
|
||
|
|
||
|
select {
|
||
|
case <-cs.DoneChan:
|
||
|
result, err = cs.Result, cs.Error
|
||
|
releaseCallState(cs)
|
||
|
case <-t.C:
|
||
|
cs.Cancel()
|
||
|
err = getClientTimeoutError(c, timeout)
|
||
|
}
|
||
|
|
||
|
releaseTimer(t)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *client) send(method string, args interface{}, result interface{}, hasResponse bool, usePool bool) (cs *CallState, err error) {
|
||
|
if !hasResponse {
|
||
|
usePool = true
|
||
|
}
|
||
|
|
||
|
if usePool {
|
||
|
cs = retainCallState()
|
||
|
} else {
|
||
|
cs = &CallState{}
|
||
|
}
|
||
|
|
||
|
cs.Method = method
|
||
|
cs.Args = args
|
||
|
|
||
|
if hasResponse {
|
||
|
cs.ID = c.ch.GetRequestID()
|
||
|
cs.Result = result
|
||
|
cs.DoneChan = make(chan *CallState, 1)
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case c.requestQueueChan <- cs:
|
||
|
return cs, nil
|
||
|
default:
|
||
|
// Try substituting the oldest async request by the new one
|
||
|
// on requests' queue overflow.
|
||
|
// This increases the chances for new request to succeed
|
||
|
// without timeout.
|
||
|
if !hasResponse {
|
||
|
// Immediately notify the caller not interested
|
||
|
// in the response on requests' queue overflow, since
|
||
|
// there are no other ways to notify it later.
|
||
|
releaseCallState(cs)
|
||
|
return nil, getClientOverflowError(c)
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case rcs := <-c.requestQueueChan:
|
||
|
if rcs.DoneChan != nil {
|
||
|
rcs.Error = getClientOverflowError(c)
|
||
|
//close(rcs.DoneChan)
|
||
|
rcs.done()
|
||
|
} else {
|
||
|
releaseCallState(rcs)
|
||
|
}
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case c.requestQueueChan <- cs:
|
||
|
return cs, nil
|
||
|
default:
|
||
|
// Release m even if usePool = true, since m wasn't exposed
|
||
|
// to the caller yet.
|
||
|
releaseCallState(cs)
|
||
|
return nil, getClientOverflowError(c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func handleRPC(c *client) {
|
||
|
defer c.stopWg.Done()
|
||
|
|
||
|
stopChan := make(chan struct{})
|
||
|
|
||
|
pendingRequests := make(map[interface{}]*CallState)
|
||
|
var pendingRequestsLock sync.Mutex
|
||
|
|
||
|
writerDone := make(chan error, 1)
|
||
|
go rpcWriter(c, pendingRequests, &pendingRequestsLock, stopChan, writerDone)
|
||
|
|
||
|
readerDone := make(chan error, 1)
|
||
|
go rpcReader(c, pendingRequests, &pendingRequestsLock, readerDone)
|
||
|
|
||
|
var err error
|
||
|
|
||
|
select {
|
||
|
case err = <-writerDone:
|
||
|
close(stopChan)
|
||
|
c.rwc.Close()
|
||
|
<-readerDone
|
||
|
case err = <-readerDone:
|
||
|
close(stopChan)
|
||
|
c.rwc.Close()
|
||
|
<-writerDone
|
||
|
case <-c.stopChan:
|
||
|
close(stopChan)
|
||
|
c.rwc.Close()
|
||
|
<-readerDone
|
||
|
<-writerDone
|
||
|
}
|
||
|
|
||
|
if err != nil {
|
||
|
//c.LogError("%s", err)
|
||
|
err = &ClientError{
|
||
|
Connection: true,
|
||
|
err: err,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, stopChan <-chan struct{}, writerDone chan<- error) {
|
||
|
var err error
|
||
|
defer func() {
|
||
|
writerDone <- err
|
||
|
}()
|
||
|
|
||
|
for {
|
||
|
var cs *CallState
|
||
|
|
||
|
select {
|
||
|
case cs = <-c.requestQueueChan:
|
||
|
default:
|
||
|
// Give the last chance for ready goroutines filling c.requestsChan :)
|
||
|
runtime.Gosched()
|
||
|
|
||
|
select {
|
||
|
case <-stopChan:
|
||
|
return
|
||
|
case cs = <-c.requestQueueChan:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if cs.IsCanceled() {
|
||
|
if nil != cs.DoneChan {
|
||
|
// cs.Error = ErrCanceled
|
||
|
// close(m.done)
|
||
|
cs.done()
|
||
|
} else {
|
||
|
releaseCallState(cs)
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if nil != cs.DoneChan {
|
||
|
pendingRequestsLock.Lock()
|
||
|
n := len(pendingRequests)
|
||
|
pendingRequests[cs.ID] = cs
|
||
|
pendingRequestsLock.Unlock()
|
||
|
atomic.AddUint32(&c.pendingRequestsCount, 1)
|
||
|
|
||
|
if n > 10*c.ch.GetPendingRequests() {
|
||
|
err = fmt.Errorf("Client: The server didn't return %d responses yet. Closing server connection in order to prevent client resource leaks", n)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if nil == cs.DoneChan {
|
||
|
releaseCallState(cs)
|
||
|
}
|
||
|
|
||
|
if err = c.ch.GetCodec().Write(c.rwc, cs.Method, cs.Args, cs.ID); nil != err {
|
||
|
err = fmt.Errorf("Client: Cannot send request to wire: [%s]", err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, readerDone chan<- error) {
|
||
|
var err error
|
||
|
defer func() {
|
||
|
if r := recover(); r != nil {
|
||
|
if err == nil {
|
||
|
err = fmt.Errorf("Client: Panic when reading data from server: %v", r)
|
||
|
}
|
||
|
}
|
||
|
readerDone <- err
|
||
|
}()
|
||
|
|
||
|
for {
|
||
|
crn, err := c.ch.GetCodec().NewResponseOrNotify(c.rwc)
|
||
|
if nil != err {
|
||
|
err = fmt.Errorf("Client: Cannot decode response or notify: [%s]", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if crn.IsResponse() {
|
||
|
err = responseHandle(c, crn.GetResponse(), pendingRequests, pendingRequestsLock)
|
||
|
} else {
|
||
|
err = notifyHandle(c, crn.GetNotify())
|
||
|
}
|
||
|
if nil != err {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func responseHandle(c *client, codecResponse protocol.ClientCodecResponse, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex) error {
|
||
|
pendingRequestsLock.Lock()
|
||
|
cs, ok := pendingRequests[codecResponse.ID()]
|
||
|
if ok {
|
||
|
delete(pendingRequests, codecResponse.ID())
|
||
|
}
|
||
|
pendingRequestsLock.Unlock()
|
||
|
|
||
|
if !ok {
|
||
|
return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", codecResponse.ID())
|
||
|
}
|
||
|
|
||
|
atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0))
|
||
|
|
||
|
cs.Result = codecResponse.Result()
|
||
|
if err := codecResponse.Error(); nil != err {
|
||
|
// cs.Error = &ClientError{
|
||
|
// Server: true,
|
||
|
// err: fmt.Errorf("gorpc.Client: [%s]. Server error: [%s]", c.Addr, wr.Error),
|
||
|
// }
|
||
|
}
|
||
|
|
||
|
cs.done()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func notifyHandle(c *client, codecNotify protocol.ClientCodecNotify) error {
|
||
|
_, err := c.ch.GetRPCRegistry().Invoke(codecNotify)
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func getClientTimeoutError(c *client, timeout time.Duration) error {
|
||
|
err := fmt.Errorf("Client: Cannot obtain response during timeout=%s", timeout)
|
||
|
//c.LogError("%s", err)
|
||
|
return &ClientError{
|
||
|
Timeout: true,
|
||
|
err: err,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func getClientOverflowError(c *client) error {
|
||
|
err := fmt.Errorf("Client: Requests' queue with size=%d is overflown. Try increasing Client.PendingRequests value", cap(c.requestQueueChan))
|
||
|
//c.LogError("%s", err)
|
||
|
return &ClientError{
|
||
|
Overflow: true,
|
||
|
err: err,
|
||
|
}
|
||
|
}
|