mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
[IMPROVED] Websocket: Add client IP in websocket upgrade failures
The error message would now look like this: ``` [8672] 2021/11/01 10:56:50.251985 [ERR] [::1]:59279 - websocket handshake error: invalid value for header 'Upgrade' ``` (without this change the part `[::1]:59279 - ` would not be present) Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
This commit is contained in:
@@ -692,33 +692,33 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
|
||||
// From https://tools.ietf.org/html/rfc6455#section-4.2.1
|
||||
// Point 1.
|
||||
if r.Method != "GET" {
|
||||
return nil, wsReturnHTTPError(w, http.StatusMethodNotAllowed, "request method must be GET")
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET")
|
||||
}
|
||||
// Point 2.
|
||||
if r.Host == _EMPTY_ {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "'Host' missing in request")
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request")
|
||||
}
|
||||
// Point 3.
|
||||
if !wsHeaderContains(r.Header, "Upgrade", "websocket") {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Upgrade'")
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'")
|
||||
}
|
||||
// Point 4.
|
||||
if !wsHeaderContains(r.Header, "Connection", "Upgrade") {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Connection'")
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'")
|
||||
}
|
||||
// Point 5.
|
||||
key := r.Header.Get("Sec-Websocket-Key")
|
||||
if key == _EMPTY_ {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "key missing")
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing")
|
||||
}
|
||||
// Point 6.
|
||||
if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") {
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid version")
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version")
|
||||
}
|
||||
// Others are optional
|
||||
// Point 7.
|
||||
if err := s.websocket.checkOrigin(r); err != nil {
|
||||
return nil, wsReturnHTTPError(w, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
|
||||
}
|
||||
// Point 8.
|
||||
// We don't have protocols, so ignore.
|
||||
@@ -738,11 +738,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return nil, wsReturnHTTPError(w, http.StatusInternalServerError, err.Error())
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if brw.Reader.Buffered() > 0 {
|
||||
conn.Close()
|
||||
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "client sent data before handshake is complete")
|
||||
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete")
|
||||
}
|
||||
|
||||
var buf [1024]byte
|
||||
@@ -841,8 +841,8 @@ func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) {
|
||||
|
||||
// Send an HTTP error with the given `status`` to the given http response writer `w`.
|
||||
// Return an error created based on the `reason` string.
|
||||
func wsReturnHTTPError(w http.ResponseWriter, status int, reason string) error {
|
||||
err := fmt.Errorf("websocket handshake error: %s", reason)
|
||||
func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error {
|
||||
err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason)
|
||||
w.Header().Set("Sec-Websocket-Version", "13")
|
||||
http.Error(w, http.StatusText(status), status)
|
||||
return err
|
||||
|
||||
@@ -2382,7 +2382,7 @@ func TestWSServerReportUpgradeFailure(t *testing.T) {
|
||||
logger := &captureErrorLogger{errCh: make(chan string, 1)}
|
||||
s.SetLogger(logger, false, false)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
|
||||
req := testWSCreateValidReq()
|
||||
req.URL, _ = url.Parse("wss://" + addr)
|
||||
|
||||
@@ -2417,6 +2417,11 @@ func TestWSServerReportUpgradeFailure(t *testing.T) {
|
||||
if !strings.Contains(e, "invalid value for header 'Connection'") {
|
||||
t.Fatalf("Unexpected error: %v", e)
|
||||
}
|
||||
// The client IP's local should be printed as a remote from server perspective.
|
||||
clientIP := wsc.LocalAddr().String()
|
||||
if !strings.HasPrefix(e, clientIP) {
|
||||
t.Fatalf("IP should have been logged, it was not: %v", e)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("Should have timed-out")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user