add CSRF middleware, and to session data
This commit is contained in:
parent
a49346160b
commit
cec2b12c35
@ -17,13 +17,14 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
UserID int `json:"user_id"`
|
UserID int `json:"user_id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
LastSeen time.Time `json:"last_seen"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
Data map[string]any `json:"data,omitempty"` // For storing additional session data
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
|
161
internal/csrf/csrf.go
Normal file
161
internal/csrf/csrf.go
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
package csrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenLength = 32
|
||||||
|
TokenFieldName = "_csrf_token"
|
||||||
|
SessionKey = "csrf_token"
|
||||||
|
SessionCtxKey = "session" // Same as middleware.SessionKey
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetCurrentSession retrieves the session from context (mirrors middleware function)
|
||||||
|
func GetCurrentSession(ctx router.Ctx) *auth.Session {
|
||||||
|
if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToken creates a new CSRF token and stores it in the session
|
||||||
|
func GenerateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
// Generate cryptographically secure random bytes
|
||||||
|
tokenBytes := make([]byte, TokenLength)
|
||||||
|
if _, err := rand.Read(tokenBytes); err != nil {
|
||||||
|
// Fallback - this should never happen in practice
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||||
|
|
||||||
|
// Store token in session if user is authenticated
|
||||||
|
if session := GetCurrentSession(ctx); session != nil {
|
||||||
|
StoreToken(session, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetToken retrieves the current CSRF token from session, generating one if needed
|
||||||
|
func GetToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
session := GetCurrentSession(ctx)
|
||||||
|
if session == nil {
|
||||||
|
return "" // No session, no CSRF protection needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if token already exists in session
|
||||||
|
if existingToken := GetStoredToken(session); existingToken != "" {
|
||||||
|
return existingToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new token if none exists
|
||||||
|
return GenerateToken(ctx, authManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken verifies a CSRF token against the stored session token
|
||||||
|
func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken string) bool {
|
||||||
|
if submittedToken == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
session := GetCurrentSession(ctx)
|
||||||
|
if session == nil {
|
||||||
|
return false // No session means no CSRF protection
|
||||||
|
}
|
||||||
|
|
||||||
|
storedToken := GetStoredToken(session)
|
||||||
|
if storedToken == "" {
|
||||||
|
return false // No stored token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use constant-time comparison to prevent timing attacks
|
||||||
|
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreToken saves a CSRF token in the session
|
||||||
|
func StoreToken(session *auth.Session, token string) {
|
||||||
|
if session.Data == nil {
|
||||||
|
session.Data = make(map[string]any)
|
||||||
|
}
|
||||||
|
session.Data[SessionKey] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStoredToken retrieves the CSRF token from session
|
||||||
|
func GetStoredToken(session *auth.Session) string {
|
||||||
|
if session.Data == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if token, ok := session.Data[SessionKey].(string); ok {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateToken generates a new token and replaces the old one in the session
|
||||||
|
func RotateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
session := GetCurrentSession(ctx)
|
||||||
|
if session == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new token
|
||||||
|
newToken := GenerateToken(ctx, authManager)
|
||||||
|
|
||||||
|
return newToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// HiddenField generates an HTML hidden input field with the CSRF token
|
||||||
|
func HiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
token := GetToken(ctx, authManager)
|
||||||
|
if token == "" {
|
||||||
|
return "" // No token available
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
|
||||||
|
TokenFieldName, escapeHTML(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenMeta generates HTML meta tag for JavaScript access to CSRF token
|
||||||
|
func TokenMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
token := GetToken(ctx, authManager)
|
||||||
|
if token == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, escapeHTML(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeHTML provides basic HTML escaping for token values
|
||||||
|
func escapeHTML(s string) string {
|
||||||
|
// Basic HTML escaping - base64 tokens shouldn't need much escaping
|
||||||
|
// but better safe than sorry
|
||||||
|
s = fmt.Sprintf("%s", s) // Ensure it's a string
|
||||||
|
// Base64 URL encoding uses only safe characters, but let's be thorough
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFormToken is a convenience function to validate CSRF token from form data
|
||||||
|
func ValidateFormToken(ctx router.Ctx, authManager *auth.AuthManager) bool {
|
||||||
|
// Try to get token from form data
|
||||||
|
tokenBytes := ctx.PostArgs().Peek(TokenFieldName)
|
||||||
|
if len(tokenBytes) == 0 {
|
||||||
|
// Try from query parameters as fallback
|
||||||
|
tokenBytes = ctx.QueryArgs().Peek(TokenFieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokenBytes) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateToken(ctx, authManager, string(tokenBytes))
|
||||||
|
}
|
178
internal/csrf/csrf_test.go
Normal file
178
internal/csrf/csrf_test.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package csrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateToken(t *testing.T) {
|
||||||
|
// Create a mock session
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
Data: make(map[string]any),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create mock context
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
// Generate token
|
||||||
|
token := GenerateToken(ctx, nil)
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Error("Expected non-empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that token was stored in session
|
||||||
|
storedToken := GetStoredToken(session)
|
||||||
|
if storedToken != token {
|
||||||
|
t.Errorf("Expected stored token %s, got %s", token, storedToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateToken(t *testing.T) {
|
||||||
|
// Create session with token
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
// Valid token should pass
|
||||||
|
if !ValidateToken(ctx, nil, "test-token") {
|
||||||
|
t.Error("Expected valid token to pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid token should fail
|
||||||
|
if ValidateToken(ctx, nil, "wrong-token") {
|
||||||
|
t.Error("Expected invalid token to fail validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty token should fail
|
||||||
|
if ValidateToken(ctx, nil, "") {
|
||||||
|
t.Error("Expected empty token to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenNoSession(t *testing.T) {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
|
// No session should fail validation
|
||||||
|
if ValidateToken(ctx, nil, "any-token") {
|
||||||
|
t.Error("Expected validation to fail with no session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHiddenField(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
field := HiddenField(ctx, nil)
|
||||||
|
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
|
||||||
|
|
||||||
|
if field != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHiddenFieldNoSession(t *testing.T) {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
|
field := HiddenField(ctx, nil)
|
||||||
|
if field != "" {
|
||||||
|
t.Errorf("Expected empty field with no session, got %s", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenMeta(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
meta := TokenMeta(ctx, nil)
|
||||||
|
expected := `<meta name="csrf-token" content="test-token">`
|
||||||
|
|
||||||
|
if meta != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreAndGetToken(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
Data: make(map[string]any),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := "test-token"
|
||||||
|
StoreToken(session, token)
|
||||||
|
|
||||||
|
retrieved := GetStoredToken(session)
|
||||||
|
if retrieved != token {
|
||||||
|
t.Errorf("Expected %s, got %s", token, retrieved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStoredTokenNoData(t *testing.T) {
|
||||||
|
session := &auth.Session{}
|
||||||
|
|
||||||
|
token := GetStoredToken(session)
|
||||||
|
if token != "" {
|
||||||
|
t.Errorf("Expected empty token, got %s", token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateFormToken(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
// Set form data
|
||||||
|
ctx.PostArgs().Set(TokenFieldName, "test-token")
|
||||||
|
|
||||||
|
if !ValidateFormToken(ctx, nil) {
|
||||||
|
t.Error("Expected form token validation to pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with wrong token
|
||||||
|
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
|
||||||
|
|
||||||
|
if ValidateFormToken(ctx, nil) {
|
||||||
|
t.Error("Expected form token validation to fail with wrong token")
|
||||||
|
}
|
||||||
|
}
|
29
internal/csrf/doc.go
Normal file
29
internal/csrf/doc.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
|
||||||
|
// with session-based token storage and form helpers.
|
||||||
|
//
|
||||||
|
// # Basic Usage
|
||||||
|
//
|
||||||
|
// // Generate token and store in session
|
||||||
|
// token := csrf.GenerateToken(ctx, authManager)
|
||||||
|
//
|
||||||
|
// // In templates - generate hidden input field
|
||||||
|
// hiddenField := csrf.HiddenField(ctx, authManager)
|
||||||
|
//
|
||||||
|
// // Verify form submission
|
||||||
|
// if !csrf.ValidateToken(ctx, authManager, formToken) {
|
||||||
|
// // Handle CSRF validation failure
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// # Middleware Integration
|
||||||
|
//
|
||||||
|
// // Add CSRF middleware to protected routes
|
||||||
|
// r.Use(middleware.CSRF(authManager))
|
||||||
|
//
|
||||||
|
// # Security Features
|
||||||
|
//
|
||||||
|
// - Cryptographically secure token generation
|
||||||
|
// - Session-based token storage and validation
|
||||||
|
// - Automatic token rotation on successful validation
|
||||||
|
// - Protection against timing attacks with constant-time comparison
|
||||||
|
// - Integration with existing authentication system
|
||||||
|
package csrf
|
126
internal/middleware/csrf.go
Normal file
126
internal/middleware/csrf.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/csrf"
|
||||||
|
"dk/internal/router"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CSRFConfig holds configuration for CSRF middleware
|
||||||
|
type CSRFConfig struct {
|
||||||
|
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
|
||||||
|
SkipMethods []string
|
||||||
|
// Custom failure handler (default: returns 403)
|
||||||
|
FailureHandler func(ctx router.Ctx)
|
||||||
|
// Skip CSRF for certain paths
|
||||||
|
SkipPaths []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRF creates a CSRF protection middleware
|
||||||
|
func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware {
|
||||||
|
cfg := CSRFConfig{
|
||||||
|
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
|
||||||
|
FailureHandler: func(ctx router.Ctx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.SetContentType("text/plain")
|
||||||
|
ctx.WriteString("CSRF token validation failed")
|
||||||
|
},
|
||||||
|
SkipPaths: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply custom config if provided
|
||||||
|
if len(config) > 0 {
|
||||||
|
if len(config[0].SkipMethods) > 0 {
|
||||||
|
cfg.SkipMethods = config[0].SkipMethods
|
||||||
|
}
|
||||||
|
if config[0].FailureHandler != nil {
|
||||||
|
cfg.FailureHandler = config[0].FailureHandler
|
||||||
|
}
|
||||||
|
if len(config[0].SkipPaths) > 0 {
|
||||||
|
cfg.SkipPaths = config[0].SkipPaths
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
method := string(ctx.Method())
|
||||||
|
path := string(ctx.Path())
|
||||||
|
|
||||||
|
// Skip CSRF validation for certain methods
|
||||||
|
shouldSkip := slices.Contains(cfg.SkipMethods, method)
|
||||||
|
|
||||||
|
// Skip CSRF validation for certain paths
|
||||||
|
if !shouldSkip {
|
||||||
|
if slices.Contains(cfg.SkipPaths, path) {
|
||||||
|
shouldSkip = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip CSRF for non-authenticated users (no session)
|
||||||
|
if !shouldSkip && !IsAuthenticated(ctx) {
|
||||||
|
shouldSkip = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldSkip {
|
||||||
|
next(ctx, params)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate CSRF token for protected methods
|
||||||
|
if !csrf.ValidateFormToken(ctx, authManager) {
|
||||||
|
cfg.FailureHandler(ctx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRF validation passed, rotate token for security
|
||||||
|
csrf.RotateToken(ctx, authManager)
|
||||||
|
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireCSRF is a stricter CSRF middleware that always validates tokens
|
||||||
|
func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ctx)) router.Middleware {
|
||||||
|
handler := func(ctx router.Ctx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.SetContentType("text/plain")
|
||||||
|
ctx.WriteString("CSRF token required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(failureHandler) > 0 {
|
||||||
|
handler = failureHandler[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
if !csrf.ValidateFormToken(ctx, authManager) {
|
||||||
|
handler(ctx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rotate token after successful validation
|
||||||
|
csrf.RotateToken(ctx, authManager)
|
||||||
|
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFToken returns the current CSRF token for the request
|
||||||
|
func CSRFToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
return csrf.GetToken(ctx, authManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFHiddenField generates a hidden input field for forms
|
||||||
|
func CSRFHiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
return csrf.HiddenField(ctx, authManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFMeta generates a meta tag for JavaScript access
|
||||||
|
func CSRFMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
return csrf.TokenMeta(ctx, authManager)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user