[IMPROVED] Websocket: Add client IP in websocket upgrade failures

The error message would now look like this:
```
[8672] 2021/11/01 10:56:50.251985 [ERR] [::1]:59279 - websocket handshake error: invalid value for header 'Upgrade'
```

(without this change the part `[::1]:59279 - ` would not be present)

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
This commit is contained in:
Ivan Kozlovic
2021-11-01 10:54:22 -06:00
parent 530ea6a5c3
commit dbfff14d3b
2 changed files with 17 additions and 12 deletions

View File

@@ -692,33 +692,33 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
// From https://tools.ietf.org/html/rfc6455#section-4.2.1
// Point 1.
if r.Method != "GET" {
return nil, wsReturnHTTPError(w, http.StatusMethodNotAllowed, "request method must be GET")
return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET")
}
// Point 2.
if r.Host == _EMPTY_ {
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "'Host' missing in request")
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request")
}
// Point 3.
if !wsHeaderContains(r.Header, "Upgrade", "websocket") {
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Upgrade'")
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'")
}
// Point 4.
if !wsHeaderContains(r.Header, "Connection", "Upgrade") {
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid value for header 'Connection'")
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'")
}
// Point 5.
key := r.Header.Get("Sec-Websocket-Key")
if key == _EMPTY_ {
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "key missing")
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing")
}
// Point 6.
if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") {
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "invalid version")
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version")
}
// Others are optional
// Point 7.
if err := s.websocket.checkOrigin(r); err != nil {
return nil, wsReturnHTTPError(w, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
}
// Point 8.
// We don't have protocols, so ignore.
@@ -738,11 +738,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
if conn != nil {
conn.Close()
}
return nil, wsReturnHTTPError(w, http.StatusInternalServerError, err.Error())
return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error())
}
if brw.Reader.Buffered() > 0 {
conn.Close()
return nil, wsReturnHTTPError(w, http.StatusBadRequest, "client sent data before handshake is complete")
return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete")
}
var buf [1024]byte
@@ -841,8 +841,8 @@ func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) {
// Send an HTTP error with the given `status`` to the given http response writer `w`.
// Return an error created based on the `reason` string.
func wsReturnHTTPError(w http.ResponseWriter, status int, reason string) error {
err := fmt.Errorf("websocket handshake error: %s", reason)
func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error {
err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason)
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
return err

View File

@@ -2382,7 +2382,7 @@ func TestWSServerReportUpgradeFailure(t *testing.T) {
logger := &captureErrorLogger{errCh: make(chan string, 1)}
s.SetLogger(logger, false, false)
addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
req := testWSCreateValidReq()
req.URL, _ = url.Parse("wss://" + addr)
@@ -2417,6 +2417,11 @@ func TestWSServerReportUpgradeFailure(t *testing.T) {
if !strings.Contains(e, "invalid value for header 'Connection'") {
t.Fatalf("Unexpected error: %v", e)
}
// The client IP's local should be printed as a remote from server perspective.
clientIP := wsc.LocalAddr().String()
if !strings.HasPrefix(e, clientIP) {
t.Fatalf("IP should have been logged, it was not: %v", e)
}
case <-time.After(time.Second):
t.Fatalf("Should have timed-out")
}