check stderr before reporting back error 1

This commit is contained in:
Tai Groot 2021-05-14 16:55:17 -07:00
parent 2517125d98
commit feff1f6edd
Signed by: taigrr
GPG Key ID: D00C269A87614812
5 changed files with 68 additions and 12 deletions

View File

@ -40,17 +40,17 @@ import (
"log" "log"
"time" "time"
"github.com/taigrr/systemctl/v1" "github.com/taigrr/systemctl"
) )
func main() { func main() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
userMode := false
// Equivalent to `systemctl enable dhcpd` with a 10 second timeout // Equivalent to `systemctl enable nginx` with a 10 second timeout
err := systemctl.Enable(ctx, "dhcpd") err := systemctl.Enable(ctx, "nginx", userMode)
if err != nil { if err != nil {
log.Fatalf("unable to enable unit %s: %v", "dhcpd", err) log.Fatalf("unable to enable unit %s: %v", "nginx", err)
} }
} }
``` ```

View File

@ -17,4 +17,6 @@ var (
// ErrDoesNotExist means the unit specified doesn't exist or can't be found // ErrDoesNotExist means the unit specified doesn't exist or can't be found
ErrDoesNotExist = errors.New("Unit does not exist") ErrDoesNotExist = errors.New("Unit does not exist")
// ErrUnspecified means something in the stderr output contains the word `Failed`, but not a known case
ErrUnspecified = errors.New("Unknown error")
) )

View File

@ -2,8 +2,7 @@ package systemctl
import ( import (
"context" "context"
"errors" "fmt"
"strconv"
) )
// TODO // TODO
@ -44,22 +43,36 @@ func Enable(ctx context.Context, unit string, usermode bool) error {
args[1] = "--user" args[1] = "--user"
} }
_, stderr, code, err := execute(ctx, args) _, stderr, code, err := execute(ctx, args)
customErr := filterErr(stderr)
if err != nil { if customErr != nil {
return err return customErr
} }
err = filterErr(stderr)
if err != nil { if err != nil {
return err return err
} }
if code != 0 { if code != 0 {
return errors.New("received error code " + strconv.Itoa(code)) return fmt.Errorf("received error code %d for stderr `%s`: %w", code, stderr, ErrUnspecified)
} }
return nil return nil
} }
// TODO // TODO
func Disable(ctx context.Context, unit string, usermode bool) error { func Disable(ctx context.Context, unit string, usermode bool) error {
var args = []string{"disable", "--system", unit}
if usermode {
args[1] = "--user"
}
_, stderr, code, err := execute(ctx, args)
customErr := filterErr(stderr)
if customErr != nil {
return customErr
}
if err != nil {
return err
}
if code != 0 {
return fmt.Errorf("received error code %d for stderr `%s`: %w", code, stderr, ErrUnspecified)
}
return nil return nil
} }

28
systemctl_test.go Normal file
View File

@ -0,0 +1,28 @@
package systemctl
import (
"context"
"testing"
"time"
)
func TestEnableNonexistant(t *testing.T) {
unit := "nonexistant"
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := Enable(ctx, unit, true)
if err != ErrDoesNotExist {
t.Errorf("error is %v, but should have been %v", err, ErrDoesNotExist)
}
}
func TestEnableNoPermissions(t *testing.T) {
unit := "nginx"
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := Enable(ctx, unit, false)
if err != ErrInsufficientPermissions {
t.Errorf("error is %v, but should have been %v", err, ErrInsufficientPermissions)
}
}

13
util.go
View File

@ -45,6 +45,19 @@ func filterErr(stderr string) error {
if matched { if matched {
return ErrDoesNotExist return ErrDoesNotExist
} }
matched, _ = regexp.MatchString(`Interactive authentication required`, stderr)
if matched {
return ErrInsufficientPermissions
}
matched, _ = regexp.MatchString(`Access denied`, stderr)
if matched {
return ErrInsufficientPermissions
}
matched, _ = regexp.MatchString(`Failed`, stderr)
if matched {
return ErrUnspecified
}
return nil return nil
} }