package client import ( "context" "fmt" "reflect" "sync" "time" "go.uber.org/zap" olog "git.loafle.net/overflow/log-go" "git.loafle.net/overflow/rpc-go/protocol" css "git.loafle.net/overflow/server-go/socket" cssc "git.loafle.net/overflow/server-go/socket/client" ) var uint64Type = reflect.TypeOf(uint64(0)) type Client struct { ClientHandler ClientHandler ctx cssc.ClientCtx stopChan chan struct{} stopWg sync.WaitGroup requestID uint64 requestQueueChan chan *requestState pendingRequests sync.Map } func (c *Client) Start() error { if c.stopChan != nil { return fmt.Errorf("%s already running. Stop it before starting it again", c.logHeader()) } if nil == c.ClientHandler { return fmt.Errorf("%s ClientHandler must be specified", c.logHeader()) } if err := c.ClientHandler.Validate(); nil != err { return fmt.Errorf("%s validate error %v", c.logHeader(), err) } c.ctx = c.ClientHandler.ClientCtx() if nil == c.ctx { return fmt.Errorf("%s ServerCtx is nil", c.logHeader()) } if err := c.ClientHandler.Init(c.ctx); nil != err { return fmt.Errorf("%s Init error %v", c.logHeader(), err) } readChan, writeChan, err := c.ClientHandler.GetConnector().Connect() if nil != err { return err } c.requestQueueChan = make(chan *requestState, c.ClientHandler.GetPendingRequestCount()) c.stopChan = make(chan struct{}) c.stopWg.Add(1) go c.handleClient(readChan, writeChan) return nil } func (c *Client) Stop(ctx context.Context) error { if c.stopChan == nil { return fmt.Errorf("%s must be started before stopping it", c.logHeader()) } close(c.stopChan) c.stopWg.Wait() c.ClientHandler.Destroy(c.ctx) c.stopChan = nil return nil } func (c *Client) logHeader() string { return fmt.Sprintf("RPC Client[%s]:", c.ClientHandler.GetName()) } func (c *Client) Send(method string, params ...interface{}) error { _, err := c.internalSend(false, nil, method, params...) return err } func (c *Client) Call(result interface{}, method string, params ...interface{}) error { return c.CallTimeout(c.ClientHandler.GetRequestTimeout(), result, method, params...) } func (c *Client) CallTimeout(timeout time.Duration, result interface{}, method string, params ...interface{}) error { rs, err := c.internalSend(true, result, method, params...) if nil != err { return err } t := retainTimer(timeout) defer func() { releaseRequestState(rs) releaseTimer(t) }() select { case <-rs.doneChan: if nil != rs.clientError { return rs.clientError } result = rs.result return nil case <-t.C: rs.cancel() return newError(method, params, fmt.Errorf("%s Timeout", c.logHeader())) } } func (c *Client) getRequestID() uint64 { c.requestID++ return c.requestID } func (c *Client) internalSend(hasResponse bool, result interface{}, method string, params ...interface{}) (*requestState, error) { rs := retainRequestState() rs.method = method rs.params = params if hasResponse { rs.id = c.getRequestID() rs.result = result rs.doneChan = make(chan *requestState, 1) } select { case c.requestQueueChan <- rs: return rs, nil default: if !hasResponse { releaseRequestState(rs) return nil, newError(method, params, fmt.Errorf("%s Request Queue overflow", c.logHeader())) } select { case oldRS := <-c.requestQueueChan: if nil != oldRS.doneChan { oldRS.setError(fmt.Errorf("%s Request Queue overflow", c.logHeader())) oldRS.done() } else { releaseRequestState(oldRS) } default: } select { case c.requestQueueChan <- rs: return rs, nil default: releaseRequestState(rs) return nil, newError(method, params, fmt.Errorf("%s Request Queue overflow", c.logHeader())) } } } func (c *Client) handleClient(readChan <-chan css.SocketMessage, writeChan chan<- css.SocketMessage) { defer func() { if err := c.ClientHandler.GetConnector().Disconnect(); nil != err { olog.Logger().Warn(err.Error()) } c.ClientHandler.OnStop(c.ctx) olog.Logger().Info(fmt.Sprintf("%s Stopped", c.logHeader())) c.stopWg.Done() }() if err := c.ClientHandler.OnStart(c.ctx); nil != err { olog.Logger().Error(err.Error()) return } stopChan := make(chan struct{}) sendDoneChan := make(chan error) receiveDoneChan := make(chan error) go c.handleSend(stopChan, sendDoneChan, writeChan) go c.handleReceive(stopChan, receiveDoneChan, readChan) olog.Logger().Info(fmt.Sprintf("%s Started", c.logHeader())) select { case <-sendDoneChan: close(stopChan) <-receiveDoneChan case <-receiveDoneChan: close(stopChan) <-sendDoneChan case <-c.stopChan: close(stopChan) <-sendDoneChan <-receiveDoneChan } } func (c *Client) handleSend(stopChan <-chan struct{}, doneChan chan<- error, writeChan chan<- css.SocketMessage) { var ( rs *requestState id interface{} messageType int message []byte err error ok bool ) defer func() { doneChan <- err }() LOOP: for { select { case rs, ok = <-c.requestQueueChan: if !ok { return } if rs.isCanceled() { if nil != rs.doneChan { rs.done() } else { releaseRequestState(rs) } continue LOOP } id = nil if 0 < rs.id { id = rs.id } messageType, message, err = c.ClientHandler.GetRPCCodec().NewRequest(rs.method, rs.params, id) if nil != err { rs.setError(err) rs.done() continue LOOP } select { case writeChan <- css.MakeSocketMessage(messageType, message): default: rs.setError(fmt.Errorf("%s cannot send request", c.logHeader())) rs.done() continue LOOP } if 0 < rs.id { c.pendingRequests.Store(rs.id, rs) } case <-stopChan: return } } } func (c *Client) handleReceive(stopChan <-chan struct{}, doneChan chan<- error, readChan <-chan css.SocketMessage) { var ( socketMessage css.SocketMessage messageType int message []byte err error ok bool ) defer func() { doneChan <- err }() LOOP: for { select { case socketMessage, ok = <-readChan: if !ok { return } messageType, message = socketMessage() resCodec, err := c.ClientHandler.GetRPCCodec().NewResponse(messageType, message) if nil != err { olog.Logger().Debug(err.Error()) continue LOOP } if resCodec.IsNotification() { // notification notiCodec, err := resCodec.Notification() if nil != err { olog.Logger().Warn(fmt.Sprintf("%s notification error %v", c.logHeader()), zap.Error(err)) continue LOOP } go c.handleNotification(notiCodec) } else { // response go c.handleResponse(resCodec) } case <-stopChan: return } } } func (c *Client) handleResponse(resCodec protocol.ClientResponseCodec) { id := reflect.ValueOf(resCodec.ID()).Convert(uint64Type).Uint() _rs, ok := c.pendingRequests.Load(id) if !ok { olog.Logger().Warn(fmt.Sprintf("%s unexpected ID=[%d] obtained from server", c.logHeader()), zap.Uint64("id", id)) return } rs := _rs.(*requestState) if nil != resCodec.Error() { rs.setError(resCodec.Error()) } else { err := resCodec.Result(rs.result) if nil != err { rs.setError(err) } } rs.done() } func (c *Client) handleNotification(notiCodec protocol.ClientNotificationCodec) { if nil == c.ClientHandler.GetRPCInvoker() { olog.Logger().Warn(fmt.Sprintf("%s received notification method[%s] but RPC Invoker is not exist", c.logHeader(), notiCodec.Method())) return } _, err := c.ClientHandler.GetRPCInvoker().Invoke(notiCodec) if nil != err { olog.Logger().Error(fmt.Sprintf("%s invoking of notification method[%s] has been failed %v", c.logHeader(), notiCodec.Method(), err)) } }