77 lines
1.8 KiB
Go
77 lines
1.8 KiB
Go
package auth
|
|
|
|
import (
|
|
sushi "git.sharkk.net/Sharkk/Sushi"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
const UserCtxKey = "user"
|
|
|
|
// Auth holds the authentication middleware and user lookup function
|
|
type Auth struct {
|
|
userLookup func(int) any
|
|
}
|
|
|
|
// New creates a new Auth instance
|
|
func New(userLookup func(int) any) *Auth {
|
|
return &Auth{userLookup: userLookup}
|
|
}
|
|
|
|
// Middleware returns the authentication middleware function
|
|
func (a *Auth) Middleware() sushi.Middleware {
|
|
return func(ctx sushi.Ctx, next func()) {
|
|
sess := sushi.GetCurrentSession(ctx)
|
|
if sess != nil && sess.UserID > 0 && a.userLookup != nil {
|
|
user := a.userLookup(sess.UserID)
|
|
if user != nil {
|
|
ctx.SetUserValue(UserCtxKey, user)
|
|
} else {
|
|
sess.SetUserID(0)
|
|
sushi.StoreSession(sess)
|
|
}
|
|
}
|
|
next()
|
|
}
|
|
}
|
|
|
|
// Update refreshes the current user data in the context
|
|
func (a *Auth) Update(ctx sushi.Ctx) {
|
|
sess := sushi.GetCurrentSession(ctx)
|
|
if sess != nil && sess.UserID > 0 && a.userLookup != nil {
|
|
user := a.userLookup(sess.UserID)
|
|
ctx.SetUserValue(UserCtxKey, user)
|
|
}
|
|
}
|
|
|
|
// RequireAuth middleware that redirects unauthenticated users
|
|
func RequireAuth(redirectPath ...string) sushi.Middleware {
|
|
redirect := "/login"
|
|
if len(redirectPath) > 0 && redirectPath[0] != "" {
|
|
redirect = redirectPath[0]
|
|
}
|
|
|
|
return func(ctx sushi.Ctx, next func()) {
|
|
if !ctx.IsAuthenticated() {
|
|
ctx.Redirect(redirect, fasthttp.StatusFound)
|
|
return
|
|
}
|
|
next()
|
|
}
|
|
}
|
|
|
|
// RequireGuest middleware that redirects authenticated users
|
|
func RequireGuest(redirectPath ...string) sushi.Middleware {
|
|
redirect := "/"
|
|
if len(redirectPath) > 0 && redirectPath[0] != "" {
|
|
redirect = redirectPath[0]
|
|
}
|
|
|
|
return func(ctx sushi.Ctx, next func()) {
|
|
if ctx.IsAuthenticated() {
|
|
ctx.Redirect(redirect, fasthttp.StatusFound)
|
|
return
|
|
}
|
|
next()
|
|
}
|
|
}
|