diff --git a/server/websocket.go b/server/websocket.go index 705dc39e..b16edf0a 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -692,33 +692,33 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe // From https://tools.ietf.org/html/rfc6455#section-4.2.1 // Point 1. if r.Method != "GET" { - return nil, wsReturnHTTPError(w, http.StatusMethodNotAllowed, "request method must be GET") + return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET") } // Point 2. if r.Host == _EMPTY_ { - return nil, wsReturnHTTPError(w, http.StatusBadRequest, "'Host' missing in request") + return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request") } // Point 3. if !wsHeaderContains(r.Header, "Upgrade", "websocket") { - return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Upgrade'") + return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'") } // Point 4. if !wsHeaderContains(r.Header, "Connection", "Upgrade") { - return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Connection'") + return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'") } // Point 5. key := r.Header.Get("Sec-Websocket-Key") if key == _EMPTY_ { - return nil, wsReturnHTTPError(w, http.StatusBadRequest, "key missing") + return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing") } // Point 6. if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") { - return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid version") + return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version") } // Others are optional // Point 7. if err := s.websocket.checkOrigin(r); err != nil { - return nil, wsReturnHTTPError(w, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err)) + return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err)) } // Point 8. // We don't have protocols, so ignore. @@ -738,11 +738,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe if conn != nil { conn.Close() } - return nil, wsReturnHTTPError(w, http.StatusInternalServerError, err.Error()) + return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error()) } if brw.Reader.Buffered() > 0 { conn.Close() - return nil, wsReturnHTTPError(w, http.StatusBadRequest, "client sent data before handshake is complete") + return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete") } var buf [1024]byte @@ -841,8 +841,8 @@ func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) { // Send an HTTP error with the given `status`` to the given http response writer `w`. // Return an error created based on the `reason` string. -func wsReturnHTTPError(w http.ResponseWriter, status int, reason string) error { - err := fmt.Errorf("websocket handshake error: %s", reason) +func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error { + err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason) w.Header().Set("Sec-Websocket-Version", "13") http.Error(w, http.StatusText(status), status) return err diff --git a/server/websocket_test.go b/server/websocket_test.go index 92f032a3..8c0c2d04 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -2382,7 +2382,7 @@ func TestWSServerReportUpgradeFailure(t *testing.T) { logger := &captureErrorLogger{errCh: make(chan string, 1)} s.SetLogger(logger, false, false) - addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) + addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) req := testWSCreateValidReq() req.URL, _ = url.Parse("wss://" + addr) @@ -2417,6 +2417,11 @@ func TestWSServerReportUpgradeFailure(t *testing.T) { if !strings.Contains(e, "invalid value for header 'Connection'") { t.Fatalf("Unexpected error: %v", e) } + // The client IP's local should be printed as a remote from server perspective. + clientIP := wsc.LocalAddr().String() + if !strings.HasPrefix(e, clientIP) { + t.Fatalf("IP should have been logged, it was not: %v", e) + } case <-time.After(time.Second): t.Fatalf("Should have timed-out") }