diff --git a/README.md b/README.md index 02bd2bb7..304bd862 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ Server Options: -ms,--https_port Use port for https monitoring -c, --config Configuration file -sl,--signal [=] Send signal to gnatsd process (stop, quit, reopen, reload) + --client_advertise Client URL to advertise to other servers Logging Options: -l, --log File to redirect log output @@ -178,6 +179,7 @@ Cluster Options: --routes Routes to solicit and connect --cluster Cluster URL for solicited routes --no_advertise Advertise known cluster IPs to clients + --cluster_advertise Cluster URL to advertise to other servers --connect_retries For implicit routes, number of connect retries @@ -286,6 +288,23 @@ The `--routes` flag specifies the NATS URL for one or more servers in the cluste Previous releases required you to build the complete mesh using the `--routes` flag. To define your cluster in the current release, please follow the "Basic example" as described below. +Suppose that server srvA is connected to server srvB. A bi-directional route exists between srvA and srvB. A new server, srvC, connects to srvA.
+When accepting the connection, srvA will gossip the address of srvC to srvB so that srvB connects to srvC, completing the full mesh.
+The URL that srvB will use to connect to srvC is the result of the TCP remote address that srvA got from its connection to srvC. + +It is possible to advertise with `--cluster_advertise` a different address than the one used in `--cluster`. + +In the previous example, if srvC uses a `--cluster_adertise` URL, this is what srvA will gossip to srvB in order to connect to srvC. + +NOTE: The advertise address should really result in a connection to srvC. Providing an address that would result in a connection to a different NATS Server would prevent the formation of a full-mesh cluster! + +As part of the gossip protocol, a server will also send to the other servers the URL clients should connect to.
+The URL is the one defined in the `listen` parameter, or, if 0.0.0.0 or :: is specified, the resolved non-local IP addresses for the "any" interface. + +If those addresses are not reacheable from the outside world where the clients are running, the administrator can use the `--no_advertise` option to disable servers gossiping those URLs.
+Another option is to provide a `--client_advertise` URL to use instead. If this option is specified (and advertise has not been disabled), then the server will advertise this URL to other servers instead of its `listen` address (or resolved IPs when listen is 0.0.0.0 or ::). + + ### Basic example NATS makes building the full mesh easy. Simply designate a server to be a *seed* server. All other servers in the cluster simply specify the *seed* server as its server's routes option as indicated below. diff --git a/main.go b/main.go index 810580c6..c33a0bed 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ Server Options: -ms,--https_port Use port for https monitoring -c, --config Configuration file -sl,--signal [=] Send signal to gnatsd process (stop, quit, reopen, reload) + --client_advertise Client URL to advertise to other servers Logging Options: -l, --log File to redirect log output @@ -47,6 +48,7 @@ Cluster Options: --routes Routes to solicit and connect --cluster Cluster URL for solicited routes --no_advertise Advertise known cluster IPs to clients + --cluster_advertise Cluster URL to advertise to other servers --connect_retries For implicit routes, number of connect retries diff --git a/server/log_test.go b/server/log_test.go index f93f2477..113495ef 100644 --- a/server/log_test.go +++ b/server/log_test.go @@ -8,6 +8,7 @@ import ( "os" "runtime" "strings" + "sync" "testing" "github.com/nats-io/gnatsd/logger" @@ -64,28 +65,41 @@ func TestSetLogger(t *testing.T) { } type DummyLogger struct { + sync.Mutex msg string } -func (dl *DummyLogger) checkContent(t *testing.T, expectedStr string) { - if dl.msg != expectedStr { - stackFatalf(t, "Expected log to be: %v, got %v", expectedStr, dl.msg) +func (l *DummyLogger) checkContent(t *testing.T, expectedStr string) { + l.Lock() + defer l.Unlock() + if l.msg != expectedStr { + stackFatalf(t, "Expected log to be: %v, got %v", expectedStr, l.msg) } } func (l *DummyLogger) Noticef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() l.msg = fmt.Sprintf(format, v...) } func (l *DummyLogger) Errorf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() l.msg = fmt.Sprintf(format, v...) } func (l *DummyLogger) Fatalf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() l.msg = fmt.Sprintf(format, v...) } func (l *DummyLogger) Debugf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() l.msg = fmt.Sprintf(format, v...) } func (l *DummyLogger) Tracef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() l.msg = fmt.Sprintf(format, v...) } diff --git a/server/opts.go b/server/opts.go index ea282ae6..75484857 100644 --- a/server/opts.go +++ b/server/opts.go @@ -30,49 +30,51 @@ type ClusterOpts struct { TLSTimeout float64 `json:"-"` TLSConfig *tls.Config `json:"-"` ListenStr string `json:"-"` + Advertise string `json:"-"` NoAdvertise bool `json:"-"` ConnectRetries int `json:"-"` } // Options block for gnatsd server. type Options struct { - ConfigFile string `json:"-"` - Host string `json:"addr"` - Port int `json:"port"` - Trace bool `json:"-"` - Debug bool `json:"-"` - NoLog bool `json:"-"` - NoSigs bool `json:"-"` - Logtime bool `json:"-"` - MaxConn int `json:"max_connections"` - Users []*User `json:"-"` - Username string `json:"-"` - Password string `json:"-"` - Authorization string `json:"-"` - PingInterval time.Duration `json:"ping_interval"` - MaxPingsOut int `json:"ping_max"` - HTTPHost string `json:"http_host"` - HTTPPort int `json:"http_port"` - HTTPSPort int `json:"https_port"` - AuthTimeout float64 `json:"auth_timeout"` - MaxControlLine int `json:"max_control_line"` - MaxPayload int `json:"max_payload"` - Cluster ClusterOpts `json:"cluster"` - ProfPort int `json:"-"` - PidFile string `json:"-"` - LogFile string `json:"-"` - Syslog bool `json:"-"` - RemoteSyslog string `json:"-"` - Routes []*url.URL `json:"-"` - RoutesStr string `json:"-"` - TLSTimeout float64 `json:"tls_timeout"` - TLS bool `json:"-"` - TLSVerify bool `json:"-"` - TLSCert string `json:"-"` - TLSKey string `json:"-"` - TLSCaCert string `json:"-"` - TLSConfig *tls.Config `json:"-"` - WriteDeadline time.Duration `json:"-"` + ConfigFile string `json:"-"` + Host string `json:"addr"` + Port int `json:"port"` + ClientAdvertise string `json:"-"` + Trace bool `json:"-"` + Debug bool `json:"-"` + NoLog bool `json:"-"` + NoSigs bool `json:"-"` + Logtime bool `json:"-"` + MaxConn int `json:"max_connections"` + Users []*User `json:"-"` + Username string `json:"-"` + Password string `json:"-"` + Authorization string `json:"-"` + PingInterval time.Duration `json:"ping_interval"` + MaxPingsOut int `json:"ping_max"` + HTTPHost string `json:"http_host"` + HTTPPort int `json:"http_port"` + HTTPSPort int `json:"https_port"` + AuthTimeout float64 `json:"auth_timeout"` + MaxControlLine int `json:"max_control_line"` + MaxPayload int `json:"max_payload"` + Cluster ClusterOpts `json:"cluster"` + ProfPort int `json:"-"` + PidFile string `json:"-"` + LogFile string `json:"-"` + Syslog bool `json:"-"` + RemoteSyslog string `json:"-"` + Routes []*url.URL `json:"-"` + RoutesStr string `json:"-"` + TLSTimeout float64 `json:"tls_timeout"` + TLS bool `json:"-"` + TLSVerify bool `json:"-"` + TLSCert string `json:"-"` + TLSKey string `json:"-"` + TLSCaCert string `json:"-"` + TLSConfig *tls.Config `json:"-"` + WriteDeadline time.Duration `json:"-"` CustomClientAuthentication Authentication `json:"-"` CustomRouterAuthentication Authentication `json:"-"` @@ -202,6 +204,8 @@ func (o *Options) ProcessConfigFile(configFile string) error { } o.Host = hp.host o.Port = hp.port + case "client_advertise": + o.ClientAdvertise = v.(string) case "port": o.Port = int(v.(int64)) case "host", "net": @@ -387,6 +391,8 @@ func parseCluster(cm map[string]interface{}, opts *Options) error { opts.Cluster.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert opts.Cluster.TLSConfig.RootCAs = opts.Cluster.TLSConfig.ClientCAs opts.Cluster.TLSTimeout = tc.Timeout + case "cluster_advertise", "advertise": + opts.Cluster.Advertise = mv.(string) case "no_advertise": opts.Cluster.NoAdvertise = mv.(bool) case "connect_retries": @@ -720,6 +726,9 @@ func MergeOptions(fileOpts, flagOpts *Options) *Options { if flagOpts.Host != "" { opts.Host = flagOpts.Host } + if flagOpts.ClientAdvertise != "" { + opts.ClientAdvertise = flagOpts.ClientAdvertise + } if flagOpts.Username != "" { opts.Username = flagOpts.Username } @@ -759,6 +768,9 @@ func MergeOptions(fileOpts, flagOpts *Options) *Options { if flagOpts.Cluster.ConnectRetries != 0 { opts.Cluster.ConnectRetries = flagOpts.Cluster.ConnectRetries } + if flagOpts.Cluster.Advertise != "" { + opts.Cluster.Advertise = flagOpts.Cluster.Advertise + } if flagOpts.RoutesStr != "" { mergeRoutes(&opts, flagOpts) } @@ -942,6 +954,7 @@ func ConfigureOptions(fs *flag.FlagSet, args []string, printVersion, printHelp, fs.StringVar(&opts.Host, "addr", "", "Network host to listen on.") fs.StringVar(&opts.Host, "a", "", "Network host to listen on.") fs.StringVar(&opts.Host, "net", "", "Network host to listen on.") + fs.StringVar(&opts.ClientAdvertise, "client_advertise", "", "Client URL to advertise to other servers.") fs.BoolVar(&opts.Debug, "D", false, "Enable Debug logging.") fs.BoolVar(&opts.Debug, "debug", false, "Enable Debug logging.") fs.BoolVar(&opts.Trace, "V", false, "Enable Trace logging.") @@ -974,6 +987,7 @@ func ConfigureOptions(fs *flag.FlagSet, args []string, printVersion, printHelp, fs.StringVar(&opts.RoutesStr, "routes", "", "Routes to actively solicit a connection.") fs.StringVar(&opts.Cluster.ListenStr, "cluster", "", "Cluster url from which members can solicit routes.") fs.StringVar(&opts.Cluster.ListenStr, "cluster_listen", "", "Cluster url from which members can solicit routes.") + fs.StringVar(&opts.Cluster.Advertise, "cluster_advertise", "", "Cluster URL to advertise to other servers.") fs.BoolVar(&opts.Cluster.NoAdvertise, "no_advertise", false, "Advertise known cluster IPs to clients.") fs.IntVar(&opts.Cluster.ConnectRetries, "connect_retries", 0, "For implicit routes, number of connect retries") fs.BoolVar(&showTLSHelp, "help_tls", false, "TLS help.") @@ -1171,6 +1185,7 @@ func overrideCluster(opts *Options) error { opts.Cluster.Username = "" opts.Cluster.Password = "" } + return nil } diff --git a/server/reload.go b/server/reload.go index 0593cdbf..8d2ca2e8 100644 --- a/server/reload.go +++ b/server/reload.go @@ -237,7 +237,7 @@ func (c *clusterOption) Apply(server *Server) { server.routeInfo.SSLRequired = tlsRequired server.routeInfo.TLSVerify = tlsRequired server.routeInfo.AuthRequired = c.newValue.Username != "" - server.generateRouteInfoJSON() + server.setRouteInfoHostPortAndIP() server.mu.Unlock() server.Noticef("Reloaded: cluster") } @@ -407,6 +407,20 @@ func (w *writeDeadlineOption) Apply(server *Server) { server.Noticef("Reloaded: write_deadline = %s", w.newValue) } +// clientAdvertiseOption implements the option interface for the `client_advertise` setting. +type clientAdvertiseOption struct { + noopOption + newValue string +} + +// Apply the setting by updating the server info and regenerate the infoJSON byte array. +func (c *clientAdvertiseOption) Apply(server *Server) { + server.mu.Lock() + server.setInfoHostPortAndGenerateJSON() + server.mu.Unlock() + server.Noticef("Reload: client_advertise = %s", c.newValue) +} + // Reload reads the current configuration file and applies any supported // changes. This returns an error if the server was not started with a config // file or an option which doesn't support hot-swapping was changed. @@ -422,11 +436,25 @@ func (s *Server) Reload() error { // TODO: Dump previous good config to a .bak file? return err } + clientOrgPort := s.clientActualPort + clusterOrgPort := s.clusterActualPort s.mu.Unlock() // Apply flags over config file settings. newOpts = MergeOptions(newOpts, FlagSnapshot) processOptions(newOpts) + + // processOptions sets Port to 0 if set to -1 (RANDOM port) + // If that's the case, set it to the saved value when the accept loop was + // created. + if newOpts.Port == 0 { + newOpts.Port = clientOrgPort + } + // We don't do that for cluster, so check against -1. + if newOpts.Cluster.Port == -1 { + newOpts.Cluster.Port = clusterOrgPort + } + if err := s.reloadOptions(newOpts); err != nil { return err } @@ -518,6 +546,15 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { diffOpts = append(diffOpts, &maxPingsOutOption{newValue: newValue.(int)}) case "writedeadline": diffOpts = append(diffOpts, &writeDeadlineOption{newValue: newValue.(time.Duration)}) + case "clientadvertise": + cliAdv := newValue.(string) + if cliAdv != "" { + // Validate ClientAdvertise syntax + if _, _, err := parseHostPort(cliAdv, 0); err != nil { + return nil, fmt.Errorf("invalid ClientAdvertise value of %s, err=%v", cliAdv, err) + } + } + diffOpts = append(diffOpts, &clientAdvertiseOption{newValue: cliAdv}) case "nolog": // Ignore NoLog option since it's not parsed and only used in // testing. @@ -612,6 +649,12 @@ func validateClusterOpts(old, new ClusterOpts) error { return fmt.Errorf("Config reload not supported for cluster port: old=%d, new=%d", old.Port, new.Port) } + // Validate Cluster.Advertise syntax + if new.Advertise != "" { + if _, _, err := parseHostPort(new.Advertise, 0); err != nil { + return fmt.Errorf("invalid Cluster.Advertise value of %s, err=%v", new.Advertise, err) + } + } return nil } diff --git a/server/reload_test.go b/server/reload_test.go index 18ee6d66..5861c3fd 100644 --- a/server/reload_test.go +++ b/server/reload_test.go @@ -3,6 +3,7 @@ package server import ( + "encoding/json" "fmt" "io/ioutil" "net" @@ -187,7 +188,7 @@ func TestConfigReload(t *testing.T) { if err := ioutil.WriteFile(platformConf, content, 0666); err != nil { t.Fatalf("Unable to write config file: %v", err) } - server, opts, config := newServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/test.conf") + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/test.conf") defer os.Remove(config) defer server.Shutdown() @@ -200,6 +201,7 @@ func TestConfigReload(t *testing.T) { AuthTimeout: 1.0, Debug: false, Trace: false, + NoLog: true, Logtime: false, MaxControlLine: 1024, MaxPayload: 1048576, @@ -209,12 +211,12 @@ func TestConfigReload(t *testing.T) { WriteDeadline: 2 * time.Second, Cluster: ClusterOpts{ Host: "localhost", - Port: -1, + Port: server.ClusterAddr().Port, }, } processOptions(golden) - if !reflect.DeepEqual(golden, server.getOpts()) { + if !reflect.DeepEqual(golden, opts) { t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", golden, opts) } @@ -1372,6 +1374,166 @@ func TestConfigReloadClusterRoutes(t *testing.T) { } } +func TestConfigReloadClusterAdvertise(t *testing.T) { + conf := "routeadv.conf" + if err := ioutil.WriteFile(conf, []byte(` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + } + `), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + defer os.Remove(conf) + opts, err := ProcessConfigFile(conf) + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + opts.NoLog = true + s := RunServer(opts) + defer s.Shutdown() + + orgClusterPort := s.ClusterAddr().Port + + updateConfig := func(content string) { + if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil { + stackFatalf(t, "Error creating config file: %v", err) + } + if err := s.Reload(); err != nil { + stackFatalf(t, "Error on reload: %v", err) + } + } + + verify := func(expectedHost string, expectedPort int, expectedIP string) { + s.mu.Lock() + routeInfo := s.routeInfo + routeInfoJSON := Info{} + err = json.Unmarshal(s.routeInfoJSON[5:], &routeInfoJSON) // Skip "INFO " + s.mu.Unlock() + if err != nil { + t.Fatalf("Error on Unmarshal: %v", err) + } + if routeInfo.Host != expectedHost || routeInfo.Port != expectedPort || routeInfo.IP != expectedIP { + t.Fatalf("Expected host/port/IP to be %s:%v, %q, got %s:%d, %q", + expectedHost, expectedPort, expectedIP, routeInfo.Host, routeInfo.Port, routeInfo.IP) + } + // Check that server routeInfoJSON was updated too + if !reflect.DeepEqual(routeInfo, routeInfoJSON) { + t.Fatalf("Expected routeInfoJSON to be %+v, got %+v", routeInfo, routeInfoJSON) + } + } + + // Update config with cluster_advertise + updateConfig(` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + cluster_advertise: "me:1" + } + `) + verify("me", 1, "nats-route://me:1/") + + // Update config with cluster_advertise (no port specified) + updateConfig(` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + cluster_advertise: "me" + } + `) + verify("me", orgClusterPort, fmt.Sprintf("nats-route://me:%d/", orgClusterPort)) + + // Update config with cluster_advertise (-1 port specified) + updateConfig(` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + cluster_advertise: "me:-1" + } + `) + verify("me", orgClusterPort, fmt.Sprintf("nats-route://me:%d/", orgClusterPort)) + + // Update to remove cluster_advertise + updateConfig(` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + } + `) + verify("0.0.0.0", orgClusterPort, "") +} + +func TestConfigReloadClientAdvertise(t *testing.T) { + conf := "clientadv.conf" + if err := ioutil.WriteFile(conf, []byte(`listen: "0.0.0.0:-1"`), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + defer os.Remove(conf) + opts, err := ProcessConfigFile(conf) + if err != nil { + stackFatalf(t, "Error processing config file: %v", err) + } + opts.NoLog = true + s := RunServer(opts) + defer s.Shutdown() + + orgPort := s.Addr().(*net.TCPAddr).Port + + updateConfig := func(content string) { + if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil { + stackFatalf(t, "Error creating config file: %v", err) + } + if err := s.Reload(); err != nil { + stackFatalf(t, "Error on reload: %v", err) + } + } + + verify := func(expectedHost string, expectedPort int) { + s.mu.Lock() + info := s.info + infoJSON := Info{clientConnectURLs: make(map[string]struct{})} + err := json.Unmarshal(s.infoJSON[5:len(s.infoJSON)-2], &infoJSON) // Skip INFO + s.mu.Unlock() + if err != nil { + stackFatalf(t, "Error on Unmarshal: %v", err) + } + if info.Host != expectedHost || info.Port != expectedPort { + stackFatalf(t, "Expected host/port to be %s:%d, got %s:%d", + expectedHost, expectedPort, info.Host, info.Port) + } + // Check that server infoJSON was updated too + if !reflect.DeepEqual(info, infoJSON) { + stackFatalf(t, "Expected infoJSON to be %+v, got %+v", info, infoJSON) + } + } + + // Update config with ClientAdvertise (port specified) + updateConfig(` + listen: "0.0.0.0:-1" + client_advertise: "me:1" + `) + verify("me", 1) + + // Update config with ClientAdvertise (no port specified) + updateConfig(` + listen: "0.0.0.0:-1" + client_advertise: "me" + `) + verify("me", orgPort) + + // Update config with ClientAdvertise (-1 port specified) + updateConfig(` + listen: "0.0.0.0:-1" + client_advertise: "me:-1" + `) + verify("me", orgPort) + + // Now remove ClientAdvertise to check that original values + // are restored. + updateConfig(`listen: "0.0.0.0:-1"`) + verify("0.0.0.0", orgPort) +} + // Ensure Reload supports changing the max connections. Test this by starting a // server with no max connections, connecting two clients, reloading with a // max connections of one, and ensuring one client is disconnected. diff --git a/server/route.go b/server/route.go index bd4fa25d..20f6d75f 100644 --- a/server/route.go +++ b/server/route.go @@ -145,16 +145,22 @@ func (c *client) processRouteInfo(info *Info) { // sendInfo will be false if the route that we just accepted // is the only route there is. if sendInfo { - // Need to get the remote IP address. - c.mu.Lock() - switch conn := c.nc.(type) { - case *net.TCPConn, *tls.Conn: - addr := conn.RemoteAddr().(*net.TCPAddr) - info.IP = fmt.Sprintf("nats-route://%s/", net.JoinHostPort(addr.IP.String(), strconv.Itoa(info.Port))) - default: - info.IP = c.route.url.String() + // The incoming INFO from the route will have IP set + // if it has Cluster.Advertise. In that case, use that + // otherwise contruct it from the remote TCP address. + if info.IP == "" { + // Need to get the remote IP address. + c.mu.Lock() + switch conn := c.nc.(type) { + case *net.TCPConn, *tls.Conn: + addr := conn.RemoteAddr().(*net.TCPAddr) + info.IP = fmt.Sprintf("nats-route://%s/", net.JoinHostPort(addr.IP.String(), + strconv.Itoa(info.Port))) + default: + info.IP = c.route.url.String() + } + c.mu.Unlock() } - c.mu.Unlock() // Now let the known servers know about this new route s.forwardNewRouteInfoToKnownServers(info) } @@ -612,6 +618,12 @@ func (s *Server) broadcastUnSubscribe(sub *subscription) { } func (s *Server) routeAcceptLoop(ch chan struct{}) { + defer func() { + if ch != nil { + close(ch) + } + }() + // Snapshot server options. opts := s.getOpts() @@ -631,19 +643,16 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) { s.Noticef("Listening for route connections on %s", hp) l, e := net.Listen("tcp", hp) if e != nil { - // We need to close this channel to avoid a deadlock - close(ch) s.Fatalf("Error listening on router port: %d - %v", opts.Cluster.Port, e) return } + s.mu.Lock() // Check for TLSConfig tlsReq := opts.Cluster.TLSConfig != nil info := Info{ ID: s.info.ID, Version: s.info.Version, - Host: opts.Cluster.Host, - Port: l.Addr().(*net.TCPAddr).Port, AuthRequired: false, TLSRequired: tlsReq, SSLRequired: tlsReq, @@ -651,20 +660,33 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) { MaxPayload: s.info.MaxPayload, ClientConnectURLs: clientConnectURLs, } + // If we have selected a random port... + if port == 0 { + // Write resolved port back to options. + opts.Cluster.Port = l.Addr().(*net.TCPAddr).Port + } + // Keep track of actual listen port. This will be needed in case of + // config reload. + s.clusterActualPort = opts.Cluster.Port // Check for Auth items if opts.Cluster.Username != "" { info.AuthRequired = true } s.routeInfo = info - s.generateRouteInfoJSON() - + // Possibly override Host/Port and set IP based on Cluster.Advertise + if err := s.setRouteInfoHostPortAndIP(); err != nil { + s.Fatalf("Error setting route INFO with Cluster.Advertise value of %s, err=%v", s.opts.Cluster.Advertise, err) + l.Close() + s.mu.Unlock() + return + } // Setup state that can enable shutdown - s.mu.Lock() s.routeListener = l s.mu.Unlock() // Let them know we are up close(ch) + ch = nil tmpDelay := ACCEPT_MIN_SLEEP @@ -694,6 +716,26 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) { s.done <- true } +// Similar to setInfoHostPortAndGenerateJSON, but for routeInfo. +func (s *Server) setRouteInfoHostPortAndIP() error { + if s.opts.Cluster.Advertise != "" { + advHost, advPort, err := parseHostPort(s.opts.Cluster.Advertise, s.opts.Cluster.Port) + if err != nil { + return err + } + s.routeInfo.Host = advHost + s.routeInfo.Port = advPort + s.routeInfo.IP = fmt.Sprintf("nats-route://%s/", net.JoinHostPort(advHost, strconv.Itoa(advPort))) + } else { + s.routeInfo.Host = s.opts.Cluster.Host + s.routeInfo.Port = s.opts.Cluster.Port + s.routeInfo.IP = "" + } + // (re)generate the routeInfoJSON byte array + s.generateRouteInfoJSON() + return nil +} + // StartRouting will start the accept loop on the cluster host:port // and will actively try to connect to listed routes. func (s *Server) StartRouting(clientListenReady chan struct{}) { diff --git a/server/routes_test.go b/server/routes_test.go index 1b396d35..6d7a367a 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -53,6 +53,130 @@ func TestRouteConfig(t *testing.T) { } } +func TestClusterAdvertise(t *testing.T) { + lst, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error starting listener: %v", err) + } + ch := make(chan error) + go func() { + c, err := lst.Accept() + if err != nil { + ch <- err + return + } + c.Close() + ch <- nil + }() + + optsA, _ := ProcessConfigFile("./configs/seed.conf") + optsA.NoSigs, optsA.NoLog = true, true + srvA := RunServer(optsA) + defer srvA.Shutdown() + + srvARouteURL := fmt.Sprintf("nats://%s:%d", optsA.Cluster.Host, srvA.ClusterAddr().Port) + optsB := nextServerOpts(optsA) + optsB.Routes = RoutesFromStr(srvARouteURL) + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + // Wait for these 2 to connect to each other + checkClusterFormed(t, srvA, srvB) + + // Now start server C that connects to A. A should ask B to connect to C, + // based on C's URL. But since C configures a Cluster.Advertise, it will connect + // to our listener. + optsC := nextServerOpts(optsB) + optsC.Cluster.Advertise = lst.Addr().String() + optsC.ClientAdvertise = "me:1" + optsC.Routes = RoutesFromStr(srvARouteURL) + + srvC := RunServer(optsC) + defer srvC.Shutdown() + + select { + case e := <-ch: + if e != nil { + t.Fatalf("Error: %v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Test timed out") + } +} + +func TestClusterAdvertiseErrorOnStartup(t *testing.T) { + opts := DefaultOptions() + // Set invalid address + opts.Cluster.Advertise = "addr:::123" + s := New(opts) + defer s.Shutdown() + dl := &DummyLogger{} + s.SetLogger(dl, false, false) + + // Start will keep running, so start in a go-routine. + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + s.Start() + wg.Done() + }() + msg := "" + ok := false + timeout := time.Now().Add(2 * time.Second) + for time.Now().Before(timeout) { + dl.Lock() + msg = dl.msg + dl.Unlock() + if strings.Contains(msg, "Cluster.Advertise") { + ok = true + break + } + } + if !ok { + t.Fatalf("Did not get expected error, got %v", msg) + } + s.Shutdown() + wg.Wait() +} + +func TestClientAdvertise(t *testing.T) { + optsA, _ := ProcessConfigFile("./configs/seed.conf") + optsA.NoSigs, optsA.NoLog = true, true + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + optsB := nextServerOpts(optsA) + optsB.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsA.Cluster.Host, optsA.Cluster.Port)) + optsB.ClientAdvertise = "me:1" + srvB := RunServer(optsB) + defer srvB.Shutdown() + + checkClusterFormed(t, srvA, srvB) + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", optsA.Host, optsA.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + timeout := time.Now().Add(time.Second) + good := false + for time.Now().Before(timeout) { + ds := nc.DiscoveredServers() + if len(ds) == 1 { + if ds[0] == "nats://me:1" { + good = true + break + } + } + time.Sleep(15 * time.Millisecond) + } + if !good { + t.Fatalf("Did not get expected discovered servers: %v", nc.DiscoveredServers()) + } +} + func TestServerRoutesWithClients(t *testing.T) { optsA, _ := ProcessConfigFile("./configs/srv_a.conf") optsB, _ := ProcessConfigFile("./configs/srv_b.conf") diff --git a/server/server.go b/server/server.go index 99b92983..d314f3c2 100644 --- a/server/server.go +++ b/server/server.go @@ -86,6 +86,12 @@ type Server struct { debug int32 } + // These store the real client/cluster listen ports. They are + // required during config reload to reset the Options (after + // reload) to the actual listen port values. + clientActualPort int + clusterActualPort int + // Used by tests to check that http.Servers do // not set any timeout. monitoringServer *http.Server @@ -137,6 +143,12 @@ func New(opts *Options) *Server { s.mu.Lock() defer s.mu.Unlock() + // This is normally done in the AcceptLoop, once the + // listener has been created (possibly with random port), + // but since some tests may expect the INFO to be properly + // set after New(), let's do it now. + s.setInfoHostPortAndGenerateJSON() + // For tracking clients s.clients = make(map[uint64]*client) @@ -155,7 +167,6 @@ func New(opts *Options) *Server { // Used to setup Authorization. s.configureAuthorization() - s.generateServerInfoJSON() s.handleSignals() return s @@ -410,19 +421,19 @@ func (s *Server) AcceptLoop(clr chan struct{}) { // to 0 at the beginning this function. So we need to get the actual port if opts.Port == 0 { // Write resolved port back to options. - _, port, err := net.SplitHostPort(l.Addr().String()) - if err != nil { - s.Fatalf("Error parsing server address (%s): %s", l.Addr().String(), e) - s.mu.Unlock() - return - } - portNum, err := strconv.Atoi(port) - if err != nil { - s.Fatalf("Error parsing server address (%s): %s", l.Addr().String(), e) - s.mu.Unlock() - return - } - opts.Port = portNum + opts.Port = l.Addr().(*net.TCPAddr).Port + } + // Keep track of actual listen port. This will be needed in case of + // config reload. + s.clientActualPort = opts.Port + + // Now that port has been set (if it was set to RANDOM), set the + // server's info Host/Port with either values from Options or + // ClientAdvertise. Also generate the JSON byte array. + if err := s.setInfoHostPortAndGenerateJSON(); err != nil { + s.Fatalf("Error setting server INFO with ClientAdvertise value of %s, err=%v", s.opts.ClientAdvertise, err) + s.mu.Unlock() + return } s.mu.Unlock() @@ -458,6 +469,31 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.done <- true } +// This function sets the server's info Host/Port based on server Options. +// Note that this function may be called during config reload, this is why +// Host/Port may be reset to original Options if the ClientAdvertise option +// is not set (since it may have previously been). +// The function then generates the server infoJSON. +func (s *Server) setInfoHostPortAndGenerateJSON() error { + // When this function is called, opts.Port is set to the actual listen + // port (if option was originally set to RANDOM), even during a config + // reload. So use of s.opts.Port is safe. + if s.opts.ClientAdvertise != "" { + h, p, err := parseHostPort(s.opts.ClientAdvertise, s.opts.Port) + if err != nil { + return err + } + s.info.Host = h + s.info.Port = p + } else { + s.info.Host = s.opts.Host + s.info.Port = s.opts.Port + } + // (re)generate the infoJSON byte array. + s.generateServerInfoJSON() + return nil +} + // StartProfiler is called to enable dynamic profiling. func (s *Server) StartProfiler() { // Snapshot server options. @@ -984,6 +1020,7 @@ func (s *Server) startGoRoutine(f func()) { // getClientConnectURLs returns suitable URLs for clients to connect to the listen // port based on the server options' Host and Port. If the Host corresponds to // "any" interfaces, this call returns the list of resolved IP addresses. +// If ClientAdvertise is set, returns the client advertise host and port func (s *Server) getClientConnectURLs() []string { // Snapshot server options. opts := s.getOpts() @@ -991,45 +1028,52 @@ func (s *Server) getClientConnectURLs() []string { s.mu.Lock() defer s.mu.Unlock() - sPort := strconv.Itoa(opts.Port) urls := make([]string, 0, 1) - ipAddr, err := net.ResolveIPAddr("ip", opts.Host) - // If the host is "any" (0.0.0.0 or ::), get specific IPs from available - // interfaces. - if err == nil && ipAddr.IP.IsUnspecified() { - var ip net.IP - ifaces, _ := net.Interfaces() - for _, i := range ifaces { - addrs, _ := i.Addrs() - for _, addr := range addrs { - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP + // short circuit if client advertise is set + if opts.ClientAdvertise != "" { + // just use the info host/port. This is updated in s.New() + urls = append(urls, net.JoinHostPort(s.info.Host, strconv.Itoa(s.info.Port))) + } else { + sPort := strconv.Itoa(opts.Port) + ipAddr, err := net.ResolveIPAddr("ip", opts.Host) + // If the host is "any" (0.0.0.0 or ::), get specific IPs from available + // interfaces. + if err == nil && ipAddr.IP.IsUnspecified() { + var ip net.IP + ifaces, _ := net.Interfaces() + for _, i := range ifaces { + addrs, _ := i.Addrs() + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + // Skip non global unicast addresses + if !ip.IsGlobalUnicast() || ip.IsUnspecified() { + ip = nil + continue + } + urls = append(urls, net.JoinHostPort(ip.String(), sPort)) } - // Skip non global unicast addresses - if !ip.IsGlobalUnicast() || ip.IsUnspecified() { - ip = nil - continue - } - urls = append(urls, net.JoinHostPort(ip.String(), sPort)) + } + } + if err != nil || len(urls) == 0 { + // We are here if s.opts.Host is not "0.0.0.0" nor "::", or if for some + // reason we could not add any URL in the loop above. + // We had a case where a Windows VM was hosed and would have err == nil + // and not add any address in the array in the loop above, and we + // ended-up returning 0.0.0.0, which is problematic for Windows clients. + // Check for 0.0.0.0 or :: specifically, and ignore if that's the case. + if opts.Host == "0.0.0.0" || opts.Host == "::" { + s.Errorf("Address %q can not be resolved properly", opts.Host) + } else { + urls = append(urls, net.JoinHostPort(opts.Host, sPort)) } } } - if err != nil || len(urls) == 0 { - // We are here if s.opts.Host is not "0.0.0.0" nor "::", or if for some - // reason we could not add any URL in the loop above. - // We had a case where a Windows VM was hosed and would have err == nil - // and not add any address in the array in the loop above, and we - // ended-up returning 0.0.0.0, which is problematic for Windows clients. - // Check for 0.0.0.0 or :: specifically, and ignore if that's the case. - if opts.Host == "0.0.0.0" || opts.Host == "::" { - s.Errorf("Address %q can not be resolved properly", opts.Host) - } else { - urls = append(urls, net.JoinHostPort(opts.Host, sPort)) - } - } + return urls } diff --git a/server/server_test.go b/server/server_test.go index 1ca9d1d1..c9e8ece4 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -213,6 +213,71 @@ func TestGetConnectURLs(t *testing.T) { } } +func TestClientAdvertiseConnectURL(t *testing.T) { + opts := DefaultOptions() + opts.Port = 4222 + opts.ClientAdvertise = "nats.example.com" + s := New(opts) + defer s.Shutdown() + + urls := s.getClientConnectURLs() + if len(urls) != 1 { + t.Fatalf("Expected to get one url, got none: %v with ClientAdvertise %v", + opts.Host, opts.ClientAdvertise) + } + if urls[0] != "nats.example.com:4222" { + t.Fatalf("Expected to get '%s', got: '%v'", "nats.example.com:4222", urls[0]) + } + s.Shutdown() + + opts.ClientAdvertise = "nats.example.com:7777" + s = New(opts) + urls = s.getClientConnectURLs() + if len(urls) != 1 { + t.Fatalf("Expected to get one url, got none: %v with ClientAdvertise %v", + opts.Host, opts.ClientAdvertise) + } + if urls[0] != "nats.example.com:7777" { + t.Fatalf("Expected 'nats.example.com:7777', got: '%v'", urls[0]) + } + if s.info.Host != "nats.example.com" { + t.Fatalf("Expected host to be set to nats.example.com") + } + if s.info.Port != 7777 { + t.Fatalf("Expected port to be set to 7777") + } + s.Shutdown() + + opts = DefaultOptions() + opts.Port = 0 + opts.ClientAdvertise = "nats.example.com:7777" + s = New(opts) + if s.info.Host != "nats.example.com" && s.info.Port != 7777 { + t.Fatalf("Expected Client Advertise Host:Port to be nats.example.com:7777, got: %s:%d", + s.info.Host, s.info.Port) + } + s.Shutdown() +} + +func TestClientAdvertiseErrorOnStartup(t *testing.T) { + opts := DefaultOptions() + // Set invalid address + opts.ClientAdvertise = "addr:::123" + s := New(opts) + defer s.Shutdown() + dl := &DummyLogger{} + s.SetLogger(dl, false, false) + + // Expect this to return due to failure + s.Start() + dl.Lock() + msg := dl.msg + dl.Unlock() + if !strings.Contains(msg, "ClientAdvertise") { + t.Fatalf("Unexpected error: %v", msg) + } +} + func TestNoDeadlockOnStartFailure(t *testing.T) { opts := DefaultOptions() opts.Host = "x.x.x.x" // bad host diff --git a/server/util.go b/server/util.go index 06730b4d..05208c73 100644 --- a/server/util.go +++ b/server/util.go @@ -3,6 +3,11 @@ package server import ( + "errors" + "fmt" + "net" + "strconv" + "strings" "time" "github.com/nats-io/nuid" @@ -68,3 +73,28 @@ func secondsToDuration(seconds float64) time.Duration { ttl := seconds * float64(time.Second) return time.Duration(ttl) } + +// Parse a host/port string with a default port to use +// if none (or 0 or -1) is specified in `hostPort` string. +func parseHostPort(hostPort string, defaultPort int) (host string, port int, err error) { + if hostPort != "" { + host, sPort, err := net.SplitHostPort(hostPort) + switch err.(type) { + case *net.AddrError: + // try appending the current port + host, sPort, err = net.SplitHostPort(fmt.Sprintf("%s:%d", hostPort, defaultPort)) + } + if err != nil { + return "", -1, err + } + port, err = strconv.Atoi(strings.TrimSpace(sPort)) + if err != nil { + return "", -1, err + } + if port == 0 || port == -1 { + port = defaultPort + } + return strings.TrimSpace(host), port, nil + } + return "", -1, errors.New("No hostport specified") +} diff --git a/server/util_test.go b/server/util_test.go index 24ba9611..393819c8 100644 --- a/server/util_test.go +++ b/server/util_test.go @@ -30,6 +30,42 @@ func TestParseSInt64(t *testing.T) { } } +func TestParseHostPort(t *testing.T) { + check := func(hostPort string, defaultPort int, expectedHost string, expectedPort int, expectedErr bool) { + h, p, err := parseHostPort(hostPort, defaultPort) + if expectedErr { + if err == nil { + stackFatalf(t, "Expected an error, did not get one") + } + // expected error, so we are done + return + } + if !expectedErr && err != nil { + stackFatalf(t, "Unexpected error: %v", err) + } + if expectedHost != h { + stackFatalf(t, "Expected host %q, got %q", expectedHost, h) + } + if expectedPort != p { + stackFatalf(t, "Expected port %d, got %d", expectedPort, p) + } + } + check("addr:1234", 5678, "addr", 1234, false) + check(" addr:1234 ", 5678, "addr", 1234, false) + check(" addr : 1234 ", 5678, "addr", 1234, false) + check("addr", 5678, "addr", 5678, false) + check(" addr ", 5678, "addr", 5678, false) + check("addr:-1", 5678, "addr", 5678, false) + check(" addr:-1 ", 5678, "addr", 5678, false) + check(" addr : -1 ", 5678, "addr", 5678, false) + check("addr:0", 5678, "addr", 5678, false) + check(" addr:0 ", 5678, "addr", 5678, false) + check(" addr : 0 ", 5678, "addr", 5678, false) + check("addr:addr", 0, "", 0, true) + check("addr:::1234", 0, "", 0, true) + check("", 0, "", 0, true) +} + func BenchmarkParseInt(b *testing.B) { b.SetBytes(1) n := "12345678" diff --git a/test/proto_test.go b/test/proto_test.go index c58c8257..aa5f42fa 100644 --- a/test/proto_test.go +++ b/test/proto_test.go @@ -3,6 +3,7 @@ package test import ( + "encoding/json" "testing" "time" @@ -281,3 +282,25 @@ func TestControlLineMaximums(t *testing.T) { send(pubTooLong) expect(errRe) } + +func TestServerInfoWithClientAdvertise(t *testing.T) { + opts := DefaultTestOptions + opts.Port = PROTO_TEST_PORT + opts.ClientAdvertise = "me:1" + s := RunServer(&opts) + defer s.Shutdown() + + c := createClientConn(t, opts.Host, PROTO_TEST_PORT) + defer c.Close() + + buf := expectResult(t, c, infoRe) + js := infoRe.FindAllSubmatch(buf, 1)[0][1] + var sinfo server.Info + err := json.Unmarshal(js, &sinfo) + if err != nil { + t.Fatalf("Could not unmarshal INFO json: %v\n", err) + } + if sinfo.Host != "me" || sinfo.Port != 1 { + t.Fatalf("Expected INFO Host:Port to be me:1, got %s:%d", sinfo.Host, sinfo.Port) + } +}