diff --git a/systemctl_test.go b/systemctl_test.go index eb45b4c..32e14e2 100644 --- a/systemctl_test.go +++ b/systemctl_test.go @@ -248,6 +248,93 @@ func TestIsEnabled(t *testing.T) { } func TestMask(t *testing.T) { + errCases := []struct { + unit string + err error + opts Options + runAsUser bool + }{ + /* Run these tests only as an unpriviledged user */ + + //try nonexistant unit in user mode as user + {"nonexistant", ErrDoesNotExist, Options{UserMode: true}, true}, + // try existing unit in user mode as user + {"syncthing", nil, Options{UserMode: true}, true}, + // try nonexisting unit in system mode as user + {"nonexistant", ErrDoesNotExist, Options{UserMode: false}, true}, + // try existing unit in system mode as user + {"nginx", ErrInsufficientPermissions, Options{UserMode: false}, true}, + + /* End user tests*/ + + /* Run these tests only as a superuser */ + + // try nonexistant unit in system mode as system + {"nonexistant", ErrDoesNotExist, Options{UserMode: false}, false}, + // try existing unit in system mode as system + {"nginx", ErrBusFailure, Options{UserMode: true}, false}, + // try existing unit in system mode as system + {"nginx", nil, Options{UserMode: false}, false}, + + /* End superuser tests*/ + + } + for _, tc := range errCases { + t.Run(fmt.Sprintf("%s as %s", tc.unit, userString), func(t *testing.T) { + if (userString == "root" || userString == "system") && tc.runAsUser { + t.Skip("skipping user test while running as superuser") + } else if (userString != "root" && userString != "system") && !tc.runAsUser { + t.Skip("skipping superuser test while running as user") + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + err := Mask(ctx, tc.unit, tc.opts) + if err != tc.err { + t.Errorf("error is %v, but should have been %v", err, tc.err) + } + Unmask(ctx, tc.unit, tc.opts) + }) + } + t.Run(fmt.Sprintf("test double masking existing"), func(t *testing.T) { + unit := "nginx" + userMode := false + if userString != "root" && userString != "system" { + userMode = true + unit = "syncthing" + } + opts := Options{UserMode: userMode} + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + err := Mask(ctx, unit, opts) + if err != nil { + t.Errorf("error on initial masking is %v, but should have been %v", err, nil) + } + err = Mask(ctx, unit, opts) + if err != nil { + t.Errorf("error on second masking is %v, but should have been %v", err, nil) + } + Unmask(ctx, unit, opts) + + }) + t.Run(fmt.Sprintf("test double masking nonexisting"), func(t *testing.T) { + unit := "nonexistant" + userMode := false + if userString != "root" && userString != "system" { + userMode = true + } + opts := Options{UserMode: userMode} + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + err := Mask(ctx, unit, opts) + if err != ErrDoesNotExist { + t.Errorf("error on initial masking is %v, but should have been %v", err, ErrDoesNotExist) + } + err = Mask(ctx, unit, opts) + if err != nil { + t.Errorf("error on second masking is %v, but should have been %v", err, nil) + } + Unmask(ctx, unit, opts) + }) }