fix session management and storage
This commit is contained in:
parent
2bbff01c0d
commit
b778469365
@ -24,15 +24,18 @@ func Middleware() router.Middleware {
|
||||
if existingSess, exists := session.Get(sessionID); exists {
|
||||
sess = existingSess
|
||||
sess.Touch()
|
||||
session.Store(sess)
|
||||
|
||||
if sess.UserID > 0 { // User session
|
||||
user, err := users.Find(sess.UserID)
|
||||
if err == nil && user != nil {
|
||||
ctx.SetUserValue("user", user)
|
||||
setSessionCookie(ctx, sessionID)
|
||||
} else {
|
||||
// User not found, reset to guest session
|
||||
sess.SetUserID(0)
|
||||
}
|
||||
}
|
||||
session.Store(sess)
|
||||
setSessionCookie(ctx, sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,8 +92,10 @@ func RequireGuest(paths ...string) router.Middleware {
|
||||
}
|
||||
|
||||
func IsAuthenticated(ctx router.Ctx) bool {
|
||||
_, exists := ctx.UserValue("user").(*users.User)
|
||||
return exists
|
||||
if user, ok := ctx.UserValue("user").(*users.User); ok && user != nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func GetCurrentUser(ctx router.Ctx) *users.User {
|
||||
@ -109,28 +114,43 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||
|
||||
func Login(ctx router.Ctx, user *users.User) {
|
||||
sess := ctx.UserValue("session").(*session.Session)
|
||||
sess.RegenerateID()
|
||||
sess.SetUserID(user.ID) // Update the session's UserID field
|
||||
|
||||
// Update the session to be authenticated
|
||||
sess.SetUserID(user.ID) // This updates the struct field
|
||||
sess.RegenerateID() // Generate new ID for security
|
||||
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
|
||||
|
||||
// Remove any old user_id from session data if it exists
|
||||
sess.Delete("user_id")
|
||||
|
||||
session.Store(sess)
|
||||
|
||||
// Update context values
|
||||
ctx.SetUserValue("session", sess)
|
||||
ctx.SetUserValue("user", user)
|
||||
|
||||
// Update cookie with new session ID
|
||||
setSessionCookie(ctx, sess.ID)
|
||||
}
|
||||
|
||||
func Logout(ctx router.Ctx) {
|
||||
sess := ctx.UserValue("session").(*session.Session)
|
||||
if sess != nil {
|
||||
sess.SetUserID(0) // Reset to guest session
|
||||
// Convert back to guest session
|
||||
sess.SetUserID(0) // Reset to guest
|
||||
sess.RegenerateID() // Generate new ID for security
|
||||
|
||||
// Clean up any user-related session data
|
||||
sess.Delete("user_id")
|
||||
|
||||
session.Store(sess)
|
||||
ctx.SetUserValue("session", sess)
|
||||
|
||||
// Update cookie with new session ID
|
||||
setSessionCookie(ctx, sess.ID)
|
||||
}
|
||||
|
||||
ctx.SetUserValue("user", nil)
|
||||
|
||||
// Update the cookie with the new session ID
|
||||
setSessionCookie(ctx, sess.ID)
|
||||
}
|
||||
|
||||
// Helper functions for session cookies
|
||||
|
@ -4,20 +4,20 @@
|
||||
// # Basic Usage
|
||||
//
|
||||
// // Generate token and store in session
|
||||
// token := csrf.GenerateToken(ctx, sessionManager)
|
||||
// token := csrf.GenerateToken(ctx)
|
||||
//
|
||||
// // In templates - generate hidden input field
|
||||
// hiddenField := csrf.HiddenField(ctx, sessionManager)
|
||||
// hiddenField := csrf.HiddenField(ctx)
|
||||
//
|
||||
// // Verify form submission
|
||||
// if !csrf.ValidateToken(ctx, sessionManager, formToken) {
|
||||
// if !csrf.ValidateToken(ctx, formToken) {
|
||||
// // Handle CSRF validation failure
|
||||
// }
|
||||
//
|
||||
// # Middleware Integration
|
||||
//
|
||||
// // Add CSRF middleware to protected routes
|
||||
// r.Use(middleware.CSRF(authManager))
|
||||
// r.Use(csrf.Middleware())
|
||||
package csrf
|
||||
|
||||
import (
|
||||
@ -25,7 +25,6 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"dk/internal/router"
|
||||
"dk/internal/session"
|
||||
@ -37,11 +36,10 @@ const (
|
||||
TokenLength = 32
|
||||
TokenFieldName = "_csrf_token"
|
||||
SessionKey = "csrf_token"
|
||||
SessionCtxKey = "session" // Same as middleware.SessionKey
|
||||
CookieName = "_csrf"
|
||||
SessionCtxKey = "session"
|
||||
)
|
||||
|
||||
// GetCurrentSession retrieves the session from context (mirrors middleware function)
|
||||
// GetCurrentSession retrieves the session from context
|
||||
func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
|
||||
return sess
|
||||
@ -49,7 +47,7 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateToken creates a new CSRF token and stores it in the session or cookie
|
||||
// GenerateToken creates a new CSRF token and stores it in the session
|
||||
func GenerateToken(ctx router.Ctx) string {
|
||||
// Generate cryptographically secure random bytes
|
||||
tokenBytes := make([]byte, TokenLength)
|
||||
@ -60,54 +58,43 @@ func GenerateToken(ctx router.Ctx) string {
|
||||
|
||||
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Store token in session if user is authenticated, otherwise use cookie
|
||||
if session := GetCurrentSession(ctx); session != nil {
|
||||
StoreToken(session, token)
|
||||
} else {
|
||||
// Store in cookie for guest users
|
||||
StoreTokenInCookie(ctx, token)
|
||||
// Store token in session (both guests and authenticated users have sessions)
|
||||
if sess := GetCurrentSession(ctx); sess != nil {
|
||||
StoreToken(sess, token)
|
||||
session.Store(sess)
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// GetToken retrieves the current CSRF token from session or cookie, generating one if needed
|
||||
// GetToken retrieves the current CSRF token from session, generating one if needed
|
||||
func GetToken(ctx router.Ctx) string {
|
||||
session := GetCurrentSession(ctx)
|
||||
sess := GetCurrentSession(ctx)
|
||||
if sess == nil {
|
||||
return "" // No session available
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
// Authenticated user - check session first
|
||||
if existingToken := GetStoredToken(session); existingToken != "" {
|
||||
return existingToken
|
||||
}
|
||||
} else {
|
||||
// Guest user - check cookie first
|
||||
if existingToken := GetTokenFromCookie(ctx); existingToken != "" {
|
||||
return existingToken
|
||||
}
|
||||
// Check for existing token
|
||||
if existingToken := GetStoredToken(sess); existingToken != "" {
|
||||
return existingToken
|
||||
}
|
||||
|
||||
// Generate new token if none exists
|
||||
return GenerateToken(ctx)
|
||||
}
|
||||
|
||||
// ValidateToken verifies a CSRF token against the stored session or cookie token
|
||||
// ValidateToken verifies a CSRF token against the stored session token
|
||||
func ValidateToken(ctx router.Ctx, submittedToken string) bool {
|
||||
if submittedToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var storedToken string
|
||||
session := GetCurrentSession(ctx)
|
||||
|
||||
if session != nil {
|
||||
// Authenticated user - get token from session
|
||||
storedToken = GetStoredToken(session)
|
||||
} else {
|
||||
// Guest user - get token from cookie
|
||||
storedToken = GetTokenFromCookie(ctx)
|
||||
sess := GetCurrentSession(ctx)
|
||||
if sess == nil {
|
||||
return false // No session
|
||||
}
|
||||
|
||||
storedToken := GetStoredToken(sess)
|
||||
if storedToken == "" {
|
||||
return false // No stored token
|
||||
}
|
||||
@ -133,14 +120,13 @@ func GetStoredToken(sess *session.Session) string {
|
||||
|
||||
// RotateToken generates a new token and replaces the old one in the session
|
||||
func RotateToken(ctx router.Ctx) string {
|
||||
session := GetCurrentSession(ctx)
|
||||
if session == nil {
|
||||
sess := GetCurrentSession(ctx)
|
||||
if sess == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Generate new token
|
||||
// Generate new token (this will automatically store it)
|
||||
newToken := GenerateToken(ctx)
|
||||
|
||||
return newToken
|
||||
}
|
||||
|
||||
@ -181,23 +167,9 @@ func ValidateFormToken(ctx router.Ctx) bool {
|
||||
return ValidateToken(ctx, string(tokenBytes))
|
||||
}
|
||||
|
||||
// StoreTokenInCookie stores a CSRF token in a cookie for guest users
|
||||
func StoreTokenInCookie(ctx router.Ctx, token string) {
|
||||
cookie := &fasthttp.Cookie{}
|
||||
cookie.SetKey(CookieName)
|
||||
cookie.SetValue(token)
|
||||
cookie.SetHTTPOnly(true)
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||
cookie.SetSecure(false) // Set to true in production with HTTPS
|
||||
cookie.SetExpire(time.Now().Add(24 * time.Hour)) // Expire in 24 hours
|
||||
cookie.SetPath("/")
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// GetTokenFromCookie retrieves a CSRF token from cookie for guest users
|
||||
// GetTokenFromCookie retrieves a CSRF token from cookie (legacy support)
|
||||
func GetTokenFromCookie(ctx router.Ctx) string {
|
||||
return string(ctx.Request.Header.Cookie(CookieName))
|
||||
return string(ctx.Request.Header.Cookie("_csrf"))
|
||||
}
|
||||
|
||||
// Middleware returns a middleware function that automatically validates CSRF tokens
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"dk/internal/auth"
|
||||
"dk/internal/csrf"
|
||||
"dk/internal/models/users"
|
||||
"dk/internal/password"
|
||||
"dk/internal/router"
|
||||
@ -77,12 +76,7 @@ func processLogin(ctx router.Ctx, _ []string) {
|
||||
|
||||
auth.Login(ctx, user)
|
||||
|
||||
// Transfer CSRF token from cookie to session for authenticated user
|
||||
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
||||
csrf.StoreToken(sess, cookieToken)
|
||||
}
|
||||
}
|
||||
// CSRF token is already in session, no need to transfer from cookie
|
||||
|
||||
ctx.Redirect("/", fasthttp.StatusFound)
|
||||
}
|
||||
@ -158,20 +152,16 @@ func processRegister(ctx router.Ctx, _ []string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-login after registration (this will update the current session)
|
||||
auth.Login(ctx, user)
|
||||
|
||||
// Set success message
|
||||
// Update success message (Login already sets a message, so override it)
|
||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
||||
sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username))
|
||||
session.Store(sess)
|
||||
}
|
||||
|
||||
// Transfer CSRF token from cookie to session for authenticated user
|
||||
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
||||
csrf.StoreToken(sess, cookieToken)
|
||||
}
|
||||
}
|
||||
// CSRF token is already in session, no need to transfer from cookie
|
||||
|
||||
ctx.Redirect("/", fasthttp.StatusFound)
|
||||
}
|
||||
|
@ -131,10 +131,13 @@ func (sm *SessionManager) load() {
|
||||
if data != nil && data.ExpiresAt > now {
|
||||
sess := &Session{
|
||||
ID: id,
|
||||
UserID: data.UserID,
|
||||
UserID: data.UserID, // Make sure we restore the UserID properly
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
Data: data.Data,
|
||||
}
|
||||
if sess.Data == nil {
|
||||
sess.Data = make(map[string]any)
|
||||
}
|
||||
sm.sessions[id] = sess
|
||||
}
|
||||
}
|
||||
@ -155,7 +158,7 @@ func (sm *SessionManager) Save() error {
|
||||
sessionsData := make(map[string]*sessionData, len(sm.sessions))
|
||||
for id, sess := range sm.sessions {
|
||||
sessionsData[id] = &sessionData{
|
||||
UserID: sess.UserID,
|
||||
UserID: sess.UserID, // Save the actual UserID from the struct
|
||||
ExpiresAt: sess.ExpiresAt,
|
||||
Data: sess.Data,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user