Compare commits

..

21 Commits

Author SHA1 Message Date
Ludovic Fernandez
feeb7f81a6 Prepare Release v1.6.6 2018-08-20 14:46:02 +02:00
Damien Duportal
2beb5236d0 A tiny rewording on the documentation API's page 2018-08-20 13:34:03 +02:00
Damien Duportal
f062ee80c8 Docs: Adding warnings and solution about the configuration exposure 2018-08-20 12:02:03 +02:00
SALLEYRON Julien
a7bb768e98 Remove TLS in API 2018-08-20 11:16:02 +02:00
SALLEYRON Julien
07be89d6e9 Update oxy dependency 2018-08-20 10:38:03 +02:00
NicoMen
d81c4e6d1a Avoid duplicated ACME resolution 2018-08-20 09:40:03 +02:00
macros
60b4095c75 Set keepalive on TCP socket so idleTimeout works 2018-08-08 19:12:03 +02:00
Ludovic Fernandez
7ff6e6b66f Freeze mkdocs version 2018-08-06 15:50:03 +02:00
Daniel Tomcej
dbe720f0f1 Remove unusable --cluster flag 2018-07-13 17:32:03 +02:00
Jonathan Ballet
f0ab2721a5 Fix path to the debug pprof API 2018-07-12 17:58:02 +02:00
Michael
a7c158f0e1 Fix bad condition in ECS provider 2018-07-12 17:40:04 +02:00
Ludovic Fernandez
bdc0e3bfcf Prepare release v1.6.5 2018-07-10 17:46:04 +02:00
SALLEYRON Julien
f173ff02e3 Add a mutex on local store for HTTPChallenges 2018-07-09 23:28:02 +02:00
SALLEYRON Julien
bacd58ed7b Add logs when error is generated in error handler 2018-07-06 10:32:03 +02:00
Or Tzabary
f323df466d Split the error handling from Consul Catalog (deadlock) 2018-07-05 15:12:03 +02:00
Fabian Beuke
b1836587f2 Update keyFile first/last line comment in kv-config.md 2018-07-04 14:20:03 +02:00
John Yani
dbc3b85cd0 Minor formatting issue in user-guide 2018-06-29 17:02:03 +02:00
Jean-Baptiste Doumenjou
5eda08e9b8 Better support on same prefix at the same level in the KV 2018-06-26 16:18:05 +02:00
Ludovic Fernandez
ec6e46e2cb segment labels: multiple frontends for one backend. 2018-06-22 19:44:03 +02:00
Michael
aa705dd691 Create middleware to be able to handle HTTP pipelining correctly 2018-06-20 09:12:03 +02:00
manu5801
1c3e4124f8 The gandiv5 provider works with wildcard 2018-06-18 09:26:02 +02:00
68 changed files with 1588 additions and 642 deletions

View File

