package client import ( "fmt" "io" "log" "reflect" "runtime" "sync" "sync/atomic" "time" "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/rpc/protocol" ) func New(ch ClientHandler, rwcHandler ClientReadWriteCloseHandler) Client { c := &client{ ch: ch, rwcHandler: rwcHandler, } return c } type Client interface { Connect() error Close() Send(method string, args ...interface{}) (err error) Call(result interface{}, method string, args ...interface{}) error CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error) } type client struct { ctx ClientContext ch ClientHandler rwcHandler ClientReadWriteCloseHandler conn interface{} pendingRequestsCount uint32 pendingRequests map[uint64]*RequestState pendingRequestsLock sync.Mutex requestQueueChan chan *RequestState stopChan chan struct{} stopWg sync.WaitGroup } func (c *client) Connect() error { var err error if nil == c.ch { panic("RPC Client: Client handler must be specified.") } c.ch.Validate() if nil == c.rwcHandler { panic("RPC Client: Client RWC handler must be specified.") } c.rwcHandler.Validate() if c.stopChan != nil { panic("RPC Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!") } c.ctx = c.ch.ClientContext(nil) if err := c.ch.Init(c.ctx); nil != err { logging.Logger().Panic(fmt.Sprintf("RPC Client: Initialization of client has been failed %v", err)) } if c.conn, err = c.rwcHandler.Connect(c.ctx); nil != err { return err } c.stopChan = make(chan struct{}) c.requestQueueChan = make(chan *RequestState, c.ch.GetPendingRequests()) c.pendingRequests = make(map[uint64]*RequestState) go c.handleRPC() return nil } func (c *client) Close() { if c.stopChan == nil { panic("RPC Client: the client must be started before stopping it") } c.ch.Destroy(c.ctx) close(c.stopChan) c.stopWg.Wait() c.stopChan = nil logging.Logger().Info(fmt.Sprintf("RPC Client: stopped")) } func (c *client) Send(method string, args ...interface{}) (err error) { var rs *RequestState if rs, err = c.send(true, false, nil, method, args...); nil != err { return } select { case <-rs.DoneChan: err = rs.Error releaseCallState(rs) } return } func (c *client) Call(result interface{}, method string, args ...interface{}) error { return c.CallTimeout(c.ch.GetRequestTimeout(), result, method, args...) } func (c *client) CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error) { var rs *RequestState if rs, err = c.send(true, true, result, method, args...); nil != err { return } t := retainTimer(timeout) select { case <-rs.DoneChan: result, err = rs.Result, rs.Error releaseCallState(rs) case <-t.C: rs.Cancel() err = getClientTimeoutError(c, timeout) } releaseTimer(t) return } func (c *client) send(usePool bool, hasResponse bool, result interface{}, method string, args ...interface{}) (rs *RequestState, err error) { if !hasResponse { usePool = true } if usePool { rs = retainRequestState() } else { rs = &RequestState{} } rs.hasResponse = hasResponse rs.Method = method rs.Args = args rs.DoneChan = make(chan *RequestState, 1) if hasResponse { rs.ID = c.ch.GetRequestID() rs.Result = result } select { case c.requestQueueChan <- rs: return rs, nil default: // Try substituting the oldest async request by the new one // on requests' queue overflow. // This increases the chances for new request to succeed // without timeout. if !hasResponse { // Immediately notify the caller not interested // in the response on requests' queue overflow, since // there are no other ways to notify it later. releaseCallState(rs) return nil, getClientOverflowError(c) } select { case rcs := <-c.requestQueueChan: if rcs.DoneChan != nil { rcs.Error = getClientOverflowError(c) //close(rcs.DoneChan) rcs.Done() } else { releaseCallState(rcs) } default: } select { case c.requestQueueChan <- rs: return rs, nil default: // Release m even if usePool = true, since m wasn't exposed // to the caller yet. releaseCallState(rs) return nil, getClientOverflowError(c) } } } func (c *client) handleRPC() { subStopChan := make(chan struct{}) writerDone := make(chan error, 1) go c.rpcWriter(subStopChan, writerDone) readerDone := make(chan error, 1) go c.rpcReader(readerDone) var err error select { case err = <-writerDone: close(subStopChan) <-readerDone case err = <-readerDone: close(subStopChan) <-writerDone case <-c.stopChan: close(subStopChan) <-readerDone <-writerDone } if nil != c.conn { c.rwcHandler.Disconnect(c.ctx, c.conn) } if err != nil { //c.LogError("%s", err) log.Printf("handleRPC: %v", err) err = &ClientError{ Connection: true, Err: err, } } for _, rs := range c.pendingRequests { atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) rs.Error = err if rs.DoneChan != nil { rs.Done() } } } func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) { var err error defer func() { writerDone <- err }() for { var rs *RequestState select { case rs = <-c.requestQueueChan: default: // Give the last chance for ready goroutines filling c.requestsChan :) runtime.Gosched() select { case <-stopChan: return case rs = <-c.requestQueueChan: } } if rs.IsCanceled() { if nil != rs.DoneChan { // rs.Error = ErrCanceled // close(m.done) rs.Done() } else { releaseCallState(rs) } continue } if rs.hasResponse { c.pendingRequestsLock.Lock() n := len(c.pendingRequests) c.pendingRequests[rs.ID] = rs c.pendingRequestsLock.Unlock() atomic.AddUint32(&c.pendingRequestsCount, 1) if n > 10*c.ch.GetPendingRequests() { err = fmt.Errorf("Client: The server didn't return %d responses yet. Closing server connection in order to prevent client resource leaks", n) logging.Logger().Error(err.Error()) continue } } var requestID interface{} if 0 < rs.ID { requestID = rs.ID } if nil == c.conn { err = io.EOF return } err = c.rwcHandler.WriteRequest(c.ctx, c.ch.GetCodec(), c.conn, rs.Method, rs.Args, requestID) if !rs.hasResponse { rs.Error = err rs.Done() } if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { logging.Logger().Info("Client: disconnected from server") return } err = fmt.Errorf("Client: Cannot send request to wire: [%s]", err) logging.Logger().Error(err.Error()) continue } } } func (c *client) rpcReader(readerDone chan<- error) { var err error defer func() { if r := recover(); r != nil { if err == nil { err = fmt.Errorf("Client: Panic when reading data from server: %v", r) } } readerDone <- err }() for { if nil == c.conn { err = io.EOF return } resCodec, err := c.rwcHandler.ReadResponse(c.ctx, c.ch.GetCodec(), c.conn) if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { logging.Logger().Info("Client: disconnected from server") return } err = fmt.Errorf("Client: Cannot decode response or notify: [%s]", err) logging.Logger().Error(err.Error()) continue } if nil != resCodec.ID() { err = c.handleResponse(resCodec) } else { err = c.handleNotification(resCodec) } if nil != err { logging.Logger().Error(err.Error()) continue } } } func (c *client) handleResponse(resCodec protocol.ClientResponseCodec) error { c.pendingRequestsLock.Lock() id := reflect.ValueOf(resCodec.ID()).Convert(uint64Type).Uint() rs, ok := c.pendingRequests[id] if ok { delete(c.pendingRequests, id) } c.pendingRequestsLock.Unlock() if !ok { return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", resCodec.ID()) } atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) if err := resCodec.Result(rs.Result); nil != err { log.Printf("responseHandle:%v", err) } if err := resCodec.Error(); nil != err { log.Printf("responseHandle:%v", err) // rs.Error = &ClientError{ // Server: true, // err: fmt.Errorf("gorpc.Client: [%s]. Server error: [%s]", c.Addr, wr.Error), // } } rs.Done() return nil } func (c *client) handleNotification(resCodec protocol.ClientResponseCodec) error { notiCodec, err := resCodec.Notification() if nil != err { return err } if nil == c.ch.GetRPCRegistry() { params, err := notiCodec.Params() if nil != err { return err } return fmt.Errorf("Client: Get Notification[method: %s, params: %v]. But RPC registry is not specified", notiCodec.Method(), params) } _, err = c.ch.GetRPCRegistry().Invoke(notiCodec) return err } func getClientTimeoutError(c *client, timeout time.Duration) error { err := fmt.Errorf("Client: Cannot obtain response during timeout=%s", timeout) //c.LogError("%s", err) return &ClientError{ Timeout: true, Err: err, } } func getClientOverflowError(c *client) error { err := fmt.Errorf("Client: Requests' queue with size=%d is overflown. Try increasing Client.PendingRequests value", cap(c.requestQueueChan)) //c.LogError("%s", err) return &ClientError{ Overflow: true, Err: err, } }