config/util.go

206 lines
4.9 KiB
Go
Raw Normal View History

2017-08-30 09:24:22 +00:00
package config
import (
"bytes"
"encoding/json"
"errors"
2017-08-30 09:24:22 +00:00
"fmt"
"io/ioutil"
2017-08-30 09:24:22 +00:00
"os"
"path/filepath"
"reflect"
2017-08-30 09:24:22 +00:00
"runtime"
"strings"
"github.com/BurntSushi/toml"
2017-08-30 09:24:22 +00:00
yaml "gopkg.in/yaml.v2"
)
func absPathify(inPath string) (string, error) {
if strings.HasPrefix(inPath, "$HOME") {
inPath = userHomeDir() + inPath[5:]
}
if strings.HasPrefix(inPath, "$") {
end := strings.Index(inPath, string(os.PathSeparator))
inPath = os.Getenv(inPath[1:end]) + inPath[end:]
}
if filepath.IsAbs(inPath) {
return filepath.Clean(inPath), nil
}
p, err := filepath.Abs(inPath)
if err == nil {
return filepath.Clean(p), nil
}
return "", err
}
func userHomeDir() string {
if runtime.GOOS == "windows" {
home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
if home == "" {
home = os.Getenv("USERPROFILE")
}
return home
}
return os.Getenv("HOME")
}
func exists(path string) bool {
if fileInfo, err := os.Stat(path); err == nil && fileInfo.Mode().IsRegular() {
return true
2017-08-30 09:24:22 +00:00
}
return false
}
func unmarshalFile(target interface{}, file string) error {
data, err := ioutil.ReadFile(file)
if err != nil {
return err
2017-08-30 09:24:22 +00:00
}
2017-08-30 11:28:19 +00:00
return unmarshalData(target, filepath.Ext(file), data)
}
2017-08-30 11:28:19 +00:00
func unmarshalData(target interface{}, ext string, data []byte) error {
switch ext {
case ".yaml", ".yml":
return yaml.Unmarshal(data, target)
case ".toml":
return toml.Unmarshal(data, target)
case ".json":
return json.Unmarshal(data, target)
default:
if toml.Unmarshal(data, target) != nil {
if json.Unmarshal(data, target) != nil {
if yaml.Unmarshal(data, target) != nil {
return errors.New("failed to decode config")
}
}
2017-08-30 11:28:19 +00:00
}
return nil
2017-08-30 11:28:19 +00:00
}
}
func unmarshalTags(target interface{}, prefixes ...string) error {
targetValue := reflect.Indirect(reflect.ValueOf(target))
if targetValue.Kind() != reflect.Struct {
return errors.New("invalid config, should be struct")
}
2017-08-30 09:24:22 +00:00
targetType := targetValue.Type()
for i := 0; i < targetType.NumField(); i++ {
var (
envNames []string
fieldStruct = targetType.Field(i)
field = targetValue.Field(i)
envName = fieldStruct.Tag.Get(ConfigTagEnv) // read configuration from shell env
)
if !field.CanAddr() || !field.CanInterface() {
continue
2017-08-30 09:24:22 +00:00
}
if envName == "" {
envNames = append(envNames, strings.Join(append(prefixes, fieldStruct.Name), "_")) // Configor_DB_Name
envNames = append(envNames, strings.ToUpper(strings.Join(append(prefixes, fieldStruct.Name), "_"))) // CONFIGOR_DB_NAME
} else {
envNames = []string{envName}
}
// Load From Shell ENV
for _, env := range envNames {
if value := os.Getenv(env); value != "" {
if err := yaml.Unmarshal([]byte(value), field.Addr().Interface()); err != nil {
return err
}
break
}
2017-08-30 09:24:22 +00:00
}
if isBlank := reflect.DeepEqual(field.Interface(), reflect.Zero(field.Type()).Interface()); isBlank {
// Set default configuration if blank
if value := fieldStruct.Tag.Get("default"); value != "" {
if err := yaml.Unmarshal([]byte(value), field.Addr().Interface()); err != nil {
return err
}
} else if fieldStruct.Tag.Get("required") == "true" {
// return error if it is required but blank
return fmt.Errorf("Field[%s] is required", fieldStruct.Name)
}
2017-08-30 09:24:22 +00:00
}
for field.Kind() == reflect.Ptr {
field = field.Elem()
2017-08-30 09:24:22 +00:00
}
if field.Kind() == reflect.Struct {
if err := unmarshalTags(field.Addr().Interface(), getPrefixForStruct(prefixes, &fieldStruct)...); err != nil {
return err
}
2017-08-30 09:24:22 +00:00
}
if field.Kind() == reflect.Slice {
for i := 0; i < field.Len(); i++ {
if reflect.Indirect(field.Index(i)).Kind() == reflect.Struct {
if err := unmarshalTags(field.Index(i).Addr().Interface(), append(getPrefixForStruct(prefixes, &fieldStruct), fmt.Sprint(i))...); err != nil {
return err
}
}
}
2017-08-30 09:24:22 +00:00
}
}
return nil
}
func getPrefixForStruct(prefixes []string, fieldStruct *reflect.StructField) []string {
if fieldStruct.Anonymous && fieldStruct.Tag.Get("anonymous") == "true" {
return prefixes
2017-08-30 09:24:22 +00:00
}
return append(prefixes, fieldStruct.Name)
2017-08-30 09:24:22 +00:00
}
func marshalFile(target interface{}, file string, overWrite bool) error {
var f *os.File
var err error
if exists(file) {
if !overWrite {
return fmt.Errorf("Config: File[%s] is exist", file)
2017-08-30 09:24:22 +00:00
}
if f, err = os.Open(file); nil != err {
return err
}
} else {
if f, err = os.Create(file); nil != err {
return err
2017-08-30 09:24:22 +00:00
}
}
var b []byte
if b, err = marshal(target, filepath.Ext(file)); nil != err {
return err
2017-08-30 09:24:22 +00:00
}
if _, err = f.Write(b); nil != err {
return err
2017-08-30 09:24:22 +00:00
}
return nil
2017-08-30 09:24:22 +00:00
}
func marshal(target interface{}, ext string) ([]byte, error) {
switch ext {
case ".yaml", ".yml":
return yaml.Marshal(target)
case ".toml":
var buf bytes.Buffer
enc := toml.NewEncoder(&buf)
if err := enc.Encode(target); nil != err {
return nil, err
2017-08-30 09:24:22 +00:00
}
return buf.Bytes(), nil
case ".json":
return json.Marshal(target)
default:
return nil, fmt.Errorf("Config: Not supported extention[%s]", ext)
2017-08-30 09:24:22 +00:00
}
}