diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index e850a6c..b277f57 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -1,33 +1,30 @@
+// Package auth provides authentication and session management functionality.
+// It includes secure session storage with in-memory caching and JSON persistence,
+// user authentication against the database, and secure cookie handling.
package auth
import (
"dk/internal/password"
+ "dk/internal/session"
"dk/internal/users"
)
-// Manager is the global singleton instance
var Manager *AuthManager
-// AuthManager is a wrapper for the session store to add
-// authentication tools over the store itself
type AuthManager struct {
- store *SessionStore
+ store *session.Store
}
-// Init initializes the global auth manager (auth.Manager)
func Init(sessionsFilePath string) {
Manager = &AuthManager{
- store: NewSessionStore(sessionsFilePath),
+ store: session.NewStore(sessionsFilePath),
}
}
-// Authenticate checks for the usernaname or email, then verifies the plain password
-// against the stored hash.
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) {
var user *users.User
var err error
- // Try to find user by username first
user, err = users.GetByUsername(usernameOrEmail)
if err != nil {
user, err = users.GetByEmail(usernameOrEmail)
@@ -47,16 +44,25 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*use
return user, nil
}
-func (am *AuthManager) CreateSession(user *users.User) *Session {
- return am.store.Create(user.ID, user.Username, user.Email)
+func (am *AuthManager) CreateSession(user *users.User) *session.Session {
+ sess := session.New(user.ID, user.Username, user.Email)
+ am.store.Save(sess)
+ return sess
}
-func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
+func (am *AuthManager) GetSession(sessionID string) (*session.Session, bool) {
return am.store.Get(sessionID)
}
func (am *AuthManager) UpdateSession(sessionID string) bool {
- return am.store.Update(sessionID)
+ sess, exists := am.store.Get(sessionID)
+ if !exists {
+ return false
+ }
+
+ sess.Touch()
+ am.store.Save(sess)
+ return true
}
func (am *AuthManager) DeleteSession(sessionID string) {
@@ -71,124 +77,6 @@ func (am *AuthManager) Close() error {
return am.store.Close()
}
-// SetFlash stores a flash message in the session that will be removed after retrieval
-func (am *AuthManager) SetFlash(sessionID, key string, value any) bool {
- session, exists := am.store.Get(sessionID)
- if !exists {
- return false
- }
-
- am.store.mu.Lock()
- defer am.store.mu.Unlock()
-
- if session.Data == nil {
- session.Data = make(map[string]any)
- }
-
- // Store flash messages under a special key
- flashData, ok := session.Data["_flash"].(map[string]any)
- if !ok {
- flashData = make(map[string]any)
- }
- flashData[key] = value
- session.Data["_flash"] = flashData
-
- return true
-}
-
-// GetFlash retrieves and removes a flash message from the session
-func (am *AuthManager) GetFlash(sessionID, key string) (any, bool) {
- session, exists := am.store.Get(sessionID)
- if !exists {
- return nil, false
- }
-
- am.store.mu.Lock()
- defer am.store.mu.Unlock()
-
- if session.Data == nil {
- return nil, false
- }
-
- flashData, ok := session.Data["_flash"].(map[string]any)
- if !ok {
- return nil, false
- }
-
- value, exists := flashData[key]
- if exists {
- delete(flashData, key)
- if len(flashData) == 0 {
- delete(session.Data, "_flash")
- } else {
- session.Data["_flash"] = flashData
- }
- }
-
- return value, exists
-}
-
-// GetAllFlash retrieves and removes all flash messages from the session
-func (am *AuthManager) GetAllFlash(sessionID string) map[string]any {
- session, exists := am.store.Get(sessionID)
- if !exists {
- return nil
- }
-
- am.store.mu.Lock()
- defer am.store.mu.Unlock()
-
- if session.Data == nil {
- return nil
- }
-
- flashData, ok := session.Data["_flash"].(map[string]any)
- if !ok {
- return nil
- }
-
- // Remove flash data from session
- delete(session.Data, "_flash")
-
- return flashData
-}
-
-// SetSessionData stores arbitrary data in the session
-func (am *AuthManager) SetSessionData(sessionID, key string, value any) bool {
- session, exists := am.store.Get(sessionID)
- if !exists {
- return false
- }
-
- am.store.mu.Lock()
- defer am.store.mu.Unlock()
-
- if session.Data == nil {
- session.Data = make(map[string]any)
- }
-
- session.Data[key] = value
- return true
-}
-
-// GetSessionData retrieves data from the session
-func (am *AuthManager) GetSessionData(sessionID, key string) (any, bool) {
- session, exists := am.store.Get(sessionID)
- if !exists {
- return nil, false
- }
-
- am.store.mu.RLock()
- defer am.store.mu.RUnlock()
-
- if session.Data == nil {
- return nil, false
- }
-
- value, exists := session.Data[key]
- return value, exists
-}
-
var (
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
ErrSessionNotFound = &AuthError{"session not found"}
diff --git a/internal/auth/cookies.go b/internal/auth/cookies.go
index 519f410..5904a56 100644
--- a/internal/auth/cookies.go
+++ b/internal/auth/cookies.go
@@ -2,18 +2,21 @@ package auth
import (
"dk/internal/cookies"
+ "dk/internal/session"
"dk/internal/utils"
"time"
"github.com/valyala/fasthttp"
)
+const SessionCookieName = "dk_session"
+
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
Name: SessionCookieName,
Value: sessionID,
Path: "/",
- Expires: time.Now().Add(DefaultExpiration),
+ Expires: time.Now().Add(session.DefaultExpiration),
HTTPOnly: true,
Secure: utils.IsHTTPS(ctx),
SameSite: "lax",
@@ -26,4 +29,4 @@ func GetSessionCookie(ctx *fasthttp.RequestCtx) string {
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
cookies.DeleteCookie(ctx, SessionCookieName)
-}
+}
\ No newline at end of file
diff --git a/internal/auth/doc.go b/internal/auth/doc.go
deleted file mode 100644
index 48cd94f..0000000
--- a/internal/auth/doc.go
+++ /dev/null
@@ -1,4 +0,0 @@
-// Package auth provides authentication and session management functionality.
-// It includes secure session storage with in-memory caching and JSON persistence,
-// user authentication against the database, and secure cookie handling.
-package auth
\ No newline at end of file
diff --git a/internal/auth/flash.go b/internal/auth/flash.go
index b5e2ce9..d62bd4b 100644
--- a/internal/auth/flash.go
+++ b/internal/auth/flash.go
@@ -2,46 +2,52 @@ package auth
import (
"dk/internal/router"
+ "dk/internal/session"
)
-// FlashMessage represents a flash message with type and content
-type FlashMessage struct {
- Type string `json:"type"` // "error", "success", "warning", "info"
- Message string `json:"message"`
-}
-
-// SetFlashMessage sets a flash message for the current session
func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
sessionID := GetSessionCookie(ctx)
if sessionID == "" {
return false
}
- return Manager.SetFlash(sessionID, "message", FlashMessage{
+ sess, exists := Manager.GetSession(sessionID)
+ if !exists {
+ return false
+ }
+
+ sess.SetFlash("message", session.FlashMessage{
Type: msgType,
Message: message,
})
+ Manager.store.Save(sess)
+ return true
}
-// GetFlashMessage retrieves and removes the flash message from the current session
-func GetFlashMessage(ctx router.Ctx) *FlashMessage {
+func GetFlashMessage(ctx router.Ctx) *session.FlashMessage {
sessionID := GetSessionCookie(ctx)
if sessionID == "" {
return nil
}
- value, exists := Manager.GetFlash(sessionID, "message")
+ sess, exists := Manager.GetSession(sessionID)
if !exists {
return nil
}
- if msg, ok := value.(FlashMessage); ok {
+ value, exists := sess.GetFlash("message")
+ if !exists {
+ return nil
+ }
+
+ Manager.store.Save(sess)
+
+ if msg, ok := value.(session.FlashMessage); ok {
return &msg
}
- // Handle map[string]interface{} from JSON deserialization
if msgMap, ok := value.(map[string]interface{}); ok {
- msg := &FlashMessage{}
+ msg := &session.FlashMessage{}
if t, ok := msgMap["type"].(string); ok {
msg.Type = t
}
@@ -54,36 +60,45 @@ func GetFlashMessage(ctx router.Ctx) *FlashMessage {
return nil
}
-// SetFormData stores form data temporarily in the session (for repopulating forms after errors)
func SetFormData(ctx router.Ctx, data map[string]string) bool {
sessionID := GetSessionCookie(ctx)
if sessionID == "" {
return false
}
- return Manager.SetSessionData(sessionID, "form_data", data)
+ sess, exists := Manager.GetSession(sessionID)
+ if !exists {
+ return false
+ }
+
+ sess.Set("form_data", data)
+ Manager.store.Save(sess)
+ return true
}
-// GetFormData retrieves and removes form data from the session
func GetFormData(ctx router.Ctx) map[string]string {
sessionID := GetSessionCookie(ctx)
if sessionID == "" {
return nil
}
- value, exists := Manager.GetSessionData(sessionID, "form_data")
+ sess, exists := Manager.GetSession(sessionID)
if !exists {
return nil
}
- // Clear form data after retrieval
- Manager.SetSessionData(sessionID, "form_data", nil)
+ value, exists := sess.Get("form_data")
+ if !exists {
+ return nil
+ }
+
+ sess.Delete("form_data")
+ Manager.store.Save(sess)
if formData, ok := value.(map[string]string); ok {
return formData
}
- // Handle map[string]interface{} from JSON deserialization
if formMap, ok := value.(map[string]interface{}); ok {
result := make(map[string]string)
for k, v := range formMap {
diff --git a/internal/auth/session.go b/internal/auth/session.go
deleted file mode 100644
index f141bf1..0000000
--- a/internal/auth/session.go
+++ /dev/null
@@ -1,222 +0,0 @@
-package auth
-
-import (
- "crypto/rand"
- "encoding/hex"
- "encoding/json"
- "maps"
- "os"
- "sync"
- "time"
-)
-
-const (
- SessionCookieName = "dk_session"
- DefaultExpiration = 24 * time.Hour
- SessionIDLength = 32
-)
-
-type Session struct {
- ID string `json:"-"` // Exclude from JSON since it's stored as the map key
- UserID int `json:"user_id"`
- Username string `json:"username"`
- Email string `json:"email"`
- CreatedAt time.Time `json:"created_at"`
- ExpiresAt time.Time `json:"expires_at"`
- LastSeen time.Time `json:"last_seen"`
- Data map[string]any `json:"data,omitempty"` // For storing additional session data
-}
-
-type SessionStore struct {
- mu sync.RWMutex
- sessions map[string]*Session
- filePath string
- saveInterval time.Duration
- stopChan chan struct{}
-}
-
-type persistedData struct {
- Sessions map[string]*Session `json:"sessions"`
- SavedAt time.Time `json:"saved_at"`
-}
-
-func NewSessionStore(filePath string) *SessionStore {
- store := &SessionStore{
- sessions: make(map[string]*Session),
- filePath: filePath,
- saveInterval: 5 * time.Minute,
- stopChan: make(chan struct{}),
- }
-
- store.loadFromFile()
- store.startPeriodicSave()
-
- return store
-}
-
-func (s *SessionStore) generateSessionID() string {
- bytes := make([]byte, SessionIDLength)
- rand.Read(bytes)
- return hex.EncodeToString(bytes)
-}
-
-func (s *SessionStore) Create(userID int, username, email string) *Session {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- session := &Session{
- ID: s.generateSessionID(),
- UserID: userID,
- Username: username,
- Email: email,
- CreatedAt: time.Now(),
- ExpiresAt: time.Now().Add(DefaultExpiration),
- LastSeen: time.Now(),
- }
-
- s.sessions[session.ID] = session
- return session
-}
-
-func (s *SessionStore) Get(sessionID string) (*Session, bool) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- session, exists := s.sessions[sessionID]
- if !exists {
- return nil, false
- }
-
- if time.Now().After(session.ExpiresAt) {
- delete(s.sessions, sessionID)
- return nil, false
- }
-
- return session, true
-}
-
-func (s *SessionStore) Update(sessionID string) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- session, exists := s.sessions[sessionID]
- if !exists {
- return false
- }
-
- if time.Now().After(session.ExpiresAt) {
- delete(s.sessions, sessionID)
- return false
- }
-
- session.LastSeen = time.Now()
- session.ExpiresAt = time.Now().Add(DefaultExpiration)
- return true
-}
-
-func (s *SessionStore) Delete(sessionID string) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- delete(s.sessions, sessionID)
-}
-
-func (s *SessionStore) Cleanup() {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- now := time.Now()
- for id, session := range s.sessions {
- if now.After(session.ExpiresAt) {
- delete(s.sessions, id)
- }
- }
-}
-
-func (s *SessionStore) loadFromFile() {
- if s.filePath == "" {
- return
- }
-
- data, err := os.ReadFile(s.filePath)
- if err != nil {
- return // File might not exist yet
- }
-
- var persisted persistedData
- if err := json.Unmarshal(data, &persisted); err != nil {
- return
- }
-
- s.mu.Lock()
- defer s.mu.Unlock()
-
- now := time.Now()
- for id, session := range persisted.Sessions {
- if now.Before(session.ExpiresAt) {
- s.sessions[id] = session
- }
- }
-}
-
-func (s *SessionStore) saveToFile() error {
- if s.filePath == "" {
- return nil
- }
-
- s.mu.RLock()
- sessionsCopy := make(map[string]*Session)
- maps.Copy(sessionsCopy, s.sessions)
- s.mu.RUnlock()
-
- data := persistedData{
- Sessions: sessionsCopy,
- SavedAt: time.Now(),
- }
-
- jsonData, err := json.MarshalIndent(data, "", " ")
- if err != nil {
- return err
- }
-
- return os.WriteFile(s.filePath, jsonData, 0600)
-}
-
-func (s *SessionStore) startPeriodicSave() {
- go func() {
- ticker := time.NewTicker(s.saveInterval)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- s.Cleanup()
- s.saveToFile()
- case <-s.stopChan:
- s.saveToFile()
- return
- }
- }
- }()
-}
-
-func (s *SessionStore) Close() error {
- close(s.stopChan)
- return s.saveToFile()
-}
-
-func (s *SessionStore) Stats() (total, active int) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- now := time.Now()
- total = len(s.sessions)
-
- for _, session := range s.sessions {
- if now.Before(session.ExpiresAt) {
- active++
- }
- }
-
- return
-}
diff --git a/internal/csrf/csrf.go b/internal/csrf/csrf.go
index c24dfe7..79ac51b 100644
--- a/internal/csrf/csrf.go
+++ b/internal/csrf/csrf.go
@@ -1,3 +1,23 @@
+// Package csrf provides Cross-Site Request Forgery (CSRF) protection
+// with session-based token storage and form helpers.
+//
+// # Basic Usage
+//
+// // Generate token and store in session
+// token := csrf.GenerateToken(ctx, authManager)
+//
+// // In templates - generate hidden input field
+// hiddenField := csrf.HiddenField(ctx, authManager)
+//
+// // Verify form submission
+// if !csrf.ValidateToken(ctx, authManager, formToken) {
+// // Handle CSRF validation failure
+// }
+//
+// # Middleware Integration
+//
+// // Add CSRF middleware to protected routes
+// r.Use(middleware.CSRF(authManager))
package csrf
import (
@@ -9,6 +29,7 @@ import (
"dk/internal/auth"
"dk/internal/router"
+ "dk/internal/session"
"github.com/valyala/fasthttp"
)
@@ -22,9 +43,9 @@ const (
)
// GetCurrentSession retrieves the session from context (mirrors middleware function)
-func GetCurrentSession(ctx router.Ctx) *auth.Session {
- if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
- return session
+func GetCurrentSession(ctx router.Ctx) *session.Session {
+ if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
+ return sess
}
return nil
}
@@ -97,23 +118,17 @@ func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken
}
// StoreToken saves a CSRF token in the session
-func StoreToken(session *auth.Session, token string) {
- if session.Data == nil {
- session.Data = make(map[string]any)
- }
- session.Data[SessionKey] = token
+func StoreToken(sess *session.Session, token string) {
+ sess.Set(SessionKey, token)
}
// GetStoredToken retrieves the CSRF token from session
-func GetStoredToken(session *auth.Session) string {
- if session.Data == nil {
- return ""
+func GetStoredToken(sess *session.Session) string {
+ if token, ok := sess.Get(SessionKey); ok {
+ if tokenStr, ok := token.(string); ok {
+ return tokenStr
+ }
}
-
- if token, ok := session.Data[SessionKey].(string); ok {
- return token
- }
-
return ""
}
diff --git a/internal/csrf/csrf_test.go b/internal/csrf/csrf_test.go
index 2513ce5..c4d04e8 100644
--- a/internal/csrf/csrf_test.go
+++ b/internal/csrf/csrf_test.go
@@ -4,14 +4,13 @@ import (
"testing"
"time"
- "dk/internal/auth"
+ "dk/internal/session"
"github.com/valyala/fasthttp"
)
func TestGenerateToken(t *testing.T) {
- // Create a mock session
- session := &auth.Session{
+ sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
@@ -22,27 +21,23 @@ func TestGenerateToken(t *testing.T) {
Data: make(map[string]any),
}
- // Create mock context
ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, session)
+ ctx.SetUserValue(SessionCtxKey, sess)
- // Generate token
token := GenerateToken(ctx, nil)
if token == "" {
t.Error("Expected non-empty token")
}
- // Check that token was stored in session
- storedToken := GetStoredToken(session)
+ storedToken := GetStoredToken(sess)
if storedToken != token {
t.Errorf("Expected stored token %s, got %s", token, storedToken)
}
}
func TestValidateToken(t *testing.T) {
- // Create session with token
- session := &auth.Session{
+ sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
@@ -51,19 +46,16 @@ func TestValidateToken(t *testing.T) {
}
ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, session)
+ ctx.SetUserValue(SessionCtxKey, sess)
- // Valid token should pass
if !ValidateToken(ctx, nil, "test-token") {
t.Error("Expected valid token to pass validation")
}
- // Invalid token should fail
if ValidateToken(ctx, nil, "wrong-token") {
t.Error("Expected invalid token to fail validation")
}
- // Empty token should fail
if ValidateToken(ctx, nil, "") {
t.Error("Expected empty token to fail validation")
}
@@ -72,14 +64,13 @@ func TestValidateToken(t *testing.T) {
func TestValidateTokenNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
- // No session should fail validation
if ValidateToken(ctx, nil, "any-token") {
t.Error("Expected validation to fail with no session")
}
}
func TestHiddenField(t *testing.T) {
- session := &auth.Session{
+ sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
@@ -88,7 +79,7 @@ func TestHiddenField(t *testing.T) {
}
ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, session)
+ ctx.SetUserValue(SessionCtxKey, sess)
field := HiddenField(ctx, nil)
expected := ``
@@ -102,13 +93,13 @@ func TestHiddenFieldNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
field := HiddenField(ctx, nil)
- if field != "" {
- t.Errorf("Expected empty field with no session, got %s", field)
+ if field == "" {
+ t.Error("Expected non-empty field for guest user with cookie-based token")
}
}
func TestTokenMeta(t *testing.T) {
- session := &auth.Session{
+ sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
@@ -117,7 +108,7 @@ func TestTokenMeta(t *testing.T) {
}
ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, session)
+ ctx.SetUserValue(SessionCtxKey, sess)
meta := TokenMeta(ctx, nil)
expected := ``
@@ -128,30 +119,30 @@ func TestTokenMeta(t *testing.T) {
}
func TestStoreAndGetToken(t *testing.T) {
- session := &auth.Session{
+ sess := &session.Session{
Data: make(map[string]any),
}
token := "test-token"
- StoreToken(session, token)
+ StoreToken(sess, token)
- retrieved := GetStoredToken(session)
+ retrieved := GetStoredToken(sess)
if retrieved != token {
t.Errorf("Expected %s, got %s", token, retrieved)
}
}
func TestGetStoredTokenNoData(t *testing.T) {
- session := &auth.Session{}
+ sess := &session.Session{}
- token := GetStoredToken(session)
+ token := GetStoredToken(sess)
if token != "" {
t.Errorf("Expected empty token, got %s", token)
}
}
func TestValidateFormToken(t *testing.T) {
- session := &auth.Session{
+ sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
@@ -160,16 +151,14 @@ func TestValidateFormToken(t *testing.T) {
}
ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, session)
+ ctx.SetUserValue(SessionCtxKey, sess)
- // Set form data
ctx.PostArgs().Set(TokenFieldName, "test-token")
if !ValidateFormToken(ctx, nil) {
t.Error("Expected form token validation to pass")
}
- // Test with wrong token
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
if ValidateFormToken(ctx, nil) {
diff --git a/internal/csrf/doc.go b/internal/csrf/doc.go
deleted file mode 100644
index 6bc6cab..0000000
--- a/internal/csrf/doc.go
+++ /dev/null
@@ -1,29 +0,0 @@
-// Package csrf provides Cross-Site Request Forgery (CSRF) protection
-// with session-based token storage and form helpers.
-//
-// # Basic Usage
-//
-// // Generate token and store in session
-// token := csrf.GenerateToken(ctx, authManager)
-//
-// // In templates - generate hidden input field
-// hiddenField := csrf.HiddenField(ctx, authManager)
-//
-// // Verify form submission
-// if !csrf.ValidateToken(ctx, authManager, formToken) {
-// // Handle CSRF validation failure
-// }
-//
-// # Middleware Integration
-//
-// // Add CSRF middleware to protected routes
-// r.Use(middleware.CSRF(authManager))
-//
-// # Security Features
-//
-// - Cryptographically secure token generation
-// - Session-based token storage and validation
-// - Automatic token rotation on successful validation
-// - Protection against timing attacks with constant-time comparison
-// - Integration with existing authentication system
-package csrf
\ No newline at end of file
diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go
index 9e401f8..b3907c2 100644
--- a/internal/middleware/auth.go
+++ b/internal/middleware/auth.go
@@ -3,30 +3,26 @@ package middleware
import (
"dk/internal/auth"
"dk/internal/router"
+ "dk/internal/session"
"dk/internal/users"
"github.com/valyala/fasthttp"
)
-// Auth creates an authentication middleware
func Auth(authManager *auth.AuthManager) router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
sessionID := auth.GetSessionCookie(ctx)
if sessionID != "" {
- if session, exists := authManager.GetSession(sessionID); exists {
- // Update session activity
+ if sess, exists := authManager.GetSession(sessionID); exists {
authManager.UpdateSession(sessionID)
- // Get the full user object
- user, err := users.Find(session.UserID)
+ user, err := users.Find(sess.UserID)
if err == nil && user != nil {
- // Store session and user info in context
- ctx.SetUserValue("session", session)
+ ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user)
- // Refresh the cookie
auth.SetSessionCookie(ctx, sessionID)
}
}
@@ -37,7 +33,6 @@ func Auth(authManager *auth.AuthManager) router.Middleware {
}
}
-// RequireAuth enforces authentication - redirect defaults to "/login"
func RequireAuth(paths ...string) router.Middleware {
redirect := "/login"
if len(paths) > 0 && paths[0] != "" {
@@ -56,7 +51,6 @@ func RequireAuth(paths ...string) router.Middleware {
}
}
-// RequireGuest enforces no authentication - redirect defaults to "/"
func RequireGuest(paths ...string) router.Middleware {
redirect := "/"
if len(paths) > 0 && paths[0] != "" {
@@ -74,13 +68,11 @@ func RequireGuest(paths ...string) router.Middleware {
}
}
-// IsAuthenticated checks if the current request has a valid session
func IsAuthenticated(ctx router.Ctx) bool {
_, exists := ctx.UserValue("user").(*users.User)
return exists
}
-// GetCurrentUser returns the current authenticated user, or nil if not authenticated
func GetCurrentUser(ctx router.Ctx) *users.User {
if user, ok := ctx.UserValue("user").(*users.User); ok {
return user
@@ -88,25 +80,21 @@ func GetCurrentUser(ctx router.Ctx) *users.User {
return nil
}
-// GetCurrentSession returns the current session, or nil if not authenticated
-func GetCurrentSession(ctx router.Ctx) *auth.Session {
- if session, ok := ctx.UserValue("session").(*auth.Session); ok {
- return session
+func GetCurrentSession(ctx router.Ctx) *session.Session {
+ if sess, ok := ctx.UserValue("session").(*session.Session); ok {
+ return sess
}
return nil
}
-// Login creates a session and sets the cookie
func Login(ctx router.Ctx, authManager *auth.AuthManager, user *users.User) {
- session := authManager.CreateSession(user)
- auth.SetSessionCookie(ctx, session.ID)
+ sess := authManager.CreateSession(user)
+ auth.SetSessionCookie(ctx, sess.ID)
- // Set in context for immediate use
- ctx.SetUserValue("session", session)
+ ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user)
}
-// Logout destroys the session and clears the cookie
func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
sessionID := auth.GetSessionCookie(ctx)
if sessionID != "" {
@@ -115,7 +103,6 @@ func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
auth.DeleteSessionCookie(ctx)
- // Clear from context
ctx.SetUserValue("session", nil)
ctx.SetUserValue("user", nil)
-}
+}
\ No newline at end of file
diff --git a/internal/session/flash.go b/internal/session/flash.go
new file mode 100644
index 0000000..9c96036
--- /dev/null
+++ b/internal/session/flash.go
@@ -0,0 +1,56 @@
+package session
+
+type FlashMessage struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+func (s *Session) SetFlash(key string, value any) {
+ if s.Data == nil {
+ s.Data = make(map[string]any)
+ }
+
+ flashData, ok := s.Data["_flash"].(map[string]any)
+ if !ok {
+ flashData = make(map[string]any)
+ }
+ flashData[key] = value
+ s.Data["_flash"] = flashData
+}
+
+func (s *Session) GetFlash(key string) (any, bool) {
+ if s.Data == nil {
+ return nil, false
+ }
+
+ flashData, ok := s.Data["_flash"].(map[string]any)
+ if !ok {
+ return nil, false
+ }
+
+ value, exists := flashData[key]
+ if exists {
+ delete(flashData, key)
+ if len(flashData) == 0 {
+ delete(s.Data, "_flash")
+ } else {
+ s.Data["_flash"] = flashData
+ }
+ }
+
+ return value, exists
+}
+
+func (s *Session) GetAllFlash() map[string]any {
+ if s.Data == nil {
+ return nil
+ }
+
+ flashData, ok := s.Data["_flash"].(map[string]any)
+ if !ok {
+ return nil
+ }
+
+ delete(s.Data, "_flash")
+ return flashData
+}
\ No newline at end of file
diff --git a/internal/session/session.go b/internal/session/session.go
new file mode 100644
index 0000000..96fbc2a
--- /dev/null
+++ b/internal/session/session.go
@@ -0,0 +1,74 @@
+// Package session provides session management functionality.
+// It includes session storage, flash messages, and data persistence.
+package session
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "time"
+)
+
+const (
+ DefaultExpiration = 24 * time.Hour
+ IDLength = 32
+)
+
+type Session struct {
+ ID string `json:"-"`
+ UserID int `json:"user_id"`
+ Username string `json:"username"`
+ Email string `json:"email"`
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at"`
+ LastSeen time.Time `json:"last_seen"`
+ Data map[string]any `json:"data,omitempty"`
+}
+
+func New(userID int, username, email string) *Session {
+ return &Session{
+ ID: generateID(),
+ UserID: userID,
+ Username: username,
+ Email: email,
+ CreatedAt: time.Now(),
+ ExpiresAt: time.Now().Add(DefaultExpiration),
+ LastSeen: time.Now(),
+ Data: make(map[string]any),
+ }
+}
+
+func (s *Session) IsExpired() bool {
+ return time.Now().After(s.ExpiresAt)
+}
+
+func (s *Session) Touch() {
+ s.LastSeen = time.Now()
+ s.ExpiresAt = time.Now().Add(DefaultExpiration)
+}
+
+func (s *Session) Set(key string, value any) {
+ if s.Data == nil {
+ s.Data = make(map[string]any)
+ }
+ s.Data[key] = value
+}
+
+func (s *Session) Get(key string) (any, bool) {
+ if s.Data == nil {
+ return nil, false
+ }
+ value, exists := s.Data[key]
+ return value, exists
+}
+
+func (s *Session) Delete(key string) {
+ if s.Data != nil {
+ delete(s.Data, key)
+ }
+}
+
+func generateID() string {
+ bytes := make([]byte, IDLength)
+ rand.Read(bytes)
+ return hex.EncodeToString(bytes)
+}
diff --git a/internal/session/store.go b/internal/session/store.go
new file mode 100644
index 0000000..6a5a410
--- /dev/null
+++ b/internal/session/store.go
@@ -0,0 +1,161 @@
+package session
+
+import (
+ "encoding/json"
+ "maps"
+ "os"
+ "sync"
+ "time"
+)
+
+type Store struct {
+ mu sync.RWMutex
+ sessions map[string]*Session
+ filePath string
+ saveInterval time.Duration
+ stopChan chan struct{}
+}
+
+type persistedData struct {
+ Sessions map[string]*Session `json:"sessions"`
+ SavedAt time.Time `json:"saved_at"`
+}
+
+func NewStore(filePath string) *Store {
+ store := &Store{
+ sessions: make(map[string]*Session),
+ filePath: filePath,
+ saveInterval: 5 * time.Minute,
+ stopChan: make(chan struct{}),
+ }
+
+ store.loadFromFile()
+ store.startPeriodicSave()
+
+ return store
+}
+
+func (s *Store) Save(session *Session) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.sessions[session.ID] = session
+}
+
+func (s *Store) Get(sessionID string) (*Session, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ session, exists := s.sessions[sessionID]
+ if !exists {
+ return nil, false
+ }
+
+ if session.IsExpired() {
+ return nil, false
+ }
+
+ return session, true
+}
+
+func (s *Store) Delete(sessionID string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.sessions, sessionID)
+}
+
+func (s *Store) Cleanup() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for id, session := range s.sessions {
+ if session.IsExpired() {
+ delete(s.sessions, id)
+ }
+ }
+}
+
+func (s *Store) Stats() (total, active int) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ total = len(s.sessions)
+ for _, session := range s.sessions {
+ if !session.IsExpired() {
+ active++
+ }
+ }
+
+ return
+}
+
+func (s *Store) loadFromFile() {
+ if s.filePath == "" {
+ return
+ }
+
+ data, err := os.ReadFile(s.filePath)
+ if err != nil {
+ return
+ }
+
+ var persisted persistedData
+ if err := json.Unmarshal(data, &persisted); err != nil {
+ return
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for id, session := range persisted.Sessions {
+ if !session.IsExpired() {
+ session.ID = id
+ s.sessions[id] = session
+ }
+ }
+}
+
+func (s *Store) saveToFile() error {
+ if s.filePath == "" {
+ return nil
+ }
+
+ s.mu.RLock()
+ sessionsCopy := make(map[string]*Session, len(s.sessions))
+ maps.Copy(sessionsCopy, s.sessions)
+ s.mu.RUnlock()
+
+ data := persistedData{
+ Sessions: sessionsCopy,
+ SavedAt: time.Now(),
+ }
+
+ jsonData, err := json.MarshalIndent(data, "", " ")
+ if err != nil {
+ return err
+ }
+
+ return os.WriteFile(s.filePath, jsonData, 0600)
+}
+
+func (s *Store) startPeriodicSave() {
+ go func() {
+ ticker := time.NewTicker(s.saveInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ s.Cleanup()
+ s.saveToFile()
+ case <-s.stopChan:
+ s.saveToFile()
+ return
+ }
+ }
+ }()
+}
+
+func (s *Store) Close() error {
+ close(s.stopChan)
+ return s.saveToFile()
+}