128 lines
2.8 KiB
Go
128 lines
2.8 KiB
Go
package csrf
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"fmt"
|
|
|
|
sushi "git.sharkk.net/Sharkk/Sushi"
|
|
)
|
|
|
|
const (
|
|
CSRFTokenLength = 32
|
|
CSRFTokenFieldName = "_csrf_token"
|
|
CSRFSessionKey = "csrf_token"
|
|
SessionCtxKey = "session"
|
|
)
|
|
|
|
// GenerateToken creates a new CSRF token and stores it in the session
|
|
func GenerateToken(ctx sushi.Ctx) string {
|
|
tokenBytes := make([]byte, CSRFTokenLength)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return ""
|
|
}
|
|
|
|
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
|
|
|
if sess := ctx.GetCurrentSession(); sess != nil {
|
|
sess.Set(CSRFSessionKey, token)
|
|
sushi.StoreSession(sess)
|
|
}
|
|
|
|
return token
|
|
}
|
|
|
|
// GetToken retrieves the current CSRF token from session, generating one if needed
|
|
func GetToken(ctx sushi.Ctx) string {
|
|
sess := ctx.GetCurrentSession()
|
|
if sess == nil {
|
|
return ""
|
|
}
|
|
|
|
if existingToken, ok := sess.Get(CSRFSessionKey); ok {
|
|
if tokenStr, ok := existingToken.(string); ok {
|
|
return tokenStr
|
|
}
|
|
}
|
|
|
|
return GenerateToken(ctx)
|
|
}
|
|
|
|
// ValidateToken verifies a CSRF token against the stored session token
|
|
func ValidateToken(ctx sushi.Ctx, submittedToken string) bool {
|
|
if submittedToken == "" {
|
|
return false
|
|
}
|
|
|
|
sess := ctx.GetCurrentSession()
|
|
if sess == nil {
|
|
return false
|
|
}
|
|
|
|
storedToken, ok := sess.Get(CSRFSessionKey)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
storedTokenStr, ok := storedToken.(string)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
|
|
}
|
|
|
|
// HiddenField generates an HTML hidden input field with the CSRF token
|
|
func HiddenField(ctx sushi.Ctx) string {
|
|
token := GetToken(ctx)
|
|
if token == "" {
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
|
|
CSRFTokenFieldName, token)
|
|
}
|
|
|
|
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
|
|
func CSRFTokenMeta(ctx sushi.Ctx) string {
|
|
token := GetToken(ctx)
|
|
if token == "" {
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
|
|
}
|
|
|
|
// ValidateFormCSRFToken validates CSRF token from form data
|
|
func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
|
|
tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName)
|
|
if len(tokenBytes) == 0 {
|
|
tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName)
|
|
}
|
|
|
|
if len(tokenBytes) == 0 {
|
|
return false
|
|
}
|
|
|
|
return ValidateToken(ctx, string(tokenBytes))
|
|
}
|
|
|
|
// Middleware returns middleware that automatically validates CSRF tokens
|
|
func Middleware() sushi.Middleware {
|
|
return func(ctx sushi.Ctx, next func()) {
|
|
method := string(ctx.Method())
|
|
|
|
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
|
if !ValidateFormCSRFToken(ctx) {
|
|
GenerateToken(ctx)
|
|
currentPath := string(ctx.Path())
|
|
ctx.Redirect(currentPath, 302)
|
|
return
|
|
}
|
|
}
|
|
|
|
next()
|
|
}
|
|
}
|