fix session management and storage

This commit is contained in:
Sky Johnson 2025-08-14 17:15:55 -05:00
parent 2bbff01c0d
commit b778469365
4 changed files with 68 additions and 83 deletions

View File

@ -24,17 +24,20 @@ func Middleware() router.Middleware {
if existingSess, exists := session.Get(sessionID); exists { if existingSess, exists := session.Get(sessionID); exists {
sess = existingSess sess = existingSess
sess.Touch() sess.Touch()
session.Store(sess)
if sess.UserID > 0 { // User session if sess.UserID > 0 { // User session
user, err := users.Find(sess.UserID) user, err := users.Find(sess.UserID)
if err == nil && user != nil { if err == nil && user != nil {
ctx.SetUserValue("user", user) ctx.SetUserValue("user", user)
} else {
// User not found, reset to guest session
sess.SetUserID(0)
}
}
session.Store(sess)
setSessionCookie(ctx, sessionID) setSessionCookie(ctx, sessionID)
} }
} }
}
}
// Create guest session if none exists // Create guest session if none exists
if sess == nil { if sess == nil {
@ -89,8 +92,10 @@ func RequireGuest(paths ...string) router.Middleware {
} }
func IsAuthenticated(ctx router.Ctx) bool { func IsAuthenticated(ctx router.Ctx) bool {
_, exists := ctx.UserValue("user").(*users.User) if user, ok := ctx.UserValue("user").(*users.User); ok && user != nil {
return exists return true
}
return false
} }
func GetCurrentUser(ctx router.Ctx) *users.User { func GetCurrentUser(ctx router.Ctx) *users.User {
@ -109,28 +114,43 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
func Login(ctx router.Ctx, user *users.User) { func Login(ctx router.Ctx, user *users.User) {
sess := ctx.UserValue("session").(*session.Session) sess := ctx.UserValue("session").(*session.Session)
sess.RegenerateID()
sess.SetUserID(user.ID) // Update the session's UserID field // Update the session to be authenticated
sess.SetUserID(user.ID) // This updates the struct field
sess.RegenerateID() // Generate new ID for security
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username)) sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
// Remove any old user_id from session data if it exists
sess.Delete("user_id")
session.Store(sess) session.Store(sess)
// Update context values
ctx.SetUserValue("session", sess) ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user) ctx.SetUserValue("user", user)
// Update cookie with new session ID
setSessionCookie(ctx, sess.ID)
} }
func Logout(ctx router.Ctx) { func Logout(ctx router.Ctx) {
sess := ctx.UserValue("session").(*session.Session) sess := ctx.UserValue("session").(*session.Session)
if sess != nil { if sess != nil {
sess.SetUserID(0) // Reset to guest session // Convert back to guest session
sess.SetUserID(0) // Reset to guest
sess.RegenerateID() // Generate new ID for security sess.RegenerateID() // Generate new ID for security
// Clean up any user-related session data
sess.Delete("user_id")
session.Store(sess) session.Store(sess)
ctx.SetUserValue("session", sess) ctx.SetUserValue("session", sess)
// Update cookie with new session ID
setSessionCookie(ctx, sess.ID)
} }
ctx.SetUserValue("user", nil) ctx.SetUserValue("user", nil)
// Update the cookie with the new session ID
setSessionCookie(ctx, sess.ID)
} }
// Helper functions for session cookies // Helper functions for session cookies

View File

