add CSRF middleware, and to session data

This commit is contained in:
Sky Johnson 2025-08-09 09:49:50 -05:00
parent a49346160b
commit cec2b12c35
5 changed files with 502 additions and 7 deletions

View File

@ -24,6 +24,7 @@ type Session struct {
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
LastSeen time.Time `json:"last_seen"`
Data map[string]any `json:"data,omitempty"` // For storing additional session data
}
type SessionStore struct {

161
internal/csrf/csrf.go Normal file
View File

@ -0,0 +1,161 @@
package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"dk/internal/auth"
"dk/internal/router"
)
const (
TokenLength = 32
TokenFieldName = "_csrf_token"
SessionKey = "csrf_token"
SessionCtxKey = "session" // Same as middleware.SessionKey
)
// GetCurrentSession retrieves the session from context (mirrors middleware function)
func GetCurrentSession(ctx router.Ctx) *auth.Session {
if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
return session
}
return nil
}
// GenerateToken creates a new CSRF token and stores it in the session
func GenerateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
// Generate cryptographically secure random bytes
tokenBytes := make([]byte, TokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
// Fallback - this should never happen in practice
return ""
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
// Store token in session if user is authenticated
if session := GetCurrentSession(ctx); session != nil {
StoreToken(session, token)
}
return token
}
// GetToken retrieves the current CSRF token from session, generating one if needed
func GetToken(ctx router.Ctx, authManager *auth.AuthManager) string {
session := GetCurrentSession(ctx)
if session == nil {
return "" // No session, no CSRF protection needed
}
// Check if token already exists in session
if existingToken := GetStoredToken(session); existingToken != "" {
return existingToken
}
// Generate new token if none exists
return GenerateToken(ctx, authManager)
}
// ValidateToken verifies a CSRF token against the stored session token
func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken string) bool {
if submittedToken == "" {
return false
}
session := GetCurrentSession(ctx)
if session == nil {
return false // No session means no CSRF protection
}
storedToken := GetStoredToken(session)
if storedToken == "" {
return false // No stored token
}
// Use constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1
}
// StoreToken saves a CSRF token in the session
func StoreToken(session *auth.Session, token string) {
if session.Data == nil {
session.Data = make(map[string]any)
}
session.Data[SessionKey] = token
}
// GetStoredToken retrieves the CSRF token from session
func GetStoredToken(session *auth.Session) string {
if session.Data == nil {
return ""
}
if token, ok := session.Data[SessionKey].(string); ok {
return token
}
return ""
}
// RotateToken generates a new token and replaces the old one in the session
func RotateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
session := GetCurrentSession(ctx)
if session == nil {
return ""
}
// Generate new token
newToken := GenerateToken(ctx, authManager)
return newToken
}
// HiddenField generates an HTML hidden input field with the CSRF token
func HiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
token := GetToken(ctx, authManager)
if token == "" {
return "" // No token available
}
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
TokenFieldName, escapeHTML(token))
}
// TokenMeta generates HTML meta tag for JavaScript access to CSRF token
func TokenMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
token := GetToken(ctx, authManager)
if token == "" {
return ""
}
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, escapeHTML(token))
}
// escapeHTML provides basic HTML escaping for token values
func escapeHTML(s string) string {
// Basic HTML escaping - base64 tokens shouldn't need much escaping
// but better safe than sorry
s = fmt.Sprintf("%s", s) // Ensure it's a string
// Base64 URL encoding uses only safe characters, but let's be thorough
return s
}
// ValidateFormToken is a convenience function to validate CSRF token from form data
func ValidateFormToken(ctx router.Ctx, authManager *auth.AuthManager) bool {
// Try to get token from form data
tokenBytes := ctx.PostArgs().Peek(TokenFieldName)
if len(tokenBytes) == 0 {
// Try from query parameters as fallback
tokenBytes = ctx.QueryArgs().Peek(TokenFieldName)
}
if len(tokenBytes) == 0 {
return false
}
return ValidateToken(ctx, authManager, string(tokenBytes))
}

178
internal/csrf/csrf_test.go Normal file
View File

