diff --git a/server/auth.go b/server/auth.go index 1ee7b417..27107822 100644 --- a/server/auth.go +++ b/server/auth.go @@ -100,6 +100,24 @@ type RoutePermissions struct { Export *SubjectPermission `json:"export"` } +// GatewayPermissions are similar to RoutePermissions +type GatewayPermissions = RoutePermissions + +// clone will clone a RoutePermissions object +func (rp *RoutePermissions) clone() *RoutePermissions { + if rp == nil { + return nil + } + clone := &RoutePermissions{} + if rp.Import != nil { + clone.Import = rp.Import.clone() + } + if rp.Export != nil { + clone.Export = rp.Export.clone() + } + return clone +} + // clone will clone an individual subject permission. func (p *SubjectPermission) clone() *SubjectPermission { if p == nil { @@ -216,6 +234,8 @@ func (s *Server) checkAuthentication(c *client) bool { return s.isClientAuthorized(c) case ROUTER: return s.isRouterAuthorized(c) + case GATEWAY: + return s.isGatewayAuthorized(c) default: return false } @@ -401,6 +421,19 @@ func (s *Server) isRouterAuthorized(c *client) bool { return true } +// isGatewayAuthorized checks optional gateway authorization which can be nil or username/password. +func (s *Server) isGatewayAuthorized(c *client) bool { + // Snapshot server options. + opts := s.getOpts() + if opts.Gateway.Username == "" { + return true + } + if opts.Gateway.Username != c.opts.Username { + return false + } + return comparePasswords(opts.Gateway.Password, c.opts.Password) +} + // Support for bcrypt stored passwords and tokens. const bcryptPrefix = "$2a$" diff --git a/server/client.go b/server/client.go index df7f6326..3fa889d0 100644 --- a/server/client.go +++ b/server/client.go @@ -36,6 +36,8 @@ const ( CLIENT = iota // ROUTER is another router in the cluster. ROUTER + // GATEWAY is a link between 2 clusters. + GATEWAY ) const ( @@ -77,6 +79,7 @@ const ( handshakeComplete // For TLS clients, indicate that the handshake is complete clearConnection // Marks that clearConnection has already been called. flushOutbound // Marks client as having a flushOutbound call in progress. + noReconnect // Indicate that on close, this connection should not attempt a reconnect ) // set the flag (would be equivalent to set the boolean to true) @@ -133,6 +136,7 @@ const ( RouteRemoved ServerShutdown AuthenticationExpired + WrongGateway ) type client struct { @@ -168,6 +172,8 @@ type client struct { route *route + gw *gateway + debug bool trace bool echo bool @@ -358,6 +364,8 @@ func (c *client) initClient() { c.ncs = fmt.Sprintf("%s - cid:%d", conn, c.cid) case ROUTER: c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid) + case GATEWAY: + c.ncs = fmt.Sprintf("%s - gid:%d", conn, c.cid) } } @@ -875,8 +883,11 @@ func (c *client) processInfo(arg []byte) error { if err := json.Unmarshal(arg, &info); err != nil { return err } - if c.typ == ROUTER { + switch c.typ { + case ROUTER: c.processRouteInfo(&info) + case GATEWAY: + c.processGatewayInfo(&info) } return nil } @@ -887,6 +898,8 @@ func (c *client) processErr(errStr string) { c.Errorf("Client Error %s", errStr) case ROUTER: c.Errorf("Route Error %s", errStr) + case GATEWAY: + c.Errorf("Gateway Error %s", errStr) } c.closeConnection(ParseError) } @@ -944,7 +957,6 @@ func (c *client) processConnect(arg []byte) error { } c.last = time.Now() typ := c.typ - r := c.route srv := c.srv // Moved unmarshalling of clients' Options under the lock. // The client has already been added to the server map, so it is possible @@ -968,14 +980,17 @@ func (c *client) processConnect(arg []byte) error { c.mu.Unlock() if srv != nil { - // As soon as c.opts is unmarshalled and if the proto is at - // least ClientProtoInfo, we need to increment the following counter. - // This is decremented when client is removed from the server's - // clients map. - if proto >= ClientProtoInfo { - srv.mu.Lock() - srv.cproto++ - srv.mu.Unlock() + // Applicable to clients only: + if typ == CLIENT { + // As soon as c.opts is unmarshalled and if the proto is at + // least ClientProtoInfo, we need to increment the following counter. + // This is decremented when client is removed from the server's + // clients map. + if proto >= ClientProtoInfo { + srv.mu.Lock() + srv.cproto++ + srv.mu.Unlock() + } } // Check for Auth @@ -1020,33 +1035,23 @@ func (c *client) processConnect(arg []byte) error { } - // Check client protocol request if it exists. - if typ == CLIENT && (proto < ClientProtoZero || proto > ClientProtoInfo) { - c.sendErr(ErrBadClientProtocol.Error()) - c.closeConnection(BadClientProtocolVersion) - return ErrBadClientProtocol - } else if typ == ROUTER && lang != "" { - // Way to detect clients that incorrectly connect to the route listen - // port. Client provide Lang in the CONNECT protocol while ROUTEs don't. - c.sendErr(ErrClientConnectedToRoutePort.Error()) - c.closeConnection(WrongPort) - return ErrClientConnectedToRoutePort - } - - // Grab connection name of remote route. - if typ == ROUTER && r != nil { - var routePerms *RoutePermissions - if srv != nil { - routePerms = srv.getOpts().Cluster.Permissions + switch typ { + case CLIENT: + // Check client protocol request if it exists. + if proto < ClientProtoZero || proto > ClientProtoInfo { + c.sendErr(ErrBadClientProtocol.Error()) + c.closeConnection(BadClientProtocolVersion) + return ErrBadClientProtocol } - c.mu.Lock() - c.route.remoteID = c.opts.Name - c.setRoutePermissions(routePerms) - c.mu.Unlock() - } - - if verbose { - c.sendOK() + if verbose { + c.sendOK() + } + case ROUTER: + // Delegate the rest of processing to the route + return c.processRouteConnect(srv, arg, lang) + case GATEWAY: + // Delegate the rest of processing to the gateway + return c.processGatewayConnect(arg) } return nil } @@ -1961,10 +1966,13 @@ func isServiceReply(reply []byte) bool { // This will decide to call the client code or router code. func (c *client) processInboundMsg(msg []byte) { - if c.typ == CLIENT { + switch c.typ { + case CLIENT: c.processInboundClientMsg(msg) - } else { + case ROUTER: c.processInboundRoutedMsg(msg) + case GATEWAY: + c.processInboundGatewayMsg(msg) } } @@ -2040,6 +2048,11 @@ func (c *client) processInboundClientMsg(msg []byte) { if len(r.psubs)+len(r.qsubs) > 0 { c.processMsgResults(c.acc, r, msg, c.pa.subject, c.pa.reply) } + + // Now deal with gateways + if c.srv.gateway.enabled { + c.sendMsgToGateways(msg, c.pa.subject, c.pa.reply, len(r.qsubs) == 0) + } } // This checks and process import services by doing the mapping and sending the @@ -2129,6 +2142,9 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, subject, } c.addSubToRouteTargets(sub) continue + } else if sub.client.typ == GATEWAY { + // Never send to gateway from here. + continue } // Check for stream import mapped subs. These apply to local subs only. if sub.im != nil && sub.im.prefix != "" { @@ -2154,6 +2170,12 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, subject, // guidance on which queue groups we should deliver to. qf := c.pa.queues + // For gateway connections, we still want to send messages to routes + // even if there is no queue filters. + if c.typ == GATEWAY && qf == nil { + goto sendToRoutes + } + // Check to see if we have our own rand yet. Global rand // has contention with lots of clients, etc. if c.in.prand == nil { @@ -2231,6 +2253,8 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, subject, } } +sendToRoutes: + // If no messages for routes return here. if len(c.in.rts) == 0 { return @@ -2396,6 +2420,8 @@ func (c *client) typeString() string { return "Client" case ROUTER: return "Router" + case GATEWAY: + return "Gateway" } return "Unknown Type" } @@ -2483,14 +2509,30 @@ func (c *client) closeConnection(reason ClosedState) { return } - c.Debugf("%s connection closed", c.typeString()) + // Be consistent with the creation: for routes and gateways, + // we use Noticef on create, so use that too for delete. + if c.typ == ROUTER || c.typ == GATEWAY { + c.Noticef("%s connection closed", c.typeString()) + } else { + c.Debugf("%s connection closed", c.typeString()) + } c.clearAuthTimer() c.clearPingTimer() c.clearConnection(reason) c.nc = nil - ctype := c.typ + var ( + retryImplicit bool + connectURLs []string + gwName string + gwIsOutbound bool + gwCfg *gatewayCfg + ctype = c.typ + srv = c.srv + noReconnect = c.flags.isSet(noReconnect) + acc = c.acc + ) // Snapshot for use if we are a client connection. // FIXME(dlc) - we can just stub in a new one for client @@ -2504,22 +2546,19 @@ func (c *client) closeConnection(reason ClosedState) { subs = append(subs, sub) } } - srv := c.srv - var ( - routeClosed bool - retryImplicit bool - connectURLs []string - ) if c.route != nil { - routeClosed = c.route.closed - if !routeClosed { + if !noReconnect { retryImplicit = c.route.retry } connectURLs = c.route.connectURLs } + if ctype == GATEWAY { + gwName = c.gw.name + gwIsOutbound = c.gw.outbound + gwCfg = c.gw.cfg + } - acc := c.acc c.mu.Unlock() // Remove clients subscriptions. @@ -2572,8 +2611,9 @@ func (c *client) closeConnection(reason ClosedState) { } } - // Don't reconnect routes that are being closed. - if routeClosed { + // Don't reconnect connections that have been marked with + // the no reconnect flag. + if noReconnect { return } @@ -2608,19 +2648,35 @@ func (c *client) closeConnection(reason ClosedState) { // server shutdown. srv.startGoRoutine(func() { srv.reConnectToRoute(rurl, rtype) }) } + } else if srv != nil && ctype == GATEWAY && gwIsOutbound { + if gwCfg != nil { + srv.Debugf("Attempting reconnect for gateway %q", gwName) + // Run this as a go routine since we may be called within + // the solicitGateway itself if there was an error during + // the creation of the gateway connection. + srv.startGoRoutine(func() { srv.reconnectGateway(gwCfg) }) + } else { + srv.Debugf("Gateway %q not in configuration, not attempting reconnect", gwName) + } } } -// If the client is a route connection, sets the `closed` flag to true -// to prevent any reconnecting attempt when c.closeConnection() is called. -func (c *client) setRouteNoReconnectOnClose() { +// Set the noReconnect flag. This is used before a call to closeConnection() +// to prevent the connection to reconnect (routes, gateways). +func (c *client) setNoReconnect() { c.mu.Lock() - if c.route != nil { - c.route.closed = true - } + c.flags.set(noReconnect) c.mu.Unlock() } +// Returns the client's RTT value with the protection of the client's lock. +func (c *client) getRTTValue() time.Duration { + c.mu.Lock() + rtt := c.rtt + c.mu.Unlock() + return rtt +} + // Logging functionality scoped to a client or route. func (c *client) Errorf(format string, v ...interface{}) { diff --git a/server/errors.go b/server/errors.go index 859154e3..f40d9440 100644 --- a/server/errors.go +++ b/server/errors.go @@ -83,6 +83,15 @@ var ( // ErrServiceImportAuthorization is returned when a service import is not authorized. ErrServiceImportAuthorization = errors.New("Service Import Not Authorized") + + // ErrClientOrRouteConnectedToGatewayPort represents an error condition when + // a client or route attempted to connect to the Gateway port. + ErrClientOrRouteConnectedToGatewayPort = errors.New("Attempted To Connect To Gateway Port") + + // ErrWrongGateway represents an error condition when a server receives a connect + // request from a remote Gateway with a destination name that does not match the server's + // Gateway's name. + ErrWrongGateway = errors.New("Wrong Gateway") ) // configErr is a configuration error. diff --git a/server/gateway.go b/server/gateway.go new file mode 100644 index 00000000..f29d7dd1 --- /dev/null +++ b/server/gateway.go @@ -0,0 +1,1707 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "math/rand" + "net" + "net/url" + "sort" + "strconv" + "sync" + "sync/atomic" + "time" +) + +const ( + defaultSolicitGatewaysDelay = time.Second + defaultGatewayConnectDelay = time.Second + defaultGatewayReconnectDelay = time.Second +) + +var ( + gatewayConnectDelay = defaultGatewayConnectDelay + gatewayReconnectDelay = defaultGatewayReconnectDelay +) + +const ( + gatewayCmdGossip byte = 1 +) + +type srvGateway struct { + sync.RWMutex + enabled bool // Immutable, true if both a name and port are configured + name string // Name of the Gateway on this server + out map[string]*client // outbound gateways + in map[uint64]*client // inbound gateways + remotes map[string]*gatewayCfg // Config of remote gateways + URLs map[string]struct{} // Set of all Gateway URLs in the cluster + URL string // This server gateway URL (after possible random port is resolved) + info *Info // Gateway Info protocol + infoJSON []byte // Marshal'ed Info protocol + defPerms *GatewayPermissions // Default permissions (when accepting an unknown remote gateway) + rqs map[string]*subscription // Map of remote queue subscriptions (key is account+subject+queue) + runknown bool // Rejects unknown (not configured) gateway connections + resolver netResolver // Used to resolve host name before calling net.Dial() +} + +type gatewayCfg struct { + sync.RWMutex + *RemoteGatewayOpts + urls map[string]*url.URL + connAttempts int + implicit bool +} + +// Struct for client's gateway related fields +type gateway struct { + name string + outbound bool + cfg *gatewayCfg + // Saved in createGateway() since we delay send of CONNECT and INFO + // for outbound connections + connectURL *url.URL + infoJSON []byte + // As an outbound connection, this indicates if there is no interest + // for an account/subject. + // As an inbound connection, this indicates if we have sent a no-interest + // protocol to the remote gateway. + noInterest sync.Map + // Queue subscriptions interest for this gateway. + qsubsInterest sync.Map +} + +// clone retursn a deep copy of the RemoteGatewayOpts object +func (r *RemoteGatewayOpts) clone() *RemoteGatewayOpts { + if r == nil { + return nil + } + clone := &RemoteGatewayOpts{ + Name: r.Name, + URLs: deepCopyURLs(r.URLs), + } + if r.TLSConfig != nil { + clone.TLSConfig = r.TLSConfig.Clone() + clone.TLSTimeout = r.TLSTimeout + } + return clone +} + +// Ensure that gateway is properly configured. +func validateGatewayOptions(o *Options) error { + if o.Gateway.Name == "" && o.Gateway.Port == 0 { + return nil + } + if o.Gateway.Name == "" { + return fmt.Errorf("gateway has no name") + } + if o.Gateway.Port == 0 { + return fmt.Errorf("gateway %q has no port specified (select -1 for random port)", o.Gateway.Name) + } + for i, g := range o.Gateway.Gateways { + if g.Name == "" { + return fmt.Errorf("gateway in the list %d has no name", i) + } + if len(g.URLs) == 0 { + return fmt.Errorf("gateway %q has no URL", g.Name) + } + } + return nil +} + +// Initialize the s.gateway structure. We do this even if the server +// does not have a gateway configured. In some part of the code, the +// server will check the number of outbound gateways, etc.. and so +// we don't have to check if s.gateway is nil or not. +func newGateway(opts *Options) (*srvGateway, error) { + gateway := &srvGateway{ + name: opts.Gateway.Name, + out: make(map[string]*client), + in: make(map[uint64]*client), + remotes: make(map[string]*gatewayCfg), + URLs: make(map[string]struct{}), + rqs: make(map[string]*subscription), + resolver: opts.Gateway.resolver, + runknown: opts.Gateway.RejectUnknown, + } + gateway.Lock() + defer gateway.Unlock() + + if gateway.resolver == nil { + gateway.resolver = netResolver(net.DefaultResolver) + } + + // Copy default permissions (works if DefaultPermissions is nil) + gateway.defPerms = opts.Gateway.DefaultPermissions.clone() + + // Create remote gateways + for _, rgo := range opts.Gateway.Gateways { + cfg := &gatewayCfg{ + RemoteGatewayOpts: rgo.clone(), + urls: make(map[string]*url.URL, len(rgo.URLs)), + } + if opts.Gateway.TLSConfig != nil && cfg.TLSConfig == nil { + cfg.TLSConfig = opts.Gateway.TLSConfig.Clone() + } + if cfg.TLSTimeout == 0 { + cfg.TLSTimeout = opts.Gateway.TLSTimeout + } + for _, u := range rgo.URLs { + cfg.urls[u.Host] = u + } + gateway.remotes[cfg.Name] = cfg + } + + gateway.enabled = opts.Gateway.Name != "" && opts.Gateway.Port != 0 + return gateway, nil +} + +// Returns the Gateway's name of this server. +func (g *srvGateway) getName() string { + g.RLock() + n := g.name + g.RUnlock() + return n +} + +// Returns the Gateway URLs of all servers in the local cluster. +// This is used to send to other cluster this server connects to. +// The gateway read-lock is held on entry +func (g *srvGateway) getURLs() []string { + a := make([]string, 0, len(g.URLs)) + for u := range g.URLs { + a = append(a, u) + } + return a +} + +// Returns if this server rejects connections from gateways that are not +// explicitly configured. +func (g *srvGateway) rejectUnknown() bool { + g.RLock() + reject := g.runknown + g.RUnlock() + return reject +} + +// Starts the gateways accept loop and solicit explicit gateways +// after an initial delay. This delay is meant to give a chance to +// the cluster to form and this server gathers gateway URLs for this +// cluster in order to send that as part of the connect/info process. +func (s *Server) startGateways() { + // Spin up the accept loop + ch := make(chan struct{}) + go s.gatewayAcceptLoop(ch) + <-ch + + // Delay start of creation of gateways to give a chance + // to the local cluster to form. + s.startGoRoutine(func() { + defer s.grWG.Done() + + dur := s.getOpts().gatewaysSolicitDelay + if dur == 0 { + dur = defaultSolicitGatewaysDelay + } + + select { + case <-time.After(dur): + s.solicitGateways() + case <-s.quitCh: + return + } + }) +} + +// This is the gateways accept loop. This runs as a go-routine. +// The listen specification is resolved (if use of random port), +// then a listener is started. After that, this routine enters +// a loop (until the server is shutdown) accepting incoming +// gateway connections. +func (s *Server) gatewayAcceptLoop(ch chan struct{}) { + defer func() { + if ch != nil { + close(ch) + } + }() + + // Snapshot server options. + opts := s.getOpts() + + port := opts.Gateway.Port + if port == -1 { + port = 0 + } + + hp := net.JoinHostPort(opts.Gateway.Host, strconv.Itoa(port)) + l, e := net.Listen("tcp", hp) + if e != nil { + s.Fatalf("Error listening on gateway port: %d - %v", opts.Gateway.Port, e) + return + } + s.Noticef("Listening for gateways connections on %s", + net.JoinHostPort(opts.Gateway.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) + + s.mu.Lock() + tlsReq := opts.Gateway.TLSConfig != nil + authRequired := opts.Gateway.Username != "" + info := &Info{ + ID: s.info.ID, + Version: s.info.Version, + AuthRequired: authRequired, + TLSRequired: tlsReq, + TLSVerify: tlsReq, + MaxPayload: s.info.MaxPayload, + Gateway: opts.Gateway.Name, + } + // If we have selected a random port... + if port == 0 { + // Write resolved port back to options. + opts.Gateway.Port = l.Addr().(*net.TCPAddr).Port + } + // Keep track of actual listen port. This will be needed in case of + // config reload. + s.gatewayActualPort = opts.Gateway.Port + // Possibly override Host/Port based on Gateway.Advertise + if err := s.setGatewayInfoHostPort(info, opts); err != nil { + s.Fatalf("Error setting gateway INFO with Gateway.Advertise value of %s, err=%v", opts.Gateway.Advertise, err) + l.Close() + s.mu.Unlock() + return + } + // Setup state that can enable shutdown + s.gatewayListener = l + s.mu.Unlock() + + // Let them know we are up + close(ch) + ch = nil + + tmpDelay := ACCEPT_MIN_SLEEP + + for s.isRunning() { + conn, err := l.Accept() + if err != nil { + tmpDelay = s.acceptError("Gateway", err, tmpDelay) + continue + } + tmpDelay = ACCEPT_MIN_SLEEP + s.startGoRoutine(func() { + s.createGateway(nil, nil, conn) + s.grWG.Done() + }) + } + s.Debugf("Gateway accept loop exiting..") + s.done <- true +} + +// Similar to setInfoHostPortAndGenerateJSON, but for gatewayInfo. +func (s *Server) setGatewayInfoHostPort(info *Info, o *Options) error { + if o.Gateway.Advertise != "" { + advHost, advPort, err := parseHostPort(o.Gateway.Advertise, o.Gateway.Port) + if err != nil { + return err + } + info.Host = advHost + info.Port = advPort + } else { + info.Host = o.Gateway.Host + info.Port = o.Gateway.Port + } + gw := s.gateway + gw.Lock() + delete(gw.URLs, gw.URL) + gw.URL = net.JoinHostPort(info.Host, strconv.Itoa(info.Port)) + gw.URLs[gw.URL] = struct{}{} + gw.info = info + info.GatewayURL = gw.URL + // (re)generate the gatewayInfoJSON byte array + gw.generateInfoJSON() + gw.Unlock() + return nil +} + +// Generates the Gateway INFO protocol. +// The gateway lock is held on entry +func (g *srvGateway) generateInfoJSON() { + b, err := json.Marshal(g.info) + if err != nil { + panic(err) + } + g.infoJSON = []byte(fmt.Sprintf(InfoProto, b)) +} + +// Goes through the list of registered gateways and try to connect to those. +// The list (remotes) is initially containing the explicit remote gateways, +// but the list is augmented with any implicit (discovered) gateway. Therefore, +// this function only solicit explicit ones. +func (s *Server) solicitGateways() { + gw := s.gateway + gw.RLock() + defer gw.RUnlock() + for _, cfg := range gw.remotes { + // Since we delay the creation of gateways, it is + // possible that server starts to receive inbound from + // other clusters and in turn create outbounds. So here + // we create only the ones that are configured. + if !cfg.isImplicit() { + cfg := cfg // Create new instance for the goroutine. + s.startGoRoutine(func() { + s.solicitGateway(cfg) + s.grWG.Done() + }) + } + } +} + +// Reconnect to the gateway after a little wait period. For explicit +// gateways, we also wait for the default reconnect time. +func (s *Server) reconnectGateway(cfg *gatewayCfg) { + defer s.grWG.Done() + + delay := time.Duration(rand.Intn(100)) * time.Millisecond + if !cfg.isImplicit() { + delay += gatewayReconnectDelay + } + select { + case <-time.After(delay): + case <-s.quitCh: + return + } + s.solicitGateway(cfg) +} + +// This function will loop trying to connect to any URL attached +// to the given Gateway. It will return once a connection has been created. +func (s *Server) solicitGateway(cfg *gatewayCfg) { + var ( + opts = s.getOpts() + isImplicit = cfg.isImplicit() + urls = cfg.getURLs() + attempts int + ) + for s.isRunning() && len(urls) > 0 { + // Iteration is random + for _, u := range urls { + address, err := s.getRandomIP(s.gateway.resolver, u.Host) + if err != nil { + s.Errorf("Error getting IP for %s: %v", u.Host, err) + continue + } + s.Debugf("Trying to connect to gateway %q at %s", cfg.Name, address) + conn, err := net.DialTimeout("tcp", address, DEFAULT_ROUTE_DIAL) + if err == nil { + // We could connect, create the gateway connection and return. + s.createGateway(cfg, u, conn) + return + } + s.Errorf("Error trying to connect to gateway: %v", err) + // Break this loop if server is being shutdown... + if !s.isRunning() { + break + } + } + if isImplicit { + attempts++ + if opts.Gateway.ConnectRetries == 0 || attempts > opts.Gateway.ConnectRetries { + s.gateway.Lock() + delete(s.gateway.remotes, cfg.Name) + s.gateway.Unlock() + return + } + } + select { + case <-s.quitCh: + return + case <-time.After(gatewayConnectDelay): + continue + } + } +} + +// Called when a gateway connection is either accepted or solicited. +// If accepted, the gateway is marked as inbound. +// If solicited, the gateway is marked as outbound. +func (s *Server) createGateway(cfg *gatewayCfg, url *url.URL, conn net.Conn) { + // Snapshot server options. + opts := s.getOpts() + tlsRequired := opts.Gateway.TLSConfig != nil + + c := &client{srv: s, nc: conn, typ: GATEWAY} + + // Are we creating the gateway based on the configuration + solicit := cfg != nil + + // Generate INFO to send + s.gateway.RLock() + // Make a copy + info := *s.gateway.info + info.GatewayURLs = s.gateway.getURLs() + s.gateway.RUnlock() + b, _ := json.Marshal(&info) + infoJSON := []byte(fmt.Sprintf(InfoProto, b)) + + // Perform some initialization under the client lock + c.mu.Lock() + c.initClient() + c.gw = &gateway{} + c.in.rcache = make(map[string]*routeCache, maxRouteCacheSize) + if solicit { + // This is an outbound gateway connection + c.gw.outbound = true + c.gw.name = cfg.Name + c.gw.cfg = cfg + cfg.bumpConnAttempts() + // Since we are delaying the connect until after receiving + // the remote's INFO protocol, save the URL we need to connect to. + c.gw.connectURL = url + c.gw.infoJSON = infoJSON + c.Noticef("Creating outbound gateway connection to %q", cfg.Name) + } else { + // Inbound gateway connection + c.Noticef("Processing inbound gateway connection") + } + + // Check for TLS + if tlsRequired { + var timeout float64 + // If we solicited, we will act like the client, otherwise the server. + if solicit { + c.Debugf("Starting TLS gateway client handshake") + // Specify the ServerName we are expecting. + host, _, _ := net.SplitHostPort(url.Host) + tlsConfig := cfg.TLSConfig + tlsConfig.ServerName = host + c.nc = tls.Client(c.nc, tlsConfig) + timeout = cfg.TLSTimeout + } else { + c.Debugf("Starting TLS gateway server handshake") + c.nc = tls.Server(c.nc, opts.Gateway.TLSConfig) + timeout = opts.Gateway.TLSTimeout + } + + conn := c.nc.(*tls.Conn) + + // Setup the timeout + ttl := secondsToDuration(timeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Errorf("TLS gateway handshake error: %v", err) + c.sendErr("Secure Connection - TLS Required") + c.closeConnection(TLSHandshakeError) + return + } + // Reset the read deadline + conn.SetReadDeadline(time.Time{}) + + // Re-Grab lock + c.mu.Lock() + + // Verify that the connection did not go away while we released the lock. + if c.nc == nil { + c.mu.Unlock() + return + } + } + + // Do final client initialization + + // Register in temp map for now until gateway properly registered + // in out or in gateways. + if !s.addToTempClients(c.cid, c) { + c.mu.Unlock() + c.closeConnection(ServerShutdown) + return + } + + // Only send if we accept a connection. Will send CONNECT+INFO as an + // outbound only after processing peer's INFO protocol. + if !solicit { + c.sendInfo(infoJSON) + } + + // Spin up the read loop. + s.startGoRoutine(c.readLoop) + + // Spin up the write loop. + s.startGoRoutine(c.writeLoop) + + if tlsRequired { + c.Debugf("TLS handshake complete") + cs := c.nc.(*tls.Conn).ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + } + + // Set the Ping timer after sending connect and info. + c.setPingTimer() + + c.mu.Unlock() +} + +// Builds and sends the CONNET protocol for a gateway. +func (c *client) sendGatewayConnect() { + tlsRequired := c.srv.getOpts().Gateway.TLSConfig != nil + url := c.gw.connectURL + c.gw.connectURL = nil + var user, pass string + if userInfo := url.User; userInfo != nil { + user = userInfo.Username() + pass, _ = userInfo.Password() + } + cinfo := connectInfo{ + Verbose: false, + Pedantic: false, + User: user, + Pass: pass, + TLS: tlsRequired, + Name: c.srv.info.ID, + Gateway: c.srv.getGatewayName(), + } + b, err := json.Marshal(cinfo) + if err != nil { + panic(err) + } + c.sendProto([]byte(fmt.Sprintf(ConProto, b)), true) +} + +// Process the CONNECT protocol from a gateway connection. +// Returns an error to the connection if the CONNECT is not from a gateway +// (for instance a client or route connecting to the gateway port), or +// if the destination does not match the gateway name of this server. +func (c *client) processGatewayConnect(arg []byte) error { + connect := &connectInfo{} + if err := json.Unmarshal(arg, connect); err != nil { + return err + } + + // Coming from a client or a route, reject + if connect.Gateway == "" { + errTxt := ErrClientOrRouteConnectedToGatewayPort.Error() + c.Errorf(errTxt) + c.sendErr(errTxt) + c.closeConnection(WrongPort) + return ErrClientOrRouteConnectedToGatewayPort + } + + c.mu.Lock() + s := c.srv + c.mu.Unlock() + + // If we reject unknown gateways, make sure we have it configured, + // otherwise return an error. + if s.gateway.rejectUnknown() && s.getRemoteGateway(connect.Gateway) == nil { + c.Errorf("Rejecting connection from gateway %q", connect.Gateway) + c.sendErr(fmt.Sprintf("Connection to gateway %q rejected", s.getGatewayName())) + c.closeConnection(WrongGateway) + return ErrWrongGateway + } + + return nil +} + +// Process the INFO protocol from a gateway connection. +// +// If the gateway connection is an outbound (this server initiated the connection), +// this function checks that the incoming INFO contains the Gateway field. If empty, +// it means that this is a response from an older server or that this server connected +// to the wrong port. +// The outbound gateway may also receive a gossip INFO protocol from the remote gateway, +// indicating other gateways that the remote knows about. This server will try to connect +// to those gateways (if not explicitly configured or already implicitly connected). +// In both cases (explicit or implicit), the local cluster is notified about the existence +// of this new gateway. This allows servers in the cluster to ensure that they have an +// outbound connection to this gateway. +// +// For an inbound gateway, the gateway is simply registered and the info protocol +// is saved to be used after processing the CONNECT. +func (c *client) processGatewayInfo(info *Info) { + var ( + gwName string + cfg *gatewayCfg + ) + c.mu.Lock() + s := c.srv + cid := c.cid + + // Check if this is the first INFO. (this call sets the flag if not already set). + isFirstINFO := c.flags.setIfNotSet(infoReceived) + + isOutbound := c.gw.outbound + if isOutbound { + gwName = c.gw.name + cfg = c.gw.cfg + } else if isFirstINFO { + c.gw.name = info.Gateway + } + c.opts.Name = info.ID + c.mu.Unlock() + + // For an outbound connection... + if isOutbound { + // Check content of INFO for fields indicating that it comes from a gateway. + // If we incorrectly connect to the wrong port (client or route), we won't + // have the Gateway field set. + if info.Gateway == "" { + errTxt := fmt.Sprintf("Attempt to connect to gateway %q using wrong port", gwName) + s.Errorf(errTxt) + c.sendErr(errTxt) + c.closeConnection(WrongPort) + return + } + // Check that the gateway name we got is what we expect + if info.Gateway != gwName { + // Unless this is the very first INFO, it may be ok if this is + // a gossip request to connect to other gateways. + if !isFirstINFO && info.GatewayCmd == gatewayCmdGossip { + // If we are configured to reject unknown, do not attempt to + // connect to one that we don't have configured. + if s.gateway.rejectUnknown() && s.getRemoteGateway(info.Gateway) == nil { + return + } + s.processImplicitGateway(info) + return + } + // Otherwise, this is a failure... + // We are reporting this error in the log... + c.Errorf("Failing connection to gateway %q, remote gateway name is %q", + gwName, info.Gateway) + // ...and sending this back to the remote so that the error + // makes more sense in the remote server's log. + c.sendErr(fmt.Sprintf("Connection from %q rejected, wanted to connect to %q, this is %q", + s.getGatewayName(), gwName, info.Gateway)) + c.closeConnection(WrongGateway) + return + } + + // Possibly add URLs that we get from the INFO protocol. + cfg.updateURLs(info.GatewayURLs) + + // If this is the first INFO, send our connect + if isFirstINFO { + // Note, if we want to support NKeys, then we would get the nonce + // from this INFO protocol and can sign it in the CONNECT we are + // going to send now. + c.mu.Lock() + c.sendGatewayConnect() + c.Debugf("Gateway connect protocol sent to %q", gwName) + // Send INFO too + c.sendInfo(c.gw.infoJSON) + c.gw.infoJSON = nil + c.mu.Unlock() + + // Register as an outbound gateway.. if we had a protocol to ack our connect, + // then we should do that when process that ack. + s.registerOutboundGateway(gwName, c) + c.Noticef("Outbound gateway connection to %q (%s) registered", gwName, info.ID) + // Now that the outbound gateway is registered, we can remove from temp map. + s.removeFromTempClients(cid) + } + + // Flood local cluster with information about this gateway. + // Servers in this cluster will ensure that they have (or otherwise create) + // an outbound connection to this gateway. + s.forwardNewGatewayToLocalCluster(info) + + } else if isFirstINFO { + // This is the first INFO of an inbound connection... + + s.registerInboundGateway(cid, c) + c.Noticef("Inbound gateway connection from %q (%s) registered", info.Gateway, info.ID) + + // Now that it is registered, we can remove from temp map. + s.removeFromTempClients(cid) + + // Send our QSubs + s.sendQueueSubsToGateway(c) + + // Initiate outbound connection. This function will behave correctly if + // we have already one. + s.processImplicitGateway(info) + + // Send back to the server that initiated this gateway connection the + // list of all remote gateways known on this server. + s.gossipGatewaysToInboundGateway(info.Gateway, c) + } +} + +// Sends to the given inbound gateway connection a gossip INFO protocol +// for each gateway known by this server. This allows for a "full mesh" +// of gateways. +func (s *Server) gossipGatewaysToInboundGateway(gwName string, c *client) { + gw := s.gateway + gw.RLock() + defer gw.RUnlock() + for gwCfgName, cfg := range gw.remotes { + // Skip the gateway that we just created + if gwCfgName == gwName { + continue + } + info := Info{ + ID: s.info.ID, + GatewayCmd: gatewayCmdGossip, + } + urls := cfg.getURLsAsStrings() + if len(urls) > 0 { + info.Gateway = gwCfgName + info.GatewayURLs = urls + b, _ := json.Marshal(&info) + c.mu.Lock() + c.sendProto([]byte(fmt.Sprintf(InfoProto, b)), true) + c.mu.Unlock() + } + } +} + +// Sends the INFO protocol of a gateway to all routes known by this server. +func (s *Server) forwardNewGatewayToLocalCluster(oinfo *Info) { + // Need tp protect s.routes here, so use server's lock + s.mu.Lock() + defer s.mu.Unlock() + + // We don't really need the ID to be set, but, we need to make sure + // that it is not set to the server ID so that if we were to connect + // to an older server that does not expect a "gateway" INFO, it + // would think that it needs to create an implicit route (since info.ID + // would not match the route's remoteID), but will fail to do so because + // the sent protocol will not have host/port defined. + info := &Info{ + ID: "GW" + s.info.ID, + Gateway: oinfo.Gateway, + GatewayURLs: oinfo.GatewayURLs, + GatewayCmd: gatewayCmdGossip, + } + b, _ := json.Marshal(info) + infoJSON := []byte(fmt.Sprintf(InfoProto, b)) + + for _, r := range s.routes { + r.mu.Lock() + r.sendInfo(infoJSON) + r.mu.Unlock() + } +} + +// Sends queue subscriptions interest to remote gateway. +func (s *Server) sendQueueSubsToGateway(c *client) { + var ( + accsa = [256]*Account{} + accs = accsa[:0] + rqsa = [4096]*subscription{} + rqs = rqsa[:0] + bufa = [32 * 1024]byte{} + buf = bufa[:0] + ) + // Collect accounts + s.mu.Lock() + for _, acc := range s.accounts { + accs = append(accs, acc) + } + s.mu.Unlock() + + // Collect remote queue subs + s.gateway.RLock() + for _, sub := range s.gateway.rqs { + rqs = append(rqs, sub) + } + s.gateway.RUnlock() + + //TODO: Buffer may get too big so should send it by chunks... + + // Build proto for local subs + for _, acc := range accs { + acc.mu.RLock() + for subAndQueue, rm := range acc.rm { + if rm.qi > 0 { + buf = append(buf, rSubBytes...) + buf = append(buf, acc.Name...) + buf = append(buf, ' ') + buf = append(buf, subAndQueue...) + buf = append(buf, ' ', '1') + buf = append(buf, CR_LF...) + } + } + acc.mu.RUnlock() + } + // Now with the remote queue subs + for _, qsub := range rqs { + buf = append(buf, rSubBytes...) + // The remote queue sub sid is account+' '+subject+' '+queue name + buf = append(buf, qsub.sid...) + buf = append(buf, ' ', '1') + buf = append(buf, CR_LF...) + } + + // Send + if len(buf) > 0 { + c.mu.Lock() + c.queueOutbound(buf) + c.flushSignal() + closed := c.flags.isSet(clearConnection) + if !closed { + c.Debugf("Sent queue subscriptions to gateway") + } + c.mu.Unlock() + } +} + +// This is invoked when getting an INFO protocol for gateway on the ROUTER port. +// This function will then execute appropriate function based on the command +// contained in the protocol. +func (s *Server) processGatewayInfoFromRoute(info *Info, routeSrvID string, route *client) { + switch info.GatewayCmd { + case gatewayCmdGossip: + s.processImplicitGateway(info) + default: + s.Errorf("Unknown command %d from server %v", info.GatewayCmd, routeSrvID) + } +} + +// Sends INFO protocols to the given route connection for each known Gateway. +// These will be processed by the route and delegated to the gateway code to +// imvoke processImplicitGateway. +func (s *Server) sendGatewayConfigsToRoute(route *client) { + gw := s.gateway + gw.RLock() + defer gw.RUnlock() + // Send only to gateways for which we have actual outbound connection to. + if len(gw.out) == 0 { + return + } + // Check forwardNewGatewayToLocalCluster() as to why we set ID this way. + info := Info{ + ID: "GW" + s.info.ID, + GatewayCmd: gatewayCmdGossip, + } + for gwCfgName, cfg := range gw.remotes { + if _, exist := gw.out[gwCfgName]; !exist { + continue + } + urls := cfg.getURLsAsStrings() + if len(urls) > 0 { + info.Gateway = gwCfgName + info.GatewayURLs = urls + b, _ := json.Marshal(&info) + route.mu.Lock() + route.sendProto([]byte(fmt.Sprintf(InfoProto, b)), true) + route.mu.Unlock() + } + } +} + +// Initiates a gateway connection using the info contained in the INFO protocol. +// If a gateway with the same name is already registered (either because explicitly +// configured, or already implicitly connected), this function will augmment the +// remote URLs with URLs present in the info protocol and return. +// Otherwise, this function will register this remote (to prevent multiple connections +// to the same remote) and call solicitGateway (which will run in a different go-routine). +func (s *Server) processImplicitGateway(info *Info) { + s.gateway.Lock() + defer s.gateway.Unlock() + // Name of the gateway to connect to is the Info.Gateway field. + gwName := info.Gateway + // Check if we already have this config, and if so, we are done + cfg := s.gateway.remotes[gwName] + if cfg != nil { + // However, possibly augment the list of URLs with the given + // info.GatewayURLs content. + cfg.Lock() + cfg.addURLs(info.GatewayURLs) + cfg.Unlock() + return + } + opts := s.getOpts() + cfg = &gatewayCfg{ + RemoteGatewayOpts: &RemoteGatewayOpts{Name: gwName}, + urls: make(map[string]*url.URL, len(info.GatewayURLs)), + implicit: true, + } + if opts.Gateway.TLSConfig != nil { + cfg.TLSConfig = opts.Gateway.TLSConfig.Clone() + cfg.TLSTimeout = opts.Gateway.TLSTimeout + } + + // Since we know we don't have URLs (no config, so just based on what we + // get from INFO), directly call addURLs(). We don't need locking since + // we just created that structure and no one else has access to it yet. + cfg.addURLs(info.GatewayURLs) + // If there is no URL, we can't proceed. + if len(cfg.urls) == 0 { + return + } + s.gateway.remotes[gwName] = cfg + s.startGoRoutine(func() { + s.solicitGateway(cfg) + s.grWG.Done() + }) +} + +// Returns the number of outbound gateway connections +func (s *Server) numOutboundGateways() int { + s.gateway.RLock() + n := len(s.gateway.out) + s.gateway.RUnlock() + return n +} + +// Returns the number of inbound gateway connections +func (s *Server) numInboundGateways() int { + s.gateway.RLock() + n := len(s.gateway.in) + s.gateway.RUnlock() + return n +} + +// Returns the remoteGateway (if any) that has the given `name` +func (s *Server) getRemoteGateway(name string) *gatewayCfg { + s.gateway.RLock() + cfg := s.gateway.remotes[name] + s.gateway.RUnlock() + return cfg +} + +// Used in tests +func (g *gatewayCfg) bumpConnAttempts() { + g.Lock() + g.connAttempts++ + g.Unlock() +} + +// Used in tests +func (g *gatewayCfg) getConnAttempts() int { + g.Lock() + ca := g.connAttempts + g.Unlock() + return ca +} + +// Used in tests +func (g *gatewayCfg) resetConnAttempts() { + g.Lock() + g.connAttempts = 0 + g.Unlock() +} + +// Returns if this remote gateway is implicit or not. +func (g *gatewayCfg) isImplicit() bool { + g.RLock() + ii := g.implicit + g.RUnlock() + return ii +} + +// getURLs returns an array of URLs in random order suitable for +// an iteration to try to connect. +func (g *gatewayCfg) getURLs() []*url.URL { + g.RLock() + a := make([]*url.URL, 0, len(g.urls)) + for _, u := range g.urls { + a = append(a, u) + } + g.RUnlock() + return a +} + +// Similar to getURLs but returns the urls as an array of strings. +func (g *gatewayCfg) getURLsAsStrings() []string { + g.RLock() + a := make([]string, 0, len(g.urls)) + for _, u := range g.urls { + a = append(a, u.Host) + } + g.RUnlock() + return a +} + +// updateURLs creates the urls map with the content of the config's URLs array +// and the given array that we get from the INFO protocol. +func (g *gatewayCfg) updateURLs(infoURLs []string) { + g.Lock() + // Clear the map... + g.urls = make(map[string]*url.URL, len(g.URLs)+len(infoURLs)) + // Add the urls from the config URLs array. + for _, u := range g.URLs { + g.urls[u.Host] = u + } + // Then add the ones from the infoURLs array we got. + g.addURLs(infoURLs) + g.Unlock() +} + +// add URLs from the given array to the urls map only if not already present. +// remoteGateway write lock is assumed to be held on entry. +func (g *gatewayCfg) addURLs(infoURLs []string) { + for _, iu := range infoURLs { + if _, present := g.urls[iu]; !present { + // Urls in Info.GatewayURLs come without scheme. Add it to parse + // the url (otherwise it fails). + if u, err := url.Parse(fmt.Sprintf("nats://%s", iu)); err == nil { + // But use u.Host for the key. + g.urls[u.Host] = u + } + } + } +} + +// Adds this URL to the set of Gateway URLs +// Server lock held on entry +func (s *Server) addGatewayURL(urlStr string) { + s.gateway.Lock() + s.gateway.URLs[urlStr] = struct{}{} + s.gateway.Unlock() +} + +// Remove this URL from the set of gateway URLs +// Server lock held on entry +func (s *Server) removeGatewayURL(urlStr string) { + s.gateway.Lock() + delete(s.gateway.URLs, urlStr) + s.gateway.Unlock() +} + +// This returns the URL of the Gateway listen spec, or empty string +// if the server has no gateway configured. +func (s *Server) getGatewayURL() string { + s.gateway.RLock() + url := s.gateway.URL + s.gateway.RUnlock() + return url +} + +// Returns this server gateway name. +// Same than calling s.gateway.getName() +func (s *Server) getGatewayName() string { + return s.gateway.getName() +} + +// All gateway connections (outbound and inbound) are put in the given map. +func (s *Server) getAllGatewayConnections(conns map[uint64]*client) { + gw := s.gateway + gw.RLock() + for _, c := range gw.out { + c.mu.Lock() + cid := c.cid + c.mu.Unlock() + conns[cid] = c + } + for cid, c := range gw.in { + conns[cid] = c + } + gw.RUnlock() +} + +// Register the given gateway connection (*client) in the inbound gateways +// map with the given name as the key. +func (s *Server) registerInboundGateway(cid uint64, gwc *client) { + s.gateway.Lock() + s.gateway.in[cid] = gwc + s.gateway.Unlock() +} + +// Register the given gateway connection (*client) in the outbound gateways +// map with the given name as the key. +func (s *Server) registerOutboundGateway(name string, gwc *client) { + s.gateway.Lock() + s.gateway.out[name] = gwc + s.gateway.Unlock() +} + +// Returns the outbound gateway connection (*client) with the given name, +// or nil if not found +func (s *Server) getOutboundGateway(name string) *client { + s.gateway.RLock() + gwc := s.gateway.out[name] + s.gateway.RUnlock() + return gwc +} + +// Returns all outbound gateway connections in the provided array +func (s *Server) getOutboundGateways(a *[]*client) { + s.gateway.RLock() + for _, gwc := range s.gateway.out { + *a = append(*a, gwc) + } + s.gateway.RUnlock() +} + +// Returns all inbound gateway connections in the provided array +func (s *Server) getInboundGateways(a *[]*client) { + s.gateway.RLock() + for _, gwc := range s.gateway.in { + *a = append(*a, gwc) + } + s.gateway.RUnlock() +} + +// This is invoked when a gateway connection is closed and the server +// is removing this connection from its state. +func (s *Server) removeRemoteGateway(c *client) { + c.mu.Lock() + cid := c.cid + isOutbound := c.gw.outbound + gwName := c.gw.name + c.mu.Unlock() + + gw := s.gateway + gw.Lock() + if isOutbound { + delete(gw.out, gwName) + } else { + delete(gw.in, cid) + } + gw.Unlock() + s.removeFromTempClients(cid) +} + +// GatewayAddr returns the net.Addr object for the gateway listener. +func (s *Server) GatewayAddr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.gatewayListener == nil { + return nil + } + return s.gatewayListener.Addr().(*net.TCPAddr) +} + +// A- protocol received when sending messages on an account that +// the remote gateway has no interest in. Mark this account +// with a "no interest" marker to prevent further messages send. +func (c *client) processGatewayAccountUnsub(accName string) { + // Just to indicate activity around "subscriptions" events. + c.in.subs++ + c.gw.noInterest.Store(accName, nil) +} + +// A+ protocol received by remote gateway if it had previously +// sent an A-. Clear the "no interest" marker for this account. +func (c *client) processGatewayAccountSub(accName string) error { + // Just to indicate activity around "subscriptions" events. + c.in.subs++ + c.gw.noInterest.Delete(accName) + return nil +} + +// RS- protocol received when sending messages on an subject +// that the remote gateway has no interest in (but knows about the account). +// Mark this subject with a "no interest" marker to prevent further +// messages send. +func (c *client) processGatewaySubjectUnsub(arg []byte) error { + accName, subject, queue, err := c.parseUnsubProto(arg) + if err != nil { + return fmt.Errorf("processGatewaySubjectUnsub %s", err.Error()) + } + // Queue subs are treated differently. + if queue != nil { + sli, _ := c.gw.qsubsInterest.Load(accName) + if sli != nil { + sl := sli.(*Sublist) + r := sl.Match(string(subject)) + if len(r.qsubs) > 0 { + for i := 0; i < len(r.qsubs); i++ { + qsubs := r.qsubs[i] + if bytes.Equal(qsubs[0].queue, queue) { + sl.Remove(qsubs[0]) + if sl.Count() == 0 { + c.gw.qsubsInterest.Delete(accName) + } + } + } + } + } + } else { + // For a given gateway, we will receive the A+,A-,RS+ and RS- + // from the same tcp connection/go routine. So although there + // may be different go-routines doing lookups on this map, + // we are guaranteed that there is no Store/Delete happening + // in parallel. + sni, noInterest := c.gw.noInterest.Load(accName) + // Do we have a no-interest at the account level? If so, + // don't bother setting no-interest on that subject. + if noInterest && sni == nil { + return nil + } + var ( + subjNoInterest *sync.Map + store bool + ) + if sni == nil { + subjNoInterest = &sync.Map{} + store = true + } else { + subjNoInterest = sni.(*sync.Map) + } + subjNoInterest.Store(string(subject), struct{}{}) + if store { + c.gw.noInterest.Store(accName, subjNoInterest) + } + } + return nil +} + +// For plain subs, RS+ protocol received by remote gateway if it +// had previously sent a RS-. Clear the "no interest" marker for +// this subject (under this account). +// For queue subs, register interest from remote gateway. +func (c *client) processGatewaySubjectSub(arg []byte) error { + c.traceInOp("RS+", arg) + + // Indicate activity. + c.in.subs++ + + var ( + queue []byte + qw int32 + ) + + args := splitArg(arg) + switch len(args) { + case 2: + case 4: + queue = args[2] + qw = int32(parseSize(args[3])) + default: + return fmt.Errorf("processGatewaySubjectSub Parse Error: '%s'", arg) + } + accName := args[0] + subject := args[1] + + if queue != nil { + var ( + sl *Sublist + store bool + ) + sli, _ := c.gw.qsubsInterest.Load(string(accName)) + if sli != nil { + sl = sli.(*Sublist) + } else { + sl = NewSublist() + store = true + } + // Copy subject and queue to avoid referencing a possibly + // big underlying buffer. + cbuf := make([]byte, len(subject)+len(queue)) + copy(cbuf, subject) + copy(cbuf[len(subject):], queue) + sub := &subscription{client: c, subject: cbuf[:len(subject)], queue: cbuf[len(subject):], qw: qw} + sl.Insert(sub) + if store { + c.gw.qsubsInterest.Store(string(accName), sl) + } + } else { + // See remark from processGatewaySubjectUnsub(). + sni, _ := c.gw.noInterest.Load(string(accName)) + // If this account is not even present, or if there is no + // specific subject with no-interest, we are done. + if sni == nil { + return nil + } + subjNoInterest := sni.(*sync.Map) + subjNoInterest.Delete(string(subject)) + } + return nil +} + +// Returns true if the remote gateway has no known no-interest +// on the account and/or subject. Returns false if we know that +// we should not be sending messages on that account/subject. +func (c *client) hasInterest(acc, subj string) bool { + // If there is no value for the `acc` key, then it means that + // there is interest (or no known no-interest). + sni, accountNoInterest := c.gw.noInterest.Load(acc) + if !accountNoInterest { + return true + } + // If what is stored is nil, then it means that there is + // no interest at the account level, so we return false. + if sni == nil { + return false + } + // If there is a value, sni is itself a sync.Map for subject + // no-interest (in that account). + subjectNoInterest := sni.(*sync.Map) + if _, noInterest := subjectNoInterest.Load(subj); noInterest { + return false + } + // There is interest + return true +} + +// This is invoked when an account is registered. We check if we did send +// to remote gateways a "no interest" in the past when receiving messages. +// If we did, we send to the remote gateway an A+ protocol (see +// processGatewayAccountSub()). +func (s *Server) endAccountNoInterestForGateways(accName string) { + var ( + _gws [4]*client + gws = _gws[:0] + ) + s.getInboundGateways(&gws) + if len(gws) == 0 { + return + } + var ( + protoa = [256]byte{} + proto = protoa[:0] + ) + proto = append(proto, aSubBytes...) + proto = append(proto, accName...) + proto = append(proto, CR_LF...) + for _, c := range gws { + if _, noInterest := c.gw.noInterest.Load(accName); noInterest { + c.gw.noInterest.Delete(accName) + c.mu.Lock() + if c.trace { + c.traceOutOp("", proto[:len(proto)-LEN_CR_LF]) + } + c.sendProto(proto, true) + c.mu.Unlock() + } + } +} + +// This is invoked when a subscription is registered. We check if we did +// send to remote gateways a "no interest" in the past when receiving messages. +// If we did, we send the remote gateway a RS+ protocol. +func (s *Server) sendSubInterestToGateways(accName string, sub *subscription) { + // If this is a remote queue sub, we need to keep track of it now. + if sub.queue != nil && sub.client != nil && sub.client.typ == ROUTER { + s.gateway.Lock() + s.gateway.rqs[string(sub.sid)] = sub + s.gateway.Unlock() + } + var ( + _gws [4]*client + gws = _gws[:0] + ) + s.getInboundGateways(&gws) + if len(gws) == 0 { + return + } + var ( + protoa = [512]byte{} + proto = protoa[:0] + ) + for _, c := range gws { + // Always send for queues. Note that by contract this function + // is called only once per account/subject/queue. + sendProto := sub.queue != nil + // We still want to execute following code even if this is a + // queue and we are going to send the proto anyway, because + // it may clear the noInterest map if we are going from + // a no-interest to a queue sub. + if sni, _ := c.gw.noInterest.Load(accName); sni != nil { + sn := sni.(*sync.Map) + if _, noInterest := sn.Load(string(sub.subject)); noInterest { + sn.Delete(string(sub.subject)) + sendProto = true + } + } + if sendProto { + if len(proto) == 0 { + proto = append(proto, rSubBytes...) + proto = append(proto, accName...) + proto = append(proto, ' ') + proto = append(proto, sub.subject...) + if sub.queue != nil { + proto = append(proto, ' ') + proto = append(proto, sub.queue...) + // For now, just use 1 for the weight + proto = append(proto, ' ', '1') + } + proto = append(proto, CR_LF...) + } + c.mu.Lock() + if c.trace { + c.traceOutOp("", proto[:len(proto)-LEN_CR_LF]) + } + c.sendProto(proto, true) + c.mu.Unlock() + } + } +} + +func (s *Server) sendQueueUnsubToGateways(accName string, qsub *subscription) { + // If this is a remote queue sub, need to remove from remote queue subs map + if qsub.client != nil && qsub.client.typ == ROUTER { + s.gateway.Lock() + delete(s.gateway.rqs, string(qsub.sid)) + s.gateway.Unlock() + } + var ( + _gws [4]*client + gws = _gws[:0] + ) + s.getInboundGateways(&gws) + if len(gws) == 0 { + return + } + var ( + protoa = [512]byte{} + proto = protoa[:0] + ) + proto = append(proto, rUnsubBytes...) + proto = append(proto, accName...) + proto = append(proto, ' ') + proto = append(proto, qsub.subject...) + proto = append(proto, ' ') + proto = append(proto, qsub.queue...) + proto = append(proto, CR_LF...) + for _, c := range gws { + c.mu.Lock() + if c.trace { + c.traceOutOp("", proto[:len(proto)-LEN_CR_LF]) + } + c.sendProto(proto, true) + c.mu.Unlock() + } +} + +// May send a message to all outbound gateways. It is possible +// that message is not sent to a given gateway if for instance +// it is known that this gateway has no interest in account or +// subject, etc.. +func (c *client) sendMsgToGateways(msg, subject, reply []byte, doQueues bool) { + var ( + gwsa [4]*client + gws = gwsa[:0] + s = c.srv + ) + s.getOutboundGateways(&gws) + if len(gws) == 0 { + return + } + var ( + subj = string(subject) + queuesa = [512]byte{} + queues = queuesa[:0] + groups = map[string]struct{}{} + accName = c.acc.Name + ) + if doQueues { + // Order the gws by lowest RTT + sort.Slice(gws, func(i, j int) bool { + return gws[i].getRTTValue() < gws[j].getRTTValue() + }) + } + for _, gwc := range gws { + // hasInterest is not covering queue subs + psubInterest := gwc.hasInterest(accName, subj) + if !psubInterest && !doQueues { + continue + } + mh := c.msgb[:msgHeadProtoLen] + mh = append(mh, accName...) + mh = append(mh, ' ') + mh = append(mh, subject...) + mh = append(mh, ' ') + + hasQueues := false + if doQueues { + queues = queuesa[:0] + sli, _ := gwc.gw.qsubsInterest.Load(accName) + if sli != nil { + sl := sli.(*Sublist) + r := sl.Match(subj) + if len(r.qsubs) > 0 { + for i := 0; i < len(r.qsubs); i++ { + qsubs := r.qsubs[i] + if len(qsubs) > 0 { + queue := qsubs[0].queue + if _, exists := groups[string(queue)]; !exists { + groups[string(queue)] = struct{}{} + queues = append(queues, queue...) + queues = append(queues, ' ') + hasQueues = true + } + } + } + } + } + if hasQueues { + if reply != nil { + mh = append(mh, "+ "...) // Signal that there is a reply. + mh = append(mh, reply...) + mh = append(mh, ' ') + } else { + mh = append(mh, "| "...) // Only queues + } + mh = append(mh, queues...) + } else if !psubInterest { + continue + } + } + if !hasQueues && reply != nil { + mh = append(mh, reply...) + mh = append(mh, ' ') + } + mh = append(mh, c.pa.szb...) + mh = append(mh, CR_LF...) + sub := subscription{client: gwc} + c.deliverMsg(&sub, mh, msg) + } +} + +func (c *client) processInboundGatewayMsg(msg []byte) { + // Update statistics + c.in.msgs++ + // The msg includes the CR_LF, so pull back out for accounting. + c.in.bytes += len(msg) - LEN_CR_LF + + if c.trace { + c.traceMsg(msg) + } + + if c.opts.Verbose { + c.sendOK() + } + + // Mostly under testing scenarios. + if c.srv == nil { + return + } + + var ( + acc *Account + rc *routeCache + r *SublistResult + ok bool + ) + + // Check our cache first. + if rc, ok = c.in.rcache[string(c.pa.rcache)]; ok { + // Check the genid to see if it's still valid. + if genid := atomic.LoadUint64(&rc.acc.sl.genid); genid != rc.genid { + ok = false + delete(c.in.rcache, string(c.pa.rcache)) + } else { + acc = rc.acc + r = rc.results + } + } + + if !ok { + // Match correct account and sublist. + acc = c.srv.LookupAccount(string(c.pa.account)) + if acc == nil { + c.Debugf("Unknown account %q for routed message on subject: %q", c.pa.account, c.pa.subject) + // Send A- only once... + if _, sent := c.gw.noInterest.Load(string(c.pa.account)); !sent { + // Send an A- protocol, but keep track that we have sent a "no interest" + // for that account, so that if later this account gets registered, + // we need to send an A+ to this remote gateway. + c.gw.noInterest.Store(string(c.pa.account), nil) + var ( + protoa = [256]byte{} + proto = protoa[:0] + ) + proto = append(proto, aUnsubBytes...) + proto = append(proto, c.pa.account...) + if c.trace { + c.traceOutOp("", proto) + } + proto = append(proto, CR_LF...) + c.mu.Lock() + c.sendProto(proto, true) + c.mu.Unlock() + } + return + } + + // Match against the account sublist. + r = acc.sl.Match(string(c.pa.subject)) + + // Store in our cache + c.in.rcache[string(c.pa.rcache)] = &routeCache{acc, r, atomic.LoadUint64(&acc.sl.genid)} + + // Check if we need to prune. + if len(c.in.rcache) > maxRouteCacheSize { + c.pruneRouteCache() + } + } + + // Check to see if we need to map/route to another account. + if acc.imports.services != nil { + c.checkForImportServices(acc, msg) + } + + // Check for no interest, short circuit if so. + // This is the fanout scale. + if len(r.psubs)+len(r.qsubs) == 0 { + sendProto := false + // Send an RS- protocol, but keep track that we have said "no interest" + // for that account/subject, so that if later there is a subscription on + // this subject, we need to send an R+ to remote gateways. + si, _ := c.gw.noInterest.Load(string(c.pa.account)) + if si == nil { + m := &sync.Map{} + m.Store(string(c.pa.subject), struct{}{}) + c.gw.noInterest.Store(string(c.pa.account), m) + sendProto = true + } else { + m := si.(*sync.Map) + if _, alreadySent := m.Load(string(c.pa.subject)); !alreadySent { + m.Store(string(c.pa.subject), struct{}{}) + sendProto = true + } + } + if sendProto { + var ( + protoa = [512]byte{} + proto = protoa[:0] + ) + proto = append(proto, rUnsubBytes...) + proto = append(proto, c.pa.account...) + proto = append(proto, ' ') + proto = append(proto, c.pa.subject...) + c.mu.Lock() + if c.trace { + c.traceOutOp("", proto) + } + proto = append(proto, CR_LF...) + c.sendProto(proto, true) + c.mu.Unlock() + } + return + } + + // Check to see if we have a routed message with a service reply. + if isServiceReply(c.pa.reply) && acc != nil { + // Need to add a sub here for local interest to send a response back + // to the originating server/requestor where it will be re-mapped. + sid := make([]byte, 0, len(acc.Name)+len(c.pa.reply)+1) + sid = append(sid, acc.Name...) + sid = append(sid, ' ') + sid = append(sid, c.pa.reply...) + // Copy off the reply since otherwise we are referencing a buffer that will be reused. + reply := make([]byte, len(c.pa.reply)) + copy(reply, c.pa.reply) + sub := &subscription{client: c, subject: reply, sid: sid, max: 1} + if err := acc.sl.Insert(sub); err != nil { + c.Errorf("Could not insert subscription: %v", err) + } else { + ttl := acc.AutoExpireTTL() + c.mu.Lock() + c.subs[string(sid)] = sub + c.addReplySubTimeout(acc, sub, ttl) + c.mu.Unlock() + } + } + + c.processMsgResults(acc, r, msg, c.pa.subject, c.pa.reply) +} diff --git a/server/gateway_test.go b/server/gateway_test.go new file mode 100644 index 00000000..5dd09865 --- /dev/null +++ b/server/gateway_test.go @@ -0,0 +1,2352 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/url" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/gnatsd/logger" + "github.com/nats-io/go-nats" +) + +func init() { + gatewayConnectDelay = 15 * time.Millisecond + gatewayReconnectDelay = 15 * time.Millisecond +} + +// Wait for the expected number of outbound gateways, or fails. +func waitForOutboundGateways(t *testing.T, s *Server, expected int, timeout time.Duration) { + t.Helper() + checkFor(t, timeout, 15*time.Millisecond, func() error { + if n := s.numOutboundGateways(); n != expected { + return fmt.Errorf("Expected %v outbound gateway(s), got %v", expected, n) + } + return nil + }) +} + +// Wait for the expected number of inbound gateways, or fails. +func waitForInboundGateways(t *testing.T, s *Server, expected int, timeout time.Duration) { + t.Helper() + checkFor(t, timeout, 15*time.Millisecond, func() error { + if n := s.numInboundGateways(); n != expected { + return fmt.Errorf("Expected %v inbound gateway(s), got %v", expected, n) + } + return nil + }) +} + +func waitForGatewayFailedConnect(t *testing.T, s *Server, gwName string, expectFailure bool, timeout time.Duration) { + t.Helper() + checkFor(t, timeout, 15*time.Millisecond, func() error { + var c int + cfg := s.getRemoteGateway(gwName) + if cfg != nil { + c = cfg.getConnAttempts() + } + if expectFailure && c <= 1 { + return fmt.Errorf("Expected several attempts to connect, got %v", c) + } else if !expectFailure && c > 1 { + return fmt.Errorf("Expected single attempt to connect, got %v", c) + } + return nil + }) +} + +func waitCh(t *testing.T, ch chan bool, errTxt string) { + t.Helper() + select { + case <-ch: + return + case <-time.After(5 * time.Second): + t.Fatalf(errTxt) + } +} + +func natsConnect(t *testing.T, url string, options ...nats.Option) *nats.Conn { + t.Helper() + nc, err := nats.Connect(url, options...) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + return nc +} + +func natsSub(t *testing.T, nc *nats.Conn, subj string, cb nats.MsgHandler) *nats.Subscription { + t.Helper() + sub, err := nc.Subscribe(subj, cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + return sub +} + +func natsQueueSub(t *testing.T, nc *nats.Conn, subj, queue string, cb nats.MsgHandler) *nats.Subscription { + t.Helper() + sub, err := nc.QueueSubscribe(subj, queue, cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + return sub +} + +func natsFlush(t *testing.T, nc *nats.Conn) { + t.Helper() + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } +} + +func natsPub(t *testing.T, nc *nats.Conn, subj string, payload []byte) { + t.Helper() + if err := nc.Publish(subj, payload); err != nil { + t.Fatalf("Error on publish: %v", err) + } +} + +func natsPubReq(t *testing.T, nc *nats.Conn, subj, reply string, payload []byte) { + t.Helper() + if err := nc.PublishRequest(subj, reply, payload); err != nil { + t.Fatalf("Error on publish: %v", err) + } +} + +func natsUnsub(t *testing.T, sub *nats.Subscription) { + t.Helper() + if err := sub.Unsubscribe(); err != nil { + t.Fatalf("Error on unsubscribe: %v", err) + } +} + +func testDefaultOptionsForGateway(name string) *Options { + o := DefaultOptions() + o.Gateway.Name = name + o.Gateway.Host = "127.0.0.1" + o.Gateway.Port = -1 + o.Gateway.DefaultPermissions = &GatewayPermissions{ + Import: &SubjectPermission{Allow: []string{">"}}, + Export: &SubjectPermission{Allow: []string{">"}}, + } + o.gatewaysSolicitDelay = 15 * time.Millisecond + return o +} + +func runGatewayServer(o *Options) *Server { + s := RunServer(o) + s.SetLogger(&DummyLogger{}, true, true) + return s +} + +func testGatewayOptionsFromToWithServers(t *testing.T, org, dst string, servers ...*Server) *Options { + t.Helper() + o := testDefaultOptionsForGateway(org) + gw := &RemoteGatewayOpts{ + Name: dst, + Permissions: &GatewayPermissions{ + Import: &SubjectPermission{Allow: []string{">"}}, + Export: &SubjectPermission{Allow: []string{">"}}, + }, + } + for _, s := range servers { + us := fmt.Sprintf("nats://127.0.0.1:%d", s.GatewayAddr().Port) + u, err := url.Parse(us) + if err != nil { + t.Fatalf("Error parsing url: %v", err) + } + gw.URLs = append(gw.URLs, u) + } + o.Gateway.Gateways = append(o.Gateway.Gateways, gw) + return o +} + +func testAddGatewayURLs(t *testing.T, o *Options, dst string, urls []string) { + t.Helper() + gw := &RemoteGatewayOpts{ + Name: dst, + Permissions: &GatewayPermissions{ + Import: &SubjectPermission{Allow: []string{">"}}, + Export: &SubjectPermission{Allow: []string{">"}}, + }, + } + for _, us := range urls { + u, err := url.Parse(us) + if err != nil { + t.Fatalf("Error parsing url: %v", err) + } + gw.URLs = append(gw.URLs, u) + } + o.Gateway.Gateways = append(o.Gateway.Gateways, gw) +} + +func testGatewayOptionsFromToWithURLs(t *testing.T, org, dst string, urls []string) *Options { + o := testDefaultOptionsForGateway(org) + testAddGatewayURLs(t, o, dst, urls) + return o +} + +func testGatewayOptionsWithTLS(t *testing.T, name string) *Options { + t.Helper() + o := testDefaultOptionsForGateway(name) + var ( + tc = &TLSConfigOpts{} + err error + ) + if name == "A" { + tc.CertFile = "../test/configs/certs/srva-cert.pem" + tc.KeyFile = "../test/configs/certs/srva-key.pem" + } else { + tc.CertFile = "../test/configs/certs/srvb-cert.pem" + tc.KeyFile = "../test/configs/certs/srvb-key.pem" + } + tc.CaFile = "../test/configs/certs/ca.pem" + o.Gateway.TLSConfig, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating TLS config: %v", err) + } + o.Gateway.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + o.Gateway.TLSConfig.RootCAs = o.Gateway.TLSConfig.ClientCAs + o.Gateway.TLSTimeout = 2.0 + return o +} + +func testGatewayOptionsFromToWithTLS(t *testing.T, org, dst string, urls []string) *Options { + o := testGatewayOptionsWithTLS(t, org) + testAddGatewayURLs(t, o, dst, urls) + return o +} + +func TestGatewayBasic(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + o2.Gateway.ConnectRetries = 0 + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // s1 should have an outbound gateway to s2. + waitForOutboundGateways(t, s1, 1, time.Second) + // s2 should have an inbound gateway + waitForInboundGateways(t, s2, 1, time.Second) + // and an outbound too + waitForOutboundGateways(t, s2, 1, time.Second) + + // Stop s2 server + s2.Shutdown() + + // gateway should go away + waitForOutboundGateways(t, s1, 0, time.Second) + waitForInboundGateways(t, s1, 0, time.Second) + + // Restart server + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + // gateway should reconnect + waitForOutboundGateways(t, s1, 1, 2*time.Second) + waitForOutboundGateways(t, s2, 1, 2*time.Second) + waitForInboundGateways(t, s1, 1, 2*time.Second) + waitForInboundGateways(t, s2, 1, 2*time.Second) + + // Shutdown s1, remove the gateway from A to B and restart. + s1.Shutdown() + // When s2 detects the connection is closed, it will attempt + // to reconnect once (even if the route is implicit). + // Wait a bit before restarting s1. For Windows, we need to wait + // more than the dialTimeout before restarting the server. + wait := 500 * time.Millisecond + if runtime.GOOS == "windows" { + wait = 1200 * time.Millisecond + } + time.Sleep(wait) + // Restart s1 without gateway to B. + o1.Gateway.Gateways = nil + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + // s1 should not have any outbound nor inbound + waitForOutboundGateways(t, s1, 0, 2*time.Second) + waitForInboundGateways(t, s1, 0, 2*time.Second) + + // Same for s2 + waitForOutboundGateways(t, s2, 0, 2*time.Second) + waitForInboundGateways(t, s2, 0, 2*time.Second) + + // Verify that s2 no longer has A gateway in its list + checkFor(t, time.Second, 15*time.Millisecond, func() error { + if s2.getRemoteGateway("A") != nil { + return fmt.Errorf("Gateway A should have been removed from s2") + } + return nil + }) +} + +func TestGatewaySolicitDelay(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + // Set the solicit delay to 0. This tests that server will use its + // default value, currently set at 1 sec. + o1.gatewaysSolicitDelay = 0 + start := time.Now() + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // After 500ms, check outbound gateway. Should not be there. + time.Sleep(500 * time.Millisecond) + if time.Since(start) < defaultSolicitGatewaysDelay { + if s1.numOutboundGateways() > 0 { + t.Fatalf("The outbound gateway was initiated sooner than expected (%v)", time.Since(start)) + } + } + // Ultimately, s1 should have an outbound gateway to s2. + waitForOutboundGateways(t, s1, 1, 2*time.Second) + // s2 should have an inbound gateway + waitForInboundGateways(t, s2, 1, 2*time.Second) + + s1.Shutdown() + // Make sure that server can be shutdown while waiting + // for that initial solicit delay + o1.gatewaysSolicitDelay = 2 * time.Second + s1 = runGatewayServer(o1) + start = time.Now() + s1.Shutdown() + if dur := time.Since(start); dur >= 2*time.Second { + t.Fatalf("Looks like shutdown was delayed: %v", dur) + } +} + +func TestGatewaySolicitDelayWithImplicitOutbounds(t *testing.T) { + // Cause a situation where A connects to B, and because of + // delay of solicit gateways set on B, we want to make sure + // that B does not end-up with 2 connections to A. + o2 := testDefaultOptionsForGateway("B") + o2.gatewaysSolicitDelay = 500 * time.Millisecond + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // s1 should have an outbound and inbound gateway to s2. + waitForOutboundGateways(t, s1, 1, 2*time.Second) + // s2 should have an inbound gateway + waitForInboundGateways(t, s2, 1, 2*time.Second) + // Wait for more than s2 solicit delay + time.Sleep(750 * time.Millisecond) + // The way we store outbound (map key'ed by gw name), we would + // not know if we had created 2 (since the newer would replace + // the older in the map). But if a second connection was made, + // then s1 would have 2 inbounds. So check it has only 1. + waitForInboundGateways(t, s1, 1, time.Second) +} + +type slowResolver struct{} + +func (r *slowResolver) LookupHost(ctx context.Context, h string) ([]string, error) { + time.Sleep(500 * time.Millisecond) + return []string{h}, nil +} + +func TestGatewaySolicitShutdown(t *testing.T) { + var urls []string + for i := 0; i < 5; i++ { + u := fmt.Sprintf("nats://127.0.0.1:%d", 1234+i) + urls = append(urls, u) + } + o1 := testGatewayOptionsFromToWithURLs(t, "A", "B", urls) + o1.Gateway.resolver = &slowResolver{} + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + time.Sleep(o1.gatewaysSolicitDelay + 10*time.Millisecond) + + start := time.Now() + s1.Shutdown() + if dur := time.Since(start); dur > 1200*time.Millisecond { + t.Fatalf("Took too long to shutdown: %v", dur) + } +} + +func TestGatewayListenError(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testDefaultOptionsForGateway("A") + o1.Gateway.Port = s2.GatewayAddr().Port + s1 := New(o1) + defer s1.Shutdown() + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s1.Start() + wg.Done() + }() + // We call Fatalf on listen error, but since there is no actual logger + // associated, we just check that the listener is not created. + time.Sleep(100 * time.Millisecond) + addr := s1.GatewayAddr() + if addr != nil { + t.Fatal("Listener should not have been created") + } + s1.Shutdown() + wg.Wait() +} + +func TestGatewayAdvertise(t *testing.T) { + o3 := testDefaultOptionsForGateway("C") + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + // Set the advertise so that this points to C + o1.Gateway.Advertise = fmt.Sprintf("127.0.0.1:%d", s3.GatewayAddr().Port) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // We should have outbound from s1 to s2 + waitForOutboundGateways(t, s1, 1, time.Second) + // But no inbound from s2 + waitForInboundGateways(t, s1, 0, time.Second) + + // And since B tries to connect to A but reaches C, it should fail to connect, + // and without connect retries, stop trying. So no outbound for s2, and no + // inbound/outbound for s3. + waitForInboundGateways(t, s2, 1, time.Second) + waitForOutboundGateways(t, s2, 0, time.Second) + waitForInboundGateways(t, s3, 0, time.Second) + waitForOutboundGateways(t, s3, 0, time.Second) +} + +func TestGatewayAdvertiseErr(t *testing.T) { + o1 := testDefaultOptionsForGateway("A") + o1.Gateway.Advertise = "wrong:address" + s1 := New(o1) + defer s1.Shutdown() + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s1.Start() + wg.Done() + }() + // We call Fatalf on listen error, but since there is no actual logger + // associated, we just check that the listener is not created. + time.Sleep(100 * time.Millisecond) + addr := s1.GatewayAddr() + if addr != nil { + t.Fatal("Listener should not have been created") + } + s1.Shutdown() + wg.Wait() +} + +func TestGatewayAuth(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + o2.Gateway.Username = "me" + o2.Gateway.Password = "pwd" + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithURLs(t, "A", "B", []string{fmt.Sprintf("nats://me:pwd@127.0.0.1:%d", s2.GatewayAddr().Port)}) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // s1 should have an outbound gateway to s2. + waitForOutboundGateways(t, s1, 1, time.Second) + // s2 should have an inbound gateway + waitForInboundGateways(t, s2, 1, time.Second) + + s2.Shutdown() + s1.Shutdown() + + o2.Gateway.Username = "me" + o2.Gateway.Password = "wrong" + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + // Connection should fail... + waitForGatewayFailedConnect(t, s1, "B", true, 2*time.Second) + + s2.Shutdown() + s1.Shutdown() + o2.Gateway.Username = "wrong" + o2.Gateway.Password = "pwd" + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + // Connection should fail... + waitForGatewayFailedConnect(t, s1, "B", true, 2*time.Second) +} + +func TestGatewayTLS(t *testing.T) { + o2 := testGatewayOptionsWithTLS(t, "B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithTLS(t, "A", "B", []string{fmt.Sprintf("nats://127.0.0.1:%d", s2.GatewayAddr().Port)}) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // s1 should have an outbound gateway to s2. + waitForOutboundGateways(t, s1, 1, time.Second) + // s2 should have an inbound gateway + waitForInboundGateways(t, s2, 1, time.Second) + + // Stop s2 server + s2.Shutdown() + + // gateway should go away + waitForOutboundGateways(t, s1, 0, time.Second) + waitForInboundGateways(t, s2, 0, time.Second) + + // Restart server + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + // gateway should reconnect + waitForOutboundGateways(t, s1, 1, 2*time.Second) + waitForInboundGateways(t, s2, 1, 2*time.Second) + + s1.Shutdown() + // Wait for s2 to lose connections to s1. + waitForOutboundGateways(t, s2, 0, 2*time.Second) + waitForInboundGateways(t, s2, 0, 2*time.Second) + + // Make an explicit TLS config for remote gateway config "B" + // on cluster A. + o1.Gateway.Gateways[0].TLSConfig = o1.Gateway.TLSConfig.Clone() + // Make the TLSTimeout so small that it should fail to connect. + smallTimeout := 0.00000001 + o1.Gateway.Gateways[0].TLSTimeout = smallTimeout + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + // s2 should be able to create connection to s1 though + waitForOutboundGateways(t, s2, 1, 2*time.Second) + // Check that s1 reports connection failures + waitForGatewayFailedConnect(t, s1, "B", true, 2*time.Second) + + // Check that TLSConfig from s1's remote "B" is based on + // what we have configured. + cfg := s1.getRemoteGateway("B") + var ( + tlsConfig *tls.Config + timeout float64 + ) + cfg.RLock() + if cfg.TLSConfig != nil { + tlsConfig = cfg.TLSConfig.Clone() + } + timeout = cfg.TLSTimeout + cfg.RUnlock() + if tlsConfig.ServerName != "127.0.0.1" { + t.Fatalf("Expected server name to be localhost, got %v", tlsConfig.ServerName) + } + if timeout != smallTimeout { + t.Fatalf("Expected tls timeout to be %v, got %v", smallTimeout, timeout) + } + s1.Shutdown() + // Wait for s2 to lose connections to s1. + waitForOutboundGateways(t, s2, 0, 2*time.Second) + waitForInboundGateways(t, s2, 0, 2*time.Second) + + // Remove explicit TLSTimeout from gateway "B" and check that + // we use the A's spec one. + o1.Gateway.Gateways[0].TLSTimeout = 0 + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + waitForInboundGateways(t, s1, 1, time.Second) + waitForInboundGateways(t, s2, 1, time.Second) + + cfg = s1.getRemoteGateway("B") + cfg.RLock() + timeout = cfg.TLSTimeout + cfg.RUnlock() + if timeout != o1.Gateway.TLSTimeout { + t.Fatalf("Expected tls timeout to be %v, got %v", o1.Gateway.TLSTimeout, timeout) + } +} + +func TestGatewayTLSErrors(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithTLS(t, "A", "B", []string{fmt.Sprintf("nats://127.0.0.1:%d", s2.ClusterAddr().Port)}) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // Expect s1 to have a failed to connect count > 0 + waitForGatewayFailedConnect(t, s1, "B", true, 2*time.Second) +} + +func TestGatewayWrongDestination(t *testing.T) { + // Start a server with a gateway named "C" + o2 := testDefaultOptionsForGateway("C") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + // Configure a gateway to "B", but since we are connecting to "C"... + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // we should not be able to connect. + waitForGatewayFailedConnect(t, s1, "B", true, time.Second) + + // Shutdown s2 and fix the gateway name. + // s1 should then connect ok and failed connect should be cleared. + s2.Shutdown() + + // Reset the conn attempts + cfg := s1.getRemoteGateway("B") + cfg.resetConnAttempts() + + o2.Gateway.Name = "B" + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + // At some point, the number of failed connect count should be reset to 0. + waitForGatewayFailedConnect(t, s1, "B", false, 2*time.Second) +} + +func TestGatewayConnectToWrongPort(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + // Configure a gateway to "B", but connect to the wrong port + urls := []string{fmt.Sprintf("nats://127.0.0.1:%d", s2.Addr().(*net.TCPAddr).Port)} + o1 := testGatewayOptionsFromToWithURLs(t, "A", "B", urls) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // we should not be able to connect. + waitForGatewayFailedConnect(t, s1, "B", true, time.Second) + + s1.Shutdown() + + // Repeat with route port + urls = []string{fmt.Sprintf("nats://127.0.0.1:%d", s2.ClusterAddr().Port)} + o1 = testGatewayOptionsFromToWithURLs(t, "A", "B", urls) + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + // we should not be able to connect. + waitForGatewayFailedConnect(t, s1, "B", true, time.Second) + + s1.Shutdown() + + // Now have a client connect to s2's gateway port. + nc, err := nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", s2.GatewayAddr().Port)) + if err == nil { + nc.Close() + t.Fatal("Expected error, got none") + } +} + +func TestGatewayCreateImplicit(t *testing.T) { + // Create a regular cluster of 2 servers + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o3 := testDefaultOptionsForGateway("B") + o3.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s2.ClusterAddr().Port)) + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + checkClusterFormed(t, s2, s3) + + // Now start s1 that creates a Gateway connection to s2 or s3 + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2, s3) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // We should have an outbound gateway connection on ALL servers. + waitForOutboundGateways(t, s1, 1, 2*time.Second) + waitForOutboundGateways(t, s2, 1, 2*time.Second) + waitForOutboundGateways(t, s3, 1, 2*time.Second) + + // Server s1 must have 2 inbound ones + waitForInboundGateways(t, s1, 2, 2*time.Second) + + // However, s1 may have created the outbound to s2 or s3. It is possible that + // either s2 or s3 does not an inbound connection. + s2Inbound := s2.numInboundGateways() + s3Inbound := s3.numInboundGateways() + if (s2Inbound == 1 && s3Inbound != 0) || (s3Inbound == 1 && s2Inbound != 0) { + t.Fatalf("Unexpected inbound for s2/s3: %v/%v", s2Inbound, s3Inbound) + } +} + +func TestGatewayCreateImplicitOnNewRoute(t *testing.T) { + // Start with only 2 clusters of 1 server each + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + // Now start s1 that creates a Gateway connection to s2 + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // Check outbounds + waitForOutboundGateways(t, s1, 1, 2*time.Second) + waitForOutboundGateways(t, s2, 1, 2*time.Second) + + // Now add a server to cluster B + o3 := testDefaultOptionsForGateway("B") + o3.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s2.ClusterAddr().Port)) + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + // Wait for cluster between s2/s3 to form + checkClusterFormed(t, s2, s3) + + // s3 should have been notified about existence of A and create its gateway to A. + waitForOutboundGateways(t, s1, 1, 2*time.Second) + waitForOutboundGateways(t, s2, 1, 2*time.Second) + waitForOutboundGateways(t, s3, 1, 2*time.Second) +} + +func TestGatewayImplicitReconnect(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + o2.Gateway.ConnectRetries = 5 + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // s1 should have an outbound gateway to s2. + waitForOutboundGateways(t, s1, 1, time.Second) + // s2 should have an inbound gateway + waitForInboundGateways(t, s2, 1, time.Second) + // It will have also created an implicit outbound connection to s1. + // We need to wait for that implicit outbound connection to be made + // to show that it will try to reconnect when we stop/restart s1 + // (without config to connect to B). + waitForOutboundGateways(t, s2, 1, time.Second) + + // Shutdown s1, remove the gateway from A to B and restart. + s1.Shutdown() + o1.Gateway.Gateways = o1.Gateway.Gateways[:0] + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + // s1 should have both outbound and inbound to s2 + waitForOutboundGateways(t, s1, 1, 2*time.Second) + waitForInboundGateways(t, s1, 1, 2*time.Second) + + // Same for s2 + waitForOutboundGateways(t, s2, 1, 2*time.Second) + waitForInboundGateways(t, s2, 1, 2*time.Second) + + // Verify that s2 still has "A" in its gateway config + if s2.getRemoteGateway("A") == nil { + t.Fatal("Gateway A should be in s2") + } +} + +func TestGatewayURLsFromClusterSentInINFO(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o3 := testDefaultOptionsForGateway("B") + o3.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s2.ClusterAddr().Port)) + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + checkClusterFormed(t, s2, s3) + + // Now start s1 that creates a Gateway connection to s2 + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // Make sure we have proper outbound/inbound + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + waitForOutboundGateways(t, s3, 1, time.Second) + + // Although s1 connected to s2 and knew only about s2, it should have + // received the list of gateway URLs in the B cluster. So if we shutdown + // server s2, it should be able to reconnect to s3. + s2.Shutdown() + // Wait for s3 to register that there s2 is gone. + checkNumRoutes(t, s3, 0) + // s1 should have reconnected to s3 because it learned about it + // when connecting earlier to s2. + waitForOutboundGateways(t, s1, 1, 2*time.Second) + // Also make sure that the gateway's urls map has 2 urls. + gw := s1.getRemoteGateway("B") + if gw == nil { + t.Fatal("Did not find gateway B") + } + gw.RLock() + l := len(gw.urls) + gw.RUnlock() + if l != 2 { + t.Fatalf("S1 should have 2 urls, got %v", l) + } +} + +func TestGatewayAutoDiscovery(t *testing.T) { + o4 := testDefaultOptionsForGateway("D") + s4 := runGatewayServer(o4) + defer s4.Shutdown() + + o3 := testGatewayOptionsFromToWithServers(t, "C", "D", s4) + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + o2 := testGatewayOptionsFromToWithServers(t, "B", "C", s3) + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // Each server should have 3 outbound gateway connections. + waitForOutboundGateways(t, s1, 3, 2*time.Second) + waitForOutboundGateways(t, s2, 3, 2*time.Second) + waitForOutboundGateways(t, s3, 3, 2*time.Second) + waitForOutboundGateways(t, s4, 3, 2*time.Second) + + s1.Shutdown() + s2.Shutdown() + s3.Shutdown() + s4.Shutdown() + + o2 = testDefaultOptionsForGateway("B") + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + o4 = testGatewayOptionsFromToWithServers(t, "D", "B", s2) + s4 = runGatewayServer(o4) + defer s4.Shutdown() + + o3 = testGatewayOptionsFromToWithServers(t, "C", "B", s2) + s3 = runGatewayServer(o3) + defer s3.Shutdown() + + o1 = testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + // Each server should have 3 outbound gateway connections. + waitForOutboundGateways(t, s1, 3, 2*time.Second) + waitForOutboundGateways(t, s2, 3, 2*time.Second) + waitForOutboundGateways(t, s3, 3, 2*time.Second) + waitForOutboundGateways(t, s4, 3, 2*time.Second) + + s1.Shutdown() + s2.Shutdown() + s3.Shutdown() + s4.Shutdown() + + o1 = testDefaultOptionsForGateway("A") + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + o2 = testDefaultOptionsForGateway("A") + o2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s1.ClusterAddr().Port)) + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + o3 = testDefaultOptionsForGateway("A") + o3.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s1.ClusterAddr().Port)) + s3 = runGatewayServer(o3) + defer s3.Shutdown() + + checkClusterFormed(t, s1, s2, s3) + + o4 = testGatewayOptionsFromToWithServers(t, "B", "A", s1) + s4 = runGatewayServer(o4) + defer s4.Shutdown() + + waitForOutboundGateways(t, s1, 1, 2*time.Second) + waitForOutboundGateways(t, s2, 1, 2*time.Second) + waitForOutboundGateways(t, s3, 1, 2*time.Second) + waitForOutboundGateways(t, s4, 1, 2*time.Second) + waitForInboundGateways(t, s4, 3, 2*time.Second) + + o5 := testGatewayOptionsFromToWithServers(t, "C", "B", s4) + s5 := runGatewayServer(o5) + defer s5.Shutdown() + + waitForOutboundGateways(t, s1, 2, 2*time.Second) + waitForOutboundGateways(t, s2, 2, 2*time.Second) + waitForOutboundGateways(t, s3, 2, 2*time.Second) + waitForOutboundGateways(t, s4, 2, 2*time.Second) + waitForInboundGateways(t, s4, 4, 2*time.Second) + waitForOutboundGateways(t, s5, 2, 2*time.Second) + waitForInboundGateways(t, s5, 4, 2*time.Second) +} + +func TestGatewayRejectUnknown(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + // Create a gateway from A to B, but configure B to reject non configured ones. + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + o1.Gateway.RejectUnknown = true + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // Wait for outbound/inbound to be created. + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + waitForInboundGateways(t, s1, 1, time.Second) + waitForInboundGateways(t, s2, 1, time.Second) + + // Create gateway C to B. B will tell C to connect to A, + // which A should reject. + o3 := testGatewayOptionsFromToWithServers(t, "C", "B", s2) + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + // s3 should have outbound to B, but not to A + waitForOutboundGateways(t, s3, 1, time.Second) + // s2 should have 2 inbounds (one from s1 one from s3) + waitForInboundGateways(t, s2, 2, time.Second) + + // s1 should have single outbound/inbound with s2. + waitForOutboundGateways(t, s1, 1, time.Second) + waitForInboundGateways(t, s1, 1, time.Second) + + // It should not have a registered remote gateway with C (s3) + if s1.getOutboundGateway("C") != nil { + t.Fatalf("A should not have outbound gateway to C") + } + if s1.getRemoteGateway("C") != nil { + t.Fatalf("A should not have a registered remote gateway to C") + } + + // Restart s1 and this time, B will tell A to connect to C. + // But A will not even attempt that since it does not have + // C configured. + s1.Shutdown() + waitForOutboundGateways(t, s2, 1, time.Second) + waitForInboundGateways(t, s2, 1, time.Second) + s1 = runGatewayServer(o1) + defer s1.Shutdown() + waitForOutboundGateways(t, s2, 2, time.Second) + waitForInboundGateways(t, s2, 2, time.Second) + waitForOutboundGateways(t, s1, 1, time.Second) + waitForInboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s3, 1, time.Second) + waitForInboundGateways(t, s3, 1, time.Second) + // It should not have a registered remote gateway with C (s3) + if s1.getOutboundGateway("C") != nil { + t.Fatalf("A should not have outbound gateway to C") + } + if s1.getRemoteGateway("C") != nil { + t.Fatalf("A should not have a registered remote gateway to C") + } +} + +func TestGatewayNoReconnectOnClose(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + // Shutdown s1, and check that there is no attempt to reconnect. + s1.Shutdown() + time.Sleep(250 * time.Millisecond) + waitForOutboundGateways(t, s1, 0, time.Second) + waitForOutboundGateways(t, s2, 0, time.Second) + waitForInboundGateways(t, s2, 0, time.Second) +} + +func TestGatewayDontSendSubInterest(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + s2Url := fmt.Sprintf("nats://127.0.0.1:%d", o2.Port) + subnc := natsConnect(t, s2Url) + defer subnc.Close() + natsSub(t, subnc, "foo", func(_ *nats.Msg) {}) + natsFlush(t, subnc) + + checkExpectedSubs(t, 1, s2) + // Subscription should not be sent to s1 + checkExpectedSubs(t, 0, s1) + + // Restart s1 + s1.Shutdown() + s1 = runGatewayServer(o1) + defer s1.Shutdown() + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + checkExpectedSubs(t, 1, s2) + checkExpectedSubs(t, 0, s1) +} + +func setAccountUserPassInOptions(o *Options, accName, username, password string) { + acc := &Account{Name: accName} + o.Accounts = append(o.Accounts, acc) + o.Users = append(o.Users, &User{Username: username, Password: password, Account: acc}) +} + +func TestGatewayAccountInterest(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + setAccountUserPassInOptions(o1, "$foo", "ivan", "password") + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + // Make this server initiate connection to A, so it is faster + // when restarting it at the end of this test. + o3 := testGatewayOptionsFromToWithServers(t, "C", "A", s1) + setAccountUserPassInOptions(o3, "$foo", "ivan", "password") + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + waitForOutboundGateways(t, s1, 2, time.Second) + waitForOutboundGateways(t, s2, 2, time.Second) + waitForOutboundGateways(t, s3, 2, time.Second) + + s1Url := fmt.Sprintf("nats://ivan:password@127.0.0.1:%d", o1.Port) + nc := natsConnect(t, s1Url) + defer nc.Close() + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + + // On first send, the message should be sent. + checkCount := func(t *testing.T, c *client, expected int) { + t.Helper() + c.mu.Lock() + out := c.outMsgs + c.mu.Unlock() + if int(out) != expected { + t.Fatalf("Expected %d message(s) to be sent over, got %v", expected, out) + } + } + gwcb := s1.getOutboundGateway("B") + checkCount(t, gwcb, 1) + gwcc := s1.getOutboundGateway("C") + checkCount(t, gwcc, 1) + + // S2 should have sent a protocol indicating no interest. + checkFor(t, time.Second, 15*time.Millisecond, func() error { + _, noInterest := gwcb.gw.noInterest.Load("$foo") + if !noInterest { + return fmt.Errorf("Did not receive account no interest") + } + return nil + }) + // Second send should not go through to B + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 1) + // it won't go to C, not because there is no account interest, + // but because there is no interest on the subject. + checkCount(t, gwcc, 1) + + // Add account to S2, this should clear the no interest for that account. + s2.RegisterAccount("$foo") + checkFor(t, time.Second, 15*time.Millisecond, func() error { + _, noInterest := gwcb.gw.noInterest.Load("$foo") + if noInterest { + return fmt.Errorf("NoInterest has not been cleared") + } + return nil + }) + // Now publish a message that should go to B + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 2) + // Still won't go to C since there is no sub interest + checkCount(t, gwcc, 1) + + // Restart C and that should reset the no-interest + s3.Shutdown() + s3 = runGatewayServer(o3) + defer s3.Shutdown() + + waitForOutboundGateways(t, s1, 2, 2*time.Second) + waitForOutboundGateways(t, s2, 2, 2*time.Second) + waitForOutboundGateways(t, s3, 2, 2*time.Second) + + // First refresh gwcc + gwcc = s1.getOutboundGateway("C") + // Verify that it's count is 0 + checkCount(t, gwcc, 0) + // Publish and now... + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + // it should not go to B (no sub interest) + checkCount(t, gwcb, 2) + // but will go to C + checkCount(t, gwcc, 1) +} + +func TestGatewaySubjectInterest(t *testing.T) { + o1 := testDefaultOptionsForGateway("A") + setAccountUserPassInOptions(o1, "$foo", "ivan", "password") + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + o2 := testGatewayOptionsFromToWithServers(t, "B", "A", s1) + setAccountUserPassInOptions(o2, "$foo", "ivan", "password") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + s1Url := fmt.Sprintf("nats://ivan:password@127.0.0.1:%d", o1.Port) + nc := natsConnect(t, s1Url) + defer nc.Close() + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + + // On first send, the message should be sent. + checkCount := func(t *testing.T, c *client, expected int) { + t.Helper() + c.mu.Lock() + out := c.outMsgs + c.mu.Unlock() + if int(out) != expected { + t.Fatalf("Expected %d message(s) to be sent over, got %v", expected, out) + } + } + gwcb := s1.getOutboundGateway("B") + checkCount(t, gwcb, 1) + + // S2 should have sent a protocol indicating no subject interest. + checkForNoInterest := func(t *testing.T) { + t.Helper() + checkFor(t, time.Second, 15*time.Millisecond, func() error { + sni, _ := gwcb.gw.noInterest.Load("$foo") + if sni == nil { + return fmt.Errorf("Did not receive subject no-interest") + } + if _, subjNoInterest := sni.(*sync.Map).Load("foo"); !subjNoInterest { + return fmt.Errorf("Did not receive subject no-interest") + } + return nil + }) + } + checkForNoInterest(t) + // Second send should not go through to B + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 1) + + // Now create subscription interest on B (s2) + s2Url := fmt.Sprintf("nats://ivan:password@127.0.0.1:%d", o2.Port) + ncb := natsConnect(t, s2Url) + defer ncb.Close() + sub := natsSub(t, ncb, "foo", func(_ *nats.Msg) {}) + natsFlush(t, ncb) + checkExpectedSubs(t, 1, s2) + checkExpectedSubs(t, 0, s1) + + // This should clear the no interest for this subject + checkFor(t, time.Second, 15*time.Millisecond, func() error { + sni, _ := gwcb.gw.noInterest.Load("$foo") + if sni == nil { + return fmt.Errorf("Did not receive subject no interest") + } + if _, noInterest := sni.(*sync.Map).Load("foo"); noInterest { + return fmt.Errorf("No-interest on foo should have been cleared") + } + return nil + }) + // Third send should not go through to B + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 2) + + // Now unsubscribe, there won't be an UNSUB sent to the gateway. + natsUnsub(t, sub) + natsFlush(t, ncb) + checkExpectedSubs(t, 0, s2) + checkExpectedSubs(t, 0, s1) + + // So now sending a message should go over, but then we should get an RS- + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 3) + + checkForNoInterest(t) + + // Send one more time and now it should not go to B + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 3) + + // Restart B and that should clear everything on A + s2.Shutdown() + s2 = runGatewayServer(o2) + defer s2.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + gwcb = s1.getOutboundGateway("B") + checkCount(t, gwcb, 0) + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 1) + + checkForNoInterest(t) + + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 1) + + // Add a node to B cluster and subscribe there. + // We want to ensure that the no-interest is cleared + // when s2 receives remote SUB from s2bis + o2bis := testGatewayOptionsFromToWithServers(t, "B", "A", s1) + setAccountUserPassInOptions(o2bis, "$foo", "ivan", "password") + o2bis.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s2.ClusterAddr().Port)) + s2bis := runGatewayServer(o2bis) + defer s2bis.Shutdown() + + checkClusterFormed(t, s2, s2bis) + + // Make sure all outbound gateway connections are setup + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + waitForOutboundGateways(t, s2bis, 1, time.Second) + + // A should have 2 inbound + waitForInboundGateways(t, s1, 2, time.Second) + + // Create sub on s2bis + ncb2bis := natsConnect(t, fmt.Sprintf("nats://ivan:password@127.0.0.1:%d", o2bis.Port)) + defer ncb2bis.Close() + natsSub(t, ncb2bis, "foo", func(_ *nats.Msg) {}) + natsFlush(t, ncb2bis) + + // Wait for subscription to be registered locally on s2bis and remotely on s2 + checkExpectedSubs(t, 1, s2, s2bis) + + // Check that subject no-interest on A was cleared. + checkFor(t, time.Second, 15*time.Millisecond, func() error { + gw := s1.getOutboundGateway("B").gw + sni, _ := gw.noInterest.Load("$foo") + if sni == nil { + t.Fatalf("Account is marked as no-interest, it should not") + } + sn := sni.(*sync.Map) + if _, noInterest := sn.Load("foo"); noInterest { + return fmt.Errorf("No-interest still registered") + } + return nil + }) + + // Now publish. Remember, s1 has outbound gateway to s2, and s2 does not + // have a local subscription and has previously sent a no-interest on "foo". + // We check that this has been cleared due to the interest on s2bis. + natsPub(t, nc, "foo", []byte("hello")) + natsFlush(t, nc) + checkCount(t, gwcb, 2) +} + +func TestGatewayDoesntSendBackToItself(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + s2Url := fmt.Sprintf("nats://127.0.0.1:%d", o2.Port) + nc2 := natsConnect(t, s2Url) + defer nc2.Close() + + count := int32(0) + cb := func(_ *nats.Msg) { + atomic.AddInt32(&count, 1) + } + natsSub(t, nc2, "foo", cb) + natsFlush(t, nc2) + + s1Url := fmt.Sprintf("nats://127.0.0.1:%d", o1.Port) + nc1 := natsConnect(t, s1Url) + defer nc1.Close() + + natsSub(t, nc1, "foo", cb) + natsFlush(t, nc1) + + // Now send 1 message. If there is a cycle, after few ms we + // should have tons of messages... + natsPub(t, nc1, "foo", []byte("cycle")) + natsFlush(t, nc1) + time.Sleep(100 * time.Millisecond) + if c := atomic.LoadInt32(&count); c != 2 { + t.Fatalf("Expected only 2 messages, got %v", c) + } +} + +func TestGatewayQueueSub(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + s2Url := fmt.Sprintf("nats://127.0.0.1:%d", o2.Port) + nc2 := natsConnect(t, s2Url) + defer nc2.Close() + + count2 := int32(0) + cb2 := func(_ *nats.Msg) { + atomic.AddInt32(&count2, 1) + } + qsubOnB := natsQueueSub(t, nc2, "foo", "bar", cb2) + natsFlush(t, nc2) + + s1Url := fmt.Sprintf("nats://127.0.0.1:%d", o1.Port) + nc1 := natsConnect(t, s1Url) + defer nc1.Close() + + count1 := int32(0) + cb1 := func(_ *nats.Msg) { + atomic.AddInt32(&count1, 1) + } + qsubOnA := natsQueueSub(t, nc1, "foo", "bar", cb1) + natsFlush(t, nc1) + + // Make sure subs are registered on each server + checkExpectedSubs(t, 1, s1, s2) + + total := 100 + send := func(t *testing.T, nc *nats.Conn) { + t.Helper() + for i := 0; i < total; i++ { + // Alternate with adding a reply + if i%2 == 0 { + natsPubReq(t, nc, "foo", "reply", []byte("msg")) + } else { + natsPub(t, nc, "foo", []byte("msg")) + } + } + natsFlush(t, nc) + } + send(t, nc1) + + check := func(t *testing.T, count *int32, expected int) { + t.Helper() + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + if n := int(atomic.LoadInt32(count)); n != expected { + return fmt.Errorf("Expected to get %v messages, got %v", expected, n) + } + return nil + }) + } + check(t, &count1, total) + check(t, &count2, 0) + + // Now send from the other side + send(t, nc2) + check(t, &count1, total) + check(t, &count2, total) + + // Reset counters + atomic.StoreInt32(&count1, 0) + atomic.StoreInt32(&count2, 0) + + // Stop qsub on A, and send messages to A, they should + // be routed to B. + qsubOnA.Unsubscribe() + checkExpectedSubs(t, 0, s1) + send(t, nc1) + check(t, &count1, 0) + check(t, &count2, total) + + // Reset counters + atomic.StoreInt32(&count1, 0) + atomic.StoreInt32(&count2, 0) + + // Create a C gateway now + o3 := testGatewayOptionsFromToWithServers(t, "C", "B", s2) + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + waitForOutboundGateways(t, s1, 2, time.Second) + waitForOutboundGateways(t, s2, 2, time.Second) + waitForOutboundGateways(t, s3, 2, time.Second) + + waitForInboundGateways(t, s1, 2, time.Second) + waitForInboundGateways(t, s2, 2, time.Second) + waitForInboundGateways(t, s3, 2, time.Second) + + // Create another qsub "bar" + s3Url := fmt.Sprintf("nats://127.0.0.1:%d", o3.Port) + nc3 := natsConnect(t, s3Url) + defer nc3.Close() + // Associate this with count1 (since A qsub is no longer running) + natsQueueSub(t, nc3, "foo", "bar", cb1) + natsFlush(t, nc3) + checkExpectedSubs(t, 1, s3) + + // Artificially bump the RTT from A to C so that + // the code should favor sending to B. + gwcC := s1.getOutboundGateway("C") + gwcC.mu.Lock() + gwcC.rtt = 10 * time.Second + gwcC.mu.Unlock() + + send(t, nc1) + check(t, &count1, 0) + check(t, &count2, total) + + // Add a new group on s3 that should receive all messages + count3 := int32(0) + cb3 := func(_ *nats.Msg) { + atomic.AddInt32(&count3, 1) + } + natsQueueSub(t, nc3, "foo", "baz", cb3) + natsFlush(t, nc3) + checkExpectedSubs(t, 2, s3) + + // Reset counters + atomic.StoreInt32(&count1, 0) + atomic.StoreInt32(&count2, 0) + + // Make the RTTs equal + gwcC.mu.Lock() + gwcC.rtt = time.Second + gwcC.mu.Unlock() + + gwcB := s1.getOutboundGateway("B") + gwcB.mu.Lock() + gwcB.rtt = time.Second + gwcB.mu.Unlock() + + send(t, nc1) + // Group baz should receive all messages + check(t, &count3, total) + + // Since RTT are equal, messages should be distributed + c1 := 0 + c2 := 0 + tc := 0 + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + c1 = int(atomic.LoadInt32(&count1)) + c2 = int(atomic.LoadInt32(&count2)) + tc = c1 + c2 + if tc == total { + break + } + time.Sleep(15 * time.Millisecond) + } + if tc != total { + t.Fatalf("Expected %v messages, got %v", total, tc) + } + // Messages should not have gone to only one GW. + if c1 == 0 || c2 == 0 { + t.Fatalf("Messages went to only one GW, c1=%v c2=%v", c1, c2) + } + + // Unsubscribe qsub on B and C should receive + // all messages on count1 and count3. + qsubOnB.Unsubscribe() + checkExpectedSubs(t, 0, s2) + + // gwcB should have the qsubs interest map empty now. + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + gwcB.mu.Lock() + _, exists := gwcB.gw.qsubsInterest.Load(globalAccountName) + gwcB.mu.Unlock() + if exists { + return fmt.Errorf("Qsub interest for account should have been removed") + } + return nil + }) + + // Reset counters + atomic.StoreInt32(&count1, 0) + atomic.StoreInt32(&count2, 0) + atomic.StoreInt32(&count3, 0) + + send(t, nc1) + check(t, &count1, total) + check(t, &count3, total) +} + +func TestGatewaySendQSubsOnGatewayConnect(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + s2Url := fmt.Sprintf("nats://127.0.0.1:%d", o2.Port) + subnc := natsConnect(t, s2Url) + defer subnc.Close() + + ch := make(chan bool, 1) + cb := func(_ *nats.Msg) { + ch <- true + } + natsQueueSub(t, subnc, "foo", "bar", cb) + natsFlush(t, subnc) + + // Now start s1 that creates a gateway to s2 + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + // Publish from s1, message should be received on s2. + pubnc := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", o1.Port)) + defer pubnc.Close() + // Publish 1 message + natsPub(t, pubnc, "foo", []byte("hello")) + waitCh(t, ch, "Did not get out message") + pubnc.Close() + + s1.Shutdown() + s1 = runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + pubnc = natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", o1.Port)) + defer pubnc.Close() + // Publish 1 message + natsPub(t, pubnc, "foo", []byte("hello")) + waitCh(t, ch, "Did not get out message") +} + +func TestGatewaySendRemoteQSubs(t *testing.T) { + ob1 := testDefaultOptionsForGateway("B") + sb1 := runGatewayServer(ob1) + defer sb1.Shutdown() + + ob2 := testDefaultOptionsForGateway("B") + ob2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", ob1.Cluster.Host, ob1.Cluster.Port)) + sb2 := runGatewayServer(ob2) + + checkClusterFormed(t, sb1, sb2) + + sbURL := fmt.Sprintf("nats://127.0.0.1:%d", ob2.Port) + subnc := natsConnect(t, sbURL) + defer subnc.Close() + + ch := make(chan bool, 1) + cb := func(_ *nats.Msg) { + ch <- true + } + qsub1 := natsQueueSub(t, subnc, "foo", "bar", cb) + qsub2 := natsQueueSub(t, subnc, "foo", "bar", cb) + natsFlush(t, subnc) + + // There will be 2 local qsubs on the sb2 server where the client is connected + checkExpectedSubs(t, 2, sb2) + // But only 1 remote on sb1 + checkExpectedSubs(t, 1, sb1) + + // Now start s1 that creates a gateway to sb1 (the one that does not have the local QSub) + oa := testGatewayOptionsFromToWithServers(t, "A", "B", sb1) + sa := runGatewayServer(oa) + defer sa.Shutdown() + + waitForOutboundGateways(t, sa, 1, time.Second) + waitForOutboundGateways(t, sb1, 1, time.Second) + waitForOutboundGateways(t, sb2, 1, time.Second) + + // Publish from s1, message should be received on s2. + saURL := fmt.Sprintf("nats://127.0.0.1:%d", oa.Port) + pubnc := natsConnect(t, saURL) + defer pubnc.Close() + // Publish 1 message + natsPub(t, pubnc, "foo", []byte("hello")) + natsFlush(t, pubnc) + waitCh(t, ch, "Did not get out message") + + // Unsubscribe 1 qsub + natsUnsub(t, qsub1) + natsFlush(t, subnc) + // There should be only 1 local qsub on sb2 now, and the remote should still exist on sb1 + checkExpectedSubs(t, 1, sb1, sb2) + + // Publish 1 message + natsPub(t, pubnc, "foo", []byte("hello")) + natsFlush(t, pubnc) + waitCh(t, ch, "Did not get out message") + + // Unsubscribe the remaining + natsUnsub(t, qsub2) + natsFlush(t, subnc) + + // No more subs now on both sb1 and sb2 + checkExpectedSubs(t, 0, sb1, sb2) + + // Server sb1 should not have qsub in its rqs map + checkFor(t, time.Second, 15*time.Millisecond, func() error { + sb1.gateway.RLock() + entry := sb1.gateway.rqs[fmt.Sprintf("%s foo bar", globalAccountName)] + sb1.gateway.RUnlock() + if entry != nil { + return fmt.Errorf("Map rqs should not have an entry, got %#v", entry) + } + return nil + }) + + // Let's wait for A to receive the unsubscribe + checkFor(t, time.Second, 15*time.Millisecond, func() error { + gw := sa.getOutboundGateway("B").gw + if sli, _ := gw.qsubsInterest.Load(globalAccountName); sli != nil { + return fmt.Errorf("Interest still present") + } + return nil + }) + + // Now send a message, and it should not be sent because of qsubs, + // but will be sent as optimistic psub send. + natsPub(t, pubnc, "foo", []byte("hello")) + natsFlush(t, pubnc) + + // Get the gateway connection from A (sa) to B (sb1) + gw := sa.getOutboundGateway("B") + gw.mu.Lock() + out := gw.outMsgs + gw.mu.Unlock() + if out != 3 { + t.Fatalf("Expected 2 out messages, got %v", out) + } + + // Wait for the no interest to be received by A + checkFor(t, time.Second, 15*time.Millisecond, func() error { + gw := sa.getOutboundGateway("B").gw + if ni, _ := gw.noInterest.Load(globalAccountName); ni == nil { + return fmt.Errorf("No-interest still not registered") + } + return nil + }) + + // Send another message, now it should not get out. + natsPub(t, pubnc, "foo", []byte("hello")) + natsFlush(t, pubnc) + + // Get the gateway connection from A (sa) to B (sb1) + gw.mu.Lock() + out = gw.outMsgs + gw.mu.Unlock() + if out != 3 { + t.Fatalf("Expected 2 out messages, got %v", out) + } + + // Restart A + pubnc.Close() + sa.Shutdown() + sa = runGatewayServer(oa) + defer sa.Shutdown() + + waitForOutboundGateways(t, sa, 1, time.Second) + + // Check qsubs interest should be empty + checkFor(t, time.Second, 15*time.Millisecond, func() error { + gw := sa.getOutboundGateway("B").gw + if sli, _ := gw.qsubsInterest.Load(globalAccountName); sli != nil { + return fmt.Errorf("Interest still present") + } + return nil + }) +} + +func TestGatewayComplexSetup(t *testing.T) { + doLog := false + + // This test will have the following setup: + // --- means route connection + // === means gateway connection + // [o] is outbound + // [i] is inbound + // Each server as an outbound connection to the other cluster. + // It may have 0 or more inbound connection(s). + // + // Cluster A Cluster B + // sa1 [o]===========>[i] + // | [i]<===========[o] + // | sb1 ------- sb2 + // | [i] [o] + // sa2 [o]=============^ || + // [i]<========================|| + ob1 := testDefaultOptionsForGateway("B") + sb1 := runGatewayServer(ob1) + defer sb1.Shutdown() + if doLog { + sb1.SetLogger(logger.NewTestLogger("[B1] - ", true), true, true) + } + + oa1 := testGatewayOptionsFromToWithServers(t, "A", "B", sb1) + sa1 := runGatewayServer(oa1) + defer sa1.Shutdown() + if doLog { + sa1.SetLogger(logger.NewTestLogger("[A1] - ", true), true, true) + } + + waitForOutboundGateways(t, sa1, 1, time.Second) + waitForOutboundGateways(t, sb1, 1, time.Second) + + waitForInboundGateways(t, sa1, 1, time.Second) + waitForInboundGateways(t, sb1, 1, time.Second) + + oa2 := testGatewayOptionsFromToWithServers(t, "A", "B", sb1) + oa2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", sa1.ClusterAddr().Port)) + sa2 := runGatewayServer(oa2) + defer sa2.Shutdown() + if doLog { + sa2.SetLogger(logger.NewTestLogger("[A2] - ", true), true, true) + } + + checkClusterFormed(t, sa1, sa2) + + waitForOutboundGateways(t, sa2, 1, time.Second) + waitForInboundGateways(t, sb1, 2, time.Second) + + ob2 := testGatewayOptionsFromToWithServers(t, "B", "A", sa2) + ob2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", sb1.ClusterAddr().Port)) + var sb2 *Server + for { + sb2 = runGatewayServer(ob2) + defer sb2.Shutdown() + + checkClusterFormed(t, sb1, sb2) + + waitForOutboundGateways(t, sb2, 1, time.Second) + waitForInboundGateways(t, sb2, 0, time.Second) + // For this test, we want the outbound to be to sa2, so if we don't have that, + // restart sb2 until we get lucky. + time.Sleep(100 * time.Millisecond) + if sa2.numInboundGateways() == 0 { + sb2.Shutdown() + sb2 = nil + } else { + break + } + } + if doLog { + sb2.SetLogger(logger.NewTestLogger("[B2] - ", true), true, true) + } + + ch := make(chan bool, 1) + cb := func(_ *nats.Msg) { + ch <- true + } + + // Create a subscription on sa1 and sa2. + ncsa1 := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", oa1.Port)) + defer ncsa1.Close() + sub1 := natsSub(t, ncsa1, "foo", cb) + natsFlush(t, ncsa1) + + ncsa2 := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", oa2.Port)) + defer ncsa2.Close() + sub2 := natsSub(t, ncsa2, "foo", cb) + natsFlush(t, ncsa2) + + // sa1 will have 1 local, one remote (from sa2), same for sa2. + checkExpectedSubs(t, 2, sa1, sa2) + + // Connect to sb2 and send 1 message + ncsb2 := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", ob2.Port)) + defer ncsb2.Close() + natsPub(t, ncsb2, "foo", []byte("hello")) + natsFlush(t, ncsb2) + + for i := 0; i < 2; i++ { + waitCh(t, ch, "Did not get our message") + } + + // Unsubscribe sub2, and send 1, should still get it. + natsUnsub(t, sub2) + natsFlush(t, ncsa2) + natsPub(t, ncsb2, "foo", []byte("hello")) + natsFlush(t, ncsb2) + waitCh(t, ch, "Did not get our message") + + // Unsubscribe sub1, all server's sublist should be empty + sub1.Unsubscribe() + natsFlush(t, ncsa1) + + checkExpectedSubs(t, 0, sa1, sa2, sb1, sb2) + + // Create queue subs + total := 100 + c1 := int32(0) + c2 := int32(0) + c3 := int32(0) + tc := int32(0) + natsQueueSub(t, ncsa1, "foo", "bar", func(_ *nats.Msg) { + atomic.AddInt32(&c1, 1) + if c := atomic.AddInt32(&tc, 1); int(c) == total { + ch <- true + } + }) + natsFlush(t, ncsa1) + natsQueueSub(t, ncsa2, "foo", "bar", func(_ *nats.Msg) { + atomic.AddInt32(&c2, 1) + if c := atomic.AddInt32(&tc, 1); int(c) == total { + ch <- true + } + }) + natsFlush(t, ncsa2) + checkExpectedSubs(t, 2, sa1, sa2) + + qsubOnB2 := natsQueueSub(t, ncsb2, "foo", "bar", func(_ *nats.Msg) { + atomic.AddInt32(&c3, 1) + if c := atomic.AddInt32(&tc, 1); int(c) == total { + ch <- true + } + }) + natsFlush(t, ncsb2) + checkExpectedSubs(t, 1, sb2) + + // Publish all messages. The queue sub on cluster B should receive all + // messages. + for i := 0; i < total; i++ { + natsPub(t, ncsb2, "foo", []byte("msg")) + } + natsFlush(t, ncsb2) + + waitCh(t, ch, "Did not get all our queue messages") + if n := int(atomic.LoadInt32(&c1)); n != 0 { + t.Fatalf("No message should have been received by qsub1, got %v", n) + } + if n := int(atomic.LoadInt32(&c2)); n != 0 { + t.Fatalf("No message should have been received by qsub2, got %v", n) + } + if n := int(atomic.LoadInt32(&c3)); n != total { + t.Fatalf("All messages should have been delivered to qsub on B, got %v", n) + } + + // Reset counters + atomic.StoreInt32(&c1, 0) + atomic.StoreInt32(&c2, 0) + atomic.StoreInt32(&c3, 0) + atomic.StoreInt32(&tc, 0) + + // Now send from cluster A, messages should be distributed to qsubs on A. + for i := 0; i < total; i++ { + natsPub(t, ncsa1, "foo", []byte("msg")) + } + natsFlush(t, ncsa1) + + expectedLow := int(float32(total/2) * 0.7) + expectedHigh := int(float32(total/2) * 1.3) + checkCount := func(t *testing.T, count *int32) { + t.Helper() + c := int(atomic.LoadInt32(count)) + if c < expectedLow || c > expectedHigh { + t.Fatalf("Expected value to be between %v/%v, got %v", expectedLow, expectedHigh, c) + } + } + waitCh(t, ch, "Did not get all our queue messages") + checkCount(t, &c1) + checkCount(t, &c2) + + // Now unsubscribe sub on B and reset counters + natsUnsub(t, qsubOnB2) + checkExpectedSubs(t, 0, sb2) + atomic.StoreInt32(&c1, 0) + atomic.StoreInt32(&c2, 0) + atomic.StoreInt32(&c3, 0) + atomic.StoreInt32(&tc, 0) + // Publish from cluster B, messages should be delivered to cluster A. + for i := 0; i < total; i++ { + natsPub(t, ncsb2, "foo", []byte("msg")) + } + natsFlush(t, ncsb2) + + waitCh(t, ch, "Did not get all our queue messages") + if n := int(atomic.LoadInt32(&c3)); n != 0 { + t.Fatalf("There should not have been messages on unsubscribed sub, got %v", n) + } + checkCount(t, &c1) + checkCount(t, &c2) +} + +func TestGatewayMsgSentOnlyOnce(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + s2Url := fmt.Sprintf("nats://127.0.0.1:%d", o2.Port) + nc2 := natsConnect(t, s2Url) + defer nc2.Close() + + s1Url := fmt.Sprintf("nats://127.0.0.1:%d", o1.Port) + nc1 := natsConnect(t, s1Url) + defer nc1.Close() + + ch := make(chan bool, 1) + count := int32(0) + expected := int32(4) + cb := func(_ *nats.Msg) { + if c := atomic.AddInt32(&count, 1); c == expected { + ch <- true + } + } + + // On s1, create 2 plain subs, 2 queue members for group + // "bar" and 1 for group "baz". + natsSub(t, nc1, ">", cb) + natsSub(t, nc1, "foo", cb) + natsQueueSub(t, nc1, "foo", "bar", cb) + natsQueueSub(t, nc1, "foo", "bar", cb) + natsQueueSub(t, nc1, "foo", "baz", cb) + natsFlush(t, nc1) + + // Ensure subs registered in S1 + checkExpectedSubs(t, 5, s1) + + // From s2, send 1 message, s1 should receive 1 only, + // and total we should get the callback notified 4 times. + natsPub(t, nc2, "foo", []byte("hello")) + natsFlush(t, nc2) + + waitCh(t, ch, "Did not get our messages") + // Verifiy that count is still 4 + if c := atomic.LoadInt32(&count); c != expected { + t.Fatalf("Expected %v messages, got %v", expected, c) + } + // Check s2 outbound connection stats. It should say that it + // sent only 1 message. + c := s2.getOutboundGateway("A") + if c == nil { + t.Fatalf("S2 outbound gateway not found") + } + c.mu.Lock() + out := c.outMsgs + c.mu.Unlock() + if out != 1 { + t.Fatalf("Expected s2's outbound gateway to have sent a single message, got %v", out) + } + // Now check s1's inbound gateway + s1.gateway.RLock() + c = nil + for _, ci := range s1.gateway.in { + c = ci + break + } + s1.gateway.RUnlock() + if c == nil { + t.Fatalf("S1 inbound gateway not found") + } + c.mu.Lock() + in := c.inMsgs + c.mu.Unlock() + if in != 1 { + t.Fatalf("Expected s1's inbound gateway to have received a single message, got %v", in) + } +} + +type checkErrorLogger struct { + DummyLogger + checkErrorStr string + gotError bool +} + +func (l *checkErrorLogger) Errorf(format string, args ...interface{}) { + l.DummyLogger.Errorf(format, args...) + l.Lock() + if strings.Contains(l.msg, l.checkErrorStr) { + l.gotError = true + } + l.Unlock() +} + +func TestGatewayRoutedServerWithoutGatewayConfigured(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + o3 := DefaultOptions() + o3.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s2.ClusterAddr().Port)) + s3 := New(o3) + defer s3.Shutdown() + l := &checkErrorLogger{checkErrorStr: "not configured"} + s3.SetLogger(l, true, true) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s3.Start() + wg.Done() + }() + + checkClusterFormed(t, s2, s3) + + // Check that server s3 does not panic when being notified + // about the A gateway, but report an error. + deadline := time.Now().Add(2 * time.Second) + gotIt := false + for time.Now().Before(deadline) { + l.Lock() + gotIt = l.gotError + l.Unlock() + if gotIt { + break + } + time.Sleep(15 * time.Millisecond) + } + if !gotIt { + t.Fatalf("Should have reported error about gateway not configured") + } + + s3.Shutdown() + wg.Wait() +} + +func TestGatewaySendsToNonLocalSubs(t *testing.T) { + ob1 := testDefaultOptionsForGateway("B") + sb1 := runGatewayServer(ob1) + defer sb1.Shutdown() + + oa1 := testGatewayOptionsFromToWithServers(t, "A", "B", sb1) + sa1 := runGatewayServer(oa1) + defer sa1.Shutdown() + + waitForOutboundGateways(t, sa1, 1, time.Second) + waitForOutboundGateways(t, sb1, 1, time.Second) + + waitForInboundGateways(t, sa1, 1, time.Second) + waitForInboundGateways(t, sb1, 1, time.Second) + + oa2 := testGatewayOptionsFromToWithServers(t, "A", "B", sb1) + oa2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", sa1.ClusterAddr().Port)) + sa2 := runGatewayServer(oa2) + defer sa2.Shutdown() + + checkClusterFormed(t, sa1, sa2) + + waitForOutboundGateways(t, sa2, 1, time.Second) + waitForInboundGateways(t, sb1, 2, time.Second) + + ch := make(chan bool, 1) + // Create an interest of sa2 + ncSub := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", oa2.Port)) + defer ncSub.Close() + natsSub(t, ncSub, "foo", func(_ *nats.Msg) { ch <- true }) + natsFlush(t, ncSub) + checkExpectedSubs(t, 1, sa1, sa2) + + // Produce a message from sb1, make sure it can be received. + ncPub := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", ob1.Port)) + defer ncPub.Close() + natsPub(t, ncPub, "foo", []byte("hello")) + waitCh(t, ch, "Did not get our message") + + ncSub.Close() + ncPub.Close() + checkExpectedSubs(t, 0, sa1, sa2) + + // Now create sb2 that has a route to sb1 and gateway connects to sa2. + ob2 := testGatewayOptionsFromToWithServers(t, "B", "A", sa2) + ob2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", sb1.ClusterAddr().Port)) + sb2 := runGatewayServer(ob2) + defer sb2.Shutdown() + + ncSub = natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", oa1.Port)) + defer ncSub.Close() + natsSub(t, ncSub, "foo", func(_ *nats.Msg) { ch <- true }) + natsFlush(t, ncSub) + checkExpectedSubs(t, 1, sa1, sa2) + + ncPub = natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", ob2.Port)) + defer ncPub.Close() + natsPub(t, ncPub, "foo", []byte("hello")) + waitCh(t, ch, "Did not get our message") +} + +func TestGatewayUnknownGatewayCommand(t *testing.T) { + o1 := testDefaultOptionsForGateway("A") + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + l := &checkErrorLogger{checkErrorStr: "Unknown command"} + s1.SetLogger(l, true, true) + + o2 := testDefaultOptionsForGateway("A") + o2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s1.ClusterAddr().Port)) + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + checkClusterFormed(t, s1, s2) + + var route *client + s2.mu.Lock() + for _, r := range s2.routes { + route = r + break + } + s2.mu.Unlock() + + route.mu.Lock() + info := &Info{ + Gateway: "B", + GatewayCmd: 255, + } + b, _ := json.Marshal(info) + route.sendProto([]byte(fmt.Sprintf(InfoProto, b)), true) + route.mu.Unlock() + + checkFor(t, time.Second, 15*time.Millisecond, func() error { + l.Lock() + gotIt := l.gotError + l.Unlock() + if gotIt { + return nil + } + return fmt.Errorf("Did not get expected error") + }) +} + +func TestGatewayRandomIP(t *testing.T) { + ob := testDefaultOptionsForGateway("B") + sb := runGatewayServer(ob) + defer sb.Shutdown() + + oa := testGatewayOptionsFromToWithURLs(t, "A", "B", + []string{ + "nats://noport", + fmt.Sprintf("nats://localhost:%d", sb.GatewayAddr().Port), + }) + // Create a dummy resolver that returns error since we + // don't provide any IP. The code should then use the configured + // url (localhost:port) and try with that, which in this case + // should work. + oa.Gateway.resolver = &myDummyDNSResolver{} + sa := runGatewayServer(oa) + defer sa.Shutdown() + + waitForOutboundGateways(t, sa, 1, 2*time.Second) + waitForOutboundGateways(t, sb, 1, 2*time.Second) +} + +/* +func TestGatewayPermissions(t *testing.T) { + bo := testDefaultOptionsForGateway("B") + sb := runGatewayServer(bo) + defer sb.Shutdown() + + ao := testGatewayOptionsFromToWithServers(t, "A", "B", sb) + // test setup by default sets import and export to ">". + // For this test, we override. + ao.Gateway.Permissions.Import = &SubjectPermission{Allow: []string{"foo"}} + ao.Gateway.Permissions.Export = &SubjectPermission{Allow: []string{"bar"}} + sa := runGatewayServer(ao) + defer sa.Shutdown() + + waitForOutboundGateways(t, sa, 1, time.Second) + waitForOutboundGateways(t, sb, 1, time.Second) + + // Create client connections, one of cluster A, one on B. + nca := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", ao.Port)) + defer nca.Close() + ncb := natsConnect(t, fmt.Sprintf("nats://127.0.0.1:%d", bo.Port)) + defer ncb.Close() + + ch := make(chan bool, 1) + cb := func(m *nats.Msg) { + ch <- true + } + + // Check import permissions... + + // Create a local sub on "foo" on Cluster A. + natsSub(t, nca, "foo", cb) + natsFlush(t, nca) + + // Message should be received + + natsPub(t, ncb, "foo", []byte("message from B to A on foo")) + natsFlush(t, ncb) + + waitCh(t, ch, "Did not get message on foo") + + // Create a sub on "baz" on Cluster A, no message should be received on that one + natsSub(t, nca, "baz", cb) + natsFlush(t, nca) + + natsPub(t, ncb, "baz", []byte("message from B to A on baz")) + natsFlush(t, ncb) + + select { + case <-ch: + t.Fatalf("Message should not have been received") + case <-time.After(250 * time.Millisecond): + // still no message, we are ok. + } + + // Check export permissions now... + + // Create sub on "bar" on Cluster B + natsSub(t, ncb, "bar", cb) + natsFlush(t, ncb) + + // Send from A, that should be possible so message should be received. + natsPub(t, nca, "bar", []byte("message from A to B on bar")) + natsFlush(t, nca) + + waitCh(t, ch, "Did not get message on bar") + + // Create a sub on "bozo" on Cluster B + natsSub(t, ncb, "bozo", cb) + natsFlush(t, ncb) + + // That message should not be received + natsPub(t, nca, "bozo", []byte("message from A to B on bozo")) + natsFlush(t, nca) + + select { + case <-ch: + t.Fatalf("Message should not have been received") + case <-time.After(250 * time.Millisecond): + // still no message, we are ok. + } +} + +func TestGatewayDefaultPermissions(t *testing.T) { + bo := testDefaultOptionsForGateway("B") + // test setup by default sets import and export to ">". + // For this test, we override. + bo.Gateway.DefaultPermissions.Import = &SubjectPermission{Allow: []string{"foo"}} + bo.Gateway.DefaultPermissions.Export = &SubjectPermission{Allow: []string{"bar", "baz"}} + sb := runGatewayServer(bo) + defer sb.Shutdown() + + ao := testGatewayOptionsFromToWithServers(t, "A", "B", sb) + sa := runGatewayServer(ao) + defer sa.Shutdown() + + waitForOutboundGateways(t, sa, 1, time.Second) + waitForOutboundGateways(t, sb, 1, time.Second) + + // Check permissions on cluster B. Since there was no + // explicit gateway defined to A, when server on B accepted + // the gateway connection, it should have "inherited" the + // default permissions. + gw := sb.getRemoteGateway("A") + if gw == nil { + t.Fatalf("There should be a remote gateway for A") + } + gw.RLock() + impc := gw.imports.Count() + expc := gw.exports.Count() + gw.RUnlock() + + if impc != 1 { + t.Fatalf("Expected import sublist to be size 1, got %v", impc) + } + if expc != 2 { + t.Fatalf("Expected export sublist to be size 2, got %v", expc) + } + + // Check server on Cluster A too. By default, tests create + // gateway permissions and remote gateways default permissions + // for import/export to be ">". + gw = sa.getRemoteGateway("B") + if gw == nil { + t.Fatalf("There should be a remote gateway for B") + } + gw.RLock() + impc = gw.imports.Count() + expc = gw.exports.Count() + gw.RUnlock() + if impc != 1 { + t.Fatalf("Expected import sublist to be size 1, got %v", impc) + } + if expc != 1 { + t.Fatalf("Expected export sublist to be size 1, got %v", expc) + } +} +*/ diff --git a/server/monitor.go b/server/monitor.go index c181487d..c5bb4d71 100644 --- a/server/monitor.go +++ b/server/monitor.go @@ -1045,6 +1045,8 @@ func (reason ClosedState) String() string { return "Server Shutdown" case AuthenticationExpired: return "Authentication Expired" + case WrongGateway: + return "Wrong Gateway" } return "Unknown State" } diff --git a/server/opts.go b/server/opts.go index a84d77ce..f6104cb9 100644 --- a/server/opts.go +++ b/server/opts.go @@ -14,6 +14,7 @@ package server import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -47,6 +48,35 @@ type ClusterOpts struct { ConnectRetries int `json:"-"` } +// GatewayOpts are options for gateways. +type GatewayOpts struct { + Name string `json:"name"` + Host string `json:"addr,omitempty"` + Port int `json:"port,omitempty"` + Username string `json:"-"` + Password string `json:"-"` + AuthTimeout float64 `json:"auth_timeout,omitempty"` + TLSConfig *tls.Config `json:"-"` + TLSTimeout float64 `json:"tls_timeout,omitempty"` + Advertise string `json:"advertise,omitempty"` + ConnectRetries int `json:"connect_retries,omitempty"` + DefaultPermissions *GatewayPermissions `json:"default_permissions,omitempty"` + Gateways []*RemoteGatewayOpts `json:"gateways,omitempty"` + RejectUnknown bool `json:"reject_unknown,omitempty"` + + // Not exported, for tests. + resolver netResolver +} + +// RemoteGatewayOpts are options for connecting to a remote gateway +type RemoteGatewayOpts struct { + Name string `json:"name"` + TLSConfig *tls.Config `json:"-"` + TLSTimeout float64 `json:"tls_timeout,omitempty"` + URLs []*url.URL `json:"urls,omitempty"` + Permissions *GatewayPermissions `json:"permissions,omitempty"` +} + // Options block for gnatsd server. type Options struct { ConfigFile string `json:"-"` @@ -77,6 +107,7 @@ type Options struct { MaxPayload int `json:"max_payload"` MaxPending int64 `json:"max_pending"` Cluster ClusterOpts `json:"cluster,omitempty"` + Gateway GatewayOpts `json:"gateway,omitempty"` ProfPort int `json:"-"` PidFile string `json:"-"` PortsFileDir string `json:"-"` @@ -103,6 +134,13 @@ type Options struct { // CheckConfig configuration file syntax test was successful and exit. CheckConfig bool `json:"-"` + + // private fields, used for testing + gatewaysSolicitDelay time.Duration +} + +type netResolver interface { + LookupHost(ctx context.Context, host string) ([]string, error) } // Clone performs a deep copy of the Options struct, returning a new clone @@ -127,12 +165,7 @@ func (o *Options) Clone() *Options { } if o.Routes != nil { - clone.Routes = make([]*url.URL, len(o.Routes)) - for i, route := range o.Routes { - routeCopy := &url.URL{} - *routeCopy = *route - clone.Routes[i] = routeCopy - } + clone.Routes = deepCopyURLs(o.Routes) } if o.TLSConfig != nil { clone.TLSConfig = o.TLSConfig.Clone() @@ -140,9 +173,31 @@ func (o *Options) Clone() *Options { if o.Cluster.TLSConfig != nil { clone.Cluster.TLSConfig = o.Cluster.TLSConfig.Clone() } + if o.Gateway.TLSConfig != nil { + clone.Gateway.TLSConfig = o.Gateway.TLSConfig.Clone() + } + if len(o.Gateway.Gateways) > 0 { + clone.Gateway.Gateways = make([]*RemoteGatewayOpts, len(o.Gateway.Gateways)) + for i, g := range o.Gateway.Gateways { + clone.Gateway.Gateways[i] = g.clone() + } + } return clone } +func deepCopyURLs(urls []*url.URL) []*url.URL { + if urls == nil { + return nil + } + curls := make([]*url.URL, len(urls)) + for i, u := range urls { + cu := &url.URL{} + *cu = *u + curls[i] = cu + } + return curls +} + // Configuration file authorization section. type authorization struct { // Singles @@ -350,6 +405,11 @@ func (o *Options) ProcessConfigFile(configFile string) error { errors = append(errors, err) continue } + case "gateway": + if err := parseGateway(tk, o, &errors, &warnings); err != nil { + errors = append(errors, err) + continue + } case "logfile", "log_file": o.LogFile = v.(string) case "syslog": @@ -377,7 +437,7 @@ func (o *Options) ProcessConfigFile(configFile string) error { case "ping_max": o.MaxPingsOut = int(v.(int64)) case "tls": - tc, err := parseTLS(tk, o) + tc, err := parseTLS(tk) if err != nil { errors = append(errors, err) continue @@ -556,35 +616,20 @@ func parseCluster(v interface{}, opts *Options, errors *[]error, warnings *[]err } case "routes": ra := mv.([]interface{}) - opts.Routes = make([]*url.URL, 0, len(ra)) - for _, r := range ra { - tk, r := unwrapValue(r) - routeURL := r.(string) - url, err := url.Parse(routeURL) - if err != nil { - err := &configErr{tk, fmt.Sprintf("error parsing route url [%q]", routeURL)} - *errors = append(*errors, err) - continue - } - opts.Routes = append(opts.Routes, url) + routes, errs := parseURLs(ra, "route") + if errs != nil { + *errors = append(*errors, errs...) + continue } + opts.Routes = routes case "tls": - tc, err := parseTLS(tk, opts) + config, timeout, err := getTLSConfig(tk) if err != nil { *errors = append(*errors, err) continue } - if opts.Cluster.TLSConfig, err = GenTLSConfig(tc); err != nil { - err := &configErr{tk, err.Error()} - *errors = append(*errors, err) - continue - } - // For clusters, we will force strict verification. We also act - // as both client and server, so will mirror the rootCA to the - // clientCA pool. - opts.Cluster.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert - opts.Cluster.TLSConfig.RootCAs = opts.Cluster.TLSConfig.ClientCAs - opts.Cluster.TLSTimeout = tc.Timeout + opts.Cluster.TLSConfig = config + opts.Cluster.TLSTimeout = timeout case "cluster_advertise", "advertise": opts.Cluster.Advertise = mv.(string) case "no_advertise": @@ -592,7 +637,7 @@ func parseCluster(v interface{}, opts *Options, errors *[]error, warnings *[]err case "connect_retries": opts.Cluster.ConnectRetries = int(mv.(int64)) case "permissions": - perms, err := parseUserPermissions(mv, opts, errors, warnings) + perms, err := parseUserPermissions(mv, errors, warnings) if err != nil { *errors = append(*errors, err) continue @@ -615,6 +660,206 @@ func parseCluster(v interface{}, opts *Options, errors *[]error, warnings *[]err return nil } +func parseURLs(a []interface{}, typ string) ([]*url.URL, []error) { + var ( + errors []error + urls = make([]*url.URL, 0, len(a)) + ) + for _, u := range a { + tk, u := unwrapValue(u) + sURL := u.(string) + url, err := parseURL(sURL, typ) + if err != nil { + err := &configErr{tk, err.Error()} + errors = append(errors, err) + continue + } + urls = append(urls, url) + } + return urls, errors +} + +func parseURL(u string, typ string) (*url.URL, error) { + urlStr := strings.TrimSpace(u) + url, err := url.Parse(urlStr) + if err != nil { + return nil, fmt.Errorf("error parsing %s url [%q]", typ, urlStr) + } + return url, nil +} + +func parseGateway(v interface{}, o *Options, errors *[]error, warnings *[]error) error { + tk, v := unwrapValue(v) + gm, ok := v.(map[string]interface{}) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected gateway to be a map, got %T", v)} + } + for mk, mv := range gm { + // Again, unwrap token value if line check is required. + tk, mv = unwrapValue(mv) + switch strings.ToLower(mk) { + case "name": + o.Gateway.Name = mv.(string) + case "listen": + hp, err := parseListen(mv) + if err != nil { + err := &configErr{tk, err.Error()} + *errors = append(*errors, err) + continue + } + o.Gateway.Host = hp.host + o.Gateway.Port = hp.port + case "port": + o.Gateway.Port = int(mv.(int64)) + case "host", "net": + o.Gateway.Host = mv.(string) + case "authorization": + auth, err := parseAuthorization(tk, o, errors, warnings) + if err != nil { + *errors = append(*errors, err) + continue + } + if auth.users != nil { + *errors = append(*errors, &configErr{tk, "Gateway authorization does not allow multiple users"}) + continue + } + o.Gateway.Username = auth.user + o.Gateway.Password = auth.pass + o.Gateway.AuthTimeout = auth.timeout + case "tls": + config, timeout, err := getTLSConfig(tk) + if err != nil { + *errors = append(*errors, err) + continue + } + o.Gateway.TLSConfig = config + o.Gateway.TLSTimeout = timeout + case "advertise": + o.Gateway.Advertise = mv.(string) + case "connect_retries": + o.Gateway.ConnectRetries = int(mv.(int64)) + case "default_permissions": + perms, err := parseGatewayPermissions(mv, errors, warnings) + if err != nil { + *errors = append(*errors, err) + continue + } + o.Gateway.DefaultPermissions = perms + case "gateways": + gateways, err := parseGateways(mv, errors, warnings) + if err != nil { + return err + } + o.Gateway.Gateways = gateways + case "reject_unknown": + o.Gateway.RejectUnknown = mv.(bool) + default: + if !tk.IsUsedVariable() { + err := &unknownConfigFieldErr{ + field: mk, + configErr: configErr{ + token: tk, + }, + } + *errors = append(*errors, err) + continue + } + } + } + return nil +} + +// Parse TLS and returns a TLSConfig and TLSTimeout. +// Used by cluster and gateway parsing. +func getTLSConfig(tk token) (*tls.Config, float64, error) { + tc, err := parseTLS(tk) + if err != nil { + return nil, 0, err + } + config, err := GenTLSConfig(tc) + if err != nil { + err := &configErr{tk, err.Error()} + return nil, 0, err + } + // For clusters/gateways, we will force strict verification. We also act + // as both client and server, so will mirror the rootCA to the + // clientCA pool. + config.ClientAuth = tls.RequireAndVerifyClientCert + config.RootCAs = config.ClientCAs + return config, tc.Timeout, nil +} + +func parseGateways(v interface{}, errors *[]error, warnings *[]error) ([]*RemoteGatewayOpts, error) { + tk, v := unwrapValue(v) + // Make sure we have an array + ga, ok := v.([]interface{}) + if !ok { + return nil, &configErr{tk, fmt.Sprintf("Expected gateways field to be an array, got %T", v)} + } + gateways := []*RemoteGatewayOpts{} + for _, g := range ga { + tk, g = unwrapValue(g) + // Check its a map/struct + gm, ok := g.(map[string]interface{}) + if !ok { + *errors = append(*errors, &configErr{tk, fmt.Sprintf("Expected gateway entry to be a map/struct, got %v", g)}) + continue + } + gateway := &RemoteGatewayOpts{} + for k, v := range gm { + tk, v = unwrapValue(v) + switch strings.ToLower(k) { + case "name": + gateway.Name = v.(string) + case "tls": + tls, timeout, err := getTLSConfig(tk) + if err != nil { + *errors = append(*errors, err) + continue + } + gateway.TLSConfig = tls + gateway.TLSTimeout = timeout + case "url": + url, err := parseURL(v.(string), "gateway") + if err != nil { + *errors = append(*errors, &configErr{tk, err.Error()}) + continue + } + gateway.URLs = append(gateway.URLs, url) + case "urls": + urls, errs := parseURLs(v.([]interface{}), "gateway") + if errs != nil { + for _, e := range errs { + *errors = append(*errors, e) + } + continue + } + gateway.URLs = urls + case "permissions": + perms, err := parseGatewayPermissions(v, errors, warnings) + if err != nil { + *errors = append(*errors, err) + continue + } + gateway.Permissions = perms + default: + if !tk.IsUsedVariable() { + err := &unknownConfigFieldErr{ + field: k, + configErr: configErr{ + token: tk, + }, + } + *errors = append(*errors, err) + continue + } + } + } + gateways = append(gateways, gateway) + } + return gateways, nil +} + // Sets cluster's permissions based on given pub/sub permissions, // doing the appropriate translation. func setClusterPermissions(opts *ClusterOpts, perms *Permissions) { @@ -1168,7 +1413,7 @@ func parseAuthorization(v interface{}, opts *Options, errors *[]error, warnings auth.users = users auth.nkeys = nkeys case "default_permission", "default_permissions", "permissions": - permissions, err := parseUserPermissions(tk, opts, errors, warnings) + permissions, err := parseUserPermissions(tk, errors, warnings) if err != nil { *errors = append(*errors, err) continue @@ -1243,7 +1488,7 @@ func parseUsers(mv interface{}, opts *Options, errors *[]error, warnings *[]erro case "pass", "password": user.Password = v.(string) case "permission", "permissions", "authorization": - perms, err = parseUserPermissions(tk, opts, errors, warnings) + perms, err = parseUserPermissions(tk, errors, warnings) if err != nil { *errors = append(*errors, err) continue @@ -1292,7 +1537,7 @@ func parseUsers(mv interface{}, opts *Options, errors *[]error, warnings *[]erro } // Helper function to parse user/account permissions -func parseUserPermissions(mv interface{}, opts *Options, errors, warnings *[]error) (*Permissions, error) { +func parseUserPermissions(mv interface{}, errors, warnings *[]error) (*Permissions, error) { var ( tk token p = &Permissions{} @@ -1310,14 +1555,14 @@ func parseUserPermissions(mv interface{}, opts *Options, errors, warnings *[]err // Import is Publish // Export is Subscribe case "pub", "publish", "import": - perms, err := parseVariablePermissions(v, opts, errors, warnings) + perms, err := parseVariablePermissions(v, errors, warnings) if err != nil { *errors = append(*errors, err) continue } p.Publish = perms case "sub", "subscribe", "export": - perms, err := parseVariablePermissions(v, opts, errors, warnings) + perms, err := parseVariablePermissions(v, errors, warnings) if err != nil { *errors = append(*errors, err) continue @@ -1334,17 +1579,57 @@ func parseUserPermissions(mv interface{}, opts *Options, errors, warnings *[]err } // Top level parser for authorization configurations. -func parseVariablePermissions(v interface{}, opts *Options, errors, warnings *[]error) (*SubjectPermission, error) { +func parseVariablePermissions(v interface{}, errors, warnings *[]error) (*SubjectPermission, error) { switch vv := v.(type) { case map[string]interface{}: // New style with allow and/or deny properties. - return parseSubjectPermission(vv, opts, errors, warnings) + return parseSubjectPermission(vv, errors, warnings) default: // Old style return parseOldPermissionStyle(v, errors, warnings) } } +// Helper function to parse gateway permissions +func parseGatewayPermissions(v interface{}, errors *[]error, warnings *[]error) (*GatewayPermissions, error) { + tk, v := unwrapValue(v) + pm, ok := v.(map[string]interface{}) + if !ok { + return nil, &configErr{tk, fmt.Sprintf("Expected permissions to be a map/struct, got %+v", v)} + } + perms := &GatewayPermissions{} + for k, v := range pm { + tk, v := unwrapValue(v) + switch strings.ToLower(k) { + case "import": + sp, err := parseVariablePermissions(v, errors, warnings) + if err != nil { + *errors = append(*errors, err) + continue + } + perms.Import = sp + case "export": + sp, err := parseVariablePermissions(v, errors, warnings) + if err != nil { + *errors = append(*errors, err) + continue + } + perms.Export = sp + default: + if !tk.IsUsedVariable() { + err := &unknownConfigFieldErr{ + field: k, + configErr: configErr{ + token: tk, + }, + } + *errors = append(*errors, err) + } + } + } + return perms, nil +} + // Helper function to parse subject singeltons and/or arrays func parseSubjects(v interface{}, errors, warnings *[]error) ([]string, error) { tk, v := unwrapValue(v) @@ -1384,7 +1669,7 @@ func parseOldPermissionStyle(v interface{}, errors, warnings *[]error) (*Subject } // Helper function to parse new style authorization into a SubjectPermission with Allow and Deny. -func parseSubjectPermission(v interface{}, opts *Options, errors, warnings *[]error) (*SubjectPermission, error) { +func parseSubjectPermission(v interface{}, errors, warnings *[]error) (*SubjectPermission, error) { m := v.(map[string]interface{}) if len(m) == 0 { return nil, nil @@ -1458,7 +1743,7 @@ func parseCurvePreferences(curveName string) (tls.CurveID, error) { } // Helper function to parse TLS configs. -func parseTLS(v interface{}, opts *Options) (*TLSConfigOpts, error) { +func parseTLS(v interface{}) (*TLSConfigOpts, error) { var ( tlsm map[string]interface{} tc = TLSConfigOpts{} @@ -1831,6 +2116,11 @@ func processOptions(opts *Options) { if opts.LameDuckDuration == 0 { opts.LameDuckDuration = DEFAULT_LAME_DUCK_DURATION } + if opts.Gateway.Port != 0 { + if opts.Gateway.Host == "" { + opts.Gateway.Host = DEFAULT_HOST + } + } } // ConfigureOptions accepts a flag set and augment it with NATS Server diff --git a/server/opts_test.go b/server/opts_test.go index 4a6432a8..fe210e50 100644 --- a/server/opts_test.go +++ b/server/opts_test.go @@ -973,6 +973,13 @@ func TestOptionsClone(t *testing.T) { NoAdvertise: true, ConnectRetries: 2, }, + Gateway: GatewayOpts{ + Name: "A", + Gateways: []*RemoteGatewayOpts{ + {Name: "B", URLs: []*url.URL{&url.URL{Scheme: "nats", Host: "host:5222"}}}, + {Name: "C"}, + }, + }, WriteDeadline: 3 * time.Second, Routes: []*url.URL{&url.URL{}}, Users: []*User{&User{Username: "foo", Password: "bar"}}, @@ -989,6 +996,14 @@ func TestOptionsClone(t *testing.T) { if reflect.DeepEqual(opts, clone) { t.Fatal("Expected Options to be different") } + + opts.Gateway.Gateways[0].URLs[0] = nil + if reflect.DeepEqual(opts.Gateway.Gateways[0], clone.Gateway.Gateways[0]) { + t.Fatal("Expected Options to be different") + } + if clone.Gateway.Gateways[0].URLs[0].Host != "host:5222" { + t.Fatalf("Unexpected URL: %v", clone.Gateway.Gateways[0].URLs[0]) + } } func TestOptionsCloneNilLists(t *testing.T) { @@ -1449,3 +1464,478 @@ func TestAccountUsersLoadedProperly(t *testing.T) { check(t) } } + +func TestParsingGateways(t *testing.T) { + content := ` + gateway { + name: "A" + listen: "127.0.0.1:4444" + host: "127.0.0.1" + port: 4444 + authorization { + user: "ivan" + password: "pwd" + timeout: 2.0 + } + tls { + cert_file: "./configs/certs/server.pem" + key_file: "./configs/certs/key.pem" + timeout: 3.0 + } + advertise: "me:1" + connect_retries: 10 + gateways: [ + { + name: "B" + urls: ["nats://user1:pwd1@host2:5222", "nats://user1:pwd1@host3:6222"] + } + { + name: "C" + url: "nats://host4:7222" + } + ] + } + ` + file := "server_config_gateways.conf" + defer os.Remove(file) + if err := ioutil.WriteFile(file, []byte(content), 0600); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + opts, err := ProcessConfigFile(file) + if err != nil { + t.Fatalf("Error processing file: %v", err) + } + + expected := &GatewayOpts{ + Name: "A", + Host: "127.0.0.1", + Port: 4444, + Username: "ivan", + Password: "pwd", + AuthTimeout: 2.0, + Advertise: "me:1", + ConnectRetries: 10, + TLSTimeout: 3.0, + } + u1, _ := url.Parse("nats://user1:pwd1@host2:5222") + u2, _ := url.Parse("nats://user1:pwd1@host3:6222") + urls := []*url.URL{u1, u2} + gw := &RemoteGatewayOpts{ + Name: "B", + URLs: urls, + } + expected.Gateways = append(expected.Gateways, gw) + + u1, _ = url.Parse("nats://host4:7222") + urls = []*url.URL{u1} + gw = &RemoteGatewayOpts{ + Name: "C", + URLs: urls, + } + expected.Gateways = append(expected.Gateways, gw) + + // Just make sure that TLSConfig is set.. we have aother test + // to check proper generating TLSConfig from config file... + if opts.Gateway.TLSConfig == nil { + t.Fatalf("Expected TLSConfig, got none") + } + opts.Gateway.TLSConfig = nil + if !reflect.DeepEqual(&opts.Gateway, expected) { + t.Fatalf("Expected %v, got %v", expected, opts.Gateway) + } +} + +func TestParsingGatewaysErrors(t *testing.T) { + for _, test := range []struct { + name string + content string + expectedErr string + }{ + { + "bad_type", + `gateway: "bad_type"`, + "Expected gateway to be a map", + }, + { + "bad_listen", + `gateway { + name: "A" + port: -1 + listen: "bad::address" + }`, + "parse address", + }, + { + "bad_auth", + `gateway { + name: "A" + port: -1 + authorization { + users { + } + } + }`, + "be an array", + }, + { + "unknown_field", + `gateway { + name: "A" + port: -1 + reject_unknown: true + unknown_field: 1 + }`, + "unknown field", + }, + { + "users_not_supported", + `gateway { + name: "A" + port: -1 + authorization { + users [ + {user: alice, password: foo} + {user: bob, password: bar} + ] + } + }`, + "does not allow multiple users", + }, + { + "tls_error", + `gateway { + name: "A" + port: -1 + tls { + cert_file: 123 + } + }`, + "to be filename", + }, + { + "tls_gen_error", + `gateway { + name: "A" + port: -1 + tls { + cert_file: "./configs/certs/server.pem" + } + }`, + "certificate/key pair", + }, + { + "gateways_needs_to_be_an_array", + `gateway { + name: "A" + gateways { + name: "B" + } + }`, + "Expected gateways field to be an array", + }, + { + "gateways_entry_needs_to_be_a_map", + `gateway { + name: "A" + gateways [ + "g1", "g2" + ] + }`, + "Expected gateway entry to be a map", + }, + { + "bad_url", + `gateway { + name: "A" + gateways [ + { + name: "B" + url: "nats://wrong url" + } + ] + }`, + "error parsing gateway url", + }, + { + "bad_urls", + `gateway { + name: "A" + gateways [ + { + name: "B" + urls: ["nats://wrong url", "nats://host:5222"] + } + ] + }`, + "error parsing gateway url", + }, + { + "gateway_tls_error", + `gateway { + name: "A" + port: -1 + gateways [ + { + name: "B" + tls { + cert_file: 123 + } + } + ] + }`, + "to be filename", + }, + { + "gateway_unknon_field", + `gateway { + name: "A" + port: -1 + gateways [ + { + name: "B" + unknown_field: 1 + } + ] + }`, + "unknown field", + }, + { + "default_permissions_bad_type", + `gateway { + name: "A" + default_permissions: shoul_be_a_map + }`, + "Expected permissions to be a map/struct", + }, + { + "default_permissions_unknown_field", + `gateway { + name: "A" + default_permissions { + import: "foo" + export: ["bar", "baz"] + unknown_field: 1 + } + }`, + "unknown field", + }, + { + "default_permissions_bad_import_subjects", + `gateway { + name: "A" + default_permissions { + import: { + "foo" + "bar" + } + } + }`, + "only 'allow' or 'deny' are permitted", + }, + { + "default_permissions_bad_import_allow_subjects", + `gateway { + name: "A" + default_permissions { + import: { + allow: { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + { + "default_permissions_bad_import_deny_subjects", + `gateway { + name: "A" + default_permissions { + import: { + deny: { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + { + "default_permissions_bad_export_subjects", + `gateway { + name: "A" + default_permissions { + export: { + "foo" + "bar" + } + } + }`, + "only 'allow' or 'deny' are permitted", + }, + { + "default_permissions_bad_export_allow_subjects", + `gateway { + name: "A" + default_permissions { + export: { + allow { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + { + "default_permissions_bad_export_deny_subjects", + `gateway { + name: "A" + default_permissions { + export: { + deny { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + { + "gateways_permissions_bad_type", + `gateway { + name: "A" + gateways [ + { + name: "B" + url: "nats://localhost:4222" + permissions: should_be_a_map + } + ] + }`, + "Expected permissions to be a map/struct", + }, + { + "gateways_permissions_unknown_field", + `gateway { + name: "A" + gateways [ + { + name: "B" + url: "nats://localhost:4222" + permissions { + import: "foo" + export: ["bar", "baz"] + unknown_field: 1 + } + } + ] + }`, + "unknown field", + }, + { + "gateways_permissions_bad_import_subjects", + `gateway { + name: "A" + default_permissions { + import: { + "foo" + "bar" + } + } + }`, + "only 'allow' or 'deny' are permitted", + }, + { + "gateways_permissions_bad_import_allow_subjects", + `gateway { + name: "A" + default_permissions { + import: { + allow { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + { + "gateways_permissions_bad_import_deny_subjects", + `gateway { + name: "A" + default_permissions { + import: { + deny { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + { + "gateways_permissions_bad_export_subjects", + `gateway { + name: "A" + default_permissions { + export: { + "foo" + "bar" + } + } + }`, + "only 'allow' or 'deny' are permitted", + }, + { + "gateways_permissions_bad_export_allow_subjects", + `gateway { + name: "A" + default_permissions { + import: { + allow { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + { + "gateways_permissions_bad_export_deny_subjects", + `gateway { + name: "A" + default_permissions { + import: { + deny { + "foo" + "bar" + } + } + } + }`, + "Expected subject permissions to be a subject, or array of subjects", + }, + } { + t.Run(test.name, func(t *testing.T) { + file := fmt.Sprintf("server_config_gateways_%s.conf", test.name) + defer os.Remove(file) + if err := ioutil.WriteFile(file, []byte(test.content), 0600); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + _, err := ProcessConfigFile(file) + if err == nil { + t.Fatalf("Expected to fail, did not. Content:\n%s\n", test.content) + } else if !strings.Contains(err.Error(), test.expectedErr) { + t.Fatalf("Expected error containing %q, got %q, for content:\n%s\n", test.expectedErr, err, test.content) + } + }) + } +} diff --git a/server/parser.go b/server/parser.go index ee33daaa..6c66f402 100644 --- a/server/parser.go +++ b/server/parser.go @@ -388,10 +388,13 @@ func (c *client) parse(buf []byte) error { arg = buf[c.as : i-c.drop] } var err error - if c.typ == CLIENT { + switch c.typ { + case CLIENT: err = c.processSub(arg) - } else { + case ROUTER: err = c.processRemoteSub(arg) + case GATEWAY: + err = c.processGatewaySubjectSub(arg) } if err != nil { return err @@ -476,10 +479,13 @@ func (c *client) parse(buf []byte) error { arg = buf[c.as : i-c.drop] } var err error - if c.typ == CLIENT { + switch c.typ { + case CLIENT: err = c.processUnsub(arg) - } else { + case ROUTER: err = c.processRemoteUnsub(arg) + case GATEWAY: + err = c.processGatewaySubjectUnsub(arg) } if err != nil { return err diff --git a/server/reload.go b/server/reload.go index 8a02d79c..48fbd1f2 100644 --- a/server/reload.go +++ b/server/reload.go @@ -304,7 +304,7 @@ func (r *routesOption) Apply(server *Server) { client.mu.Unlock() if url != nil && urlsAreEqual(url, remove) { // Do not attempt to reconnect when route is removed. - client.setRouteNoReconnectOnClose() + client.setNoReconnect() client.closeConnection(RouteRemoved) server.Noticef("Removed route %v", remove) } @@ -552,8 +552,13 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { ) for i := 0; i < oldConfig.NumField(); i++ { + field := oldConfig.Type().Field(i) + // field.PkgPath is empty for exported fields, and is not for unexported ones. + // We skip the unexported fields. + if field.PkgPath != "" { + continue + } var ( - field = oldConfig.Type().Field(i) oldValue = oldConfig.Field(i).Interface() newValue = newConfig.Field(i).Interface() changed = !reflect.DeepEqual(oldValue, newValue) @@ -758,7 +763,7 @@ func (s *Server) reloadAuthorization() { // because in the later case, we don't have the user name/password // of the remote server. if !route.isSolicitedRoute() && !s.isRouterAuthorized(route) { - route.setRouteNoReconnectOnClose() + route.setNoReconnect() route.authViolation() } } diff --git a/server/route.go b/server/route.go index 12defb60..b58d9072 100644 --- a/server/route.go +++ b/server/route.go @@ -50,6 +50,14 @@ const ( RouteProtoV2 ) +// Include the space for the proto +var ( + aSubBytes = []byte{'A', '+', ' '} + aUnsubBytes = []byte{'A', '-', ' '} + rSubBytes = []byte{'R', 'S', '+', ' '} + rUnsubBytes = []byte{'R', 'S', '-', ' '} +) + // Used by tests var testRouteProto = RouteProtoV2 @@ -61,9 +69,9 @@ type route struct { url *url.URL authRequired bool tlsRequired bool - closed bool connectURLs []string replySubs map[*subscription]*time.Timer + gatewayURL string } type connectInfo struct { @@ -74,6 +82,7 @@ type connectInfo struct { Pass string `json:"pass,omitempty"` TLS bool `json:"tls_required"` Name string `json:"name"` + Gateway string `json:"gateway,omitempty"` } // Route protocol constants @@ -130,16 +139,20 @@ func (c *client) removeReplySubTimeout(sub *subscription) { } func (c *client) processAccountSub(arg []byte) error { - // Placeholder in case we add in to the protocol active senders of - // informtation. For now we do not do account interest propagation. c.traceInOp("A+", arg) + accName := string(arg) + if c.typ == GATEWAY { + return c.processGatewayAccountSub(accName) + } return nil } func (c *client) processAccountUnsub(arg []byte) { - // Placeholder in case we add in to the protocol active senders of - // informtation. For now we do not do account interest propagation. c.traceInOp("A-", arg) + accName := string(arg) + if c.typ == GATEWAY { + c.processGatewayAccountUnsub(accName) + } } // Process an inbound RMSG specification from the remote route. @@ -376,6 +389,20 @@ func (c *client) processRouteInfo(info *Info) { s := c.srv remoteID := c.route.remoteID + // Check if this is an INFO for gateways... + if info.Gateway != "" { + c.mu.Unlock() + // If this server has no gateway configured, report error and return. + if !s.gateway.enabled { + // FIXME: Should this be a Fatalf()? + s.Errorf("Received information about gateway %q from %s, but gateway is not configured", + info.Gateway, remoteID) + return + } + s.processGatewayInfoFromRoute(info, remoteID, c) + return + } + // We receive an INFO from a server that informs us about another server, // so the info.ID in the INFO protocol does not match the ID of this route. if remoteID != "" && remoteID != info.ID { @@ -404,6 +431,7 @@ func (c *client) processRouteInfo(info *Info) { // Copy over important information. c.route.authRequired = info.AuthRequired c.route.tlsRequired = info.TLSRequired + c.route.gatewayURL = info.GatewayURL // If this is an update due to config reload on the remote server, // need to possibly send local subs to the remote server. @@ -446,6 +474,9 @@ func (c *client) processRouteInfo(info *Info) { // Send our subs to the other side. s.sendSubsToRoute(c) + // Send info about the known gateways to this route. + s.sendGatewayConfigsToRoute(c) + // sendInfo will be false if the route that we just accepted // is the only route there is. if sendInfo { @@ -691,36 +722,44 @@ func (c *client) removeRemoteSubs() { } } -// Indicates no more interest in the given account/subject for the remote side. -func (c *client) processRemoteUnsub(arg []byte) (err error) { +func (c *client) parseUnsubProto(arg []byte) (string, []byte, []byte, error) { c.traceInOp("RS-", arg) // Indicate any activity, so pub and sub or unsubs. c.in.subs++ + args := splitArg(arg) + var ( + accountName string + subject []byte + queue []byte + ) + switch len(args) { + case 2: + case 3: + queue = args[2] + default: + return "", nil, nil, fmt.Errorf("Parse Error: '%s'", arg) + } + subject = args[1] + accountName = string(args[0]) + return accountName, subject, queue, nil +} + +// Indicates no more interest in the given account/subject for the remote side. +func (c *client) processRemoteUnsub(arg []byte) (err error) { srv := c.srv if srv == nil { return nil } - - args := splitArg(arg) - sub := &subscription{client: c} - - switch len(args) { - case 2: - sub.queue = nil - case 3: - sub.queue = args[2] - default: - return fmt.Errorf("processRemoteUnSub Parse Error: '%s'", arg) + accountName, subject, _, err := c.parseUnsubProto(arg) + if err != nil { + return fmt.Errorf("processRemoteUnsub %s", err.Error()) } - sub.subject = args[1] - // Lookup the account - accountName := string(args[0]) acc := c.srv.LookupAccount(accountName) if acc == nil { - c.Debugf("Unknown account %q for subject %q", accountName, sub.subject) + c.Debugf("Unknown account %q for subject %q", accountName, subject) // Mark this account as not interested since we received a RS- and we // do not have any record of it. return nil @@ -732,16 +771,24 @@ func (c *client) processRemoteUnsub(arg []byte) (err error) { return nil } + sendToGWs := false // We store local subs by account and subject and optionally queue name. // RS- will have the arg exactly as the key. key := string(arg) - if sub, ok := c.subs[key]; ok { + sub, ok := c.subs[key] + if ok { delete(c.subs, key) acc.sl.Remove(sub) c.removeReplySubTimeout(sub) + // Send only for queue subs + sendToGWs = srv.gateway.enabled && sub.queue != nil } c.mu.Unlock() + if sendToGWs { + srv.sendQueueUnsubToGateways(accountName, sub) + } + if c.opts.Verbose { c.sendOK() } @@ -818,6 +865,7 @@ func (c *client) processRemoteSub(argo []byte) (err error) { } key := string(sub.sid) osub := c.subs[key] + sendToGWs := false if osub == nil { c.subs[string(key)] = sub // Now place into the account sl. @@ -828,6 +876,7 @@ func (c *client) processRemoteSub(argo []byte) (err error) { c.sendErr("Invalid Subscription") return nil } + sendToGWs = srv.gateway.enabled } else if sub.queue != nil { // For a queue we need to update the weight. atomic.StoreInt32(&osub.qw, sub.qw) @@ -838,6 +887,15 @@ func (c *client) processRemoteSub(argo []byte) (err error) { if c.opts.Verbose { c.sendOK() } + if sendToGWs { + // For a plain sub, this will send an RS+ to gateways only if + // we had previously sent an RS-. In other words, we don't send + // an RS+ per plain sub. + // For queue subs, we will send an RS+, but if we are here, we + // know there is a single qsub per account/subject/queue: + // sendToGWs is true only if we did not find that key before. + srv.sendSubInterestToGateways(acc.Name, sub) + } return nil } @@ -972,11 +1030,11 @@ func (c *client) sendRouteSubOrUnSubProtos(subs []*subscription, isSubProto, tra } as := len(buf) if isSubProto { - buf = append(buf, []byte("RS+ ")...) + buf = append(buf, rSubBytes...) } else { - buf = append(buf, []byte("RS- ")...) + buf = append(buf, rUnsubBytes...) } - buf = append(buf, []byte(accName)...) + buf = append(buf, accName...) buf = append(buf, ' ') buf = append(buf, sub.subject...) if len(sub.queue) > 0 { @@ -997,7 +1055,7 @@ func (c *client) sendRouteSubOrUnSubProtos(subs []*subscription, isSubProto, tra if trace { c.traceOutOp("", buf[as:]) } - buf = append(buf, []byte(CR_LF)...) + buf = append(buf, CR_LF...) } if !closed && len(buf) > 0 { c.queueOutbound(buf) @@ -1103,15 +1161,9 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { // to the client (connection) to be closed, leaving this readLoop // uinterrupted, causing the Shutdown() to wait indefinitively. // We need to store the client in a special map, under a special lock. - s.grMu.Lock() - running := s.grRunning - if running { - s.grTmpClients[c.cid] = c - } - s.grMu.Unlock() - if !running { + if !s.addToTempClients(c.cid, c) { c.mu.Unlock() - c.setRouteNoReconnectOnClose() + c.setNoReconnect() c.closeConnection(ServerShutdown) return nil } @@ -1124,7 +1176,7 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { } // Spin up the read loop. - s.startGoRoutine(func() { c.readLoop() }) + s.startGoRoutine(c.readLoop) // Spin up the write loop. s.startGoRoutine(c.writeLoop) @@ -1173,13 +1225,16 @@ func (s *Server) addRoute(c *client, info *Info) (bool, bool) { cid := c.cid c.mu.Unlock() - // Remove from the temporary map - s.grMu.Lock() - delete(s.grTmpClients, cid) - s.grMu.Unlock() + // Now that we have registered the route, we can remove from the temp map. + s.removeFromTempClients(cid) // we don't need to send if the only route is the one we just accepted. sendInfo = len(s.routes) > 1 + + // If the INFO contains a Gateway URL, add it to the list for our cluster. + if info.GatewayURL != "" { + s.addGatewayURL(info.GatewayURL) + } } s.mu.Unlock() @@ -1251,8 +1306,11 @@ func (s *Server) updateRouteSubscriptionMap(acc *Account, sub *subscription, del } // We always update for a queue subscriber since we need to send our relative weight. - var entry *rme - var ok bool + var ( + entry *rme + ok bool + added bool + ) // Always update if a queue subscriber. update := qi > 0 @@ -1271,6 +1329,7 @@ func (s *Server) updateRouteSubscriptionMap(acc *Account, sub *subscription, del entry = &rme{qi, delta} rm[string(key)] = entry update = true // Adding for normal sub means update. + added = true } if entry != nil { entryN = entry.n @@ -1295,8 +1354,20 @@ func (s *Server) updateRouteSubscriptionMap(acc *Account, sub *subscription, del // subscribes with a smaller weight. if entryN > 0 { s.broadcastSubscribe(sub) + // Here we want to send RS+ only when going from 0 to 1 + if s.gateway.enabled && added && entryN == 1 { + // Sends an RS+ to gateways if this is a queue sub, + // or if a plain sub but only if we had previously + // sent an RS- to the gateways. + s.sendSubInterestToGateways(acc.Name, sub) + } } else { s.broadcastUnSubscribe(sub) + // Last of the queue member of this group, so send to + // gateways. + if s.gateway.enabled && sub.queue != nil { + s.sendQueueUnsubToGateways(acc.Name, sub) + } } } @@ -1373,6 +1444,7 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) { TLSVerify: tlsReq, MaxPayload: s.info.MaxPayload, Proto: proto, + GatewayURL: s.getGatewayURL(), } // Set this if only if advertise is not disabled if !opts.Cluster.NoAdvertise { @@ -1416,17 +1488,7 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) { for s.isRunning() { conn, err := l.Accept() if err != nil { - if ne, ok := err.(net.Error); ok && ne.Temporary() { - s.Debugf("Temporary Route Accept Errorf(%v), sleeping %dms", - ne, tmpDelay/time.Millisecond) - time.Sleep(tmpDelay) - tmpDelay *= 2 - if tmpDelay > ACCEPT_MAX_SLEEP { - tmpDelay = ACCEPT_MAX_SLEEP - } - } else if s.isRunning() { - s.Noticef("Accept error: %v", err) - } + tmpDelay = s.acceptError("Route", err, tmpDelay) continue } tmpDelay = ACCEPT_MIN_SLEEP @@ -1563,3 +1625,62 @@ func (s *Server) solicitRoutes(routes []*url.URL) { s.startGoRoutine(func() { s.connectToRoute(route, true) }) } } + +func (c *client) processRouteConnect(srv *Server, arg []byte, lang string) error { + // Way to detect clients that incorrectly connect to the route listen + // port. Client provide Lang in the CONNECT protocol while ROUTEs don't. + if lang != "" { + errTxt := ErrClientConnectedToRoutePort.Error() + c.Errorf(errTxt) + c.sendErr(errTxt) + c.closeConnection(WrongPort) + return ErrClientConnectedToRoutePort + } + // Unmarshal as a route connect protocol + proto := &connectInfo{} + if err := json.Unmarshal(arg, proto); err != nil { + return err + } + // Reject if this has Gateway which means that it would be from a gateway + // connection that incorrectly connects to the Route port. + if proto.Gateway != "" { + errTxt := fmt.Sprintf("Rejecting connection from gateway %q on the Route port", proto.Gateway) + c.Errorf(errTxt) + c.sendErr(errTxt) + c.closeConnection(WrongGateway) + return ErrWrongGateway + } + var perms *RoutePermissions + if srv != nil { + perms = srv.getOpts().Cluster.Permissions + } + // Grab connection name of remote route. + c.mu.Lock() + c.route.remoteID = c.opts.Name + c.setRoutePermissions(perms) + c.mu.Unlock() + return nil +} + +func (s *Server) removeRoute(c *client) { + var rID string + c.mu.Lock() + cid := c.cid + r := c.route + if r != nil { + rID = r.remoteID + } + c.mu.Unlock() + s.mu.Lock() + delete(s.routes, cid) + if r != nil { + rc, ok := s.remotes[rID] + // Only delete it if it is us.. + if ok && c == rc { + delete(s.remotes, rID) + } + s.removeGatewayURL(r.gatewayURL) + } + s.removeFromTempClients(cid) + s.mu.Unlock() +} diff --git a/server/server.go b/server/server.go index de5b823e..594415ae 100644 --- a/server/server.go +++ b/server/server.go @@ -15,6 +15,7 @@ package server import ( "bytes" + "context" "crypto/tls" "encoding/json" "flag" @@ -68,6 +69,12 @@ type Info struct { // Route Specific Import *SubjectPermission `json:"import,omitempty"` Export *SubjectPermission `json:"export,omitempty"` + + // Gateways Specific + Gateway string `json:"gateway,omitempty"` // Name of the origin Gateway (sent by gateway's INFO) + GatewayURLs []string `json:"gateway_urls,omitempty"` // Gateway URLs in the originating cluster (sent by gateway's INFO) + GatewayURL string `json:"gateway_url,omitempty"` // Gateway URL on that server (sent by route's INFO) + GatewayCmd byte `json:"gateway_cmd,omitempty"` // Command code for the receiving server to know what to do } // Server is our main struct. @@ -128,11 +135,16 @@ type Server struct { lastCURLsUpdate int64 + // For Gateways + gatewayListener net.Listener // Accept listener + gateway *srvGateway + // These store the real client/cluster listen ports. They are // required during config reload to reset the Options (after // reload) to the actual listen port values. clientActualPort int clusterActualPort int + gatewayActualPort int // Use during reload oldClusterPerms *RoutePermissions @@ -167,6 +179,16 @@ func New(opts *Options) *Server { tlsReq := opts.TLSConfig != nil verify := (tlsReq && opts.TLSConfig.ClientAuth == tls.RequireAndVerifyClientCert) + // Validate some options. This is here because we cannot assume that + // server will always be started with configuration parsing (that could + // report issues). Its options can be (incorrectly) set by hand when + // server is embedded. If there is an error, return nil. + // TODO: Should probably have a new NewServer() API that returns (*Server, error) + // so user can know what's wrong. + if err := validateOptions(opts); err != nil { + return nil + } + info := Info{ ID: genID(), Version: VERSION, @@ -208,6 +230,17 @@ func New(opts *Options) *Server { // Used internally for quick look-ups. s.clientConnectURLsMap = make(map[string]struct{}) + // Call this even if there is no gateway defined. It will + // initialize the structure so we don't have to check for + // it to be nil or not in various places in the code. + // Do this before calling registerAccount() since registerAccount + // may try to send things to gateways. + gws, err := newGateway(opts) + if err != nil { + return nil + } + s.gateway = gws + // For tracking accounts s.accounts = make(map[string]*Account) @@ -245,6 +278,12 @@ func New(opts *Options) *Server { return s } +func validateOptions(o *Options) error { + // Check that gateway is properly configured. Returns no error + // if there is no gateway defined. + return validateGatewayOptions(o) +} + func (s *Server) getOpts() *Options { s.optsMu.RLock() opts := s.opts @@ -463,11 +502,16 @@ func (s *Server) registerAccount(acc *Account) { // already created (global account), so use locking and // make sure we create only if needed. acc.mu.Lock() - if acc.rm == nil && s.opts != nil && s.opts.Cluster.Port != 0 { + if acc.rm == nil && s.opts != nil && (s.opts.Cluster.Port != 0 || s.opts.Gateway.Port != 0) { acc.rm = make(map[string]*rme, 256) } acc.mu.Unlock() s.accounts[acc.Name] = acc + if s.gateway.enabled { + // Check and possibly send an A+ to gateways for which + // we had sent an A- because account did not exist at that time. + s.endAccountNoInterestForGateways(acc.Name) + } } // LookupAccount is a public function to return the account structure @@ -605,6 +649,13 @@ func (s *Server) Start() { return } + // Start up gateway if needed. Do this before starting the routes, because + // we want to resolve the gateway host:port so that this information can + // be sent to other routes. + if opts.Gateway.Port != 0 { + s.startGateways() + } + // The Routing routine needs to wait for the client listen // port to be opened and potential ephemeral port selected. clientListenReady := make(chan struct{}) @@ -663,9 +714,10 @@ func (s *Server) Shutdown() { s.grMu.Unlock() // Copy off the routes for i, r := range s.routes { - r.setRouteNoReconnectOnClose() conns[i] = r } + // Copy off the gateways + s.getAllGatewayConnections(conns) // Number of done channel responses we expect. doneExpected := 0 @@ -684,6 +736,13 @@ func (s *Server) Shutdown() { s.routeListener = nil } + // Kick Gateway AcceptLoop() + if s.gatewayListener != nil { + doneExpected++ + s.gatewayListener.Close() + s.gatewayListener = nil + } + // Kick HTTP monitoring if its running if s.http != nil { doneExpected++ @@ -704,6 +763,7 @@ func (s *Server) Shutdown() { // Close client and route connections for _, c := range conns { + c.setNoReconnect() c.closeConnection(ServerShutdown) } @@ -759,7 +819,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) { s.Noticef("TLS required for client connections") } - s.Debugf("Server id is %s", s.info.ID) + s.Noticef("Server id is %s", s.info.ID) s.Noticef("Server is ready") // Setup state that can enable shutdown @@ -804,17 +864,7 @@ func (s *Server) AcceptLoop(clr chan struct{}) { <-s.quitCh return } - if ne, ok := err.(net.Error); ok && ne.Temporary() { - s.Errorf("Temporary Client Accept Error (%v), sleeping %dms", - ne, tmpDelay/time.Millisecond) - time.Sleep(tmpDelay) - tmpDelay *= 2 - if tmpDelay > ACCEPT_MAX_SLEEP { - tmpDelay = ACCEPT_MAX_SLEEP - } - } else if s.isRunning() { - s.Errorf("Client Accept Error: %v", err) - } + tmpDelay = s.acceptError("Client", err, tmpDelay) continue } tmpDelay = ACCEPT_MIN_SLEEP @@ -1306,42 +1356,45 @@ func tlsCipher(cs uint16) string { // Remove a client or route from our internal accounting. func (s *Server) removeClient(c *client) { - var rID string - c.mu.Lock() - cid := c.cid - typ := c.typ - r := c.route - if r != nil { - rID = r.remoteID - } - updateProtoInfoCount := false - if typ == CLIENT && c.opts.Protocol >= ClientProtoInfo { - updateProtoInfoCount = true - } - c.mu.Unlock() - - s.mu.Lock() - switch typ { + // type is immutable, so can check without lock + switch c.typ { case CLIENT: + c.mu.Lock() + cid := c.cid + updateProtoInfoCount := false + if c.typ == CLIENT && c.opts.Protocol >= ClientProtoInfo { + updateProtoInfoCount = true + } + c.mu.Unlock() + + s.mu.Lock() delete(s.clients, cid) if updateProtoInfoCount { s.cproto-- } + s.mu.Unlock() case ROUTER: - delete(s.routes, cid) - if r != nil { - rc, ok := s.remotes[rID] - // Only delete it if it is us.. - if ok && c == rc { - delete(s.remotes, rID) - } - } - // Remove from temporary map in case it is there. - s.grMu.Lock() - delete(s.grTmpClients, cid) - s.grMu.Unlock() + s.removeRoute(c) + case GATEWAY: + s.removeRemoteGateway(c) } - s.mu.Unlock() +} + +func (s *Server) removeFromTempClients(cid uint64) { + s.grMu.Lock() + delete(s.grTmpClients, cid) + s.grMu.Unlock() +} + +func (s *Server) addToTempClients(cid uint64, c *client) bool { + added := false + s.grMu.Lock() + if s.grRunning { + s.grTmpClients[cid] = c + added = true + } + s.grMu.Unlock() + return added } ///////////////////////////////////////////////////////////////// @@ -1452,7 +1505,7 @@ func (s *Server) ReadyForConnections(dur time.Duration) bool { end := time.Now().Add(dur) for time.Now().Before(end) { s.mu.Lock() - ok := s.listener != nil && (opts.Cluster.Port == 0 || s.routeListener != nil) + ok := s.listener != nil && (opts.Cluster.Port == 0 || s.routeListener != nil) && (opts.Gateway.Name == "" || s.gatewayListener != nil) s.mu.Unlock() if ok { return true @@ -1838,3 +1891,52 @@ func (s *Server) lameDuckMode() { } s.Shutdown() } + +// If given error is a net.Error and is temporary, sleeps for the given +// delay and double it, but cap it to ACCEPT_MAX_SLEEP. The sleep is +// interrupted if the server is shutdown. +// An error message is displayed depending on the type of error. +// Returns the new (or unchanged) delay. +func (s *Server) acceptError(acceptName string, err error, tmpDelay time.Duration) time.Duration { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + s.Errorf("Temporary %s Accept Error(%v), sleeping %dms", acceptName, ne, tmpDelay/time.Millisecond) + select { + case <-time.After(tmpDelay): + case <-s.quitCh: + return tmpDelay + } + tmpDelay *= 2 + if tmpDelay > ACCEPT_MAX_SLEEP { + tmpDelay = ACCEPT_MAX_SLEEP + } + } else if s.isRunning() { + s.Errorf("%s Accept error: %v", acceptName, err) + } + return tmpDelay +} + +func (s *Server) getRandomIP(resolver netResolver, url string) (string, error) { + host, port, err := net.SplitHostPort(url) + if err != nil { + return "", err + } + ips, err := resolver.LookupHost(context.Background(), host) + if err != nil { + return "", fmt.Errorf("lookup for host %q: %v", host, err) + } + var address string + if len(ips) == 0 { + s.Warnf("Unable to get IP for %s, will try with %s: %v", host, url, err) + address = url + } else { + var ip string + if len(ips) == 1 { + ip = ips[0] + } else { + ip = ips[rand.Int31n(int32(len(ips)))] + } + // add the port + address = net.JoinHostPort(ip, port) + } + return address, nil +} diff --git a/server/server_test.go b/server/server_test.go index 6f82fe11..925b8991 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -14,11 +14,14 @@ package server import ( + "context" "flag" "fmt" "net" + "net/url" "os" "strings" + "sync" "sync/atomic" "testing" "time" @@ -847,3 +850,169 @@ func TestLameDuckMode(t *testing.T) { t.Fatalf("Expected client to reconnect only once, got %v", n) } } + +func TestServerValidateGatewaysOptions(t *testing.T) { + baseOpt := testDefaultOptionsForGateway("A") + u, _ := url.Parse("host:5222") + g := &RemoteGatewayOpts{ + URLs: []*url.URL{u}, + } + baseOpt.Gateway.Gateways = append(baseOpt.Gateway.Gateways, g) + + for _, test := range []struct { + name string + opts func() *Options + expectedErr string + }{ + { + name: "gateway_has_no_name", + opts: func() *Options { + o := baseOpt.Clone() + o.Gateway.Name = "" + return o + }, + expectedErr: "has no name", + }, + { + name: "gateway_has_no_port", + opts: func() *Options { + o := baseOpt.Clone() + o.Gateway.Port = 0 + return o + }, + expectedErr: "no port specified", + }, + { + name: "gateway_dst_has_no_name", + opts: func() *Options { + o := baseOpt.Clone() + return o + }, + expectedErr: "has no name", + }, + { + name: "gateway_dst_urls_is_nil", + opts: func() *Options { + o := baseOpt.Clone() + o.Gateway.Gateways[0].Name = "B" + o.Gateway.Gateways[0].URLs = nil + return o + }, + expectedErr: "has no URL", + }, + { + name: "gateway_dst_urls_is_empty", + opts: func() *Options { + o := baseOpt.Clone() + o.Gateway.Gateways[0].Name = "B" + o.Gateway.Gateways[0].URLs = []*url.URL{} + return o + }, + expectedErr: "has no URL", + }, + } { + t.Run(test.name, func(t *testing.T) { + if err := validateOptions(test.opts()); err == nil || !strings.Contains(err.Error(), test.expectedErr) { + t.Fatalf("Expected error about %q, got %v", test.expectedErr, err) + } + }) + } +} + +func TestAcceptError(t *testing.T) { + o := DefaultOptions() + s := New(o) + s.mu.Lock() + s.running = true + s.mu.Unlock() + defer s.Shutdown() + orgDelay := time.Hour + delay := s.acceptError("Test", fmt.Errorf("any error"), orgDelay) + if delay != orgDelay { + t.Fatalf("With this type of error, delay should have stayed same, got %v", delay) + } + + // Create any net.Error and make it a temporary + ne := &net.DNSError{IsTemporary: true} + orgDelay = 10 * time.Millisecond + delay = s.acceptError("Test", ne, orgDelay) + if delay != 2*orgDelay { + t.Fatalf("Expected delay to double, got %v", delay) + } + // Now check the max + orgDelay = 60 * ACCEPT_MAX_SLEEP / 100 + delay = s.acceptError("Test", ne, orgDelay) + if delay != ACCEPT_MAX_SLEEP { + t.Fatalf("Expected delay to double, got %v", delay) + } + wg := sync.WaitGroup{} + wg.Add(1) + start := time.Now() + go func() { + s.acceptError("Test", ne, orgDelay) + wg.Done() + }() + time.Sleep(100 * time.Millisecond) + // This should kick out the sleep in acceptError + s.Shutdown() + if dur := time.Since(start); dur >= ACCEPT_MAX_SLEEP { + t.Fatalf("Shutdown took too long: %v", dur) + } +} + +type myDummyDNSResolver struct { + ips []string + err error +} + +func (r *myDummyDNSResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + if r.err != nil { + return nil, r.err + } + return r.ips, nil +} + +func TestGetRandomIP(t *testing.T) { + s := &Server{} + resolver := &myDummyDNSResolver{} + // no port... + if _, err := s.getRandomIP(resolver, "noport"); err == nil || !strings.Contains(err.Error(), "port") { + t.Fatalf("Expected error about port missing, got %v", err) + } + resolver.err = fmt.Errorf("on purpose") + if _, err := s.getRandomIP(resolver, "localhost:4222"); err == nil || !strings.Contains(err.Error(), "on purpose") { + t.Fatalf("Expected error about no port, got %v", err) + } + resolver.err = nil + a, err := s.getRandomIP(resolver, "localhost:4222") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if a != "localhost:4222" { + t.Fatalf("Expected address to be %q, got %q", "localhost:4222", a) + } + resolver.ips = []string{"1.2.3.4"} + a, err = s.getRandomIP(resolver, "localhost:4222") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if a != "1.2.3.4:4222" { + t.Fatalf("Expected address to be %q, got %q", "1.2.3.4:4222", a) + } + // Check for randomness + resolver.ips = []string{"1.2.3.4", "2.2.3.4", "3.2.3.4"} + dist := [3]int{} + for i := 0; i < 100; i++ { + ip, err := s.getRandomIP(resolver, "localhost:4222") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + v := int(ip[0]-'0') - 1 + dist[v]++ + } + for i, d := range dist { + if d < 23 || d > 43 { + t.Fatalf("Unexpected distribution for ip %v, got %v", i, d) + } + } +} diff --git a/server/split_test.go b/server/split_test.go index ef8e45a2..c610f54b 100644 --- a/server/split_test.go +++ b/server/split_test.go @@ -24,7 +24,11 @@ func TestSplitBufferSubOp(t *testing.T) { defer cli.Close() defer trash.Close() - s := &Server{gacc: &Account{Name: globalAccountName}, accounts: make(map[string]*Account)} + gws, err := newGateway(DefaultOptions()) + if err != nil { + t.Fatalf("Error creating gateways: %v", err) + } + s := &Server{gacc: &Account{Name: globalAccountName}, accounts: make(map[string]*Account), gateway: gws} s.registerAccount(s.gacc) c := &client{srv: s, acc: s.gacc, subs: make(map[string]*subscription), nc: cli} @@ -61,7 +65,7 @@ func TestSplitBufferSubOp(t *testing.T) { } func TestSplitBufferUnsubOp(t *testing.T) { - s := &Server{gacc: &Account{Name: globalAccountName}, accounts: make(map[string]*Account)} + s := &Server{gacc: &Account{Name: globalAccountName}, accounts: make(map[string]*Account), gateway: &srvGateway{}} s.registerAccount(s.gacc) c := &client{srv: s, acc: s.gacc, subs: make(map[string]*subscription)} diff --git a/test/gateway_test.go b/test/gateway_test.go new file mode 100644 index 00000000..4336871e --- /dev/null +++ b/test/gateway_test.go @@ -0,0 +1,317 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "net" + "regexp" + "testing" + + "github.com/nats-io/gnatsd/server" +) + +func testDefaultOptionsForGateway(name string) *server.Options { + o := DefaultTestOptions + o.Gateway.Name = name + o.Gateway.Host = "127.0.0.1" + o.Gateway.Port = -1 + return &o +} + +func runGatewayServer(o *server.Options) *server.Server { + s := RunServer(o) + return s +} + +func createGatewayConn(t testing.TB, host string, port int) net.Conn { + t.Helper() + return createClientConn(t, host, port) +} + +func setupGatewayConn(t testing.TB, c net.Conn, org, dst string) (sendFun, expectFun) { + t.Helper() + dstInfo := checkInfoMsg(t, c) + if dstInfo.Gateway != dst { + t.Fatalf("Expected to connect to %q, got %q", dst, dstInfo.Gateway) + } + cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"tls_required\":%v,\"gateway\":%q}\r\n", + false, false, false, org) + sendProto(t, c, cs) + sendProto(t, c, fmt.Sprintf("INFO {\"gateway\":%q}\r\n", org)) + return sendCommand(t, c), expectCommand(t, c) +} + +func expectNumberOfProtos(t *testing.T, expFn expectFun, proto *regexp.Regexp, expected int) { + t.Helper() + for count := 0; count != expected; { + buf := expFn(proto) + count += len(proto.FindAllSubmatch(buf, -1)) + if count > expected { + t.Fatalf("Expected %v matches, got %v", expected, count) + } + } +} + +func TestGatewayAccountInterest(t *testing.T) { + ob := testDefaultOptionsForGateway("B") + sb := runGatewayServer(ob) + defer sb.Shutdown() + + gA := createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gA.Close() + + gASend, gAExpect := setupGatewayConn(t, gA, "A", "B") + gASend("PING\r\n") + gAExpect(pongRe) + + // Sending a bunch of messages. On the first, "B" will send an A- + // protocol. + for i := 0; i < 100; i++ { + gASend("RMSG $foo foo 2\r\nok\r\n") + } + // We expect single A- followed by PONG. If "B" was sending more + // this expect call would fail. + gAExpect(aunsubRe) + gASend("PING\r\n") + gAExpect(pongRe) + + // Start gateway C that connects to B + gC := createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gC.Close() + + gCSend, gCExpect := setupGatewayConn(t, gC, "C", "B") + gCSend("PING\r\n") + gCExpect(pongRe) + // Send more messages, C should get A-, but A should not (already + // got it). + for i := 0; i < 100; i++ { + gCSend("RMSG $foo foo 2\r\nok\r\n") + } + gCExpect(aunsubRe) + gCSend("PING\r\n") + gCExpect(pongRe) + expectNothing(t, gA) + + // Restart one of the gateway, and resend a message, verify + // that it receives A- (things get cleared on reconnect) + gC.Close() + gC = createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gC.Close() + + gCSend, gCExpect = setupGatewayConn(t, gC, "C", "B") + gCSend("PING\r\n") + gCExpect(pongRe) + gCSend("RMSG $foo foo 2\r\nok\r\n") + gCExpect(aunsubRe) + gCSend("PING\r\n") + gCExpect(pongRe) + expectNothing(t, gA) + + // Close again and re-create, but this time don't send anything. + gC.Close() + gC = createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gC.Close() + + gCSend, gCExpect = setupGatewayConn(t, gC, "C", "B") + gCSend("PING\r\n") + gCExpect(pongRe) + + // Now register the $foo account on B, A should receive an A+ + // because B knows that it previously sent an A-, but since + // it did not send one to C, C should not receive the A+. + sb.RegisterAccount("$foo") + gAExpect(asubRe) + expectNothing(t, gC) +} + +func TestGatewaySubjectInterest(t *testing.T) { + ob := testDefaultOptionsForGateway("B") + fooAcc := &server.Account{Name: "$foo"} + ob.Accounts = []*server.Account{fooAcc} + ob.Users = []*server.User{&server.User{Username: "ivan", Password: "password", Account: fooAcc}} + sb := runGatewayServer(ob) + defer sb.Shutdown() + + gA := createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gA.Close() + + gASend, gAExpect := setupGatewayConn(t, gA, "A", "B") + gASend("PING\r\n") + gAExpect(pongRe) + + for i := 0; i < 100; i++ { + gASend("RMSG $foo foo 2\r\nok\r\n") + } + // We expect single RS- followed by PONG. If "B" was sending more + // this expect call would fail. + gAExpect(runsubRe) + gASend("PING\r\n") + gAExpect(pongRe) + + // Start gateway C that connects to B + gC := createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gC.Close() + + gCSend, gCExpect := setupGatewayConn(t, gC, "C", "B") + gCSend("PING\r\n") + gCExpect(pongRe) + // Send more messages, C should get RS-, but A should not (already + // got it). + for i := 0; i < 100; i++ { + gCSend("RMSG $foo foo 2\r\nok\r\n") + } + gCExpect(runsubRe) + gCSend("PING\r\n") + gCExpect(pongRe) + expectNothing(t, gA) + + // Restart one of the gateway, and resend a message, verify + // that it receives RS- (things get cleared on reconnect) + gC.Close() + gC = createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gC.Close() + + gCSend, gCExpect = setupGatewayConn(t, gC, "C", "B") + gCSend("PING\r\n") + gCExpect(pongRe) + gCSend("RMSG $foo foo 2\r\nok\r\n") + gCExpect(runsubRe) + gCSend("PING\r\n") + gCExpect(pongRe) + expectNothing(t, gA) + + // Close again and re-create, but this time don't send anything. + gC.Close() + gC = createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gC.Close() + + gCSend, gCExpect = setupGatewayConn(t, gC, "C", "B") + gCSend("PING\r\n") + gCExpect(pongRe) + + // Now register a subscription on foo for account $foo on B. + // A should receive a RS+ because B knows that it previously + // sent a RS-, but since it did not send one to C, C should + // not receive the RS+. + client := createClientConn(t, ob.Host, ob.Port) + defer client.Close() + + clientSend, clientExpect := setupConnWithUserPass(t, client, "ivan", "password") + clientSend("SUB foo 1\r\nSUB foo 2\r\n") + // Also subscribe to subject that was not used before, + // so there should be no RS+ for this one. + clientSend("SUB bar 3\r\nPING\r\n") + clientExpect(pongRe) + + gAExpect(rsubRe) + expectNothing(t, gC) + // Check that we get only one protocol + expectNothing(t, gA) + + // Unsubscribe the 2 subs on foo, expect to receive nothing. + clientSend("UNSUB 1\r\nUNSUB 2\r\nPING\r\n") + clientExpect(pongRe) + + expectNothing(t, gC) + expectNothing(t, gA) + + gC.Close() + + // Send on foo, should get an RS- + gASend("RMSG $foo foo 2\r\nok\r\n") + gAExpect(runsubRe) + // Subscribe on foo, should get an RS+ that removes the no-interest + clientSend("SUB foo 1\r\nPING\r\n") + clientExpect(pongRe) + gAExpect(rsubRe) + // Send on bar, message should be received. + gASend("RMSG $foo bar 2\r\nok\r\n") + clientExpect(msgRe) +} + +func TestGatewayQueue(t *testing.T) { + ob := testDefaultOptionsForGateway("B") + fooAcc := &server.Account{Name: "$foo"} + ob.Accounts = []*server.Account{fooAcc} + ob.Users = []*server.User{&server.User{Username: "ivan", Password: "password", Account: fooAcc}} + sb := runGatewayServer(ob) + defer sb.Shutdown() + + gA := createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gA.Close() + + gASend, gAExpect := setupGatewayConn(t, gA, "A", "B") + gASend("PING\r\n") + gAExpect(pongRe) + + client := createClientConn(t, ob.Host, ob.Port) + defer client.Close() + clientSend, clientExpect := setupConnWithUserPass(t, client, "ivan", "password") + + // Create one queue sub on foo.* for group bar. + clientSend("SUB foo.* bar 1\r\nPING\r\n") + clientExpect(pongRe) + // Expect RS+ + gAExpect(rsubRe) + // Add another queue sub on same group + clientSend("SUB foo.* bar 2\r\nPING\r\n") + clientExpect(pongRe) + // Should not receive another RS+ for that one + expectNothing(t, gA) + // However, if subject is different, we can expect to receive another RS+ + clientSend("SUB foo.> bar 3\r\nPING\r\n") + clientExpect(pongRe) + gAExpect(rsubRe) + + // Unsub one of the foo.* qsub, no RS- should be received + clientSend("UNSUB 1\r\nPING\r\n") + clientExpect(pongRe) + expectNothing(t, gA) + // Remove the other one, now we should get the RS- + clientSend("UNSUB 2\r\nPING\r\n") + clientExpect(pongRe) + gAExpect(runsubRe) + // Remove last one + clientSend("UNSUB 3\r\nPING\r\n") + clientExpect(pongRe) + gAExpect(runsubRe) + + // Create some queues and check that interest is sent + // when GW reconnects. + clientSend("SUB foo bar 4\r\n") + gAExpect(rsubRe) + clientSend("SUB foo baz 5\r\n") + gAExpect(rsubRe) + clientSend("SUB foo bat 6\r\n") + gAExpect(rsubRe) + // There is already one on foo/bar, so nothing sent + clientSend("SUB foo bar 7\r\n") + expectNothing(t, gA) + // Add regular sub that should not cause RS+ + clientSend("SUB foo 8\r\n") + expectNothing(t, gA) + + // Recreate gA + gA.Close() + gA = createGatewayConn(t, ob.Gateway.Host, ob.Gateway.Port) + defer gA.Close() + gASend, gAExpect = setupGatewayConn(t, gA, "A", "B") + // A should receive 3 RS+ + expectNumberOfProtos(t, gAExpect, rsubRe, 3) + // Nothing more + expectNothing(t, gA) + gASend("PING\r\n") + gAExpect(pongRe) +} diff --git a/test/test.go b/test/test.go index 03ee4351..4ec1dae6 100644 --- a/test/test.go +++ b/test/test.go @@ -221,6 +221,14 @@ func setupConnWithAccount(t tLogger, c net.Conn, account string) (sendFun, expec return sendCommand(t, c), expectCommand(t, c) } +func setupConnWithUserPass(t tLogger, c net.Conn, username, password string) (sendFun, expectFun) { + checkInfoMsg(t, c) + cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"tls_required\":%v,\"protocol\":1,\"user\":%q,\"pass\":%q}\r\n", + false, false, false, username, password) + sendProto(t, c, cs) + return sendCommand(t, c), expectCommand(t, c) +} + type sendFun func(string) type expectFun func(*regexp.Regexp) []byte @@ -260,6 +268,8 @@ var ( rsubRe = regexp.MustCompile(`RS\+\s+([^\s]+)\s+([^\s]+)\s*([^\s]+)?\s*(\d+)?\r\n`) runsubRe = regexp.MustCompile(`RS\-\s+([^\s]+)\s+([^\s]+)\s*([^\s]+)?\r\n`) rmsgRe = regexp.MustCompile(`(?:(?:RMSG\s+([^\s]+)\s+([^\s]+)\s+(?:([|+]\s+([\w\s]+)|[^\s]+)[^\S\r\n]+)?(\d+)\s*\r\n([^\\r\\n]*?)\r\n)+?)`) + asubRe = regexp.MustCompile(`A\+\s+([^\r\n]+)\r\n`) + aunsubRe = regexp.MustCompile(`A\-\s+([^\r\n]+)\r\n`) ) const (