rpc/client/client.go
crusader 2286a5021e ing
2017-11-26 19:15:51 +09:00

402 lines
8.3 KiB
Go

package client
import (
"fmt"
"io"
"log"
"net"
"reflect"
"runtime"
"sync"
"sync/atomic"
"time"
"git.loafle.net/commons_go/logging"
"git.loafle.net/commons_go/rpc/protocol"
)
func New(ch ClientHandler) Client {
c := &client{
ch: ch,
}
return c
}
type Client interface {
Connect() error
Close()
Send(method string, args ...interface{}) (err error)
Call(result interface{}, method string, args ...interface{}) error
CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error)
}
type client struct {
ch ClientHandler
conn net.Conn
pendingRequestsCount uint32
pendingRequests map[uint64]*RequestState
pendingRequestsLock sync.Mutex
requestQueueChan chan *RequestState
stopChan chan struct{}
stopWg sync.WaitGroup
}
func (c *client) Connect() error {
var err error
c.ch.Validate()
if c.stopChan != nil {
panic("RPC Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!")
}
if c.conn, err = c.ch.Connect(); nil != err {
return err
}
c.stopChan = make(chan struct{})
c.requestQueueChan = make(chan *RequestState, c.ch.GetPendingRequests())
c.pendingRequests = make(map[uint64]*RequestState)
go c.handleRPC()
return nil
}
func (c *client) Close() {
if c.stopChan == nil {
panic("Client: the client must be started before stopping it")
}
close(c.stopChan)
c.stopWg.Wait()
c.stopChan = nil
}
func (c *client) Send(method string, args ...interface{}) (err error) {
var cs *RequestState
if cs, err = c.send(true, false, nil, method, args...); nil != err {
return
}
select {
case <-cs.DoneChan:
err = cs.Error
releaseCallState(cs)
}
return
}
func (c *client) Call(result interface{}, method string, args ...interface{}) error {
return c.CallTimeout(c.ch.GetRequestTimeout(), result, method, args...)
}
func (c *client) CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error) {
var cs *RequestState
if cs, err = c.send(true, true, result, method, args...); nil != err {
return
}
t := retainTimer(timeout)
select {
case <-cs.DoneChan:
result, err = cs.Result, cs.Error
releaseCallState(cs)
case <-t.C:
cs.Cancel()
err = getClientTimeoutError(c, timeout)
}
releaseTimer(t)
return
}
func (c *client) send(usePool bool, hasResponse bool, result interface{}, method string, args ...interface{}) (cs *RequestState, err error) {
if !hasResponse {
usePool = true
}
if usePool {
cs = retainRequestState()
} else {
cs = &RequestState{}
}
cs.hasResponse = hasResponse
cs.Method = method
cs.Args = args
cs.DoneChan = make(chan *RequestState, 1)
if hasResponse {
cs.ID = c.ch.GetRequestID()
cs.Result = result
}
select {
case c.requestQueueChan <- cs:
return cs, nil
default:
// Try substituting the oldest async request by the new one
// on requests' queue overflow.
// This increases the chances for new request to succeed
// without timeout.
if !hasResponse {
// Immediately notify the caller not interested
// in the response on requests' queue overflow, since
// there are no other ways to notify it later.
releaseCallState(cs)
return nil, getClientOverflowError(c)
}
select {
case rcs := <-c.requestQueueChan:
if rcs.DoneChan != nil {
rcs.Error = getClientOverflowError(c)
//close(rcs.DoneChan)
rcs.Done()
} else {
releaseCallState(rcs)
}
default:
}
select {
case c.requestQueueChan <- cs:
return cs, nil
default:
// Release m even if usePool = true, since m wasn't exposed
// to the caller yet.
releaseCallState(cs)
return nil, getClientOverflowError(c)
}
}
}
func (c *client) handleRPC() {
subStopChan := make(chan struct{})
writerDone := make(chan error, 1)
go c.rpcWriter(subStopChan, writerDone)
readerDone := make(chan error, 1)
go c.rpcReader(readerDone)
var err error
select {
case err = <-writerDone:
close(subStopChan)
<-readerDone
case err = <-readerDone:
close(subStopChan)
<-writerDone
case <-c.stopChan:
close(subStopChan)
<-readerDone
<-writerDone
}
c.conn.Close()
if err != nil {
//c.LogError("%s", err)
log.Printf("handleRPC: %v", err)
err = &ClientError{
Connection: true,
Err: err,
}
}
for _, cs := range c.pendingRequests {
atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0))
cs.Error = err
if cs.DoneChan != nil {
cs.Done()
}
}
}
func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) {
var err error
defer func() {
writerDone <- err
}()
for {
var cs *RequestState
select {
case cs = <-c.requestQueueChan:
default:
// Give the last chance for ready goroutines filling c.requestsChan :)
runtime.Gosched()
select {
case <-stopChan:
return
case cs = <-c.requestQueueChan:
}
}
if cs.IsCanceled() {
if nil != cs.DoneChan {
// cs.Error = ErrCanceled
// close(m.done)
cs.Done()
} else {
releaseCallState(cs)
}
continue
}
if cs.hasResponse {
c.pendingRequestsLock.Lock()
n := len(c.pendingRequests)
c.pendingRequests[cs.ID] = cs
c.pendingRequestsLock.Unlock()
atomic.AddUint32(&c.pendingRequestsCount, 1)
if n > 10*c.ch.GetPendingRequests() {
err = fmt.Errorf("Client: The server didn't return %d responses yet. Closing server connection in order to prevent client resource leaks", n)
logging.Logger().Error(err.Error())
continue
}
}
var requestID interface{}
if 0 < cs.ID {
requestID = cs.ID
}
err = c.ch.GetCodec().WriteRequest(c.conn, cs.Method, cs.Args, requestID)
if !cs.hasResponse {
cs.Error = err
cs.Done()
}
if nil != err {
if err == io.ErrUnexpectedEOF || err == io.EOF {
logging.Logger().Info("Client: disconnected from server")
return
}
err = fmt.Errorf("Client: Cannot send request to wire: [%s]", err)
logging.Logger().Error(err.Error())
continue
}
}
}
func (c *client) rpcReader(readerDone chan<- error) {
var err error
defer func() {
if r := recover(); r != nil {
if err == nil {
err = fmt.Errorf("Client: Panic when reading data from server: %v", r)
}
}
readerDone <- err
}()
for {
msg, err := c.ch.GetCodec().NewMessage(c.conn)
if nil != err {
if err == io.ErrUnexpectedEOF || err == io.EOF {
logging.Logger().Info("Client: disconnected from server")
return
}
err = fmt.Errorf("Client: Cannot decode response or notify: [%s]", err)
logging.Logger().Error(err.Error())
continue
}
switch msg.MessageType() {
case protocol.MessageTypeResponse:
c.handleResponse(msg)
case protocol.MessageTypeNotification:
c.handleNotification(msg)
default:
}
if nil != err {
logging.Logger().Error(err.Error())
continue
}
}
}
func (c *client) handleResponse(msg protocol.ClientMessageCodec) error {
codec, err := msg.MessageCodec()
if nil != err {
return err
}
resCodec := codec.(protocol.ClientResponseCodec)
c.pendingRequestsLock.Lock()
id := reflect.ValueOf(resCodec.ID()).Convert(uint64Type).Uint()
cs, ok := c.pendingRequests[id]
if ok {
delete(c.pendingRequests, id)
}
c.pendingRequestsLock.Unlock()
if !ok {
return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", resCodec.ID())
}
atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0))
if err := resCodec.Result(cs.Result); nil != err {
log.Printf("responseHandle:%v", err)
}
if err := resCodec.Error(); nil != err {
log.Printf("responseHandle:%v", err)
// cs.Error = &ClientError{
// Server: true,
// err: fmt.Errorf("gorpc.Client: [%s]. Server error: [%s]", c.Addr, wr.Error),
// }
}
cs.Done()
return nil
}
func (c *client) handleNotification(msg protocol.ClientMessageCodec) error {
codec, err := msg.MessageCodec()
if nil != err {
return err
}
notiCodec := codec.(protocol.ClientNotificationCodec)
_, err = c.ch.GetRPCRegistry().Invoke(notiCodec)
return err
}
func getClientTimeoutError(c *client, timeout time.Duration) error {
err := fmt.Errorf("Client: Cannot obtain response during timeout=%s", timeout)
//c.LogError("%s", err)
return &ClientError{
Timeout: true,
Err: err,
}
}
func getClientOverflowError(c *client) error {
err := fmt.Errorf("Client: Requests' queue with size=%d is overflown. Try increasing Client.PendingRequests value", cap(c.requestQueueChan))
//c.LogError("%s", err)
return &ClientError{
Overflow: true,
Err: err,
}
}