@ -0,0 +1,178 @@
package csrf
import (
"testing"
"time"
"dk/internal/auth"
"github.com/valyala/fasthttp"
)
func TestGenerateToken(t *testing.T) {
// Create a mock session
session := &auth.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
LastSeen: time.Now(),
Data: make(map[string]any),
}
// Create mock context
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session)
// Generate token
token := GenerateToken(ctx, nil)
if token == "" {
t.Error("Expected non-empty token")
}
// Check that token was stored in session
storedToken := GetStoredToken(session)
if storedToken != token {
t.Errorf("Expected stored token %s, got %s", token, storedToken)
}
}
func TestValidateToken(t *testing.T) {
// Create session with token
session := &auth.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session)
// Valid token should pass
if !ValidateToken(ctx, nil, "test-token") {
t.Error("Expected valid token to pass validation")
}
// Invalid token should fail
if ValidateToken(ctx, nil, "wrong-token") {
t.Error("Expected invalid token to fail validation")
}
// Empty token should fail
if ValidateToken(ctx, nil, "") {
t.Error("Expected empty token to fail validation")
}
}
func TestValidateTokenNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// No session should fail validation
if ValidateToken(ctx, nil, "any-token") {
t.Error("Expected validation to fail with no session")
}
}
func TestHiddenField(t *testing.T) {
session := &auth.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session)
field := HiddenField(ctx, nil)
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
if field != expected {
t.Errorf("Expected %s, got %s", expected, field)
}
}
func TestHiddenFieldNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
field := HiddenField(ctx, nil)
if field != "" {
t.Errorf("Expected empty field with no session, got %s", field)
}
}
func TestTokenMeta(t *testing.T) {
session := &auth.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session)
meta := TokenMeta(ctx, nil)
expected := `<meta name="csrf-token" content="test-token">`
if meta != expected {
t.Errorf("Expected %s, got %s", expected, meta)
}
}
func TestStoreAndGetToken(t *testing.T) {
session := &auth.Session{
Data: make(map[string]any),
}
token := "test-token"
StoreToken(session, token)
retrieved := GetStoredToken(session)
if retrieved != token {
t.Errorf("Expected %s, got %s", token, retrieved)
}
}
func TestGetStoredTokenNoData(t *testing.T) {
session := &auth.Session{}
token := GetStoredToken(session)
if token != "" {
t.Errorf("Expected empty token, got %s", token)
}
}
func TestValidateFormToken(t *testing.T) {
session := &auth.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, session)
// Set form data
ctx.PostArgs().Set(TokenFieldName, "test-token")
if !ValidateFormToken(ctx, nil) {
t.Error("Expected form token validation to pass")
}
// Test with wrong token
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
if ValidateFormToken(ctx, nil) {
t.Error("Expected form token validation to fail with wrong token")
}
}

29
internal/csrf/doc.go Normal file
View File

@ -0,0 +1,29 @@
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
// with session-based token storage and form helpers.
//
// # Basic Usage
//
// // Generate token and store in session
// token := csrf.GenerateToken(ctx, authManager)
//
// // In templates - generate hidden input field
// hiddenField := csrf.HiddenField(ctx, authManager)
//
// // Verify form submission
// if !csrf.ValidateToken(ctx, authManager, formToken) {
// // Handle CSRF validation failure
// }
//
// # Middleware Integration
//
// // Add CSRF middleware to protected routes
// r.Use(middleware.CSRF(authManager))
//
// # Security Features
//
// - Cryptographically secure token generation
// - Session-based token storage and validation
// - Automatic token rotation on successful validation
// - Protection against timing attacks with constant-time comparison
// - Integration with existing authentication system
package csrf

126
internal/middleware/csrf.go Normal file
View File

@ -0,0 +1,126 @@
package middleware
import (
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/router"
"slices"
"github.com/valyala/fasthttp"
)
// CSRFConfig holds configuration for CSRF middleware
type CSRFConfig struct {
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
SkipMethods []string
// Custom failure handler (default: returns 403)
FailureHandler func(ctx router.Ctx)
// Skip CSRF for certain paths
SkipPaths []string
}
// CSRF creates a CSRF protection middleware
func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware {
cfg := CSRFConfig{
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
FailureHandler: func(ctx router.Ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("text/plain")
ctx.WriteString("CSRF token validation failed")
},
SkipPaths: []string{},
}
// Apply custom config if provided
if len(config) > 0 {
if len(config[0].SkipMethods) > 0 {
cfg.SkipMethods = config[0].SkipMethods
}
if config[0].FailureHandler != nil {
cfg.FailureHandler = config[0].FailureHandler
}
if len(config[0].SkipPaths) > 0 {
cfg.SkipPaths = config[0].SkipPaths
}
}
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
method := string(ctx.Method())
path := string(ctx.Path())
// Skip CSRF validation for certain methods
shouldSkip := slices.Contains(cfg.SkipMethods, method)
// Skip CSRF validation for certain paths
if !shouldSkip {
if slices.Contains(cfg.SkipPaths, path) {
shouldSkip = true
}
}
// Skip CSRF for non-authenticated users (no session)
if !shouldSkip && !IsAuthenticated(ctx) {
shouldSkip = true
}
if shouldSkip {
next(ctx, params)
return
}
// Validate CSRF token for protected methods
if !csrf.ValidateFormToken(ctx, authManager) {
cfg.FailureHandler(ctx)
return
}
// CSRF validation passed, rotate token for security
csrf.RotateToken(ctx, authManager)
next(ctx, params)
}
}
}
// RequireCSRF is a stricter CSRF middleware that always validates tokens
func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ctx)) router.Middleware {
handler := func(ctx router.Ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("text/plain")
ctx.WriteString("CSRF token required")
}
if len(failureHandler) > 0 {
handler = failureHandler[0]
}
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
if !csrf.ValidateFormToken(ctx, authManager) {
handler(ctx)
return
}
// Rotate token after successful validation
csrf.RotateToken(ctx, authManager)
next(ctx, params)
}
}
}
// CSRFToken returns the current CSRF token for the request
func CSRFToken(ctx router.Ctx, authManager *auth.AuthManager) string {
return csrf.GetToken(ctx, authManager)
}
// CSRFHiddenField generates a hidden input field for forms
func CSRFHiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
return csrf.HiddenField(ctx, authManager)
}
// CSRFMeta generates a meta tag for JavaScript access
func CSRFMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
return csrf.TokenMeta(ctx, authManager)
}