diff --git a/auth/auth.go b/auth/auth.go index da5af57..5084724 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -10,20 +10,18 @@ const UserCtxKey = "user" // Middleware adds authentication handling func Middleware(userLookup func(int) any) sushi.Middleware { - return func(next sushi.Handler) sushi.Handler { - return func(ctx sushi.Ctx, params []string) { - sess := session.GetCurrentSession(ctx) - if sess != nil && sess.UserID > 0 && userLookup != nil { - user := userLookup(sess.UserID) - if user != nil { - ctx.SetUserValue(UserCtxKey, user) - } else { - sess.SetUserID(0) - session.StoreSession(sess) - } + return func(ctx sushi.Ctx, params []string, next func()) { + sess := session.GetCurrentSession(ctx) + if sess != nil && sess.UserID > 0 && userLookup != nil { + user := userLookup(sess.UserID) + if user != nil { + ctx.SetUserValue(UserCtxKey, user) + } else { + sess.SetUserID(0) + session.StoreSession(sess) } - next(ctx, params) } + next() } } @@ -34,14 +32,12 @@ func RequireAuth(redirectPath ...string) sushi.Middleware { redirect = redirectPath[0] } - return func(next sushi.Handler) sushi.Handler { - return func(ctx sushi.Ctx, params []string) { - if !IsAuthenticated(ctx) { - ctx.Redirect(redirect, fasthttp.StatusFound) - return - } - next(ctx, params) + return func(ctx sushi.Ctx, params []string, next func()) { + if !IsAuthenticated(ctx) { + ctx.Redirect(redirect, fasthttp.StatusFound) + return } + next() } } @@ -52,14 +48,12 @@ func RequireGuest(redirectPath ...string) sushi.Middleware { redirect = redirectPath[0] } - return func(next sushi.Handler) sushi.Handler { - return func(ctx sushi.Ctx, params []string) { - if IsAuthenticated(ctx) { - ctx.Redirect(redirect, fasthttp.StatusFound) - return - } - next(ctx, params) + return func(ctx sushi.Ctx, params []string, next func()) { + if IsAuthenticated(ctx) { + ctx.Redirect(redirect, fasthttp.StatusFound) + return } + next() } } diff --git a/csrf/csrf.go b/csrf/csrf.go index 6367969..d56c6e5 100644 --- a/csrf/csrf.go +++ b/csrf/csrf.go @@ -119,20 +119,18 @@ func ValidateFormCSRFToken(ctx sushi.Ctx) bool { // Middleware returns middleware that automatically validates CSRF tokens func Middleware() sushi.Middleware { - return func(next sushi.Handler) sushi.Handler { - return func(ctx sushi.Ctx, params []string) { - method := string(ctx.Method()) + return func(ctx sushi.Ctx, params []string, next func()) { + method := string(ctx.Method()) - if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { - if !ValidateFormCSRFToken(ctx) { - GenerateCSRFToken(ctx) - currentPath := string(ctx.Path()) - ctx.Redirect(currentPath, 302) - return - } + if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { + if !ValidateFormCSRFToken(ctx) { + GenerateCSRFToken(ctx) + currentPath := string(ctx.Path()) + ctx.Redirect(currentPath, 302) + return } - - next(ctx, params) } + + next() } } diff --git a/router.go b/router.go index 2e28862..d188abb 100644 --- a/router.go +++ b/router.go @@ -141,10 +141,27 @@ func (r *Router) methodNode(method string) *node { } func applyMiddleware(h Handler, mw []Middleware) Handler { - for i := len(mw) - 1; i >= 0; i-- { - h = mw[i](h) + if len(mw) == 0 { + return h + } + + return func(ctx Ctx, params []string) { + var index int + var next func() + + next = func() { + if index >= len(mw) { + h(ctx, params) + return + } + + currentMW := mw[index] + index++ + currentMW(ctx, params, next) + } + + next() } - return h } func readSegment(path string, start int) (segment string, end int, hasMore bool) { diff --git a/session/middleware.go b/session/middleware.go index a6013c3..52b9b37 100644 --- a/session/middleware.go +++ b/session/middleware.go @@ -4,27 +4,25 @@ import sushi "git.sharkk.net/Sharkk/Sushi" // Middleware provides session handling func Middleware() sushi.Middleware { - return func(next sushi.Handler) sushi.Handler { - return func(ctx sushi.Ctx, params []string) { - sessionID := sushi.GetCookie(ctx, SessionCookieName) - var sess *Session + return func(ctx sushi.Ctx, params []string, next func()) { + sessionID := sushi.GetCookie(ctx, SessionCookieName) + var sess *Session - if sessionID != "" { - if existingSess, exists := GetSession(sessionID); exists { - sess = existingSess - sess.Touch() - StoreSession(sess) - SetSessionCookie(ctx, sessionID) - } + if sessionID != "" { + if existingSess, exists := GetSession(sessionID); exists { + sess = existingSess + sess.Touch() + StoreSession(sess) + SetSessionCookie(ctx, sessionID) } - - if sess == nil { - sess = CreateSession(0) // Guest session - SetSessionCookie(ctx, sess.ID) - } - - ctx.SetUserValue(SessionCtxKey, sess) - next(ctx, params) } + + if sess == nil { + sess = CreateSession(0) // Guest session + SetSessionCookie(ctx, sess.ID) + } + + ctx.SetUserValue(SessionCtxKey, sess) + next() } } diff --git a/timing/timing.go b/timing/timing.go index eda682d..1c668d7 100644 --- a/timing/timing.go +++ b/timing/timing.go @@ -11,12 +11,10 @@ const RequestTimerKey = "request_start_time" // Middleware adds request timing functionality func Middleware() sushi.Middleware { - return func(next sushi.Handler) sushi.Handler { - return func(ctx sushi.Ctx, params []string) { - startTime := time.Now() - ctx.SetUserValue(RequestTimerKey, startTime) - next(ctx, params) - } + return func(ctx sushi.Ctx, params []string, next func()) { + startTime := time.Now() + ctx.SetUserValue(RequestTimerKey, startTime) + next() } } diff --git a/types.go b/types.go index b4b20fa..34eed04 100644 --- a/types.go +++ b/types.go @@ -4,4 +4,4 @@ import "github.com/valyala/fasthttp" type Ctx = *fasthttp.RequestCtx type Handler func(ctx Ctx, params []string) -type Middleware func(Handler) Handler +type Middleware func(ctx Ctx, params []string, next func())