From 0c7808f8e3aa0b42b82a602dd5c0d86c2f42fa1a Mon Sep 17 00:00:00 2001 From: crusader Date: Thu, 26 Oct 2017 16:21:35 +0900 Subject: [PATCH] ing --- .gitignore | 70 ++++++++++++++++++++++++ adapter/fasthttp/adapter.go | 51 ++++++++++++++++++ adapter/http/adapter.go | 58 ++++++++++++++++++++ adapter/http/gzip_encode.go | 87 ++++++++++++++++++++++++++++++ adapter/http/handler.go | 83 ---------------------------- encode/encode.go | 18 +++++++ encode/selector.go | 20 +++++++ encoder_selector.go | 38 ------------- glide.yaml | 4 ++ {codec => protocol}/codec.go | 8 +-- protocol/json/client.go | 80 +++++++++++++++++++++++++++ {codec => protocol}/json/error.go | 0 {codec => protocol}/json/server.go | 44 ++++++++------- registry.go | 78 +++++++++++++++++---------- 14 files changed, 465 insertions(+), 174 deletions(-) create mode 100644 .gitignore create mode 100644 adapter/fasthttp/adapter.go create mode 100644 adapter/http/adapter.go create mode 100644 adapter/http/gzip_encode.go delete mode 100644 adapter/http/handler.go create mode 100644 encode/encode.go create mode 100644 encode/selector.go delete mode 100644 encoder_selector.go create mode 100644 glide.yaml rename {codec => protocol}/codec.go (81%) create mode 100644 protocol/json/client.go rename {codec => protocol}/json/error.go (100%) rename {codec => protocol}/json/server.go (83%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a5cf485 --- /dev/null +++ b/.gitignore @@ -0,0 +1,70 @@ +# Created by .ignore support plugin (hsz.mobi) +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff: +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/dictionaries + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.xml +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries + +# Mongo Explorer plugin: +.idea/**/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +/out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties +### Go template +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ +.idea/ +*.iml + +vendor/ +glide.lock +.DS_Store +dist/ +debug + + diff --git a/adapter/fasthttp/adapter.go b/adapter/fasthttp/adapter.go new file mode 100644 index 0000000..842d513 --- /dev/null +++ b/adapter/fasthttp/adapter.go @@ -0,0 +1,51 @@ +package fasthttp + +import ( + "fmt" + "io" + "strings" + + "git.loafle.net/commons_go/rpc" + "github.com/valyala/fasthttp" +) + +type FastHTTPAdapter struct { + registry rpc.Registry +} + +// FastHTTPHandler +func (a *FastHTTPAdapter) FastHTTPHandler(ctx *fasthttp.RequestCtx) { + if !ctx.IsPost() { + WriteError(ctx, 405, "rpc: POST method required, received "+r.Method) + return + } + + contentType := string(ctx.Request.Header.ContentType()) + idx := strings.Index(contentType, ";") + if idx != -1 { + contentType = contentType[:idx] + } + + err := a.registry.Invoke(contentType, ctx.PostBody(), ctx, beforeWrite, afterWrite) + + if nil != err { + WriteError(w, 400, err.Error()) + } + +} + +func beforeWrite(w io.Writer) { + ctx := w.(*fasthttp.RequestCtx) + ctx.Response.Header.Set("x-content-type-options", "nosniff") + ctx.SetContentType("application/json; charset=utf-8") +} + +func afterWrite(w io.Writer) { + +} + +func writeError(ctx *fasthttp.RequestCtx, status int, msg string) { + ctx.SetStatusCode(status) + ctx.SetContentType("text/plain; charset=utf-8") + fmt.Fprint(ctx, msg) +} diff --git a/adapter/http/adapter.go b/adapter/http/adapter.go new file mode 100644 index 0000000..c831cc6 --- /dev/null +++ b/adapter/http/adapter.go @@ -0,0 +1,58 @@ +package http + +import ( + "fmt" + "io" + "net/http" + "strings" + + "git.loafle.net/commons_go/rpc" +) + +type HTTPAdapter struct { + registry rpc.Registry +} + +func NewAdapter(registry rpc.Registry) *HTTPAdapter { + return &HTTPAdapter{ + registry: registry, + } +} + +// ServeHTTP +func (a *HTTPAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + writeError(w, 405, "rpc: POST method required, received "+r.Method) + return + } + + contentType := r.Header.Get("Content-Type") + idx := strings.Index(contentType, ";") + if idx != -1 { + contentType = contentType[:idx] + } + + err := a.registry.Invoke(contentType, r.Body, w, beforeWrite, afterWrite) + r.Body.Close() + + if nil != err { + writeError(w, 400, err.Error()) + } + +} + +func beforeWrite(w io.Writer) { + writer := w.(http.ResponseWriter) + writer.Header().Set("x-content-type-options", "nosniff") + writer.Header().Set("Content-Type", "application/json; charset=utf-8") +} + +func afterWrite(w io.Writer) { + +} + +func writeError(w http.ResponseWriter, status int, msg string) { + w.WriteHeader(status) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, msg) +} diff --git a/adapter/http/gzip_encode.go b/adapter/http/gzip_encode.go new file mode 100644 index 0000000..ffc779c --- /dev/null +++ b/adapter/http/gzip_encode.go @@ -0,0 +1,87 @@ +package http + +import ( + "compress/flate" + "compress/gzip" + "io" + "net/http" + "strings" + "unicode" + + "git.loafle.net/commons_go/rpc/encode" +) + +// gzipWriter writes and closes the gzip writer. +type gzipWriter struct { + w *gzip.Writer +} + +func (gw *gzipWriter) Write(p []byte) (n int, err error) { + defer gw.w.Close() + return gw.w.Write(p) +} + +// gzipEncoder implements the gzip compressed http encoder. +type gzipEncoder struct { +} + +func (enc *gzipEncoder) Encode(w http.ResponseWriter) io.Writer { + w.Header().Set("Content-Encoding", "gzip") + return &gzipWriter{gzip.NewWriter(w)} +} + +// flateWriter writes and closes the flate writer. +type flateWriter struct { + w *flate.Writer +} + +func (fw *flateWriter) Write(p []byte) (n int, err error) { + defer fw.w.Close() + return fw.w.Write(p) +} + +// flateEncoder implements the flate compressed http encoder. +type flateEncoder struct { +} + +func (enc *flateEncoder) Encode(w http.ResponseWriter) io.Writer { + fw, err := flate.NewWriter(w, flate.DefaultCompression) + if err != nil { + return w + } + w.Header().Set("Content-Encoding", "deflate") + return &flateWriter{fw} +} + +// CompressionSelector generates the compressed http encoder. +type CompressionSelector struct { +} + +// acceptedEnc returns the first compression type in "Accept-Encoding" header +// field of the request. +func acceptedEnc(req *http.Request) string { + encHeader := req.Header.Get("Accept-Encoding") + if encHeader == "" { + return "" + } + encTypes := strings.FieldsFunc(encHeader, func(r rune) bool { + return unicode.IsSpace(r) || r == ',' + }) + for _, enc := range encTypes { + if enc == "gzip" || enc == "deflate" { + return enc + } + } + return "" +} + +// Select method selects the correct compression encoder based on http HEADER. +func (_ *CompressionSelector) Select(r *http.Request) Encoder { + switch acceptedEnc(r) { + case "gzip": + return &gzipEncoder{} + case "flate": + return &flateEncoder{} + } + return encode.DefaultEncoder +} diff --git a/adapter/http/handler.go b/adapter/http/handler.go deleted file mode 100644 index 0815884..0000000 --- a/adapter/http/handler.go +++ /dev/null @@ -1,83 +0,0 @@ -package http - -import ( - "fmt" - "net/http" - "reflect" - "strings" -) - -type HTTPAdapter struct { -} - -// ServeHTTP -func (a *HTTPAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - WriteError(w, 405, "rpc: POST method required, received "+r.Method) - return - } - contentType := r.Header.Get("Content-Type") - idx := strings.Index(contentType, ";") - if idx != -1 { - contentType = contentType[:idx] - } - var codec Codec - if contentType == "" && len(s.codecs) == 1 { - // If Content-Type is not set and only one codec has been registered, - // then default to that codec. - for _, c := range s.codecs { - codec = c - } - } else if codec = s.codecs[strings.ToLower(contentType)]; codec == nil { - WriteError(w, 415, "rpc: unrecognized Content-Type: "+contentType) - return - } - // Create a new codec request. - codecReq := codec.NewRequest(r) - // Get service method to be called. - method, errMethod := codecReq.Method() - if errMethod != nil { - codecReq.WriteError(w, 400, errMethod) - return - } - serviceSpec, methodSpec, errGet := s.services.get(method) - if errGet != nil { - codecReq.WriteError(w, 400, errGet) - return - } - // Decode the args. - args := reflect.New(methodSpec.argsType) - if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil { - codecReq.WriteError(w, 400, errRead) - return - } - // Call the service method. - reply := reflect.New(methodSpec.replyType) - errValue := methodSpec.method.Func.Call([]reflect.Value{ - serviceSpec.rcvr, - reflect.ValueOf(r), - args, - reply, - }) - // Cast the result to error if needed. - var errResult error - errInter := errValue[0].Interface() - if errInter != nil { - errResult = errInter.(error) - } - // Prevents Internet Explorer from MIME-sniffing a response away - // from the declared content-type - w.Header().Set("x-content-type-options", "nosniff") - // Encode the response. - if errResult == nil { - codecReq.WriteResponse(w, reply.Interface()) - } else { - codecReq.WriteError(w, 400, errResult) - } -} - -func WriteError(w http.ResponseWriter, status int, msg string) { - w.WriteHeader(status) - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - fmt.Fprint(w, msg) -} diff --git a/encode/encode.go b/encode/encode.go new file mode 100644 index 0000000..e198fd6 --- /dev/null +++ b/encode/encode.go @@ -0,0 +1,18 @@ +package encode + +import "io" + +// Encoder interface contains the encoder for http response. +// Eg. gzip, flate compressions. +type Encoder interface { + Encode(w io.Writer) io.Writer +} + +type encoder struct { +} + +func (_ *encoder) Encode(w io.Writer) io.Writer { + return w +} + +var DefaultEncoder = &encoder{} diff --git a/encode/selector.go b/encode/selector.go new file mode 100644 index 0000000..2c749b9 --- /dev/null +++ b/encode/selector.go @@ -0,0 +1,20 @@ +package encode + +import "io" + +// EncoderSelector interface provides a way to select encoder using the http +// request. Typically people can use this to check HEADER of the request and +// figure out client capabilities. +// Eg. "Accept-Encoding" tells about supported compressions. +type EncoderSelector interface { + Select(r io.Reader) Encoder +} + +type encoderSelector struct { +} + +func (_ *encoderSelector) Select(_ io.Reader) Encoder { + return DefaultEncoder +} + +var DefaultEncoderSelector = &encoderSelector{} diff --git a/encoder_selector.go b/encoder_selector.go deleted file mode 100644 index ceea276..0000000 --- a/encoder_selector.go +++ /dev/null @@ -1,38 +0,0 @@ -package rpc - -import ( - "io" - "net/http" -) - -// Encoder interface contains the encoder for http response. -// Eg. gzip, flate compressions. -type Encoder interface { - Encode(w http.ResponseWriter) io.Writer -} - -type encoder struct { -} - -func (_ *encoder) Encode(w http.ResponseWriter) io.Writer { - return w -} - -var DefaultEncoder = &encoder{} - -// EncoderSelector interface provides a way to select encoder using the http -// request. Typically people can use this to check HEADER of the request and -// figure out client capabilities. -// Eg. "Accept-Encoding" tells about supported compressions. -type EncoderSelector interface { - Select(r *http.Request) Encoder -} - -type encoderSelector struct { -} - -func (_ *encoderSelector) Select(_ *http.Request) Encoder { - return DefaultEncoder -} - -var DefaultEncoderSelector = &encoderSelector{} diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 0000000..d5c64ce --- /dev/null +++ b/glide.yaml @@ -0,0 +1,4 @@ +package: git.loafle.net/commons_go/rpc +import: +- package: github.com/valyala/fasthttp + version: v20160617 diff --git a/codec/codec.go b/protocol/codec.go similarity index 81% rename from codec/codec.go rename to protocol/codec.go index 29e00f2..b47133c 100644 --- a/codec/codec.go +++ b/protocol/codec.go @@ -1,4 +1,4 @@ -package codec +package protocol import ( "io" @@ -10,7 +10,7 @@ import ( // Codec creates a CodecRequest to process each request. type Codec interface { - NewRequest(r io.Reader) (CodecRequest, bool) + NewRequest(rc io.Reader) CodecRequest } // CodecRequest decodes a request and encodes a response using a specific @@ -21,7 +21,7 @@ type CodecRequest interface { // Reads the request filling the RPC method args. ReadRequest(interface{}) error // Writes the response using the RPC method reply. - WriteResponse(io.Writer, interface{}) + WriteResponse(io.Writer, interface{}) error // Writes an error produced by the server. - WriteError(w io.Writer, status int, err error) + WriteError(w io.Writer, status int, err error) error } diff --git a/protocol/json/client.go b/protocol/json/client.go new file mode 100644 index 0000000..e59093a --- /dev/null +++ b/protocol/json/client.go @@ -0,0 +1,80 @@ +package json + +import ( + "encoding/json" + "io" + "math/rand" +) + +// ---------------------------------------------------------------------------- +// Request and Response +// ---------------------------------------------------------------------------- + +// clientRequest represents a JSON-RPC request sent by a client. +type clientRequest struct { + // JSON-RPC protocol. + Version string `json:"jsonrpc"` + + // A String containing the name of the method to be invoked. + Method string `json:"method"` + + // Object to pass as request parameter to the method. + Params interface{} `json:"params"` + + // The request id. This can be of any type. It is used to match the + // response with the request that it is replying to. + ID uint64 `json:"id"` +} + +// clientResponse represents a JSON-RPC response returned to a client. +type clientResponse struct { + Version string `json:"jsonrpc"` + Result *json.RawMessage `json:"result"` + Error *json.RawMessage `json:"error"` +} + +// EncodeClientRequest encodes parameters for a JSON-RPC client request. +func EncodeClientRequest(method string, args interface{}) ([]byte, error) { + c := &clientRequest{ + Version: "2.0", + Method: method, + Params: args, + ID: uint64(rand.Int63()), + } + return json.Marshal(c) +} + +// EncodeClientNotify encodes parameters for a JSON-RPC client notification. +func EncodeClientNotify(method string, args interface{}) ([]byte, error) { + c := &clientRequest{ + Version: "2.0", + Method: method, + Params: args, + } + return json.Marshal(c) +} + +// DecodeClientResponse decodes the response body of a client request into +// the interface reply. +func DecodeClientResponse(r io.Reader, reply interface{}) error { + var c clientResponse + if err := json.NewDecoder(r).Decode(&c); err != nil { + return err + } + if c.Error != nil { + jsonErr := &Error{} + if err := json.Unmarshal(*c.Error, jsonErr); err != nil { + return &Error{ + Code: E_SERVER, + Message: string(*c.Error), + } + } + return jsonErr + } + + if c.Result == nil { + return ErrNullResult + } + + return json.Unmarshal(*c.Result, reply) +} diff --git a/codec/json/error.go b/protocol/json/error.go similarity index 100% rename from codec/json/error.go rename to protocol/json/error.go diff --git a/codec/json/server.go b/protocol/json/server.go similarity index 83% rename from codec/json/server.go rename to protocol/json/server.go index f55a370..e7e165d 100644 --- a/codec/json/server.go +++ b/protocol/json/server.go @@ -2,7 +2,10 @@ package json import ( "encoding/json" - "net/http" + "io" + + "git.loafle.net/commons_go/rpc/encode" + "git.loafle.net/commons_go/rpc/protocol" ) var null = json.RawMessage([]byte("null")) @@ -53,22 +56,22 @@ type serverResponse struct { // ---------------------------------------------------------------------------- // NewcustomCodec returns a new JSON Codec based on passed encoder selector. -func NewCustomCodec(encSel rpc.EncoderSelector) *Codec { +func NewCustomCodec(encSel encode.EncoderSelector) *Codec { return &Codec{encSel: encSel} } // NewCodec returns a new JSON Codec. func NewCodec() *Codec { - return NewCustomCodec(rpc.DefaultEncoderSelector) + return NewCustomCodec(encode.DefaultEncoderSelector) } // Codec creates a CodecRequest to process each request. type Codec struct { - encSel rpc.EncoderSelector + encSel encode.EncoderSelector } // NewRequest returns a CodecRequest. -func (c *Codec) NewRequest(r *http.Request) rpc.CodecRequest { +func (c *Codec) NewRequest(r io.Reader) protocol.CodecRequest { return newCodecRequest(r, c.encSel.Select(r)) } @@ -77,10 +80,10 @@ func (c *Codec) NewRequest(r *http.Request) rpc.CodecRequest { // ---------------------------------------------------------------------------- // newCodecRequest returns a new CodecRequest. -func newCodecRequest(r *http.Request, encoder rpc.Encoder) rpc.CodecRequest { +func newCodecRequest(r io.Reader, encoder encode.Encoder) protocol.CodecRequest { // Decode the request body and check if RPC method is valid. req := new(serverRequest) - err := json.NewDecoder(r.Body).Decode(req) + err := json.NewDecoder(r).Decode(req) if err != nil { err = &Error{ Code: E_PARSE, @@ -95,7 +98,7 @@ func newCodecRequest(r *http.Request, encoder rpc.Encoder) rpc.CodecRequest { Data: req, } } - r.Body.Close() + return &CodecRequest{request: req, err: err, encoder: encoder} } @@ -103,7 +106,7 @@ func newCodecRequest(r *http.Request, encoder rpc.Encoder) rpc.CodecRequest { type CodecRequest struct { request *serverRequest err error - encoder rpc.Encoder + encoder encode.Encoder } // Method returns the RPC method for the current request. @@ -152,16 +155,17 @@ func (c *CodecRequest) ReadRequest(args interface{}) error { } // WriteResponse encodes the response and writes it to the ResponseWriter. -func (c *CodecRequest) WriteResponse(w http.ResponseWriter, reply interface{}) { +func (c *CodecRequest) WriteResponse(w io.Writer, reply interface{}) error { res := &serverResponse{ Version: Version, Result: reply, - Id: c.request.Id, + ID: c.request.ID, } - c.writeServerResponse(w, res) + return c.writeServerResponse(w, res) } -func (c *CodecRequest) WriteError(w http.ResponseWriter, status int, err error) { +// WriteError encodes the response and writes it to the ResponseWriter. +func (c *CodecRequest) WriteError(w io.Writer, status int, err error) error { jsonErr, ok := err.(*Error) if !ok { jsonErr = &Error{ @@ -172,23 +176,23 @@ func (c *CodecRequest) WriteError(w http.ResponseWriter, status int, err error) res := &serverResponse{ Version: Version, Error: jsonErr, - Id: c.request.Id, + ID: c.request.ID, } - c.writeServerResponse(w, res) + return c.writeServerResponse(w, res) } -func (c *CodecRequest) writeServerResponse(w http.ResponseWriter, res *serverResponse) { - // Id is null for notifications and they don't have a response. - if c.request.Id != nil { - w.Header().Set("Content-Type", "application/json; charset=utf-8") +func (c *CodecRequest) writeServerResponse(w io.Writer, res *serverResponse) error { + // ID is null for notifications and they don't have a response. + if c.request.ID != nil { encoder := json.NewEncoder(c.encoder.Encode(w)) err := encoder.Encode(res) // Not sure in which case will this happen. But seems harmless. if err != nil { - rpc.WriteError(w, 400, err.Error()) + return err } } + return nil } type EmptyResponse struct { diff --git a/registry.go b/registry.go index df93320..6758b98 100644 --- a/registry.go +++ b/registry.go @@ -6,7 +6,7 @@ import ( "reflect" "strings" - "git.loafle.net/commons_go/rpc/codec" + "git.loafle.net/commons_go/rpc/protocol" ) /** @@ -19,24 +19,26 @@ Network connection */ +type WriteHookFunc func(io.Writer) + // NewRPCRegistry returns a new RPC registry. func NewRegistry() Registry { return &rpcRegistry{ - codecs: make(map[string]codec.Codec), + codecs: make(map[string]protocol.Codec), services: new(serviceMap), } } type Registry interface { - RegisterCodec(codec codec.Codec, contentType string) + RegisterCodec(codec protocol.Codec, contentType string) RegisterService(receiver interface{}, name string) error HasMethod(method string) bool - Invoke(contentType string, reader io.Reader, writer io.Writer) error + Invoke(contentType string, reader io.Reader, writer io.Writer, beforeWrite WriteHookFunc, afterWrite WriteHookFunc) error } // RPCRegistry serves registered RPC services using registered codecs. type rpcRegistry struct { - codecs map[string]codec.Codec + codecs map[string]protocol.Codec services *serviceMap } @@ -45,8 +47,8 @@ type rpcRegistry struct { // Codecs are defined to process a given serialization scheme, e.g., JSON or // XML. A codec is chosen based on the "Content-Type" header from the request, // excluding the charset definition. -func (r *rpcRegistry) RegisterCodec(codec codec.Codec, contentType string) { - r.codecs[strings.ToLower(contentType)] = codec +func (rr *rpcRegistry) RegisterCodec(codec protocol.Codec, contentType string) { + rr.codecs[strings.ToLower(contentType)] = codec } // RegisterService adds a new service to the server. @@ -65,15 +67,15 @@ func (r *rpcRegistry) RegisterCodec(codec codec.Codec, contentType string) { // - The method has return type error. // // All other methods are ignored. -func (r *rpcRegistry) RegisterService(receiver interface{}, name string) error { - return r.services.register(receiver, name) +func (rr *rpcRegistry) RegisterService(receiver interface{}, name string) error { + return rr.services.register(receiver, name) } // HasMethod returns true if the given method is registered. // // The method uses a dotted notation as in "Service.Method". -func (r *rpcRegistry) HasMethod(method string) bool { - if _, _, err := r.services.get(method); err == nil { +func (rr *rpcRegistry) HasMethod(method string) bool { + if _, _, err := rr.services.get(method); err == nil { return true } return false @@ -84,32 +86,33 @@ func (r *rpcRegistry) HasMethod(method string) bool { // Codecs are defined to process a given serialization scheme, e.g., JSON or // XML. A codec is chosen based on the "Content-Type" header from the request, // excluding the charset definition. -func (r *rpcRegistry) Invoke(contentType string, reader io.Reader, writer io.Writer) error { - var codec codec.Codec - if contentType == "" && len(r.codecs) == 1 { +func (rr *rpcRegistry) Invoke(contentType string, r io.Reader, w io.Writer, beforeWrite WriteHookFunc, afterWrite WriteHookFunc) error { + var codec protocol.Codec + if contentType == "" && len(rr.codecs) == 1 { // If Content-Type is not set and only one codec has been registered, // then default to that codec. - for _, c := range r.codecs { + for _, c := range rr.codecs { codec = c } - } else if codec = r.codecs[strings.ToLower(contentType)]; codec == nil { + } else if codec = rr.codecs[strings.ToLower(contentType)]; codec == nil { return fmt.Errorf("Unrecognized Content-Type: %s", contentType) } + // Create a new codec request. - codecReq, hasResponse := codec.NewRequest(reader) + codecReq := codec.NewRequest(r) // Get service method to be called. method, errMethod := codecReq.Method() if errMethod != nil { - return errMethod + return write(codecReq, w, beforeWrite, afterWrite, nil, errMethod) } - serviceSpec, methodSpec, errGet := r.services.get(method) + serviceSpec, methodSpec, errGet := rr.services.get(method) if errGet != nil { - return errGet + return write(codecReq, w, beforeWrite, afterWrite, nil, errGet) } // Decode the args. args := reflect.New(methodSpec.argsType) if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil { - return errRead + return write(codecReq, w, beforeWrite, afterWrite, nil, errRead) } // Call the service method. reply := reflect.New(methodSpec.replyType) @@ -120,10 +123,6 @@ func (r *rpcRegistry) Invoke(contentType string, reader io.Reader, writer io.Wri reply, }) - if !hasResponse { - return nil - } - // Cast the result to error if needed. var errResult error errInter := errValue[0].Interface() @@ -131,11 +130,32 @@ func (r *rpcRegistry) Invoke(contentType string, reader io.Reader, writer io.Wri errResult = errInter.(error) } - // Encode the response. - if errResult == nil { - codecReq.WriteResponse(writer, reply.Interface()) + if errResult != nil { + return write(codecReq, w, beforeWrite, afterWrite, nil, errResult) + } + + return write(codecReq, w, beforeWrite, afterWrite, reply.Interface(), nil) +} + +func write(codecReq protocol.CodecRequest, w io.Writer, beforeWrite WriteHookFunc, afterWrite WriteHookFunc, result interface{}, err error) error { + if nil != beforeWrite { + beforeWrite(w) + } + + var wErr error + + if err == nil { + wErr = codecReq.WriteResponse(w, result) } else { - codecReq.WriteError(writer, 400, errResult) + wErr = codecReq.WriteError(w, 400, err) + } + + if nil != wErr { + return wErr + } + + if nil != afterWrite { + afterWrite(w) } return nil