diff --git a/server/client.go b/server/client.go index b90c6eff..e1f28b07 100644 --- a/server/client.go +++ b/server/client.go @@ -3961,13 +3961,11 @@ func (c *client) teardownConn() { } if srv != nil { - // This is a route that disconnected, but we are not in lame duck mode... - if (len(connectURLs) > 0 || len(wsConnectURLs) > 0) && !srv.isLameDuckMode() { - // Unless disabled, possibly update the server's INFO protocol - // and send to clients that know how to handle async INFOs. - if !srv.getOpts().Cluster.NoAdvertise { - srv.removeConnectURLsAndSendINFOToClients(connectURLs, wsConnectURLs) - } + // If this is a route that disconnected, possibly send an INFO with + // the updated list of connect URLs to clients that know how to + // handle async INFOs. + if (len(connectURLs) > 0 || len(wsConnectURLs) > 0) && !srv.getOpts().Cluster.NoAdvertise { + srv.removeConnectURLsAndSendINFOToClients(connectURLs, wsConnectURLs) } // Unregister diff --git a/server/client_test.go b/server/client_test.go index 6fd4cb34..aaff319d 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -36,14 +36,16 @@ import ( ) type serverInfo struct { - ID string `json:"server_id"` - Host string `json:"host"` - Port uint `json:"port"` - Version string `json:"version"` - AuthRequired bool `json:"auth_required"` - TLSRequired bool `json:"tls_required"` - MaxPayload int64 `json:"max_payload"` - Headers bool `json:"headers"` + ID string `json:"server_id"` + Host string `json:"host"` + Port uint `json:"port"` + Version string `json:"version"` + AuthRequired bool `json:"auth_required"` + TLSRequired bool `json:"tls_required"` + MaxPayload int64 `json:"max_payload"` + Headers bool `json:"headers"` + ConnectURLs []string `json:"connect_urls,omitempty"` + LameDuckMode bool `json:"ldm,omitempty"` } type testAsyncClient struct { diff --git a/server/route.go b/server/route.go index 330ce80e..0a69a74b 100644 --- a/server/route.go +++ b/server/route.go @@ -400,51 +400,79 @@ func (c *client) processRouteInfo(info *Info) { } s := c.srv - remoteID := c.route.remoteID - - // Check if this is an INFO for gateways... - if info.Gateway != "" { - c.mu.Unlock() - // If this server has no gateway configured, report error and return. - if !s.gateway.enabled { - // FIXME: Should this be a Fatalf()? - s.Errorf("Received information about gateway %q from %s, but gateway is not configured", - info.Gateway, remoteID) - return - } - s.processGatewayInfoFromRoute(info, remoteID, c) - return - } - - // We receive an INFO from a server that informs us about another server, - // so the info.ID in the INFO protocol does not match the ID of this route. - if remoteID != "" && remoteID != info.ID { - c.mu.Unlock() - - // Process this implicit route. We will check that it is not an explicit - // route and/or that it has not been connected already. - s.processImplicitRoute(info) - return - } - - // Need to set this for the detection of the route to self to work - // in closeConnection(). - c.route.remoteID = info.ID - - // Get the route's proto version - c.opts.Protocol = info.Proto // Detect route to self. - if c.route.remoteID == s.info.ID { + if info.ID == s.info.ID { + // Need to set this so that the close does the right thing + c.route.remoteID = info.ID c.mu.Unlock() c.closeConnection(DuplicateRoute) return } + // If this is an async INFO from an existing route... + if c.flags.isSet(infoReceived) { + remoteID := c.route.remoteID + + // Check if this is an INFO for gateways... + if info.Gateway != "" { + c.mu.Unlock() + // If this server has no gateway configured, report error and return. + if !s.gateway.enabled { + // FIXME: Should this be a Fatalf()? + s.Errorf("Received information about gateway %q from %s, but gateway is not configured", + info.Gateway, remoteID) + return + } + s.processGatewayInfoFromRoute(info, remoteID, c) + return + } + + // We receive an INFO from a server that informs us about another server, + // so the info.ID in the INFO protocol does not match the ID of this route. + if remoteID != "" && remoteID != info.ID { + c.mu.Unlock() + + // Process this implicit route. We will check that it is not an explicit + // route and/or that it has not been connected already. + s.processImplicitRoute(info) + return + } + + var connectURLs []string + var wsConnectURLs []string + + // If we are notified that the remote is going into LDM mode, capture route's connectURLs. + if info.LameDuckMode { + connectURLs = c.route.connectURLs + wsConnectURLs = c.route.wsConnURLs + } else { + // If this is an update due to config reload on the remote server, + // need to possibly send local subs to the remote server. + c.updateRemoteRoutePerms(sl, info) + } + c.mu.Unlock() + + // If the remote is going into LDM and there are client connect URLs + // associated with this route and we are allowed to advertise, remove + // those URLs and update our clients. + if (len(connectURLs) > 0 || len(wsConnectURLs) > 0) && !s.getOpts().Cluster.NoAdvertise { + s.removeConnectURLsAndSendINFOToClients(connectURLs, wsConnectURLs) + } + return + } + + // Mark that the INFO protocol has been received, so we can detect updates. + c.flags.set(infoReceived) + + // Get the route's proto version + c.opts.Protocol = info.Proto + // Headers c.headers = supportsHeaders && info.Headers // Copy over important information. + c.route.remoteID = info.ID c.route.authRequired = info.AuthRequired c.route.tlsRequired = info.TLSRequired c.route.gatewayURL = info.GatewayURL @@ -456,14 +484,6 @@ func (c *client) processRouteInfo(info *Info) { // Compute the hash of this route based on remoteID c.route.hash = string(getHash(info.ID)) - // If this is an update due to config reload on the remote server, - // need to possibly send local subs to the remote server. - if c.flags.isSet(infoReceived) { - c.updateRemoteRoutePerms(sl, info) - c.mu.Unlock() - return - } - // Copy over permissions as well. c.opts.Import = info.Import c.opts.Export = info.Export @@ -483,10 +503,6 @@ func (c *client) processRouteInfo(info *Info) { c.route.url = url } - // Mark that the INFO protocol has been received. Will allow - // to detect INFO updates. - c.flags.set(infoReceived) - // Check to see if we have this remote already registered. // This can happen when both servers have routes to each other. c.mu.Unlock() @@ -577,6 +593,7 @@ func (s *Server) sendAsyncInfoToClients(regCli, wsCli bool) { if s.cproto == 0 || s.shutdown { return } + info := s.copyInfo() for _, c := range s.clients { c.mu.Lock() @@ -589,7 +606,7 @@ func (s *Server) sendAsyncInfoToClients(regCli, wsCli bool) { c.flags.isSet(firstPongSent) { // sendInfo takes care of checking if the connection is still // valid or not, so don't duplicate tests here. - c.enqueueProto(c.generateClientInfoJSON(s.copyInfo())) + c.enqueueProto(c.generateClientInfoJSON(info)) } c.mu.Unlock() } @@ -653,6 +670,11 @@ func (s *Server) forwardNewRouteInfoToKnownServers(info *Info) { s.mu.Lock() defer s.mu.Unlock() + // Note: nonce is not used in routes. + // That being said, the info we get is the initial INFO which + // contains a nonce, but we now forward this to existing routes, + // so clear it now. + info.Nonce = _EMPTY_ b, _ := json.Marshal(info) infoJSON := []byte(fmt.Sprintf(InfoProto, b)) @@ -1092,7 +1114,14 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { // Grab server variables s.mu.Lock() + // New proto wants a nonce (although not used in routes, that is, not signed in CONNECT) + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + s.routeInfo.Nonce = string(nonce) s.generateRouteInfoJSON() + // Clear now that it has been serialized. Will prevent nonce to be included in async INFO that we may send. + s.routeInfo.Nonce = _EMPTY_ infoJSON := s.routeInfoJSON authRequired := s.routeInfo.AuthRequired tlsRequired := s.routeInfo.TLSRequired diff --git a/server/server.go b/server/server.go index 0028a38b..38923cf2 100644 --- a/server/server.go +++ b/server/server.go @@ -82,6 +82,7 @@ type Info struct { Cluster string `json:"cluster,omitempty"` ClientConnectURLs []string `json:"connect_urls,omitempty"` // Contains URLs a client can connect to. WSConnectURLs []string `json:"ws_connect_urls,omitempty"` // Contains URLs a ws client can connect to. + LameDuckMode bool `json:"ldm,omitempty"` // Route Specific Import *SubjectPermission `json:"import,omitempty"` @@ -620,11 +621,6 @@ func (s *Server) checkResolvePreloads() { } func (s *Server) generateRouteInfoJSON() { - // New proto wants a nonce. - var raw [nonceLen]byte - nonce := raw[:] - s.generateNonce(nonce) - s.routeInfo.Nonce = string(nonce) b, _ := json.Marshal(s.routeInfo) pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)} s.routeInfoJSON = bytes.Join(pcs, []byte(" ")) @@ -1835,13 +1831,6 @@ func (s *Server) copyInfo() Info { if len(info.WSConnectURLs) > 0 { info.WSConnectURLs = append([]string(nil), s.info.WSConnectURLs...) } - if s.nonceRequired() { - // Nonce handling - var raw [nonceLen]byte - nonce := raw[:] - s.generateNonce(nonce) - info.Nonce = string(nonce) - } return info } @@ -1864,6 +1853,13 @@ func (s *Server) createClient(conn net.Conn, ws *websocket) *client { // Grab JSON info string s.mu.Lock() info := s.copyInfo() + if s.nonceRequired() { + // Nonce handling + var raw [nonceLen]byte + nonce := raw[:] + s.generateNonce(nonce) + info.Nonce = string(nonce) + } c.nonce = []byte(info.Nonce) s.totalClients++ s.mu.Unlock() @@ -2691,6 +2687,8 @@ func (s *Server) lameDuckMode() { s.ldmCh = make(chan bool, 1) s.listener.Close() s.listener = nil + s.sendLDMToRoutes() + s.sendLDMToClients() s.mu.Unlock() // Wait for accept loop to be done to make sure that no new @@ -2770,6 +2768,50 @@ func (s *Server) lameDuckMode() { s.Shutdown() } +// Send an INFO update to routes with the indication that this server is in LDM mode. +// Server lock is held on entry. +func (s *Server) sendLDMToRoutes() { + s.routeInfo.LameDuckMode = true + s.generateRouteInfoJSON() + for _, r := range s.routes { + r.mu.Lock() + r.enqueueProto(s.routeInfoJSON) + r.mu.Unlock() + } + // Clear now so that we notify only once, should we have to send other INFOs. + s.routeInfo.LameDuckMode = false +} + +// Send an INFO update to clients with the indication that this server is in +// LDM mode and with only URLs of other nodes. +// Server lock is held on entry. +func (s *Server) sendLDMToClients() { + s.info.LameDuckMode = true + // Clear this so that if there are further updates, we don't send our URLs. + s.clientConnectURLs = s.clientConnectURLs[:0] + if s.websocket.connectURLs != nil { + s.websocket.connectURLs = s.websocket.connectURLs[:0] + } + // Reset content first. + s.info.ClientConnectURLs = s.info.ClientConnectURLs[:0] + s.info.WSConnectURLs = s.info.WSConnectURLs[:0] + // Only add the other nodes if we are allowed to. + if !s.getOpts().Cluster.NoAdvertise { + for url := range s.clientConnectURLsMap { + s.info.ClientConnectURLs = append(s.info.ClientConnectURLs, url) + } + for url := range s.websocket.connectURLsMap { + s.info.WSConnectURLs = append(s.info.WSConnectURLs, url) + } + } + // Send to all registered clients that support async INFO protocols. + s.sendAsyncInfoToClients(true, true) + // We now clear the info.LameDuckMode flag so that if there are + // cluster updates and we send the INFO, we don't have the boolean + // set which would cause multiple LDM notifications to clients. + s.info.LameDuckMode = false +} + // If given error is a net.Error and is temporary, sleeps for the given // delay and double it, but cap it to ACCEPT_MAX_SLEEP. The sleep is // interrupted if the server is shutdown. diff --git a/server/server_test.go b/server/server_test.go index e55bc689..9d377734 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -14,16 +14,20 @@ package server import ( + "bufio" "bytes" "context" "crypto/tls" + "encoding/json" "flag" "fmt" "io/ioutil" "net" "net/url" "os" + "reflect" "runtime" + "sort" "strings" "sync" "sync/atomic" @@ -873,6 +877,153 @@ func TestLameDuckMode(t *testing.T) { }) } +func TestLameDuckModeInfo(t *testing.T) { + // Ensure that initial delay is set very high so that we can + // check that some events occur as expected before the client + // is disconnected. + atomic.StoreInt64(&lameDuckModeInitialDelay, int64(5*time.Second)) + defer atomic.StoreInt64(&lameDuckModeInitialDelay, lameDuckModeDefaultInitialDelay) + + optsA := DefaultOptions() + optsA.Cluster.Host = "127.0.0.1" + optsA.Cluster.Port = -1 + optsA.LameDuckDuration = 50 * time.Millisecond + optsA.DisableShortFirstPing = true + srvA := RunServer(optsA) + defer srvA.Shutdown() + + curla := fmt.Sprintf("127.0.0.1:%d", optsA.Port) + c, err := net.Dial("tcp", curla) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + defer c.Close() + + client := bufio.NewReaderSize(c, maxBufSize) + + getInfo := func() *serverInfo { + t.Helper() + l, err := client.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving info from server: %v\n", err) + } + var info serverInfo + if err = json.Unmarshal([]byte(l[5:]), &info); err != nil { + t.Fatalf("Could not parse INFO json: %v\n", err) + } + return &info + } + getInfo() + c.Write([]byte("CONNECT {\"protocol\":1,\"verbose\":false}\r\nPING\r\n")) + client.ReadString('\n') + + optsB := DefaultOptions() + optsB.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", srvA.ClusterAddr().Port)) + srvB := RunServer(optsB) + defer srvB.Shutdown() + + checkClusterFormed(t, srvA, srvB) + + checkConnectURLs := func(expected []string) *serverInfo { + t.Helper() + sort.Strings(expected) + si := getInfo() + sort.Strings(si.ConnectURLs) + if !reflect.DeepEqual(expected, si.ConnectURLs) { + t.Fatalf("Expected %q, got %q", expected, si.ConnectURLs) + } + return si + } + + curlb := fmt.Sprintf("127.0.0.1:%d", optsB.Port) + expected := []string{curla, curlb} + checkConnectURLs(expected) + + optsC := DefaultOptions() + optsC.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", srvA.ClusterAddr().Port)) + srvC := RunServer(optsC) + defer srvC.Shutdown() + + checkClusterFormed(t, srvA, srvB, srvC) + + curlc := fmt.Sprintf("127.0.0.1:%d", optsC.Port) + expected = append(expected, curlc) + checkConnectURLs(expected) + + optsD := DefaultOptions() + optsD.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", srvA.ClusterAddr().Port)) + srvD := RunServer(optsD) + defer srvD.Shutdown() + + checkClusterFormed(t, srvA, srvB, srvC, srvD) + + curld := fmt.Sprintf("127.0.0.1:%d", optsD.Port) + expected = append(expected, curld) + checkConnectURLs(expected) + + // Now lame duck server A and C. We should have client connected to A + // receive info that A is in LDM without A's URL, but also receive + // an update with C's URL gone. + // But first we need to create a client to C because otherwise the + // LDM signal will just shut it down because it would have no client. + nc, err := nats.Connect(srvC.ClientURL()) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer nc.Close() + nc.Flush() + + start := time.Now() + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + srvA.lameDuckMode() + }() + + expected = []string{curlb, curlc, curld} + si := checkConnectURLs(expected) + if !si.LameDuckMode { + t.Fatal("Expected LameDuckMode to be true, it was not") + } + + // Start LDM for server C. This should send an update to A + // which in turn should remove C from the list of URLs and + // update its client. + go func() { + defer wg.Done() + srvC.lameDuckMode() + }() + + expected = []string{curlb, curld} + si = checkConnectURLs(expected) + // This update should not say that it is LDM. + if si.LameDuckMode { + t.Fatal("Expected LameDuckMode to be false, it was true") + } + + // Now shutdown D, and we also should get an update. + srvD.Shutdown() + + expected = []string{curlb} + si = checkConnectURLs(expected) + // This update should not say that it is LDM. + if si.LameDuckMode { + t.Fatal("Expected LameDuckMode to be false, it was true") + } + if time.Since(start) > 2*time.Second { + t.Fatalf("Did not get the expected events prior of server A and C shutting down") + } + + c.Close() + nc.Close() + // Don't need to wait for actual disconnect of clients, + // so shutdown the servers now. + srvA.Shutdown() + srvC.Shutdown() + wg.Wait() +} + func TestServerValidateGatewaysOptions(t *testing.T) { baseOpt := testDefaultOptionsForGateway("A") u, _ := url.Parse("host:5222")