fix csrf and simplify global utilities

This commit is contained in:
Sky Johnson 2025-08-09 13:58:48 -05:00
parent 56dca44815
commit 0534da09a1
7 changed files with 237 additions and 176 deletions

View File

@ -6,6 +6,9 @@ import (
"dk/internal/users" "dk/internal/users"
) )
// Manager is the global singleton instance
var Manager *AuthManager
type User struct { type User struct {
ID int ID int
Username string Username string
@ -24,6 +27,11 @@ func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager {
} }
} }
// InitializeManager initializes the global Manager singleton
func InitializeManager(db *database.DB, sessionsFilePath string) {
Manager = NewAuthManager(db, sessionsFilePath)
}
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) { func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
var user *users.User var user *users.User
var err error var err error

View File

@ -17,7 +17,7 @@ const (
) )
type Session struct { type Session struct {
ID string `json:"id"` ID string `json:"-"` // Exclude from JSON since it's stored as the map key
UserID int `json:"user_id"` UserID int `json:"user_id"`
Username string `json:"username"` Username string `json:"username"`
Email string `json:"email"` Email string `json:"email"`

View File

@ -73,9 +73,6 @@ func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware
return return
} }
// CSRF validation passed, rotate token for security
csrf.RotateToken(ctx, authManager)
next(ctx, params) next(ctx, params)
} }
} }
@ -100,9 +97,6 @@ func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ct
return return
} }
// Rotate token after successful validation
csrf.RotateToken(ctx, authManager)
next(ctx, params) next(ctx, params)
} }
} }

View File

