package client import ( "fmt" "io" "log" "net" "reflect" "runtime" "sync" "sync/atomic" "time" "git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/rpc/protocol" ) func New(ch ClientHandler) Client { c := &client{ ch: ch, } return c } type Client interface { Connect() error Close() Notify(method string, args ...interface{}) (err error) Call(result interface{}, method string, args ...interface{}) error CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error) } type client struct { ch ClientHandler conn net.Conn pendingRequestsCount uint32 pendingRequests map[uint64]*CallState pendingRequestsLock sync.Mutex requestQueueChan chan *CallState stopChan chan struct{} stopWg sync.WaitGroup } func (c *client) Connect() error { var err error c.ch.Validate() if c.stopChan != nil { panic("RPC Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!") } if c.conn, err = c.ch.Connect(); nil != err { return err } c.stopChan = make(chan struct{}) c.requestQueueChan = make(chan *CallState, c.ch.GetPendingRequests()) c.pendingRequests = make(map[uint64]*CallState) go c.handleRPC() return nil } func (c *client) Close() { if c.stopChan == nil { panic("Client: the client must be started before stopping it") } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil } func (c *client) Notify(method string, args ...interface{}) (err error) { var cs *CallState if cs, err = c.send(true, false, nil, method, args...); nil != err { return } select { case <-cs.DoneChan: err = cs.Error ReleaseCallState(cs) } return } func (c *client) Call(result interface{}, method string, args ...interface{}) error { return c.CallTimeout(c.ch.GetRequestTimeout(), result, method, args...) } func (c *client) CallTimeout(timeout time.Duration, result interface{}, method string, args ...interface{}) (err error) { var cs *CallState if cs, err = c.send(true, true, result, method, args...); nil != err { return } t := retainTimer(timeout) select { case <-cs.DoneChan: result, err = cs.Result, cs.Error ReleaseCallState(cs) case <-t.C: cs.Cancel() err = getClientTimeoutError(c, timeout) } releaseTimer(t) return } func (c *client) send(usePool bool, hasResponse bool, result interface{}, method string, args ...interface{}) (cs *CallState, err error) { if !hasResponse { usePool = true } if usePool { cs = RetainCallState() } else { cs = &CallState{} } cs.hasResponse = hasResponse cs.Method = method cs.Args = args cs.DoneChan = make(chan *CallState, 1) if hasResponse { cs.ID = c.ch.GetRequestID() cs.Result = result } select { case c.requestQueueChan <- cs: return cs, nil default: // Try substituting the oldest async request by the new one // on requests' queue overflow. // This increases the chances for new request to succeed // without timeout. if !hasResponse { // Immediately notify the caller not interested // in the response on requests' queue overflow, since // there are no other ways to notify it later. ReleaseCallState(cs) return nil, getClientOverflowError(c) } select { case rcs := <-c.requestQueueChan: if rcs.DoneChan != nil { rcs.Error = getClientOverflowError(c) //close(rcs.DoneChan) rcs.Done() } else { ReleaseCallState(rcs) } default: } select { case c.requestQueueChan <- cs: return cs, nil default: // Release m even if usePool = true, since m wasn't exposed // to the caller yet. ReleaseCallState(cs) return nil, getClientOverflowError(c) } } } func (c *client) handleRPC() { subStopChan := make(chan struct{}) writerDone := make(chan error, 1) go c.rpcWriter(subStopChan, writerDone) readerDone := make(chan error, 1) go c.rpcReader(readerDone) var err error select { case err = <-writerDone: close(subStopChan) <-readerDone case err = <-readerDone: close(subStopChan) <-writerDone case <-c.stopChan: close(subStopChan) <-readerDone <-writerDone } c.conn.Close() if err != nil { //c.LogError("%s", err) log.Printf("handleRPC: %v", err) err = &ClientError{ Connection: true, Err: err, } } for _, cs := range c.pendingRequests { atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) cs.Error = err if cs.DoneChan != nil { cs.Done() } } } func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) { var err error defer func() { writerDone <- err }() for { var cs *CallState select { case cs = <-c.requestQueueChan: default: // Give the last chance for ready goroutines filling c.requestsChan :) runtime.Gosched() select { case <-stopChan: return case cs = <-c.requestQueueChan: } } if cs.IsCanceled() { if nil != cs.DoneChan { // cs.Error = ErrCanceled // close(m.done) cs.Done() } else { ReleaseCallState(cs) } continue } if cs.hasResponse { c.pendingRequestsLock.Lock() n := len(c.pendingRequests) c.pendingRequests[cs.ID] = cs c.pendingRequestsLock.Unlock() atomic.AddUint32(&c.pendingRequestsCount, 1) if n > 10*c.ch.GetPendingRequests() { err = fmt.Errorf("Client: The server didn't return %d responses yet. Closing server connection in order to prevent client resource leaks", n) logging.Logger().Error(err.Error()) continue } } var requestID interface{} if 0 < cs.ID { requestID = cs.ID } err = c.ch.GetCodec().Write(c.conn, cs.Method, cs.Args, requestID) if !cs.hasResponse { cs.Error = err cs.Done() } if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { logging.Logger().Info("Client: disconnected from server") return } err = fmt.Errorf("Client: Cannot send request to wire: [%s]", err) logging.Logger().Error(err.Error()) continue } } } func (c *client) rpcReader(readerDone chan<- error) { var err error defer func() { if r := recover(); r != nil { if err == nil { err = fmt.Errorf("Client: Panic when reading data from server: %v", r) } } readerDone <- err }() for { crn, err := c.ch.GetCodec().NewResponseOrNotify(c.conn) if nil != err { if err == io.ErrUnexpectedEOF || err == io.EOF { logging.Logger().Info("Client: disconnected from server") return } err = fmt.Errorf("Client: Cannot decode response or notify: [%s]", err) logging.Logger().Error(err.Error()) continue } if crn.IsResponse() { err = c.responseHandle(crn.GetResponse()) } else { err = c.notifyHandle(crn.GetNotify()) } if nil != err { logging.Logger().Error(err.Error()) continue } } } func (c *client) responseHandle(codecResponse protocol.ClientCodecResponse) error { c.pendingRequestsLock.Lock() id := reflect.ValueOf(codecResponse.ID()).Convert(uint64Type).Uint() cs, ok := c.pendingRequests[id] if ok { delete(c.pendingRequests, id) } c.pendingRequestsLock.Unlock() if !ok { return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", codecResponse.ID()) } atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) if err := codecResponse.Result(cs.Result); nil != err { log.Printf("responseHandle:%v", err) } if err := codecResponse.Error(); nil != err { log.Printf("responseHandle:%v", err) // cs.Error = &ClientError{ // Server: true, // err: fmt.Errorf("gorpc.Client: [%s]. Server error: [%s]", c.Addr, wr.Error), // } } cs.Done() return nil } func (c *client) notifyHandle(codecNotify protocol.ClientCodecNotify) error { _, err := c.ch.GetRPCRegistry().Invoke(codecNotify) return err } func getClientTimeoutError(c *client, timeout time.Duration) error { err := fmt.Errorf("Client: Cannot obtain response during timeout=%s", timeout) //c.LogError("%s", err) return &ClientError{ Timeout: true, Err: err, } } func getClientOverflowError(c *client) error { err := fmt.Errorf("Client: Requests' queue with size=%d is overflown. Try increasing Client.PendingRequests value", cap(c.requestQueueChan)) //c.LogError("%s", err) return &ClientError{ Overflow: true, Err: err, } }