// Package csrf provides Cross-Site Request Forgery (CSRF) protection
// with session-based token storage and form helpers.
//
// # Basic Usage
//
// // Generate token and store in session
// token := csrf.GenerateToken(ctx, sessionManager)
//
// // In templates - generate hidden input field
// hiddenField := csrf.HiddenField(ctx, sessionManager)
//
// // Verify form submission
// if !csrf.ValidateToken(ctx, sessionManager, formToken) {
// // Handle CSRF validation failure
// }
//
// # Middleware Integration
//
// // Add CSRF middleware to protected routes
// r.Use(middleware.CSRF(authManager))
package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"time"
"dk/internal/router"
"dk/internal/session"
"github.com/valyala/fasthttp"
)
const (
TokenLength = 32
TokenFieldName = "_csrf_token"
SessionKey = "csrf_token"
SessionCtxKey = "session" // Same as middleware.SessionKey
CookieName = "_csrf"
)
// GetCurrentSession retrieves the session from context (mirrors middleware function)
func GetCurrentSession(ctx router.Ctx) *session.Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return sess
}
return nil
}
// GenerateToken creates a new CSRF token and stores it in the session or cookie
func GenerateToken(ctx router.Ctx) string {
// Generate cryptographically secure random bytes
tokenBytes := make([]byte, TokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
// Fallback - this should never happen in practice
return ""
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
// Store token in session if user is authenticated, otherwise use cookie
if session := GetCurrentSession(ctx); session != nil {
StoreToken(session, token)
} else {
// Store in cookie for guest users
StoreTokenInCookie(ctx, token)
}
return token
}
// GetToken retrieves the current CSRF token from session or cookie, generating one if needed
func GetToken(ctx router.Ctx) string {
session := GetCurrentSession(ctx)
if session != nil {
// Authenticated user - check session first
if existingToken := GetStoredToken(session); existingToken != "" {
return existingToken
}
} else {
// Guest user - check cookie first
if existingToken := GetTokenFromCookie(ctx); existingToken != "" {
return existingToken
}
}
// Generate new token if none exists
return GenerateToken(ctx)
}
// ValidateToken verifies a CSRF token against the stored session or cookie token
func ValidateToken(ctx router.Ctx, submittedToken string) bool {
if submittedToken == "" {
return false
}
var storedToken string
session := GetCurrentSession(ctx)
if session != nil {
// Authenticated user - get token from session
storedToken = GetStoredToken(session)
} else {
// Guest user - get token from cookie
storedToken = GetTokenFromCookie(ctx)
}
if storedToken == "" {
return false // No stored token
}
// Use constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1
}
// StoreToken saves a CSRF token in the session
func StoreToken(sess *session.Session, token string) {
sess.Set(SessionKey, token)
}
// GetStoredToken retrieves the CSRF token from session
func GetStoredToken(sess *session.Session) string {
if token, ok := sess.Get(SessionKey); ok {
if tokenStr, ok := token.(string); ok {
return tokenStr
}
}
return ""
}
// RotateToken generates a new token and replaces the old one in the session
func RotateToken(ctx router.Ctx) string {
session := GetCurrentSession(ctx)
if session == nil {
return ""
}
// Generate new token
newToken := GenerateToken(ctx)
return newToken
}
// HiddenField generates an HTML hidden input field with the CSRF token
func HiddenField(ctx router.Ctx) string {
token := GetToken(ctx)
if token == "" {
return "" // No token available
}
return fmt.Sprintf(``,
TokenFieldName, token)
}
// TokenMeta generates HTML meta tag for JavaScript access to CSRF token
func TokenMeta(ctx router.Ctx) string {
token := GetToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(``, token)
}
// ValidateFormToken is a convenience function to validate CSRF token from form data
func ValidateFormToken(ctx router.Ctx) bool {
// Try to get token from form data
tokenBytes := ctx.PostArgs().Peek(TokenFieldName)
if len(tokenBytes) == 0 {
// Try from query parameters as fallback
tokenBytes = ctx.QueryArgs().Peek(TokenFieldName)
}
if len(tokenBytes) == 0 {
return false
}
return ValidateToken(ctx, string(tokenBytes))
}
// StoreTokenInCookie stores a CSRF token in a cookie for guest users
func StoreTokenInCookie(ctx router.Ctx, token string) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(CookieName)
cookie.SetValue(token)
cookie.SetHTTPOnly(true)
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
cookie.SetSecure(false) // Set to true in production with HTTPS
cookie.SetExpire(time.Now().Add(24 * time.Hour)) // Expire in 24 hours
cookie.SetPath("/")
ctx.Response.Header.SetCookie(cookie)
}
// GetTokenFromCookie retrieves a CSRF token from cookie for guest users
func GetTokenFromCookie(ctx router.Ctx) string {
return string(ctx.Request.Header.Cookie(CookieName))
}