From c4c05732f55d5fe21a42eb7ddd4b6909ec066d93 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Sat, 24 Jan 2026 20:29:43 -0500 Subject: [PATCH] update to fixup some race conditions --- README.md | 119 +++++- default.go | 4 +- doc.go | 31 ++ getters.go | 37 +- go.mod | 2 +- jety.go | 61 ++- jety_test.go | 1146 ++++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1358 insertions(+), 42 deletions(-) create mode 100644 doc.go create mode 100644 jety_test.go diff --git a/README.md b/README.md index 3990736..1d84c39 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,121 @@ JSON, ENV, TOML, YAML -This is a package for collapsing multiple configuration stores (env+json, env+yaml, env+toml) and writing them back to a centralized config. +A lightweight Go configuration management library supporting JSON, ENV, TOML, and YAML formats. +It provides viper-like `AutomaticEnv` functionality with fewer dependencies. +Originally built to support [grlx](http://github.com/gogrlx/grlx). -It should behave similarly to the AutomaticEnv functionality of viper, but without some of the extra heft of the depedendencies it carries. +## Installation -The inital purpose of this repo is to support the configuration requirements of [grlx](http://github.com/gogrlx/grlx), but development may continue to expand until more viper use cases and functionality are covered. +```bash +go get github.com/taigrr/jety +``` + +Requires Go 1.25.5 or later. + +## Quick Start + +```go +package main + +import "github.com/taigrr/jety" + +func main() { + // Set defaults + jety.SetDefault("port", 8080) + jety.SetDefault("host", "localhost") + + // Environment variables are loaded automatically + // e.g., PORT=9000 overrides the default + + // Read from config file + jety.SetConfigFile("config.toml") + jety.SetConfigType("toml") + if err := jety.ReadInConfig(); err != nil { + // handle error + } + + // Get values (config file > env > default) + port := jety.GetInt("port") + host := jety.GetString("host") +} +``` + +## Features + +- **Multiple formats**: JSON, TOML, YAML +- **Automatic env loading**: Environment variables loaded on init +- **Prefix filtering**: Filter env vars by prefix (e.g., `MYAPP_`) +- **Case-insensitive keys**: Keys normalized to lowercase +- **Type coercion**: Getters handle type conversion gracefully +- **Thread-safe**: Safe for concurrent access +- **Config precedence**: config file > environment > defaults + +## Migration Guide + +### From v0.x to v1.x + +#### Breaking Changes + +1. **`WriteConfig()` now returns `error`** + + ```go + // Before + jety.WriteConfig() + + // After + if err := jety.WriteConfig(); err != nil { + // handle error + } + // Or if you want to ignore the error: + _ = jety.WriteConfig() + ``` + +2. **Go 1.25.5 minimum required** + + Update your Go version or pin to an older jety release. + +#### Non-Breaking Improvements + +- Getters (`GetBool`, `GetInt`, `GetDuration`) now return zero values instead of panicking on unknown types +- Added `int64` support in `GetInt`, `GetIntSlice`, and `GetDuration` +- Improved env var parsing (handles values containing `=`) + +## API + +### Configuration + +| Function | Description | +| --------------------- | --------------------------------------------- | +| `SetConfigFile(path)` | Set config file path | +| `SetConfigDir(dir)` | Set config directory | +| `SetConfigName(name)` | Set config file name (without extension) | +| `SetConfigType(type)` | Set config type: `"toml"`, `"yaml"`, `"json"` | +| `ReadInConfig()` | Read config file | +| `WriteConfig()` | Write config to file | + +### Values + +| Function | Description | +| ------------------------ | ------------------------ | +| `Set(key, value)` | Set a value | +| `SetDefault(key, value)` | Set a default value | +| `Get(key)` | Get raw value | +| `GetString(key)` | Get as string | +| `GetInt(key)` | Get as int | +| `GetBool(key)` | Get as bool | +| `GetDuration(key)` | Get as time.Duration | +| `GetStringSlice(key)` | Get as []string | +| `GetIntSlice(key)` | Get as []int | +| `GetStringMap(key)` | Get as map[string]string | + +### Environment + +| Function | Description | +| ----------------------- | --------------------------------------------------- | +| `WithEnvPrefix(prefix)` | Filter env vars by prefix (strips prefix from keys) | +| `SetEnvPrefix(prefix)` | Set prefix for env var lookups | + +## License + +See [LICENSE](LICENSE) file. diff --git a/default.go b/default.go index f50c71f..8e36d87 100644 --- a/default.go +++ b/default.go @@ -40,8 +40,8 @@ func Set(key string, value any) { defaultConfigManager.Set(key, value) } -func WriteConfig() { - defaultConfigManager.WriteConfig() +func WriteConfig() error { + return defaultConfigManager.WriteConfig() } func ConfigFileUsed() string { diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..e6e7885 --- /dev/null +++ b/doc.go @@ -0,0 +1,31 @@ +// Package jety provides configuration management supporting JSON, ENV, TOML, and YAML formats. +// +// It offers viper-like AutomaticEnv functionality with minimal dependencies, allowing +// configuration to be loaded from files and environment variables with automatic merging. +// +// Configuration sources are layered with the following precedence (highest to lowest): +// - Values set via Set() or SetString()/SetBool() +// - Environment variables (optionally filtered by prefix) +// - Values from config file via ReadInConfig() +// - Default values set via SetDefault() +// +// Basic usage: +// +// jety.SetConfigFile("/etc/myapp/config.yaml") +// jety.SetConfigType("yaml") +// jety.SetEnvPrefix("MYAPP_") +// jety.SetDefault("port", 8080) +// +// if err := jety.ReadInConfig(); err != nil { +// log.Fatal(err) +// } +// +// port := jety.GetInt("port") +// +// For multiple independent configurations, create separate ConfigManager instances: +// +// cm := jety.NewConfigManager() +// cm.SetConfigFile("/etc/myapp/config.toml") +// cm.SetConfigType("toml") +// cm.ReadInConfig() +package jety diff --git a/getters.go b/getters.go index 33cad95..b45c224 100644 --- a/getters.go +++ b/getters.go @@ -35,29 +35,19 @@ func (c *ConfigManager) GetBool(key string) bool { case bool: return val case string: - if strings.ToLower(val) == "true" { - return true - } - return false + return strings.EqualFold(val, "true") case int: - if val == 0 { - return false - } - return true - case float32, float64: - if val == 0 { - return false - } - return true + return val != 0 + case float32: + return val != 0 + case float64: + return val != 0 + case time.Duration: + return val > 0 case nil: return false - case time.Duration: - if val == 0 || val < 0 { - return false - } - return true default: - return val.(bool) + return false } } @@ -83,6 +73,8 @@ func (c *ConfigManager) GetDuration(key string) time.Duration { return d case int: return time.Duration(val) + case int64: + return time.Duration(val) case float32: return time.Duration(val) case float64: @@ -90,8 +82,7 @@ func (c *ConfigManager) GetDuration(key string) time.Duration { case nil: return 0 default: - return val.(time.Duration) - + return 0 } } @@ -174,6 +165,8 @@ func (c *ConfigManager) GetInt(key string) int { switch val := v.Value.(type) { case int: return val + case int64: + return int(val) case string: i, err := strconv.Atoi(val) if err != nil { @@ -210,6 +203,8 @@ func (c *ConfigManager) GetIntSlice(key string) []int { switch v := v.(type) { case int: ret = append(ret, v) + case int64: + ret = append(ret, int(v)) case string: i, err := strconv.Atoi(v) if err != nil { diff --git a/go.mod b/go.mod index b956201..fb68b5c 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/taigrr/jety -go 1.21.3 +go 1.25.5 require ( github.com/BurntSushi/toml v1.3.2 diff --git a/jety.go b/jety.go index b105279..10e1350 100644 --- a/jety.go +++ b/jety.go @@ -4,7 +4,9 @@ import ( "encoding/json" "errors" "fmt" + "maps" "os" + "path/filepath" "strings" "sync" @@ -55,9 +57,12 @@ func NewConfigManager() *ConfigManager { cm.envPrefix = "" envSet := os.Environ() for _, env := range envSet { - kv := strings.Split(env, "=") - lower := strings.ToLower(kv[0]) - cm.envConfig[lower] = ConfigMap{Key: kv[0], Value: kv[1]} + key, value, found := strings.Cut(env, "=") + if !found { + continue + } + lower := strings.ToLower(key) + cm.envConfig[lower] = ConfigMap{Key: key, Value: value} } return &cm } @@ -65,17 +70,20 @@ func NewConfigManager() *ConfigManager { func (c *ConfigManager) WithEnvPrefix(prefix string) *ConfigManager { c.mutex.Lock() defer c.mutex.Unlock() - c.envPrefix = prefix envSet := os.Environ() c.envConfig = make(map[string]ConfigMap) for _, env := range envSet { - kv := strings.Split(env, "=") - if strings.HasPrefix(kv[0], prefix) { - withoutPrefix := strings.TrimPrefix(kv[0], prefix) + key, value, found := strings.Cut(env, "=") + if !found { + continue + } + if withoutPrefix, ok := strings.CutPrefix(key, prefix); ok { lower := strings.ToLower(withoutPrefix) - c.envConfig[lower] = ConfigMap{Key: withoutPrefix, Value: kv[1]} + c.envConfig[lower] = ConfigMap{Key: withoutPrefix, Value: value} } } + // Don't set envPrefix since keys are already stripped of prefix + c.envPrefix = "" return c } @@ -92,8 +100,8 @@ func (c *ConfigManager) UseExplicitDefaults(enable bool) { } func (c *ConfigManager) collapse() { - c.mutex.RLock() - defer c.mutex.RUnlock() + c.mutex.Lock() + defer c.mutex.Unlock() ccm := make(map[string]ConfigMap) for k, v := range c.defaultConfig { ccm[k] = v @@ -101,9 +109,7 @@ func (c *ConfigManager) collapse() { ccm[k] = c.envConfig[k] } } - for k, v := range c.mapConfig { - ccm[k] = v - } + maps.Copy(ccm, c.mapConfig) c.combinedConfig = ccm } @@ -147,6 +153,8 @@ func (c *ConfigManager) WriteConfig() error { } func (c *ConfigManager) SetConfigType(configType string) error { + c.mutex.Lock() + defer c.mutex.Unlock() switch configType { case "toml": c.configType = ConfigTypeTOML @@ -161,12 +169,34 @@ func (c *ConfigManager) SetConfigType(configType string) error { } func (c *ConfigManager) SetEnvPrefix(prefix string) { + c.mutex.Lock() + defer c.mutex.Unlock() c.envPrefix = prefix } func (c *ConfigManager) ReadInConfig() error { - // assume config = map[string]any - confFileData, err := readFile(c.configFileUsed, c.configType) + c.mutex.RLock() + configFile := c.configFileUsed + if configFile == "" && c.configPath != "" && c.configName != "" { + ext := "" + switch c.configType { + case ConfigTypeTOML: + ext = ".toml" + case ConfigTypeYAML: + ext = ".yaml" + case ConfigTypeJSON: + ext = ".json" + } + configFile = filepath.Join(c.configPath, c.configName+ext) + } + configType := c.configType + c.mutex.RUnlock() + + if configFile == "" { + return errors.New("no config file specified: use SetConfigFile or SetConfigDir + SetConfigName") + } + + confFileData, err := readFile(configFile, configType) if err != nil { return err } @@ -177,6 +207,7 @@ func (c *ConfigManager) ReadInConfig() error { } c.mutex.Lock() c.mapConfig = conf + c.configFileUsed = configFile c.mutex.Unlock() c.collapse() return nil diff --git a/jety_test.go b/jety_test.go new file mode 100644 index 0000000..a52a9e2 --- /dev/null +++ b/jety_test.go @@ -0,0 +1,1146 @@ +package jety + +import ( + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +// Test file contents +const ( + tomlConfig = ` +port = 8080 +host = "localhost" +debug = true +timeout = "30s" +rate = 1.5 +tags = ["api", "v1"] +counts = [1, 2, 3] + +[database] +host = "db.example.com" +port = 5432 +` + + yamlConfig = ` +port: 9090 +host: "yaml-host" +debug: false +timeout: "1m" +rate: 2.5 +tags: + - web + - v2 +counts: + - 10 + - 20 +database: + host: "yaml-db.example.com" + port: 3306 +` + + jsonConfig = `{ + "port": 7070, + "host": "json-host", + "debug": true, + "timeout": "15s", + "rate": 3.5, + "tags": ["json", "v3"], + "counts": [100, 200], + "database": { + "host": "json-db.example.com", + "port": 27017 + } +}` +) + +func TestNewConfigManager(t *testing.T) { + cm := NewConfigManager() + if cm == nil { + t.Fatal("NewConfigManager returned nil") + } + if cm.envConfig == nil { + t.Error("envConfig not initialized") + } + if cm.mapConfig == nil { + t.Error("mapConfig not initialized") + } + if cm.defaultConfig == nil { + t.Error("defaultConfig not initialized") + } + if cm.combinedConfig == nil { + t.Error("combinedConfig not initialized") + } +} + +func TestSetAndGetString(t *testing.T) { + cm := NewConfigManager() + cm.Set("name", "test-value") + + got := cm.GetString("name") + if got != "test-value" { + t.Errorf("GetString() = %q, want %q", got, "test-value") + } + + // Case insensitive + got = cm.GetString("NAME") + if got != "test-value" { + t.Errorf("GetString(NAME) = %q, want %q", got, "test-value") + } +} + +func TestSetAndGetInt(t *testing.T) { + cm := NewConfigManager() + + tests := []struct { + name string + value any + want int + }{ + {"int", 42, 42}, + {"string", "123", 123}, + {"float64", 99.9, 99}, + {"float32", float32(50.5), 50}, + {"invalid string", "not-a-number", 0}, + {"nil", nil, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cm.Set("key", tt.value) + got := cm.GetInt("key") + if got != tt.want { + t.Errorf("GetInt() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestSetAndGetBool(t *testing.T) { + cm := NewConfigManager() + + tests := []struct { + name string + value any + want bool + }{ + {"true bool", true, true}, + {"false bool", false, false}, + {"string true", "true", true}, + {"string TRUE", "TRUE", true}, + {"string false", "false", false}, + {"string other", "yes", false}, + {"int zero", 0, false}, + {"int nonzero", 1, true}, + {"float32 zero", float32(0), false}, + {"float32 nonzero", float32(1.5), true}, + {"float64 zero", float64(0), false}, + {"float64 nonzero", float64(1.5), true}, + {"duration zero", time.Duration(0), false}, + {"duration positive", time.Second, true}, + {"nil", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cm.Set("key", tt.value) + got := cm.GetBool("key") + if got != tt.want { + t.Errorf("GetBool() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSetAndGetDuration(t *testing.T) { + cm := NewConfigManager() + + tests := []struct { + name string + value any + want time.Duration + }{ + {"duration", 5 * time.Second, 5 * time.Second}, + {"string", "10s", 10 * time.Second}, + {"string minutes", "2m", 2 * time.Minute}, + {"invalid string", "not-duration", 0}, + {"int", 1000, time.Duration(1000)}, + {"int64", int64(2000), time.Duration(2000)}, + {"float64", float64(3000), time.Duration(3000)}, + {"float32", float32(4000), time.Duration(4000)}, + {"nil", nil, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cm.Set("key", tt.value) + got := cm.GetDuration("key") + if got != tt.want { + t.Errorf("GetDuration() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSetAndGetStringSlice(t *testing.T) { + cm := NewConfigManager() + + // Direct string slice + cm.Set("tags", []string{"a", "b", "c"}) + got := cm.GetStringSlice("tags") + if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" { + t.Errorf("GetStringSlice() = %v, want [a b c]", got) + } + + // []any slice + cm.Set("mixed", []any{"x", 123, "z"}) + got = cm.GetStringSlice("mixed") + if len(got) != 3 || got[0] != "x" || got[1] != "123" || got[2] != "z" { + t.Errorf("GetStringSlice() = %v, want [x 123 z]", got) + } + + // Non-slice returns nil + cm.Set("notslice", "single") + got = cm.GetStringSlice("notslice") + if got != nil { + t.Errorf("GetStringSlice() = %v, want nil", got) + } +} + +func TestSetAndGetIntSlice(t *testing.T) { + cm := NewConfigManager() + + // Direct int slice + cm.Set("nums", []int{1, 2, 3}) + got := cm.GetIntSlice("nums") + if len(got) != 3 || got[0] != 1 || got[1] != 2 || got[2] != 3 { + t.Errorf("GetIntSlice() = %v, want [1 2 3]", got) + } + + // []any slice with mixed types + cm.Set("mixed", []any{10, "20", float64(30), float32(40)}) + got = cm.GetIntSlice("mixed") + if len(got) != 4 || got[0] != 10 || got[1] != 20 || got[2] != 30 || got[3] != 40 { + t.Errorf("GetIntSlice() = %v, want [10 20 30 40]", got) + } + + // Invalid entries skipped + cm.Set("invalid", []any{1, "not-a-number", nil, 2}) + got = cm.GetIntSlice("invalid") + if len(got) != 2 || got[0] != 1 || got[1] != 2 { + t.Errorf("GetIntSlice() = %v, want [1 2]", got) + } +} + +func TestSetAndGetStringMap(t *testing.T) { + cm := NewConfigManager() + + m := map[string]any{"foo": "bar", "num": 123} + cm.Set("config", m) + got := cm.GetStringMap("config") + if got["foo"] != "bar" || got["num"] != 123 { + t.Errorf("GetStringMap() = %v, want %v", got, m) + } + + // Non-map returns nil + cm.Set("notmap", "string") + got = cm.GetStringMap("notmap") + if got != nil { + t.Errorf("GetStringMap() = %v, want nil", got) + } +} + +func TestGet(t *testing.T) { + cm := NewConfigManager() + cm.Set("key", "value") + + got := cm.Get("key") + if got != "value" { + t.Errorf("Get() = %v, want %q", got, "value") + } + + // Non-existent key + got = cm.Get("nonexistent") + if got != nil { + t.Errorf("Get(nonexistent) = %v, want nil", got) + } +} + +func TestSetDefault(t *testing.T) { + cm := NewConfigManager() + + cm.SetDefault("port", 8080) + if got := cm.GetInt("port"); got != 8080 { + t.Errorf("GetInt(port) = %d, want 8080", got) + } + + // Set overrides default + cm.Set("port", 9090) + if got := cm.GetInt("port"); got != 9090 { + t.Errorf("GetInt(port) after Set = %d, want 9090", got) + } + + // New default doesn't override existing value + cm.SetDefault("port", 7070) + if got := cm.GetInt("port"); got != 9090 { + t.Errorf("GetInt(port) after second SetDefault = %d, want 9090", got) + } +} + +func TestReadTOMLConfig(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "config.toml") + if err := os.WriteFile(configFile, []byte(tomlConfig), 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType("toml"); err != nil { + t.Fatal(err) + } + + if err := cm.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + if got := cm.GetInt("port"); got != 8080 { + t.Errorf("GetInt(port) = %d, want 8080", got) + } + if got := cm.GetString("host"); got != "localhost" { + t.Errorf("GetString(host) = %q, want %q", got, "localhost") + } + if got := cm.GetBool("debug"); got != true { + t.Errorf("GetBool(debug) = %v, want true", got) + } + if got := cm.ConfigFileUsed(); got != configFile { + t.Errorf("ConfigFileUsed() = %q, want %q", got, configFile) + } +} + +func TestReadYAMLConfig(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(configFile, []byte(yamlConfig), 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + + if err := cm.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + if got := cm.GetInt("port"); got != 9090 { + t.Errorf("GetInt(port) = %d, want 9090", got) + } + if got := cm.GetString("host"); got != "yaml-host" { + t.Errorf("GetString(host) = %q, want %q", got, "yaml-host") + } + if got := cm.GetBool("debug"); got != false { + t.Errorf("GetBool(debug) = %v, want false", got) + } +} + +func TestReadJSONConfig(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "config.json") + if err := os.WriteFile(configFile, []byte(jsonConfig), 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType("json"); err != nil { + t.Fatal(err) + } + + if err := cm.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + if got := cm.GetInt("port"); got != 7070 { + t.Errorf("GetInt(port) = %d, want 7070", got) + } + if got := cm.GetString("host"); got != "json-host" { + t.Errorf("GetString(host) = %q, want %q", got, "json-host") + } + if got := cm.GetBool("debug"); got != true { + t.Errorf("GetBool(debug) = %v, want true", got) + } +} + +func TestReadConfigWithDirAndName(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "myconfig.yaml") + if err := os.WriteFile(configFile, []byte(yamlConfig), 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigDir(dir) + cm.SetConfigName("myconfig") + if err := cm.SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + + if err := cm.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + if got := cm.GetInt("port"); got != 9090 { + t.Errorf("GetInt(port) = %d, want 9090", got) + } + if got := cm.ConfigFileUsed(); got != configFile { + t.Errorf("ConfigFileUsed() = %q, want %q", got, configFile) + } +} + +func TestReadConfigNoFileSpecified(t *testing.T) { + cm := NewConfigManager() + if err := cm.SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + + err := cm.ReadInConfig() + if err == nil { + t.Error("ReadInConfig() expected error, got nil") + } +} + +func TestReadConfigFileNotFound(t *testing.T) { + cm := NewConfigManager() + cm.SetConfigFile("/nonexistent/path/config.yaml") + if err := cm.SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + + err := cm.ReadInConfig() + if err != ErrConfigFileNotFound { + t.Errorf("ReadInConfig() error = %v, want ErrConfigFileNotFound", err) + } +} + +func TestReadConfigFileEmpty(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "empty.yaml") + if err := os.WriteFile(configFile, []byte{}, 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + + err := cm.ReadInConfig() + if err != ErrConfigFileEmpty { + t.Errorf("ReadInConfig() error = %v, want ErrConfigFileEmpty", err) + } +} + +func TestWriteConfig(t *testing.T) { + dir := t.TempDir() + + tests := []struct { + name string + configType string + ext string + }{ + {"TOML", "toml", ".toml"}, + {"YAML", "yaml", ".yaml"}, + {"JSON", "json", ".json"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configFile := filepath.Join(dir, "write_test"+tt.ext) + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType(tt.configType); err != nil { + t.Fatal(err) + } + + cm.Set("port", 8080) + cm.Set("host", "example.com") + + if err := cm.WriteConfig(); err != nil { + t.Fatalf("WriteConfig() error = %v", err) + } + + // Read it back + cm2 := NewConfigManager() + cm2.SetConfigFile(configFile) + if err := cm2.SetConfigType(tt.configType); err != nil { + t.Fatal(err) + } + if err := cm2.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + if got := cm2.GetInt("port"); got != 8080 { + t.Errorf("GetInt(port) = %d, want 8080", got) + } + if got := cm2.GetString("host"); got != "example.com" { + t.Errorf("GetString(host) = %q, want %q", got, "example.com") + } + }) + } +} + +func TestSetConfigTypeInvalid(t *testing.T) { + cm := NewConfigManager() + err := cm.SetConfigType("xml") + if err == nil { + t.Error("SetConfigType(xml) expected error, got nil") + } +} + +func TestEnvPrefix(t *testing.T) { + // Set env vars BEFORE creating ConfigManager + os.Setenv("TESTAPP_PORT", "3000") + os.Setenv("TESTAPP_HOST", "envhost") + os.Setenv("OTHER_VAR", "other") + defer func() { + os.Unsetenv("TESTAPP_PORT") + os.Unsetenv("TESTAPP_HOST") + os.Unsetenv("OTHER_VAR") + }() + + // Create new manager AFTER setting env vars, then apply prefix + cm := NewConfigManager().WithEnvPrefix("TESTAPP_") + + if got := cm.GetString("port"); got != "3000" { + t.Errorf("GetString(port) = %q, want %q", got, "3000") + } + if got := cm.GetString("host"); got != "envhost" { + t.Errorf("GetString(host) = %q, want %q", got, "envhost") + } + // OTHER_VAR should not be accessible without prefix + if got := cm.GetString("other_var"); got != "" { + t.Errorf("GetString(other_var) = %q, want empty", got) + } +} + +func TestEnvVarWithEqualsInValue(t *testing.T) { + os.Setenv("TEST_CONN", "host=localhost;user=admin") + defer os.Unsetenv("TEST_CONN") + + cm := NewConfigManager() + if got := cm.GetString("test_conn"); got != "host=localhost;user=admin" { + t.Errorf("GetString(test_conn) = %q, want %q", got, "host=localhost;user=admin") + } +} + +func TestEnvOverridesDefault(t *testing.T) { + os.Setenv("MYPORT", "5000") + defer os.Unsetenv("MYPORT") + + cm := NewConfigManager() + cm.SetDefault("myport", 8080) + + if got := cm.GetInt("myport"); got != 5000 { + t.Errorf("GetInt(myport) = %d, want 5000 (from env)", got) + } +} + +func TestConfigFileOverridesEnv(t *testing.T) { + os.Setenv("PORT", "5000") + defer os.Unsetenv("PORT") + + dir := t.TempDir() + configFile := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(configFile, []byte("port: 9000"), 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + cm.SetDefault("port", 8080) + + if err := cm.ReadInConfig(); err != nil { + t.Fatal(err) + } + + // Config file should override env and default + if got := cm.GetInt("port"); got != 9000 { + t.Errorf("GetInt(port) = %d, want 9000 (from file)", got) + } +} + +func TestCaseInsensitiveKeys(t *testing.T) { + cm := NewConfigManager() + cm.Set("MyKey", "value") + + tests := []string{"MyKey", "mykey", "MYKEY", "mYkEy"} + for _, key := range tests { + if got := cm.GetString(key); got != "value" { + t.Errorf("GetString(%q) = %q, want %q", key, got, "value") + } + } +} + +func TestGetNonExistentKey(t *testing.T) { + cm := NewConfigManager() + + if got := cm.GetString("nonexistent"); got != "" { + t.Errorf("GetString(nonexistent) = %q, want empty", got) + } + if got := cm.GetInt("nonexistent"); got != 0 { + t.Errorf("GetInt(nonexistent) = %d, want 0", got) + } + if got := cm.GetBool("nonexistent"); got != false { + t.Errorf("GetBool(nonexistent) = %v, want false", got) + } + if got := cm.GetDuration("nonexistent"); got != 0 { + t.Errorf("GetDuration(nonexistent) = %v, want 0", got) + } + if got := cm.GetStringSlice("nonexistent"); got != nil { + t.Errorf("GetStringSlice(nonexistent) = %v, want nil", got) + } + if got := cm.GetIntSlice("nonexistent"); got != nil { + t.Errorf("GetIntSlice(nonexistent) = %v, want nil", got) + } + if got := cm.GetStringMap("nonexistent"); got != nil { + t.Errorf("GetStringMap(nonexistent) = %v, want nil", got) + } +} + +func TestGetBoolUnknownType(t *testing.T) { + cm := NewConfigManager() + cm.Set("key", struct{}{}) + + // Should not panic, should return false + got := cm.GetBool("key") + if got != false { + t.Errorf("GetBool(struct) = %v, want false", got) + } +} + +func TestGetDurationUnknownType(t *testing.T) { + cm := NewConfigManager() + cm.Set("key", struct{}{}) + + // Should not panic, should return 0 + got := cm.GetDuration("key") + if got != 0 { + t.Errorf("GetDuration(struct) = %v, want 0", got) + } +} + +func TestConcurrentAccess(t *testing.T) { + cm := NewConfigManager() + var wg sync.WaitGroup + + // Concurrent writes + for i := range 100 { + wg.Add(1) + go func(n int) { + defer wg.Done() + cm.Set("key", n) + cm.SetDefault("default", n) + cm.SetString("str", "value") + cm.SetBool("bool", true) + }(i) + } + + // Concurrent reads + for range 100 { + wg.Go(func() { + _ = cm.GetInt("key") + _ = cm.GetString("str") + _ = cm.GetBool("bool") + _ = cm.Get("key") + _ = cm.ConfigFileUsed() + }) + } + + wg.Wait() +} + +func TestConcurrentReadWrite(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "concurrent.yaml") + if err := os.WriteFile(configFile, []byte("port: 8080"), 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + + // Reader goroutines + for range 50 { + wg.Go(func() { + for range 10 { + _ = cm.GetInt("port") + _ = cm.GetString("host") + } + }) + } + + // Writer goroutines + for i := range 50 { + wg.Add(1) + go func(n int) { + defer wg.Done() + for range 10 { + cm.Set("port", n) + cm.SetDefault("host", "localhost") + } + }(i) + } + + // Config operations + for range 10 { + wg.Go(func() { + _ = cm.ReadInConfig() + }) + } + + wg.Wait() +} + +// Package-level function tests (default.go) + +func TestPackageLevelFunctions(t *testing.T) { + // Reset default manager for this test + defaultConfigManager = NewConfigManager() + + dir := t.TempDir() + configFile := filepath.Join(dir, "pkg_test.yaml") + if err := os.WriteFile(configFile, []byte("port: 8888\nhost: pkghost"), 0o644); err != nil { + t.Fatal(err) + } + + SetConfigFile(configFile) + if err := SetConfigType("yaml"); err != nil { + t.Fatal(err) + } + SetDefault("timeout", "30s") + + if err := ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + // Set() must be called AFTER ReadInConfig to override file values + Set("debug", true) + + if got := GetInt("port"); got != 8888 { + t.Errorf("GetInt(port) = %d, want 8888", got) + } + if got := GetString("host"); got != "pkghost" { + t.Errorf("GetString(host) = %q, want %q", got, "pkghost") + } + if got := GetBool("debug"); got != true { + t.Errorf("GetBool(debug) = %v, want true", got) + } + if got := GetDuration("timeout"); got != 30*time.Second { + t.Errorf("GetDuration(timeout) = %v, want 30s", got) + } + if got := ConfigFileUsed(); got != configFile { + t.Errorf("ConfigFileUsed() = %q, want %q", got, configFile) + } +} + +func TestUseExplicitDefaults(t *testing.T) { + cm := NewConfigManager() + cm.UseExplicitDefaults(true) + + // Just verify it doesn't panic and the field is set + cm.SetDefault("key", "value") + if got := cm.GetString("key"); got != "value" { + t.Errorf("GetString(key) = %q, want %q", got, "value") + } +} + +func TestSetString(t *testing.T) { + cm := NewConfigManager() + cm.SetString("name", "test") + + if got := cm.GetString("name"); got != "test" { + t.Errorf("GetString(name) = %q, want %q", got, "test") + } +} + +func TestSetBool(t *testing.T) { + cm := NewConfigManager() + cm.SetBool("enabled", true) + + if got := cm.GetBool("enabled"); got != true { + t.Errorf("GetBool(enabled) = %v, want true", got) + } +} + +func TestWriteConfigUnsupportedType(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "test.txt") + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + // Don't set config type + + err := cm.WriteConfig() + if err == nil { + t.Error("WriteConfig() expected error for unsupported type, got nil") + } +} + +func TestSetEnvPrefix(t *testing.T) { + cm := NewConfigManager() + cm.SetEnvPrefix("PREFIX_") + + // Verify it doesn't panic + if cm.envPrefix != "PREFIX_" { + t.Errorf("envPrefix = %q, want %q", cm.envPrefix, "PREFIX_") + } +} + +func TestDeeplyNestedConfig(t *testing.T) { + const nestedYAML = ` +app: + name: myapp + server: + host: localhost + port: 8080 + tls: + enabled: true + cert: /path/to/cert.pem + key: /path/to/key.pem + database: + primary: + host: db1.example.com + port: 5432 + credentials: + username: admin + password: secret + replicas: + - host: db2.example.com + port: 5432 + - host: db3.example.com + port: 5432 + features: + - name: feature1 + enabled: true + config: + timeout: 30s + retries: 3 + - name: feature2 + enabled: false +` + + const nestedTOML = ` +[app] +name = "myapp" + +[app.server] +host = "localhost" +port = 8080 + +[app.server.tls] +enabled = true +cert = "/path/to/cert.pem" +key = "/path/to/key.pem" + +[app.database.primary] +host = "db1.example.com" +port = 5432 + +[app.database.primary.credentials] +username = "admin" +password = "secret" + +[[app.database.replicas]] +host = "db2.example.com" +port = 5432 + +[[app.database.replicas]] +host = "db3.example.com" +port = 5432 + +[[app.features]] +name = "feature1" +enabled = true + +[app.features.config] +timeout = "30s" +retries = 3 + +[[app.features]] +name = "feature2" +enabled = false +` + + const nestedJSON = `{ + "app": { + "name": "myapp", + "server": { + "host": "localhost", + "port": 8080, + "tls": { + "enabled": true, + "cert": "/path/to/cert.pem", + "key": "/path/to/key.pem" + } + }, + "database": { + "primary": { + "host": "db1.example.com", + "port": 5432, + "credentials": { + "username": "admin", + "password": "secret" + } + }, + "replicas": [ + {"host": "db2.example.com", "port": 5432}, + {"host": "db3.example.com", "port": 5432} + ] + }, + "features": [ + { + "name": "feature1", + "enabled": true, + "config": { + "timeout": "30s", + "retries": 3 + } + }, + { + "name": "feature2", + "enabled": false + } + ] + } +}` + + tests := []struct { + name string + configType string + content string + ext string + }{ + {"YAML", "yaml", nestedYAML, ".yaml"}, + {"TOML", "toml", nestedTOML, ".toml"}, + {"JSON", "json", nestedJSON, ".json"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "nested"+tt.ext) + if err := os.WriteFile(configFile, []byte(tt.content), 0o644); err != nil { + t.Fatal(err) + } + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType(tt.configType); err != nil { + t.Fatal(err) + } + + if err := cm.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + // Test that we can retrieve the top-level nested structure + app := cm.GetStringMap("app") + if app == nil { + t.Fatal("GetStringMap(app) = nil, want nested map") + } + + // Verify app.name exists + if name, ok := app["name"].(string); !ok || name != "myapp" { + t.Errorf("app.name = %v, want %q", app["name"], "myapp") + } + + // Verify nested server config + server, ok := app["server"].(map[string]any) + if !ok { + t.Fatalf("app.server is not a map: %T", app["server"]) + } + if server["host"] != "localhost" { + t.Errorf("app.server.host = %v, want %q", server["host"], "localhost") + } + + // Verify deeply nested TLS config + tls, ok := server["tls"].(map[string]any) + if !ok { + t.Fatalf("app.server.tls is not a map: %T", server["tls"]) + } + if tls["enabled"] != true { + t.Errorf("app.server.tls.enabled = %v, want true", tls["enabled"]) + } + if tls["cert"] != "/path/to/cert.pem" { + t.Errorf("app.server.tls.cert = %v, want %q", tls["cert"], "/path/to/cert.pem") + } + + // Verify database.primary.credentials (4 levels deep) + database, ok := app["database"].(map[string]any) + if !ok { + t.Fatalf("app.database is not a map: %T", app["database"]) + } + primary, ok := database["primary"].(map[string]any) + if !ok { + t.Fatalf("app.database.primary is not a map: %T", database["primary"]) + } + creds, ok := primary["credentials"].(map[string]any) + if !ok { + t.Fatalf("app.database.primary.credentials is not a map: %T", primary["credentials"]) + } + if creds["username"] != "admin" { + t.Errorf("credentials.username = %v, want %q", creds["username"], "admin") + } + + // Verify array of nested objects (replicas) + // TOML decodes to []map[string]interface{}, YAML/JSON to []any + var replicaHost any + switch r := database["replicas"].(type) { + case []any: + if len(r) != 2 { + t.Errorf("len(replicas) = %d, want 2", len(r)) + } + if len(r) > 0 { + replica0, ok := r[0].(map[string]any) + if !ok { + t.Fatalf("replicas[0] is not a map: %T", r[0]) + } + replicaHost = replica0["host"] + } + case []map[string]any: + if len(r) != 2 { + t.Errorf("len(replicas) = %d, want 2", len(r)) + } + if len(r) > 0 { + replicaHost = r[0]["host"] + } + default: + t.Fatalf("app.database.replicas unexpected type: %T", database["replicas"]) + } + if replicaHost != "db2.example.com" { + t.Errorf("replicas[0].host = %v, want %q", replicaHost, "db2.example.com") + } + + // Verify features array with nested config + // TOML decodes to []map[string]interface{}, YAML/JSON to []any + var featureName any + switch f := app["features"].(type) { + case []any: + if len(f) < 1 { + t.Fatal("features slice is empty") + } + feature0, ok := f[0].(map[string]any) + if !ok { + t.Fatalf("features[0] is not a map: %T", f[0]) + } + featureName = feature0["name"] + case []map[string]any: + if len(f) < 1 { + t.Fatal("features slice is empty") + } + featureName = f[0]["name"] + default: + t.Fatalf("app.features unexpected type: %T", app["features"]) + } + if featureName != "feature1" { + t.Errorf("features[0].name = %v, want %q", featureName, "feature1") + } + }) + } +} + +func TestDeeplyNestedWriteConfig(t *testing.T) { + dir := t.TempDir() + + tests := []struct { + name string + configType string + ext string + }{ + {"YAML", "yaml", ".yaml"}, + {"TOML", "toml", ".toml"}, + {"JSON", "json", ".json"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configFile := filepath.Join(dir, "nested_write"+tt.ext) + + cm := NewConfigManager() + cm.SetConfigFile(configFile) + if err := cm.SetConfigType(tt.configType); err != nil { + t.Fatal(err) + } + + // Set a deeply nested structure + nested := map[string]any{ + "server": map[string]any{ + "host": "localhost", + "port": 8080, + "tls": map[string]any{ + "enabled": true, + "cert": "/path/to/cert.pem", + }, + }, + "database": map[string]any{ + "primary": map[string]any{ + "host": "db.example.com", + "port": 5432, + }, + }, + } + cm.Set("app", nested) + + if err := cm.WriteConfig(); err != nil { + t.Fatalf("WriteConfig() error = %v", err) + } + + // Read it back + cm2 := NewConfigManager() + cm2.SetConfigFile(configFile) + if err := cm2.SetConfigType(tt.configType); err != nil { + t.Fatal(err) + } + if err := cm2.ReadInConfig(); err != nil { + t.Fatalf("ReadInConfig() error = %v", err) + } + + // Verify nested structure was preserved + app := cm2.GetStringMap("app") + if app == nil { + t.Fatal("GetStringMap(app) = nil after read") + } + + server, ok := app["server"].(map[string]any) + if !ok { + t.Fatalf("app.server is not a map: %T", app["server"]) + } + if server["host"] != "localhost" { + t.Errorf("app.server.host = %v, want %q", server["host"], "localhost") + } + + tls, ok := server["tls"].(map[string]any) + if !ok { + t.Fatalf("app.server.tls is not a map: %T", server["tls"]) + } + if tls["enabled"] != true { + t.Errorf("app.server.tls.enabled = %v, want true", tls["enabled"]) + } + }) + } +}