diff --git a/README.md b/README.md index 023dc1d..09f94a7 100644 --- a/README.md +++ b/README.md @@ -74,22 +74,22 @@ app.Get("/users/:id/posts/:slug", func(ctx sushi.Ctx, params []any) { ```go func myHandler(ctx sushi.Ctx, params []any) { // JSON responses - sushi.SendJSON(ctx, map[string]string{"message": "success"}) + ctx.SendJSON(map[string]string{"message": "success"}) // HTML responses - sushi.SendHTML(ctx, "

Welcome

") + ctx.SendHTML("

Welcome

") // Text responses - sushi.SendText(ctx, "Plain text") + ctx.SendText("Plain text") // Error responses - sushi.SendError(ctx, 404, "Not Found") + ctx.SendError(404, "Not Found") // Redirects - sushi.SendRedirect(ctx, "/login") + ctx.Redirect("/login") // Status only - sushi.SendStatus(ctx, 204) + ctx.SendStatus(204) } ``` @@ -186,21 +186,21 @@ func loginHandler(ctx sushi.Ctx, params []string) { // Find user by email/username user := findUserByEmail(email) if user == nil { - sushi.SendError(ctx, 401, "Invalid credentials") + ctx.SendError(401, "Invalid credentials") return } // Verify password isValid, err := password.VerifyPassword(password, user.Password) if err != nil || !isValid { - sushi.SendError(ctx, 401, "Invalid credentials") + ctx.SendError(401, "Invalid credentials") return } // Log the user in auth.Login(ctx, user.ID, user) - sushi.SendRedirect(ctx, "/dashboard") + ctx.Redirect("/dashboard") } ``` @@ -209,7 +209,7 @@ func loginHandler(ctx sushi.Ctx, params []string) { ```go func logoutHandler(ctx sushi.Ctx, params []string) { auth.Logout(ctx) - sushi.SendRedirect(ctx, "/") + ctx.SendRedirect("/") } ``` @@ -220,7 +220,7 @@ func dashboardHandler(ctx sushi.Ctx, params []string) { user := auth.GetCurrentUser(ctx).(*User) html := fmt.Sprintf("

Welcome, %s!

", user.Username) - sushi.SendHTML(ctx, html) + ctx.SendHTML(html) } ``` @@ -245,7 +245,7 @@ func loginPageHandler(ctx sushi.Ctx, params []string) { `, csrfField) - sushi.SendHTML(ctx, html) + ctx.SendHTML(html) } ``` @@ -362,10 +362,10 @@ func main() { func homeHandler(ctx sushi.Ctx, params []string) { if auth.IsAuthenticated(ctx) { - sushi.SendRedirect(ctx, "/dashboard") + ctx.SendRedirect("/dashboard") return } - sushi.SendHTML(ctx, `Login`) + ctx.SendHTML(c`Login`) } func loginPageHandler(ctx sushi.Ctx, params []string) { diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..05c2a69 --- /dev/null +++ b/auth.go @@ -0,0 +1,42 @@ +package sushi + +// IsAuthenticated checks if the current request is from an authenticated user +func (ctx Ctx) IsAuthenticated() bool { + user := ctx.UserValue("user") + return user != nil +} + +// GetCurrentUser returns the current authenticated user +func (ctx Ctx) GetCurrentUser() any { + return ctx.UserValue("user") +} + +// Login authenticates a user session +func (ctx Ctx) Login(userID int, user any) { + sess := GetCurrentSession(ctx) + if sess != nil { + sess.SetUserID(userID) + sess.RegenerateID() + StoreSession(sess) + + ctx.SetUserValue(SessionCtxKey, sess) + ctx.SetUserValue("user", user) + + SetSessionCookie(ctx, sess.ID) + } +} + +// Logout clears the user session +func (ctx Ctx) Logout() { + sess := GetCurrentSession(ctx) + if sess != nil { + sess.SetUserID(0) + sess.RegenerateID() + StoreSession(sess) + + ctx.SetUserValue(SessionCtxKey, sess) + SetSessionCookie(ctx, sess.ID) + } + + ctx.SetUserValue("user", nil) +} diff --git a/auth/auth.go b/auth/middleware.go similarity index 51% rename from auth/auth.go rename to auth/middleware.go index 63fe9cd..20f903a 100644 --- a/auth/auth.go +++ b/auth/middleware.go @@ -2,7 +2,6 @@ package auth import ( sushi "git.sharkk.net/Sharkk/Sushi" - "git.sharkk.net/Sharkk/Sushi/session" "github.com/valyala/fasthttp" ) @@ -11,14 +10,14 @@ const UserCtxKey = "user" // Middleware adds authentication handling func Middleware(userLookup func(int) any) sushi.Middleware { return func(ctx sushi.Ctx, params []any, next func()) { - sess := session.GetCurrentSession(ctx) + sess := sushi.GetCurrentSession(ctx) if sess != nil && sess.UserID > 0 && userLookup != nil { user := userLookup(sess.UserID) if user != nil { ctx.SetUserValue(UserCtxKey, user) } else { sess.SetUserID(0) - session.StoreSession(sess) + sushi.StoreSession(sess) } } next() @@ -33,7 +32,7 @@ func RequireAuth(redirectPath ...string) sushi.Middleware { } return func(ctx sushi.Ctx, params []any, next func()) { - if !IsAuthenticated(ctx) { + if !ctx.IsAuthenticated() { ctx.Redirect(redirect, fasthttp.StatusFound) return } @@ -49,51 +48,10 @@ func RequireGuest(redirectPath ...string) sushi.Middleware { } return func(ctx sushi.Ctx, params []any, next func()) { - if IsAuthenticated(ctx) { + if ctx.IsAuthenticated() { ctx.Redirect(redirect, fasthttp.StatusFound) return } next() } } - -// IsAuthenticated checks if the current request is from an authenticated user -func IsAuthenticated(ctx sushi.Ctx) bool { - user := ctx.UserValue(UserCtxKey) - return user != nil -} - -// GetCurrentUser returns the current authenticated user -func GetCurrentUser(ctx sushi.Ctx) any { - return ctx.UserValue(UserCtxKey) -} - -// Login authenticates a user session -func Login(ctx sushi.Ctx, userID int, user any) { - sess := session.GetCurrentSession(ctx) - if sess != nil { - sess.SetUserID(userID) - sess.RegenerateID() - session.StoreSession(sess) - - ctx.SetUserValue(session.SessionCtxKey, sess) - ctx.SetUserValue(UserCtxKey, user) - - session.SetSessionCookie(ctx, sess.ID) - } -} - -// Logout clears the user session -func Logout(ctx sushi.Ctx) { - sess := session.GetCurrentSession(ctx) - if sess != nil { - sess.SetUserID(0) - sess.RegenerateID() - session.StoreSession(sess) - - ctx.SetUserValue(session.SessionCtxKey, sess) - session.SetSessionCookie(ctx, sess.ID) - } - - ctx.SetUserValue(UserCtxKey, nil) -} diff --git a/csrf/csrf.go b/csrf/csrf.go index 44293f9..7f15fac 100644 --- a/csrf/csrf.go +++ b/csrf/csrf.go @@ -7,7 +7,6 @@ import ( "fmt" sushi "git.sharkk.net/Sharkk/Sushi" - "git.sharkk.net/Sharkk/Sushi/session" ) const ( @@ -17,16 +16,8 @@ const ( SessionCtxKey = "session" ) -// GetCurrentSession retrieves the session from context -func GetCurrentSession(ctx sushi.Ctx) *session.Session { - if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok { - return sess - } - return nil -} - -// GenerateCSRFToken creates a new CSRF token and stores it in the session -func GenerateCSRFToken(ctx sushi.Ctx) string { +// GenerateToken creates a new CSRF token and stores it in the session +func GenerateToken(ctx sushi.Ctx) string { tokenBytes := make([]byte, CSRFTokenLength) if _, err := rand.Read(tokenBytes); err != nil { return "" @@ -34,17 +25,17 @@ func GenerateCSRFToken(ctx sushi.Ctx) string { token := base64.URLEncoding.EncodeToString(tokenBytes) - if sess := GetCurrentSession(ctx); sess != nil { + if sess := ctx.GetCurrentSession(); sess != nil { sess.Set(CSRFSessionKey, token) - session.StoreSession(sess) + sushi.StoreSession(sess) } return token } -// GetCSRFToken retrieves the current CSRF token from session, generating one if needed -func GetCSRFToken(ctx sushi.Ctx) string { - sess := GetCurrentSession(ctx) +// GetToken retrieves the current CSRF token from session, generating one if needed +func GetToken(ctx sushi.Ctx) string { + sess := ctx.GetCurrentSession() if sess == nil { return "" } @@ -55,16 +46,16 @@ func GetCSRFToken(ctx sushi.Ctx) string { } } - return GenerateCSRFToken(ctx) + return GenerateToken(ctx) } -// ValidateCSRFToken verifies a CSRF token against the stored session token -func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool { +// ValidateToken verifies a CSRF token against the stored session token +func ValidateToken(ctx sushi.Ctx, submittedToken string) bool { if submittedToken == "" { return false } - sess := GetCurrentSession(ctx) + sess := ctx.GetCurrentSession() if sess == nil { return false } @@ -82,9 +73,9 @@ func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool { return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1 } -// CSRFHiddenField generates an HTML hidden input field with the CSRF token -func CSRFHiddenField(ctx sushi.Ctx) string { - token := GetCSRFToken(ctx) +// HiddenField generates an HTML hidden input field with the CSRF token +func HiddenField(ctx sushi.Ctx) string { + token := GetToken(ctx) if token == "" { return "" } @@ -95,7 +86,7 @@ func CSRFHiddenField(ctx sushi.Ctx) string { // CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token func CSRFTokenMeta(ctx sushi.Ctx) string { - token := GetCSRFToken(ctx) + token := GetToken(ctx) if token == "" { return "" } @@ -114,7 +105,7 @@ func ValidateFormCSRFToken(ctx sushi.Ctx) bool { return false } - return ValidateCSRFToken(ctx, string(tokenBytes)) + return ValidateToken(ctx, string(tokenBytes)) } // Middleware returns middleware that automatically validates CSRF tokens @@ -124,7 +115,7 @@ func Middleware() sushi.Middleware { if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { if !ValidateFormCSRFToken(ctx) { - GenerateCSRFToken(ctx) + GenerateToken(ctx) currentPath := string(ctx.Path()) ctx.Redirect(currentPath, 302) return diff --git a/session/session.go b/session.go similarity index 93% rename from session/session.go rename to session.go index 36c9774..0fa2890 100644 --- a/session/session.go +++ b/session.go @@ -1,4 +1,4 @@ -package session +package sushi import ( "crypto/rand" @@ -7,8 +7,6 @@ import ( "os" "sync" "time" - - sushi "git.sharkk.net/Sharkk/Sushi" ) const ( @@ -132,7 +130,7 @@ func (s *Session) SetUserID(userID int) { } // GetCurrentSession retrieves the session from context -func GetCurrentSession(ctx sushi.Ctx) *Session { +func GetCurrentSession(ctx Ctx) *Session { if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok { return sess } @@ -267,14 +265,22 @@ func SaveSessions() error { return sessionManager.Save() } -func SetSessionCookie(ctx sushi.Ctx, sessionID string) { - sushi.SetSecureCookie(ctx, sushi.CookieOptions{ +func SetSessionCookie(ctx Ctx, sessionID string) { + SetSecureCookie(ctx, CookieOptions{ Name: SessionCookieName, Value: sessionID, Path: "/", Expires: time.Now().Add(24 * time.Hour), HTTPOnly: true, - Secure: sushi.IsHTTPS(ctx), + Secure: IsHTTPS(ctx), SameSite: "lax", }) } + +// GetCurrentSession retrieves the session from context +func (ctx Ctx) GetCurrentSession() *Session { + if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok { + return sess + } + return nil +} diff --git a/session/middleware.go b/session/middleware.go index 2272e70..675cbac 100644 --- a/session/middleware.go +++ b/session/middleware.go @@ -5,24 +5,24 @@ import sushi "git.sharkk.net/Sharkk/Sushi" // Middleware provides session handling func Middleware() sushi.Middleware { return func(ctx sushi.Ctx, params []any, next func()) { - sessionID := sushi.GetCookie(ctx, SessionCookieName) - var sess *Session + sessionID := sushi.GetCookie(ctx, sushi.SessionCookieName) + var sess *sushi.Session if sessionID != "" { - if existingSess, exists := GetSession(sessionID); exists { + if existingSess, exists := sushi.GetSession(sessionID); exists { sess = existingSess sess.Touch() - StoreSession(sess) - SetSessionCookie(ctx, sessionID) + sushi.StoreSession(sess) + sushi.SetSessionCookie(ctx, sessionID) } } if sess == nil { - sess = CreateSession(0) // Guest session - SetSessionCookie(ctx, sess.ID) + sess = sushi.CreateSession(0) // Guest session + sushi.SetSessionCookie(ctx, sess.ID) } - ctx.SetUserValue(SessionCtxKey, sess) + ctx.SetUserValue(sushi.SessionCtxKey, sess) next() } } diff --git a/types.go b/types.go index cae1e0b..7510328 100644 --- a/types.go +++ b/types.go @@ -57,7 +57,7 @@ func (ctx Ctx) Redirect(url string, statusCode ...int) { if len(statusCode) > 0 { code = statusCode[0] } - ctx.Redirect(url, code) + ctx.RequestCtx.Redirect(url, code) } // SendFile serves a file