Refactor signalling and add tests

This commit is contained in:
Tyler Treat
2017-06-27 11:25:34 -05:00
parent 4b05302e93
commit 82f92e0939
6 changed files with 271 additions and 32 deletions

40
main.go
View File

@@ -159,25 +159,9 @@ func main() {
// Snapshot flag options.
server.FlagSnapshot = opts.Clone()
// Process signal.
// Process signal control.
if signal != "" {
var (
pid = -1
commandAndPid = strings.Split(signal, "=")
)
if l := len(commandAndPid); l == 2 {
p, err := strconv.Atoi(commandAndPid[1])
if err != nil {
usage()
}
pid = p
} else if l > 2 {
usage()
}
if err := server.ProcessSignal(commandAndPid[0], pid); err != nil {
server.PrintAndDie(err.Error())
}
os.Exit(0)
processSignal(signal)
}
// Parse config if given
@@ -295,3 +279,23 @@ func configureClusterOpts(opts *server.Options) error {
return nil
}
func processSignal(signal string) {
var (
pid = -1
commandAndPid = strings.Split(signal, "=")
)
if l := len(commandAndPid); l == 2 {
p, err := strconv.Atoi(commandAndPid[1])
if err != nil {
usage()
}
pid = p
} else if l > 2 {
usage()
}
if err := server.ProcessSignal(server.Command(commandAndPid[0]), pid); err != nil {
server.PrintAndDie(err.Error())
}
os.Exit(0)
}

View File

@@ -6,6 +6,17 @@ import (
"time"
)
// Command is a signal used to control a running gnatsd process.
type Command string
// Valid Command values.
const (
CommandStop = Command("stop")
CommandQuit = Command("quit")
CommandReopen = Command("reopen")
CommandReload = Command("reload")
)
const (
// VERSION is the current version for the server.
VERSION = "0.9.6"

22
server/service_test.go Normal file
View File

@@ -0,0 +1,22 @@
// +build !windows
// Copyright 2012-2017 Apcera Inc. All rights reserved.
package server
import (
"testing"
"time"
)
func TestRun(t *testing.T) {
s := New(DefaultOptions())
go func() {
if !s.ReadyForConnections(time.Second) {
t.Fatal("Failed to start server in time")
}
s.Shutdown()
}()
if err := Run(s); err != nil {
t.Fatalf("Run failed: %v", err)
}
}

View File

@@ -48,7 +48,7 @@ func (s *Server) handleSignals() {
// ProcessSignal sends the given signal command to the given process. If pid is
// -1, this will send the signal to the single running instance of gnatsd. If
// multiple instances are running, it returns an error.
func ProcessSignal(command string, pid int) (err error) {
func ProcessSignal(command Command, pid int) (err error) {
if pid == -1 {
pids, err := resolvePids()
if err != nil {
@@ -70,14 +70,14 @@ func ProcessSignal(command string, pid int) (err error) {
}
switch command {
case "stop":
err = syscall.Kill(pid, syscall.SIGKILL)
case "quit":
err = syscall.Kill(pid, syscall.SIGINT)
case "reopen":
err = syscall.Kill(pid, syscall.SIGUSR1)
case "reload":
err = syscall.Kill(pid, syscall.SIGHUP)
case CommandStop:
err = kill(pid, syscall.SIGKILL)
case CommandQuit:
err = kill(pid, syscall.SIGINT)
case CommandReopen:
err = kill(pid, syscall.SIGUSR1)
case CommandReload:
err = kill(pid, syscall.SIGHUP)
default:
err = fmt.Errorf("unknown signal %q", command)
}
@@ -88,7 +88,7 @@ func ProcessSignal(command string, pid int) (err error) {
func resolvePids() ([]int, error) {
// If pgrep isn't available, this will just bail out and the user will be
// required to specify a pid.
output, err := exec.Command("pgrep", processName).Output()
output, err := pgrep()
if err != nil {
switch err.(type) {
case *exec.ExitError:
@@ -120,3 +120,11 @@ func resolvePids() ([]int, error) {
}
return pids, nil
}
var kill = func(pid int, signal syscall.Signal) error {
return syscall.Kill(pid, signal)
}
var pgrep = func() ([]byte, error) {
return exec.Command("pgrep", processName).Output()
}

View File

@@ -4,6 +4,8 @@
package server
import (
"errors"
"fmt"
"io/ioutil"
"os"
"strings"
@@ -81,3 +83,195 @@ func TestSignalToReloadConfig(t *testing.T) {
t.Fatalf("Reloaded is incorrect.\nexpected: 1\ngot: %d", reloaded)
}
}
func TestProcessSignalNoProcesses(t *testing.T) {
err := ProcessSignal(CommandStop, -1)
if err == nil {
t.Fatal("Expected error")
}
expectedStr := "no gnatsd processes running"
if err.Error() != expectedStr {
t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error())
}
}
func TestProcessSignalMultipleProcesses(t *testing.T) {
pid := os.Getpid()
pgrepBefore := pgrep
pgrep = func() ([]byte, error) {
return []byte(fmt.Sprintf("123\n456\n%d\n", pid)), nil
}
defer func() {
pgrep = pgrepBefore
}()
err := ProcessSignal(CommandStop, -1)
if err == nil {
t.Fatal("Expected error")
}
expectedStr := "multiple gnatsd processes running:\n123\n456"
if err.Error() != expectedStr {
t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error())
}
}
func TestProcessSignalPgrepError(t *testing.T) {
pgrepBefore := pgrep
pgrep = func() ([]byte, error) {
return nil, errors.New("error")
}
defer func() {
pgrep = pgrepBefore
}()
err := ProcessSignal(CommandStop, -1)
if err == nil {
t.Fatal("Expected error")
}
expectedStr := "unable to resolve pid, try providing one"
if err.Error() != expectedStr {
t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error())
}
}
func TestProcessSignalPgrepMangled(t *testing.T) {
pgrepBefore := pgrep
pgrep = func() ([]byte, error) {
return []byte("12x"), nil
}
defer func() {
pgrep = pgrepBefore
}()
err := ProcessSignal(CommandStop, -1)
if err == nil {
t.Fatal("Expected error")
}
expectedStr := "unable to resolve pid, try providing one"
if err.Error() != expectedStr {
t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error())
}
}
func TestProcessSignalResolveSingleProcess(t *testing.T) {
pid := os.Getpid()
pgrepBefore := pgrep
pgrep = func() ([]byte, error) {
return []byte(fmt.Sprintf("123\n%d\n", pid)), nil
}
defer func() {
pgrep = pgrepBefore
}()
killBefore := kill
called := false
kill = func(pid int, signal syscall.Signal) error {
called = true
if pid != 123 {
t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid)
}
if signal != syscall.SIGKILL {
t.Fatalf("signal is incorrect.\nexpected: killed\ngot: %v", signal)
}
return nil
}
defer func() {
kill = killBefore
}()
if err := ProcessSignal(CommandStop, -1); err != nil {
t.Fatalf("ProcessSignal failed: %v", err)
}
if !called {
t.Fatal("Expected kill to be called")
}
}
func TestProcessSignalInvalidCommand(t *testing.T) {
err := ProcessSignal(Command("invalid"), 123)
if err == nil {
t.Fatal("Expected error")
}
expectedStr := "unknown signal \"invalid\""
if err.Error() != expectedStr {
t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error())
}
}
func TestProcessSignalQuitProcess(t *testing.T) {
killBefore := kill
called := false
kill = func(pid int, signal syscall.Signal) error {
called = true
if pid != 123 {
t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid)
}
if signal != syscall.SIGINT {
t.Fatalf("signal is incorrect.\nexpected: interrupt\ngot: %v", signal)
}
return nil
}
defer func() {
kill = killBefore
}()
if err := ProcessSignal(CommandQuit, 123); err != nil {
t.Fatalf("ProcessSignal failed: %v", err)
}
if !called {
t.Fatal("Expected kill to be called")
}
}
func TestProcessSignalReopenProcess(t *testing.T) {
killBefore := kill
called := false
kill = func(pid int, signal syscall.Signal) error {
called = true
if pid != 123 {
t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid)
}
if signal != syscall.SIGUSR1 {
t.Fatalf("signal is incorrect.\nexpected: user defined signal 1\ngot: %v", signal)
}
return nil
}
defer func() {
kill = killBefore
}()
if err := ProcessSignal(CommandReopen, 123); err != nil {
t.Fatalf("ProcessSignal failed: %v", err)
}
if !called {
t.Fatal("Expected kill to be called")
}
}
func TestProcessSignalReloadProcess(t *testing.T) {
killBefore := kill
called := false
kill = func(pid int, signal syscall.Signal) error {
called = true
if pid != 123 {
t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid)
}
if signal != syscall.SIGHUP {
t.Fatalf("signal is incorrect.\nexpected: hangup\ngot: %v", signal)
}
return nil
}
defer func() {
kill = killBefore
}()
if err := ProcessSignal(CommandReload, 123); err != nil {
t.Fatalf("ProcessSignal failed: %v", err)
}
if !called {
t.Fatal("Expected kill to be called")
}
}

View File

@@ -34,7 +34,7 @@ func (s *Server) handleSignals() {
// ProcessSignal sends the given signal command to the running gnatsd service.
// If pid is not -1 or if there is no gnatsd service running, it returns an
// error.
func ProcessSignal(command string, pid int) error {
func ProcessSignal(command Command, pid int) error {
if pid != -1 {
return errors.New("cannot signal pid on Windows")
}
@@ -57,13 +57,13 @@ func ProcessSignal(command string, pid int) error {
)
switch command {
case "stop", "quit":
case CommandStop, CommandQuit:
cmd = svc.Stop
to = svc.Stopped
case "reopen":
case CommandReopen:
cmd = reopenLogCmd
to = svc.Running
case "reload":
case CommandReload:
cmd = svc.ParamChange
to = svc.Running
default: