196 lines
4.9 KiB
Go

// Package csrf provides Cross-Site Request Forgery (CSRF) protection
// with session-based token storage and form helpers.
//
// # Basic Usage
//
// // Generate token and store in session
// token := csrf.GenerateToken(ctx)
//
// // In templates - generate hidden input field
// hiddenField := csrf.HiddenField(ctx)
//
// // Verify form submission
// if !csrf.ValidateToken(ctx, formToken) {
// // Handle CSRF validation failure
// }
//
// # Middleware Integration
//
// // Add CSRF middleware to protected routes
// r.Use(csrf.Middleware())
package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"dk/internal/router"
"dk/internal/session"
)
const (
TokenLength = 32
TokenFieldName = "_csrf_token"
SessionKey = "csrf_token"
SessionCtxKey = "session"
)
// GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx router.Ctx) *session.Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return sess
}
return nil
}
// GenerateToken creates a new CSRF token and stores it in the session
func GenerateToken(ctx router.Ctx) string {
// Generate cryptographically secure random bytes
tokenBytes := make([]byte, TokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
// Fallback - this should never happen in practice
return ""
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
// Store token in session (both guests and authenticated users have sessions)
if sess := GetCurrentSession(ctx); sess != nil {
StoreToken(sess, token)
session.Store(sess)
}
return token
}
// GetToken retrieves the current CSRF token from session, generating one if needed
func GetToken(ctx router.Ctx) string {
sess := GetCurrentSession(ctx)
if sess == nil {
return "" // No session available
}
// Check for existing token
if existingToken := GetStoredToken(sess); existingToken != "" {
return existingToken
}
// Generate new token if none exists
return GenerateToken(ctx)
}
// ValidateToken verifies a CSRF token against the stored session token
func ValidateToken(ctx router.Ctx, submittedToken string) bool {
if submittedToken == "" {
return false
}
sess := GetCurrentSession(ctx)
if sess == nil {
return false // No session
}
storedToken := GetStoredToken(sess)
if storedToken == "" {
return false // No stored token
}
// Use constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1
}
// StoreToken saves a CSRF token in the session
func StoreToken(sess *session.Session, token string) {
sess.Set(SessionKey, token)
}
// GetStoredToken retrieves the CSRF token from session
func GetStoredToken(sess *session.Session) string {
if token, ok := sess.Get(SessionKey); ok {
if tokenStr, ok := token.(string); ok {
return tokenStr
}
}
return ""
}
// RotateToken generates a new token and replaces the old one in the session
func RotateToken(ctx router.Ctx) string {
sess := GetCurrentSession(ctx)
if sess == nil {
return ""
}
// Generate new token (this will automatically store it)
newToken := GenerateToken(ctx)
return newToken
}
// HiddenField generates an HTML hidden input field with the CSRF token
func HiddenField(ctx router.Ctx) string {
token := GetToken(ctx)
if token == "" {
return "" // No token available
}
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
TokenFieldName, token)
}
// TokenMeta generates HTML meta tag for JavaScript access to CSRF token
func TokenMeta(ctx router.Ctx) string {
token := GetToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
}
// ValidateFormToken is a convenience function to validate CSRF token from form data
func ValidateFormToken(ctx router.Ctx) bool {
// Try to get token from form data
tokenBytes := ctx.PostArgs().Peek(TokenFieldName)
if len(tokenBytes) == 0 {
// Try from query parameters as fallback
tokenBytes = ctx.QueryArgs().Peek(TokenFieldName)
}
if len(tokenBytes) == 0 {
return false
}
return ValidateToken(ctx, string(tokenBytes))
}
// GetTokenFromCookie retrieves a CSRF token from cookie (legacy support)
func GetTokenFromCookie(ctx router.Ctx) string {
return string(ctx.Request.Header.Cookie("_csrf"))
}
// Middleware returns a middleware function that automatically validates CSRF tokens
// for state-changing HTTP methods (POST, PUT, PATCH, DELETE)
func Middleware() router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
method := string(ctx.Method())
// Only validate CSRF for state-changing methods
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateFormToken(ctx) {
fmt.Println("Failed CSRF validation.")
RotateToken(ctx)
currentPath := string(ctx.Path())
ctx.Redirect(currentPath, 302)
return
}
}
// Continue to next handler
next(ctx, params)
}
}
}