diff --git a/internal/testhelper/logging.go b/internal/testhelper/logging.go new file mode 100644 index 00000000..2ae9ce6b --- /dev/null +++ b/internal/testhelper/logging.go @@ -0,0 +1,127 @@ +// Copyright 2019-2021 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testhelper + +// These routines need to be accessible in both the server and test +// directories, and tests importing a package don't get exported symbols from +// _test.go files in the imported package, so we put them here where they can +// be used freely. + +import ( + "fmt" + "strings" + "sync" + "testing" +) + +type DummyLogger struct { + sync.Mutex + Msg string + AllMsgs []string +} + +func (l *DummyLogger) CheckContent(t *testing.T, expectedStr string) { + t.Helper() + l.Lock() + defer l.Unlock() + if l.Msg != expectedStr { + t.Fatalf("Expected log to be: %v, got %v", expectedStr, l.Msg) + } +} + +func (l *DummyLogger) aggregate() { + if l.AllMsgs != nil { + l.AllMsgs = append(l.AllMsgs, l.Msg) + } +} + +func (l *DummyLogger) Noticef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Errorf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Warnf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Fatalf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Debugf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} +func (l *DummyLogger) Tracef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.Msg = fmt.Sprintf(format, v...) + l.aggregate() +} + +// NewDummyLogger creates a dummy logger and allows to ask for logs to be +// retained instead of just keeping the most recent. Use retain to provide an +// initial size estimate on messages (not to provide a max capacity). +func NewDummyLogger(retain uint) *DummyLogger { + l := &DummyLogger{} + if retain > 0 { + l.AllMsgs = make([]string, 0, retain) + } + return l +} + +func (l *DummyLogger) Drain() { + l.Lock() + defer l.Unlock() + if l.AllMsgs == nil { + return + } + l.AllMsgs = make([]string, 0, len(l.AllMsgs)) +} + +func (l *DummyLogger) CheckForProhibited(t *testing.T, reason, needle string) { + t.Helper() + l.Lock() + defer l.Unlock() + + if l.AllMsgs == nil { + t.Fatal("DummyLogger.CheckForProhibited called without AllMsgs being collected") + } + + // Collect _all_ matches, rather than have to re-test repeatedly. + // This will particularly help with less deterministic tests with multiple matches. + shouldFail := false + for i := range l.AllMsgs { + if strings.Contains(l.AllMsgs[i], needle) { + t.Errorf("log contains %s: %v", reason, l.AllMsgs[i]) + shouldFail = true + } + } + if shouldFail { + t.FailNow() + } +} diff --git a/server/accounts.go b/server/accounts.go index 545cff60..f6292b6e 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -3473,11 +3473,11 @@ func (ur *URLAccResolver) Fetch(name string) (string, error) { url := ur.url + name resp, err := ur.c.Get(url) if err != nil { - return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", url, err) + return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", redactURLString(url), err) } else if resp == nil { - return _EMPTY_, fmt.Errorf("could not fetch <%q>: no response", url) + return _EMPTY_, fmt.Errorf("could not fetch <%q>: no response", redactURLString(url)) } else if resp.StatusCode != http.StatusOK { - return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", url, resp.Status) + return _EMPTY_, fmt.Errorf("could not fetch <%q>: %v", redactURLString(url), resp.Status) } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) diff --git a/server/client.go b/server/client.go index a5bef56d..70ba50f3 100644 --- a/server/client.go +++ b/server/client.go @@ -4826,10 +4826,10 @@ func (c *client) reconnect() { srv.Debugf("Not attempting reconnect for solicited route, already connected to \"%s\"", rid) return } else if rid == srv.info.ID { - srv.Debugf("Detected route to self, ignoring %q", rurl) + srv.Debugf("Detected route to self, ignoring %q", rurl.Redacted()) return } else if rtype != Implicit || retryImplicit { - srv.Debugf("Attempting reconnect for solicited route \"%s\"", rurl) + srv.Debugf("Attempting reconnect for solicited route \"%s\"", rurl.Redacted()) // Keep track of this go-routine so we can wait for it on // server shutdown. srv.startGoRoutine(func() { srv.reConnectToRoute(rurl, rtype) }) diff --git a/server/client_test.go b/server/client_test.go index a233d26c..dfe03843 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -1819,7 +1819,7 @@ func TestTraceMsg(t *testing.T) { c.traceMsg(ut.Msg) - got := c.srv.logging.logger.(*DummyLogger).msg + got := c.srv.logging.logger.(*DummyLogger).Msg if !reflect.DeepEqual(ut.Wanted, got) { t.Errorf("Desc: %s. Msg %q. Traced msg want: %s, got: %s", ut.Desc, ut.Msg, ut.Wanted, got) } @@ -2427,15 +2427,15 @@ func TestClientConnectionName(t *testing.T) { checkLog := func(suffix string) { t.Helper() l.Lock() - msg := l.msg + msg := l.Msg l.Unlock() if strings.Contains(msg, "(MISSING)") { t.Fatalf("conn name was not escaped properly, got MISSING: %s", msg) } - if !strings.Contains(l.msg, test.kindStr) { + if !strings.Contains(l.Msg, test.kindStr) { t.Fatalf("expected kind to be %q, got: %s", test.kindStr, msg) } - if !strings.HasSuffix(l.msg, suffix) { + if !strings.HasSuffix(l.Msg, suffix) { t.Fatalf("expected statement to end with %q, got %s", suffix, msg) } } diff --git a/server/gateway_test.go b/server/gateway_test.go index 4d4d6f7e..4280ef23 100644 --- a/server/gateway_test.go +++ b/server/gateway_test.go @@ -2946,7 +2946,7 @@ type checkErrorLogger struct { func (l *checkErrorLogger) Errorf(format string, args ...interface{}) { l.DummyLogger.Errorf(format, args...) l.Lock() - if strings.Contains(l.msg, l.checkErrorStr) { + if strings.Contains(l.Msg, l.checkErrorStr) { l.gotError = true } l.Unlock() diff --git a/server/leafnode.go b/server/leafnode.go index dd0616d0..6bb3cecf 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -320,7 +320,7 @@ func validateLeafNode(o *Options) error { } } if !ok { - return fmt.Errorf("remote leaf node configuration cannot have a mix of websocket and non-websocket urls: %q", rcfg.URLs) + return fmt.Errorf("remote leaf node configuration cannot have a mix of websocket and non-websocket urls: %q", redactURLList(rcfg.URLs)) } } } @@ -1187,6 +1187,7 @@ func (c *client) doUpdateLNURLs(cfg *leafNodeCfg, scheme string, URLs []string) for _, surl := range URLs { url, err := url.Parse(fmt.Sprintf("%s://%s", scheme, surl)) if err != nil { + // As per below, the URLs we receive should not have contained URL info, so this should be safe to log. c.Errorf("Error parsing url %q: %v", surl, err) continue } diff --git a/server/leafnode_test.go b/server/leafnode_test.go index ebc6f4dc..5d1e19b5 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -32,6 +32,8 @@ import ( jwt "github.com/nats-io/jwt/v2" "github.com/nats-io/nats.go" + + "github.com/nats-io/nats-server/v2/internal/testhelper" ) type captureLeafNodeRandomIPLogger struct { @@ -409,7 +411,12 @@ func TestLeafNodeAccountNotFound(t *testing.T) { // This test ensures that we can connect using proper user/password // to a LN URL that was discovered through the INFO protocol. +// We also check that the password doesn't leak to debug/trace logs. func TestLeafNodeBasicAuthFailover(t *testing.T) { + // Something a little longer than "pwd" to prevent false positives amongst many log lines; + // don't make it complex enough to be subject to %-escaping, we want a simple needle search. + fatalPassword := "pwdfatal" + content := ` listen: "127.0.0.1:-1" cluster { @@ -421,18 +428,18 @@ func TestLeafNodeBasicAuthFailover(t *testing.T) { listen: "127.0.0.1:-1" authorization { user: foo - password: pwd + password: %s timeout: 1 } } ` - conf := createConfFile(t, []byte(fmt.Sprintf(content, ""))) + conf := createConfFile(t, []byte(fmt.Sprintf(content, "", fatalPassword))) defer removeFile(t, conf) sb1, ob1 := RunServerWithConfig(conf) defer sb1.Shutdown() - conf = createConfFile(t, []byte(fmt.Sprintf(content, fmt.Sprintf("routes: [nats://127.0.0.1:%d]", ob1.Cluster.Port)))) + conf = createConfFile(t, []byte(fmt.Sprintf(content, fmt.Sprintf("routes: [nats://127.0.0.1:%d]", ob1.Cluster.Port), fatalPassword))) defer removeFile(t, conf) sb2, _ := RunServerWithConfig(conf) @@ -450,17 +457,20 @@ func TestLeafNodeBasicAuthFailover(t *testing.T) { remotes [ { account: "foo" - url: "nats://foo:pwd@127.0.0.1:%d" + url: "nats://foo:%s@127.0.0.1:%d" } ] } ` - conf = createConfFile(t, []byte(fmt.Sprintf(content, ob1.LeafNode.Port))) + conf = createConfFile(t, []byte(fmt.Sprintf(content, fatalPassword, ob1.LeafNode.Port))) defer removeFile(t, conf) sa, _ := RunServerWithConfig(conf) defer sa.Shutdown() + l := testhelper.NewDummyLogger(100) + sa.SetLogger(l, true, true) // we want debug & trace logs, to check for passwords in them + checkLeafNodeConnected(t, sa) // Shutdown sb1, sa should reconnect to sb2 @@ -471,6 +481,10 @@ func TestLeafNodeBasicAuthFailover(t *testing.T) { // Should be able to reconnect checkLeafNodeConnected(t, sa) + + // Look at all our logs for the password; at time of writing it doesn't appear + // but we want to safe-guard against it. + l.CheckForProhibited(t, "fatal password", fatalPassword) } func TestLeafNodeRTT(t *testing.T) { diff --git a/server/log_test.go b/server/log_test.go index 2014ea62..23e20338 100644 --- a/server/log_test.go +++ b/server/log_test.go @@ -15,15 +15,14 @@ package server import ( "bytes" - "fmt" "io/ioutil" "os" "runtime" "strings" - "sync" "testing" "time" + "github.com/nats-io/nats-server/v2/internal/testhelper" "github.com/nats-io/nats-server/v2/logger" ) @@ -47,22 +46,22 @@ func TestSetLogger(t *testing.T) { // Check traces expectedStr := "This is a Notice" server.Noticef(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is an Error" server.Errorf(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Fatal" server.Fatalf(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Debug" server.Debugf(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Trace" server.Tracef(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) expectedStr = "This is a Warning" server.Tracef(expectedStr) - dl.checkContent(t, expectedStr) + dl.CheckContent(t, expectedStr) // Make sure that we can reset to fal server.SetLogger(dl, false, false) @@ -73,56 +72,14 @@ func TestSetLogger(t *testing.T) { t.Fatalf("Expected trace 0, got %v", server.logging.trace) } // Now, Debug and Trace should not produce anything - dl.msg = "" + dl.Msg = "" server.Debugf("This Debug should not be traced") - dl.checkContent(t, "") + dl.CheckContent(t, "") server.Tracef("This Trace should not be traced") - dl.checkContent(t, "") + dl.CheckContent(t, "") } -type DummyLogger struct { - sync.Mutex - msg string -} - -func (l *DummyLogger) checkContent(t *testing.T, expectedStr string) { - l.Lock() - defer l.Unlock() - if l.msg != expectedStr { - stackFatalf(t, "Expected log to be: %v, got %v", expectedStr, l.msg) - } -} - -func (l *DummyLogger) Noticef(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Errorf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Warnf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Fatalf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Debugf(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} -func (l *DummyLogger) Tracef(format string, v ...interface{}) { - l.Lock() - defer l.Unlock() - l.msg = fmt.Sprintf(format, v...) -} +type DummyLogger = testhelper.DummyLogger func TestReOpenLogFile(t *testing.T) { // We can't rename the file log when still opened on Windows, so skip @@ -140,7 +97,7 @@ func TestReOpenLogFile(t *testing.T) { dl := &DummyLogger{} s.SetLogger(dl, false, false) s.ReOpenLogFile() - dl.checkContent(t, "File log re-open ignored, not a file logger") + dl.CheckContent(t, "File log re-open ignored, not a file logger") // Set a File log s.opts.LogFile = "test.log" @@ -247,7 +204,7 @@ func TestNoPasswordsFromConnectTrace(t *testing.T) { opts.PingInterval = 2 * time.Minute setBaselineOptions(opts) s := &Server{opts: opts} - dl := &DummyLogger{} + dl := testhelper.NewDummyLogger(100) s.SetLogger(dl, false, true) _ = s.logging.logger.(*DummyLogger) @@ -265,13 +222,7 @@ func TestNoPasswordsFromConnectTrace(t *testing.T) { t.Fatalf("Received error: %v\n", err) } - dl.Lock() - hasPass := strings.Contains(dl.msg, "s3cr3t") - dl.Unlock() - - if hasPass { - t.Fatalf("Password detected in log output: %s", dl.msg) - } + dl.CheckForProhibited(t, "password found", "s3cr3t") } func TestRemovePassFromTrace(t *testing.T) { diff --git a/server/opts.go b/server/opts.go index e733a495..5543ca84 100644 --- a/server/opts.go +++ b/server/opts.go @@ -1472,6 +1472,8 @@ func parseURL(u string, typ string) (*url.URL, error) { urlStr := strings.TrimSpace(u) url, err := url.Parse(urlStr) if err != nil { + // Security note: if it's not well-formed but still reached us, then we're going to log as-is which might include password information here. + // If the URL parses, we don't log the credentials ever, but if it doesn't even parse we don't have a sane way to redact. return nil, fmt.Errorf("error parsing %s url [%q]", typ, urlStr) } return url, nil diff --git a/server/util.go b/server/util.go index 88b3f635..3113bd80 100644 --- a/server/util.go +++ b/server/util.go @@ -221,3 +221,44 @@ func natsDialTimeout(network, address string, timeout time.Duration) (net.Conn, } return d.Dial(network, address) } + +// redactURLList() returns a copy of a list of URL pointers where each item +// in the list will either be the same pointer if the URL does not contain a +// password, or to a new object if there is a password. +// The intended use-case is for logging lists of URLs safely. +func redactURLList(unredacted []*url.URL) []*url.URL { + r := make([]*url.URL, len(unredacted)) + // In the common case of no passwords, if we don't let the new object leave + // this function then GC should be easier. + needCopy := false + for i := range unredacted { + if unredacted[i] == nil { + r[i] = nil + continue + } + if _, has := unredacted[i].User.Password(); !has { + r[i] = unredacted[i] + continue + } + needCopy = true + ru := *unredacted[i] + ru.User = url.UserPassword(ru.User.Username(), "xxxxx") + r[i] = &ru + } + if needCopy { + return r + } + return unredacted +} + +// redactURLString() attempts to redact a URL string. +func redactURLString(raw string) string { + if !strings.ContainsRune(raw, '@') { + return raw + } + u, err := url.Parse(raw) + if err != nil { + return raw + } + return u.Redacted() +} diff --git a/server/util_test.go b/server/util_test.go index ec84315b..23a47e72 100644 --- a/server/util_test.go +++ b/server/util_test.go @@ -17,6 +17,7 @@ import ( "math" "math/rand" "net/url" + "reflect" "strconv" "sync" "testing" @@ -147,12 +148,48 @@ func TestComma(t *testing.T) { {"-10", comma(-10), "-10"}, } + failed := false for _, test := range l { if test.got != test.exp { t.Errorf("On %v, expected '%v', but got '%v'", test.name, test.exp, test.got) + failed = true } } + if failed { + t.FailNow() + } +} + +func TestURLRedaction(t *testing.T) { + redactionFromTo := []struct { + Full string + Safe string + }{ + {"nats://foo:bar@example.org", "nats://foo:xxxxx@example.org"}, + {"nats://foo@example.org", "nats://foo@example.org"}, + {"nats://example.org", "nats://example.org"}, + {"nats://example.org/foo?bar=1", "nats://example.org/foo?bar=1"}, + } + var err error + listFull := make([]*url.URL, len(redactionFromTo)) + listSafe := make([]*url.URL, len(redactionFromTo)) + for i := range redactionFromTo { + r := redactURLString(redactionFromTo[i].Full) + if r != redactionFromTo[i].Safe { + t.Fatalf("Redacting URL [index %d] %q, expected %q got %q", i, redactionFromTo[i].Full, redactionFromTo[i].Safe, r) + } + if listFull[i], err = url.Parse(redactionFromTo[i].Full); err != nil { + t.Fatalf("Redacting URL index %d parse Full failed: %v", i, err) + } + if listSafe[i], err = url.Parse(redactionFromTo[i].Safe); err != nil { + t.Fatalf("Redacting URL index %d parse Safe failed: %v", i, err) + } + } + results := redactURLList(listFull) + if !reflect.DeepEqual(results, listSafe) { + t.Fatalf("Redacting URL list did not compare equal, even after each URL did") + } } func BenchmarkParseInt(b *testing.B) { diff --git a/test/routes_test.go b/test/routes_test.go index 27f9c824..982ba398 100644 --- a/test/routes_test.go +++ b/test/routes_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/nats-io/nats-server/v2/internal/testhelper" "github.com/nats-io/nats-server/v2/server" "github.com/nats-io/nats.go" ) @@ -34,6 +35,10 @@ func runRouteServer(t *testing.T) (*server.Server, *server.Options) { return RunServerWithConfig("./configs/cluster.conf") } +func runRouteServerOverrides(t *testing.T, cbo func(*server.Options), cbs func(*server.Server)) (*server.Server, *server.Options) { + return RunServerWithConfigOverrides("./configs/cluster.conf", cbo, cbs) +} + func TestRouterListeningSocket(t *testing.T) { s, opts := runRouteServer(t) defer s.Shutdown() @@ -93,7 +98,11 @@ func TestSendRouteInfoOnConnect(t *testing.T) { } func TestRouteToSelf(t *testing.T) { - s, opts := runRouteServer(t) + l := testhelper.NewDummyLogger(100) + s, opts := runRouteServerOverrides(t, nil, + func(s *server.Server) { + s.SetLogger(l, true, true) + }) defer s.Shutdown() rc := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) @@ -123,6 +132,8 @@ func TestRouteToSelf(t *testing.T) { if _, err := rc.Read(buf); err == nil { t.Fatal("Expected route connection to be closed") } + // This should have been removed by removePassFromTrace(), but we also check debug logs here + l.CheckForProhibited(t, "route authorization password found", "top_secret") } func TestSendRouteSubAndUnsub(t *testing.T) { @@ -375,7 +386,11 @@ func TestRouteQueueSemantics(t *testing.T) { } func TestSolicitRouteReconnect(t *testing.T) { - s, opts := runRouteServer(t) + l := testhelper.NewDummyLogger(100) + s, opts := runRouteServerOverrides(t, nil, + func(s *server.Server) { + s.SetLogger(l, true, true) + }) defer s.Shutdown() rURL := opts.Routes[0] @@ -388,6 +403,9 @@ func TestSolicitRouteReconnect(t *testing.T) { // We expect to get called back.. route = acceptRouteConn(t, rURL.Host, 2*server.DEFAULT_ROUTE_CONNECT) route.Close() + + // Now we want to check for the debug logs when it tries to reconnect + l.CheckForProhibited(t, "route authorization password found", ":bar") } func TestMultipleRoutesSameId(t *testing.T) { diff --git a/test/test.go b/test/test.go index a872b04b..b27a6605 100644 --- a/test/test.go +++ b/test/test.go @@ -72,6 +72,10 @@ var ( // RunServer starts a new Go routine based server func RunServer(opts *server.Options) *server.Server { + return RunServerCallback(opts, nil) +} + +func RunServerCallback(opts *server.Options, callback func(*server.Server)) *server.Server { if opts == nil { opts = &DefaultTestOptions } @@ -89,6 +93,10 @@ func RunServer(opts *server.Options) *server.Server { s.ConfigureLogger() } + if callback != nil { + callback(s) + } + // Run server in Go routine. go s.Start() @@ -115,6 +123,17 @@ func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Op return } +// RunServerWithConfigOverrides starts a new Go routine based server with a configuration file, +// providing a callback to update the options configured. +func RunServerWithConfigOverrides(configFile string, optsCallback func(*server.Options), svrCallback func(*server.Server)) (srv *server.Server, opts *server.Options) { + opts = LoadConfig(configFile) + if optsCallback != nil { + optsCallback(opts) + } + srv = RunServerCallback(opts, svrCallback) + return +} + func stackFatalf(t tLogger, f string, args ...interface{}) { lines := make([]string, 0, 32) msg := fmt.Sprintf(f, args...)