fix(safety): eliminate TOCTOU race in readFile, guard WriteConfig, DRY getters

- readFile now opens the file first, then stats via the fd (no race
  between stat and open). Uses toml.NewDecoder instead of DecodeFile.
- WriteConfig returns an error if no config file has been set.
- YAML WriteConfig now calls enc.Close() to flush properly.
- Extract resolve() helper to deduplicate the combinedConfig→envConfig
  fallback pattern across all 9 getter methods.
This commit is contained in:
2026-03-01 23:44:16 +00:00
parent 5aadc84d50
commit b16df4e1a9
2 changed files with 53 additions and 60 deletions

View File

@@ -7,15 +7,24 @@ import (
"time" "time"
) )
// resolve looks up a key in combinedConfig, falling back to envConfig.
func (c *ConfigManager) resolve(key string) (ConfigMap, bool) {
lower := strings.ToLower(key)
if v, ok := c.combinedConfig[lower]; ok {
return v, true
}
if v, ok := c.envConfig[lower]; ok {
return v, true
}
return ConfigMap{}, false
}
func (c *ConfigManager) Get(key string) any { func (c *ConfigManager) Get(key string) any {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return nil
if !ok {
return nil
}
} }
return v.Value return v.Value
} }
@@ -23,12 +32,9 @@ func (c *ConfigManager) Get(key string) any {
func (c *ConfigManager) GetBool(key string) bool { func (c *ConfigManager) GetBool(key string) bool {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return false
if !ok {
return false
}
} }
val := v.Value val := v.Value
switch val := val.(type) { switch val := val.(type) {
@@ -54,12 +60,9 @@ func (c *ConfigManager) GetBool(key string) bool {
func (c *ConfigManager) GetDuration(key string) time.Duration { func (c *ConfigManager) GetDuration(key string) time.Duration {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return 0
if !ok {
return 0
}
} }
val := v.Value val := v.Value
switch val := val.(type) { switch val := val.(type) {
@@ -89,12 +92,9 @@ func (c *ConfigManager) GetDuration(key string) time.Duration {
func (c *ConfigManager) GetString(key string) string { func (c *ConfigManager) GetString(key string) string {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return ""
if !ok {
return ""
}
} }
switch val := v.Value.(type) { switch val := v.Value.(type) {
@@ -108,12 +108,9 @@ func (c *ConfigManager) GetString(key string) string {
func (c *ConfigManager) GetStringMap(key string) map[string]any { func (c *ConfigManager) GetStringMap(key string) map[string]any {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return nil
if !ok {
return nil
}
} }
switch val := v.Value.(type) { switch val := v.Value.(type) {
case map[string]any: case map[string]any:
@@ -126,12 +123,9 @@ func (c *ConfigManager) GetStringMap(key string) map[string]any {
func (c *ConfigManager) GetStringSlice(key string) []string { func (c *ConfigManager) GetStringSlice(key string) []string {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return nil
if !ok {
return nil
}
} }
switch val := v.Value.(type) { switch val := v.Value.(type) {
case []string: case []string:
@@ -155,12 +149,9 @@ func (c *ConfigManager) GetStringSlice(key string) []string {
func (c *ConfigManager) GetInt(key string) int { func (c *ConfigManager) GetInt(key string) int {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return 0
if !ok {
return 0
}
} }
switch val := v.Value.(type) { switch val := v.Value.(type) {
case int: case int:
@@ -187,12 +178,9 @@ func (c *ConfigManager) GetInt(key string) int {
func (c *ConfigManager) GetIntSlice(key string) []int { func (c *ConfigManager) GetIntSlice(key string) []int {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
v, ok := c.combinedConfig[strings.ToLower(key)] v, ok := c.resolve(key)
if !ok { if !ok {
v, ok = c.envConfig[strings.ToLower(key)] return nil
if !ok {
return nil
}
} }
switch val := v.Value.(type) { switch val := v.Value.(type) {
case []int: case []int:

45
jety.go
View File

@@ -160,6 +160,9 @@ func (c *ConfigManager) collapse() {
func (c *ConfigManager) WriteConfig() error { func (c *ConfigManager) WriteConfig() error {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
if c.configFileUsed == "" {
return errors.New("no config file specified")
}
flattenedConfig := make(map[string]any) flattenedConfig := make(map[string]any)
for _, v := range c.combinedConfig { for _, v := range c.combinedConfig {
flattenedConfig[v.Key] = v.Value flattenedConfig[v.Key] = v.Value
@@ -181,8 +184,10 @@ func (c *ConfigManager) WriteConfig() error {
} }
defer f.Close() defer f.Close()
enc := yaml.NewEncoder(f) enc := yaml.NewEncoder(f)
err = enc.Encode(flattenedConfig) if err = enc.Encode(flattenedConfig); err != nil {
return err return err
}
return enc.Close()
case ConfigTypeJSON: case ConfigTypeJSON:
f, err := os.Create(c.configFileUsed) f, err := os.Create(c.configFileUsed)
if err != nil { if err != nil {
@@ -272,33 +277,33 @@ func (c *ConfigManager) ReadInConfig() error {
} }
func readFile(filename string, fileType configType) (map[string]any, error) { func readFile(filename string, fileType configType) (map[string]any, error) {
fileData := make(map[string]any) f, err := os.Open(filename)
if d, err := os.Stat(filename); os.IsNotExist(err) { if err != nil {
return nil, ErrConfigFileNotFound if os.IsNotExist(err) {
} else if d.Size() == 0 { return nil, ErrConfigFileNotFound
}
return nil, err
}
defer f.Close()
info, err := f.Stat()
if err != nil {
return nil, err
}
if info.Size() == 0 {
return nil, ErrConfigFileEmpty return nil, ErrConfigFileEmpty
} }
fileData := make(map[string]any)
switch fileType { switch fileType {
case ConfigTypeTOML: case ConfigTypeTOML:
_, err := toml.DecodeFile(filename, &fileData) _, err := toml.NewDecoder(f).Decode(&fileData)
return fileData, err return fileData, err
case ConfigTypeYAML: case ConfigTypeYAML:
f, err := os.Open(filename) err := yaml.NewDecoder(f).Decode(&fileData)
if err != nil {
return nil, err
}
defer f.Close()
d := yaml.NewDecoder(f)
err = d.Decode(&fileData)
return fileData, err return fileData, err
case ConfigTypeJSON: case ConfigTypeJSON:
f, err := os.Open(filename) err := json.NewDecoder(f).Decode(&fileData)
if err != nil {
return nil, err
}
defer f.Close()
err = json.NewDecoder(f).Decode(&fileData)
return fileData, err return fileData, err
default: default:
return nil, fmt.Errorf("config type %s not supported", fileType) return nil, fmt.Errorf("config type %s not supported", fileType)