diff --git a/internal/auth/auth.go b/internal/auth/auth.go index e850a6c..b277f57 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,33 +1,30 @@ +// Package auth provides authentication and session management functionality. +// It includes secure session storage with in-memory caching and JSON persistence, +// user authentication against the database, and secure cookie handling. package auth import ( "dk/internal/password" + "dk/internal/session" "dk/internal/users" ) -// Manager is the global singleton instance var Manager *AuthManager -// AuthManager is a wrapper for the session store to add -// authentication tools over the store itself type AuthManager struct { - store *SessionStore + store *session.Store } -// Init initializes the global auth manager (auth.Manager) func Init(sessionsFilePath string) { Manager = &AuthManager{ - store: NewSessionStore(sessionsFilePath), + store: session.NewStore(sessionsFilePath), } } -// Authenticate checks for the usernaname or email, then verifies the plain password -// against the stored hash. func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) { var user *users.User var err error - // Try to find user by username first user, err = users.GetByUsername(usernameOrEmail) if err != nil { user, err = users.GetByEmail(usernameOrEmail) @@ -47,16 +44,25 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*use return user, nil } -func (am *AuthManager) CreateSession(user *users.User) *Session { - return am.store.Create(user.ID, user.Username, user.Email) +func (am *AuthManager) CreateSession(user *users.User) *session.Session { + sess := session.New(user.ID, user.Username, user.Email) + am.store.Save(sess) + return sess } -func (am *AuthManager) GetSession(sessionID string) (*Session, bool) { +func (am *AuthManager) GetSession(sessionID string) (*session.Session, bool) { return am.store.Get(sessionID) } func (am *AuthManager) UpdateSession(sessionID string) bool { - return am.store.Update(sessionID) + sess, exists := am.store.Get(sessionID) + if !exists { + return false + } + + sess.Touch() + am.store.Save(sess) + return true } func (am *AuthManager) DeleteSession(sessionID string) { @@ -71,124 +77,6 @@ func (am *AuthManager) Close() error { return am.store.Close() } -// SetFlash stores a flash message in the session that will be removed after retrieval -func (am *AuthManager) SetFlash(sessionID, key string, value any) bool { - session, exists := am.store.Get(sessionID) - if !exists { - return false - } - - am.store.mu.Lock() - defer am.store.mu.Unlock() - - if session.Data == nil { - session.Data = make(map[string]any) - } - - // Store flash messages under a special key - flashData, ok := session.Data["_flash"].(map[string]any) - if !ok { - flashData = make(map[string]any) - } - flashData[key] = value - session.Data["_flash"] = flashData - - return true -} - -// GetFlash retrieves and removes a flash message from the session -func (am *AuthManager) GetFlash(sessionID, key string) (any, bool) { - session, exists := am.store.Get(sessionID) - if !exists { - return nil, false - } - - am.store.mu.Lock() - defer am.store.mu.Unlock() - - if session.Data == nil { - return nil, false - } - - flashData, ok := session.Data["_flash"].(map[string]any) - if !ok { - return nil, false - } - - value, exists := flashData[key] - if exists { - delete(flashData, key) - if len(flashData) == 0 { - delete(session.Data, "_flash") - } else { - session.Data["_flash"] = flashData - } - } - - return value, exists -} - -// GetAllFlash retrieves and removes all flash messages from the session -func (am *AuthManager) GetAllFlash(sessionID string) map[string]any { - session, exists := am.store.Get(sessionID) - if !exists { - return nil - } - - am.store.mu.Lock() - defer am.store.mu.Unlock() - - if session.Data == nil { - return nil - } - - flashData, ok := session.Data["_flash"].(map[string]any) - if !ok { - return nil - } - - // Remove flash data from session - delete(session.Data, "_flash") - - return flashData -} - -// SetSessionData stores arbitrary data in the session -func (am *AuthManager) SetSessionData(sessionID, key string, value any) bool { - session, exists := am.store.Get(sessionID) - if !exists { - return false - } - - am.store.mu.Lock() - defer am.store.mu.Unlock() - - if session.Data == nil { - session.Data = make(map[string]any) - } - - session.Data[key] = value - return true -} - -// GetSessionData retrieves data from the session -func (am *AuthManager) GetSessionData(sessionID, key string) (any, bool) { - session, exists := am.store.Get(sessionID) - if !exists { - return nil, false - } - - am.store.mu.RLock() - defer am.store.mu.RUnlock() - - if session.Data == nil { - return nil, false - } - - value, exists := session.Data[key] - return value, exists -} - var ( ErrInvalidCredentials = &AuthError{"invalid username/email or password"} ErrSessionNotFound = &AuthError{"session not found"} diff --git a/internal/auth/cookies.go b/internal/auth/cookies.go index 519f410..5904a56 100644 --- a/internal/auth/cookies.go +++ b/internal/auth/cookies.go @@ -2,18 +2,21 @@ package auth import ( "dk/internal/cookies" + "dk/internal/session" "dk/internal/utils" "time" "github.com/valyala/fasthttp" ) +const SessionCookieName = "dk_session" + func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) { cookies.SetSecureCookie(ctx, cookies.CookieOptions{ Name: SessionCookieName, Value: sessionID, Path: "/", - Expires: time.Now().Add(DefaultExpiration), + Expires: time.Now().Add(session.DefaultExpiration), HTTPOnly: true, Secure: utils.IsHTTPS(ctx), SameSite: "lax", @@ -26,4 +29,4 @@ func GetSessionCookie(ctx *fasthttp.RequestCtx) string { func DeleteSessionCookie(ctx *fasthttp.RequestCtx) { cookies.DeleteCookie(ctx, SessionCookieName) -} +} \ No newline at end of file diff --git a/internal/auth/doc.go b/internal/auth/doc.go deleted file mode 100644 index 48cd94f..0000000 --- a/internal/auth/doc.go +++ /dev/null @@ -1,4 +0,0 @@ -// Package auth provides authentication and session management functionality. -// It includes secure session storage with in-memory caching and JSON persistence, -// user authentication against the database, and secure cookie handling. -package auth \ No newline at end of file diff --git a/internal/auth/flash.go b/internal/auth/flash.go index b5e2ce9..d62bd4b 100644 --- a/internal/auth/flash.go +++ b/internal/auth/flash.go @@ -2,46 +2,52 @@ package auth import ( "dk/internal/router" + "dk/internal/session" ) -// FlashMessage represents a flash message with type and content -type FlashMessage struct { - Type string `json:"type"` // "error", "success", "warning", "info" - Message string `json:"message"` -} - -// SetFlashMessage sets a flash message for the current session func SetFlashMessage(ctx router.Ctx, msgType, message string) bool { sessionID := GetSessionCookie(ctx) if sessionID == "" { return false } - return Manager.SetFlash(sessionID, "message", FlashMessage{ + sess, exists := Manager.GetSession(sessionID) + if !exists { + return false + } + + sess.SetFlash("message", session.FlashMessage{ Type: msgType, Message: message, }) + Manager.store.Save(sess) + return true } -// GetFlashMessage retrieves and removes the flash message from the current session -func GetFlashMessage(ctx router.Ctx) *FlashMessage { +func GetFlashMessage(ctx router.Ctx) *session.FlashMessage { sessionID := GetSessionCookie(ctx) if sessionID == "" { return nil } - value, exists := Manager.GetFlash(sessionID, "message") + sess, exists := Manager.GetSession(sessionID) if !exists { return nil } - if msg, ok := value.(FlashMessage); ok { + value, exists := sess.GetFlash("message") + if !exists { + return nil + } + + Manager.store.Save(sess) + + if msg, ok := value.(session.FlashMessage); ok { return &msg } - // Handle map[string]interface{} from JSON deserialization if msgMap, ok := value.(map[string]interface{}); ok { - msg := &FlashMessage{} + msg := &session.FlashMessage{} if t, ok := msgMap["type"].(string); ok { msg.Type = t } @@ -54,36 +60,45 @@ func GetFlashMessage(ctx router.Ctx) *FlashMessage { return nil } -// SetFormData stores form data temporarily in the session (for repopulating forms after errors) func SetFormData(ctx router.Ctx, data map[string]string) bool { sessionID := GetSessionCookie(ctx) if sessionID == "" { return false } - return Manager.SetSessionData(sessionID, "form_data", data) + sess, exists := Manager.GetSession(sessionID) + if !exists { + return false + } + + sess.Set("form_data", data) + Manager.store.Save(sess) + return true } -// GetFormData retrieves and removes form data from the session func GetFormData(ctx router.Ctx) map[string]string { sessionID := GetSessionCookie(ctx) if sessionID == "" { return nil } - value, exists := Manager.GetSessionData(sessionID, "form_data") + sess, exists := Manager.GetSession(sessionID) if !exists { return nil } - // Clear form data after retrieval - Manager.SetSessionData(sessionID, "form_data", nil) + value, exists := sess.Get("form_data") + if !exists { + return nil + } + + sess.Delete("form_data") + Manager.store.Save(sess) if formData, ok := value.(map[string]string); ok { return formData } - // Handle map[string]interface{} from JSON deserialization if formMap, ok := value.(map[string]interface{}); ok { result := make(map[string]string) for k, v := range formMap { diff --git a/internal/auth/session.go b/internal/auth/session.go deleted file mode 100644 index f141bf1..0000000 --- a/internal/auth/session.go +++ /dev/null @@ -1,222 +0,0 @@ -package auth - -import ( - "crypto/rand" - "encoding/hex" - "encoding/json" - "maps" - "os" - "sync" - "time" -) - -const ( - SessionCookieName = "dk_session" - DefaultExpiration = 24 * time.Hour - SessionIDLength = 32 -) - -type Session struct { - ID string `json:"-"` // Exclude from JSON since it's stored as the map key - UserID int `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - CreatedAt time.Time `json:"created_at"` - ExpiresAt time.Time `json:"expires_at"` - LastSeen time.Time `json:"last_seen"` - Data map[string]any `json:"data,omitempty"` // For storing additional session data -} - -type SessionStore struct { - mu sync.RWMutex - sessions map[string]*Session - filePath string - saveInterval time.Duration - stopChan chan struct{} -} - -type persistedData struct { - Sessions map[string]*Session `json:"sessions"` - SavedAt time.Time `json:"saved_at"` -} - -func NewSessionStore(filePath string) *SessionStore { - store := &SessionStore{ - sessions: make(map[string]*Session), - filePath: filePath, - saveInterval: 5 * time.Minute, - stopChan: make(chan struct{}), - } - - store.loadFromFile() - store.startPeriodicSave() - - return store -} - -func (s *SessionStore) generateSessionID() string { - bytes := make([]byte, SessionIDLength) - rand.Read(bytes) - return hex.EncodeToString(bytes) -} - -func (s *SessionStore) Create(userID int, username, email string) *Session { - s.mu.Lock() - defer s.mu.Unlock() - - session := &Session{ - ID: s.generateSessionID(), - UserID: userID, - Username: username, - Email: email, - CreatedAt: time.Now(), - ExpiresAt: time.Now().Add(DefaultExpiration), - LastSeen: time.Now(), - } - - s.sessions[session.ID] = session - return session -} - -func (s *SessionStore) Get(sessionID string) (*Session, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - - session, exists := s.sessions[sessionID] - if !exists { - return nil, false - } - - if time.Now().After(session.ExpiresAt) { - delete(s.sessions, sessionID) - return nil, false - } - - return session, true -} - -func (s *SessionStore) Update(sessionID string) bool { - s.mu.Lock() - defer s.mu.Unlock() - - session, exists := s.sessions[sessionID] - if !exists { - return false - } - - if time.Now().After(session.ExpiresAt) { - delete(s.sessions, sessionID) - return false - } - - session.LastSeen = time.Now() - session.ExpiresAt = time.Now().Add(DefaultExpiration) - return true -} - -func (s *SessionStore) Delete(sessionID string) { - s.mu.Lock() - defer s.mu.Unlock() - - delete(s.sessions, sessionID) -} - -func (s *SessionStore) Cleanup() { - s.mu.Lock() - defer s.mu.Unlock() - - now := time.Now() - for id, session := range s.sessions { - if now.After(session.ExpiresAt) { - delete(s.sessions, id) - } - } -} - -func (s *SessionStore) loadFromFile() { - if s.filePath == "" { - return - } - - data, err := os.ReadFile(s.filePath) - if err != nil { - return // File might not exist yet - } - - var persisted persistedData - if err := json.Unmarshal(data, &persisted); err != nil { - return - } - - s.mu.Lock() - defer s.mu.Unlock() - - now := time.Now() - for id, session := range persisted.Sessions { - if now.Before(session.ExpiresAt) { - s.sessions[id] = session - } - } -} - -func (s *SessionStore) saveToFile() error { - if s.filePath == "" { - return nil - } - - s.mu.RLock() - sessionsCopy := make(map[string]*Session) - maps.Copy(sessionsCopy, s.sessions) - s.mu.RUnlock() - - data := persistedData{ - Sessions: sessionsCopy, - SavedAt: time.Now(), - } - - jsonData, err := json.MarshalIndent(data, "", " ") - if err != nil { - return err - } - - return os.WriteFile(s.filePath, jsonData, 0600) -} - -func (s *SessionStore) startPeriodicSave() { - go func() { - ticker := time.NewTicker(s.saveInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - s.Cleanup() - s.saveToFile() - case <-s.stopChan: - s.saveToFile() - return - } - } - }() -} - -func (s *SessionStore) Close() error { - close(s.stopChan) - return s.saveToFile() -} - -func (s *SessionStore) Stats() (total, active int) { - s.mu.RLock() - defer s.mu.RUnlock() - - now := time.Now() - total = len(s.sessions) - - for _, session := range s.sessions { - if now.Before(session.ExpiresAt) { - active++ - } - } - - return -} diff --git a/internal/csrf/csrf.go b/internal/csrf/csrf.go index c24dfe7..79ac51b 100644 --- a/internal/csrf/csrf.go +++ b/internal/csrf/csrf.go @@ -1,3 +1,23 @@ +// Package csrf provides Cross-Site Request Forgery (CSRF) protection +// with session-based token storage and form helpers. +// +// # Basic Usage +// +// // Generate token and store in session +// token := csrf.GenerateToken(ctx, authManager) +// +// // In templates - generate hidden input field +// hiddenField := csrf.HiddenField(ctx, authManager) +// +// // Verify form submission +// if !csrf.ValidateToken(ctx, authManager, formToken) { +// // Handle CSRF validation failure +// } +// +// # Middleware Integration +// +// // Add CSRF middleware to protected routes +// r.Use(middleware.CSRF(authManager)) package csrf import ( @@ -9,6 +29,7 @@ import ( "dk/internal/auth" "dk/internal/router" + "dk/internal/session" "github.com/valyala/fasthttp" ) @@ -22,9 +43,9 @@ const ( ) // GetCurrentSession retrieves the session from context (mirrors middleware function) -func GetCurrentSession(ctx router.Ctx) *auth.Session { - if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok { - return session +func GetCurrentSession(ctx router.Ctx) *session.Session { + if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok { + return sess } return nil } @@ -97,23 +118,17 @@ func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken } // StoreToken saves a CSRF token in the session -func StoreToken(session *auth.Session, token string) { - if session.Data == nil { - session.Data = make(map[string]any) - } - session.Data[SessionKey] = token +func StoreToken(sess *session.Session, token string) { + sess.Set(SessionKey, token) } // GetStoredToken retrieves the CSRF token from session -func GetStoredToken(session *auth.Session) string { - if session.Data == nil { - return "" +func GetStoredToken(sess *session.Session) string { + if token, ok := sess.Get(SessionKey); ok { + if tokenStr, ok := token.(string); ok { + return tokenStr + } } - - if token, ok := session.Data[SessionKey].(string); ok { - return token - } - return "" } diff --git a/internal/csrf/csrf_test.go b/internal/csrf/csrf_test.go index 2513ce5..c4d04e8 100644 --- a/internal/csrf/csrf_test.go +++ b/internal/csrf/csrf_test.go @@ -4,14 +4,13 @@ import ( "testing" "time" - "dk/internal/auth" + "dk/internal/session" "github.com/valyala/fasthttp" ) func TestGenerateToken(t *testing.T) { - // Create a mock session - session := &auth.Session{ + sess := &session.Session{ ID: "test-session", UserID: 1, Username: "testuser", @@ -22,27 +21,23 @@ func TestGenerateToken(t *testing.T) { Data: make(map[string]any), } - // Create mock context ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, session) + ctx.SetUserValue(SessionCtxKey, sess) - // Generate token token := GenerateToken(ctx, nil) if token == "" { t.Error("Expected non-empty token") } - // Check that token was stored in session - storedToken := GetStoredToken(session) + storedToken := GetStoredToken(sess) if storedToken != token { t.Errorf("Expected stored token %s, got %s", token, storedToken) } } func TestValidateToken(t *testing.T) { - // Create session with token - session := &auth.Session{ + sess := &session.Session{ ID: "test-session", UserID: 1, Username: "testuser", @@ -51,19 +46,16 @@ func TestValidateToken(t *testing.T) { } ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, session) + ctx.SetUserValue(SessionCtxKey, sess) - // Valid token should pass if !ValidateToken(ctx, nil, "test-token") { t.Error("Expected valid token to pass validation") } - // Invalid token should fail if ValidateToken(ctx, nil, "wrong-token") { t.Error("Expected invalid token to fail validation") } - // Empty token should fail if ValidateToken(ctx, nil, "") { t.Error("Expected empty token to fail validation") } @@ -72,14 +64,13 @@ func TestValidateToken(t *testing.T) { func TestValidateTokenNoSession(t *testing.T) { ctx := &fasthttp.RequestCtx{} - // No session should fail validation if ValidateToken(ctx, nil, "any-token") { t.Error("Expected validation to fail with no session") } } func TestHiddenField(t *testing.T) { - session := &auth.Session{ + sess := &session.Session{ ID: "test-session", UserID: 1, Username: "testuser", @@ -88,7 +79,7 @@ func TestHiddenField(t *testing.T) { } ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, session) + ctx.SetUserValue(SessionCtxKey, sess) field := HiddenField(ctx, nil) expected := `` @@ -102,13 +93,13 @@ func TestHiddenFieldNoSession(t *testing.T) { ctx := &fasthttp.RequestCtx{} field := HiddenField(ctx, nil) - if field != "" { - t.Errorf("Expected empty field with no session, got %s", field) + if field == "" { + t.Error("Expected non-empty field for guest user with cookie-based token") } } func TestTokenMeta(t *testing.T) { - session := &auth.Session{ + sess := &session.Session{ ID: "test-session", UserID: 1, Username: "testuser", @@ -117,7 +108,7 @@ func TestTokenMeta(t *testing.T) { } ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, session) + ctx.SetUserValue(SessionCtxKey, sess) meta := TokenMeta(ctx, nil) expected := `` @@ -128,30 +119,30 @@ func TestTokenMeta(t *testing.T) { } func TestStoreAndGetToken(t *testing.T) { - session := &auth.Session{ + sess := &session.Session{ Data: make(map[string]any), } token := "test-token" - StoreToken(session, token) + StoreToken(sess, token) - retrieved := GetStoredToken(session) + retrieved := GetStoredToken(sess) if retrieved != token { t.Errorf("Expected %s, got %s", token, retrieved) } } func TestGetStoredTokenNoData(t *testing.T) { - session := &auth.Session{} + sess := &session.Session{} - token := GetStoredToken(session) + token := GetStoredToken(sess) if token != "" { t.Errorf("Expected empty token, got %s", token) } } func TestValidateFormToken(t *testing.T) { - session := &auth.Session{ + sess := &session.Session{ ID: "test-session", UserID: 1, Username: "testuser", @@ -160,16 +151,14 @@ func TestValidateFormToken(t *testing.T) { } ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, session) + ctx.SetUserValue(SessionCtxKey, sess) - // Set form data ctx.PostArgs().Set(TokenFieldName, "test-token") if !ValidateFormToken(ctx, nil) { t.Error("Expected form token validation to pass") } - // Test with wrong token ctx.PostArgs().Set(TokenFieldName, "wrong-token") if ValidateFormToken(ctx, nil) { diff --git a/internal/csrf/doc.go b/internal/csrf/doc.go deleted file mode 100644 index 6bc6cab..0000000 --- a/internal/csrf/doc.go +++ /dev/null @@ -1,29 +0,0 @@ -// Package csrf provides Cross-Site Request Forgery (CSRF) protection -// with session-based token storage and form helpers. -// -// # Basic Usage -// -// // Generate token and store in session -// token := csrf.GenerateToken(ctx, authManager) -// -// // In templates - generate hidden input field -// hiddenField := csrf.HiddenField(ctx, authManager) -// -// // Verify form submission -// if !csrf.ValidateToken(ctx, authManager, formToken) { -// // Handle CSRF validation failure -// } -// -// # Middleware Integration -// -// // Add CSRF middleware to protected routes -// r.Use(middleware.CSRF(authManager)) -// -// # Security Features -// -// - Cryptographically secure token generation -// - Session-based token storage and validation -// - Automatic token rotation on successful validation -// - Protection against timing attacks with constant-time comparison -// - Integration with existing authentication system -package csrf \ No newline at end of file diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 9e401f8..b3907c2 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -3,30 +3,26 @@ package middleware import ( "dk/internal/auth" "dk/internal/router" + "dk/internal/session" "dk/internal/users" "github.com/valyala/fasthttp" ) -// Auth creates an authentication middleware func Auth(authManager *auth.AuthManager) router.Middleware { return func(next router.Handler) router.Handler { return func(ctx router.Ctx, params []string) { sessionID := auth.GetSessionCookie(ctx) if sessionID != "" { - if session, exists := authManager.GetSession(sessionID); exists { - // Update session activity + if sess, exists := authManager.GetSession(sessionID); exists { authManager.UpdateSession(sessionID) - // Get the full user object - user, err := users.Find(session.UserID) + user, err := users.Find(sess.UserID) if err == nil && user != nil { - // Store session and user info in context - ctx.SetUserValue("session", session) + ctx.SetUserValue("session", sess) ctx.SetUserValue("user", user) - // Refresh the cookie auth.SetSessionCookie(ctx, sessionID) } } @@ -37,7 +33,6 @@ func Auth(authManager *auth.AuthManager) router.Middleware { } } -// RequireAuth enforces authentication - redirect defaults to "/login" func RequireAuth(paths ...string) router.Middleware { redirect := "/login" if len(paths) > 0 && paths[0] != "" { @@ -56,7 +51,6 @@ func RequireAuth(paths ...string) router.Middleware { } } -// RequireGuest enforces no authentication - redirect defaults to "/" func RequireGuest(paths ...string) router.Middleware { redirect := "/" if len(paths) > 0 && paths[0] != "" { @@ -74,13 +68,11 @@ func RequireGuest(paths ...string) router.Middleware { } } -// IsAuthenticated checks if the current request has a valid session func IsAuthenticated(ctx router.Ctx) bool { _, exists := ctx.UserValue("user").(*users.User) return exists } -// GetCurrentUser returns the current authenticated user, or nil if not authenticated func GetCurrentUser(ctx router.Ctx) *users.User { if user, ok := ctx.UserValue("user").(*users.User); ok { return user @@ -88,25 +80,21 @@ func GetCurrentUser(ctx router.Ctx) *users.User { return nil } -// GetCurrentSession returns the current session, or nil if not authenticated -func GetCurrentSession(ctx router.Ctx) *auth.Session { - if session, ok := ctx.UserValue("session").(*auth.Session); ok { - return session +func GetCurrentSession(ctx router.Ctx) *session.Session { + if sess, ok := ctx.UserValue("session").(*session.Session); ok { + return sess } return nil } -// Login creates a session and sets the cookie func Login(ctx router.Ctx, authManager *auth.AuthManager, user *users.User) { - session := authManager.CreateSession(user) - auth.SetSessionCookie(ctx, session.ID) + sess := authManager.CreateSession(user) + auth.SetSessionCookie(ctx, sess.ID) - // Set in context for immediate use - ctx.SetUserValue("session", session) + ctx.SetUserValue("session", sess) ctx.SetUserValue("user", user) } -// Logout destroys the session and clears the cookie func Logout(ctx router.Ctx, authManager *auth.AuthManager) { sessionID := auth.GetSessionCookie(ctx) if sessionID != "" { @@ -115,7 +103,6 @@ func Logout(ctx router.Ctx, authManager *auth.AuthManager) { auth.DeleteSessionCookie(ctx) - // Clear from context ctx.SetUserValue("session", nil) ctx.SetUserValue("user", nil) -} +} \ No newline at end of file diff --git a/internal/session/flash.go b/internal/session/flash.go new file mode 100644 index 0000000..9c96036 --- /dev/null +++ b/internal/session/flash.go @@ -0,0 +1,56 @@ +package session + +type FlashMessage struct { + Type string `json:"type"` + Message string `json:"message"` +} + +func (s *Session) SetFlash(key string, value any) { + if s.Data == nil { + s.Data = make(map[string]any) + } + + flashData, ok := s.Data["_flash"].(map[string]any) + if !ok { + flashData = make(map[string]any) + } + flashData[key] = value + s.Data["_flash"] = flashData +} + +func (s *Session) GetFlash(key string) (any, bool) { + if s.Data == nil { + return nil, false + } + + flashData, ok := s.Data["_flash"].(map[string]any) + if !ok { + return nil, false + } + + value, exists := flashData[key] + if exists { + delete(flashData, key) + if len(flashData) == 0 { + delete(s.Data, "_flash") + } else { + s.Data["_flash"] = flashData + } + } + + return value, exists +} + +func (s *Session) GetAllFlash() map[string]any { + if s.Data == nil { + return nil + } + + flashData, ok := s.Data["_flash"].(map[string]any) + if !ok { + return nil + } + + delete(s.Data, "_flash") + return flashData +} \ No newline at end of file diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..96fbc2a --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,74 @@ +// Package session provides session management functionality. +// It includes session storage, flash messages, and data persistence. +package session + +import ( + "crypto/rand" + "encoding/hex" + "time" +) + +const ( + DefaultExpiration = 24 * time.Hour + IDLength = 32 +) + +type Session struct { + ID string `json:"-"` + UserID int `json:"user_id"` + Username string `json:"username"` + Email string `json:"email"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + LastSeen time.Time `json:"last_seen"` + Data map[string]any `json:"data,omitempty"` +} + +func New(userID int, username, email string) *Session { + return &Session{ + ID: generateID(), + UserID: userID, + Username: username, + Email: email, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(DefaultExpiration), + LastSeen: time.Now(), + Data: make(map[string]any), + } +} + +func (s *Session) IsExpired() bool { + return time.Now().After(s.ExpiresAt) +} + +func (s *Session) Touch() { + s.LastSeen = time.Now() + s.ExpiresAt = time.Now().Add(DefaultExpiration) +} + +func (s *Session) Set(key string, value any) { + if s.Data == nil { + s.Data = make(map[string]any) + } + s.Data[key] = value +} + +func (s *Session) Get(key string) (any, bool) { + if s.Data == nil { + return nil, false + } + value, exists := s.Data[key] + return value, exists +} + +func (s *Session) Delete(key string) { + if s.Data != nil { + delete(s.Data, key) + } +} + +func generateID() string { + bytes := make([]byte, IDLength) + rand.Read(bytes) + return hex.EncodeToString(bytes) +} diff --git a/internal/session/store.go b/internal/session/store.go new file mode 100644 index 0000000..6a5a410 --- /dev/null +++ b/internal/session/store.go @@ -0,0 +1,161 @@ +package session + +import ( + "encoding/json" + "maps" + "os" + "sync" + "time" +) + +type Store struct { + mu sync.RWMutex + sessions map[string]*Session + filePath string + saveInterval time.Duration + stopChan chan struct{} +} + +type persistedData struct { + Sessions map[string]*Session `json:"sessions"` + SavedAt time.Time `json:"saved_at"` +} + +func NewStore(filePath string) *Store { + store := &Store{ + sessions: make(map[string]*Session), + filePath: filePath, + saveInterval: 5 * time.Minute, + stopChan: make(chan struct{}), + } + + store.loadFromFile() + store.startPeriodicSave() + + return store +} + +func (s *Store) Save(session *Session) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions[session.ID] = session +} + +func (s *Store) Get(sessionID string) (*Session, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + session, exists := s.sessions[sessionID] + if !exists { + return nil, false + } + + if session.IsExpired() { + return nil, false + } + + return session, true +} + +func (s *Store) Delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, sessionID) +} + +func (s *Store) Cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + + for id, session := range s.sessions { + if session.IsExpired() { + delete(s.sessions, id) + } + } +} + +func (s *Store) Stats() (total, active int) { + s.mu.RLock() + defer s.mu.RUnlock() + + total = len(s.sessions) + for _, session := range s.sessions { + if !session.IsExpired() { + active++ + } + } + + return +} + +func (s *Store) loadFromFile() { + if s.filePath == "" { + return + } + + data, err := os.ReadFile(s.filePath) + if err != nil { + return + } + + var persisted persistedData + if err := json.Unmarshal(data, &persisted); err != nil { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + for id, session := range persisted.Sessions { + if !session.IsExpired() { + session.ID = id + s.sessions[id] = session + } + } +} + +func (s *Store) saveToFile() error { + if s.filePath == "" { + return nil + } + + s.mu.RLock() + sessionsCopy := make(map[string]*Session, len(s.sessions)) + maps.Copy(sessionsCopy, s.sessions) + s.mu.RUnlock() + + data := persistedData{ + Sessions: sessionsCopy, + SavedAt: time.Now(), + } + + jsonData, err := json.MarshalIndent(data, "", " ") + if err != nil { + return err + } + + return os.WriteFile(s.filePath, jsonData, 0600) +} + +func (s *Store) startPeriodicSave() { + go func() { + ticker := time.NewTicker(s.saveInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.Cleanup() + s.saveToFile() + case <-s.stopChan: + s.saveToFile() + return + } + } + }() +} + +func (s *Store) Close() error { + close(s.stopChan) + return s.saveToFile() +}