forked from Ivasoft/traefik
error pages: do not buffer response when it's not an error
This commit is contained in:
@@ -74,25 +74,28 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
|
||||
return
|
||||
}
|
||||
|
||||
recorder := newResponseRecorder(w)
|
||||
next.ServeHTTP(recorder, req)
|
||||
catcher := newCodeCatcher(w, h.httpCodeRanges)
|
||||
next.ServeHTTP(catcher, req)
|
||||
if !catcher.isError {
|
||||
return
|
||||
}
|
||||
|
||||
// check the recorder code against the configured http status code ranges
|
||||
for _, block := range h.httpCodeRanges {
|
||||
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
|
||||
log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
|
||||
if catcher.code >= block[0] && catcher.code <= block[1] {
|
||||
log.Errorf("Caught HTTP Status Code %d, returning error page", catcher.code)
|
||||
|
||||
var query string
|
||||
if len(h.backendQuery) > 0 {
|
||||
query = "/" + strings.TrimPrefix(h.backendQuery, "/")
|
||||
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
|
||||
query = strings.Replace(query, "{status}", strconv.Itoa(catcher.code), -1)
|
||||
}
|
||||
|
||||
pageReq, err := newRequest(h.backendURL + query)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
fmt.Fprint(w, http.StatusText(recorder.GetCode()))
|
||||
w.WriteHeader(catcher.code)
|
||||
fmt.Fprint(w, http.StatusText(catcher.code))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -102,16 +105,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
|
||||
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
|
||||
|
||||
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
w.WriteHeader(catcher.code)
|
||||
w.Write(recorderErrorPage.GetBody().Bytes())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// did not catch a configured status code so proceed with the request
|
||||
utils.CopyHeaders(w.Header(), recorder.Header())
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
w.Write(recorder.GetBody().Bytes())
|
||||
}
|
||||
|
||||
func newRequest(baseURL string) (*http.Request, error) {
|
||||
@@ -129,6 +127,73 @@ func newRequest(baseURL string) (*http.Request, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// codeCatcher is a response writer that detects as soon as possible whether the
|
||||
// response is a code within the ranges of codes it watches for. If it is, it
|
||||
// simply drops the data from the response. Otherwise, it forwards it directly to
|
||||
// the original client (its responseWriter) without any buffering.
|
||||
type codeCatcher struct {
|
||||
headerMap http.Header
|
||||
code int
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
firstWrite bool
|
||||
isError bool
|
||||
responseWriter http.ResponseWriter
|
||||
err error
|
||||
}
|
||||
|
||||
func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) *codeCatcher {
|
||||
catcher := &codeCatcher{
|
||||
headerMap: make(http.Header),
|
||||
code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200.
|
||||
responseWriter: rw,
|
||||
httpCodeRanges: httpCodeRanges,
|
||||
firstWrite: true,
|
||||
}
|
||||
return catcher
|
||||
}
|
||||
|
||||
func (cc *codeCatcher) Header() http.Header {
|
||||
if cc.headerMap == nil {
|
||||
cc.headerMap = make(http.Header)
|
||||
}
|
||||
|
||||
return cc.headerMap
|
||||
}
|
||||
|
||||
func (cc *codeCatcher) Write(buf []byte) (int, error) {
|
||||
if cc.err != nil {
|
||||
return 0, cc.err
|
||||
}
|
||||
if !cc.firstWrite {
|
||||
if cc.isError {
|
||||
// We don't care about the contents of the response,
|
||||
// since we want to serve the ones from the error page,
|
||||
// so we just drop them.
|
||||
return len(buf), nil
|
||||
}
|
||||
return cc.responseWriter.Write(buf)
|
||||
}
|
||||
|
||||
for _, block := range cc.httpCodeRanges {
|
||||
if cc.code >= block[0] && cc.code <= block[1] {
|
||||
cc.isError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
cc.firstWrite = false
|
||||
if !cc.isError {
|
||||
utils.CopyHeaders(cc.responseWriter.Header(), cc.Header())
|
||||
cc.responseWriter.WriteHeader(cc.code)
|
||||
} else {
|
||||
return len(buf), nil
|
||||
}
|
||||
return cc.responseWriter.Write(buf)
|
||||
}
|
||||
|
||||
func (cc *codeCatcher) WriteHeader(code int) {
|
||||
cc.code = code
|
||||
}
|
||||
|
||||
type responseRecorder interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
|
||||
@@ -34,6 +34,18 @@ func TestHandler(t *testing.T) {
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no error, but not a 200",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusPartialContent,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusPartialContent, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
|
||||
Reference in New Issue
Block a user