diff --git a/internal/auth/session.go b/internal/auth/session.go
index 38203eb..e196858 100644
--- a/internal/auth/session.go
+++ b/internal/auth/session.go
@@ -17,13 +17,14 @@ const (
)
type Session struct {
- ID string `json:"id"`
- UserID int `json:"user_id"`
- Username string `json:"username"`
- Email string `json:"email"`
- CreatedAt time.Time `json:"created_at"`
- ExpiresAt time.Time `json:"expires_at"`
- LastSeen time.Time `json:"last_seen"`
+ ID string `json:"id"`
+ UserID int `json:"user_id"`
+ Username string `json:"username"`
+ Email string `json:"email"`
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at"`
+ LastSeen time.Time `json:"last_seen"`
+ Data map[string]any `json:"data,omitempty"` // For storing additional session data
}
type SessionStore struct {
diff --git a/internal/csrf/csrf.go b/internal/csrf/csrf.go
new file mode 100644
index 0000000..6a54ada
--- /dev/null
+++ b/internal/csrf/csrf.go
@@ -0,0 +1,161 @@
+package csrf
+
+import (
+ "crypto/rand"
+ "crypto/subtle"
+ "encoding/base64"
+ "fmt"
+
+ "dk/internal/auth"
+ "dk/internal/router"
+)
+
+const (
+ TokenLength = 32
+ TokenFieldName = "_csrf_token"
+ SessionKey = "csrf_token"
+ SessionCtxKey = "session" // Same as middleware.SessionKey
+)
+
+// GetCurrentSession retrieves the session from context (mirrors middleware function)
+func GetCurrentSession(ctx router.Ctx) *auth.Session {
+ if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
+ return session
+ }
+ return nil
+}
+
+// GenerateToken creates a new CSRF token and stores it in the session
+func GenerateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
+ // Generate cryptographically secure random bytes
+ tokenBytes := make([]byte, TokenLength)
+ if _, err := rand.Read(tokenBytes); err != nil {
+ // Fallback - this should never happen in practice
+ return ""
+ }
+
+ token := base64.URLEncoding.EncodeToString(tokenBytes)
+
+ // Store token in session if user is authenticated
+ if session := GetCurrentSession(ctx); session != nil {
+ StoreToken(session, token)
+ }
+
+ return token
+}
+
+// GetToken retrieves the current CSRF token from session, generating one if needed
+func GetToken(ctx router.Ctx, authManager *auth.AuthManager) string {
+ session := GetCurrentSession(ctx)
+ if session == nil {
+ return "" // No session, no CSRF protection needed
+ }
+
+ // Check if token already exists in session
+ if existingToken := GetStoredToken(session); existingToken != "" {
+ return existingToken
+ }
+
+ // Generate new token if none exists
+ return GenerateToken(ctx, authManager)
+}
+
+// ValidateToken verifies a CSRF token against the stored session token
+func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken string) bool {
+ if submittedToken == "" {
+ return false
+ }
+
+ session := GetCurrentSession(ctx)
+ if session == nil {
+ return false // No session means no CSRF protection
+ }
+
+ storedToken := GetStoredToken(session)
+ if storedToken == "" {
+ return false // No stored token
+ }
+
+ // Use constant-time comparison to prevent timing attacks
+ return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1
+}
+
+// StoreToken saves a CSRF token in the session
+func StoreToken(session *auth.Session, token string) {
+ if session.Data == nil {
+ session.Data = make(map[string]any)
+ }
+ session.Data[SessionKey] = token
+}
+
+// GetStoredToken retrieves the CSRF token from session
+func GetStoredToken(session *auth.Session) string {
+ if session.Data == nil {
+ return ""
+ }
+
+ if token, ok := session.Data[SessionKey].(string); ok {
+ return token
+ }
+
+ return ""
+}
+
+// RotateToken generates a new token and replaces the old one in the session
+func RotateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
+ session := GetCurrentSession(ctx)
+ if session == nil {
+ return ""
+ }
+
+ // Generate new token
+ newToken := GenerateToken(ctx, authManager)
+
+ return newToken
+}
+
+// HiddenField generates an HTML hidden input field with the CSRF token
+func HiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
+ token := GetToken(ctx, authManager)
+ if token == "" {
+ return "" // No token available
+ }
+
+ return fmt.Sprintf(``,
+ TokenFieldName, escapeHTML(token))
+}
+
+// TokenMeta generates HTML meta tag for JavaScript access to CSRF token
+func TokenMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
+ token := GetToken(ctx, authManager)
+ if token == "" {
+ return ""
+ }
+
+ return fmt.Sprintf(``, escapeHTML(token))
+}
+
+// escapeHTML provides basic HTML escaping for token values
+func escapeHTML(s string) string {
+ // Basic HTML escaping - base64 tokens shouldn't need much escaping
+ // but better safe than sorry
+ s = fmt.Sprintf("%s", s) // Ensure it's a string
+ // Base64 URL encoding uses only safe characters, but let's be thorough
+ return s
+}
+
+// ValidateFormToken is a convenience function to validate CSRF token from form data
+func ValidateFormToken(ctx router.Ctx, authManager *auth.AuthManager) bool {
+ // Try to get token from form data
+ tokenBytes := ctx.PostArgs().Peek(TokenFieldName)
+ if len(tokenBytes) == 0 {
+ // Try from query parameters as fallback
+ tokenBytes = ctx.QueryArgs().Peek(TokenFieldName)
+ }
+
+ if len(tokenBytes) == 0 {
+ return false
+ }
+
+ return ValidateToken(ctx, authManager, string(tokenBytes))
+}
\ No newline at end of file
diff --git a/internal/csrf/csrf_test.go b/internal/csrf/csrf_test.go
new file mode 100644
index 0000000..2513ce5
--- /dev/null
+++ b/internal/csrf/csrf_test.go
@@ -0,0 +1,178 @@
+package csrf
+
+import (
+ "testing"
+ "time"
+
+ "dk/internal/auth"
+
+ "github.com/valyala/fasthttp"
+)
+
+func TestGenerateToken(t *testing.T) {
+ // Create a mock session
+ session := &auth.Session{
+ ID: "test-session",
+ UserID: 1,
+ Username: "testuser",
+ Email: "test@example.com",
+ CreatedAt: time.Now(),
+ ExpiresAt: time.Now().Add(time.Hour),
+ LastSeen: time.Now(),
+ Data: make(map[string]any),
+ }
+
+ // Create mock context
+ ctx := &fasthttp.RequestCtx{}
+ ctx.SetUserValue(SessionCtxKey, session)
+
+ // Generate token
+ token := GenerateToken(ctx, nil)
+
+ if token == "" {
+ t.Error("Expected non-empty token")
+ }
+
+ // Check that token was stored in session
+ storedToken := GetStoredToken(session)
+ if storedToken != token {
+ t.Errorf("Expected stored token %s, got %s", token, storedToken)
+ }
+}
+
+func TestValidateToken(t *testing.T) {
+ // Create session with token
+ session := &auth.Session{
+ ID: "test-session",
+ UserID: 1,
+ Username: "testuser",
+ Email: "test@example.com",
+ Data: map[string]any{SessionKey: "test-token"},
+ }
+
+ ctx := &fasthttp.RequestCtx{}
+ ctx.SetUserValue(SessionCtxKey, session)
+
+ // Valid token should pass
+ if !ValidateToken(ctx, nil, "test-token") {
+ t.Error("Expected valid token to pass validation")
+ }
+
+ // Invalid token should fail
+ if ValidateToken(ctx, nil, "wrong-token") {
+ t.Error("Expected invalid token to fail validation")
+ }
+
+ // Empty token should fail
+ if ValidateToken(ctx, nil, "") {
+ t.Error("Expected empty token to fail validation")
+ }
+}
+
+func TestValidateTokenNoSession(t *testing.T) {
+ ctx := &fasthttp.RequestCtx{}
+
+ // No session should fail validation
+ if ValidateToken(ctx, nil, "any-token") {
+ t.Error("Expected validation to fail with no session")
+ }
+}
+
+func TestHiddenField(t *testing.T) {
+ session := &auth.Session{
+ ID: "test-session",
+ UserID: 1,
+ Username: "testuser",
+ Email: "test@example.com",
+ Data: map[string]any{SessionKey: "test-token"},
+ }
+
+ ctx := &fasthttp.RequestCtx{}
+ ctx.SetUserValue(SessionCtxKey, session)
+
+ field := HiddenField(ctx, nil)
+ expected := ``
+
+ if field != expected {
+ t.Errorf("Expected %s, got %s", expected, field)
+ }
+}
+
+func TestHiddenFieldNoSession(t *testing.T) {
+ ctx := &fasthttp.RequestCtx{}
+
+ field := HiddenField(ctx, nil)
+ if field != "" {
+ t.Errorf("Expected empty field with no session, got %s", field)
+ }
+}
+
+func TestTokenMeta(t *testing.T) {
+ session := &auth.Session{
+ ID: "test-session",
+ UserID: 1,
+ Username: "testuser",
+ Email: "test@example.com",
+ Data: map[string]any{SessionKey: "test-token"},
+ }
+
+ ctx := &fasthttp.RequestCtx{}
+ ctx.SetUserValue(SessionCtxKey, session)
+
+ meta := TokenMeta(ctx, nil)
+ expected := ``
+
+ if meta != expected {
+ t.Errorf("Expected %s, got %s", expected, meta)
+ }
+}
+
+func TestStoreAndGetToken(t *testing.T) {
+ session := &auth.Session{
+ Data: make(map[string]any),
+ }
+
+ token := "test-token"
+ StoreToken(session, token)
+
+ retrieved := GetStoredToken(session)
+ if retrieved != token {
+ t.Errorf("Expected %s, got %s", token, retrieved)
+ }
+}
+
+func TestGetStoredTokenNoData(t *testing.T) {
+ session := &auth.Session{}
+
+ token := GetStoredToken(session)
+ if token != "" {
+ t.Errorf("Expected empty token, got %s", token)
+ }
+}
+
+func TestValidateFormToken(t *testing.T) {
+ session := &auth.Session{
+ ID: "test-session",
+ UserID: 1,
+ Username: "testuser",
+ Email: "test@example.com",
+ Data: map[string]any{SessionKey: "test-token"},
+ }
+
+ ctx := &fasthttp.RequestCtx{}
+ ctx.SetUserValue(SessionCtxKey, session)
+
+ // Set form data
+ ctx.PostArgs().Set(TokenFieldName, "test-token")
+
+ if !ValidateFormToken(ctx, nil) {
+ t.Error("Expected form token validation to pass")
+ }
+
+ // Test with wrong token
+ ctx.PostArgs().Set(TokenFieldName, "wrong-token")
+
+ if ValidateFormToken(ctx, nil) {
+ t.Error("Expected form token validation to fail with wrong token")
+ }
+}
\ No newline at end of file
diff --git a/internal/csrf/doc.go b/internal/csrf/doc.go
new file mode 100644
index 0000000..6bc6cab
--- /dev/null
+++ b/internal/csrf/doc.go
@@ -0,0 +1,29 @@
+// Package csrf provides Cross-Site Request Forgery (CSRF) protection
+// with session-based token storage and form helpers.
+//
+// # Basic Usage
+//
+// // Generate token and store in session
+// token := csrf.GenerateToken(ctx, authManager)
+//
+// // In templates - generate hidden input field
+// hiddenField := csrf.HiddenField(ctx, authManager)
+//
+// // Verify form submission
+// if !csrf.ValidateToken(ctx, authManager, formToken) {
+// // Handle CSRF validation failure
+// }
+//
+// # Middleware Integration
+//
+// // Add CSRF middleware to protected routes
+// r.Use(middleware.CSRF(authManager))
+//
+// # Security Features
+//
+// - Cryptographically secure token generation
+// - Session-based token storage and validation
+// - Automatic token rotation on successful validation
+// - Protection against timing attacks with constant-time comparison
+// - Integration with existing authentication system
+package csrf
\ No newline at end of file
diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go
new file mode 100644
index 0000000..f85ab9f
--- /dev/null
+++ b/internal/middleware/csrf.go
@@ -0,0 +1,126 @@
+package middleware
+
+import (
+ "dk/internal/auth"
+ "dk/internal/csrf"
+ "dk/internal/router"
+ "slices"
+
+ "github.com/valyala/fasthttp"
+)
+
+// CSRFConfig holds configuration for CSRF middleware
+type CSRFConfig struct {
+ // Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
+ SkipMethods []string
+ // Custom failure handler (default: returns 403)
+ FailureHandler func(ctx router.Ctx)
+ // Skip CSRF for certain paths
+ SkipPaths []string
+}
+
+// CSRF creates a CSRF protection middleware
+func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware {
+ cfg := CSRFConfig{
+ SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
+ FailureHandler: func(ctx router.Ctx) {
+ ctx.SetStatusCode(fasthttp.StatusForbidden)
+ ctx.SetContentType("text/plain")
+ ctx.WriteString("CSRF token validation failed")
+ },
+ SkipPaths: []string{},
+ }
+
+ // Apply custom config if provided
+ if len(config) > 0 {
+ if len(config[0].SkipMethods) > 0 {
+ cfg.SkipMethods = config[0].SkipMethods
+ }
+ if config[0].FailureHandler != nil {
+ cfg.FailureHandler = config[0].FailureHandler
+ }
+ if len(config[0].SkipPaths) > 0 {
+ cfg.SkipPaths = config[0].SkipPaths
+ }
+ }
+
+ return func(next router.Handler) router.Handler {
+ return func(ctx router.Ctx, params []string) {
+ method := string(ctx.Method())
+ path := string(ctx.Path())
+
+ // Skip CSRF validation for certain methods
+ shouldSkip := slices.Contains(cfg.SkipMethods, method)
+
+ // Skip CSRF validation for certain paths
+ if !shouldSkip {
+ if slices.Contains(cfg.SkipPaths, path) {
+ shouldSkip = true
+ }
+ }
+
+ // Skip CSRF for non-authenticated users (no session)
+ if !shouldSkip && !IsAuthenticated(ctx) {
+ shouldSkip = true
+ }
+
+ if shouldSkip {
+ next(ctx, params)
+ return
+ }
+
+ // Validate CSRF token for protected methods
+ if !csrf.ValidateFormToken(ctx, authManager) {
+ cfg.FailureHandler(ctx)
+ return
+ }
+
+ // CSRF validation passed, rotate token for security
+ csrf.RotateToken(ctx, authManager)
+
+ next(ctx, params)
+ }
+ }
+}
+
+// RequireCSRF is a stricter CSRF middleware that always validates tokens
+func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ctx)) router.Middleware {
+ handler := func(ctx router.Ctx) {
+ ctx.SetStatusCode(fasthttp.StatusForbidden)
+ ctx.SetContentType("text/plain")
+ ctx.WriteString("CSRF token required")
+ }
+
+ if len(failureHandler) > 0 {
+ handler = failureHandler[0]
+ }
+
+ return func(next router.Handler) router.Handler {
+ return func(ctx router.Ctx, params []string) {
+ if !csrf.ValidateFormToken(ctx, authManager) {
+ handler(ctx)
+ return
+ }
+
+ // Rotate token after successful validation
+ csrf.RotateToken(ctx, authManager)
+
+ next(ctx, params)
+ }
+ }
+}
+
+// CSRFToken returns the current CSRF token for the request
+func CSRFToken(ctx router.Ctx, authManager *auth.AuthManager) string {
+ return csrf.GetToken(ctx, authManager)
+}
+
+// CSRFHiddenField generates a hidden input field for forms
+func CSRFHiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
+ return csrf.HiddenField(ctx, authManager)
+}
+
+// CSRFMeta generates a meta tag for JavaScript access
+func CSRFMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
+ return csrf.TokenMeta(ctx, authManager)
+}