From b8b77351d0b86edfe03a362306fd0f19abf9638f Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Mon, 11 Aug 2025 12:24:16 -0500 Subject: [PATCH] add flash messages, preserve usernames/emails in forms --- internal/auth/auth.go | 118 ++++++++++++++++++++++++++++++++ internal/auth/flash.go | 98 +++++++++++++++++++++++++++ internal/routes/auth.go | 138 ++++++++++++++++++-------------------- templates/auth/login.html | 2 +- templates/town/town.html | 10 +-- 5 files changed, 286 insertions(+), 80 deletions(-) create mode 100644 internal/auth/flash.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 53662fd..e850a6c 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -71,6 +71,124 @@ func (am *AuthManager) Close() error { return am.store.Close() } +// SetFlash stores a flash message in the session that will be removed after retrieval +func (am *AuthManager) SetFlash(sessionID, key string, value any) bool { + session, exists := am.store.Get(sessionID) + if !exists { + return false + } + + am.store.mu.Lock() + defer am.store.mu.Unlock() + + if session.Data == nil { + session.Data = make(map[string]any) + } + + // Store flash messages under a special key + flashData, ok := session.Data["_flash"].(map[string]any) + if !ok { + flashData = make(map[string]any) + } + flashData[key] = value + session.Data["_flash"] = flashData + + return true +} + +// GetFlash retrieves and removes a flash message from the session +func (am *AuthManager) GetFlash(sessionID, key string) (any, bool) { + session, exists := am.store.Get(sessionID) + if !exists { + return nil, false + } + + am.store.mu.Lock() + defer am.store.mu.Unlock() + + if session.Data == nil { + return nil, false + } + + flashData, ok := session.Data["_flash"].(map[string]any) + if !ok { + return nil, false + } + + value, exists := flashData[key] + if exists { + delete(flashData, key) + if len(flashData) == 0 { + delete(session.Data, "_flash") + } else { + session.Data["_flash"] = flashData + } + } + + return value, exists +} + +// GetAllFlash retrieves and removes all flash messages from the session +func (am *AuthManager) GetAllFlash(sessionID string) map[string]any { + session, exists := am.store.Get(sessionID) + if !exists { + return nil + } + + am.store.mu.Lock() + defer am.store.mu.Unlock() + + if session.Data == nil { + return nil + } + + flashData, ok := session.Data["_flash"].(map[string]any) + if !ok { + return nil + } + + // Remove flash data from session + delete(session.Data, "_flash") + + return flashData +} + +// SetSessionData stores arbitrary data in the session +func (am *AuthManager) SetSessionData(sessionID, key string, value any) bool { + session, exists := am.store.Get(sessionID) + if !exists { + return false + } + + am.store.mu.Lock() + defer am.store.mu.Unlock() + + if session.Data == nil { + session.Data = make(map[string]any) + } + + session.Data[key] = value + return true +} + +// GetSessionData retrieves data from the session +func (am *AuthManager) GetSessionData(sessionID, key string) (any, bool) { + session, exists := am.store.Get(sessionID) + if !exists { + return nil, false + } + + am.store.mu.RLock() + defer am.store.mu.RUnlock() + + if session.Data == nil { + return nil, false + } + + value, exists := session.Data[key] + return value, exists +} + var ( ErrInvalidCredentials = &AuthError{"invalid username/email or password"} ErrSessionNotFound = &AuthError{"session not found"} diff --git a/internal/auth/flash.go b/internal/auth/flash.go new file mode 100644 index 0000000..b5e2ce9 --- /dev/null +++ b/internal/auth/flash.go @@ -0,0 +1,98 @@ +package auth + +import ( + "dk/internal/router" +) + +// FlashMessage represents a flash message with type and content +type FlashMessage struct { + Type string `json:"type"` // "error", "success", "warning", "info" + Message string `json:"message"` +} + +// SetFlashMessage sets a flash message for the current session +func SetFlashMessage(ctx router.Ctx, msgType, message string) bool { + sessionID := GetSessionCookie(ctx) + if sessionID == "" { + return false + } + + return Manager.SetFlash(sessionID, "message", FlashMessage{ + Type: msgType, + Message: message, + }) +} + +// GetFlashMessage retrieves and removes the flash message from the current session +func GetFlashMessage(ctx router.Ctx) *FlashMessage { + sessionID := GetSessionCookie(ctx) + if sessionID == "" { + return nil + } + + value, exists := Manager.GetFlash(sessionID, "message") + if !exists { + return nil + } + + if msg, ok := value.(FlashMessage); ok { + return &msg + } + + // Handle map[string]interface{} from JSON deserialization + if msgMap, ok := value.(map[string]interface{}); ok { + msg := &FlashMessage{} + if t, ok := msgMap["type"].(string); ok { + msg.Type = t + } + if m, ok := msgMap["message"].(string); ok { + msg.Message = m + } + return msg + } + + return nil +} + +// SetFormData stores form data temporarily in the session (for repopulating forms after errors) +func SetFormData(ctx router.Ctx, data map[string]string) bool { + sessionID := GetSessionCookie(ctx) + if sessionID == "" { + return false + } + + return Manager.SetSessionData(sessionID, "form_data", data) +} + +// GetFormData retrieves and removes form data from the session +func GetFormData(ctx router.Ctx) map[string]string { + sessionID := GetSessionCookie(ctx) + if sessionID == "" { + return nil + } + + value, exists := Manager.GetSessionData(sessionID, "form_data") + if !exists { + return nil + } + + // Clear form data after retrieval + Manager.SetSessionData(sessionID, "form_data", nil) + + if formData, ok := value.(map[string]string); ok { + return formData + } + + // Handle map[string]interface{} from JSON deserialization + if formMap, ok := value.(map[string]interface{}); ok { + result := make(map[string]string) + for k, v := range formMap { + if str, ok := v.(string); ok { + result[k] = str + } + } + return result + } + + return nil +} \ No newline at end of file diff --git a/internal/routes/auth.go b/internal/routes/auth.go index 774830b..aca1fb1 100644 --- a/internal/routes/auth.go +++ b/internal/routes/auth.go @@ -9,7 +9,6 @@ import ( "dk/internal/middleware" "dk/internal/password" "dk/internal/router" - "dk/internal/template" "dk/internal/template/components" "dk/internal/users" @@ -36,10 +35,24 @@ func RegisterAuthRoutes(r *router.Router) { // showLogin displays the login form func showLogin(ctx router.Ctx, _ []string) { + // Get flash message if any + var errorHTML string + if flash := auth.GetFlashMessage(ctx); flash != nil { + errorHTML = fmt.Sprintf(`
%s
`, flash.Message) + } + + // Get form data if any (for preserving email/username on error) + formData := auth.GetFormData(ctx) + id := "" + if formData != nil { + id = formData["id"] + } + components.RenderPageTemplate(ctx, "Log In", "auth/login.html", map[string]any{ "csrf_token": csrf.GetToken(ctx, auth.Manager), "csrf_field": csrf.HiddenField(ctx, auth.Manager), - "error_message": "", + "error_message": errorHTML, + "id": id, }) } @@ -51,17 +64,21 @@ func processLogin(ctx router.Ctx, _ []string) { return } - email := strings.TrimSpace(string(ctx.PostArgs().Peek("email"))) + email := strings.TrimSpace(string(ctx.PostArgs().Peek("id"))) userPassword := string(ctx.PostArgs().Peek("password")) if email == "" || userPassword == "" { - showLoginError(ctx, "Email and password are required") + auth.SetFlashMessage(ctx, "error", "Email and password are required") + auth.SetFormData(ctx, map[string]string{"id": email}) + ctx.Redirect("/login", fasthttp.StatusFound) return } user, err := auth.Manager.Authenticate(email, userPassword) if err != nil { - showLoginError(ctx, "Invalid email or password") + auth.SetFlashMessage(ctx, "error", "Invalid email or password") + auth.SetFormData(ctx, map[string]string{"id": email}) + ctx.Redirect("/login", fasthttp.StatusFound) return } @@ -79,12 +96,27 @@ func processLogin(ctx router.Ctx, _ []string) { // showRegister displays the registration form func showRegister(ctx router.Ctx, _ []string) { + // Get flash message if any + var errorHTML string + if flash := auth.GetFlashMessage(ctx); flash != nil { + errorHTML = fmt.Sprintf(`
%s
`, flash.Message) + } + + // Get form data if any (for preserving values on error) + formData := auth.GetFormData(ctx) + username := "" + email := "" + if formData != nil { + username = formData["username"] + email = formData["email"] + } + components.RenderPageTemplate(ctx, "Register", "auth/register.html", map[string]any{ "csrf_token": csrf.GetToken(ctx, auth.Manager), "csrf_field": csrf.HiddenField(ctx, auth.Manager), - "error_message": "", - "username": "", - "email": "", + "error_message": errorHTML, + "username": username, + "email": email, }) } @@ -102,17 +134,32 @@ func processRegister(ctx router.Ctx, _ []string) { confirmPassword := string(ctx.PostArgs().Peek("confirm_password")) if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil { - showRegisterError(ctx, err.Error(), username, email) + auth.SetFlashMessage(ctx, "error", err.Error()) + auth.SetFormData(ctx, map[string]string{ + "username": username, + "email": email, + }) + ctx.Redirect("/register", fasthttp.StatusFound) return } if _, err := users.GetByUsername(username); err == nil { - showRegisterError(ctx, "Username already exists", username, email) + auth.SetFlashMessage(ctx, "error", "Username already exists") + auth.SetFormData(ctx, map[string]string{ + "username": username, + "email": email, + }) + ctx.Redirect("/register", fasthttp.StatusFound) return } if _, err := users.GetByEmail(email); err == nil { - showRegisterError(ctx, "Email already registered", username, email) + auth.SetFlashMessage(ctx, "error", "Email already registered") + auth.SetFormData(ctx, map[string]string{ + "username": username, + "email": email, + }) + ctx.Redirect("/register", fasthttp.StatusFound) return } @@ -124,7 +171,12 @@ func processRegister(ctx router.Ctx, _ []string) { user.Auth = 1 if err := user.Insert(); err != nil { - showRegisterError(ctx, "Failed to create account", username, email) + auth.SetFlashMessage(ctx, "error", "Failed to create account") + auth.SetFormData(ctx, map[string]string{ + "username": username, + "email": email, + }) + ctx.Redirect("/register", fasthttp.StatusFound) return } @@ -156,68 +208,6 @@ func processLogout(ctx router.Ctx, params []string) { // Helper functions -func showLoginError(ctx router.Ctx, errorMsg string) { - loginTmpl, err := template.Cache.Load("auth/login.html") - if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } - - var errorHTML string - if errorMsg != "" { - errorHTML = fmt.Sprintf(`
%s
`, errorMsg) - } - - loginFormData := map[string]any{ - "csrf_token": csrf.GetToken(ctx, auth.Manager), - "csrf_field": csrf.HiddenField(ctx, auth.Manager), - "error_message": errorHTML, - } - - loginContent := loginTmpl.RenderNamed(loginFormData) - - ctx.SetStatusCode(fasthttp.StatusBadRequest) - pageData := components.NewPageData("Login - Dragon Knight", loginContent) - if err := components.RenderPage(ctx, pageData, nil); err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } -} - -func showRegisterError(ctx router.Ctx, errorMsg, username, email string) { - registerTmpl, err := template.Cache.Load("auth/register.html") - if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } - - var errorHTML string - if errorMsg != "" { - errorHTML = fmt.Sprintf(`
%s
`, errorMsg) - } - - registerFormData := map[string]any{ - "csrf_token": csrf.GetToken(ctx, auth.Manager), - "csrf_field": csrf.HiddenField(ctx, auth.Manager), - "error_message": errorHTML, - "username": username, - "email": email, - } - - registerContent := registerTmpl.RenderNamed(registerFormData) - - ctx.SetStatusCode(fasthttp.StatusBadRequest) - pageData := components.NewPageData("Register - Dragon Knight", registerContent) - if err := components.RenderPage(ctx, pageData, nil); err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - fmt.Fprintf(ctx, "Template error: %v", err) - return - } -} - func validateRegistration(username, email, password, confirmPassword string) error { if username == "" { return fmt.Errorf("username is required") diff --git a/templates/auth/login.html b/templates/auth/login.html index 6910e7e..4915284 100644 --- a/templates/auth/login.html +++ b/templates/auth/login.html @@ -7,7 +7,7 @@
- +
diff --git a/templates/town/town.html b/templates/town/town.html index 4081349..ce93aee 100644 --- a/templates/town/town.html +++ b/templates/town/town.html @@ -1,11 +1,11 @@
Welcome to {town.Name}
- Town Options:
-