simplified middleware interface

This commit is contained in:
Sky Johnson 2025-08-15 14:34:28 -05:00
parent 5bcaa4c89f
commit b7822c1b50
6 changed files with 72 additions and 67 deletions

View File

@ -10,20 +10,18 @@ const UserCtxKey = "user"
// Middleware adds authentication handling // Middleware adds authentication handling
func Middleware(userLookup func(int) any) sushi.Middleware { func Middleware(userLookup func(int) any) sushi.Middleware {
return func(next sushi.Handler) sushi.Handler { return func(ctx sushi.Ctx, params []string, next func()) {
return func(ctx sushi.Ctx, params []string) { sess := session.GetCurrentSession(ctx)
sess := session.GetCurrentSession(ctx) if sess != nil && sess.UserID > 0 && userLookup != nil {
if sess != nil && sess.UserID > 0 && userLookup != nil { user := userLookup(sess.UserID)
user := userLookup(sess.UserID) if user != nil {
if user != nil { ctx.SetUserValue(UserCtxKey, user)
ctx.SetUserValue(UserCtxKey, user) } else {
} else { sess.SetUserID(0)
sess.SetUserID(0) session.StoreSession(sess)
session.StoreSession(sess)
}
} }
next(ctx, params)
} }
next()
} }
} }
@ -34,14 +32,12 @@ func RequireAuth(redirectPath ...string) sushi.Middleware {
redirect = redirectPath[0] redirect = redirectPath[0]
} }
return func(next sushi.Handler) sushi.Handler { return func(ctx sushi.Ctx, params []string, next func()) {
return func(ctx sushi.Ctx, params []string) { if !IsAuthenticated(ctx) {
if !IsAuthenticated(ctx) { ctx.Redirect(redirect, fasthttp.StatusFound)
ctx.Redirect(redirect, fasthttp.StatusFound) return
return
}
next(ctx, params)
} }
next()
} }
} }
@ -52,14 +48,12 @@ func RequireGuest(redirectPath ...string) sushi.Middleware {
redirect = redirectPath[0] redirect = redirectPath[0]
} }
return func(next sushi.Handler) sushi.Handler { return func(ctx sushi.Ctx, params []string, next func()) {
return func(ctx sushi.Ctx, params []string) { if IsAuthenticated(ctx) {
if IsAuthenticated(ctx) { ctx.Redirect(redirect, fasthttp.StatusFound)
ctx.Redirect(redirect, fasthttp.StatusFound) return
return
}
next(ctx, params)
} }
next()
} }
} }

View File

@ -119,20 +119,18 @@ func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
// Middleware returns middleware that automatically validates CSRF tokens // Middleware returns middleware that automatically validates CSRF tokens
func Middleware() sushi.Middleware { func Middleware() sushi.Middleware {
return func(next sushi.Handler) sushi.Handler { return func(ctx sushi.Ctx, params []string, next func()) {
return func(ctx sushi.Ctx, params []string) { method := string(ctx.Method())
method := string(ctx.Method())
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateFormCSRFToken(ctx) { if !ValidateFormCSRFToken(ctx) {
GenerateCSRFToken(ctx) GenerateCSRFToken(ctx)
currentPath := string(ctx.Path()) currentPath := string(ctx.Path())
ctx.Redirect(currentPath, 302) ctx.Redirect(currentPath, 302)
return return
}
} }
next(ctx, params)
} }
next()
} }
} }

View File

@ -141,10 +141,27 @@ func (r *Router) methodNode(method string) *node {
} }
func applyMiddleware(h Handler, mw []Middleware) Handler { func applyMiddleware(h Handler, mw []Middleware) Handler {
for i := len(mw) - 1; i >= 0; i-- { if len(mw) == 0 {
h = mw[i](h) return h
}
return func(ctx Ctx, params []string) {
var index int
var next func()
next = func() {
if index >= len(mw) {
h(ctx, params)
return
}
currentMW := mw[index]
index++
currentMW(ctx, params, next)
}
next()
} }
return h
} }
func readSegment(path string, start int) (segment string, end int, hasMore bool) { func readSegment(path string, start int) (segment string, end int, hasMore bool) {

View File

@ -4,27 +4,25 @@ import sushi "git.sharkk.net/Sharkk/Sushi"
// Middleware provides session handling // Middleware provides session handling
func Middleware() sushi.Middleware { func Middleware() sushi.Middleware {
return func(next sushi.Handler) sushi.Handler { return func(ctx sushi.Ctx, params []string, next func()) {
return func(ctx sushi.Ctx, params []string) { sessionID := sushi.GetCookie(ctx, SessionCookieName)
sessionID := sushi.GetCookie(ctx, SessionCookieName) var sess *Session
var sess *Session
if sessionID != "" { if sessionID != "" {
if existingSess, exists := GetSession(sessionID); exists { if existingSess, exists := GetSession(sessionID); exists {
sess = existingSess sess = existingSess
sess.Touch() sess.Touch()
StoreSession(sess) StoreSession(sess)
SetSessionCookie(ctx, sessionID) SetSessionCookie(ctx, sessionID)
}
} }
if sess == nil {
sess = CreateSession(0) // Guest session
SetSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue(SessionCtxKey, sess)
next(ctx, params)
} }
if sess == nil {
sess = CreateSession(0) // Guest session
SetSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue(SessionCtxKey, sess)
next()
} }
} }

View File

@ -11,12 +11,10 @@ const RequestTimerKey = "request_start_time"
// Middleware adds request timing functionality // Middleware adds request timing functionality
func Middleware() sushi.Middleware { func Middleware() sushi.Middleware {
return func(next sushi.Handler) sushi.Handler { return func(ctx sushi.Ctx, params []string, next func()) {
return func(ctx sushi.Ctx, params []string) { startTime := time.Now()
startTime := time.Now() ctx.SetUserValue(RequestTimerKey, startTime)
ctx.SetUserValue(RequestTimerKey, startTime) next()
next(ctx, params)
}
} }
} }

View File

@ -4,4 +4,4 @@ import "github.com/valyala/fasthttp"
type Ctx = *fasthttp.RequestCtx type Ctx = *fasthttp.RequestCtx
type Handler func(ctx Ctx, params []string) type Handler func(ctx Ctx, params []string)
type Middleware func(Handler) Handler type Middleware func(ctx Ctx, params []string, next func())