forked from Ivasoft/opds-proxy
feat: request deduplication / debouncing
Kobo eReaders have a buggy browser that makes 2 requests for the same HTTP resource when you click a link. This change ensures that requests within a certain time frame from the same IP, for the same path / query params will only be executed a single time. We record the http request response and replay it for the second request. If we get 2 simultaneous requests, we use the sync/singleflight library to ensure only the first request is actually processed. The second waits for the shared result of the first. This probably adds latency since some requests are blocked while we determine if we already have a cache entry, but for a simple service like this I don't think it matters.
This commit is contained in:
1
go.mod
1
go.mod
@@ -30,6 +30,7 @@ require (
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/sync v0.8.0
|
||||
golang.org/x/sys v0.22.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -42,6 +42,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
71
internal/cache/cache.go
vendored
Normal file
71
internal/cache/cache.go
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CacheEntry[T any] struct {
|
||||
timestamp time.Time
|
||||
Value *T
|
||||
}
|
||||
|
||||
type Cache[T any] struct {
|
||||
entries map[string]*CacheEntry[T]
|
||||
config CacheConfig
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
TTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
}
|
||||
|
||||
func NewCache[T any](config CacheConfig) *Cache[T] {
|
||||
cache := &Cache[T]{
|
||||
entries: make(map[string]*CacheEntry[T]),
|
||||
config: config,
|
||||
}
|
||||
go cache.cleanupLoop()
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *Cache[T]) Set(key string, entry *T) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.entries[key] = &CacheEntry[T]{timestamp: time.Now(), Value: entry}
|
||||
}
|
||||
|
||||
func (c *Cache[T]) Get(key string) (*T, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
entry, exists := c.entries[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if exists && time.Since(entry.timestamp) > c.config.TTL {
|
||||
delete(c.entries, key)
|
||||
return nil, false
|
||||
}
|
||||
return entry.Value, exists
|
||||
}
|
||||
|
||||
func (c *Cache[T]) cleanupLoop() {
|
||||
ticker := time.NewTicker(c.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.cleanEntries()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache[T]) cleanEntries() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
for key, entry := range c.entries {
|
||||
if time.Since(entry.timestamp) > c.config.TTL {
|
||||
delete(c.entries, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
53
internal/debounce/debounce.go
Normal file
53
internal/debounce/debounce.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package debounce
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/evan-buss/opds-proxy/internal/cache"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
func NewDebounceMiddleware(debounce time.Duration) func(next http.HandlerFunc) http.HandlerFunc {
|
||||
responseCache := cache.NewCache[httptest.ResponseRecorder](cache.CacheConfig{CleanupInterval: time.Second, TTL: debounce})
|
||||
singleflight := singleflight.Group{}
|
||||
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
hash := md5.Sum([]byte(ip + r.URL.Path + r.URL.RawQuery))
|
||||
key := string(hex.EncodeToString(hash[:]))
|
||||
|
||||
if entry, exists := responseCache.Get(key); exists {
|
||||
w.Header().Set("X-Debounce", "true")
|
||||
writeResponse(entry, w)
|
||||
return
|
||||
}
|
||||
|
||||
rw, _, shared := singleflight.Do(key, func() (interface{}, error) {
|
||||
rw := httptest.NewRecorder()
|
||||
next(rw, r)
|
||||
return rw, nil
|
||||
})
|
||||
|
||||
recorder := rw.(*httptest.ResponseRecorder)
|
||||
responseCache.Set(key, recorder)
|
||||
|
||||
w.Header().Set("X-Shared", strconv.FormatBool(shared))
|
||||
writeResponse(recorder, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeResponse(rec *httptest.ResponseRecorder, w http.ResponseWriter) {
|
||||
for k, v := range rec.Header() {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
w.WriteHeader(rec.Code)
|
||||
w.Write(rec.Body.Bytes())
|
||||
}
|
||||
106
internal/debounce/debounce_test.go
Normal file
106
internal/debounce/debounce_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package debounce
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDebounceMiddleware(t *testing.T) {
|
||||
setup := func() (http.Handler, *int) {
|
||||
// Mock handler that simulates a slow response
|
||||
handlerCallCount := 0
|
||||
mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCallCount++
|
||||
time.Sleep(100 * time.Millisecond) // Simulate some work
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
middleware := NewDebounceMiddleware(500 * time.Millisecond)
|
||||
wrappedHandler := middleware(mockHandler)
|
||||
|
||||
return wrappedHandler, &handlerCallCount
|
||||
}
|
||||
|
||||
t.Run("Caching Behavior", func(t *testing.T) {
|
||||
wrappedHandler, handlerCallCount := setup()
|
||||
|
||||
// First request
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec1, req1)
|
||||
|
||||
if *handlerCallCount != 1 {
|
||||
t.Errorf("Expected handler to be called once, got %d", handlerCallCount)
|
||||
}
|
||||
|
||||
// Second request within debounce period
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec2, req2)
|
||||
|
||||
if *handlerCallCount != 1 {
|
||||
t.Errorf("Expected handler to still be called once, got %d", handlerCallCount)
|
||||
}
|
||||
|
||||
if rec2.Header().Get("X-Debounce") != "true" {
|
||||
t.Error("Expected second response to be debounced")
|
||||
}
|
||||
|
||||
// Wait for debounce period to expire
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// Third request after debounce period
|
||||
req3 := httptest.NewRequest("GET", "/test", nil)
|
||||
rec3 := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec3, req3)
|
||||
|
||||
if *handlerCallCount != 2 {
|
||||
t.Errorf("Expected handler to be called twice, got %d", handlerCallCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Singleflight Behavior", func(t *testing.T) {
|
||||
wrappedHandler, handlerCallCount := setup()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
requestCount := 10
|
||||
|
||||
for i := 0; i < requestCount; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec, req)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if *handlerCallCount != 1 {
|
||||
t.Errorf("Expected handler to be called once for concurrent requests, got %d", handlerCallCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Different Paths", func(t *testing.T) {
|
||||
wrappedHandler, handlerCallCount := setup()
|
||||
|
||||
// Request to path A
|
||||
reqA := httptest.NewRequest("GET", "/testA", nil)
|
||||
recA := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(recA, reqA)
|
||||
|
||||
// Request to path B
|
||||
reqB := httptest.NewRequest("GET", "/testB", nil)
|
||||
recB := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(recB, reqB)
|
||||
|
||||
if *handlerCallCount != 2 {
|
||||
t.Errorf("Expected handler to be called twice for different paths, got %d", handlerCallCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
17
server.go
17
server.go
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/evan-buss/opds-proxy/convert"
|
||||
"github.com/evan-buss/opds-proxy/html"
|
||||
"github.com/evan-buss/opds-proxy/internal/debounce"
|
||||
"github.com/evan-buss/opds-proxy/opds"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/securecookie"
|
||||
@@ -72,9 +73,15 @@ func NewServer(config *ProxyConfig) (*Server, error) {
|
||||
|
||||
s := securecookie.New(hashKey, blockKey)
|
||||
|
||||
// Kobo issues 2 requests for each clicked link. This middleware ensures
|
||||
// we only process the first request and provide the same response for the second.
|
||||
// This becomes more important when the requests aren't idempotent, such as triggering
|
||||
// a download.
|
||||
debounceMiddleware := debounce.NewDebounceMiddleware(time.Millisecond * 100)
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("GET /{$}", requestMiddleware(handleHome(config.Feeds)))
|
||||
router.Handle("GET /feed", requestMiddleware(handleFeed("tmp/", config.Feeds, s)))
|
||||
router.Handle("GET /feed", requestMiddleware(debounceMiddleware(handleFeed("tmp/", config.Feeds, s))))
|
||||
router.Handle("/auth", requestMiddleware(handleAuth(s)))
|
||||
router.Handle("GET /static/", http.FileServer(http.FS(html.StaticFiles())))
|
||||
|
||||
@@ -120,7 +127,12 @@ func requestMiddleware(next http.Handler) http.Handler {
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
log.Info("Request Completed", slog.String("duration", time.Since(start).String()))
|
||||
|
||||
log.Info("Request Completed",
|
||||
slog.String("duration", time.Since(start).String()),
|
||||
slog.Bool("debounce", w.Header().Get("X-Debounce") == "true"),
|
||||
slog.Bool("shared", w.Header().Get("X-Shared") == "true"),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -208,6 +220,7 @@ func handleFeed(outputDir string, feeds []FeedConfig, s *securecookie.SecureCook
|
||||
handleError(r, w, "Failed to render feed", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
|
||||
Reference in New Issue
Block a user