161 lines
4.3 KiB
Go

package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"dk/internal/auth"
"dk/internal/router"
)
const (
TokenLength = 32
TokenFieldName = "_csrf_token"
SessionKey = "csrf_token"
SessionCtxKey = "session" // Same as middleware.SessionKey
)
// GetCurrentSession retrieves the session from context (mirrors middleware function)
func GetCurrentSession(ctx router.Ctx) *auth.Session {
if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
return session
}
return nil
}
// GenerateToken creates a new CSRF token and stores it in the session
func GenerateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
// Generate cryptographically secure random bytes
tokenBytes := make([]byte, TokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
// Fallback - this should never happen in practice
return ""
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
// Store token in session if user is authenticated
if session := GetCurrentSession(ctx); session != nil {
StoreToken(session, token)
}
return token
}
// GetToken retrieves the current CSRF token from session, generating one if needed
func GetToken(ctx router.Ctx, authManager *auth.AuthManager) string {
session := GetCurrentSession(ctx)
if session == nil {
return "" // No session, no CSRF protection needed
}
// Check if token already exists in session
if existingToken := GetStoredToken(session); existingToken != "" {
return existingToken
}
// Generate new token if none exists
return GenerateToken(ctx, authManager)
}
// ValidateToken verifies a CSRF token against the stored session token
func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken string) bool {
if submittedToken == "" {
return false
}
session := GetCurrentSession(ctx)
if session == nil {
return false // No session means no CSRF protection
}
storedToken := GetStoredToken(session)
if storedToken == "" {
return false // No stored token
}
// Use constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1
}
// StoreToken saves a CSRF token in the session
func StoreToken(session *auth.Session, token string) {
if session.Data == nil {
session.Data = make(map[string]any)
}
session.Data[SessionKey] = token
}
// GetStoredToken retrieves the CSRF token from session
func GetStoredToken(session *auth.Session) string {
if session.Data == nil {
return ""
}
if token, ok := session.Data[SessionKey].(string); ok {
return token
}
return ""
}
// RotateToken generates a new token and replaces the old one in the session
func RotateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
session := GetCurrentSession(ctx)
if session == nil {
return ""
}
// Generate new token
newToken := GenerateToken(ctx, authManager)
return newToken
}
// HiddenField generates an HTML hidden input field with the CSRF token
func HiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
token := GetToken(ctx, authManager)
if token == "" {
return "" // No token available
}
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
TokenFieldName, escapeHTML(token))
}
// TokenMeta generates HTML meta tag for JavaScript access to CSRF token
func TokenMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
token := GetToken(ctx, authManager)
if token == "" {
return ""
}
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, escapeHTML(token))
}
// escapeHTML provides basic HTML escaping for token values
func escapeHTML(s string) string {
// Basic HTML escaping - base64 tokens shouldn't need much escaping
// but better safe than sorry
s = fmt.Sprintf("%s", s) // Ensure it's a string
// Base64 URL encoding uses only safe characters, but let's be thorough
return s
}
// ValidateFormToken is a convenience function to validate CSRF token from form data
func ValidateFormToken(ctx router.Ctx, authManager *auth.AuthManager) bool {
// Try to get token from form data
tokenBytes := ctx.PostArgs().Peek(TokenFieldName)
if len(tokenBytes) == 0 {
// Try from query parameters as fallback
tokenBytes = ctx.QueryArgs().Peek(TokenFieldName)
}
if len(tokenBytes) == 0 {
return false
}
return ValidateToken(ctx, authManager, string(tokenBytes))
}