diff --git a/server/client.go b/server/client.go index d774f042..02b6058a 100644 --- a/server/client.go +++ b/server/client.go @@ -1,4 +1,4 @@ -// Copyright 2012-2019 The NATS Authors +// Copyright 2012-2020 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/server/errors.go b/server/errors.go index 9472f30d..cf5f2549 100644 --- a/server/errors.go +++ b/server/errors.go @@ -1,4 +1,4 @@ -// Copyright 2012-2019 The NATS Authors +// Copyright 2012-2020 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -66,9 +66,9 @@ var ( // attempted to connect to the leaf node listen port. ErrClientConnectedToLeafNodePort = errors.New("attempted to connect to leaf node port") - // ErrLeafConnectedToClientPort represents an error condition when a client - // attempted to connect to the leaf node listen port. - ErrLeafConnectedToClientPort = errors.New("attempted to connect to client port") + // ErrConnectedToWrongPort represents an error condition when a connection is attempted + // to the wrong listen port (for instance a LeafNode to a client port, etc...) + ErrConnectedToWrongPort = errors.New("attempted to connect to wrong port") // ErrAccountExists is returned when an account is attempted to be registered // but already exists. diff --git a/server/leafnode.go b/server/leafnode.go index b8a5e6b3..ccef467d 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -1,4 +1,4 @@ -// Copyright 2019 The NATS Authors +// Copyright 2019-2020 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -627,18 +627,17 @@ func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { c.nc.SetReadDeadline(time.Time{}) c.mu.Unlock() - // Error will be handled below, so ignore here. - err = c.parse([]byte(info)) - if err != nil { - c.Debugf("Error reading remote leafnode's INFO: %s", err) - c.closeConnection(ReadError) + // Handle only connection to wrong port here, others will be handled below. + if err := c.parse([]byte(info)); err == ErrConnectedToWrongPort { + c.Errorf(err.Error()) + c.closeConnection(WrongPort) return nil } c.mu.Lock() if !c.flags.isSet(infoReceived) { c.mu.Unlock() - c.Debugf("Did not get the remote leafnode's INFO, timed-out") + c.Errorf("Did not get the remote leafnode's INFO, timed-out") c.closeConnection(ReadError) return nil } @@ -798,15 +797,32 @@ func (c *client) processLeafnodeInfo(info *Info) error { return nil } - // Prevent connecting to client port. - if info.ClientConnectURLs != nil { - return ErrLeafConnectedToClientPort - } - // Mark that the INFO protocol has been received. // Note: For now, only the initial INFO has a nonce. We // will probably do auto key rotation at some point. if c.flags.setIfNotSet(infoReceived) { + // Prevent connecting to non leafnode port. Need to do this only for + // the first INFO, not for async INFO updates... + // + // Content of INFO sent by the server when accepting a tcp connection. + // ------------------------------------------------------------------- + // Listen Port Of | CID | ClientConnectURLs | LeafNodeURLs | Gateway | + // ------------------------------------------------------------------- + // CLIENT | X* | X** | | | + // ROUTE | | X** | X*** | | + // GATEWAY | | | | X | + // LEAFNODE | X | | X | | + // ------------------------------------------------------------------- + // * Not on older servers. + // ** Not if "no advertise" is enabled. + // *** Not if leafnode's "no advertise" is enabled. + // + // As seen from above, a solicited LeafNode connection should receive + // from the remote server an INFO with CID and LeafNodeURLs. Anything + // else should be considered an attempt to connect to a wrong port. + if c.leaf.remote != nil && (info.CID == 0 || info.LeafNodeURLs == nil) { + return ErrConnectedToWrongPort + } // Capture a nonce here. c.nonce = []byte(info.Nonce) if info.TLSRequired && c.leaf.remote != nil { diff --git a/server/leafnode_test.go b/server/leafnode_test.go index f7f7d484..55649bc1 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -1,4 +1,4 @@ -// Copyright 2019 The NATS Authors +// Copyright 2019-2020 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -870,65 +870,82 @@ func TestLeafCloseTLSConnection(t *testing.T) { ch <- true } -type captureDebugErrorLogger struct { - DummyLogger - errCh chan string -} - -func (l *captureDebugErrorLogger) Debugf(format string, v ...interface{}) { - select { - case l.errCh <- fmt.Sprintf(format, v...): - default: - } -} - func TestLeafNodeRemoteWrongPort(t *testing.T) { - port := 8786 + for _, test1 := range []struct { + name string + clusterAdvertise bool + leafnodeAdvertise bool + }{ + {"advertise_on", false, false}, + {"cluster_no_advertise", true, false}, + {"leafnode_no_advertise", false, true}, + } { + t.Run(test1.name, func(t *testing.T) { + oa := DefaultOptions() + // Make sure we have all ports (client, route, gateway) and we will try + // to create a leafnode to connection to each and make sure we get the error. + oa.Cluster.NoAdvertise = test1.clusterAdvertise + oa.Cluster.Host = "127.0.0.1" + oa.Cluster.Port = -1 + oa.Gateway.Host = "127.0.0.1" + oa.Gateway.Port = -1 + oa.Gateway.Name = "A" + oa.LeafNode.Host = "127.0.0.1" + oa.LeafNode.Port = -1 + oa.LeafNode.NoAdvertise = test1.leafnodeAdvertise + oa.Accounts = []*Account{NewAccount("sys")} + oa.SystemAccount = "sys" + sa := RunServer(oa) + defer sa.Shutdown() - // Server with the wrong config against other server client's port. - leafURL, _ := url.Parse(fmt.Sprintf("nats://127.0.0.1:%d", port)) - oa := DefaultOptions() - oa.Port = -1 - oa.PingInterval = 15 * time.Millisecond - oa.LeafNode.Remotes = []*RemoteLeafOpts{{URLs: []*url.URL{leafURL}}} - sa := RunServer(oa) - defer sa.Shutdown() - l := &captureDebugErrorLogger{errCh: make(chan string, 10)} - sa.SetLogger(l, true, true) + ob := DefaultOptions() + ob.Cluster.NoAdvertise = test1.clusterAdvertise + ob.Cluster.Host = "127.0.0.1" + ob.Cluster.Port = -1 + ob.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", oa.Cluster.Host, oa.Cluster.Port)) + ob.Gateway.Host = "127.0.0.1" + ob.Gateway.Port = -1 + ob.Gateway.Name = "A" + ob.LeafNode.Host = "127.0.0.1" + ob.LeafNode.Port = -1 + ob.LeafNode.NoAdvertise = test1.leafnodeAdvertise + ob.Accounts = []*Account{NewAccount("sys")} + ob.SystemAccount = "sys" + sb := RunServer(ob) + defer sb.Shutdown() - // Make a cluster so that connect_urls is gossiped to clients. - ob := DefaultOptions() - ob.PingInterval = 15 * time.Millisecond - ob.Host = "127.0.0.1" - ob.Port = -1 - ob.Cluster = ClusterOpts{ - Host: "127.0.0.1", - Port: -1, - } - sb := RunServer(ob) - defer sb.Shutdown() + checkClusterFormed(t, sa, sb) - oc := DefaultOptions() - oc.PingInterval = 15 * time.Millisecond - oc.Host = "127.0.0.1" - oc.Port = port - oc.Cluster = ClusterOpts{ - Host: "127.0.0.1", - Port: -1, - } - routeURL, _ := url.Parse(fmt.Sprintf("nats://127.0.0.1:%d", ob.Cluster.Port)) - oc.Routes = []*url.URL{routeURL} - sc := RunServer(oc) - defer sc.Shutdown() + for _, test := range []struct { + name string + port int + }{ + {"client", oa.Port}, + {"cluster", oa.Cluster.Port}, + {"gateway", oa.Gateway.Port}, + } { + t.Run(test.name, func(t *testing.T) { + oc := DefaultOptions() + // Server with the wrong config against non leafnode port. + leafURL, _ := url.Parse(fmt.Sprintf("nats://127.0.0.1:%d", test.port)) + oc.LeafNode.Remotes = []*RemoteLeafOpts{{URLs: []*url.URL{leafURL}}} + oc.LeafNode.ReconnectInterval = 5 * time.Millisecond + sc := RunServer(oc) + defer sc.Shutdown() + l := &captureErrorLogger{errCh: make(chan string, 10)} + sc.SetLogger(l, true, true) - for { - select { - case e := <-l.errCh: - if strings.Contains(e, `attempted to connect to client port`) { - return + select { + case e := <-l.errCh: + if strings.Contains(e, ErrConnectedToWrongPort.Error()) { + return + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get any error about connecting to wrong port for %q - %q", + test1.name, test.name) + } + }) } - case <-time.After(2 * time.Second): - t.Fatalf("Did not get any error about connecting to client port") - } + }) } }