From fc6df0fbbc74dcd7c202b47b73e36b1e306ddf69 Mon Sep 17 00:00:00 2001 From: Phil Pennock Date: Wed, 27 Oct 2021 12:44:59 -0400 Subject: [PATCH] Redact URLs before logging or returning in error (#2643) * Redact URLs before logging or returning in error This does not affect strings which failed to parse, and in such a scenario there's a mix of "which evil" to accept; we can't sanely find what should be redacted in those cases, so we leave them alone for debugging. The JWT library returns some errors for Operator URLs, but it rejects URLs which contain userinfo, so there can't be passwords in those and they're safe. Fixes #2597 * Test the URL redaction auxiliary functions * End-to-end tests for secrets in debug/trace Create internal/testhelper and move DummyLogger there, so it can be used from the test/ sub-dir too. Let DummyLogger optionally accumulate all log messages, not just retain the last-seen message. Confirm no passwords logged by TestLeafNodeBasicAuthFailover. Change TestNoPasswordsFromConnectTrace to check all trace messages, not just the most recent. Validate existing trace redaction in TestRouteToSelf. * Test for password in solicited route reconnect debug --- internal/testhelper/logging.go | 127 +++++++++++++++++++++++++++++++++ server/accounts.go | 6 +- server/client.go | 4 +- server/client_test.go | 8 +-- server/gateway_test.go | 2 +- server/leafnode.go | 3 +- server/leafnode_test.go | 24 +++++-- server/log_test.go | 77 ++++---------------- server/opts.go | 2 + server/util.go | 41 +++++++++++ server/util_test.go | 37 ++++++++++ test/routes_test.go | 22 +++++- test/test.go | 19 +++++ 13 files changed, 291 insertions(+), 81 deletions(-) create mode 100644 internal/testhelper/logging.go diff --git a/internal/testhelper/logging.go b/internal/testhelper/logging.go new file mode 100644 index 00000000..2ae9ce6b --- /dev/null +++ b/internal/testhelper/logging.go @@ -0,0 +1,127 @@ +// Copyright 2019-2021 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testhelper + +// These routines need to be accessible in both the server and test +// directories, and tests importing a package don't get exported symbols from +// _test.go files in the imported package, so we put them here where they can +// be used freely. + +import ( + "fmt" + "strings" + "sync" + "testing" +) + +type DummyLogger struct { + sync.Mutex + Msg string + AllMsgs []string +} + +func (l *DummyLogger) CheckContent(t *testing.T, expectedStr string) { + t.Helper() + l.Lock() + defer l.Unlock() + if l.Msg != expectedStr { + t.Fatalf("Expected log to be: %v, got %v", expectedStr, l.Msg) + } +} + +func (l *DummyLogger) aggregate() { + if l.AllMsgs != nil { + l.AllMsgs = append(l.AllMsgs, l.Msg) + } +} + +func (l *DummyLogger) Noticef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Errorf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Warnf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Fatalf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Debugf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Tracef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} + +// NewDummyLogger creates a dummy logger and allows to ask for logs to be +// retained instead of just keeping the most recent. Use retain to provide an +// initial size estimate on messages (not to provide a max capacity). +func NewDummyLogger(retain uint) *DummyLogger { + l := &DummyLogger{} + if retain > 0 { + l.AllMsgs = make([]string, 0, retain) + } + return l +} + +func (l *DummyLogger) Drain() { + l.Lock() + defer l.Unlock() + if l.AllMsgs == nil { + return + } + l.AllMsgs = make([]string, 0, len(l.AllMsgs)) +} + +func (l *DummyLogger) CheckForProhibited(t *testing.T, reason, needle string) { + t.Helper() + l.Lock() + defer l.Unlock() + + if l.AllMsgs == nil { + t.Fatal("DummyLogger.CheckForProhibited called without AllMsgs being collected") + } + + // Collect _all_ matches, rather than have to re-test repeatedly. + // This will particularly help with less deterministic tests with multiple matches. + shouldFail := false + for i := range l.AllMsgs { + if strings.Contains(l.AllMsgs[i], needle) { + t.Errorf("log contains %s: %v", reason, l.AllMsgs[i]) + shouldFail = true + } + } + if shouldFail { + t.FailNow() + } +} diff --git a/server/accounts.go b/server/accounts.go index 545cff60..f6292b6e 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -3473,11 +3473,11 @@ func (ur *URLAccResolver) Fetch(name string) (string, error) { url := ur.url + name resp, err := ur.c.Get(url) if err != nil { - return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", url, err) + return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", redactURLString(url), err) } else if resp == nil { - return _EMPTY_, fmt.Errorf("could not fetch <%q>: no response", url) + return _EMPTY_, fmt.Errorf("could not fetch <%q>: no response", redactURLString(url)) } else if resp.StatusCode != http.StatusOK { - return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", url, resp.Status) + return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", redactURLString(url), resp.Status) } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) diff --git a/server/client.go b/server/client.go index a5bef56d..70ba50f3 100644 --- a/server/client.go +++ b/server/client.go @@ -4826,10 +4826,10 @@ func (c *client) reconnect() { srv.Debugf("Not attempting reconnect for solicited route, already connected to \"%s\"", rid) return } else if rid == srv.info.ID { - srv.Debugf("Detected route to self, ignoring %q", rurl) + srv.Debugf("Detected route to self, ignoring %q", rurl.Redacted()) return } else if rtype != Implicit || retryImplicit { - srv.Debugf("Attempting reconnect for solicited route \"%s\"", rurl) + srv.Debugf("Attempting reconnect for solicited route \"%s\"", rurl.Redacted()) // Keep track of this go-routine so we can wait for it on // server shutdown. srv.startGoRoutine(func() { srv.reConnectToRoute(rurl, rtype) }) diff --git a/server/client_test.go b/server/client_test.go index a233d26c..dfe03843 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -1819,7 +1819,7 @@ func TestTraceMsg(t *testing.T) { c.traceMsg(ut.Msg) - got := c.srv.logging.logger.(*DummyLogger).msg + got := c.srv.logging.logger.(*DummyLogger).Msg if !reflect.DeepEqual(ut.Wanted, got) { t.Errorf("Desc: %s. Msg %q. Traced msg want: %s, got: %s", ut.Desc, ut.Msg, ut.Wanted, got) } @@ -2427,15 +2427,15 @@ func TestClientConnectionName(t *testing.T) { checkLog := func(suffix string) { t.Helper() l.Lock() - msg := l.msg + msg := l.Msg l.Unlock() if strings.Contains(msg, "(MISSING)") { t.Fatalf("conn name was not escaped properly, got MISSING: %s", msg) } - if !strings.Contains(l.msg, test.kindStr) { + if !strings.Contains(l.Msg, test.kindStr) { t.Fatalf("expected kind to be %q, got: %s", test.kindStr, msg) } - if !strings.HasSuffix(l.msg, suffix) { + if !strings.HasSuffix(l.Msg, suffix) { t.Fatalf("expected statement to end with %q, got %s", suffix, msg) } } diff --git a/server/gateway_test.go b/server/gateway_test.go index 4d4d6f7e..4280ef23 100644 --- a/server/gateway_test.go +++ b/server/gateway_test.go @@ -2946,7 +2946,7 @@ type checkErrorLogger struct { func (l *checkErrorLogger) Errorf(format string, args ...interface{}) { l.DummyLogger.Errorf(format, args...) l.Lock() - if strings.Contains(l.msg, l.checkErrorStr) { + if strings.Contains(l.Msg, l.checkErrorStr) { l.gotError = true } l.Unlock() diff --git a/server/leafnode.go b/server/leafnode.go index dd0616d0..6bb3cecf 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -320,7 +320,7 @@ func validateLeafNode(o *Options) error { } } if !ok { - return fmt.Errorf("remote leaf node configuration cannot have a mix of websocket and non-websocket urls: %q", rcfg.URLs) + return fmt.Errorf("remote leaf node configuration cannot have a mix of websocket and non-websocket urls: %q", redactURLList(rcfg.URLs)) } } } @@ -1187,6 +1187,7 @@ func (c *client) doUpdateLNURLs(cfg *leafNodeCfg, scheme string, URLs []string) for _, surl := range URLs { url, err := url.Parse(fmt.Sprintf("%s://%s", scheme, surl)) if err != nil { + // As per below, the URLs we receive should not have contained URL info, so this should be safe to log. c.Errorf("Error parsing url %q: %v", surl, err) continue } diff --git a/server/leafnode_test.go b/server/leafnode_test.go index ebc6f4dc..5d1e19b5 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -32,6 +32,8 @@ import ( jwt "github.com/nats-io/jwt/v2" "github.com/nats-io/nats.go" + + "github.com/nats-io/nats-server/v2/internal/testhelper" ) type captureLeafNodeRandomIPLogger struct { @@ -409,7 +411,12 @@ func TestLeafNodeAccountNotFound(t *testing.T) { // This test ensures that we can connect using proper user/password // to a LN URL that was discovered through the INFO protocol. +// We also check that the password doesn't leak to debug/trace logs. func TestLeafNodeBasicAuthFailover(t *testing.T) { + // Something a little longer than "pwd" to prevent false positives amongst many log lines; + // don't make it complex enough to be subject to %-escaping, we want a simple needle search. + fatalPassword := "pwdfatal" + content := ` listen: "127.0.0.1:-1" cluster { @@ -421,18 +428,18 @@ func TestLeafNodeBasicAuthFailover(t *testing.T) { listen: "127.0.0.1:-1" authorization { user: foo - password: pwd + password: %s timeout: 1 } } ` - conf := createConfFile(t, []byte(fmt.Sprintf(content, ""))) + conf := createConfFile(t, []byte(fmt.Sprintf(content, "", fatalPassword))) defer removeFile(t, conf) sb1, ob1 := RunServerWithConfig(conf) defer sb1.Shutdown() - conf = createConfFile(t, []byte(fmt.Sprintf(content, fmt.Sprintf("routes: [nats://127.0.0.1:%d]", ob1.Cluster.Port)))) + conf = createConfFile(t, []byte(fmt.Sprintf(content, fmt.Sprintf("routes: [nats://127.0.0.1:%d]", ob1.Cluster.Port), fatalPassword))) defer removeFile(t, conf) sb2, _ := RunServerWithConfig(conf) @@ -450,17 +457,20 @@ func TestLeafNodeBasicAuthFailover(t *testing.T) { remotes [ { account: "foo" - url: "nats://foo:pwd@127.0.0.1:%d" + url: "nats://foo:%s@127.0.0.1:%d" } ] } ` - conf = createConfFile(t, []byte(fmt.Sprintf(content, ob1.LeafNode.Port))) + conf = createConfFile(t, []byte(fmt.Sprintf(content, fatalPassword, ob1.LeafNode.Port))) defer removeFile(t, conf) sa, _ := RunServerWithConfig(conf) defer sa.Shutdown() + l := testhelper.NewDummyLogger(100) + sa.SetLogger(l, true, true) // we want debug & trace logs, to check for passwords in them + checkLeafNodeConnected(t, sa) // Shutdown sb1, sa should reconnect to sb2 @@ -471,6 +481,10 @@ func TestLeafNodeBasicAuthFailover(t *testing.T) { // Should be able to reconnect checkLeafNodeConnected(t, sa) + + // Look at all our logs for the password; at time of writing it doesn't appear + // but we want to safe-guard against it. + l.CheckForProhibited(t, "fatal password", fatalPassword) } func TestLeafNodeRTT(t *testing.T) { diff --git a/server/log_test.go b/server/log_test.go index 2014ea62..23e20338 100644 --- a/server/log_test.go +++ b/server/log_test.go @@ -15,15 +15,14 @@ package server import ( "bytes" - "fmt" "io/ioutil" "os" "runtime" "strings" - "sync" "testing" "time" + "github.com/nats-io/nats-server/v2/internal/testhelper" "github.com/nats-io/nats-server/v2/logger" ) @@ -47,22 +46,22 @@ func TestSetLogger(t *testing.T) { // Check traces expectedStr := "This is a Notice" server.Noticef(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is an Error" server.Errorf(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Fatal" server.Fatalf(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Debug" server.Debugf(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Trace" server.Tracef(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Warning" server.Tracef(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) // Make sure that we can reset to fal server.SetLogger(dl, false, false) @@ -73,56 +72,14 @@ func TestSetLogger(t *testing.T) { t.Fatalf("Expected trace 0, got %v", server.logging.trace) } // Now, Debug and Trace should not produce anything - dl.msg = "" + dl.Msg = "" server.Debugf("This Debug should not be traced") - dl.checkContent(t, "") + dl.CheckContent(t, "") server.Tracef("This Trace should not be traced") - dl.checkContent(t, "") + dl.CheckContent(t, "") } -type DummyLogger struct { - sync.Mutex - msg string -} - -func (l *DummyLogger) checkContent(t *testing.T, expectedStr string) { - l.Lock() - defer l.Unlock() - if l.msg != expectedStr { - stackFatalf(t, "Expected log to be: %v, got %v", expectedStr, l.msg) - } -} - -func (l *DummyLogger) Noticef(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Errorf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Warnf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Fatalf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Debugf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Tracef(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} +type DummyLogger = testhelper.DummyLogger func TestReOpenLogFile(t *testing.T) { // We can't rename the file log when still opened on Windows, so skip @@ -140,7 +97,7 @@ func TestReOpenLogFile(t *testing.T) { dl := &DummyLogger{} s.SetLogger(dl, false, false) s.ReOpenLogFile() - dl.checkContent(t, "File log re-open ignored, not a file logger") + dl.CheckContent(t, "File log re-open ignored, not a file logger") // Set a File log s.opts.LogFile = "test.log" @@ -247,7 +204,7 @@ func TestNoPasswordsFromConnectTrace(t *testing.T) { opts.PingInterval = 2 * time.Minute setBaselineOptions(opts) s := &Server{opts: opts} - dl := &DummyLogger{} + dl := testhelper.NewDummyLogger(100) s.SetLogger(dl, false, true) _ = s.logging.logger.(*DummyLogger) @@ -265,13 +222,7 @@ func TestNoPasswordsFromConnectTrace(t *testing.T) { t.Fatalf("Received error: %v\n", err) } - dl.Lock() - hasPass := strings.Contains(dl.msg, "s3cr3t") - dl.Unlock() - - if hasPass { - t.Fatalf("Password detected in log output: %s", dl.msg) - } + dl.CheckForProhibited(t, "password found", "s3cr3t") } func TestRemovePassFromTrace(t *testing.T) { diff --git a/server/opts.go b/server/opts.go index e733a495..5543ca84 100644 --- a/server/opts.go +++ b/server/opts.go @@ -1472,6 +1472,8 @@ func parseURL(u string, typ string) (*url.URL, error) { urlStr := strings.TrimSpace(u) url, err := url.Parse(urlStr) if err != nil { + // Security note: if it's not well-formed but still reached us, then we're going to log as-is which might include password information here. + // If the URL parses, we don't log the credentials ever, but if it doesn't even parse we don't have a sane way to redact. return nil, fmt.Errorf("error parsing %s url [%q]", typ, urlStr) } return url, nil diff --git a/server/util.go b/server/util.go index 88b3f635..3113bd80 100644 --- a/server/util.go +++ b/server/util.go @@ -221,3 +221,44 @@ func natsDialTimeout(network, address string, timeout time.Duration) (net.Conn, } return d.Dial(network, address) } + +// redactURLList() returns a copy of a list of URL pointers where each item +// in the list will either be the same pointer if the URL does not contain a +// password, or to a new object if there is a password. +// The intended use-case is for logging lists of URLs safely. +func redactURLList(unredacted []*url.URL) []*url.URL { + r := make([]*url.URL, len(unredacted)) + // In the common case of no passwords, if we don't let the new object leave + // this function then GC should be easier. + needCopy := false + for i := range unredacted { + if unredacted[i] == nil { + r[i] = nil + continue + } + if _, has := unredacted[i].User.Password(); !has { + r[i] = unredacted[i] + continue + } + needCopy = true + ru := *unredacted[i] + ru.User = url.UserPassword(ru.User.Username(), "xxxxx") + r[i] = &ru + } + if needCopy { + return r + } + return unredacted +} + +// redactURLString() attempts to redact a URL string. +func redactURLString(raw string) string { + if !strings.ContainsRune(raw, '@') { + return raw + } + u, err := url.Parse(raw) + if err != nil { + return raw + } + return u.Redacted() +} diff --git a/server/util_test.go b/server/util_test.go index ec84315b..23a47e72 100644 --- a/server/util_test.go +++ b/server/util_test.go @@ -17,6 +17,7 @@ import ( "math" "math/rand" "net/url" + "reflect" "strconv" "sync" "testing" @@ -147,12 +148,48 @@ func TestComma(t *testing.T) { {"-10", comma(-10), "-10"}, } + failed := false for _, test := range l { if test.got != test.exp { t.Errorf("On %v, expected '%v', but got '%v'", test.name, test.exp, test.got) + failed = true } } + if failed { + t.FailNow() + } +} + +func TestURLRedaction(t *testing.T) { + redactionFromTo := []struct { + Full string + Safe string + }{ + {"nats://foo:bar@example.org", "nats://foo:xxxxx@example.org"}, + {"nats://foo@example.org", "nats://foo@example.org"}, + {"nats://example.org", "nats://example.org"}, + {"nats://example.org/foo?bar=1", "nats://example.org/foo?bar=1"}, + } + var err error + listFull := make([]*url.URL, len(redactionFromTo)) + listSafe := make([]*url.URL, len(redactionFromTo)) + for i := range redactionFromTo { + r := redactURLString(redactionFromTo[i].Full) + if r != redactionFromTo[i].Safe { + t.Fatalf("Redacting URL [index %d] %q, expected %q got %q", i, redactionFromTo[i].Full, redactionFromTo[i].Safe, r) + } + if listFull[i], err = url.Parse(redactionFromTo[i].Full); err != nil { + t.Fatalf("Redacting URL index %d parse Full failed: %v", i, err) + } + if listSafe[i], err = url.Parse(redactionFromTo[i].Safe); err != nil { + t.Fatalf("Redacting URL index %d parse Safe failed: %v", i, err) + } + } + results := redactURLList(listFull) + if !reflect.DeepEqual(results, listSafe) { + t.Fatalf("Redacting URL list did not compare equal, even after each URL did") + } } func BenchmarkParseInt(b *testing.B) { diff --git a/test/routes_test.go b/test/routes_test.go index 27f9c824..982ba398 100644 --- a/test/routes_test.go +++ b/test/routes_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/nats-io/nats-server/v2/internal/testhelper" "github.com/nats-io/nats-server/v2/server" "github.com/nats-io/nats.go" ) @@ -34,6 +35,10 @@ func runRouteServer(t *testing.T) (*server.Server, *server.Options) { return RunServerWithConfig("./configs/cluster.conf") } +func runRouteServerOverrides(t *testing.T, cbo func(*server.Options), cbs func(*server.Server)) (*server.Server, *server.Options) { + return RunServerWithConfigOverrides("./configs/cluster.conf", cbo, cbs) +} + func TestRouterListeningSocket(t *testing.T) { s, opts := runRouteServer(t) defer s.Shutdown() @@ -93,7 +98,11 @@ func TestSendRouteInfoOnConnect(t *testing.T) { } func TestRouteToSelf(t *testing.T) { - s, opts := runRouteServer(t) + l := testhelper.NewDummyLogger(100) + s, opts := runRouteServerOverrides(t, nil, + func(s *server.Server) { + s.SetLogger(l, true, true) + }) defer s.Shutdown() rc := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) @@ -123,6 +132,8 @@ func TestRouteToSelf(t *testing.T) { if _, err := rc.Read(buf); err == nil { t.Fatal("Expected route connection to be closed") } + // This should have been removed by removePassFromTrace(), but we also check debug logs here + l.CheckForProhibited(t, "route authorization password found", "top_secret") } func TestSendRouteSubAndUnsub(t *testing.T) { @@ -375,7 +386,11 @@ func TestRouteQueueSemantics(t *testing.T) { } func TestSolicitRouteReconnect(t *testing.T) { - s, opts := runRouteServer(t) + l := testhelper.NewDummyLogger(100) + s, opts := runRouteServerOverrides(t, nil, + func(s *server.Server) { + s.SetLogger(l, true, true) + }) defer s.Shutdown() rURL := opts.Routes[0] @@ -388,6 +403,9 @@ func TestSolicitRouteReconnect(t *testing.T) { // We expect to get called back.. route = acceptRouteConn(t, rURL.Host, 2*server.DEFAULT_ROUTE_CONNECT) route.Close() + + // Now we want to check for the debug logs when it tries to reconnect + l.CheckForProhibited(t, "route authorization password found", ":bar") } func TestMultipleRoutesSameId(t *testing.T) { diff --git a/test/test.go b/test/test.go index a872b04b..b27a6605 100644 --- a/test/test.go +++ b/test/test.go @@ -72,6 +72,10 @@ var ( // RunServer starts a new Go routine based server func RunServer(opts *server.Options) *server.Server { + return RunServerCallback(opts, nil) +} + +func RunServerCallback(opts *server.Options, callback func(*server.Server)) *server.Server { if opts == nil { opts = &DefaultTestOptions } @@ -89,6 +93,10 @@ func RunServer(opts *server.Options) *server.Server { s.ConfigureLogger() } + if callback != nil { + callback(s) + } + // Run server in Go routine. go s.Start() @@ -115,6 +123,17 @@ func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Op return } +// RunServerWithConfigOverrides starts a new Go routine based server with a configuration file, +// providing a callback to update the options configured. +func RunServerWithConfigOverrides(configFile string, optsCallback func(*server.Options), svrCallback func(*server.Server)) (srv *server.Server, opts *server.Options) { + opts = LoadConfig(configFile) + if optsCallback != nil { + optsCallback(opts) + } + srv = RunServerCallback(opts, svrCallback) + return +} + func stackFatalf(t tLogger, f string, args ...interface{}) { lines := make([]string, 0, 32) msg := fmt.Sprintf(f, args...)