diff --git a/context/context.go b/context/context.go index 465a763..a549ba2 100644 --- a/context/context.go +++ b/context/context.go @@ -1,6 +1,9 @@ package context -import "sync" +import ( + "reflect" + "sync" +) type ContextKey string @@ -12,22 +15,21 @@ func NewContext(parent Context) Context { c := &defaultContext{ Context: parent, } - c.attributes = make(map[string]interface{}) - + c.attributes = make(map[interface{}]interface{}) return c } type Context interface { Parent() Context - SetAttribute(key string, value interface{}) - GetAttribute(key string) (value interface{}) - RemoveAttribute(key string) - ContainsAttribute(key string) (exist bool) + SetAttribute(key interface{}, value interface{}) + GetAttribute(key interface{}) (value interface{}) + RemoveAttribute(key interface{}) + ContainsAttribute(key interface{}) (exist bool) } type defaultContext struct { Context - attributes map[string]interface{} + attributes map[interface{}]interface{} mtx sync.RWMutex } @@ -36,16 +38,23 @@ func (dc *defaultContext) Parent() Context { return dc.Context } -func (dc *defaultContext) SetAttribute(key string, value interface{}) { +func (dc *defaultContext) SetAttribute(key interface{}, value interface{}) { dc.checkInitialized() + if key == nil { + panic("nil key") + } + if !reflect.TypeOf(key).Comparable() { + panic("key is not comparable") + } + dc.mtx.Lock() defer dc.mtx.Unlock() dc.attributes[key] = value } -func (dc *defaultContext) GetAttribute(key string) (value interface{}) { +func (dc *defaultContext) GetAttribute(key interface{}) (value interface{}) { dc.checkInitialized() dc.mtx.RLock() @@ -61,7 +70,7 @@ func (dc *defaultContext) GetAttribute(key string) (value interface{}) { return dc.Context.GetAttribute(key) } -func (dc *defaultContext) RemoveAttribute(key string) { +func (dc *defaultContext) RemoveAttribute(key interface{}) { dc.checkInitialized() dc.mtx.Lock() @@ -79,7 +88,7 @@ func (dc *defaultContext) RemoveAttribute(key string) { dc.Context.RemoveAttribute(key) } -func (dc *defaultContext) ContainsAttribute(key string) (exist bool) { +func (dc *defaultContext) ContainsAttribute(key interface{}) (exist bool) { dc.checkInitialized() dc.mtx.RLock()