From cec2b12c35cd1158cd4aab9663d4f97d8cc56b4b Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Sat, 9 Aug 2025 09:49:50 -0500 Subject: [PATCH] add CSRF middleware, and to session data --- internal/auth/session.go | 15 +-- internal/csrf/csrf.go | 161 ++++++++++++++++++++++++++++++++ internal/csrf/csrf_test.go | 178 ++++++++++++++++++++++++++++++++++++ internal/csrf/doc.go | 29 ++++++ internal/middleware/csrf.go | 126 +++++++++++++++++++++++++ 5 files changed, 502 insertions(+), 7 deletions(-) create mode 100644 internal/csrf/csrf.go create mode 100644 internal/csrf/csrf_test.go create mode 100644 internal/csrf/doc.go create mode 100644 internal/middleware/csrf.go diff --git a/internal/auth/session.go b/internal/auth/session.go index 38203eb..e196858 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -17,13 +17,14 @@ const ( ) type Session struct { - ID string `json:"id"` - UserID int `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - CreatedAt time.Time `json:"created_at"` - ExpiresAt time.Time `json:"expires_at"` - LastSeen time.Time `json:"last_seen"` + ID string `json:"id"` + UserID int `json:"user_id"` + Username string `json:"username"` + Email string `json:"email"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + LastSeen time.Time `json:"last_seen"` + Data map[string]any `json:"data,omitempty"` // For storing additional session data } type SessionStore struct { diff --git a/internal/csrf/csrf.go b/internal/csrf/csrf.go new file mode 100644 index 0000000..6a54ada --- /dev/null +++ b/internal/csrf/csrf.go @@ -0,0 +1,161 @@ +package csrf + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + + "dk/internal/auth" + "dk/internal/router" +) + +const ( + TokenLength = 32 + TokenFieldName = "_csrf_token" + SessionKey = "csrf_token" + SessionCtxKey = "session" // Same as middleware.SessionKey +) + +// GetCurrentSession retrieves the session from context (mirrors middleware function) +func GetCurrentSession(ctx router.Ctx) *auth.Session { + if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok { + return session + } + return nil +} + +// GenerateToken creates a new CSRF token and stores it in the session +func GenerateToken(ctx router.Ctx, authManager *auth.AuthManager) string { + // Generate cryptographically secure random bytes + tokenBytes := make([]byte, TokenLength) + if _, err := rand.Read(tokenBytes); err != nil { + // Fallback - this should never happen in practice + return "" + } + + token := base64.URLEncoding.EncodeToString(tokenBytes) + + // Store token in session if user is authenticated + if session := GetCurrentSession(ctx); session != nil { + StoreToken(session, token) + } + + return token +} + +// GetToken retrieves the current CSRF token from session, generating one if needed +func GetToken(ctx router.Ctx, authManager *auth.AuthManager) string { + session := GetCurrentSession(ctx) + if session == nil { + return "" // No session, no CSRF protection needed + } + + // Check if token already exists in session + if existingToken := GetStoredToken(session); existingToken != "" { + return existingToken + } + + // Generate new token if none exists + return GenerateToken(ctx, authManager) +} + +// ValidateToken verifies a CSRF token against the stored session token +func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken string) bool { + if submittedToken == "" { + return false + } + + session := GetCurrentSession(ctx) + if session == nil { + return false // No session means no CSRF protection + } + + storedToken := GetStoredToken(session) + if storedToken == "" { + return false // No stored token + } + + // Use constant-time comparison to prevent timing attacks + return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1 +} + +// StoreToken saves a CSRF token in the session +func StoreToken(session *auth.Session, token string) { + if session.Data == nil { + session.Data = make(map[string]any) + } + session.Data[SessionKey] = token +} + +// GetStoredToken retrieves the CSRF token from session +func GetStoredToken(session *auth.Session) string { + if session.Data == nil { + return "" + } + + if token, ok := session.Data[SessionKey].(string); ok { + return token + } + + return "" +} + +// RotateToken generates a new token and replaces the old one in the session +func RotateToken(ctx router.Ctx, authManager *auth.AuthManager) string { + session := GetCurrentSession(ctx) + if session == nil { + return "" + } + + // Generate new token + newToken := GenerateToken(ctx, authManager) + + return newToken +} + +// HiddenField generates an HTML hidden input field with the CSRF token +func HiddenField(ctx router.Ctx, authManager *auth.AuthManager) string { + token := GetToken(ctx, authManager) + if token == "" { + return "" // No token available + } + + return fmt.Sprintf(``, + TokenFieldName, escapeHTML(token)) +} + +// TokenMeta generates HTML meta tag for JavaScript access to CSRF token +func TokenMeta(ctx router.Ctx, authManager *auth.AuthManager) string { + token := GetToken(ctx, authManager) + if token == "" { + return "" + } + + return fmt.Sprintf(``, escapeHTML(token)) +} + +// escapeHTML provides basic HTML escaping for token values +func escapeHTML(s string) string { + // Basic HTML escaping - base64 tokens shouldn't need much escaping + // but better safe than sorry + s = fmt.Sprintf("%s", s) // Ensure it's a string + // Base64 URL encoding uses only safe characters, but let's be thorough + return s +} + +// ValidateFormToken is a convenience function to validate CSRF token from form data +func ValidateFormToken(ctx router.Ctx, authManager *auth.AuthManager) bool { + // Try to get token from form data + tokenBytes := ctx.PostArgs().Peek(TokenFieldName) + if len(tokenBytes) == 0 { + // Try from query parameters as fallback + tokenBytes = ctx.QueryArgs().Peek(TokenFieldName) + } + + if len(tokenBytes) == 0 { + return false + } + + return ValidateToken(ctx, authManager, string(tokenBytes)) +} \ No newline at end of file diff --git a/internal/csrf/csrf_test.go b/internal/csrf/csrf_test.go new file mode 100644 index 0000000..2513ce5 --- /dev/null +++ b/internal/csrf/csrf_test.go @@ -0,0 +1,178 @@ +package csrf + +import ( + "testing" + "time" + + "dk/internal/auth" + + "github.com/valyala/fasthttp" +) + +func TestGenerateToken(t *testing.T) { + // Create a mock session + session := &auth.Session{ + ID: "test-session", + UserID: 1, + Username: "testuser", + Email: "test@example.com", + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Hour), + LastSeen: time.Now(), + Data: make(map[string]any), + } + + // Create mock context + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(SessionCtxKey, session) + + // Generate token + token := GenerateToken(ctx, nil) + + if token == "" { + t.Error("Expected non-empty token") + } + + // Check that token was stored in session + storedToken := GetStoredToken(session) + if storedToken != token { + t.Errorf("Expected stored token %s, got %s", token, storedToken) + } +} + +func TestValidateToken(t *testing.T) { + // Create session with token + session := &auth.Session{ + ID: "test-session", + UserID: 1, + Username: "testuser", + Email: "test@example.com", + Data: map[string]any{SessionKey: "test-token"}, + } + + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(SessionCtxKey, session) + + // Valid token should pass + if !ValidateToken(ctx, nil, "test-token") { + t.Error("Expected valid token to pass validation") + } + + // Invalid token should fail + if ValidateToken(ctx, nil, "wrong-token") { + t.Error("Expected invalid token to fail validation") + } + + // Empty token should fail + if ValidateToken(ctx, nil, "") { + t.Error("Expected empty token to fail validation") + } +} + +func TestValidateTokenNoSession(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + // No session should fail validation + if ValidateToken(ctx, nil, "any-token") { + t.Error("Expected validation to fail with no session") + } +} + +func TestHiddenField(t *testing.T) { + session := &auth.Session{ + ID: "test-session", + UserID: 1, + Username: "testuser", + Email: "test@example.com", + Data: map[string]any{SessionKey: "test-token"}, + } + + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(SessionCtxKey, session) + + field := HiddenField(ctx, nil) + expected := `` + + if field != expected { + t.Errorf("Expected %s, got %s", expected, field) + } +} + +func TestHiddenFieldNoSession(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + field := HiddenField(ctx, nil) + if field != "" { + t.Errorf("Expected empty field with no session, got %s", field) + } +} + +func TestTokenMeta(t *testing.T) { + session := &auth.Session{ + ID: "test-session", + UserID: 1, + Username: "testuser", + Email: "test@example.com", + Data: map[string]any{SessionKey: "test-token"}, + } + + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(SessionCtxKey, session) + + meta := TokenMeta(ctx, nil) + expected := `` + + if meta != expected { + t.Errorf("Expected %s, got %s", expected, meta) + } +} + +func TestStoreAndGetToken(t *testing.T) { + session := &auth.Session{ + Data: make(map[string]any), + } + + token := "test-token" + StoreToken(session, token) + + retrieved := GetStoredToken(session) + if retrieved != token { + t.Errorf("Expected %s, got %s", token, retrieved) + } +} + +func TestGetStoredTokenNoData(t *testing.T) { + session := &auth.Session{} + + token := GetStoredToken(session) + if token != "" { + t.Errorf("Expected empty token, got %s", token) + } +} + +func TestValidateFormToken(t *testing.T) { + session := &auth.Session{ + ID: "test-session", + UserID: 1, + Username: "testuser", + Email: "test@example.com", + Data: map[string]any{SessionKey: "test-token"}, + } + + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(SessionCtxKey, session) + + // Set form data + ctx.PostArgs().Set(TokenFieldName, "test-token") + + if !ValidateFormToken(ctx, nil) { + t.Error("Expected form token validation to pass") + } + + // Test with wrong token + ctx.PostArgs().Set(TokenFieldName, "wrong-token") + + if ValidateFormToken(ctx, nil) { + t.Error("Expected form token validation to fail with wrong token") + } +} \ No newline at end of file diff --git a/internal/csrf/doc.go b/internal/csrf/doc.go new file mode 100644 index 0000000..6bc6cab --- /dev/null +++ b/internal/csrf/doc.go @@ -0,0 +1,29 @@ +// Package csrf provides Cross-Site Request Forgery (CSRF) protection +// with session-based token storage and form helpers. +// +// # Basic Usage +// +// // Generate token and store in session +// token := csrf.GenerateToken(ctx, authManager) +// +// // In templates - generate hidden input field +// hiddenField := csrf.HiddenField(ctx, authManager) +// +// // Verify form submission +// if !csrf.ValidateToken(ctx, authManager, formToken) { +// // Handle CSRF validation failure +// } +// +// # Middleware Integration +// +// // Add CSRF middleware to protected routes +// r.Use(middleware.CSRF(authManager)) +// +// # Security Features +// +// - Cryptographically secure token generation +// - Session-based token storage and validation +// - Automatic token rotation on successful validation +// - Protection against timing attacks with constant-time comparison +// - Integration with existing authentication system +package csrf \ No newline at end of file diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go new file mode 100644 index 0000000..f85ab9f --- /dev/null +++ b/internal/middleware/csrf.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "dk/internal/auth" + "dk/internal/csrf" + "dk/internal/router" + "slices" + + "github.com/valyala/fasthttp" +) + +// CSRFConfig holds configuration for CSRF middleware +type CSRFConfig struct { + // Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS) + SkipMethods []string + // Custom failure handler (default: returns 403) + FailureHandler func(ctx router.Ctx) + // Skip CSRF for certain paths + SkipPaths []string +} + +// CSRF creates a CSRF protection middleware +func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware { + cfg := CSRFConfig{ + SkipMethods: []string{"GET", "HEAD", "OPTIONS"}, + FailureHandler: func(ctx router.Ctx) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetContentType("text/plain") + ctx.WriteString("CSRF token validation failed") + }, + SkipPaths: []string{}, + } + + // Apply custom config if provided + if len(config) > 0 { + if len(config[0].SkipMethods) > 0 { + cfg.SkipMethods = config[0].SkipMethods + } + if config[0].FailureHandler != nil { + cfg.FailureHandler = config[0].FailureHandler + } + if len(config[0].SkipPaths) > 0 { + cfg.SkipPaths = config[0].SkipPaths + } + } + + return func(next router.Handler) router.Handler { + return func(ctx router.Ctx, params []string) { + method := string(ctx.Method()) + path := string(ctx.Path()) + + // Skip CSRF validation for certain methods + shouldSkip := slices.Contains(cfg.SkipMethods, method) + + // Skip CSRF validation for certain paths + if !shouldSkip { + if slices.Contains(cfg.SkipPaths, path) { + shouldSkip = true + } + } + + // Skip CSRF for non-authenticated users (no session) + if !shouldSkip && !IsAuthenticated(ctx) { + shouldSkip = true + } + + if shouldSkip { + next(ctx, params) + return + } + + // Validate CSRF token for protected methods + if !csrf.ValidateFormToken(ctx, authManager) { + cfg.FailureHandler(ctx) + return + } + + // CSRF validation passed, rotate token for security + csrf.RotateToken(ctx, authManager) + + next(ctx, params) + } + } +} + +// RequireCSRF is a stricter CSRF middleware that always validates tokens +func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ctx)) router.Middleware { + handler := func(ctx router.Ctx) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetContentType("text/plain") + ctx.WriteString("CSRF token required") + } + + if len(failureHandler) > 0 { + handler = failureHandler[0] + } + + return func(next router.Handler) router.Handler { + return func(ctx router.Ctx, params []string) { + if !csrf.ValidateFormToken(ctx, authManager) { + handler(ctx) + return + } + + // Rotate token after successful validation + csrf.RotateToken(ctx, authManager) + + next(ctx, params) + } + } +} + +// CSRFToken returns the current CSRF token for the request +func CSRFToken(ctx router.Ctx, authManager *auth.AuthManager) string { + return csrf.GetToken(ctx, authManager) +} + +// CSRFHiddenField generates a hidden input field for forms +func CSRFHiddenField(ctx router.Ctx, authManager *auth.AuthManager) string { + return csrf.HiddenField(ctx, authManager) +} + +// CSRFMeta generates a meta tag for JavaScript access +func CSRFMeta(ctx router.Ctx, authManager *auth.AuthManager) string { + return csrf.TokenMeta(ctx, authManager) +}