package client import ( "fmt" "log" "net" "runtime" "sync" "sync/atomic" "time" "git.loafle.net/commons_go/rpc/protocol" ) func New(ch ClientHandler) Client { c := &client{ ch: ch, } return c } type Client interface { Connect() error Close() Notify(method string, args interface{}) error Call(method string, args interface{}, result interface{}) error CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) error } type client struct { ch ClientHandler conn net.Conn pendingRequestsCount uint32 pendingRequests map[uint64]*CallState pendingRequestsLock sync.Mutex requestQueueChan chan *CallState stopChan chan struct{} stopWg sync.WaitGroup } func (c *client) Connect() error { var err error c.ch.Validate() if c.stopChan != nil { panic("RPC Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!") } if c.conn, err = c.ch.Connect(); nil != err { return err } c.stopChan = make(chan struct{}) c.requestQueueChan = make(chan *CallState, c.ch.GetPendingRequests()) c.pendingRequests = make(map[uint64]*CallState) go c.handleRPC() return nil } func (c *client) Close() { if c.stopChan == nil { panic("Client: the client must be started before stopping it") } close(c.stopChan) c.stopWg.Wait() c.stopChan = nil } func (c *client) Notify(method string, args interface{}) error { _, err := c.send(method, args, nil, false, true) return err } func (c *client) Call(method string, args interface{}, result interface{}) error { return c.CallTimeout(method, args, result, c.ch.GetRequestTimeout()) } func (c *client) CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) (err error) { var cs *CallState if cs, err = c.send(method, args, result, true, true); nil != err { return } t := retainTimer(timeout) select { case <-cs.DoneChan: result, err = cs.Result, cs.Error releaseCallState(cs) case <-t.C: cs.Cancel() err = getClientTimeoutError(c, timeout) } releaseTimer(t) return nil } func (c *client) send(method string, args interface{}, result interface{}, hasResponse bool, usePool bool) (cs *CallState, err error) { if !hasResponse { usePool = true } if usePool { cs = retainCallState() } else { cs = &CallState{} } cs.Method = method cs.Args = args if hasResponse { cs.ID = c.ch.GetRequestID() cs.Result = result cs.DoneChan = make(chan *CallState, 1) } select { case c.requestQueueChan <- cs: return cs, nil default: // Try substituting the oldest async request by the new one // on requests' queue overflow. // This increases the chances for new request to succeed // without timeout. if !hasResponse { // Immediately notify the caller not interested // in the response on requests' queue overflow, since // there are no other ways to notify it later. releaseCallState(cs) return nil, getClientOverflowError(c) } select { case rcs := <-c.requestQueueChan: if rcs.DoneChan != nil { rcs.Error = getClientOverflowError(c) //close(rcs.DoneChan) rcs.done() } else { releaseCallState(rcs) } default: } select { case c.requestQueueChan <- cs: return cs, nil default: // Release m even if usePool = true, since m wasn't exposed // to the caller yet. releaseCallState(cs) return nil, getClientOverflowError(c) } } } func (c *client) handleRPC() { subStopChan := make(chan struct{}) writerDone := make(chan error, 1) go c.rpcWriter(subStopChan, writerDone) readerDone := make(chan error, 1) go c.rpcReader(readerDone) var err error select { case err = <-writerDone: close(subStopChan) <-readerDone case err = <-readerDone: close(subStopChan) <-writerDone case <-c.stopChan: close(subStopChan) <-readerDone <-writerDone } c.conn.Close() if err != nil { //c.LogError("%s", err) log.Printf("handleRPC: %v", err) err = &ClientError{ Connection: true, err: err, } } for _, cs := range c.pendingRequests { atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) cs.Error = err if cs.DoneChan != nil { cs.done() } } } func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) { var err error defer func() { writerDone <- err }() for { var cs *CallState select { case cs = <-c.requestQueueChan: default: // Give the last chance for ready goroutines filling c.requestsChan :) runtime.Gosched() select { case <-stopChan: return case cs = <-c.requestQueueChan: } } if cs.IsCanceled() { if nil != cs.DoneChan { // cs.Error = ErrCanceled // close(m.done) cs.done() } else { releaseCallState(cs) } continue } if nil != cs.DoneChan { c.pendingRequestsLock.Lock() n := len(c.pendingRequests) c.pendingRequests[cs.ID] = cs c.pendingRequestsLock.Unlock() atomic.AddUint32(&c.pendingRequestsCount, 1) if n > 10*c.ch.GetPendingRequests() { err = fmt.Errorf("Client: The server didn't return %d responses yet. Closing server connection in order to prevent client resource leaks", n) return } } if nil == cs.DoneChan { releaseCallState(cs) } if err = c.ch.GetCodec().Write(c.conn, cs.Method, cs.Args, cs.ID); nil != err { err = fmt.Errorf("Client: Cannot send request to wire: [%s]", err) return } } } func (c *client) rpcReader(readerDone chan<- error) { var err error defer func() { if r := recover(); r != nil { if err == nil { err = fmt.Errorf("Client: Panic when reading data from server: %v", r) } } readerDone <- err }() for { crn, err := c.ch.GetCodec().NewResponseOrNotify(c.conn) if nil != err { err = fmt.Errorf("Client: Cannot decode response or notify: [%s]", err) return } if crn.IsResponse() { err = c.responseHandle(crn.GetResponse()) } else { err = c.notifyHandle(crn.GetNotify()) } if nil != err { return } } } func (c *client) responseHandle(codecResponse protocol.ClientCodecResponse) error { c.pendingRequestsLock.Lock() cs, ok := c.pendingRequests[codecResponse.ID().(uint64)] if ok { delete(c.pendingRequests, codecResponse.ID().(uint64)) } c.pendingRequestsLock.Unlock() if !ok { return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", codecResponse.ID()) } atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) cs.Result = codecResponse.Result() if err := codecResponse.Error(); nil != err { // cs.Error = &ClientError{ // Server: true, // err: fmt.Errorf("gorpc.Client: [%s]. Server error: [%s]", c.Addr, wr.Error), // } } cs.done() return nil } func (c *client) notifyHandle(codecNotify protocol.ClientCodecNotify) error { _, err := c.ch.GetRPCRegistry().Invoke(codecNotify) return err } func getClientTimeoutError(c *client, timeout time.Duration) error { err := fmt.Errorf("Client: Cannot obtain response during timeout=%s", timeout) //c.LogError("%s", err) return &ClientError{ Timeout: true, err: err, } } func getClientOverflowError(c *client) error { err := fmt.Errorf("Client: Requests' queue with size=%d is overflown. Try increasing Client.PendingRequests value", cap(c.requestQueueChan)) //c.LogError("%s", err) return &ClientError{ Overflow: true, err: err, } }