139 lines
3.1 KiB
Go
139 lines
3.1 KiB
Go
package csrf
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"fmt"
|
|
|
|
sushi "git.sharkk.net/Sharkk/Sushi"
|
|
"git.sharkk.net/Sharkk/Sushi/session"
|
|
)
|
|
|
|
const (
|
|
CSRFTokenLength = 32
|
|
CSRFTokenFieldName = "_csrf_token"
|
|
CSRFSessionKey = "csrf_token"
|
|
SessionCtxKey = "session"
|
|
)
|
|
|
|
// GetCurrentSession retrieves the session from context
|
|
func GetCurrentSession(ctx sushi.Ctx) *session.Session {
|
|
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
|
|
return sess
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GenerateCSRFToken creates a new CSRF token and stores it in the session
|
|
func GenerateCSRFToken(ctx sushi.Ctx) string {
|
|
tokenBytes := make([]byte, CSRFTokenLength)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return ""
|
|
}
|
|
|
|
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
|
|
|
if sess := GetCurrentSession(ctx); sess != nil {
|
|
sess.Set(CSRFSessionKey, token)
|
|
session.StoreSession(sess)
|
|
}
|
|
|
|
return token
|
|
}
|
|
|
|
// GetCSRFToken retrieves the current CSRF token from session, generating one if needed
|
|
func GetCSRFToken(ctx sushi.Ctx) string {
|
|
sess := GetCurrentSession(ctx)
|
|
if sess == nil {
|
|
return ""
|
|
}
|
|
|
|
if existingToken, ok := sess.Get(CSRFSessionKey); ok {
|
|
if tokenStr, ok := existingToken.(string); ok {
|
|
return tokenStr
|
|
}
|
|
}
|
|
|
|
return GenerateCSRFToken(ctx)
|
|
}
|
|
|
|
// ValidateCSRFToken verifies a CSRF token against the stored session token
|
|
func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool {
|
|
if submittedToken == "" {
|
|
return false
|
|
}
|
|
|
|
sess := GetCurrentSession(ctx)
|
|
if sess == nil {
|
|
return false
|
|
}
|
|
|
|
storedToken, ok := sess.Get(CSRFSessionKey)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
storedTokenStr, ok := storedToken.(string)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
|
|
}
|
|
|
|
// CSRFHiddenField generates an HTML hidden input field with the CSRF token
|
|
func CSRFHiddenField(ctx sushi.Ctx) string {
|
|
token := GetCSRFToken(ctx)
|
|
if token == "" {
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
|
|
CSRFTokenFieldName, token)
|
|
}
|
|
|
|
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
|
|
func CSRFTokenMeta(ctx sushi.Ctx) string {
|
|
token := GetCSRFToken(ctx)
|
|
if token == "" {
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
|
|
}
|
|
|
|
// ValidateFormCSRFToken validates CSRF token from form data
|
|
func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
|
|
tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName)
|
|
if len(tokenBytes) == 0 {
|
|
tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName)
|
|
}
|
|
|
|
if len(tokenBytes) == 0 {
|
|
return false
|
|
}
|
|
|
|
return ValidateCSRFToken(ctx, string(tokenBytes))
|
|
}
|
|
|
|
// Middleware returns middleware that automatically validates CSRF tokens
|
|
func Middleware() sushi.Middleware {
|
|
return func(next sushi.Handler) sushi.Handler {
|
|
return func(ctx sushi.Ctx, params []string) {
|
|
method := string(ctx.Method())
|
|
|
|
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
|
if !ValidateFormCSRFToken(ctx) {
|
|
GenerateCSRFToken(ctx)
|
|
currentPath := string(ctx.Path())
|
|
ctx.Redirect(currentPath, 302)
|
|
return
|
|
}
|
|
}
|
|
|
|
next(ctx, params)
|
|
}
|
|
}
|
|
}
|