Server support for headers between routes

Signed-off-by: Derek Collison <derek@nats.io>
This commit is contained in:
Derek Collison
2020-05-13 19:31:06 -07:00
parent d8b475c4b0
commit f5ceab339a
9 changed files with 397 additions and 52 deletions

View File

@@ -2457,14 +2457,69 @@ func (c *client) checkDenySub(subject string) bool {
return false
}
func (c *client) msgHeader(subj []byte, sub *subscription, reply []byte) []byte {
// Create a message header for routes or leafnodes. Header aware.
func (c *client) msgHeaderForRouteOrLeaf(subj, reply []byte, rt *routeTarget, acc *Account) []byte {
hasHeader := c.pa.hdr > 0
canReceiveHeader := rt.sub.client.headers
kind := rt.sub.client.kind
mh := c.msgb[:msgHeadProtoLen]
if kind == ROUTER {
// Router (and Gateway) nodes are RMSG. Set here since leafnodes may rewrite.
mh[0] = 'R'
mh = append(mh, acc.Name...)
mh = append(mh, ' ')
} else {
// Leaf nodes are LMSG
mh[0] = 'L'
// Remap subject if its a shadow subscription, treat like a normal client.
if rt.sub.im != nil && rt.sub.im.prefix != "" {
mh = append(mh, rt.sub.im.prefix...)
}
}
mh = append(mh, subj...)
mh = append(mh, ' ')
if len(rt.qs) > 0 {
if reply != nil {
mh = append(mh, "+ "...) // Signal that there is a reply.
mh = append(mh, reply...)
mh = append(mh, ' ')
} else {
mh = append(mh, "| "...) // Only queues
}
mh = append(mh, rt.qs...)
} else if reply != nil {
mh = append(mh, reply...)
mh = append(mh, ' ')
}
if hasHeader {
if canReceiveHeader {
mh[0] = 'H'
mh = append(mh, c.pa.hdb...)
mh = append(mh, ' ')
mh = append(mh, c.pa.szb...)
} else {
// If we are here we need to truncate the payload size
nsz := strconv.Itoa(c.pa.size - c.pa.hdr)
mh = append(mh, nsz...)
}
} else {
mh = append(mh, c.pa.szb...)
}
mh = append(mh, _CRLF_...)
return mh
}
// Create a message header for clients. Header aware.
func (c *client) msgHeader(subj, reply []byte, sub *subscription) []byte {
// See if we should do headers. We have to have a headers msg and
// the client we are going to deliver to needs to support headers as well.
// TODO(dlc) - This should only be for client connections, but should we check?
doHeaders := c.pa.hdr > 0 && sub.client != nil && sub.client.headers
hasHeader := c.pa.hdr > 0
canReceiveHeader := sub.client != nil && sub.client.headers
var mh []byte
if doHeaders {
if hasHeader && canReceiveHeader {
mh = c.msgb[:msgHeadProtoLen]
mh[0] = 'H'
} else {
@@ -2481,12 +2536,19 @@ func (c *client) msgHeader(subj []byte, sub *subscription, reply []byte) []byte
mh = append(mh, reply...)
mh = append(mh, ' ')
}
if doHeaders {
mh = append(mh, c.pa.hdb...)
mh = append(mh, ' ')
if hasHeader {
if canReceiveHeader {
mh = append(mh, c.pa.hdb...)
mh = append(mh, ' ')
mh = append(mh, c.pa.szb...)
} else {
// If we are here we need to truncate the payload size
nsz := strconv.Itoa(c.pa.size - c.pa.hdr)
mh = append(mh, nsz...)
}
} else {
mh = append(mh, c.pa.szb...)
}
mh = append(mh, c.pa.szb...)
mh = append(mh, _CRLF_...)
return mh
}
@@ -2592,6 +2654,8 @@ func (c *client) deliverMsg(sub *subscription, subject, mh, msg []byte, gwrply b
// Check here if we have a header with our message. If this client can not
// support we need to strip the headers from the payload.
// The actual header would have been processed correctluy for us, so just
// need to update payload.
if c.pa.hdr > 0 && !sub.client.headers {
msg = msg[c.pa.hdr:]
}
@@ -3324,7 +3388,7 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver,
dsubj = append(dsubj, subj...)
}
// Normal delivery
mh := c.msgHeader(dsubj, sub, creply)
mh := c.msgHeader(dsubj, creply, sub)
didDeliver = c.deliverMsg(sub, subj, mh, msg, rplyHasGWPrefix) || didDeliver
}
@@ -3437,7 +3501,7 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver,
}
// "rreply" will be stripped of the $GNR prefix (if present)
// for client connections only.
mh := c.msgHeader(dsubj, sub, rreply)
mh := c.msgHeader(dsubj, rreply, sub)
if c.deliverMsg(sub, subject, mh, msg, rplyHasGWPrefix) {
didDeliver = true
// Clear rsub
@@ -3479,39 +3543,7 @@ sendToRoutesOrLeafs:
// We have inline structs for memory layout and cache coherency.
for i := range c.in.rts {
rt := &c.in.rts[i]
kind := rt.sub.client.kind
mh := c.msgb[:msgHeadProtoLen]
if kind == ROUTER {
// Router (and Gateway) nodes are RMSG. Set here since leafnodes may rewrite.
mh[0] = 'R'
mh = append(mh, acc.Name...)
mh = append(mh, ' ')
} else {
// Leaf nodes are LMSG
mh[0] = 'L'
// Remap subject if its a shadow subscription, treat like a normal client.
if rt.sub.im != nil && rt.sub.im.prefix != "" {
mh = append(mh, rt.sub.im.prefix...)
}
}
mh = append(mh, subject...)
mh = append(mh, ' ')
if len(rt.qs) > 0 {
if reply != nil {
mh = append(mh, "+ "...) // Signal that there is a reply.
mh = append(mh, reply...)
mh = append(mh, ' ')
} else {
mh = append(mh, "| "...) // Only queues
}
mh = append(mh, rt.qs...)
} else if reply != nil {
mh = append(mh, reply...)
mh = append(mh, ' ')
}
mh = append(mh, c.pa.szb...)
mh = append(mh, _CRLF_...)
mh := c.msgHeaderForRouteOrLeaf(subject, reply, rt, acc)
didDeliver = c.deliverMsg(rt.sub, subject, mh, msg, false) || didDeliver
}
return didDeliver, queues

View File

@@ -348,10 +348,13 @@ func TestClientHeaderDeliverStrippedMsg(t *testing.T) {
if matches[SID_INDEX] != "1" {
t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX])
}
if matches[LEN_INDEX] != "14" {
if matches[LEN_INDEX] != "2" {
t.Fatalf("Did not get correct msg length: '%s'\n", matches[LEN_INDEX])
}
checkPayload(br, []byte("OK\r\n"), t)
if br.Buffered() != 0 {
t.Fatalf("Expected no extra bytes to be buffered, got %d", br.Buffered())
}
}
func TestClientHeaderDeliverQueueSubStrippedMsg(t *testing.T) {
@@ -394,7 +397,7 @@ func TestClientHeaderDeliverQueueSubStrippedMsg(t *testing.T) {
if matches[SID_INDEX] != "1" {
t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX])
}
if matches[LEN_INDEX] != "14" {
if matches[LEN_INDEX] != "2" {
t.Fatalf("Did not get correct msg length: '%s'\n", matches[LEN_INDEX])
}
checkPayload(br, []byte("OK\r\n"), t)

View File

@@ -127,6 +127,12 @@ const (
// MAX_MSG_ARGS Maximum possible number of arguments from MSG proto.
MAX_MSG_ARGS = 4
// MAX_RMSG_ARGS Maximum possible number of arguments from RMSG proto.
MAX_RMSG_ARGS = 6
// MAX_HMSG_ARGS Maximum possible number of arguments from HMSG proto.
MAX_HMSG_ARGS = 7
// MAX_PUB_ARGS Maximum possible number of arguments from PUB proto.
MAX_PUB_ARGS = 3

View File

@@ -1442,6 +1442,11 @@ func (c *client) processLeafUnsub(arg []byte) error {
return nil
}
func (c *client) processLeafHeaderMsgArgs(arg []byte) error {
fmt.Printf("arg is %q\n", arg)
return nil
}
func (c *client) processLeafMsgArgs(arg []byte) error {
// Unroll splitArgs to avoid runtime/heap issues
a := [MAX_MSG_ARGS][]byte{}

View File

@@ -68,6 +68,11 @@ const (
OP_HPUB
OP_HPUB_SPC
HPUB_ARG
OP_HM
OP_HMS
OP_HMSG
OP_HMSG_SPC
HMSG_ARG
OP_P
OP_PU
OP_PUB
@@ -191,6 +196,8 @@ func (c *client) parse(buf []byte) error {
switch b {
case 'P', 'p':
c.state = OP_HP
case 'M', 'm':
c.state = OP_HM
default:
goto parseErr
}
@@ -254,6 +261,73 @@ func (c *client) parse(buf []byte) error {
c.argBuf = append(c.argBuf, b)
}
}
case OP_HM:
switch b {
case 'S', 's':
c.state = OP_HMS
default:
goto parseErr
}
case OP_HMS:
switch b {
case 'G', 'g':
c.state = OP_HMSG
default:
goto parseErr
}
case OP_HMSG:
switch b {
case ' ', '\t':
c.state = OP_HMSG_SPC
default:
goto parseErr
}
case OP_HMSG_SPC:
switch b {
case ' ', '\t':
continue
default:
c.state = HMSG_ARG
c.as = i
}
case HMSG_ARG:
switch b {
case '\r':
c.drop = 1
case '\n':
var arg []byte
if c.argBuf != nil {
arg = c.argBuf
c.argBuf = nil
} else {
arg = buf[c.as : i-c.drop]
}
var err error
if c.kind == ROUTER || c.kind == GATEWAY {
if trace {
c.traceInOp("HMSG", arg)
}
err = c.processRoutedHeaderMsgArgs(arg)
} else if c.kind == LEAF {
if trace {
c.traceInOp("HMSG", arg)
}
err = c.processLeafHeaderMsgArgs(arg)
}
if err != nil {
return err
}
c.drop, c.as, c.state = 0, i+1, MSG_PAYLOAD
// jump ahead with the index. If this overruns
// what is left we fall out and process split
// buffer.
i = c.as + c.pa.size - LEN_CR_LF
default:
if c.argBuf != nil {
c.argBuf = append(c.argBuf, b)
}
}
case OP_P:
switch b {
case 'U', 'u':

View File

@@ -86,6 +86,7 @@ type connectInfo struct {
User string `json:"user,omitempty"`
Pass string `json:"pass,omitempty"`
TLS bool `json:"tls_required"`
Headers bool `json:"headers"`
Name string `json:"name"`
Gateway string `json:"gateway,omitempty"`
}
@@ -154,10 +155,91 @@ func (c *client) processAccountUnsub(arg []byte) {
}
}
// Process an inbound HMSG specification from the remote route.
func (c *client) processRoutedHeaderMsgArgs(arg []byte) error {
// Unroll splitArgs to avoid runtime/heap issues
a := [MAX_HMSG_ARGS][]byte{}
args := a[:0]
start := -1
for i, b := range arg {
switch b {
case ' ', '\t', '\r', '\n':
if start >= 0 {
args = append(args, arg[start:i])
start = -1
}
default:
if start < 0 {
start = i
}
}
}
if start >= 0 {
args = append(args, arg[start:])
}
c.pa.arg = arg
switch len(args) {
case 0, 1, 2, 3:
return fmt.Errorf("processRoutedHeaderMsgArgs Parse Error: '%s'", args)
case 4:
c.pa.reply = nil
c.pa.queues = nil
c.pa.hdb = args[2]
c.pa.hdr = parseSize(args[2])
c.pa.szb = args[3]
c.pa.size = parseSize(args[3])
case 5:
c.pa.reply = args[2]
c.pa.queues = nil
c.pa.hdb = args[3]
c.pa.hdr = parseSize(args[3])
c.pa.szb = args[4]
c.pa.size = parseSize(args[4])
default:
// args[2] is our reply indicator. Should be + or | normally.
if len(args[2]) != 1 {
return fmt.Errorf("processRoutedHeaderMsgArgs Bad or Missing Reply Indicator: '%s'", args[2])
}
switch args[2][0] {
case '+':
c.pa.reply = args[3]
case '|':
c.pa.reply = nil
default:
return fmt.Errorf("processRoutedHeaderMsgArgs Bad or Missing Reply Indicator: '%s'", args[2])
}
// Grab header size.
c.pa.hdb = args[len(args)-2]
c.pa.hdr = parseSize(c.pa.hdb)
// Grab size.
c.pa.szb = args[len(args)-1]
c.pa.size = parseSize(c.pa.szb)
// Grab queue names.
if c.pa.reply != nil {
c.pa.queues = args[4 : len(args)-1]
} else {
c.pa.queues = args[3 : len(args)-1]
}
}
if c.pa.size < 0 {
return fmt.Errorf("processRoutedHeaderMsgArgs Bad or Missing Size: '%s'", args)
}
// Common ones processed after check for arg length
c.pa.account = args[0]
c.pa.subject = args[1]
c.pa.pacache = arg[:len(args[0])+len(args[1])+1]
return nil
}
// Process an inbound RMSG specification from the remote route.
func (c *client) processRoutedMsgArgs(arg []byte) error {
// Unroll splitArgs to avoid runtime/heap issues
a := [MAX_MSG_ARGS][]byte{}
a := [MAX_RMSG_ARGS][]byte{}
args := a[:0]
start := -1
for i, b := range arg {
@@ -276,6 +358,7 @@ func (c *client) sendRouteConnect(tlsRequired bool) {
Pass: pass,
TLS: tlsRequired,
Name: c.srv.info.ID,
Headers: c.srv.supportsHeaders(),
}
b, err := json.Marshal(cinfo)
@@ -299,6 +382,8 @@ func (c *client) processRouteInfo(info *Info) {
sl := gacc.sl
gacc.mu.RUnlock()
supportsHeaders := c.srv.supportsHeaders()
c.mu.Lock()
// Connection can be closed at any time (by auth timeout, etc).
// Does not make sense to continue here if connection is gone.
@@ -349,6 +434,9 @@ func (c *client) processRouteInfo(info *Info) {
return
}
// Headers
c.headers = supportsHeaders && info.Headers
// Copy over important information.
c.route.authRequired = info.AuthRequired
c.route.tlsRequired = info.TLSRequired
@@ -1396,6 +1484,7 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) {
MaxPayload: s.info.MaxPayload,
Proto: proto,
GatewayURL: s.getGatewayURL(),
Headers: s.supportsHeaders(),
}
// Set this if only if advertise is not disabled
if !opts.Cluster.NoAdvertise {
@@ -1618,10 +1707,14 @@ func (c *client) processRouteConnect(srv *Server, arg []byte, lang string) error
if srv != nil {
perms = srv.getOpts().Cluster.Permissions
}
supportsHeaders := c.srv.supportsHeaders()
// Grab connection name of remote route.
c.mu.Lock()
c.route.remoteID = c.opts.Name
c.setRoutePermissions(perms)
c.headers = supportsHeaders && proto.Headers
c.mu.Unlock()
return nil
}

View File

@@ -50,6 +50,78 @@ func TestNewRouteInfoOnConnect(t *testing.T) {
if info.Nonce == "" {
t.Fatalf("Expected a non empty nonce in new route INFO")
}
// By default headers should be true.
if !info.Headers {
t.Fatalf("Expected to have headers on by default")
}
}
func TestNewRouteHeaderSupport(t *testing.T) {
srvA, srvB, optsA, optsB := runServers(t)
defer srvA.Shutdown()
defer srvB.Shutdown()
clientA := createClientConn(t, optsA.Host, optsA.Port)
defer clientA.Close()
clientB := createClientConn(t, optsB.Host, optsB.Port)
defer clientB.Close()
sendA, expectA := setupHeaderConn(t, clientA)
sendA("SUB foo bar 22\r\n")
sendA("PING\r\n")
expectA(pongRe)
if err := checkExpectedSubs(1, srvA, srvB); err != nil {
t.Fatalf("%v", err)
}
sendB, expectB := setupHeaderConn(t, clientB)
// Can not have \r\n in payload fyi for regex.
sendB("HPUB foo reply 12 14\r\nK1:V1,K2:V2 ok\r\n")
sendB("PING\r\n")
expectB(pongRe)
expectHeaderMsgs := expectHeaderMsgsCommand(t, expectA)
matches := expectHeaderMsgs(1)
checkHmsg(t, matches[0], "foo", "22", "reply", "12", "14", "K1:V1,K2:V2 ", "ok")
}
func TestNewRouteHeaderSupportOldAndNew(t *testing.T) {
optsA := LoadConfig("./configs/srv_a.conf")
optsA.NoHeaderSupport = true
srvA := RunServer(optsA)
defer srvA.Shutdown()
srvB, optsB := RunServerWithConfig("./configs/srv_b.conf")
defer srvB.Shutdown()
checkClusterFormed(t, srvA, srvB)
clientA := createClientConn(t, optsA.Host, optsA.Port)
defer clientA.Close()
clientB := createClientConn(t, optsB.Host, optsB.Port)
defer clientB.Close()
sendA, expectA := setupHeaderConn(t, clientA)
sendA("SUB foo bar 22\r\n")
sendA("PING\r\n")
expectA(pongRe)
if err := checkExpectedSubs(1, srvA, srvB); err != nil {
t.Fatalf("%v", err)
}
sendB, expectB := setupHeaderConn(t, clientB)
// Can not have \r\n in payload fyi for regex.
sendB("HPUB foo reply 12 14\r\nK1:V1,K2:V2 ok\r\n")
sendB("PING\r\n")
expectB(pongRe)
expectMsgs := expectMsgsCommand(t, expectA)
matches := expectMsgs(1)
checkMsg(t, matches[0], "foo", "22", "reply", "2", "ok")
}
func TestNewRouteConnectSubs(t *testing.T) {

View File

@@ -675,7 +675,7 @@ func TestServiceLatencyWithQueueSubscribersAndNames(t *testing.T) {
sc.setupLatencyTracking(t, 100)
selectServer := func() *server.Options {
si, ci := rand.Int63n(int64(numServers)), rand.Int63n(int64(numServers))
si, ci := rand.Int63n(int64(numServers)), rand.Int63n(int64(numClusters))
return sc.clusters[ci].opts[si]
}

View File

@@ -22,6 +22,7 @@ import (
"net"
"regexp"
"runtime"
"strconv"
"strings"
"testing"
"time"
@@ -41,7 +42,7 @@ var DefaultTestOptions = server.Options{
Port: 4222,
NoLog: true,
NoSigs: true,
MaxControlLine: 2048,
MaxControlLine: 4096,
DisableShortFirstPing: true,
}
@@ -195,12 +196,21 @@ func checkInfoMsg(t tLogger, c net.Conn) server.Info {
return sinfo
}
func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) {
func doHeadersConnect(t tLogger, c net.Conn, verbose, pedantic, ssl, headers bool) {
checkInfoMsg(t, c)
cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"tls_required\":%v}\r\n", verbose, pedantic, ssl)
cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"tls_required\":%v,\"headers\":%v}\r\n",
verbose, pedantic, ssl, headers)
sendProto(t, c, cs)
}
func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) {
doHeadersConnect(t, c, verbose, pedantic, ssl, false)
}
func doDefaultHeadersConnect(t tLogger, c net.Conn) {
doHeadersConnect(t, c, false, false, false, true)
}
func doDefaultConnect(t tLogger, c net.Conn) {
// Basic Connect
doConnect(t, c, false, false, false)
@@ -227,6 +237,11 @@ func setupRoute(t tLogger, c net.Conn, opts *server.Options) (sendFun, expectFun
return setupRouteEx(t, c, opts, id)
}
func setupHeaderConn(t tLogger, c net.Conn) (sendFun, expectFun) {
doDefaultHeadersConnect(t, c)
return sendCommand(t, c), expectCommand(t, c)
}
func setupConn(t tLogger, c net.Conn) (sendFun, expectFun) {
doDefaultConnect(t, c)
return sendCommand(t, c), expectCommand(t, c)
@@ -286,6 +301,7 @@ var (
infoRe = regexp.MustCompile(`INFO\s+([^\r\n]+)\r\n`)
pingRe = regexp.MustCompile(`^PING\r\n`)
pongRe = regexp.MustCompile(`^PONG\r\n`)
hmsgRe = regexp.MustCompile(`(?:(?:HMSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\s+(\d+)\s*\r\n([^\\r\\n]*?)\r\n)+?)`)
msgRe = regexp.MustCompile(`(?:(?:MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\s*\r\n([^\\r\\n]*?)\r\n)+?)`)
rawMsgRe = regexp.MustCompile(`(?:(?:MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\s*\r\n(.*?)))`)
okRe = regexp.MustCompile(`\A\+OK\r\n`)
@@ -308,6 +324,10 @@ const (
replyIndex = 4
lenIndex = 5
msgIndex = 6
// Headers
hlenIndex = 5
tlenIndex = 6
hmsgIndex = 7
// Routed Messages
accIndex = 1
@@ -327,7 +347,6 @@ func expectResult(t tLogger, c net.Conn, re *regexp.Regexp) []byte {
stackFatalf(t, "Error reading from conn: %v\n", err)
}
buf := expBuf[:n]
if !re.Match(buf) {
stackFatalf(t, "Response did not match expected: \n\tReceived:'%q'\n\tExpected:'%s'", buf, re)
}
@@ -416,6 +435,35 @@ func checkLmsg(t tLogger, m [][]byte, subject, replyAndQueues, len, msg string)
}
}
// This will check that we got what we expected from a header message.
func checkHmsg(t tLogger, m [][]byte, subject, sid, reply, hlen, len, hdr, msg string) {
if string(m[subIndex]) != subject {
stackFatalf(t, "Did not get correct subject: expected '%s' got '%s'\n", subject, m[subIndex])
}
if sid != "" && string(m[sidIndex]) != sid {
stackFatalf(t, "Did not get correct sid: expected '%s' got '%s'\n", sid, m[sidIndex])
}
if string(m[replyIndex]) != reply {
stackFatalf(t, "Did not get correct reply: expected '%s' got '%s'\n", reply, m[replyIndex])
}
if string(m[hlenIndex]) != hlen {
stackFatalf(t, "Did not get correct header length: expected '%s' got '%s'\n", hlen, m[hlenIndex])
}
if string(m[tlenIndex]) != len {
stackFatalf(t, "Did not get correct msg length: expected '%s' got '%s'\n", len, m[tlenIndex])
}
// Extract the payload and break up the headers and msg.
payload := string(m[hmsgIndex])
hi, _ := strconv.Atoi(hlen)
rhdr, rmsg := payload[:hi], payload[hi:]
if rhdr != hdr {
stackFatalf(t, "Did not get correct headers: expected '%s' got '%s'\n", hdr, rhdr)
}
if rmsg != msg {
stackFatalf(t, "Did not get correct msg: expected '%s' got '%s'\n", msg, rmsg)
}
}
// Closure for expectMsgs
func expectRmsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte {
return func(expected int) [][][]byte {
@@ -428,6 +476,18 @@ func expectRmsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte {
}
}
// Closure for expectHMsgs
func expectHeaderMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte {
return func(expected int) [][][]byte {
buf := ef(hmsgRe)
matches := hmsgRe.FindAllSubmatch(buf, -1)
if len(matches) != expected {
stackFatalf(t, "Did not get correct # msgs: %d vs %d\n", len(matches), expected)
}
return matches
}
}
// Closure for expectMsgs
func expectMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte {
return func(expected int) [][][]byte {