error pages: do not buffer response when it's not an error

This commit is contained in:
mpl
2019-09-12 16:20:05 +02:00
committed by Traefiker Bot
parent ffd1f122de
commit 753d173965
2 changed files with 90 additions and 13 deletions

View File

@@ -74,25 +74,28 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
return
}
recorder := newResponseRecorder(w)
next.ServeHTTP(recorder, req)
catcher := newCodeCatcher(w, h.httpCodeRanges)
next.ServeHTTP(catcher, req)
if !catcher.isError {
return
}
// check the recorder code against the configured http status code ranges
for _, block := range h.httpCodeRanges {
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
if catcher.code >= block[0] && catcher.code <= block[1] {
log.Errorf("Caught HTTP Status Code %d, returning error page", catcher.code)
var query string
if len(h.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(h.backendQuery, "/")
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
query = strings.Replace(query, "{status}", strconv.Itoa(catcher.code), -1)
}
pageReq, err := newRequest(h.backendURL + query)
if err != nil {
log.Error(err)
w.WriteHeader(recorder.GetCode())
fmt.Fprint(w, http.StatusText(recorder.GetCode()))
w.WriteHeader(catcher.code)
fmt.Fprint(w, http.StatusText(catcher.code))
return
}
@@ -102,16 +105,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
w.WriteHeader(recorder.GetCode())
w.WriteHeader(catcher.code)
w.Write(recorderErrorPage.GetBody().Bytes())
return
}
}
// did not catch a configured status code so proceed with the request
utils.CopyHeaders(w.Header(), recorder.Header())
w.WriteHeader(recorder.GetCode())
w.Write(recorder.GetBody().Bytes())
}
func newRequest(baseURL string) (*http.Request, error) {
@@ -129,6 +127,73 @@ func newRequest(baseURL string) (*http.Request, error) {
return req, nil
}
// codeCatcher is a response writer that detects as soon as possible whether the
// response is a code within the ranges of codes it watches for. If it is, it
// simply drops the data from the response. Otherwise, it forwards it directly to
// the original client (its responseWriter) without any buffering.
type codeCatcher struct {
headerMap http.Header
code int
httpCodeRanges types.HTTPCodeRanges
firstWrite bool
isError bool
responseWriter http.ResponseWriter
err error
}
func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) *codeCatcher {
catcher := &codeCatcher{
headerMap: make(http.Header),
code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200.
responseWriter: rw,
httpCodeRanges: httpCodeRanges,
firstWrite: true,
}
return catcher
}
func (cc *codeCatcher) Header() http.Header {
if cc.headerMap == nil {
cc.headerMap = make(http.Header)
}
return cc.headerMap
}
func (cc *codeCatcher) Write(buf []byte) (int, error) {
if cc.err != nil {
return 0, cc.err
}
if !cc.firstWrite {
if cc.isError {
// We don't care about the contents of the response,
// since we want to serve the ones from the error page,
// so we just drop them.
return len(buf), nil
}
return cc.responseWriter.Write(buf)
}
for _, block := range cc.httpCodeRanges {
if cc.code >= block[0] && cc.code <= block[1] {
cc.isError = true
break
}
}
cc.firstWrite = false
if !cc.isError {
utils.CopyHeaders(cc.responseWriter.Header(), cc.Header())
cc.responseWriter.WriteHeader(cc.code)
} else {
return len(buf), nil
}
return cc.responseWriter.Write(buf)
}
func (cc *codeCatcher) WriteHeader(code int) {
cc.code = code
}
type responseRecorder interface {
http.ResponseWriter
http.Flusher

View File

@@ -34,6 +34,18 @@ func TestHandler(t *testing.T) {
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
},
},
{
desc: "no error, but not a 200",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusPartialContent,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusPartialContent, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent))
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},