From a610f0b2a1864a048be20c88c4623458a948dcfd Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Mon, 23 Mar 2020 15:18:04 +0100 Subject: [PATCH] Force http/1.1 for upgrade (Traefik v1) --- server/server.go | 5 +++- server/server_configuration.go | 6 ++++ server/server_loadbalancer.go | 54 ++++++++++++++++++---------------- server/smart_roundtripper.go | 43 +++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 27 deletions(-) create mode 100644 server/smart_roundtripper.go diff --git a/server/server.go b/server/server.go index a98e7348d..14056681d 100644 --- a/server/server.go +++ b/server/server.go @@ -218,7 +218,10 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p log.Errorf("failed to create HTTP transport: %v", err) } - server.defaultForwardingRoundTripper = transport + server.defaultForwardingRoundTripper, err = newSmartRoundTripper(transport) + if err != nil { + log.Errorf("Failed to create HTTP transport: %v", err) + } server.tracingMiddleware = globalConfiguration.Tracing if server.tracingMiddleware != nil && server.tracingMiddleware.Backend != "" { diff --git a/server/server_configuration.go b/server/server_configuration.go index 3dbc7f08e..d5c96f4f7 100644 --- a/server/server_configuration.go +++ b/server/server_configuration.go @@ -237,11 +237,17 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration } } + var tlsConfig *tls.Config + if smartRt, ok := roundTripper.(*smartRoundTripper); ok { + tlsConfig = smartRt.GetTLSClientConfig() + } + var fwd http.Handler fwd, err = forward.New( forward.Stream(true), forward.PassHostHeader(frontend.PassHostHeader), forward.RoundTripper(roundTripper), + forward.WebsocketTLSClientConfig(tlsConfig), forward.Rewriter(rewriter), forward.ResponseModifier(responseModifier), forward.BufferPool(s.bufferPool), diff --git a/server/server_loadbalancer.go b/server/server_loadbalancer.go index 27c228742..3d7659216 100644 --- a/server/server_loadbalancer.go +++ b/server/server_loadbalancer.go @@ -228,23 +228,29 @@ func (s *Server) configureLBServers(lb healthcheck.BalancerHandler, backend *typ // getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one // given a custom TLS configuration is passed and the passTLSCert option is set to true. func (s *Server) getRoundTripper(entryPointName string, passTLSCert bool, tls *traefiktls.TLS) (http.RoundTripper, error) { - if passTLSCert { - tlsConfig, err := createClientTLSConfig(entryPointName, tls) - if err != nil { - return nil, fmt.Errorf("failed to create TLSClientConfig: %v", err) - } - tlsConfig.InsecureSkipVerify = s.globalConfiguration.InsecureSkipVerify - - transport, err := createHTTPTransport(s.globalConfiguration) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP transport: %v", err) - } - - transport.TLSClientConfig = tlsConfig - return transport, nil + if !passTLSCert { + return s.defaultForwardingRoundTripper, nil } - return s.defaultForwardingRoundTripper, nil + tlsConfig, err := createClientTLSConfig(entryPointName, tls) + if err != nil { + return nil, fmt.Errorf("failed to create TLSClientConfig: %v", err) + } + tlsConfig.InsecureSkipVerify = s.globalConfiguration.InsecureSkipVerify + + transport, err := createHTTPTransport(s.globalConfiguration) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %v", err) + } + + transport.TLSClientConfig = tlsConfig + + smartTransport, err := newSmartRoundTripper(transport) + if err != nil { + return nil, err + } + + return smartTransport, nil } // createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings. @@ -285,25 +291,21 @@ func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout) } - if globalConfiguration.InsecureSkipVerify { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - if len(globalConfiguration.RootCAs) > 0 { + if globalConfiguration.InsecureSkipVerify || len(globalConfiguration.RootCAs) > 0 { transport.TLSClientConfig = &tls.Config{ - RootCAs: createRootCACertPool(globalConfiguration.RootCAs), + InsecureSkipVerify: globalConfiguration.InsecureSkipVerify, + RootCAs: createRootCACertPool(globalConfiguration.RootCAs), } } - err := http2.ConfigureTransport(transport) - if err != nil { - return nil, err - } - return transport, nil } func createRootCACertPool(rootCAs traefiktls.FilesOrContents) *x509.CertPool { + if len(rootCAs) == 0 { + return nil + } + roots := x509.NewCertPool() for _, cert := range rootCAs { diff --git a/server/smart_roundtripper.go b/server/smart_roundtripper.go new file mode 100644 index 000000000..f8a9fd361 --- /dev/null +++ b/server/smart_roundtripper.go @@ -0,0 +1,43 @@ +package server + +import ( + "crypto/tls" + "net/http" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2" +) + +func newSmartRoundTripper(transport *http.Transport) (http.RoundTripper, error) { + transportHTTP1 := transport.Clone() + + err := http2.ConfigureTransport(transport) + if err != nil { + return nil, err + } + + return &smartRoundTripper{ + http2: transport, + http: transportHTTP1, + }, nil +} + +// smartRoundTripper implements RoundTrip while making sure that HTTP/2 is not used +// with protocols that start with a Connection Upgrade, such as SPDY or Websocket. +type smartRoundTripper struct { + http2 *http.Transport + http *http.Transport +} + +func (m *smartRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // If we have a connection upgrade, we don't use HTTP2 + if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { + return m.http.RoundTrip(req) + } + + return m.http2.RoundTrip(req) +} + +func (m *smartRoundTripper) GetTLSClientConfig() *tls.Config { + return m.http2.TLSClientConfig +}