Sushi/csrf/csrf.go

128 lines
2.8 KiB
Go

package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
sushi "git.sharkk.net/Sharkk/Sushi"
)
const (
CSRFTokenLength = 32
CSRFTokenFieldName = "_csrf_token"
CSRFSessionKey = "csrf_token"
SessionCtxKey = "session"
)
// GenerateToken creates a new CSRF token and stores it in the session
func GenerateToken(ctx sushi.Ctx) string {
tokenBytes := make([]byte, CSRFTokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
return ""
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
if sess := ctx.GetCurrentSession(); sess != nil {
sess.Set(CSRFSessionKey, token)
sushi.StoreSession(sess)
}
return token
}
// GetToken retrieves the current CSRF token from session, generating one if needed
func GetToken(ctx sushi.Ctx) string {
sess := ctx.GetCurrentSession()
if sess == nil {
return ""
}
if existingToken, ok := sess.Get(CSRFSessionKey); ok {
if tokenStr, ok := existingToken.(string); ok {
return tokenStr
}
}
return GenerateToken(ctx)
}
// ValidateToken verifies a CSRF token against the stored session token
func ValidateToken(ctx sushi.Ctx, submittedToken string) bool {
if submittedToken == "" {
return false
}
sess := ctx.GetCurrentSession()
if sess == nil {
return false
}
storedToken, ok := sess.Get(CSRFSessionKey)
if !ok {
return false
}
storedTokenStr, ok := storedToken.(string)
if !ok {
return false
}
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
}
// HiddenField generates an HTML hidden input field with the CSRF token
func HiddenField(ctx sushi.Ctx) string {
token := GetToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
CSRFTokenFieldName, token)
}
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
func CSRFTokenMeta(ctx sushi.Ctx) string {
token := GetToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
}
// ValidateFormCSRFToken validates CSRF token from form data
func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName)
if len(tokenBytes) == 0 {
tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName)
}
if len(tokenBytes) == 0 {
return false
}
return ValidateToken(ctx, string(tokenBytes))
}
// Middleware returns middleware that automatically validates CSRF tokens
func Middleware() sushi.Middleware {
return func(ctx sushi.Ctx, next func()) {
method := string(ctx.Method())
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateFormCSRFToken(ctx) {
GenerateToken(ctx)
currentPath := string(ctx.Path())
ctx.Redirect(currentPath, 302)
return
}
}
next()
}
}