@ -4,20 +4,20 @@
// # Basic Usage // # Basic Usage
// //
// // Generate token and store in session // // Generate token and store in session
// token := csrf.GenerateToken(ctx, sessionManager) // token := csrf.GenerateToken(ctx)
// //
// // In templates - generate hidden input field // // In templates - generate hidden input field
// hiddenField := csrf.HiddenField(ctx, sessionManager) // hiddenField := csrf.HiddenField(ctx)
// //
// // Verify form submission // // Verify form submission
// if !csrf.ValidateToken(ctx, sessionManager, formToken) { // if !csrf.ValidateToken(ctx, formToken) {
// // Handle CSRF validation failure // // Handle CSRF validation failure
// } // }
// //
// # Middleware Integration // # Middleware Integration
// //
// // Add CSRF middleware to protected routes // // Add CSRF middleware to protected routes
// r.Use(middleware.CSRF(authManager)) // r.Use(csrf.Middleware())
package csrf package csrf
import ( import (
@ -25,7 +25,6 @@ import (
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"time"
"dk/internal/router" "dk/internal/router"
"dk/internal/session" "dk/internal/session"
@ -37,11 +36,10 @@ const (
TokenLength = 32 TokenLength = 32
TokenFieldName = "_csrf_token" TokenFieldName = "_csrf_token"
SessionKey = "csrf_token" SessionKey = "csrf_token"
SessionCtxKey = "session" // Same as middleware.SessionKey SessionCtxKey = "session"
CookieName = "_csrf"
) )
// GetCurrentSession retrieves the session from context (mirrors middleware function) // GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx router.Ctx) *session.Session { func GetCurrentSession(ctx router.Ctx) *session.Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok { if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return sess return sess
@ -49,7 +47,7 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
return nil return nil
} }
// GenerateToken creates a new CSRF token and stores it in the session or cookie // GenerateToken creates a new CSRF token and stores it in the session
func GenerateToken(ctx router.Ctx) string { func GenerateToken(ctx router.Ctx) string {
// Generate cryptographically secure random bytes // Generate cryptographically secure random bytes
tokenBytes := make([]byte, TokenLength) tokenBytes := make([]byte, TokenLength)
@ -60,54 +58,43 @@ func GenerateToken(ctx router.Ctx) string {
token := base64.URLEncoding.EncodeToString(tokenBytes) token := base64.URLEncoding.EncodeToString(tokenBytes)
// Store token in session if user is authenticated, otherwise use cookie // Store token in session (both guests and authenticated users have sessions)
if session := GetCurrentSession(ctx); session != nil { if sess := GetCurrentSession(ctx); sess != nil {
StoreToken(session, token) StoreToken(sess, token)
} else { session.Store(sess)
// Store in cookie for guest users
StoreTokenInCookie(ctx, token)
} }
return token return token
} }
// GetToken retrieves the current CSRF token from session or cookie, generating one if needed // GetToken retrieves the current CSRF token from session, generating one if needed
func GetToken(ctx router.Ctx) string { func GetToken(ctx router.Ctx) string {
session := GetCurrentSession(ctx) sess := GetCurrentSession(ctx)
if sess == nil {
return "" // No session available
}
if session != nil { // Check for existing token
// Authenticated user - check session first if existingToken := GetStoredToken(sess); existingToken != "" {
if existingToken := GetStoredToken(session); existingToken != "" {
return existingToken return existingToken
} }
} else {
// Guest user - check cookie first
if existingToken := GetTokenFromCookie(ctx); existingToken != "" {
return existingToken
}
}
// Generate new token if none exists // Generate new token if none exists
return GenerateToken(ctx) return GenerateToken(ctx)
} }
// ValidateToken verifies a CSRF token against the stored session or cookie token // ValidateToken verifies a CSRF token against the stored session token
func ValidateToken(ctx router.Ctx, submittedToken string) bool { func ValidateToken(ctx router.Ctx, submittedToken string) bool {
if submittedToken == "" { if submittedToken == "" {
return false return false
} }
var storedToken string sess := GetCurrentSession(ctx)
session := GetCurrentSession(ctx) if sess == nil {
return false // No session
if session != nil {
// Authenticated user - get token from session
storedToken = GetStoredToken(session)
} else {
// Guest user - get token from cookie
storedToken = GetTokenFromCookie(ctx)
} }
storedToken := GetStoredToken(sess)
if storedToken == "" { if storedToken == "" {
return false // No stored token return false // No stored token
} }
@ -133,14 +120,13 @@ func GetStoredToken(sess *session.Session) string {
// RotateToken generates a new token and replaces the old one in the session // RotateToken generates a new token and replaces the old one in the session
func RotateToken(ctx router.Ctx) string { func RotateToken(ctx router.Ctx) string {
session := GetCurrentSession(ctx) sess := GetCurrentSession(ctx)
if session == nil { if sess == nil {
return "" return ""
} }
// Generate new token // Generate new token (this will automatically store it)
newToken := GenerateToken(ctx) newToken := GenerateToken(ctx)
return newToken return newToken
} }
@ -181,23 +167,9 @@ func ValidateFormToken(ctx router.Ctx) bool {
return ValidateToken(ctx, string(tokenBytes)) return ValidateToken(ctx, string(tokenBytes))
} }
// StoreTokenInCookie stores a CSRF token in a cookie for guest users // GetTokenFromCookie retrieves a CSRF token from cookie (legacy support)
func StoreTokenInCookie(ctx router.Ctx, token string) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(CookieName)
cookie.SetValue(token)
cookie.SetHTTPOnly(true)
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
cookie.SetSecure(false) // Set to true in production with HTTPS
cookie.SetExpire(time.Now().Add(24 * time.Hour)) // Expire in 24 hours
cookie.SetPath("/")
ctx.Response.Header.SetCookie(cookie)
}
// GetTokenFromCookie retrieves a CSRF token from cookie for guest users
func GetTokenFromCookie(ctx router.Ctx) string { func GetTokenFromCookie(ctx router.Ctx) string {
return string(ctx.Request.Header.Cookie(CookieName)) return string(ctx.Request.Header.Cookie("_csrf"))
} }
// Middleware returns a middleware function that automatically validates CSRF tokens // Middleware returns a middleware function that automatically validates CSRF tokens

View File

@ -5,7 +5,6 @@ import (
"strings" "strings"
"dk/internal/auth" "dk/internal/auth"
"dk/internal/csrf"
"dk/internal/models/users" "dk/internal/models/users"
"dk/internal/password" "dk/internal/password"
"dk/internal/router" "dk/internal/router"
@ -77,12 +76,7 @@ func processLogin(ctx router.Ctx, _ []string) {
auth.Login(ctx, user) auth.Login(ctx, user)
// Transfer CSRF token from cookie to session for authenticated user // CSRF token is already in session, no need to transfer from cookie
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
csrf.StoreToken(sess, cookieToken)
}
}
ctx.Redirect("/", fasthttp.StatusFound) ctx.Redirect("/", fasthttp.StatusFound)
} }
@ -158,20 +152,16 @@ func processRegister(ctx router.Ctx, _ []string) {
return return
} }
// Auto-login after registration (this will update the current session)
auth.Login(ctx, user) auth.Login(ctx, user)
// Set success message // Update success message (Login already sets a message, so override it)
if sess := ctx.UserValue("session").(*session.Session); sess != nil { if sess := ctx.UserValue("session").(*session.Session); sess != nil {
sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username)) sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username))
session.Store(sess) session.Store(sess)
} }
// Transfer CSRF token from cookie to session for authenticated user // CSRF token is already in session, no need to transfer from cookie
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
csrf.StoreToken(sess, cookieToken)
}
}
ctx.Redirect("/", fasthttp.StatusFound) ctx.Redirect("/", fasthttp.StatusFound)
} }

View File

@ -131,10 +131,13 @@ func (sm *SessionManager) load() {
if data != nil && data.ExpiresAt > now { if data != nil && data.ExpiresAt > now {
sess := &Session{ sess := &Session{
ID: id, ID: id,
UserID: data.UserID, UserID: data.UserID, // Make sure we restore the UserID properly
ExpiresAt: data.ExpiresAt, ExpiresAt: data.ExpiresAt,
Data: data.Data, Data: data.Data,
} }
if sess.Data == nil {
sess.Data = make(map[string]any)
}
sm.sessions[id] = sess sm.sessions[id] = sess
} }
} }
@ -155,7 +158,7 @@ func (sm *SessionManager) Save() error {
sessionsData := make(map[string]*sessionData, len(sm.sessions)) sessionsData := make(map[string]*sessionData, len(sm.sessions))
for id, sess := range sm.sessions { for id, sess := range sm.sessions {
sessionsData[id] = &sessionData{ sessionsData[id] = &sessionData{
UserID: sess.UserID, UserID: sess.UserID, // Save the actual UserID from the struct
ExpiresAt: sess.ExpiresAt, ExpiresAt: sess.ExpiresAt,
Data: sess.Data, Data: sess.Data,
} }