package middleware import ( "dk/internal/csrf" "dk/internal/router" "slices" "github.com/valyala/fasthttp" ) // CSRFConfig holds configuration for CSRF middleware type CSRFConfig struct { // Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS) SkipMethods []string // Custom failure handler (default: returns 403) FailureHandler func(ctx router.Ctx) // Skip CSRF for certain paths SkipPaths []string } // CSRF creates a CSRF protection middleware func CSRF(config ...CSRFConfig) router.Middleware { cfg := CSRFConfig{ SkipMethods: []string{"GET", "HEAD", "OPTIONS"}, FailureHandler: func(ctx router.Ctx) { ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.SetContentType("text/plain") ctx.WriteString("CSRF token validation failed") }, SkipPaths: []string{}, } // Apply custom config if provided if len(config) > 0 { if len(config[0].SkipMethods) > 0 { cfg.SkipMethods = config[0].SkipMethods } if config[0].FailureHandler != nil { cfg.FailureHandler = config[0].FailureHandler } if len(config[0].SkipPaths) > 0 { cfg.SkipPaths = config[0].SkipPaths } } return func(next router.Handler) router.Handler { return func(ctx router.Ctx, params []string) { method := string(ctx.Method()) path := string(ctx.Path()) // Skip CSRF validation for certain methods shouldSkip := slices.Contains(cfg.SkipMethods, method) // Skip CSRF validation for certain paths if !shouldSkip { if slices.Contains(cfg.SkipPaths, path) { shouldSkip = true } } // CSRF protection now works for both authenticated and guest users // Remove the skip for non-authenticated users if shouldSkip { next(ctx, params) return } // Validate CSRF token for protected methods if !csrf.ValidateFormToken(ctx) { cfg.FailureHandler(ctx) return } next(ctx, params) } } } // RequireCSRF is a stricter CSRF middleware that always validates tokens func RequireCSRF(failureHandler ...func(router.Ctx)) router.Middleware { handler := func(ctx router.Ctx) { ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.SetContentType("text/plain") ctx.WriteString("CSRF token required") } if len(failureHandler) > 0 { handler = failureHandler[0] } return func(next router.Handler) router.Handler { return func(ctx router.Ctx, params []string) { if !csrf.ValidateFormToken(ctx) { handler(ctx) return } next(ctx, params) } } } // CSRFToken returns the current CSRF token for the request func CSRFToken(ctx router.Ctx) string { return csrf.GetToken(ctx) } // CSRFHiddenField generates a hidden input field for forms func CSRFHiddenField(ctx router.Ctx) string { return csrf.HiddenField(ctx) } // CSRFMeta generates a meta tag for JavaScript access func CSRFMeta(ctx router.Ctx) string { return csrf.TokenMeta(ctx) }