125 lines
3.2 KiB
Go
125 lines
3.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"dk/internal/auth"
|
|
"dk/internal/csrf"
|
|
"dk/internal/router"
|
|
"slices"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// CSRFConfig holds configuration for CSRF middleware
|
|
type CSRFConfig struct {
|
|
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
|
|
SkipMethods []string
|
|
// Custom failure handler (default: returns 403)
|
|
FailureHandler func(ctx router.Ctx)
|
|
// Skip CSRF for certain paths
|
|
SkipPaths []string
|
|
}
|
|
|
|
// CSRF creates a CSRF protection middleware
|
|
func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware {
|
|
cfg := CSRFConfig{
|
|
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
|
|
FailureHandler: func(ctx router.Ctx) {
|
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
ctx.SetContentType("text/plain")
|
|
ctx.WriteString("CSRF token validation failed")
|
|
},
|
|
SkipPaths: []string{},
|
|
}
|
|
|
|
// Apply custom config if provided
|
|
if len(config) > 0 {
|
|
if len(config[0].SkipMethods) > 0 {
|
|
cfg.SkipMethods = config[0].SkipMethods
|
|
}
|
|
if config[0].FailureHandler != nil {
|
|
cfg.FailureHandler = config[0].FailureHandler
|
|
}
|
|
if len(config[0].SkipPaths) > 0 {
|
|
cfg.SkipPaths = config[0].SkipPaths
|
|
}
|
|
}
|
|
|
|
return func(next router.Handler) router.Handler {
|
|
return func(ctx router.Ctx, params []string) {
|
|
method := string(ctx.Method())
|
|
path := string(ctx.Path())
|
|
|
|
// Skip CSRF validation for certain methods
|
|
shouldSkip := slices.Contains(cfg.SkipMethods, method)
|
|
|
|
// Skip CSRF validation for certain paths
|
|
if !shouldSkip {
|
|
if slices.Contains(cfg.SkipPaths, path) {
|
|
shouldSkip = true
|
|
}
|
|
}
|
|
|
|
// CSRF protection now works for both authenticated and guest users
|
|
// Remove the skip for non-authenticated users
|
|
|
|
if shouldSkip {
|
|
next(ctx, params)
|
|
return
|
|
}
|
|
|
|
// Validate CSRF token for protected methods
|
|
if !csrf.ValidateFormToken(ctx, authManager) {
|
|
cfg.FailureHandler(ctx)
|
|
return
|
|
}
|
|
|
|
// CSRF validation passed, rotate token for security
|
|
csrf.RotateToken(ctx, authManager)
|
|
|
|
next(ctx, params)
|
|
}
|
|
}
|
|
}
|
|
|
|
// RequireCSRF is a stricter CSRF middleware that always validates tokens
|
|
func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ctx)) router.Middleware {
|
|
handler := func(ctx router.Ctx) {
|
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
ctx.SetContentType("text/plain")
|
|
ctx.WriteString("CSRF token required")
|
|
}
|
|
|
|
if len(failureHandler) > 0 {
|
|
handler = failureHandler[0]
|
|
}
|
|
|
|
return func(next router.Handler) router.Handler {
|
|
return func(ctx router.Ctx, params []string) {
|
|
if !csrf.ValidateFormToken(ctx, authManager) {
|
|
handler(ctx)
|
|
return
|
|
}
|
|
|
|
// Rotate token after successful validation
|
|
csrf.RotateToken(ctx, authManager)
|
|
|
|
next(ctx, params)
|
|
}
|
|
}
|
|
}
|
|
|
|
// CSRFToken returns the current CSRF token for the request
|
|
func CSRFToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
|
return csrf.GetToken(ctx, authManager)
|
|
}
|
|
|
|
// CSRFHiddenField generates a hidden input field for forms
|
|
func CSRFHiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
|
|
return csrf.HiddenField(ctx, authManager)
|
|
}
|
|
|
|
// CSRFMeta generates a meta tag for JavaScript access
|
|
func CSRFMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
|
|
return csrf.TokenMeta(ctx, authManager)
|
|
}
|