move auth and session to core package, move helpers to ctx

This commit is contained in:
Sky Johnson 2025-08-16 11:37:20 -05:00
parent 5370d14152
commit 0f59ba225a
7 changed files with 99 additions and 102 deletions

View File

@ -74,22 +74,22 @@ app.Get("/users/:id/posts/:slug", func(ctx sushi.Ctx, params []any) {
```go ```go
func myHandler(ctx sushi.Ctx, params []any) { func myHandler(ctx sushi.Ctx, params []any) {
// JSON responses // JSON responses
sushi.SendJSON(ctx, map[string]string{"message": "success"}) ctx.SendJSON(map[string]string{"message": "success"})
// HTML responses // HTML responses
sushi.SendHTML(ctx, "<h1>Welcome</h1>") ctx.SendHTML("<h1>Welcome</h1>")
// Text responses // Text responses
sushi.SendText(ctx, "Plain text") ctx.SendText("Plain text")
// Error responses // Error responses
sushi.SendError(ctx, 404, "Not Found") ctx.SendError(404, "Not Found")
// Redirects // Redirects
sushi.SendRedirect(ctx, "/login") ctx.Redirect("/login")
// Status only // Status only
sushi.SendStatus(ctx, 204) ctx.SendStatus(204)
} }
``` ```
@ -186,21 +186,21 @@ func loginHandler(ctx sushi.Ctx, params []string) {
// Find user by email/username // Find user by email/username
user := findUserByEmail(email) user := findUserByEmail(email)
if user == nil { if user == nil {
sushi.SendError(ctx, 401, "Invalid credentials") ctx.SendError(401, "Invalid credentials")
return return
} }
// Verify password // Verify password
isValid, err := password.VerifyPassword(password, user.Password) isValid, err := password.VerifyPassword(password, user.Password)
if err != nil || !isValid { if err != nil || !isValid {
sushi.SendError(ctx, 401, "Invalid credentials") ctx.SendError(401, "Invalid credentials")
return return
} }
// Log the user in // Log the user in
auth.Login(ctx, user.ID, user) auth.Login(ctx, user.ID, user)
sushi.SendRedirect(ctx, "/dashboard") ctx.Redirect("/dashboard")
} }
``` ```
@ -209,7 +209,7 @@ func loginHandler(ctx sushi.Ctx, params []string) {
```go ```go
func logoutHandler(ctx sushi.Ctx, params []string) { func logoutHandler(ctx sushi.Ctx, params []string) {
auth.Logout(ctx) auth.Logout(ctx)
sushi.SendRedirect(ctx, "/") ctx.SendRedirect("/")
} }
``` ```
@ -220,7 +220,7 @@ func dashboardHandler(ctx sushi.Ctx, params []string) {
user := auth.GetCurrentUser(ctx).(*User) user := auth.GetCurrentUser(ctx).(*User)
html := fmt.Sprintf("<h1>Welcome, %s!</h1>", user.Username) html := fmt.Sprintf("<h1>Welcome, %s!</h1>", user.Username)
sushi.SendHTML(ctx, html) ctx.SendHTML(html)
} }
``` ```
@ -245,7 +245,7 @@ func loginPageHandler(ctx sushi.Ctx, params []string) {
</form> </form>
`, csrfField) `, csrfField)
sushi.SendHTML(ctx, html) ctx.SendHTML(html)
} }
``` ```
@ -362,10 +362,10 @@ func main() {
func homeHandler(ctx sushi.Ctx, params []string) { func homeHandler(ctx sushi.Ctx, params []string) {
if auth.IsAuthenticated(ctx) { if auth.IsAuthenticated(ctx) {
sushi.SendRedirect(ctx, "/dashboard") ctx.SendRedirect("/dashboard")
return return
} }
sushi.SendHTML(ctx, `<a href="/login">Login</a>`) ctx.SendHTML(c`<a href="/login">Login</a>`)
} }
func loginPageHandler(ctx sushi.Ctx, params []string) { func loginPageHandler(ctx sushi.Ctx, params []string) {

42
auth.go Normal file
View File

@ -0,0 +1,42 @@
package sushi
// IsAuthenticated checks if the current request is from an authenticated user
func (ctx Ctx) IsAuthenticated() bool {
user := ctx.UserValue("user")
return user != nil
}
// GetCurrentUser returns the current authenticated user
func (ctx Ctx) GetCurrentUser() any {
return ctx.UserValue("user")
}
// Login authenticates a user session
func (ctx Ctx) Login(userID int, user any) {
sess := GetCurrentSession(ctx)
if sess != nil {
sess.SetUserID(userID)
sess.RegenerateID()
StoreSession(sess)
ctx.SetUserValue(SessionCtxKey, sess)
ctx.SetUserValue("user", user)
SetSessionCookie(ctx, sess.ID)
}
}
// Logout clears the user session
func (ctx Ctx) Logout() {
sess := GetCurrentSession(ctx)
if sess != nil {
sess.SetUserID(0)
sess.RegenerateID()
StoreSession(sess)
ctx.SetUserValue(SessionCtxKey, sess)
SetSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue("user", nil)
}

View File

@ -2,7 +2,6 @@ package auth
import ( import (
sushi "git.sharkk.net/Sharkk/Sushi" sushi "git.sharkk.net/Sharkk/Sushi"
"git.sharkk.net/Sharkk/Sushi/session"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -11,14 +10,14 @@ const UserCtxKey = "user"
// Middleware adds authentication handling // Middleware adds authentication handling
func Middleware(userLookup func(int) any) sushi.Middleware { func Middleware(userLookup func(int) any) sushi.Middleware {
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, params []any, next func()) {
sess := session.GetCurrentSession(ctx) sess := sushi.GetCurrentSession(ctx)
if sess != nil && sess.UserID > 0 && userLookup != nil { if sess != nil && sess.UserID > 0 && userLookup != nil {
user := userLookup(sess.UserID) user := userLookup(sess.UserID)
if user != nil { if user != nil {
ctx.SetUserValue(UserCtxKey, user) ctx.SetUserValue(UserCtxKey, user)
} else { } else {
sess.SetUserID(0) sess.SetUserID(0)
session.StoreSession(sess) sushi.StoreSession(sess)
} }
} }
next() next()
@ -33,7 +32,7 @@ func RequireAuth(redirectPath ...string) sushi.Middleware {
} }
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, params []any, next func()) {
if !IsAuthenticated(ctx) { if !ctx.IsAuthenticated() {
ctx.Redirect(redirect, fasthttp.StatusFound) ctx.Redirect(redirect, fasthttp.StatusFound)
return return
} }
@ -49,51 +48,10 @@ func RequireGuest(redirectPath ...string) sushi.Middleware {
} }
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, params []any, next func()) {
if IsAuthenticated(ctx) { if ctx.IsAuthenticated() {
ctx.Redirect(redirect, fasthttp.StatusFound) ctx.Redirect(redirect, fasthttp.StatusFound)
return return
} }
next() next()
} }
} }
// IsAuthenticated checks if the current request is from an authenticated user
func IsAuthenticated(ctx sushi.Ctx) bool {
user := ctx.UserValue(UserCtxKey)
return user != nil
}
// GetCurrentUser returns the current authenticated user
func GetCurrentUser(ctx sushi.Ctx) any {
return ctx.UserValue(UserCtxKey)
}
// Login authenticates a user session
func Login(ctx sushi.Ctx, userID int, user any) {
sess := session.GetCurrentSession(ctx)
if sess != nil {
sess.SetUserID(userID)
sess.RegenerateID()
session.StoreSession(sess)
ctx.SetUserValue(session.SessionCtxKey, sess)
ctx.SetUserValue(UserCtxKey, user)
session.SetSessionCookie(ctx, sess.ID)
}
}
// Logout clears the user session
func Logout(ctx sushi.Ctx) {
sess := session.GetCurrentSession(ctx)
if sess != nil {
sess.SetUserID(0)
sess.RegenerateID()
session.StoreSession(sess)
ctx.SetUserValue(session.SessionCtxKey, sess)
session.SetSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue(UserCtxKey, nil)
}

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
sushi "git.sharkk.net/Sharkk/Sushi" sushi "git.sharkk.net/Sharkk/Sushi"
"git.sharkk.net/Sharkk/Sushi/session"
) )
const ( const (
@ -17,16 +16,8 @@ const (
SessionCtxKey = "session" SessionCtxKey = "session"
) )
// GetCurrentSession retrieves the session from context // GenerateToken creates a new CSRF token and stores it in the session
func GetCurrentSession(ctx sushi.Ctx) *session.Session { func GenerateToken(ctx sushi.Ctx) string {
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return sess
}
return nil
}
// GenerateCSRFToken creates a new CSRF token and stores it in the session
func GenerateCSRFToken(ctx sushi.Ctx) string {
tokenBytes := make([]byte, CSRFTokenLength) tokenBytes := make([]byte, CSRFTokenLength)
if _, err := rand.Read(tokenBytes); err != nil { if _, err := rand.Read(tokenBytes); err != nil {
return "" return ""
@ -34,17 +25,17 @@ func GenerateCSRFToken(ctx sushi.Ctx) string {
token := base64.URLEncoding.EncodeToString(tokenBytes) token := base64.URLEncoding.EncodeToString(tokenBytes)
if sess := GetCurrentSession(ctx); sess != nil { if sess := ctx.GetCurrentSession(); sess != nil {
sess.Set(CSRFSessionKey, token) sess.Set(CSRFSessionKey, token)
session.StoreSession(sess) sushi.StoreSession(sess)
} }
return token return token
} }
// GetCSRFToken retrieves the current CSRF token from session, generating one if needed // GetToken retrieves the current CSRF token from session, generating one if needed
func GetCSRFToken(ctx sushi.Ctx) string { func GetToken(ctx sushi.Ctx) string {
sess := GetCurrentSession(ctx) sess := ctx.GetCurrentSession()
if sess == nil { if sess == nil {
return "" return ""
} }
@ -55,16 +46,16 @@ func GetCSRFToken(ctx sushi.Ctx) string {
} }
} }
return GenerateCSRFToken(ctx) return GenerateToken(ctx)
} }
// ValidateCSRFToken verifies a CSRF token against the stored session token // ValidateToken verifies a CSRF token against the stored session token
func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool { func ValidateToken(ctx sushi.Ctx, submittedToken string) bool {
if submittedToken == "" { if submittedToken == "" {
return false return false
} }
sess := GetCurrentSession(ctx) sess := ctx.GetCurrentSession()
if sess == nil { if sess == nil {
return false return false
} }
@ -82,9 +73,9 @@ func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool {
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1 return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
} }
// CSRFHiddenField generates an HTML hidden input field with the CSRF token // HiddenField generates an HTML hidden input field with the CSRF token
func CSRFHiddenField(ctx sushi.Ctx) string { func HiddenField(ctx sushi.Ctx) string {
token := GetCSRFToken(ctx) token := GetToken(ctx)
if token == "" { if token == "" {
return "" return ""
} }
@ -95,7 +86,7 @@ func CSRFHiddenField(ctx sushi.Ctx) string {
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token // CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
func CSRFTokenMeta(ctx sushi.Ctx) string { func CSRFTokenMeta(ctx sushi.Ctx) string {
token := GetCSRFToken(ctx) token := GetToken(ctx)
if token == "" { if token == "" {
return "" return ""
} }
@ -114,7 +105,7 @@ func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
return false return false
} }
return ValidateCSRFToken(ctx, string(tokenBytes)) return ValidateToken(ctx, string(tokenBytes))
} }
// Middleware returns middleware that automatically validates CSRF tokens // Middleware returns middleware that automatically validates CSRF tokens
@ -124,7 +115,7 @@ func Middleware() sushi.Middleware {
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateFormCSRFToken(ctx) { if !ValidateFormCSRFToken(ctx) {
GenerateCSRFToken(ctx) GenerateToken(ctx)
currentPath := string(ctx.Path()) currentPath := string(ctx.Path())
ctx.Redirect(currentPath, 302) ctx.Redirect(currentPath, 302)
return return

View File

@ -1,4 +1,4 @@
package session package sushi
import ( import (
"crypto/rand" "crypto/rand"
@ -7,8 +7,6 @@ import (
"os" "os"
"sync" "sync"
"time" "time"
sushi "git.sharkk.net/Sharkk/Sushi"
) )
const ( const (
@ -132,7 +130,7 @@ func (s *Session) SetUserID(userID int) {
} }
// GetCurrentSession retrieves the session from context // GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx sushi.Ctx) *Session { func GetCurrentSession(ctx Ctx) *Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok { if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
return sess return sess
} }
@ -267,14 +265,22 @@ func SaveSessions() error {
return sessionManager.Save() return sessionManager.Save()
} }
func SetSessionCookie(ctx sushi.Ctx, sessionID string) { func SetSessionCookie(ctx Ctx, sessionID string) {
sushi.SetSecureCookie(ctx, sushi.CookieOptions{ SetSecureCookie(ctx, CookieOptions{
Name: SessionCookieName, Name: SessionCookieName,
Value: sessionID, Value: sessionID,
Path: "/", Path: "/",
Expires: time.Now().Add(24 * time.Hour), Expires: time.Now().Add(24 * time.Hour),
HTTPOnly: true, HTTPOnly: true,
Secure: sushi.IsHTTPS(ctx), Secure: IsHTTPS(ctx),
SameSite: "lax", SameSite: "lax",
}) })
} }
// GetCurrentSession retrieves the session from context
func (ctx Ctx) GetCurrentSession() *Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
return sess
}
return nil
}

View File

@ -5,24 +5,24 @@ import sushi "git.sharkk.net/Sharkk/Sushi"
// Middleware provides session handling // Middleware provides session handling
func Middleware() sushi.Middleware { func Middleware() sushi.Middleware {
return func(ctx sushi.Ctx, params []any, next func()) { return func(ctx sushi.Ctx, params []any, next func()) {
sessionID := sushi.GetCookie(ctx, SessionCookieName) sessionID := sushi.GetCookie(ctx, sushi.SessionCookieName)
var sess *Session var sess *sushi.Session
if sessionID != "" { if sessionID != "" {
if existingSess, exists := GetSession(sessionID); exists { if existingSess, exists := sushi.GetSession(sessionID); exists {
sess = existingSess sess = existingSess
sess.Touch() sess.Touch()
StoreSession(sess) sushi.StoreSession(sess)
SetSessionCookie(ctx, sessionID) sushi.SetSessionCookie(ctx, sessionID)
} }
} }
if sess == nil { if sess == nil {
sess = CreateSession(0) // Guest session sess = sushi.CreateSession(0) // Guest session
SetSessionCookie(ctx, sess.ID) sushi.SetSessionCookie(ctx, sess.ID)
} }
ctx.SetUserValue(SessionCtxKey, sess) ctx.SetUserValue(sushi.SessionCtxKey, sess)
next() next()
} }
} }

View File

@ -57,7 +57,7 @@ func (ctx Ctx) Redirect(url string, statusCode ...int) {
if len(statusCode) > 0 { if len(statusCode) > 0 {
code = statusCode[0] code = statusCode[0]
} }
ctx.Redirect(url, code) ctx.RequestCtx.Redirect(url, code)
} }
// SendFile serves a file // SendFile serves a file