Merge pull request #2613 from nats-io/ws_compression

[FIXED] Websocket: issue with compression and Safari
This commit is contained in:
Ivan Kozlovic
2021-10-12 13:14:04 -06:00
committed by GitHub
2 changed files with 83 additions and 21 deletions

View File

@@ -58,6 +58,7 @@ const (
wsMaxFrameHeaderSize = 14 // Since LeafNode may need to behave as a client
wsMaxControlPayloadSize = 125
wsFrameSizeForBrowsers = 4096 // From experiment, webrowsers behave better with limited frame size
wsCompressThreshold = 64 // Don't compress for small buffer(s)
wsCloseSatusSize = 2
// From https://tools.ietf.org/html/rfc6455#section-11.7
@@ -107,6 +108,7 @@ type websocket struct {
compress bool
closeSent bool
browser bool
nocompfrag bool // No fragment for compressed frames
maskread bool
maskwrite bool
compressor *flate.Writer
@@ -693,7 +695,7 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
return nil, wsReturnHTTPError(w, http.StatusMethodNotAllowed, "request method must be GET")
}
// Point 2.
if r.Host == "" {
if r.Host == _EMPTY_ {
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "'Host' missing in request")
}
// Point 3.
@@ -706,7 +708,7 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
}
// Point 5.
key := r.Header.Get("Sec-Websocket-Key")
if key == "" {
if key == _EMPTY_ {
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "key missing")
}
// Point 6.
@@ -771,10 +773,16 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
ws := &websocket{compress: compress, maskread: !noMasking}
if kind == CLIENT {
// Indicate if this is likely coming from a browser.
if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") {
if ua := r.Header.Get("User-Agent"); ua != _EMPTY_ && strings.HasPrefix(ua, "Mozilla/") {
ws.browser = true
// Disable fragmentation of compressed frames for Safari browsers.
// Unfortunately, you could be running Chrome on macOS and this
// string will contain "Safari/" (along "Chrome/"). However, what
// I have found is that actual Safari browser also have "Version/".
// So make the combination of the two.
ws.nocompfrag = ws.compress && strings.Contains(ua, "Version/") && strings.Contains(ua, "Safari/")
}
if opts.Websocket.JWTCookie != "" {
if opts.Websocket.JWTCookie != _EMPTY_ {
if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil {
ws.cookieJwt = c.Value
}
@@ -926,7 +934,7 @@ func wsAcceptKey(key string) string {
func wsMakeChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
return _EMPTY_, err
}
return base64.StdEncoding.EncodeToString(p), nil
}
@@ -965,7 +973,7 @@ func validateWebsocketOptions(o *Options) error {
}
}
// Using JWT requires Trusted Keys
if wo.JWTCookie != "" {
if wo.JWTCookie != _EMPTY_ {
if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 {
return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie)
}
@@ -1104,7 +1112,7 @@ func (s *Server) startWebsocketServer() {
Addr: hp,
Handler: mux,
ReadTimeout: o.HandshakeTimeout,
ErrorLog: log.New(&wsCaptureHTTPServerLog{s}, "", 0),
ErrorLog: log.New(&wsCaptureHTTPServerLog{s}, _EMPTY_, 0),
}
s.websocket.server = hs
s.websocket.listener = hl
@@ -1251,8 +1259,8 @@ func (cl *wsCaptureHTTPServerLog) Write(p []byte) (int, error) {
func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
var nb net.Buffers
var total = 0
var mfs = 0
var mfs int
var usz int
if c.ws.browser {
mfs = wsFrameSizeForBrowsers
}
@@ -1267,7 +1275,21 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
// Start with possible already framed buffers (that we could have
// got from partials or control messages such as ws pings or pongs).
bufs := c.ws.frames
if c.ws.compress && len(nb) > 0 {
compress := c.ws.compress
if compress && len(nb) > 0 {
// First, make sure we don't compress for very small cumulative buffers.
for _, b := range nb {
usz += len(b)
}
if usz <= wsCompressThreshold {
compress = false
}
}
if compress && len(nb) > 0 {
// Overwrite mfs if this connection does not support fragmented compressed frames.
if mfs > 0 && c.ws.nocompfrag {
mfs = 0
}
buf := &bytes.Buffer{}
cp := c.ws.compressor
@@ -1277,13 +1299,15 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
} else {
cp.Reset(buf)
}
var usz int
var csz int
for _, b := range nb {
usz += len(b)
cp.Write(b)
}
cp.Close()
if err := cp.Flush(); err != nil {
c.Errorf("Error during compression: %v", err)
c.markConnAsClosed(WriteError)
return nil, 0
}
b := buf.Bytes()
p := b[:len(b)-4]
if mfs > 0 && len(p) > mfs {
@@ -1319,6 +1343,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
c.out.pb += int64(csz) - int64(usz)
c.ws.fs += int64(csz)
} else if len(nb) > 0 {
var total int
if mfs > 0 {
// We are limiting the frame size.
startFrame := func() int {

View File

@@ -2838,7 +2838,7 @@ func TestWSCompressionBasic(t *testing.T) {
cbuf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(cbuf, flate.BestSpeed)
compressor.Write([]byte(msgProto))
compressor.Close()
compressor.Flush()
compressed := cbuf.Bytes()
// The last 4 bytes are dropped
compressed = compressed[:len(compressed)-4]
@@ -2907,6 +2907,29 @@ func TestWSCompressionBasic(t *testing.T) {
if !bytes.Equal(body, compressed) {
t.Fatalf("Unexpected compress body: %q", body)
}
wc.mu.Lock()
wc.buf.Reset()
wc.mu.Unlock()
payload = "small"
natsPub(t, nc, "foo", []byte(payload))
msgProto = fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
res = &bytes.Buffer{}
for total := 0; total < len(msgProto); {
l := testWSReadFrame(t, br)
n, _ := res.Write(l)
total += n
}
if !bytes.Equal([]byte(msgProto), res.Bytes()) {
t.Fatalf("Unexpected result: %q", res)
}
wc.mu.RLock()
res = wc.buf
wc.mu.RUnlock()
if !bytes.HasSuffix(res.Bytes(), []byte(msgProto)) {
t.Fatalf("Looks like frame was compressed: %q", res.Bytes())
}
}
func TestWSCompressionWithPartialWrite(t *testing.T) {
@@ -3001,15 +3024,16 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) {
for _, test := range []struct {
name string
maskWrite bool
noLimit bool
}{
{"no write masking", false},
{"write masking", true},
{"no write masking", false, false},
{"write masking", true, false},
} {
t.Run(test.name, func(t *testing.T) {
opts := testWSOptions()
opts.MaxPending = MAX_PENDING_SIZE
s := &Server{opts: opts}
c := &client{srv: s, ws: &websocket{compress: true, browser: true, maskwrite: test.maskWrite}}
c := &client{srv: s, ws: &websocket{compress: true, browser: true, nocompfrag: test.noLimit, maskwrite: test.maskWrite}}
c.initClient()
uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers)
@@ -3022,13 +3046,19 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) {
nb, _ := c.collapsePtoNB()
c.mu.Unlock()
if test.noLimit && len(nb) != 2 {
t.Fatalf("There should be only 2 buffers, the header and payload, got %v", len(nb))
}
bb := &bytes.Buffer{}
var key []byte
for i, b := range nb {
// frame header buffer are always very small. The payload should not be more
// than 10 bytes since that is what we passed as the limit.
if len(b) > wsFrameSizeForBrowsers {
t.Fatalf("Frame size too big: %v (%q)", len(b), b)
if !test.noLimit {
// frame header buffer are always very small. The payload should not be more
// than 10 bytes since that is what we passed as the limit.
if len(b) > wsFrameSizeForBrowsers {
t.Fatalf("Frame size too big: %v (%q)", len(b), b)
}
}
if test.maskWrite {
if i%2 == 0 {
@@ -3048,6 +3078,13 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) {
t.Fatalf("Compressed bit missing")
}
} else {
if test.noLimit {
// Since the payload is likely not well compressed, we are expecting
// the length to be > wsFrameSizeForBrowsers
if len(b) <= wsFrameSizeForBrowsers {
t.Fatalf("Expected frame to be bigger, got %v", len(b))
}
}
// Collect the payload
bb.Write(b)
}