diff --git a/middlewares/retry.go b/middlewares/retry.go index 79a05d900..5166ca8f6 100644 --- a/middlewares/retry.go +++ b/middlewares/retry.go @@ -44,9 +44,7 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { attempts := 1 for { - attemptsExhausted := attempts >= retry.attempts - - shouldRetry := !attemptsExhausted + shouldRetry := attempts < retry.attempts retryResponseWriter := newRetryResponseWriter(rw, shouldRetry) // Disable retries when the backend already received request data @@ -99,6 +97,7 @@ type retryResponseWriter interface { func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter { responseWriter := &retryResponseWriterWithoutCloseNotify{ responseWriter: rw, + headers: make(http.Header), shouldRetry: shouldRetry, } if _, ok := rw.(http.CloseNotifier); ok { @@ -109,6 +108,7 @@ func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryRespo type retryResponseWriterWithoutCloseNotify struct { responseWriter http.ResponseWriter + headers http.Header shouldRetry bool } @@ -121,10 +121,7 @@ func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() { } func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header { - if rr.ShouldRetry() { - return make(http.Header) - } - return rr.responseWriter.Header() + return rr.headers } func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) { @@ -147,6 +144,16 @@ func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) { if rr.ShouldRetry() { return } + + // In that case retry case is set to false which means we at least managed + // to write headers to the backend : we are not going to perform any further retry. + // So it is now safe to alter current response headers with headers collected during + // the latest try before writing headers to client. + headers := rr.responseWriter.Header() + for header, value := range rr.headers { + headers[header] = value + } + rr.responseWriter.WriteHeader(code) } diff --git a/middlewares/retry_test.go b/middlewares/retry_test.go index 9c51c567f..30160be33 100644 --- a/middlewares/retry_test.go +++ b/middlewares/retry_test.go @@ -1,8 +1,10 @@ package middlewares import ( + "fmt" "net/http" "net/http/httptest" + "net/http/httptrace" "strings" "testing" @@ -256,3 +258,45 @@ func TestRetryWithFlush(t *testing.T) { t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA") } } + +func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { + attempt := 0 + expectedHeaderName := "X-Foo-Test-2" + expectedHeaderValue := "bar" + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + headerName := fmt.Sprintf("X-Foo-Test-%d", attempt) + rw.Header().Add(headerName, expectedHeaderValue) + if attempt < 2 { + attempt++ + return + } + + // Request has been successfully written to backend + trace := httptrace.ContextClientTrace(req.Context()) + trace.WroteHeaders() + + // And we decide to answer to client + rw.WriteHeader(http.StatusNoContent) + }) + + retry := NewRetry(3, next, &countingRetryListener{}) + responseRecorder := httptest.NewRecorder() + retry.ServeHTTP(responseRecorder, &http.Request{}) + + headerValue := responseRecorder.Header().Get(expectedHeaderName) + + // Validate if we have the correct header + if headerValue != expectedHeaderValue { + t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue) + } + + // Validate that we don't have headers from previous attempts + for i := 0; i < attempt; i++ { + headerName := fmt.Sprintf("X-Foo-Test-%d", i) + headerValue = responseRecorder.Header().Get("headerName") + if headerValue != "" { + t.Errorf("Expected no value for header %s, got %s", headerName, headerValue) + } + } +}