@@ -1,5 +1,39 @@
# Change Log
## [v1.6.6](https://github.com/containous/traefik/tree/v1.6.6) (2018-08-20)
[All Commits](https://github.com/containous/traefik/compare/v1.6.5...v1.6.6)
**Bug fixes:**
- **[acme]** Avoid duplicated ACME resolution ([#3751](https://github.com/containous/traefik/pull/3751) by [nmengin](https://github.com/nmengin))
- **[api]** Remove TLS in API ([#3788](https://github.com/containous/traefik/pull/3788) by [Juliens](https://github.com/Juliens))
- **[cluster]** Remove unusable `--cluster` flag ([#3616](https://github.com/containous/traefik/pull/3616) by [dtomcej](https://github.com/dtomcej))
- **[ecs]** Fix bad condition in ECS provider ([#3609](https://github.com/containous/traefik/pull/3609) by [mmatur](https://github.com/mmatur))
- Set keepalive on TCP socket so idleTimeout works ([#3740](https://github.com/containous/traefik/pull/3740) by [ajardan](https://github.com/ajardan))
**Documentation:**
- A tiny rewording on the documentation API's page ([#3794](https://github.com/containous/traefik/pull/3794) by [dduportal](https://github.com/dduportal))
- Adding warnings and solution about the configuration exposure ([#3790](https://github.com/containous/traefik/pull/3790) by [dduportal](https://github.com/dduportal))
- Fix path to the debug pprof API ([#3608](https://github.com/containous/traefik/pull/3608) by [multani](https://github.com/multani))
**Misc:**
- **[oxy,websocket]** Update oxy dependency ([#3777](https://github.com/containous/traefik/pull/3777) by [Juliens](https://github.com/Juliens))
## [v1.6.5](https://github.com/containous/traefik/tree/v1.6.5) (2018-07-09)
[All Commits](https://github.com/containous/traefik/compare/v1.6.4...v1.6.5)
**Bug fixes:**
- **[acme]** Add a mutex on local store for HTTPChallenges ([#3579](https://github.com/containous/traefik/pull/3579) by [Juliens](https://github.com/Juliens))
- **[consulcatalog]** Split the error handling from Consul Catalog (deadlock) ([#3560](https://github.com/containous/traefik/pull/3560) by [ortz](https://github.com/ortz))
- **[docker]** segment labels: multiple frontends for one backend. ([#3511](https://github.com/containous/traefik/pull/3511) by [ldez](https://github.com/ldez))
- **[kv]** Better support on same prefix at the same level in the KV ([#3532](https://github.com/containous/traefik/pull/3532) by [jbdoumenjou](https://github.com/jbdoumenjou))
- **[logs]** Add logs when error is generated in error handler ([#3567](https://github.com/containous/traefik/pull/3567) by [Juliens](https://github.com/Juliens))
- **[middleware]** Create middleware to be able to handle HTTP pipelining correctly ([#3513](https://github.com/containous/traefik/pull/3513) by [mmatur](https://github.com/mmatur))
**Documentation:**
- **[acme]** The gandiv5 provider works with wildcard ([#3506](https://github.com/containous/traefik/pull/3506) by [manu5801](https://github.com/manu5801))
- **[kv]** Update keyFile first/last line comment in kv-config.md ([#3558](https://github.com/containous/traefik/pull/3558) by [madnight](https://github.com/madnight))
- Minor formatting issue in user-guide ([#3546](https://github.com/containous/traefik/pull/3546) by [Vanuan](https://github.com/Vanuan))
## [v1.6.4](https://github.com/containous/traefik/tree/v1.6.4) (2018-06-15)
[All Commits](https://github.com/containous/traefik/compare/v1.6.3...v1.6.4)

8
Gopkg.lock generated
View File

@@ -257,8 +257,8 @@
[[projects]]
name = "github.com/containous/staert"
packages = ["."]
revision = "cc00c303ccbd2491ddc1dccc9eb7ccadd807557e"
version = "v3.1.0"
revision = "66717a0e0ca950c4b6dc8c87b46da0b8495c6e41"
version = "v3.1.1"
[[projects]]
name = "github.com/containous/traefik-extra-service-fabric"
@@ -1217,7 +1217,7 @@
"roundrobin",
"utils"
]
revision = "d5b73186eed4aa34b52748699ad19e90f61d4059"
revision = "885e42fe04d8e0efa6c18facad4e0fc5757cde9b"
[[projects]]
name = "github.com/vulcand/predicate"
@@ -1679,6 +1679,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "ac06fad81167510635546d4e5500b938d61f2eb999bf04d5520d7967b9621f0d"
inputs-digest = "ad34e6336e6f19b82c52e991d22c5b43b9144ed7dc83d7b17197583ace43f346"
solver-name = "gps-cdcl"
solver-version = 1

View File

@@ -62,7 +62,7 @@
[[constraint]]
name = "github.com/containous/staert"
version = "3.1.0"
version = "3.1.1"
[[constraint]]
name = "github.com/containous/traefik-extra-service-fabric"

View File

@@ -11,6 +11,7 @@ import (
"net/http"
"reflect"
"strings"
"sync"
"time"
"github.com/BurntSushi/ty/fun"
@@ -61,6 +62,8 @@ type ACME struct {
jobs *channels.InfiniteChannel
TLSConfig *tls.Config `description:"TLS config in case wildcard certs are used"`
dynamicCerts *safe.Safe
resolvingDomains map[string]struct{}
resolvingDomainsMutex sync.RWMutex
}
func (a *ACME) init() error {
@@ -81,6 +84,10 @@ func (a *ACME) init() error {
a.defaultCertificate = cert
a.jobs = channels.NewInfiniteChannel()
// Init the currently resolved domain map
a.resolvingDomains = make(map[string]struct{})
return nil
}
@@ -502,6 +509,10 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
if len(uncheckedDomains) == 0 {
return
}
a.addResolvingDomains(uncheckedDomains)
defer a.removeResolvingDomains(uncheckedDomains)
certificate, err := a.getDomainsCertificates(uncheckedDomains)
if err != nil {
log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
@@ -533,6 +544,24 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
}
}
func (a *ACME) addResolvingDomains(resolvingDomains []string) {
a.resolvingDomainsMutex.Lock()
defer a.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
a.resolvingDomains[domain] = struct{}{}
}
}
func (a *ACME) removeResolvingDomains(resolvingDomains []string) {
a.resolvingDomainsMutex.Lock()
defer a.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
delete(a.resolvingDomains, domain)
}
}
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
@@ -568,6 +597,9 @@ func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Ce
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string {
a.resolvingDomainsMutex.RLock()
defer a.resolvingDomainsMutex.RUnlock()
log.Debugf("Looking for provided certificate to validate %s...", domains)
allCerts := make(map[string]*tls.Certificate)
@@ -590,6 +622,13 @@ func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string
}
}
// Get currently resolved domains
for domain := range a.resolvingDomains {
if _, ok := allCerts[domain]; !ok {
allCerts[domain] = &tls.Certificate{}
}
}
// Get Configuration Domains
for i := 0; i < len(a.Domains); i++ {
allCerts[a.Domains[i].Main] = &tls.Certificate{}

View File

@@ -331,9 +331,12 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
mm["*.containo.us"] = &tls.Certificate{}
mm["traefik.acme.io"] = &tls.Certificate{}
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}
dm := make(map[string]struct{})
dm["*.traefik.wtf"] = struct{}{}
domains := []string{"traefik.containo.us", "trae.containo.us"}
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}, resolvingDomains: dm}
domains := []string{"traefik.containo.us", "trae.containo.us", "foo.traefik.wtf"}
uncheckedDomains := a.getUncheckedDomains(domains, nil)
assert.Empty(t, uncheckedDomains)
domains = []string{"traefik.acme.io", "trae.acme.io"}
@@ -351,6 +354,9 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
account := Account{DomainsCertificate: domainsCertificates}
uncheckedDomains = a.getUncheckedDomains(domains, &account)
assert.Empty(t, uncheckedDomains)
domains = []string{"traefik.containo.us", "trae.containo.us", "traefik.wtf"}
uncheckedDomains = a.getUncheckedDomains(domains, nil)
assert.Len(t, uncheckedDomains, 1)
}
func TestAcme_getProvidedCertificate(t *testing.T) {

View File

@@ -175,7 +175,7 @@ func runCmd(globalConfiguration *configuration.GlobalConfiguration, configFile s
log.Debugf("Global configuration loaded %s", string(jsonConf))
if acme.IsEnabled() {
store := acme.NewLocalStore(acme.Get().Storage)
acme.Get().Store = &store
acme.Get().Store = store
}
svr := server.NewServer(*globalConfiguration, configuration.NewProviderAggregator(globalConfiguration))
if acme.IsEnabled() && acme.Get().OnHostRule {

View File

@@ -58,19 +58,19 @@ const (
// GlobalConfiguration holds global configuration (with providers, etc.).
// It's populated from the traefik configuration file passed as an argument to the binary.
type GlobalConfiguration struct {
LifeCycle *LifeCycle `description:"Timeouts influencing the server life cycle" export:"true"`
GraceTimeOut flaeg.Duration `short:"g" description:"(Deprecated) Duration to give active requests a chance to finish before Traefik stops" export:"true"` // Deprecated
Debug bool `short:"d" description:"Enable debug mode" export:"true"`
CheckNewVersion bool `description:"Periodically check if a new version has been released" export:"true"`
SendAnonymousUsage bool `description:"send periodically anonymous usage statistics" export:"true"`
AccessLogsFile string `description:"(Deprecated) Access logs file" export:"true"` // Deprecated
AccessLog *types.AccessLog `description:"Access log settings" export:"true"`
TraefikLogsFile string `description:"(Deprecated) Traefik logs file. Stdout is used when omitted or empty" export:"true"` // Deprecated
TraefikLog *types.TraefikLog `description:"Traefik log settings" export:"true"`
Tracing *tracing.Tracing `description:"OpenTracing configuration" export:"true"`
LogLevel string `short:"l" description:"Log level" export:"true"`
EntryPoints EntryPoints `description:"Entrypoints definition using format: --entryPoints='Name:http Address::8000 Redirect.EntryPoint:https' --entryPoints='Name:https Address::4442 TLS:tests/traefik.crt,tests/traefik.key;prod/traefik.crt,prod/traefik.key'" export:"true"`
Cluster *types.Cluster `description:"Enable clustering" export:"true"`
LifeCycle *LifeCycle `description:"Timeouts influencing the server life cycle" export:"true"`
GraceTimeOut flaeg.Duration `short:"g" description:"(Deprecated) Duration to give active requests a chance to finish before Traefik stops" export:"true"` // Deprecated
Debug bool `short:"d" description:"Enable debug mode" export:"true"`
CheckNewVersion bool `description:"Periodically check if a new version has been released" export:"true"`
SendAnonymousUsage bool `description:"send periodically anonymous usage statistics" export:"true"`
AccessLogsFile string `description:"(Deprecated) Access logs file" export:"true"` // Deprecated
AccessLog *types.AccessLog `description:"Access log settings" export:"true"`
TraefikLogsFile string `description:"(Deprecated) Traefik logs file. Stdout is used when omitted or empty" export:"true"` // Deprecated
TraefikLog *types.TraefikLog `description:"Traefik log settings" export:"true"`
Tracing *tracing.Tracing `description:"OpenTracing configuration" export:"true"`
LogLevel string `short:"l" description:"Log level" export:"true"`
EntryPoints EntryPoints `description:"Entrypoints definition using format: --entryPoints='Name:http Address::8000 Redirect.EntryPoint:https' --entryPoints='Name:https Address::4442 TLS:tests/traefik.crt,tests/traefik.key;prod/traefik.crt,prod/traefik.key'" export:"true"`
Cluster *types.Cluster
Constraints types.Constraints `description:"Filter services by constraint, matching with service tags" export:"true"`
ACME *acme.ACME `description:"Enable ACME (Let's Encrypt): automatic SSL" export:"true"`
DefaultEntryPoints DefaultEntryPoints `description:"Entrypoints to be used by frontends that do not specify any entrypoint" export:"true"`

View File

@@ -190,7 +190,7 @@ Here is a list of supported `provider`s, that can automate the DNS verification,
| [Exoscale](https://www.exoscale.ch) | `exoscale` | `EXOSCALE_API_KEY`, `EXOSCALE_API_SECRET`, `EXOSCALE_ENDPOINT` | YES |
| [Fast DNS](https://www.akamai.com/) | `fastdns` | `AKAMAI_CLIENT_TOKEN`, `AKAMAI_CLIENT_SECRET`, `AKAMAI_ACCESS_TOKEN` | Not tested yet |
| [Gandi](https://www.gandi.net) | `gandi` | `GANDI_API_KEY` | Not tested yet |
| [Gandi V5](http://doc.livedns.gandi.net) | `gandiv5` | `GANDIV5_API_KEY` | Not tested yet |
| [Gandi V5](http://doc.livedns.gandi.net) | `gandiv5` | `GANDIV5_API_KEY` | YES |
| [Glesys](https://glesys.com/) | `glesys` | `GLESYS_API_USER`, `GLESYS_API_KEY`, `GLESYS_DOMAIN` | Not tested yet |
| [GoDaddy](https://godaddy.com/domains) | `godaddy` | `GODADDY_API_KEY`, `GODADDY_API_SECRET` | Not tested yet |
| [Google Cloud DNS](https://cloud.google.com/dns/docs/) | `gcloud` | `GCE_PROJECT`, `GCE_SERVICE_ACCOUNT_FILE` | YES |

View File

@@ -4,6 +4,9 @@
```toml
# API definition
# Warning: Enabling API will expose Træfik's configuration.
# It is not recommended in production,
# unless secured by authentication and authorizations
[api]
# Name of the related entry point
#
@@ -12,7 +15,7 @@
#
entryPoint = "traefik"
# Enabled Dashboard
# Enable Dashboard
#
# Optional
# Default: true
@@ -21,7 +24,7 @@
# Enable debug mode.
# This will install HTTP handlers to expose Go expvars under /debug/vars and
# pprof profiling data under /debug/pprof.
# pprof profiling data under /debug/pprof/.
# Additionally, the log level will be set to DEBUG.
#
# Optional
@@ -38,6 +41,22 @@ For more customization, see [entry points](/configuration/entrypoints/) document
![Web UI Health](/img/traefik-health.png)
## Security
Enabling the API will expose all configuration elements,
including sensitive data.
It is not recommended in production,
unless secured by authentication and authorizations.
A good sane default (but not exhaustive) set of recommendations
would be to apply the following protection mechanism:
* _At application level:_ enabling HTTP [Basic Authentication](#authentication)
* _At transport level:_ NOT exposing publicly the API's port,
keeping it restricted over internal networks
(restricted networks as in https://en.wikipedia.org/wiki/Principle_of_least_privilege).
## API
| Path | Method | Description |

View File

@@ -18,7 +18,7 @@
# Enable debug mode.
# This will install HTTP handlers to expose Go expvars under /debug/vars and
# pprof profiling data under /debug/pprof.
# pprof profiling data under /debug/pprof/.
# The log level will be set to DEBUG unless `logLevel` is specified.
#
# Optional

View File

@@ -86,6 +86,10 @@ services:
- /var/run/docker.sock:/var/run/docker.sock # So that Traefik can listen to the Docker events
```
!!! warning
Enabling the Web UI with the `--api` flag might exposes configuration elements. You can read more about this on the [API/Dashboard's Security section](/configuration/api#security).
**That's it. Now you can launch Træfik!**
Start your `reverse-proxy` with the following command:

View File

@@ -85,9 +85,9 @@ defaultEntryPoints = ["http", "https"]
certFile = """-----BEGIN CERTIFICATE-----
<cert file content>
-----END CERTIFICATE-----"""
keyFile = """-----BEGIN CERTIFICATE-----
keyFile = """-----BEGIN PRIVATE KEY-----
<key file content>
-----END CERTIFICATE-----"""
-----END PRIVATE KEY-----"""
[entryPoints.other-https]
address = ":4443"
[entryPoints.other-https.tls]

View File

@@ -102,7 +102,7 @@ Let's explain this command:
| `--mount type=bind,source=/var/run/docker.sock,target=/var/run/docker.sock` | we bind mount the docker socket where Træfik is scheduled to be able to speak to the daemon. |
| `--network traefik-net` | we attach the Træfik service (and thus the underlying container) to the `traefik-net` network. |
| `--docker` | enable docker provider, and `--docker.swarmMode` to enable the swarm mode on Træfik. |
| `--api | activate the webUI on port 8080 |
| `--api` | activate the webUI on port 8080 |
## Deploy your apps

View File

@@ -50,7 +50,7 @@ start_boulder() {
# Script usage
show_usage() {
echo
echo "USAGE : manage_acme_docker_environment.sh [--start|--stop|--restart]"
echo "USAGE : manage_acme_docker_environment.sh [--dev|--start|--stop|--restart]"
echo
}

View File

@@ -585,21 +585,14 @@ func (s *ConsulSuite) TestSNIDynamicTlsConfig(c *check.C) {
})
c.Assert(err, checker.IsNil)
// wait for traefik
err = try.GetRequest("http://127.0.0.1:8081/api/providers", 60*time.Second, try.BodyContains("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7hG"))
c.Assert(err, checker.IsNil)
req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client := &http.Client{Transport: tr1}
req.Host = tr1.TLSClientConfig.ServerName
req.Header.Set("Host", tr1.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
var resp *http.Response
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com"))
c.Assert(err, checker.IsNil)
cn := resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "snitest.com")
// now we configure the second keypair in consul and the request for host "snitest.org" will use the second keypair
for key, value := range tlsconfigure2 {
@@ -614,18 +607,12 @@ func (s *ConsulSuite) TestSNIDynamicTlsConfig(c *check.C) {
})
c.Assert(err, checker.IsNil)
// waiting for traefik to pull configuration
err = try.GetRequest("http://127.0.0.1:8081/api/providers", 30*time.Second, try.BodyContains("MIIEogIBAAKCAQEAvG9kL+vF57+MICehzbqcQAUlAOSl5r"))
c.Assert(err, checker.IsNil)
req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client = &http.Client{Transport: tr2}
req.Host = tr2.TLSClientConfig.ServerName
req.Header.Set("Host", tr2.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr2, try.HasCn("snitest.org"))
c.Assert(err, checker.IsNil)
cn = resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "snitest.org")
}

View File

@@ -532,21 +532,14 @@ func (s *Etcd3Suite) TestSNIDynamicTlsConfig(c *check.C) {
c.Assert(err, checker.IsNil)
defer cmd.Process.Kill()
// wait for Træfik
err = try.GetRequest("http://127.0.0.1:8081/api/providers", 60*time.Second, try.BodyContains(string("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h")))
c.Assert(err, checker.IsNil)
req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client := &http.Client{Transport: tr1}
req.Host = tr1.TLSClientConfig.ServerName
req.Header.Set("Host", tr1.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
var resp *http.Response
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com"))
c.Assert(err, checker.IsNil)
cn := resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "snitest.com")
// now we configure the second keypair in etcd and the request for host "snitest.org" will use the second keypair
@@ -562,20 +555,14 @@ func (s *Etcd3Suite) TestSNIDynamicTlsConfig(c *check.C) {
})
c.Assert(err, checker.IsNil)
// waiting for Træfik to pull configuration
err = try.GetRequest("http://127.0.0.1:8081/api/providers", 30*time.Second, try.BodyContains("MIIEogIBAAKCAQEAvG9kL+vF57+MICehzbqcQAUlAOSl5r"))
c.Assert(err, checker.IsNil)
req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client = &http.Client{Transport: tr2}
req.Host = tr2.TLSClientConfig.ServerName
req.Header.Set("Host", tr2.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr2, try.HasCn("snitest.org"))
c.Assert(err, checker.IsNil)
cn = resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "snitest.org")
}
func (s *Etcd3Suite) TestDeleteSNIDynamicTlsConfig(c *check.C) {
@@ -646,21 +633,14 @@ func (s *Etcd3Suite) TestDeleteSNIDynamicTlsConfig(c *check.C) {
c.Assert(err, checker.IsNil)
defer cmd.Process.Kill()
// wait for Træfik
err = try.GetRequest(traefikWebEtcdURL+"api/providers", 60*time.Second, try.BodyContains(string("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h")))
c.Assert(err, checker.IsNil)
req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client := &http.Client{Transport: tr1}
req.Host = tr1.TLSClientConfig.ServerName
req.Header.Set("Host", tr1.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
var resp *http.Response
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com"))
c.Assert(err, checker.IsNil)
cn := resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "snitest.com")
// now we delete the tls cert/key pairs,so the endpoint show use default cert/key pair
for key := range tlsconfigure1 {
@@ -668,18 +648,12 @@ func (s *Etcd3Suite) TestDeleteSNIDynamicTlsConfig(c *check.C) {
c.Assert(err, checker.IsNil)
}
// waiting for Træfik to pull configuration
err = try.GetRequest(traefikWebEtcdURL+"api/providers", 30*time.Second, try.BodyNotContains("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h"))
c.Assert(err, checker.IsNil)
req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client = &http.Client{Transport: tr1}
req.Host = tr1.TLSClientConfig.ServerName
req.Header.Set("Host", tr1.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("TRAEFIK DEFAULT CERT"))
c.Assert(err, checker.IsNil)
cn = resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "TRAEFIK DEFAULT CERT")
}

View File

@@ -548,21 +548,14 @@ func (s *EtcdSuite) TestSNIDynamicTlsConfig(c *check.C) {
c.Assert(err, checker.IsNil)
defer cmd.Process.Kill()
// wait for Træfik
err = try.GetRequest("http://127.0.0.1:8081/api/providers", 60*time.Second, try.BodyContains(string("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h")))
c.Assert(err, checker.IsNil)
req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client := &http.Client{Transport: tr1}
req.Host = tr1.TLSClientConfig.ServerName
req.Header.Set("Host", tr1.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
var resp *http.Response
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com"))
c.Assert(err, checker.IsNil)
cn := resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "snitest.com")
// now we configure the second keypair in etcd and the request for host "snitest.org" will use the second keypair
@@ -578,18 +571,12 @@ func (s *EtcdSuite) TestSNIDynamicTlsConfig(c *check.C) {
})
c.Assert(err, checker.IsNil)
// waiting for Træfik to pull configuration
err = try.GetRequest("http://127.0.0.1:8081/api/providers", 30*time.Second, try.BodyContains("MIIEogIBAAKCAQEAvG9kL+vF57+MICehzbqcQAUlAOSl5r"))
c.Assert(err, checker.IsNil)
req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil)
c.Assert(err, checker.IsNil)
client = &http.Client{Transport: tr2}
req.Host = tr2.TLSClientConfig.ServerName
req.Header.Set("Host", tr2.TLSClientConfig.ServerName)
req.Header.Set("Accept", "*/*")
resp, err = client.Do(req)
err = try.RequestWithTransport(req, 30*time.Second, tr2, try.HasCn("snitest.org"))
c.Assert(err, checker.IsNil)
cn = resp.TLS.PeerCertificates[0].Subject.CommonName
c.Assert(cn, checker.Equals, "snitest.org")
}

View File

@@ -2,7 +2,7 @@
[backends]
[backends.backend2]
[backends.backend2.servers.server1]
url = "http://172.17.0.2:80"
url = "http://172.17.0.123:80"
weight = 1
[frontends]

View File

@@ -88,6 +88,31 @@ func HasBody() ResponseCondition {
}
}
// HasCn returns a retry condition function.
// The condition returns an error if the cn is not correct.
func HasCn(cn string) ResponseCondition {
return func(res *http.Response) error {
if res.TLS == nil {
return errors.New("response doesn't have TLS")
}
if len(res.TLS.PeerCertificates) == 0 {
return errors.New("response TLS doesn't have peer certificates")
}
if res.TLS.PeerCertificates[0] == nil {
return errors.New("first peer certificate is nil")
}
commonName := res.TLS.PeerCertificates[0].Subject.CommonName
if cn != commonName {
return fmt.Errorf("common name don't match: %s != %s", cn, commonName)
}
return nil
}
}
// StatusCodeIs returns a retry condition function.
// The condition returns an error if the given response's status code is not the
// given HTTP status code.

View File

@@ -31,7 +31,7 @@ func Sleep(d time.Duration) {
// response body needs to be closed or not. Callers are expected to close on
// their own if the function returns a nil error.
func Response(req *http.Request, timeout time.Duration) (*http.Response, error) {
return doTryRequest(req, timeout)
return doTryRequest(req, timeout, nil)
}
// ResponseUntilStatusCode is like Request, but returns the response for further
@@ -40,7 +40,7 @@ func Response(req *http.Request, timeout time.Duration) (*http.Response, error)
// response body needs to be closed or not. Callers are expected to close on
// their own if the function returns a nil error.
func ResponseUntilStatusCode(req *http.Request, timeout time.Duration, statusCode int) (*http.Response, error) {
return doTryRequest(req, timeout, StatusCodeIs(statusCode))
return doTryRequest(req, timeout, nil, StatusCodeIs(statusCode))
}
// GetRequest is like Do, but runs a request against the given URL and applies
@@ -48,7 +48,7 @@ func ResponseUntilStatusCode(req *http.Request, timeout time.Duration, statusCod
// ResponseCondition may be nil, in which case only the request against the URL must
// succeed.
func GetRequest(url string, timeout time.Duration, conditions ...ResponseCondition) error {
resp, err := doTryGet(url, timeout, conditions...)
resp, err := doTryGet(url, timeout, nil, conditions...)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
@@ -62,7 +62,21 @@ func GetRequest(url string, timeout time.Duration, conditions ...ResponseConditi
// ResponseCondition may be nil, in which case only the request against the URL must
// succeed.
func Request(req *http.Request, timeout time.Duration, conditions ...ResponseCondition) error {
resp, err := doTryRequest(req, timeout, conditions...)
resp, err := doTryRequest(req, timeout, nil, conditions...)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
return err
}
// RequestWithTransport is like Do, but runs a request against the given URL and applies
// the condition on the response.
// ResponseCondition may be nil, in which case only the request against the URL must
// succeed.
func RequestWithTransport(req *http.Request, timeout time.Duration, transport *http.Transport, conditions ...ResponseCondition) error {
resp, err := doTryRequest(req, timeout, transport, conditions...)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
@@ -112,24 +126,27 @@ func Do(timeout time.Duration, operation DoCondition) error {
}
}
func doTryGet(url string, timeout time.Duration, conditions ...ResponseCondition) (*http.Response, error) {
func doTryGet(url string, timeout time.Duration, transport *http.Transport, conditions ...ResponseCondition) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
return doTryRequest(req, timeout, conditions...)
return doTryRequest(req, timeout, transport, conditions...)
}
func doTryRequest(request *http.Request, timeout time.Duration, conditions ...ResponseCondition) (*http.Response, error) {
return doRequest(Do, timeout, request, conditions...)
func doTryRequest(request *http.Request, timeout time.Duration, transport *http.Transport, conditions ...ResponseCondition) (*http.Response, error) {
return doRequest(Do, timeout, request, transport, conditions...)
}
func doRequest(action timedAction, timeout time.Duration, request *http.Request, conditions ...ResponseCondition) (*http.Response, error) {
func doRequest(action timedAction, timeout time.Duration, request *http.Request, transport *http.Transport, conditions ...ResponseCondition) (*http.Response, error) {
var resp *http.Response
return resp, action(timeout, func() error {
var err error
client := http.DefaultClient
if transport != nil {
client.Transport = transport
}
resp, err = client.Do(request)
if err != nil {

View File

@@ -0,0 +1,62 @@
package pipelining
import (
"bufio"
"net"
"net/http"
)
// Pipelining returns a middleware
type Pipelining struct {
next http.Handler
}
// NewPipelining returns a new Pipelining instance
func NewPipelining(next http.Handler) *Pipelining {
return &Pipelining{
next: next,
}
}
func (p *Pipelining) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// https://github.com/golang/go/blob/3d59583836630cf13ec4bfbed977d27b1b7adbdc/src/net/http/server.go#L201-L218
if r.Method == http.MethodPut || r.Method == http.MethodPost {
p.next.ServeHTTP(rw, r)
} else {
p.next.ServeHTTP(&writerWithoutCloseNotify{rw}, r)
}
}
// writerWithoutCloseNotify helps to disable closeNotify
type writerWithoutCloseNotify struct {
W http.ResponseWriter
}
// Header returns the response headers.
func (w *writerWithoutCloseNotify) Header() http.Header {
return w.W.Header()
}
// Write writes the data to the connection as part of an HTTP reply.
func (w *writerWithoutCloseNotify) Write(buf []byte) (int, error) {
return w.W.Write(buf)
}
// WriteHeader sends an HTTP response header with the provided
// status code.
func (w *writerWithoutCloseNotify) WriteHeader(code int) {
w.W.WriteHeader(code)
}
// Flush sends any buffered data to the client.
func (w *writerWithoutCloseNotify) Flush() {
if f, ok := w.W.(http.Flusher); ok {
f.Flush()
}
}
// Hijack hijacks the connection.
func (w *writerWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.W.(http.Hijacker).Hijack()
}

View File

@@ -0,0 +1,69 @@
package pipelining
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
type recorderWithCloseNotify struct {
*httptest.ResponseRecorder
}
func (r *recorderWithCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func TestNewPipelining(t *testing.T) {
testCases := []struct {
desc string
HTTPMethod string
implementCloseNotifier bool
}{
{
desc: "should not implement CloseNotifier with GET method",
HTTPMethod: http.MethodGet,
implementCloseNotifier: false,
},
{
desc: "should implement CloseNotifier with PUT method",
HTTPMethod: http.MethodPut,
implementCloseNotifier: true,
},
{
desc: "should implement CloseNotifier with POST method",
HTTPMethod: http.MethodPost,
implementCloseNotifier: true,
},
{
desc: "should not implement CloseNotifier with GET method",
HTTPMethod: http.MethodHead,
implementCloseNotifier: false,
},
{
desc: "should not implement CloseNotifier with PROPFIND method",
HTTPMethod: "PROPFIND",
implementCloseNotifier: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := w.(http.CloseNotifier)
assert.Equal(t, test.implementCloseNotifier, ok)
w.WriteHeader(http.StatusOK)
})
handler := NewPipelining(nextHandler)
req := httptest.NewRequest(test.HTTPMethod, "http://localhost", nil)
handler.ServeHTTP(&recorderWithCloseNotify{httptest.NewRecorder()}, req)
})
}
}

View File

@@ -34,15 +34,9 @@ func getTokenValue(token, domain string, store Store) []byte {
var result []byte
operation := func() error {
var ok bool
httpChallenges, err := store.GetHTTPChallenges()
if err != nil {
return fmt.Errorf("HTTPChallenges not available : %s", err)
}
if result, ok = httpChallenges[token][domain]; !ok {
return fmt.Errorf("cannot find challenge for token %v", token)
}
return nil
var err error
result, err = store.GetHTTPChallengeToken(token, domain)
return err
}
notify := func(err error, time time.Duration) {
@@ -60,40 +54,9 @@ func getTokenValue(token, domain string, store Store) []byte {
}
func presentHTTPChallenge(domain, token, keyAuth string, store Store) error {
httpChallenges, err := store.GetHTTPChallenges()
if err != nil {
return fmt.Errorf("unable to get HTTPChallenges : %s", err)
}
if httpChallenges == nil {
httpChallenges = map[string]map[string][]byte{}
}
if _, ok := httpChallenges[token]; !ok {
httpChallenges[token] = map[string][]byte{}
}
httpChallenges[token][domain] = []byte(keyAuth)
return store.SaveHTTPChallenges(httpChallenges)
return store.SetHTTPChallengeToken(token, domain, []byte(keyAuth))
}
func cleanUpHTTPChallenge(domain, token string, store Store) error {
httpChallenges, err := store.GetHTTPChallenges()
if err != nil {
return fmt.Errorf("unable to get HTTPChallenges : %s", err)
}
log.Debugf("Challenge CleanUp for domain %s", domain)
if _, ok := httpChallenges[token]; ok {
if _, domainOk := httpChallenges[token][domain]; domainOk {
delete(httpChallenges[token], domain)
}
if len(httpChallenges[token]) == 0 {
delete(httpChallenges, token)
}
return store.SaveHTTPChallenges(httpChallenges)
}
return nil
return store.RemoveHTTPChallengeToken(token, domain)
}

View File

@@ -2,9 +2,11 @@ package acme
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"regexp"
"sync"
"github.com/containous/traefik/log"
"github.com/containous/traefik/safe"
@@ -17,11 +19,12 @@ type LocalStore struct {
filename string
storedData *StoredData
SaveDataChan chan *StoredData `json:"-"`
lock sync.RWMutex
}
// NewLocalStore initializes a new LocalStore with a file name
func NewLocalStore(filename string) LocalStore {
store := LocalStore{filename: filename, SaveDataChan: make(chan *StoredData)}
func NewLocalStore(filename string) *LocalStore {
store := &LocalStore{filename: filename, SaveDataChan: make(chan *StoredData)}
store.listenSaveAction()
return store
}
@@ -149,13 +152,59 @@ func (s *LocalStore) SaveCertificates(certificates []*Certificate) error {
return nil
}
// GetHTTPChallenges returns ACME HTTP Challenges list
func (s *LocalStore) GetHTTPChallenges() (map[string]map[string][]byte, error) {
return s.storedData.HTTPChallenges, nil
// GetHTTPChallengeToken Get the http challenge token from the store
func (s *LocalStore) GetHTTPChallengeToken(token, domain string) ([]byte, error) {
s.lock.RLock()
defer s.lock.RUnlock()
if s.storedData.HTTPChallenges == nil {
s.storedData.HTTPChallenges = map[string]map[string][]byte{}
}
if _, ok := s.storedData.HTTPChallenges[token]; !ok {
return nil, fmt.Errorf("cannot find challenge for token %v", token)
}
result, ok := s.storedData.HTTPChallenges[token][domain]
if !ok {
return nil, fmt.Errorf("cannot find challenge for token %v", token)
}
return result, nil
}
// SaveHTTPChallenges stores ACME HTTP Challenges list
func (s *LocalStore) SaveHTTPChallenges(httpChallenges map[string]map[string][]byte) error {
s.storedData.HTTPChallenges = httpChallenges
// SetHTTPChallengeToken Set the http challenge token in the store
func (s *LocalStore) SetHTTPChallengeToken(token, domain string, keyAuth []byte) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.storedData.HTTPChallenges == nil {
s.storedData.HTTPChallenges = map[string]map[string][]byte{}
}
if _, ok := s.storedData.HTTPChallenges[token]; !ok {
s.storedData.HTTPChallenges[token] = map[string][]byte{}
}
s.storedData.HTTPChallenges[token][domain] = []byte(keyAuth)
return nil
}
// RemoveHTTPChallengeToken Remove the http challenge token in the store
func (s *LocalStore) RemoveHTTPChallengeToken(token, domain string) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.storedData.HTTPChallenges == nil {
return nil
}
if _, ok := s.storedData.HTTPChallenges[token]; ok {
if _, domainOk := s.storedData.HTTPChallenges[token][domain]; domainOk {
delete(s.storedData.HTTPChallenges[token], domain)
}
if len(s.storedData.HTTPChallenges[token]) == 0 {
delete(s.storedData.HTTPChallenges, token)
}
}
return nil
}

View File

@@ -63,6 +63,8 @@ type Provider struct {
clientMutex sync.Mutex
configFromListenerChan chan types.Configuration
pool *safe.Pool
resolvingDomains map[string]struct{}
resolvingDomainsMutex sync.RWMutex
}
// Certificate is a struct which contains all data needed from an ACME certificate
@@ -127,6 +129,9 @@ func (p *Provider) init() error {
return fmt.Errorf("unable to get ACME certificates : %v", err)
}
// Init the currently resolved domain map
p.resolvingDomains = make(map[string]struct{})
p.watchCertificate()
p.watchNewDomains()
@@ -226,6 +231,9 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
return nil, nil
}
p.addResolvingDomains(uncheckedDomains)
defer p.removeResolvingDomains(uncheckedDomains)
log.Debugf("Loading ACME certificates %+v...", uncheckedDomains)
client, err := p.getClient()
if err != nil {
@@ -254,6 +262,24 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
return certificate, nil
}
func (p *Provider) removeResolvingDomains(resolvingDomains []string) {
p.resolvingDomainsMutex.Lock()
defer p.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
delete(p.resolvingDomains, domain)
}
}
func (p *Provider) addResolvingDomains(resolvingDomains []string) {
p.resolvingDomainsMutex.Lock()
defer p.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
p.resolvingDomains[domain] = struct{}{}
}
}
func (p *Provider) getClient() (*acme.Client, error) {
p.clientMutex.Lock()
defer p.clientMutex.Unlock()
@@ -503,6 +529,9 @@ func (p *Provider) AddRoutes(router *mux.Router) {
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string {
p.resolvingDomainsMutex.RLock()
defer p.resolvingDomainsMutex.RUnlock()
log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck)
var allCerts []string
@@ -523,6 +552,11 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati
allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ","))
}
// Get currently resolved domains
for domain := range p.resolvingDomains {
allCerts = append(allCerts, domain)
}
// Get Configuration Domains
if checkConfigurationDomains {
for i := 0; i < len(p.Domains); i++ {
@@ -540,8 +574,9 @@ func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) [
uncheckedDomains = append(uncheckedDomains, domainToCheck)
}
}
if len(uncheckedDomains) == 0 {
log.Debugf("No ACME certificate to generate for domains %q.", domainsToCheck)
log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck)
} else {
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ","))
}

View File

@@ -26,6 +26,7 @@ func TestGetUncheckedCertificates(t *testing.T) {
desc string
dynamicCerts *safe.Safe
staticCerts map[string]*tls.Certificate
resolvingDomains map[string]struct{}
acmeCertificates []*Certificate
domains []string
expectedDomains []string
@@ -138,17 +139,55 @@ func TestGetUncheckedCertificates(t *testing.T) {
},
expectedDomains: []string{"traefik.wtf"},
},
{
desc: "all domains already managed by ACME",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"traefik.wtf": {},
"foo.traefik.wtf": {},
},
expectedDomains: []string{},
},
{
desc: "one domain already managed by ACME",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"traefik.wtf": {},
},
expectedDomains: []string{"foo.traefik.wtf"},
},
{
desc: "wildcard domain already managed by ACME checks the domains",
domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"*.traefik.wtf": {},
},
expectedDomains: []string{},
},
{
desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked",
domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"},
resolvingDomains: map[string]struct{}{
"*.traefik.wtf": {},
"traefik.wtf": {},
},
expectedDomains: []string{"acme.wtf"},
},
}
for _, test := range testCases {
test := test
if test.resolvingDomains == nil {
test.resolvingDomains = make(map[string]struct{})
}
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
acmeProvider := Provider{
dynamicCerts: test.dynamicCerts,
staticCerts: test.staticCerts,
certificates: test.acmeCertificates,
dynamicCerts: test.dynamicCerts,
staticCerts: test.staticCerts,
certificates: test.acmeCertificates,
resolvingDomains: test.resolvingDomains,
}
domains := acmeProvider.getUncheckedDomains(test.domains, false)

View File

@@ -13,6 +13,7 @@ type Store interface {
SaveAccount(*Account) error
GetCertificates() ([]*Certificate, error)
SaveCertificates([]*Certificate) error
GetHTTPChallenges() (map[string]map[string][]byte, error)
SaveHTTPChallenges(map[string]map[string][]byte) error
GetHTTPChallengeToken(token, domain string) ([]byte, error)
SetHTTPChallengeToken(token, domain string, keyAuth []byte) error
RemoveHTTPChallengeToken(token, domain string) error
}

View File

@@ -1,7 +1,6 @@
package consulcatalog
import (
"errors"
"fmt"
"strconv"
"strings"
@@ -155,14 +154,8 @@ func (p *Provider) watch(configurationChan chan<- types.ConfigMessage, stop chan
defer close(stopCh)
defer close(watchCh)
for {
select {
case <-stop:
return nil
case index, ok := <-watchCh:
if !ok {
return errors.New("consul service list nil")
}
safe.Go(func() {
for index := range watchCh {
log.Debug("List of services changed")
nodes, err := p.getNodes(index)
if err != nil {
@@ -173,6 +166,13 @@ func (p *Provider) watch(configurationChan chan<- types.ConfigMessage, stop chan
ProviderName: "consul_catalog",
Configuration: configuration,
}
}
})
for {
select {
case <-stop:
return nil
case err := <-errorCh:
return err
}

View File

@@ -2,6 +2,8 @@ package docker
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net"
"strconv"
@@ -107,13 +109,11 @@ func (p *Provider) buildConfigurationV2(containersInspected []dockerData) *types
}
func getServiceNameKey(container dockerData, swarmMode bool, segmentName string) string {
serviceNameKey := container.ServiceName
if values, err := label.GetStringMultipleStrict(container.Labels, labelDockerComposeProject, labelDockerComposeService); !swarmMode && err == nil {
serviceNameKey = values[labelDockerComposeService] + values[labelDockerComposeProject]
if swarmMode {
return container.ServiceName + segmentName
}
return serviceNameKey + segmentName
return getServiceName(container) + segmentName
}
func (p *Provider) containerFilter(container dockerData) bool {
@@ -170,7 +170,7 @@ func checkSegmentPort(labels map[string]string, segmentName string) error {
func (p *Provider) getFrontendName(container dockerData, idx int) string {
var name string
if len(container.SegmentName) > 0 {
name = getBackendName(container)
name = container.SegmentName + "-" + getBackendName(container)
} else {
name = p.getFrontendRule(container, container.SegmentLabels) + "-" + strconv.Itoa(idx)
}
@@ -262,17 +262,21 @@ func isBackendLBSwarm(container dockerData) bool {
return label.GetBoolValue(container.Labels, labelBackendLoadBalancerSwarm, false)
}
func getSegmentBackendName(container dockerData) string {
serviceName := container.ServiceName
if values, err := label.GetStringMultipleStrict(container.Labels, labelDockerComposeProject, labelDockerComposeService); err == nil {
serviceName = provider.Normalize(values[labelDockerComposeService] + "_" + values[labelDockerComposeProject])
func getBackendName(container dockerData) string {
if len(container.SegmentName) > 0 {
return getSegmentBackendName(container)
}
return getDefaultBackendName(container)
}
func getSegmentBackendName(container dockerData) string {
serviceName := getServiceName(container)
if value := label.GetStringValue(container.SegmentLabels, label.TraefikBackend, ""); len(value) > 0 {
return provider.Normalize(serviceName + "-" + value)
}
return provider.Normalize(serviceName + "-" + getDefaultBackendName(container) + "-" + container.SegmentName)
return provider.Normalize(serviceName + "-" + container.SegmentName)
}
func getDefaultBackendName(container dockerData) string {
@@ -280,19 +284,17 @@ func getDefaultBackendName(container dockerData) string {
return provider.Normalize(value)
}
if values, err := label.GetStringMultipleStrict(container.Labels, labelDockerComposeProject, labelDockerComposeService); err == nil {
return provider.Normalize(values[labelDockerComposeService] + "_" + values[labelDockerComposeProject])
}
return provider.Normalize(container.ServiceName)
return provider.Normalize(getServiceName(container))
}
func getBackendName(container dockerData) string {
if len(container.SegmentName) > 0 {
return getSegmentBackendName(container)
func getServiceName(container dockerData) string {
serviceName := container.ServiceName
if values, err := label.GetStringMultipleStrict(container.Labels, labelDockerComposeProject, labelDockerComposeService); err == nil {
serviceName = values[labelDockerComposeService] + "_" + values[labelDockerComposeProject]
}
return getDefaultBackendName(container)
return serviceName
}
func getPort(container dockerData) string {
@@ -322,7 +324,7 @@ func getPort(container dockerData) string {
func (p *Provider) getServers(containers []dockerData) map[string]types.Server {
var servers map[string]types.Server
for i, container := range containers {
for _, container := range containers {
ip := p.getIPAddress(container)
if len(ip) == 0 {
log.Warnf("Unable to find the IP address for the container %q: the server is ignored.", container.Name)
@@ -336,16 +338,30 @@ func (p *Provider) getServers(containers []dockerData) map[string]types.Server {
protocol := label.GetStringValue(container.SegmentLabels, label.TraefikProtocol, label.DefaultProtocol)
port := getPort(container)
serverName := "server-" + container.SegmentName + "-" + container.Name
if len(container.SegmentName) > 0 {
serverName += "-" + strconv.Itoa(i)
serverURL := fmt.Sprintf("%s://%s", protocol, net.JoinHostPort(ip, port))
serverName := getServerName(container.Name, serverURL)
if _, exist := servers[serverName]; exist {
log.Debugf("Skipping server %q with the same URL.", serverName)
continue
}
servers[provider.Normalize(serverName)] = types.Server{
URL: fmt.Sprintf("%s://%s", protocol, net.JoinHostPort(ip, port)),
servers[serverName] = types.Server{
URL: serverURL,
Weight: label.GetIntValue(container.SegmentLabels, label.TraefikWeight, label.DefaultWeight),
}
}
return servers
}
func getServerName(containerName, url string) string {
hash := md5.New()
_, err := hash.Write([]byte(url))
if err != nil {
// Impossible case
log.Errorf("Fail to hash server URL %q", url)
}
return provider.Normalize("server-" + containerName + "-" + hex.EncodeToString(hash.Sum(nil)))
}

View File

@@ -55,7 +55,7 @@ func TestDockerBuildConfiguration(t *testing.T) {
expectedBackends: map[string]*types.Backend{
"backend-test": {
Servers: map[string]types.Server{
"server-test": {
"server-test-842895ca2aca17f6ee36ddb2f621194d": {
URL: "http://127.0.0.1:80",
Weight: label.DefaultWeight,
},
@@ -270,7 +270,7 @@ func TestDockerBuildConfiguration(t *testing.T) {
expectedBackends: map[string]*types.Backend{
"backend-foobar": {
Servers: map[string]types.Server{
"server-test1": {
"server-test1-7f6444e0dff3330c8b0ad2bbbd383b0f": {
URL: "https://127.0.0.1:666",
Weight: 12,
},
@@ -372,10 +372,11 @@ func TestDockerBuildConfiguration(t *testing.T) {
expectedBackends: map[string]*types.Backend{
"backend-myService-myProject": {
Servers: map[string]types.Server{
"server-test-0": {
"server-test-0-842895ca2aca17f6ee36ddb2f621194d": {
URL: "http://127.0.0.1:80",
Weight: label.DefaultWeight,
}, "server-test-1": {
},
"server-test-1-48093b9fc43454203aacd2bc4057a08c": {
URL: "http://127.0.0.2:80",
Weight: label.DefaultWeight,
},
@@ -384,7 +385,7 @@ func TestDockerBuildConfiguration(t *testing.T) {
},
"backend-myService2-myProject": {
Servers: map[string]types.Server{
"server-test-2": {
"server-test-2-405767e9733427148cd8dae6c4d331b0": {
URL: "http://127.0.0.3:80",
Weight: label.DefaultWeight,
},
@@ -1055,7 +1056,7 @@ func TestDockerGetServers(t *testing.T) {
})),
},
expected: map[string]types.Server{
"server-test1": {
"server-test1-fb00f762970935200c76ccdaf91458f6": {
URL: "http://10.10.10.10:80",
Weight: 1,
},
@@ -1084,15 +1085,15 @@ func TestDockerGetServers(t *testing.T) {
})),
},
expected: map[string]types.Server{
"server-test1": {
"server-test1-743440b6f4a8ffd8737626215f2c5a33": {
URL: "http://10.10.10.11:80",
Weight: 1,
},
"server-test2": {
"server-test2-547f74bbb5da02b6c8141ce9aa96c13b": {
URL: "http://10.10.10.12:81",
Weight: 1,
},
"server-test3": {
"server-test3-c57fd8b848c814a3f2a4a4c12e13c179": {
URL: "http://10.10.10.13:82",
Weight: 1,
},
@@ -1121,11 +1122,11 @@ func TestDockerGetServers(t *testing.T) {
})),
},
expected: map[string]types.Server{
"server-test2": {
"server-test2-547f74bbb5da02b6c8141ce9aa96c13b": {
URL: "http://10.10.10.12:81",
Weight: 1,
},
"server-test3": {
"server-test3-c57fd8b848c814a3f2a4a4c12e13c179": {
URL: "http://10.10.10.13:82",
Weight: 1,
},

View File

@@ -57,7 +57,7 @@ func TestSwarmBuildConfiguration(t *testing.T) {
expectedBackends: map[string]*types.Backend{
"backend-test": {
Servers: map[string]types.Server{
"server-test": {
"server-test-842895ca2aca17f6ee36ddb2f621194d": {
URL: "http://127.0.0.1:80",
Weight: label.DefaultWeight,
},
@@ -238,7 +238,6 @@ func TestSwarmBuildConfiguration(t *testing.T) {
ReferrerPolicy: "foo",
IsDevelopment: true,
},
Errors: map[string]*types.ErrorPage{
"foo": {
Status: []string{"404"},
@@ -276,7 +275,7 @@ func TestSwarmBuildConfiguration(t *testing.T) {
expectedBackends: map[string]*types.Backend{
"backend-foobar": {
Servers: map[string]types.Server{
"server-test1": {
"server-test1-7f6444e0dff3330c8b0ad2bbbd383b0f": {
URL: "https://127.0.0.1:666",
Weight: 12,
},

View File

@@ -42,22 +42,22 @@ func TestSegmentBuildConfiguration(t *testing.T) {
),
},
expectedFrontends: map[string]*types.Frontend{
"frontend-foo-foo-sauternes": {
Backend: "backend-foo-foo-sauternes",
"frontend-sauternes-foo-sauternes": {
Backend: "backend-foo-sauternes",
PassHostHeader: true,
EntryPoints: []string{"http", "https"},
BasicAuth: []string{},
Routes: map[string]types.Route{
"route-frontend-foo-foo-sauternes": {
"route-frontend-sauternes-foo-sauternes": {
Rule: "Host:foo.docker.localhost",
},
},
},
},
expectedBackends: map[string]*types.Backend{
"backend-foo-foo-sauternes": {
"backend-foo-sauternes": {
Servers: map[string]types.Server{
"server-sauternes-foo-0": {
"server-foo-863563a2e23c95502862016417ee95ea": {
URL: "http://127.0.0.1:2503",
Weight: label.DefaultWeight,
},
@@ -132,8 +132,8 @@ func TestSegmentBuildConfiguration(t *testing.T) {
),
},
expectedFrontends: map[string]*types.Frontend{
"frontend-foo-foo-sauternes": {
Backend: "backend-foo-foo-sauternes",
"frontend-sauternes-foo-sauternes": {
Backend: "backend-foo-sauternes",
EntryPoints: []string{
"http",
"https",
@@ -224,16 +224,16 @@ func TestSegmentBuildConfiguration(t *testing.T) {
},
Routes: map[string]types.Route{
"route-frontend-foo-foo-sauternes": {
"route-frontend-sauternes-foo-sauternes": {
Rule: "Host:foo.docker.localhost",
},
},
},
},
expectedBackends: map[string]*types.Backend{
"backend-foo-foo-sauternes": {
"backend-foo-sauternes": {
Servers: map[string]types.Server{
"server-sauternes-foo-0": {
"server-foo-7f6444e0dff3330c8b0ad2bbbd383b0f": {
URL: "https://127.0.0.1:666",
Weight: 12,
},
@@ -278,7 +278,7 @@ func TestSegmentBuildConfiguration(t *testing.T) {
),
},
expectedFrontends: map[string]*types.Frontend{
"frontend-test1-foobar": {
"frontend-sauternes-test1-foobar": {
Backend: "backend-test1-foobar",
PassHostHeader: false,
Priority: 5000,
@@ -288,18 +288,18 @@ func TestSegmentBuildConfiguration(t *testing.T) {
EntryPoint: "https",
},
Routes: map[string]types.Route{
"route-frontend-test1-foobar": {
"route-frontend-sauternes-test1-foobar": {
Rule: "Path:/mypath",
},
},
},
"frontend-test2-test2-anothersauternes": {
Backend: "backend-test2-test2-anothersauternes",
"frontend-anothersauternes-test2-anothersauternes": {
Backend: "backend-test2-anothersauternes",
PassHostHeader: true,
EntryPoints: []string{},
BasicAuth: []string{},
Routes: map[string]types.Route{
"route-frontend-test2-test2-anothersauternes": {
"route-frontend-anothersauternes-test2-anothersauternes": {
Rule: "Path:/anotherpath",
},
},
@@ -308,16 +308,16 @@ func TestSegmentBuildConfiguration(t *testing.T) {
expectedBackends: map[string]*types.Backend{
"backend-test1-foobar": {
Servers: map[string]types.Server{
"server-sauternes-test1-0": {
"server-test1-79533a101142718f0fdf84c42593c41e": {
URL: "https://127.0.0.1:2503",
Weight: 80,
},
},
CircuitBreaker: nil,
},
"backend-test2-test2-anothersauternes": {
"backend-test2-anothersauternes": {
Servers: map[string]types.Server{
"server-anothersauternes-test2-0": {
"server-test2-e9c1b66f9af919aa46053fbc2391bb4a": {
URL: "http://127.0.0.1:8079",
Weight: 33,
},
@@ -326,6 +326,152 @@ func TestSegmentBuildConfiguration(t *testing.T) {
},
},
},
{
desc: "several segments with the same backend name and same port",
containers: []docker.ContainerJSON{
containerJSON(
name("test1"),
labels(map[string]string{
"traefik.port": "2503",
"traefik.protocol": "https",
"traefik.weight": "80",
"traefik.frontend.entryPoints": "http,https",
"traefik.frontend.redirect.entryPoint": "https",
"traefik.sauternes.backend": "foobar",
"traefik.sauternes.frontend.rule": "Path:/sauternes",
"traefik.sauternes.frontend.priority": "5000",
"traefik.arbois.backend": "foobar",
"traefik.arbois.frontend.rule": "Path:/arbois",
"traefik.arbois.frontend.priority": "3000",
}),
ports(nat.PortMap{
"80/tcp": {},
}),
withNetwork("bridge", ipv4("127.0.0.1")),
),
},
expectedFrontends: map[string]*types.Frontend{
"frontend-sauternes-test1-foobar": {
Backend: "backend-test1-foobar",
PassHostHeader: true,
Priority: 5000,
EntryPoints: []string{"http", "https"},
BasicAuth: []string{},
Redirect: &types.Redirect{
EntryPoint: "https",
},
Routes: map[string]types.Route{
"route-frontend-sauternes-test1-foobar": {
Rule: "Path:/sauternes",
},
},
},
"frontend-arbois-test1-foobar": {
Backend: "backend-test1-foobar",
PassHostHeader: true,
Priority: 3000,
EntryPoints: []string{"http", "https"},
BasicAuth: []string{},
Redirect: &types.Redirect{
EntryPoint: "https",
},
Routes: map[string]types.Route{
"route-frontend-arbois-test1-foobar": {
Rule: "Path:/arbois",
},
},
},
},
expectedBackends: map[string]*types.Backend{
"backend-test1-foobar": {
Servers: map[string]types.Server{
"server-test1-79533a101142718f0fdf84c42593c41e": {
URL: "https://127.0.0.1:2503",
Weight: 80,
},
},
CircuitBreaker: nil,
},
},
},
{
desc: "several segments with the same backend name and different port (wrong behavior)",
containers: []docker.ContainerJSON{
containerJSON(
name("test1"),
labels(map[string]string{
"traefik.protocol": "https",
"traefik.frontend.entryPoints": "http,https",
"traefik.frontend.redirect.entryPoint": "https",
"traefik.sauternes.port": "2503",
"traefik.sauternes.weight": "80",
"traefik.sauternes.backend": "foobar",
"traefik.sauternes.frontend.rule": "Path:/sauternes",
"traefik.sauternes.frontend.priority": "5000",
"traefik.arbois.port": "2504",
"traefik.arbois.weight": "90",
"traefik.arbois.backend": "foobar",
"traefik.arbois.frontend.rule": "Path:/arbois",
"traefik.arbois.frontend.priority": "3000",
}),
ports(nat.PortMap{
"80/tcp": {},
}),
withNetwork("bridge", ipv4("127.0.0.1")),
),
},
expectedFrontends: map[string]*types.Frontend{
"frontend-sauternes-test1-foobar": {
Backend: "backend-test1-foobar",
PassHostHeader: true,
Priority: 5000,
EntryPoints: []string{"http", "https"},
BasicAuth: []string{},
Redirect: &types.Redirect{
EntryPoint: "https",
},
Routes: map[string]types.Route{
"route-frontend-sauternes-test1-foobar": {
Rule: "Path:/sauternes",
},
},
},
"frontend-arbois-test1-foobar": {
Backend: "backend-test1-foobar",
PassHostHeader: true,
Priority: 3000,
EntryPoints: []string{"http", "https"},
BasicAuth: []string{},
Redirect: &types.Redirect{
EntryPoint: "https",
},
Routes: map[string]types.Route{
"route-frontend-arbois-test1-foobar": {
Rule: "Path:/arbois",
},
},
},
},
expectedBackends: map[string]*types.Backend{
"backend-test1-foobar": {
Servers: map[string]types.Server{
"server-test1-79533a101142718f0fdf84c42593c41e": {
URL: "https://127.0.0.1:2503",
Weight: 80,
},
"server-test1-315a41140f1bd825b066e39686c18482": {
URL: "https://127.0.0.1:2504",
Weight: 90,
},
},
CircuitBreaker: nil,
},
},
},
}
provider := &Provider{

View File

@@ -205,7 +205,7 @@ func getFuncFirstStringValueV1(labelName string, defaultValue string) func(insta
// Deprecated
func getFuncFirstBoolValueV1(labelName string, defaultValue bool) func(instances []ecsInstance) bool {
return func(instances []ecsInstance) bool {
if len(instances) < 0 {
if len(instances) == 0 {
return defaultValue
}
return getBoolValueV1(instances[0], labelName, defaultValue)

View File

@@ -1,4 +1,4 @@
mkdocs>=0.17.3
pymdown-extensions>=1.4
mkdocs-bootswatch>=0.4.0
mkdocs-material>=2.2.6
mkdocs==0.17.5
pymdown-extensions==4.12
mkdocs-bootswatch==0.5.0
mkdocs-material==2.9.4

View File

@@ -1,13 +1,21 @@
package server
import (
"context"
"io"
"net"
"net/http"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
)
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
const StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
const StatusClientClosedRequestText = "Client Closed Request"
// RecordingErrorHandler is an error handler, implementing the vulcand/oxy
// error handler interface, which is recording network errors by using the netErrorRecorder.
// In addition it sets a proper HTTP status code and body, depending on the type of error occurred.
@@ -33,8 +41,18 @@ func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Requ
} else if err == io.EOF {
eh.netErrorRecorder.Record(req.Context())
statusCode = http.StatusBadGateway
} else if err == context.Canceled {
statusCode = StatusClientClosedRequest
}
w.WriteHeader(statusCode)
w.Write([]byte(http.StatusText(statusCode)))
w.Write([]byte(statusText(statusCode)))
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
}
return http.StatusText(statusCode)
}

View File

@@ -32,6 +32,7 @@ import (
"github.com/containous/traefik/middlewares/accesslog"
mauth "github.com/containous/traefik/middlewares/auth"
"github.com/containous/traefik/middlewares/errorpages"
"github.com/containous/traefik/middlewares/pipelining"
"github.com/containous/traefik/middlewares/redirect"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/provider"
@@ -79,14 +80,112 @@ type Server struct {
bufferPool httputil.BufferPool
}
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type hijackConnectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddHijackedConnection add a connection in the tracked connections list
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
h.conns[conn] = struct{}{}
}
// RemoveHijackedConnection remove a connection from the tracked connections list
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
delete(h.conns, conn)
}
// Shutdown wait for the connection closing
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
h.lock.RLock()
if len(h.conns) == 0 {
return nil
}
h.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.Errorf("Error while closing Hijacked conn: %v", err)
}
delete(h.conns, conn)
}
}
type serverEntryPoints map[string]*serverEntryPoint
type serverEntryPoint struct {
httpServer *http.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
certs safe.Safe
onDemandListener func(string) (*tls.Certificate, error)
httpServer *http.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
certs safe.Safe
onDemandListener func(string) (*tls.Certificate, error)
hijackConnectionTracker *hijackConnectionTracker
}
func (s serverEntryPoint) Shutdown(ctx context.Context) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait server shutdown is over due to: %s", err)
err = s.httpServer.Close()
if err != nil {
log.Error(err)
}
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait hijack connection is over due to: %s", err)
s.hijackConnectionTracker.Close()
}
}
}()
wg.Wait()
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// NewServer returns an initialized Server.
@@ -243,10 +342,7 @@ func (s *Server) Stop() {
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil {
log.Debugf("Wait is over due to: %s", err)
serverEntryPoint.httpServer.Close()
}
serverEntryPoint.Shutdown(ctx)
cancel()
log.Debugf("Entrypoint %s closed", serverEntryPointName)
}(sepn, sep)
@@ -359,9 +455,20 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer
log.Fatal("Error preparing server: ", err)
}
serverEntryPoint := s.serverEntryPoints[newServerEntryPointName]
serverEntryPoint.httpServer = newSrv
serverEntryPoint.listener = listener
serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker()
serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateHijacked:
serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn)
case http.StateClosed:
serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn)
}
}
return serverEntryPoint
}
@@ -802,6 +909,8 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.
return nil, nil, err
}
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
IPs, err := whitelist.NewIP(entryPoint.ProxyProtocol.TrustedIPs, entryPoint.ProxyProtocol.Insecure, false)
if err != nil {
@@ -1006,6 +1115,15 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
forward.Rewriter(rewriter),
forward.ResponseModifier(responseModifier),
forward.BufferPool(s.bufferPool),
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
server := req.Context().Value(http.ServerContextKey).(*http.Server)
if server != nil {
connState := server.ConnState
if connState != nil {
connState(conn, http.StateClosed)
}
}
}),
)
if err != nil {
@@ -1023,6 +1141,8 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
})
}
fwd = pipelining.NewPipelining(fwd)
var rr *roundrobin.RoundRobin
var saveFrontend http.Handler
if s.accessLoggerMiddleware != nil {

View File

@@ -235,7 +235,7 @@ type Configurations map[string]*Configuration
type Configuration struct {
Backends map[string]*Backend `json:"backends,omitempty"`
Frontends map[string]*Frontend `json:"frontends,omitempty"`
TLS []*traefiktls.Configuration `json:"tls,omitempty"`
TLS []*traefiktls.Configuration `json:"-"`
}
// ConfigMessage hold configuration information exchanged between parts of traefik.

View File

@@ -46,16 +46,16 @@ func (kv *KvSource) Parse(cmd *flaeg.Command) (*flaeg.Command, error) {
// LoadConfig loads data from the KV Store into the config structure (given by reference)
func (kv *KvSource) LoadConfig(config interface{}) error {
pairs := map[string][]byte{}
if err := kv.ListRecursive(kv.Prefix, pairs); err != nil {
pairs, err := kv.ListValuedPairWithPrefix(kv.Prefix)
if err != nil {
return err
}
// fmt.Printf("pairs : %#v\n", pairs)
mapStruct, err := generateMapstructure(convertPairs(pairs), kv.Prefix)
if err != nil {
return err
}
// fmt.Printf("mapStruct : %#v\n", mapStruct)
configDecoder := &mapstructure.DecoderConfig{
Metadata: nil,
Result: config,
@@ -77,11 +77,11 @@ func generateMapstructure(pairs []*store.KVPair, prefix string) (map[string]inte
for _, p := range pairs {
// Trim the prefix off our key first
key := strings.TrimPrefix(strings.Trim(p.Key, "/"), strings.Trim(prefix, "/")+"/")
raw, err := processKV(key, p.Value, raw)
var err error
raw, err = processKV(key, p.Value, raw)
if err != nil {
return raw, err
}
}
return raw, nil
}
@@ -313,15 +313,23 @@ func collateKvRecursive(objValue reflect.Value, kv map[string]string, key string
func writeCompressedData(data []byte) (string, error) {
var buffer bytes.Buffer
gzipWriter := gzip.NewWriter(&buffer)
_, err := gzipWriter.Write(data)
if err != nil {
return "", err
}
gzipWriter.Close()
err = gzipWriter.Close()
if err != nil {
return "", err
}
return buffer.String(), nil
}
// ListRecursive lists all key value children under key
// Replaced by ListValuedPairWithPrefix
// Deprecated
func (kv *KvSource) ListRecursive(key string, pairs map[string][]byte) error {
pairsN1, err := kv.List(key, nil)
if err == store.ErrKeyNotFound {
@@ -342,14 +350,37 @@ func (kv *KvSource) ListRecursive(key string, pairs map[string][]byte) error {
return nil
}
for _, p := range pairsN1 {
err := kv.ListRecursive(p.Key, pairs)
if err != nil {
return err
if p.Key != key {
err := kv.ListRecursive(p.Key, pairs)
if err != nil {
return err
}
}
}
return nil
}
// ListValuedPairWithPrefix lists all key value children under key
func (kv *KvSource) ListValuedPairWithPrefix(key string) (map[string][]byte, error) {
pairs := make(map[string][]byte)
pairsN1, err := kv.List(key, nil)
if err == store.ErrKeyNotFound {
return pairs, nil
}
if err != nil {
return pairs, err
}
for _, p := range pairsN1 {
if len(p.Value) > 0 {
pairs[p.Key] = p.Value
}
}
return pairs, nil
}
func convertPairs(pairs map[string][]byte) []*store.KVPair {
slicePairs := make([]*store.KVPair, len(pairs))
i := 0

View File

@@ -2,12 +2,8 @@ package staert
import (
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"github.com/BurntSushi/toml"
"github.com/containous/flaeg"
)
@@ -24,10 +20,7 @@ type Staert struct {
// NewStaert creates and return a pointer on Staert. Need defaultConfig and defaultPointersConfig given by references
func NewStaert(rootCommand *flaeg.Command) *Staert {
s := Staert{
command: rootCommand,
}
return &s
return &Staert{command: rootCommand}
}
// AddSource adds new Source to Staert, give it by reference
@@ -35,40 +28,31 @@ func (s *Staert) AddSource(src Source) {
s.sources = append(s.sources, src)
}
// getConfig for a flaeg.Command run sources Parse func in the raw
func (s *Staert) parseConfigAllSources(cmd *flaeg.Command) error {
for _, src := range s.sources {
var err error
_, err = src.Parse(cmd)
if err != nil {
return err
}
}
return nil
}
// LoadConfig check which command is called and parses config
// It returns the the parsed config or an error if it fails
func (s *Staert) LoadConfig() (interface{}, error) {
for _, src := range s.sources {
//Type assertion
f, ok := src.(*flaeg.Flaeg)
if ok {
if fCmd, err := f.GetCommand(); err != nil {
// Type assertion
if flg, ok := src.(*flaeg.Flaeg); ok {
fCmd, err := flg.GetCommand()
if err != nil {
return nil, err
} else if s.command != fCmd {
//IF fleag sub-command
}
// if fleag sub-command
if s.command != fCmd {
// if parseAllSources
if fCmd.Metadata["parseAllSources"] == "true" {
//IF parseAllSources
fCmdConfigType := reflect.TypeOf(fCmd.Config)
sCmdConfigType := reflect.TypeOf(s.command.Config)
if fCmdConfigType != sCmdConfigType {
return nil, fmt.Errorf("command %s : Config type doesn't match with root command config type. Expected %s got %s", fCmd.Name, sCmdConfigType.Name(), fCmdConfigType.Name())
return nil, fmt.Errorf("command %s : Config type doesn't match with root command config type. Expected %s got %s",
fCmd.Name, sCmdConfigType.Name(), fCmdConfigType.Name())
}
s.command = fCmd
} else {
// ELSE (not parseAllSources)
s.command, err = f.Parse(fCmd)
// (not parseAllSources)
s.command, err = flg.Parse(fCmd)
return s.command.Config, err
}
}
@@ -78,117 +62,19 @@ func (s *Staert) LoadConfig() (interface{}, error) {
return s.command.Config, err
}
// parseConfigAllSources getConfig for a flaeg.Command run sources Parse func in the raw
func (s *Staert) parseConfigAllSources(cmd *flaeg.Command) error {
for _, src := range s.sources {
_, err := src.Parse(cmd)
if err != nil {
return err
}
}
return nil
}
// Run calls the Run func of the command
// Warning, Run doesn't parse the config
func (s *Staert) Run() error {
return s.command.Run()
}
//TomlSource impement Source
type TomlSource struct {
filename string
dirNfullpath []string
fullpath string
}
// NewTomlSource creates and return a pointer on TomlSource.
// Parameter filename is the file name (without extension type, ".toml" will be added)
// dirNfullpath may contain directories or fullpath to the file.
func NewTomlSource(filename string, dirNfullpath []string) *TomlSource {
return &TomlSource{filename, dirNfullpath, ""}
}
// ConfigFileUsed return config file used
func (ts *TomlSource) ConfigFileUsed() string {
return ts.fullpath
}
func preprocessDir(dirIn string) (string, error) {
dirOut := dirIn
expanded := os.ExpandEnv(dirIn)
dirOut, err := filepath.Abs(expanded)
return dirOut, err
}
func findFile(filename string, dirNfile []string) string {
for _, df := range dirNfile {
if df != "" {
fullPath, _ := preprocessDir(df)
if fileInfo, err := os.Stat(fullPath); err == nil && !fileInfo.IsDir() {
return fullPath
}
fullPath = filepath.Join(fullPath, filename+".toml")
if fileInfo, err := os.Stat(fullPath); err == nil && !fileInfo.IsDir() {
return fullPath
}
}
}
return ""
}
// Parse calls toml.DecodeFile() func
func (ts *TomlSource) Parse(cmd *flaeg.Command) (*flaeg.Command, error) {
ts.fullpath = findFile(ts.filename, ts.dirNfullpath)
if len(ts.fullpath) < 2 {
return cmd, nil
}
metadata, err := toml.DecodeFile(ts.fullpath, cmd.Config)
if err != nil {
return nil, err
}
boolFlags, err := flaeg.GetBoolFlags(cmd.Config)
if err != nil {
return nil, err
}
flaegArgs, hasUnderField, err := generateArgs(metadata, boolFlags)
if err != nil {
return nil, err
}
// fmt.Println(flaegArgs)
err = flaeg.Load(cmd.Config, cmd.DefaultPointersConfig, flaegArgs)
//if err!= missing parser err
if err != nil && err != flaeg.ErrParserNotFound {
return nil, err
}
if hasUnderField {
_, err := toml.DecodeFile(ts.fullpath, cmd.Config)
if err != nil {
return nil, err
}
}
return cmd, nil
}
func generateArgs(metadata toml.MetaData, flags []string) ([]string, bool, error) {
var flaegArgs []string
keys := metadata.Keys()
hasUnderField := false
for i, key := range keys {
// fmt.Println(key)
if metadata.Type(key.String()) == "Hash" {
// TOML hashes correspond to Go structs or maps.
// fmt.Printf("%s could be a ptr on a struct, or a map\n", key)
for j := i; j < len(keys); j++ {
// fmt.Printf("%s =? %s\n", keys[j].String(), "."+key.String())
if strings.Contains(keys[j].String(), key.String()+".") {
hasUnderField = true
break
}
}
match := false
for _, flag := range flags {
if flag == strings.ToLower(key.String()) {
match = true
break
}
}
if match {
flaegArgs = append(flaegArgs, "--"+strings.ToLower(key.String()))
}
}
}
return flaegArgs, hasUnderField, nil
}

121
vendor/github.com/containous/staert/toml.go generated vendored Normal file
View File

@@ -0,0 +1,121 @@
package staert
import (
"os"
"path/filepath"
"strings"
"github.com/BurntSushi/toml"
"github.com/containous/flaeg"
)
var _ Source = (*TomlSource)(nil)
// TomlSource implement staert.Source
type TomlSource struct {
filename string
dirNFullPath []string
fullPath string
}
// NewTomlSource creates and return a pointer on Source.
// Parameter filename is the file name (without extension type, ".toml" will be added)
// dirNFullPath may contain directories or fullPath to the file.
func NewTomlSource(filename string, dirNFullPath []string) *TomlSource {
return &TomlSource{filename, dirNFullPath, ""}
}
// ConfigFileUsed return config file used
func (ts *TomlSource) ConfigFileUsed() string {
return ts.fullPath
}
// Parse calls toml.DecodeFile() func
func (ts *TomlSource) Parse(cmd *flaeg.Command) (*flaeg.Command, error) {
ts.fullPath = findFile(ts.filename, ts.dirNFullPath)
if len(ts.fullPath) < 2 {
return cmd, nil
}
metadata, err := toml.DecodeFile(ts.fullPath, cmd.Config)
if err != nil {
return nil, err
}
boolFlags, err := flaeg.GetBoolFlags(cmd.Config)
if err != nil {
return nil, err
}
flgArgs, hasUnderField, err := generateArgs(metadata, boolFlags)
if err != nil {
return nil, err
}
err = flaeg.Load(cmd.Config, cmd.DefaultPointersConfig, flgArgs)
if err != nil && err != flaeg.ErrParserNotFound {
return nil, err
}
if hasUnderField {
_, err := toml.DecodeFile(ts.fullPath, cmd.Config)
if err != nil {
return nil, err
}
}
return cmd, nil
}
func preProcessDir(dirIn string) (string, error) {
expanded := os.ExpandEnv(dirIn)
return filepath.Abs(expanded)
}
func findFile(filename string, dirNFile []string) string {
for _, df := range dirNFile {
if df != "" {
fullPath, _ := preProcessDir(df)
if fileInfo, err := os.Stat(fullPath); err == nil && !fileInfo.IsDir() {
return fullPath
}
fullPath = filepath.Join(fullPath, filename+".toml")
if fileInfo, err := os.Stat(fullPath); err == nil && !fileInfo.IsDir() {
return fullPath
}
}
}
return ""
}
func generateArgs(metadata toml.MetaData, flags []string) ([]string, bool, error) {
var flgArgs []string
keys := metadata.Keys()
hasUnderField := false
for i, key := range keys {
if metadata.Type(key.String()) == "Hash" {
// TOML hashes correspond to Go structs or maps.
for j := i; j < len(keys); j++ {
if strings.Contains(keys[j].String(), key.String()+".") {
hasUnderField = true
break
}
}
match := false
for _, flag := range flags {
if flag == strings.ToLower(key.String()) {
match = true
break
}
}
if match {
flgArgs = append(flgArgs, "--"+strings.ToLower(key.String()))
}
}
}
return flgArgs, hasUnderField, nil
}

View File

@@ -36,13 +36,12 @@ Examples of a buffering middleware:
package buffer
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"net/http"
"bufio"
"net"
"net/http"
"reflect"
"github.com/mailgun/multibuf"
@@ -74,6 +73,8 @@ type Buffer struct {
next http.Handler
errHandler utils.ErrorHandler
log *log.Logger
}
// New returns a new buffer middleware. New() function supports optional functional arguments
@@ -86,6 +87,8 @@ func New(next http.Handler, setters ...optSetter) (*Buffer, error) {
maxResponseBodyBytes: DefaultMaxBodyBytes,
memResponseBodyBytes: DefaultMemBodyBytes,
log: log.StandardLogger(),
}
for _, s := range setters {
if err := s(strm); err != nil {
@@ -99,6 +102,16 @@ func New(next http.Handler, setters ...optSetter) (*Buffer, error) {
return strm, nil
}
// Logger defines the logger the buffer will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) optSetter {
return func(b *Buffer) error {
b.log = l
return nil
}
}
type optSetter func(b *Buffer) error
// CondSetter Conditional setter.
@@ -154,7 +167,7 @@ func MaxRequestBodyBytes(m int64) optSetter {
}
}
// MaxRequestBody bytes sets the maximum request body to be stored in memory
// MemRequestBodyBytes bytes sets the maximum request body to be stored in memory
// buffer middleware will serialize the excess to disk.
func MemRequestBodyBytes(m int64) optSetter {
return func(b *Buffer) error {
@@ -196,8 +209,8 @@ func (b *Buffer) Wrap(next http.Handler) error {
}
func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
if b.log.Level >= log.DebugLevel {
logEntry := b.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/buffer: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/buffer: completed ServeHttp on request")
}
@@ -210,11 +223,11 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Read the body while keeping limits in mind. This reader controls the maximum bytes
// to read into memory and disk. This reader returns an error if the total request size exceeds the
// prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1
// predefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1
// and the reader would be unbounded bufio in the http.Server
body, err := multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes))
if err != nil || body == nil {
log.Errorf("vulcand/oxy/buffer: error when reading request body, err: %v", err)
b.log.Errorf("vulcand/oxy/buffer: error when reading request body, err: %v", err)
b.errHandler.ServeHTTP(w, req, err)
return
}
@@ -235,7 +248,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// set without content length or using chunked TransferEncoding
totalSize, err := body.Size()
if err != nil {
log.Errorf("vulcand/oxy/buffer: failed to get request size, err: %v", err)
b.log.Errorf("vulcand/oxy/buffer: failed to get request size, err: %v", err)
b.errHandler.ServeHTTP(w, req, err)
return
}
@@ -251,7 +264,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// We create a special writer that will limit the response size, buffer it to disk if necessary
writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(b.maxResponseBodyBytes), multibuf.MemBytes(b.memResponseBodyBytes))
if err != nil {
log.Errorf("vulcand/oxy/buffer: failed create response writer, err: %v", err)
b.log.Errorf("vulcand/oxy/buffer: failed create response writer, err: %v", err)
b.errHandler.ServeHTTP(w, req, err)
return
}
@@ -261,12 +274,13 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
header: make(http.Header),
buffer: writer,
responseWriter: w,
log: b.log,
}
defer bw.Close()
b.next.ServeHTTP(bw, outreq)
if bw.hijacked {
log.Debugf("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.")
b.log.Debugf("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.")
return
}
@@ -274,7 +288,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if bw.expectBody(outreq) {
rdr, err := writer.Reader()
if err != nil {
log.Errorf("vulcand/oxy/buffer: failed to read response, err: %v", err)
b.log.Errorf("vulcand/oxy/buffer: failed to read response, err: %v", err)
b.errHandler.ServeHTTP(w, req, err)
return
}
@@ -292,17 +306,17 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
attempt += 1
attempt++
if body != nil {
if _, err := body.Seek(0, 0); err != nil {
log.Errorf("vulcand/oxy/buffer: failed to rewind response body, err: %v", err)
b.log.Errorf("vulcand/oxy/buffer: failed to rewind response body, err: %v", err)
b.errHandler.ServeHTTP(w, req, err)
return
}
}
outreq = b.copyRequest(req, body, totalSize)
log.Debugf("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt)
b.log.Debugf("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt)
}
}
@@ -339,6 +353,7 @@ type bufferWriter struct {
buffer multibuf.WriterOnce
responseWriter http.ResponseWriter
hijacked bool
log *log.Logger
}
// RFC2616 #4.4
@@ -376,16 +391,16 @@ func (b *bufferWriter) WriteHeader(code int) {
b.code = code
}
//CloseNotifier interface - this allows downstream connections to be terminated when the client terminates.
// CloseNotifier interface - this allows downstream connections to be terminated when the client terminates.
func (b *bufferWriter) CloseNotify() <-chan bool {
if cn, ok := b.responseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter))
b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter))
return make(<-chan bool)
}
//This allows connections to be hijacked for websockets for instance.
// Hijack This allows connections to be hijacked for websockets for instance.
func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hi, ok := b.responseWriter.(http.Hijacker); ok {
conn, rw, err := hi.Hijack()
@@ -394,12 +409,12 @@ func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
return conn, rw, err
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.responseWriter))
b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.responseWriter))
return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.responseWriter))
}
type SizeErrHandler struct {
}
// SizeErrHandler Size error handler
type SizeErrHandler struct{}
func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if _, ok := err.(*multibuf.MaxSizeReachedError); ok {

View File

@@ -7,6 +7,7 @@ import (
"github.com/vulcand/predicate"
)
// IsValidExpression check if it's a valid expression
func IsValidExpression(expr string) bool {
_, err := parseExpression(expr)
return err == nil

View File

@@ -3,7 +3,7 @@
// Vulcan circuit breaker watches the error condtion to match
// after which it activates the fallback scenario, e.g. returns the response code
// or redirects the request to another location
//
// Circuit breakers start in the Standby state first, observing responses and watching location metrics.
//
// Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario
@@ -31,9 +31,8 @@ import (
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/mailgun/timetools"
log "github.com/sirupsen/logrus"
"github.com/vulcand/oxy/memmetrics"
"github.com/vulcand/oxy/utils"
)
@@ -63,6 +62,8 @@ type CircuitBreaker struct {
next http.Handler
clock timetools.TimeProvider
log *log.Logger
}
// New creates a new CircuitBreaker middleware
@@ -76,6 +77,7 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption)
fallbackDuration: defaultFallbackDuration,
recoveryDuration: defaultRecoveryDuration,
fallback: defaultFallback,
log: log.StandardLogger(),
}
for _, s := range options {
@@ -99,9 +101,19 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption)
return cb, nil
}
// Logger defines the logger the circuit breaker will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) CircuitBreakerOption {
return func(c *CircuitBreaker) error {
c.log = l
return nil
}
}
func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
if c.log.Level >= log.DebugLevel {
logEntry := c.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/circuitbreaker: completed ServeHttp on request")
}
@@ -112,6 +124,7 @@ func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c.serve(w, req)
}
// Wrap sets the next handler to be called by circuit breaker handler.
func (c *CircuitBreaker) Wrap(next http.Handler) {
c.next = next
}
@@ -126,7 +139,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque
c.m.Lock()
defer c.m.Unlock()
log.Warnf("%v is in error state", c)
c.log.Warnf("%v is in error state", c)
switch c.state {
case stateStandby:
@@ -156,7 +169,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque
func (c *CircuitBreaker) serve(w http.ResponseWriter, req *http.Request) {
start := c.clock.UtcNow()
p := utils.NewSimpleProxyWriter(w)
p := utils.NewProxyWriterWithLogger(w, c.log)
c.next.ServeHTTP(p, req)
@@ -191,13 +204,13 @@ func (c *CircuitBreaker) exec(s SideEffect) {
}
go func() {
if err := s.Exec(); err != nil {
log.Errorf("%v side effect failure: %v", c, err)
c.log.Errorf("%v side effect failure: %v", c, err)
}
}()
}
func (c *CircuitBreaker) setState(new cbState, until time.Time) {
log.Debugf("%v setting state to %v, until %v", c, new, until)
c.log.Debugf("%v setting state to %v, until %v", c, new, until)
c.state = new
c.until = until
switch new {
@@ -230,7 +243,7 @@ func (c *CircuitBreaker) checkAndSet() {
c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod)
if c.state == stateTripped {
log.Debugf("%v skip set tripped", c)
c.log.Debugf("%v skip set tripped", c)
return
}
@@ -244,7 +257,7 @@ func (c *CircuitBreaker) checkAndSet() {
func (c *CircuitBreaker) setRecovering() {
c.setState(stateRecovering, c.clock.UtcNow().Add(c.recoveryDuration))
c.rc = newRatioController(c.clock, c.recoveryDuration)
c.rc = newRatioController(c.clock, c.recoveryDuration, c.log)
}
// CircuitBreakerOption represents an option you can pass to New.
@@ -296,7 +309,7 @@ func OnTripped(s SideEffect) CircuitBreakerOption {
}
}
// OnTripped sets a SideEffect to run when entering the Standby state.
// OnStandby sets a SideEffect to run when entering the Standby state.
// Only one SideEffect can be set for this hook.
func OnStandby(s SideEffect) CircuitBreakerOption {
return func(c *CircuitBreaker) error {
@@ -346,8 +359,7 @@ const (
var defaultFallback = &fallback{}
type fallback struct {
}
type fallback struct{}
func (f *fallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)

View File

@@ -13,10 +13,12 @@ import (
"github.com/vulcand/oxy/utils"
)
// SideEffect a side effect
type SideEffect interface {
Exec() error
}
// Webhook Web hook
type Webhook struct {
URL string
Method string
@@ -25,11 +27,15 @@ type Webhook struct {
Body []byte
}
// WebhookSideEffect a web hook side effect
type WebhookSideEffect struct {
w Webhook
log *log.Logger
}
func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) {
// NewWebhookSideEffectsWithLogger creates a new WebhookSideEffect
func NewWebhookSideEffectsWithLogger(w Webhook, l *log.Logger) (*WebhookSideEffect, error) {
if w.Method == "" {
return nil, fmt.Errorf("Supply method")
}
@@ -38,7 +44,12 @@ func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) {
return nil, err
}
return &WebhookSideEffect{w: w}, nil
return &WebhookSideEffect{w: w, log: l}, nil
}
// NewWebhookSideEffect creates a new WebhookSideEffect
func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) {
return NewWebhookSideEffectsWithLogger(w, log.StandardLogger())
}
func (w *WebhookSideEffect) getBody() io.Reader {
@@ -51,6 +62,7 @@ func (w *WebhookSideEffect) getBody() io.Reader {
return nil
}
// Exec execute the side effect
func (w *WebhookSideEffect) Exec() error {
r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody())
if err != nil {
@@ -73,6 +85,6 @@ func (w *WebhookSideEffect) Exec() error {
if err != nil {
return err
}
log.Debugf("%v got response: (%s): %s", w, re.Status, string(body))
w.log.Debugf("%v got response: (%s): %s", w, re.Status, string(body))
return nil
}

View File

@@ -10,26 +10,36 @@ import (
"github.com/vulcand/oxy/utils"
)
// Response response model
type Response struct {
StatusCode int
ContentType string
Body []byte
}
// ResponseFallback fallback response handler
type ResponseFallback struct {
r Response
log *log.Logger
}
func NewResponseFallback(r Response) (*ResponseFallback, error) {
// NewResponseFallbackWithLogger creates a new ResponseFallback
func NewResponseFallbackWithLogger(r Response, l *log.Logger) (*ResponseFallback, error) {
if r.StatusCode == 0 {
return nil, fmt.Errorf("response code should not be 0")
}
return &ResponseFallback{r: r}, nil
return &ResponseFallback{r: r, log: l}, nil
}
// NewResponseFallback creates a new ResponseFallback
func NewResponseFallback(r Response) (*ResponseFallback, error) {
return NewResponseFallbackWithLogger(r, log.StandardLogger())
}
func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/fallback/response: completed ServeHttp on request")
}
@@ -45,27 +55,38 @@ func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
}
// Redirect redirect model
type Redirect struct {
URL string
PreservePath bool
}
// RedirectFallback fallback redirect handler
type RedirectFallback struct {
u *url.URL
r Redirect
u *url.URL
log *log.Logger
}
func NewRedirectFallback(r Redirect) (*RedirectFallback, error) {
// NewRedirectFallbackWithLogger creates a new RedirectFallback
func NewRedirectFallbackWithLogger(r Redirect, l *log.Logger) (*RedirectFallback, error) {
u, err := url.ParseRequestURI(r.URL)
if err != nil {
return nil, err
}
return &RedirectFallback{u: u, r: r}, nil
return &RedirectFallback{r: r, u: u, log: l}, nil
}
// NewRedirectFallback creates a new RedirectFallback
func NewRedirectFallback(r Redirect) (*RedirectFallback, error) {
return NewRedirectFallbackWithLogger(r, log.StandardLogger())
}
func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/fallback/redirect: completed ServeHttp on request")
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/vulcand/predicate"
)
@@ -50,7 +49,7 @@ func latencyAtQuantile(quantile float64) toInt {
return func(c *CircuitBreaker) int {
h, err := c.metrics.LatencyHistogram()
if err != nil {
log.Errorf("Failed to get latency histogram, for %v error: %v", c, err)
c.log.Errorf("Failed to get latency histogram, for %v error: %v", c, err)
return 0
}
return int(h.LatencyAtQuantile(quantile) / time.Millisecond)

View File

@@ -19,13 +19,17 @@ type ratioController struct {
tm timetools.TimeProvider
allowed int
denied int
log *log.Logger
}
func newRatioController(tm timetools.TimeProvider, rampUp time.Duration) *ratioController {
func newRatioController(tm timetools.TimeProvider, rampUp time.Duration, log *log.Logger) *ratioController {
return &ratioController{
duration: rampUp,
tm: tm,
start: tm.UtcNow(),
log: log,
}
}
@@ -34,17 +38,17 @@ func (r *ratioController) String() string {
}
func (r *ratioController) allowRequest() bool {
log.Debugf("%v", r)
r.log.Debugf("%v", r)
t := r.targetRatio()
// This condition answers the question - would we satisfy the target ratio if we allow this request?
e := r.computeRatio(r.allowed+1, r.denied)
if e < t {
r.allowed++
log.Debugf("%v allowed", r)
r.log.Debugf("%v allowed", r)
return true
}
r.denied++
log.Debugf("%v denied", r)
r.log.Debugf("%v denied", r)
return false
}

View File

@@ -1,4 +1,4 @@
// package connlimit provides control over simultaneous connections coming from the same source
// Package connlimit provides control over simultaneous connections coming from the same source
package connlimit
import (
@@ -10,7 +10,7 @@ import (
"github.com/vulcand/oxy/utils"
)
// Limiter tracks concurrent connection per token
// ConnLimiter tracks concurrent connection per token
// and is capable of rejecting connections if they are failed
type ConnLimiter struct {
mutex *sync.Mutex
@@ -21,8 +21,10 @@ type ConnLimiter struct {
next http.Handler
errHandler utils.ErrorHandler
log *log.Logger
}
// New creates a new ConnLimiter
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) {
if extract == nil {
return nil, fmt.Errorf("Extract function can not be nil")
@@ -33,6 +35,7 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64,
maxConnections: maxConnections,
connections: make(map[string]int64),
next: next,
log: log.StandardLogger(),
}
for _, o := range options {
@@ -41,11 +44,24 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64,
}
}
if cl.errHandler == nil {
cl.errHandler = defaultErrHandler
cl.errHandler = &ConnErrHandler{
log: cl.log,
}
}
return cl, nil
}
// Logger defines the logger the connection limiter will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.log = l
return nil
}
}
// Wrap sets the next handler to be called by connexion limiter handler.
func (cl *ConnLimiter) Wrap(h http.Handler) {
cl.next = h
}
@@ -53,12 +69,12 @@ func (cl *ConnLimiter) Wrap(h http.Handler) {
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := cl.extract.Extract(r)
if err != nil {
log.Errorf("failed to extract source of the connection: %v", err)
cl.log.Errorf("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
if err := cl.acquire(token, amount); err != nil {
log.Debugf("limiting request source %s: %v", token, err)
cl.log.Debugf("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
@@ -95,6 +111,7 @@ func (cl *ConnLimiter) release(token string, amount int64) {
}
}
// MaxConnError maximum connections reached error
type MaxConnError struct {
max int64
}
@@ -103,12 +120,14 @@ func (m *MaxConnError) Error() string {
return fmt.Sprintf("max connections reached: %d", m.max)
}
// ConnErrHandler connection limiter error handler
type ConnErrHandler struct {
log *log.Logger
}
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
if e.log.Level >= log.DebugLevel {
logEntry := e.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/connlimit: completed ServeHttp on request")
}
@@ -121,6 +140,7 @@ func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err
utils.DefaultHandler.ServeHTTP(w, req, err)
}
// ConnLimitOption connection limit option type
type ConnLimitOption func(l *ConnLimiter) error
// ErrorHandler sets error handler of the server
@@ -130,5 +150,3 @@ func ErrorHandler(h utils.ErrorHandler) ConnLimitOption {
return nil
}
}
var defaultErrHandler = &ConnErrHandler{}

View File

@@ -1,12 +1,15 @@
// package forwarder implements http handler that forwards requests to remote server
// Package forward implements http handler that forwards requests to remote server
// and serves back the response
// websocket proxying support based on https://github.com/yhat/wsutil
package forward
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
@@ -21,7 +24,7 @@ import (
"github.com/vulcand/oxy/utils"
)
// Oxy Logger interface of the internal
// OxyLogger interface of the internal
type OxyLogger interface {
log.FieldLogger
GetLevel() log.Level
@@ -42,8 +45,7 @@ type ReqRewriter interface {
type optSetter func(f *Forwarder) error
// PassHostHeader specifies if a client's Host header field should
// be delegated
// PassHostHeader specifies if a client's Host header field should be delegated
func PassHostHeader(b bool) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.passHost = b
@@ -68,8 +70,7 @@ func Rewriter(r ReqRewriter) optSetter {
}
}
// PassHostHeader specifies if a client's Host header field should
// be delegated
// WebsocketTLSClientConfig define the websocker client TLS configuration
func WebsocketTLSClientConfig(tcc *tls.Config) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.tlsClientConfig = tcc
@@ -120,6 +121,7 @@ func Logger(l log.FieldLogger) optSetter {
}
}
// StateListener defines a state listener for the HTTP forwarder
func StateListener(stateListener UrlForwardingStateListener) optSetter {
return func(f *Forwarder) error {
f.stateListener = stateListener
@@ -127,6 +129,15 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter {
}
}
// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed
func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.websocketConnectionClosedHook = hook
return nil
}
}
// ResponseModifier defines a response modifier for the HTTP forwarder
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.modifyResponse = responseModifier
@@ -134,6 +145,7 @@ func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
}
}
// StreamingFlushInterval defines a streaming flush interval for the HTTP forwarder
func StreamingFlushInterval(flushInterval time.Duration) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.flushInterval = flushInterval
@@ -141,11 +153,13 @@ func StreamingFlushInterval(flushInterval time.Duration) optSetter {
}
}
// ErrorHandlingRoundTripper a error handling round tripper
type ErrorHandlingRoundTripper struct {
http.RoundTripper
errorHandler utils.ErrorHandler
}
// RoundTrip executes the round trip
func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTripper.RoundTrip(req)
if err != nil {
@@ -185,15 +199,19 @@ type httpForwarder struct {
log OxyLogger
bufferPool httputil.BufferPool
bufferPool httputil.BufferPool
websocketConnectionClosedHook func(req *http.Request, conn net.Conn)
}
const defaultFlushInterval = time.Duration(100) * time.Millisecond
// Connection states
const (
defaultFlushInterval = time.Duration(100) * time.Millisecond
StateConnected = iota
StateConnected = iota
StateDisconnected
)
// UrlForwardingStateListener URL forwarding state listener
type UrlForwardingStateListener func(*url.URL, int)
// New creates an instance of Forwarder based on the provided list of configuration options
@@ -293,11 +311,6 @@ func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) {
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !f.passHost {
outReq.Host = target.Host
}
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
@@ -305,6 +318,11 @@ func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) {
if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !f.passHost {
outReq.Host = target.Host
}
}
// serveHTTP forwards websocket traffic
@@ -368,14 +386,40 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return
}
defer underlyingConn.Close()
defer targetConn.Close()
defer func() {
underlyingConn.Close()
targetConn.Close()
if f.websocketConnectionClosedHook != nil {
f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn())
}
}()
errClient := make(chan error, 1)
errBackend := make(chan error, 1)
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
forward := func(messageType int, reader io.Reader) error {
writer, err := dst.NextWriter(messageType)
if err != nil {
return err
}
_, err = io.Copy(writer, reader)
if err != nil {
return err
}
return writer.Close()
}
src.SetPingHandler(func(data string) error {
return forward(websocket.PingMessage, bytes.NewReader([]byte(data)))
})
src.SetPongHandler(func(data string) error {
return forward(websocket.PongMessage, bytes.NewReader([]byte(data)))
})
for {
msgType, msg, err := src.ReadMessage()
msgType, reader, err := src.NextReader()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
@@ -393,11 +437,11 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
}
errc <- err
if m != nil {
dst.WriteMessage(websocket.CloseMessage, m)
forward(websocket.CloseMessage, bytes.NewReader([]byte(m)))
}
break
}
err = dst.WriteMessage(msgType, msg)
err = forward(msgType, reader)
if err != nil {
errc <- err
break
@@ -466,16 +510,6 @@ func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ct
defer logEntry.Debug("vulcand/oxy/forward/http: completed ServeHttp on request")
}
var pw utils.ProxyWriter
// Disable closeNotify when method GET for http pipelining
// Waiting for https://github.com/golang/go/issues/23921
if inReq.Method == http.MethodGet {
pw = utils.NewProxyWriterWithoutCloseNotify(w)
} else {
pw = utils.NewSimpleProxyWriter(w)
}
start := time.Now().UTC()
outReq := new(http.Request)
@@ -490,22 +524,28 @@ func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ct
ModifyResponse: f.modifyResponse,
BufferPool: f.bufferPool,
}
revproxy.ServeHTTP(pw, outReq)
if inReq.TLS != nil {
f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start),
inReq.TLS.Version,
inReq.TLS.DidResume,
inReq.TLS.CipherSuite,
inReq.TLS.ServerName)
if f.log.GetLevel() >= log.DebugLevel {
pw := utils.NewProxyWriter(w)
revproxy.ServeHTTP(pw, outReq)
if inReq.TLS != nil {
f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start),
inReq.TLS.Version,
inReq.TLS.DidResume,
inReq.TLS.CipherSuite,
inReq.TLS.ServerName)
} else {
f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v",
inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start))
}
} else {
f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v",
inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start))
revproxy.ServeHTTP(w, outReq)
}
}
// isWebsocketRequest determines if the specified HTTP request is a
// IsWebsocketRequest determines if the specified HTTP request is a
// websocket handshake request
func IsWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool {

View File

@@ -1,5 +1,6 @@
package forward
// Headers
const (
XForwardedProto = "X-Forwarded-Proto"
XForwardedFor = "X-Forwarded-For"
@@ -22,7 +23,7 @@ const (
SecWebsocketAccept = "Sec-Websocket-Accept"
)
// Hop-by-hop headers. These are removed when sent to the backend.
// HopHeaders Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
// Copied from reverseproxy.go, too bad
var HopHeaders = []string{
@@ -36,6 +37,7 @@ var HopHeaders = []string{
Upgrade,
}
// WebsocketDialHeaders Websocket dial headers
var WebsocketDialHeaders = []string{
Upgrade,
Connection,
@@ -45,6 +47,7 @@ var WebsocketDialHeaders = []string{
SecWebsocketAccept,
}
// WebsocketUpgradeHeaders Websocket upgrade headers
var WebsocketUpgradeHeaders = []string{
Upgrade,
Connection,
@@ -52,6 +55,7 @@ var WebsocketUpgradeHeaders = []string{
SecWebsocketExtensions,
}
// XHeaders X-* headers
var XHeaders = []string{
XForwardedProto,
XForwardedFor,

View File

@@ -8,7 +8,7 @@ import (
"github.com/vulcand/oxy/utils"
)
// Rewriter is responsible for removing hop-by-hop headers and setting forwarding headers
// HeaderRewriter is responsible for removing hop-by-hop headers and setting forwarding headers
type HeaderRewriter struct {
TrustForwardHeader bool
Hostname string
@@ -19,6 +19,7 @@ func ipv6fix(clientIP string) string {
return strings.Split(clientIP, "%")[0]
}
// Rewrite rewrite request headers
func (rw *HeaderRewriter) Rewrite(req *http.Request) {
if !rw.TrustForwardHeader {
utils.RemoveHeaders(req.Header, XHeaders...)
@@ -85,6 +86,10 @@ func forwardedPort(req *http.Request) string {
return port
}
if req.Header.Get(XForwardedProto) == "https" || req.Header.Get(XForwardedProto) == "wss" {
return "443"
}
if req.TLS != nil {
return "443"
}

View File

@@ -6,7 +6,7 @@ import (
"time"
)
// SplitRatios provides simple anomaly detection for requests latencies.
// SplitLatencies provides simple anomaly detection for requests latencies.
// it splits values into good or bad category based on the threshold and the median value.
// If all values are not far from the median, it will return all values in 'good' set.
// Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored.
@@ -23,10 +23,10 @@ func SplitLatencies(values []time.Duration, precision time.Duration) (good map[t
good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool)
// Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise.
vgood, vbad := SplitFloat64(2, 0, ratios)
for r, _ := range vgood {
for r := range vgood {
good[v2r[r]] = true
}
for r, _ := range vbad {
for r := range vbad {
bad[v2r[r]] = true
}
return good, bad

View File

@@ -9,6 +9,7 @@ import (
type rcOptSetter func(*RollingCounter) error
// CounterClock defines a counter clock
func CounterClock(c timetools.TimeProvider) rcOptSetter {
return func(r *RollingCounter) error {
r.clock = c
@@ -16,7 +17,7 @@ func CounterClock(c timetools.TimeProvider) rcOptSetter {
}
}
// Calculates in memory failure rate of an endpoint using rolling window of a predefined size
// RollingCounter Calculates in memory failure rate of an endpoint using rolling window of a predefined size
type RollingCounter struct {
clock timetools.TimeProvider
resolution time.Duration
@@ -57,11 +58,13 @@ func NewCounter(buckets int, resolution time.Duration, options ...rcOptSetter) (
return rc, nil
}
// Append append a counter
func (c *RollingCounter) Append(o *RollingCounter) error {
c.Inc(int(o.Count()))
return nil
}
// Clone clone a counter
func (c *RollingCounter) Clone() *RollingCounter {
c.cleanup()
other := &RollingCounter{
@@ -75,6 +78,7 @@ func (c *RollingCounter) Clone() *RollingCounter {
return other
}
// Reset reset a counter
func (c *RollingCounter) Reset() {
c.lastBucket = -1
c.countedBuckets = 0
@@ -84,27 +88,33 @@ func (c *RollingCounter) Reset() {
}
}
// CountedBuckets gets counted buckets
func (c *RollingCounter) CountedBuckets() int {
return c.countedBuckets
}
// Count counts
func (c *RollingCounter) Count() int64 {
c.cleanup()
return c.sum()
}
// Resolution gets resolution
func (c *RollingCounter) Resolution() time.Duration {
return c.resolution
}
// Buckets gets buckets
func (c *RollingCounter) Buckets() int {
return len(c.values)
}
// WindowSize gets windows size
func (c *RollingCounter) WindowSize() time.Duration {
return time.Duration(len(c.values)) * c.resolution
}
// Inc increment counter
func (c *RollingCounter) Inc(v int) {
c.cleanup()
c.incBucketValue(v)

View File

@@ -20,6 +20,7 @@ type HDRHistogram struct {
h *hdrhistogram.Histogram
}
// NewHDRHistogram creates a new HDRHistogram
func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) {
defer func() {
if msg := recover(); msg != nil {
@@ -34,37 +35,42 @@ func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error)
}, nil
}
func (r *HDRHistogram) Export() *HDRHistogram {
var hist *hdrhistogram.Histogram = nil
if r.h != nil {
snapshot := r.h.Export()
// Export export a HDRHistogram
func (h *HDRHistogram) Export() *HDRHistogram {
var hist *hdrhistogram.Histogram
if h.h != nil {
snapshot := h.h.Export()
hist = hdrhistogram.Import(snapshot)
}
return &HDRHistogram{low: r.low, high: r.high, sigfigs: r.sigfigs, h: hist}
return &HDRHistogram{low: h.low, high: h.high, sigfigs: h.sigfigs, h: hist}
}
// Returns latency at quantile with microsecond precision
// LatencyAtQuantile sets latency at quantile with microsecond precision
func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration {
return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond
}
// Records latencies with microsecond precision
// RecordLatencies Records latencies with microsecond precision
func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error {
return h.RecordValues(int64(d/time.Microsecond), n)
}
// Reset reset a HDRHistogram
func (h *HDRHistogram) Reset() {
h.h.Reset()
}
// ValueAtQuantile sets value at quantile
func (h *HDRHistogram) ValueAtQuantile(q float64) int64 {
return h.h.ValueAtQuantile(q)
}
// RecordValues sets record values
func (h *HDRHistogram) RecordValues(v, n int64) error {
return h.h.RecordValues(v, n)
}
// Merge merge a HDRHistogram
func (h *HDRHistogram) Merge(other *HDRHistogram) error {
if other == nil {
return fmt.Errorf("other is nil")
@@ -75,6 +81,7 @@ func (h *HDRHistogram) Merge(other *HDRHistogram) error {
type rhOptSetter func(r *RollingHDRHistogram) error
// RollingClock sets a clock
func RollingClock(clock timetools.TimeProvider) rhOptSetter {
return func(r *RollingHDRHistogram) error {
r.clock = clock
@@ -82,7 +89,7 @@ func RollingClock(clock timetools.TimeProvider) rhOptSetter {
}
}
// RollingHistogram holds multiple histograms and rotates every period.
// RollingHDRHistogram holds multiple histograms and rotates every period.
// It provides resulting histogram as a result of a call of 'Merged' function.
type RollingHDRHistogram struct {
idx int
@@ -96,6 +103,7 @@ type RollingHDRHistogram struct {
clock timetools.TimeProvider
}
// NewRollingHDRHistogram created a new RollingHDRHistogram
func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOptSetter) (*RollingHDRHistogram, error) {
rh := &RollingHDRHistogram{
bucketCount: bucketCount,
@@ -127,6 +135,7 @@ func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration,
return rh, nil
}
// Export export a RollingHDRHistogram
func (r *RollingHDRHistogram) Export() *RollingHDRHistogram {
export := &RollingHDRHistogram{}
export.idx = r.idx
@@ -147,6 +156,7 @@ func (r *RollingHDRHistogram) Export() *RollingHDRHistogram {
return export
}
// Append append a RollingHDRHistogram
func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error {
if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs {
return fmt.Errorf("can't merge")
@@ -160,6 +170,7 @@ func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error {
return nil
}
// Reset reset a RollingHDRHistogram
func (r *RollingHDRHistogram) Reset() {
r.idx = 0
r.lastRoll = r.clock.UtcNow()
@@ -173,6 +184,7 @@ func (r *RollingHDRHistogram) rotate() {
r.buckets[r.idx].Reset()
}
// Merged gets merged histogram
func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) {
m, err := NewHDRHistogram(r.low, r.high, r.sigfigs)
if err != nil {
@@ -194,10 +206,12 @@ func (r *RollingHDRHistogram) getHist() *HDRHistogram {
return r.buckets[r.idx]
}
// RecordLatencies sets records latencies
func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error {
return r.getHist().RecordLatencies(v, n)
}
// RecordValues set record values
func (r *RollingHDRHistogram) RecordValues(v, n int64) error {
return r.getHist().RecordValues(v, n)
}

View File

@@ -8,6 +8,7 @@ import (
type ratioOptSetter func(r *RatioCounter) error
// RatioClock sets a clock
func RatioClock(clock timetools.TimeProvider) ratioOptSetter {
return func(r *RatioCounter) error {
r.clock = clock
@@ -22,6 +23,7 @@ type RatioCounter struct {
b *RollingCounter
}
// NewRatioCounter creates a new RatioCounter
func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptSetter) (*RatioCounter, error) {
rc := &RatioCounter{}
@@ -50,39 +52,48 @@ func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptS
return rc, nil
}
// Reset reset the counter
func (r *RatioCounter) Reset() {
r.a.Reset()
r.b.Reset()
}
// IsReady returns true if the counter is ready
func (r *RatioCounter) IsReady() bool {
return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values)
}
// CountA gets count A
func (r *RatioCounter) CountA() int64 {
return r.a.Count()
}
// CountB gets count B
func (r *RatioCounter) CountB() int64 {
return r.b.Count()
}
// Resolution gets resolution
func (r *RatioCounter) Resolution() time.Duration {
return r.a.Resolution()
}
// Buckets gets buckets
func (r *RatioCounter) Buckets() int {
return r.a.Buckets()
}
// WindowSize gets windows size
func (r *RatioCounter) WindowSize() time.Duration {
return r.a.WindowSize()
}
// ProcessedCount gets processed count
func (r *RatioCounter) ProcessedCount() int64 {
return r.CountA() + r.CountB()
}
// Ratio gets ratio
func (r *RatioCounter) Ratio() float64 {
a := r.a.Count()
b := r.b.Count()
@@ -93,28 +104,34 @@ func (r *RatioCounter) Ratio() float64 {
return float64(a) / float64(a+b)
}
// IncA increment counter A
func (r *RatioCounter) IncA(v int) {
r.a.Inc(v)
}
// IncB increment counter B
func (r *RatioCounter) IncB(v int) {
r.b.Inc(v)
}
// TestMeter a test meter
type TestMeter struct {
Rate float64
NotReady bool
WindowSize time.Duration
}
// GetWindowSize gets windows size
func (tm *TestMeter) GetWindowSize() time.Duration {
return tm.WindowSize
}
// IsReady returns true if the meter is ready
func (tm *TestMeter) IsReady() bool {
return !tm.NotReady
}
// GetRate gets rate
func (tm *TestMeter) GetRate() float64 {
return tm.Rate
}

View File

@@ -29,10 +29,16 @@ type RTMetrics struct {
type rrOptSetter func(r *RTMetrics) error
// NewRTMetricsFn builder function type
type NewRTMetricsFn func() (*RTMetrics, error)
// NewCounterFn builder function type
type NewCounterFn func() (*RollingCounter, error)
// NewRollingHistogramFn builder function type
type NewRollingHistogramFn func() (*RollingHDRHistogram, error)
// RTCounter set a builder function for Counter
func RTCounter(new NewCounterFn) rrOptSetter {
return func(r *RTMetrics) error {
r.newCounter = new
@@ -40,13 +46,15 @@ func RTCounter(new NewCounterFn) rrOptSetter {
}
}
func RTHistogram(new NewRollingHistogramFn) rrOptSetter {
// RTHistogram set a builder function for RollingHistogram
func RTHistogram(fn NewRollingHistogramFn) rrOptSetter {
return func(r *RTMetrics) error {
r.newHist = new
r.newHist = fn
return nil
}
}
// RTClock sets a clock
func RTClock(clock timetools.TimeProvider) rrOptSetter {
return func(r *RTMetrics) error {
r.clock = clock
@@ -103,7 +111,7 @@ func NewRTMetrics(settings ...rrOptSetter) (*RTMetrics, error) {
return m, nil
}
// Returns a new RTMetrics which is a copy of the current one
// Export Returns a new RTMetrics which is a copy of the current one
func (m *RTMetrics) Export() *RTMetrics {
m.statusCodesLock.RLock()
defer m.statusCodesLock.RUnlock()
@@ -130,11 +138,12 @@ func (m *RTMetrics) Export() *RTMetrics {
return export
}
// CounterWindowSize gets total windows size
func (m *RTMetrics) CounterWindowSize() time.Duration {
return m.total.WindowSize()
}
// GetNetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection
// NetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection
// that occurred in the given time window compared to the total requests count.
func (m *RTMetrics) NetworkErrorRatio() float64 {
if m.total.Count() == 0 {
@@ -143,7 +152,7 @@ func (m *RTMetrics) NetworkErrorRatio() float64 {
return float64(m.netErrors.Count()) / float64(m.total.Count())
}
// GetResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB)
// ResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB)
func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 {
a := int64(0)
b := int64(0)
@@ -163,6 +172,7 @@ func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 {
return 0
}
// Append append a metric
func (m *RTMetrics) Append(other *RTMetrics) error {
if m == other {
return errors.New("RTMetrics cannot append to self")
@@ -196,6 +206,7 @@ func (m *RTMetrics) Append(other *RTMetrics) error {
return m.histogram.Append(copied.histogram)
}
// Record records a metric
func (m *RTMetrics) Record(code int, duration time.Duration) {
m.total.Inc(1)
if code == http.StatusGatewayTimeout || code == http.StatusBadGateway {
@@ -205,17 +216,17 @@ func (m *RTMetrics) Record(code int, duration time.Duration) {
m.recordLatency(duration)
}
// GetTotalCount returns total count of processed requests collected.
// TotalCount returns total count of processed requests collected.
func (m *RTMetrics) TotalCount() int64 {
return m.total.Count()
}
// GetNetworkErrorCount returns total count of processed requests observed
// NetworkErrorCount returns total count of processed requests observed
func (m *RTMetrics) NetworkErrorCount() int64 {
return m.netErrors.Count()
}
// GetStatusCodesCounts returns map with counts of the response codes
// StatusCodesCounts returns map with counts of the response codes
func (m *RTMetrics) StatusCodesCounts() map[int]int64 {
sc := make(map[int]int64)
m.statusCodesLock.RLock()
@@ -228,13 +239,14 @@ func (m *RTMetrics) StatusCodesCounts() map[int]int64 {
return sc
}
// GetLatencyHistogram computes and returns resulting histogram with latencies observed.
// LatencyHistogram computes and returns resulting histogram with latencies observed.
func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) {
m.histogramLock.Lock()
defer m.histogramLock.Unlock()
return m.histogram.Merged()
}
// Reset reset metrics
func (m *RTMetrics) Reset() {
m.statusCodesLock.Lock()
defer m.statusCodesLock.Unlock()
@@ -284,7 +296,7 @@ const (
counterResolution = time.Second
histMin = 1
histMax = 3600000000 // 1 hour in microseconds
histSignificantFigures = 2 // signigicant figures (1% precision)
histSignificantFigures = 2 // significant figures (1% precision)
histBuckets = 6 // number of sub-histograms in a rolling histogram
histPeriod = 10 * time.Second // roll time
)

View File

@@ -7,6 +7,7 @@ import (
"github.com/mailgun/timetools"
)
// UndefinedDelay default delay
const UndefinedDelay = -1
// rate defines token bucket parameters.
@@ -20,7 +21,7 @@ func (r *rate) String() string {
return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst)
}
// Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket)
// tokenBucket Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket)
type tokenBucket struct {
// The time period controlled by the bucket in nanoseconds.
period time.Duration
@@ -63,7 +64,7 @@ func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) {
tb.updateAvailableTokens()
tb.lastConsumed = 0
if tokens > tb.burst {
return UndefinedDelay, fmt.Errorf("Requested tokens larger than max tokens")
return UndefinedDelay, fmt.Errorf("requested tokens larger than max tokens")
}
if tb.availableTokens < tokens {
return tb.timeTillAvailable(tokens), nil
@@ -83,11 +84,11 @@ func (tb *tokenBucket) rollback() {
tb.lastConsumed = 0
}
// Update modifies `average` and `burst` fields of the token bucket according
// update modifies `average` and `burst` fields of the token bucket according
// to the provided `Rate`
func (tb *tokenBucket) update(rate *rate) error {
if rate.period != tb.period {
return fmt.Errorf("Period mismatch: %v != %v", tb.period, rate.period)
return fmt.Errorf("period mismatch: %v != %v", tb.period, rate.period)
}
tb.timePerToken = time.Duration(int64(tb.period) / rate.average)
tb.burst = rate.burst

View File

@@ -2,11 +2,11 @@ package ratelimit
import (
"fmt"
"sort"
"strings"
"time"
"github.com/mailgun/timetools"
"sort"
)
// TokenBucketSet represents a set of TokenBucket covering different time periods.
@@ -16,7 +16,7 @@ type TokenBucketSet struct {
clock timetools.TimeProvider
}
// newTokenBucketSet creates a `TokenBucketSet` from the specified `rates`.
// NewTokenBucketSet creates a `TokenBucketSet` from the specified `rates`.
func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet {
tbs := new(TokenBucketSet)
tbs.clock = clock
@@ -54,9 +54,10 @@ func (tbs *TokenBucketSet) Update(rates *RateSet) {
}
}
// Consume consume tokens
func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) {
var maxDelay time.Duration = UndefinedDelay
var firstErr error = nil
var firstErr error
for _, tokenBucket := range tbs.buckets {
// We keep calling `Consume` even after a error is returned for one of
// buckets because that allows us to simplify the rollback procedure,
@@ -80,6 +81,7 @@ func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) {
return maxDelay, firstErr
}
// GetMaxPeriod returns the max period
func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration {
return tbs.maxPeriod
}

View File

@@ -1,4 +1,4 @@
// Tokenbucket based request rate limiter
// Package ratelimit Tokenbucket based request rate limiter
package ratelimit
import (
@@ -13,6 +13,7 @@ import (
"github.com/vulcand/oxy/utils"
)
// DefaultCapacity default capacity
const DefaultCapacity = 65536
// RateSet maintains a set of rates. It can contain only one rate per period at a time.
@@ -31,15 +32,15 @@ func NewRateSet() *RateSet {
// set then the new rate overrides the old one.
func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error {
if period <= 0 {
return fmt.Errorf("Invalid period: %v", period)
return fmt.Errorf("invalid period: %v", period)
}
if average <= 0 {
return fmt.Errorf("Invalid average: %v", average)
return fmt.Errorf("invalid average: %v", average)
}
if burst <= 0 {
return fmt.Errorf("Invalid burst: %v", burst)
return fmt.Errorf("invalid burst: %v", burst)
}
rs.m[period] = &rate{period, average, burst}
rs.m[period] = &rate{period: period, average: average, burst: burst}
return nil
}
@@ -47,12 +48,15 @@ func (rs *RateSet) String() string {
return fmt.Sprint(rs.m)
}
// RateExtractor rate extractor
type RateExtractor interface {
Extract(r *http.Request) (*RateSet, error)
}
// RateExtractorFunc rate extractor function type
type RateExtractorFunc func(r *http.Request) (*RateSet, error)
// Extract extract from request
func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) {
return e(r)
}
@@ -68,20 +72,24 @@ type TokenLimiter struct {
errHandler utils.ErrorHandler
capacity int
next http.Handler
log *log.Logger
}
// New constructs a `TokenLimiter` middleware instance.
func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) {
if defaultRates == nil || len(defaultRates.m) == 0 {
return nil, fmt.Errorf("Provide default rates")
return nil, fmt.Errorf("provide default rates")
}
if extract == nil {
return nil, fmt.Errorf("Provide extract function")
return nil, fmt.Errorf("provide extract function")
}
tl := &TokenLimiter{
next: next,
defaultRates: defaultRates,
extract: extract,
log: log.StandardLogger(),
}
for _, o := range opts {
@@ -98,6 +106,17 @@ func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet
return tl, nil
}
// Logger defines the logger the token limiter will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) TokenLimiterOption {
return func(tl *TokenLimiter) error {
tl.log = l
return nil
}
}
// Wrap sets the next handler to be called by token limiter handler.
func (tl *TokenLimiter) Wrap(next http.Handler) {
tl.next = next
}
@@ -110,7 +129,7 @@ func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
if err := tl.consumeRates(req, source, amount); err != nil {
log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err)
tl.log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err)
tl.errHandler.ServeHTTP(w, req, err)
return
}
@@ -155,7 +174,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
rates, err := tl.extractRates.Extract(req)
if err != nil {
log.Errorf("Failed to retrieve rates: %v", err)
tl.log.Errorf("Failed to retrieve rates: %v", err)
return tl.defaultRates
}
@@ -167,6 +186,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
return rates
}
// MaxRateError max rate error
type MaxRateError struct {
delay time.Duration
}
@@ -175,19 +195,21 @@ func (m *MaxRateError) Error() string {
return fmt.Sprintf("max rate reached: retry-in %v", m.delay)
}
type RateErrHandler struct {
}
// RateErrHandler error handler
type RateErrHandler struct{}
func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if rerr, ok := err.(*MaxRateError); ok {
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", rerr.delay.Seconds()))
w.Header().Set("X-Retry-In", rerr.delay.String())
w.WriteHeader(429)
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(err.Error()))
return
}
utils.DefaultHandler.ServeHTTP(w, req, err)
}
// TokenLimiterOption token limiter option type
type TokenLimiterOption func(l *TokenLimiter) error
// ErrorHandler sets error handler of the server
@@ -198,6 +220,7 @@ func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption {
}
}
// ExtractRates sets the rate extractor
func ExtractRates(e RateExtractor) TokenLimiterOption {
return func(cl *TokenLimiter) error {
cl.extractRates = e
@@ -205,6 +228,7 @@ func ExtractRates(e RateExtractor) TokenLimiterOption {
}
}
// Clock sets the clock
func Clock(clock timetools.TimeProvider) TokenLimiterOption {
return func(cl *TokenLimiter) error {
cl.clock = clock
@@ -212,6 +236,7 @@ func Clock(clock timetools.TimeProvider) TokenLimiterOption {
}
}
// Capacity sets the capacity
func Capacity(cap int) TokenLimiterOption {
return func(cl *TokenLimiter) error {
if cap <= 0 {

View File

@@ -2,4 +2,5 @@ package roundrobin
import "net/http"
// RequestRewriteListener function to rewrite request
type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request)

View File

@@ -16,13 +16,14 @@ import (
// RebalancerOption - functional option setter for rebalancer
type RebalancerOption func(*Rebalancer) error
// Meter measures server peformance and returns it's relative value via rating
// Meter measures server performance and returns it's relative value via rating
type Meter interface {
Rating() float64
Record(int, time.Duration)
IsReady() bool
}
// NewMeterFn type of functions to create new Meter
type NewMeterFn func() (Meter, error)
// Rebalancer increases weights on servers that perform better than others. It also rolls back to original weights
@@ -52,8 +53,11 @@ type Rebalancer struct {
stickySession *StickySession
requestRewriteListener RequestRewriteListener
log *log.Logger
}
// RebalancerClock sets a clock
func RebalancerClock(clock timetools.TimeProvider) RebalancerOption {
return func(r *Rebalancer) error {
r.clock = clock
@@ -61,6 +65,7 @@ func RebalancerClock(clock timetools.TimeProvider) RebalancerOption {
}
}
// RebalancerBackoff sets a beck off duration
func RebalancerBackoff(d time.Duration) RebalancerOption {
return func(r *Rebalancer) error {
r.backoffDuration = d
@@ -68,6 +73,7 @@ func RebalancerBackoff(d time.Duration) RebalancerOption {
}
}
// RebalancerMeter sets a Meter builder function
func RebalancerMeter(newMeter NewMeterFn) RebalancerOption {
return func(r *Rebalancer) error {
r.newMeter = newMeter
@@ -83,6 +89,7 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption {
}
}
// RebalancerStickySession sets a sticky session
func RebalancerStickySession(stickySession *StickySession) RebalancerOption {
return func(r *Rebalancer) error {
r.stickySession = stickySession
@@ -90,7 +97,7 @@ func RebalancerStickySession(stickySession *StickySession) RebalancerOption {
}
}
// RebalancerErrorHandler is a functional argument that sets error handler of the server
// RebalancerRequestRewriteListener is a functional argument that sets error handler of the server
func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption {
return func(r *Rebalancer) error {
r.requestRewriteListener = rrl
@@ -98,11 +105,14 @@ func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOpti
}
}
// NewRebalancer creates a new Rebalancer
func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) {
rb := &Rebalancer{
mtx: &sync.Mutex{},
next: handler,
stickySession: nil,
log: log.StandardLogger(),
}
for _, o := range opts {
if err := o(rb); err != nil {
@@ -134,6 +144,17 @@ func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalanc
return rb, nil
}
// RebalancerLogger defines the logger the rebalancer will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func RebalancerLogger(l *log.Logger) RebalancerOption {
return func(rb *Rebalancer) error {
rb.log = l
return nil
}
}
// Servers gets all servers
func (rb *Rebalancer) Servers() []*url.URL {
rb.mtx.Lock()
defer rb.mtx.Unlock()
@@ -142,13 +163,13 @@ func (rb *Rebalancer) Servers() []*url.URL {
}
func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
if rb.log.Level >= log.DebugLevel {
logEntry := rb.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: completed ServeHttp on request")
}
pw := utils.NewSimpleProxyWriter(w)
pw := utils.NewProxyWriter(w)
start := rb.clock.UtcNow()
// make shallow copy of request before changing anything to avoid side effects
@@ -169,25 +190,25 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
if !stuck {
url, err := rb.next.NextServer()
fwdURL, err := rb.next.NextServer()
if err != nil {
rb.errHandler.ServeHTTP(w, req, err)
return
}
if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
// log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": fwdURL}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
}
if rb.stickySession != nil {
rb.stickySession.StickBackend(url, &w)
rb.stickySession.StickBackend(fwdURL, &w)
}
newReq.URL = url
newReq.URL = fwdURL
}
//Emit event to a listener if one exists
// Emit event to a listener if one exists
if rb.requestRewriteListener != nil {
rb.requestRewriteListener(req, &newReq)
}
@@ -215,6 +236,7 @@ func (rb *Rebalancer) reset() {
rb.ratings = make([]float64, len(rb.servers))
}
// Wrap sets the next handler to be called by rebalancer handler.
func (rb *Rebalancer) Wrap(next balancerHandler) error {
if rb.next != nil {
return fmt.Errorf("already bound to %T", rb.next)
@@ -223,6 +245,7 @@ func (rb *Rebalancer) Wrap(next balancerHandler) error {
return nil
}
// UpsertServer upsert a server
func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error {
rb.mtx.Lock()
defer rb.mtx.Unlock()
@@ -239,6 +262,7 @@ func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error {
return nil
}
// RemoveServer remove a server
func (rb *Rebalancer) RemoveServer(u *url.URL) error {
rb.mtx.Lock()
defer rb.mtx.Unlock()
@@ -289,7 +313,7 @@ func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
return nil, -1
}
// Called on every load balancer ServeHTTP call, returns the suggested weights
// adjustWeights Called on every load balancer ServeHTTP call, returns the suggested weights
// on every call, can adjust weights if needed.
func (rb *Rebalancer) adjustWeights() {
rb.mtx.Lock()
@@ -319,7 +343,7 @@ func (rb *Rebalancer) adjustWeights() {
func (rb *Rebalancer) applyWeights() {
for _, srv := range rb.servers {
log.Debugf("upsert server %v, weight %v", srv.url, srv.curWeight)
rb.log.Debugf("upsert server %v, weight %v", srv.url, srv.curWeight)
rb.next.UpsertServer(srv.url, Weight(srv.curWeight))
}
}
@@ -331,7 +355,7 @@ func (rb *Rebalancer) setMarkedWeights() bool {
if srv.good {
weight := increase(srv.curWeight)
if weight <= FSMMaxWeight {
log.Debugf("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight)
rb.log.Debugf("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight)
srv.curWeight = weight
changed = true
}
@@ -378,7 +402,7 @@ func (rb *Rebalancer) markServers() bool {
}
}
if len(g) != 0 && len(b) != 0 {
log.Debugf("bad: %v good: %v, ratings: %v", b, g, rb.ratings)
rb.log.Debugf("bad: %v good: %v, ratings: %v", b, g, rb.ratings)
}
return len(g) != 0 && len(b) != 0
}
@@ -433,9 +457,8 @@ func decrease(target, current int) int {
adjusted := current / FSMGrowFactor
if adjusted < target {
return target
} else {
return adjusted
}
return adjusted
}
// rebalancer server record that keeps track of the original weight supplied by user
@@ -448,9 +471,9 @@ type rbServer struct {
}
const (
// This is the maximum weight that handler will set for the server
// FSMMaxWeight is the maximum weight that handler will set for the server
FSMMaxWeight = 4096
// Multiplier for the server weight
// FSMGrowFactor Multiplier for the server weight
FSMGrowFactor = 4
)
@@ -460,10 +483,12 @@ type codeMeter struct {
codeE int
}
// Rating gets ratio
func (n *codeMeter) Rating() float64 {
return n.r.Ratio()
}
// Record records a meter
func (n *codeMeter) Record(code int, d time.Duration) {
if code >= n.codeS && code < n.codeE {
n.r.IncA(1)
@@ -472,6 +497,7 @@ func (n *codeMeter) Record(code int, d time.Duration) {
}
}
// IsReady returns true if the counter is ready
func (n *codeMeter) IsReady() bool {
return n.r.IsReady()
}

View File

@@ -1,4 +1,4 @@
// package roundrobin implements dynamic weighted round robin load balancer http handler
// Package roundrobin implements dynamic weighted round robin load balancer http handler
package roundrobin
import (
@@ -30,6 +30,7 @@ func ErrorHandler(h utils.ErrorHandler) LBOption {
}
}
// EnableStickySession enable sticky session
func EnableStickySession(stickySession *StickySession) LBOption {
return func(s *RoundRobin) error {
s.stickySession = stickySession
@@ -37,7 +38,7 @@ func EnableStickySession(stickySession *StickySession) LBOption {
}
}
// ErrorHandler is a functional argument that sets error handler of the server
// RoundRobinRequestRewriteListener is a functional argument that sets error handler of the server
func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption {
return func(s *RoundRobin) error {
s.requestRewriteListener = rrl
@@ -45,6 +46,7 @@ func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption {
}
}
// RoundRobin implements dynamic weighted round robin load balancer http handler
type RoundRobin struct {
mutex *sync.Mutex
next http.Handler
@@ -55,8 +57,11 @@ type RoundRobin struct {
currentWeight int
stickySession *StickySession
requestRewriteListener RequestRewriteListener
log *log.Logger
}
// New created a new RoundRobin
func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
rr := &RoundRobin{
next: next,
@@ -64,6 +69,8 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
mutex: &sync.Mutex{},
servers: []*server{},
stickySession: nil,
log: log.StandardLogger(),
}
for _, o := range opts {
if err := o(rr); err != nil {
@@ -76,13 +83,24 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
return rr, nil
}
// RoundRobinLogger defines the logger the round robin load balancer will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func RoundRobinLogger(l *log.Logger) LBOption {
return func(r *RoundRobin) error {
r.log = l
return nil
}
}
// Next returns the next handler
func (r *RoundRobin) Next() http.Handler {
return r.next
}
func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
if r.log.Level >= log.DebugLevel {
logEntry := r.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/roundrobin/rr: completed ServeHttp on request")
}
@@ -116,12 +134,12 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
newReq.URL = url
}
if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
if r.log.Level >= log.DebugLevel {
// log which backend URL we're sending this request to
r.log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
}
//Emit event to a listener if one exists
// Emit event to a listener if one exists
if r.requestRewriteListener != nil {
r.requestRewriteListener(req, &newReq)
}
@@ -129,6 +147,7 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.next.ServeHTTP(w, &newReq)
}
// NextServer gets the next server
func (r *RoundRobin) NextServer() (*url.URL, error) {
srv, err := r.nextServer()
if err != nil {
@@ -172,6 +191,7 @@ func (r *RoundRobin) nextServer() (*server, error) {
}
}
// RemoveServer remove a server
func (r *RoundRobin) RemoveServer(u *url.URL) error {
r.mutex.Lock()
defer r.mutex.Unlock()
@@ -185,6 +205,7 @@ func (r *RoundRobin) RemoveServer(u *url.URL) error {
return nil
}
// Servers gets servers URL
func (r *RoundRobin) Servers() []*url.URL {
r.mutex.Lock()
defer r.mutex.Unlock()
@@ -196,6 +217,7 @@ func (r *RoundRobin) Servers() []*url.URL {
return out
}
// ServerWeight gets the server weight
func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) {
r.mutex.Lock()
defer r.mutex.Unlock()
@@ -206,7 +228,7 @@ func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) {
return -1, false
}
// In case if server is already present in the load balancer, returns error
// UpsertServer In case if server is already present in the load balancer, returns error
func (r *RoundRobin) UpsertServer(u *url.URL, options ...ServerOption) error {
r.mutex.Lock()
defer r.mutex.Unlock()
@@ -306,6 +328,7 @@ type server struct {
var defaultWeight = 1
// SetDefaultWeight sets the default server weight
func SetDefaultWeight(weight int) error {
if weight < 0 {
return fmt.Errorf("default weight should be >= 0")

View File

@@ -1,4 +1,3 @@
// package stickysession is a mixin for load balancers that implements layer 7 (http cookie) session affinity
package roundrobin
import (
@@ -6,12 +5,14 @@ import (
"net/url"
)
// StickySession is a mixin for load balancers that implements layer 7 (http cookie) session affinity
type StickySession struct {
cookieName string
}
// NewStickySession creates a new StickySession
func NewStickySession(cookieName string) *StickySession {
return &StickySession{cookieName}
return &StickySession{cookieName: cookieName}
}
// GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers.
@@ -32,11 +33,11 @@ func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.
if s.isBackendAlive(serverURL, servers) {
return serverURL, true, nil
} else {
return nil, false, nil
}
return nil, false, nil
}
// StickBackend creates and sets the cookie
func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) {
cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"}
http.SetCookie(*w, cookie)

View File

@@ -6,6 +6,7 @@ import (
"strings"
)
// BasicAuth basic auth information
type BasicAuth struct {
Username string
Password string
@@ -16,6 +17,7 @@ func (ba *BasicAuth) String() string {
return fmt.Sprintf("Basic %s", encoded)
}
// ParseAuthHeader creates a new BasicAuth from header values
func ParseAuthHeader(header string) (*BasicAuth, error) {
values := strings.Fields(header)
if len(values) != 2 {

View File

@@ -9,6 +9,7 @@ import (
"net/url"
)
// SerializableHttpRequest serializable HTTP request
type SerializableHttpRequest struct {
Method string
URL *url.URL
@@ -28,6 +29,7 @@ type SerializableHttpRequest struct {
TLS *tls.ConnectionState
}
// Clone clone a request
func Clone(r *http.Request) *SerializableHttpRequest {
if r == nil {
return nil
@@ -47,14 +49,16 @@ func Clone(r *http.Request) *SerializableHttpRequest {
return rc
}
// ToJson serializes to JSON
func (s *SerializableHttpRequest) ToJson() string {
if jsonVal, err := json.Marshal(s); err != nil || jsonVal == nil {
return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err.Error())
} else {
return string(jsonVal)
jsonVal, err := json.Marshal(s)
if err != nil || jsonVal == nil {
return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err)
}
return string(jsonVal)
}
// DumpHttpRequest dump a HTTP request to JSON
func DumpHttpRequest(req *http.Request) string {
return fmt.Sprintf("%v", Clone(req).ToJson())
return Clone(req).ToJson()
}

View File

@@ -1,22 +1,34 @@
package utils
import (
"context"
"io"
"net"
"net/http"
log "github.com/sirupsen/logrus"
)
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
const StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
const StatusClientClosedRequestText = "Client Closed Request"
// ErrorHandler error handler
type ErrorHandler interface {
ServeHTTP(w http.ResponseWriter, req *http.Request, err error)
}
// DefaultHandler default error handler
var DefaultHandler ErrorHandler = &StdHandler{}
type StdHandler struct {
}
// StdHandler Standard error handler
type StdHandler struct{}
func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
statusCode := http.StatusInternalServerError
if e, ok := err.(net.Error); ok {
if e.Timeout() {
statusCode = http.StatusGatewayTimeout
@@ -25,11 +37,23 @@ func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err err
}
} else if err == io.EOF {
statusCode = http.StatusBadGateway
} else if err == context.Canceled {
statusCode = StatusClientClosedRequest
}
w.WriteHeader(statusCode)
w.Write([]byte(http.StatusText(statusCode)))
w.Write([]byte(statusText(statusCode)))
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
}
return http.StatusText(statusCode)
}
// ErrorHandlerFunc error handler function type
type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error)
// ServeHTTP calls f(w, r).

View File

@@ -12,91 +12,86 @@ import (
log "github.com/sirupsen/logrus"
)
type ProxyWriter interface {
http.ResponseWriter
GetLength() int64
StatusCode() int
GetWriter() http.ResponseWriter
// ProxyWriter calls recorder, used to debug logs
type ProxyWriter struct {
w http.ResponseWriter
code int
length int64
log *log.Logger
}
// ProxyWriterWithoutCloseNotify helps to capture response headers and status code
// from the ServeHTTP. It can be safely passed to ServeHTTP handler,
// wrapping the real response writer.
type ProxyWriterWithoutCloseNotify struct {
W http.ResponseWriter
Code int
Length int64
// NewProxyWriter creates a new ProxyWriter
func NewProxyWriter(w http.ResponseWriter) *ProxyWriter {
return NewProxyWriterWithLogger(w, log.StandardLogger())
}
func NewProxyWriterWithoutCloseNotify(writer http.ResponseWriter) *ProxyWriterWithoutCloseNotify {
return &ProxyWriterWithoutCloseNotify{
W: writer,
// NewProxyWriterWithLogger creates a new ProxyWriter
func NewProxyWriterWithLogger(w http.ResponseWriter, l *log.Logger) *ProxyWriter {
return &ProxyWriter{
w: w,
log: l,
}
}
func NewSimpleProxyWriter(writer http.ResponseWriter) *SimpleProxyWriter {
return &SimpleProxyWriter{
ProxyWriterWithoutCloseNotify: NewProxyWriterWithoutCloseNotify(writer),
}
}
type SimpleProxyWriter struct {
*ProxyWriterWithoutCloseNotify
}
func (p *ProxyWriterWithoutCloseNotify) GetWriter() http.ResponseWriter {
return p.W
}
func (p *ProxyWriterWithoutCloseNotify) StatusCode() int {
if p.Code == 0 {
// StatusCode gets status code
func (p *ProxyWriter) StatusCode() int {
if p.code == 0 {
// per contract standard lib will set this to http.StatusOK if not set
// by user, here we avoid the confusion by mirroring this logic
return http.StatusOK
}
return p.Code
return p.code
}
func (p *ProxyWriterWithoutCloseNotify) Header() http.Header {
return p.W.Header()
// GetLength gets content length
func (p *ProxyWriter) GetLength() int64 {
return p.length
}
func (p *ProxyWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
p.Length = p.Length + int64(len(buf))
return p.W.Write(buf)
// Header gets response header
func (p *ProxyWriter) Header() http.Header {
return p.w.Header()
}
func (p *ProxyWriterWithoutCloseNotify) WriteHeader(code int) {
p.Code = code
p.W.WriteHeader(code)
func (p *ProxyWriter) Write(buf []byte) (int, error) {
p.length = p.length + int64(len(buf))
return p.w.Write(buf)
}
func (p *ProxyWriterWithoutCloseNotify) Flush() {
if f, ok := p.W.(http.Flusher); ok {
// WriteHeader writes status code
func (p *ProxyWriter) WriteHeader(code int) {
p.code = code
p.w.WriteHeader(code)
}
// Flush flush the writer
func (p *ProxyWriter) Flush() {
if f, ok := p.w.(http.Flusher); ok {
f.Flush()
}
}
func (p *ProxyWriterWithoutCloseNotify) GetLength() int64 {
return p.Length
}
func (p *SimpleProxyWriter) CloseNotify() <-chan bool {
if cn, ok := p.GetWriter().(http.CloseNotifier); ok {
// CloseNotify returns a channel that receives at most a single value (true)
// when the client connection has gone away.
func (p *ProxyWriter) CloseNotify() <-chan bool {
if cn, ok := p.w.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.GetWriter()))
p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.w))
return make(<-chan bool)
}
func (p *ProxyWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hi, ok := p.W.(http.Hijacker); ok {
// Hijack lets the caller take over the connection.
func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hi, ok := p.w.(http.Hijacker); ok {
return hi.Hijack()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.W))
return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.W))
p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.w))
return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.w))
}
// NewBufferWriter creates a new BufferWriter
func NewBufferWriter(w io.WriteCloser) *BufferWriter {
return &BufferWriter{
W: w,
@@ -104,16 +99,19 @@ func NewBufferWriter(w io.WriteCloser) *BufferWriter {
}
}
// BufferWriter buffer writer
type BufferWriter struct {
H http.Header
Code int
W io.WriteCloser
}
// Close close the writer
func (b *BufferWriter) Close() error {
return b.W.Close()
}
// Header gets response header
func (b *BufferWriter) Header() http.Header {
return b.H
}
@@ -122,11 +120,13 @@ func (b *BufferWriter) Write(buf []byte) (int, error) {
return b.W.Write(buf)
}
// WriteHeader sets rw.Code.
// WriteHeader writes status code
func (b *BufferWriter) WriteHeader(code int) {
b.Code = code
}
// CloseNotify returns a channel that receives at most a single value (true)
// when the client connection has gone away.
func (b *BufferWriter) CloseNotify() <-chan bool {
if cn, ok := b.W.(http.CloseNotifier); ok {
return cn.CloseNotify()
@@ -135,12 +135,13 @@ func (b *BufferWriter) CloseNotify() <-chan bool {
return make(<-chan bool)
}
// Hijack lets the caller take over the connection.
func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hi, ok := b.W.(http.Hijacker); ok {
return hi.Hijack()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.W))
return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.W))
log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.W))
return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.W))
}
type nopWriteCloser struct {
@@ -149,10 +150,10 @@ type nopWriteCloser struct {
func (*nopWriteCloser) Close() error { return nil }
// NopCloser returns a WriteCloser with a no-op Close method wrapping
// NopWriteCloser returns a WriteCloser with a no-op Close method wrapping
// the provided Writer w.
func NopWriteCloser(w io.Writer) io.WriteCloser {
return &nopWriteCloser{w}
return &nopWriteCloser{Writer: w}
}
// CopyURL provides update safe copy by avoiding shallow copying User field

View File

@@ -6,21 +6,25 @@ import (
"strings"
)
// ExtractSource extracts the source from the request, e.g. that may be client ip, or particular header that
// SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that
// identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters
// error should be returned when source can not be identified
type SourceExtractor interface {
Extract(req *http.Request) (token string, amount int64, err error)
}
// ExtractorFunc extractor function type
type ExtractorFunc func(req *http.Request) (token string, amount int64, err error)
// Extract extract from request
func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) {
return f(req)
}
// ExtractSource extract source function type
type ExtractSource func(req *http.Request)
// NewExtractor creates a new SourceExtractor
func NewExtractor(variable string) (SourceExtractor, error) {
if variable == "client.ip" {
return ExtractorFunc(extractClientIP), nil
@@ -31,17 +35,17 @@ func NewExtractor(variable string) (SourceExtractor, error) {
if strings.HasPrefix(variable, "request.header.") {
header := strings.TrimPrefix(variable, "request.header.")
if len(header) == 0 {
return nil, fmt.Errorf("Wrong header: %s", header)
return nil, fmt.Errorf("wrong header: %s", header)
}
return makeHeaderExtractor(header), nil
}
return nil, fmt.Errorf("Unsupported limiting variable: '%s'", variable)
return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable)
}
func extractClientIP(req *http.Request) (string, int64, error) {
vals := strings.SplitN(req.RemoteAddr, ":", 2)
if len(vals[0]) == 0 {
return "", 0, fmt.Errorf("Failed to parse client IP: %v", req.RemoteAddr)
return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr)
}
return vals[0], 1, nil
}