440 lines
9.3 KiB
Go
440 lines
9.3 KiB
Go
package client
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"reflect"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.loafle.net/commons_go/logging"
|
|
"git.loafle.net/commons_go/rpc/protocol"
|
|
)
|
|
|
|
func New(ch ClientHandler, rwcHandler ClientReadWriteCloseHandler) Client {
|
|
c := &client{
|
|
ch: ch,
|
|
rwcHandler: rwcHandler,
|
|
}
|
|
return c
|
|
}
|
|
|
|
type Client interface {
|
|
Connect() error
|
|
Close()
|
|
|
|
Send(method string, args ...interface{}) (err error)
|
|
Call(result interface{}, method string, args ...interface{}) error
|
|
CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error)
|
|
}
|
|
|
|
type client struct {
|
|
ctx ClientContext
|
|
ch ClientHandler
|
|
rwcHandler ClientReadWriteCloseHandler
|
|
|
|
conn interface{}
|
|
|
|
pendingRequestsCount uint32
|
|
pendingRequests map[uint64]*RequestState
|
|
pendingRequestsLock sync.Mutex
|
|
|
|
requestQueueChan chan *RequestState
|
|
|
|
stopChan chan struct{}
|
|
stopWg sync.WaitGroup
|
|
|
|
requestMtx sync.Mutex
|
|
responseMtx sync.Mutex
|
|
}
|
|
|
|
func (c *client) Connect() error {
|
|
var err error
|
|
|
|
if nil == c.ch {
|
|
return fmt.Errorf("RPC Client: Client handler must be specified")
|
|
}
|
|
c.ch.Validate()
|
|
|
|
if nil == c.rwcHandler {
|
|
return fmt.Errorf("RPC Client: Client RWC handler must be specified")
|
|
}
|
|
c.rwcHandler.Validate()
|
|
|
|
if c.stopChan != nil {
|
|
return fmt.Errorf("RPC Client: the given client is already started. Call Client.Stop() before calling Client.Start() again")
|
|
}
|
|
c.ctx = c.ch.ClientContext(nil)
|
|
|
|
if err := c.ch.Init(c.ctx); nil != err {
|
|
return fmt.Errorf("RPC Client: Initialization of client has been failed %v", err)
|
|
}
|
|
|
|
if c.conn, err = c.rwcHandler.Connect(c.ctx); nil != err {
|
|
return err
|
|
}
|
|
|
|
c.stopChan = make(chan struct{})
|
|
c.requestQueueChan = make(chan *RequestState, c.ch.GetPendingRequests())
|
|
c.pendingRequests = make(map[uint64]*RequestState)
|
|
|
|
go c.handleRPC()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *client) Close() {
|
|
if c.stopChan == nil {
|
|
logging.Logger().Warnf("RPC Client: the client must be started before stopping it")
|
|
return
|
|
}
|
|
|
|
c.ch.Destroy(c.ctx)
|
|
|
|
close(c.stopChan)
|
|
c.stopWg.Wait()
|
|
c.stopChan = nil
|
|
|
|
logging.Logger().Infof("RPC Client: stopped")
|
|
}
|
|
|
|
func (c *client) Send(method string, args ...interface{}) (err error) {
|
|
var rs *RequestState
|
|
if rs, err = c.send(true, false, nil, method, args...); nil != err {
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-rs.DoneChan:
|
|
err = rs.Error
|
|
releaseCallState(rs)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *client) Call(result interface{}, method string, args ...interface{}) error {
|
|
return c.CallTimeout(c.ch.GetRequestTimeout(), result, method, args...)
|
|
}
|
|
|
|
func (c *client) CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error) {
|
|
var rs *RequestState
|
|
if rs, err = c.send(true, true, result, method, args...); nil != err {
|
|
return
|
|
}
|
|
|
|
t := retainTimer(timeout)
|
|
|
|
select {
|
|
case <-rs.DoneChan:
|
|
result, err = rs.Result, rs.Error
|
|
releaseCallState(rs)
|
|
case <-t.C:
|
|
rs.Cancel()
|
|
err = getClientTimeoutError(c, timeout)
|
|
}
|
|
|
|
releaseTimer(t)
|
|
|
|
return
|
|
}
|
|
|
|
func (c *client) send(usePool bool, hasResponse bool, result interface{}, method string, args ...interface{}) (rs *RequestState, err error) {
|
|
if !hasResponse {
|
|
usePool = true
|
|
}
|
|
|
|
if usePool {
|
|
rs = retainRequestState()
|
|
} else {
|
|
rs = &RequestState{}
|
|
}
|
|
|
|
rs.hasResponse = hasResponse
|
|
rs.Method = method
|
|
rs.Args = args
|
|
rs.DoneChan = make(chan *RequestState, 1)
|
|
|
|
if hasResponse {
|
|
rs.ID = c.ch.GetRequestID()
|
|
rs.Result = result
|
|
}
|
|
|
|
select {
|
|
case c.requestQueueChan <- rs:
|
|
return rs, nil
|
|
default:
|
|
// Try substituting the oldest async request by the new one
|
|
// on requests' queue overflow.
|
|
// This increases the chances for new request to succeed
|
|
// without timeout.
|
|
if !hasResponse {
|
|
// Immediately notify the caller not interested
|
|
// in the response on requests' queue overflow, since
|
|
// there are no other ways to notify it later.
|
|
releaseCallState(rs)
|
|
return nil, getClientOverflowError(c)
|
|
}
|
|
|
|
select {
|
|
case rcs := <-c.requestQueueChan:
|
|
if rcs.DoneChan != nil {
|
|
rcs.Error = getClientOverflowError(c)
|
|
//close(rcs.DoneChan)
|
|
rcs.Done()
|
|
} else {
|
|
releaseCallState(rcs)
|
|
}
|
|
default:
|
|
}
|
|
|
|
select {
|
|
case c.requestQueueChan <- rs:
|
|
return rs, nil
|
|
default:
|
|
// Release m even if usePool = true, since m wasn't exposed
|
|
// to the caller yet.
|
|
releaseCallState(rs)
|
|
return nil, getClientOverflowError(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *client) handleRPC() {
|
|
subStopChan := make(chan struct{})
|
|
|
|
writerDone := make(chan error, 1)
|
|
go c.rpcWriter(subStopChan, writerDone)
|
|
|
|
readerDone := make(chan error, 1)
|
|
go c.rpcReader(readerDone)
|
|
|
|
var err error
|
|
|
|
select {
|
|
case err = <-writerDone:
|
|
close(subStopChan)
|
|
<-readerDone
|
|
case err = <-readerDone:
|
|
close(subStopChan)
|
|
<-writerDone
|
|
case <-c.stopChan:
|
|
close(subStopChan)
|
|
<-readerDone
|
|
<-writerDone
|
|
}
|
|
|
|
if nil != c.conn {
|
|
c.rwcHandler.Disconnect(c.ctx, c.conn)
|
|
}
|
|
|
|
if err != nil {
|
|
//c.LogError("%s", err)
|
|
log.Printf("handleRPC: %v", err)
|
|
err = &ClientError{
|
|
Connection: true,
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
for _, rs := range c.pendingRequests {
|
|
atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0))
|
|
rs.Error = err
|
|
if rs.DoneChan != nil {
|
|
rs.Done()
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) {
|
|
var err error
|
|
defer func() {
|
|
writerDone <- err
|
|
}()
|
|
|
|
for {
|
|
var rs *RequestState
|
|
|
|
select {
|
|
case rs = <-c.requestQueueChan:
|
|
default:
|
|
// Give the last chance for ready goroutines filling c.requestsChan :)
|
|
runtime.Gosched()
|
|
|
|
select {
|
|
case <-stopChan:
|
|
return
|
|
case rs = <-c.requestQueueChan:
|
|
}
|
|
}
|
|
|
|
if rs.IsCanceled() {
|
|
if nil != rs.DoneChan {
|
|
// rs.Error = ErrCanceled
|
|
// close(m.done)
|
|
rs.Done()
|
|
} else {
|
|
releaseCallState(rs)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if rs.hasResponse {
|
|
c.pendingRequestsLock.Lock()
|
|
n := len(c.pendingRequests)
|
|
c.pendingRequests[rs.ID] = rs
|
|
c.pendingRequestsLock.Unlock()
|
|
atomic.AddUint32(&c.pendingRequestsCount, 1)
|
|
|
|
if n > 10*c.ch.GetPendingRequests() {
|
|
err = fmt.Errorf("Client: The server didn't return %d responses yet. Closing server connection in order to prevent client resource leaks", n)
|
|
logging.Logger().Error(err.Error())
|
|
continue
|
|
}
|
|
}
|
|
|
|
var requestID interface{}
|
|
if 0 < rs.ID {
|
|
requestID = rs.ID
|
|
}
|
|
|
|
if nil == c.conn {
|
|
err = io.EOF
|
|
return
|
|
}
|
|
|
|
c.requestMtx.Lock()
|
|
err = c.rwcHandler.WriteRequest(c.ctx, c.ch.GetCodec(), c.conn, rs.Method, rs.Args, requestID)
|
|
c.requestMtx.Unlock()
|
|
if !rs.hasResponse {
|
|
rs.Error = err
|
|
rs.Done()
|
|
}
|
|
if nil != err {
|
|
if err == io.ErrUnexpectedEOF || err == io.EOF {
|
|
logging.Logger().Infof("Client: disconnected from server")
|
|
return
|
|
}
|
|
|
|
err = fmt.Errorf("Client: Cannot send request to wire: [%s]", err)
|
|
logging.Logger().Error(err)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *client) rpcReader(readerDone chan<- error) {
|
|
var err error
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
if err == nil {
|
|
err = fmt.Errorf("Client: Panic when reading data from server: %v", r)
|
|
}
|
|
}
|
|
readerDone <- err
|
|
}()
|
|
|
|
for {
|
|
if nil == c.conn {
|
|
err = io.EOF
|
|
return
|
|
}
|
|
c.responseMtx.Lock()
|
|
resCodec, err := c.rwcHandler.ReadResponse(c.ctx, c.ch.GetCodec(), c.conn)
|
|
c.responseMtx.Unlock()
|
|
if nil != err {
|
|
if err == io.ErrUnexpectedEOF || err == io.EOF {
|
|
logging.Logger().Infof("Client: disconnected from server")
|
|
return
|
|
}
|
|
logging.Logger().Errorf("Client: Cannot decode response or notify: [%s]", err)
|
|
continue
|
|
}
|
|
|
|
if nil != resCodec.ID() {
|
|
err = c.handleResponse(resCodec)
|
|
} else {
|
|
err = c.handleNotification(resCodec)
|
|
}
|
|
|
|
if nil != err {
|
|
logging.Logger().Error(err.Error())
|
|
continue
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func (c *client) handleResponse(resCodec protocol.ClientResponseCodec) error {
|
|
c.pendingRequestsLock.Lock()
|
|
id := reflect.ValueOf(resCodec.ID()).Convert(uint64Type).Uint()
|
|
|
|
rs, ok := c.pendingRequests[id]
|
|
if ok {
|
|
delete(c.pendingRequests, id)
|
|
}
|
|
c.pendingRequestsLock.Unlock()
|
|
|
|
if !ok {
|
|
return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", resCodec.ID())
|
|
}
|
|
|
|
atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0))
|
|
|
|
if err := resCodec.Result(rs.Result); nil != err {
|
|
logging.Logger().Errorf("responseHandle:%v", err)
|
|
}
|
|
if err := resCodec.Error(); nil != err {
|
|
logging.Logger().Errorf("responseHandle:%v", err)
|
|
// rs.Error = &ClientError{
|
|
// Server: true,
|
|
// err: fmt.Errorf("gorpc.Client: [%s]. Server error: [%s]", c.Addr, wr.Error),
|
|
// }
|
|
}
|
|
|
|
rs.Done()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *client) handleNotification(resCodec protocol.ClientResponseCodec) error {
|
|
notiCodec, err := resCodec.Notification()
|
|
if nil != err {
|
|
return err
|
|
}
|
|
|
|
if nil == c.ch.GetRPCInvoker() {
|
|
params, err := notiCodec.Params()
|
|
if nil != err {
|
|
return err
|
|
}
|
|
return fmt.Errorf("Client: Get Notification[method: %s, params: %v]. But RPC registry is not specified", notiCodec.Method(), params)
|
|
}
|
|
|
|
_, err = c.ch.GetRPCInvoker().Invoke(notiCodec)
|
|
|
|
return err
|
|
}
|
|
|
|
func getClientTimeoutError(c *client, timeout time.Duration) error {
|
|
err := fmt.Errorf("Client: Cannot obtain response during timeout=%s", timeout)
|
|
//c.LogError("%s", err)
|
|
return &ClientError{
|
|
Timeout: true,
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
func getClientOverflowError(c *client) error {
|
|
err := fmt.Errorf("Client: Requests' queue with size=%d is overflown. Try increasing Client.PendingRequests value", cap(c.requestQueueChan))
|
|
//c.LogError("%s", err)
|
|
return &ClientError{
|
|
Overflow: true,
|
|
Err: err,
|
|
}
|
|
}
|