4 Commits

Author SHA1 Message Date
21fce7918e add linux build tags to restrict compilation to linux targets 2025-02-13 16:49:49 -08:00
54f4f7a235 add IsSystemd checker 2025-02-12 19:21:48 -08:00
fa15432121 add ability to list all units 2024-08-08 15:37:04 -07:00
7bd5bef0cb fix broken error filtration 2023-06-28 23:43:20 -07:00
7 changed files with 122 additions and 15 deletions

View File

@@ -1,3 +1,5 @@
//go:build linux
package systemctl package systemctl
import ( import (

View File

@@ -1,8 +1,11 @@
//go:build linux
package systemctl package systemctl
import ( import (
"context" "context"
"errors" "errors"
"os"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -55,6 +58,34 @@ func GetPID(ctx context.Context, unit string, opts Options) (int, error) {
return strconv.Atoi(value) return strconv.Atoi(value)
} }
func GetUnits(ctx context.Context, opts Options) ([]Unit, error) {
args := []string{"list-units", "--all", "--no-legend", "--full", "--no-pager"}
if opts.UserMode {
args = append(args, "--user")
}
stdout, stderr, _, err := execute(ctx, args)
if err != nil {
return []Unit{}, errors.Join(err, filterErr(stderr))
}
lines := strings.Split(stdout, "\n")
units := []Unit{}
for _, line := range lines {
entry := strings.Fields(line)
if len(entry) < 4 {
continue
}
unit := Unit{
Name: entry[0],
Load: entry[1],
Active: entry[2],
Sub: entry[3],
Description: strings.Join(entry[4:], " "),
}
units = append(units, unit)
}
return units, nil
}
func GetMaskedUnits(ctx context.Context, opts Options) ([]string, error) { func GetMaskedUnits(ctx context.Context, opts Options) ([]string, error) {
args := []string{"list-unit-files", "--state=masked"} args := []string{"list-unit-files", "--state=masked"}
if opts.UserMode { if opts.UserMode {
@@ -84,6 +115,15 @@ func GetMaskedUnits(ctx context.Context, opts Options) ([]string, error) {
return units, nil return units, nil
} }
// check if systemd is the current init system
func IsSystemd() (bool, error) {
b, err := os.ReadFile("/proc/1/comm")
if err != nil {
return false, err
}
return strings.TrimSpace(string(b)) == "systemd", nil
}
// check if a service is masked // check if a service is masked
func IsMasked(ctx context.Context, unit string, opts Options) (bool, error) { func IsMasked(ctx context.Context, unit string, opts Options) (bool, error) {
units, err := GetMaskedUnits(ctx, opts) units, err := GetMaskedUnits(ctx, opts)

View File

@@ -238,6 +238,54 @@ func TestGetMemoryUsage(t *testing.T) {
}) })
} }
func TestGetUnits(t *testing.T) {
type testCase struct {
err error
runAsUser bool
opts Options
}
testCases := []testCase{{
// Run these tests only as a user
runAsUser: true,
opts: Options{UserMode: true},
err: nil,
}}
for _, tc := range testCases {
t.Run(fmt.Sprintf("as %s", userString), func(t *testing.T) {
if (userString == "root" || userString == "system") && tc.runAsUser {
t.Skip("skipping user test while running as superuser")
} else if (userString != "root" && userString != "system") && !tc.runAsUser {
t.Skip("skipping superuser test while running as user")
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
units, err := GetUnits(ctx, tc.opts)
if !errors.Is(err, tc.err) {
t.Errorf("error is %v, but should have been %v", err, tc.err)
}
if len(units) == 0 {
t.Errorf("Expected at least one unit, but got none")
}
unit := units[0]
if unit.Name == "" {
t.Errorf("Expected unit name to be non-empty, but got empty")
}
if unit.Load == "" {
t.Errorf("Expected unit load state to be non-empty, but got empty")
}
if unit.Active == "" {
t.Errorf("Expected unit active state to be non-empty, but got empty")
}
if unit.Sub == "" {
t.Errorf("Expected unit sub state to be non-empty, but got empty")
}
if unit.Description == "" {
t.Errorf("Expected unit description to be non-empty, but got empty")
}
})
}
}
func TestGetPID(t *testing.T) { func TestGetPID(t *testing.T) {
type testCase struct { type testCase struct {
unit string unit string

View File

@@ -1,5 +1,15 @@
//go:build linux
package systemctl package systemctl
type Options struct { type Options struct {
UserMode bool UserMode bool
} }
type Unit struct {
Name string
Load string
Active string
Sub string
Description string
}

View File

@@ -1,3 +1,5 @@
//go:build linux
package systemctl package systemctl
import ( import (

View File

@@ -42,7 +42,6 @@ func TestMain(m *testing.M) {
func TestDaemonReload(t *testing.T) { func TestDaemonReload(t *testing.T) {
testCases := []struct { testCases := []struct {
unit string
err error err error
opts Options opts Options
runAsUser bool runAsUser bool
@@ -50,22 +49,26 @@ func TestDaemonReload(t *testing.T) {
/* Run these tests only as a user */ /* Run these tests only as a user */
// fail to reload system daemon as user // fail to reload system daemon as user
{"", ErrInsufficientPermissions, Options{UserMode: false}, true}, {ErrInsufficientPermissions, Options{UserMode: false}, true},
// reload user's scope daemon // reload user's scope daemon
{"", nil, Options{UserMode: true}, true}, {nil, Options{UserMode: true}, true},
/* End user tests*/ /* End user tests*/
/* Run these tests only as a superuser */ /* Run these tests only as a superuser */
// succeed to reload daemon // succeed to reload daemon
{"", nil, Options{UserMode: false}, false}, {nil, Options{UserMode: false}, false},
// fail to connect to user bus as system // fail to connect to user bus as system
{"", ErrBusFailure, Options{UserMode: true}, false}, {ErrBusFailure, Options{UserMode: true}, false},
/* End superuser tests*/ /* End superuser tests*/
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(fmt.Sprintf("%s as %s", tc.unit, userString), func(t *testing.T) { mode := "user"
if tc.opts.UserMode == false {
mode = "system"
}
t.Run(fmt.Sprintf("DaemonReload as %s, %s mode", userString, mode), func(t *testing.T) {
if (userString == "root" || userString == "system") && tc.runAsUser { if (userString == "root" || userString == "system") && tc.runAsUser {
t.Skip("skipping user test while running as superuser") t.Skip("skipping user test while running as superuser")
} else if (userString != "root" && userString != "system") && !tc.runAsUser { } else if (userString != "root" && userString != "system") && !tc.runAsUser {

20
util.go
View File

@@ -1,3 +1,5 @@
//go:build linux
package systemctl package systemctl
import ( import (
@@ -52,23 +54,23 @@ func execute(ctx context.Context, args []string) (string, string, int, error) {
func filterErr(stderr string) error { func filterErr(stderr string) error {
switch { switch {
case strings.Contains(`does not exist`, stderr): case strings.Contains(stderr, `does not exist`):
return errors.Join(ErrDoesNotExist, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrDoesNotExist, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`not found.`, stderr): case strings.Contains(stderr, `not found.`):
return errors.Join(ErrDoesNotExist, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrDoesNotExist, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`not loaded.`, stderr): case strings.Contains(stderr, `not loaded.`):
return errors.Join(ErrUnitNotLoaded, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrUnitNotLoaded, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`No such file or directory`, stderr): case strings.Contains(stderr, `No such file or directory`):
return errors.Join(ErrDoesNotExist, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrDoesNotExist, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`Interactive authentication required`, stderr): case strings.Contains(stderr, `Interactive authentication required`):
return errors.Join(ErrInsufficientPermissions, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrInsufficientPermissions, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`Access denied`, stderr): case strings.Contains(stderr, `Access denied`):
return errors.Join(ErrInsufficientPermissions, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrInsufficientPermissions, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`DBUS_SESSION_BUS_ADDRESS`, stderr): case strings.Contains(stderr, `DBUS_SESSION_BUS_ADDRESS`):
return errors.Join(ErrBusFailure, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrBusFailure, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`is masked`, stderr): case strings.Contains(stderr, `is masked`):
return errors.Join(ErrMasked, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrMasked, fmt.Errorf("stderr: %s", stderr))
case strings.Contains(`Failed`, stderr): case strings.Contains(stderr, `Failed`):
return errors.Join(ErrUnspecified, fmt.Errorf("stderr: %s", stderr)) return errors.Join(ErrUnspecified, fmt.Errorf("stderr: %s", stderr))
default: default:
return nil return nil