diff --git a/middlewares/errorpages/error_pages.go b/middlewares/errorpages/error_pages.go index 16da5584f..7480ccdc1 100644 --- a/middlewares/errorpages/error_pages.go +++ b/middlewares/errorpages/error_pages.go @@ -19,7 +19,10 @@ import ( ) // Compile time validation that the response recorder implements http interfaces correctly. -var _ middlewares.Stateful = &responseRecorderWithCloseNotify{} +var ( + _ middlewares.Stateful = &responseRecorderWithCloseNotify{} + _ middlewares.Stateful = &codeCatcherWithCloseNotify{} +) // Handler is a middleware that provides the custom error pages type Handler struct { @@ -76,26 +79,27 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. catcher := newCodeCatcher(w, h.httpCodeRanges) next.ServeHTTP(catcher, req) - if !catcher.isError { + if !catcher.isFilteredCode() { return } // check the recorder code against the configured http status code ranges + code := catcher.getCode() for _, block := range h.httpCodeRanges { - if catcher.code >= block[0] && catcher.code <= block[1] { - log.Errorf("Caught HTTP Status Code %d, returning error page", catcher.code) + if code >= block[0] && code <= block[1] { + log.Errorf("Caught HTTP Status Code %d, returning error page", code) var query string if len(h.backendQuery) > 0 { query = "/" + strings.TrimPrefix(h.backendQuery, "/") - query = strings.Replace(query, "{status}", strconv.Itoa(catcher.code), -1) + query = strings.Replace(query, "{status}", strconv.Itoa(code), -1) } pageReq, err := newRequest(h.backendURL + query) if err != nil { log.Error(err) - w.WriteHeader(catcher.code) - fmt.Fprint(w, http.StatusText(catcher.code)) + w.WriteHeader(code) + fmt.Fprint(w, http.StatusText(code)) return } @@ -105,7 +109,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context())) utils.CopyHeaders(w.Header(), recorderErrorPage.Header()) - w.WriteHeader(catcher.code) + w.WriteHeader(code) w.Write(recorderErrorPage.GetBody().Bytes()) return } @@ -127,21 +131,39 @@ func newRequest(baseURL string) (*http.Request, error) { return req, nil } +type responseInterceptor interface { + http.ResponseWriter + http.Flusher + getCode() int + isFilteredCode() bool +} + // codeCatcher is a response writer that detects as soon as possible whether the // response is a code within the ranges of codes it watches for. If it is, it // simply drops the data from the response. Otherwise, it forwards it directly to // the original client (its responseWriter) without any buffering. type codeCatcher struct { - headerMap http.Header - code int - httpCodeRanges types.HTTPCodeRanges - firstWrite bool - isError bool - responseWriter http.ResponseWriter - err error + headerMap http.Header + code int + httpCodeRanges types.HTTPCodeRanges + firstWrite bool + caughtFilteredCode bool + responseWriter http.ResponseWriter + headersSent bool + err error } -func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) *codeCatcher { +type codeCatcherWithCloseNotify struct { + *codeCatcher +} + +// CloseNotify returns a channel that receives at most a +// single value (true) when the client connection has gone away. +func (cc *codeCatcherWithCloseNotify) CloseNotify() <-chan bool { + return cc.responseWriter.(http.CloseNotifier).CloseNotify() +} + +func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) responseInterceptor { catcher := &codeCatcher{ headerMap: make(http.Header), code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200. @@ -149,6 +171,9 @@ func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) httpCodeRanges: httpCodeRanges, firstWrite: true, } + if _, ok := rw.(http.CloseNotifier); ok { + return &codeCatcherWithCloseNotify{catcher} + } return catcher } @@ -160,12 +185,19 @@ func (cc *codeCatcher) Header() http.Header { return cc.headerMap } +func (cc *codeCatcher) getCode() int { + return cc.code +} + +// isFilteredCode returns whether the codeCatcher received a response code among the ones it is watching, +// and for which the response should be deferred to the error handler. +func (cc *codeCatcher) isFilteredCode() bool { + return cc.caughtFilteredCode +} + func (cc *codeCatcher) Write(buf []byte) (int, error) { - if cc.err != nil { - return 0, cc.err - } if !cc.firstWrite { - if cc.isError { + if cc.caughtFilteredCode { // We don't care about the contents of the response, // since we want to serve the ones from the error page, // so we just drop them. @@ -173,25 +205,38 @@ func (cc *codeCatcher) Write(buf []byte) (int, error) { } return cc.responseWriter.Write(buf) } - - for _, block := range cc.httpCodeRanges { - if cc.code >= block[0] && cc.code <= block[1] { - cc.isError = true - break - } - } cc.firstWrite = false - if !cc.isError { - utils.CopyHeaders(cc.responseWriter.Header(), cc.Header()) - cc.responseWriter.WriteHeader(cc.code) - } else { + + // If WriteHeader was already called from the caller, this is a NOOP. + // Otherwise, cc.code is actually a 200 here. + cc.WriteHeader(cc.code) + + if cc.caughtFilteredCode { return len(buf), nil } return cc.responseWriter.Write(buf) } func (cc *codeCatcher) WriteHeader(code int) { + if cc.headersSent || cc.caughtFilteredCode { + return + } + cc.code = code + for _, block := range cc.httpCodeRanges { + if cc.code >= block[0] && cc.code <= block[1] { + cc.caughtFilteredCode = true + break + } + } + // it will be up to the other response recorder to send the headers, + // so it is out of our hands now. + if cc.caughtFilteredCode { + return + } + utils.CopyHeaders(cc.responseWriter.Header(), cc.Header()) + cc.responseWriter.WriteHeader(cc.code) + cc.headersSent = true } // Hijack hijacks the connection @@ -204,6 +249,10 @@ func (cc *codeCatcher) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Flush sends any buffered data to the client. func (cc *codeCatcher) Flush() { + // If WriteHeader was already called from the caller, this is a NOOP. + // Otherwise, cc.code is actually a 200 here. + cc.WriteHeader(cc.code) + if flusher, ok := cc.responseWriter.(http.Flusher); ok { flusher.Flush() } diff --git a/middlewares/errorpages/error_pages_test.go b/middlewares/errorpages/error_pages_test.go index ecca07773..6c5505091 100644 --- a/middlewares/errorpages/error_pages_test.go +++ b/middlewares/errorpages/error_pages_test.go @@ -46,6 +46,18 @@ func TestHandler(t *testing.T) { assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent)) }, }, + { + desc: "a 304, so no Write called", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusNotModified, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "whatever, should not be called") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusNotModified, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "") + }, + }, { desc: "in the range", errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, @@ -120,6 +132,9 @@ func TestHandler(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(test.backendCode) + if test.backendCode == http.StatusNotModified { + return + } fmt.Fprintln(w, http.StatusText(test.backendCode)) })