diff --git a/server/leafnode.go b/server/leafnode.go index d1e11960..129770f5 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -2047,10 +2047,10 @@ func (c *client) leafNodeSolicitWSConnection(opts *Options, rURL *url.URL, remot req.Header["Sec-WebSocket-Key"] = []string{wsKey} req.Header["Sec-WebSocket-Version"] = []string{"13"} if compress { - req.Header.Add("Sec-WebSocket-Extensions", wsPMCExtension+wsNoCtxTakeOver) + req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue) } if noMasking { - req.Header.Add("Sec-WebSocket-Extensions", wsNoMaskingExtension) + req.Header.Add(wsNoMaskingHeader, wsNoMaskingValue) } if err := req.Write(c.nc); err != nil { return nil, WriteError, err @@ -2069,22 +2069,24 @@ func (c *client) leafNodeSolicitWSConnection(opts *Options, rURL *url.URL, remot err = fmt.Errorf("invalid websocket connection") } - if err == nil && (c.ws.compress || noMasking) { - // Check extensions... + // Check compression extension... + if err == nil && c.ws.compress { + // Check that not only permessage-deflate extension is present, but that + // we also have server and client no context take over. + srvCompress, noCtxTakeover := wsPMCExtensionSupport(resp.Header, false) - srvCompress, srvNoMasking := wsClientWantedExtensions(resp.Header) - - // We said to the otherside that we support compression. Now check that - // the other side said that it supports compression too. - if c.ws.compress && !srvCompress { - // No extension, or does not contain the indication that per-message - // compression is supported, so disable on our side. + // If server does not support compression, then simply disable it in our side. + if !srvCompress { c.ws.compress = false + } else if !noCtxTakeover { + err = fmt.Errorf("compression negotiation error") } - - // Same for no masking... - if noMasking && !srvNoMasking { - // Need to mask our writes as any client would do. + } + // Same for no masking... + if err == nil && noMasking { + // Check if server accepts no masking + if resp.Header.Get(wsNoMaskingHeader) != wsNoMaskingValue { + // Nope, need to mask our writes as any client would do. c.ws.maskwrite = true } } diff --git a/server/leafnode_test.go b/server/leafnode_test.go index f279f94e..788a8f2d 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -2621,6 +2621,7 @@ func TestLeafNodeWSNoMaskingRejected(t *testing.T) { defer s.Shutdown() lo := testDefaultRemoteLeafNodeWSOptions(t, o, false) + lo.LeafNode.Remotes[0].Websocket.NoMasking = true ln := RunServer(lo) defer ln.Shutdown() diff --git a/server/websocket.go b/server/websocket.go index abf7bc6d..6e074642 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -81,9 +81,14 @@ const ( wsSchemePrefix = "ws" wsSchemePrefixTLS = "wss" - wsNoMaskingExtension = "no-masking" - wsPMCExtension = "permessage-deflate" // per-message compression - wsNoCtxTakeOver = "; server_no_context_takeover; client_no_context_takeover; " + wsNoMaskingHeader = "Nats-No-Masking" + wsNoMaskingValue = "true" + wsNoMaskingFullResponse = wsNoMaskingHeader + ": " + wsNoMaskingValue + CR_LF + wsPMCExtension = "permessage-deflate" // per-message compression + wsPMCSrvNoCtx = "server_no_context_takeover" + wsPMCCliNoCtx = "client_no_context_takeover" + wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx + wsPMCFullResponse = "Sec-WebSocket-Extensions: " + wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx + _CRLF_ ) var decompressorPool sync.Pool @@ -631,12 +636,14 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe // Point 8. // We don't have protocols, so ignore. // Point 9. - // Extensions, only support for compression and no-masking at the moment - wantsCompress, wantsNoMasking := wsClientWantedExtensions(r.Header) - // We will use compression only if both agree - compress := opts.Websocket.Compression && wantsCompress + // Extensions, only support for compression at the moment + compress := opts.Websocket.Compression + if compress { + // Simply check if permessage-deflate extension is present. + compress, _ = wsPMCExtensionSupport(r.Header, true) + } // We will do masking if asked (unless we reject for tests) - noMasking := wantsNoMasking && !wsTestRejectNoMasking + noMasking := r.Header.Get(wsNoMaskingHeader) == wsNoMaskingValue && !wsTestRejectNoMasking h := w.(http.Hijacker) conn, brw, err := h.Hijack() @@ -658,16 +665,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, wsAcceptKey(key)...) p = append(p, _CRLF_...) - if compress || noMasking { - p = append(p, "Sec-WebSocket-Extensions: "...) - if compress { - p = append(p, wsPMCExtension...) - p = append(p, wsNoCtxTakeOver...) - } - if noMasking { - p = append(p, wsNoMaskingExtension...) - } - p = append(p, CR_LF...) + if compress { + p = append(p, wsPMCFullResponse...) + } + if noMasking { + p = append(p, wsNoMaskingFullResponse...) } p = append(p, _CRLF_...) @@ -710,28 +712,38 @@ func wsHeaderContains(header http.Header, name string, value string) bool { return false } -// Return if known extensions are wanted by the client. -func wsClientWantedExtensions(header http.Header) (bool, bool) { - var compress bool - var noMasking bool - +func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) { for _, extensionList := range header["Sec-Websocket-Extensions"] { extensions := strings.Split(extensionList, ",") for _, extension := range extensions { extension = strings.Trim(extension, " \t") params := strings.Split(extension, ";") - for _, p := range params { - p = strings.ToLower(strings.Trim(p, " \t")) - switch p { - case wsPMCExtension: - compress = true - case wsNoMaskingExtension: - noMasking = true + for i, p := range params { + p = strings.Trim(p, " \t") + if strings.EqualFold(p, wsPMCExtension) { + if checkPMCOnly { + return true, false + } + var snc bool + var cnc bool + for j := i + 1; j < len(params); j++ { + p = params[j] + p = strings.Trim(p, " \t") + if strings.EqualFold(p, wsPMCSrvNoCtx) { + snc = true + } else if strings.EqualFold(p, wsPMCCliNoCtx) { + cnc = true + } + if snc && cnc { + return true, true + } + } + return true, false } } } } - return compress, noMasking + return false, false } // Send an HTTP error with the given `status`` to the given http response writer `w`.