diff --git a/client/call.go b/client/call.go new file mode 100644 index 0000000..167bb05 --- /dev/null +++ b/client/call.go @@ -0,0 +1,68 @@ +package client + +import ( + "sync" + "sync/atomic" + "time" +) + +var callStatePool sync.Pool +var zeroTime time.Time + +type CallState struct { + ID interface{} + Method string + Args interface{} + Result interface{} + Error error + DoneChan chan *CallState + + canceled uint32 +} + +func (cs *CallState) done() { + select { + case cs.DoneChan <- cs: + // ok + default: + // We don't want to block here. It is the caller's responsibility to make + // sure the channel has enough buffer space. See comment in Go(). + // if debugLog { + // log.Println("rpc: discarding Call reply due to insufficient Done chan capacity") + // } + } +} + +// Cancel cancels async call. +// +// Canceled call isn't sent to the server unless it is already sent there. +// Canceled call may successfully complete if it has been already sent +// to the server before Cancel call. +// +// It is safe calling this function multiple times from concurrently +// running goroutines. +func (cs *CallState) Cancel() { + atomic.StoreUint32(&cs.canceled, 1) +} + +func (cs *CallState) IsCanceled() bool { + return atomic.LoadUint32(&cs.canceled) != 0 +} + +func retainCallState() *CallState { + v := callStatePool.Get() + if v == nil { + return &CallState{} + } + return v.(*CallState) +} + +func releaseCallState(cs *CallState) { + cs.Method = "" + cs.Args = nil + cs.Result = nil + cs.Error = nil + cs.DoneChan = nil + + callStatePool.Put(cs) +} diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..4a4d77f --- /dev/null +++ b/client/client.go @@ -0,0 +1,339 @@ +package client + +import ( + "fmt" + "io" + "runtime" + "sync" + "sync/atomic" + "time" + + "git.loafle.net/commons_go/rpc/protocol" +) + +func New(ch ClientHandler) Client { + c := &client{ + ch: ch, + } + return c +} + +type Client interface { + Start(rwc io.ReadWriteCloser) + Stop() + Notify(method string, args interface{}) error + Call(method string, args interface{}, result interface{}) error + CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) error +} + +type client struct { + ch ClientHandler + + rwc io.ReadWriteCloser + + pendingRequestsCount uint32 + + requestQueueChan chan *CallState + + stopChan chan struct{} + stopWg sync.WaitGroup +} + +func (c *client) Start(rwc io.ReadWriteCloser) { + c.ch.Validate() + + if nil == rwc { + panic("RWC(io.ReadWriteCloser) must be specified.") + } + + if c.stopChan != nil { + panic("Client: the given client is already started. Call Client.Stop() before calling Client.Start() again!") + } + + c.rwc = rwc + c.stopChan = make(chan struct{}) + c.requestQueueChan = make(chan *CallState, c.ch.GetPendingRequests()) + + c.stopWg.Add(1) + go handleRPC(c) +} + +func (c *client) Stop() { + if c.stopChan == nil { + panic("Client: the client must be started before stopping it") + } + close(c.stopChan) + c.stopWg.Wait() + c.stopChan = nil +} + +func (c *client) Notify(method string, args interface{}) error { + _, err := c.send(method, args, nil, false, true) + return err +} + +func (c *client) Call(method string, args interface{}, result interface{}) error { + return c.CallTimeout(method, args, result, c.ch.GetRequestTimeout()) +} + +func (c *client) CallTimeout(method string, args interface{}, result interface{}, timeout time.Duration) (err error) { + var cs *CallState + if cs, err = c.send(method, args, result, true, true); nil != err { + return + } + + t := retainTimer(timeout) + + select { + case <-cs.DoneChan: + result, err = cs.Result, cs.Error + releaseCallState(cs) + case <-t.C: + cs.Cancel() + err = getClientTimeoutError(c, timeout) + } + + releaseTimer(t) + + return nil +} + +func (c *client) send(method string, args interface{}, result interface{}, hasResponse bool, usePool bool) (cs *CallState, err error) { + if !hasResponse { + usePool = true + } + + if usePool { + cs = retainCallState() + } else { + cs = &CallState{} + } + + cs.Method = method + cs.Args = args + + if hasResponse { + cs.ID = c.ch.GetRequestID() + cs.Result = result + cs.DoneChan = make(chan *CallState, 1) + } + + select { + case c.requestQueueChan <- cs: + return cs, nil + default: + // Try substituting the oldest async request by the new one + // on requests' queue overflow. + // This increases the chances for new request to succeed + // without timeout. + if !hasResponse { + // Immediately notify the caller not interested + // in the response on requests' queue overflow, since + // there are no other ways to notify it later. + releaseCallState(cs) + return nil, getClientOverflowError(c) + } + + select { + case rcs := <-c.requestQueueChan: + if rcs.DoneChan != nil { + rcs.Error = getClientOverflowError(c) + //close(rcs.DoneChan) + rcs.done() + } else { + releaseCallState(rcs) + } + default: + } + + select { + case c.requestQueueChan <- cs: + return cs, nil + default: + // Release m even if usePool = true, since m wasn't exposed + // to the caller yet. + releaseCallState(cs) + return nil, getClientOverflowError(c) + } + } +} + +func handleRPC(c *client) { + defer c.stopWg.Done() + + stopChan := make(chan struct{}) + + pendingRequests := make(map[interface{}]*CallState) + var pendingRequestsLock sync.Mutex + + writerDone := make(chan error, 1) + go rpcWriter(c, pendingRequests, &pendingRequestsLock, stopChan, writerDone) + + readerDone := make(chan error, 1) + go rpcReader(c, pendingRequests, &pendingRequestsLock, readerDone) + + var err error + + select { + case err = <-writerDone: + close(stopChan) + c.rwc.Close() + <-readerDone + case err = <-readerDone: + close(stopChan) + c.rwc.Close() + <-writerDone + case <-c.stopChan: + close(stopChan) + c.rwc.Close() + <-readerDone + <-writerDone + } + + if err != nil { + //c.LogError("%s", err) + err = &ClientError{ + Connection: true, + err: err, + } + } + +} + +func rpcWriter(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, stopChan <-chan struct{}, writerDone chan<- error) { + var err error + defer func() { + writerDone <- err + }() + + for { + var cs *CallState + + select { + case cs = <-c.requestQueueChan: + default: + // Give the last chance for ready goroutines filling c.requestsChan :) + runtime.Gosched() + + select { + case <-stopChan: + return + case cs = <-c.requestQueueChan: + } + } + + if cs.IsCanceled() { + if nil != cs.DoneChan { + // cs.Error = ErrCanceled + // close(m.done) + cs.done() + } else { + releaseCallState(cs) + } + continue + } + + if nil != cs.DoneChan { + pendingRequestsLock.Lock() + n := len(pendingRequests) + pendingRequests[cs.ID] = cs + pendingRequestsLock.Unlock() + atomic.AddUint32(&c.pendingRequestsCount, 1) + + if n > 10*c.ch.GetPendingRequests() { + err = fmt.Errorf("Client: The server didn't return %d responses yet. Closing server connection in order to prevent client resource leaks", n) + return + } + } + + if nil == cs.DoneChan { + releaseCallState(cs) + } + + if err = c.ch.GetCodec().Write(c.rwc, cs.Method, cs.Args, cs.ID); nil != err { + err = fmt.Errorf("Client: Cannot send request to wire: [%s]", err) + return + } + } +} + +func rpcReader(c *client, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex, readerDone chan<- error) { + var err error + defer func() { + if r := recover(); r != nil { + if err == nil { + err = fmt.Errorf("Client: Panic when reading data from server: %v", r) + } + } + readerDone <- err + }() + + for { + crn, err := c.ch.GetCodec().NewResponseOrNotify(c.rwc) + if nil != err { + err = fmt.Errorf("Client: Cannot decode response or notify: [%s]", err) + return + } + + if crn.IsResponse() { + err = responseHandle(c, crn.GetResponse(), pendingRequests, pendingRequestsLock) + } else { + err = notifyHandle(c, crn.GetNotify()) + } + if nil != err { + return + } + } + +} + +func responseHandle(c *client, codecResponse protocol.ClientCodecResponse, pendingRequests map[interface{}]*CallState, pendingRequestsLock *sync.Mutex) error { + pendingRequestsLock.Lock() + cs, ok := pendingRequests[codecResponse.ID()] + if ok { + delete(pendingRequests, codecResponse.ID()) + } + pendingRequestsLock.Unlock() + + if !ok { + return fmt.Errorf("Client: Unexpected ID=[%v] obtained from server", codecResponse.ID()) + } + + atomic.AddUint32(&c.pendingRequestsCount, ^uint32(0)) + + cs.Result = codecResponse.Result() + if err := codecResponse.Error(); nil != err { + // cs.Error = &ClientError{ + // Server: true, + // err: fmt.Errorf("gorpc.Client: [%s]. Server error: [%s]", c.Addr, wr.Error), + // } + } + + cs.done() + + return nil +} + +func notifyHandle(c *client, codecNotify protocol.ClientCodecNotify) error { + _, err := c.ch.GetRPCRegistry().Invoke(codecNotify) + + return err +} + +func getClientTimeoutError(c *client, timeout time.Duration) error { + err := fmt.Errorf("Client: Cannot obtain response during timeout=%s", timeout) + //c.LogError("%s", err) + return &ClientError{ + Timeout: true, + err: err, + } +} + +func getClientOverflowError(c *client) error { + err := fmt.Errorf("Client: Requests' queue with size=%d is overflown. Try increasing Client.PendingRequests value", cap(c.requestQueueChan)) + //c.LogError("%s", err) + return &ClientError{ + Overflow: true, + err: err, + } +} diff --git a/client/client_handler.go b/client/client_handler.go new file mode 100644 index 0000000..49874b1 --- /dev/null +++ b/client/client_handler.go @@ -0,0 +1,25 @@ +package client + +import ( + "time" + + "git.loafle.net/commons_go/rpc" + "git.loafle.net/commons_go/rpc/protocol" +) + +type ClientHandler interface { + OnStart() + OnStop() + + GetContentType() string + GetCodec() protocol.ClientCodec + GetRPCRegistry() rpc.Registry + GetRequestTimeout() time.Duration + GetPendingRequests() int + + GetRequestID() interface{} + Send() + Validate() + + addWrite(cs *CallState) +} diff --git a/client/client_handlers.go b/client/client_handlers.go new file mode 100644 index 0000000..e6dea91 --- /dev/null +++ b/client/client_handlers.go @@ -0,0 +1,80 @@ +package client + +import ( + "sync" + "time" + + "git.loafle.net/commons_go/rpc" + "git.loafle.net/commons_go/rpc/protocol" +) + +type ClientHandlers struct { + ContentType string + Codec protocol.ClientCodec + // Maximum request time. + // Default value is DefaultRequestTimeout. + RequestTimeout time.Duration + // The maximum number of pending requests in the queue. + // + // The number of pending requsts should exceed the expected number + // of concurrent goroutines calling client's methods. + // Otherwise a lot of ClientError.Overflow errors may appear. + // + // Default is DefaultPendingMessages. + PendingRequests int + + RPCRegistry rpc.Registry + + requestID uint64 + requestIDMtx sync.Mutex +} + +func (ch *ClientHandlers) OnStart() { + // no op +} + +func (ch *ClientHandlers) OnStop() { + // no op +} + +func (ch *ClientHandlers) GetContentType() string { + return ch.ContentType +} + +func (ch *ClientHandlers) GetCodec() protocol.ClientCodec { + return ch.Codec +} + +func (ch *ClientHandlers) GetRPCRegistry() rpc.Registry { + return ch.RPCRegistry +} + +func (ch *ClientHandlers) GetRequestTimeout() time.Duration { + return ch.RequestTimeout +} + +func (ch *ClientHandlers) GetPendingRequests() int { + return ch.PendingRequests +} + +func (ch *ClientHandlers) GetRequestID() interface{} { + var id uint64 + ch.requestIDMtx.Lock() + ch.requestID++ + id = ch.requestID + ch.requestIDMtx.Unlock() + + return id +} + +func (ch *ClientHandlers) Validate() { + if "" == ch.ContentType { + panic("ContentType must be specified.") + } + if ch.RequestTimeout <= 0 { + ch.RequestTimeout = DefaultRequestTimeout + } + if ch.PendingRequests <= 0 { + ch.PendingRequests = DefaultPendingMessages + } +} diff --git a/client/constants.go b/client/constants.go new file mode 100644 index 0000000..a741697 --- /dev/null +++ b/client/constants.go @@ -0,0 +1,12 @@ +package client + +import "time" + +const ( + // DefaultRequestTimeout is the default timeout for client request. + DefaultRequestTimeout = 20 * time.Second + + // DefaultPendingMessages is the default number of pending messages + // handled by Client and Server. + DefaultPendingMessages = 32 * 1024 +) diff --git a/client/error.go b/client/error.go new file mode 100644 index 0000000..e10be8d --- /dev/null +++ b/client/error.go @@ -0,0 +1,26 @@ +package client + +// ClientError is an error Client methods can return. +type ClientError struct { + // Set if the error is timeout-related. + Timeout bool + + // Set if the error is connection-related. + Connection bool + + // Set if the error is server-related. + Server bool + + // Set if the error is related to internal resources' overflow. + // Increase PendingRequests if you see a lot of such errors. + Overflow bool + + // May be set if AsyncResult.Cancel is called. + Canceled bool + + err error +} + +func (e *ClientError) Error() string { + return e.err.Error() +} diff --git a/client/timer.go b/client/timer.go new file mode 100644 index 0000000..9ad91b1 --- /dev/null +++ b/client/timer.go @@ -0,0 +1,34 @@ +package client + +import ( + "sync" + "time" +) + +var timerPool sync.Pool + +func retainTimer(timeout time.Duration) *time.Timer { + tv := timerPool.Get() + if tv == nil { + return time.NewTimer(timeout) + } + + t := tv.(*time.Timer) + if t.Reset(timeout) { + panic("BUG: Active timer trapped into retainTimer()") + } + return t +} + +func releaseTimer(t *time.Timer) { + if !t.Stop() { + // Collect possibly added time from the channel + // if timer has been stopped and nobody collected its' value. + select { + case <-t.C: + default: + } + } + + timerPool.Put(t) +} diff --git a/encode/encode.go b/encode/encode.go index e198fd6..544c99d 100644 --- a/encode/encode.go +++ b/encode/encode.go @@ -2,10 +2,11 @@ package encode import "io" -// Encoder interface contains the encoder for http response. +// Encoder interface contains the encoder for response. // Eg. gzip, flate compressions. type Encoder interface { Encode(w io.Writer) io.Writer + Decode(r io.Reader) io.Reader } type encoder struct { @@ -15,4 +16,8 @@ func (_ *encoder) Encode(w io.Writer) io.Writer { return w } +func (_ *encoder) Decode(r io.Reader) io.Reader { + return r +} + var DefaultEncoder = &encoder{} diff --git a/encode/selector.go b/encode/selector.go index 2c749b9..30f6c56 100644 --- a/encode/selector.go +++ b/encode/selector.go @@ -7,13 +7,18 @@ import "io" // figure out client capabilities. // Eg. "Accept-Encoding" tells about supported compressions. type EncoderSelector interface { - Select(r io.Reader) Encoder + SelectByReader(r io.Reader) Encoder + SelectByWriter(w io.Writer) Encoder } type encoderSelector struct { } -func (_ *encoderSelector) Select(_ io.Reader) Encoder { +func (_ *encoderSelector) SelectByReader(_ io.Reader) Encoder { + return DefaultEncoder +} + +func (_ *encoderSelector) SelectByWriter(_ io.Writer) Encoder { return DefaultEncoder } diff --git a/protocol/client_codec.go b/protocol/client_codec.go new file mode 100644 index 0000000..ecee9c5 --- /dev/null +++ b/protocol/client_codec.go @@ -0,0 +1,32 @@ +package protocol + +import ( + "io" +) + +// ClientCodec creates a ClientCodecRequest to process each request. +type ClientCodec interface { + Write(w io.Writer, method string, args interface{}, id interface{}) error + NewResponseOrNotify(rc io.Reader) (ClientCodecResponseOrNotify, error) +} + +// ClientCodecResponseOrNotify encodes a response or notify using a specific +// serialization scheme. +type ClientCodecResponseOrNotify interface { + IsResponse() bool + IsNotify() bool + GetResponse() ClientCodecResponse + GetNotify() ClientCodecNotify + Complete() +} + +type ClientCodecResponse interface { + ID() interface{} + Result() interface{} + Error() error + Complete() +} + +type ClientCodecNotify interface { + RegistryCodec +} diff --git a/protocol/codec.go b/protocol/codec.go deleted file mode 100644 index 02f7642..0000000 --- a/protocol/codec.go +++ /dev/null @@ -1,29 +0,0 @@ -package protocol - -import ( - "io" -) - -// ---------------------------------------------------------------------------- -// Codec -// ---------------------------------------------------------------------------- - -// Codec creates a CodecRequest to process each request. -type Codec interface { - // NewRequest is constructor of new request object - // error io.ErrUnexpectedEOF or io.EOF - NewRequest(rc io.Reader) (CodecRequest, error) -} - -// CodecRequest decodes a request and encodes a response using a specific -// serialization scheme. -type CodecRequest interface { - // Reads the request and returns the RPC method name. - Method() (string, error) - // Reads the request filling the RPC method args. - ReadRequest(interface{}) error - // Writes the response using the RPC method reply. - WriteResponse(io.Writer, interface{}) error - // Writes an error produced by the server. - WriteError(w io.Writer, status int, err error) error -} diff --git a/protocol/json/client.go b/protocol/json/client.go index 48af5d7..bce8333 100644 --- a/protocol/json/client.go +++ b/protocol/json/client.go @@ -2,8 +2,12 @@ package json import ( "encoding/json" + "fmt" "io" - "math/rand" + "sync" + + "git.loafle.net/commons_go/rpc/encode" + "git.loafle.net/commons_go/rpc/protocol" ) // ---------------------------------------------------------------------------- @@ -23,58 +27,167 @@ type clientRequest struct { // The request id. This can be of any type. It is used to match the // response with the request that it is replying to. - ID uint64 `json:"id"` + ID interface{} `json:"id"` } // clientResponse represents a JSON-RPC response returned to a client. type clientResponse struct { Version string `json:"jsonrpc"` Result *json.RawMessage `json:"result"` - Error *json.RawMessage `json:"error"` + Error interface{} `json:"error"` + ID interface{} `json:"id"` } -// EncodeClientRequest encodes parameters for a JSON-RPC client request. -func EncodeClientRequest(method string, args interface{}) ([]byte, error) { - c := &clientRequest{ - Version: Version, - Method: method, - Params: args, - ID: uint64(rand.Int63()), - } - return json.Marshal(c) +// clientRequest represents a JSON-RPC request sent by a client. +type clientNotify struct { + // JSON-RPC protocol. + Version string `json:"jsonrpc"` + + // A String containing the name of the method to be invoked. + Method string `json:"method"` + + // Object to pass as request parameter to the method. + Params *json.RawMessage `json:"params"` } -// EncodeClientNotify encodes parameters for a JSON-RPC client notification. -func EncodeClientNotify(method string, args interface{}) ([]byte, error) { - c := &clientRequest{ - Version: Version, - Method: method, - Params: args, - } - return json.Marshal(c) +// ---------------------------------------------------------------------------- +// Codec +// ---------------------------------------------------------------------------- + +// NewCustomClientCodec returns a new JSON Codec based on passed encoder selector. +func NewCustomClientCodec(encSel encode.EncoderSelector) *ClientCodec { + return &ClientCodec{encSel: encSel} } -// DecodeClientResponse decodes the response body of a client request into -// the interface reply. -func DecodeClientResponse(r io.Reader, reply interface{}) error { - var c clientResponse - if err := json.NewDecoder(r).Decode(&c); err != nil { +// NewClientCodec returns a new JSON Codec. +func NewClientCodec() *ClientCodec { + return NewCustomClientCodec(encode.DefaultEncoderSelector) +} + +// ClientCodec creates a ClientCodecRequest to process each request. +type ClientCodec struct { + encSel encode.EncoderSelector +} + +func (cc *ClientCodec) Write(w io.Writer, method string, args interface{}, id interface{}) error { + req := retainClientRequest(method, args, id) + defer func() { + if nil != req { + releaseClientRequest(req) + } + }() + + encoder := json.NewEncoder(cc.encSel.SelectByWriter(w).Encode(w)) + if err := encoder.Encode(req); nil != err { return err } - if c.Error != nil { - jsonErr := &Error{} - if err := json.Unmarshal(*c.Error, jsonErr); err != nil { - return &Error{ - Code: E_SERVER, - Message: string(*c.Error), - } - } - return jsonErr - } - if c.Result == nil { - return ErrNullResult - } - - return json.Unmarshal(*c.Result, reply) + return nil +} + +// NewResponse returns a ClientCodecResponse. +func (cc *ClientCodec) NewResponseOrNotify(r io.Reader) (protocol.ClientCodecResponseOrNotify, error) { + return newClientCodecResponseOrNotify(r, cc.encSel.SelectByReader(r)) +} + +// newCodecRequest returns a new ServerCodecRequest. +func newClientCodecResponseOrNotify(r io.Reader, encoder encode.Encoder) (protocol.ClientCodecResponseOrNotify, error) { + // Decode the request body and check if RPC method is valid. + var raw json.RawMessage + dec := json.NewDecoder(r) + err := dec.Decode(&raw) + + if err == io.ErrUnexpectedEOF || err == io.EOF { + return nil, err + } + if err != nil { + err = &Error{ + Code: E_PARSE, + Message: err.Error(), + Data: raw, + } + } + + ccrn := retainClientCodecResponseOrNotify() + + if res, err := newClientCodecResponse(raw, dec); nil != err { + notify, err := newClientCodecNotify(raw, dec) + if nil != err { + releaseClientCodecResponseOrNotify(ccrn) + return nil, fmt.Errorf("Is not response or notification [%v]", raw) + } + ccrn.notify = notify + } else { + ccrn.response = res + } + + return ccrn, nil +} + +type ClientCodecResponseOrNotify struct { + notify protocol.ClientCodecNotify + response protocol.ClientCodecResponse +} + +func (ccrn *ClientCodecResponseOrNotify) IsResponse() bool { + return nil != ccrn.response +} + +func (ccrn *ClientCodecResponseOrNotify) IsNotify() bool { + return nil != ccrn.notify +} + +func (ccrn *ClientCodecResponseOrNotify) GetResponse() protocol.ClientCodecResponse { + return ccrn.response +} + +func (ccrn *ClientCodecResponseOrNotify) GetNotify() protocol.ClientCodecNotify { + return ccrn.notify +} + +func (ccrn *ClientCodecResponseOrNotify) Complete() { + if nil != ccrn.notify { + ccrn.notify.Complete() + } + if nil != ccrn.response { + ccrn.response.Complete() + } + releaseClientCodecResponseOrNotify(ccrn) +} + +var clientRequestPool sync.Pool + +func retainClientRequest(method string, params interface{}, id interface{}) *clientRequest { + v := clientRequestPool.Get() + if v == nil { + return &clientRequest{} + } + cr := v.(*clientRequest) + cr.Method = method + cr.Params = params + cr.ID = id + return cr +} + +func releaseClientRequest(cr *clientRequest) { + cr.Method = "" + cr.Params = nil + cr.ID = nil + + clientRequestPool.Put(cr) +} + +var clientCodecResponseOrNotifyPool sync.Pool + +func retainClientCodecResponseOrNotify() *ClientCodecResponseOrNotify { + v := clientCodecResponseOrNotifyPool.Get() + if v == nil { + return &ClientCodecResponseOrNotify{} + } + return v.(*ClientCodecResponseOrNotify) +} + +func releaseClientCodecResponseOrNotify(cr *ClientCodecResponseOrNotify) { + + clientCodecResponseOrNotifyPool.Put(cr) } diff --git a/protocol/json/client_notify.go b/protocol/json/client_notify.go new file mode 100644 index 0000000..158f773 --- /dev/null +++ b/protocol/json/client_notify.go @@ -0,0 +1,92 @@ +package json + +import ( + "encoding/json" + "fmt" + "sync" + + "git.loafle.net/commons_go/rpc/protocol" +) + +// ---------------------------------------------------------------------------- +// ClientCodecNotify +// ---------------------------------------------------------------------------- + +// newCodecRequest returns a new ClientCodecNotify. +func newClientCodecNotify(raw json.RawMessage, decoder *json.Decoder) (protocol.ClientCodecNotify, error) { + // Decode the request body and check if RPC method is valid. + ccn := retainClientCodecNotify() + err := decoder.Decode(&ccn.notify) + if err != nil { + releaseClientCodecNotify(ccn) + return nil, err + } + if "" == ccn.notify.Method { + releaseClientCodecNotify(ccn) + return nil, fmt.Errorf("This is not ClientNotify") + } + + if ccn.notify.Version != Version { + ccn.err = &Error{ + Code: E_INVALID_REQ, + Message: "jsonrpc must be " + Version, + Data: ccn.notify, + } + } + + return ccn, nil +} + +// ClientCodecNotify decodes and encodes a single notification. +type ClientCodecNotify struct { + notify clientNotify + err error +} + +func (ccn *ClientCodecNotify) Method() string { + return ccn.notify.Method +} + +func (ccn *ClientCodecNotify) ReadParams(args interface{}) error { + if ccn.err == nil && ccn.notify.Params != nil { + // Note: if scr.request.Params is nil it's not an error, it's an optional member. + // JSON params structured object. Unmarshal to the args object. + if err := json.Unmarshal(*ccn.notify.Params, args); err != nil { + // Clearly JSON params is not a structured object, + // fallback and attempt an unmarshal with JSON params as + // array value and RPC params is struct. Unmarshal into + // array containing the request struct. + params := [1]interface{}{args} + if err = json.Unmarshal(*ccn.notify.Params, ¶ms); err != nil { + ccn.err = &Error{ + Code: E_INVALID_REQ, + Message: err.Error(), + Data: ccn.notify.Params, + } + } + } + } + return ccn.err +} + +func (ccn *ClientCodecNotify) Complete() { + releaseClientCodecNotify(ccn) +} + +var clientCodecNotifyPool sync.Pool + +func retainClientCodecNotify() *ClientCodecNotify { + v := clientCodecNotifyPool.Get() + if v == nil { + return &ClientCodecNotify{} + } + return v.(*ClientCodecNotify) +} + +func releaseClientCodecNotify(ccn *ClientCodecNotify) { + ccn.notify.Version = "" + ccn.notify.Method = "" + ccn.notify.Params = nil + + clientCodecNotifyPool.Put(ccn) +} diff --git a/protocol/json/client_response.go b/protocol/json/client_response.go new file mode 100644 index 0000000..8c12bda --- /dev/null +++ b/protocol/json/client_response.go @@ -0,0 +1,79 @@ +package json + +import ( + "encoding/json" + "fmt" + "sync" + + "git.loafle.net/commons_go/rpc/protocol" +) + +// ---------------------------------------------------------------------------- +// ClientCodecResponse +// ---------------------------------------------------------------------------- + +// newClientCodecResponse returns a new ClientCodecResponse. +func newClientCodecResponse(raw json.RawMessage, decoder *json.Decoder) (protocol.ClientCodecResponse, error) { + // Decode the request body and check if RPC method is valid. + ccr := retainClientCodecResponse() + err := decoder.Decode(&ccr.response) + if err != nil { + releaseClientCodecResponse(ccr) + return nil, err + } + if nil == ccr.response.ID { + releaseClientCodecResponse(ccr) + return nil, fmt.Errorf("This is not Response") + } + + if ccr.response.Version != Version { + ccr.err = &Error{ + Code: E_INVALID_REQ, + Message: "jsonrpc must be " + Version, + Data: ccr.response, + } + } + + return ccr, nil +} + +// ClientCodecResponse decodes and encodes a single request. +type ClientCodecResponse struct { + response clientResponse + err error +} + +func (ccr *ClientCodecResponse) ID() interface{} { + return ccr.response.ID +} + +func (ccr *ClientCodecResponse) Result() interface{} { + return ccr.response.Result +} + +func (ccr *ClientCodecResponse) Error() error { + return ccr.response.Error.(error) +} + +func (ccr *ClientCodecResponse) Complete() { + releaseClientCodecResponse(ccr) +} + +var clientCodecResponsePool sync.Pool + +func retainClientCodecResponse() *ClientCodecResponse { + v := clientCodecResponsePool.Get() + if v == nil { + return &ClientCodecResponse{} + } + return v.(*ClientCodecResponse) +} + +func releaseClientCodecResponse(ccr *ClientCodecResponse) { + ccr.response.Version = "" + ccr.response.Result = nil + ccr.response.Error = nil + ccr.response.ID = nil + + clientCodecResponsePool.Put(ccr) +} diff --git a/protocol/json/constants.go b/protocol/json/constants.go new file mode 100644 index 0000000..b642d1d --- /dev/null +++ b/protocol/json/constants.go @@ -0,0 +1,5 @@ +package json + +const ( + Version = "2.0" +) diff --git a/protocol/json/server.go b/protocol/json/server.go index 6327a8f..566c1d6 100644 --- a/protocol/json/server.go +++ b/protocol/json/server.go @@ -3,13 +3,13 @@ package json import ( "encoding/json" "io" + "sync" "git.loafle.net/commons_go/rpc/encode" "git.loafle.net/commons_go/rpc/protocol" ) var null = json.RawMessage([]byte("null")) -var Version = "2.0" // ---------------------------------------------------------------------------- // Request and Response @@ -55,34 +55,54 @@ type serverResponse struct { // Codec // ---------------------------------------------------------------------------- -// NewcustomCodec returns a new JSON Codec based on passed encoder selector. -func NewCustomCodec(encSel encode.EncoderSelector) *Codec { - return &Codec{encSel: encSel} +// NewCustomServerCodec returns a new JSON Codec based on passed encoder selector. +func NewCustomServerCodec(encSel encode.EncoderSelector) *ServerCodec { + return &ServerCodec{encSel: encSel} } -// NewCodec returns a new JSON Codec. -func NewCodec() *Codec { - return NewCustomCodec(encode.DefaultEncoderSelector) +// NewServerCodec returns a new JSON Codec. +func NewServerCodec() *ServerCodec { + return NewCustomServerCodec(encode.DefaultEncoderSelector) } -// Codec creates a CodecRequest to process each request. -type Codec struct { - encSel encode.EncoderSelector +// ServerCodec creates a ServerCodecRequest to process each request. +type ServerCodec struct { + encSel encode.EncoderSelector + notifyMtx sync.Mutex + notify clientRequest } -// NewRequest returns a CodecRequest. -func (c *Codec) NewRequest(r io.Reader) (protocol.CodecRequest, error) { - return newCodecRequest(r, c.encSel.Select(r)) +// NewRequest returns a ServerCodecRequest. +func (sc *ServerCodec) NewRequest(r io.Reader) (protocol.ServerCodecRequest, error) { + return newServerCodecRequest(r, sc.encSel.SelectByReader(r)) +} + +// WriteNotify send a notification from server to client. +func (sc *ServerCodec) WriteNotify(w io.Writer, method string, args interface{}) error { + sc.notifyMtx.Lock() + + sc.notify.Version = Version + sc.notify.Method = method + sc.notify.Params = args + + encoder := json.NewEncoder(sc.encSel.SelectByWriter(w).Encode(w)) + err := encoder.Encode(&sc.notify) + sc.notifyMtx.Unlock() + // Not sure in which case will this happen. But seems harmless. + if err != nil { + return err + } + return nil } // ---------------------------------------------------------------------------- -// CodecRequest +// ServerCodecRequest // ---------------------------------------------------------------------------- -// newCodecRequest returns a new CodecRequest. -func newCodecRequest(r io.Reader, encoder encode.Encoder) (protocol.CodecRequest, error) { +// newCodecRequest returns a new ServerCodecRequest. +func newServerCodecRequest(r io.Reader, encoder encode.Encoder) (protocol.ServerCodecRequest, error) { // Decode the request body and check if RPC method is valid. - req := new(serverRequest) + req := retainServerRequest() err := json.NewDecoder(r).Decode(req) if err == io.ErrUnexpectedEOF || err == io.EOF { return nil, err @@ -102,24 +122,29 @@ func newCodecRequest(r io.Reader, encoder encode.Encoder) (protocol.CodecRequest } } - return &CodecRequest{request: req, err: err, encoder: encoder}, nil + return retainServerCodecRequest(req, err, encoder), nil } // CodecRequest decodes and encodes a single request. -type CodecRequest struct { +type ServerCodecRequest struct { request *serverRequest err error encoder encode.Encoder } +// Complete is callback function that end of request. +func (scr *ServerCodecRequest) Complete() { + if nil != scr.request { + releaseServerRequest(scr.request) + } + releaseServerCodecRequest(scr) +} + // Method returns the RPC method for the current request. // // The method uses a dotted notation as in "Service.Method". -func (c *CodecRequest) Method() (string, error) { - if c.err == nil { - return c.request.Method, nil - } - return "", c.err +func (scr *ServerCodecRequest) Method() string { + return scr.request.Method } // ReadRequest fills the request object for the RPC method. @@ -135,40 +160,36 @@ func (c *CodecRequest) Method() (string, error) { // absence of expected names MAY result in an error being // generated. The names MUST match exactly, including // case, to the method's expected parameters. -func (c *CodecRequest) ReadRequest(args interface{}) error { - if c.err == nil && c.request.Params != nil { - // Note: if c.request.Params is nil it's not an error, it's an optional member. +func (scr *ServerCodecRequest) ReadParams(args interface{}) error { + if scr.err == nil && scr.request.Params != nil { + // Note: if scr.request.Params is nil it's not an error, it's an optional member. // JSON params structured object. Unmarshal to the args object. - if err := json.Unmarshal(*c.request.Params, args); err != nil { + if err := json.Unmarshal(*scr.request.Params, args); err != nil { // Clearly JSON params is not a structured object, // fallback and attempt an unmarshal with JSON params as // array value and RPC params is struct. Unmarshal into // array containing the request struct. params := [1]interface{}{args} - if err = json.Unmarshal(*c.request.Params, ¶ms); err != nil { - c.err = &Error{ + if err = json.Unmarshal(*scr.request.Params, ¶ms); err != nil { + scr.err = &Error{ Code: E_INVALID_REQ, Message: err.Error(), - Data: c.request.Params, + Data: scr.request.Params, } } } } - return c.err + return scr.err } // WriteResponse encodes the response and writes it to the ResponseWriter. -func (c *CodecRequest) WriteResponse(w io.Writer, reply interface{}) error { - res := &serverResponse{ - Version: Version, - Result: reply, - ID: c.request.ID, - } - return c.writeServerResponse(w, res) +func (scr *ServerCodecRequest) WriteResponse(w io.Writer, reply interface{}) error { + res := retainServerResponse(Version, reply, nil, scr.request.ID) + return scr.writeServerResponse(w, res) } // WriteError encodes the response and writes it to the ResponseWriter. -func (c *CodecRequest) WriteError(w io.Writer, status int, err error) error { +func (scr *ServerCodecRequest) WriteError(w io.Writer, status int, err error) error { jsonErr, ok := err.(*Error) if !ok { jsonErr = &Error{ @@ -176,18 +197,19 @@ func (c *CodecRequest) WriteError(w io.Writer, status int, err error) error { Message: err.Error(), } } - res := &serverResponse{ - Version: Version, - Error: jsonErr, - ID: c.request.ID, - } - return c.writeServerResponse(w, res) + res := retainServerResponse(Version, nil, jsonErr, scr.request.ID) + return scr.writeServerResponse(w, res) } -func (c *CodecRequest) writeServerResponse(w io.Writer, res *serverResponse) error { +func (scr *ServerCodecRequest) writeServerResponse(w io.Writer, res *serverResponse) error { + defer func() { + if nil != res { + releaseServerResponse(res) + } + }() // ID is null for notifications and they don't have a response. - if c.request.ID != nil { - encoder := json.NewEncoder(c.encoder.Encode(w)) + if scr.request.ID != nil { + encoder := json.NewEncoder(scr.encoder.Encode(w)) err := encoder.Encode(res) // Not sure in which case will this happen. But seems harmless. @@ -200,3 +222,70 @@ func (c *CodecRequest) writeServerResponse(w io.Writer, res *serverResponse) err type EmptyResponse struct { } + +var serverCodecRequestPool sync.Pool + +func retainServerCodecRequest(request *serverRequest, err error, encoder encode.Encoder) *ServerCodecRequest { + v := serverCodecRequestPool.Get() + if v == nil { + return &ServerCodecRequest{} + } + + scr := v.(*ServerCodecRequest) + scr.request = request + scr.err = err + scr.encoder = encoder + + return scr +} + +func releaseServerCodecRequest(scr *ServerCodecRequest) { + scr.request = nil + scr.err = nil + scr.encoder = nil + + serverCodecRequestPool.Put(scr) +} + +var serverRequestPool sync.Pool + +func retainServerRequest() *serverRequest { + v := serverRequestPool.Get() + if v == nil { + return &serverRequest{} + } + return v.(*serverRequest) +} + +func releaseServerRequest(sr *serverRequest) { + sr.Method = "" + sr.Params = nil + sr.ID = nil + + serverRequestPool.Put(sr) +} + +var serverResponsePool sync.Pool + +func retainServerResponse(version string, result interface{}, err *Error, id *json.RawMessage) *serverResponse { + v := serverResponsePool.Get() + if v == nil { + return &serverResponse{} + } + sr := v.(*serverResponse) + sr.Version = version + sr.Result = result + sr.Error = err + sr.ID = id + + return sr +} + +func releaseServerResponse(sr *serverResponse) { + sr.Version = "" + sr.Result = nil + sr.Error = nil + sr.ID = nil + + serverResponsePool.Put(sr) +} diff --git a/protocol/registry_codec.go b/protocol/registry_codec.go new file mode 100644 index 0000000..4446be5 --- /dev/null +++ b/protocol/registry_codec.go @@ -0,0 +1,13 @@ +package protocol + +// ---------------------------------------------------------------------------- +// Codec +// ---------------------------------------------------------------------------- +// RegistryCodec creates a RegistryCodecRequest to process each request. +type RegistryCodec interface { + // Reads the request and returns the RPC method name. + Method() string + // Reads the request filling the RPC method args. + ReadParams(interface{}) error + Complete() +} diff --git a/protocol/server_codec.go b/protocol/server_codec.go new file mode 100644 index 0000000..bffb9c6 --- /dev/null +++ b/protocol/server_codec.go @@ -0,0 +1,18 @@ +package protocol + +import ( + "io" +) + +// ServerCodec creates a ServerCodecRequest to process each request. +type ServerCodec interface { + NewRequest(r io.Reader) (ServerCodecRequest, error) +} + +// ServerCodecRequest decodes a request and encodes a response using a specific +// serialization scheme. +type ServerCodecRequest interface { + RegistryCodec + WriteResponse(w io.Writer, reply interface{}) error + WriteError(w io.Writer, status int, err error) error +} diff --git a/registry.go b/registry.go index 5135e30..f563b7e 100644 --- a/registry.go +++ b/registry.go @@ -1,10 +1,7 @@ package rpc import ( - "fmt" - "io" "reflect" - "strings" "git.loafle.net/commons_go/rpc/protocol" ) @@ -19,38 +16,24 @@ Network connection */ -type WriteHookFunc func(io.Writer) - // NewRPCRegistry returns a new RPC registry. func NewRegistry() Registry { return &rpcRegistry{ - codecs: make(map[string]protocol.Codec), services: new(serviceMap), } } type Registry interface { - RegisterCodec(codec protocol.Codec, contentType string) RegisterService(receiver interface{}, name string) error HasMethod(method string) bool - Invoke(contentType string, reader io.Reader, writer io.Writer, beforeWrite WriteHookFunc, afterWrite WriteHookFunc) error + Invoke(codec protocol.RegistryCodec) (result interface{}, err error) } // RPCRegistry serves registered RPC services using registered codecs. type rpcRegistry struct { - codecs map[string]protocol.Codec services *serviceMap } -// RegisterCodec adds a new codec to the server. -// -// Codecs are defined to process a given serialization scheme, e.g., JSON or -// XML. A codec is chosen based on the "Content-Type" header from the request, -// excluding the charset definition. -func (rr *rpcRegistry) RegisterCodec(codec protocol.Codec, contentType string) { - rr.codecs[strings.ToLower(contentType)] = codec -} - // RegisterService adds a new service to the server. // // The name parameter is optional: if empty it will be inferred from @@ -86,36 +69,15 @@ func (rr *rpcRegistry) HasMethod(method string) bool { // Codecs are defined to process a given serialization scheme, e.g., JSON or // XML. A codec is chosen based on the "Content-Type" header from the request, // excluding the charset definition. -func (rr *rpcRegistry) Invoke(contentType string, r io.Reader, w io.Writer, beforeWrite WriteHookFunc, afterWrite WriteHookFunc) error { - var codec protocol.Codec - if contentType == "" && len(rr.codecs) == 1 { - // If Content-Type is not set and only one codec has been registered, - // then default to that codec. - for _, c := range rr.codecs { - codec = c - } - } else if codec = rr.codecs[strings.ToLower(contentType)]; codec == nil { - return fmt.Errorf("Unrecognized Content-Type: %s", contentType) - } - - // Create a new codec request. - codecReq, errNew := codec.NewRequest(r) - if nil != errNew { - return errNew - } - // Get service method to be called. - method, errMethod := codecReq.Method() - if errMethod != nil { - return write(codecReq, w, beforeWrite, afterWrite, nil, errMethod) - } - serviceSpec, methodSpec, errGet := rr.services.get(method) +func (rr *rpcRegistry) Invoke(codec protocol.RegistryCodec) (result interface{}, err error) { + serviceSpec, methodSpec, errGet := rr.services.get(codec.Method()) if errGet != nil { - return write(codecReq, w, beforeWrite, afterWrite, nil, errGet) + return nil, errGet } // Decode the args. args := reflect.New(methodSpec.argsType) - if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil { - return write(codecReq, w, beforeWrite, afterWrite, nil, errRead) + if errRead := codec.ReadParams(args.Interface()); errRead != nil { + return nil, errRead } // Call the service method. reply := reflect.New(methodSpec.replyType) @@ -133,32 +95,8 @@ func (rr *rpcRegistry) Invoke(contentType string, r io.Reader, w io.Writer, befo } if errResult != nil { - return write(codecReq, w, beforeWrite, afterWrite, nil, errResult) + return nil, errResult } - return write(codecReq, w, beforeWrite, afterWrite, reply.Interface(), nil) -} - -func write(codecReq protocol.CodecRequest, w io.Writer, beforeWrite WriteHookFunc, afterWrite WriteHookFunc, result interface{}, err error) error { - if nil != beforeWrite { - beforeWrite(w) - } - - var wErr error - - if err == nil { - wErr = codecReq.WriteResponse(w, result) - } else { - wErr = codecReq.WriteError(w, 400, err) - } - - if nil != wErr { - return wErr - } - - if nil != afterWrite { - afterWrite(w) - } - - return nil + return reply.Interface(), nil } diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..5615d8d --- /dev/null +++ b/server/server.go @@ -0,0 +1,89 @@ +package server + +import ( + "io" + "sync" + + "git.loafle.net/commons_go/rpc/protocol" +) + +func New(sh ServerHandler) Server { + s := &server{ + sh: sh, + } + return s +} + +type Server interface { + Start() + Stop() + Handle(r io.Reader, w io.Writer) error +} + +type server struct { + sh ServerHandler + + stopChan chan struct{} + stopWg sync.WaitGroup +} + +func (s *server) Start() { + if nil == s.sh { + panic("Server: server handler must be specified.") + } + s.sh.Validate() + + if s.stopChan != nil { + panic("Server: server is already running. Stop it before starting it again") + } + s.stopChan = make(chan struct{}) + +} + +func (s *server) Stop() { + if s.stopChan == nil { + panic("Server: server must be started before stopping it") + } + close(s.stopChan) + s.stopWg.Wait() + s.stopChan = nil +} + +func (s *server) Handle(r io.Reader, w io.Writer) error { + contentType := s.sh.GetContentType(r) + codec, err := s.sh.getCodec(contentType) + if nil != err { + return err + } + + var codecReq protocol.ServerCodecRequest + + defer func() { + if nil != codecReq { + codecReq.Complete() + } + }() + + s.sh.OnPreRead(r) + // Create a new codec request. + codecReq, errNew := codec.NewRequest(r) + if nil != errNew { + return errNew + } + s.sh.OnPostRead(r) + + result, err := s.sh.invoke(codecReq) + + if nil != err { + s.sh.OnPreWriteError(w, err) + codecReq.WriteError(w, 400, err) + s.sh.OnPostWriteError(w, err) + return nil + } + + s.sh.OnPreWriteResult(w, result) + codecReq.WriteResponse(w, result) + s.sh.OnPostWriteResult(w, result) + + return nil +} diff --git a/server/server_handler.go b/server/server_handler.go new file mode 100644 index 0000000..53f848c --- /dev/null +++ b/server/server_handler.go @@ -0,0 +1,27 @@ +package server + +import ( + "io" + + "git.loafle.net/commons_go/rpc/protocol" +) + +type ServerHandler interface { + RegisterCodec(codec protocol.ServerCodec, contentType string) + + GetContentType(r io.Reader) string + + OnPreRead(r io.Reader) + OnPostRead(r io.Reader) + + OnPreWriteResult(w io.Writer, result interface{}) + OnPostWriteResult(w io.Writer, result interface{}) + + OnPreWriteError(w io.Writer, err error) + OnPostWriteError(w io.Writer, err error) + + getCodec(contentType string) (protocol.ServerCodec, error) + invoke(codec protocol.RegistryCodec) (result interface{}, err error) + + Validate() +} diff --git a/server/server_handlers.go b/server/server_handlers.go new file mode 100644 index 0000000..a9441c7 --- /dev/null +++ b/server/server_handlers.go @@ -0,0 +1,81 @@ +package server + +import ( + "fmt" + "io" + "strings" + + "git.loafle.net/commons_go/rpc" + "git.loafle.net/commons_go/rpc/protocol" +) + +type ServerHandlers struct { + Registry rpc.Registry + + codecs map[string]protocol.ServerCodec +} + +// RegisterCodec adds a new codec to the server. +// +// Codecs are defined to process a given serialization scheme, e.g., JSON or +// XML. A codec is chosen based on the "Content-Type" header from the request, +// excluding the charset definition. +func (sh *ServerHandlers) RegisterCodec(codec protocol.ServerCodec, contentType string) { + if nil == sh.codecs { + sh.codecs = make(map[string]protocol.ServerCodec) + } + sh.codecs[strings.ToLower(contentType)] = codec +} + +func (sh *ServerHandlers) GetContentType(r io.Reader) string { + return "" +} + +func (sh *ServerHandlers) OnPreRead(r io.Reader) { + // no op +} + +func (sh *ServerHandlers) OnPostRead(r io.Reader) { + // no op +} + +func (sh *ServerHandlers) OnPreWriteResult(w io.Writer, result interface{}) { + // no op +} + +func (sh *ServerHandlers) OnPostWriteResult(w io.Writer, result interface{}) { + // no op +} + +func (sh *ServerHandlers) OnPreWriteError(w io.Writer, err error) { + // no op +} + +func (sh *ServerHandlers) OnPostWriteError(w io.Writer, err error) { + // no op +} + +func (sh *ServerHandlers) Validate() { + if nil == sh.Registry { + panic("Registry(RPCRegistry) must be specified.") + } +} + +func (sh *ServerHandlers) getCodec(contentType string) (protocol.ServerCodec, error) { + var codec protocol.ServerCodec + if contentType == "" && len(sh.codecs) == 1 { + // If Content-Type is not set and only one codec has been registered, + // then default to that codec. + for _, c := range sh.codecs { + codec = c + } + } else if codec = sh.codecs[strings.ToLower(contentType)]; codec == nil { + return nil, fmt.Errorf("Unrecognized Content-Type: %s", contentType) + } + + return codec, nil +} + +func (sh *ServerHandlers) invoke(codec protocol.RegistryCodec) (result interface{}, err error) { + return sh.Registry.Invoke(codec) +} diff --git a/service_map.go b/service_map.go index 0e76035..33cf571 100644 --- a/service_map.go +++ b/service_map.go @@ -31,6 +31,7 @@ type serviceMethod struct { replyType reflect.Type // type of the response argument } + // ---------------------------------------------------------------------------- // serviceMap // ----------------------------------------------------------------------------