From b16df4e1a985394d062a367af6776cc0b8b227cc Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Sun, 1 Mar 2026 23:44:16 +0000 Subject: [PATCH] fix(safety): eliminate TOCTOU race in readFile, guard WriteConfig, DRY getters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - readFile now opens the file first, then stats via the fd (no race between stat and open). Uses toml.NewDecoder instead of DecodeFile. - WriteConfig returns an error if no config file has been set. - YAML WriteConfig now calls enc.Close() to flush properly. - Extract resolve() helper to deduplicate the combinedConfig→envConfig fallback pattern across all 9 getter methods. --- getters.go | 68 ++++++++++++++++++++++-------------------------------- jety.go | 45 ++++++++++++++++++++---------------- 2 files changed, 53 insertions(+), 60 deletions(-) diff --git a/getters.go b/getters.go index dbb6690..9b345bd 100644 --- a/getters.go +++ b/getters.go @@ -7,15 +7,24 @@ import ( "time" ) +// resolve looks up a key in combinedConfig, falling back to envConfig. +func (c *ConfigManager) resolve(key string) (ConfigMap, bool) { + lower := strings.ToLower(key) + if v, ok := c.combinedConfig[lower]; ok { + return v, true + } + if v, ok := c.envConfig[lower]; ok { + return v, true + } + return ConfigMap{}, false +} + func (c *ConfigManager) Get(key string) any { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return nil - } + return nil } return v.Value } @@ -23,12 +32,9 @@ func (c *ConfigManager) Get(key string) any { func (c *ConfigManager) GetBool(key string) bool { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return false - } + return false } val := v.Value switch val := val.(type) { @@ -54,12 +60,9 @@ func (c *ConfigManager) GetBool(key string) bool { func (c *ConfigManager) GetDuration(key string) time.Duration { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return 0 - } + return 0 } val := v.Value switch val := val.(type) { @@ -89,12 +92,9 @@ func (c *ConfigManager) GetDuration(key string) time.Duration { func (c *ConfigManager) GetString(key string) string { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return "" - } + return "" } switch val := v.Value.(type) { @@ -108,12 +108,9 @@ func (c *ConfigManager) GetString(key string) string { func (c *ConfigManager) GetStringMap(key string) map[string]any { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return nil - } + return nil } switch val := v.Value.(type) { case map[string]any: @@ -126,12 +123,9 @@ func (c *ConfigManager) GetStringMap(key string) map[string]any { func (c *ConfigManager) GetStringSlice(key string) []string { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return nil - } + return nil } switch val := v.Value.(type) { case []string: @@ -155,12 +149,9 @@ func (c *ConfigManager) GetStringSlice(key string) []string { func (c *ConfigManager) GetInt(key string) int { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return 0 - } + return 0 } switch val := v.Value.(type) { case int: @@ -187,12 +178,9 @@ func (c *ConfigManager) GetInt(key string) int { func (c *ConfigManager) GetIntSlice(key string) []int { c.mutex.RLock() defer c.mutex.RUnlock() - v, ok := c.combinedConfig[strings.ToLower(key)] + v, ok := c.resolve(key) if !ok { - v, ok = c.envConfig[strings.ToLower(key)] - if !ok { - return nil - } + return nil } switch val := v.Value.(type) { case []int: diff --git a/jety.go b/jety.go index 0f42691..9eee510 100644 --- a/jety.go +++ b/jety.go @@ -160,6 +160,9 @@ func (c *ConfigManager) collapse() { func (c *ConfigManager) WriteConfig() error { c.mutex.RLock() defer c.mutex.RUnlock() + if c.configFileUsed == "" { + return errors.New("no config file specified") + } flattenedConfig := make(map[string]any) for _, v := range c.combinedConfig { flattenedConfig[v.Key] = v.Value @@ -181,8 +184,10 @@ func (c *ConfigManager) WriteConfig() error { } defer f.Close() enc := yaml.NewEncoder(f) - err = enc.Encode(flattenedConfig) - return err + if err = enc.Encode(flattenedConfig); err != nil { + return err + } + return enc.Close() case ConfigTypeJSON: f, err := os.Create(c.configFileUsed) if err != nil { @@ -272,33 +277,33 @@ func (c *ConfigManager) ReadInConfig() error { } func readFile(filename string, fileType configType) (map[string]any, error) { - fileData := make(map[string]any) - if d, err := os.Stat(filename); os.IsNotExist(err) { - return nil, ErrConfigFileNotFound - } else if d.Size() == 0 { + f, err := os.Open(filename) + if err != nil { + if os.IsNotExist(err) { + return nil, ErrConfigFileNotFound + } + return nil, err + } + defer f.Close() + + info, err := f.Stat() + if err != nil { + return nil, err + } + if info.Size() == 0 { return nil, ErrConfigFileEmpty } + fileData := make(map[string]any) switch fileType { case ConfigTypeTOML: - _, err := toml.DecodeFile(filename, &fileData) + _, err := toml.NewDecoder(f).Decode(&fileData) return fileData, err case ConfigTypeYAML: - f, err := os.Open(filename) - if err != nil { - return nil, err - } - defer f.Close() - d := yaml.NewDecoder(f) - err = d.Decode(&fileData) + err := yaml.NewDecoder(f).Decode(&fileData) return fileData, err case ConfigTypeJSON: - f, err := os.Open(filename) - if err != nil { - return nil, err - } - defer f.Close() - err = json.NewDecoder(f).Decode(&fileData) + err := json.NewDecoder(f).Decode(&fileData) return fileData, err default: return nil, fmt.Errorf("config type %s not supported", fileType)