From 82f92e09395ea3706971f0ca4db7467e65ce2a91 Mon Sep 17 00:00:00 2001 From: Tyler Treat Date: Tue, 27 Jun 2017 11:25:34 -0500 Subject: [PATCH] Refactor signalling and add tests --- main.go | 40 ++++---- server/const.go | 11 +++ server/service_test.go | 22 +++++ server/signal.go | 28 ++++-- server/signal_test.go | 194 +++++++++++++++++++++++++++++++++++++++ server/signal_windows.go | 8 +- 6 files changed, 271 insertions(+), 32 deletions(-) create mode 100644 server/service_test.go diff --git a/main.go b/main.go index d4c1e849..4a38c630 100644 --- a/main.go +++ b/main.go @@ -159,25 +159,9 @@ func main() { // Snapshot flag options. server.FlagSnapshot = opts.Clone() - // Process signal. + // Process signal control. if signal != "" { - var ( - pid = -1 - commandAndPid = strings.Split(signal, "=") - ) - if l := len(commandAndPid); l == 2 { - p, err := strconv.Atoi(commandAndPid[1]) - if err != nil { - usage() - } - pid = p - } else if l > 2 { - usage() - } - if err := server.ProcessSignal(commandAndPid[0], pid); err != nil { - server.PrintAndDie(err.Error()) - } - os.Exit(0) + processSignal(signal) } // Parse config if given @@ -295,3 +279,23 @@ func configureClusterOpts(opts *server.Options) error { return nil } + +func processSignal(signal string) { + var ( + pid = -1 + commandAndPid = strings.Split(signal, "=") + ) + if l := len(commandAndPid); l == 2 { + p, err := strconv.Atoi(commandAndPid[1]) + if err != nil { + usage() + } + pid = p + } else if l > 2 { + usage() + } + if err := server.ProcessSignal(server.Command(commandAndPid[0]), pid); err != nil { + server.PrintAndDie(err.Error()) + } + os.Exit(0) +} diff --git a/server/const.go b/server/const.go index f4336f3d..ab8558f3 100644 --- a/server/const.go +++ b/server/const.go @@ -6,6 +6,17 @@ import ( "time" ) +// Command is a signal used to control a running gnatsd process. +type Command string + +// Valid Command values. +const ( + CommandStop = Command("stop") + CommandQuit = Command("quit") + CommandReopen = Command("reopen") + CommandReload = Command("reload") +) + const ( // VERSION is the current version for the server. VERSION = "0.9.6" diff --git a/server/service_test.go b/server/service_test.go new file mode 100644 index 00000000..da1d632e --- /dev/null +++ b/server/service_test.go @@ -0,0 +1,22 @@ +// +build !windows +// Copyright 2012-2017 Apcera Inc. All rights reserved. + +package server + +import ( + "testing" + "time" +) + +func TestRun(t *testing.T) { + s := New(DefaultOptions()) + go func() { + if !s.ReadyForConnections(time.Second) { + t.Fatal("Failed to start server in time") + } + s.Shutdown() + }() + if err := Run(s); err != nil { + t.Fatalf("Run failed: %v", err) + } +} diff --git a/server/signal.go b/server/signal.go index 8c5d0caa..8f0e2090 100644 --- a/server/signal.go +++ b/server/signal.go @@ -48,7 +48,7 @@ func (s *Server) handleSignals() { // ProcessSignal sends the given signal command to the given process. If pid is // -1, this will send the signal to the single running instance of gnatsd. If // multiple instances are running, it returns an error. -func ProcessSignal(command string, pid int) (err error) { +func ProcessSignal(command Command, pid int) (err error) { if pid == -1 { pids, err := resolvePids() if err != nil { @@ -70,14 +70,14 @@ func ProcessSignal(command string, pid int) (err error) { } switch command { - case "stop": - err = syscall.Kill(pid, syscall.SIGKILL) - case "quit": - err = syscall.Kill(pid, syscall.SIGINT) - case "reopen": - err = syscall.Kill(pid, syscall.SIGUSR1) - case "reload": - err = syscall.Kill(pid, syscall.SIGHUP) + case CommandStop: + err = kill(pid, syscall.SIGKILL) + case CommandQuit: + err = kill(pid, syscall.SIGINT) + case CommandReopen: + err = kill(pid, syscall.SIGUSR1) + case CommandReload: + err = kill(pid, syscall.SIGHUP) default: err = fmt.Errorf("unknown signal %q", command) } @@ -88,7 +88,7 @@ func ProcessSignal(command string, pid int) (err error) { func resolvePids() ([]int, error) { // If pgrep isn't available, this will just bail out and the user will be // required to specify a pid. - output, err := exec.Command("pgrep", processName).Output() + output, err := pgrep() if err != nil { switch err.(type) { case *exec.ExitError: @@ -120,3 +120,11 @@ func resolvePids() ([]int, error) { } return pids, nil } + +var kill = func(pid int, signal syscall.Signal) error { + return syscall.Kill(pid, signal) +} + +var pgrep = func() ([]byte, error) { + return exec.Command("pgrep", processName).Output() +} diff --git a/server/signal_test.go b/server/signal_test.go index 16d55306..2da2bde9 100644 --- a/server/signal_test.go +++ b/server/signal_test.go @@ -4,6 +4,8 @@ package server import ( + "errors" + "fmt" "io/ioutil" "os" "strings" @@ -81,3 +83,195 @@ func TestSignalToReloadConfig(t *testing.T) { t.Fatalf("Reloaded is incorrect.\nexpected: 1\ngot: %d", reloaded) } } + +func TestProcessSignalNoProcesses(t *testing.T) { + err := ProcessSignal(CommandStop, -1) + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "no gnatsd processes running" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalMultipleProcesses(t *testing.T) { + pid := os.Getpid() + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte(fmt.Sprintf("123\n456\n%d\n", pid)), nil + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, -1) + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "multiple gnatsd processes running:\n123\n456" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalPgrepError(t *testing.T) { + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return nil, errors.New("error") + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, -1) + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "unable to resolve pid, try providing one" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalPgrepMangled(t *testing.T) { + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte("12x"), nil + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, -1) + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "unable to resolve pid, try providing one" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalResolveSingleProcess(t *testing.T) { + pid := os.Getpid() + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte(fmt.Sprintf("123\n%d\n", pid)), nil + } + defer func() { + pgrep = pgrepBefore + }() + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGKILL { + t.Fatalf("signal is incorrect.\nexpected: killed\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandStop, -1); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} + +func TestProcessSignalInvalidCommand(t *testing.T) { + err := ProcessSignal(Command("invalid"), 123) + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "unknown signal \"invalid\"" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalQuitProcess(t *testing.T) { + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGINT { + t.Fatalf("signal is incorrect.\nexpected: interrupt\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandQuit, 123); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} + +func TestProcessSignalReopenProcess(t *testing.T) { + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGUSR1 { + t.Fatalf("signal is incorrect.\nexpected: user defined signal 1\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandReopen, 123); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} + +func TestProcessSignalReloadProcess(t *testing.T) { + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGHUP { + t.Fatalf("signal is incorrect.\nexpected: hangup\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandReload, 123); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} diff --git a/server/signal_windows.go b/server/signal_windows.go index e8f81c10..f4780b35 100644 --- a/server/signal_windows.go +++ b/server/signal_windows.go @@ -34,7 +34,7 @@ func (s *Server) handleSignals() { // ProcessSignal sends the given signal command to the running gnatsd service. // If pid is not -1 or if there is no gnatsd service running, it returns an // error. -func ProcessSignal(command string, pid int) error { +func ProcessSignal(command Command, pid int) error { if pid != -1 { return errors.New("cannot signal pid on Windows") } @@ -57,13 +57,13 @@ func ProcessSignal(command string, pid int) error { ) switch command { - case "stop", "quit": + case CommandStop, CommandQuit: cmd = svc.Stop to = svc.Stopped - case "reopen": + case CommandReopen: cmd = reopenLogCmd to = svc.Running - case "reload": + case CommandReload: cmd = svc.ParamChange to = svc.Running default: