127 lines
3.2 KiB
Go

package middleware
import (
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/router"
"slices"
"github.com/valyala/fasthttp"
)
// CSRFConfig holds configuration for CSRF middleware
type CSRFConfig struct {
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
SkipMethods []string
// Custom failure handler (default: returns 403)
FailureHandler func(ctx router.Ctx)
// Skip CSRF for certain paths
SkipPaths []string
}
// CSRF creates a CSRF protection middleware
func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware {
cfg := CSRFConfig{
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
FailureHandler: func(ctx router.Ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("text/plain")
ctx.WriteString("CSRF token validation failed")
},
SkipPaths: []string{},
}
// Apply custom config if provided
if len(config) > 0 {
if len(config[0].SkipMethods) > 0 {
cfg.SkipMethods = config[0].SkipMethods
}
if config[0].FailureHandler != nil {
cfg.FailureHandler = config[0].FailureHandler
}
if len(config[0].SkipPaths) > 0 {
cfg.SkipPaths = config[0].SkipPaths
}
}
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
method := string(ctx.Method())
path := string(ctx.Path())
// Skip CSRF validation for certain methods
shouldSkip := slices.Contains(cfg.SkipMethods, method)
// Skip CSRF validation for certain paths
if !shouldSkip {
if slices.Contains(cfg.SkipPaths, path) {
shouldSkip = true
}
}
// Skip CSRF for non-authenticated users (no session)
if !shouldSkip && !IsAuthenticated(ctx) {
shouldSkip = true
}
if shouldSkip {
next(ctx, params)
return
}
// Validate CSRF token for protected methods
if !csrf.ValidateFormToken(ctx, authManager) {
cfg.FailureHandler(ctx)
return
}
// CSRF validation passed, rotate token for security
csrf.RotateToken(ctx, authManager)
next(ctx, params)
}
}
}
// RequireCSRF is a stricter CSRF middleware that always validates tokens
func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ctx)) router.Middleware {
handler := func(ctx router.Ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("text/plain")
ctx.WriteString("CSRF token required")
}
if len(failureHandler) > 0 {
handler = failureHandler[0]
}
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
if !csrf.ValidateFormToken(ctx, authManager) {
handler(ctx)
return
}
// Rotate token after successful validation
csrf.RotateToken(ctx, authManager)
next(ctx, params)
}
}
}
// CSRFToken returns the current CSRF token for the request
func CSRFToken(ctx router.Ctx, authManager *auth.AuthManager) string {
return csrf.GetToken(ctx, authManager)
}
// CSRFHiddenField generates a hidden input field for forms
func CSRFHiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
return csrf.HiddenField(ctx, authManager)
}
// CSRFMeta generates a meta tag for JavaScript access
func CSRFMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
return csrf.TokenMeta(ctx, authManager)
}