diff --git a/go.mod b/go.mod index b5ad254..006d517 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,11 @@ require ( github.com/knadh/koanf/v2 v2.1.1 ) +require github.com/gorilla/securecookie v1.1.2 // indirect + require ( github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 github.com/go-viper/mapstructure/v2 v2.0.0 // indirect github.com/knadh/koanf/maps v0.1.1 // indirect github.com/knadh/koanf/providers/basicflag v1.0.0 diff --git a/go.sum b/go.sum index 962c684..6a19494 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-viper/mapstructure/v2 v2.0.0 h1:dhn8MZ1gZ0mzeodTG3jt5Vj/o87xZKuNAprG2mQfMfc= github.com/go-viper/mapstructure/v2 v2.0.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/knadh/koanf/maps v0.1.1 h1:G5TjmUh2D7G2YWf5SQQqSiHRJEjaicvU0KpypqB3NIs= github.com/knadh/koanf/maps v0.1.1/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI= github.com/knadh/koanf/parsers/yaml v0.1.0 h1:ZZ8/iGfRLvKSaMEECEBPM1HQslrZADk8fP1XFUxVI5w= diff --git a/html/html.go b/html/html.go index 717981a..6c1540a 100644 --- a/html/html.go +++ b/html/html.go @@ -6,19 +6,36 @@ import ( "io" "github.com/evan-buss/opds-proxy/opds" + sprig "github.com/go-task/slim-sprig/v3" ) //go:embed * var files embed.FS var ( - home = parse("home.html") - feed = parse("feed.html", "partials/search.html") + home = parse("home.html") + feed = parse("feed.html", "partials/search.html") + login = parse("login.html") ) func parse(file ...string) *template.Template { file = append(file, "layout.html") - return template.Must(template.New("layout.html").ParseFS(files, file...)) + return template.Must( + template.New("layout.html"). + Funcs(sprig.FuncMap()). + ParseFS(files, file...), + ) +} + +type LoginParams struct { + ReturnURL string +} + +func Login(w io.Writer, p LoginParams, partial string) error { + if partial == "" { + partial = "layout.html" + } + return login.ExecuteTemplate(w, partial, p) } type FeedParams struct { diff --git a/html/login.html b/html/login.html new file mode 100644 index 0000000..993c462 --- /dev/null +++ b/html/login.html @@ -0,0 +1,48 @@ +{{ define "content" }} +
+

{{index (urlParse .ReturnURL) "query" | trimPrefix "q=" }}

+

Log in to access this feed

+
+
+ + + +
+
+
+ + +{{ end }} \ No newline at end of file diff --git a/html/static/style.css b/html/static/style.css index e2337e9..7733f09 100644 --- a/html/static/style.css +++ b/html/static/style.css @@ -101,3 +101,24 @@ a { .nav-controls:last-child { padding-right: 1rem; } + +*, *::before, *::after { + box-sizing: border-box; +} +* { + margin: 0; +} +body { + line-height: 1.5; + -webkit-font-smoothing: antialiased; +} +img, picture, video, canvas, svg { + display: block; + max-width: 100%; +} +input, button, textarea, select { + font: inherit; +} +p, h1, h2, h3, h4, h5, h6 { + overflow-wrap: break-word; +} \ No newline at end of file diff --git a/server.go b/server.go index c53391c..83eb29e 100644 --- a/server.go +++ b/server.go @@ -16,13 +16,9 @@ import ( "github.com/evan-buss/opds-proxy/convert" "github.com/evan-buss/opds-proxy/html" "github.com/evan-buss/opds-proxy/opds" + "github.com/gorilla/securecookie" ) -type Server struct { - addr string - router *http.ServeMux -} - const ( MOBI_MIME = "application/x-mobipocket-ebook" EPUB_MIME = "application/epub+zip" @@ -35,10 +31,23 @@ var ( _ = mime.AddExtensionType(".mobi", MOBI_MIME) ) +type Server struct { + addr string + router *http.ServeMux +} + +type Credentials struct { + Username string + Password string +} + +var s = securecookie.New(securecookie.GenerateRandomKey(32), securecookie.GenerateRandomKey(32)) + func NewServer(config *config) *Server { router := http.NewServeMux() router.HandleFunc("GET /{$}", handleHome(config.Feeds)) router.HandleFunc("GET /feed", handleFeed("tmp/")) + router.HandleFunc("/auth", handleAuth()) router.Handle("GET /static/", http.FileServer(http.FS(html.StaticFiles()))) return &Server{ @@ -90,13 +99,18 @@ func handleFeed(outputDir string) http.HandlerFunc { queryURL = replaceSearchPlaceHolder(queryURL, searchTerm) } - resp, err := fetchFromUrl(queryURL) + resp, err := fetchFromUrl(queryURL, getCredentials(r)) if err != nil { handleError(r, w, "Failed to fetch", err) return } defer resp.Body.Close() + if resp.StatusCode == http.StatusUnauthorized { + http.Redirect(w, r, "/auth?return="+r.URL.String(), http.StatusFound) + return + } + contentType := resp.Header.Get("Content-Type") mimeType, _, err := mime.ParseMediaType(contentType) if err != nil { @@ -153,7 +167,82 @@ func handleFeed(outputDir string) http.HandlerFunc { } } -func fetchFromUrl(url string) (*http.Response, error) { +func handleAuth() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + returnUrl := r.URL.Query().Get("return") + if returnUrl == "" { + http.Error(w, "No return URL specified", http.StatusBadRequest) + return + } + + if r.Method == "GET" { + html.Login(w, html.LoginParams{ReturnURL: returnUrl}, partial(r)) + return + } + + if r.Method == "POST" { + username := r.FormValue("username") + password := r.FormValue("password") + + rUrl, err := url.Parse(returnUrl) + if err != nil { + http.Error(w, "Invalid return URL", http.StatusBadRequest) + } + domain, err := url.Parse(rUrl.Query().Get("q")) + if err != nil { + http.Error(w, "Invalid site", http.StatusBadRequest) + } + + value := map[string]Credentials{ + domain.Hostname(): {Username: username, Password: password}, + } + + encoded, err := s.Encode("auth-creds", value) + if err != nil { + handleError(r, w, "Failed to encode credentials", err) + return + } + cookie := &http.Cookie{ + Name: "auth-creds", + Value: encoded, + Path: "/", + // Kobo fails to set cookies with HttpOnly or Secure flags + Secure: false, + HttpOnly: false, + } + + http.SetCookie(w, cookie) + http.Redirect(w, r, returnUrl, http.StatusFound) + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +func getCredentials(r *http.Request) *Credentials { + cookie, err := r.Cookie("auth-creds") + if err != nil { + return nil + } + + value := make(map[string]*Credentials) + if err = s.Decode("auth-creds", cookie.Value, &value); err != nil { + return nil + } + + if !r.URL.Query().Has("q") { + return nil + } + + feedUrl, err := url.Parse(r.URL.Query().Get("q")) + if err != nil { + return nil + } + + return value[feedUrl.Hostname()] +} + +func fetchFromUrl(url string, credentials *Credentials) (*http.Response, error) { client := &http.Client{ Timeout: 2 * time.Second, } @@ -162,7 +251,10 @@ func fetchFromUrl(url string) (*http.Response, error) { if err != nil { return nil, err } - req.SetBasicAuth("public", "evanbuss") + + if credentials != nil { + req.SetBasicAuth(credentials.Username, credentials.Password) + } return client.Do(req) }