@ -10,40 +10,34 @@ import (
"dk/internal/password" "dk/internal/password"
"dk/internal/router" "dk/internal/router"
"dk/internal/template" "dk/internal/template"
"dk/internal/template/components"
"dk/internal/users" "dk/internal/users"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
// RegisterAuthRoutes sets up authentication routes // RegisterAuthRoutes sets up authentication routes
func RegisterAuthRoutes(r *router.Router, authManager *auth.AuthManager, templateCache *template.Cache) { func RegisterAuthRoutes(r *router.Router) {
// Guest routes (redirect to dashboard if already authenticated) // Guest routes
guestGroup := r.Group("") guestGroup := r.Group("")
guestGroup.Use(middleware.RequireGuest("/")) guestGroup.Use(middleware.RequireGuest("/"))
guestGroup.Get("/login", showLogin(authManager, templateCache)) guestGroup.Get("/login", showLogin())
guestGroup.Post("/login", processLogin(authManager, templateCache)) guestGroup.Post("/login", processLogin())
guestGroup.Get("/register", showRegister(authManager, templateCache)) guestGroup.Get("/register", showRegister())
guestGroup.Post("/register", processRegister(authManager, templateCache)) guestGroup.Post("/register", processRegister())
// Authenticated routes // Authenticated routes
authGroup := r.Group("") authGroup := r.Group("")
authGroup.Use(middleware.RequireAuth("/login")) authGroup.Use(middleware.RequireAuth("/login"))
authGroup.Post("/logout", processLogout(authManager)) authGroup.Post("/logout", processLogout())
} }
// showLogin displays the login form // showLogin displays the login form
func showLogin(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { func showLogin() router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
layoutTmpl, err := templateCache.Load("layout.html") loginTmpl, err := template.Cache.Load("auth/login.html")
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return
}
loginTmpl, err := templateCache.Load("auth/login.html")
if err != nil { if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err) fmt.Fprintf(ctx, "Template error: %v", err)
@ -51,34 +45,27 @@ func showLogin(authManager *auth.AuthManager, templateCache *template.Cache) rou
} }
loginFormData := map[string]any{ loginFormData := map[string]any{
"csrf_token": csrf.GetToken(ctx, authManager), "csrf_token": csrf.GetToken(ctx, auth.Manager),
"csrf_field": csrf.HiddenField(ctx, authManager), "csrf_field": csrf.HiddenField(ctx, auth.Manager),
"error_message": "", "error_message": "",
} }
loginContent := loginTmpl.RenderNamed(loginFormData) loginContent := loginTmpl.RenderNamed(loginFormData)
data := map[string]any{ pageData := components.NewPageData("Login - Dragon Knight", loginContent)
"title": "Login - Dragon Knight", if err := components.RenderPage(ctx, pageData, nil); err != nil {
"content": loginContent, ctx.SetStatusCode(fasthttp.StatusInternalServerError)
"topnav": "", fmt.Fprintf(ctx, "Template error: %v", err)
"leftside": "", return
"rightside": "",
"totaltime": middleware.GetRequestTime(ctx),
"numqueries": "0",
"version": "1.0.0",
"build": "dev",
} }
layoutTmpl.WriteTo(ctx, data)
} }
} }
// processLogin handles login form submission // processLogin handles login form submission
func processLogin(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { func processLogin() router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
// Validate CSRF token // Validate CSRF token
if !csrf.ValidateFormToken(ctx, authManager) { if !csrf.ValidateFormToken(ctx, auth.Manager) {
ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.WriteString("CSRF validation failed") ctx.WriteString("CSRF validation failed")
return return
@ -90,19 +77,26 @@ func processLogin(authManager *auth.AuthManager, templateCache *template.Cache)
// Validate input // Validate input
if email == "" || userPassword == "" { if email == "" || userPassword == "" {
showLoginError(ctx, authManager, templateCache, "Email and password are required") showLoginError(ctx, "Email and password are required")
return return
} }
// Authenticate user // Authenticate user
user, err := authManager.Authenticate(email, userPassword) user, err := auth.Manager.Authenticate(email, userPassword)
if err != nil { if err != nil {
showLoginError(ctx, authManager, templateCache, "Invalid email or password") showLoginError(ctx, "Invalid email or password")
return return
} }
// Create session and login // Create session and login
middleware.Login(ctx, authManager, user) middleware.Login(ctx, auth.Manager, user)
// Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if session := csrf.GetCurrentSession(ctx); session != nil {
csrf.StoreToken(session, cookieToken)
}
}
// Redirect to dashboard // Redirect to dashboard
ctx.Redirect("/dashboard", fasthttp.StatusFound) ctx.Redirect("/dashboard", fasthttp.StatusFound)
@ -110,16 +104,9 @@ func processLogin(authManager *auth.AuthManager, templateCache *template.Cache)
} }
// showRegister displays the registration form // showRegister displays the registration form
func showRegister(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { func showRegister() router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
layoutTmpl, err := templateCache.Load("layout.html") registerTmpl, err := template.Cache.Load("auth/register.html")
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return
}
registerTmpl, err := templateCache.Load("auth/register.html")
if err != nil { if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err) fmt.Fprintf(ctx, "Template error: %v", err)
@ -127,8 +114,8 @@ func showRegister(authManager *auth.AuthManager, templateCache *template.Cache)
} }
registerFormData := map[string]any{ registerFormData := map[string]any{
"csrf_token": csrf.GetToken(ctx, authManager), "csrf_token": csrf.GetToken(ctx, auth.Manager),
"csrf_field": csrf.HiddenField(ctx, authManager), "csrf_field": csrf.HiddenField(ctx, auth.Manager),
"error_message": "", "error_message": "",
"username": "", "username": "",
"email": "", "email": "",
@ -136,27 +123,20 @@ func showRegister(authManager *auth.AuthManager, templateCache *template.Cache)
registerContent := registerTmpl.RenderNamed(registerFormData) registerContent := registerTmpl.RenderNamed(registerFormData)
data := map[string]any{ pageData := components.NewPageData("Register - Dragon Knight", registerContent)
"title": "Register - Dragon Knight", if err := components.RenderPage(ctx, pageData, nil); err != nil {
"content": registerContent, ctx.SetStatusCode(fasthttp.StatusInternalServerError)
"topnav": "", fmt.Fprintf(ctx, "Template error: %v", err)
"leftside": "", return
"rightside": "",
"totaltime": middleware.GetRequestTime(ctx),
"numqueries": "0",
"version": "1.0.0",
"build": "dev",
} }
layoutTmpl.WriteTo(ctx, data)
} }
} }
// processRegister handles registration form submission // processRegister handles registration form submission
func processRegister(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { func processRegister() router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
// Validate CSRF token // Validate CSRF token
if !csrf.ValidateFormToken(ctx, authManager) { if !csrf.ValidateFormToken(ctx, auth.Manager) {
ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.WriteString("CSRF validation failed") ctx.WriteString("CSRF validation failed")
return return
@ -170,26 +150,26 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach
// Validate input // Validate input
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil { if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
showRegisterError(ctx, authManager, templateCache, err.Error(), username, email) showRegisterError(ctx, err.Error(), username, email)
return return
} }
// Check if username already exists // Check if username already exists
if _, err := users.GetByUsername(authManager.DB(), username); err == nil { if _, err := users.GetByUsername(auth.Manager.DB(), username); err == nil {
showRegisterError(ctx, authManager, templateCache, "Username already exists", username, email) showRegisterError(ctx, "Username already exists", username, email)
return return
} }
// Check if email already exists // Check if email already exists
if _, err := users.GetByEmail(authManager.DB(), email); err == nil { if _, err := users.GetByEmail(auth.Manager.DB(), email); err == nil {
showRegisterError(ctx, authManager, templateCache, "Email already registered", username, email) showRegisterError(ctx, "Email already registered", username, email)
return return
} }
// Hash password // Hash password
hashedPassword, err := password.Hash(userPassword) hashedPassword, err := password.Hash(userPassword)
if err != nil { if err != nil {
showRegisterError(ctx, authManager, templateCache, "Failed to process password", username, email) showRegisterError(ctx, "Failed to process password", username, email)
return return
} }
@ -203,8 +183,8 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach
} }
// Insert into database // Insert into database
if err := createUser(authManager, user); err != nil { if err := createUser(user); err != nil {
showRegisterError(ctx, authManager, templateCache, "Failed to create account", username, email) showRegisterError(ctx, "Failed to create account", username, email)
return return
} }
@ -215,38 +195,38 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach
Email: user.Email, Email: user.Email,
} }
middleware.Login(ctx, authManager, authUser) middleware.Login(ctx, auth.Manager, authUser)
// Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if session := csrf.GetCurrentSession(ctx); session != nil {
csrf.StoreToken(session, cookieToken)
}
}
ctx.Redirect("/", fasthttp.StatusFound) ctx.Redirect("/", fasthttp.StatusFound)
} }
} }
// processLogout handles logout // processLogout handles logout
func processLogout(authManager *auth.AuthManager) router.Handler { func processLogout() router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
// Validate CSRF token // Validate CSRF token
if !csrf.ValidateFormToken(ctx, authManager) { if !csrf.ValidateFormToken(ctx, auth.Manager) {
ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.WriteString("CSRF validation failed") ctx.WriteString("CSRF validation failed")
return return
} }
middleware.Logout(ctx, authManager) middleware.Logout(ctx, auth.Manager)
ctx.Redirect("/", fasthttp.StatusFound) ctx.Redirect("/", fasthttp.StatusFound)
} }
} }
// Helper functions // Helper functions
func showLoginError(ctx router.Ctx, authManager *auth.AuthManager, templateCache *template.Cache, errorMsg string) { func showLoginError(ctx router.Ctx, errorMsg string) {
layoutTmpl, err := templateCache.Load("layout.html") loginTmpl, err := template.Cache.Load("auth/login.html")
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return
}
loginTmpl, err := templateCache.Load("auth/login.html")
if err != nil { if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err) fmt.Fprintf(ctx, "Template error: %v", err)
@ -259,38 +239,24 @@ func showLoginError(ctx router.Ctx, authManager *auth.AuthManager, templateCache
} }
loginFormData := map[string]any{ loginFormData := map[string]any{
"csrf_token": csrf.GetToken(ctx, authManager), "csrf_token": csrf.GetToken(ctx, auth.Manager),
"csrf_field": csrf.HiddenField(ctx, authManager), "csrf_field": csrf.HiddenField(ctx, auth.Manager),
"error_message": errorHTML, "error_message": errorHTML,
} }
loginContent := loginTmpl.RenderNamed(loginFormData) loginContent := loginTmpl.RenderNamed(loginFormData)
data := map[string]any{
"title": "Login - Dragon Knight",
"content": loginContent,
"topnav": "",
"leftside": "",
"rightside": "",
"totaltime": middleware.GetRequestTime(ctx),
"numqueries": "0",
"version": "1.0.0",
"build": "dev",
}
ctx.SetStatusCode(fasthttp.StatusBadRequest) ctx.SetStatusCode(fasthttp.StatusBadRequest)
layoutTmpl.WriteTo(ctx, data) pageData := components.NewPageData("Login - Dragon Knight", loginContent)
} if err := components.RenderPage(ctx, pageData, nil); err != nil {
func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCache *template.Cache, errorMsg, username, email string) {
layoutTmpl, err := templateCache.Load("layout.html")
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err) fmt.Fprintf(ctx, "Template error: %v", err)
return return
} }
}
registerTmpl, err := templateCache.Load("auth/register.html") func showRegisterError(ctx router.Ctx, errorMsg, username, email string) {
registerTmpl, err := template.Cache.Load("auth/register.html")
if err != nil { if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err) fmt.Fprintf(ctx, "Template error: %v", err)
@ -303,8 +269,8 @@ func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCa
} }
registerFormData := map[string]any{ registerFormData := map[string]any{
"csrf_token": csrf.GetToken(ctx, authManager), "csrf_token": csrf.GetToken(ctx, auth.Manager),
"csrf_field": csrf.HiddenField(ctx, authManager), "csrf_field": csrf.HiddenField(ctx, auth.Manager),
"error_message": errorHTML, "error_message": errorHTML,
"username": username, "username": username,
"email": email, "email": email,
@ -312,20 +278,13 @@ func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCa
registerContent := registerTmpl.RenderNamed(registerFormData) registerContent := registerTmpl.RenderNamed(registerFormData)
data := map[string]any{
"title": "Register - Dragon Knight",
"content": registerContent,
"topnav": "",
"leftside": "",
"rightside": "",
"totaltime": middleware.GetRequestTime(ctx),
"numqueries": "0",
"version": "1.0.0",
"build": "dev",
}
ctx.SetStatusCode(fasthttp.StatusBadRequest) ctx.SetStatusCode(fasthttp.StatusBadRequest)
layoutTmpl.WriteTo(ctx, data) pageData := components.NewPageData("Register - Dragon Knight", registerContent)
if err := components.RenderPage(ctx, pageData, nil); err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return
}
} }
func validateRegistration(username, email, password, confirmPassword string) error { func validateRegistration(username, email, password, confirmPassword string) error {
@ -355,8 +314,8 @@ func validateRegistration(username, email, password, confirmPassword string) err
// createUser inserts a new user into the database // createUser inserts a new user into the database
// This is a simplified version - in a real app you'd have a proper users.Create function // This is a simplified version - in a real app you'd have a proper users.Create function
func createUser(authManager *auth.AuthManager, user *users.User) error { func createUser(user *users.User) error {
db := authManager.DB() db := auth.Manager.DB()
query := `INSERT INTO users (username, password, email, verified, auth) VALUES (?, ?, ?, ?, ?)` query := `INSERT INTO users (username, password, email, verified, auth) VALUES (?, ?, ?, ?, ?)`

View File

@ -14,77 +14,64 @@ import (
"dk/internal/router" "dk/internal/router"
"dk/internal/routes" "dk/internal/routes"
"dk/internal/template" "dk/internal/template"
"dk/internal/template/components"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
func Start(port string) error { func Start(port string) error {
// Initialize template cache - use current working directory for development
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
return fmt.Errorf("failed to get current working directory: %w", err) return fmt.Errorf("failed to get current working directory: %w", err)
} }
templateCache := template.NewCache(cwd) // Initialize template singleton
template.InitializeCache(cwd)
// Initialize database
db, err := database.Open("dk.db") db, err := database.Open("dk.db")
if err != nil { if err != nil {
return fmt.Errorf("failed to open database: %w", err) return fmt.Errorf("failed to open database: %w", err)
} }
defer db.Close() defer db.Close()
// Initialize authentication manager // Initialize auth singleton
authManager := auth.NewAuthManager(db, "sessions.json") auth.InitializeManager(db, "sessions.json")
// Don't defer Close() here - we'll handle it in shutdown
// Initialize router // Initialize router
r := router.New() r := router.New()
// Add middleware // Add middleware
r.Use(middleware.Timing()) r.Use(middleware.Timing())
r.Use(middleware.Auth(authManager)) r.Use(middleware.Auth(auth.Manager))
r.Use(middleware.CSRF(authManager)) r.Use(middleware.CSRF(auth.Manager))
// Setup route handlers // Setup route handlers
routes.RegisterAuthRoutes(r, authManager, templateCache) routes.RegisterAuthRoutes(r)
// Dashboard (protected route) // Dashboard (protected route)
r.WithMiddleware(middleware.RequireAuth("/login")).Get("/dashboard", func(ctx router.Ctx, params []string) { r.WithMiddleware(middleware.RequireAuth("/login")).Get("/dashboard", func(ctx router.Ctx, params []string) {
tmpl, err := templateCache.Load("layout.html")
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return
}
currentUser := middleware.GetCurrentUser(ctx) currentUser := middleware.GetCurrentUser(ctx)
totalSessions, activeSessions := authManager.SessionStats() totalSessions, activeSessions := auth.Manager.SessionStats()
data := map[string]any{ pageData := components.NewPageData(
"title": "Dashboard - Dragon Knight", "Dashboard - Dragon Knight",
"content": fmt.Sprintf("Welcome back, %s!", currentUser.Username), fmt.Sprintf("Welcome back, %s!", currentUser.Username),
"totaltime": middleware.GetRequestTime(ctx), )
"numqueries": "0",
"version": "1.0.0", additionalData := map[string]any{
"build": "dev",
"total_sessions": totalSessions, "total_sessions": totalSessions,
"active_sessions": activeSessions, "active_sessions": activeSessions,
"authenticated": true, "authenticated": true,
"username": currentUser.Username, "username": currentUser.Username,
} }
tmpl.WriteTo(ctx, data) if err := components.RenderPage(ctx, pageData, additionalData); err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
}
}) })
// Hello world endpoint (public) // Hello world endpoint (public)
r.Get("/", func(ctx router.Ctx, params []string) { r.Get("/", func(ctx router.Ctx, params []string) {
tmpl, err := templateCache.Load("layout.html")
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return
}
// Get current user if authenticated // Get current user if authenticated
currentUser := middleware.GetCurrentUser(ctx) currentUser := middleware.GetCurrentUser(ctx)
var username string var username string
@ -94,22 +81,25 @@ func Start(port string) error {
username = "Guest" username = "Guest"
} }
totalSessions, activeSessions := authManager.SessionStats() totalSessions, activeSessions := auth.Manager.SessionStats()
data := map[string]any{ pageData := components.NewPageData(
"title": "Dragon Knight", "Dragon Knight",
"content": fmt.Sprintf("Hello %s!", username), fmt.Sprintf("Hello %s!", username),
"totaltime": middleware.GetRequestTime(ctx), )
"numqueries": "0", // Placeholder for now
"version": "1.0.0", additionalData := map[string]any{
"build": "dev",
"total_sessions": totalSessions, "total_sessions": totalSessions,
"active_sessions": activeSessions, "active_sessions": activeSessions,
"authenticated": currentUser != nil, "authenticated": currentUser != nil,
"username": username, "username": username,
} }
tmpl.WriteTo(ctx, data) if err := components.RenderPage(ctx, pageData, additionalData); err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return
}
}) })
// Use current working directory for static files // Use current working directory for static files
@ -165,7 +155,7 @@ func Start(port string) error {
// Save sessions before shutdown // Save sessions before shutdown
log.Println("Saving sessions...") log.Println("Saving sessions...")
if err := authManager.Close(); err != nil { if err := auth.Manager.Close(); err != nil {
log.Printf("Error saving sessions: %v", err) log.Printf("Error saving sessions: %v", err)
} }

View File

@ -0,0 +1,102 @@
package components
import (
"fmt"
"maps"
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/middleware"
"dk/internal/router"
"dk/internal/template"
)
// GenerateTopNav generates the top navigation HTML based on authentication status
func GenerateTopNav(ctx router.Ctx) string {
if middleware.IsAuthenticated(ctx) {
csrfField := csrf.HiddenField(ctx, auth.Manager)
return fmt.Sprintf(`<form action="/logout" method="post" class="logout">
%s
<button class="img-button" type="submit"><img src="/assets/images/button_logout.gif" alt="Log Out" title="Log Out"></button>
</form>
<a href="/help"><img src="/assets/images/button_help.gif" alt="Help" title="Help"></a>`, csrfField)
} else {
return `<a href="/login"><img src="/assets/images/button_login.gif" alt="Log In" title="Log In"></a>
<a href="/register"><img src="/assets/images/button_register.gif" alt="Register" title="Register"></a>
<a href="/help"><img src="/assets/images/button_help.gif" alt="Help" title="Help"></a>`
}
}
// PageData holds common page template data
type PageData struct {
Title string
Content string
TopNav string
LeftSide string
RightSide string
TotalTime string
NumQueries string
Version string
Build string
}
// RenderPage renders a page using the layout template with common data and additional custom data
func RenderPage(ctx router.Ctx, pageData PageData, additionalData map[string]any) error {
if template.Cache == nil || auth.Manager == nil {
return fmt.Errorf("singleton template.Cache or auth.Manager not initialized")
}
layoutTmpl, err := template.Cache.Load("layout.html")
if err != nil {
return fmt.Errorf("failed to load layout template: %w", err)
}
// Build the base template data with common fields
data := map[string]any{
"title": pageData.Title,
"content": pageData.Content,
"topnav": GenerateTopNav(ctx),
"leftside": pageData.LeftSide,
"rightside": pageData.RightSide,
"totaltime": middleware.GetRequestTime(ctx),
"numqueries": pageData.NumQueries,
"version": pageData.Version,
"build": pageData.Build,
}
// Merge in additional data (overwrites common data if keys conflict)
maps.Copy(data, additionalData)
// Set defaults for empty fields
if data["leftside"] == "" {
data["leftside"] = ""
}
if data["rightside"] == "" {
data["rightside"] = ""
}
if data["numqueries"] == "" {
data["numqueries"] = "0"
}
if data["version"] == "" {
data["version"] = "1.0.0"
}
if data["build"] == "" {
data["build"] = "dev"
}
layoutTmpl.WriteTo(ctx, data)
return nil
}
// NewPageData creates a new PageData with sensible defaults
func NewPageData(title, content string) PageData {
return PageData{
Title: title,
Content: content,
LeftSide: "",
RightSide: "",
NumQueries: "0",
Version: "1.0.0",
Build: "dev",
}
}

View File

@ -12,7 +12,10 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
type Cache struct { // Cache is the global singleton instance
var Cache *TemplateCache
type TemplateCache struct {
mu sync.RWMutex mu sync.RWMutex
templates map[string]*Template templates map[string]*Template
basePath string basePath string
@ -28,10 +31,10 @@ type Template struct {
content string content string
modTime time.Time modTime time.Time
filePath string filePath string
cache *Cache cache *TemplateCache
} }
func NewCache(basePath string) *Cache { func NewCache(basePath string) *TemplateCache {
if basePath == "" { if basePath == "" {
exe, err := os.Executable() exe, err := os.Executable()
if err != nil { if err != nil {
@ -41,13 +44,18 @@ func NewCache(basePath string) *Cache {
} }
} }
return &Cache{ return &TemplateCache{
templates: make(map[string]*Template), templates: make(map[string]*Template),
basePath: basePath, basePath: basePath,
} }
} }
func (c *Cache) Load(name string) (*Template, error) { // InitializeCache initializes the global Cache singleton
func InitializeCache(basePath string) {
Cache = NewCache(basePath)
}
func (c *TemplateCache) Load(name string) (*Template, error) {
c.mu.RLock() c.mu.RLock()
tmpl, exists := c.templates[name] tmpl, exists := c.templates[name]
c.mu.RUnlock() c.mu.RUnlock()
@ -62,7 +70,7 @@ func (c *Cache) Load(name string) (*Template, error) {
return c.loadFromFile(name) return c.loadFromFile(name)
} }
func (c *Cache) loadFromFile(name string) (*Template, error) { func (c *TemplateCache) loadFromFile(name string) (*Template, error) {
filePath := filepath.Join(c.basePath, "templates", name) filePath := filepath.Join(c.basePath, "templates", name)
info, err := os.Stat(filePath) info, err := os.Stat(filePath)
@ -90,7 +98,7 @@ func (c *Cache) loadFromFile(name string) (*Template, error) {
return tmpl, nil return tmpl, nil
} }
func (c *Cache) checkAndReload(tmpl *Template) error { func (c *TemplateCache) checkAndReload(tmpl *Template) error {
info, err := os.Stat(tmpl.filePath) info, err := os.Stat(tmpl.filePath)
if err != nil { if err != nil {
return err return err