package csrf import ( "crypto/rand" "crypto/subtle" "encoding/base64" "fmt" sushi "git.sharkk.net/Sharkk/Sushi" "git.sharkk.net/Sharkk/Sushi/session" ) const ( CSRFTokenLength = 32 CSRFTokenFieldName = "_csrf_token" CSRFSessionKey = "csrf_token" SessionCtxKey = "session" ) // GetCurrentSession retrieves the session from context func GetCurrentSession(ctx sushi.Ctx) *session.Session { if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok { return sess } return nil } // GenerateCSRFToken creates a new CSRF token and stores it in the session func GenerateCSRFToken(ctx sushi.Ctx) string { tokenBytes := make([]byte, CSRFTokenLength) if _, err := rand.Read(tokenBytes); err != nil { return "" } token := base64.URLEncoding.EncodeToString(tokenBytes) if sess := GetCurrentSession(ctx); sess != nil { sess.Set(CSRFSessionKey, token) session.StoreSession(sess) } return token } // GetCSRFToken retrieves the current CSRF token from session, generating one if needed func GetCSRFToken(ctx sushi.Ctx) string { sess := GetCurrentSession(ctx) if sess == nil { return "" } if existingToken, ok := sess.Get(CSRFSessionKey); ok { if tokenStr, ok := existingToken.(string); ok { return tokenStr } } return GenerateCSRFToken(ctx) } // ValidateCSRFToken verifies a CSRF token against the stored session token func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool { if submittedToken == "" { return false } sess := GetCurrentSession(ctx) if sess == nil { return false } storedToken, ok := sess.Get(CSRFSessionKey) if !ok { return false } storedTokenStr, ok := storedToken.(string) if !ok { return false } return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1 } // CSRFHiddenField generates an HTML hidden input field with the CSRF token func CSRFHiddenField(ctx sushi.Ctx) string { token := GetCSRFToken(ctx) if token == "" { return "" } return fmt.Sprintf(``, CSRFTokenFieldName, token) } // CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token func CSRFTokenMeta(ctx sushi.Ctx) string { token := GetCSRFToken(ctx) if token == "" { return "" } return fmt.Sprintf(``, token) } // ValidateFormCSRFToken validates CSRF token from form data func ValidateFormCSRFToken(ctx sushi.Ctx) bool { tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName) if len(tokenBytes) == 0 { tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName) } if len(tokenBytes) == 0 { return false } return ValidateCSRFToken(ctx, string(tokenBytes)) } // Middleware returns middleware that automatically validates CSRF tokens func Middleware() sushi.Middleware { return func(next sushi.Handler) sushi.Handler { return func(ctx sushi.Ctx, params []string) { method := string(ctx.Method()) if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { if !ValidateFormCSRFToken(ctx) { GenerateCSRFToken(ctx) currentPath := string(ctx.Path()) ctx.Redirect(currentPath, 302) return } } next(ctx, params) } } }