From 38cee1fb4252d963468b987577c3c0f8d79ef153 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Sat, 7 Mar 2026 11:33:59 +0000 Subject: [PATCH] test(ws,browser): add test suites for ws and browser packages - Add ws/server_test.go: WebSocket connection, namespace filtering, non-upgrade rejection, SetUpgrader - Add ws/namespaces_test.go: NamespacesHandler response format and content - Add browser/browser_test.go: template rendering, ws URL construction, trailing slash handling - Fix goimports whitespace in ws/server.go --- browser/browser_test.go | 53 +++++++++++++++++ ws/namespaces_test.go | 62 ++++++++++++++++++++ ws/server.go | 2 +- ws/server_test.go | 127 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 browser/browser_test.go create mode 100644 ws/namespaces_test.go create mode 100644 ws/server_test.go diff --git a/browser/browser_test.go b/browser/browser_test.go new file mode 100644 index 0000000..40b1122 --- /dev/null +++ b/browser/browser_test.go @@ -0,0 +1,53 @@ +package browser + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestLogSocketViewHandler_HTTP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/", nil) + w := httptest.NewRecorder() + LogSocketViewHandler(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + + body := w.Body.String() + // html/template escapes forward slashes in JS context + if !strings.Contains(body, `ws:\/\/localhost:8080\/ws`) { + t.Error("response should contain escaped ws://localhost:8080/ws URL") + } + if !strings.Contains(body, "") { + t.Error("response should contain HTML doctype") + } +} + +func TestLogSocketViewHandler_CustomPath(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://myhost:9090/dashboard/", nil) + w := httptest.NewRecorder() + LogSocketViewHandler(w, req) + + body := w.Body.String() + if !strings.Contains(body, `ws:\/\/myhost:9090\/dashboard\/ws`) { + t.Error("expected escaped ws://myhost:9090/dashboard/ws in body") + } +} + +func TestLogSocketViewHandler_TrailingSlashTrimmed(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + w := httptest.NewRecorder() + LogSocketViewHandler(w, req) + + body := w.Body.String() + // Should NOT have double slash before ws + if strings.Contains(body, `\/\/ws`) { + t.Error("should not have double slash before /ws") + } + if !strings.Contains(body, `ws:\/\/example.com\/ws`) { + t.Error("expected escaped ws://example.com/ws in body") + } +} diff --git a/ws/namespaces_test.go b/ws/namespaces_test.go new file mode 100644 index 0000000..c131e0b --- /dev/null +++ b/ws/namespaces_test.go @@ -0,0 +1,62 @@ +package ws + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + logger "github.com/taigrr/log-socket/v2/log" +) + +func TestNamespacesHandler(t *testing.T) { + // Log to a known namespace to ensure it appears + nsLogger := logger.NewLogger("ns-handler-test") + nsLogger.Info("register namespace") + + req := httptest.NewRequest(http.MethodGet, "/api/namespaces", nil) + w := httptest.NewRecorder() + NamespacesHandler(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + + ct := w.Header().Get("Content-Type") + if ct != "application/json" { + t.Errorf("Content-Type = %q, want application/json", ct) + } + + var result struct { + Namespaces []string `json:"namespaces"` + } + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + found := false + for _, ns := range result.Namespaces { + if ns == "ns-handler-test" { + found = true + break + } + } + if !found { + t.Errorf("namespace 'ns-handler-test' not found in %v", result.Namespaces) + } +} + +func TestNamespacesHandler_ResponseFormat(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/namespaces", nil) + w := httptest.NewRecorder() + NamespacesHandler(w, req) + + var result map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + + if _, ok := result["namespaces"]; !ok { + t.Error("response missing 'namespaces' key") + } +} diff --git a/ws/server.go b/ws/server.go index e7e3a24..4c7027f 100644 --- a/ws/server.go +++ b/ws/server.go @@ -23,7 +23,7 @@ func LogSocketHandler(w http.ResponseWriter, r *http.Request) { if namespacesParam != "" { namespaces = strings.Split(namespacesParam, ",") } - + c, err := upgrader.Upgrade(w, r, nil) if err != nil { logger.Error("upgrade:", err) diff --git a/ws/server_test.go b/ws/server_test.go new file mode 100644 index 0000000..1421fe0 --- /dev/null +++ b/ws/server_test.go @@ -0,0 +1,127 @@ +package ws + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + logger "github.com/taigrr/log-socket/v2/log" +) + +func TestSetUpgrader(t *testing.T) { + custom := websocket.Upgrader{ + ReadBufferSize: 2048, + WriteBufferSize: 2048, + } + SetUpgrader(custom) + if upgrader.ReadBufferSize != 2048 { + t.Errorf("ReadBufferSize = %d, want 2048", upgrader.ReadBufferSize) + } + if upgrader.WriteBufferSize != 2048 { + t.Errorf("WriteBufferSize = %d, want 2048", upgrader.WriteBufferSize) + } + // Reset to default + SetUpgrader(websocket.Upgrader{}) +} + +func TestLogSocketHandler_NonWebSocket(t *testing.T) { + // A non-upgrade request should fail gracefully (upgrader returns error) + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + w := httptest.NewRecorder() + LogSocketHandler(w, req) + // The upgrader should return a 400-level error for non-websocket requests + if w.Code == http.StatusOK || w.Code == http.StatusSwitchingProtocols { + t.Errorf("expected error status for non-websocket request, got %d", w.Code) + } +} + +func TestLogSocketHandler_WebSocket(t *testing.T) { + // Set upgrader with permissive origin check for testing + SetUpgrader(websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + }) + defer SetUpgrader(websocket.Upgrader{}) + + server := httptest.NewServer(http.HandlerFunc(LogSocketHandler)) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + // Send a log entry and verify it arrives over the websocket + testLogger := logger.NewLogger("ws-test") + testLogger.Info("test message for websocket") + + // Read messages until we find our test entry (the handler itself + // logs "Websocket client attached." which may arrive first) + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + var found bool + for i := 0; i < 10; i++ { + _, message, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read message: %v", err) + } + var entry logger.Entry + if err := json.Unmarshal(message, &entry); err != nil { + t.Fatalf("failed to unmarshal entry: %v", err) + } + if entry.Namespace == "ws-test" && entry.Level == "INFO" { + found = true + break + } + } + if !found { + t.Error("did not receive expected log entry with namespace ws-test") + } +} + +func TestLogSocketHandler_NamespaceFilter(t *testing.T) { + SetUpgrader(websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + }) + defer SetUpgrader(websocket.Upgrader{}) + + server := httptest.NewServer(http.HandlerFunc(LogSocketHandler)) + defer server.Close() + + // Connect with namespace filter + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?namespaces=filtered-ns" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + // Send a log to a different namespace — it should NOT be received + otherLogger := logger.NewLogger("other-ns") + otherLogger.Info("should not arrive") + + // Send a log to the filtered namespace — it SHOULD be received + filteredLogger := logger.NewLogger("filtered-ns") + filteredLogger.Info("should arrive") + + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, message, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read message: %v", err) + } + + var entry logger.Entry + if err := json.Unmarshal(message, &entry); err != nil { + t.Fatalf("failed to unmarshal entry: %v", err) + } + if entry.Namespace != "filtered-ns" { + t.Errorf("namespace = %q, want filtered-ns", entry.Namespace) + } + if !strings.Contains(entry.Output, "should arrive") { + t.Errorf("output = %q, want to contain 'should arrive'", entry.Output) + } +}