fix csrf and simplify global utilities
This commit is contained in:
parent
56dca44815
commit
0534da09a1
@ -6,6 +6,9 @@ import (
|
|||||||
"dk/internal/users"
|
"dk/internal/users"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Manager is the global singleton instance
|
||||||
|
var Manager *AuthManager
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int
|
ID int
|
||||||
Username string
|
Username string
|
||||||
@ -24,6 +27,11 @@ func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializeManager initializes the global Manager singleton
|
||||||
|
func InitializeManager(db *database.DB, sessionsFilePath string) {
|
||||||
|
Manager = NewAuthManager(db, sessionsFilePath)
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
|
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
|
||||||
var user *users.User
|
var user *users.User
|
||||||
var err error
|
var err error
|
||||||
|
@ -17,7 +17,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"-"` // Exclude from JSON since it's stored as the map key
|
||||||
UserID int `json:"user_id"`
|
UserID int `json:"user_id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
@ -73,9 +73,6 @@ func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRF validation passed, rotate token for security
|
|
||||||
csrf.RotateToken(ctx, authManager)
|
|
||||||
|
|
||||||
next(ctx, params)
|
next(ctx, params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,9 +97,6 @@ func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ct
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rotate token after successful validation
|
|
||||||
csrf.RotateToken(ctx, authManager)
|
|
||||||
|
|
||||||
next(ctx, params)
|
next(ctx, params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,40 +10,34 @@ import (
|
|||||||
"dk/internal/password"
|
"dk/internal/password"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
"dk/internal/template"
|
"dk/internal/template"
|
||||||
|
"dk/internal/template/components"
|
||||||
"dk/internal/users"
|
"dk/internal/users"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterAuthRoutes sets up authentication routes
|
// RegisterAuthRoutes sets up authentication routes
|
||||||
func RegisterAuthRoutes(r *router.Router, authManager *auth.AuthManager, templateCache *template.Cache) {
|
func RegisterAuthRoutes(r *router.Router) {
|
||||||
// Guest routes (redirect to dashboard if already authenticated)
|
// Guest routes
|
||||||
guestGroup := r.Group("")
|
guestGroup := r.Group("")
|
||||||
guestGroup.Use(middleware.RequireGuest("/"))
|
guestGroup.Use(middleware.RequireGuest("/"))
|
||||||
|
|
||||||
guestGroup.Get("/login", showLogin(authManager, templateCache))
|
guestGroup.Get("/login", showLogin())
|
||||||
guestGroup.Post("/login", processLogin(authManager, templateCache))
|
guestGroup.Post("/login", processLogin())
|
||||||
guestGroup.Get("/register", showRegister(authManager, templateCache))
|
guestGroup.Get("/register", showRegister())
|
||||||
guestGroup.Post("/register", processRegister(authManager, templateCache))
|
guestGroup.Post("/register", processRegister())
|
||||||
|
|
||||||
// Authenticated routes
|
// Authenticated routes
|
||||||
authGroup := r.Group("")
|
authGroup := r.Group("")
|
||||||
authGroup.Use(middleware.RequireAuth("/login"))
|
authGroup.Use(middleware.RequireAuth("/login"))
|
||||||
|
|
||||||
authGroup.Post("/logout", processLogout(authManager))
|
authGroup.Post("/logout", processLogout())
|
||||||
}
|
}
|
||||||
|
|
||||||
// showLogin displays the login form
|
// showLogin displays the login form
|
||||||
func showLogin(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler {
|
func showLogin() router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
layoutTmpl, err := templateCache.Load("layout.html")
|
loginTmpl, err := template.Cache.Load("auth/login.html")
|
||||||
if err != nil {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
loginTmpl, err := templateCache.Load("auth/login.html")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
@ -51,34 +45,27 @@ func showLogin(authManager *auth.AuthManager, templateCache *template.Cache) rou
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginFormData := map[string]any{
|
loginFormData := map[string]any{
|
||||||
"csrf_token": csrf.GetToken(ctx, authManager),
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
"csrf_field": csrf.HiddenField(ctx, authManager),
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
"error_message": "",
|
"error_message": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
loginContent := loginTmpl.RenderNamed(loginFormData)
|
loginContent := loginTmpl.RenderNamed(loginFormData)
|
||||||
|
|
||||||
data := map[string]any{
|
pageData := components.NewPageData("Login - Dragon Knight", loginContent)
|
||||||
"title": "Login - Dragon Knight",
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
"content": loginContent,
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
"topnav": "",
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
"leftside": "",
|
return
|
||||||
"rightside": "",
|
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
|
||||||
"numqueries": "0",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"build": "dev",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
layoutTmpl.WriteTo(ctx, data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// processLogin handles login form submission
|
// processLogin handles login form submission
|
||||||
func processLogin(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler {
|
func processLogin() router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
// Validate CSRF token
|
// Validate CSRF token
|
||||||
if !csrf.ValidateFormToken(ctx, authManager) {
|
if !csrf.ValidateFormToken(ctx, auth.Manager) {
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
ctx.WriteString("CSRF validation failed")
|
ctx.WriteString("CSRF validation failed")
|
||||||
return
|
return
|
||||||
@ -90,19 +77,26 @@ func processLogin(authManager *auth.AuthManager, templateCache *template.Cache)
|
|||||||
|
|
||||||
// Validate input
|
// Validate input
|
||||||
if email == "" || userPassword == "" {
|
if email == "" || userPassword == "" {
|
||||||
showLoginError(ctx, authManager, templateCache, "Email and password are required")
|
showLoginError(ctx, "Email and password are required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate user
|
// Authenticate user
|
||||||
user, err := authManager.Authenticate(email, userPassword)
|
user, err := auth.Manager.Authenticate(email, userPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showLoginError(ctx, authManager, templateCache, "Invalid email or password")
|
showLoginError(ctx, "Invalid email or password")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create session and login
|
// Create session and login
|
||||||
middleware.Login(ctx, authManager, user)
|
middleware.Login(ctx, auth.Manager, user)
|
||||||
|
|
||||||
|
// Transfer CSRF token from cookie to session for authenticated user
|
||||||
|
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||||
|
if session := csrf.GetCurrentSession(ctx); session != nil {
|
||||||
|
csrf.StoreToken(session, cookieToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Redirect to dashboard
|
// Redirect to dashboard
|
||||||
ctx.Redirect("/dashboard", fasthttp.StatusFound)
|
ctx.Redirect("/dashboard", fasthttp.StatusFound)
|
||||||
@ -110,16 +104,9 @@ func processLogin(authManager *auth.AuthManager, templateCache *template.Cache)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// showRegister displays the registration form
|
// showRegister displays the registration form
|
||||||
func showRegister(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler {
|
func showRegister() router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
layoutTmpl, err := templateCache.Load("layout.html")
|
registerTmpl, err := template.Cache.Load("auth/register.html")
|
||||||
if err != nil {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
registerTmpl, err := templateCache.Load("auth/register.html")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
@ -127,8 +114,8 @@ func showRegister(authManager *auth.AuthManager, templateCache *template.Cache)
|
|||||||
}
|
}
|
||||||
|
|
||||||
registerFormData := map[string]any{
|
registerFormData := map[string]any{
|
||||||
"csrf_token": csrf.GetToken(ctx, authManager),
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
"csrf_field": csrf.HiddenField(ctx, authManager),
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
"error_message": "",
|
"error_message": "",
|
||||||
"username": "",
|
"username": "",
|
||||||
"email": "",
|
"email": "",
|
||||||
@ -136,27 +123,20 @@ func showRegister(authManager *auth.AuthManager, templateCache *template.Cache)
|
|||||||
|
|
||||||
registerContent := registerTmpl.RenderNamed(registerFormData)
|
registerContent := registerTmpl.RenderNamed(registerFormData)
|
||||||
|
|
||||||
data := map[string]any{
|
pageData := components.NewPageData("Register - Dragon Knight", registerContent)
|
||||||
"title": "Register - Dragon Knight",
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
"content": registerContent,
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
"topnav": "",
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
"leftside": "",
|
return
|
||||||
"rightside": "",
|
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
|
||||||
"numqueries": "0",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"build": "dev",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
layoutTmpl.WriteTo(ctx, data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// processRegister handles registration form submission
|
// processRegister handles registration form submission
|
||||||
func processRegister(authManager *auth.AuthManager, templateCache *template.Cache) router.Handler {
|
func processRegister() router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
// Validate CSRF token
|
// Validate CSRF token
|
||||||
if !csrf.ValidateFormToken(ctx, authManager) {
|
if !csrf.ValidateFormToken(ctx, auth.Manager) {
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
ctx.WriteString("CSRF validation failed")
|
ctx.WriteString("CSRF validation failed")
|
||||||
return
|
return
|
||||||
@ -170,26 +150,26 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach
|
|||||||
|
|
||||||
// Validate input
|
// Validate input
|
||||||
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
|
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
|
||||||
showRegisterError(ctx, authManager, templateCache, err.Error(), username, email)
|
showRegisterError(ctx, err.Error(), username, email)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if username already exists
|
// Check if username already exists
|
||||||
if _, err := users.GetByUsername(authManager.DB(), username); err == nil {
|
if _, err := users.GetByUsername(auth.Manager.DB(), username); err == nil {
|
||||||
showRegisterError(ctx, authManager, templateCache, "Username already exists", username, email)
|
showRegisterError(ctx, "Username already exists", username, email)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if email already exists
|
// Check if email already exists
|
||||||
if _, err := users.GetByEmail(authManager.DB(), email); err == nil {
|
if _, err := users.GetByEmail(auth.Manager.DB(), email); err == nil {
|
||||||
showRegisterError(ctx, authManager, templateCache, "Email already registered", username, email)
|
showRegisterError(ctx, "Email already registered", username, email)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash password
|
// Hash password
|
||||||
hashedPassword, err := password.Hash(userPassword)
|
hashedPassword, err := password.Hash(userPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showRegisterError(ctx, authManager, templateCache, "Failed to process password", username, email)
|
showRegisterError(ctx, "Failed to process password", username, email)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,8 +183,8 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Insert into database
|
// Insert into database
|
||||||
if err := createUser(authManager, user); err != nil {
|
if err := createUser(user); err != nil {
|
||||||
showRegisterError(ctx, authManager, templateCache, "Failed to create account", username, email)
|
showRegisterError(ctx, "Failed to create account", username, email)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,38 +195,38 @@ func processRegister(authManager *auth.AuthManager, templateCache *template.Cach
|
|||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
}
|
}
|
||||||
|
|
||||||
middleware.Login(ctx, authManager, authUser)
|
middleware.Login(ctx, auth.Manager, authUser)
|
||||||
|
|
||||||
|
// Transfer CSRF token from cookie to session for authenticated user
|
||||||
|
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||||
|
if session := csrf.GetCurrentSession(ctx); session != nil {
|
||||||
|
csrf.StoreToken(session, cookieToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ctx.Redirect("/", fasthttp.StatusFound)
|
ctx.Redirect("/", fasthttp.StatusFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// processLogout handles logout
|
// processLogout handles logout
|
||||||
func processLogout(authManager *auth.AuthManager) router.Handler {
|
func processLogout() router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
// Validate CSRF token
|
// Validate CSRF token
|
||||||
if !csrf.ValidateFormToken(ctx, authManager) {
|
if !csrf.ValidateFormToken(ctx, auth.Manager) {
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
ctx.WriteString("CSRF validation failed")
|
ctx.WriteString("CSRF validation failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
middleware.Logout(ctx, authManager)
|
middleware.Logout(ctx, auth.Manager)
|
||||||
ctx.Redirect("/", fasthttp.StatusFound)
|
ctx.Redirect("/", fasthttp.StatusFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
|
|
||||||
func showLoginError(ctx router.Ctx, authManager *auth.AuthManager, templateCache *template.Cache, errorMsg string) {
|
func showLoginError(ctx router.Ctx, errorMsg string) {
|
||||||
layoutTmpl, err := templateCache.Load("layout.html")
|
loginTmpl, err := template.Cache.Load("auth/login.html")
|
||||||
if err != nil {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
loginTmpl, err := templateCache.Load("auth/login.html")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
@ -259,38 +239,24 @@ func showLoginError(ctx router.Ctx, authManager *auth.AuthManager, templateCache
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginFormData := map[string]any{
|
loginFormData := map[string]any{
|
||||||
"csrf_token": csrf.GetToken(ctx, authManager),
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
"csrf_field": csrf.HiddenField(ctx, authManager),
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
"error_message": errorHTML,
|
"error_message": errorHTML,
|
||||||
}
|
}
|
||||||
|
|
||||||
loginContent := loginTmpl.RenderNamed(loginFormData)
|
loginContent := loginTmpl.RenderNamed(loginFormData)
|
||||||
|
|
||||||
data := map[string]any{
|
|
||||||
"title": "Login - Dragon Knight",
|
|
||||||
"content": loginContent,
|
|
||||||
"topnav": "",
|
|
||||||
"leftside": "",
|
|
||||||
"rightside": "",
|
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
|
||||||
"numqueries": "0",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"build": "dev",
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||||
layoutTmpl.WriteTo(ctx, data)
|
pageData := components.NewPageData("Login - Dragon Knight", loginContent)
|
||||||
}
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
|
|
||||||
func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCache *template.Cache, errorMsg, username, email string) {
|
|
||||||
layoutTmpl, err := templateCache.Load("layout.html")
|
|
||||||
if err != nil {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
registerTmpl, err := templateCache.Load("auth/register.html")
|
func showRegisterError(ctx router.Ctx, errorMsg, username, email string) {
|
||||||
|
registerTmpl, err := template.Cache.Load("auth/register.html")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
@ -303,8 +269,8 @@ func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCa
|
|||||||
}
|
}
|
||||||
|
|
||||||
registerFormData := map[string]any{
|
registerFormData := map[string]any{
|
||||||
"csrf_token": csrf.GetToken(ctx, authManager),
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
"csrf_field": csrf.HiddenField(ctx, authManager),
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
"error_message": errorHTML,
|
"error_message": errorHTML,
|
||||||
"username": username,
|
"username": username,
|
||||||
"email": email,
|
"email": email,
|
||||||
@ -312,20 +278,13 @@ func showRegisterError(ctx router.Ctx, authManager *auth.AuthManager, templateCa
|
|||||||
|
|
||||||
registerContent := registerTmpl.RenderNamed(registerFormData)
|
registerContent := registerTmpl.RenderNamed(registerFormData)
|
||||||
|
|
||||||
data := map[string]any{
|
|
||||||
"title": "Register - Dragon Knight",
|
|
||||||
"content": registerContent,
|
|
||||||
"topnav": "",
|
|
||||||
"leftside": "",
|
|
||||||
"rightside": "",
|
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
|
||||||
"numqueries": "0",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"build": "dev",
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||||
layoutTmpl.WriteTo(ctx, data)
|
pageData := components.NewPageData("Register - Dragon Knight", registerContent)
|
||||||
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRegistration(username, email, password, confirmPassword string) error {
|
func validateRegistration(username, email, password, confirmPassword string) error {
|
||||||
@ -355,8 +314,8 @@ func validateRegistration(username, email, password, confirmPassword string) err
|
|||||||
|
|
||||||
// createUser inserts a new user into the database
|
// createUser inserts a new user into the database
|
||||||
// This is a simplified version - in a real app you'd have a proper users.Create function
|
// This is a simplified version - in a real app you'd have a proper users.Create function
|
||||||
func createUser(authManager *auth.AuthManager, user *users.User) error {
|
func createUser(user *users.User) error {
|
||||||
db := authManager.DB()
|
db := auth.Manager.DB()
|
||||||
|
|
||||||
query := `INSERT INTO users (username, password, email, verified, auth) VALUES (?, ?, ?, ?, ?)`
|
query := `INSERT INTO users (username, password, email, verified, auth) VALUES (?, ?, ?, ?, ?)`
|
||||||
|
|
||||||
|
@ -14,77 +14,64 @@ import (
|
|||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
"dk/internal/routes"
|
"dk/internal/routes"
|
||||||
"dk/internal/template"
|
"dk/internal/template"
|
||||||
|
"dk/internal/template/components"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Start(port string) error {
|
func Start(port string) error {
|
||||||
// Initialize template cache - use current working directory for development
|
|
||||||
cwd, err := os.Getwd()
|
cwd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get current working directory: %w", err)
|
return fmt.Errorf("failed to get current working directory: %w", err)
|
||||||
}
|
}
|
||||||
templateCache := template.NewCache(cwd)
|
// Initialize template singleton
|
||||||
|
template.InitializeCache(cwd)
|
||||||
|
|
||||||
// Initialize database
|
|
||||||
db, err := database.Open("dk.db")
|
db, err := database.Open("dk.db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open database: %w", err)
|
return fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Initialize authentication manager
|
// Initialize auth singleton
|
||||||
authManager := auth.NewAuthManager(db, "sessions.json")
|
auth.InitializeManager(db, "sessions.json")
|
||||||
// Don't defer Close() here - we'll handle it in shutdown
|
|
||||||
|
|
||||||
// Initialize router
|
// Initialize router
|
||||||
r := router.New()
|
r := router.New()
|
||||||
|
|
||||||
// Add middleware
|
// Add middleware
|
||||||
r.Use(middleware.Timing())
|
r.Use(middleware.Timing())
|
||||||
r.Use(middleware.Auth(authManager))
|
r.Use(middleware.Auth(auth.Manager))
|
||||||
r.Use(middleware.CSRF(authManager))
|
r.Use(middleware.CSRF(auth.Manager))
|
||||||
|
|
||||||
// Setup route handlers
|
// Setup route handlers
|
||||||
routes.RegisterAuthRoutes(r, authManager, templateCache)
|
routes.RegisterAuthRoutes(r)
|
||||||
|
|
||||||
// Dashboard (protected route)
|
// Dashboard (protected route)
|
||||||
r.WithMiddleware(middleware.RequireAuth("/login")).Get("/dashboard", func(ctx router.Ctx, params []string) {
|
r.WithMiddleware(middleware.RequireAuth("/login")).Get("/dashboard", func(ctx router.Ctx, params []string) {
|
||||||
tmpl, err := templateCache.Load("layout.html")
|
|
||||||
if err != nil {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
currentUser := middleware.GetCurrentUser(ctx)
|
currentUser := middleware.GetCurrentUser(ctx)
|
||||||
totalSessions, activeSessions := authManager.SessionStats()
|
totalSessions, activeSessions := auth.Manager.SessionStats()
|
||||||
|
|
||||||
data := map[string]any{
|
pageData := components.NewPageData(
|
||||||
"title": "Dashboard - Dragon Knight",
|
"Dashboard - Dragon Knight",
|
||||||
"content": fmt.Sprintf("Welcome back, %s!", currentUser.Username),
|
fmt.Sprintf("Welcome back, %s!", currentUser.Username),
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
)
|
||||||
"numqueries": "0",
|
|
||||||
"version": "1.0.0",
|
additionalData := map[string]any{
|
||||||
"build": "dev",
|
|
||||||
"total_sessions": totalSessions,
|
"total_sessions": totalSessions,
|
||||||
"active_sessions": activeSessions,
|
"active_sessions": activeSessions,
|
||||||
"authenticated": true,
|
"authenticated": true,
|
||||||
"username": currentUser.Username,
|
"username": currentUser.Username,
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl.WriteTo(ctx, data)
|
if err := components.RenderPage(ctx, pageData, additionalData); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Hello world endpoint (public)
|
// Hello world endpoint (public)
|
||||||
r.Get("/", func(ctx router.Ctx, params []string) {
|
r.Get("/", func(ctx router.Ctx, params []string) {
|
||||||
tmpl, err := templateCache.Load("layout.html")
|
|
||||||
if err != nil {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get current user if authenticated
|
// Get current user if authenticated
|
||||||
currentUser := middleware.GetCurrentUser(ctx)
|
currentUser := middleware.GetCurrentUser(ctx)
|
||||||
var username string
|
var username string
|
||||||
@ -94,22 +81,25 @@ func Start(port string) error {
|
|||||||
username = "Guest"
|
username = "Guest"
|
||||||
}
|
}
|
||||||
|
|
||||||
totalSessions, activeSessions := authManager.SessionStats()
|
totalSessions, activeSessions := auth.Manager.SessionStats()
|
||||||
|
|
||||||
data := map[string]any{
|
pageData := components.NewPageData(
|
||||||
"title": "Dragon Knight",
|
"Dragon Knight",
|
||||||
"content": fmt.Sprintf("Hello %s!", username),
|
fmt.Sprintf("Hello %s!", username),
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
)
|
||||||
"numqueries": "0", // Placeholder for now
|
|
||||||
"version": "1.0.0",
|
additionalData := map[string]any{
|
||||||
"build": "dev",
|
|
||||||
"total_sessions": totalSessions,
|
"total_sessions": totalSessions,
|
||||||
"active_sessions": activeSessions,
|
"active_sessions": activeSessions,
|
||||||
"authenticated": currentUser != nil,
|
"authenticated": currentUser != nil,
|
||||||
"username": username,
|
"username": username,
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl.WriteTo(ctx, data)
|
if err := components.RenderPage(ctx, pageData, additionalData); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Use current working directory for static files
|
// Use current working directory for static files
|
||||||
@ -165,7 +155,7 @@ func Start(port string) error {
|
|||||||
|
|
||||||
// Save sessions before shutdown
|
// Save sessions before shutdown
|
||||||
log.Println("Saving sessions...")
|
log.Println("Saving sessions...")
|
||||||
if err := authManager.Close(); err != nil {
|
if err := auth.Manager.Close(); err != nil {
|
||||||
log.Printf("Error saving sessions: %v", err)
|
log.Printf("Error saving sessions: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
102
internal/template/components/components.go
Normal file
102
internal/template/components/components.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package components
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/csrf"
|
||||||
|
"dk/internal/middleware"
|
||||||
|
"dk/internal/router"
|
||||||
|
"dk/internal/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateTopNav generates the top navigation HTML based on authentication status
|
||||||
|
func GenerateTopNav(ctx router.Ctx) string {
|
||||||
|
if middleware.IsAuthenticated(ctx) {
|
||||||
|
csrfField := csrf.HiddenField(ctx, auth.Manager)
|
||||||
|
return fmt.Sprintf(`<form action="/logout" method="post" class="logout">
|
||||||
|
%s
|
||||||
|
<button class="img-button" type="submit"><img src="/assets/images/button_logout.gif" alt="Log Out" title="Log Out"></button>
|
||||||
|
</form>
|
||||||
|
<a href="/help"><img src="/assets/images/button_help.gif" alt="Help" title="Help"></a>`, csrfField)
|
||||||
|
} else {
|
||||||
|
return `<a href="/login"><img src="/assets/images/button_login.gif" alt="Log In" title="Log In"></a>
|
||||||
|
<a href="/register"><img src="/assets/images/button_register.gif" alt="Register" title="Register"></a>
|
||||||
|
<a href="/help"><img src="/assets/images/button_help.gif" alt="Help" title="Help"></a>`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PageData holds common page template data
|
||||||
|
type PageData struct {
|
||||||
|
Title string
|
||||||
|
Content string
|
||||||
|
TopNav string
|
||||||
|
LeftSide string
|
||||||
|
RightSide string
|
||||||
|
TotalTime string
|
||||||
|
NumQueries string
|
||||||
|
Version string
|
||||||
|
Build string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenderPage renders a page using the layout template with common data and additional custom data
|
||||||
|
func RenderPage(ctx router.Ctx, pageData PageData, additionalData map[string]any) error {
|
||||||
|
if template.Cache == nil || auth.Manager == nil {
|
||||||
|
return fmt.Errorf("singleton template.Cache or auth.Manager not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
layoutTmpl, err := template.Cache.Load("layout.html")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load layout template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the base template data with common fields
|
||||||
|
data := map[string]any{
|
||||||
|
"title": pageData.Title,
|
||||||
|
"content": pageData.Content,
|
||||||
|
"topnav": GenerateTopNav(ctx),
|
||||||
|
"leftside": pageData.LeftSide,
|
||||||
|
"rightside": pageData.RightSide,
|
||||||
|
"totaltime": middleware.GetRequestTime(ctx),
|
||||||
|
"numqueries": pageData.NumQueries,
|
||||||
|
"version": pageData.Version,
|
||||||
|
"build": pageData.Build,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge in additional data (overwrites common data if keys conflict)
|
||||||
|
maps.Copy(data, additionalData)
|
||||||
|
|
||||||
|
// Set defaults for empty fields
|
||||||
|
if data["leftside"] == "" {
|
||||||
|
data["leftside"] = ""
|
||||||
|
}
|
||||||
|
if data["rightside"] == "" {
|
||||||
|
data["rightside"] = ""
|
||||||
|
}
|
||||||
|
if data["numqueries"] == "" {
|
||||||
|
data["numqueries"] = "0"
|
||||||
|
}
|
||||||
|
if data["version"] == "" {
|
||||||
|
data["version"] = "1.0.0"
|
||||||
|
}
|
||||||
|
if data["build"] == "" {
|
||||||
|
data["build"] = "dev"
|
||||||
|
}
|
||||||
|
|
||||||
|
layoutTmpl.WriteTo(ctx, data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPageData creates a new PageData with sensible defaults
|
||||||
|
func NewPageData(title, content string) PageData {
|
||||||
|
return PageData{
|
||||||
|
Title: title,
|
||||||
|
Content: content,
|
||||||
|
LeftSide: "",
|
||||||
|
RightSide: "",
|
||||||
|
NumQueries: "0",
|
||||||
|
Version: "1.0.0",
|
||||||
|
Build: "dev",
|
||||||
|
}
|
||||||
|
}
|
@ -12,7 +12,10 @@ import (
|
|||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Cache struct {
|
// Cache is the global singleton instance
|
||||||
|
var Cache *TemplateCache
|
||||||
|
|
||||||
|
type TemplateCache struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
templates map[string]*Template
|
templates map[string]*Template
|
||||||
basePath string
|
basePath string
|
||||||
@ -28,10 +31,10 @@ type Template struct {
|
|||||||
content string
|
content string
|
||||||
modTime time.Time
|
modTime time.Time
|
||||||
filePath string
|
filePath string
|
||||||
cache *Cache
|
cache *TemplateCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCache(basePath string) *Cache {
|
func NewCache(basePath string) *TemplateCache {
|
||||||
if basePath == "" {
|
if basePath == "" {
|
||||||
exe, err := os.Executable()
|
exe, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -41,13 +44,18 @@ func NewCache(basePath string) *Cache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Cache{
|
return &TemplateCache{
|
||||||
templates: make(map[string]*Template),
|
templates: make(map[string]*Template),
|
||||||
basePath: basePath,
|
basePath: basePath,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Load(name string) (*Template, error) {
|
// InitializeCache initializes the global Cache singleton
|
||||||
|
func InitializeCache(basePath string) {
|
||||||
|
Cache = NewCache(basePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TemplateCache) Load(name string) (*Template, error) {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
tmpl, exists := c.templates[name]
|
tmpl, exists := c.templates[name]
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
@ -62,7 +70,7 @@ func (c *Cache) Load(name string) (*Template, error) {
|
|||||||
return c.loadFromFile(name)
|
return c.loadFromFile(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) loadFromFile(name string) (*Template, error) {
|
func (c *TemplateCache) loadFromFile(name string) (*Template, error) {
|
||||||
filePath := filepath.Join(c.basePath, "templates", name)
|
filePath := filepath.Join(c.basePath, "templates", name)
|
||||||
|
|
||||||
info, err := os.Stat(filePath)
|
info, err := os.Stat(filePath)
|
||||||
@ -90,7 +98,7 @@ func (c *Cache) loadFromFile(name string) (*Template, error) {
|
|||||||
return tmpl, nil
|
return tmpl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) checkAndReload(tmpl *Template) error {
|
func (c *TemplateCache) checkAndReload(tmpl *Template) error {
|
||||||
info, err := os.Stat(tmpl.filePath)
|
info, err := os.Stat(tmpl.filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
Loading…
x
Reference in New Issue
Block a user