This commit is contained in:
crusader 2017-11-01 12:10:39 +09:00
parent 63b5613089
commit b4b379c891
3 changed files with 49 additions and 62 deletions

View File

@ -3,6 +3,7 @@ package client
import ( import (
"fmt" "fmt"
"io" "io"
"log"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -19,8 +20,9 @@ func New(ch ClientHandler) Client {
} }
type Client interface { type Client interface {
Start(rwc io.ReadWriteCloser) Connect() error
Stop() Close()
Notify(method string, args interface{}) error Notify(method string, args interface{}) error
Call(method string, args interface{}, result interface{}) error Call(method string, args interface{}, result interface{}) error
CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) error CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) error
@ -32,6 +34,8 @@ type client struct {
rwc io.ReadWriteCloser rwc io.ReadWriteCloser
pendingRequestsCount uint32 pendingRequestsCount uint32
pendingRequests map[interface{}]*CallState
pendingRequestsLock sync.Mutex
requestQueueChan chan *CallState requestQueueChan chan *CallState
@ -39,28 +43,27 @@ type client struct {
stopWg sync.WaitGroup stopWg sync.WaitGroup
} }
func (c *client) Start(rwc io.ReadWriteCloser) { func (c *client) Connect() error {
var err error
c.ch.Validate() c.ch.Validate()
if nil == rwc {
panic("RWC(io.ReadWriteCloser) must be specified.")
}
if c.stopChan != nil { if c.stopChan != nil {
panic("Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!") panic("RPC Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!")
} }
c.rwc = rwc if c.rwc, err = c.ch.Connect(); nil != err {
return err
}
c.stopChan = make(chan struct{}) c.stopChan = make(chan struct{})
c.requestQueueChan = make(chan *CallState, c.ch.GetPendingRequests()) c.requestQueueChan = make(chan *CallState, c.ch.GetPendingRequests())
c.pendingRequests = make(map[interface{}]*CallState)
c.ch.OnStart() go c.handleRPC()
c.stopWg.Add(1) return nil
go handleRPC(c)
} }
func (c *client) Stop() { func (c *client) Close() {
if c.stopChan == nil { if c.stopChan == nil {
panic("Client: the client must be started before stopping it") panic("Client: the client must be started before stopping it")
} }
@ -160,47 +163,42 @@ func (c *client) send(method string, args interface{}, result interface{}, hasRe
} }
} }
func handleRPC(c *client) { func (c *client) handleRPC() {
defer c.stopWg.Done() subStopChan := make(chan struct{})
stopChan := make(chan struct{})
pendingRequests := make(map[interface{}]*CallState)
var pendingRequestsLock sync.Mutex
writerDone := make(chan error, 1) writerDone := make(chan error, 1)
go rpcWriter(c, pendingRequests, &pendingRequestsLock, stopChan, writerDone) go c.rpcWriter(subStopChan, writerDone)
readerDone := make(chan error, 1) readerDone := make(chan error, 1)
go rpcReader(c, pendingRequests, &pendingRequestsLock, readerDone) go c.rpcReader(readerDone)
var err error var err error
select { select {
case err = <-writerDone: case err = <-writerDone:
close(stopChan) close(subStopChan)
c.rwc.Close()
<-readerDone <-readerDone
case err = <-readerDone: case err = <-readerDone:
close(stopChan) close(subStopChan)
c.rwc.Close()
<-writerDone <-writerDone
case <-c.stopChan: case <-c.stopChan:
close(stopChan) close(subStopChan)
c.rwc.Close()
<-readerDone <-readerDone
<-writerDone <-writerDone
} }
c.rwc.Close()
if err != nil { if err != nil {
//c.LogError("%s", err) //c.LogError("%s", err)
log.Printf("handleRPC: %v", err)
err = &ClientError{ err = &ClientError{
Connection: true, Connection: true,
err: err, err: err,
} }
} }
for _, cs := range pendingRequests { for _, cs := range c.pendingRequests {
atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0))
cs.Error = err cs.Error = err
if cs.DoneChan != nil { if cs.DoneChan != nil {
@ -210,7 +208,7 @@ func handleRPC(c *client) {
} }
func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, stopChan <-chan struct{}, writerDone chan<- error) { func (c *client) rpcWriter(stopChan <-chan struct{}, writerDone chan<- error) {
var err error var err error
defer func() { defer func() {
writerDone <- err writerDone <- err
@ -244,10 +242,10 @@ func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingReq
} }
if nil != cs.DoneChan { if nil != cs.DoneChan {
pendingRequestsLock.Lock() c.pendingRequestsLock.Lock()
n := len(pendingRequests) n := len(c.pendingRequests)
pendingRequests[cs.ID] = cs c.pendingRequests[cs.ID] = cs
pendingRequestsLock.Unlock() c.pendingRequestsLock.Unlock()
atomic.AddUint32(&c.pendingRequestsCount, 1) atomic.AddUint32(&c.pendingRequestsCount, 1)
if n > 10*c.ch.GetPendingRequests() { if n > 10*c.ch.GetPendingRequests() {
@ -267,7 +265,7 @@ func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingReq
} }
} }
func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, readerDone chan<- error) { func (c *client) rpcReader(readerDone chan<- error) {
var err error var err error
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -286,9 +284,9 @@ func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingReq
} }
if crn.IsResponse() { if crn.IsResponse() {
err = responseHandle(c, crn.GetResponse(), pendingRequests, pendingRequestsLock) err = c.responseHandle(crn.GetResponse())
} else { } else {
err = notifyHandle(c, crn.GetNotify()) err = c.notifyHandle(crn.GetNotify())
} }
if nil != err { if nil != err {
return return
@ -297,13 +295,13 @@ func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingReq
} }
func responseHandle(c *client, codecResponse protocol.ClientCodecResponse, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex) error { func (c *client) responseHandle(codecResponse protocol.ClientCodecResponse) error {
pendingRequestsLock.Lock() c.pendingRequestsLock.Lock()
cs, ok := pendingRequests[codecResponse.ID()] cs, ok := c.pendingRequests[codecResponse.ID()]
if ok { if ok {
delete(pendingRequests, codecResponse.ID()) delete(c.pendingRequests, codecResponse.ID())
} }
pendingRequestsLock.Unlock() c.pendingRequestsLock.Unlock()
if !ok { if !ok {
return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", codecResponse.ID()) return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", codecResponse.ID())
@ -324,7 +322,7 @@ func responseHandle(c *client, codecResponse protocol.ClientCodecResponse, pendi
return nil return nil
} }
func notifyHandle(c *client, codecNotify protocol.ClientCodecNotify) error { func (c *client) notifyHandle(codecNotify protocol.ClientCodecNotify) error {
_, err := c.ch.GetRPCRegistry().Invoke(codecNotify) _, err := c.ch.GetRPCRegistry().Invoke(codecNotify)
return err return err

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"io"
"time" "time"
"git.loafle.net/commons_go/rpc" "git.loafle.net/commons_go/rpc"
@ -8,11 +9,9 @@ import (
) )
type ClientHandler interface { type ClientHandler interface {
OnStart() Connect() (io.ReadWriteCloser, error)
OnStop()
GetContentType() string
GetCodec() protocol.ClientCodec GetCodec() protocol.ClientCodec
GetRPCRegistry() rpc.Registry GetRPCRegistry() rpc.Registry
GetRequestTimeout() time.Duration GetRequestTimeout() time.Duration
GetPendingRequests() int GetPendingRequests() int

View File

@ -1,6 +1,8 @@
package client package client
import ( import (
"errors"
"io"
"sync" "sync"
"time" "time"
@ -9,7 +11,6 @@ import (
) )
type ClientHandlers struct { type ClientHandlers struct {
ContentType string
Codec protocol.ClientCodec Codec protocol.ClientCodec
// Maximum request time. // Maximum request time.
// Default value is DefaultRequestTimeout. // Default value is DefaultRequestTimeout.
@ -29,16 +30,8 @@ type ClientHandlers struct {
requestIDMtx sync.Mutex requestIDMtx sync.Mutex
} }
func (ch *ClientHandlers) OnStart() { func (ch *ClientHandlers) Connect() (io.ReadWriteCloser, error) {
// no op return nil, errors.New("RPC Client: ClientHandlers method[Connect] is not implement")
}
func (ch *ClientHandlers) OnStop() {
// no op
}
func (ch *ClientHandlers) GetContentType() string {
return ch.ContentType
} }
func (ch *ClientHandlers) GetCodec() protocol.ClientCodec { func (ch *ClientHandlers) GetCodec() protocol.ClientCodec {
@ -68,9 +61,6 @@ func (ch *ClientHandlers) GetRequestID() interface{} {
} }
func (ch *ClientHandlers) Validate() { func (ch *ClientHandlers) Validate() {
if "" == ch.ContentType {
panic("ContentType must be specified.")
}
if ch.RequestTimeout <= 0 { if ch.RequestTimeout <= 0 {
ch.RequestTimeout = DefaultRequestTimeout ch.RequestTimeout = DefaultRequestTimeout
} }