From 41ec9359fcb334aa46d7b05345cc003918a562de Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Wed, 5 May 2021 22:06:43 -0700 Subject: [PATCH] Update client to released version Signed-off-by: Derek Collison --- go.mod | 2 +- go.sum | 4 +- vendor/github.com/nats-io/nats.go/.travis.yml | 5 +- vendor/github.com/nats-io/nats.go/README.md | 160 ++-- vendor/github.com/nats-io/nats.go/go_test.mod | 2 +- vendor/github.com/nats-io/nats.go/go_test.sum | 17 + vendor/github.com/nats-io/nats.go/js.go | 326 ++++++-- vendor/github.com/nats-io/nats.go/jsm.go | 13 +- vendor/github.com/nats-io/nats.go/nats.go | 596 +++++++++------ vendor/github.com/nats-io/nats.go/ws.go | 700 ++++++++++++++++++ vendor/modules.txt | 2 +- 11 files changed, 1446 insertions(+), 381 deletions(-) create mode 100644 vendor/github.com/nats-io/nats.go/ws.go diff --git a/go.mod b/go.mod index d9a490b0..9a537ab8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/klauspost/compress v1.11.12 github.com/minio/highwayhash v1.0.1 github.com/nats-io/jwt/v2 v2.0.2 - github.com/nats-io/nats.go v1.10.1-0.20210419223411-20527524c393 + github.com/nats-io/nats.go v1.11.0 github.com/nats-io/nkeys v0.3.0 github.com/nats-io/nuid v1.0.1 golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b diff --git a/go.sum b/go.sum index 7df6a1ba..a713a0e8 100644 --- a/go.sum +++ b/go.sum @@ -16,8 +16,8 @@ github.com/nats-io/jwt v1.2.2 h1:w3GMTO969dFg+UOKTmmyuu7IGdusK+7Ytlt//OYH/uU= github.com/nats-io/jwt v1.2.2/go.mod h1:/xX356yQA6LuXI9xWW7mZNpxgF2mBmGecH+Fj34sP5Q= github.com/nats-io/jwt/v2 v2.0.2 h1:ejVCLO8gu6/4bOKIHQpmB5UhhUJfAQw55yvLWpfmKjI= github.com/nats-io/jwt/v2 v2.0.2/go.mod h1:VRP+deawSXyhNjXmxPCHskrR6Mq50BqpEI5SEcNiGlY= -github.com/nats-io/nats.go v1.10.1-0.20210419223411-20527524c393 h1:GQxfDz4otI9mde5QqJlpyRNpa2tfURHiPy0YLf7hy4c= -github.com/nats-io/nats.go v1.10.1-0.20210419223411-20527524c393/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= +github.com/nats-io/nats.go v1.11.0 h1:L263PZkrmkRJRJT2YHU8GwWWvEvmr9/LUKuJTXsF32k= +github.com/nats-io/nats.go v1.11.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/nats-io/nkeys v0.2.0/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1tqEu/s= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= diff --git a/vendor/github.com/nats-io/nats.go/.travis.yml b/vendor/github.com/nats-io/nats.go/.travis.yml index f43acfe0..89c5c11f 100644 --- a/vendor/github.com/nats-io/nats.go/.travis.yml +++ b/vendor/github.com/nats-io/nats.go/.travis.yml @@ -15,6 +15,5 @@ before_script: - find . -type f -name "*.go" | xargs misspell -error -locale US - staticcheck ./... script: -- go test -modfile=go_test.mod -race ./... -- go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... -- if [[ "$TRAVIS_GO_VERSION" =~ 1.16 ]]; then ./scripts/cov.sh TRAVIS; else go test -race -v -p=1 ./... --failfast; fi +- go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... --failfast +- if [[ "$TRAVIS_GO_VERSION" =~ 1.16 ]]; then ./scripts/cov.sh TRAVIS; else go test -modfile=go_test.mod -race -v -p=1 ./... --failfast; fi diff --git a/vendor/github.com/nats-io/nats.go/README.md b/vendor/github.com/nats-io/nats.go/README.md index d86ba668..f6ecfc50 100644 --- a/vendor/github.com/nats-io/nats.go/README.md +++ b/vendor/github.com/nats-io/nats.go/README.md @@ -21,7 +21,7 @@ When using or transitioning to Go modules support: ```bash # Go client latest or explicit version go get github.com/nats-io/nats.go/@latest -go get github.com/nats-io/nats.go/@v1.10.0 +go get github.com/nats-io/nats.go/@v1.11.0 # For latest NATS Server, add /v2 at the end go get github.com/nats-io/nats-server/v2 @@ -82,6 +82,85 @@ nc.Drain() nc.Close() ``` +## JetStream Basic Usage + +```go +import "github.com/nats-io/nats.go" + +// Connect to NATS +nc, _ := nats.Connect(nats.DefaultURL) + +// Create JetStream Context +js, _ := nc.JetStream(nats.PublishAsyncMaxPending(256)) + +// Simple Stream Publisher +js.Publish("ORDERS.scratch", []byte("hello")) + +// Simple Async Stream Publisher +for i := 0; i < 500; i++ { + js.PublishAsync("ORDERS.scratch", []byte("hello")) +} +select { +case <-js.PublishAsyncComplete(): +case <-time.After(5 * time.Second): + fmt.Println("Did not resolve in time") +} + +// Simple Async Ephemeral Consumer +js.Subscribe("ORDERS.*", func(m *nats.Msg) { + fmt.Printf("Received a JetStream message: %s\n", string(m.Data)) +}) + +// Simple Sync Durable Consumer (optional SubOpts at the end) +sub, err := js.SubscribeSync("ORDERS.*", nats.Durable("MONITOR"), nats.MaxDeliver(3)) +m, err := sub.NextMsg(timeout) + +// Simple Pull Consumer +sub, err := js.PullSubscribe("ORDERS.*", "MONITOR") +msgs, err := sub.Fetch(10) + +// Unsubscribe +sub.Unsubscribe() + +// Drain +sub.Drain() +``` + +## JetStream Basic Management + +```go +import "github.com/nats-io/nats.go" + +// Connect to NATS +nc, _ := nats.Connect(nats.DefaultURL) + +// Create JetStream Context +js, _ := nc.JetStream() + +// Create a Stream +js.AddStream(&nats.StreamConfig{ + Name: "ORDERS", + Subjects: []string{"ORDERS.*"}, +}) + +// Update a Stream +js.UpdateStream(&nats.StreamConfig{ + Name: "ORDERS", + MaxBytes: 8, +}) + +// Create a Consumer +js.AddConsumer("ORDERS", &nats.ConsumerConfig{ + Durable: "MONITOR", +}) + +// Delete Consumer +js.DeleteConsumer("ORDERS", "MONITOR") + +// Delete Stream +js.DeleteStream("ORDERS") +``` + ## Encoded Connections ```go @@ -422,85 +501,6 @@ resp := &response{} err := c.RequestWithContext(ctx, "foo", req, resp) ``` -## JetStream Basic Usage - -```go -import "github.com/nats-io/nats.go" - -// Connect to NATS -nc, _ := nats.Connect(nats.DefaultURL) - -// Create JetStream Context -js, _ := nc.JetStream(nats.PublishAsyncMaxPending(256)) - -// Simple Stream Publisher -js.Publish("ORDERS.scratch", []byte("hello")) - -// Simple Async Stream Publisher -for i := 0; i < 500; i++ { - js.PublishAsync("ORDERS.scratch", []byte("hello")) -} -select { -case <-js.PublishAsyncComplete(): -case <-time.After(5 * time.Second): - fmt.Println("Did not resolve in time") -} - -// Simple Async Ephemeral Consumer -js.Subscribe("ORDERS.*", func(m *nats.Msg) { - fmt.Printf("Received a JetStream message: %s\n", string(m.Data)) -}) - -// Simple Sync Durable Consumer (optional SubOpts at the end) -sub, err := js.SubscribeSync("ORDERS.*", nats.Durable("MONITOR"), nats.MaxDeliver(3)) -m, err := sub.NextMsg(timeout) - -// Simple Pull Consumer -sub, err := js.PullSubscribe("ORDERS.*", "MONITOR") -msgs, err := sub.Fetch(10) - -// Unsubscribe -sub.Unsubscribe() - -// Drain -sub.Drain() -``` - -## JetStream Basic Management - -```go -import "github.com/nats-io/nats.go" - -// Connect to NATS -nc, _ := nats.Connect(nats.DefaultURL) - -// Create JetStream Context -js, _ := nc.JetStream() - -// Create a Stream -js.AddStream(&nats.StreamConfig{ - Name: "ORDERS", - Subjects: []string{"ORDERS.*"}, -}) - -// Update a Stream -js.UpdateStream(&nats.StreamConfig{ - Name: "ORDERS", - MaxBytes: 8, -}) - -// Create a Consumer -js.AddConsumer("ORDERS", &nats.ConsumerConfig{ - Durable: "MONITOR", -}) - -// Delete Consumer -js.DeleteConsumer("ORDERS", "MONITOR") - -// Delete Stream -js.DeleteStream("ORDERS") -``` - ## License Unless otherwise noted, the NATS source files are distributed diff --git a/vendor/github.com/nats-io/nats.go/go_test.mod b/vendor/github.com/nats-io/nats.go/go_test.mod index f07a6629..72e30d62 100644 --- a/vendor/github.com/nats-io/nats.go/go_test.mod +++ b/vendor/github.com/nats-io/nats.go/go_test.mod @@ -4,7 +4,7 @@ go 1.15 require ( github.com/golang/protobuf v1.4.2 - github.com/nats-io/nats-server/v2 v2.2.1 + github.com/nats-io/nats-server/v2 v2.2.3-0.20210501163444-670f44f1e82e github.com/nats-io/nkeys v0.3.0 github.com/nats-io/nuid v1.0.1 google.golang.org/protobuf v1.23.0 diff --git a/vendor/github.com/nats-io/nats.go/go_test.sum b/vendor/github.com/nats-io/nats.go/go_test.sum index 0e1fafb6..7567402c 100644 --- a/vendor/github.com/nats-io/nats.go/go_test.sum +++ b/vendor/github.com/nats-io/nats.go/go_test.sum @@ -35,6 +35,22 @@ github.com/nats-io/nats-server/v2 v2.1.8-0.20210227190344-51550e242af8/go.mod h1 github.com/nats-io/nats-server/v2 v2.2.1-0.20210330155036-61cbd74e213d/go.mod h1:eKlAaGmSQHZMFQA6x56AaP5/Bl9N3mWF4awyT2TTpzc= github.com/nats-io/nats-server/v2 v2.2.1 h1:QaWKih9qAa1kod7xXy0G1ry0AEUGmDEaptaiqzuO1e8= github.com/nats-io/nats-server/v2 v2.2.1/go.mod h1:A+5EOqdnhH7FvLxtAK6SEDx6hyHriVOwf+FT/eEV99c= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421001316-7ac0ff667439 h1:wbm+DoCrBx3XUkfgfnzSGKGKXSSnR8z0EzaH8iEsYT4= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421001316-7ac0ff667439/go.mod h1:A+5EOqdnhH7FvLxtAK6SEDx6hyHriVOwf+FT/eEV99c= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421031524-a3f66508dd3a h1:Ihh+7S9hHb3zn4nibE9EV8P3Ed7OrH4TlGXHqIUYDfk= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421031524-a3f66508dd3a/go.mod h1:aF2IwMZdYktJswITm41c/k66uCHjTvpTxGQ7+d4cPeg= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421135834-a9607573b30c h1:URcPI+y2OIGWM1pKzHhHTvRItB0Czlv3dzuJA0rklvk= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421135834-a9607573b30c/go.mod h1:aF2IwMZdYktJswITm41c/k66uCHjTvpTxGQ7+d4cPeg= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421164150-3d928c847a0c h1:cbbxAcABuk2WdXKRm9VezFcGsceRhls4VCmQ/2aRJjQ= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421164150-3d928c847a0c/go.mod h1:aF2IwMZdYktJswITm41c/k66uCHjTvpTxGQ7+d4cPeg= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421195432-ea21e86996f7 h1:wcd++VZMdwDpQ7P1VXJ7NpAwtgdlxcjFLZ12Y/pL8Nw= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421195432-ea21e86996f7/go.mod h1:aF2IwMZdYktJswITm41c/k66uCHjTvpTxGQ7+d4cPeg= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421215445-a48a39251636 h1:iy6c/tV66xi5DT9WLUu9rJ8uQj8Kf7kmwHAqlYfczP4= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421215445-a48a39251636/go.mod h1:aF2IwMZdYktJswITm41c/k66uCHjTvpTxGQ7+d4cPeg= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421232642-f2d3f5fb81d0 h1:e2MoeAShQE/oOSjkkV6J6R+l5ugbfkXI5spxgQykgoM= +github.com/nats-io/nats-server/v2 v2.2.2-0.20210421232642-f2d3f5fb81d0/go.mod h1:aF2IwMZdYktJswITm41c/k66uCHjTvpTxGQ7+d4cPeg= +github.com/nats-io/nats-server/v2 v2.2.3-0.20210501163444-670f44f1e82e h1:Hvpz1/Epth4q7LnaU0U9SqMFd8grUMFTL8LMO5HFVok= +github.com/nats-io/nats-server/v2 v2.2.3-0.20210501163444-670f44f1e82e/go.mod h1:aF2IwMZdYktJswITm41c/k66uCHjTvpTxGQ7+d4cPeg= github.com/nats-io/nats.go v1.10.0/go.mod h1:AjGArbfyR50+afOUotNX2Xs5SYHf+CoOa5HH1eEl2HE= github.com/nats-io/nats.go v1.10.1-0.20200531124210-96f2130e4d55/go.mod h1:ARiFsjW9DVxk48WJbO3OSZ2DG8fjkMi7ecLmXoY/n9I= github.com/nats-io/nats.go v1.10.1-0.20200606002146-fc6fed82929a/go.mod h1:8eAIv96Mo9QW6Or40jUHejS7e4VwZ3VRYD6Sf0BTDp4= @@ -43,6 +59,7 @@ github.com/nats-io/nats.go v1.10.1-0.20210127212649-5b4924938a9a/go.mod h1:Sa3kL github.com/nats-io/nats.go v1.10.1-0.20210211000709-75ded9c77585/go.mod h1:uBWnCKg9luW1g7hgzPxUjHFRI40EuTSX7RCzgnc74Jk= github.com/nats-io/nats.go v1.10.1-0.20210228004050-ed743748acac/go.mod h1:hxFvLNbNmT6UppX5B5Tr/r3g+XSwGjJzFn6mxPNJEHc= github.com/nats-io/nats.go v1.10.1-0.20210330225420-a0b1f60162f8/go.mod h1:Zq9IEHy7zurF0kFbU5aLIknnFI7guh8ijHk+2v+Vf5g= +github.com/nats-io/nats.go v1.10.1-0.20210419223411-20527524c393/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= github.com/nats-io/nkeys v0.1.4/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1tqEu/s= github.com/nats-io/nkeys v0.2.0/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1tqEu/s= diff --git a/vendor/github.com/nats-io/nats.go/js.go b/vendor/github.com/nats-io/nats.go/js.go index 626bd819..7b70bb1e 100644 --- a/vendor/github.com/nats-io/nats.go/js.go +++ b/vendor/github.com/nats-io/nats.go/js.go @@ -20,7 +20,7 @@ import ( "encoding/json" "errors" "fmt" - "net/http" + "math/rand" "strconv" "strings" "sync" @@ -146,6 +146,7 @@ type js struct { pafs map[string]*pubAckFuture stc chan struct{} dch chan struct{} + rr *rand.Rand } type jsOpts struct { @@ -283,7 +284,7 @@ func (js *js) PublishMsg(m *Msg, opts ...PubOpt) (*PubAck, error) { var o pubOpts if len(opts) > 0 { if m.Header == nil { - m.Header = http.Header{} + m.Header = Header{} } for _, opt := range opts { if err := opt.configurePublish(&o); err != nil { @@ -401,6 +402,27 @@ func (paf *pubAckFuture) Msg() *Msg { return paf.msg } +// pullSubscribe creates the wildcard subscription used per pull subscriber +// to make fetch requests. +func (js *js) pullSubscribe(subj string) (*Subscription, error) { + jsi := &jsSub{js: js, pull: true} + + // Similar to async request handler we create a wildcard subscription for making requests, + // though we do not use the token based approach since we cannot match the response to + // the requestor due to JS subject being remapped on delivery. Instead, we just use an array + // of channels similar to how ping/pong interval is handled and send the message to the first + // available requestor via a channel. + jsi.rr = rand.New(rand.NewSource(time.Now().UnixNano())) + jsi.rpre = fmt.Sprintf("%s.", NewInbox()) + sub, err := js.nc.Subscribe(fmt.Sprintf("%s*", jsi.rpre), jsi.handleFetch) + if err != nil { + return nil, err + } + jsi.psub = sub + + return &Subscription{Subject: subj, conn: js.nc, typ: PullSubscription, jsi: jsi}, nil +} + // For quick token lookup etc. const aReplyPreLen = 14 const aReplyTokensize = 6 @@ -422,10 +444,11 @@ func (js *js) newAsyncReply() string { return _EMPTY_ } js.rsub = sub + js.rr = rand.New(rand.NewSource(time.Now().UnixNano())) } var sb strings.Builder sb.WriteString(js.rpre) - rn := js.nc.respRand.Int63() + rn := js.rr.Int63() var b [aReplyTokensize]byte for i, l := 0, rn; i < len(b); i++ { b[i] = rdigits[l%base] @@ -584,7 +607,7 @@ func (js *js) PublishMsgAsync(m *Msg, opts ...PubOpt) (PubAckFuture, error) { var o pubOpts if len(opts) > 0 { if m.Header == nil { - m.Header = http.Header{} + m.Header = Header{} } for _, opt := range opts { if err := opt.configurePublish(&o); err != nil { @@ -754,7 +777,8 @@ func (ctx ContextOpt) configureAck(opts *ackOpts) error { return nil } -// Context returns an option that can be used to configure a context. +// Context returns an option that can be used to configure a context for APIs +// that are context aware such as those part of the JetStream interface. func Context(ctx context.Context) ContextOpt { return ContextOpt{ctx} } @@ -811,7 +835,15 @@ type nextRequest struct { // jsSub includes JetStream subscription info. type jsSub struct { - js *js + js *js + + // To setup request mux handler for pull subscribers. + mu sync.RWMutex + psub *Subscription + rpre string + rr *rand.Rand + freqs []chan *Msg + consumer string stream string deliver string @@ -820,17 +852,84 @@ type jsSub struct { attached bool // Heartbeats and Flow Control handling from push consumers. - hbs bool - fc bool - - // cmeta is holds metadata from a push consumer when HBs are enabled. - cmeta atomic.Value + hbs bool + fc bool + cmeta string + fcs map[uint64]string } -// controlMetadata is metadata used to be able to detect sequence mismatch -// errors in push based consumers that have heartbeats enabled. -type controlMetadata struct { - meta string +// newFetchReply generates a unique inbox used for a fetch request. +func (jsi *jsSub) newFetchReply() string { + jsi.mu.Lock() + rpre := jsi.rpre + rn := jsi.rr.Int63() + jsi.mu.Unlock() + var sb strings.Builder + sb.WriteString(rpre) + var b [aReplyTokensize]byte + for i, l := 0, rn; i < len(b); i++ { + b[i] = rdigits[l%base] + l /= base + } + sb.Write(b[:]) + return sb.String() +} + +// handleFetch is delivered a message requested by pull subscribers +// when calling Fetch. +func (jsi *jsSub) handleFetch(m *Msg) { + jsi.mu.Lock() + if len(jsi.freqs) == 0 { + nc := jsi.js.nc + sub := jsi.psub + nc.mu.Lock() + errCB := nc.Opts.AsyncErrorCB + err := fmt.Errorf("nats: fetch response delivered but requestor has gone away") + if errCB != nil { + nc.ach.push(func() { errCB(nc, sub, err) }) + } + nc.mu.Unlock() + jsi.mu.Unlock() + return + } + mch := jsi.freqs[0] + if len(jsi.freqs) > 1 { + jsi.freqs = append(jsi.freqs[:0], jsi.freqs[1:]...) + } else { + jsi.freqs = jsi.freqs[:0] + } + jsi.mu.Unlock() + mch <- m +} + +// fetchNoWait makes a request to get a single message using no wait. +func (jsi *jsSub) fetchNoWait(ctx context.Context, subj string, payload []byte) (*Msg, error) { + nc := jsi.js.nc + m := NewMsg(subj) + m.Reply = jsi.newFetchReply() + m.Data = payload + + mch := make(chan *Msg, 1) + jsi.mu.Lock() + jsi.freqs = append(jsi.freqs, mch) + jsi.mu.Unlock() + if err := nc.PublishMsg(m); err != nil { + return nil, err + } + + var ok bool + var msg *Msg + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + case <-ctx.Done(): + return nil, ctx.Err() + } + + return msg, nil } func (jsi *jsSub) unsubscribe(drainMode bool) error { @@ -839,6 +938,11 @@ func (jsi *jsSub) unsubscribe(drainMode bool) error { // consumers when using drain mode. return nil } + // Clear the extra async pull subscription used for fetch requests. + if jsi.psub != nil { + jsi.psub.Drain() + } + js := jsi.js return js.DeleteConsumer(jsi.stream, jsi.consumer) } @@ -979,12 +1083,18 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync } if isPullMode { - sub = &Subscription{Subject: subj, conn: js.nc, typ: PullSubscription, jsi: &jsSub{js: js, pull: isPullMode}} + sub, err = js.pullSubscribe(subj) } else { sub, err = js.nc.subscribe(deliver, queue, cb, ch, isSync, &jsSub{js: js, hbs: hasHeartbeats, fc: hasFC}) - if err != nil { - return nil, err - } + } + if err != nil { + return nil, err + } + + // With flow control enabled async subscriptions we will disable msgs + // limits, and set a larger pending bytes limit by default. + if !isPullMode && cb != nil && hasFC { + sub.SetPendingLimits(DefaultSubPendingMsgsLimit*16, DefaultSubPendingBytesLimit) } // If we are creating or updating let's process that request. @@ -1020,31 +1130,30 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync resp, err := js.nc.Request(js.apiSubj(ccSubj), j, js.opts.wait) if err != nil { + sub.Drain() if err == ErrNoResponders { err = ErrJetStreamNotEnabled } - sub.Unsubscribe() return nil, err } - var cinfo consumerResponse err = json.Unmarshal(resp.Data, &cinfo) if err != nil { - sub.Unsubscribe() + sub.Drain() return nil, err } info = cinfo.ConsumerInfo if cinfo.Error != nil { // Remove interest from previous subscribe since it // may have an incorrect delivery subject. - sub.Unsubscribe() + sub.Drain() // Multiple subscribers could compete in creating the first consumer // that will be shared using the same durable name. If this happens, then // do a lookup of the consumer info and resubscribe using the latest info. - if consumer != _EMPTY_ && strings.Contains(cinfo.Error.Description, `consumer already exists`) { + if consumer != _EMPTY_ && (strings.Contains(cinfo.Error.Description, `consumer already exists`) || strings.Contains(cinfo.Error.Description, `consumer name already in use`)) { info, err = js.ConsumerInfo(stream, consumer) - if err != nil && err.Error() != "nats: consumer not found" { + if err != nil { return nil, err } ccfg = &info.Config @@ -1056,6 +1165,10 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync // Use the deliver subject from latest consumer config to attach. if ccfg.DeliverSubject != _EMPTY_ { + // We can't reuse the channel, so if one was passed, we need to create a new one. + if ch != nil { + ch = make(chan *Msg, cap(ch)) + } sub, err = js.nc.subscribe(ccfg.DeliverSubject, queue, cb, ch, isSync, &jsSub{js: js, hbs: hasHeartbeats, fc: hasFC}) if err != nil { @@ -1071,49 +1184,130 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync consumer = info.Name deliver = info.Config.DeliverSubject } + sub.mu.Lock() sub.jsi.stream = stream sub.jsi.consumer = consumer sub.jsi.durable = isDurable sub.jsi.attached = attached sub.jsi.deliver = deliver + sub.mu.Unlock() return sub, nil } -func (nc *Conn) processControlFlow(msg *Msg, s *Subscription, jsi *jsSub) { - // If it is a flow control message then have to ack. - if msg.Reply != "" { - nc.publish(msg.Reply, _EMPTY_, nil, nil) - } else if jsi.hbs { - // Process heartbeat received, get latest control metadata if present. - var ctrl *controlMetadata - cmeta := jsi.cmeta.Load() - if cmeta == nil { - return - } +// ErrConsumerSequenceMismatch represents an error from a consumer +// that received a Heartbeat including sequence different to the +// one expected from the view of the client. +type ErrConsumerSequenceMismatch struct { + // StreamResumeSequence is the stream sequence from where the consumer + // should resume consuming from the stream. + StreamResumeSequence uint64 - ctrl = cmeta.(*controlMetadata) - tokens, err := getMetadataFields(ctrl.meta) - if err != nil { - return - } - // Consumer sequence - dseq := tokens[6] - ldseq := msg.Header.Get(lastConsumerSeqHdr) + // ConsumerSequence is the sequence of the consumer that is behind. + ConsumerSequence uint64 - // Detect consumer sequence mismatch and whether - // should restart the consumer. - if ldseq != dseq { - // Dispatch async error including details such as - // from where the consumer could be restarted. - sseq := parseNum(tokens[5]) - ecs := &ErrConsumerSequenceMismatch{ - StreamResumeSequence: uint64(sseq), - ConsumerSequence: parseNum(dseq), - LastConsumerSequence: parseNum(ldseq), - } - nc.handleConsumerSequenceMismatch(s, ecs) + // LastConsumerSequence is the sequence of the consumer when the heartbeat + // was received. + LastConsumerSequence uint64 +} + +func (ecs *ErrConsumerSequenceMismatch) Error() string { + return fmt.Sprintf("nats: sequence mismatch for consumer at sequence %d (%d sequences behind), should restart consumer from stream sequence %d", + ecs.ConsumerSequence, + ecs.LastConsumerSequence-ecs.ConsumerSequence, + ecs.StreamResumeSequence, + ) +} + +// isControlMessage will return true if this is an empty control status message. +func isControlMessage(msg *Msg) bool { + return len(msg.Data) == 0 && msg.Header.Get(statusHdr) == controlMsg +} + +func (jsi *jsSub) trackSequences(reply string) { + jsi.mu.Lock() + jsi.cmeta = reply + jsi.mu.Unlock() +} + +// checkForFlowControlResponse will check to see if we should send a flow control response +// based on the delivered index. +// Lock should be held. +func (sub *Subscription) checkForFlowControlResponse(delivered uint64) { + jsi, nc := sub.jsi, sub.conn + if jsi == nil { + return + } + + jsi.mu.Lock() + defer jsi.mu.Unlock() + + if len(jsi.fcs) == 0 { + return + } + + if reply := jsi.fcs[delivered]; reply != _EMPTY_ { + delete(jsi.fcs, delivered) + nc.Publish(reply, nil) + } +} + +// Record an inbound flow control message. +func (jsi *jsSub) scheduleFlowControlResponse(dfuture uint64, reply string) { + jsi.mu.Lock() + if jsi.fcs == nil { + jsi.fcs = make(map[uint64]string) + } + jsi.fcs[dfuture] = reply + jsi.mu.Unlock() +} + +// handleConsumerSequenceMismatch will send an async error that can be used to restart a push based consumer. +func (nc *Conn) handleConsumerSequenceMismatch(sub *Subscription, err error) { + nc.mu.Lock() + errCB := nc.Opts.AsyncErrorCB + if errCB != nil { + nc.ach.push(func() { errCB(nc, sub, err) }) + } + nc.mu.Unlock() +} + +// processControlFlow will automatically respond to control messages sent by the server. +func (nc *Conn) processSequenceMismatch(msg *Msg, s *Subscription, jsi *jsSub) { + // Process heartbeat received, get latest control metadata if present. + jsi.mu.RLock() + ctrl := jsi.cmeta + jsi.mu.RUnlock() + + if ctrl == _EMPTY_ { + return + } + + tokens, err := getMetadataFields(ctrl) + if err != nil { + return + } + + // Consumer sequence. + var ldseq string + dseq := tokens[6] + hdr := msg.Header[lastConsumerSeqHdr] + if len(hdr) == 1 { + ldseq = hdr[0] + } + + // Detect consumer sequence mismatch and whether + // should restart the consumer. + if ldseq != dseq { + // Dispatch async error including details such as + // from where the consumer could be restarted. + sseq := parseNum(tokens[5]) + ecs := &ErrConsumerSequenceMismatch{ + StreamResumeSequence: uint64(sseq), + ConsumerSequence: uint64(parseNum(dseq)), + LastConsumerSequence: uint64(parseNum(ldseq)), } + nc.handleConsumerSequenceMismatch(s, ecs) } } @@ -1364,7 +1558,8 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } sub.mu.Lock() - if sub.jsi == nil || sub.typ != PullSubscription { + jsi := sub.jsi + if jsi == nil || sub.typ != PullSubscription { sub.mu.Unlock() return nil, ErrTypeSubscription } @@ -1437,9 +1632,10 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { msgs = make([]*Msg, 0) ) - // In case of only one message, then can already handle with built-in request functions. if batch == 1 { - resp, err := nc.oldRequestWithContext(ctx, reqNext, nil, req) + // To optimize single message no wait fetch, we use a shared wildcard + // subscription per pull subscriber to wait for the response. + resp, err := jsi.fetchNoWait(ctx, reqNext, req) if err != nil { return nil, checkCtxErr(err) } @@ -1646,12 +1842,12 @@ func (js *js) getConsumerInfoContext(ctx context.Context, stream, consumer strin return info.ConsumerInfo, nil } -func (m *Msg) checkReply() (*js, bool, error) { +func (m *Msg) checkReply() (*js, *jsSub, error) { if m == nil || m.Sub == nil { - return nil, false, ErrMsgNotBound + return nil, nil, ErrMsgNotBound } if m.Reply == "" { - return nil, false, ErrMsgNoReply + return nil, nil, ErrMsgNoReply } sub := m.Sub sub.mu.Lock() @@ -1659,13 +1855,13 @@ func (m *Msg) checkReply() (*js, bool, error) { sub.mu.Unlock() // Not using a JS context. - return nil, false, nil + return nil, nil, nil } js := sub.jsi.js - isPullMode := sub.jsi.pull + jsi := sub.jsi sub.mu.Unlock() - return js, isPullMode, nil + return js, jsi, nil } // ackReply handles all acks. Will do the right thing for pull and sync mode. diff --git a/vendor/github.com/nats-io/nats.go/jsm.go b/vendor/github.com/nats-io/nats.go/jsm.go index 6b1900fa..b6a3f8b7 100644 --- a/vendor/github.com/nats-io/nats.go/jsm.go +++ b/vendor/github.com/nats-io/nats.go/jsm.go @@ -18,7 +18,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "time" ) @@ -532,6 +531,10 @@ func (js *js) AddStream(cfg *StreamConfig, opts ...JSOpt) (*StreamInfo, error) { return nil, ErrStreamNameRequired } + if strings.Contains(cfg.Name, ".") { + return nil, ErrInvalidStreamName + } + req, err := json.Marshal(cfg) if err != nil { return nil, err @@ -555,6 +558,10 @@ func (js *js) AddStream(cfg *StreamConfig, opts ...JSOpt) (*StreamInfo, error) { type streamInfoResponse = streamCreateResponse func (js *js) StreamInfo(stream string, opts ...JSOpt) (*StreamInfo, error) { + if strings.Contains(stream, ".") { + return nil, ErrInvalidStreamName + } + o, cancel, err := getJSContextOpts(js.opts, opts...) if err != nil { return nil, err @@ -701,7 +708,7 @@ type apiMsgGetRequest struct { type RawStreamMsg struct { Subject string Sequence uint64 - Header http.Header + Header Header Data []byte Time time.Time } @@ -757,7 +764,7 @@ func (js *js) GetMsg(name string, seq uint64, opts ...JSOpt) (*RawStreamMsg, err msg := resp.Message - var hdr http.Header + var hdr Header if msg.Header != nil { hdr, err = decodeHeadersMsg(msg.Header) if err != nil { diff --git a/vendor/github.com/nats-io/nats.go/nats.go b/vendor/github.com/nats-io/nats.go/nats.go index 3f025261..88f52810 100644 --- a/vendor/github.com/nats-io/nats.go/nats.go +++ b/vendor/github.com/nats-io/nats.go/nats.go @@ -426,6 +426,10 @@ type Options struct { // is established, and if a ClosedHandler is set, it will be invoked if // it fails to connect (after exhausting the MaxReconnect attempts). RetryOnFailedConnect bool + + // For websocket connections, indicates to the server that the connection + // supports compression. If the server does too, then data will be compressed. + Compression bool } const ( @@ -466,8 +470,8 @@ type Conn struct { current *srv urls map[string]struct{} // Keep track of all known URLs (used by processInfo) conn net.Conn - bw *bufio.Writer - pending *bytes.Buffer + bw *natsWriter + br *natsReader fch chan struct{} info serverInfo ssid int64 @@ -484,6 +488,7 @@ type Conn struct { pout int ar bool // abort reconnect rqch chan struct{} + ws bool // true if a websocket connection // New style response handler respSub string // The wildcard subject @@ -496,6 +501,21 @@ type Conn struct { jsLastCheck time.Time } +type natsReader struct { + r io.Reader + buf []byte + off int + n int +} + +type natsWriter struct { + w io.Writer + bufs []byte + limit int + pending *bytes.Buffer + plimit int +} + // Subscription represents interest in a given subject. type Subscription struct { mu sync.Mutex @@ -562,7 +582,7 @@ type Subscription struct { type Msg struct { Subject string Reply string - Header http.Header + Header Header Data []byte Sub *Subscription next *Msg @@ -582,7 +602,7 @@ func (m *Msg) headerBytes() ([]byte, error) { return nil, ErrBadHeaderMsg } - err = m.Header.Write(&b) + err = http.Header(m.Header).Write(&b) if err != nil { return nil, ErrBadHeaderMsg } @@ -676,6 +696,8 @@ type MsgHandler func(msg *Msg) // The url can contain username/password semantics. e.g. nats://derek:pass@localhost:4222 // Comma separated arrays are also supported, e.g. urlA, urlB. // Options start with the defaults but can be overridden. +// To connect to a NATS Server's websocket port, use the `ws` or `wss` scheme, such as +// `ws://localhost:8080`. Note that websocket schemes cannot be mixed with others (nats/tls). func Connect(url string, options ...Option) (*Conn, error) { opts := GetDefaultOptions() opts.Servers = processUrlString(url) @@ -1070,6 +1092,15 @@ func RetryOnFailedConnect(retry bool) Option { } } +// Compression is an Option to indicate if this connection supports +// compression. Currently only supported for Websocket connections. +func Compression(enabled bool) Option { + return func(o *Options) error { + o.Compression = enabled + return nil + } +} + // Handler processing // SetDisconnectHandler will set the disconnect event handler. @@ -1194,6 +1225,9 @@ func (o Options) Connect() (*Conn, error) { nc.Opts.AsyncErrorCB = defaultErrHandler } + // Create reader/writer + nc.newReaderWriter() + if err := nc.connect(); err != nil { return nil, err } @@ -1228,6 +1262,8 @@ const ( _HPUB_P_ = "HPUB " ) +var _CRLF_BYTES_ = []byte(_CRLF_) + const ( _OK_OP_ = "+OK" _ERR_OP_ = "-ERR" @@ -1355,6 +1391,12 @@ func (nc *Conn) setupServerPool() error { // Helper function to return scheme func (nc *Conn) connScheme() string { + if nc.ws { + if nc.Opts.Secure { + return wsSchemeTLS + } + return wsScheme + } if nc.Opts.Secure { return tlsScheme } @@ -1391,6 +1433,16 @@ func (nc *Conn) addURLToPool(sURL string, implicit, saveTLSName bool) error { sURL += defaultPortString } + isWS := isWebsocketScheme(u) + // We don't support mix and match of websocket and non websocket URLs. + // If this is the first URL, then we accept and switch the global state + // to websocket. After that, we will know how to reject mixed URLs. + if len(nc.srvPool) == 0 { + nc.ws = isWS + } else if isWS && !nc.ws || !isWS && nc.ws { + return fmt.Errorf("mixing of websocket and non websocket URLs is not allowed") + } + var tlsName string if implicit { curl := nc.current.url @@ -1428,12 +1480,145 @@ func (nc *Conn) shufflePool(offset int) { } } -func (nc *Conn) newBuffer() *bufio.Writer { +func (nc *Conn) newReaderWriter() { + nc.br = &natsReader{ + buf: make([]byte, defaultBufSize), + off: -1, + } + nc.bw = &natsWriter{ + limit: defaultBufSize, + plimit: nc.Opts.ReconnectBufSize, + } +} + +func (nc *Conn) bindToNewConn() { + bw := nc.bw + bw.w, bw.bufs = nc.newWriter(), nil + br := nc.br + br.r, br.n, br.off = nc.conn, 0, -1 +} + +func (nc *Conn) newWriter() io.Writer { var w io.Writer = nc.conn if nc.Opts.FlusherTimeout > 0 { w = &timeoutWriter{conn: nc.conn, timeout: nc.Opts.FlusherTimeout} } - return bufio.NewWriterSize(w, defaultBufSize) + return w +} + +func (w *natsWriter) appendString(str string) error { + return w.appendBufs([]byte(str)) +} + +func (w *natsWriter) appendBufs(bufs ...[]byte) error { + for _, buf := range bufs { + if len(buf) == 0 { + continue + } + if w.pending != nil { + w.pending.Write(buf) + } else { + w.bufs = append(w.bufs, buf...) + } + } + if w.pending == nil && len(w.bufs) >= w.limit { + return w.flush() + } + return nil +} + +func (w *natsWriter) writeDirect(strs ...string) error { + for _, str := range strs { + if _, err := w.w.Write([]byte(str)); err != nil { + return err + } + } + return nil +} + +func (w *natsWriter) flush() error { + // If a pending buffer is set, we don't flush. Code that needs to + // write directly to the socket, by-passing buffers during (re)connect, + // will use the writeDirect() API. + if w.pending != nil { + return nil + } + // Do not skip calling w.w.Write() here if len(w.bufs) is 0 because + // the actual writer (if websocket for instance) may have things + // to do such as sending control frames, etc.. + _, err := w.w.Write(w.bufs) + w.bufs = w.bufs[:0] + return err +} + +func (w *natsWriter) buffered() int { + if w.pending != nil { + return w.pending.Len() + } + return len(w.bufs) +} + +func (w *natsWriter) switchToPending() { + w.pending = new(bytes.Buffer) +} + +func (w *natsWriter) flushPendingBuffer() error { + if w.pending == nil || w.pending.Len() == 0 { + return nil + } + _, err := w.w.Write(w.pending.Bytes()) + // Reset the pending buffer at this point because we don't want + // to take the risk of sending duplicates or partials. + w.pending.Reset() + return err +} + +func (w *natsWriter) atLimitIfUsingPending() bool { + if w.pending == nil { + return false + } + return w.pending.Len() >= w.plimit +} + +func (w *natsWriter) doneWithPending() { + w.pending = nil +} + +func (r *natsReader) Read() ([]byte, error) { + if r.off >= 0 { + off := r.off + r.off = -1 + return r.buf[off:r.n], nil + } + var err error + r.n, err = r.r.Read(r.buf) + return r.buf[:r.n], err +} + +func (r *natsReader) ReadString(delim byte) (string, error) { + var s string +build_string: + // First look if we have something in the buffer + if r.off >= 0 { + i := bytes.IndexByte(r.buf[r.off:r.n], delim) + if i >= 0 { + end := r.off + i + 1 + s += string(r.buf[r.off:end]) + r.off = end + if r.off >= r.n { + r.off = -1 + } + return s, nil + } + // We did not find the delim, so will have to read more. + s += string(r.buf[r.off:r.n]) + r.off = -1 + } + if _, err := r.Read(); err != nil { + return s, err + } + r.off = 0 + goto build_string } // createConn will connect to the server and wrap the appropriate @@ -1488,11 +1673,13 @@ func (nc *Conn) createConn() (err error) { return err } - if nc.pending != nil && nc.bw != nil { - // Move to pending buffer. - nc.bw.Flush() + // If scheme starts with "ws" then branch out to websocket code. + if isWebsocketScheme(u) { + return nc.wsInitHandshake(u) } - nc.bw = nc.newBuffer() + + // Reset reader/writer to this new TCP connection + nc.bindToNewConn() return nil } @@ -1519,7 +1706,7 @@ func (nc *Conn) makeTLSConn() error { if err := conn.Handshake(); err != nil { return err } - nc.bw = nc.newBuffer() + nc.bindToNewConn() return nil } @@ -1669,7 +1856,7 @@ func (nc *Conn) processConnectInit() error { // Main connect function. Will connect to the nats-server func (nc *Conn) connect() error { - var returnedErr error + var err error // Create actual socket connection // For first connect we walk all servers in the pool and try @@ -1681,7 +1868,7 @@ func (nc *Conn) connect() error { for i := 0; i < len(nc.srvPool); i++ { nc.current = nc.srvPool[i] - if err := nc.createConn(); err == nil { + if err = nc.createConn(); err == nil { // This was moved out of processConnectInit() because // that function is now invoked from doReconnect() too. nc.setup() @@ -1692,10 +1879,8 @@ func (nc *Conn) connect() error { nc.current.didConnect = true nc.current.reconnects = 0 nc.current.lastErr = nil - returnedErr = nil break } else { - returnedErr = err nc.mu.Unlock() nc.close(DISCONNECTED, false, err) nc.mu.Lock() @@ -1707,32 +1892,28 @@ func (nc *Conn) connect() error { // Cancel out default connection refused, will trigger the // No servers error conditional if strings.Contains(err.Error(), "connection refused") { - returnedErr = nil + err = nil } } } - if returnedErr == nil && nc.status != CONNECTED { - returnedErr = ErrNoServers + if err == nil && nc.status != CONNECTED { + err = ErrNoServers } - if returnedErr == nil { + if err == nil { nc.initc = false } else if nc.Opts.RetryOnFailedConnect { nc.setup() nc.status = RECONNECTING - nc.pending = new(bytes.Buffer) - if nc.bw == nil { - nc.bw = nc.newBuffer() - } - nc.bw.Reset(nc.pending) + nc.bw.switchToPending() go nc.doReconnect(ErrNoServers) - returnedErr = nil + err = nil } else { nc.current = nil } - return returnedErr + return err } // This will check to see if the connection should be @@ -1785,6 +1966,12 @@ func (nc *Conn) processExpectedInfo() error { return ErrNkeysNotSupported } + // For websocket connections, we already switched to TLS if need be, + // so we are done here. + if nc.ws { + return nil + } + return nc.checkForSecure() } @@ -1792,7 +1979,7 @@ func (nc *Conn) processExpectedInfo() error { // and kicking the flush Go routine. These writes are protected. func (nc *Conn) sendProto(proto string) { nc.mu.Lock() - nc.bw.WriteString(proto) + nc.bw.appendString(proto) nc.kickFlusher() nc.mu.Unlock() } @@ -1887,21 +2074,8 @@ func (nc *Conn) sendConnect() error { return err } - // Write the protocol into the buffer - _, err = nc.bw.WriteString(cProto) - if err != nil { - return err - } - - // Add to the buffer the PING protocol - _, err = nc.bw.WriteString(pingProto) - if err != nil { - return err - } - - // Flush the buffer - err = nc.bw.Flush() - if err != nil { + // Write the protocol and PING directly to the underlying writer. + if err := nc.bw.writeDirect(cProto, pingProto); err != nil { return err } @@ -1958,27 +2132,9 @@ func (nc *Conn) sendConnect() error { return nil } -// reads a protocol one byte at a time. +// reads a protocol line. func (nc *Conn) readProto() (string, error) { - var ( - _buf = [10]byte{} - buf = _buf[:0] - b = [1]byte{} - protoEnd = byte('\n') - ) - for { - if _, err := nc.conn.Read(b[:1]); err != nil { - // Do not report EOF error - if err == io.EOF { - return string(buf), nil - } - return "", err - } - buf = append(buf, b[0]) - if b[0] == protoEnd { - return string(buf), nil - } - } + return nc.br.ReadString('\n') } // A control protocol line. @@ -1988,8 +2144,7 @@ type control struct { // Read a control line and process the intended op. func (nc *Conn) readOp(c *control) error { - br := bufio.NewReaderSize(nc.conn, defaultBufSize) - line, err := br.ReadString('\n') + line, err := nc.readProto() if err != nil { return err } @@ -2012,13 +2167,8 @@ func parseControl(line string, c *control) { // flushReconnectPending will push the pending items that were // gathered while we were in a RECONNECTING state to the socket. -func (nc *Conn) flushReconnectPendingItems() { - if nc.pending == nil { - return - } - if nc.pending.Len() > 0 { - nc.bw.Write(nc.pending.Bytes()) - } +func (nc *Conn) flushReconnectPendingItems() error { + return nc.bw.flushPendingBuffer() } // Stops the ping timer if set. @@ -2044,9 +2194,6 @@ func (nc *Conn) doReconnect(err error) { // can't do defer here. nc.mu.Lock() - // Clear any queued pongs, e.g. pending flush calls. - nc.clearPendingFlushCalls() - // Clear any errors. nc.err = nil // Perform appropriate callback if needed for a disconnect. @@ -2156,9 +2303,6 @@ func (nc *Conn) doReconnect(err error) { break } nc.status = RECONNECTING - // Reset the buffered writer to the pending buffer - // (was set to a buffered writer on nc.conn in createConn) - nc.bw.Reset(nc.pending) continue } @@ -2174,14 +2318,9 @@ func (nc *Conn) doReconnect(err error) { nc.resendSubscriptions() // Now send off and clear pending buffer - nc.flushReconnectPendingItems() - - // Flush the buffer - nc.err = nc.bw.Flush() + nc.err = nc.flushReconnectPendingItems() if nc.err != nil { nc.status = RECONNECTING - // Reset the buffered writer to the pending buffer (bytes.Buffer). - nc.bw.Reset(nc.pending) // Stop the ping timer (if set) nc.stopPingTimer() // Since processConnectInit() returned without error, the @@ -2192,7 +2331,7 @@ func (nc *Conn) doReconnect(err error) { } // Done with the pending buffer - nc.pending = nil + nc.bw.doneWithPending() // This is where we are truly connected. nc.status = CONNECTED @@ -2238,14 +2377,15 @@ func (nc *Conn) processOpErr(err error) { // Stop ping timer if set nc.stopPingTimer() if nc.conn != nil { - nc.bw.Flush() nc.conn.Close() nc.conn = nil } // Create pending buffer before reconnecting. - nc.pending = new(bytes.Buffer) - nc.bw.Reset(nc.pending) + nc.bw.switchToPending() + + // Clear any queued pongs, e.g. pending flush calls. + nc.clearPendingFlushCalls() go nc.doReconnect(err) nc.mu.Unlock() @@ -2332,20 +2472,19 @@ func (nc *Conn) readLoop() { nc.ps = &parseState{} } conn := nc.conn + br := nc.br nc.mu.Unlock() if conn == nil { return } - // Stack based buffer. - b := make([]byte, defaultBufSize) - for { - if n, err := conn.Read(b); err != nil { - nc.processOpErr(err) - break - } else if err = nc.parse(b[:n]); err != nil { + buf, err := br.Read() + if err == nil { + err = nc.parse(buf) + } + if err != nil { nc.processOpErr(err) break } @@ -2400,6 +2539,9 @@ func (nc *Conn) waitForMsgs(s *Subscription) { if !s.closed { s.delivered++ delivered = s.delivered + if s.jsi != nil && s.jsi.fc && len(s.jsi.fcs) > 0 { + s.checkForFlowControlResponse(delivered) + } } s.mu.Unlock() @@ -2466,10 +2608,11 @@ func (nc *Conn) processMsg(data []byte) { copy(msgPayload, data) // Check if we have headers encoded here. - var h http.Header + var h Header var err error - var ctrl bool + var ctrlMsg bool var hasFC bool + var hasHBs bool if nc.ps.ma.hdr > 0 { hbuf := msgPayload[:nc.ps.ma.hdr] @@ -2491,40 +2634,40 @@ func (nc *Conn) processMsg(data []byte) { sub.mu.Lock() - // Skip flow control messages in case of using a JetStream context. - jsi := sub.jsi - if jsi != nil { - ctrl = isControlMessage(m) - hasFC = jsi.fc - } - // Check if closed. if sub.closed { sub.mu.Unlock() return } - // Subscription internal stats (applicable only for non ChanSubscription's) - if sub.typ != ChanSubscription { - sub.pMsgs++ - if sub.pMsgs > sub.pMsgsMax { - sub.pMsgsMax = sub.pMsgs - } - sub.pBytes += len(m.Data) - if sub.pBytes > sub.pBytesMax { - sub.pBytesMax = sub.pBytes - } - - // Check for a Slow Consumer - if (sub.pMsgsLimit > 0 && sub.pMsgs > sub.pMsgsLimit) || - (sub.pBytesLimit > 0 && sub.pBytes > sub.pBytesLimit) { - goto slowConsumer - } + // Skip flow control messages in case of using a JetStream context. + jsi := sub.jsi + if jsi != nil { + ctrlMsg, hasHBs, hasFC = isControlMessage(m), jsi.hbs, jsi.fc } - // We have two modes of delivery. One is the channel, used by channel - // subscribers and syncSubscribers, the other is a linked list for async. - if !ctrl { + // Skip processing if this is a control message. + if !ctrlMsg { + // Subscription internal stats (applicable only for non ChanSubscription's) + if sub.typ != ChanSubscription { + sub.pMsgs++ + if sub.pMsgs > sub.pMsgsMax { + sub.pMsgsMax = sub.pMsgs + } + sub.pBytes += len(m.Data) + if sub.pBytes > sub.pBytesMax { + sub.pBytesMax = sub.pBytes + } + + // Check for a Slow Consumer + if (sub.pMsgsLimit > 0 && sub.pMsgs > sub.pMsgsLimit) || + (sub.pBytesLimit > 0 && sub.pBytes > sub.pBytesLimit) { + goto slowConsumer + } + } + + // We have two modes of delivery. One is the channel, used by channel + // subscribers and syncSubscribers, the other is a linked list for async. if sub.mch != nil { select { case sub.mch <- m: @@ -2544,8 +2687,19 @@ func (nc *Conn) processMsg(data []byte) { sub.pTail = m } } - if hasFC { - jsi.trackSequences(m) + if jsi != nil && hasHBs { + // Store the ACK metadata from the message to + // compare later on with the received heartbeat. + jsi.trackSequences(m.Reply) + } + } else if hasFC && m.Reply != _EMPTY_ { + // This is a flow control message. + // If we have no pending, go ahead and send in place. + if sub.pMsgs == 0 { + nc.Publish(m.Reply, nil) + } else { + // Schedule a reply after the previous message is delivered. + jsi.scheduleFlowControlResponse(sub.delivered+uint64(sub.pMsgs), m.Reply) } } @@ -2554,10 +2708,9 @@ func (nc *Conn) processMsg(data []byte) { sub.mu.Unlock() - // Handle flow control and heartbeat messages automatically - // for JetStream Push consumers. - if ctrl { - nc.processControlFlow(m, sub, jsi) + // Handle control heartbeat messages. + if ctrlMsg && hasHBs && m.Reply == _EMPTY_ { + nc.processSequenceMismatch(m, sub, jsi) } return @@ -2642,12 +2795,12 @@ func (nc *Conn) flusher() { nc.mu.Lock() // Check to see if we should bail out. - if !nc.isConnected() || nc.isConnecting() || bw != nc.bw || conn != nc.conn { + if !nc.isConnected() || nc.isConnecting() || conn != nc.conn { nc.mu.Unlock() return } - if bw.Buffered() > 0 { - if err := bw.Flush(); err != nil { + if bw.buffered() > 0 { + if err := bw.flush(); err != nil { if nc.err == nil { nc.err = err } @@ -2671,7 +2824,7 @@ func (nc *Conn) processPong() { nc.mu.Lock() if len(nc.pongs) > 0 { ch = nc.pongs[0] - nc.pongs = nc.pongs[1:] + nc.pongs = append(nc.pongs[:0], nc.pongs[1:]...) } nc.pout = 0 nc.mu.Unlock() @@ -2862,11 +3015,52 @@ func (nc *Conn) Publish(subj string, data []byte) error { return nc.publish(subj, _EMPTY_, nil, data) } +// Header represents the optional Header for a NATS message, +// based on the implementation of http.Header. +type Header map[string][]string + +// Add adds the key, value pair to the header. It is case-sensitive +// and appends to any existing values associated with key. +func (h Header) Add(key, value string) { + h[key] = append(h[key], value) +} + +// Set sets the header entries associated with key to the single +// element value. It is case-sensitive and replaces any existing +// values associated with key. +func (h Header) Set(key, value string) { + h[key] = []string{value} +} + +// Get gets the first value associated with the given key. +// It is case-sensitive. +func (h Header) Get(key string) string { + if h == nil { + return _EMPTY_ + } + if v := h[key]; v != nil { + return v[0] + } + return _EMPTY_ +} + +// Values returns all values associated with the given key. +// It is case-sensitive. +func (h Header) Values(key string) []string { + return h[key] +} + +// Del deletes the values associated with a key. +// It is case-sensitive. +func (h Header) Del(key string) { + delete(h, key) +} + // NewMsg creates a message for publishing that will use headers. func NewMsg(subject string) *Msg { return &Msg{ Subject: subject, - Header: make(http.Header), + Header: make(Header), } } @@ -2885,7 +3079,7 @@ const ( ) // decodeHeadersMsg will decode and headers. -func decodeHeadersMsg(data []byte) (http.Header, error) { +func decodeHeadersMsg(data []byte) (Header, error) { tp := textproto.NewReader(bufio.NewReader(bytes.NewReader(data))) l, err := tp.ReadLine() if err != nil || len(l) < hdrPreEnd || l[:hdrPreEnd] != hdrLine[:hdrPreEnd] { @@ -2910,7 +3104,7 @@ func decodeHeadersMsg(data []byte) (http.Header, error) { mh.Add(descrHdr, description) } } - return http.Header(mh), nil + return Header(mh), nil } // readMIMEHeader returns a MIMEHeader that preserves the @@ -3026,14 +3220,9 @@ func (nc *Conn) publish(subj, reply string, hdr, data []byte) error { // Check if we are reconnecting, and if so check if // we have exceeded our reconnect outbound buffer limits. - if nc.isReconnecting() { - // Flush to underlying buffer. - nc.bw.Flush() - // Check if we are over - if nc.pending.Len() >= nc.Opts.ReconnectBufSize { - nc.mu.Unlock() - return ErrReconnectBufExceeded - } + if nc.bw.atLimitIfUsingPending() { + nc.mu.Unlock() + return ErrReconnectBufExceeded } var mh []byte @@ -3087,19 +3276,7 @@ func (nc *Conn) publish(subj, reply string, hdr, data []byte) error { mh = append(mh, b[i:]...) mh = append(mh, _CRLF_...) - _, err := nc.bw.Write(mh) - if err == nil { - if hdr != nil { - _, err = nc.bw.Write(hdr) - } - if err == nil { - _, err = nc.bw.Write(data) - } - } - if err == nil { - _, err = nc.bw.WriteString(_CRLF_) - } - if err != nil { + if err := nc.bw.appendBufs(mh, hdr, data, _CRLF_BYTES_); err != nil { nc.mu.Unlock() return err } @@ -3485,10 +3662,11 @@ func (nc *Conn) subscribeLocked(subj, queue string, cb MsgHandler, ch chan *Msg, // If we have an async callback, start up a sub specific // Go routine to deliver the messages. + var sr bool if cb != nil { sub.typ = AsyncSubscription sub.pCond = sync.NewCond(&sub.mu) - go nc.waitForMsgs(sub) + sr = true } else if !isSync { sub.typ = ChanSubscription sub.mch = ch @@ -3503,14 +3681,16 @@ func (nc *Conn) subscribeLocked(subj, queue string, cb MsgHandler, ch chan *Msg, nc.subs[sub.sid] = sub nc.subsMu.Unlock() + // Let's start the go routine now that it is fully setup and registered. + if sr { + go nc.waitForMsgs(sub) + } + // We will send these for all subs when we reconnect // so that we can suppress here if reconnecting. if !nc.isReconnecting() { - fmt.Fprintf(nc.bw, subProto, subj, queue, sub.sid) - // Kick flusher if needed. - if len(nc.fch) == 0 { - nc.kickFlusher() - } + nc.bw.appendString(fmt.Sprintf(subProto, subj, queue, sub.sid)) + nc.kickFlusher() } return sub, nil @@ -3702,7 +3882,9 @@ func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error { maxStr := _EMPTY_ if max > 0 { + s.mu.Lock() s.max = uint64(max) + s.mu.Unlock() maxStr = strconv.Itoa(max) } else if !drainMode { nc.removeSub(s) @@ -3715,63 +3897,16 @@ func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error { // We will send these for all subs when we reconnect // so that we can suppress here. if !nc.isReconnecting() { - fmt.Fprintf(nc.bw, unsubProto, s.sid, maxStr) + nc.bw.appendString(fmt.Sprintf(unsubProto, s.sid, maxStr)) + nc.kickFlusher() } return nil } -// ErrConsumerSequenceMismatch represents an error from a consumer -// that received a Heartbeat including sequence different to the -// one expected from the view of the client. -type ErrConsumerSequenceMismatch struct { - // StreamResumeSequence is the stream sequence from where the consumer - // should resume consuming from the stream. - StreamResumeSequence uint64 - - // ConsumerSequence is the sequence of the consumer that is behind. - ConsumerSequence int64 - - // LastConsumerSequence is the sequence of the consumer when the heartbeat - // was received. - LastConsumerSequence int64 -} - -func (ecs *ErrConsumerSequenceMismatch) Error() string { - return fmt.Sprintf("nats: sequence mismatch for consumer at sequence %d (%d sequences behind), should restart consumer from stream sequence %d", - ecs.ConsumerSequence, - ecs.LastConsumerSequence-ecs.ConsumerSequence, - ecs.StreamResumeSequence, - ) -} - -// handleConsumerSequenceMismatch will send an async error that can be used to restart a push based consumer. -func (nc *Conn) handleConsumerSequenceMismatch(sub *Subscription, err error) { - nc.mu.Lock() - errCB := nc.Opts.AsyncErrorCB - if errCB != nil { - nc.ach.push(func() { errCB(nc, sub, err) }) - } - nc.mu.Unlock() -} - -func isControlMessage(msg *Msg) bool { - return len(msg.Data) == 0 && msg.Header.Get(statusHdr) == controlMsg -} - -func (jsi *jsSub) trackSequences(msg *Msg) { - var ctrl *controlMetadata - if cmeta := jsi.cmeta.Load(); cmeta == nil { - ctrl = &controlMetadata{} - } else { - ctrl = cmeta.(*controlMetadata) - } - ctrl.meta = msg.Reply - jsi.cmeta.Store(ctrl) -} - // NextMsg will return the next message available to a synchronous subscriber // or block until one is available. An error is returned if the subscription is invalid (ErrBadSubscription), -// the connection is closed (ErrConnectionClosed), or the timeout is reached (ErrTimeout). +// the connection is closed (ErrConnectionClosed), the timeout is reached (ErrTimeout), +// or if there were no responders (ErrNoResponders) when used in the context of a request/reply. func (s *Subscription) NextMsg(timeout time.Duration) (*Msg, error) { if s == nil { return nil, ErrBadSubscription @@ -3875,6 +4010,10 @@ func (s *Subscription) processNextMsgDelivered(msg *Msg) error { // Update some stats. s.delivered++ delivered := s.delivered + if s.jsi != nil && s.jsi.fc && len(s.jsi.fcs) > 0 { + s.checkForFlowControlResponse(delivered) + } + if s.typ == SyncSubscription { s.pMsgs-- s.pBytes -= len(msg.Data) @@ -3892,6 +4031,9 @@ func (s *Subscription) processNextMsgDelivered(msg *Msg) error { nc.mu.Unlock() } } + if len(msg.Data) == 0 && msg.Header.Get(statusHdr) == noResponders { + return ErrNoResponders + } return nil } @@ -4081,9 +4223,9 @@ func (nc *Conn) removeFlushEntry(ch chan struct{}) bool { // The lock must be held entering this function. func (nc *Conn) sendPing(ch chan struct{}) { nc.pongs = append(nc.pongs, ch) - nc.bw.WriteString(pingProto) + nc.bw.appendString(pingProto) // Flush in place. - nc.bw.Flush() + nc.bw.flush() } // This will fire periodically and send a client origin @@ -4180,7 +4322,7 @@ func (nc *Conn) Buffered() (int, error) { if nc.isClosed() || nc.bw == nil { return -1, ErrConnectionClosed } - return nc.bw.Buffered(), nil + return nc.bw.buffered(), nil } // resendSubscriptions will send our subscription state back to the @@ -4206,16 +4348,16 @@ func (nc *Conn) resendSubscriptions() { // reached the max, if so unsubscribe. if adjustedMax == 0 { s.mu.Unlock() - fmt.Fprintf(nc.bw, unsubProto, s.sid, _EMPTY_) + nc.bw.writeDirect(fmt.Sprintf(unsubProto, s.sid, _EMPTY_)) continue } } s.mu.Unlock() - fmt.Fprintf(nc.bw, subProto, s.Subject, s.Queue, s.sid) + nc.bw.writeDirect(fmt.Sprintf(subProto, s.Subject, s.Queue, s.sid)) if adjustedMax > 0 { maxStr := strconv.Itoa(int(adjustedMax)) - fmt.Fprintf(nc.bw, unsubProto, s.sid, maxStr) + nc.bw.writeDirect(fmt.Sprintf(unsubProto, s.sid, maxStr)) } } } @@ -4287,7 +4429,7 @@ func (nc *Conn) close(status Status, doCBs bool, err error) { nc.conn = nil } else if nc.conn != nil { // Go ahead and make sure we have flushed the outbound - nc.bw.Flush() + nc.bw.flush() defer nc.conn.Close() } @@ -4343,6 +4485,12 @@ func (nc *Conn) close(status Status, doCBs bool, err error) { // all blocking calls, such as Flush() and NextMsg() func (nc *Conn) Close() { if nc != nil { + // This will be a no-op if the connection was not websocket. + // We do this here as opposed to inside close() because we want + // to do this only for the final user-driven close of the client. + // Otherwise, we would need to change close() to pass a boolean + // indicating that this is the case. + nc.wsClose() nc.close(CLOSED, !nc.Opts.NoCallbacksAfterClientClose, nil) } } @@ -4382,7 +4530,7 @@ func (nc *Conn) drainConnection() { if nc.isConnecting() || nc.isReconnecting() { nc.mu.Unlock() // Move to closed state. - nc.close(CLOSED, true, nil) + nc.Close() return } @@ -4462,12 +4610,10 @@ func (nc *Conn) drainConnection() { err := nc.FlushTimeout(5 * time.Second) if err != nil { pushErr(err) - nc.close(CLOSED, true, nil) - return } // Move to closed state. - nc.close(CLOSED, true, nil) + nc.Close() } // Drain will put a connection into a drain state. All subscriptions will @@ -4483,7 +4629,7 @@ func (nc *Conn) Drain() error { } if nc.isConnecting() || nc.isReconnecting() { nc.mu.Unlock() - nc.close(CLOSED, true, nil) + nc.Close() return ErrConnectionReconnecting } if nc.isDraining() { diff --git a/vendor/github.com/nats-io/nats.go/ws.go b/vendor/github.com/nats-io/nats.go/ws.go new file mode 100644 index 00000000..eb0c7d88 --- /dev/null +++ b/vendor/github.com/nats-io/nats.go/ws.go @@ -0,0 +1,700 @@ +// Copyright 2021 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nats + +import ( + "bufio" + "bytes" + "compress/flate" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + mrand "math/rand" + "net/http" + "net/url" + "strings" + "time" + "unicode/utf8" +) + +type wsOpCode int + +const ( + // From https://tools.ietf.org/html/rfc6455#section-5.2 + wsTextMessage = wsOpCode(1) + wsBinaryMessage = wsOpCode(2) + wsCloseMessage = wsOpCode(8) + wsPingMessage = wsOpCode(9) + wsPongMessage = wsOpCode(10) + + wsFinalBit = 1 << 7 + wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6 + wsRsv2Bit = 1 << 5 + wsRsv3Bit = 1 << 4 + + wsMaskBit = 1 << 7 + + wsContinuationFrame = 0 + wsMaxFrameHeaderSize = 14 + wsMaxControlPayloadSize = 125 + + // From https://tools.ietf.org/html/rfc6455#section-11.7 + wsCloseStatusNormalClosure = 1000 + wsCloseStatusNoStatusReceived = 1005 + wsCloseStatusAbnormalClosure = 1006 + wsCloseStatusInvalidPayloadData = 1007 + + wsScheme = "ws" + wsSchemeTLS = "wss" + + wsPMCExtension = "permessage-deflate" // per-message compression + wsPMCSrvNoCtx = "server_no_context_takeover" + wsPMCCliNoCtx = "client_no_context_takeover" + wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx +) + +// From https://tools.ietf.org/html/rfc6455#section-1.3 +var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +// As per https://tools.ietf.org/html/rfc7692#section-7.2.2 +// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader +// does not report unexpected EOF. +var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} + +type websocketReader struct { + r io.Reader + pending [][]byte + ib []byte + ff bool + fc bool + dc io.ReadCloser + nc *Conn +} + +type websocketWriter struct { + w io.Writer + compress bool + compressor *flate.Writer + ctrlFrames [][]byte // pending frames that should be sent at the next Write() + cm []byte // close message that needs to be sent when everything else has been sent + cmDone bool // a close message has been added or sent (never going back to false) + noMoreSend bool // if true, even if there is a Write() call, we should not send anything +} + +type decompressorBuffer struct { + buf []byte + rem int + off int + final bool +} + +func newDecompressorBuffer(buf []byte) *decompressorBuffer { + return &decompressorBuffer{buf: buf, rem: len(buf)} +} + +func (d *decompressorBuffer) Read(p []byte) (int, error) { + if d.buf == nil { + return 0, io.EOF + } + lim := d.rem + if len(p) < lim { + lim = len(p) + } + n := copy(p, d.buf[d.off:d.off+lim]) + d.off += n + d.rem -= n + d.checkRem() + return n, nil +} + +func (d *decompressorBuffer) checkRem() { + if d.rem != 0 { + return + } + if !d.final { + d.buf = compressFinalBlock + d.off = 0 + d.rem = len(d.buf) + d.final = true + } else { + d.buf = nil + } +} + +func (d *decompressorBuffer) ReadByte() (byte, error) { + if d.buf == nil { + return 0, io.EOF + } + b := d.buf[d.off] + d.off++ + d.rem-- + d.checkRem() + return b, nil +} + +func wsNewReader(r io.Reader) *websocketReader { + return &websocketReader{r: r, ff: true} +} + +func (r *websocketReader) Read(p []byte) (int, error) { + var err error + var buf []byte + + if l := len(r.ib); l > 0 { + buf = r.ib + r.ib = nil + } else { + if len(r.pending) > 0 { + return r.drainPending(p), nil + } + + // Get some data from the underlying reader. + n, err := r.r.Read(p) + if err != nil { + return 0, err + } + buf = p[:n] + } + + // Now parse this and decode frames. We will possibly read more to + // ensure that we get a full frame. + var ( + tmpBuf []byte + pos int + max = len(buf) + rem = 0 + ) + for pos < max { + b0 := buf[pos] + frameType := wsOpCode(b0 & 0xF) + final := b0&wsFinalBit != 0 + compressed := b0&wsRsv1Bit != 0 + pos++ + + tmpBuf, pos, err = wsGet(r.r, buf, pos, 1) + if err != nil { + return 0, err + } + b1 := tmpBuf[0] + + // Store size in case it is < 125 + rem = int(b1 & 0x7F) + + switch frameType { + case wsPingMessage, wsPongMessage, wsCloseMessage: + if rem > wsMaxControlPayloadSize { + return 0, fmt.Errorf( + fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes", + wsMaxControlPayloadSize)) + } + if compressed { + return 0, errors.New("control frame should not be compressed") + } + if !final { + return 0, errors.New("control frame does not have final bit set") + } + case wsTextMessage, wsBinaryMessage: + if !r.ff { + return 0, errors.New("new message started before final frame for previous message was received") + } + r.ff = final + r.fc = compressed + case wsContinuationFrame: + // Compressed bit must be only set in the first frame + if r.ff || compressed { + return 0, errors.New("invalid continuation frame") + } + r.ff = final + default: + return 0, fmt.Errorf("unknown opcode %v", frameType) + } + + // If the encoded size is <= 125, then `rem` is simply the remainder size of the + // frame. If it is 126, then the actual size is encoded as a uint16. For larger + // frames, `rem` will initially be 127 and the actual size is encoded as a uint64. + switch rem { + case 126: + tmpBuf, pos, err = wsGet(r.r, buf, pos, 2) + if err != nil { + return 0, err + } + rem = int(binary.BigEndian.Uint16(tmpBuf)) + case 127: + tmpBuf, pos, err = wsGet(r.r, buf, pos, 8) + if err != nil { + return 0, err + } + rem = int(binary.BigEndian.Uint64(tmpBuf)) + } + + // Handle control messages in place... + if wsIsControlFrame(frameType) { + pos, err = r.handleControlFrame(frameType, buf, pos, rem) + if err != nil { + return 0, err + } + rem = 0 + continue + } + + var b []byte + b, pos, err = wsGet(r.r, buf, pos, rem) + if err != nil { + return 0, err + } + rem = 0 + if r.fc { + br := newDecompressorBuffer(b) + if r.dc == nil { + r.dc = flate.NewReader(br) + } else { + r.dc.(flate.Resetter).Reset(br, nil) + } + // TODO: When Go 1.15 support is dropped, replace with io.ReadAll() + b, err = ioutil.ReadAll(r.dc) + if err != nil { + return 0, err + } + r.fc = false + } + r.pending = append(r.pending, b) + } + // At this point we should have pending slices. + return r.drainPending(p), nil +} + +func (r *websocketReader) drainPending(p []byte) int { + var n int + var max = len(p) + + for i, buf := range r.pending { + if n+len(buf) <= max { + copy(p[n:], buf) + n += len(buf) + } else { + // Is there room left? + if n < max { + // Write the partial and update this slice. + rem := max - n + copy(p[n:], buf[:rem]) + n += rem + r.pending[i] = buf[rem:] + } + // These are the remaining slices that will need to be used at + // the next Read() call. + r.pending = r.pending[i:] + return n + } + } + r.pending = r.pending[:0] + return n +} + +func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { + avail := len(buf) - pos + if avail >= needed { + return buf[pos : pos+needed], pos + needed, nil + } + b := make([]byte, needed) + start := copy(b, buf[pos:]) + for start != needed { + n, err := r.Read(b[start:cap(b)]) + start += n + if err != nil { + return b, start, err + } + } + return b, pos + avail, nil +} + +func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos, rem int) (int, error) { + var payload []byte + var err error + + statusPos := pos + if rem > 0 { + payload, pos, err = wsGet(r.r, buf, pos, rem) + if err != nil { + return pos, err + } + } + switch frameType { + case wsCloseMessage: + status := wsCloseStatusNoStatusReceived + body := "" + // If there is a payload, it should contain 2 unsigned bytes + // that represent the status code and then optional payload. + if len(payload) >= 2 { + status = int(binary.BigEndian.Uint16(buf[statusPos : statusPos+2])) + body = string(buf[statusPos+2 : statusPos+len(payload)]) + if body != "" && !utf8.ValidString(body) { + // https://tools.ietf.org/html/rfc6455#section-5.5.1 + // If body is present, it must be a valid utf8 + status = wsCloseStatusInvalidPayloadData + body = "invalid utf8 body in close frame" + } + } + r.nc.wsEnqueueCloseMsg(status, body) + // Return io.EOF so that readLoop will close the connection as ClientClosed + // after processing pending buffers. + return pos, io.EOF + case wsPingMessage: + r.nc.wsEnqueueControlMsg(wsPongMessage, payload) + case wsPongMessage: + // Nothing to do.. + } + return pos, nil +} + +func (w *websocketWriter) Write(p []byte) (int, error) { + if w.noMoreSend { + return 0, nil + } + var total int + var n int + var err error + // If there are control frames, they can be sent now. Actually spec says + // that they should be sent ASAP, so we will send before any application data. + if len(w.ctrlFrames) > 0 { + n, err = w.writeCtrlFrames() + if err != nil { + return n, err + } + total += n + } + // Do the following only if there is something to send. + // We will end with checking for need to send close message. + if len(p) > 0 { + if w.compress { + buf := &bytes.Buffer{} + if w.compressor == nil { + w.compressor, _ = flate.NewWriter(buf, flate.BestSpeed) + } else { + w.compressor.Reset(buf) + } + w.compressor.Write(p) + w.compressor.Close() + b := buf.Bytes() + p = b[:len(b)-4] + } + fh, key := wsCreateFrameHeader(w.compress, wsBinaryMessage, len(p)) + wsMaskBuf(key, p) + n, err = w.w.Write(fh) + total += n + if err == nil { + n, err = w.w.Write(p) + total += n + } + } + if err == nil && w.cm != nil { + n, err = w.writeCloseMsg() + total += n + } + return total, err +} + +func (w *websocketWriter) writeCtrlFrames() (int, error) { + var ( + n int + total int + i int + err error + ) + for ; i < len(w.ctrlFrames); i++ { + buf := w.ctrlFrames[i] + n, err = w.w.Write(buf) + total += n + if err != nil { + break + } + } + if i != len(w.ctrlFrames) { + w.ctrlFrames = w.ctrlFrames[i+1:] + } else { + w.ctrlFrames = w.ctrlFrames[:0] + } + return total, err +} + +func (w *websocketWriter) writeCloseMsg() (int, error) { + n, err := w.w.Write(w.cm) + w.cm, w.noMoreSend = nil, true + return n, err +} + +func wsMaskBuf(key, buf []byte) { + for i := 0; i < len(buf); i++ { + buf[i] ^= key[i&3] + } +} + +// Create the frame header. +// Encodes the frame type and optional compression flag, and the size of the payload. +func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { + fh := make([]byte, wsMaxFrameHeaderSize) + n, key := wsFillFrameHeader(fh, compressed, frameType, l) + return fh[:n], key +} + +func wsFillFrameHeader(fh []byte, compressed bool, frameType wsOpCode, l int) (int, []byte) { + var n int + b := byte(frameType) + b |= wsFinalBit + if compressed { + b |= wsRsv1Bit + } + b1 := byte(wsMaskBit) + switch { + case l <= 125: + n = 2 + fh[0] = b + fh[1] = b1 | byte(l) + case l < 65536: + n = 4 + fh[0] = b + fh[1] = b1 | 126 + binary.BigEndian.PutUint16(fh[2:], uint16(l)) + default: + n = 10 + fh[0] = b + fh[1] = b1 | 127 + binary.BigEndian.PutUint64(fh[2:], uint64(l)) + } + var key []byte + var keyBuf [4]byte + if _, err := io.ReadFull(rand.Reader, keyBuf[:4]); err != nil { + kv := mrand.Int31() + binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv)) + } + copy(fh[n:], keyBuf[:4]) + key = fh[n : n+4] + n += 4 + return n, key +} + +func (nc *Conn) wsInitHandshake(u *url.URL) error { + compress := nc.Opts.Compression + tlsRequired := u.Scheme == wsSchemeTLS || nc.Opts.Secure || nc.Opts.TLSConfig != nil + // Do TLS here as needed. + if tlsRequired { + if err := nc.makeTLSConn(); err != nil { + return err + } + } else { + nc.bindToNewConn() + } + + var err error + + // For http request, we need the passed URL to contain either http or https scheme. + scheme := "http" + if tlsRequired { + scheme = "https" + } + ustr := fmt.Sprintf("%s://%s", scheme, u.Host) + u, err = url.Parse(ustr) + if err != nil { + return err + } + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + wsKey, err := wsMakeChallengeKey() + if err != nil { + return err + } + + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{wsKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if compress { + req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue) + } + if err := req.Write(nc.conn); err != nil { + return err + } + + var resp *http.Response + + br := bufio.NewReaderSize(nc.conn, 4096) + nc.conn.SetReadDeadline(time.Now().Add(nc.Opts.Timeout)) + resp, err = http.ReadResponse(br, req) + if err == nil && + (resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) { + + err = fmt.Errorf("invalid websocket connection") + } + // Check compression extension... + if err == nil && compress { + // Check that not only permessage-deflate extension is present, but that + // we also have server and client no context take over. + srvCompress, noCtxTakeover := wsPMCExtensionSupport(resp.Header) + + // If server does not support compression, then simply disable it in our side. + if !srvCompress { + compress = false + } else if !noCtxTakeover { + err = fmt.Errorf("compression negotiation error") + } + } + if resp != nil { + resp.Body.Close() + } + nc.conn.SetReadDeadline(time.Time{}) + if err != nil { + return err + } + + wsr := wsNewReader(nc.br.r) + wsr.nc = nc + // We have to slurp whatever is in the bufio reader and copy to br.r + if n := br.Buffered(); n != 0 { + wsr.ib, _ = br.Peek(n) + } + nc.br.r = wsr + nc.bw.w = &websocketWriter{w: nc.bw.w, compress: compress} + nc.ws = true + return nil +} + +func (nc *Conn) wsClose() { + nc.mu.Lock() + defer nc.mu.Unlock() + if !nc.ws { + return + } + nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_) +} + +func (nc *Conn) wsEnqueueCloseMsg(status int, payload string) { + // In some low-level unit tests it will happen... + if nc == nil { + return + } + nc.mu.Lock() + nc.wsEnqueueCloseMsgLocked(status, payload) + nc.mu.Unlock() +} + +func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) { + wr, ok := nc.bw.w.(*websocketWriter) + if !ok || wr.cmDone { + return + } + statusAndPayloadLen := 2 + len(payload) + frame := make([]byte, 2+4+statusAndPayloadLen) + n, key := wsFillFrameHeader(frame, false, wsCloseMessage, statusAndPayloadLen) + // Set the status + binary.BigEndian.PutUint16(frame[n:], uint16(status)) + // If there is a payload, copy + if len(payload) > 0 { + copy(frame[n+2:], payload) + } + // Mask status + payload + wsMaskBuf(key, frame[n:n+statusAndPayloadLen]) + wr.cm = frame + wr.cmDone = true + nc.bw.flush() +} + +func (nc *Conn) wsEnqueueControlMsg(frameType wsOpCode, payload []byte) { + // In some low-level unit tests it will happen... + if nc == nil { + return + } + fh, key := wsCreateFrameHeader(false, frameType, len(payload)) + nc.mu.Lock() + wr, ok := nc.bw.w.(*websocketWriter) + if !ok { + nc.mu.Unlock() + return + } + wr.ctrlFrames = append(wr.ctrlFrames, fh) + if len(payload) > 0 { + wsMaskBuf(key, payload) + wr.ctrlFrames = append(wr.ctrlFrames, payload) + } + nc.bw.flush() + nc.mu.Unlock() +} + +func wsPMCExtensionSupport(header http.Header) (bool, bool) { + for _, extensionList := range header["Sec-Websocket-Extensions"] { + extensions := strings.Split(extensionList, ",") + for _, extension := range extensions { + extension = strings.Trim(extension, " \t") + params := strings.Split(extension, ";") + for i, p := range params { + p = strings.Trim(p, " \t") + if strings.EqualFold(p, wsPMCExtension) { + var snc bool + var cnc bool + for j := i + 1; j < len(params); j++ { + p = params[j] + p = strings.Trim(p, " \t") + if strings.EqualFold(p, wsPMCSrvNoCtx) { + snc = true + } else if strings.EqualFold(p, wsPMCCliNoCtx) { + cnc = true + } + if snc && cnc { + return true, true + } + } + return true, false + } + } + } + } + return false, false +} + +func wsMakeChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +func wsAcceptKey(key string) string { + h := sha1.New() + h.Write([]byte(key)) + h.Write(wsGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +// Returns true if the op code corresponds to a control frame. +func wsIsControlFrame(frameType wsOpCode) bool { + return frameType >= wsCloseMessage +} + +func isWebsocketScheme(u *url.URL) bool { + return u.Scheme == wsScheme || u.Scheme == wsSchemeTLS +} diff --git a/vendor/modules.txt b/vendor/modules.txt index fa549e09..50825562 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -9,7 +9,7 @@ github.com/minio/highwayhash # github.com/nats-io/jwt/v2 v2.0.2 ## explicit github.com/nats-io/jwt/v2 -# github.com/nats-io/nats.go v1.10.1-0.20210419223411-20527524c393 +# github.com/nats-io/nats.go v1.11.0 ## explicit github.com/nats-io/nats.go github.com/nats-io/nats.go/encoders/builtin