Actually send header and code during WriteHeader, if needed

Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
mpl
2019-09-20 18:42:03 +02:00
committed by Traefiker Bot
parent 226f20b626
commit 216710864e
2 changed files with 95 additions and 31 deletions

View File

@@ -19,7 +19,10 @@ import (
)
// Compile time validation that the response recorder implements http interfaces correctly.
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
var (
_ middlewares.Stateful = &responseRecorderWithCloseNotify{}
_ middlewares.Stateful = &codeCatcherWithCloseNotify{}
)
// Handler is a middleware that provides the custom error pages
type Handler struct {
@@ -76,26 +79,27 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
catcher := newCodeCatcher(w, h.httpCodeRanges)
next.ServeHTTP(catcher, req)
if !catcher.isError {
if !catcher.isFilteredCode() {
return
}
// check the recorder code against the configured http status code ranges
code := catcher.getCode()
for _, block := range h.httpCodeRanges {
if catcher.code >= block[0] && catcher.code <= block[1] {
log.Errorf("Caught HTTP Status Code %d, returning error page", catcher.code)
if code >= block[0] && code <= block[1] {
log.Errorf("Caught HTTP Status Code %d, returning error page", code)
var query string
if len(h.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(h.backendQuery, "/")
query = strings.Replace(query, "{status}", strconv.Itoa(catcher.code), -1)
query = strings.Replace(query, "{status}", strconv.Itoa(code), -1)
}
pageReq, err := newRequest(h.backendURL + query)
if err != nil {
log.Error(err)
w.WriteHeader(catcher.code)
fmt.Fprint(w, http.StatusText(catcher.code))
w.WriteHeader(code)
fmt.Fprint(w, http.StatusText(code))
return
}
@@ -105,7 +109,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
w.WriteHeader(catcher.code)
w.WriteHeader(code)
w.Write(recorderErrorPage.GetBody().Bytes())
return
}
@@ -127,21 +131,39 @@ func newRequest(baseURL string) (*http.Request, error) {
return req, nil
}
type responseInterceptor interface {
http.ResponseWriter
http.Flusher
getCode() int
isFilteredCode() bool
}
// codeCatcher is a response writer that detects as soon as possible whether the
// response is a code within the ranges of codes it watches for. If it is, it
// simply drops the data from the response. Otherwise, it forwards it directly to
// the original client (its responseWriter) without any buffering.
type codeCatcher struct {
headerMap http.Header
code int
httpCodeRanges types.HTTPCodeRanges
firstWrite bool
isError bool
responseWriter http.ResponseWriter
err error
headerMap http.Header
code int
httpCodeRanges types.HTTPCodeRanges
firstWrite bool
caughtFilteredCode bool
responseWriter http.ResponseWriter
headersSent bool
err error
}
func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) *codeCatcher {
type codeCatcherWithCloseNotify struct {
*codeCatcher
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone away.
func (cc *codeCatcherWithCloseNotify) CloseNotify() <-chan bool {
return cc.responseWriter.(http.CloseNotifier).CloseNotify()
}
func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) responseInterceptor {
catcher := &codeCatcher{
headerMap: make(http.Header),
code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200.
@@ -149,6 +171,9 @@ func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges)
httpCodeRanges: httpCodeRanges,
firstWrite: true,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &codeCatcherWithCloseNotify{catcher}
}
return catcher
}
@@ -160,12 +185,19 @@ func (cc *codeCatcher) Header() http.Header {
return cc.headerMap
}
func (cc *codeCatcher) getCode() int {
return cc.code
}
// isFilteredCode returns whether the codeCatcher received a response code among the ones it is watching,
// and for which the response should be deferred to the error handler.
func (cc *codeCatcher) isFilteredCode() bool {
return cc.caughtFilteredCode
}
func (cc *codeCatcher) Write(buf []byte) (int, error) {
if cc.err != nil {
return 0, cc.err
}
if !cc.firstWrite {
if cc.isError {
if cc.caughtFilteredCode {
// We don't care about the contents of the response,
// since we want to serve the ones from the error page,
// so we just drop them.
@@ -173,25 +205,38 @@ func (cc *codeCatcher) Write(buf []byte) (int, error) {
}
return cc.responseWriter.Write(buf)
}
for _, block := range cc.httpCodeRanges {
if cc.code >= block[0] && cc.code <= block[1] {
cc.isError = true
break
}
}
cc.firstWrite = false
if !cc.isError {
utils.CopyHeaders(cc.responseWriter.Header(), cc.Header())
cc.responseWriter.WriteHeader(cc.code)
} else {
// If WriteHeader was already called from the caller, this is a NOOP.
// Otherwise, cc.code is actually a 200 here.
cc.WriteHeader(cc.code)
if cc.caughtFilteredCode {
return len(buf), nil
}
return cc.responseWriter.Write(buf)
}
func (cc *codeCatcher) WriteHeader(code int) {
if cc.headersSent || cc.caughtFilteredCode {
return
}
cc.code = code
for _, block := range cc.httpCodeRanges {
if cc.code >= block[0] && cc.code <= block[1] {
cc.caughtFilteredCode = true
break
}
}
// it will be up to the other response recorder to send the headers,
// so it is out of our hands now.
if cc.caughtFilteredCode {
return
}
utils.CopyHeaders(cc.responseWriter.Header(), cc.Header())
cc.responseWriter.WriteHeader(cc.code)
cc.headersSent = true
}
// Hijack hijacks the connection
@@ -204,6 +249,10 @@ func (cc *codeCatcher) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Flush sends any buffered data to the client.
func (cc *codeCatcher) Flush() {
// If WriteHeader was already called from the caller, this is a NOOP.
// Otherwise, cc.code is actually a 200 here.
cc.WriteHeader(cc.code)
if flusher, ok := cc.responseWriter.(http.Flusher); ok {
flusher.Flush()
}

View File

@@ -46,6 +46,18 @@ func TestHandler(t *testing.T) {
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent))
},
},
{
desc: "a 304, so no Write called",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusNotModified,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "whatever, should not be called")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusNotModified, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "")
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
@@ -120,6 +132,9 @@ func TestHandler(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
if test.backendCode == http.StatusNotModified {
return
}
fmt.Fprintln(w, http.StatusText(test.backendCode))
})