feat: secure cookie keys config

If keys aren't specified they will be generated. The downside
of this is that the keys will change every server restart
which will invalidate previously generated cookies.

If the keys are specified in the config.yml, they are
used between restarts and all previously generated
cookies remain valid.
This commit is contained in:
Evan Buss
2024-07-13 18:00:09 +00:00
parent 9094e780d0
commit 28d3a7d761
2 changed files with 61 additions and 10 deletions

41
main.go
View File

@@ -1,10 +1,13 @@
package main package main
import ( import (
"encoding/hex"
"flag" "flag"
"fmt"
"log" "log"
"os" "os"
"github.com/gorilla/securecookie"
"github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/basicflag" "github.com/knadh/koanf/providers/basicflag"
"github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/providers/file"
@@ -13,9 +16,15 @@ import (
type config struct { type config struct {
Port string `koanf:"port"` Port string `koanf:"port"`
Auth auth `koanf:"auth"`
Feeds []feedConfig `koanf:"feeds" ` Feeds []feedConfig `koanf:"feeds" `
} }
type auth struct {
HashKey string `koanf:"hash_key"`
BlockKey string `koanf:"block_key"`
}
type feedConfig struct { type feedConfig struct {
Name string `koanf:"name"` Name string `koanf:"name"`
Url string `koanf:"url"` Url string `koanf:"url"`
@@ -27,7 +36,13 @@ func main() {
fs := flag.NewFlagSet("", flag.ContinueOnError) fs := flag.NewFlagSet("", flag.ContinueOnError)
fs.String("port", "8080", "port to listen on") fs.String("port", "8080", "port to listen on")
configPath := fs.String("config", "config.yml", "config file to load") configPath := fs.String("config", "config.yml", "config file to load")
generateKeys := fs.Bool("generate-keys", false, "generate cookie signing keys and exit")
if err := fs.Parse(os.Args[1:]); err != nil { if err := fs.Parse(os.Args[1:]); err != nil {
log.Fatal(err)
}
if *generateKeys {
displayKeys()
os.Exit(0) os.Exit(0)
} }
@@ -49,6 +64,30 @@ func main() {
log.Fatal("No feeds defined in config") log.Fatal("No feeds defined in config")
} }
server := NewServer(&config) if config.Auth.HashKey == "" || config.Auth.BlockKey == "" {
log.Println("Generating new cookie signing credentials")
hashKey, blockKey := displayKeys()
config.Auth.HashKey = hashKey
config.Auth.BlockKey = blockKey
}
server, err := NewServer(&config)
if err != nil {
log.Fatal(err)
}
server.Serve() server.Serve()
} }
func displayKeys() (string, string) {
hashKey := hex.EncodeToString(securecookie.GenerateRandomKey(32))
blockKey := hex.EncodeToString(securecookie.GenerateRandomKey(32))
log.Println("Set these values in your config file to persist authentication between server restarts.")
fmt.Println("auth:")
fmt.Printf(" hash_key: %s\n", hashKey)
fmt.Printf(" block_key: %s\n", blockKey)
return hashKey, blockKey
}

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"encoding/hex"
"fmt" "fmt"
"io" "io"
"log" "log"
@@ -34,6 +35,7 @@ var (
type Server struct { type Server struct {
addr string addr string
router *http.ServeMux router *http.ServeMux
s *securecookie.SecureCookie
} }
type Credentials struct { type Credentials struct {
@@ -41,19 +43,29 @@ type Credentials struct {
Password string Password string
} }
var s = securecookie.New(securecookie.GenerateRandomKey(32), securecookie.GenerateRandomKey(32)) func NewServer(config *config) (*Server, error) {
hashKey, err := hex.DecodeString(config.Auth.HashKey)
if err != nil {
return nil, err
}
blockKey, err := hex.DecodeString(config.Auth.BlockKey)
if err != nil {
return nil, err
}
s := securecookie.New(hashKey, blockKey)
func NewServer(config *config) *Server {
router := http.NewServeMux() router := http.NewServeMux()
router.HandleFunc("GET /{$}", handleHome(config.Feeds)) router.HandleFunc("GET /{$}", handleHome(config.Feeds))
router.HandleFunc("GET /feed", handleFeed("tmp/")) router.HandleFunc("GET /feed", handleFeed("tmp/", s))
router.HandleFunc("/auth", handleAuth()) router.HandleFunc("/auth", handleAuth(s))
router.Handle("GET /static/", http.FileServer(http.FS(html.StaticFiles()))) router.Handle("GET /static/", http.FileServer(http.FS(html.StaticFiles())))
return &Server{ return &Server{
addr: ":" + config.Port, addr: ":" + config.Port,
router: router, router: router,
} s: s,
}, nil
} }
func (s *Server) Serve() { func (s *Server) Serve() {
@@ -75,7 +87,7 @@ func handleHome(feeds []feedConfig) http.HandlerFunc {
} }
} }
func handleFeed(outputDir string) http.HandlerFunc { func handleFeed(outputDir string, s *securecookie.SecureCookie) http.HandlerFunc {
kepubConverter := &convert.KepubConverter{} kepubConverter := &convert.KepubConverter{}
mobiConverter := &convert.MobiConverter{} mobiConverter := &convert.MobiConverter{}
@@ -98,7 +110,7 @@ func handleFeed(outputDir string) http.HandlerFunc {
queryURL = replaceSearchPlaceHolder(queryURL, searchTerm) queryURL = replaceSearchPlaceHolder(queryURL, searchTerm)
} }
resp, err := fetchFromUrl(queryURL, getCredentials(r)) resp, err := fetchFromUrl(queryURL, getCredentials(r, s))
if err != nil { if err != nil {
handleError(r, w, "Failed to fetch", err) handleError(r, w, "Failed to fetch", err)
return return
@@ -166,7 +178,7 @@ func handleFeed(outputDir string) http.HandlerFunc {
} }
} }
func handleAuth() http.HandlerFunc { func handleAuth(s *securecookie.SecureCookie) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
returnUrl := r.URL.Query().Get("return") returnUrl := r.URL.Query().Get("return")
if returnUrl == "" { if returnUrl == "" {
@@ -219,7 +231,7 @@ func handleAuth() http.HandlerFunc {
} }
} }
func getCredentials(r *http.Request) *Credentials { func getCredentials(r *http.Request, s *securecookie.SecureCookie) *Credentials {
cookie, err := r.Cookie("auth-creds") cookie, err := r.Cookie("auth-creds")
if err != nil { if err != nil {
return nil return nil