diff --git a/registry/definition.go b/registry/definition.go index 4120569..e85269c 100644 --- a/registry/definition.go +++ b/registry/definition.go @@ -16,7 +16,7 @@ type TypeDefinition struct { RealType reflect.Type TypeAnnotations map[string]cda.Annotation - MethodAnnotations map[string]cda.Annotation + MethodAnnotations map[string]map[string]cda.Annotation Fields []*FieldDefinition } @@ -48,6 +48,25 @@ func (td *TypeDefinition) GetAnnotationByType(at reflect.Type, includeEmbedding return nil } +func (td *TypeDefinition) GetMethodAnnotationByType(at reflect.Type, methodName string) cda.Annotation { + if nil == td.MethodAnnotations { + return nil + } + + ms, ok := td.MethodAnnotations[methodName] + if !ok { + return nil + } + + for _, v := range ms { + if at == reflect.TypeOf(v) { + return v + } + } + + return nil +} + func checkAnnotation(t reflect.Type, st reflect.Type) bool { rt, _, _ := cur.GetTypeInfo(t) if reflect.Struct != rt.Kind() { diff --git a/registry/registry.go b/registry/registry.go index 1170785..d12c7db 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -27,6 +27,9 @@ type ComponentRegistry interface { GetInstanceByName(name string) (interface{}, error) // GetInstancesByAnnotationName(n string) ([]interface{}, error) GetInstancesByAnnotationType(t reflect.Type) ([]interface{}, error) + + GetTypeAnnotation(instanceType reflect.Type, annotationType reflect.Type) cda.Annotation + GetMethodAnnotation(instanceType reflect.Type, annotationType reflect.Type, methodName string) cda.Annotation } func newRegistry() ComponentRegistry { @@ -259,6 +262,29 @@ func (cr *defaultComponentRegistry) GetInstancesByAnnotationType(t reflect.Type) return instances, nil } +func GetTypeAnnotation(instanceType reflect.Type, annotationType reflect.Type) cda.Annotation { + return registry.GetTypeAnnotation(instanceType, annotationType) +} +func (cr *defaultComponentRegistry) GetTypeAnnotation(instanceType reflect.Type, annotationType reflect.Type) cda.Annotation { + def, ok := cr.definitionByType[instanceType] + if !ok { + return nil + } + + return def.GetAnnotationByType(annotationType, false) +} +func GetMethodAnnotation(instanceType reflect.Type, annotationType reflect.Type, methodName string) cda.Annotation { + return registry.GetMethodAnnotation(instanceType, annotationType, methodName) +} +func (cr *defaultComponentRegistry) GetMethodAnnotation(instanceType reflect.Type, annotationType reflect.Type, methodName string) cda.Annotation { + def, ok := cr.definitionByType[instanceType] + if !ok { + return nil + } + + return def.GetMethodAnnotationByType(annotationType, methodName) +} + func (cr *defaultComponentRegistry) buildDefinition(t reflect.Type) (*TypeDefinition, error) { if nil == t { return nil, fmt.Errorf("t[reflect.Type] is nil") @@ -350,7 +376,10 @@ func parseMethodAnnotation(f *reflect.StructField, td *TypeDefinition) { } if nil != as && 0 < len(as) { - td.MethodAnnotations = as + if nil == td.MethodAnnotations { + td.MethodAnnotations = make(map[string]map[string]cda.Annotation, 0) + } + td.MethodAnnotations[f.Name] = as } }