Sushi/csrf/csrf.go
2025-08-15 14:23:09 -05:00

139 lines
3.1 KiB
Go

package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
sushi "git.sharkk.net/Sharkk/Sushi"
"git.sharkk.net/Sharkk/Sushi/session"
)
const (
CSRFTokenLength = 32
CSRFTokenFieldName = "_csrf_token"
CSRFSessionKey = "csrf_token"
SessionCtxKey = "session"
)
// GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx sushi.Ctx) *session.Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return sess
}
return nil
}
// GenerateCSRFToken creates a new CSRF token and stores it in the session
func GenerateCSRFToken(ctx sushi.Ctx) string {
tokenBytes := make([]byte, CSRFTokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
return ""
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
if sess := GetCurrentSession(ctx); sess != nil {
sess.Set(CSRFSessionKey, token)
session.StoreSession(sess)
}
return token
}
// GetCSRFToken retrieves the current CSRF token from session, generating one if needed
func GetCSRFToken(ctx sushi.Ctx) string {
sess := GetCurrentSession(ctx)
if sess == nil {
return ""
}
if existingToken, ok := sess.Get(CSRFSessionKey); ok {
if tokenStr, ok := existingToken.(string); ok {
return tokenStr
}
}
return GenerateCSRFToken(ctx)
}
// ValidateCSRFToken verifies a CSRF token against the stored session token
func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool {
if submittedToken == "" {
return false
}
sess := GetCurrentSession(ctx)
if sess == nil {
return false
}
storedToken, ok := sess.Get(CSRFSessionKey)
if !ok {
return false
}
storedTokenStr, ok := storedToken.(string)
if !ok {
return false
}
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
}
// CSRFHiddenField generates an HTML hidden input field with the CSRF token
func CSRFHiddenField(ctx sushi.Ctx) string {
token := GetCSRFToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
CSRFTokenFieldName, token)
}
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
func CSRFTokenMeta(ctx sushi.Ctx) string {
token := GetCSRFToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
}
// ValidateFormCSRFToken validates CSRF token from form data
func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName)
if len(tokenBytes) == 0 {
tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName)
}
if len(tokenBytes) == 0 {
return false
}
return ValidateCSRFToken(ctx, string(tokenBytes))
}
// Middleware returns middleware that automatically validates CSRF tokens
func Middleware() sushi.Middleware {
return func(next sushi.Handler) sushi.Handler {
return func(ctx sushi.Ctx, params []string) {
method := string(ctx.Method())
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateFormCSRFToken(ctx) {
GenerateCSRFToken(ctx)
currentPath := string(ctx.Path())
ctx.Redirect(currentPath, 302)
return
}
}
next(ctx, params)
}
}
}