From b77846936578ee403be23fc09eed763b28488660 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 14 Aug 2025 17:15:55 -0500 Subject: [PATCH] fix session management and storage --- internal/auth/auth.go | 40 ++++++++++++----- internal/csrf/csrf.go | 86 +++++++++++++------------------------ internal/routes/auth.go | 18 ++------ internal/session/manager.go | 7 ++- 4 files changed, 68 insertions(+), 83 deletions(-) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 50941c3..84d6630 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -24,15 +24,18 @@ func Middleware() router.Middleware { if existingSess, exists := session.Get(sessionID); exists { sess = existingSess sess.Touch() - session.Store(sess) if sess.UserID > 0 { // User session user, err := users.Find(sess.UserID) if err == nil && user != nil { ctx.SetUserValue("user", user) - setSessionCookie(ctx, sessionID) + } else { + // User not found, reset to guest session + sess.SetUserID(0) } } + session.Store(sess) + setSessionCookie(ctx, sessionID) } } @@ -89,8 +92,10 @@ func RequireGuest(paths ...string) router.Middleware { } func IsAuthenticated(ctx router.Ctx) bool { - _, exists := ctx.UserValue("user").(*users.User) - return exists + if user, ok := ctx.UserValue("user").(*users.User); ok && user != nil { + return true + } + return false } func GetCurrentUser(ctx router.Ctx) *users.User { @@ -109,28 +114,43 @@ func GetCurrentSession(ctx router.Ctx) *session.Session { func Login(ctx router.Ctx, user *users.User) { sess := ctx.UserValue("session").(*session.Session) - sess.RegenerateID() - sess.SetUserID(user.ID) // Update the session's UserID field + + // Update the session to be authenticated + sess.SetUserID(user.ID) // This updates the struct field + sess.RegenerateID() // Generate new ID for security sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username)) + + // Remove any old user_id from session data if it exists + sess.Delete("user_id") + session.Store(sess) + // Update context values ctx.SetUserValue("session", sess) ctx.SetUserValue("user", user) + + // Update cookie with new session ID + setSessionCookie(ctx, sess.ID) } func Logout(ctx router.Ctx) { sess := ctx.UserValue("session").(*session.Session) if sess != nil { - sess.SetUserID(0) // Reset to guest session + // Convert back to guest session + sess.SetUserID(0) // Reset to guest sess.RegenerateID() // Generate new ID for security + + // Clean up any user-related session data + sess.Delete("user_id") + session.Store(sess) ctx.SetUserValue("session", sess) + + // Update cookie with new session ID + setSessionCookie(ctx, sess.ID) } ctx.SetUserValue("user", nil) - - // Update the cookie with the new session ID - setSessionCookie(ctx, sess.ID) } // Helper functions for session cookies diff --git a/internal/csrf/csrf.go b/internal/csrf/csrf.go index 9545137..28955d5 100644 --- a/internal/csrf/csrf.go +++ b/internal/csrf/csrf.go @@ -4,20 +4,20 @@ // # Basic Usage // // // Generate token and store in session -// token := csrf.GenerateToken(ctx, sessionManager) +// token := csrf.GenerateToken(ctx) // // // In templates - generate hidden input field -// hiddenField := csrf.HiddenField(ctx, sessionManager) +// hiddenField := csrf.HiddenField(ctx) // // // Verify form submission -// if !csrf.ValidateToken(ctx, sessionManager, formToken) { +// if !csrf.ValidateToken(ctx, formToken) { // // Handle CSRF validation failure // } // // # Middleware Integration // // // Add CSRF middleware to protected routes -// r.Use(middleware.CSRF(authManager)) +// r.Use(csrf.Middleware()) package csrf import ( @@ -25,7 +25,6 @@ import ( "crypto/subtle" "encoding/base64" "fmt" - "time" "dk/internal/router" "dk/internal/session" @@ -37,11 +36,10 @@ const ( TokenLength = 32 TokenFieldName = "_csrf_token" SessionKey = "csrf_token" - SessionCtxKey = "session" // Same as middleware.SessionKey - CookieName = "_csrf" + SessionCtxKey = "session" ) -// GetCurrentSession retrieves the session from context (mirrors middleware function) +// GetCurrentSession retrieves the session from context func GetCurrentSession(ctx router.Ctx) *session.Session { if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok { return sess @@ -49,7 +47,7 @@ func GetCurrentSession(ctx router.Ctx) *session.Session { return nil } -// GenerateToken creates a new CSRF token and stores it in the session or cookie +// GenerateToken creates a new CSRF token and stores it in the session func GenerateToken(ctx router.Ctx) string { // Generate cryptographically secure random bytes tokenBytes := make([]byte, TokenLength) @@ -60,54 +58,43 @@ func GenerateToken(ctx router.Ctx) string { token := base64.URLEncoding.EncodeToString(tokenBytes) - // Store token in session if user is authenticated, otherwise use cookie - if session := GetCurrentSession(ctx); session != nil { - StoreToken(session, token) - } else { - // Store in cookie for guest users - StoreTokenInCookie(ctx, token) + // Store token in session (both guests and authenticated users have sessions) + if sess := GetCurrentSession(ctx); sess != nil { + StoreToken(sess, token) + session.Store(sess) } return token } -// GetToken retrieves the current CSRF token from session or cookie, generating one if needed +// GetToken retrieves the current CSRF token from session, generating one if needed func GetToken(ctx router.Ctx) string { - session := GetCurrentSession(ctx) + sess := GetCurrentSession(ctx) + if sess == nil { + return "" // No session available + } - if session != nil { - // Authenticated user - check session first - if existingToken := GetStoredToken(session); existingToken != "" { - return existingToken - } - } else { - // Guest user - check cookie first - if existingToken := GetTokenFromCookie(ctx); existingToken != "" { - return existingToken - } + // Check for existing token + if existingToken := GetStoredToken(sess); existingToken != "" { + return existingToken } // Generate new token if none exists return GenerateToken(ctx) } -// ValidateToken verifies a CSRF token against the stored session or cookie token +// ValidateToken verifies a CSRF token against the stored session token func ValidateToken(ctx router.Ctx, submittedToken string) bool { if submittedToken == "" { return false } - var storedToken string - session := GetCurrentSession(ctx) - - if session != nil { - // Authenticated user - get token from session - storedToken = GetStoredToken(session) - } else { - // Guest user - get token from cookie - storedToken = GetTokenFromCookie(ctx) + sess := GetCurrentSession(ctx) + if sess == nil { + return false // No session } + storedToken := GetStoredToken(sess) if storedToken == "" { return false // No stored token } @@ -133,14 +120,13 @@ func GetStoredToken(sess *session.Session) string { // RotateToken generates a new token and replaces the old one in the session func RotateToken(ctx router.Ctx) string { - session := GetCurrentSession(ctx) - if session == nil { + sess := GetCurrentSession(ctx) + if sess == nil { return "" } - // Generate new token + // Generate new token (this will automatically store it) newToken := GenerateToken(ctx) - return newToken } @@ -181,23 +167,9 @@ func ValidateFormToken(ctx router.Ctx) bool { return ValidateToken(ctx, string(tokenBytes)) } -// StoreTokenInCookie stores a CSRF token in a cookie for guest users -func StoreTokenInCookie(ctx router.Ctx, token string) { - cookie := &fasthttp.Cookie{} - cookie.SetKey(CookieName) - cookie.SetValue(token) - cookie.SetHTTPOnly(true) - cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) - cookie.SetSecure(false) // Set to true in production with HTTPS - cookie.SetExpire(time.Now().Add(24 * time.Hour)) // Expire in 24 hours - cookie.SetPath("/") - - ctx.Response.Header.SetCookie(cookie) -} - -// GetTokenFromCookie retrieves a CSRF token from cookie for guest users +// GetTokenFromCookie retrieves a CSRF token from cookie (legacy support) func GetTokenFromCookie(ctx router.Ctx) string { - return string(ctx.Request.Header.Cookie(CookieName)) + return string(ctx.Request.Header.Cookie("_csrf")) } // Middleware returns a middleware function that automatically validates CSRF tokens diff --git a/internal/routes/auth.go b/internal/routes/auth.go index 189de97..42bde28 100644 --- a/internal/routes/auth.go +++ b/internal/routes/auth.go @@ -5,7 +5,6 @@ import ( "strings" "dk/internal/auth" - "dk/internal/csrf" "dk/internal/models/users" "dk/internal/password" "dk/internal/router" @@ -77,12 +76,7 @@ func processLogin(ctx router.Ctx, _ []string) { auth.Login(ctx, user) - // Transfer CSRF token from cookie to session for authenticated user - if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { - if sess := ctx.UserValue("session").(*session.Session); sess != nil { - csrf.StoreToken(sess, cookieToken) - } - } + // CSRF token is already in session, no need to transfer from cookie ctx.Redirect("/", fasthttp.StatusFound) } @@ -158,20 +152,16 @@ func processRegister(ctx router.Ctx, _ []string) { return } + // Auto-login after registration (this will update the current session) auth.Login(ctx, user) - // Set success message + // Update success message (Login already sets a message, so override it) if sess := ctx.UserValue("session").(*session.Session); sess != nil { sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username)) session.Store(sess) } - // Transfer CSRF token from cookie to session for authenticated user - if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { - if sess := ctx.UserValue("session").(*session.Session); sess != nil { - csrf.StoreToken(sess, cookieToken) - } - } + // CSRF token is already in session, no need to transfer from cookie ctx.Redirect("/", fasthttp.StatusFound) } diff --git a/internal/session/manager.go b/internal/session/manager.go index 872d9bc..0735a02 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -131,10 +131,13 @@ func (sm *SessionManager) load() { if data != nil && data.ExpiresAt > now { sess := &Session{ ID: id, - UserID: data.UserID, + UserID: data.UserID, // Make sure we restore the UserID properly ExpiresAt: data.ExpiresAt, Data: data.Data, } + if sess.Data == nil { + sess.Data = make(map[string]any) + } sm.sessions[id] = sess } } @@ -155,7 +158,7 @@ func (sm *SessionManager) Save() error { sessionsData := make(map[string]*sessionData, len(sm.sessions)) for id, sess := range sm.sessions { sessionsData[id] = &sessionData{ - UserID: sess.UserID, + UserID: sess.UserID, // Save the actual UserID from the struct ExpiresAt: sess.ExpiresAt, Data: sess.Data, }