diff --git a/annotation/annotation.go b/annotation/annotation.go index 0e08579..2cec4c3 100644 --- a/annotation/annotation.go +++ b/annotation/annotation.go @@ -67,7 +67,7 @@ type Annotation interface { // @Inject(name? string) // @Resource(name? string) -func ParseAnnotation(tag reflect.StructTag) (map[string]Annotation, error) { +func ParseAnnotation(tag reflect.StructTag) (map[reflect.Type]Annotation, error) { s := strings.Trim(tag.Get(AnnotationTag), " ") if "" == s { return nil, nil @@ -82,13 +82,13 @@ func ParseAnnotation(tag reflect.StructTag) (map[string]Annotation, error) { return nil, nil } - rKVs := make(map[string]Annotation, 0) + rKVs := make(map[reflect.Type]Annotation, 0) for name, attributes := range annotations { - annotation, err := newAnnotation(name, attributes) + t, annotation, err := newAnnotation(name, attributes) if nil != err { return nil, err } - rKVs[name] = annotation + rKVs[t] = annotation } return rKVs, nil @@ -155,10 +155,10 @@ func splitAnnotationAttribute(s string) (map[string]string, error) { return ss, nil } -func newAnnotation(name string, attributes map[string]string) (Annotation, error) { +func newAnnotation(name string, attributes map[string]string) (reflect.Type, Annotation, error) { def, ok := annotationRegistry[name] if !ok { - return nil, fmt.Errorf("There is no annotation[%s]", name) + return nil, nil, fmt.Errorf("There is no annotation[%s]", name) } v := reflect.New(def.rt) @@ -167,7 +167,7 @@ func newAnnotation(name string, attributes map[string]string) (Annotation, error if nil != attributes { setMetaAttributes(def, v.Elem(), attributes) } - return i, nil + return def.t, i, nil } // func parseAnnotationItem(a string) (name string, attributes map[string]string, err error) { diff --git a/registry/definition.go b/registry/definition.go index c50b10e..29e884f 100644 --- a/registry/definition.go +++ b/registry/definition.go @@ -15,24 +15,20 @@ type TypeDefinition struct { Type reflect.Type RealType reflect.Type - TypeAnnotations map[string]cda.Annotation - MethodAnnotations map[string]map[string]cda.Annotation + TypeAnnotations map[reflect.Type]cda.Annotation + MethodAnnotations map[string]map[reflect.Type]cda.Annotation Fields []*FieldDefinition } -func (td *TypeDefinition) GetAnnotation(name string) cda.Annotation { - if nil == td.TypeAnnotations { - return nil - } - - return td.TypeAnnotations[name] -} - func (td *TypeDefinition) GetTypeAnnotationByType(at reflect.Type, includeEmbedding bool) cda.Annotation { if nil == td.TypeAnnotations { return nil } + if !includeEmbedding { + return td.TypeAnnotations[at] + } + for _, v := range td.TypeAnnotations { if at == reflect.TypeOf(v) { return v @@ -58,13 +54,24 @@ func (td *TypeDefinition) GetMethodAnnotationByType(at reflect.Type, methodName return nil } - for _, v := range ms { - if at == reflect.TypeOf(v) { - return v - } + return ms[at] +} + +func (td *TypeDefinition) GetMethodAnnotationsByType(at reflect.Type) map[string]cda.Annotation { + if nil == td.MethodAnnotations { + return nil } - return nil + mas := make(map[string]cda.Annotation) + for k, v := range td.MethodAnnotations { + a, ok := v[at] + if !ok { + continue + } + mas[k] = a + } + + return mas } func checkAnnotation(t reflect.Type, st reflect.Type) bool { @@ -96,14 +103,7 @@ type FieldDefinition struct { Type reflect.Type RealType reflect.Type - Annotations map[string]cda.Annotation -} - -func (fd *FieldDefinition) GetAnnotation(name string) cda.Annotation { - if nil == fd.Annotations { - return nil - } - return fd.Annotations[name] + Annotations map[reflect.Type]cda.Annotation } func (fd *FieldDefinition) GetAnnotationByType(at reflect.Type, includeEmbedding bool) cda.Annotation { @@ -111,6 +111,10 @@ func (fd *FieldDefinition) GetAnnotationByType(at reflect.Type, includeEmbedding return nil } + if !includeEmbedding { + return fd.Annotations[at] + } + for _, v := range fd.Annotations { if at == reflect.TypeOf(v) { return v diff --git a/registry/registry.go b/registry/registry.go index cc407ba..bc477c9 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -29,6 +29,7 @@ type ComponentRegistry interface { GetTypeAnnotation(instanceType reflect.Type, annotationType reflect.Type) cda.Annotation GetMethodAnnotation(instanceType reflect.Type, annotationType reflect.Type, methodName string) cda.Annotation + GetMethodAnnotations(instanceType reflect.Type, annotationType reflect.Type) map[string]cda.Annotation } func newRegistry() ComponentRegistry { @@ -215,30 +216,6 @@ func (cr *defaultComponentRegistry) GetInstances(ts []reflect.Type) ([]interface return instances, nil } -// GetInstancesByAnnotationName returns instance of annotated -// n must be name of registered annotation -// func GetInstancesByAnnotationName(n string) ([]interface{}, error) { -// return registry.GetInstancesByAnnotationName(n) -// } -// func (cr *defaultComponentRegistry) GetInstancesByAnnotationName(n string) ([]interface{}, error) { -// var ( -// i interface{} -// err error -// ) -// instances := make([]interface{}, 0) - -// for _, td := range cr.definitionByType { -// if nil != td.GetAnnotation(n) { -// if i, err = cr.GetInstance(td.Type); nil != err { -// return nil, err -// } -// instances = append(instances, i) -// } -// } - -// return instances, nil -// } - func GetInstancesByAnnotationType(t reflect.Type) ([]interface{}, error) { return registry.GetInstancesByAnnotationType(t) } @@ -286,6 +263,20 @@ func (cr *defaultComponentRegistry) GetMethodAnnotation(instanceType reflect.Typ return def.GetMethodAnnotationByType(annotationType, methodName) } +func GetMethodAnnotations(instanceType reflect.Type, annotationType reflect.Type) map[string]cda.Annotation { + return registry.GetMethodAnnotations(instanceType, annotationType) +} +func (cr *defaultComponentRegistry) GetMethodAnnotations(instanceType reflect.Type, annotationType reflect.Type) map[string]cda.Annotation { + rt, _, _ := cur.GetTypeInfo(instanceType) + + def, ok := cr.definitionByType[rt] + if !ok { + return nil + } + + return def.GetMethodAnnotationsByType(annotationType) +} + func (cr *defaultComponentRegistry) buildDefinition(t reflect.Type) (*TypeDefinition, error) { if nil == t { return nil, fmt.Errorf("t[reflect.Type] is nil") @@ -384,7 +375,7 @@ func parseMethodAnnotation(f *reflect.StructField, td *TypeDefinition) bool { if nil != as && 0 < len(as) { if nil == td.MethodAnnotations { - td.MethodAnnotations = make(map[string]map[string]cda.Annotation, 0) + td.MethodAnnotations = make(map[string]map[reflect.Type]cda.Annotation, 0) } td.MethodAnnotations[f.Name[1:]] = as return true diff --git a/registry/registry_test.go b/registry/registry_test.go index 3f64d4c..b8f954b 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -50,7 +50,7 @@ func TestRegisterType(t *testing.T) { } type AService struct { - annotation.TypeAnnotation `annotation:"@Component(name='dkdkdf', methods='[1, 2]')"` + annotation.TypeAnnotation `annotation:"@Component(name='dkdkdf')"` NameA string }