// Copyright 2020 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "bufio" "bytes" "compress/flate" "crypto/tls" "encoding/base64" "encoding/binary" "encoding/json" "errors" "fmt" "io" "math/rand" "net" "net/http" "net/url" "reflect" "strings" "sync" "testing" "time" "github.com/nats-io/jwt/v2" "github.com/nats-io/nkeys" ) type testReader struct { buf []byte pos int max int err error } func (tr *testReader) Read(p []byte) (int, error) { if tr.err != nil { return 0, tr.err } n := len(tr.buf) - tr.pos if n == 0 { return 0, nil } if n > len(p) { n = len(p) } if tr.max > 0 && n > tr.max { n = tr.max } copy(p, tr.buf[tr.pos:tr.pos+n]) tr.pos += n return n, nil } func TestWSGet(t *testing.T) { rb := []byte("012345") tr := &testReader{buf: []byte("6789")} for _, test := range []struct { name string pos int needed int newpos int trmax int result string reterr bool }{ {"fromrb1", 0, 3, 3, 4, "012", false}, // Partial from read buffer {"fromrb2", 3, 2, 5, 4, "34", false}, // Partial from read buffer {"fromrb3", 5, 1, 6, 4, "5", false}, // Partial from read buffer {"fromtr1", 4, 4, 6, 4, "4567", false}, // Partial from read buffer + some of ioReader {"fromtr2", 4, 6, 6, 4, "456789", false}, // Partial from read buffer + all of ioReader {"fromtr3", 4, 6, 6, 2, "456789", false}, // Partial from read buffer + all of ioReader with several reads {"fromtr4", 4, 6, 6, 2, "", true}, // ioReader returns error } { t.Run(test.name, func(t *testing.T) { tr.pos = 0 tr.max = test.trmax if test.reterr { tr.err = fmt.Errorf("on purpose") } res, np, err := wsGet(tr, rb, test.pos, test.needed) if test.reterr { if err == nil { t.Fatalf("Expected error, got none") } if err.Error() != "on purpose" { t.Fatalf("Unexpected error: %v", err) } if np != 0 || res != nil { t.Fatalf("Unexpected returned values: res=%v n=%v", res, np) } return } if err != nil { t.Fatalf("Error on get: %v", err) } if np != test.newpos { t.Fatalf("Expected pos=%v, got %v", test.newpos, np) } if string(res) != test.result { t.Fatalf("Invalid returned content: %s", res) } }) } } func TestWSIsControlFrame(t *testing.T) { for _, test := range []struct { name string code wsOpCode isControl bool }{ {"binary", wsBinaryMessage, false}, {"text", wsTextMessage, false}, {"ping", wsPingMessage, true}, {"pong", wsPongMessage, true}, {"close", wsCloseMessage, true}, } { t.Run(test.name, func(t *testing.T) { if res := wsIsControlFrame(test.code); res != test.isControl { t.Fatalf("Expected %q isControl to be %v, got %v", test.name, test.isControl, res) } }) } } func testWSSimpleMask(key, buf []byte) { for i := 0; i < len(buf); i++ { buf[i] ^= key[i&3] } } func TestWSUnmask(t *testing.T) { key := []byte{1, 2, 3, 4} orgBuf := []byte("this is a clear text") mask := func() []byte { t.Helper() buf := append([]byte(nil), orgBuf...) testWSSimpleMask(key, buf) // First ensure that the content is masked. if bytes.Equal(buf, orgBuf) { t.Fatalf("Masking did not do anything: %q", buf) } return buf } ri := &wsReadInfo{mask: true} ri.init() copy(ri.mkey[:], key) buf := mask() // Unmask in one call ri.unmask(buf) if !bytes.Equal(buf, orgBuf) { t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf) } // Unmask in multiple calls buf = mask() ri.mkpos = 0 ri.unmask(buf[:3]) ri.unmask(buf[3:11]) ri.unmask(buf[11:]) if !bytes.Equal(buf, orgBuf) { t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf) } } func TestWSCreateCloseMessage(t *testing.T) { for _, test := range []struct { name string status int psize int truncated bool }{ {"fits", wsCloseStatusInternalSrvError, 10, false}, {"truncated", wsCloseStatusProtocolError, wsMaxControlPayloadSize + 10, true}, } { t.Run(test.name, func(t *testing.T) { payload := make([]byte, test.psize) for i := 0; i < len(payload); i++ { payload[i] = byte('A' + (i % 26)) } res := wsCreateCloseMessage(test.status, string(payload)) if status := binary.BigEndian.Uint16(res[:2]); int(status) != test.status { t.Fatalf("Expected status to be %v, got %v", test.status, status) } psize := len(res) - 2 if !test.truncated { if int(psize) != test.psize { t.Fatalf("Expected size to be %v, got %v", test.psize, psize) } if !bytes.Equal(res[2:], payload) { t.Fatalf("Unexpected result: %q", res[2:]) } return } // Since the payload of a close message contains a 2 byte status, the // actual max text size will be wsMaxControlPayloadSize-2 if int(psize) != wsMaxControlPayloadSize-2 { t.Fatalf("Expected size to be capped to %v, got %v", wsMaxControlPayloadSize-2, psize) } if string(res[len(res)-3:]) != "..." { t.Fatalf("Expected res to have `...` at the end, got %q", res[4:]) } }) } } func TestWSCreateFrameHeader(t *testing.T) { for _, test := range []struct { name string frameType wsOpCode compressed bool len int }{ {"uncompressed 10", wsBinaryMessage, false, 10}, {"uncompressed 600", wsTextMessage, false, 600}, {"uncompressed 100000", wsTextMessage, false, 100000}, {"compressed 10", wsBinaryMessage, true, 10}, {"compressed 600", wsBinaryMessage, true, 600}, {"compressed 100000", wsTextMessage, true, 100000}, } { t.Run(test.name, func(t *testing.T) { res, _ := wsCreateFrameHeader(false, test.compressed, test.frameType, test.len) // The server is always sending the message has a single frame, // so the "final" bit should be set. expected := byte(test.frameType) | wsFinalBit if test.compressed { expected |= wsRsv1Bit } if b := res[0]; b != expected { t.Fatalf("Expected first byte to be %v, got %v", expected, b) } switch { case test.len <= 125: if len(res) != 2 { t.Fatalf("Frame len should be 2, got %v", len(res)) } if res[1] != byte(test.len) { t.Fatalf("Expected len to be in second byte and be %v, got %v", test.len, res[1]) } case test.len < 65536: // 1+1+2 if len(res) != 4 { t.Fatalf("Frame len should be 4, got %v", len(res)) } if res[1] != 126 { t.Fatalf("Second byte value should be 126, got %v", res[1]) } if rl := binary.BigEndian.Uint16(res[2:]); int(rl) != test.len { t.Fatalf("Expected len to be %v, got %v", test.len, rl) } default: // 1+1+8 if len(res) != 10 { t.Fatalf("Frame len should be 10, got %v", len(res)) } if res[1] != 127 { t.Fatalf("Second byte value should be 127, got %v", res[1]) } if rl := binary.BigEndian.Uint64(res[2:]); int(rl) != test.len { t.Fatalf("Expected len to be %v, got %v", test.len, rl) } } }) } } func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed bool, payload []byte) []byte { if compressed { buf := &bytes.Buffer{} compressor, _ := flate.NewWriter(buf, 1) compressor.Write(payload) compressor.Flush() payload = buf.Bytes() // The last 4 bytes are dropped payload = payload[:len(payload)-4] } frame := make([]byte, 14+len(payload)) if frameNum == 1 { frame[0] = byte(frameType) } if final { frame[0] |= wsFinalBit } if compressed { frame[0] |= wsRsv1Bit } pos := 1 lenPayload := len(payload) switch { case lenPayload <= 125: frame[pos] = byte(lenPayload) | wsMaskBit pos++ case lenPayload < 65536: frame[pos] = 126 | wsMaskBit binary.BigEndian.PutUint16(frame[2:], uint16(lenPayload)) pos += 3 default: frame[1] = 127 | wsMaskBit binary.BigEndian.PutUint64(frame[2:], uint64(lenPayload)) pos += 9 } key := []byte{1, 2, 3, 4} copy(frame[pos:], key) pos += 4 copy(frame[pos:], payload) testWSSimpleMask(key, frame[pos:]) pos += lenPayload return frame[:pos] } func testWSSetupForRead() (*client, *wsReadInfo, *testReader) { ri := &wsReadInfo{mask: true} ri.init() tr := &testReader{} opts := DefaultOptions() opts.MaxPending = MAX_PENDING_SIZE s := &Server{opts: opts} c := &client{srv: s, ws: &websocket{}} c.initClient() return c, ri, tr } func TestWSReadUncompressedFrames(t *testing.T) { c, ri, tr := testWSSetupForRead() // Create 2 WS messages pl1 := []byte("first message") wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl1) pl2 := []byte("second message") wsmsg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl2) // Add both in single buffer orgrb := append([]byte(nil), wsmsg1...) orgrb = append(orgrb, wsmsg2...) rb := append([]byte(nil), orgrb...) bufs, err := c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 2 { t.Fatalf("Expected 2 buffers, got %v", n) } if !bytes.Equal(bufs[0], pl1) { t.Fatalf("Unexpected content for buffer 1: %s", bufs[0]) } if !bytes.Equal(bufs[1], pl2) { t.Fatalf("Unexpected content for buffer 2: %s", bufs[1]) } // Now reset and try with the read buffer not containing full ws frame c, ri, tr = testWSSetupForRead() rb = append([]byte(nil), orgrb...) // Frame is 1+1+4+'first message'. So say we pass with rb of 11 bytes, // then we should get "first" bufs, err = c.wsRead(ri, tr, rb[:11]) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } if string(bufs[0]) != "first" { t.Fatalf("Unexpected content: %q", bufs[0]) } // Call again with more data.. bufs, err = c.wsRead(ri, tr, rb[11:32]) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 2 { t.Fatalf("Unexpected buffer returned: %v", n) } if string(bufs[0]) != " message" { t.Fatalf("Unexpected content: %q", bufs[0]) } if string(bufs[1]) != "second " { t.Fatalf("Unexpected content: %q", bufs[1]) } // Call with the rest bufs, err = c.wsRead(ri, tr, rb[32:]) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } if string(bufs[0]) != "message" { t.Fatalf("Unexpected content: %q", bufs[0]) } } func TestWSReadCompressedFrames(t *testing.T) { c, ri, tr := testWSSetupForRead() uncompressed := []byte("this is the uncompress data") wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed) rb := append([]byte(nil), wsmsg1...) // Call with some but not all of the payload bufs, err := c.wsRead(ri, tr, rb[:10]) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Unexpected buffer returned: %v", n) } // Call with the rest, only then should we get the uncompressed data. bufs, err = c.wsRead(ri, tr, rb[10:]) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } if !bytes.Equal(bufs[0], uncompressed) { t.Fatalf("Unexpected content: %s", bufs[0]) } // Stress the fact that we use a pool and want to make sure // that if we get a decompressor from the pool, it is properly reset // with the buffer to decompress. // Since we unmask the read buffer, reset it now and fill it // with 10 compressed frames. rb = nil for i := 0; i < 10; i++ { rb = append(rb, wsmsg1...) } bufs, err = c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 10 { t.Fatalf("Unexpected buffer returned: %v", n) } // Compress a message and send it in several frames. buf := &bytes.Buffer{} compressor, _ := flate.NewWriter(buf, 1) compressor.Write(uncompressed) compressor.Flush() compressed := buf.Bytes() // The last 4 bytes are dropped compressed = compressed[:len(compressed)-4] ncomp := 10 frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp]) frag1[0] |= wsRsv1Bit frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:]) rb = append([]byte(nil), frag1...) rb = append(rb, frag2...) bufs, err = c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } if !bytes.Equal(bufs[0], uncompressed) { t.Fatalf("Unexpected content: %s", bufs[0]) } } func TestWSReadCompressedFrameCorrupted(t *testing.T) { c, ri, tr := testWSSetupForRead() uncompressed := []byte("this is the uncompress data") wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed) copy(wsmsg1[10:], []byte{1, 2, 3, 4}) rb := append([]byte(nil), wsmsg1...) bufs, err := c.wsRead(ri, tr, rb) if err == nil || !strings.Contains(err.Error(), "corrupt") { t.Fatalf("Expected error about corrupted data, got %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Expected no buffer, got %v", n) } } func TestWSReadVariousFrameSizes(t *testing.T) { for _, test := range []struct { name string size int }{ {"tiny", 100}, {"medium", 1000}, {"large", 70000}, } { t.Run(test.name, func(t *testing.T) { c, ri, tr := testWSSetupForRead() uncompressed := make([]byte, test.size) for i := 0; i < len(uncompressed); i++ { uncompressed[i] = 'A' + byte(i%26) } wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, uncompressed) rb := append([]byte(nil), wsmsg1...) bufs, err := c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } if !bytes.Equal(bufs[0], uncompressed) { t.Fatalf("Unexpected content: %s", bufs[0]) } }) } } func TestWSReadFragmentedFrames(t *testing.T) { c, ri, tr := testWSSetupForRead() payloads := []string{"first", "second", "third"} var rb []byte for i := 0; i < len(payloads); i++ { final := i == len(payloads)-1 frag := testWSCreateClientMsg(wsBinaryMessage, i+1, final, false, []byte(payloads[i])) rb = append(rb, frag...) } bufs, err := c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 3 { t.Fatalf("Unexpected buffer returned: %v", n) } for i, expected := range payloads { if string(bufs[i]) != expected { t.Fatalf("Unexpected content for buf=%v: %s", i, bufs[i]) } } } func TestWSReadPartialFrameHeaderAtEndOfReadBuffer(t *testing.T) { c, ri, tr := testWSSetupForRead() msg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg1")) msg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg2")) rb := append([]byte(nil), msg1...) rb = append(rb, msg2...) // We will pass the first frame + the first byte of the next frame. rbl := rb[:len(msg1)+1] // Make the io reader return the rest of the frame tr.buf = rb[len(msg1)+1:] bufs, err := c.wsRead(ri, tr, rbl) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } // We should not have asked to the io reader more than what is needed for reading // the frame header. Since we had already the first byte in the read buffer, // tr.pos should be 1(size)+4(key)=5 if tr.pos != 5 { t.Fatalf("Expected reader pos to be 5, got %v", tr.pos) } } func TestWSReadPingFrame(t *testing.T) { for _, test := range []struct { name string payload []byte }{ {"without payload", nil}, {"with payload", []byte("optional payload")}, } { t.Run(test.name, func(t *testing.T) { c, ri, tr := testWSSetupForRead() ping := testWSCreateClientMsg(wsPingMessage, 1, true, false, test.payload) rb := append([]byte(nil), ping...) bufs, err := c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Unexpected buffer returned: %v", n) } // A PONG should have been queued with the payload of the ping c.mu.Lock() nb, _ := c.collapsePtoNB() c.mu.Unlock() if n := len(nb); n == 0 { t.Fatalf("Expected buffers, got %v", n) } if expected := 2 + len(test.payload); expected != len(nb[0]) { t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) } b := nb[0][0] if b&wsFinalBit == 0 { t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) } if b&byte(wsPongMessage) == 0 { t.Fatalf("Should have been a PONG, it wasn't: %v", b) } if len(test.payload) > 0 { if !bytes.Equal(nb[0][2:], test.payload) { t.Fatalf("Unexpected content: %s", nb[0][2:]) } } }) } } func TestWSReadPongFrame(t *testing.T) { for _, test := range []struct { name string payload []byte }{ {"without payload", nil}, {"with payload", []byte("optional payload")}, } { t.Run(test.name, func(t *testing.T) { c, ri, tr := testWSSetupForRead() pong := testWSCreateClientMsg(wsPongMessage, 1, true, false, test.payload) rb := append([]byte(nil), pong...) bufs, err := c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Unexpected buffer returned: %v", n) } // Nothing should be sent... c.mu.Lock() nb, _ := c.collapsePtoNB() c.mu.Unlock() if n := len(nb); n != 0 { t.Fatalf("Expected no buffer, got %v", n) } }) } } func TestWSReadCloseFrame(t *testing.T) { for _, test := range []struct { name string payload []byte }{ {"without payload", nil}, {"with payload", []byte("optional payload")}, } { t.Run(test.name, func(t *testing.T) { c, ri, tr := testWSSetupForRead() // a close message has a status in 2 bytes + optional payload payload := make([]byte, 2+len(test.payload)) binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) if len(test.payload) > 0 { copy(payload[2:], test.payload) } close := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) // Have a normal frame prior to close to make sure that wsRead returns // the normal frame along with io.EOF to indicate that wsCloseMessage was received. msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg")) rb := append([]byte(nil), msg...) rb = append(rb, close...) bufs, err := c.wsRead(ri, tr, rb) // It is expected that wsRead returns io.EOF on processing a close. if err != io.EOF { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } if string(bufs[0]) != "msg" { t.Fatalf("Unexpected content: %s", bufs[0]) } // A CLOSE should have been queued with the payload of the original close message. c.mu.Lock() nb, _ := c.collapsePtoNB() c.mu.Unlock() if n := len(nb); n == 0 { t.Fatalf("Expected buffers, got %v", n) } if expected := 2 + 2 + len(test.payload); expected != len(nb[0]) { t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) } b := nb[0][0] if b&wsFinalBit == 0 { t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) } if b&byte(wsCloseMessage) == 0 { t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) } if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure { t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status) } if len(test.payload) > 0 { if !bytes.Equal(nb[0][4:], test.payload) { t.Fatalf("Unexpected content: %s", nb[0][4:]) } } }) } } func TestWSReadControlFrameBetweebFragmentedFrames(t *testing.T) { c, ri, tr := testWSSetupForRead() frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("first")) frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, []byte("second")) ctrl := testWSCreateClientMsg(wsPongMessage, 1, true, false, nil) rb := append([]byte(nil), frag1...) rb = append(rb, ctrl...) rb = append(rb, frag2...) bufs, err := c.wsRead(ri, tr, rb) if err != nil { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 2 { t.Fatalf("Unexpected buffer returned: %v", n) } if string(bufs[0]) != "first" { t.Fatalf("Unexpected content: %s", bufs[0]) } if string(bufs[1]) != "second" { t.Fatalf("Unexpected content: %s", bufs[1]) } } func TestWSCloseFrameWithPartialOrInvalid(t *testing.T) { c, ri, tr := testWSSetupForRead() // a close message has a status in 2 bytes + optional payload payloadTxt := []byte("hello") payload := make([]byte, 2+len(payloadTxt)) binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) copy(payload[2:], payloadTxt) closeMsg := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) // We will pass to wsRead a buffer of small capacity that contains // only 1 byte. closeFirtByte := []byte{closeMsg[0]} // Make the io reader return the rest of the frame tr.buf = closeMsg[1:] bufs, err := c.wsRead(ri, tr, closeFirtByte[:]) // It is expected that wsRead returns io.EOF on processing a close. if err != io.EOF { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Unexpected buffer returned: %v", n) } // A CLOSE should have been queued with the payload of the original close message. c.mu.Lock() nb, _ := c.collapsePtoNB() c.mu.Unlock() if n := len(nb); n == 0 { t.Fatalf("Expected buffers, got %v", n) } if expected := 2 + 2 + len(payloadTxt); expected != len(nb[0]) { t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) } b := nb[0][0] if b&wsFinalBit == 0 { t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) } if b&byte(wsCloseMessage) == 0 { t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) } if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure { t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status) } if !bytes.Equal(nb[0][4:], payloadTxt) { t.Fatalf("Unexpected content: %s", nb[0][4:]) } // Now test close with invalid status size (1 instead of 2 bytes) c, ri, tr = testWSSetupForRead() payload[0] = 100 binary.BigEndian.PutUint16(payload, wsCloseStatusNormalClosure) closeMsg = testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload[:1]) // We will pass to wsRead a buffer of small capacity that contains // only 1 byte. closeFirtByte = []byte{closeMsg[0]} // Make the io reader return the rest of the frame tr.buf = closeMsg[1:] bufs, err = c.wsRead(ri, tr, closeFirtByte[:]) // It is expected that wsRead returns io.EOF on processing a close. if err != io.EOF { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Unexpected buffer returned: %v", n) } // A CLOSE should have been queued with the payload of the original close message. c.mu.Lock() nb, _ = c.collapsePtoNB() c.mu.Unlock() if n := len(nb); n == 0 { t.Fatalf("Expected buffers, got %v", n) } if expected := 2 + 2; expected != len(nb[0]) { t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) } b = nb[0][0] if b&wsFinalBit == 0 { t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) } if b&byte(wsCloseMessage) == 0 { t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) } // Since satus was not valid, we should get wsCloseStatusNoStatusReceived if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNoStatusReceived { t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNoStatusReceived, status) } if len(nb[0][:]) != 4 { t.Fatalf("Unexpected content: %s", nb[0][2:]) } } func TestWSReadGetErrors(t *testing.T) { tr := &testReader{err: fmt.Errorf("on purpose")} for _, test := range []struct { lenPayload int rbextra int }{ {10, 1}, {10, 3}, {200, 1}, {200, 2}, {200, 5}, {70000, 1}, {70000, 5}, {70000, 13}, } { t.Run("", func(t *testing.T) { c, ri, _ := testWSSetupForRead() msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg")) frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, make([]byte, test.lenPayload)) rb := append([]byte(nil), msg...) rb = append(rb, frame...) bufs, err := c.wsRead(ri, tr, rb[:len(msg)+test.rbextra]) if err == nil || err.Error() != "on purpose" { t.Fatalf("Expected 'on purpose' error, got %v", err) } if n := len(bufs); n != 1 { t.Fatalf("Unexpected buffer returned: %v", n) } if string(bufs[0]) != "msg" { t.Fatalf("Unexpected content: %s", bufs[0]) } }) } } func TestWSHandleControlFrameErrors(t *testing.T) { c, ri, tr := testWSSetupForRead() tr.err = fmt.Errorf("on purpose") // a close message has a status in 2 bytes + optional payload text := []byte("this is a close message") payload := make([]byte, 2+len(text)) binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) copy(payload[2:], text) ctrl := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) bufs, err := c.wsRead(ri, tr, ctrl[:len(ctrl)-4]) if err == nil || err.Error() != "on purpose" { t.Fatalf("Expected 'on purpose' error, got %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Unexpected buffer returned: %v", n) } // Alter the content of close message. It is supposed to be valid utf-8. c, ri, tr = testWSSetupForRead() cp := append([]byte(nil), payload...) cp[10] = 0xF1 ctrl = testWSCreateClientMsg(wsCloseMessage, 1, true, false, cp) bufs, err = c.wsRead(ri, tr, ctrl) // We should still receive an EOF but the message enqueued to the client // should contain wsCloseStatusInvalidPayloadData and the error about invalid utf8 if err != io.EOF { t.Fatalf("Unexpected error: %v", err) } if n := len(bufs); n != 0 { t.Fatalf("Unexpected buffer returned: %v", n) } c.mu.Lock() nb, _ := c.collapsePtoNB() c.mu.Unlock() if n := len(nb); n == 0 { t.Fatalf("Expected buffers, got %v", n) } b := nb[0][0] if b&wsFinalBit == 0 { t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) } if b&byte(wsCloseMessage) == 0 { t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) } if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusInvalidPayloadData { t.Fatalf("Expected status to be %v, got %v", wsCloseStatusInvalidPayloadData, status) } if !bytes.Contains(nb[0][4:], []byte("utf8")) { t.Fatalf("Unexpected content: %s", nb[0][4:]) } } func TestWSReadErrors(t *testing.T) { for _, test := range []struct { cframe func() []byte err string nbufs int }{ { func() []byte { msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello")) msg[1] &= ^byte(wsMaskBit) return msg }, "mask bit missing", 1, }, { func() []byte { return testWSCreateClientMsg(wsPingMessage, 1, true, false, make([]byte, 200)) }, "control frame length bigger than maximum allowed", 1, }, { func() []byte { return testWSCreateClientMsg(wsPingMessage, 1, false, false, []byte("hello")) }, "control frame does not have final bit set", 1, }, { func() []byte { frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("frag1")) newMsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("new message")) all := append([]byte(nil), frag1...) all = append(all, newMsg...) return all }, "new message started before final frame for previous message was received", 2, }, { func() []byte { frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame")) frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation")) all := append([]byte(nil), frame...) all = append(all, frag...) return all }, "invalid continuation frame", 2, }, { func() []byte { return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame")) }, "invalid continuation frame", 1, }, { func() []byte { return testWSCreateClientMsg(99, 1, false, false, []byte("hello")) }, "unknown opcode", 1, }, } { t.Run(test.err, func(t *testing.T) { c, ri, tr := testWSSetupForRead() // Add a valid message first msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello")) // Then add the bad frame bad := test.cframe() // Add them both to a read buffer rb := append([]byte(nil), msg...) rb = append(rb, bad...) bufs, err := c.wsRead(ri, tr, rb) if err == nil || !strings.Contains(err.Error(), test.err) { t.Fatalf("Expected error to contain %q, got %q", test.err, err.Error()) } if n := len(bufs); n != test.nbufs { t.Fatalf("Unexpected number of buffers: %v", n) } if string(bufs[0]) != "hello" { t.Fatalf("Unexpected content: %s", bufs[0]) } }) } } func TestWSEnqueueCloseMsg(t *testing.T) { for _, test := range []struct { reason ClosedState status int }{ {ClientClosed, wsCloseStatusNormalClosure}, {AuthenticationTimeout, wsCloseStatusPolicyViolation}, {AuthenticationViolation, wsCloseStatusPolicyViolation}, {SlowConsumerPendingBytes, wsCloseStatusPolicyViolation}, {SlowConsumerWriteDeadline, wsCloseStatusPolicyViolation}, {MaxAccountConnectionsExceeded, wsCloseStatusPolicyViolation}, {MaxConnectionsExceeded, wsCloseStatusPolicyViolation}, {MaxControlLineExceeded, wsCloseStatusPolicyViolation}, {MaxSubscriptionsExceeded, wsCloseStatusPolicyViolation}, {MissingAccount, wsCloseStatusPolicyViolation}, {AuthenticationExpired, wsCloseStatusPolicyViolation}, {Revocation, wsCloseStatusPolicyViolation}, {TLSHandshakeError, wsCloseStatusTLSHandshake}, {ParseError, wsCloseStatusProtocolError}, {ProtocolViolation, wsCloseStatusProtocolError}, {BadClientProtocolVersion, wsCloseStatusProtocolError}, {MaxPayloadExceeded, wsCloseStatusMessageTooBig}, {ServerShutdown, wsCloseStatusGoingAway}, {WriteError, wsCloseStatusAbnormalClosure}, {ReadError, wsCloseStatusAbnormalClosure}, {StaleConnection, wsCloseStatusAbnormalClosure}, {ClosedState(254), wsCloseStatusInternalSrvError}, } { t.Run(test.reason.String(), func(t *testing.T) { c, _, _ := testWSSetupForRead() c.wsEnqueueCloseMessage(test.reason) c.mu.Lock() nb, _ := c.collapsePtoNB() c.mu.Unlock() if n := len(nb); n != 1 { t.Fatalf("Expected 1 buffer, got %v", n) } b := nb[0][0] if b&wsFinalBit == 0 { t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) } if b&byte(wsCloseMessage) == 0 { t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) } if status := binary.BigEndian.Uint16(nb[0][2:4]); int(status) != test.status { t.Fatalf("Expected status to be %v, got %v", test.status, status) } if string(nb[0][4:]) != test.reason.String() { t.Fatalf("Unexpected content: %s", nb[0][4:]) } }) } } type testResponseWriter struct { http.ResponseWriter buf bytes.Buffer headers http.Header err error brw *bufio.ReadWriter conn *testWSFakeNetConn } func (trw *testResponseWriter) Write(p []byte) (int, error) { return trw.buf.Write(p) } func (trw *testResponseWriter) WriteHeader(status int) { trw.buf.WriteString(fmt.Sprintf("%v", status)) } func (trw *testResponseWriter) Header() http.Header { if trw.headers == nil { trw.headers = make(http.Header) } return trw.headers } type testWSFakeNetConn struct { net.Conn wbuf bytes.Buffer err error wsOpened bool isClosed bool deadlineCleared bool } func (c *testWSFakeNetConn) Write(p []byte) (int, error) { if c.err != nil { return 0, c.err } return c.wbuf.Write(p) } func (c *testWSFakeNetConn) SetDeadline(t time.Time) error { if t.IsZero() { c.deadlineCleared = true } return nil } func (c *testWSFakeNetConn) Close() error { c.isClosed = true return nil } func (trw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if trw.conn == nil { trw.conn = &testWSFakeNetConn{} } trw.conn.wsOpened = true if trw.brw == nil { trw.brw = bufio.NewReadWriter(bufio.NewReader(trw.conn), bufio.NewWriter(trw.conn)) } return trw.conn, trw.brw, trw.err } func testWSOptions() *Options { opts := DefaultOptions() opts.DisableShortFirstPing = true opts.Websocket.Host = "127.0.0.1" opts.Websocket.Port = -1 opts.NoSystemAccount = true var err error tc := &TLSConfigOpts{ CertFile: "./configs/certs/server.pem", KeyFile: "./configs/certs/key.pem", } opts.Websocket.TLSConfig, err = GenTLSConfig(tc) if err != nil { panic(err) } return opts } func testWSCreateValidReq() *http.Request { req := &http.Request{ Method: "GET", Host: "localhost", Proto: "HTTP/1.1", } req.Header = make(http.Header) req.Header.Set("Upgrade", "websocket") req.Header.Set("Connection", "Upgrade") req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") req.Header.Set("Sec-Websocket-Version", "13") return req } func TestWSCheckOrigin(t *testing.T) { notSameOrigin := false sameOrigin := true allowedListEmpty := []string{} someList := []string{"http://host1.com", "http://host2.com:1234"} for _, test := range []struct { name string sameOrigin bool origins []string reqHost string reqTLS bool origin string err string }{ {"any", notSameOrigin, allowedListEmpty, "", false, "http://any.host.com", ""}, {"same origin ok", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:80", ""}, {"same origin bad host", sameOrigin, allowedListEmpty, "host.com", false, "http://other.host.com", "not same origin"}, {"same origin bad port", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:81", "not same origin"}, {"same origin bad scheme", sameOrigin, allowedListEmpty, "host.com", true, "http://host.com", "not same origin"}, {"same origin bad uri", sameOrigin, allowedListEmpty, "host.com", false, "@@@://invalid:url:1234", "invalid URI"}, {"same origin bad url", sameOrigin, allowedListEmpty, "host.com", false, "http://invalid:url:1234", "too many colons"}, {"same origin bad req host", sameOrigin, allowedListEmpty, "invalid:url:1234", false, "http://host.com", "too many colons"}, {"no origin same origin ignored", sameOrigin, allowedListEmpty, "", false, "", ""}, {"no origin list ignored", sameOrigin, someList, "", false, "", ""}, {"no origin same origin and list ignored", sameOrigin, someList, "", false, "", ""}, {"allowed from list", notSameOrigin, someList, "", false, "http://host2.com:1234", ""}, {"allowed with different path", notSameOrigin, someList, "", false, "http://host1.com/some/path", ""}, {"list bad port", notSameOrigin, someList, "", false, "http://host1.com:1234", "not in the allowed list"}, {"list bad scheme", notSameOrigin, someList, "", false, "https://host2.com:1234", "not in the allowed list"}, } { t.Run(test.name, func(t *testing.T) { opts := DefaultOptions() opts.Websocket.SameOrigin = test.sameOrigin opts.Websocket.AllowedOrigins = test.origins s := &Server{opts: opts} s.wsSetOriginOptions(&opts.Websocket) req := testWSCreateValidReq() req.Host = test.reqHost if test.reqTLS { req.TLS = &tls.ConnectionState{} } if test.origin != "" { req.Header.Set("Origin", test.origin) } err := s.websocket.checkOrigin(req) if test.err == "" && err != nil { t.Fatalf("Unexpected error: %v", err) } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { t.Fatalf("Expected error %q, got %v", test.err, err) } }) } } func TestWSUpgradeValidationErrors(t *testing.T) { for _, test := range []struct { name string setup func() (*Options, *testResponseWriter, *http.Request) err string status int }{ { "bad method", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Method = "POST" return opts, nil, req }, "must be GET", http.StatusMethodNotAllowed, }, { "no host", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Host = "" return opts, nil, req }, "'Host' missing in request", http.StatusBadRequest, }, { "invalid upgrade header", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Header.Del("Upgrade") return opts, nil, req }, "invalid value for header 'Upgrade'", http.StatusBadRequest, }, { "invalid connection header", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Header.Del("Connection") return opts, nil, req }, "invalid value for header 'Connection'", http.StatusBadRequest, }, { "no key", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Header.Del("Sec-Websocket-Key") return opts, nil, req }, "key missing", http.StatusBadRequest, }, { "empty key", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Header.Set("Sec-Websocket-Key", "") return opts, nil, req }, "key missing", http.StatusBadRequest, }, { "missing version", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Header.Del("Sec-Websocket-Version") return opts, nil, req }, "invalid version", http.StatusBadRequest, }, { "wrong version", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() req := testWSCreateValidReq() req.Header.Set("Sec-Websocket-Version", "99") return opts, nil, req }, "invalid version", http.StatusBadRequest, }, { "origin", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() opts.Websocket.SameOrigin = true req := testWSCreateValidReq() req.Header.Set("Origin", "http://bad.host.com") return opts, nil, req }, "origin not allowed", http.StatusForbidden, }, { "hijack error", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() rw := &testResponseWriter{err: fmt.Errorf("on purpose")} req := testWSCreateValidReq() return opts, rw, req }, "on purpose", http.StatusInternalServerError, }, { "hijack buffered data", func() (*Options, *testResponseWriter, *http.Request) { opts := testWSOptions() buf := &bytes.Buffer{} buf.WriteString("some data") rw := &testResponseWriter{ conn: &testWSFakeNetConn{}, brw: bufio.NewReadWriter(bufio.NewReader(buf), bufio.NewWriter(nil)), } tmp := [1]byte{} io.ReadAtLeast(rw.brw, tmp[:1], 1) req := testWSCreateValidReq() return opts, rw, req }, "client sent data before handshake is complete", http.StatusBadRequest, }, } { t.Run(test.name, func(t *testing.T) { opts, rw, req := test.setup() if rw == nil { rw = &testResponseWriter{} } s := &Server{opts: opts} s.wsSetOriginOptions(&opts.Websocket) res, err := s.wsUpgrade(rw, req) if err == nil || !strings.Contains(err.Error(), test.err) { t.Fatalf("Should get error %q, got %v", test.err, err) } if res != nil { t.Fatalf("Should not have returned a result, got %v", res) } expected := fmt.Sprintf("%v%s\n", test.status, http.StatusText(test.status)) if got := rw.buf.String(); got != expected { t.Fatalf("Expected %q got %q", expected, got) } // Check that if the connection was opened, it is now closed. if rw.conn != nil && rw.conn.wsOpened && !rw.conn.isClosed { t.Fatal("Connection was opened, but has not been closed") } }) } } func TestWSUpgradeResponseWriteError(t *testing.T) { opts := testWSOptions() s := &Server{opts: opts} expectedErr := errors.New("on purpose") rw := &testResponseWriter{ conn: &testWSFakeNetConn{err: expectedErr}, } req := testWSCreateValidReq() res, err := s.wsUpgrade(rw, req) if err != expectedErr { t.Fatalf("Should get error %q, got %v", expectedErr.Error(), err) } if res != nil { t.Fatalf("Should not have returned a result, got %v", res) } if !rw.conn.isClosed { t.Fatal("Connection should have been closed") } } func TestWSUpgradeConnDeadline(t *testing.T) { opts := testWSOptions() opts.Websocket.HandshakeTimeout = time.Second s := &Server{opts: opts} rw := &testResponseWriter{} req := testWSCreateValidReq() res, err := s.wsUpgrade(rw, req) if res == nil || err != nil { t.Fatalf("Unexpected error: %v", err) } if rw.conn.isClosed { t.Fatal("Connection should NOT have been closed") } if !rw.conn.deadlineCleared { t.Fatal("Connection deadline should have been cleared after handshake") } } func TestWSCompressNegotiation(t *testing.T) { // No compression on the server, but client asks opts := testWSOptions() s := &Server{opts: opts} rw := &testResponseWriter{} req := testWSCreateValidReq() req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate") res, err := s.wsUpgrade(rw, req) if res == nil || err != nil { t.Fatalf("Unexpected error: %v", err) } // The http response should not contain "permessage-deflate" output := rw.conn.wbuf.String() if strings.Contains(output, "permessage-deflate") { t.Fatalf("Compression disabled in server so response to client should not contain extension, got %s", output) } // Option in the server and client, so compression should be negotiated. s.opts.Websocket.Compression = true rw = &testResponseWriter{} res, err = s.wsUpgrade(rw, req) if res == nil || err != nil { t.Fatalf("Unexpected error: %v", err) } // The http response should not contain "permessage-deflate" output = rw.conn.wbuf.String() if !strings.Contains(output, "permessage-deflate") { t.Fatalf("Compression in server and client request, so response should contain extension, got %s", output) } // Option in server but not asked by the client, so response should not contain "permessage-deflate" rw = &testResponseWriter{} req.Header.Del("Sec-Websocket-Extensions") res, err = s.wsUpgrade(rw, req) if res == nil || err != nil { t.Fatalf("Unexpected error: %v", err) } // The http response should not contain "permessage-deflate" output = rw.conn.wbuf.String() if strings.Contains(output, "permessage-deflate") { t.Fatalf("Compression in server but not in client, so response to client should not contain extension, got %s", output) } } func TestWSParseOptions(t *testing.T) { for _, test := range []struct { name string content string checkOpt func(*WebsocketOpts) error err string }{ // Negative tests {"bad type", "websocket: []", nil, "to be a map"}, {"bad listen", "websocket: { listen: [] }", nil, "port or host:port"}, {"bad port", `websocket: { port: "abc" }`, nil, "not int64"}, {"bad host", `websocket: { host: 123 }`, nil, "not string"}, {"bad advertise type", `websocket: { advertise: 123 }`, nil, "not string"}, {"bad tls", `websocket: { tls: 123 }`, nil, "not map[string]interface {}"}, {"bad same origin", `websocket: { same_origin: "abc" }`, nil, "not bool"}, {"bad allowed origins type", `websocket: { allowed_origins: {} }`, nil, "unsupported type"}, {"bad allowed origins values", `websocket: { allowed_origins: [ {} ] }`, nil, "unsupported type in array"}, {"bad handshake timeout type", `websocket: { handshake_timeout: [] }`, nil, "unsupported type"}, {"bad handshake timeout duration", `websocket: { handshake_timeout: "abc" }`, nil, "invalid duration"}, {"unknown field", `websocket: { this_does_not_exist: 123 }`, nil, "unknown"}, // Positive tests {"listen port only", `websocket { listen: 1234 }`, func(wo *WebsocketOpts) error { if wo.Port != 1234 { return fmt.Errorf("expected 1234, got %v", wo.Port) } return nil }, ""}, {"listen host and port", `websocket { listen: "localhost:1234" }`, func(wo *WebsocketOpts) error { if wo.Host != "localhost" || wo.Port != 1234 { return fmt.Errorf("expected localhost:1234, got %v:%v", wo.Host, wo.Port) } return nil }, ""}, {"host", `websocket { host: "localhost" }`, func(wo *WebsocketOpts) error { if wo.Host != "localhost" { return fmt.Errorf("expected localhost, got %v", wo.Host) } return nil }, ""}, {"port", `websocket { port: 1234 }`, func(wo *WebsocketOpts) error { if wo.Port != 1234 { return fmt.Errorf("expected 1234, got %v", wo.Port) } return nil }, ""}, {"advertise", `websocket { advertise: "host:1234" }`, func(wo *WebsocketOpts) error { if wo.Advertise != "host:1234" { return fmt.Errorf("expected %q, got %q", "host:1234", wo.Advertise) } return nil }, ""}, {"same origin", `websocket { same_origin: true }`, func(wo *WebsocketOpts) error { if !wo.SameOrigin { return fmt.Errorf("expected same_origin==true, got %v", wo.SameOrigin) } return nil }, ""}, {"allowed origins one only", `websocket { allowed_origins: "https://host.com/" }`, func(wo *WebsocketOpts) error { expected := []string{"https://host.com/"} if !reflect.DeepEqual(wo.AllowedOrigins, expected) { return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins) } return nil }, ""}, {"allowed origins array", ` websocket { allowed_origins: [ "https://host1.com/" "https://host2.com/" ] } `, func(wo *WebsocketOpts) error { expected := []string{"https://host1.com/", "https://host2.com/"} if !reflect.DeepEqual(wo.AllowedOrigins, expected) { return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins) } return nil }, ""}, {"handshake timeout in whole seconds", `websocket { handshake_timeout: 3 }`, func(wo *WebsocketOpts) error { if wo.HandshakeTimeout != 3*time.Second { return fmt.Errorf("expected handshake to be 3s, got %v", wo.HandshakeTimeout) } return nil }, ""}, {"handshake timeout n duration", `websocket { handshake_timeout: "4s" }`, func(wo *WebsocketOpts) error { if wo.HandshakeTimeout != 4*time.Second { return fmt.Errorf("expected handshake to be 4s, got %v", wo.HandshakeTimeout) } return nil }, ""}, {"tls config", ` websocket { tls { cert_file: "./configs/certs/server.pem" key_file: "./configs/certs/key.pem" } } `, func(wo *WebsocketOpts) error { if wo.TLSConfig == nil { return fmt.Errorf("TLSConfig should have been set") } return nil }, ""}, {"compression", ` websocket { compression: true } `, func(wo *WebsocketOpts) error { if !wo.Compression { return fmt.Errorf("Compression should have been set") } return nil }, ""}, {"jwt cookie", ` websocket { jwt_cookie: "jwtcookie" } `, func(wo *WebsocketOpts) error { if wo.JWTCookie != "jwtcookie" { return fmt.Errorf("Invalid JWTCookie value: %q", wo.JWTCookie) } return nil }, ""}, {"no auth user", ` websocket { no_auth_user: "noauthuser" } `, func(wo *WebsocketOpts) error { if wo.NoAuthUser != "noauthuser" { return fmt.Errorf("Invalid NoAuthUser value: %q", wo.NoAuthUser) } return nil }, ""}, {"auth block", ` websocket { authorization { user: "webuser" password: "pwd" token: "token" timeout: 2.0 } } `, func(wo *WebsocketOpts) error { if wo.Username != "webuser" || wo.Password != "pwd" || wo.Token != "token" || wo.AuthTimeout != 2.0 { return fmt.Errorf("Invalid auth block: %+v", wo) } return nil }, ""}, {"auth timeout as int", ` websocket { authorization { timeout: 2 } } `, func(wo *WebsocketOpts) error { if wo.AuthTimeout != 2.0 { return fmt.Errorf("Invalid auth timeout: %v", wo.AuthTimeout) } return nil }, ""}, } { t.Run(test.name, func(t *testing.T) { conf := createConfFile(t, []byte(test.content)) o, err := ProcessConfigFile(conf) if test.err != _EMPTY_ { if err == nil || !strings.Contains(err.Error(), test.err) { t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err) } return } else if err != nil { t.Fatalf("Unexpected error for content %q: %v", test.content, err) } if err := test.checkOpt(&o.Websocket); err != nil { t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error()) } }) } } func TestWSValidateOptions(t *testing.T) { nwso := DefaultOptions() wso := testWSOptions() for _, test := range []struct { name string getOpts func() *Options err string }{ {"websocket disabled", func() *Options { return nwso.Clone() }, ""}, {"no tls", func() *Options { o := wso.Clone(); o.Websocket.TLSConfig = nil; return o }, "requires TLS configuration"}, {"bad url in allowed list", func() *Options { o := wso.Clone() o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"} return o }, "unable to parse"}, {"missing trusted configuration", func() *Options { o := wso.Clone() o.Websocket.JWTCookie = "jwt" return o }, "keys configuration is required"}, {"websocket username not allowed if users specified", func() *Options { o := wso.Clone() o.Nkeys = []*NkeyUser{{Nkey: "abc"}} o.Websocket.Username = "b" o.Websocket.Password = "pwd" return o }, "websocket authentication username not compatible with presence of users/nkeys"}, {"websocket token not allowed if users specified", func() *Options { o := wso.Clone() o.Nkeys = []*NkeyUser{{Nkey: "abc"}} o.Websocket.Token = "mytoken" return o }, "websocket authentication token not compatible with presence of users/nkeys"}, } { t.Run(test.name, func(t *testing.T) { err := validateWebsocketOptions(test.getOpts()) if test.err == "" && err != nil { t.Fatalf("Unexpected error: %v", err) } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { t.Fatalf("Expected error to contain %q, got %v", test.err, err) } }) } } func TestWSSetOriginOptions(t *testing.T) { o := testWSOptions() for _, test := range []struct { content string err string }{ {"@@@://host.com/", "invalid URI"}, {"http://this:is:bad:url/", "invalid port"}, } { t.Run(test.err, func(t *testing.T) { o.Websocket.AllowedOrigins = []string{test.content} s := &Server{} l := &captureErrorLogger{errCh: make(chan string, 1)} s.SetLogger(l, false, false) s.wsSetOriginOptions(&o.Websocket) select { case e := <-l.errCh: if !strings.Contains(e, test.err) { t.Fatalf("Unexpected error: %v", e) } case <-time.After(50 * time.Millisecond): t.Fatalf("Did not get the error") } }) } } type captureFatalLogger struct { DummyLogger fatalCh chan string } func (l *captureFatalLogger) Fatalf(format string, v ...interface{}) { select { case l.fatalCh <- fmt.Sprintf(format, v...): default: } } func TestWSFailureToStartServer(t *testing.T) { // Create a listener to use a port l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Error listening: %v", err) } defer l.Close() o := testWSOptions() // Make sure we don't have unnecessary listen ports opened. o.HTTPPort = 0 o.Cluster.Port = 0 o.Gateway.Name = "" o.Gateway.Port = 0 o.LeafNode.Port = 0 o.Websocket.Port = l.Addr().(*net.TCPAddr).Port s, err := NewServer(o) if err != nil { t.Fatalf("Error creating server: %v", err) } defer s.Shutdown() logger := &captureFatalLogger{fatalCh: make(chan string, 1)} s.SetLogger(logger, false, false) wg := sync.WaitGroup{} wg.Add(1) go func() { s.Start() wg.Done() }() select { case e := <-logger.fatalCh: if !strings.Contains(e, "Unable to listen") { t.Fatalf("Unexpected error: %v", e) } case <-time.After(2 * time.Second): t.Fatalf("Should have reported a fatal error") } // Since this is a test and the process does not actually // exit on Fatal error, wait for the client port to be // ready so when we shutdown we don't leave the accept // loop hanging. checkFor(t, time.Second, 15*time.Millisecond, func() error { s.mu.Lock() ready := s.listener != nil s.mu.Unlock() if !ready { return fmt.Errorf("client accept loop not started yet") } return nil }) s.Shutdown() wg.Wait() } func TestWSAbnormalFailureOfWebServer(t *testing.T) { o := testWSOptions() s := RunServer(o) defer s.Shutdown() logger := &captureFatalLogger{fatalCh: make(chan string, 1)} s.SetLogger(logger, false, false) // Now close the WS listener to cause a WebServer error s.mu.Lock() s.websocket.listener.Close() s.mu.Unlock() select { case e := <-logger.fatalCh: if !strings.Contains(e, "websocket listener error") { t.Fatalf("Unexpected error: %v", e) } case <-time.After(2 * time.Second): t.Fatalf("Should have reported a fatal error") } } type testWSClientOptions struct { compress, web bool host string port int extraHeaders map[string][]string noTLS bool path string } func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) { t.Helper() c, br, info, err := testNewWSClientWithError(t, o) if err != nil { t.Fatal(err) } return c, br, info } func testNewWSClientWithError(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte, error) { addr := fmt.Sprintf("%s:%d", o.host, o.port) wsc, err := net.Dial("tcp", addr) if err != nil { return nil, nil, nil, fmt.Errorf("Error creating ws connection: %v", err) } if !o.noTLS { wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) wsc.SetDeadline(time.Now().Add(time.Second)) if err := wsc.(*tls.Conn).Handshake(); err != nil { return nil, nil, nil, fmt.Errorf("Error during handshake: %v", err) } wsc.SetDeadline(time.Time{}) } req := testWSCreateValidReq() if o.compress { req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate") } if o.web { req.Header.Set("User-Agent", "Mozilla/5.0") } if len(o.extraHeaders) > 0 { for hdr, values := range o.extraHeaders { if len(values) == 0 { req.Header.Set(hdr, _EMPTY_) continue } req.Header.Set(hdr, values[0]) for i := 1; i < len(values); i++ { req.Header.Add(hdr, values[i]) } } } req.URL, _ = url.Parse("wss://" + addr + o.path) if err := req.Write(wsc); err != nil { return nil, nil, nil, fmt.Errorf("Error sending request: %v", err) } br := bufio.NewReader(wsc) resp, err := http.ReadResponse(br, req) if err != nil { return nil, nil, nil, fmt.Errorf("Error reading response: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusSwitchingProtocols { return nil, nil, nil, fmt.Errorf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) } var info []byte if o.path == mqttWSPath { if v := resp.Header[wsSecProto]; len(v) != 1 || v[0] != wsMQTTSecProtoVal { return nil, nil, nil, fmt.Errorf("No mqtt protocol in header: %v", resp.Header) } } else { // Wait for the INFO info = testWSReadFrame(t, br) if !bytes.HasPrefix(info, []byte("INFO {")) { return nil, nil, nil, fmt.Errorf("Expected INFO, got %s", info) } } return wsc, br, info, nil } type testClaimsOptions struct { nac *jwt.AccountClaims nuc *jwt.UserClaims connectRequest interface{} dontSign bool expectAnswer string } func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) { t.Helper() okp, _ := nkeys.FromSeed(oSeed) akp, _ := nkeys.CreateAccount() apub, _ := akp.PublicKey() if tclm.nac == nil { tclm.nac = jwt.NewAccountClaims(apub) } else { tclm.nac.Subject = apub } ajwt, err := tclm.nac.Encode(okp) if err != nil { t.Fatalf("Error generating account JWT: %v", err) } nkp, _ := nkeys.CreateUser() pub, _ := nkp.PublicKey() if tclm.nuc == nil { tclm.nuc = jwt.NewUserClaims(pub) } else { tclm.nuc.Subject = pub } jwt, err := tclm.nuc.Encode(akp) if err != nil { t.Fatalf("Error generating user JWT: %v", err) } addAccountToMemResolver(s, apub, ajwt) c, cr, l := testNewWSClient(t, o) var info struct { Nonce string `json:"nonce,omitempty"` AuthRequired bool `json:"auth_required,omitempty"` } if err := json.Unmarshal([]byte(l[5:]), &info); err != nil { t.Fatal(err) } if info.AuthRequired { cs := "" if tclm.connectRequest != nil { customReq, err := json.Marshal(tclm.connectRequest) if err != nil { t.Fatal(err) } // PING needed to flush the +OK/-ERR to us. cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq)) } else if !tclm.dontSign { // Sign Nonce sigraw, _ := nkp.Sign([]byte(info.Nonce)) sig := base64.RawURLEncoding.EncodeToString(sigraw) cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig) } else { cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt) } wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs)) c.Write(wsmsg) l = testWSReadFrame(t, cr) if !strings.HasPrefix(string(l), tclm.expectAnswer) { t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l) } } return akp, c, cr, info.AuthRequired } func setupAddTrusted(o *Options) { kp, _ := nkeys.FromSeed(oSeed) pub, _ := kp.PublicKey() o.TrustedKeys = []string{pub} } func setupAddCookie(o *Options) { o.Websocket.JWTCookie = "jwt" } func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) { t.Helper() return testNewWSClient(t, testWSClientOptions{ compress: compress, web: web, host: host, port: port, }) } func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) { wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port) // Send CONNECT and PING wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } // Wait for the PONG if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected PONG, got %s", msg) } return wsc, br } func testWSReadFrame(t testing.TB, br *bufio.Reader) []byte { t.Helper() fh := [2]byte{} if _, err := io.ReadAtLeast(br, fh[:2], 2); err != nil { t.Fatalf("Error reading frame: %v", err) } fc := fh[0]&wsRsv1Bit != 0 sb := fh[1] size := 0 switch { case sb <= 125: size = int(sb) case sb == 126: tmp := [2]byte{} if _, err := io.ReadAtLeast(br, tmp[:2], 2); err != nil { t.Fatalf("Error reading frame: %v", err) } size = int(binary.BigEndian.Uint16(tmp[:2])) case sb == 127: tmp := [8]byte{} if _, err := io.ReadAtLeast(br, tmp[:8], 8); err != nil { t.Fatalf("Error reading frame: %v", err) } size = int(binary.BigEndian.Uint64(tmp[:8])) } buf := make([]byte, size) if _, err := io.ReadAtLeast(br, buf, size); err != nil { t.Fatalf("Error reading frame: %v", err) } if !fc { return buf } buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) dbr := bytes.NewBuffer(buf) d := flate.NewReader(dbr) uncompressed, err := io.ReadAll(d) if err != nil { t.Fatalf("Error reading frame: %v", err) } return uncompressed } func TestWSPubSub(t *testing.T) { for _, test := range []struct { name string compression bool }{ {"no compression", false}, {"compression", true}, } { t.Run(test.name, func(t *testing.T) { o := testWSOptions() if test.compression { o.Websocket.Compression = true } s := RunServer(o) defer s.Shutdown() // Create a regular client to subscribe nc := natsConnect(t, s.ClientURL()) defer nc.Close() nsub := natsSubSync(t, nc, "foo") checkExpectedSubs(t, 1, s) // Now create a WS client and send a message on "foo" wsc, br := testWSCreateClient(t, test.compression, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() // Send a WS message for "PUB foo 2\r\nok\r\n" wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n")) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } // Now check that message is received msg := natsNexMsg(t, nsub, time.Second) if string(msg.Data) != "from ws" { t.Fatalf("Expected message to be %q, got %q", "ok", string(msg.Data)) } // Now do reverse, create a subscription on WS client on bar wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB bar 1\r\n")) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending subscription: %v", err) } // Wait for it to be registered on server checkExpectedSubs(t, 2, s) // Now publish from NATS connection and verify received on WS client natsPub(t, nc, "bar", []byte("from nats")) natsFlush(t, nc) // Check for the "from nats" message... // Set some deadline so we are not stuck forever on failure wsc.SetReadDeadline(time.Now().Add(10 * time.Second)) ok := 0 for { line, _, err := br.ReadLine() if err != nil { t.Fatalf("Error reading: %v", err) } // Note that this works even in compression test because those // texts are likely not to be compressed, but compression code is // still executed. if ok == 0 && bytes.Contains(line, []byte("MSG bar 1 9")) { ok = 1 continue } else if ok == 1 && bytes.Contains(line, []byte("from nats")) { ok = 2 break } } }) } } func TestWSTLSConnection(t *testing.T) { o := testWSOptions() s := RunServer(o) defer s.Shutdown() addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) for _, test := range []struct { name string useTLS bool status int }{ {"client uses TLS", true, http.StatusSwitchingProtocols}, {"client does not use TLS", false, http.StatusBadRequest}, } { t.Run(test.name, func(t *testing.T) { wsc, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) } defer wsc.Close() if test.useTLS { wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) if err := wsc.(*tls.Conn).Handshake(); err != nil { t.Fatalf("Error during handshake: %v", err) } } req := testWSCreateValidReq() var scheme string if test.useTLS { scheme = "s" } req.URL, _ = url.Parse("ws" + scheme + "://" + addr) if err := req.Write(wsc); err != nil { t.Fatalf("Error sending request: %v", err) } br := bufio.NewReader(wsc) resp, err := http.ReadResponse(br, req) if err != nil { t.Fatalf("Error reading response: %v", err) } defer resp.Body.Close() if resp.StatusCode != test.status { t.Fatalf("Expected status %v, got %v", test.status, resp.StatusCode) } }) } } func TestWSTLSVerifyClientCert(t *testing.T) { o := testWSOptions() tc := &TLSConfigOpts{ CertFile: "../test/configs/certs/server-cert.pem", KeyFile: "../test/configs/certs/server-key.pem", CaFile: "../test/configs/certs/ca.pem", Verify: true, } tlsc, err := GenTLSConfig(tc) if err != nil { t.Fatalf("Error creating tls config: %v", err) } o.Websocket.TLSConfig = tlsc s := RunServer(o) defer s.Shutdown() addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) for _, test := range []struct { name string provideCert bool }{ {"client provides cert", true}, {"client does not provide cert", false}, } { t.Run(test.name, func(t *testing.T) { wsc, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) } defer wsc.Close() tlsc := &tls.Config{} if test.provideCert { tc := &TLSConfigOpts{ CertFile: "../test/configs/certs/client-cert.pem", KeyFile: "../test/configs/certs/client-key.pem", } var err error tlsc, err = GenTLSConfig(tc) if err != nil { t.Fatalf("Error generating tls config: %v", err) } } tlsc.InsecureSkipVerify = true wsc = tls.Client(wsc, tlsc) if err := wsc.(*tls.Conn).Handshake(); err != nil { t.Fatalf("Error during handshake: %v", err) } req := testWSCreateValidReq() req.URL, _ = url.Parse("wss://" + addr) if err := req.Write(wsc); err != nil { t.Fatalf("Error sending request: %v", err) } br := bufio.NewReader(wsc) resp, err := http.ReadResponse(br, req) if resp != nil { resp.Body.Close() } if !test.provideCert { if err == nil { t.Fatal("Expected error, did not get one") } else if !strings.Contains(err.Error(), "bad certificate") { t.Fatalf("Unexpected error: %v", err) } return } if err != nil { t.Fatalf("Unexpected error: %v", err) } if resp.StatusCode != http.StatusSwitchingProtocols { t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) } }) } } func testCreateAllowedConnectionTypes(list []string) map[string]struct{} { if len(list) == 0 { return nil } m := make(map[string]struct{}, len(list)) for _, l := range list { m[l] = struct{}{} } return m } func TestWSTLSVerifyAndMap(t *testing.T) { accName := "MyAccount" acc := NewAccount(accName) certUserName := "CN=example.com,OU=NATS.io" users := []*User{{Username: certUserName, Account: acc}} for _, test := range []struct { name string filtering bool provideCert bool }{ {"no filtering, client provides cert", false, true}, {"no filtering, client does not provide cert", false, false}, {"filtering, client provides cert", true, true}, {"filtering, client does not provide cert", true, false}, {"no users override, client provides cert", false, true}, {"no users override, client does not provide cert", false, false}, {"users override, client provides cert", true, true}, {"users override, client does not provide cert", true, false}, } { t.Run(test.name, func(t *testing.T) { o := testWSOptions() o.Accounts = []*Account{acc} o.Users = users if test.filtering { o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) } tc := &TLSConfigOpts{ CertFile: "../test/configs/certs/tlsauth/server.pem", KeyFile: "../test/configs/certs/tlsauth/server-key.pem", CaFile: "../test/configs/certs/tlsauth/ca.pem", Verify: true, } tlsc, err := GenTLSConfig(tc) if err != nil { t.Fatalf("Error creating tls config: %v", err) } o.Websocket.TLSConfig = tlsc o.Websocket.TLSMap = true s := RunServer(o) defer s.Shutdown() addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) wsc, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) } defer wsc.Close() tlscc := &tls.Config{} if test.provideCert { tc := &TLSConfigOpts{ CertFile: "../test/configs/certs/tlsauth/client.pem", KeyFile: "../test/configs/certs/tlsauth/client-key.pem", } var err error tlscc, err = GenTLSConfig(tc) if err != nil { t.Fatalf("Error generating tls config: %v", err) } } tlscc.InsecureSkipVerify = true wsc = tls.Client(wsc, tlscc) if err := wsc.(*tls.Conn).Handshake(); err != nil { t.Fatalf("Error during handshake: %v", err) } req := testWSCreateValidReq() req.URL, _ = url.Parse("wss://" + addr) if err := req.Write(wsc); err != nil { t.Fatalf("Error sending request: %v", err) } br := bufio.NewReader(wsc) resp, err := http.ReadResponse(br, req) if resp != nil { resp.Body.Close() } if !test.provideCert { if err == nil { t.Fatal("Expected error, did not get one") } else if !strings.Contains(err.Error(), "bad certificate") { t.Fatalf("Unexpected error: %v", err) } return } if err != nil { t.Fatalf("Unexpected error: %v", err) } if resp.StatusCode != http.StatusSwitchingProtocols { t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) } // Wait for the INFO l := testWSReadFrame(t, br) if !bytes.HasPrefix(l, []byte("INFO {")) { t.Fatalf("Expected INFO, got %s", l) } var info serverInfo if err := json.Unmarshal(l[5:], &info); err != nil { t.Fatalf("Unable to unmarshal info: %v", err) } // Send CONNECT and PING wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } // Wait for the PONG if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected PONG, got %s", msg) } var uname string var accname string c := s.getClient(info.CID) if c != nil { c.mu.Lock() uname = c.opts.Username if c.acc != nil { accname = c.acc.GetName() } c.mu.Unlock() } if uname != certUserName { t.Fatalf("Expected username %q, got %q", certUserName, uname) } if accname != accName { t.Fatalf("Expected account %q, got %v", accName, accname) } }) } } func TestWSHandshakeTimeout(t *testing.T) { o := testWSOptions() o.Websocket.HandshakeTimeout = time.Millisecond tc := &TLSConfigOpts{ CertFile: "./configs/certs/server.pem", KeyFile: "./configs/certs/key.pem", } o.Websocket.TLSConfig, _ = GenTLSConfig(tc) s := RunServer(o) defer s.Shutdown() logger := &captureErrorLogger{errCh: make(chan string, 1)} s.SetLogger(logger, false, false) addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) wsc, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) } defer wsc.Close() // Delay the handshake wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) time.Sleep(20 * time.Millisecond) // We expect error since the server should have cut us off if err := wsc.(*tls.Conn).Handshake(); err == nil { t.Fatal("Expected error during handshake") } // Check that server logs error select { case e := <-logger.errCh: // Check that log starts with "websocket: " if !strings.HasPrefix(e, "websocket: ") { t.Fatalf("Wrong log line start: %s", e) } if !strings.Contains(e, "timeout") { t.Fatalf("Unexpected error: %v", e) } case <-time.After(time.Second): t.Fatalf("Should have timed-out") } } func TestWSServerReportUpgradeFailure(t *testing.T) { o := testWSOptions() s := RunServer(o) defer s.Shutdown() logger := &captureErrorLogger{errCh: make(chan string, 1)} s.SetLogger(logger, false, false) addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) req := testWSCreateValidReq() req.URL, _ = url.Parse("wss://" + addr) wsc, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) } defer wsc.Close() wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) if err := wsc.(*tls.Conn).Handshake(); err != nil { t.Fatalf("Error during handshake: %v", err) } // Remove a required field from the request to have it fail req.Header.Del("Connection") // Send the request if err := req.Write(wsc); err != nil { t.Fatalf("Error sending request: %v", err) } br := bufio.NewReader(wsc) resp, err := http.ReadResponse(br, req) if err != nil { t.Fatalf("Error reading response: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Fatalf("Expected status %v, got %v", http.StatusBadRequest, resp.StatusCode) } // Check that server logs error select { case e := <-logger.errCh: if !strings.Contains(e, "invalid value for header 'Connection'") { t.Fatalf("Unexpected error: %v", e) } // The client IP's local should be printed as a remote from server perspective. clientIP := wsc.LocalAddr().String() if !strings.HasPrefix(e, clientIP) { t.Fatalf("IP should have been logged, it was not: %v", e) } case <-time.After(time.Second): t.Fatalf("Should have timed-out") } } func TestWSCloseMsgSendOnConnectionClose(t *testing.T) { o := testWSOptions() s := RunServer(o) defer s.Shutdown() wsc, br := testWSCreateClient(t, false, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() checkClientsCount(t, s, 1) var c *client s.mu.Lock() for _, cli := range s.clients { c = cli break } s.mu.Unlock() c.closeConnection(ProtocolViolation) msg := testWSReadFrame(t, br) if len(msg) < 2 { t.Fatalf("Should have 2 bytes to represent the status, got %v", msg) } if sc := int(binary.BigEndian.Uint16(msg[:2])); sc != wsCloseStatusProtocolError { t.Fatalf("Expected status to be %v, got %v", wsCloseStatusProtocolError, sc) } expectedPayload := ProtocolViolation.String() if p := string(msg[2:]); p != expectedPayload { t.Fatalf("Expected payload to be %q, got %q", expectedPayload, p) } } func TestWSAdvertise(t *testing.T) { o := testWSOptions() o.Cluster.Port = 0 o.HTTPPort = 0 o.Websocket.Advertise = "xxx:host:yyy" s, err := NewServer(o) if err != nil { t.Fatalf("Unexpected error: %v", err) } defer s.Shutdown() l := &captureFatalLogger{fatalCh: make(chan string, 1)} s.SetLogger(l, false, false) go s.Start() select { case e := <-l.fatalCh: if !strings.Contains(e, "Unable to get websocket connect URLs") { t.Fatalf("Unexpected error: %q", e) } case <-time.After(time.Second): t.Fatal("Should have failed to start") } s.Shutdown() o1 := testWSOptions() o1.Websocket.Advertise = "host1:1234" s1 := RunServer(o1) defer s1.Shutdown() wsc, br := testWSCreateClient(t, false, false, o1.Websocket.Host, o1.Websocket.Port) defer wsc.Close() o2 := testWSOptions() o2.Websocket.Advertise = "host2:5678" o2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", o1.Cluster.Host, o1.Cluster.Port)) s2 := RunServer(o2) defer s2.Shutdown() checkInfo := func(expected []string) { t.Helper() infob := testWSReadFrame(t, br) info := &Info{} json.Unmarshal(infob[5:], info) if n := len(info.ClientConnectURLs); n != len(expected) { t.Fatalf("Unexpected info: %+v", info) } good := 0 for _, u := range info.ClientConnectURLs { for _, eu := range expected { if u == eu { good++ } } } if good != len(expected) { t.Fatalf("Unexpected connect urls: %q", info.ClientConnectURLs) } } checkInfo([]string{"host1:1234", "host2:5678"}) // Now shutdown s2 and expect another INFO s2.Shutdown() checkInfo([]string{"host1:1234"}) // Restart with another advertise and check that it gets updated o2.Websocket.Advertise = "host3:9012" s2 = RunServer(o2) defer s2.Shutdown() checkInfo([]string{"host1:1234", "host3:9012"}) } func TestWSFrameOutbound(t *testing.T) { for _, test := range []struct { name string maskingWrite bool }{ {"no write masking", false}, {"write masking", true}, } { t.Run(test.name, func(t *testing.T) { c, _, _ := testWSSetupForRead() c.ws.maskwrite = test.maskingWrite getKey := func(buf []byte) []byte { return buf[len(buf)-4:] } var bufs net.Buffers bufs = append(bufs, []byte("this ")) bufs = append(bufs, []byte("is ")) bufs = append(bufs, []byte("a ")) bufs = append(bufs, []byte("set ")) bufs = append(bufs, []byte("of ")) bufs = append(bufs, []byte("buffers")) en := 2 for _, b := range bufs { en += len(b) } if test.maskingWrite { en += 4 } c.mu.Lock() c.out.nb = bufs res, n := c.collapsePtoNB() c.mu.Unlock() if n != int64(en) { t.Fatalf("Expected size to be %v, got %v", en, n) } if eb := 1 + len(bufs); eb != len(res) { t.Fatalf("Expected %v buffers, got %v", eb, len(res)) } var ob []byte for i := 1; i < len(res); i++ { ob = append(ob, res[i]...) } if test.maskingWrite { wsMaskBuf(getKey(res[0]), ob) } if !bytes.Equal(ob, []byte("this is a set of buffers")) { t.Fatalf("Unexpected outbound: %q", ob) } bufs = nil c.out.pb = 0 c.ws.fs = 0 c.ws.frames = nil c.ws.browser = true bufs = append(bufs, []byte("some smaller ")) bufs = append(bufs, []byte("buffers")) bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10)) bufs = append(bufs, []byte("then some more")) en = 2 + len(bufs[0]) + len(bufs[1]) en += 4 + len(bufs[2]) - 10 en += 2 + len(bufs[3]) + 10 c.mu.Lock() c.out.nb = bufs res, n = c.collapsePtoNB() c.mu.Unlock() if test.maskingWrite { en += 3 * 4 } if n != int64(en) { t.Fatalf("Expected size to be %v, got %v", en, n) } if len(res) != 8 { t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) } if len(res[4]) != wsFrameSizeForBrowsers { t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) } if len(res[6]) != 10 { t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6])) } if test.maskingWrite { b := &bytes.Buffer{} key := getKey(res[0]) b.Write(res[1]) b.Write(res[2]) ud := b.Bytes() wsMaskBuf(key, ud) if string(ud) != "some smaller buffers" { t.Fatalf("Unexpected result: %q", ud) } b.Reset() key = getKey(res[3]) b.Write(res[4]) ud = b.Bytes() wsMaskBuf(key, ud) for i := 0; i < len(ud); i++ { if ud[i] != 0 { t.Fatalf("Unexpected result: %v", ud) } } b.Reset() key = getKey(res[5]) b.Write(res[6]) b.Write(res[7]) ud = b.Bytes() wsMaskBuf(key, ud) for i := 0; i < len(ud[:10]); i++ { if ud[i] != 0 { t.Fatalf("Unexpected result: %v", ud[:10]) } } if string(ud[10:]) != "then some more" { t.Fatalf("Unexpected result: %q", ud[10:]) } } bufs = nil c.out.pb = 0 c.ws.fs = 0 c.ws.frames = nil c.ws.browser = true bufs = append(bufs, []byte("some smaller ")) bufs = append(bufs, []byte("buffers")) // Have one of the exact max size bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) bufs = append(bufs, []byte("then some more")) en = 2 + len(bufs[0]) + len(bufs[1]) en += 4 + len(bufs[2]) en += 2 + len(bufs[3]) c.mu.Lock() c.out.nb = bufs res, n = c.collapsePtoNB() c.mu.Unlock() if test.maskingWrite { en += 3 * 4 } if n != int64(en) { t.Fatalf("Expected size to be %v, got %v", en, n) } if len(res) != 7 { t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) } if len(res[4]) != wsFrameSizeForBrowsers { t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) } if test.maskingWrite { key := getKey(res[5]) wsMaskBuf(key, res[6]) } if string(res[6]) != "then some more" { t.Fatalf("Frame 6 incorrect: %q", res[6]) } bufs = nil c.out.pb = 0 c.ws.fs = 0 c.ws.frames = nil c.ws.browser = true bufs = append(bufs, []byte("some smaller ")) bufs = append(bufs, []byte("buffers")) // Have one of the exact max size, and last in the list bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) en = 2 + len(bufs[0]) + len(bufs[1]) en += 4 + len(bufs[2]) c.mu.Lock() c.out.nb = bufs res, n = c.collapsePtoNB() c.mu.Unlock() if test.maskingWrite { en += 2 * 4 } if n != int64(en) { t.Fatalf("Expected size to be %v, got %v", en, n) } if len(res) != 5 { t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) } if len(res[4]) != wsFrameSizeForBrowsers { t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) } bufs = nil c.out.pb = 0 c.ws.fs = 0 c.ws.frames = nil c.ws.browser = true bufs = append(bufs, []byte("some smaller buffer")) bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5)) bufs = append(bufs, []byte("then some more")) en = 2 + len(bufs[0]) en += 4 + len(bufs[1]) en += 2 + len(bufs[2]) c.mu.Lock() c.out.nb = bufs res, n = c.collapsePtoNB() c.mu.Unlock() if test.maskingWrite { en += 3 * 4 } if n != int64(en) { t.Fatalf("Expected size to be %v, got %v", en, n) } if len(res) != 6 { t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) } if len(res[3]) != wsFrameSizeForBrowsers-5 { t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) } if test.maskingWrite { key := getKey(res[4]) wsMaskBuf(key, res[5]) } if string(res[5]) != "then some more" { t.Fatalf("Frame 6 incorrect %q", res[5]) } bufs = nil c.out.pb = 0 c.ws.fs = 0 c.ws.frames = nil c.ws.browser = true bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+100)) c.mu.Lock() c.out.nb = bufs res, _ = c.collapsePtoNB() c.mu.Unlock() if len(res) != 4 { t.Fatalf("Unexpected number of frames: %v", len(res)) } }) } } func TestWSWebrowserClient(t *testing.T) { o := testWSOptions() s := RunServer(o) defer s.Shutdown() wsc, br := testWSCreateClient(t, false, true, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() checkClientsCount(t, s, 1) var c *client s.mu.Lock() for _, cli := range s.clients { c = cli break } s.mu.Unlock() proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB foo 1\r\nPING\r\n")) wsc.Write(proto) if res := testWSReadFrame(t, br); !bytes.Equal(res, []byte(pongProto)) { t.Fatalf("Expected PONG back") } c.mu.Lock() ok := c.isWebsocket() && c.ws.browser == true c.mu.Unlock() if !ok { t.Fatalf("Client is not marked as webrowser client") } nc := natsConnect(t, s.ClientURL()) defer nc.Close() // Send a big message and check that it is received in smaller frames psize := 204813 nc.Publish("foo", make([]byte, psize)) nc.Flush() rsize := psize + len(fmt.Sprintf("MSG foo %d\r\n\r\n", psize)) nframes := 0 for total := 0; total < rsize; nframes++ { res := testWSReadFrame(t, br) total += len(res) } if expected := psize / wsFrameSizeForBrowsers; expected > nframes { t.Fatalf("Expected %v frames, got %v", expected, nframes) } } type testWSWrappedConn struct { net.Conn mu sync.RWMutex buf *bytes.Buffer partial bool } func (wc *testWSWrappedConn) Write(p []byte) (int, error) { wc.mu.Lock() defer wc.mu.Unlock() var err error n := len(p) if wc.partial && n > 10 { n = 10 err = io.ErrShortWrite } p = p[:n] wc.buf.Write(p) wc.Conn.Write(p) return n, err } func TestWSCompressionBasic(t *testing.T) { payload := "This is the content of a message that will be compresseddddddddddddddddddddd." msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) cbuf := &bytes.Buffer{} compressor, _ := flate.NewWriter(cbuf, flate.BestSpeed) compressor.Write([]byte(msgProto)) compressor.Flush() compressed := cbuf.Bytes() // The last 4 bytes are dropped compressed = compressed[:len(compressed)-4] o := testWSOptions() o.Websocket.Compression = true s := RunServer(o) defer s.Shutdown() c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port) defer c.Close() proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n")) c.Write(proto) l := testWSReadFrame(t, br) if !bytes.Equal(l, []byte(pongProto)) { t.Fatalf("Expected PONG, got %q", l) } var wc *testWSWrappedConn s.mu.Lock() for _, c := range s.clients { c.mu.Lock() wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}} c.nc = wc c.mu.Unlock() } s.mu.Unlock() nc := natsConnect(t, s.ClientURL()) defer nc.Close() natsPub(t, nc, "foo", []byte(payload)) res := &bytes.Buffer{} for total := 0; total < len(msgProto); { l := testWSReadFrame(t, br) n, _ := res.Write(l) total += n } if !bytes.Equal([]byte(msgProto), res.Bytes()) { t.Fatalf("Unexpected result: %q", res) } // Now check the wrapped connection buffer to check that data was actually compressed. wc.mu.RLock() res = wc.buf wc.mu.RUnlock() if bytes.Contains(res.Bytes(), []byte(payload)) { t.Fatalf("Looks like frame was not compressed: %q", res.Bytes()) } header := res.Bytes()[:2] body := res.Bytes()[2:] expectedB0 := byte(wsBinaryMessage) | wsFinalBit | wsRsv1Bit expectedPS := len(compressed) expectedB1 := byte(expectedPS) if b := header[0]; b != expectedB0 { t.Fatalf("Expected first byte to be %v, got %v", expectedB0, b) } if b := header[1]; b != expectedB1 { t.Fatalf("Expected second byte to be %v, got %v", expectedB1, b) } if len(body) != expectedPS { t.Fatalf("Expected payload length to be %v, got %v", expectedPS, len(body)) } if !bytes.Equal(body, compressed) { t.Fatalf("Unexpected compress body: %q", body) } wc.mu.Lock() wc.buf.Reset() wc.mu.Unlock() payload = "small" natsPub(t, nc, "foo", []byte(payload)) msgProto = fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) res = &bytes.Buffer{} for total := 0; total < len(msgProto); { l := testWSReadFrame(t, br) n, _ := res.Write(l) total += n } if !bytes.Equal([]byte(msgProto), res.Bytes()) { t.Fatalf("Unexpected result: %q", res) } wc.mu.RLock() res = wc.buf wc.mu.RUnlock() if !bytes.HasSuffix(res.Bytes(), []byte(msgProto)) { t.Fatalf("Looks like frame was compressed: %q", res.Bytes()) } } func TestWSCompressionWithPartialWrite(t *testing.T) { payload := "This is the content of a message that will be compresseddddddddddddddddddddd." msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) o := testWSOptions() o.Websocket.Compression = true s := RunServer(o) defer s.Shutdown() c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port) defer c.Close() proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n")) c.Write(proto) l := testWSReadFrame(t, br) if !bytes.Equal(l, []byte(pongProto)) { t.Fatalf("Expected PONG, got %q", l) } pingPayload := []byte("my ping") pingFromWSClient := testWSCreateClientMsg(wsPingMessage, 1, true, false, pingPayload) var wc *testWSWrappedConn var ws *client s.mu.Lock() for _, c := range s.clients { ws = c c.mu.Lock() wc = &testWSWrappedConn{ Conn: c.nc, buf: &bytes.Buffer{}, } c.nc = wc c.mu.Unlock() break } s.mu.Unlock() wc.mu.Lock() wc.partial = true wc.mu.Unlock() nc := natsConnect(t, s.ClientURL()) defer nc.Close() expected := &bytes.Buffer{} for i := 0; i < 10; i++ { if i > 0 { time.Sleep(10 * time.Millisecond) } expected.Write([]byte(msgProto)) natsPub(t, nc, "foo", []byte(payload)) if i == 1 { c.Write(pingFromWSClient) } } var gotPingResponse bool res := &bytes.Buffer{} for total := 0; total < 10*len(msgProto); { l := testWSReadFrame(t, br) if bytes.Equal(l, pingPayload) { gotPingResponse = true } else { n, _ := res.Write(l) total += n } } if !bytes.Equal(expected.Bytes(), res.Bytes()) { t.Fatalf("Unexpected result: %q", res) } if !gotPingResponse { t.Fatal("Did not get the ping response") } checkFor(t, time.Second, 15*time.Millisecond, func() error { ws.mu.Lock() pb := ws.out.pb wf := ws.ws.frames fs := ws.ws.fs ws.mu.Unlock() if pb != 0 || len(wf) != 0 || fs != 0 { return fmt.Errorf("Expected pb, wf and fs to be 0, got %v, %v, %v", pb, wf, fs) } return nil }) } func TestWSCompressionFrameSizeLimit(t *testing.T) { for _, test := range []struct { name string maskWrite bool noLimit bool }{ {"no write masking", false, false}, {"write masking", true, false}, } { t.Run(test.name, func(t *testing.T) { opts := testWSOptions() opts.MaxPending = MAX_PENDING_SIZE s := &Server{opts: opts} c := &client{srv: s, ws: &websocket{compress: true, browser: true, nocompfrag: test.noLimit, maskwrite: test.maskWrite}} c.initClient() uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers) for i := 0; i < len(uncompressedPayload); i++ { uncompressedPayload[i] = byte(rand.Intn(256)) } c.mu.Lock() c.out.nb = append(net.Buffers(nil), uncompressedPayload) nb, _ := c.collapsePtoNB() c.mu.Unlock() if test.noLimit && len(nb) != 2 { t.Fatalf("There should be only 2 buffers, the header and payload, got %v", len(nb)) } bb := &bytes.Buffer{} var key []byte for i, b := range nb { if !test.noLimit { // frame header buffer are always very small. The payload should not be more // than 10 bytes since that is what we passed as the limit. if len(b) > wsFrameSizeForBrowsers { t.Fatalf("Frame size too big: %v (%q)", len(b), b) } } if test.maskWrite { if i%2 == 0 { key = b[len(b)-4:] } else { wsMaskBuf(key, b) } } // Check frame headers for the proper formatting. if i%2 == 0 { // Only the first frame should have the compress bit set. if b[0]&wsRsv1Bit != 0 { if i > 0 { t.Fatalf("Compressed bit should not be in continuation frame") } } else if i == 0 { t.Fatalf("Compressed bit missing") } } else { if test.noLimit { // Since the payload is likely not well compressed, we are expecting // the length to be > wsFrameSizeForBrowsers if len(b) <= wsFrameSizeForBrowsers { t.Fatalf("Expected frame to be bigger, got %v", len(b)) } } // Collect the payload bb.Write(b) } } buf := bb.Bytes() buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) dbr := bytes.NewBuffer(buf) d := flate.NewReader(dbr) uncompressed, err := io.ReadAll(d) if err != nil { t.Fatalf("Error reading frame: %v", err) } if !bytes.Equal(uncompressed, uncompressedPayload) { t.Fatalf("Unexpected uncomressed data: %q", uncompressed) } }) } } func TestWSBasicAuth(t *testing.T) { for _, test := range []struct { name string opts func() *Options user string pass string err string }{ { "top level auth, no override, wrong u/p", func() *Options { o := testWSOptions() o.Username = "normal" o.Password = "client" return o }, "websocket", "client", "-ERR 'Authorization Violation'", }, { "top level auth, no override, correct u/p", func() *Options { o := testWSOptions() o.Username = "normal" o.Password = "client" return o }, "normal", "client", "", }, { "no top level auth, ws auth, wrong u/p", func() *Options { o := testWSOptions() o.Websocket.Username = "websocket" o.Websocket.Password = "client" return o }, "normal", "client", "-ERR 'Authorization Violation'", }, { "no top level auth, ws auth, correct u/p", func() *Options { o := testWSOptions() o.Websocket.Username = "websocket" o.Websocket.Password = "client" return o }, "websocket", "client", "", }, { "top level auth, ws override, wrong u/p", func() *Options { o := testWSOptions() o.Username = "normal" o.Password = "client" o.Websocket.Username = "websocket" o.Websocket.Password = "client" return o }, "normal", "client", "-ERR 'Authorization Violation'", }, { "top level auth, ws override, correct u/p", func() *Options { o := testWSOptions() o.Username = "normal" o.Password = "client" o.Websocket.Username = "websocket" o.Websocket.Password = "client" return o }, "websocket", "client", "", }, } { t.Run(test.name, func(t *testing.T) { o := test.opts() s := RunServer(o) defer s.Shutdown() wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", test.user, test.pass) wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } msg := testWSReadFrame(t, br) if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected to receive PONG, got %q", msg) } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { t.Fatalf("Expected to receive %q, got %q", test.err, msg) } }) } } func TestWSAuthTimeout(t *testing.T) { for _, test := range []struct { name string at float64 wat float64 err string }{ {"use top-level auth timeout", 10.0, 0.0, ""}, {"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"}, } { t.Run(test.name, func(t *testing.T) { o := testWSOptions() o.AuthTimeout = test.at o.Websocket.Username = "websocket" o.Websocket.Password = "client" o.Websocket.AuthTimeout = test.wat s := RunServer(o) defer s.Shutdown() wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() var info serverInfo json.Unmarshal([]byte(l[5:]), &info) // Make sure that we are told that auth is required. if !info.AuthRequired { t.Fatalf("Expected auth required, was not: %q", l) } start := time.Now() // Wait before sending connect time.Sleep(100 * time.Millisecond) connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n" wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } msg := testWSReadFrame(t, br) if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { t.Fatalf("Expected to receive %q error, got %q", test.err, msg) } else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Unexpected error: %q", msg) } if dur := time.Since(start); dur > time.Second { t.Fatalf("Too long to get timeout error: %v", dur) } }) } } func TestWSTokenAuth(t *testing.T) { for _, test := range []struct { name string opts func() *Options token string err string }{ { "top level auth, no override, wrong token", func() *Options { o := testWSOptions() o.Authorization = "goodtoken" return o }, "badtoken", "-ERR 'Authorization Violation'", }, { "top level auth, no override, correct token", func() *Options { o := testWSOptions() o.Authorization = "goodtoken" return o }, "goodtoken", "", }, { "no top level auth, ws auth, wrong token", func() *Options { o := testWSOptions() o.Websocket.Token = "goodtoken" return o }, "badtoken", "-ERR 'Authorization Violation'", }, { "no top level auth, ws auth, correct token", func() *Options { o := testWSOptions() o.Websocket.Token = "goodtoken" return o }, "goodtoken", "", }, { "top level auth, ws override, wrong token", func() *Options { o := testWSOptions() o.Authorization = "clienttoken" o.Websocket.Token = "websockettoken" return o }, "clienttoken", "-ERR 'Authorization Violation'", }, { "top level auth, ws override, correct token", func() *Options { o := testWSOptions() o.Authorization = "clienttoken" o.Websocket.Token = "websockettoken" return o }, "websockettoken", "", }, } { t.Run(test.name, func(t *testing.T) { o := test.opts() s := RunServer(o) defer s.Shutdown() wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n", test.token) wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } msg := testWSReadFrame(t, br) if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected to receive PONG, got %q", msg) } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { t.Fatalf("Expected to receive %q, got %q", test.err, msg) } }) } } func TestWSBindToProperAccount(t *testing.T) { conf := createConfFile(t, []byte(fmt.Sprintf(` listen: "127.0.0.1:-1" accounts { a { users [ {user: a, password: pwd, allowed_connection_types: ["%s", "%s"]} ] } b { users [ {user: b, password: pwd} ] } } websocket { listen: "127.0.0.1:-1" no_tls: true } `, jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)))) // on purpose use lower case to ensure that it is converted. s, o := RunServerWithConfig(conf) defer s.Shutdown() nc := natsConnect(t, fmt.Sprintf("nats://a:pwd@127.0.0.1:%d", o.Port)) defer nc.Close() sub := natsSubSync(t, nc, "foo") wsc, br, _ := testNewWSClient(t, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port, noTLS: true}) // Send CONNECT and PING wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", "a", "pwd"))) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } // Wait for the PONG if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected PONG, got %s", msg) } wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n")) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } natsNexMsg(t, sub, time.Second) } func TestWSUsersAuth(t *testing.T) { users := []*User{{Username: "user", Password: "pwd"}} for _, test := range []struct { name string opts func() *Options user string pass string err string }{ { "no filtering, wrong user", func() *Options { o := testWSOptions() o.Users = users return o }, "wronguser", "pwd", "-ERR 'Authorization Violation'", }, { "no filtering, correct user", func() *Options { o := testWSOptions() o.Users = users return o }, "user", "pwd", "", }, { "filering, user not allowed", func() *Options { o := testWSOptions() o.Users = users // Only allowed for regular clients o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}) return o }, "user", "pwd", "-ERR 'Authorization Violation'", }, { "filtering, user allowed", func() *Options { o := testWSOptions() o.Users = users o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) return o }, "user", "pwd", "", }, { "filtering, wrong password", func() *Options { o := testWSOptions() o.Users = users o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) return o }, "user", "badpassword", "-ERR 'Authorization Violation'", }, } { t.Run(test.name, func(t *testing.T) { o := test.opts() s := RunServer(o) defer s.Shutdown() wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", test.user, test.pass) wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } msg := testWSReadFrame(t, br) if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected to receive PONG, got %q", msg) } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { t.Fatalf("Expected to receive %q, got %q", test.err, msg) } }) } } func TestWSNoAuthUserValidation(t *testing.T) { o := testWSOptions() o.Users = []*User{{Username: "user", Password: "pwd"}} // Should fail because it is not part of o.Users. o.Websocket.NoAuthUser = "notfound" if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { t.Fatalf("Expected error saying not present as user, got %v", err) } // Set a valid no auth user for global options, but still should fail because // of o.Websocket.NoAuthUser o.NoAuthUser = "user" o.Websocket.NoAuthUser = "notfound" if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { t.Fatalf("Expected error saying not present as user, got %v", err) } } func TestWSNoAuthUser(t *testing.T) { for _, test := range []struct { name string override bool useAuth bool expectedUser string expectedAcc string }{ {"no override, no user provided", false, false, "noauth", "normal"}, {"no override, user povided", false, true, "user", "normal"}, {"override, no user provided", true, false, "wsnoauth", "websocket"}, {"override, user provided", true, true, "wsuser", "websocket"}, } { t.Run(test.name, func(t *testing.T) { o := testWSOptions() normalAcc := NewAccount("normal") websocketAcc := NewAccount("websocket") o.Accounts = []*Account{normalAcc, websocketAcc} o.Users = []*User{ {Username: "noauth", Password: "pwd", Account: normalAcc}, {Username: "user", Password: "pwd", Account: normalAcc}, {Username: "wsnoauth", Password: "pwd", Account: websocketAcc}, {Username: "wsuser", Password: "pwd", Account: websocketAcc}, } o.NoAuthUser = "noauth" if test.override { o.Websocket.NoAuthUser = "wsnoauth" } s := RunServer(o) defer s.Shutdown() wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() var info serverInfo json.Unmarshal([]byte(l[5:]), &info) var connectProto string if test.useAuth { connectProto = fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"pwd\"}\r\nPING\r\n", test.expectedUser) } else { connectProto = "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n" } wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } msg := testWSReadFrame(t, br) if !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Unexpected error: %q", msg) } c := s.getClient(info.CID) c.mu.Lock() uname := c.opts.Username aname := c.acc.GetName() c.mu.Unlock() if uname != test.expectedUser { t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname) } if aname != test.expectedAcc { t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname) } }) } } func TestWSNkeyAuth(t *testing.T) { nkp, _ := nkeys.CreateUser() pub, _ := nkp.PublicKey() wsnkp, _ := nkeys.CreateUser() wspub, _ := wsnkp.PublicKey() badkp, _ := nkeys.CreateUser() badpub, _ := badkp.PublicKey() for _, test := range []struct { name string opts func() *Options nkey string kp nkeys.KeyPair err string }{ { "no filtering, wrong nkey", func() *Options { o := testWSOptions() o.Nkeys = []*NkeyUser{{Nkey: pub}} return o }, badpub, badkp, "-ERR 'Authorization Violation'", }, { "no filtering, correct nkey", func() *Options { o := testWSOptions() o.Nkeys = []*NkeyUser{{Nkey: pub}} return o }, pub, nkp, "", }, { "filtering, nkey not allowed", func() *Options { o := testWSOptions() o.Nkeys = []*NkeyUser{ { Nkey: pub, AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}), }, { Nkey: wspub, AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeWebsocket}), }, } return o }, pub, nkp, "-ERR 'Authorization Violation'", }, { "filtering, correct nkey", func() *Options { o := testWSOptions() o.Nkeys = []*NkeyUser{ {Nkey: pub}, { Nkey: wspub, AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}), }, } return o }, wspub, wsnkp, "", }, { "filtering, wrong nkey", func() *Options { o := testWSOptions() o.Nkeys = []*NkeyUser{ { Nkey: wspub, AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}), }, } return o }, badpub, badkp, "-ERR 'Authorization Violation'", }, } { t.Run(test.name, func(t *testing.T) { o := test.opts() s := RunServer(o) defer s.Shutdown() wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) defer wsc.Close() // Sign Nonce var info nonceInfo json.Unmarshal([]byte(infoMsg[5:]), &info) sigraw, _ := test.kp.Sign([]byte(info.Nonce)) sig := base64.RawURLEncoding.EncodeToString(sigraw) connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig) wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) if _, err := wsc.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } msg := testWSReadFrame(t, br) if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected to receive PONG, got %q", msg) } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { t.Fatalf("Expected to receive %q, got %q", test.err, msg) } }) } } func TestWSJWTWithAllowedConnectionTypes(t *testing.T) { o := testWSOptions() setupAddTrusted(o) s := RunServer(o) buildMemAccResolver(s) defer s.Shutdown() for _, test := range []struct { name string connectionTypes []string expectedAnswer string }{ {"not allowed", []string{jwt.ConnectionTypeStandard}, "-ERR"}, {"allowed", []string{jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)}, "+OK"}, {"allowed with unknown", []string{jwt.ConnectionTypeWebsocket, "SomeNewType"}, "+OK"}, {"not allowed with unknown", []string{"SomeNewType"}, "-ERR"}, } { t.Run(test.name, func(t *testing.T) { nuc := newJWTTestUserClaims() nuc.AllowedConnectionTypes = test.connectionTypes claimOpt := testClaimsOptions{ nuc: nuc, expectAnswer: test.expectedAnswer, } _, c, _, _ := testWSWithClaims(t, s, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port}, claimOpt) c.Close() }) } } func TestWSJWTCookieUser(t *testing.T) { nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() } nucBearerFunc := func() *jwt.UserClaims { ret := newJWTTestUserClaims() ret.BearerToken = true return ret } o := testWSOptions() setupAddTrusted(o) setupAddCookie(o) s := RunServer(o) buildMemAccResolver(s) defer s.Shutdown() genJwt := func(t *testing.T, nuc *jwt.UserClaims) string { okp, _ := nkeys.FromSeed(oSeed) akp, _ := nkeys.CreateAccount() apub, _ := akp.PublicKey() nac := jwt.NewAccountClaims(apub) ajwt, err := nac.Encode(okp) if err != nil { t.Fatalf("Error generating account JWT: %v", err) } nkp, _ := nkeys.CreateUser() pub, _ := nkp.PublicKey() nuc.Subject = pub jwt, err := nuc.Encode(akp) if err != nil { t.Fatalf("Error generating user JWT: %v", err) } addAccountToMemResolver(s, apub, ajwt) return jwt } cliOpts := testWSClientOptions{ host: o.Websocket.Host, port: o.Websocket.Port, } for _, test := range []struct { name string nuc *jwt.UserClaims opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) expectAnswer string }{ { name: "protocol auth, non-bearer key, with signature", nuc: nucSigFunc(), opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { return cliOpts, testClaimsOptions{nuc: claims} }, expectAnswer: "+OK", }, { name: "protocol auth, non-bearer key, w/o required signature", nuc: nucSigFunc(), opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} }, expectAnswer: "-ERR", }, { name: "protocol auth, bearer key, w/o signature", nuc: nucBearerFunc(), opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} }, expectAnswer: "+OK", }, { name: "cookie auth, non-bearer key, protocol auth fail", nuc: nucSigFunc(), opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { co := cliOpts co.extraHeaders = map[string][]string{} co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} return co, testClaimsOptions{connectRequest: struct{}{}} }, expectAnswer: "-ERR", }, { name: "cookie auth, bearer key, protocol auth success with implied cookie jwt", nuc: nucBearerFunc(), opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { co := cliOpts co.extraHeaders = map[string][]string{} co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} return co, testClaimsOptions{connectRequest: struct{}{}} }, expectAnswer: "+OK", }, { name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts", nuc: nucSigFunc(), opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { co := cliOpts co.extraHeaders = map[string][]string{} co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} return co, testClaimsOptions{nuc: nucBearerFunc()} }, expectAnswer: "+OK", }, } { t.Run(test.name, func(t *testing.T) { cliOpt, claimOpt := test.opts(t, test.nuc) claimOpt.expectAnswer = test.expectAnswer _, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt) c.Close() }) } s.Shutdown() } func TestWSReloadTLSConfig(t *testing.T) { template := ` listen: "127.0.0.1:-1" websocket { listen: "127.0.0.1:-1" tls { cert_file: '%s' key_file: '%s' ca_file: '../test/configs/certs/ca.pem' } } ` conf := createConfFile(t, []byte(fmt.Sprintf(template, "../test/configs/certs/server-noip.pem", "../test/configs/certs/server-key-noip.pem"))) s, o := RunServerWithConfig(conf) defer s.Shutdown() addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) wsc, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) } defer wsc.Close() tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"} tlsConfig, err := GenTLSConfig(tc) if err != nil { t.Fatalf("Error generating TLS config: %v", err) } tlsConfig.ServerName = "127.0.0.1" tlsConfig.RootCAs = tlsConfig.ClientCAs tlsConfig.ClientCAs = nil wsc = tls.Client(wsc, tlsConfig.Clone()) if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") { t.Fatalf("Unexpected error: %v", err) } wsc.Close() reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, "../test/configs/certs/server-cert.pem", "../test/configs/certs/server-key.pem")) wsc, err = net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) } defer wsc.Close() wsc = tls.Client(wsc, tlsConfig.Clone()) if err := wsc.(*tls.Conn).Handshake(); err != nil { t.Fatalf("Error on TLS handshake: %v", err) } } type captureClientConnectedLogger struct { DummyLogger ch chan string } func (l *captureClientConnectedLogger) Debugf(format string, v ...interface{}) { msg := fmt.Sprintf(format, v...) if !strings.Contains(msg, "Client connection created") { return } select { case l.ch <- msg: default: } } func TestWSXForwardedFor(t *testing.T) { o := testWSOptions() s := RunServer(o) defer s.Shutdown() l := &captureClientConnectedLogger{ch: make(chan string, 1)} s.SetLogger(l, true, false) for _, test := range []struct { name string headers func() map[string][]string useHdrValue bool expectedValue string }{ {"nil map", func() map[string][]string { return nil }, false, _EMPTY_}, {"empty map", func() map[string][]string { return make(map[string][]string) }, false, _EMPTY_}, {"header present empty value", func() map[string][]string { m := make(map[string][]string) m[wsXForwardedForHeader] = []string{} return m }, false, _EMPTY_}, {"header present invalid IP", func() map[string][]string { m := make(map[string][]string) m[wsXForwardedForHeader] = []string{"not a valid IP"} return m }, false, _EMPTY_}, {"header present one IP", func() map[string][]string { m := make(map[string][]string) m[wsXForwardedForHeader] = []string{"1.2.3.4"} return m }, true, "1.2.3.4"}, {"header present multiple IPs", func() map[string][]string { m := make(map[string][]string) m[wsXForwardedForHeader] = []string{"1.2.3.4", "5.6.7.8"} return m }, true, "1.2.3.4"}, {"header present IPv6", func() map[string][]string { m := make(map[string][]string) m[wsXForwardedForHeader] = []string{"::1"} return m }, true, "[::1]"}, } { t.Run(test.name, func(t *testing.T) { c, r, _ := testNewWSClient(t, testWSClientOptions{ host: o.Websocket.Host, port: o.Websocket.Port, extraHeaders: test.headers(), }) defer c.Close() // Send CONNECT and PING wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) if _, err := c.Write(wsmsg); err != nil { t.Fatalf("Error sending message: %v", err) } // Wait for the PONG if msg := testWSReadFrame(t, r); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { t.Fatalf("Expected PONG, got %s", msg) } select { case d := <-l.ch: ipAndSlash := fmt.Sprintf("%s/", test.expectedValue) if test.useHdrValue { if !strings.HasPrefix(d, ipAndSlash) { t.Fatalf("Expected debug statement to start with: %q, got %q", ipAndSlash, d) } } else if strings.HasPrefix(d, ipAndSlash) { t.Fatalf("Unexpected debug statement: %q", d) } case <-time.After(time.Second): t.Fatal("Did not get connect debug statement") } }) } } // ================================================================== // = Benchmark tests // ================================================================== const testWSBenchSubject = "a" var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()") func sizedString(sz int) string { b := make([]byte, sz) for i := range b { b[i] = ch[rand.Intn(len(ch))] } return string(b) } func sizedStringForCompression(sz int) string { b := make([]byte, sz) c := byte(0) s := 0 for i := range b { if s%20 == 0 { c = ch[rand.Intn(len(ch))] } b[i] = c } return string(b) } func testWSFlushConn(b *testing.B, compress bool, c net.Conn, br *bufio.Reader) { buf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(pingProto)) c.Write(buf) c.SetReadDeadline(time.Now().Add(5 * time.Second)) res := testWSReadFrame(b, br) c.SetReadDeadline(time.Time{}) if !bytes.HasPrefix(res, []byte(pongProto)) { b.Fatalf("Failed read of PONG: %s\n", res) } } func wsBenchPub(b *testing.B, numPubs int, compress bool, payload string) { b.StopTimer() opts := testWSOptions() opts.Websocket.Compression = compress s := RunServer(opts) defer s.Shutdown() n := b.N extra := 0 pubProto := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", testWSBenchSubject, len(payload), payload)) singleOpBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, pubProto) // Simulate client that would buffer messages before framing/sending. // Figure out how many we can fit in one frame based on b.N and length of pubProto const bufSize = 32768 tmpa := [bufSize]byte{} tmp := tmpa[:0] pb := 0 for i := 0; i < b.N; i++ { tmp = append(tmp, pubProto...) pb++ if len(tmp) >= bufSize { break } } sendBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, tmp) n = b.N / pb extra = b.N - (n * pb) wg := sync.WaitGroup{} wg.Add(numPubs) type pub struct { c net.Conn br *bufio.Reader bw *bufio.Writer } var pubs []pub for i := 0; i < numPubs; i++ { wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port) defer wsc.Close() bw := bufio.NewWriterSize(wsc, bufSize) pubs = append(pubs, pub{wsc, br, bw}) } // Average the amount of bytes sent by iteration avg := len(sendBuf) / pb if extra > 0 { avg += len(singleOpBuf) avg /= 2 } b.SetBytes(int64(numPubs * avg)) b.StartTimer() for i := 0; i < numPubs; i++ { p := pubs[i] go func(p pub) { defer wg.Done() for i := 0; i < n; i++ { p.bw.Write(sendBuf) } for i := 0; i < extra; i++ { p.bw.Write(singleOpBuf) } p.bw.Flush() testWSFlushConn(b, compress, p.c, p.br) }(p) } wg.Wait() b.StopTimer() } func Benchmark_WS_Pubx1_CN_____0b(b *testing.B) { wsBenchPub(b, 1, false, "") } func Benchmark_WS_Pubx1_CY_____0b(b *testing.B) { wsBenchPub(b, 1, true, "") } func Benchmark_WS_Pubx1_CN___128b(b *testing.B) { s := sizedString(128) wsBenchPub(b, 1, false, s) } func Benchmark_WS_Pubx1_CY___128b(b *testing.B) { s := sizedStringForCompression(128) wsBenchPub(b, 1, true, s) } func Benchmark_WS_Pubx1_CN__1024b(b *testing.B) { s := sizedString(1024) wsBenchPub(b, 1, false, s) } func Benchmark_WS_Pubx1_CY__1024b(b *testing.B) { s := sizedStringForCompression(1024) wsBenchPub(b, 1, true, s) } func Benchmark_WS_Pubx1_CN__4096b(b *testing.B) { s := sizedString(4 * 1024) wsBenchPub(b, 1, false, s) } func Benchmark_WS_Pubx1_CY__4096b(b *testing.B) { s := sizedStringForCompression(4 * 1024) wsBenchPub(b, 1, true, s) } func Benchmark_WS_Pubx1_CN__8192b(b *testing.B) { s := sizedString(8 * 1024) wsBenchPub(b, 1, false, s) } func Benchmark_WS_Pubx1_CY__8192b(b *testing.B) { s := sizedStringForCompression(8 * 1024) wsBenchPub(b, 1, true, s) } func Benchmark_WS_Pubx1_CN_32768b(b *testing.B) { s := sizedString(32 * 1024) wsBenchPub(b, 1, false, s) } func Benchmark_WS_Pubx1_CY_32768b(b *testing.B) { s := sizedStringForCompression(32 * 1024) wsBenchPub(b, 1, true, s) } func Benchmark_WS_Pubx5_CN_____0b(b *testing.B) { wsBenchPub(b, 5, false, "") } func Benchmark_WS_Pubx5_CY_____0b(b *testing.B) { wsBenchPub(b, 5, true, "") } func Benchmark_WS_Pubx5_CN___128b(b *testing.B) { s := sizedString(128) wsBenchPub(b, 5, false, s) } func Benchmark_WS_Pubx5_CY___128b(b *testing.B) { s := sizedStringForCompression(128) wsBenchPub(b, 5, true, s) } func Benchmark_WS_Pubx5_CN__1024b(b *testing.B) { s := sizedString(1024) wsBenchPub(b, 5, false, s) } func Benchmark_WS_Pubx5_CY__1024b(b *testing.B) { s := sizedStringForCompression(1024) wsBenchPub(b, 5, true, s) } func Benchmark_WS_Pubx5_CN__4096b(b *testing.B) { s := sizedString(4 * 1024) wsBenchPub(b, 5, false, s) } func Benchmark_WS_Pubx5_CY__4096b(b *testing.B) { s := sizedStringForCompression(4 * 1024) wsBenchPub(b, 5, true, s) } func Benchmark_WS_Pubx5_CN__8192b(b *testing.B) { s := sizedString(8 * 1024) wsBenchPub(b, 5, false, s) } func Benchmark_WS_Pubx5_CY__8192b(b *testing.B) { s := sizedStringForCompression(8 * 1024) wsBenchPub(b, 5, true, s) } func Benchmark_WS_Pubx5_CN_32768b(b *testing.B) { s := sizedString(32 * 1024) wsBenchPub(b, 5, false, s) } func Benchmark_WS_Pubx5_CY_32768b(b *testing.B) { s := sizedStringForCompression(32 * 1024) wsBenchPub(b, 5, true, s) } func wsBenchSub(b *testing.B, numSubs int, compress bool, payload string) { b.StopTimer() opts := testWSOptions() opts.Websocket.Compression = compress s := RunServer(opts) defer s.Shutdown() var subs []*bufio.Reader for i := 0; i < numSubs; i++ { wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port) defer wsc.Close() subProto := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(fmt.Sprintf("SUB %s 1\r\nPING\r\n", testWSBenchSubject))) wsc.Write(subProto) // Waiting for PONG testWSReadFrame(b, br) subs = append(subs, br) } wg := sync.WaitGroup{} wg.Add(numSubs) // Use regular NATS client to publish messages nc := natsConnect(b, s.ClientURL()) defer nc.Close() b.StartTimer() for i := 0; i < numSubs; i++ { br := subs[i] go func(br *bufio.Reader) { defer wg.Done() for count := 0; count < b.N; { msgs := testWSReadFrame(b, br) count += bytes.Count(msgs, []byte("MSG ")) } }(br) } for i := 0; i < b.N; i++ { natsPub(b, nc, testWSBenchSubject, []byte(payload)) } wg.Wait() b.StopTimer() } func Benchmark_WS_Subx1_CN_____0b(b *testing.B) { wsBenchSub(b, 1, false, "") } func Benchmark_WS_Subx1_CY_____0b(b *testing.B) { wsBenchSub(b, 1, true, "") } func Benchmark_WS_Subx1_CN___128b(b *testing.B) { s := sizedString(128) wsBenchSub(b, 1, false, s) } func Benchmark_WS_Subx1_CY___128b(b *testing.B) { s := sizedStringForCompression(128) wsBenchSub(b, 1, true, s) } func Benchmark_WS_Subx1_CN__1024b(b *testing.B) { s := sizedString(1024) wsBenchSub(b, 1, false, s) } func Benchmark_WS_Subx1_CY__1024b(b *testing.B) { s := sizedStringForCompression(1024) wsBenchSub(b, 1, true, s) } func Benchmark_WS_Subx1_CN__4096b(b *testing.B) { s := sizedString(4096) wsBenchSub(b, 1, false, s) } func Benchmark_WS_Subx1_CY__4096b(b *testing.B) { s := sizedStringForCompression(4096) wsBenchSub(b, 1, true, s) } func Benchmark_WS_Subx1_CN__8192b(b *testing.B) { s := sizedString(8192) wsBenchSub(b, 1, false, s) } func Benchmark_WS_Subx1_CY__8192b(b *testing.B) { s := sizedStringForCompression(8192) wsBenchSub(b, 1, true, s) } func Benchmark_WS_Subx1_CN_32768b(b *testing.B) { s := sizedString(32768) wsBenchSub(b, 1, false, s) } func Benchmark_WS_Subx1_CY_32768b(b *testing.B) { s := sizedStringForCompression(32768) wsBenchSub(b, 1, true, s) } func Benchmark_WS_Subx5_CN_____0b(b *testing.B) { wsBenchSub(b, 5, false, "") } func Benchmark_WS_Subx5_CY_____0b(b *testing.B) { wsBenchSub(b, 5, true, "") } func Benchmark_WS_Subx5_CN___128b(b *testing.B) { s := sizedString(128) wsBenchSub(b, 5, false, s) } func Benchmark_WS_Subx5_CY___128b(b *testing.B) { s := sizedStringForCompression(128) wsBenchSub(b, 5, true, s) } func Benchmark_WS_Subx5_CN__1024b(b *testing.B) { s := sizedString(1024) wsBenchSub(b, 5, false, s) } func Benchmark_WS_Subx5_CY__1024b(b *testing.B) { s := sizedStringForCompression(1024) wsBenchSub(b, 5, true, s) } func Benchmark_WS_Subx5_CN__4096b(b *testing.B) { s := sizedString(4096) wsBenchSub(b, 5, false, s) } func Benchmark_WS_Subx5_CY__4096b(b *testing.B) { s := sizedStringForCompression(4096) wsBenchSub(b, 5, true, s) } func Benchmark_WS_Subx5_CN__8192b(b *testing.B) { s := sizedString(8192) wsBenchSub(b, 5, false, s) } func Benchmark_WS_Subx5_CY__8192b(b *testing.B) { s := sizedStringForCompression(8192) wsBenchSub(b, 5, true, s) } func Benchmark_WS_Subx5_CN_32768b(b *testing.B) { s := sizedString(32768) wsBenchSub(b, 5, false, s) } func Benchmark_WS_Subx5_CY_32768b(b *testing.B) { s := sizedStringForCompression(32768) wsBenchSub(b, 5, true, s) }