343 lines
7.5 KiB
Go
343 lines
7.5 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
olog "git.loafle.net/overflow/log-go"
|
|
"git.loafle.net/overflow/rpc-go/protocol"
|
|
css "git.loafle.net/overflow/server-go/socket"
|
|
cssc "git.loafle.net/overflow/server-go/socket/client"
|
|
)
|
|
|
|
var uint64Type = reflect.TypeOf(uint64(0))
|
|
|
|
type Client struct {
|
|
ClientHandler ClientHandler
|
|
|
|
ctx cssc.ClientCtx
|
|
stopChan chan struct{}
|
|
stopWg sync.WaitGroup
|
|
|
|
requestID uint64
|
|
requestQueueChan chan *requestState
|
|
|
|
pendingRequests sync.Map
|
|
}
|
|
|
|
func (c *Client) Start() error {
|
|
if c.stopChan != nil {
|
|
return fmt.Errorf("%s already running. Stop it before starting it again", c.logHeader())
|
|
}
|
|
|
|
if nil == c.ClientHandler {
|
|
return fmt.Errorf("%s ClientHandler must be specified", c.logHeader())
|
|
}
|
|
if err := c.ClientHandler.Validate(); nil != err {
|
|
return fmt.Errorf("%s validate error %v", c.logHeader(), err)
|
|
}
|
|
|
|
c.ctx = c.ClientHandler.ClientCtx()
|
|
if nil == c.ctx {
|
|
return fmt.Errorf("%s ServerCtx is nil", c.logHeader())
|
|
}
|
|
|
|
if err := c.ClientHandler.Init(c.ctx); nil != err {
|
|
return fmt.Errorf("%s Init error %v", c.logHeader(), err)
|
|
}
|
|
|
|
readChan, writeChan, err := c.ClientHandler.GetConnector().Connect()
|
|
if nil != err {
|
|
return err
|
|
}
|
|
|
|
c.requestQueueChan = make(chan *requestState, c.ClientHandler.GetPendingRequestCount())
|
|
c.stopChan = make(chan struct{})
|
|
|
|
c.stopWg.Add(1)
|
|
go c.handleClient(readChan, writeChan)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) Stop(ctx context.Context) error {
|
|
if c.stopChan == nil {
|
|
return fmt.Errorf("%s must be started before stopping it", c.logHeader())
|
|
}
|
|
close(c.stopChan)
|
|
c.stopWg.Wait()
|
|
|
|
c.ClientHandler.Destroy(c.ctx)
|
|
|
|
c.stopChan = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) logHeader() string {
|
|
return fmt.Sprintf("RPC Client[%s]:", c.ClientHandler.GetName())
|
|
}
|
|
|
|
func (c *Client) Send(method string, params ...interface{}) error {
|
|
_, err := c.internalSend(false, nil, method, params...)
|
|
|
|
return err
|
|
}
|
|
|
|
func (c *Client) Call(result interface{}, method string, params ...interface{}) error {
|
|
return c.CallTimeout(c.ClientHandler.GetRequestTimeout(), result, method, params...)
|
|
}
|
|
|
|
func (c *Client) CallTimeout(timeout time.Duration, result interface{}, method string, params ...interface{}) error {
|
|
rs, err := c.internalSend(true, result, method, params...)
|
|
if nil != err {
|
|
return err
|
|
}
|
|
|
|
t := retainTimer(timeout)
|
|
defer func() {
|
|
releaseRequestState(rs)
|
|
releaseTimer(t)
|
|
}()
|
|
|
|
select {
|
|
case <-rs.doneChan:
|
|
if nil != rs.clientError {
|
|
return rs.clientError
|
|
}
|
|
result = rs.result
|
|
return nil
|
|
case <-t.C:
|
|
rs.cancel()
|
|
return newError(method, params, fmt.Errorf("%s Timeout", c.logHeader()))
|
|
}
|
|
}
|
|
|
|
func (c *Client) getRequestID() uint64 {
|
|
c.requestID++
|
|
return c.requestID
|
|
}
|
|
|
|
func (c *Client) internalSend(hasResponse bool, result interface{}, method string, params ...interface{}) (*requestState, error) {
|
|
rs := retainRequestState()
|
|
|
|
rs.method = method
|
|
rs.params = params
|
|
if hasResponse {
|
|
rs.id = c.getRequestID()
|
|
rs.result = result
|
|
rs.doneChan = make(chan *requestState, 1)
|
|
}
|
|
|
|
select {
|
|
case c.requestQueueChan <- rs:
|
|
return rs, nil
|
|
default:
|
|
if !hasResponse {
|
|
releaseRequestState(rs)
|
|
return nil, newError(method, params, fmt.Errorf("%s Request Queue overflow", c.logHeader()))
|
|
}
|
|
select {
|
|
case oldRS := <-c.requestQueueChan:
|
|
if nil != oldRS.doneChan {
|
|
oldRS.setError(fmt.Errorf("%s Request Queue overflow", c.logHeader()))
|
|
oldRS.done()
|
|
} else {
|
|
releaseRequestState(oldRS)
|
|
}
|
|
default:
|
|
}
|
|
select {
|
|
case c.requestQueueChan <- rs:
|
|
return rs, nil
|
|
default:
|
|
releaseRequestState(rs)
|
|
return nil, newError(method, params, fmt.Errorf("%s Request Queue overflow", c.logHeader()))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleClient(readChan <-chan css.SocketMessage, writeChan chan<- css.SocketMessage) {
|
|
defer func() {
|
|
|
|
if err := c.ClientHandler.GetConnector().Disconnect(); nil != err {
|
|
olog.Logger().Warn(err.Error())
|
|
}
|
|
|
|
c.ClientHandler.OnStop(c.ctx)
|
|
olog.Logger().Info(fmt.Sprintf("%s Stopped", c.logHeader()))
|
|
c.stopWg.Done()
|
|
}()
|
|
|
|
if err := c.ClientHandler.OnStart(c.ctx); nil != err {
|
|
olog.Logger().Error(err.Error())
|
|
return
|
|
}
|
|
|
|
stopChan := make(chan struct{})
|
|
sendDoneChan := make(chan error)
|
|
receiveDoneChan := make(chan error)
|
|
|
|
go c.handleSend(stopChan, sendDoneChan, writeChan)
|
|
go c.handleReceive(stopChan, receiveDoneChan, readChan)
|
|
|
|
olog.Logger().Info(fmt.Sprintf("%s Started", c.logHeader()))
|
|
|
|
select {
|
|
case <-sendDoneChan:
|
|
close(stopChan)
|
|
<-receiveDoneChan
|
|
case <-receiveDoneChan:
|
|
close(stopChan)
|
|
<-sendDoneChan
|
|
case <-c.stopChan:
|
|
close(stopChan)
|
|
<-sendDoneChan
|
|
<-receiveDoneChan
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleSend(stopChan <-chan struct{}, doneChan chan<- error, writeChan chan<- css.SocketMessage) {
|
|
var (
|
|
rs *requestState
|
|
id interface{}
|
|
messageType int
|
|
message []byte
|
|
err error
|
|
ok bool
|
|
)
|
|
|
|
defer func() {
|
|
doneChan <- err
|
|
}()
|
|
|
|
LOOP:
|
|
for {
|
|
select {
|
|
case rs, ok = <-c.requestQueueChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
if rs.isCanceled() {
|
|
if nil != rs.doneChan {
|
|
rs.done()
|
|
} else {
|
|
releaseRequestState(rs)
|
|
}
|
|
continue LOOP
|
|
}
|
|
|
|
id = nil
|
|
if 0 < rs.id {
|
|
id = rs.id
|
|
}
|
|
messageType, message, err = c.ClientHandler.GetRPCCodec().NewRequest(rs.method, rs.params, id)
|
|
if nil != err {
|
|
rs.setError(err)
|
|
rs.done()
|
|
continue LOOP
|
|
}
|
|
|
|
select {
|
|
case writeChan <- css.MakeSocketMessage(messageType, message):
|
|
default:
|
|
rs.setError(fmt.Errorf("%s cannot send request", c.logHeader()))
|
|
rs.done()
|
|
continue LOOP
|
|
}
|
|
|
|
if 0 < rs.id {
|
|
c.pendingRequests.Store(rs.id, rs)
|
|
}
|
|
case <-stopChan:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleReceive(stopChan <-chan struct{}, doneChan chan<- error, readChan <-chan css.SocketMessage) {
|
|
var (
|
|
socketMessage css.SocketMessage
|
|
messageType int
|
|
message []byte
|
|
err error
|
|
ok bool
|
|
)
|
|
|
|
defer func() {
|
|
doneChan <- err
|
|
}()
|
|
|
|
LOOP:
|
|
for {
|
|
select {
|
|
case socketMessage, ok = <-readChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
messageType, message = socketMessage()
|
|
|
|
resCodec, err := c.ClientHandler.GetRPCCodec().NewResponse(messageType, message)
|
|
if nil != err {
|
|
olog.Logger().Debug(err.Error())
|
|
continue LOOP
|
|
}
|
|
|
|
if resCodec.IsNotification() {
|
|
// notification
|
|
notiCodec, err := resCodec.Notification()
|
|
if nil != err {
|
|
olog.Logger().Warn(fmt.Sprintf("%s notification error %v", c.logHeader()), zap.Error(err))
|
|
continue LOOP
|
|
}
|
|
|
|
go c.handleNotification(notiCodec)
|
|
} else {
|
|
// response
|
|
go c.handleResponse(resCodec)
|
|
}
|
|
case <-stopChan:
|
|
return
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func (c *Client) handleResponse(resCodec protocol.ClientResponseCodec) {
|
|
id := reflect.ValueOf(resCodec.ID()).Convert(uint64Type).Uint()
|
|
_rs, ok := c.pendingRequests.Load(id)
|
|
if !ok {
|
|
olog.Logger().Warn(fmt.Sprintf("%s unexpected ID=[%d] obtained from server", c.logHeader()), zap.Uint64("id", id))
|
|
return
|
|
}
|
|
rs := _rs.(*requestState)
|
|
if nil != resCodec.Error() {
|
|
rs.setError(resCodec.Error())
|
|
} else {
|
|
err := resCodec.Result(rs.result)
|
|
if nil != err {
|
|
rs.setError(err)
|
|
}
|
|
}
|
|
|
|
rs.done()
|
|
}
|
|
|
|
func (c *Client) handleNotification(notiCodec protocol.ClientNotificationCodec) {
|
|
if nil == c.ClientHandler.GetRPCInvoker() {
|
|
olog.Logger().Warn(fmt.Sprintf("%s received notification method[%s] but RPC Invoker is not exist", c.logHeader(), notiCodec.Method()))
|
|
return
|
|
}
|
|
|
|
_, err := c.ClientHandler.GetRPCInvoker().Invoke(notiCodec)
|
|
if nil != err {
|
|
olog.Logger().Error(fmt.Sprintf("%s invoking of notification method[%s] has been failed %v", c.logHeader(), notiCodec.Method(), err))
|
|
}
|
|
}
|