diff --git a/registry.go b/registry.go index 53634d2..4da4922 100644 --- a/registry.go +++ b/registry.go @@ -77,19 +77,21 @@ func (rr *rpcRegistry) Invoke(codec protocol.RegistryCodec) (result interface{}, // Decode the args. var in []reflect.Value - paramValues, paramInstances := methodSpec.getParams() + paramValues := methodSpec.paramValues + if nil != paramValues { - in = make([]reflect.Value, len(paramValues)+1) - if errRead := codec.ReadParams(paramInstances); errRead != nil { + params := methodSpec.getInterfaces() + if errRead := codec.ReadParams(params); errRead != nil { return nil, errRead } + in = make([]reflect.Value, len(paramValues)+1) for indexI := 0; indexI < len(paramValues); indexI++ { in[indexI+1] = paramValues[indexI] } } else { in = make([]reflect.Value, 1) } - in[0] = serviceSpec.rcvr + in[0] = serviceSpec.rcvrV // Call the service method. returnValues := methodSpec.method.Func.Call(in) diff --git a/service_map.go b/service_map.go index abe6d8f..5a3e0c1 100644 --- a/service_map.go +++ b/service_map.go @@ -19,30 +19,30 @@ var ( // ---------------------------------------------------------------------------- type service struct { - name string // name of service - rcvr reflect.Value // receiver of methods for the service - rcvrType reflect.Type // type of the receiver - methods map[string]*serviceMethod // registered methods + name string // name of service + rcvrV reflect.Value // receiver of methods for the service + rcvrT reflect.Type // type of the receiver + methods map[string]*serviceMethod // registered methods } type serviceMethod struct { - method reflect.Method // receiver method - paramTypes []reflect.Type // type of the request argument - returnType reflect.Type // type of the response argument + method reflect.Method // receiver method + paramTypes []reflect.Type // type of the request argument + paramValues []reflect.Value + returnType reflect.Type // type of the response argument + returnValue reflect.Value } -func (sm *serviceMethod) getParams() (values []reflect.Value, instances []interface{}) { - if nil == sm.paramTypes || 0 == len(sm.paramTypes) { - return nil, nil +func (sm *serviceMethod) getInterfaces() (instances []interface{}) { + if nil == sm.paramValues || 0 == len(sm.paramValues) { + return nil } - pCount := len(sm.paramTypes) - values = make([]reflect.Value, pCount) + pCount := len(sm.paramValues) instances = make([]interface{}, pCount) for indexI := 0; indexI < pCount; indexI++ { - values[indexI] = reflect.New(sm.paramTypes[indexI]) - instances[indexI] = values[indexI].Interface() + instances[indexI] = sm.paramValues[indexI].Interface() } return @@ -62,69 +62,77 @@ type serviceMap struct { func (m *serviceMap) register(rcvr interface{}, name string) error { // Setup service. s := &service{ - name: name, - rcvr: reflect.ValueOf(rcvr), - rcvrType: reflect.TypeOf(rcvr), - methods: make(map[string]*serviceMethod), + name: name, + rcvrV: reflect.ValueOf(rcvr), + rcvrT: reflect.TypeOf(rcvr), + methods: make(map[string]*serviceMethod), } if name == "" { - s.name = reflect.Indirect(s.rcvr).Type().Name() + s.name = reflect.Indirect(s.rcvrV).Type().Name() if !isExported(s.name) { return fmt.Errorf("rpc: type %q is not exported", s.name) } } if s.name == "" { return fmt.Errorf("rpc: no service name for type %q", - s.rcvrType.String()) + s.rcvrT.String()) } + var err error // Setup methods. Loop: - for i := 0; i < s.rcvrType.NumMethod(); i++ { - method := s.rcvrType.Method(i) - mtype := method.Type + for i := 0; i < s.rcvrT.NumMethod(); i++ { + m := s.rcvrT.Method(i) + mt := m.Type // Method must be exported. - if method.PkgPath != "" { + if m.PkgPath != "" { continue } var paramTypes []reflect.Type + var paramValues []reflect.Value + var returnType reflect.Type + var returnValue reflect.Value - mCount := mtype.NumIn() - if 0 < mCount { - paramTypes = make([]reflect.Type, mCount) - for indexI := 1; indexI < mCount; indexI++ { - param := mtype.In(indexI) - if !isExportedOrBuiltin(param) { - continue Loop + pCount := mt.NumIn() - 1 + + if 0 < pCount { + paramTypes = make([]reflect.Type, pCount) + paramValues = make([]reflect.Value, pCount) + + for indexI := 0; indexI < pCount; indexI++ { + if paramTypes[indexI], paramValues[indexI], err = validateType(mt.In(indexI + 1)); nil != err { + return err } - paramTypes[indexI] = param.Elem() } } - var returnType reflect.Type - switch mtype.NumOut() { + switch mt.NumOut() { case 1: - if returnType := mtype.Out(0); returnType != typeOfError { + if t := mt.Out(0); t != typeOfError { continue Loop } case 2: - if returnType := mtype.Out(0); !isExportedOrBuiltin(returnType) { + if t := mt.Out(0); !isExportedOrBuiltin(t) { continue Loop } - if returnType := mtype.Out(1); returnType != typeOfError { + if t := mt.Out(1); t != typeOfError { continue Loop } - returnType = mtype.Out(0).Elem() + if returnType, returnValue, err = validateType(mt.Out(0)); nil != err { + return err + } default: continue } - s.methods[method.Name] = &serviceMethod{ - method: method, - paramTypes: paramTypes, - returnType: returnType, + s.methods[m.Name] = &serviceMethod{ + method: m, + paramTypes: paramTypes, + paramValues: paramValues, + returnType: returnType, + returnValue: returnValue, } } if len(s.methods) == 0 { @@ -142,6 +150,22 @@ Loop: return nil } +func validateType(t reflect.Type) (rt reflect.Type, rv reflect.Value, err error) { + if t.Kind() == reflect.Struct { + err = fmt.Errorf("Type is Struct. Pass by reference, i.e. %s", t) + return + } + if t.Kind() == reflect.Ptr { + rt = t.Elem() + } + + rv = reflect.New(rt) + if rt.Kind() != reflect.Struct { + rv = reflect.Indirect(rv) + } + return +} + // get returns a registered service given a method name. // // The method name uses a dotted notation as in "Service.Method".