separate session into its own package, clean up some docs

This commit is contained in:
Sky Johnson 2025-08-11 13:21:07 -05:00
parent b8b77351d0
commit 4a73b7cc0d
12 changed files with 412 additions and 479 deletions

View File

@ -1,33 +1,30 @@
// Package auth provides authentication and session management functionality.
// It includes secure session storage with in-memory caching and JSON persistence,
// user authentication against the database, and secure cookie handling.
package auth package auth
import ( import (
"dk/internal/password" "dk/internal/password"
"dk/internal/session"
"dk/internal/users" "dk/internal/users"
) )
// Manager is the global singleton instance
var Manager *AuthManager var Manager *AuthManager
// AuthManager is a wrapper for the session store to add
// authentication tools over the store itself
type AuthManager struct { type AuthManager struct {
store *SessionStore store *session.Store
} }
// Init initializes the global auth manager (auth.Manager)
func Init(sessionsFilePath string) { func Init(sessionsFilePath string) {
Manager = &AuthManager{ Manager = &AuthManager{
store: NewSessionStore(sessionsFilePath), store: session.NewStore(sessionsFilePath),
} }
} }
// Authenticate checks for the usernaname or email, then verifies the plain password
// against the stored hash.
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) { func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) {
var user *users.User var user *users.User
var err error var err error
// Try to find user by username first
user, err = users.GetByUsername(usernameOrEmail) user, err = users.GetByUsername(usernameOrEmail)
if err != nil { if err != nil {
user, err = users.GetByEmail(usernameOrEmail) user, err = users.GetByEmail(usernameOrEmail)
@ -47,16 +44,25 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*use
return user, nil return user, nil
} }
func (am *AuthManager) CreateSession(user *users.User) *Session { func (am *AuthManager) CreateSession(user *users.User) *session.Session {
return am.store.Create(user.ID, user.Username, user.Email) sess := session.New(user.ID, user.Username, user.Email)
am.store.Save(sess)
return sess
} }
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) { func (am *AuthManager) GetSession(sessionID string) (*session.Session, bool) {
return am.store.Get(sessionID) return am.store.Get(sessionID)
} }
func (am *AuthManager) UpdateSession(sessionID string) bool { func (am *AuthManager) UpdateSession(sessionID string) bool {
return am.store.Update(sessionID) sess, exists := am.store.Get(sessionID)
if !exists {
return false
}
sess.Touch()
am.store.Save(sess)
return true
} }
func (am *AuthManager) DeleteSession(sessionID string) { func (am *AuthManager) DeleteSession(sessionID string) {
@ -71,124 +77,6 @@ func (am *AuthManager) Close() error {
return am.store.Close() return am.store.Close()
} }
// SetFlash stores a flash message in the session that will be removed after retrieval
func (am *AuthManager) SetFlash(sessionID, key string, value any) bool {
session, exists := am.store.Get(sessionID)
if !exists {
return false
}
am.store.mu.Lock()
defer am.store.mu.Unlock()
if session.Data == nil {
session.Data = make(map[string]any)
}
// Store flash messages under a special key
flashData, ok := session.Data["_flash"].(map[string]any)
if !ok {
flashData = make(map[string]any)
}
flashData[key] = value
session.Data["_flash"] = flashData
return true
}
// GetFlash retrieves and removes a flash message from the session
func (am *AuthManager) GetFlash(sessionID, key string) (any, bool) {
session, exists := am.store.Get(sessionID)
if !exists {
return nil, false
}
am.store.mu.Lock()
defer am.store.mu.Unlock()
if session.Data == nil {
return nil, false
}
flashData, ok := session.Data["_flash"].(map[string]any)
if !ok {
return nil, false
}
value, exists := flashData[key]
if exists {
delete(flashData, key)
if len(flashData) == 0 {
delete(session.Data, "_flash")
} else {
session.Data["_flash"] = flashData
}
}
return value, exists
}
// GetAllFlash retrieves and removes all flash messages from the session
func (am *AuthManager) GetAllFlash(sessionID string) map[string]any {
session, exists := am.store.Get(sessionID)
if !exists {
return nil
}
am.store.mu.Lock()
defer am.store.mu.Unlock()
if session.Data == nil {
return nil
}
flashData, ok := session.Data["_flash"].(map[string]any)
if !ok {
return nil
}
// Remove flash data from session
delete(session.Data, "_flash")
return flashData
}
// SetSessionData stores arbitrary data in the session
func (am *AuthManager) SetSessionData(sessionID, key string, value any) bool {
session, exists := am.store.Get(sessionID)
if !exists {
return false
}
am.store.mu.Lock()
defer am.store.mu.Unlock()
if session.Data == nil {
session.Data = make(map[string]any)
}
session.Data[key] = value
return true
}
// GetSessionData retrieves data from the session
func (am *AuthManager) GetSessionData(sessionID, key string) (any, bool) {
session, exists := am.store.Get(sessionID)
if !exists {
return nil, false
}
am.store.mu.RLock()
defer am.store.mu.RUnlock()
if session.Data == nil {
return nil, false
}
value, exists := session.Data[key]
return value, exists
}
var ( var (
ErrInvalidCredentials = &AuthError{"invalid username/email or password"} ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
ErrSessionNotFound = &AuthError{"session not found"} ErrSessionNotFound = &AuthError{"session not found"}

View File

@ -2,18 +2,21 @@ package auth
import ( import (
"dk/internal/cookies" "dk/internal/cookies"
"dk/internal/session"
"dk/internal/utils" "dk/internal/utils"
"time" "time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
const SessionCookieName = "dk_session"
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) { func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
cookies.SetSecureCookie(ctx, cookies.CookieOptions{ cookies.SetSecureCookie(ctx, cookies.CookieOptions{
Name: SessionCookieName, Name: SessionCookieName,
Value: sessionID, Value: sessionID,
Path: "/", Path: "/",
Expires: time.Now().Add(DefaultExpiration), Expires: time.Now().Add(session.DefaultExpiration),
HTTPOnly: true, HTTPOnly: true,
Secure: utils.IsHTTPS(ctx), Secure: utils.IsHTTPS(ctx),
SameSite: "lax", SameSite: "lax",

View File

@ -1,4 +0,0 @@
// Package auth provides authentication and session management functionality.
// It includes secure session storage with in-memory caching and JSON persistence,
// user authentication against the database, and secure cookie handling.
package auth

View File

@ -2,46 +2,52 @@ package auth
import ( import (
"dk/internal/router" "dk/internal/router"
"dk/internal/session"
) )
// FlashMessage represents a flash message with type and content
type FlashMessage struct {
Type string `json:"type"` // "error", "success", "warning", "info"
Message string `json:"message"`
}
// SetFlashMessage sets a flash message for the current session
func SetFlashMessage(ctx router.Ctx, msgType, message string) bool { func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
sessionID := GetSessionCookie(ctx) sessionID := GetSessionCookie(ctx)
if sessionID == "" { if sessionID == "" {
return false return false
} }
return Manager.SetFlash(sessionID, "message", FlashMessage{ sess, exists := Manager.GetSession(sessionID)
if !exists {
return false
}
sess.SetFlash("message", session.FlashMessage{
Type: msgType, Type: msgType,
Message: message, Message: message,
}) })
Manager.store.Save(sess)
return true
} }
// GetFlashMessage retrieves and removes the flash message from the current session func GetFlashMessage(ctx router.Ctx) *session.FlashMessage {
func GetFlashMessage(ctx router.Ctx) *FlashMessage {
sessionID := GetSessionCookie(ctx) sessionID := GetSessionCookie(ctx)
if sessionID == "" { if sessionID == "" {
return nil return nil
} }
value, exists := Manager.GetFlash(sessionID, "message") sess, exists := Manager.GetSession(sessionID)
if !exists { if !exists {
return nil return nil
} }
if msg, ok := value.(FlashMessage); ok { value, exists := sess.GetFlash("message")
if !exists {
return nil
}
Manager.store.Save(sess)
if msg, ok := value.(session.FlashMessage); ok {
return &msg return &msg
} }
// Handle map[string]interface{} from JSON deserialization
if msgMap, ok := value.(map[string]interface{}); ok { if msgMap, ok := value.(map[string]interface{}); ok {
msg := &FlashMessage{} msg := &session.FlashMessage{}
if t, ok := msgMap["type"].(string); ok { if t, ok := msgMap["type"].(string); ok {
msg.Type = t msg.Type = t
} }
@ -54,36 +60,45 @@ func GetFlashMessage(ctx router.Ctx) *FlashMessage {
return nil return nil
} }
// SetFormData stores form data temporarily in the session (for repopulating forms after errors)
func SetFormData(ctx router.Ctx, data map[string]string) bool { func SetFormData(ctx router.Ctx, data map[string]string) bool {
sessionID := GetSessionCookie(ctx) sessionID := GetSessionCookie(ctx)
if sessionID == "" { if sessionID == "" {
return false return false
} }
return Manager.SetSessionData(sessionID, "form_data", data) sess, exists := Manager.GetSession(sessionID)
if !exists {
return false
}
sess.Set("form_data", data)
Manager.store.Save(sess)
return true
} }
// GetFormData retrieves and removes form data from the session
func GetFormData(ctx router.Ctx) map[string]string { func GetFormData(ctx router.Ctx) map[string]string {
sessionID := GetSessionCookie(ctx) sessionID := GetSessionCookie(ctx)
if sessionID == "" { if sessionID == "" {
return nil return nil
} }
value, exists := Manager.GetSessionData(sessionID, "form_data") sess, exists := Manager.GetSession(sessionID)
if !exists { if !exists {
return nil return nil
} }
// Clear form data after retrieval value, exists := sess.Get("form_data")
Manager.SetSessionData(sessionID, "form_data", nil) if !exists {
return nil
}
sess.Delete("form_data")
Manager.store.Save(sess)
if formData, ok := value.(map[string]string); ok { if formData, ok := value.(map[string]string); ok {
return formData return formData
} }
// Handle map[string]interface{} from JSON deserialization
if formMap, ok := value.(map[string]interface{}); ok { if formMap, ok := value.(map[string]interface{}); ok {
result := make(map[string]string) result := make(map[string]string)
for k, v := range formMap { for k, v := range formMap {

View File

@ -1,222 +0,0 @@
package auth
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"maps"
"os"
"sync"
"time"
)
const (
SessionCookieName = "dk_session"
DefaultExpiration = 24 * time.Hour
SessionIDLength = 32
)
type Session struct {
ID string `json:"-"` // Exclude from JSON since it's stored as the map key
UserID int `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
LastSeen time.Time `json:"last_seen"`
Data map[string]any `json:"data,omitempty"` // For storing additional session data
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
saveInterval time.Duration
stopChan chan struct{}
}
type persistedData struct {
Sessions map[string]*Session `json:"sessions"`
SavedAt time.Time `json:"saved_at"`
}
func NewSessionStore(filePath string) *SessionStore {
store := &SessionStore{
sessions: make(map[string]*Session),
filePath: filePath,
saveInterval: 5 * time.Minute,
stopChan: make(chan struct{}),
}
store.loadFromFile()
store.startPeriodicSave()
return store
}
func (s *SessionStore) generateSessionID() string {
bytes := make([]byte, SessionIDLength)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
func (s *SessionStore) Create(userID int, username, email string) *Session {
s.mu.Lock()
defer s.mu.Unlock()
session := &Session{
ID: s.generateSessionID(),
UserID: userID,
Username: username,
Email: email,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(DefaultExpiration),
LastSeen: time.Now(),
}
s.sessions[session.ID] = session
return session
}
func (s *SessionStore) Get(sessionID string) (*Session, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, false
}
if time.Now().After(session.ExpiresAt) {
delete(s.sessions, sessionID)
return nil, false
}
return session, true
}
func (s *SessionStore) Update(sessionID string) bool {
s.mu.Lock()
defer s.mu.Unlock()
session, exists := s.sessions[sessionID]
if !exists {
return false
}
if time.Now().After(session.ExpiresAt) {
delete(s.sessions, sessionID)
return false
}
session.LastSeen = time.Now()
session.ExpiresAt = time.Now().Add(DefaultExpiration)
return true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
}
func (s *SessionStore) loadFromFile() {
if s.filePath == "" {
return
}
data, err := os.ReadFile(s.filePath)
if err != nil {
return // File might not exist yet
}
var persisted persistedData
if err := json.Unmarshal(data, &persisted); err != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range persisted.Sessions {
if now.Before(session.ExpiresAt) {
s.sessions[id] = session
}
}
}
func (s *SessionStore) saveToFile() error {
if s.filePath == "" {
return nil
}
s.mu.RLock()
sessionsCopy := make(map[string]*Session)
maps.Copy(sessionsCopy, s.sessions)
s.mu.RUnlock()
data := persistedData{
Sessions: sessionsCopy,
SavedAt: time.Now(),
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.filePath, jsonData, 0600)
}
func (s *SessionStore) startPeriodicSave() {
go func() {
ticker := time.NewTicker(s.saveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.Cleanup()
s.saveToFile()
case <-s.stopChan:
s.saveToFile()
return
}
}
}()
}
func (s *SessionStore) Close() error {
close(s.stopChan)
return s.saveToFile()
}
func (s *SessionStore) Stats() (total, active int) {
s.mu.RLock()
defer s.mu.RUnlock()
now := time.Now()
total = len(s.sessions)
for _, session := range s.sessions {
if now.Before(session.ExpiresAt) {
active++
}
}
return
}

View File

@ -1,3 +1,23 @@
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
// with session-based token storage and form helpers.
//
// # Basic Usage
//
// // Generate token and store in session
// token := csrf.GenerateToken(ctx, authManager)
//
// // In templates - generate hidden input field
// hiddenField := csrf.HiddenField(ctx, authManager)
//
// // Verify form submission
// if !csrf.ValidateToken(ctx, authManager, formToken) {
// // Handle CSRF validation failure
// }
//
// # Middleware Integration
//
// // Add CSRF middleware to protected routes
// r.Use(middleware.CSRF(authManager))
package csrf package csrf
import ( import (
@ -9,6 +29,7 @@ import (
"dk/internal/auth" "dk/internal/auth"
"dk/internal/router" "dk/internal/router"
"dk/internal/session"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -22,9 +43,9 @@ const (
) )
// GetCurrentSession retrieves the session from context (mirrors middleware function) // GetCurrentSession retrieves the session from context (mirrors middleware function)
func GetCurrentSession(ctx router.Ctx) *auth.Session { func GetCurrentSession(ctx router.Ctx) *session.Session {
if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok { if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return session return sess
} }
return nil return nil
} }
@ -97,23 +118,17 @@ func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken
} }
// StoreToken saves a CSRF token in the session // StoreToken saves a CSRF token in the session
func StoreToken(session *auth.Session, token string) { func StoreToken(sess *session.Session, token string) {
if session.Data == nil { sess.Set(SessionKey, token)
session.Data = make(map[string]any)
}
session.Data[SessionKey] = token
} }
// GetStoredToken retrieves the CSRF token from session // GetStoredToken retrieves the CSRF token from session
func GetStoredToken(session *auth.Session) string { func GetStoredToken(sess *session.Session) string {
if session.Data == nil { if token, ok := sess.Get(SessionKey); ok {
return "" if tokenStr, ok := token.(string); ok {
return tokenStr
} }
if token, ok := session.Data[SessionKey].(string); ok {
return token
} }
return "" return ""
} }

View File

@ -4,14 +4,13 @@ import (
"testing" "testing"
"time" "time"
"dk/internal/auth" "dk/internal/session"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
func TestGenerateToken(t *testing.T) { func TestGenerateToken(t *testing.T) {
// Create a mock session sess := &session.Session{
session := &auth.Session{
ID: "test-session", ID: "test-session",
UserID: 1, UserID: 1,
Username: "testuser", Username: "testuser",
@ -22,27 +21,23 @@ func TestGenerateToken(t *testing.T) {
Data: make(map[string]any), Data: make(map[string]any),
} }
// Create mock context
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session) ctx.SetUserValue(SessionCtxKey, sess)
// Generate token
token := GenerateToken(ctx, nil) token := GenerateToken(ctx, nil)
if token == "" { if token == "" {
t.Error("Expected non-empty token") t.Error("Expected non-empty token")
} }
// Check that token was stored in session storedToken := GetStoredToken(sess)
storedToken := GetStoredToken(session)
if storedToken != token { if storedToken != token {
t.Errorf("Expected stored token %s, got %s", token, storedToken) t.Errorf("Expected stored token %s, got %s", token, storedToken)
} }
} }
func TestValidateToken(t *testing.T) { func TestValidateToken(t *testing.T) {
// Create session with token sess := &session.Session{
session := &auth.Session{
ID: "test-session", ID: "test-session",
UserID: 1, UserID: 1,
Username: "testuser", Username: "testuser",
@ -51,19 +46,16 @@ func TestValidateToken(t *testing.T) {
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session) ctx.SetUserValue(SessionCtxKey, sess)
// Valid token should pass
if !ValidateToken(ctx, nil, "test-token") { if !ValidateToken(ctx, nil, "test-token") {
t.Error("Expected valid token to pass validation") t.Error("Expected valid token to pass validation")
} }
// Invalid token should fail
if ValidateToken(ctx, nil, "wrong-token") { if ValidateToken(ctx, nil, "wrong-token") {
t.Error("Expected invalid token to fail validation") t.Error("Expected invalid token to fail validation")
} }
// Empty token should fail
if ValidateToken(ctx, nil, "") { if ValidateToken(ctx, nil, "") {
t.Error("Expected empty token to fail validation") t.Error("Expected empty token to fail validation")
} }
@ -72,14 +64,13 @@ func TestValidateToken(t *testing.T) {
func TestValidateTokenNoSession(t *testing.T) { func TestValidateTokenNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// No session should fail validation
if ValidateToken(ctx, nil, "any-token") { if ValidateToken(ctx, nil, "any-token") {
t.Error("Expected validation to fail with no session") t.Error("Expected validation to fail with no session")
} }
} }
func TestHiddenField(t *testing.T) { func TestHiddenField(t *testing.T) {
session := &auth.Session{ sess := &session.Session{
ID: "test-session", ID: "test-session",
UserID: 1, UserID: 1,
Username: "testuser", Username: "testuser",
@ -88,7 +79,7 @@ func TestHiddenField(t *testing.T) {
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session) ctx.SetUserValue(SessionCtxKey, sess)
field := HiddenField(ctx, nil) field := HiddenField(ctx, nil)
expected := `<input type="hidden" name="_csrf_token" value="test-token">` expected := `<input type="hidden" name="_csrf_token" value="test-token">`
@ -102,13 +93,13 @@ func TestHiddenFieldNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
field := HiddenField(ctx, nil) field := HiddenField(ctx, nil)
if field != "" { if field == "" {
t.Errorf("Expected empty field with no session, got %s", field) t.Error("Expected non-empty field for guest user with cookie-based token")
} }
} }
func TestTokenMeta(t *testing.T) { func TestTokenMeta(t *testing.T) {
session := &auth.Session{ sess := &session.Session{
ID: "test-session", ID: "test-session",
UserID: 1, UserID: 1,
Username: "testuser", Username: "testuser",
@ -117,7 +108,7 @@ func TestTokenMeta(t *testing.T) {
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session) ctx.SetUserValue(SessionCtxKey, sess)
meta := TokenMeta(ctx, nil) meta := TokenMeta(ctx, nil)
expected := `<meta name="csrf-token" content="test-token">` expected := `<meta name="csrf-token" content="test-token">`
@ -128,30 +119,30 @@ func TestTokenMeta(t *testing.T) {
} }
func TestStoreAndGetToken(t *testing.T) { func TestStoreAndGetToken(t *testing.T) {
session := &auth.Session{ sess := &session.Session{
Data: make(map[string]any), Data: make(map[string]any),
} }
token := "test-token" token := "test-token"
StoreToken(session, token) StoreToken(sess, token)
retrieved := GetStoredToken(session) retrieved := GetStoredToken(sess)
if retrieved != token { if retrieved != token {
t.Errorf("Expected %s, got %s", token, retrieved) t.Errorf("Expected %s, got %s", token, retrieved)
} }
} }
func TestGetStoredTokenNoData(t *testing.T) { func TestGetStoredTokenNoData(t *testing.T) {
session := &auth.Session{} sess := &session.Session{}
token := GetStoredToken(session) token := GetStoredToken(sess)
if token != "" { if token != "" {
t.Errorf("Expected empty token, got %s", token) t.Errorf("Expected empty token, got %s", token)
} }
} }
func TestValidateFormToken(t *testing.T) { func TestValidateFormToken(t *testing.T) {
session := &auth.Session{ sess := &session.Session{
ID: "test-session", ID: "test-session",
UserID: 1, UserID: 1,
Username: "testuser", Username: "testuser",
@ -160,16 +151,14 @@ func TestValidateFormToken(t *testing.T) {
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session) ctx.SetUserValue(SessionCtxKey, sess)
// Set form data
ctx.PostArgs().Set(TokenFieldName, "test-token") ctx.PostArgs().Set(TokenFieldName, "test-token")
if !ValidateFormToken(ctx, nil) { if !ValidateFormToken(ctx, nil) {
t.Error("Expected form token validation to pass") t.Error("Expected form token validation to pass")
} }
// Test with wrong token
ctx.PostArgs().Set(TokenFieldName, "wrong-token") ctx.PostArgs().Set(TokenFieldName, "wrong-token")
if ValidateFormToken(ctx, nil) { if ValidateFormToken(ctx, nil) {

View File

@ -1,29 +0,0 @@
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
// with session-based token storage and form helpers.
//
// # Basic Usage
//
// // Generate token and store in session
// token := csrf.GenerateToken(ctx, authManager)
//
// // In templates - generate hidden input field
// hiddenField := csrf.HiddenField(ctx, authManager)
//
// // Verify form submission
// if !csrf.ValidateToken(ctx, authManager, formToken) {
// // Handle CSRF validation failure
// }
//
// # Middleware Integration
//
// // Add CSRF middleware to protected routes
// r.Use(middleware.CSRF(authManager))
//
// # Security Features
//
// - Cryptographically secure token generation
// - Session-based token storage and validation
// - Automatic token rotation on successful validation
// - Protection against timing attacks with constant-time comparison
// - Integration with existing authentication system
package csrf

View File

@ -3,30 +3,26 @@ package middleware
import ( import (
"dk/internal/auth" "dk/internal/auth"
"dk/internal/router" "dk/internal/router"
"dk/internal/session"
"dk/internal/users" "dk/internal/users"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
// Auth creates an authentication middleware
func Auth(authManager *auth.AuthManager) router.Middleware { func Auth(authManager *auth.AuthManager) router.Middleware {
return func(next router.Handler) router.Handler { return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
sessionID := auth.GetSessionCookie(ctx) sessionID := auth.GetSessionCookie(ctx)
if sessionID != "" { if sessionID != "" {
if session, exists := authManager.GetSession(sessionID); exists { if sess, exists := authManager.GetSession(sessionID); exists {
// Update session activity
authManager.UpdateSession(sessionID) authManager.UpdateSession(sessionID)
// Get the full user object user, err := users.Find(sess.UserID)
user, err := users.Find(session.UserID)
if err == nil && user != nil { if err == nil && user != nil {
// Store session and user info in context ctx.SetUserValue("session", sess)
ctx.SetUserValue("session", session)
ctx.SetUserValue("user", user) ctx.SetUserValue("user", user)
// Refresh the cookie
auth.SetSessionCookie(ctx, sessionID) auth.SetSessionCookie(ctx, sessionID)
} }
} }
@ -37,7 +33,6 @@ func Auth(authManager *auth.AuthManager) router.Middleware {
} }
} }
// RequireAuth enforces authentication - redirect defaults to "/login"
func RequireAuth(paths ...string) router.Middleware { func RequireAuth(paths ...string) router.Middleware {
redirect := "/login" redirect := "/login"
if len(paths) > 0 && paths[0] != "" { if len(paths) > 0 && paths[0] != "" {
@ -56,7 +51,6 @@ func RequireAuth(paths ...string) router.Middleware {
} }
} }
// RequireGuest enforces no authentication - redirect defaults to "/"
func RequireGuest(paths ...string) router.Middleware { func RequireGuest(paths ...string) router.Middleware {
redirect := "/" redirect := "/"
if len(paths) > 0 && paths[0] != "" { if len(paths) > 0 && paths[0] != "" {
@ -74,13 +68,11 @@ func RequireGuest(paths ...string) router.Middleware {
} }
} }
// IsAuthenticated checks if the current request has a valid session
func IsAuthenticated(ctx router.Ctx) bool { func IsAuthenticated(ctx router.Ctx) bool {
_, exists := ctx.UserValue("user").(*users.User) _, exists := ctx.UserValue("user").(*users.User)
return exists return exists
} }
// GetCurrentUser returns the current authenticated user, or nil if not authenticated
func GetCurrentUser(ctx router.Ctx) *users.User { func GetCurrentUser(ctx router.Ctx) *users.User {
if user, ok := ctx.UserValue("user").(*users.User); ok { if user, ok := ctx.UserValue("user").(*users.User); ok {
return user return user
@ -88,25 +80,21 @@ func GetCurrentUser(ctx router.Ctx) *users.User {
return nil return nil
} }
// GetCurrentSession returns the current session, or nil if not authenticated func GetCurrentSession(ctx router.Ctx) *session.Session {
func GetCurrentSession(ctx router.Ctx) *auth.Session { if sess, ok := ctx.UserValue("session").(*session.Session); ok {
if session, ok := ctx.UserValue("session").(*auth.Session); ok { return sess
return session
} }
return nil return nil
} }
// Login creates a session and sets the cookie
func Login(ctx router.Ctx, authManager *auth.AuthManager, user *users.User) { func Login(ctx router.Ctx, authManager *auth.AuthManager, user *users.User) {
session := authManager.CreateSession(user) sess := authManager.CreateSession(user)
auth.SetSessionCookie(ctx, session.ID) auth.SetSessionCookie(ctx, sess.ID)
// Set in context for immediate use ctx.SetUserValue("session", sess)
ctx.SetUserValue("session", session)
ctx.SetUserValue("user", user) ctx.SetUserValue("user", user)
} }
// Logout destroys the session and clears the cookie
func Logout(ctx router.Ctx, authManager *auth.AuthManager) { func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
sessionID := auth.GetSessionCookie(ctx) sessionID := auth.GetSessionCookie(ctx)
if sessionID != "" { if sessionID != "" {
@ -115,7 +103,6 @@ func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
auth.DeleteSessionCookie(ctx) auth.DeleteSessionCookie(ctx)
// Clear from context
ctx.SetUserValue("session", nil) ctx.SetUserValue("session", nil)
ctx.SetUserValue("user", nil) ctx.SetUserValue("user", nil)
} }

56
internal/session/flash.go Normal file
View File

@ -0,0 +1,56 @@
package session
type FlashMessage struct {
Type string `json:"type"`
Message string `json:"message"`
}
func (s *Session) SetFlash(key string, value any) {
if s.Data == nil {
s.Data = make(map[string]any)
}
flashData, ok := s.Data["_flash"].(map[string]any)
if !ok {
flashData = make(map[string]any)
}
flashData[key] = value
s.Data["_flash"] = flashData
}
func (s *Session) GetFlash(key string) (any, bool) {
if s.Data == nil {
return nil, false
}
flashData, ok := s.Data["_flash"].(map[string]any)
if !ok {
return nil, false
}
value, exists := flashData[key]
if exists {
delete(flashData, key)
if len(flashData) == 0 {
delete(s.Data, "_flash")
} else {
s.Data["_flash"] = flashData
}
}
return value, exists
}
func (s *Session) GetAllFlash() map[string]any {
if s.Data == nil {
return nil
}
flashData, ok := s.Data["_flash"].(map[string]any)
if !ok {
return nil
}
delete(s.Data, "_flash")
return flashData
}

View File

@ -0,0 +1,74 @@
// Package session provides session management functionality.
// It includes session storage, flash messages, and data persistence.
package session
import (
"crypto/rand"
"encoding/hex"
"time"
)
const (
DefaultExpiration = 24 * time.Hour
IDLength = 32
)
type Session struct {
ID string `json:"-"`
UserID int `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
LastSeen time.Time `json:"last_seen"`
Data map[string]any `json:"data,omitempty"`
}
func New(userID int, username, email string) *Session {
return &Session{
ID: generateID(),
UserID: userID,
Username: username,
Email: email,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(DefaultExpiration),
LastSeen: time.Now(),
Data: make(map[string]any),
}
}
func (s *Session) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
func (s *Session) Touch() {
s.LastSeen = time.Now()
s.ExpiresAt = time.Now().Add(DefaultExpiration)
}
func (s *Session) Set(key string, value any) {
if s.Data == nil {
s.Data = make(map[string]any)
}
s.Data[key] = value
}
func (s *Session) Get(key string) (any, bool) {
if s.Data == nil {
return nil, false
}
value, exists := s.Data[key]
return value, exists
}
func (s *Session) Delete(key string) {
if s.Data != nil {
delete(s.Data, key)
}
}
func generateID() string {
bytes := make([]byte, IDLength)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}

161
internal/session/store.go Normal file
View File

@ -0,0 +1,161 @@
package session
import (
"encoding/json"
"maps"
"os"
"sync"
"time"
)
type Store struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
saveInterval time.Duration
stopChan chan struct{}
}
type persistedData struct {
Sessions map[string]*Session `json:"sessions"`
SavedAt time.Time `json:"saved_at"`
}
func NewStore(filePath string) *Store {
store := &Store{
sessions: make(map[string]*Session),
filePath: filePath,
saveInterval: 5 * time.Minute,
stopChan: make(chan struct{}),
}
store.loadFromFile()
store.startPeriodicSave()
return store
}
func (s *Store) Save(session *Session) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[session.ID] = session
}
func (s *Store) Get(sessionID string) (*Session, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, false
}
if session.IsExpired() {
return nil, false
}
return session, true
}
func (s *Store) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *Store) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
for id, session := range s.sessions {
if session.IsExpired() {
delete(s.sessions, id)
}
}
}
func (s *Store) Stats() (total, active int) {
s.mu.RLock()
defer s.mu.RUnlock()
total = len(s.sessions)
for _, session := range s.sessions {
if !session.IsExpired() {
active++
}
}
return
}
func (s *Store) loadFromFile() {
if s.filePath == "" {
return
}
data, err := os.ReadFile(s.filePath)
if err != nil {
return
}
var persisted persistedData
if err := json.Unmarshal(data, &persisted); err != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
for id, session := range persisted.Sessions {
if !session.IsExpired() {
session.ID = id
s.sessions[id] = session
}
}
}
func (s *Store) saveToFile() error {
if s.filePath == "" {
return nil
}
s.mu.RLock()
sessionsCopy := make(map[string]*Session, len(s.sessions))
maps.Copy(sessionsCopy, s.sessions)
s.mu.RUnlock()
data := persistedData{
Sessions: sessionsCopy,
SavedAt: time.Now(),
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.filePath, jsonData, 0600)
}
func (s *Store) startPeriodicSave() {
go func() {
ticker := time.NewTicker(s.saveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.Cleanup()
s.saveToFile()
case <-s.stopChan:
s.saveToFile()
return
}
}
}()
}
func (s *Store) Close() error {
close(s.stopChan)
return s.saveToFile()
}