From 753d1739659d5561915849c8ec9cd2251410bf89 Mon Sep 17 00:00:00 2001 From: mpl Date: Thu, 12 Sep 2019 16:20:05 +0200 Subject: [PATCH] error pages: do not buffer response when it's not an error --- middlewares/errorpages/error_pages.go | 91 ++++++++++++++++++---- middlewares/errorpages/error_pages_test.go | 12 +++ 2 files changed, 90 insertions(+), 13 deletions(-) diff --git a/middlewares/errorpages/error_pages.go b/middlewares/errorpages/error_pages.go index 10f241bb4..2e5b37556 100644 --- a/middlewares/errorpages/error_pages.go +++ b/middlewares/errorpages/error_pages.go @@ -74,25 +74,28 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. return } - recorder := newResponseRecorder(w) - next.ServeHTTP(recorder, req) + catcher := newCodeCatcher(w, h.httpCodeRanges) + next.ServeHTTP(catcher, req) + if !catcher.isError { + return + } // check the recorder code against the configured http status code ranges for _, block := range h.httpCodeRanges { - if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] { - log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode()) + if catcher.code >= block[0] && catcher.code <= block[1] { + log.Errorf("Caught HTTP Status Code %d, returning error page", catcher.code) var query string if len(h.backendQuery) > 0 { query = "/" + strings.TrimPrefix(h.backendQuery, "/") - query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1) + query = strings.Replace(query, "{status}", strconv.Itoa(catcher.code), -1) } pageReq, err := newRequest(h.backendURL + query) if err != nil { log.Error(err) - w.WriteHeader(recorder.GetCode()) - fmt.Fprint(w, http.StatusText(recorder.GetCode())) + w.WriteHeader(catcher.code) + fmt.Fprint(w, http.StatusText(catcher.code)) return } @@ -102,16 +105,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context())) utils.CopyHeaders(w.Header(), recorderErrorPage.Header()) - w.WriteHeader(recorder.GetCode()) + w.WriteHeader(catcher.code) w.Write(recorderErrorPage.GetBody().Bytes()) return } } - - // did not catch a configured status code so proceed with the request - utils.CopyHeaders(w.Header(), recorder.Header()) - w.WriteHeader(recorder.GetCode()) - w.Write(recorder.GetBody().Bytes()) } func newRequest(baseURL string) (*http.Request, error) { @@ -129,6 +127,73 @@ func newRequest(baseURL string) (*http.Request, error) { return req, nil } +// codeCatcher is a response writer that detects as soon as possible whether the +// response is a code within the ranges of codes it watches for. If it is, it +// simply drops the data from the response. Otherwise, it forwards it directly to +// the original client (its responseWriter) without any buffering. +type codeCatcher struct { + headerMap http.Header + code int + httpCodeRanges types.HTTPCodeRanges + firstWrite bool + isError bool + responseWriter http.ResponseWriter + err error +} + +func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) *codeCatcher { + catcher := &codeCatcher{ + headerMap: make(http.Header), + code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200. + responseWriter: rw, + httpCodeRanges: httpCodeRanges, + firstWrite: true, + } + return catcher +} + +func (cc *codeCatcher) Header() http.Header { + if cc.headerMap == nil { + cc.headerMap = make(http.Header) + } + + return cc.headerMap +} + +func (cc *codeCatcher) Write(buf []byte) (int, error) { + if cc.err != nil { + return 0, cc.err + } + if !cc.firstWrite { + if cc.isError { + // We don't care about the contents of the response, + // since we want to serve the ones from the error page, + // so we just drop them. + return len(buf), nil + } + return cc.responseWriter.Write(buf) + } + + for _, block := range cc.httpCodeRanges { + if cc.code >= block[0] && cc.code <= block[1] { + cc.isError = true + break + } + } + cc.firstWrite = false + if !cc.isError { + utils.CopyHeaders(cc.responseWriter.Header(), cc.Header()) + cc.responseWriter.WriteHeader(cc.code) + } else { + return len(buf), nil + } + return cc.responseWriter.Write(buf) +} + +func (cc *codeCatcher) WriteHeader(code int) { + cc.code = code +} + type responseRecorder interface { http.ResponseWriter http.Flusher diff --git a/middlewares/errorpages/error_pages_test.go b/middlewares/errorpages/error_pages_test.go index 9cf19d87d..ecca07773 100644 --- a/middlewares/errorpages/error_pages_test.go +++ b/middlewares/errorpages/error_pages_test.go @@ -34,6 +34,18 @@ func TestHandler(t *testing.T) { assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK)) }, }, + { + desc: "no error, but not a 200", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusPartialContent, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusPartialContent, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent)) + }, + }, { desc: "in the range", errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},