diff --git a/server/client_test.go b/server/client_test.go index c2568b0a..32c009bc 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -84,6 +84,16 @@ func setupClient() (*Server, *client, *bufio.Reader) { return s, c, cr } +func checkClientsCount(t *testing.T, s *Server, expected int) { + t.Helper() + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + if nc := s.NumClients(); nc != expected { + return fmt.Errorf("The number of expected connections was %v, got %v", expected, nc) + } + return nil + }) +} + func TestClientCreateAndInfo(t *testing.T) { c, l := setUpClientWithResponse() @@ -575,22 +585,8 @@ func TestClientDoesNotAddSubscriptionsWhenConnectionClosed(t *testing.T) { func TestClientMapRemoval(t *testing.T) { s, c, _ := setupClient() c.nc.Close() - end := time.Now().Add(1 * time.Second) - for time.Now().Before(end) { - s.mu.Lock() - lsc := len(s.clients) - s.mu.Unlock() - if lsc > 0 { - time.Sleep(5 * time.Millisecond) - } - } - s.mu.Lock() - lsc := len(s.clients) - s.mu.Unlock() - if lsc > 0 { - t.Fatal("Client still in server map") - } + checkClientsCount(t, s, 0) } func TestAuthorizationTimeout(t *testing.T) { @@ -721,23 +717,15 @@ func TestTLSCloseClientConnection(t *testing.T) { t.Fatalf("Unexpected error reading PONG: %v", err) } - getClient := func() *client { - s.mu.Lock() - defer s.mu.Unlock() - for _, c := range s.clients { - return c - } - return nil - } - // Wait for client to be registered. - timeout := time.Now().Add(5 * time.Second) + // Check that client is registered. + checkClientsCount(t, s, 1) var cli *client - for time.Now().Before(timeout) { - cli = getClient() - if cli != nil { - break - } + s.mu.Lock() + for _, c := range s.clients { + cli = c + break } + s.mu.Unlock() if cli == nil { t.Fatal("Did not register client on time") } @@ -1009,15 +997,13 @@ func TestQueueAutoUnsubscribe(t *testing.T) { } nc.Flush() - wait := time.Now().Add(5 * time.Second) - for time.Now().Before(wait) { + checkFor(t, 5*time.Second, 10*time.Millisecond, func() error { nbar := atomic.LoadInt32(&rbar) nbaz := atomic.LoadInt32(&rbaz) if nbar == expected && nbaz == expected { - return + return nil } - time.Sleep(10 * time.Millisecond) - } - t.Fatalf("Did not receive all %d queue messages, received %d for 'bar' and %d for 'baz'\n", - expected, atomic.LoadInt32(&rbar), atomic.LoadInt32(&rbaz)) + return fmt.Errorf("Did not receive all %d queue messages, received %d for 'bar' and %d for 'baz'", + expected, atomic.LoadInt32(&rbar), atomic.LoadInt32(&rbaz)) + }) } diff --git a/server/closed_conns_test.go b/server/closed_conns_test.go index 6e625bd3..e3a86a64 100644 --- a/server/closed_conns_test.go +++ b/server/closed_conns_test.go @@ -23,28 +23,24 @@ import ( nats "github.com/nats-io/go-nats" ) -func closedConnsEqual(s *Server, num int, wait time.Duration) bool { - end := time.Now().Add(wait) - for time.Now().Before(end) { - if s.numClosedConns() == num { - break +func checkClosedConns(t *testing.T, s *Server, num int, wait time.Duration) { + t.Helper() + checkFor(t, wait, 5*time.Millisecond, func() error { + if nc := s.numClosedConns(); nc != num { + return fmt.Errorf("Closed conns expected to be %v, got %v", num, nc) } - time.Sleep(5 * time.Millisecond) - } - n := s.numClosedConns() - return n == num + return nil + }) } -func totalClosedConnsEqual(s *Server, num uint64, wait time.Duration) bool { - end := time.Now().Add(wait) - for time.Now().Before(end) { - if s.totalClosedConns() == num { - break +func checkTotalClosedConns(t *testing.T, s *Server, num uint64, wait time.Duration) { + t.Helper() + checkFor(t, wait, 5*time.Millisecond, func() error { + if nc := s.totalClosedConns(); nc != num { + return fmt.Errorf("Total closed conns expected to be %v, got %v", num, nc) } - time.Sleep(5 * time.Millisecond) - } - n := s.totalClosedConns() - return n == num + return nil + }) } func TestClosedConnsAccounting(t *testing.T) { @@ -62,9 +58,7 @@ func TestClosedConnsAccounting(t *testing.T) { } nc.Close() - if !closedConnsEqual(s, 1, wait) { - t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 1, wait) conns := s.closedClients() if lc := len(conns); lc != 1 { @@ -81,22 +75,11 @@ func TestClosedConnsAccounting(t *testing.T) { t.Fatalf("Error on connect: %v", err) } nc.Close() - // FIXME: For now just sleep a bit to ensure that closed connections - // are added in the expected order for tests down below where we - // check for cid. - time.Sleep(15 * time.Millisecond) + checkTotalClosedConns(t, s, uint64(i+2), wait) } - if !closedConnsEqual(s, opts.MaxClosedClients, wait) { - t.Fatalf("Closed conns expected to be %d, got %d\n", - opts.MaxClosedClients, - s.numClosedConns()) - } - - if !totalClosedConnsEqual(s, 22, wait) { - t.Fatalf("Closed conns expected to be 22, got %d\n", - s.numClosedConns()) - } + checkClosedConns(t, s, opts.MaxClosedClients, wait) + checkTotalClosedConns(t, s, 22, wait) conns = s.closedClients() if lc := len(conns); lc != opts.MaxClosedClients { @@ -135,10 +118,7 @@ func TestClosedConnsSubsAccounting(t *testing.T) { nc.Flush() nc.Close() - if !closedConnsEqual(s, 1, 20*time.Millisecond) { - t.Fatalf("Closed conns expected to be 1, got %d\n", - s.numClosedConns()) - } + checkClosedConns(t, s, 1, 20*time.Millisecond) conns := s.closedClients() if lc := len(conns); lc != 1 { t.Fatalf("len(conns) expected to be 1, got %d\n", lc) @@ -170,9 +150,7 @@ func TestClosedAuthorizationTimeout(t *testing.T) { } defer conn.Close() - if !closedConnsEqual(s, 1, 2*time.Second) { - t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 1, 2*time.Second) conns := s.closedClients() if lc := len(conns); lc != 1 { t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) @@ -195,9 +173,7 @@ func TestClosedAuthorizationViolation(t *testing.T) { t.Fatal("Expected failure for connection") } - if !closedConnsEqual(s, 1, 2*time.Second) { - t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 1, 2*time.Second) conns := s.closedClients() if lc := len(conns); lc != 1 { t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) @@ -228,9 +204,7 @@ func TestClosedUPAuthorizationViolation(t *testing.T) { t.Fatal("Expected failure for connection") } - if !closedConnsEqual(s, 2, 2*time.Second) { - t.Fatalf("Closed conns expected to be 2, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 2, 2*time.Second) conns := s.closedClients() if lc := len(conns); lc != 2 { t.Fatalf("len(conns) expected to be %d, got %d\n", 2, lc) @@ -259,9 +233,7 @@ func TestClosedMaxPayload(t *testing.T) { pub := fmt.Sprintf("PUB foo.bar 1024\r\n") conn.Write([]byte(pub)) - if !closedConnsEqual(s, 1, 2*time.Second) { - t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 1, 2*time.Second) conns := s.closedClients() if lc := len(conns); lc != 1 { t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) @@ -308,9 +280,7 @@ func TestClosedSlowConsumerWriteDeadline(t *testing.T) { } // At this point server should have closed connection c. - if !closedConnsEqual(s, 1, 2*time.Second) { - t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 1, 2*time.Second) conns := s.closedClients() if lc := len(conns); lc != 1 { t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) @@ -357,9 +327,7 @@ func TestClosedSlowConsumerPendingBytes(t *testing.T) { } // At this point server should have closed connection c. - if !closedConnsEqual(s, 1, 2*time.Second) { - t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 1, 2*time.Second) conns := s.closedClients() if lc := len(conns); lc != 1 { t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) @@ -384,9 +352,7 @@ func TestClosedTLSHandshake(t *testing.T) { t.Fatal("Expected failure for connection") } - if !closedConnsEqual(s, 1, 2*time.Second) { - t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns()) - } + checkClosedConns(t, s, 1, 2*time.Second) conns := s.closedClients() if lc := len(conns); lc != 1 { t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) diff --git a/server/monitor_test.go b/server/monitor_test.go index 6884f6e3..47e7ef34 100644 --- a/server/monitor_test.go +++ b/server/monitor_test.go @@ -222,30 +222,6 @@ func pollConz(t *testing.T, s *Server, mode int, url string, opts *ConnzOptions) return c } -func waitForClientConnCount(t *testing.T, s *Server, count int) { - timeout := time.Now().Add(2 * time.Second) - for time.Now().Before(timeout) { - if s.NumClients() == count { - return - } - time.Sleep(15 * time.Millisecond) - } - stackFatalf(t, "The number of expected connections was %v, got %v", count, s.NumClients()) -} - -func waitForClosedClientConnCount(t *testing.T, s *Server, count int) { - timeout := time.Now().Add(2 * time.Second) - for time.Now().Before(timeout) { - if s.numClosedConns() == count { - return - } - time.Sleep(15 * time.Millisecond) - } - stackFatalf(t, - "The number of expected closed connections was %v, got %v", - count, s.numClosedConns()) -} - func TestConnz(t *testing.T) { s := runMonitorServer() defer s.Shutdown() @@ -344,7 +320,7 @@ func TestConnz(t *testing.T) { for mode := 0; mode < 2; mode++ { testConnz(mode) - waitForClientConnCount(t, s, 0) + checkClientsCount(t, s, 0) } // Test JSONP @@ -503,7 +479,7 @@ func TestConnzRTT(t *testing.T) { for mode := 0; mode < 2; mode++ { testRTT(mode) - waitForClientConnCount(t, s, 0) + checkClientsCount(t, s, 0) } } @@ -975,7 +951,7 @@ func TestConnzWithRoutes(t *testing.T) { sc := RunServer(opts) defer sc.Shutdown() - time.Sleep(time.Second) + checkClusterFormed(t, s, sc) url := fmt.Sprintf("http://localhost:%d/", s.MonitorAddr().Port) for mode := 0; mode < 2; mode++ { @@ -994,6 +970,8 @@ func TestConnzWithRoutes(t *testing.T) { defer nc.Close() nc.Subscribe("hello.bar", func(m *nats.Msg) {}) + nc.Flush() + checkExpectedSubs(t, 1, s, sc) // Now check routez urls := []string{"routez", "routez?subs=1"} @@ -1236,7 +1214,7 @@ func TestConnzClosedConnsRace(t *testing.T) { urlWithoutSubs := fmt.Sprintf("http://localhost:%d/connz?state=closed", s.MonitorAddr().Port) urlWithSubs := urlWithoutSubs + "&subs=true" - waitForClosedClientConnCount(t, s, 100) + checkClosedConns(t, s, 100, 2*time.Second) wg := &sync.WaitGroup{} @@ -1437,7 +1415,7 @@ func TestConnzTLSInHandshake(t *testing.T) { defer c.Close() // Wait for the connection to be registered - waitForClientConnCount(t, s, 1) + checkClientsCount(t, s, 1) start := time.Now() endpoint := fmt.Sprintf("http://%s:%d/connz", opts.HTTPHost, s.MonitorAddr().Port) diff --git a/server/reload_test.go b/server/reload_test.go index 7ad1b5d5..bfabed78 100644 --- a/server/reload_test.go +++ b/server/reload_test.go @@ -1441,18 +1441,7 @@ func TestConfigReloadClusterRemoveSolicitedRoutes(t *testing.T) { srvb.Shutdown() // Wait til route is dropped. - numRoutes := 0 - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - if numRoutes = srva.NumRoutes(); numRoutes == 0 { - break - } else { - time.Sleep(100 * time.Millisecond) - } - } - if numRoutes != 0 { - t.Fatalf("Expected 0 routes for server A, got %d", numRoutes) - } + checkNumRoutes(t, srva, 0) // Now change config for server A to not solicit a route to server B. createSymlink(t, srvaConfig, "./configs/reload/srv_a_4.conf") @@ -1466,7 +1455,8 @@ func TestConfigReloadClusterRemoveSolicitedRoutes(t *testing.T) { defer srvb.Shutdown() // We should not have a cluster formed here. - deadline = time.Now().Add(2 * DEFAULT_ROUTE_RECONNECT) + numRoutes := 0 + deadline := time.Now().Add(2 * DEFAULT_ROUTE_RECONNECT) for time.Now().Before(deadline) { if numRoutes = srva.NumRoutes(); numRoutes != 0 { break diff --git a/server/routes_test.go b/server/routes_test.go index 4f430d8a..2b40438a 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -14,7 +14,6 @@ package server import ( - "errors" "fmt" "net" "net/url" @@ -29,6 +28,16 @@ import ( "github.com/nats-io/go-nats" ) +func checkNumRoutes(t *testing.T, s *Server, expected int) { + t.Helper() + checkFor(t, 5*time.Second, 15*time.Millisecond, func() error { + if nr := s.NumRoutes(); nr != expected { + return fmt.Errorf("Expected %v routes, got %v", expected, nr) + } + return nil + }) +} + func TestRouteConfig(t *testing.T) { opts, err := ProcessConfigFile("./configs/cluster.conf") if err != nil { @@ -134,21 +143,15 @@ func TestClusterAdvertiseErrorOnStartup(t *testing.T) { s.Start() wg.Done() }() - msg := "" - ok := false - timeout := time.Now().Add(2 * time.Second) - for time.Now().Before(timeout) { + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { dl.Lock() - msg = dl.msg + msg := dl.msg dl.Unlock() if strings.Contains(msg, "Cluster.Advertise") { - ok = true - break + return nil } - } - if !ok { - t.Fatalf("Did not get expected error, got %v", msg) - } + return fmt.Errorf("Did not get expected error, got %v", msg) + }) s.Shutdown() wg.Wait() } @@ -173,21 +176,15 @@ func TestClientAdvertise(t *testing.T) { t.Fatalf("Error on connect: %v", err) } defer nc.Close() - timeout := time.Now().Add(time.Second) - good := false - for time.Now().Before(timeout) { + checkFor(t, time.Second, 15*time.Millisecond, func() error { ds := nc.DiscoveredServers() if len(ds) == 1 { if ds[0] == "nats://me:1" { - good = true - break + return nil } } - time.Sleep(15 * time.Millisecond) - } - if !good { - t.Fatalf("Did not get expected discovered servers: %v", nc.DiscoveredServers()) - } + return fmt.Errorf("Did not get expected discovered servers: %v", nc.DiscoveredServers()) + }) } func TestServerRoutesWithClients(t *testing.T) { @@ -284,27 +281,16 @@ func TestServerRoutesWithAuthAndBCrypt(t *testing.T) { // Helper function to check that a cluster is formed func checkClusterFormed(t *testing.T, servers ...*Server) { - // Wait for the cluster to form - var err string + t.Helper() expectedNumRoutes := len(servers) - 1 - maxTime := time.Now().Add(10 * time.Second) - for time.Now().Before(maxTime) { - err = "" + checkFor(t, 10*time.Second, 100*time.Millisecond, func() error { for _, s := range servers { if numRoutes := s.NumRoutes(); numRoutes != expectedNumRoutes { - err = fmt.Sprintf("Expected %d routes for server %q, got %d", expectedNumRoutes, s.ID(), numRoutes) - break + return fmt.Errorf("Expected %d routes for server %q, got %d", expectedNumRoutes, s.ID(), numRoutes) } } - if err != "" { - time.Sleep(100 * time.Millisecond) - } else { - break - } - } - if err != "" { - stackFatalf(t, "%s", err) - } + return nil + }) } // Helper function to generate next opts to make sure no port conflicts etc. @@ -485,27 +471,16 @@ func TestChainedSolicitWorks(t *testing.T) { // Helper function to check that a server (or list of servers) have the // expected number of subscriptions. -func checkExpectedSubs(expected int, servers ...*Server) error { - var err string - maxTime := time.Now().Add(10 * time.Second) - for time.Now().Before(maxTime) { - err = "" +func checkExpectedSubs(t *testing.T, expected int, servers ...*Server) { + t.Helper() + checkFor(t, 10*time.Second, 10*time.Millisecond, func() error { for _, s := range servers { if numSubs := int(s.NumSubscriptions()); numSubs != expected { - err = fmt.Sprintf("Expected %d subscriptions for server %q, got %d", expected, s.ID(), numSubs) - break + return fmt.Errorf("Expected %d subscriptions for server %q, got %d", expected, s.ID(), numSubs) } } - if err != "" { - time.Sleep(10 * time.Millisecond) - } else { - break - } - } - if err != "" { - return errors.New(err) - } - return nil + return nil + }) } func TestTLSChainedSolicitWorks(t *testing.T) { @@ -546,9 +521,7 @@ func TestTLSChainedSolicitWorks(t *testing.T) { defer srvB.Shutdown() checkClusterFormed(t, srvSeed, srvA, srvB) - if err := checkExpectedSubs(1, srvA, srvB); err != nil { - t.Fatalf("%v", err) - } + checkExpectedSubs(t, 1, srvA, srvB) urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, srvB.Addr().(*net.TCPAddr).Port) @@ -583,17 +556,7 @@ func TestRouteTLSHandshakeError(t *testing.T) { time.Sleep(500 * time.Millisecond) - maxTime := time.Now().Add(1 * time.Second) - for time.Now().Before(maxTime) { - if srv.NumRoutes() > 0 { - time.Sleep(100 * time.Millisecond) - continue - } - break - } - if srv.NumRoutes() > 0 { - t.Fatal("Route should have failed") - } + checkNumRoutes(t, srv, 0) } func TestBlockedShutdownOnRouteAcceptLoopFailure(t *testing.T) { @@ -826,12 +789,8 @@ func TestServerPoolUpdatedWhenRouteGoesAway(t *testing.T) { // Don't use discovered here, but Servers to have the full list. // Also, there may be cases where the mesh is not formed yet, // so try again on failure. - var ( - ds []string - timeout = time.Now().Add(5 * time.Second) - ) - for time.Now().Before(timeout) { - ds = nc.Servers() + checkFor(t, 5*time.Second, 50*time.Millisecond, func() error { + ds := nc.Servers() if len(ds) == len(expected) { m := make(map[string]struct{}, len(ds)) for _, url := range ds { @@ -845,12 +804,11 @@ func TestServerPoolUpdatedWhenRouteGoesAway(t *testing.T) { } } if ok { - return + return nil } } - time.Sleep(50 * time.Millisecond) - } - stackFatalf(t, "Expected %v, got %v", expected, ds) + return fmt.Errorf("Expected %v, got %v", expected, ds) + }) } // Verify that we now know about s2 checkPool([]string{s1Url, s2Url}) @@ -972,8 +930,7 @@ func TestRoutedQueueAutoUnsubscribe(t *testing.T) { c.Flush() } - wait := time.Now().Add(10 * time.Second) - for time.Now().Before(wait) { + checkFor(t, 10*time.Second, 100*time.Millisecond, func() error { nbar := atomic.LoadInt32(&rbar) nbaz := atomic.LoadInt32(&rbaz) if nbar == expected && nbaz == expected { @@ -986,15 +943,14 @@ func TestRoutedQueueAutoUnsubscribe(t *testing.T) { nrqsb := len(srvB.rqsubs) srvB.rqsMu.RUnlock() if nrqsa != 0 || nrqsb != 0 { - t.Fatalf("Expected rqs mappings to have cleared, but got A:%d, B:%d\n", + return fmt.Errorf("Expected rqs mappings to have cleared, but got A:%d, B:%d", nrqsa, nrqsb) } - return + return nil } - time.Sleep(100 * time.Millisecond) - } - t.Fatalf("Did not receive all %d queue messages, received %d for 'bar' and %d for 'baz'\n", - expected, atomic.LoadInt32(&rbar), atomic.LoadInt32(&rbaz)) + return fmt.Errorf("Did not receive all %d queue messages, received %d for 'bar' and %d for 'baz'", + expected, atomic.LoadInt32(&rbar), atomic.LoadInt32(&rbaz)) + }) } func TestRouteFailedConnRemovedFromTmpMap(t *testing.T) { @@ -1012,8 +968,16 @@ func TestRouteFailedConnRemovedFromTmpMap(t *testing.T) { // Start this way to increase chance of having the two connect // to each other at the same time. This will cause one of the // route to be dropped. - go srvA.Start() - go srvB.Start() + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + srvA.Start() + wg.Done() + }() + go func() { + srvB.Start() + wg.Done() + }() checkClusterFormed(t, srvA, srvB) @@ -1028,4 +992,8 @@ func TestRouteFailedConnRemovedFromTmpMap(t *testing.T) { } checkMap(srvA) checkMap(srvB) + + srvB.Shutdown() + srvA.Shutdown() + wg.Wait() } diff --git a/server/server_test.go b/server/server_test.go index 0947b100..bd5c59a4 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -25,6 +25,22 @@ import ( "github.com/nats-io/go-nats" ) +func checkFor(t *testing.T, totalWait, sleepDur time.Duration, f func() error) { + t.Helper() + timeout := time.Now().Add(totalWait) + var err error + for time.Now().Before(timeout) { + err = f() + if err == nil { + return + } + time.Sleep(sleepDur) + } + if err != nil { + t.Fatal(err.Error()) + } +} + func DefaultOptions() *Options { return &Options{ Host: "localhost", @@ -593,18 +609,11 @@ func TestCustomRouterAuthentication(t *testing.T) { opts2.Routes = RoutesFromStr(fmt.Sprintf("nats://invalid@127.0.0.1:%d", clusterPort)) s2 := RunServer(opts2) defer s2.Shutdown() - timeout := time.Now().Add(2 * time.Second) - nr := 0 - for time.Now().Before(timeout) { - nr = s2.NumRoutes() - if nr == 0 { - break - } - time.Sleep(15 * time.Millisecond) - } - if nr != 0 { - t.Fatalf("Expected no route, got %v", nr) - } + + // s2 will attempt to connect to s, which should reject. + // Keep in mind that s2 will try again... + time.Sleep(50 * time.Millisecond) + checkNumRoutes(t, s2, 0) opts3 := DefaultOptions() opts3.Cluster.Host = "127.0.0.1" @@ -612,9 +621,7 @@ func TestCustomRouterAuthentication(t *testing.T) { s3 := RunServer(opts3) defer s3.Shutdown() checkClusterFormed(t, s, s3) - if nr := s3.NumRoutes(); nr != 1 { - t.Fatalf("Expected 1 route, got %v", nr) - } + checkNumRoutes(t, s3, 1) } func TestMonitoringNoTimeout(t *testing.T) { diff --git a/test/cluster_test.go b/test/cluster_test.go index 718a04c8..c7d183b8 100644 --- a/test/cluster_test.go +++ b/test/cluster_test.go @@ -25,27 +25,26 @@ import ( // Helper function to check that a cluster is formed func checkClusterFormed(t *testing.T, servers ...*server.Server) { - // Wait for the cluster to form - var err string + t.Helper() expectedNumRoutes := len(servers) - 1 - maxTime := time.Now().Add(5 * time.Second) - for time.Now().Before(maxTime) { - err = "" + checkFor(t, 10*time.Second, 100*time.Millisecond, func() error { for _, s := range servers { if numRoutes := s.NumRoutes(); numRoutes != expectedNumRoutes { - err = fmt.Sprintf("Expected %d routes for server %q, got %d", expectedNumRoutes, s.ID(), numRoutes) - break + return fmt.Errorf("Expected %d routes for server %q, got %d", expectedNumRoutes, s.ID(), numRoutes) } } - if err != "" { - time.Sleep(100 * time.Millisecond) - } else { - break + return nil + }) +} + +func checkNumRoutes(t *testing.T, s *server.Server, expected int) { + t.Helper() + checkFor(t, 5*time.Second, 15*time.Millisecond, func() error { + if nr := s.NumRoutes(); nr != expected { + return fmt.Errorf("Expected %v routes, got %v", expected, nr) } - } - if err != "" { - t.Fatalf("%s", err) - } + return nil + }) } // Helper function to check that a server (or list of servers) have the diff --git a/test/route_discovery_test.go b/test/route_discovery_test.go index 1a49323a..5455bbef 100644 --- a/test/route_discovery_test.go +++ b/test/route_discovery_test.go @@ -297,37 +297,16 @@ func TestStressSeedSolicitWorks(t *testing.T) { serversInfo := []serverInfo{{s1, opts}, {s2, s2Opts}, {s3, s3Opts}, {s4, s4Opts}} - var err error - maxTime := time.Now().Add(5 * time.Second) - for time.Now().Before(maxTime) { + checkFor(t, 5*time.Second, 100*time.Millisecond, func() error { for j := 0; j < len(serversInfo); j++ { - err = checkConnected(t, serversInfo, j, true) - // If error, start this for loop from beginning - if err != nil { - // Sleep a bit before the next attempt - time.Sleep(100 * time.Millisecond) - break + if err := checkConnected(t, serversInfo, j, true); err != nil { + return err } } - // All servers checked ok, we are done, otherwise, try again - // until time is up - if err == nil { - break - } - } - // Report error - if err != nil { - t.Fatalf("Error: %v", err) - } + return nil + }) }() - maxTime := time.Now().Add(2 * time.Second) - for time.Now().Before(maxTime) { - if s1.NumRoutes() > 0 { - time.Sleep(10 * time.Millisecond) - } else { - break - } - } + checkNumRoutes(t, s1, 0) } } @@ -436,37 +415,16 @@ func TestStressChainedSolicitWorks(t *testing.T) { serversInfo := []serverInfo{{s1, opts}, {s2, s2Opts}, {s3, s3Opts}, {s4, s4Opts}} - var err error - maxTime := time.Now().Add(5 * time.Second) - for time.Now().Before(maxTime) { + checkFor(t, 5*time.Second, 100*time.Millisecond, func() error { for j := 0; j < len(serversInfo); j++ { - err = checkConnected(t, serversInfo, j, false) - // If error, start this for loop from beginning - if err != nil { - // Sleep a bit before the next attempt - time.Sleep(100 * time.Millisecond) - break + if err := checkConnected(t, serversInfo, j, false); err != nil { + return err } } - // All servers checked ok, we are done, otherwise, try again - // until time is up - if err == nil { - break - } - } - // Report error - if err != nil { - t.Fatalf("Error: %v", err) - } + return nil + }) }() - maxTime := time.Now().Add(2 * time.Second) - for time.Now().Before(maxTime) { - if s1.NumRoutes() > 0 { - time.Sleep(10 * time.Millisecond) - } else { - break - } - } + checkNumRoutes(t, s1, 0) } } diff --git a/test/test_test.go b/test/test_test.go index 48af00b3..ac041747 100644 --- a/test/test_test.go +++ b/test/test_test.go @@ -18,8 +18,25 @@ import ( "strings" "sync" "testing" + "time" ) +func checkFor(t *testing.T, totalWait, sleepDur time.Duration, f func() error) { + t.Helper() + timeout := time.Now().Add(totalWait) + var err error + for time.Now().Before(timeout) { + err = f() + if err == nil { + return + } + time.Sleep(sleepDur) + } + if err != nil { + t.Fatal(err.Error()) + } +} + type dummyLogger struct { sync.Mutex msg string