package csrf import ( "crypto/rand" "crypto/subtle" "encoding/base64" "fmt" sushi "git.sharkk.net/Sharkk/Sushi" ) const ( CSRFTokenLength = 32 CSRFTokenFieldName = "_csrf_token" CSRFSessionKey = "csrf_token" SessionCtxKey = "session" ) // GenerateToken creates a new CSRF token and stores it in the session func GenerateToken(ctx sushi.Ctx) string { tokenBytes := make([]byte, CSRFTokenLength) if _, err := rand.Read(tokenBytes); err != nil { return "" } token := base64.URLEncoding.EncodeToString(tokenBytes) if sess := ctx.GetCurrentSession(); sess != nil { sess.Set(CSRFSessionKey, token) sushi.StoreSession(sess) } return token } // GetToken retrieves the current CSRF token from session, generating one if needed func GetToken(ctx sushi.Ctx) string { sess := ctx.GetCurrentSession() if sess == nil { return "" } if existingToken, ok := sess.Get(CSRFSessionKey); ok { if tokenStr, ok := existingToken.(string); ok { return tokenStr } } return GenerateToken(ctx) } // ValidateToken verifies a CSRF token against the stored session token func ValidateToken(ctx sushi.Ctx, submittedToken string) bool { if submittedToken == "" { return false } sess := ctx.GetCurrentSession() if sess == nil { return false } storedToken, ok := sess.Get(CSRFSessionKey) if !ok { return false } storedTokenStr, ok := storedToken.(string) if !ok { return false } return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1 } // HiddenField generates an HTML hidden input field with the CSRF token func HiddenField(ctx sushi.Ctx) string { token := GetToken(ctx) if token == "" { return "" } return fmt.Sprintf(``, CSRFTokenFieldName, token) } // CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token func CSRFTokenMeta(ctx sushi.Ctx) string { token := GetToken(ctx) if token == "" { return "" } return fmt.Sprintf(``, token) } // ValidateFormCSRFToken validates CSRF token from form data func ValidateFormCSRFToken(ctx sushi.Ctx) bool { tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName) if len(tokenBytes) == 0 { tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName) } if len(tokenBytes) == 0 { return false } return ValidateToken(ctx, string(tokenBytes)) } // Middleware returns middleware that automatically validates CSRF tokens func Middleware() sushi.Middleware { return func(ctx sushi.Ctx, next func()) { method := string(ctx.Method()) if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { if !ValidateFormCSRFToken(ctx) { GenerateToken(ctx) currentPath := string(ctx.Path()) ctx.Redirect(currentPath, 302) return } } next() } }