diff --git a/server/client.go b/server/client.go index b99864f8..cfa0103b 100644 --- a/server/client.go +++ b/server/client.go @@ -202,8 +202,9 @@ type client struct { // Struct for PING initiation from the server. type pinfo struct { - tmr *time.Timer - out int + tmr *time.Timer + last time.Time + out int } // outbound holds pending data for a socket. @@ -1461,6 +1462,10 @@ func (c *client) processPing() { } c.sendPong() + // Record this to suppress us sending one if this + // is withing a given time interval for activity. + c.ping.last = time.Now() + // If not a CLIENT, we are done if c.kind != CLIENT { c.mu.Unlock() @@ -2635,10 +2640,15 @@ func (c *client) processPingTimer() { c.Debugf("%s Ping Timer", c.typeString()) - // If we have had activity within the PingInterval no - // need to send a ping. - if delta := time.Since(c.last); delta < c.srv.getOpts().PingInterval { - c.Debugf("Delaying PING due to activity %v ago", delta.Round(time.Second)) + // If we have had activity within the PingInterval then + // there is no need to send a ping. This can be client data + // or if we received a ping from the other side. + pingInterval := c.srv.getOpts().PingInterval + + if delta := time.Since(c.last); delta < pingInterval { + c.Debugf("Delaying PING due to client activity %v ago", delta.Round(time.Second)) + } else if delta := time.Since(c.ping.last); delta < pingInterval { + c.Debugf("Delaying PING due to remote ping %v ago", delta.Round(time.Second)) } else { // Check for violation if c.ping.out+1 > c.srv.getOpts().MaxPingsOut { @@ -2655,6 +2665,20 @@ func (c *client) processPingTimer() { c.setPingTimer() } +// Lock should be held +// We Randomize the first one by an offset up to 20%, e.g. 2m ~= max 24s. +// This is because the clients by default are usually setting same interval +// and we have alot of cross ping/pongs between clients and servers. +// We will now suppress the server ping/pong if we have received a client ping. +func (c *client) setFirstPingTimer(pingInterval time.Duration) { + if c.srv == nil { + return + } + addDelay := rand.Int63n(int64(pingInterval / 5)) + d := pingInterval + time.Duration(addDelay) + c.ping.tmr = time.AfterFunc(d, c.processPingTimer) +} + // Lock should be held func (c *client) setPingTimer() { if c.srv == nil { diff --git a/server/server.go b/server/server.go index 681fdccf..9811383f 100644 --- a/server/server.go +++ b/server/server.go @@ -1609,8 +1609,8 @@ func (s *Server) createClient(conn net.Conn) *client { // Do final client initialization - // Set the Ping timer - c.setPingTimer() + // Set the First Ping timer. + c.setFirstPingTimer(opts.PingInterval) // Spin up the read loop. s.startGoRoutine(func() { c.readLoop() }) diff --git a/test/ping_test.go b/test/ping_test.go index fa58c903..8807450c 100644 --- a/test/ping_test.go +++ b/test/ping_test.go @@ -187,3 +187,42 @@ func TestUnpromptedPong(t *testing.T) { t.Fatal("timeout: Expected to have connection closed") } } + +func TestPingSuppresion(t *testing.T) { + pingInterval := 100 * time.Millisecond + + opts := DefaultTestOptions + opts.Port = PING_TEST_PORT + opts.PingInterval = pingInterval + + s := RunServer(&opts) + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PING_TEST_PORT) + defer c.Close() + + connectTime := time.Now() + + send, expect := setupConn(t, c) + + expect(pingRe) + pingTime := time.Since(connectTime) + + // Should be > 100 but less then 120(ish) + if pingTime < pingInterval { + t.Fatalf("pingTime too low: %v", pingTime) + } + // +5 is just for fudging in case things are slow in the testing system. + if pingTime > pingInterval+(pingInterval/5)+5 { + t.Fatalf("pingTime too high: %v", pingTime) + } + + time.Sleep(pingInterval / 2) + + // Sending a PING should suppress. + send("PING\r\n") + expect(pongRe) + + // This will wait for + expectNothingTimeout(t, c, time.Now().Add(100*time.Millisecond)) +} diff --git a/test/test.go b/test/test.go index 0518a299..1455f5a8 100644 --- a/test/test.go +++ b/test/test.go @@ -350,8 +350,12 @@ func expectDisconnect(t *testing.T, c net.Conn) { } func expectNothing(t tLogger, c net.Conn) { + expectNothingTimeout(t, c, time.Now().Add(100*time.Millisecond)) +} + +func expectNothingTimeout(t tLogger, c net.Conn, dl time.Time) { expBuf := make([]byte, 32) - c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + c.SetReadDeadline(dl) n, err := c.Read(expBuf) c.SetReadDeadline(time.Time{}) if err == nil && n > 0 {