diff --git a/server/client.go b/server/client.go index b3f5f1cf..5a64a79c 100644 --- a/server/client.go +++ b/server/client.go @@ -446,6 +446,7 @@ func (c *client) processConnect(arg []byte) error { // Capture these under lock proto := c.opts.Protocol verbose := c.opts.Verbose + lang := c.opts.Lang c.mu.Unlock() if srv != nil { @@ -468,7 +469,15 @@ func (c *client) processConnect(arg []byte) error { // Check client protocol request if it exists. if typ == CLIENT && (proto < ClientProtoZero || proto > ClientProtoInfo) { + c.sendErr(ErrBadClientProtocol.Error()) + c.closeConnection() return ErrBadClientProtocol + } else if typ == ROUTER && lang != "" { + // Way to detect clients that incorrectly connect to the route listen + // port. Client provide Lang in the CONNECT protocol while ROUTEs don't. + c.sendErr(ErrClientConnectedToRoutePort.Error()) + c.closeConnection() + return ErrClientConnectedToRoutePort } // Grab connection name of remote route. diff --git a/server/client_test.go b/server/client_test.go index 630a4af5..46a48064 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -177,7 +177,7 @@ func TestClientConnect(t *testing.T) { } func TestClientConnectProto(t *testing.T) { - _, c, _ := setupClient() + _, c, r := setupClient() // Basic Connect setting flags, proto should be zero (original proto) connectOp := []byte("CONNECT {\"verbose\":true,\"pedantic\":true,\"ssl_required\":false}\r\n") @@ -210,6 +210,19 @@ func TestClientConnectProto(t *testing.T) { // Illegal Option connectOp = []byte("CONNECT {\"protocol\":22}\r\n") + wg := sync.WaitGroup{} + wg.Add(1) + // The client here is using a pipe, we need to be dequeuing + // data otherwise the server would be blocked trying to send + // the error back to it. + go func() { + defer wg.Done() + for { + if _, _, err := r.ReadLine(); err != nil { + return + } + } + }() err = c.parse(connectOp) if err == nil { t.Fatalf("Expected to receive an error\n") @@ -217,6 +230,7 @@ func TestClientConnectProto(t *testing.T) { if err != ErrBadClientProtocol { t.Fatalf("Expected err of %q, got %q\n", ErrBadClientProtocol, err) } + wg.Wait() } func TestClientPing(t *testing.T) { diff --git a/server/errors.go b/server/errors.go index fee22923..cc22bd1e 100644 --- a/server/errors.go +++ b/server/errors.go @@ -29,4 +29,8 @@ var ( // ErrTooManyConnections signals a client that the maximum number of connections supported by the // server has been reached. ErrTooManyConnections = errors.New("Maximum Connections Exceeded") + + // ErrClientConnectedToRoutePort represents an error condition when a client + // attempted to connect to the route listen port. + ErrClientConnectedToRoutePort = errors.New("Attempted To Connect To Route Port") ) diff --git a/server/routes_test.go b/server/routes_test.go index e0f5f90d..9cd0e30e 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -477,3 +477,46 @@ func TestRouteUseIPv6(t *testing.T) { t.Fatal("Server failed to start route accept loop") } } + +func TestClientConnectToRoutePort(t *testing.T) { + opts := DefaultOptions + opts.Cluster.NoAdvertise = true + s := RunServer(&opts) + defer s.Shutdown() + + url := fmt.Sprintf("nats://%s:%d", opts.Cluster.Host, opts.Cluster.Port) + clientURL := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + total := 100 + for i := 0; i < total; i++ { + nc, err := nats.Connect(url) + if err == nil { + // It is possible that the client reconnects because + // it gets from the initial connect to the route the + // connectedUrl array and may be able to reconnects + // to the client URL. + // If connected, it should be to the client URL + if nc.ConnectedUrl() != clientURL { + t.Fatalf("Expected client to be connected to %v, got %v", clientURL, nc.ConnectedUrl()) + } + nc.Close() + } + // If error, it could be ErrClientConnectingToRoutePort or + // other (EOF, etc)... so not checking for specific one. + } + + // When disabling randomization, the client URL is added to the server + // pool and so should be tried after the connection is closed trying + // to connect to the route port. Connect must always succeed and + // must be connected to client URL. + for i := 0; i < total; i++ { + nc, err := nats.Connect(url, nats.DontRandomize()) + if err == nil { + if nc.ConnectedUrl() != clientURL { + t.Fatalf("Expected client to be connected to %v, got %v", clientURL, nc.ConnectedUrl()) + } + nc.Close() + continue + } + t.Fatalf("Error on connect: %v", err) + } +}