diff --git a/config.go b/config.go index e601134..8b24c38 100644 --- a/config.go +++ b/config.go @@ -2,1099 +2,123 @@ package config import ( "bytes" - "encoding/csv" - "fmt" "io" - "log" "os" "path/filepath" - "reflect" - "strings" - "time" - - "github.com/fsnotify/fsnotify" - "github.com/mitchellh/mapstructure" - "github.com/spf13/afero" - "github.com/spf13/cast" - "github.com/spf13/pflag" ) -// SupportedExts are universally supported extensions. -var SupportedExts = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop", "hcl"} - -type UnsupportedConfigError string - -// Error returns the formatted configuration error. -func (str UnsupportedConfigError) Error() string { - return fmt.Sprintf("Unsupported Config Type %q", string(str)) -} - -// ConfigFileNotFoundError denotes failing to find configuration file. -type ConfigFileNotFoundError struct { - name, locations string -} - -// Error returns the formatted configuration error. -func (fnfe ConfigFileNotFoundError) Error() string { - return fmt.Sprintf("Config File %q Not Found in %q", fnfe.name, fnfe.locations) -} - -type CircularReferenceAliasError struct { - alias, key, realKey string -} - -// Error returns the formatted configuration error. -func (crae CircularReferenceAliasError) Error() string { - return fmt.Sprintf("Creating circular reference alias[%q] for key[%q]:realKey[%q]", crae.alias, crae.key, crae.realKey) -} +const ( + ConfigEnvPrefix = "CONFIG_ENV_PREFIX" + ConfigTagEnv = "env" +) type Configurator interface { - OnConfigChange(run func(in fsnotify.Event)) - WatchConfig() - SetConfigType(in string) - SetConfigFile(in string) + SetConfigPath(in string) error + Load(target interface{}, files ...string) error + LoadReader(target interface{}, ext string, in io.Reader) error + Save(target interface{}, file string, overWrite bool) error + SetEnvPrefix(in string) - SetConfigName(in string) - AddConfigPath(in string) error - ReadConfig(in io.Reader) error - ReadInConfig() error - AutomaticEnv() - SetEnvKeyReplacer(r *strings.Replacer) - RegisterAlias(alias string, key string) error - SetDefault(key string, value interface{}) - Set(key string, value interface{}) - IsSet(key string) bool - Get(key string) interface{} - GetString(key string) string - GetBool(key string) bool - GetInt(key string) int - GetInt64(key string) int64 - GetFloat64(key string) float64 - GetTime(key string) time.Time - GetDuration(key string) time.Duration - GetStringSlice(key string) []string - GetStringMap(key string) map[string]interface{} - GetStringMapString(key string) map[string]string - GetStringMapStringSlice(key string) map[string][]string - GetSizeInBytes(key string) uint - Sub(key string) Configurator - UnmarshalKey(key string, rawVal interface{}) error - Unmarshal(rawVal interface{}) error - UnmarshalExact(rawVal interface{}) error - Marshal(configType string) ([]byte, error) - BindPFlags(flags *pflag.FlagSet) error - BindPFlag(key string, flag *pflag.Flag) error - BindFlagValues(flags FlagValueSet) (err error) - BindFlagValue(key string, flag FlagValue) error - BindEnv(input ...string) error - AllSettings() map[string]interface{} - AllKeys() []string } -type configurator struct { - keyDelim string - configPaths []string - // The filesystem to read config from. - fs afero.Fs +type config struct { + configPath string // Name of file to look for inside the path - configName string - configFile string - configType string - envPrefix string - - automaticEnvApplied bool - envKeyReplacer *strings.Replacer - - config map[string]interface{} - override map[string]interface{} - defaults map[string]interface{} - kvstore map[string]interface{} - pflags map[string]FlagValue - env map[string]string - aliases map[string]string - typeByDefValue bool - - onConfigChange func(fsnotify.Event) -} - -var _c *configurator - -func init() { - _c = New().(*configurator) + envPrefix string } func New() Configurator { - c := new(configurator) - c.keyDelim = "." - c.configName = "config" - c.fs = afero.NewOsFs() - c.config = make(map[string]interface{}) - c.override = make(map[string]interface{}) - c.defaults = make(map[string]interface{}) - c.kvstore = make(map[string]interface{}) - c.pflags = make(map[string]FlagValue) - c.env = make(map[string]string) - c.aliases = make(map[string]string) - c.typeByDefValue = false - - return c + return &config{} } -func OnConfigChange(run func(in fsnotify.Event)) { _c.OnConfigChange(run) } -func (c *configurator) OnConfigChange(run func(in fsnotify.Event)) { - c.onConfigChange = run -} +var _c *config -func WatchConfig() { _c.WatchConfig() } -func (c *configurator) WatchConfig() { - go func() { - watcher, err := fsnotify.NewWatcher() - if err != nil { - log.Fatal(err) - } - defer watcher.Close() - - // we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way - filename, err := c.getConfigFile() - if err != nil { - log.Println("error:", err) - return - } - - configFile := filepath.Clean(filename) - configDir, _ := filepath.Split(configFile) - - done := make(chan bool) - go func() { - for { - select { - case event := <-watcher.Events: - // we only care about the config file - if filepath.Clean(event.Name) == configFile { - if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create { - err := c.ReadInConfig() - if err != nil { - log.Println("error:", err) - } - c.onConfigChange(event) - } - } - case err := <-watcher.Errors: - log.Println("error:", err) - } - } - }() - - watcher.Add(configDir) - <-done - }() -} - -func SetConfigType(in string) { _c.SetConfigType(in) } -func (c *configurator) SetConfigType(in string) { - if in != "" { - c.configType = in - } -} - -func SetConfigFile(in string) { _c.SetConfigFile(in) } -func (c *configurator) SetConfigFile(in string) { - if in != "" { - c.configFile = in - } -} - -func SetEnvPrefix(in string) { _c.SetEnvPrefix(in) } -func (c *configurator) SetEnvPrefix(in string) { - if in != "" { - c.envPrefix = in - } -} - -// SetConfigName sets name for the config file. -// Does not include extension. -func SetConfigName(in string) { _c.SetConfigName(in) } -func (c *configurator) SetConfigName(in string) { - if in != "" { - c.configName = in - c.configFile = "" - } -} - -// AddConfigPath adds a path for Viper to search for the config file in. -// Can be called multiple times to define multiple search paths. -func AddConfigPath(in string) error { return _c.AddConfigPath(in) } -func (c *configurator) AddConfigPath(in string) error { +// SetConfigPath set a path to search for the config file in. +func SetConfigPath(in string) error { return _c.SetConfigPath(in) } +func (c *config) SetConfigPath(in string) error { if in != "" { absin, err := absPathify(in) if nil != err { return err } - if !stringInSlice(absin, c.configPaths) { - c.configPaths = append(c.configPaths, absin) - } + c.configPath = absin } return nil } -func ReadConfig(in io.Reader) error { return _c.ReadConfig(in) } -func (c *configurator) ReadConfig(in io.Reader) error { - c.config = make(map[string]interface{}) - return c.unmarshalReader(in, c.config) -} - -// ReadInConfig will discover and load the configuration file from disk -// and key/value stores, searching in one of the defined paths. -func ReadInConfig() error { return _c.ReadInConfig() } -func (c *configurator) ReadInConfig() error { - filename, err := c.getConfigFile() - if err != nil { - return err +// SetEnvPrefix set a prefix to search for the env variable. +func SetEnvPrefix(in string) { _c.SetEnvPrefix(in) } +func (c *config) SetEnvPrefix(in string) { + if in != "" { + c.envPrefix = in } - - if !stringInSlice(c.getConfigType(), SupportedExts) { - return UnsupportedConfigError(c.getConfigType()) - } - - file, err := afero.ReadFile(c.fs, filename) - if err != nil { - return err - } - - config := make(map[string]interface{}) - - err = c.unmarshalReader(bytes.NewReader(file), config) - if err != nil { - return err - } - - c.config = config - return nil } -// AutomaticEnv has Viper check ENV variables for all. -// keys set in config, default & flags -func AutomaticEnv() { _c.AutomaticEnv() } -func (c *configurator) AutomaticEnv() { - c.automaticEnvApplied = true +// Load will unmarshal from configuration file from disk +func Load(target interface{}, files ...string) error { + return _c.Load(target, files...) } +func (c *config) Load(target interface{}, files ...string) error { + filenames := c.getConfigFiles(files...) -// SetEnvKeyReplacer sets the strings.Replacer on the viper object -// Useful for mapping an environmental variable to a key that does -// not match it. -func SetEnvKeyReplacer(r *strings.Replacer) { _c.SetEnvKeyReplacer(r) } -func (c *configurator) SetEnvKeyReplacer(r *strings.Replacer) { - c.envKeyReplacer = r -} - -func RegisterAlias(alias string, key string) error { return _c.RegisterAlias(alias, key) } -func (c *configurator) RegisterAlias(alias string, key string) error { - return c.registerAlias(alias, strings.ToLower(key)) -} - -func (c *configurator) registerAlias(alias string, key string) error { - alias = strings.ToLower(alias) - if alias != key && alias != c.realKey(key) { - _, exists := c.aliases[alias] - - if !exists { - // if we alias something that exists in one of the maps to another - // name, we'll never be able to get that value using the original - // name, so move the config value to the new realkey. - if val, ok := c.config[alias]; ok { - delete(c.config, alias) - c.config[key] = val - } - if val, ok := c.kvstore[alias]; ok { - delete(c.kvstore, alias) - c.kvstore[key] = val - } - if val, ok := c.defaults[alias]; ok { - delete(c.defaults, alias) - c.defaults[key] = val - } - if val, ok := c.override[alias]; ok { - delete(c.override, alias) - c.override[key] = val - } - c.aliases[alias] = key - } - } else { - return CircularReferenceAliasError{alias: alias, key: key, realKey: c.realKey(key)} - } - return nil -} - -func SetDefault(key string, value interface{}) { _c.SetDefault(key, value) } -func (c *configurator) SetDefault(key string, value interface{}) { - // If alias passed in, then set the proper default - key = c.realKey(strings.ToLower(key)) - value = toCaseInsensitiveValue(value) - - path := strings.Split(key, c.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) - deepestMap := deepSearch(c.defaults, path[0:len(path)-1]) - - // set innermost value - deepestMap[lastKey] = value -} - -func Set(key string, value interface{}) { _c.Set(key, value) } -func (c *configurator) Set(key string, value interface{}) { - // If alias passed in, then set the proper override - key = c.realKey(strings.ToLower(key)) - value = toCaseInsensitiveValue(value) - - path := strings.Split(key, c.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) - deepestMap := deepSearch(c.override, path[0:len(path)-1]) - - // set innermost value - deepestMap[lastKey] = value -} - -func IsSet(key string) bool { return _c.IsSet(key) } -func (c *configurator) IsSet(key string) bool { - lcaseKey := strings.ToLower(key) - val := c.find(lcaseKey) - return val != nil -} - -func Get(key string) interface{} { return _c.Get(key) } -func (c *configurator) Get(key string) interface{} { - lcaseKey := strings.ToLower(key) - val := c.find(lcaseKey) - if val == nil { - return nil - } - - if c.typeByDefValue { - // TODO(bep) this branch isn't covered by a single test. - valType := val - path := strings.Split(lcaseKey, c.keyDelim) - defVal := c.searchMap(c.defaults, path) - if defVal != nil { - valType = defVal - } - - switch valType.(type) { - case bool: - return cast.ToBool(val) - case string: - return cast.ToString(val) - case int64, int32, int16, int8, int: - return cast.ToInt(val) - case float64, float32: - return cast.ToFloat64(val) - case time.Time: - return cast.ToTime(val) - case time.Duration: - return cast.ToDuration(val) - case []string: - return cast.ToStringSlice(val) + for _, file := range filenames { + if err := unmarshalFile(target, file); err != nil { + return err } } - return val + return unmarshalTags(target, c.getENVPrefix()) } -// GetString returns the value associated with the key as a string. -func GetString(key string) string { return _c.GetString(key) } -func (c *configurator) GetString(key string) string { - return cast.ToString(c.Get(key)) +// LoadReader will unmarshal from configuration bytes +func LoadReader(target interface{}, ext string, in io.Reader) error { + return _c.LoadReader(target, ext, in) +} +func (c *config) LoadReader(target interface{}, ext string, in io.Reader) error { + buf := new(bytes.Buffer) + buf.ReadFrom(in) + + return unmarshalData(target, ext, buf.Bytes()) } -// GetBool returns the value associated with the key as a boolean. -func GetBool(key string) bool { return _c.GetBool(key) } -func (c *configurator) GetBool(key string) bool { - return cast.ToBool(c.Get(key)) +// Save store to configuration file from disk +func Save(target interface{}, file string, overWrite bool) error { + return _c.Save(target, file, overWrite) +} +func (c *config) Save(target interface{}, file string, overWrite bool) error { + return marshalFile(target, file, overWrite) } -// GetInt returns the value associated with the key as an integer. -func GetInt(key string) int { return _c.GetInt(key) } -func (c *configurator) GetInt(key string) int { - return cast.ToInt(c.Get(key)) -} +// 1. file +// 2. configPath/file +func (c *config) getConfigFiles(files ...string) []string { + var results []string -// GetInt64 returns the value associated with the key as an integer. -func GetInt64(key string) int64 { return _c.GetInt64(key) } -func (c *configurator) GetInt64(key string) int64 { - return cast.ToInt64(c.Get(key)) -} - -// GetFloat64 returns the value associated with the key as a float64. -func GetFloat64(key string) float64 { return _c.GetFloat64(key) } -func (c *configurator) GetFloat64(key string) float64 { - return cast.ToFloat64(c.Get(key)) -} - -// GetTime returns the value associated with the key as time. -func GetTime(key string) time.Time { return _c.GetTime(key) } -func (c *configurator) GetTime(key string) time.Time { - return cast.ToTime(c.Get(key)) -} - -// GetDuration returns the value associated with the key as a duration. -func GetDuration(key string) time.Duration { return _c.GetDuration(key) } -func (c *configurator) GetDuration(key string) time.Duration { - return cast.ToDuration(c.Get(key)) -} - -// GetStringSlice returns the value associated with the key as a slice of strings. -func GetStringSlice(key string) []string { return _c.GetStringSlice(key) } -func (c *configurator) GetStringSlice(key string) []string { - return cast.ToStringSlice(c.Get(key)) -} - -// GetStringMap returns the value associated with the key as a map of interfaces. -func GetStringMap(key string) map[string]interface{} { return _c.GetStringMap(key) } -func (c *configurator) GetStringMap(key string) map[string]interface{} { - return cast.ToStringMap(c.Get(key)) -} - -// GetStringMapString returns the value associated with the key as a map of strings. -func GetStringMapString(key string) map[string]string { return _c.GetStringMapString(key) } -func (c *configurator) GetStringMapString(key string) map[string]string { - return cast.ToStringMapString(c.Get(key)) -} - -// GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. -func GetStringMapStringSlice(key string) map[string][]string { return _c.GetStringMapStringSlice(key) } -func (c *configurator) GetStringMapStringSlice(key string) map[string][]string { - return cast.ToStringMapStringSlice(c.Get(key)) -} - -// GetSizeInBytes returns the size of the value associated with the given key -// in bytes. -func GetSizeInBytes(key string) uint { return _c.GetSizeInBytes(key) } -func (c *configurator) GetSizeInBytes(key string) uint { - sizeStr := cast.ToString(c.Get(key)) - return parseSizeInBytes(sizeStr) -} - -// Sub returns new Viper instance representing a sub tree of this instance. -// Sub is case-insensitive for a key. -func Sub(key string) Configurator { return _c.Sub(key) } -func (c *configurator) Sub(key string) Configurator { - subv := New() - data := c.Get(key) - if data == nil { - return nil + for _, file := range files { + // check configuration + if exists(file) { + results = append(results, file) + } } - if reflect.TypeOf(data).Kind() == reflect.Map { - subv.(*configurator).config = cast.ToStringMap(data) - return subv + for _, file := range files { + // check configuration + pFile := filepath.Join(c.configPath, file) + if exists(pFile) { + results = append(results, pFile) + } } - return nil + + return results } -func (c *configurator) getConfigFile() (string, error) { - // if explicitly set, then use it - if c.configFile != "" { - return c.configFile, nil - } - - cf, err := c.findConfigFile() - if err != nil { - return "", err - } - - c.configFile = cf - return c.getConfigFile() -} - -func (c *configurator) getConfigType() string { - if c.configType != "" { - return c.configType - } - - cf, err := c.getConfigFile() - if err != nil { +func (c *config) getENVPrefix() string { + if c.envPrefix == "" { + if prefix := os.Getenv(ConfigEnvPrefix); prefix != "" { + return prefix + } return "" } - - ext := filepath.Ext(cf) - - if len(ext) > 1 { - return ext[1:] - } - - return "" -} - -func (c *configurator) findConfigFile() (string, error) { - for _, cp := range c.configPaths { - file := c.searchInPath(cp) - if file != "" { - return file, nil - } - } - return "", ConfigFileNotFoundError{c.configName, fmt.Sprintf("%s", c.configPaths)} -} - -func (c *configurator) searchInPath(in string) (filename string) { - for _, ext := range SupportedExts { - if b, _ := exists(filepath.Join(in, c.configName+"."+ext)); b { - return filepath.Join(in, c.configName+"."+ext) - } - } - - return "" -} - -func unmarshalReader(in io.Reader, config map[string]interface{}) error { - return _c.unmarshalReader(in, config) -} - -func (c *configurator) unmarshalReader(in io.Reader, config map[string]interface{}) error { - return unmarshallConfigReader(in, config, c.getConfigType()) -} - -func (c *configurator) insensitiviseMaps() { - insensitiviseMap(c.config) - insensitiviseMap(c.defaults) - insensitiviseMap(c.override) - insensitiviseMap(c.kvstore) -} - -func (c *configurator) realKey(key string) string { - newkey, exists := c.aliases[key] - if exists { - return c.realKey(newkey) - } - return key -} - -func (c *configurator) searchMap(source map[string]interface{}, path []string) interface{} { - if len(path) == 0 { - return source - } - - next, ok := source[path[0]] - if ok { - // Fast path - if len(path) == 1 { - return next - } - - // Nested case - switch next.(type) { - case map[interface{}]interface{}: - return c.searchMap(cast.ToStringMap(next), path[1:]) - case map[string]interface{}: - // Type assertion is safe here since it is only reached - // if the type of `next` is the same as the type being asserted - return c.searchMap(next.(map[string]interface{}), path[1:]) - default: - // got a value but nested key expected, return "nil" for not found - return nil - } - } - return nil -} - -// Marshal marshals the config to []byte. -func Marshal(configType string) ([]byte, error) { return _c.Marshal(configType) } -func (c *configurator) Marshal(configType string) ([]byte, error) { - - return marshallConfig(c.config, configType) -} - -// UnmarshalKey takes a single key and unmarshals it into a Struct. -func UnmarshalKey(key string, rawVal interface{}) error { return _c.UnmarshalKey(key, rawVal) } -func (c *configurator) UnmarshalKey(key string, rawVal interface{}) error { - err := decode(c.Get(key), defaultDecoderConfig(rawVal)) - - if err != nil { - return err - } - - c.insensitiviseMaps() - - return nil -} - -// Unmarshal unmarshals the config into a Struct. Make sure that the tags -// on the fields of the structure are properly set. -func Unmarshal(rawVal interface{}) error { return _c.Unmarshal(rawVal) } -func (c *configurator) Unmarshal(rawVal interface{}) error { - err := decode(c.AllSettings(), defaultDecoderConfig(rawVal)) - - if err != nil { - return err - } - - c.insensitiviseMaps() - - return nil -} - -// defaultDecoderConfig returns default mapsstructure.DecoderConfig with suppot -// of time.Duration values -func defaultDecoderConfig(output interface{}) *mapstructure.DecoderConfig { - return &mapstructure.DecoderConfig{ - Metadata: nil, - Result: output, - WeaklyTypedInput: true, - DecodeHook: mapstructure.StringToTimeDurationHookFunc(), - } -} - -// A wrapper around mapstructure.Decode that mimics the WeakDecode functionality -func decode(input interface{}, config *mapstructure.DecoderConfig) error { - decoder, err := mapstructure.NewDecoder(config) - if err != nil { - return err - } - return decoder.Decode(input) -} - -// UnmarshalExact unmarshals the config into a Struct, erroring if a field is nonexistent -// in the destination struct. -func (c *configurator) UnmarshalExact(rawVal interface{}) error { - config := defaultDecoderConfig(rawVal) - config.ErrorUnused = true - - err := decode(c.AllSettings(), config) - - if err != nil { - return err - } - - c.insensitiviseMaps() - - return nil -} - -// BindPFlags binds a full flag set to the configuration, using each flag's long -// name as the config key. -func BindPFlags(flags *pflag.FlagSet) error { return _c.BindPFlags(flags) } -func (c *configurator) BindPFlags(flags *pflag.FlagSet) error { - return c.BindFlagValues(pflagValueSet{flags}) -} - -// BindPFlag binds a specific key to a pflag (as used by cobra). -// Example (where serverCmd is a Cobra instance): -// -// serverCmd.Flags().Int("port", 1138, "Port to run Application server on") -// Viper.BindPFlag("port", serverCmd.Flags().Lookup("port")) -// -func BindPFlag(key string, flag *pflag.Flag) error { return _c.BindPFlag(key, flag) } -func (c *configurator) BindPFlag(key string, flag *pflag.Flag) error { - return c.BindFlagValue(key, pflagValue{flag}) -} - -// BindFlagValues binds a full FlagValue set to the configuration, using each flag's long -// name as the config key. -func BindFlagValues(flags FlagValueSet) error { return _c.BindFlagValues(flags) } -func (c *configurator) BindFlagValues(flags FlagValueSet) (err error) { - flags.VisitAll(func(flag FlagValue) { - if err = c.BindFlagValue(flag.Name(), flag); err != nil { - return - } - }) - return nil -} - -// BindFlagValue binds a specific key to a FlagValue. -// Example (where serverCmd is a Cobra instance): -// -// serverCmd.Flags().Int("port", 1138, "Port to run Application server on") -// Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port")) -// -func BindFlagValue(key string, flag FlagValue) error { return _c.BindFlagValue(key, flag) } -func (c *configurator) BindFlagValue(key string, flag FlagValue) error { - if flag == nil { - return fmt.Errorf("flag for %q is nil", key) - } - c.pflags[strings.ToLower(key)] = flag - return nil -} - -// BindEnv binds a Viper key to a ENV variable. -// ENV variables are case sensitive. -// If only a key is provided, it will use the env key matching the key, uppercased. -// EnvPrefix will be used when set when env name is not provided. -func BindEnv(input ...string) error { return _c.BindEnv(input...) } -func (c *configurator) BindEnv(input ...string) error { - var key, envkey string - if len(input) == 0 { - return fmt.Errorf("BindEnv missing key to bind to") - } - - key = strings.ToLower(input[0]) - - if len(input) == 1 { - envkey = c.mergeWithEnvPrefix(key) - } else { - envkey = input[1] - } - - c.env[key] = envkey - - return nil -} - -func (c *configurator) find(lcaseKey string) interface{} { - - var ( - val interface{} - exists bool - path = strings.Split(lcaseKey, c.keyDelim) - nested = len(path) > 1 - ) - - // compute the path through the nested maps to the nested value - if nested && c.isPathShadowedInDeepMap(path, castMapStringToMapInterface(c.aliases)) != "" { - return nil - } - - // if the requested key is an alias, then return the proper key - lcaseKey = c.realKey(lcaseKey) - path = strings.Split(lcaseKey, c.keyDelim) - nested = len(path) > 1 - - // Set() override first - val = c.searchMap(c.override, path) - if val != nil { - return val - } - if nested && c.isPathShadowedInDeepMap(path, c.override) != "" { - return nil - } - - // PFlag override next - flag, exists := c.pflags[lcaseKey] - if exists && flag.HasChanged() { - switch flag.ValueType() { - case "int", "int8", "int16", "int32", "int64": - return cast.ToInt(flag.ValueString()) - case "bool": - return cast.ToBool(flag.ValueString()) - case "stringSlice": - s := strings.TrimPrefix(flag.ValueString(), "[") - s = strings.TrimSuffix(s, "]") - res, _ := readAsCSV(s) - return res - default: - return flag.ValueString() - } - } - if nested && c.isPathShadowedInFlatMap(path, c.pflags) != "" { - return nil - } - - // Env override next - if c.automaticEnvApplied { - // even if it hasn't been registered, if automaticEnv is used, - // check any Get request - if val = c.getEnv(c.mergeWithEnvPrefix(lcaseKey)); val != "" { - return val - } - if nested && c.isPathShadowedInAutoEnv(path) != "" { - return nil - } - } - envkey, exists := c.env[lcaseKey] - if exists { - if val = c.getEnv(envkey); val != "" { - return val - } - } - if nested && c.isPathShadowedInFlatMap(path, c.env) != "" { - return nil - } - - // Config file next - val = c.searchMapWithPathPrefixes(c.config, path) - if val != nil { - return val - } - if nested && c.isPathShadowedInDeepMap(path, c.config) != "" { - return nil - } - - // K/V store next - val = c.searchMap(c.kvstore, path) - if val != nil { - return val - } - if nested && c.isPathShadowedInDeepMap(path, c.kvstore) != "" { - return nil - } - - // Default next - val = c.searchMap(c.defaults, path) - if val != nil { - return val - } - if nested && c.isPathShadowedInDeepMap(path, c.defaults) != "" { - return nil - } - - // last chance: if no other value is returned and a flag does exist for the value, - // get the flag's value even if the flag's value has not changed - if flag, exists := c.pflags[lcaseKey]; exists { - switch flag.ValueType() { - case "int", "int8", "int16", "int32", "int64": - return cast.ToInt(flag.ValueString()) - case "bool": - return cast.ToBool(flag.ValueString()) - case "stringSlice": - s := strings.TrimPrefix(flag.ValueString(), "[") - s = strings.TrimSuffix(s, "]") - res, _ := readAsCSV(s) - return res - default: - return flag.ValueString() - } - } - // last item, no need to check shadowing - - return nil -} - -func (c *configurator) searchMapWithPathPrefixes(source map[string]interface{}, path []string) interface{} { - if len(path) == 0 { - return source - } - - // search for path prefixes, starting from the longest one - for i := len(path); i > 0; i-- { - prefixKey := strings.ToLower(strings.Join(path[0:i], c.keyDelim)) - - next, ok := source[prefixKey] - if ok { - // Fast path - if i == len(path) { - return next - } - - // Nested case - var val interface{} - switch next.(type) { - case map[interface{}]interface{}: - val = c.searchMapWithPathPrefixes(cast.ToStringMap(next), path[i:]) - case map[string]interface{}: - // Type assertion is safe here since it is only reached - // if the type of `next` is the same as the type being asserted - val = c.searchMapWithPathPrefixes(next.(map[string]interface{}), path[i:]) - default: - // got a value but nested key expected, do nothing and look for next prefix - } - if val != nil { - return val - } - } - } - - // not found - return nil -} - -func (c *configurator) isPathShadowedInDeepMap(path []string, m map[string]interface{}) string { - var parentVal interface{} - for i := 1; i < len(path); i++ { - parentVal = c.searchMap(m, path[0:i]) - if parentVal == nil { - // not found, no need to add more path elements - return "" - } - switch parentVal.(type) { - case map[interface{}]interface{}: - continue - case map[string]interface{}: - continue - default: - // parentVal is a regular value which shadows "path" - return strings.Join(path[0:i], c.keyDelim) - } - } - return "" -} - -func (c *configurator) isPathShadowedInFlatMap(path []string, mi interface{}) string { - // unify input map - var m map[string]interface{} - switch mi.(type) { - case map[string]string, map[string]FlagValue: - m = cast.ToStringMap(mi) - default: - return "" - } - - // scan paths - var parentKey string - for i := 1; i < len(path); i++ { - parentKey = strings.Join(path[0:i], c.keyDelim) - if _, ok := m[parentKey]; ok { - return parentKey - } - } - return "" -} - -func (c *configurator) isPathShadowedInAutoEnv(path []string) string { - var parentKey string - var val string - for i := 1; i < len(path); i++ { - parentKey = strings.Join(path[0:i], c.keyDelim) - if val = c.getEnv(c.mergeWithEnvPrefix(parentKey)); val != "" { - return parentKey - } - } - return "" -} - -func (c *configurator) getEnv(key string) string { - if c.envKeyReplacer != nil { - key = c.envKeyReplacer.Replace(key) - } - return os.Getenv(key) -} - -func (c *configurator) mergeWithEnvPrefix(in string) string { - if c.envPrefix != "" { - return strings.ToUpper(c.envPrefix + "_" + in) - } - - return strings.ToUpper(in) -} - -func castToMapStringInterface( - src map[interface{}]interface{}) map[string]interface{} { - tgt := map[string]interface{}{} - for k, v := range src { - tgt[fmt.Sprintf("%v", k)] = v - } - return tgt -} - -func castMapStringToMapInterface(src map[string]string) map[string]interface{} { - tgt := map[string]interface{}{} - for k, v := range src { - tgt[k] = v - } - return tgt -} - -func castMapFlagToMapInterface(src map[string]FlagValue) map[string]interface{} { - tgt := map[string]interface{}{} - for k, v := range src { - tgt[k] = v - } - return tgt -} - -func readAsCSV(val string) ([]string, error) { - if val == "" { - return []string{}, nil - } - stringReader := strings.NewReader(val) - csvReader := csv.NewReader(stringReader) - return csvReader.Read() -} - -// AllSettings merges all settings and returns them as a map[string]interface{}. -func AllSettings() map[string]interface{} { return _c.AllSettings() } -func (c *configurator) AllSettings() map[string]interface{} { - m := map[string]interface{}{} - // start from the list of keys, and construct the map one value at a time - for _, k := range c.AllKeys() { - value := c.Get(k) - if value == nil { - // should not happen, since AllKeys() returns only keys holding a value, - // check just in case anything changes - continue - } - path := strings.Split(k, c.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) - deepestMap := deepSearch(m, path[0:len(path)-1]) - // set innermost value - deepestMap[lastKey] = value - } - return m -} - -// AllKeys returns all keys holding a value, regardless of where they are set. -// Nested keys are returned with a v.keyDelim (= ".") separator -func AllKeys() []string { return _c.AllKeys() } -func (c *configurator) AllKeys() []string { - m := map[string]bool{} - // add all paths, by order of descending priority to ensure correct shadowing - m = c.flattenAndMergeMap(m, castMapStringToMapInterface(c.aliases), "") - m = c.flattenAndMergeMap(m, c.override, "") - m = c.mergeFlatMap(m, castMapFlagToMapInterface(c.pflags)) - m = c.mergeFlatMap(m, castMapStringToMapInterface(c.env)) - m = c.flattenAndMergeMap(m, c.config, "") - m = c.flattenAndMergeMap(m, c.kvstore, "") - m = c.flattenAndMergeMap(m, c.defaults, "") - - // convert set of paths to list - a := []string{} - for x := range m { - a = append(a, x) - } - return a -} - -// flattenAndMergeMap recursively flattens the given map into a map[string]bool -// of key paths (used as a set, easier to manipulate than a []string): -// - each path is merged into a single key string, delimited with v.keyDelim (= ".") -// - if a path is shadowed by an earlier value in the initial shadow map, -// it is skipped. -// The resulting set of paths is merged to the given shadow set at the same time. -func (c *configurator) flattenAndMergeMap(shadow map[string]bool, m map[string]interface{}, prefix string) map[string]bool { - if shadow != nil && prefix != "" && shadow[prefix] { - // prefix is shadowed => nothing more to flatten - return shadow - } - if shadow == nil { - shadow = make(map[string]bool) - } - - var m2 map[string]interface{} - if prefix != "" { - prefix += c.keyDelim - } - for k, val := range m { - fullKey := prefix + k - switch val.(type) { - case map[string]interface{}: - m2 = val.(map[string]interface{}) - case map[interface{}]interface{}: - m2 = cast.ToStringMap(val) - default: - // immediate value - shadow[strings.ToLower(fullKey)] = true - continue - } - // recursively merge to shadow map - shadow = c.flattenAndMergeMap(shadow, m2, fullKey) - } - return shadow -} - -// mergeFlatMap merges the given maps, excluding values of the second map -// shadowed by values from the first map. -func (c *configurator) mergeFlatMap(shadow map[string]bool, m map[string]interface{}) map[string]bool { - // scan keys -outer: - for k, _ := range m { - path := strings.Split(k, c.keyDelim) - // scan intermediate paths - var parentKey string - for i := 1; i < len(path); i++ { - parentKey = strings.Join(path[0:i], c.keyDelim) - if shadow[parentKey] { - // path is shadowed, continue - continue outer - } - } - // add key - shadow[strings.ToLower(k)] = true - } - return shadow + return c.envPrefix } diff --git a/flags.go b/flags.go deleted file mode 100644 index 324dfec..0000000 --- a/flags.go +++ /dev/null @@ -1,57 +0,0 @@ -package config - -import "github.com/spf13/pflag" - -// FlagValueSet is an interface that users can implement -// to bind a set of flags to viper. -type FlagValueSet interface { - VisitAll(fn func(FlagValue)) -} - -// FlagValue is an interface that users can implement -// to bind different flags to viper. -type FlagValue interface { - HasChanged() bool - Name() string - ValueString() string - ValueType() string -} - -// pflagValueSet is a wrapper around *pflag.ValueSet -// that implements FlagValueSet. -type pflagValueSet struct { - flags *pflag.FlagSet -} - -// VisitAll iterates over all *pflag.Flag inside the *pflag.FlagSet. -func (p pflagValueSet) VisitAll(fn func(flag FlagValue)) { - p.flags.VisitAll(func(flag *pflag.Flag) { - fn(pflagValue{flag}) - }) -} - -// pflagValue is a wrapper aroung *pflag.flag -// that implements FlagValue -type pflagValue struct { - flag *pflag.Flag -} - -// HasChanges returns whether the flag has changes or not. -func (p pflagValue) HasChanged() bool { - return p.flag.Changed -} - -// Name returns the name of the flag. -func (p pflagValue) Name() string { - return p.flag.Name -} - -// ValueString returns the value of the flag as a string. -func (p pflagValue) ValueString() string { - return p.flag.Value.String() -} - -// ValueType returns the type of the flag as a string. -func (p pflagValue) ValueType() string { - return p.flag.Value.Type() -} diff --git a/glide.yaml b/glide.yaml index cbdf1ce..d985129 100644 --- a/glide.yaml +++ b/glide.yaml @@ -1,16 +1,4 @@ package: git.loafle.net/commons_go/config import: -- package: github.com/spf13/afero +- package: github.com/BurntSushi/toml - package: gopkg.in/yaml.v2 -- package: github.com/hashicorp/hcl -- package: github.com/pelletier/go-toml - version: v1.0.0 -- package: github.com/magiconair/properties - version: v1.7.3 -- package: github.com/spf13/cast - version: v1.1.0 -- package: github.com/spf13/pflag - version: v1.0.0 -- package: github.com/mitchellh/mapstructure -- package: github.com/fsnotify/fsnotify - version: v1.4.2 diff --git a/util.go b/util.go index 4a09872..594bd14 100644 --- a/util.go +++ b/util.go @@ -3,41 +3,19 @@ package config import ( "bytes" "encoding/json" + "errors" "fmt" - "io" + "io/ioutil" "os" "path/filepath" + "reflect" "runtime" "strings" - "unicode" - "github.com/hashicorp/hcl" - "github.com/magiconair/properties" - toml "github.com/pelletier/go-toml" - "github.com/spf13/cast" + "github.com/BurntSushi/toml" yaml "gopkg.in/yaml.v2" ) -// ConfigParseError denotes failing to parse configuration file. -type ConfigParseError struct { - err error -} - -// Error returns the formatted configuration error. -func (pe ConfigParseError) Error() string { - return fmt.Sprintf("While parsing config: %s", pe.err.Error()) -} - -// ConfigMarshalError denotes failing to marshal configuration. -type ConfigMarshalError struct { - err error -} - -// Error returns the formatted marshal error. -func (me ConfigMarshalError) Error() string { - return fmt.Sprintf("While marshaling config: %s", me.err.Error()) -} - func absPathify(inPath string) (string, error) { if strings.HasPrefix(inPath, "$HOME") { inPath = userHomeDir() + inPath[5:] @@ -71,221 +49,157 @@ func userHomeDir() string { return os.Getenv("HOME") } -func stringInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } +func exists(path string) bool { + if fileInfo, err := os.Stat(path); err == nil && fileInfo.Mode().IsRegular() { + return true } return false } -func exists(path string) (bool, error) { - _, err := _c.fs.Stat(path) - if err == nil { - return true, nil +func unmarshalFile(target interface{}, file string) error { + data, err := ioutil.ReadFile(file) + if err != nil { + return err } - if os.IsNotExist(err) { - return false, nil - } - return false, err + + return unmarshalData(target, filepath.Ext(file), data) } -func marshallConfig(c map[string]interface{}, configType string) ([]byte, error) { - var ( - buf []byte - err error - ) - - switch strings.ToLower(configType) { - case "yaml", "yml": - if buf, err = yaml.Marshal(c); err != nil { - return nil, ConfigMarshalError{err} +func unmarshalData(target interface{}, ext string, data []byte) error { + switch ext { + case ".yaml", ".yml": + return yaml.Unmarshal(data, target) + case ".toml": + return toml.Unmarshal(data, target) + case ".json": + return json.Unmarshal(data, target) + default: + if toml.Unmarshal(data, target) != nil { + if json.Unmarshal(data, target) != nil { + if yaml.Unmarshal(data, target) != nil { + return errors.New("failed to decode config") + } + } } - - case "json": - if buf, err = json.Marshal(c); err != nil { - return nil, ConfigMarshalError{err} - } - - case "toml": - if buf, err = toml.Marshal(c); err != nil { - return nil, ConfigMarshalError{err} - } - - } - - return buf, nil -} - -func unmarshallConfigReader(in io.Reader, c map[string]interface{}, configType string) error { - buf := new(bytes.Buffer) - buf.ReadFrom(in) - - switch strings.ToLower(configType) { - case "yaml", "yml": - if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil { - return ConfigParseError{err} - } - - case "json": - if err := json.Unmarshal(buf.Bytes(), &c); err != nil { - return ConfigParseError{err} - } - - case "hcl": - obj, err := hcl.Parse(string(buf.Bytes())) - if err != nil { - return ConfigParseError{err} - } - if err = hcl.DecodeObject(&c, obj); err != nil { - return ConfigParseError{err} - } - - case "toml": - tree, err := toml.LoadReader(buf) - if err != nil { - return ConfigParseError{err} - } - tmap := tree.ToMap() - for k, v := range tmap { - c[k] = v - } - - case "properties", "props", "prop": - var p *properties.Properties - var err error - if p, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil { - return ConfigParseError{err} - } - for _, key := range p.Keys() { - value, _ := p.Get(key) - // recursively build nested maps - path := strings.Split(key, ".") - lastKey := strings.ToLower(path[len(path)-1]) - deepestMap := deepSearch(c, path[0:len(path)-1]) - // set innermost value - deepestMap[lastKey] = value - } - } - - insensitiviseMap(c) - return nil -} - -func insensitiviseMap(m map[string]interface{}) { - for key, val := range m { - switch val.(type) { - case map[interface{}]interface{}: - // nested map: cast and recursively insensitivise - val = cast.ToStringMap(val) - insensitiviseMap(val.(map[string]interface{})) - case map[string]interface{}: - // nested map: recursively insensitivise - insensitiviseMap(val.(map[string]interface{})) - } - - lower := strings.ToLower(key) - if key != lower { - // remove old key (not lower-cased) - delete(m, key) - } - // update map - m[lower] = val + return nil } } -func deepSearch(m map[string]interface{}, path []string) map[string]interface{} { - for _, k := range path { - m2, ok := m[k] - if !ok { - // intermediate key does not exist - // => create it and continue from there - m3 := make(map[string]interface{}) - m[k] = m3 - m = m3 +func unmarshalTags(target interface{}, prefixes ...string) error { + targetValue := reflect.Indirect(reflect.ValueOf(target)) + if targetValue.Kind() != reflect.Struct { + return errors.New("invalid config, should be struct") + } + + targetType := targetValue.Type() + for i := 0; i < targetType.NumField(); i++ { + var ( + envNames []string + fieldStruct = targetType.Field(i) + field = targetValue.Field(i) + envName = fieldStruct.Tag.Get(ConfigTagEnv) // read configuration from shell env + ) + if !field.CanAddr() || !field.CanInterface() { continue } - m3, ok := m2.(map[string]interface{}) - if !ok { - // intermediate key is a value - // => replace with a new map - m3 = make(map[string]interface{}) - m[k] = m3 + if envName == "" { + envNames = append(envNames, strings.Join(append(prefixes, fieldStruct.Name), "_")) // Configor_DB_Name + envNames = append(envNames, strings.ToUpper(strings.Join(append(prefixes, fieldStruct.Name), "_"))) // CONFIGOR_DB_NAME + } else { + envNames = []string{envName} } - // continue search from here - m = m3 - } - return m -} - -func toCaseInsensitiveValue(value interface{}) interface{} { - switch v := value.(type) { - case map[interface{}]interface{}: - value = copyAndInsensitiviseMap(cast.ToStringMap(v)) - case map[string]interface{}: - value = copyAndInsensitiviseMap(v) - } - - return value -} - -func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} { - nm := make(map[string]interface{}) - - for key, val := range m { - lkey := strings.ToLower(key) - switch v := val.(type) { - case map[interface{}]interface{}: - nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v)) - case map[string]interface{}: - nm[lkey] = copyAndInsensitiviseMap(v) - default: - nm[lkey] = v + // Load From Shell ENV + for _, env := range envNames { + if value := os.Getenv(env); value != "" { + if err := yaml.Unmarshal([]byte(value), field.Addr().Interface()); err != nil { + return err + } + break + } + } + if isBlank := reflect.DeepEqual(field.Interface(), reflect.Zero(field.Type()).Interface()); isBlank { + // Set default configuration if blank + if value := fieldStruct.Tag.Get("default"); value != "" { + if err := yaml.Unmarshal([]byte(value), field.Addr().Interface()); err != nil { + return err + } + } else if fieldStruct.Tag.Get("required") == "true" { + // return error if it is required but blank + return fmt.Errorf("Field[%s] is required", fieldStruct.Name) + } + } + for field.Kind() == reflect.Ptr { + field = field.Elem() + } + if field.Kind() == reflect.Struct { + if err := unmarshalTags(field.Addr().Interface(), getPrefixForStruct(prefixes, &fieldStruct)...); err != nil { + return err + } } - } - return nm -} - -func safeMul(a, b uint) uint { - c := a * b - if a > 1 && b > 1 && c/b != a { - return 0 - } - return c -} - -func parseSizeInBytes(sizeStr string) uint { - sizeStr = strings.TrimSpace(sizeStr) - lastChar := len(sizeStr) - 1 - multiplier := uint(1) - - if lastChar > 0 { - if sizeStr[lastChar] == 'b' || sizeStr[lastChar] == 'B' { - if lastChar > 1 { - switch unicode.ToLower(rune(sizeStr[lastChar-1])) { - case 'k': - multiplier = 1 << 10 - sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) - case 'm': - multiplier = 1 << 20 - sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) - case 'g': - multiplier = 1 << 30 - sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) - default: - multiplier = 1 - sizeStr = strings.TrimSpace(sizeStr[:lastChar]) + if field.Kind() == reflect.Slice { + for i := 0; i < field.Len(); i++ { + if reflect.Indirect(field.Index(i)).Kind() == reflect.Struct { + if err := unmarshalTags(field.Index(i).Addr().Interface(), append(getPrefixForStruct(prefixes, &fieldStruct), fmt.Sprint(i))...); err != nil { + return err + } } } } } - size := cast.ToInt(sizeStr) - if size < 0 { - size = 0 + return nil +} + +func getPrefixForStruct(prefixes []string, fieldStruct *reflect.StructField) []string { + if fieldStruct.Anonymous && fieldStruct.Tag.Get("anonymous") == "true" { + return prefixes + } + return append(prefixes, fieldStruct.Name) +} + +func marshalFile(target interface{}, file string, overWrite bool) error { + var f *os.File + var err error + if exists(file) { + if !overWrite { + return fmt.Errorf("Config: File[%s] is exist", file) + } + if f, err = os.Open(file); nil != err { + return err + } + } else { + if f, err = os.Create(file); nil != err { + return err + } } - return safeMul(uint(size), multiplier) + var b []byte + if b, err = marshal(target, filepath.Ext(file)); nil != err { + return err + } + + if _, err = f.Write(b); nil != err { + return err + } + return nil + +} + +func marshal(target interface{}, ext string) ([]byte, error) { + switch ext { + case ".yaml", ".yml": + return yaml.Marshal(target) + case ".toml": + var buf bytes.Buffer + enc := toml.NewEncoder(&buf) + if err := enc.Encode(target); nil != err { + return nil, err + } + return buf.Bytes(), nil + case ".json": + return json.Marshal(target) + default: + return nil, fmt.Errorf("Config: Not supported extention[%s]", ext) + } }