1
0
mirror of https://github.com/taigrr/log-socket synced 2026-04-10 12:11:23 -07:00

1 Commits

Author SHA1 Message Date
9384170eb0 feat(ws): add read pump for disconnect detection, GetContext, Level.String
- Add WebSocket read pump in LogSocketHandler so client disconnects are
  detected promptly instead of only on the next WriteMessage call
- Add Client.GetContext(ctx) for context-aware blocking reads
- Implement fmt.Stringer on Level type (Level.String())
- Add tests for GetContext, GetContext cancellation, and Level.String()
2026-04-07 10:08:50 +00:00
5 changed files with 147 additions and 16 deletions

View File

@@ -1,6 +1,7 @@
package log package log
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@@ -153,6 +154,7 @@ func (c *Client) SetLogLevel(level Level) {
c.LogLevel = level c.LogLevel = level
} }
// Get blocks until a log entry is available and returns it.
func (c *Client) Get() Entry { func (c *Client) Get() Entry {
if !c.initialized { if !c.initialized {
panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?")) panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?"))
@@ -160,6 +162,21 @@ func (c *Client) Get() Entry {
return <-c.writer return <-c.writer
} }
// GetContext blocks until a log entry is available or ctx is cancelled.
// The second return value is false when the context was cancelled before
// an entry arrived.
func (c *Client) GetContext(ctx context.Context) (Entry, bool) {
if !c.initialized {
panic(errors.New("cannot get logs for uninitialized client, did you use CreateClient?"))
}
select {
case e := <-c.writer:
return e, true
case <-ctx.Done():
return Entry{}, false
}
}
// Trace prints out logs on trace level // Trace prints out logs on trace level
func Trace(args ...any) { func Trace(args ...any) {
output := fmt.Sprint(args...) output := fmt.Sprint(args...)

View File

@@ -1,6 +1,7 @@
package log package log
import ( import (
"context"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
@@ -378,8 +379,8 @@ func TestMultiNamespaceClient(t *testing.T) {
authLogger := NewLogger("auth") authLogger := NewLogger("auth")
dbLogger := NewLogger("database") dbLogger := NewLogger("database")
dbLogger.Info("db message") // filtered out dbLogger.Info("db message") // filtered out
apiLogger.Info("api message") // should arrive apiLogger.Info("api message") // should arrive
authLogger.Info("auth message") // should arrive authLogger.Info("auth message") // should arrive
e1, ok := getEntry(c, time.Second) e1, ok := getEntry(c, time.Second)
@@ -550,6 +551,68 @@ func TestMatchesNamespace(t *testing.T) {
c2.Destroy() c2.Destroy()
} }
// TestGetContext verifies context cancellation stops blocking Get.
func TestGetContext(t *testing.T) {
c := CreateClient(DefaultNamespace)
c.SetLogLevel(LTrace)
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
_, ok := c.GetContext(ctx)
if ok {
t.Error("expected GetContext to return false on cancelled context")
}
c.Destroy()
}
// TestGetContextReceivesEntry verifies GetContext delivers entries normally.
func TestGetContextReceivesEntry(t *testing.T) {
c := CreateClient(DefaultNamespace)
c.SetLogLevel(LTrace)
go func() {
time.Sleep(10 * time.Millisecond)
Info("context entry")
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
e, ok := c.GetContext(ctx)
if !ok {
t.Fatal("expected GetContext to return entry")
}
if e.Output != "context entry" {
t.Errorf("output = %q, want %q", e.Output, "context entry")
}
c.Destroy()
}
// TestLevelString verifies the Level.String() method.
func TestLevelString(t *testing.T) {
tests := []struct {
level Level
want string
}{
{LTrace, "TRACE"},
{LDebug, "DEBUG"},
{LInfo, "INFO"},
{LNotice, "NOTICE"},
{LWarn, "WARN"},
{LError, "ERROR"},
{LPanic, "PANIC"},
{LFatal, "FATAL"},
{Level(99), "UNKNOWN"},
}
for _, tt := range tests {
got := tt.level.String()
if got != tt.want {
t.Errorf("Level(%d).String() = %q, want %q", tt.level, got, tt.want)
}
}
}
func TestFlush(t *testing.T) { func TestFlush(t *testing.T) {
defer Flush() defer Flush()
} }

View File

@@ -15,6 +15,31 @@ const (
const DefaultNamespace = "default" const DefaultNamespace = "default"
// String returns the human-readable name of the log level (e.g. "INFO").
// It implements [fmt.Stringer].
func (l Level) String() string {
switch l {
case LTrace:
return "TRACE"
case LDebug:
return "DEBUG"
case LInfo:
return "INFO"
case LNotice:
return "NOTICE"
case LWarn:
return "WARN"
case LError:
return "ERROR"
case LPanic:
return "PANIC"
case LFatal:
return "FATAL"
default:
return "UNKNOWN"
}
}
type ( type (
LogWriter chan Entry LogWriter chan Entry
Level int Level int

View File

@@ -1,6 +1,7 @@
package ws package ws
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
@@ -11,36 +12,61 @@ import (
var upgrader = websocket.Upgrader{} // use default options var upgrader = websocket.Upgrader{} // use default options
// SetUpgrader replaces the default [websocket.Upgrader] used by
// [LogSocketHandler].
func SetUpgrader(u websocket.Upgrader) { func SetUpgrader(u websocket.Upgrader) {
upgrader = u upgrader = u
} }
// LogSocketHandler upgrades the HTTP connection to a WebSocket and streams
// log entries to the client. An optional "namespaces" query parameter
// (comma-separated) filters which namespaces the client receives.
func LogSocketHandler(w http.ResponseWriter, r *http.Request) { func LogSocketHandler(w http.ResponseWriter, r *http.Request) {
// Get namespaces from query parameter, comma-separated // Get namespaces from query parameter, comma-separated.
// Empty or missing means all namespaces // Empty or missing means all namespaces.
namespacesParam := r.URL.Query().Get("namespaces") namespacesParam := r.URL.Query().Get("namespaces")
var namespaces []string var namespaces []string
if namespacesParam != "" { if namespacesParam != "" {
namespaces = strings.Split(namespacesParam, ",") namespaces = strings.Split(namespacesParam, ",")
} }
c, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
logger.Error("upgrade:", err) logger.Error("upgrade:", err)
return return
} }
defer c.Close() defer conn.Close()
lc := logger.CreateClient(namespaces...) lc := logger.CreateClient(namespaces...)
defer lc.Destroy() defer lc.Destroy()
lc.SetLogLevel(logger.LTrace) lc.SetLogLevel(logger.LTrace)
logger.Info("Websocket client attached.") logger.Info("Websocket client attached.")
// Start a read pump so the server detects client disconnects promptly.
// Without this, a disconnected client is only noticed when WriteMessage
// fails, which can be delayed indefinitely when no logs are produced.
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
go func() {
defer cancel()
for {
if _, _, err := conn.ReadMessage(); err != nil {
return
}
}
}()
for { for {
logEvent := lc.Get() entry, ok := lc.GetContext(ctx)
logJSON, _ := json.Marshal(logEvent) if !ok {
err = c.WriteMessage(websocket.TextMessage, logJSON) // Context cancelled — client disconnected.
if err != nil { return
}
logJSON, _ := json.Marshal(entry)
if err := conn.WriteMessage(websocket.TextMessage, logJSON); err != nil {
logger.Warn("write:", err) logger.Warn("write:", err)
break return
} }
} }
} }