diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index 53662fd..e850a6c 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -71,6 +71,124 @@ func (am *AuthManager) Close() error {
return am.store.Close()
}
+// SetFlash stores a flash message in the session that will be removed after retrieval
+func (am *AuthManager) SetFlash(sessionID, key string, value any) bool {
+ session, exists := am.store.Get(sessionID)
+ if !exists {
+ return false
+ }
+
+ am.store.mu.Lock()
+ defer am.store.mu.Unlock()
+
+ if session.Data == nil {
+ session.Data = make(map[string]any)
+ }
+
+ // Store flash messages under a special key
+ flashData, ok := session.Data["_flash"].(map[string]any)
+ if !ok {
+ flashData = make(map[string]any)
+ }
+ flashData[key] = value
+ session.Data["_flash"] = flashData
+
+ return true
+}
+
+// GetFlash retrieves and removes a flash message from the session
+func (am *AuthManager) GetFlash(sessionID, key string) (any, bool) {
+ session, exists := am.store.Get(sessionID)
+ if !exists {
+ return nil, false
+ }
+
+ am.store.mu.Lock()
+ defer am.store.mu.Unlock()
+
+ if session.Data == nil {
+ return nil, false
+ }
+
+ flashData, ok := session.Data["_flash"].(map[string]any)
+ if !ok {
+ return nil, false
+ }
+
+ value, exists := flashData[key]
+ if exists {
+ delete(flashData, key)
+ if len(flashData) == 0 {
+ delete(session.Data, "_flash")
+ } else {
+ session.Data["_flash"] = flashData
+ }
+ }
+
+ return value, exists
+}
+
+// GetAllFlash retrieves and removes all flash messages from the session
+func (am *AuthManager) GetAllFlash(sessionID string) map[string]any {
+ session, exists := am.store.Get(sessionID)
+ if !exists {
+ return nil
+ }
+
+ am.store.mu.Lock()
+ defer am.store.mu.Unlock()
+
+ if session.Data == nil {
+ return nil
+ }
+
+ flashData, ok := session.Data["_flash"].(map[string]any)
+ if !ok {
+ return nil
+ }
+
+ // Remove flash data from session
+ delete(session.Data, "_flash")
+
+ return flashData
+}
+
+// SetSessionData stores arbitrary data in the session
+func (am *AuthManager) SetSessionData(sessionID, key string, value any) bool {
+ session, exists := am.store.Get(sessionID)
+ if !exists {
+ return false
+ }
+
+ am.store.mu.Lock()
+ defer am.store.mu.Unlock()
+
+ if session.Data == nil {
+ session.Data = make(map[string]any)
+ }
+
+ session.Data[key] = value
+ return true
+}
+
+// GetSessionData retrieves data from the session
+func (am *AuthManager) GetSessionData(sessionID, key string) (any, bool) {
+ session, exists := am.store.Get(sessionID)
+ if !exists {
+ return nil, false
+ }
+
+ am.store.mu.RLock()
+ defer am.store.mu.RUnlock()
+
+ if session.Data == nil {
+ return nil, false
+ }
+
+ value, exists := session.Data[key]
+ return value, exists
+}
+
var (
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
ErrSessionNotFound = &AuthError{"session not found"}
diff --git a/internal/auth/flash.go b/internal/auth/flash.go
new file mode 100644
index 0000000..b5e2ce9
--- /dev/null
+++ b/internal/auth/flash.go
@@ -0,0 +1,98 @@
+package auth
+
+import (
+ "dk/internal/router"
+)
+
+// FlashMessage represents a flash message with type and content
+type FlashMessage struct {
+ Type string `json:"type"` // "error", "success", "warning", "info"
+ Message string `json:"message"`
+}
+
+// SetFlashMessage sets a flash message for the current session
+func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
+ sessionID := GetSessionCookie(ctx)
+ if sessionID == "" {
+ return false
+ }
+
+ return Manager.SetFlash(sessionID, "message", FlashMessage{
+ Type: msgType,
+ Message: message,
+ })
+}
+
+// GetFlashMessage retrieves and removes the flash message from the current session
+func GetFlashMessage(ctx router.Ctx) *FlashMessage {
+ sessionID := GetSessionCookie(ctx)
+ if sessionID == "" {
+ return nil
+ }
+
+ value, exists := Manager.GetFlash(sessionID, "message")
+ if !exists {
+ return nil
+ }
+
+ if msg, ok := value.(FlashMessage); ok {
+ return &msg
+ }
+
+ // Handle map[string]interface{} from JSON deserialization
+ if msgMap, ok := value.(map[string]interface{}); ok {
+ msg := &FlashMessage{}
+ if t, ok := msgMap["type"].(string); ok {
+ msg.Type = t
+ }
+ if m, ok := msgMap["message"].(string); ok {
+ msg.Message = m
+ }
+ return msg
+ }
+
+ return nil
+}
+
+// SetFormData stores form data temporarily in the session (for repopulating forms after errors)
+func SetFormData(ctx router.Ctx, data map[string]string) bool {
+ sessionID := GetSessionCookie(ctx)
+ if sessionID == "" {
+ return false
+ }
+
+ return Manager.SetSessionData(sessionID, "form_data", data)
+}
+
+// GetFormData retrieves and removes form data from the session
+func GetFormData(ctx router.Ctx) map[string]string {
+ sessionID := GetSessionCookie(ctx)
+ if sessionID == "" {
+ return nil
+ }
+
+ value, exists := Manager.GetSessionData(sessionID, "form_data")
+ if !exists {
+ return nil
+ }
+
+ // Clear form data after retrieval
+ Manager.SetSessionData(sessionID, "form_data", nil)
+
+ if formData, ok := value.(map[string]string); ok {
+ return formData
+ }
+
+ // Handle map[string]interface{} from JSON deserialization
+ if formMap, ok := value.(map[string]interface{}); ok {
+ result := make(map[string]string)
+ for k, v := range formMap {
+ if str, ok := v.(string); ok {
+ result[k] = str
+ }
+ }
+ return result
+ }
+
+ return nil
+}
\ No newline at end of file
diff --git a/internal/routes/auth.go b/internal/routes/auth.go
index 774830b..aca1fb1 100644
--- a/internal/routes/auth.go
+++ b/internal/routes/auth.go
@@ -9,7 +9,6 @@ import (
"dk/internal/middleware"
"dk/internal/password"
"dk/internal/router"
- "dk/internal/template"
"dk/internal/template/components"
"dk/internal/users"
@@ -36,10 +35,24 @@ func RegisterAuthRoutes(r *router.Router) {
// showLogin displays the login form
func showLogin(ctx router.Ctx, _ []string) {
+ // Get flash message if any
+ var errorHTML string
+ if flash := auth.GetFlashMessage(ctx); flash != nil {
+ errorHTML = fmt.Sprintf(`
%s
`, flash.Message)
+ }
+
+ // Get form data if any (for preserving email/username on error)
+ formData := auth.GetFormData(ctx)
+ id := ""
+ if formData != nil {
+ id = formData["id"]
+ }
+
components.RenderPageTemplate(ctx, "Log In", "auth/login.html", map[string]any{
"csrf_token": csrf.GetToken(ctx, auth.Manager),
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
- "error_message": "",
+ "error_message": errorHTML,
+ "id": id,
})
}
@@ -51,17 +64,21 @@ func processLogin(ctx router.Ctx, _ []string) {
return
}
- email := strings.TrimSpace(string(ctx.PostArgs().Peek("email")))
+ email := strings.TrimSpace(string(ctx.PostArgs().Peek("id")))
userPassword := string(ctx.PostArgs().Peek("password"))
if email == "" || userPassword == "" {
- showLoginError(ctx, "Email and password are required")
+ auth.SetFlashMessage(ctx, "error", "Email and password are required")
+ auth.SetFormData(ctx, map[string]string{"id": email})
+ ctx.Redirect("/login", fasthttp.StatusFound)
return
}
user, err := auth.Manager.Authenticate(email, userPassword)
if err != nil {
- showLoginError(ctx, "Invalid email or password")
+ auth.SetFlashMessage(ctx, "error", "Invalid email or password")
+ auth.SetFormData(ctx, map[string]string{"id": email})
+ ctx.Redirect("/login", fasthttp.StatusFound)
return
}
@@ -79,12 +96,27 @@ func processLogin(ctx router.Ctx, _ []string) {
// showRegister displays the registration form
func showRegister(ctx router.Ctx, _ []string) {
+ // Get flash message if any
+ var errorHTML string
+ if flash := auth.GetFlashMessage(ctx); flash != nil {
+ errorHTML = fmt.Sprintf(`%s
`, flash.Message)
+ }
+
+ // Get form data if any (for preserving values on error)
+ formData := auth.GetFormData(ctx)
+ username := ""
+ email := ""
+ if formData != nil {
+ username = formData["username"]
+ email = formData["email"]
+ }
+
components.RenderPageTemplate(ctx, "Register", "auth/register.html", map[string]any{
"csrf_token": csrf.GetToken(ctx, auth.Manager),
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
- "error_message": "",
- "username": "",
- "email": "",
+ "error_message": errorHTML,
+ "username": username,
+ "email": email,
})
}
@@ -102,17 +134,32 @@ func processRegister(ctx router.Ctx, _ []string) {
confirmPassword := string(ctx.PostArgs().Peek("confirm_password"))
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
- showRegisterError(ctx, err.Error(), username, email)
+ auth.SetFlashMessage(ctx, "error", err.Error())
+ auth.SetFormData(ctx, map[string]string{
+ "username": username,
+ "email": email,
+ })
+ ctx.Redirect("/register", fasthttp.StatusFound)
return
}
if _, err := users.GetByUsername(username); err == nil {
- showRegisterError(ctx, "Username already exists", username, email)
+ auth.SetFlashMessage(ctx, "error", "Username already exists")
+ auth.SetFormData(ctx, map[string]string{
+ "username": username,
+ "email": email,
+ })
+ ctx.Redirect("/register", fasthttp.StatusFound)
return
}
if _, err := users.GetByEmail(email); err == nil {
- showRegisterError(ctx, "Email already registered", username, email)
+ auth.SetFlashMessage(ctx, "error", "Email already registered")
+ auth.SetFormData(ctx, map[string]string{
+ "username": username,
+ "email": email,
+ })
+ ctx.Redirect("/register", fasthttp.StatusFound)
return
}
@@ -124,7 +171,12 @@ func processRegister(ctx router.Ctx, _ []string) {
user.Auth = 1
if err := user.Insert(); err != nil {
- showRegisterError(ctx, "Failed to create account", username, email)
+ auth.SetFlashMessage(ctx, "error", "Failed to create account")
+ auth.SetFormData(ctx, map[string]string{
+ "username": username,
+ "email": email,
+ })
+ ctx.Redirect("/register", fasthttp.StatusFound)
return
}
@@ -156,68 +208,6 @@ func processLogout(ctx router.Ctx, params []string) {
// Helper functions
-func showLoginError(ctx router.Ctx, errorMsg string) {
- loginTmpl, err := template.Cache.Load("auth/login.html")
- if err != nil {
- ctx.SetStatusCode(fasthttp.StatusInternalServerError)
- fmt.Fprintf(ctx, "Template error: %v", err)
- return
- }
-
- var errorHTML string
- if errorMsg != "" {
- errorHTML = fmt.Sprintf(`%s
`, errorMsg)
- }
-
- loginFormData := map[string]any{
- "csrf_token": csrf.GetToken(ctx, auth.Manager),
- "csrf_field": csrf.HiddenField(ctx, auth.Manager),
- "error_message": errorHTML,
- }
-
- loginContent := loginTmpl.RenderNamed(loginFormData)
-
- ctx.SetStatusCode(fasthttp.StatusBadRequest)
- pageData := components.NewPageData("Login - Dragon Knight", loginContent)
- if err := components.RenderPage(ctx, pageData, nil); err != nil {
- ctx.SetStatusCode(fasthttp.StatusInternalServerError)
- fmt.Fprintf(ctx, "Template error: %v", err)
- return
- }
-}
-
-func showRegisterError(ctx router.Ctx, errorMsg, username, email string) {
- registerTmpl, err := template.Cache.Load("auth/register.html")
- if err != nil {
- ctx.SetStatusCode(fasthttp.StatusInternalServerError)
- fmt.Fprintf(ctx, "Template error: %v", err)
- return
- }
-
- var errorHTML string
- if errorMsg != "" {
- errorHTML = fmt.Sprintf(`%s
`, errorMsg)
- }
-
- registerFormData := map[string]any{
- "csrf_token": csrf.GetToken(ctx, auth.Manager),
- "csrf_field": csrf.HiddenField(ctx, auth.Manager),
- "error_message": errorHTML,
- "username": username,
- "email": email,
- }
-
- registerContent := registerTmpl.RenderNamed(registerFormData)
-
- ctx.SetStatusCode(fasthttp.StatusBadRequest)
- pageData := components.NewPageData("Register - Dragon Knight", registerContent)
- if err := components.RenderPage(ctx, pageData, nil); err != nil {
- ctx.SetStatusCode(fasthttp.StatusInternalServerError)
- fmt.Fprintf(ctx, "Template error: %v", err)
- return
- }
-}
-
func validateRegistration(username, email, password, confirmPassword string) error {
if username == "" {
return fmt.Errorf("username is required")
diff --git a/templates/auth/login.html b/templates/auth/login.html
index 6910e7e..4915284 100644
--- a/templates/auth/login.html
+++ b/templates/auth/login.html
@@ -7,7 +7,7 @@
-
+
diff --git a/templates/town/town.html b/templates/town/town.html
index 4081349..ce93aee 100644
--- a/templates/town/town.html
+++ b/templates/town/town.html
@@ -1,11 +1,11 @@