Adds MergeConfig functionality
This patch adds the `MergeConfig` and `MergeInConfig` functions to
enable reading new configuration files via a merge strategy rather
than replace. For example, take the following as the base YAML for a
configuration:
hello:
pop: 37890
world:
- us
- uk
- fr
- de
Now imagine we want to read the following, new configuration data:
hello:
pop: 45000
universe:
- mw
- ad
fu: bar
Using the standard `ReadConfig` function the value returned by the
nested key `hello.world` would no longer be present after the second
configuration is read. This is because the `ReadConfig` function and
its relatives replace nested structures entirely.
The new `MergeConfig` function would produce the following config
after the second YAML snippet was merged with the first:
hello:
pop: 45000
world:
- us
- uk
- fr
- de
universe:
- mw
- ad
fu: bar
Examples showing how this works can be found in the two unit tests
named `TestMergeConfig` and `TestMergeConfigNoMerge`.
diff --git a/viper.go b/viper.go
index b253c75..7c98c6f 100644
--- a/viper.go
+++ b/viper.go
@@ -901,12 +901,124 @@
return v.unmarshalReader(bytes.NewReader(file), v.config)
}
+// MergeInConfig merges a new configuration with an existing config.
+func MergeInConfig() error { return v.MergeInConfig() }
+func (v *Viper) MergeInConfig() error {
+ jww.INFO.Println("Attempting to merge in config file")
+ if !stringInSlice(v.getConfigType(), SupportedExts) {
+ return UnsupportedConfigError(v.getConfigType())
+ }
+
+ file, err := ioutil.ReadFile(v.getConfigFile())
+ if err != nil {
+ return err
+ }
+
+ return v.MergeConfig(bytes.NewReader(file))
+}
+
+// Viper will read a configuration file, setting existing keys to nil if the
+// key does not exist in the file.
func ReadConfig(in io.Reader) error { return v.ReadConfig(in) }
func (v *Viper) ReadConfig(in io.Reader) error {
v.config = make(map[string]interface{})
return v.unmarshalReader(in, v.config)
}
+// MergeConfig merges a new configuration with an existing config.
+func MergeConfig(in io.Reader) error { return v.MergeConfig(in) }
+func (v *Viper) MergeConfig(in io.Reader) error {
+ if v.config == nil {
+ v.config = make(map[string]interface{})
+ }
+ cfg := make(map[string]interface{})
+ if err := v.unmarshalReader(in, cfg); err != nil {
+ return err
+ }
+ mergeMaps(cfg, v.config, nil)
+ return nil
+}
+
+func keyExists(k string, m map[string]interface{}) string {
+ lk := strings.ToLower(k)
+ for mk := range m {
+ lmk := strings.ToLower(mk)
+ if lmk == lk {
+ return mk
+ }
+ }
+ return ""
+}
+
+func castToMapStringInterface(
+ src map[interface{}]interface{}) map[string]interface{} {
+ tgt := map[string]interface{}{}
+ for k, v := range src {
+ tgt[fmt.Sprintf("%v", k)] = v
+ }
+ return tgt
+}
+
+// mergeMaps merges two maps. The `itgt` parameter is for handling go-yaml's
+// insistence on parsing nested structures as `map[interface{}]interface{}`
+// instead of using a `string` as the key for nest structures beyond one level
+// deep. Both map types are supported as there is a go-yaml fork that uses
+// `map[string]interface{}` instead.
+func mergeMaps(
+ src, tgt map[string]interface{}, itgt map[interface{}]interface{}) {
+ for sk, sv := range src {
+ tk := keyExists(sk, tgt)
+ if tk == "" {
+ jww.TRACE.Printf("tk=\"\", tgt[%s]=%v", sk, sv)
+ tgt[sk] = sv
+ if itgt != nil {
+ itgt[sk] = sv
+ }
+ continue
+ }
+
+ tv, ok := tgt[tk]
+ if !ok {
+ jww.TRACE.Printf("tgt[%s] != ok, tgt[%s]=%v", tk, sk, sv)
+ tgt[sk] = sv
+ if itgt != nil {
+ itgt[sk] = sv
+ }
+ continue
+ }
+
+ svType := reflect.TypeOf(sv)
+ tvType := reflect.TypeOf(tv)
+ if svType != tvType {
+ jww.ERROR.Printf(
+ "svType != tvType; key=%s, st=%v, tt=%v, sv=%v, tv=%v",
+ sk, svType, tvType, sv, tv)
+ continue
+ }
+
+ jww.TRACE.Printf("processing key=%s, st=%v, tt=%v, sv=%v, tv=%v",
+ sk, svType, tvType, sv, tv)
+
+ switch ttv := tv.(type) {
+ case map[interface{}]interface{}:
+ jww.TRACE.Printf("merging maps (must convert)")
+ tsv := sv.(map[interface{}]interface{})
+ ssv := castToMapStringInterface(tsv)
+ stv := castToMapStringInterface(ttv)
+ mergeMaps(ssv, stv, ttv)
+ case map[string]interface{}:
+ jww.TRACE.Printf("merging maps")
+ mergeMaps(sv.(map[string]interface{}), ttv, nil)
+ default:
+ jww.TRACE.Printf("setting value")
+ tgt[tk] = sv
+ if itgt != nil {
+ itgt[tk] = sv
+ }
+ }
+ }
+}
+
// func ReadBufConfig(buf *bytes.Buffer) error { return v.ReadBufConfig(buf) }
// func (v *Viper) ReadBufConfig(buf *bytes.Buffer) error {
// v.config = make(map[string]interface{})
diff --git a/viper_test.go b/viper_test.go
index 0c71549..e9bb3f4 100644
--- a/viper_test.go
+++ b/viper_test.go
@@ -737,3 +737,101 @@
assert.Equal(t, subv, (*Viper)(nil))
}
+var yamlMergeExampleTgt = []byte(`
+hello:
+ pop: 37890
+ world:
+ - us
+ - uk
+ - fr
+ - de
+`)
+
+var yamlMergeExampleSrc = []byte(`
+hello:
+ pop: 45000
+ universe:
+ - mw
+ - ad
+fu: bar
+`)
+
+func TestMergeConfig(t *testing.T) {
+ v := New()
+ v.SetConfigType("yml")
+ if err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleTgt)); err != nil {
+ t.Fatal(err)
+ }
+
+ if pop := v.GetInt("hello.pop"); pop != 37890 {
+ t.Fatalf("pop != 37890, = %d", pop)
+ }
+
+ if world := v.GetStringSlice("hello.world"); len(world) != 4 {
+ t.Fatalf("len(world) != 4, = %d", len(world))
+ }
+
+ if fu := v.GetString("fu"); fu != "" {
+ t.Fatalf("fu != \"\", = %s", fu)
+ }
+
+ if err := v.MergeConfig(bytes.NewBuffer(yamlMergeExampleSrc)); err != nil {
+ t.Fatal(err)
+ }
+
+ if pop := v.GetInt("hello.pop"); pop != 45000 {
+ t.Fatalf("pop != 45000, = %d", pop)
+ }
+
+ if world := v.GetStringSlice("hello.world"); len(world) != 4 {
+ t.Fatalf("len(world) != 4, = %d", len(world))
+ }
+
+ if universe := v.GetStringSlice("hello.universe"); len(universe) != 2 {
+ t.Fatalf("len(universe) != 2, = %d", len(universe))
+ }
+
+ if fu := v.GetString("fu"); fu != "bar" {
+ t.Fatalf("fu != \"bar\", = %s", fu)
+ }
+}
+
+func TestMergeConfigNoMerge(t *testing.T) {
+ v := New()
+ v.SetConfigType("yml")
+ if err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleTgt)); err != nil {
+ t.Fatal(err)
+ }
+
+ if pop := v.GetInt("hello.pop"); pop != 37890 {
+ t.Fatalf("pop != 37890, = %d", pop)
+ }
+
+ if world := v.GetStringSlice("hello.world"); len(world) != 4 {
+ t.Fatalf("len(world) != 4, = %d", len(world))
+ }
+
+ if fu := v.GetString("fu"); fu != "" {
+ t.Fatalf("fu != \"\", = %s", fu)
+ }
+
+ if err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleSrc)); err != nil {
+ t.Fatal(err)
+ }
+
+ if pop := v.GetInt("hello.pop"); pop != 45000 {
+ t.Fatalf("pop != 45000, = %d", pop)
+ }
+
+ if world := v.GetStringSlice("hello.world"); len(world) != 0 {
+ t.Fatalf("len(world) != 0, = %d", len(world))
+ }
+
+ if universe := v.GetStringSlice("hello.universe"); len(universe) != 2 {
+ t.Fatalf("len(universe) != 2, = %d", len(universe))
+ }
+
+ if fu := v.GetString("fu"); fu != "bar" {
+ t.Fatalf("fu != \"bar\", = %s", fu)
+ }
+}