diff --git a/main.go b/main.go index f491a5df..5d86f165 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "net" "net/url" "os" - "strings" "github.com/nats-io/gnatsd/auth" "github.com/nats-io/gnatsd/logger" @@ -138,16 +137,13 @@ func main() { // Process args looking for non-flag options, // 'version' and 'help' only for now - if len(flag.Args()) > 0 { - switch strings.ToLower(flag.Args()[0]) { - case "version": - server.PrintServerAndExit() - case "help": - usage() - default: - // Unrecognized command - usage() - } + showVersion, showHelp, err := server.ProcessCommandLineArgs(flag.CommandLine) + if err != nil { + server.PrintAndDie(err.Error() + usageStr) + } else if showVersion { + server.PrintServerAndExit() + } else if showHelp { + usage() } // Parse config if given diff --git a/server/server.go b/server/server.go index 14dd800c..ee888e16 100644 --- a/server/server.go +++ b/server/server.go @@ -6,6 +6,7 @@ import ( "bufio" "crypto/tls" "encoding/json" + "flag" "fmt" "io/ioutil" "net" @@ -13,6 +14,7 @@ import ( "os" "runtime" "strconv" + "strings" "sync" "time" @@ -180,6 +182,25 @@ func PrintServerAndExit() { os.Exit(0) } +// ProcessCommandLineArgs takes the command line arguments +// validating and setting flags for handling in case any +// sub command was present. +func ProcessCommandLineArgs(cmd *flag.FlagSet) (showVersion bool, showHelp bool, err error) { + if len(cmd.Args()) > 0 { + arg := cmd.Args()[0] + switch strings.ToLower(arg) { + case "version": + return true, false, nil + case "help": + return false, true, nil + default: + return false, false, fmt.Errorf("Unrecognized command: %q\n", arg) + } + } + + return false, false, nil +} + // Protected check on running state func (s *Server) isRunning() bool { s.mu.Lock() diff --git a/server/server_test.go b/server/server_test.go index 141302ed..18f304e7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" "time" + "flag" "github.com/nats-io/go-nats" ) @@ -246,3 +247,50 @@ func TestMaxConnections(t *testing.T) { t.Fatal("Expected connection to fail") } } + +func TestProcessCommandLineArgs(t *testing.T) { + var host string + var port int + cmd := flag.NewFlagSet("gnatsd", flag.ExitOnError) + cmd.StringVar(&host, "a", "0.0.0.0", "Host.") + cmd.IntVar(&port, "p", 4222, "Port.") + + cmd.Parse([]string{"-a", "127.0.0.1", "-p", "9090"}) + showVersion, showHelp, err := ProcessCommandLineArgs(cmd) + if err != nil { + t.Errorf("Expected no errors, got: %s", err) + } + if showVersion || showHelp { + t.Errorf("Expected not having to handle subcommands") + } + + cmd.Parse([]string{"version"}) + showVersion, showHelp, err = ProcessCommandLineArgs(cmd) + if err != nil { + t.Errorf("Expected no errors, got: %s", err) + } + if !showVersion { + t.Errorf("Expected having to handle version command") + } + if showHelp { + t.Errorf("Expected not having to handle help command") + } + + cmd.Parse([]string{"help"}) + showVersion, showHelp, err = ProcessCommandLineArgs(cmd) + if err != nil { + t.Errorf("Expected no errors, got: %s", err) + } + if showVersion { + t.Errorf("Expected not having to handle version command") + } + if !showHelp { + t.Errorf("Expected having to handle help command") + } + + cmd.Parse([]string{"foo", "-p", "9090"}) + _, _, err = ProcessCommandLineArgs(cmd) + if err == nil { + t.Errorf("Expected an error handling the command arguments") + } +}