diff --git a/server/client.go b/server/client.go index 7ff1c4e2..41622e20 100644 --- a/server/client.go +++ b/server/client.go @@ -1369,7 +1369,9 @@ func computeRTT(start time.Time) time.Duration { return rtt } +// processConnect will process a client connect op. func (c *client) processConnect(arg []byte) error { + supportsHeaders := c.srv.supportsHeaders() c.mu.Lock() // If we can't stop the timer because the callback is in progress... if !c.clearAuthTimer() { @@ -1415,9 +1417,8 @@ func (c *client) processConnect(arg []byte) error { account := c.opts.Account accountNew := c.opts.AccountNew ujwt := c.opts.JWT - // For headers both need to support. If the server supports headers - /// it will have been set to true before this is called. - c.headers = c.headers && c.opts.Headers + // For headers both client and server need to support. + c.headers = supportsHeaders && c.opts.Headers c.mu.Unlock() if srv != nil { @@ -1881,12 +1882,14 @@ func (c *client) processHeaderPub(arg []byte) error { c.pa.reply = nil c.pa.hdr = parseSize(args[1]) c.pa.size = parseSize(args[2]) + c.pa.hdb = args[1] c.pa.szb = args[2] case 4: c.pa.subject = args[0] c.pa.reply = args[1] c.pa.hdr = parseSize(args[2]) c.pa.size = parseSize(args[3]) + c.pa.hdb = args[2] c.pa.szb = args[3] default: return fmt.Errorf("processHeaderPub Parse Error: '%s'", arg) @@ -2454,7 +2457,22 @@ func (c *client) checkDenySub(subject string) bool { return false } -func (c *client) msgHeader(mh []byte, sub *subscription, reply []byte) []byte { +func (c *client) msgHeader(subj []byte, sub *subscription, reply []byte) []byte { + // See if we should do headers. We have to have a headers msg and + // the client we are going to deliver to needs to support headers as well. + // TODO(dlc) - This should only be for client connections, but should we check? + doHeaders := c.pa.hdr > 0 && sub.client != nil && sub.client.headers + + var mh []byte + if doHeaders { + mh = c.msgb[:msgHeadProtoLen] + mh[0] = 'H' + } else { + mh = c.msgb[1:msgHeadProtoLen] + } + mh = append(mh, subj...) + mh = append(mh, ' ') + if len(sub.sid) > 0 { mh = append(mh, sub.sid...) mh = append(mh, ' ') @@ -2463,6 +2481,11 @@ func (c *client) msgHeader(mh []byte, sub *subscription, reply []byte) []byte { mh = append(mh, reply...) mh = append(mh, ' ') } + + if doHeaders { + mh = append(mh, c.pa.hdb...) + mh = append(mh, ' ') + } mh = append(mh, c.pa.szb...) mh = append(mh, _CRLF_...) return mh @@ -2567,6 +2590,12 @@ func (c *client) deliverMsg(sub *subscription, subject, mh, msg []byte, gwrply b } } + // Check here if we have a header with our message. If this client can not + // support we need to strip the headers from the payload. + if c.pa.hdr > 0 && !sub.client.headers { + msg = msg[c.pa.hdr:] + } + // Update statistics // The msg includes the CR_LF, so pull back out for accounting. @@ -3259,11 +3288,10 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, var didDeliver bool - // msg header for clients. - msgh := c.msgb[1:msgHeadProtoLen] - msgh = append(msgh, subj...) - msgh = append(msgh, ' ') - si := len(msgh) + // delivery subject for clients + var dsubj []byte + // Used as scratch if mapping + var _dsubj [64]byte // Loop over all normal subscriptions that match. for _, sub := range r.psubs { @@ -3288,17 +3316,15 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } continue } + // Assume delivery subject is normal subject to this point. + dsubj = subj // Check for stream import mapped subs. These apply to local subs only. if sub.im != nil && sub.im.prefix != "" { - // Redo the subject here on the fly. - msgh = c.msgb[1:msgHeadProtoLen] - msgh = append(msgh, sub.im.prefix...) - msgh = append(msgh, subj...) - msgh = append(msgh, ' ') - si = len(msgh) + dsubj = append(_dsubj[:0], sub.im.prefix...) + dsubj = append(dsubj, subj...) } // Normal delivery - mh := c.msgHeader(msgh[:si], sub, creply) + mh := c.msgHeader(dsubj, sub, creply) didDeliver = c.deliverMsg(sub, subj, mh, msg, rplyHasGWPrefix) || didDeliver } @@ -3397,14 +3423,12 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, break } - // Check for mapped subs + // Assume delivery subject is normal subject to this point. + dsubj = subj + // Check for stream import mapped subs. These apply to local subs only. if sub.im != nil && sub.im.prefix != "" { - // Redo the subject here on the fly. - msgh = c.msgb[1:msgHeadProtoLen] - msgh = append(msgh, sub.im.prefix...) - msgh = append(msgh, subject...) - msgh = append(msgh, ' ') - si = len(msgh) + dsubj = append(_dsubj[:0], sub.im.prefix...) + dsubj = append(dsubj, subj...) } var rreply = reply @@ -3413,7 +3437,7 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } // "rreply" will be stripped of the $GNR prefix (if present) // for client connections only. - mh := c.msgHeader(msgh[:si], sub, rreply) + mh := c.msgHeader(dsubj, sub, rreply) if c.deliverMsg(sub, subject, mh, msg, rplyHasGWPrefix) { didDeliver = true // Clear rsub diff --git a/server/client_test.go b/server/client_test.go index 91594f6e..8895f7c5 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -262,6 +262,144 @@ func TestClientHeaderSupport(t *testing.T) { } } +var hmsgPat = regexp.MustCompile(`HMSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)[^\S\r\n]+(\d+)\r\n`) + +func TestClientHeaderDeliverMsg(t *testing.T) { + opts := defaultServerOptions + opts.Port = -1 + s := New(&opts) + + c, cr, _ := newClientForServer(s) + defer c.close() + + connect := "CONNECT {\"headers\":true}" + subOp := "SUB foo 1" + pubOp := "HPUB foo 12 14\r\nName:Derek\r\nOK\r\n" + cmd := strings.Join([]string{connect, subOp, pubOp}, "\r\n") + + c.parseAsync(cmd) + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + + am := hmsgPat.FindAllStringSubmatch(l, -1) + if len(am) == 0 { + t.Fatalf("Did not get a match for %q", l) + } + matches := am[0] + if len(matches) != 7 { + t.Fatalf("Did not get correct # matches: %d vs %d\n", len(matches), 7) + } + if matches[SUB_INDEX] != "foo" { + t.Fatalf("Did not get correct subject: '%s'\n", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "1" { + t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX]) + } + if matches[HDR_INDEX] != "12" { + t.Fatalf("Did not get correct msg length: '%s'\n", matches[HDR_INDEX]) + } + if matches[TLEN_INDEX] != "14" { + t.Fatalf("Did not get correct msg length: '%s'\n", matches[TLEN_INDEX]) + } + checkPayload(cr, []byte("Name:Derek\r\nOK\r\n"), t) +} + +var smsgPat = regexp.MustCompile(`^MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\r\n`) + +func TestClientHeaderDeliverStrippedMsg(t *testing.T) { + opts := defaultServerOptions + opts.Port = -1 + s := New(&opts) + + c, _, _ := newClientForServer(s) + defer c.close() + + b, br, _ := newClientForServer(s) + defer b.close() + + // Does not support headers + b.parseAsync("SUB foo 1\r\nPING\r\n") + if _, err := br.ReadString('\n'); err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + + connect := "CONNECT {\"headers\":true}" + pubOp := "HPUB foo 12 14\r\nName:Derek\r\nOK\r\n" + cmd := strings.Join([]string{connect, pubOp}, "\r\n") + c.parseAsync(cmd) + // Read from 'b' client. + l, err := br.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + am := smsgPat.FindAllStringSubmatch(l, -1) + if len(am) == 0 { + t.Fatalf("Did not get a correct match for %q", l) + } + matches := am[0] + if len(matches) != 6 { + t.Fatalf("Did not get correct # matches: %d vs %d\n", len(matches), 6) + } + if matches[SUB_INDEX] != "foo" { + t.Fatalf("Did not get correct subject: '%s'\n", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "1" { + t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX]) + } + if matches[LEN_INDEX] != "14" { + t.Fatalf("Did not get correct msg length: '%s'\n", matches[LEN_INDEX]) + } + checkPayload(br, []byte("OK\r\n"), t) +} + +func TestClientHeaderDeliverQueueSubStrippedMsg(t *testing.T) { + opts := defaultServerOptions + opts.Port = -1 + s := New(&opts) + + c, _, _ := newClientForServer(s) + defer c.close() + + b, br, _ := newClientForServer(s) + defer b.close() + + // Does not support headers + b.parseAsync("SUB foo bar 1\r\nPING\r\n") + if _, err := br.ReadString('\n'); err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + + connect := "CONNECT {\"headers\":true}" + pubOp := "HPUB foo 12 14\r\nName:Derek\r\nOK\r\n" + cmd := strings.Join([]string{connect, pubOp}, "\r\n") + c.parseAsync(cmd) + // Read from 'b' client. + l, err := br.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + am := smsgPat.FindAllStringSubmatch(l, -1) + if len(am) == 0 { + t.Fatalf("Did not get a correct match for %q", l) + } + matches := am[0] + if len(matches) != 6 { + t.Fatalf("Did not get correct # matches: %d vs %d\n", len(matches), 6) + } + if matches[SUB_INDEX] != "foo" { + t.Fatalf("Did not get correct subject: '%s'\n", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "1" { + t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX]) + } + if matches[LEN_INDEX] != "14" { + t.Fatalf("Did not get correct msg length: '%s'\n", matches[LEN_INDEX]) + } + checkPayload(br, []byte("OK\r\n"), t) +} + func TestNonTLSConnectionState(t *testing.T) { _, c, _ := setupClient() defer c.close() @@ -435,6 +573,8 @@ const ( SID_INDEX = 2 REPLY_INDEX = 4 LEN_INDEX = 5 + HDR_INDEX = 5 + TLEN_INDEX = 6 ) func grabPayload(cr *bufio.Reader, expected int) []byte { diff --git a/server/parser.go b/server/parser.go index 7ce12048..f7e6df90 100644 --- a/server/parser.go +++ b/server/parser.go @@ -36,6 +36,7 @@ type pubArg struct { deliver []byte reply []byte szb []byte + hdb []byte queues [][]byte size int hdr int @@ -367,7 +368,7 @@ func (c *client) parse(buf []byte) error { c.drop, c.as, c.state = 0, i+1, OP_START // Drop all pub args c.pa.arg, c.pa.pacache, c.pa.account, c.pa.subject = nil, nil, nil, nil - c.pa.reply, c.pa.hdr, c.pa.size, c.pa.szb, c.pa.queues = nil, -1, 0, nil, nil + c.pa.reply, c.pa.hdr, c.pa.size, c.pa.szb, c.pa.hdb, c.pa.queues = nil, -1, 0, nil, nil, nil case OP_A: switch b { case '+': diff --git a/server/parser_test.go b/server/parser_test.go index 7afd269d..00528b88 100644 --- a/server/parser_test.go +++ b/server/parser_test.go @@ -306,16 +306,19 @@ func TestParseHeaderPub(t *testing.T) { t.Fatalf("Unexpected: %d : %v\n", c.state, err) } if !bytes.Equal(c.pa.subject, []byte("foo")) { - t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'\n", c.pa.subject) + t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'", c.pa.subject) } if c.pa.reply != nil { - t.Fatalf("Did not parse reply correctly: 'nil' vs '%s'\n", c.pa.reply) + t.Fatalf("Did not parse reply correctly: 'nil' vs '%s'", c.pa.reply) } if c.pa.hdr != 12 { - t.Fatalf("Did not parse msg header size correctly: 12 vs %d\n", c.pa.hdr) + t.Fatalf("Did not parse msg header size correctly: 12 vs %d", c.pa.hdr) + } + if !bytes.Equal(c.pa.hdb, []byte("12")) { + t.Fatalf("Did not parse or capture the header size as bytes correctly: %q", c.pa.hdb) } if c.pa.size != 17 { - t.Fatalf("Did not parse msg size correctly: 17 vs %d\n", c.pa.size) + t.Fatalf("Did not parse msg size correctly: 17 vs %d", c.pa.size) } // Clear snapshots @@ -326,16 +329,19 @@ func TestParseHeaderPub(t *testing.T) { t.Fatalf("Unexpected: %d : %v\n", c.state, err) } if !bytes.Equal(c.pa.subject, []byte("foo")) { - t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'\n", c.pa.subject) + t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'", c.pa.subject) } if !bytes.Equal(c.pa.reply, []byte("INBOX.22")) { - t.Fatalf("Did not parse reply correctly: 'INBOX.22' vs '%s'\n", c.pa.reply) + t.Fatalf("Did not parse reply correctly: 'INBOX.22' vs '%s'", c.pa.reply) } if c.pa.hdr != 12 { - t.Fatalf("Did not parse msg header size correctly: 12 vs %d\n", c.pa.hdr) + t.Fatalf("Did not parse msg header size correctly: 12 vs %d", c.pa.hdr) + } + if !bytes.Equal(c.pa.hdb, []byte("12")) { + t.Fatalf("Did not parse or capture the header size as bytes correctly: %q", c.pa.hdb) } if c.pa.size != 17 { - t.Fatalf("Did not parse msg size correctly: 17 vs %d\n", c.pa.size) + t.Fatalf("Did not parse msg size correctly: 17 vs %d", c.pa.size) } // Clear snapshots @@ -354,6 +360,9 @@ func TestParseHeaderPub(t *testing.T) { if c.pa.hdr != 0 { t.Fatalf("Did not parse msg header size correctly: 0 vs %d\n", c.pa.hdr) } + if !bytes.Equal(c.pa.hdb, []byte("0")) { + t.Fatalf("Did not parse or capture the header size as bytes correctly: %q", c.pa.hdb) + } if c.pa.size != 5 { t.Fatalf("Did not parse msg size correctly: 5 vs %d\n", c.pa.size) } diff --git a/server/server.go b/server/server.go index dc374025..a91c72e1 100644 --- a/server/server.go +++ b/server/server.go @@ -1846,7 +1846,6 @@ func (s *Server) createClient(conn net.Conn) *client { s.mu.Lock() info := s.copyInfo() c.nonce = []byte(info.Nonce) - c.headers = info.Headers s.totalClients++ s.mu.Unlock() @@ -2284,6 +2283,17 @@ func (s *Server) ReadyForConnections(dur time.Duration) bool { return false } +// Quick utility to function to tell if the server supports headers. +func (s *Server) supportsHeaders() bool { + if s == nil { + return false + } + s.mu.Lock() + ans := s.info.Headers + s.mu.Unlock() + return ans +} + // ID returns the server's ID func (s *Server) ID() string { s.mu.Lock()