fix session management and storage

This commit is contained in:
Sky Johnson 2025-08-14 17:15:55 -05:00
parent 2bbff01c0d
commit b778469365
4 changed files with 68 additions and 83 deletions

View File

@ -24,15 +24,18 @@ func Middleware() router.Middleware {
if existingSess, exists := session.Get(sessionID); exists {
sess = existingSess
sess.Touch()
session.Store(sess)
if sess.UserID > 0 { // User session
user, err := users.Find(sess.UserID)
if err == nil && user != nil {
ctx.SetUserValue("user", user)
setSessionCookie(ctx, sessionID)
} else {
// User not found, reset to guest session
sess.SetUserID(0)
}
}
session.Store(sess)
setSessionCookie(ctx, sessionID)
}
}
@ -89,8 +92,10 @@ func RequireGuest(paths ...string) router.Middleware {
}
func IsAuthenticated(ctx router.Ctx) bool {
_, exists := ctx.UserValue("user").(*users.User)
return exists
if user, ok := ctx.UserValue("user").(*users.User); ok && user != nil {
return true
}
return false
}
func GetCurrentUser(ctx router.Ctx) *users.User {
@ -109,28 +114,43 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
func Login(ctx router.Ctx, user *users.User) {
sess := ctx.UserValue("session").(*session.Session)
sess.RegenerateID()
sess.SetUserID(user.ID) // Update the session's UserID field
// Update the session to be authenticated
sess.SetUserID(user.ID) // This updates the struct field
sess.RegenerateID() // Generate new ID for security
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
// Remove any old user_id from session data if it exists
sess.Delete("user_id")
session.Store(sess)
// Update context values
ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user)
// Update cookie with new session ID
setSessionCookie(ctx, sess.ID)
}
func Logout(ctx router.Ctx) {
sess := ctx.UserValue("session").(*session.Session)
if sess != nil {
sess.SetUserID(0) // Reset to guest session
// Convert back to guest session
sess.SetUserID(0) // Reset to guest
sess.RegenerateID() // Generate new ID for security
// Clean up any user-related session data
sess.Delete("user_id")
session.Store(sess)
ctx.SetUserValue("session", sess)
// Update cookie with new session ID
setSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue("user", nil)
// Update the cookie with the new session ID
setSessionCookie(ctx, sess.ID)
}
// Helper functions for session cookies

View File

@ -4,20 +4,20 @@
// # Basic Usage
//
// // Generate token and store in session
// token := csrf.GenerateToken(ctx, sessionManager)
// token := csrf.GenerateToken(ctx)
//
// // In templates - generate hidden input field
// hiddenField := csrf.HiddenField(ctx, sessionManager)
// hiddenField := csrf.HiddenField(ctx)
//
// // Verify form submission
// if !csrf.ValidateToken(ctx, sessionManager, formToken) {
// if !csrf.ValidateToken(ctx, formToken) {
// // Handle CSRF validation failure
// }
//
// # Middleware Integration
//
// // Add CSRF middleware to protected routes
// r.Use(middleware.CSRF(authManager))
// r.Use(csrf.Middleware())
package csrf
import (
@ -25,7 +25,6 @@ import (
"crypto/subtle"
"encoding/base64"
"fmt"
"time"
"dk/internal/router"
"dk/internal/session"
@ -37,11 +36,10 @@ const (
TokenLength = 32
TokenFieldName = "_csrf_token"
SessionKey = "csrf_token"
SessionCtxKey = "session" // Same as middleware.SessionKey
CookieName = "_csrf"
SessionCtxKey = "session"
)
// GetCurrentSession retrieves the session from context (mirrors middleware function)
// GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx router.Ctx) *session.Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return sess
@ -49,7 +47,7 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
return nil
}
// GenerateToken creates a new CSRF token and stores it in the session or cookie
// GenerateToken creates a new CSRF token and stores it in the session
func GenerateToken(ctx router.Ctx) string {
// Generate cryptographically secure random bytes
tokenBytes := make([]byte, TokenLength)
@ -60,54 +58,43 @@ func GenerateToken(ctx router.Ctx) string {
token := base64.URLEncoding.EncodeToString(tokenBytes)
// Store token in session if user is authenticated, otherwise use cookie
if session := GetCurrentSession(ctx); session != nil {
StoreToken(session, token)
} else {
// Store in cookie for guest users
StoreTokenInCookie(ctx, token)
// Store token in session (both guests and authenticated users have sessions)
if sess := GetCurrentSession(ctx); sess != nil {
StoreToken(sess, token)
session.Store(sess)
}
return token
}
// GetToken retrieves the current CSRF token from session or cookie, generating one if needed
// GetToken retrieves the current CSRF token from session, generating one if needed
func GetToken(ctx router.Ctx) string {
session := GetCurrentSession(ctx)
sess := GetCurrentSession(ctx)
if sess == nil {
return "" // No session available
}
if session != nil {
// Authenticated user - check session first
if existingToken := GetStoredToken(session); existingToken != "" {
return existingToken
}
} else {
// Guest user - check cookie first
if existingToken := GetTokenFromCookie(ctx); existingToken != "" {
return existingToken
}
// Check for existing token
if existingToken := GetStoredToken(sess); existingToken != "" {
return existingToken
}
// Generate new token if none exists
return GenerateToken(ctx)
}
// ValidateToken verifies a CSRF token against the stored session or cookie token
// ValidateToken verifies a CSRF token against the stored session token
func ValidateToken(ctx router.Ctx, submittedToken string) bool {
if submittedToken == "" {
return false
}
var storedToken string
session := GetCurrentSession(ctx)
if session != nil {
// Authenticated user - get token from session
storedToken = GetStoredToken(session)
} else {
// Guest user - get token from cookie
storedToken = GetTokenFromCookie(ctx)
sess := GetCurrentSession(ctx)
if sess == nil {
return false // No session
}
storedToken := GetStoredToken(sess)
if storedToken == "" {
return false // No stored token
}
@ -133,14 +120,13 @@ func GetStoredToken(sess *session.Session) string {
// RotateToken generates a new token and replaces the old one in the session
func RotateToken(ctx router.Ctx) string {
session := GetCurrentSession(ctx)
if session == nil {
sess := GetCurrentSession(ctx)
if sess == nil {
return ""
}
// Generate new token
// Generate new token (this will automatically store it)
newToken := GenerateToken(ctx)
return newToken
}
@ -181,23 +167,9 @@ func ValidateFormToken(ctx router.Ctx) bool {
return ValidateToken(ctx, string(tokenBytes))
}
// StoreTokenInCookie stores a CSRF token in a cookie for guest users
func StoreTokenInCookie(ctx router.Ctx, token string) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(CookieName)
cookie.SetValue(token)
cookie.SetHTTPOnly(true)
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
cookie.SetSecure(false) // Set to true in production with HTTPS
cookie.SetExpire(time.Now().Add(24 * time.Hour)) // Expire in 24 hours
cookie.SetPath("/")
ctx.Response.Header.SetCookie(cookie)
}
// GetTokenFromCookie retrieves a CSRF token from cookie for guest users
// GetTokenFromCookie retrieves a CSRF token from cookie (legacy support)
func GetTokenFromCookie(ctx router.Ctx) string {
return string(ctx.Request.Header.Cookie(CookieName))
return string(ctx.Request.Header.Cookie("_csrf"))
}
// Middleware returns a middleware function that automatically validates CSRF tokens

View File

@ -5,7 +5,6 @@ import (
"strings"
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/models/users"
"dk/internal/password"
"dk/internal/router"
@ -77,12 +76,7 @@ func processLogin(ctx router.Ctx, _ []string) {
auth.Login(ctx, user)
// Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
csrf.StoreToken(sess, cookieToken)
}
}
// CSRF token is already in session, no need to transfer from cookie
ctx.Redirect("/", fasthttp.StatusFound)
}
@ -158,20 +152,16 @@ func processRegister(ctx router.Ctx, _ []string) {
return
}
// Auto-login after registration (this will update the current session)
auth.Login(ctx, user)
// Set success message
// Update success message (Login already sets a message, so override it)
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username))
session.Store(sess)
}
// Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
csrf.StoreToken(sess, cookieToken)
}
}
// CSRF token is already in session, no need to transfer from cookie
ctx.Redirect("/", fasthttp.StatusFound)
}

View File

@ -131,10 +131,13 @@ func (sm *SessionManager) load() {
if data != nil && data.ExpiresAt > now {
sess := &Session{
ID: id,
UserID: data.UserID,
UserID: data.UserID, // Make sure we restore the UserID properly
ExpiresAt: data.ExpiresAt,
Data: data.Data,
}
if sess.Data == nil {
sess.Data = make(map[string]any)
}
sm.sessions[id] = sess
}
}
@ -155,7 +158,7 @@ func (sm *SessionManager) Save() error {
sessionsData := make(map[string]*sessionData, len(sm.sessions))
for id, sess := range sm.sessions {
sessionsData[id] = &sessionData{
UserID: sess.UserID,
UserID: sess.UserID, // Save the actual UserID from the struct
ExpiresAt: sess.ExpiresAt,
Data: sess.Data,
}