diff --git a/internal/auth/auth.go b/internal/auth/auth.go index aa2c23a..3df84c9 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,6 +6,9 @@ import ( "dk/internal/users" ) +// Manager is the global singleton instance +var Manager *AuthManager + type User struct { ID int Username string @@ -24,6 +27,11 @@ func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager { } } +// InitializeManager initializes the global Manager singleton +func InitializeManager(db *database.DB, sessionsFilePath string) { + Manager = NewAuthManager(db, sessionsFilePath) +} + func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) { var user *users.User var err error diff --git a/internal/auth/session.go b/internal/auth/session.go index e196858..f141bf1 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -17,7 +17,7 @@ const ( ) type Session struct { - ID string `json:"id"` + ID string `json:"-"` // Exclude from JSON since it's stored as the map key UserID int `json:"user_id"` Username string `json:"username"` Email string `json:"email"` diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index 84fcaae..f80c1ef 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -73,9 +73,6 @@ func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware return } - // CSRF validation passed, rotate token for security - csrf.RotateToken(ctx, authManager) - next(ctx, params) } } @@ -100,9 +97,6 @@ func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ct return } - // Rotate token after successful validation - csrf.RotateToken(ctx, authManager) - next(ctx, params) } } diff --git a/internal/routes/auth.go b/internal/routes/auth.go index 8d48219..5b10116 100644 --- a/internal/routes/auth.go +++ b/internal/routes/auth.go @@ -10,40 +10,34 @@ import ( "dk/internal/password" "dk/internal/router" "dk/internal/template" + "dk/internal/template/components" "dk/internal/users" "github.com/valyala/fasthttp" ) // RegisterAuthRoutes sets up authentication routes -func RegisterAuthRoutes(r *router.Router, authManager *auth.AuthManager, templateCache *template.Cache) { - // Guest routes (redirect to dashboard if already authenticated) +func RegisterAuthRoutes(r *router.Router) { + // Guest routes guestGroup := r.Group("") guestGroup.Use(middleware.RequireGuest("/")) - guestGroup.Get("/login", showLogin(authManager, templateCache)) - guestGroup.Post("/login", processLogin(authManager, templateCache)) - guestGroup.Get("/register", showRegister(authManager, templateCache)) - guestGroup.Post("/register", processRegister(authManager, templateCache)) + guestGroup.Get("/login", showLogin()) + guestGroup.Post("/login", processLogin()) + guestGroup.Get("/register", showRegister()) + guestGroup.Post("/register", processRegister()) // Authenticated routes authGroup := r.Group("") authGroup.Use(middleware.RequireAuth("/login")) - authGroup.Post("/logout", processLogout(authManager)) + authGroup.Post("/logout", processLogout()) } // showLogin displays the login form -func showLogin(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { +func showLogin() router.Handler { return func(ctx router.Ctx, params []string) { - layoutTmpl, err := templateCache.Load("layout.html") - if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } - - loginTmpl, err := templateCache.Load("auth/login.html") + loginTmpl, err := template.Cache.Load("auth/login.html") if err != nil { ctx.SetStatusCode(fasthttp.StatusInternalServerError) fmt.Fprintf(ctx, "Template error: %v", err) @@ -51,34 +45,27 @@ func showLogin(authManager *auth.AuthManager, templateCache *template.Cache) rou } loginFormData := map[string]any{ - "csrf_token": csrf.GetToken(ctx, authManager), - "csrf_field": csrf.HiddenField(ctx, authManager), + "csrf_token": csrf.GetToken(ctx, auth.Manager), + "csrf_field": csrf.HiddenField(ctx, auth.Manager), "error_message": "", } loginContent := loginTmpl.RenderNamed(loginFormData) - data := map[string]any{ - "title": "Login - Dragon Knight", - "content": loginContent, - "topnav": "", - "leftside": "", - "rightside": "", - "totaltime": middleware.GetRequestTime(ctx), - "numqueries": "0", - "version": "1.0.0", - "build": "dev", + pageData := components.NewPageData("Login - Dragon Knight", loginContent) + if err := components.RenderPage(ctx, pageData, nil); err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + fmt.Fprintf(ctx, "Template error: %v", err) + return } - - layoutTmpl.WriteTo(ctx, data) } } // processLogin handles login form submission -func processLogin(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { +func processLogin() router.Handler { return func(ctx router.Ctx, params []string) { // Validate CSRF token - if !csrf.ValidateFormToken(ctx, authManager) { + if !csrf.ValidateFormToken(ctx, auth.Manager) { ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.WriteString("CSRF validation failed") return @@ -90,19 +77,26 @@ func processLogin(authManager *auth.AuthManager, templateCache *template.Cache) // Validate input if email == "" || userPassword == "" { - showLoginError(ctx, authManager, templateCache, "Email and password are required") + showLoginError(ctx, "Email and password are required") return } // Authenticate user - user, err := authManager.Authenticate(email, userPassword) + user, err := auth.Manager.Authenticate(email, userPassword) if err != nil { - showLoginError(ctx, authManager, templateCache, "Invalid email or password") + showLoginError(ctx, "Invalid email or password") return } // Create session and login - middleware.Login(ctx, authManager, user) + middleware.Login(ctx, auth.Manager, user) + + // Transfer CSRF token from cookie to session for authenticated user + if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { + if session := csrf.GetCurrentSession(ctx); session != nil { + csrf.StoreToken(session, cookieToken) + } + } // Redirect to dashboard ctx.Redirect("/dashboard", fasthttp.StatusFound) @@ -110,16 +104,9 @@ func processLogin(authManager *auth.AuthManager, templateCache *template.Cache) } // showRegister displays the registration form -func showRegister(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { +func showRegister() router.Handler { return func(ctx router.Ctx, params []string) { - layoutTmpl, err := templateCache.Load("layout.html") - if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } - - registerTmpl, err := templateCache.Load("auth/register.html") + registerTmpl, err := template.Cache.Load("auth/register.html") if err != nil { ctx.SetStatusCode(fasthttp.StatusInternalServerError) fmt.Fprintf(ctx, "Template error: %v", err) @@ -127,8 +114,8 @@ func showRegister(authManager *auth.AuthManager, templateCache *template.Cache) } registerFormData := map[string]any{ - "csrf_token": csrf.GetToken(ctx, authManager), - "csrf_field": csrf.HiddenField(ctx, authManager), + "csrf_token": csrf.GetToken(ctx, auth.Manager), + "csrf_field": csrf.HiddenField(ctx, auth.Manager), "error_message": "", "username": "", "email": "", @@ -136,27 +123,20 @@ func showRegister(authManager *auth.AuthManager, templateCache *template.Cache) registerContent := registerTmpl.RenderNamed(registerFormData) - data := map[string]any{ - "title": "Register - Dragon Knight", - "content": registerContent, - "topnav": "", - "leftside": "", - "rightside": "", - "totaltime": middleware.GetRequestTime(ctx), - "numqueries": "0", - "version": "1.0.0", - "build": "dev", + pageData := components.NewPageData("Register - Dragon Knight", registerContent) + if err := components.RenderPage(ctx, pageData, nil); err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + fmt.Fprintf(ctx, "Template error: %v", err) + return } - - layoutTmpl.WriteTo(ctx, data) } } // processRegister handles registration form submission -func processRegister(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler { +func processRegister() router.Handler { return func(ctx router.Ctx, params []string) { // Validate CSRF token - if !csrf.ValidateFormToken(ctx, authManager) { + if !csrf.ValidateFormToken(ctx, auth.Manager) { ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.WriteString("CSRF validation failed") return @@ -170,26 +150,26 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach // Validate input if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil { - showRegisterError(ctx, authManager, templateCache, err.Error(), username, email) + showRegisterError(ctx, err.Error(), username, email) return } // Check if username already exists - if _, err := users.GetByUsername(authManager.DB(), username); err == nil { - showRegisterError(ctx, authManager, templateCache, "Username already exists", username, email) + if _, err := users.GetByUsername(auth.Manager.DB(), username); err == nil { + showRegisterError(ctx, "Username already exists", username, email) return } // Check if email already exists - if _, err := users.GetByEmail(authManager.DB(), email); err == nil { - showRegisterError(ctx, authManager, templateCache, "Email already registered", username, email) + if _, err := users.GetByEmail(auth.Manager.DB(), email); err == nil { + showRegisterError(ctx, "Email already registered", username, email) return } // Hash password hashedPassword, err := password.Hash(userPassword) if err != nil { - showRegisterError(ctx, authManager, templateCache, "Failed to process password", username, email) + showRegisterError(ctx, "Failed to process password", username, email) return } @@ -203,8 +183,8 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach } // Insert into database - if err := createUser(authManager, user); err != nil { - showRegisterError(ctx, authManager, templateCache, "Failed to create account", username, email) + if err := createUser(user); err != nil { + showRegisterError(ctx, "Failed to create account", username, email) return } @@ -215,38 +195,38 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach Email: user.Email, } - middleware.Login(ctx, authManager, authUser) + middleware.Login(ctx, auth.Manager, authUser) + + // Transfer CSRF token from cookie to session for authenticated user + if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { + if session := csrf.GetCurrentSession(ctx); session != nil { + csrf.StoreToken(session, cookieToken) + } + } ctx.Redirect("/", fasthttp.StatusFound) } } // processLogout handles logout -func processLogout(authManager *auth.AuthManager) router.Handler { +func processLogout() router.Handler { return func(ctx router.Ctx, params []string) { // Validate CSRF token - if !csrf.ValidateFormToken(ctx, authManager) { + if !csrf.ValidateFormToken(ctx, auth.Manager) { ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.WriteString("CSRF validation failed") return } - middleware.Logout(ctx, authManager) + middleware.Logout(ctx, auth.Manager) ctx.Redirect("/", fasthttp.StatusFound) } } // Helper functions -func showLoginError(ctx router.Ctx, authManager *auth.AuthManager, templateCache *template.Cache, errorMsg string) { - layoutTmpl, err := templateCache.Load("layout.html") - if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } - - loginTmpl, err := templateCache.Load("auth/login.html") +func showLoginError(ctx router.Ctx, errorMsg string) { + loginTmpl, err := template.Cache.Load("auth/login.html") if err != nil { ctx.SetStatusCode(fasthttp.StatusInternalServerError) fmt.Fprintf(ctx, "Template error: %v", err) @@ -259,38 +239,24 @@ func showLoginError(ctx router.Ctx, authManager *auth.AuthManager, templateCache } loginFormData := map[string]any{ - "csrf_token": csrf.GetToken(ctx, authManager), - "csrf_field": csrf.HiddenField(ctx, authManager), + "csrf_token": csrf.GetToken(ctx, auth.Manager), + "csrf_field": csrf.HiddenField(ctx, auth.Manager), "error_message": errorHTML, } loginContent := loginTmpl.RenderNamed(loginFormData) - data := map[string]any{ - "title": "Login - Dragon Knight", - "content": loginContent, - "topnav": "", - "leftside": "", - "rightside": "", - "totaltime": middleware.GetRequestTime(ctx), - "numqueries": "0", - "version": "1.0.0", - "build": "dev", - } - ctx.SetStatusCode(fasthttp.StatusBadRequest) - layoutTmpl.WriteTo(ctx, data) -} - -func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCache *template.Cache, errorMsg, username, email string) { - layoutTmpl, err := templateCache.Load("layout.html") - if err != nil { + pageData := components.NewPageData("Login - Dragon Knight", loginContent) + if err := components.RenderPage(ctx, pageData, nil); err != nil { ctx.SetStatusCode(fasthttp.StatusInternalServerError) fmt.Fprintf(ctx, "Template error: %v", err) return } +} - registerTmpl, err := templateCache.Load("auth/register.html") +func showRegisterError(ctx router.Ctx, errorMsg, username, email string) { + registerTmpl, err := template.Cache.Load("auth/register.html") if err != nil { ctx.SetStatusCode(fasthttp.StatusInternalServerError) fmt.Fprintf(ctx, "Template error: %v", err) @@ -303,8 +269,8 @@ func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCa } registerFormData := map[string]any{ - "csrf_token": csrf.GetToken(ctx, authManager), - "csrf_field": csrf.HiddenField(ctx, authManager), + "csrf_token": csrf.GetToken(ctx, auth.Manager), + "csrf_field": csrf.HiddenField(ctx, auth.Manager), "error_message": errorHTML, "username": username, "email": email, @@ -312,20 +278,13 @@ func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCa registerContent := registerTmpl.RenderNamed(registerFormData) - data := map[string]any{ - "title": "Register - Dragon Knight", - "content": registerContent, - "topnav": "", - "leftside": "", - "rightside": "", - "totaltime": middleware.GetRequestTime(ctx), - "numqueries": "0", - "version": "1.0.0", - "build": "dev", - } - ctx.SetStatusCode(fasthttp.StatusBadRequest) - layoutTmpl.WriteTo(ctx, data) + pageData := components.NewPageData("Register - Dragon Knight", registerContent) + if err := components.RenderPage(ctx, pageData, nil); err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + fmt.Fprintf(ctx, "Template error: %v", err) + return + } } func validateRegistration(username, email, password, confirmPassword string) error { @@ -355,8 +314,8 @@ func validateRegistration(username, email, password, confirmPassword string) err // createUser inserts a new user into the database // This is a simplified version - in a real app you'd have a proper users.Create function -func createUser(authManager *auth.AuthManager, user *users.User) error { - db := authManager.DB() +func createUser(user *users.User) error { + db := auth.Manager.DB() query := `INSERT INTO users (username, password, email, verified, auth) VALUES (?, ?, ?, ?, ?)` diff --git a/internal/server/server.go b/internal/server/server.go index 7d7fa35..beaba3a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,77 +14,64 @@ import ( "dk/internal/router" "dk/internal/routes" "dk/internal/template" + "dk/internal/template/components" "github.com/valyala/fasthttp" ) func Start(port string) error { - // Initialize template cache - use current working directory for development cwd, err := os.Getwd() if err != nil { return fmt.Errorf("failed to get current working directory: %w", err) } - templateCache := template.NewCache(cwd) + // Initialize template singleton + template.InitializeCache(cwd) - // Initialize database db, err := database.Open("dk.db") if err != nil { return fmt.Errorf("failed to open database: %w", err) } defer db.Close() - // Initialize authentication manager - authManager := auth.NewAuthManager(db, "sessions.json") - // Don't defer Close() here - we'll handle it in shutdown + // Initialize auth singleton + auth.InitializeManager(db, "sessions.json") // Initialize router r := router.New() // Add middleware r.Use(middleware.Timing()) - r.Use(middleware.Auth(authManager)) - r.Use(middleware.CSRF(authManager)) + r.Use(middleware.Auth(auth.Manager)) + r.Use(middleware.CSRF(auth.Manager)) // Setup route handlers - routes.RegisterAuthRoutes(r, authManager, templateCache) + routes.RegisterAuthRoutes(r) // Dashboard (protected route) r.WithMiddleware(middleware.RequireAuth("/login")).Get("/dashboard", func(ctx router.Ctx, params []string) { - tmpl, err := templateCache.Load("layout.html") - if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } - currentUser := middleware.GetCurrentUser(ctx) - totalSessions, activeSessions := authManager.SessionStats() + totalSessions, activeSessions := auth.Manager.SessionStats() - data := map[string]any{ - "title": "Dashboard - Dragon Knight", - "content": fmt.Sprintf("Welcome back, %s!", currentUser.Username), - "totaltime": middleware.GetRequestTime(ctx), - "numqueries": "0", - "version": "1.0.0", - "build": "dev", + pageData := components.NewPageData( + "Dashboard - Dragon Knight", + fmt.Sprintf("Welcome back, %s!", currentUser.Username), + ) + + additionalData := map[string]any{ "total_sessions": totalSessions, "active_sessions": activeSessions, "authenticated": true, "username": currentUser.Username, } - tmpl.WriteTo(ctx, data) + if err := components.RenderPage(ctx, pageData, additionalData); err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + fmt.Fprintf(ctx, "Template error: %v", err) + } }) // Hello world endpoint (public) r.Get("/", func(ctx router.Ctx, params []string) { - tmpl, err := templateCache.Load("layout.html") - if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } - // Get current user if authenticated currentUser := middleware.GetCurrentUser(ctx) var username string @@ -94,22 +81,25 @@ func Start(port string) error { username = "Guest" } - totalSessions, activeSessions := authManager.SessionStats() + totalSessions, activeSessions := auth.Manager.SessionStats() - data := map[string]any{ - "title": "Dragon Knight", - "content": fmt.Sprintf("Hello %s!", username), - "totaltime": middleware.GetRequestTime(ctx), - "numqueries": "0", // Placeholder for now - "version": "1.0.0", - "build": "dev", + pageData := components.NewPageData( + "Dragon Knight", + fmt.Sprintf("Hello %s!", username), + ) + + additionalData := map[string]any{ "total_sessions": totalSessions, "active_sessions": activeSessions, "authenticated": currentUser != nil, "username": username, } - tmpl.WriteTo(ctx, data) + if err := components.RenderPage(ctx, pageData, additionalData); err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + fmt.Fprintf(ctx, "Template error: %v", err) + return + } }) // Use current working directory for static files @@ -165,7 +155,7 @@ func Start(port string) error { // Save sessions before shutdown log.Println("Saving sessions...") - if err := authManager.Close(); err != nil { + if err := auth.Manager.Close(); err != nil { log.Printf("Error saving sessions: %v", err) } diff --git a/internal/template/components/components.go b/internal/template/components/components.go new file mode 100644 index 0000000..08be0e6 --- /dev/null +++ b/internal/template/components/components.go @@ -0,0 +1,102 @@ +package components + +import ( + "fmt" + "maps" + + "dk/internal/auth" + "dk/internal/csrf" + "dk/internal/middleware" + "dk/internal/router" + "dk/internal/template" +) + +// GenerateTopNav generates the top navigation HTML based on authentication status +func GenerateTopNav(ctx router.Ctx) string { + if middleware.IsAuthenticated(ctx) { + csrfField := csrf.HiddenField(ctx, auth.Manager) + return fmt.Sprintf(`
+ %s + +
+ Help`, csrfField) + } else { + return `Log In + Register + Help` + } +} + +// PageData holds common page template data +type PageData struct { + Title string + Content string + TopNav string + LeftSide string + RightSide string + TotalTime string + NumQueries string + Version string + Build string +} + +// RenderPage renders a page using the layout template with common data and additional custom data +func RenderPage(ctx router.Ctx, pageData PageData, additionalData map[string]any) error { + if template.Cache == nil || auth.Manager == nil { + return fmt.Errorf("singleton template.Cache or auth.Manager not initialized") + } + + layoutTmpl, err := template.Cache.Load("layout.html") + if err != nil { + return fmt.Errorf("failed to load layout template: %w", err) + } + + // Build the base template data with common fields + data := map[string]any{ + "title": pageData.Title, + "content": pageData.Content, + "topnav": GenerateTopNav(ctx), + "leftside": pageData.LeftSide, + "rightside": pageData.RightSide, + "totaltime": middleware.GetRequestTime(ctx), + "numqueries": pageData.NumQueries, + "version": pageData.Version, + "build": pageData.Build, + } + + // Merge in additional data (overwrites common data if keys conflict) + maps.Copy(data, additionalData) + + // Set defaults for empty fields + if data["leftside"] == "" { + data["leftside"] = "" + } + if data["rightside"] == "" { + data["rightside"] = "" + } + if data["numqueries"] == "" { + data["numqueries"] = "0" + } + if data["version"] == "" { + data["version"] = "1.0.0" + } + if data["build"] == "" { + data["build"] = "dev" + } + + layoutTmpl.WriteTo(ctx, data) + return nil +} + +// NewPageData creates a new PageData with sensible defaults +func NewPageData(title, content string) PageData { + return PageData{ + Title: title, + Content: content, + LeftSide: "", + RightSide: "", + NumQueries: "0", + Version: "1.0.0", + Build: "dev", + } +} diff --git a/internal/template/template.go b/internal/template/template.go index 086aa47..367c683 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -12,7 +12,10 @@ import ( "github.com/valyala/fasthttp" ) -type Cache struct { +// Cache is the global singleton instance +var Cache *TemplateCache + +type TemplateCache struct { mu sync.RWMutex templates map[string]*Template basePath string @@ -28,10 +31,10 @@ type Template struct { content string modTime time.Time filePath string - cache *Cache + cache *TemplateCache } -func NewCache(basePath string) *Cache { +func NewCache(basePath string) *TemplateCache { if basePath == "" { exe, err := os.Executable() if err != nil { @@ -41,13 +44,18 @@ func NewCache(basePath string) *Cache { } } - return &Cache{ + return &TemplateCache{ templates: make(map[string]*Template), basePath: basePath, } } -func (c *Cache) Load(name string) (*Template, error) { +// InitializeCache initializes the global Cache singleton +func InitializeCache(basePath string) { + Cache = NewCache(basePath) +} + +func (c *TemplateCache) Load(name string) (*Template, error) { c.mu.RLock() tmpl, exists := c.templates[name] c.mu.RUnlock() @@ -62,7 +70,7 @@ func (c *Cache) Load(name string) (*Template, error) { return c.loadFromFile(name) } -func (c *Cache) loadFromFile(name string) (*Template, error) { +func (c *TemplateCache) loadFromFile(name string) (*Template, error) { filePath := filepath.Join(c.basePath, "templates", name) info, err := os.Stat(filePath) @@ -90,7 +98,7 @@ func (c *Cache) loadFromFile(name string) (*Template, error) { return tmpl, nil } -func (c *Cache) checkAndReload(tmpl *Template) error { +func (c *TemplateCache) checkAndReload(tmpl *Template) error { info, err := os.Stat(tmpl.filePath) if err != nil { return err