move auth and session to core package, move helpers to ctx
This commit is contained in:
parent
5370d14152
commit
0f59ba225a
28
README.md
28
README.md
@ -74,22 +74,22 @@ app.Get("/users/:id/posts/:slug", func(ctx sushi.Ctx, params []any) {
|
|||||||
```go
|
```go
|
||||||
func myHandler(ctx sushi.Ctx, params []any) {
|
func myHandler(ctx sushi.Ctx, params []any) {
|
||||||
// JSON responses
|
// JSON responses
|
||||||
sushi.SendJSON(ctx, map[string]string{"message": "success"})
|
ctx.SendJSON(map[string]string{"message": "success"})
|
||||||
|
|
||||||
// HTML responses
|
// HTML responses
|
||||||
sushi.SendHTML(ctx, "<h1>Welcome</h1>")
|
ctx.SendHTML("<h1>Welcome</h1>")
|
||||||
|
|
||||||
// Text responses
|
// Text responses
|
||||||
sushi.SendText(ctx, "Plain text")
|
ctx.SendText("Plain text")
|
||||||
|
|
||||||
// Error responses
|
// Error responses
|
||||||
sushi.SendError(ctx, 404, "Not Found")
|
ctx.SendError(404, "Not Found")
|
||||||
|
|
||||||
// Redirects
|
// Redirects
|
||||||
sushi.SendRedirect(ctx, "/login")
|
ctx.Redirect("/login")
|
||||||
|
|
||||||
// Status only
|
// Status only
|
||||||
sushi.SendStatus(ctx, 204)
|
ctx.SendStatus(204)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -186,21 +186,21 @@ func loginHandler(ctx sushi.Ctx, params []string) {
|
|||||||
// Find user by email/username
|
// Find user by email/username
|
||||||
user := findUserByEmail(email)
|
user := findUserByEmail(email)
|
||||||
if user == nil {
|
if user == nil {
|
||||||
sushi.SendError(ctx, 401, "Invalid credentials")
|
ctx.SendError(401, "Invalid credentials")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify password
|
// Verify password
|
||||||
isValid, err := password.VerifyPassword(password, user.Password)
|
isValid, err := password.VerifyPassword(password, user.Password)
|
||||||
if err != nil || !isValid {
|
if err != nil || !isValid {
|
||||||
sushi.SendError(ctx, 401, "Invalid credentials")
|
ctx.SendError(401, "Invalid credentials")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log the user in
|
// Log the user in
|
||||||
auth.Login(ctx, user.ID, user)
|
auth.Login(ctx, user.ID, user)
|
||||||
|
|
||||||
sushi.SendRedirect(ctx, "/dashboard")
|
ctx.Redirect("/dashboard")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -209,7 +209,7 @@ func loginHandler(ctx sushi.Ctx, params []string) {
|
|||||||
```go
|
```go
|
||||||
func logoutHandler(ctx sushi.Ctx, params []string) {
|
func logoutHandler(ctx sushi.Ctx, params []string) {
|
||||||
auth.Logout(ctx)
|
auth.Logout(ctx)
|
||||||
sushi.SendRedirect(ctx, "/")
|
ctx.SendRedirect("/")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ func dashboardHandler(ctx sushi.Ctx, params []string) {
|
|||||||
user := auth.GetCurrentUser(ctx).(*User)
|
user := auth.GetCurrentUser(ctx).(*User)
|
||||||
|
|
||||||
html := fmt.Sprintf("<h1>Welcome, %s!</h1>", user.Username)
|
html := fmt.Sprintf("<h1>Welcome, %s!</h1>", user.Username)
|
||||||
sushi.SendHTML(ctx, html)
|
ctx.SendHTML(html)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ func loginPageHandler(ctx sushi.Ctx, params []string) {
|
|||||||
</form>
|
</form>
|
||||||
`, csrfField)
|
`, csrfField)
|
||||||
|
|
||||||
sushi.SendHTML(ctx, html)
|
ctx.SendHTML(html)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -362,10 +362,10 @@ func main() {
|
|||||||
|
|
||||||
func homeHandler(ctx sushi.Ctx, params []string) {
|
func homeHandler(ctx sushi.Ctx, params []string) {
|
||||||
if auth.IsAuthenticated(ctx) {
|
if auth.IsAuthenticated(ctx) {
|
||||||
sushi.SendRedirect(ctx, "/dashboard")
|
ctx.SendRedirect("/dashboard")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sushi.SendHTML(ctx, `<a href="/login">Login</a>`)
|
ctx.SendHTML(c`<a href="/login">Login</a>`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func loginPageHandler(ctx sushi.Ctx, params []string) {
|
func loginPageHandler(ctx sushi.Ctx, params []string) {
|
||||||
|
42
auth.go
Normal file
42
auth.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package sushi
|
||||||
|
|
||||||
|
// IsAuthenticated checks if the current request is from an authenticated user
|
||||||
|
func (ctx Ctx) IsAuthenticated() bool {
|
||||||
|
user := ctx.UserValue("user")
|
||||||
|
return user != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentUser returns the current authenticated user
|
||||||
|
func (ctx Ctx) GetCurrentUser() any {
|
||||||
|
return ctx.UserValue("user")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login authenticates a user session
|
||||||
|
func (ctx Ctx) Login(userID int, user any) {
|
||||||
|
sess := GetCurrentSession(ctx)
|
||||||
|
if sess != nil {
|
||||||
|
sess.SetUserID(userID)
|
||||||
|
sess.RegenerateID()
|
||||||
|
StoreSession(sess)
|
||||||
|
|
||||||
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
ctx.SetUserValue("user", user)
|
||||||
|
|
||||||
|
SetSessionCookie(ctx, sess.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout clears the user session
|
||||||
|
func (ctx Ctx) Logout() {
|
||||||
|
sess := GetCurrentSession(ctx)
|
||||||
|
if sess != nil {
|
||||||
|
sess.SetUserID(0)
|
||||||
|
sess.RegenerateID()
|
||||||
|
StoreSession(sess)
|
||||||
|
|
||||||
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
SetSessionCookie(ctx, sess.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.SetUserValue("user", nil)
|
||||||
|
}
|
@ -2,7 +2,6 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
sushi "git.sharkk.net/Sharkk/Sushi"
|
sushi "git.sharkk.net/Sharkk/Sushi"
|
||||||
"git.sharkk.net/Sharkk/Sushi/session"
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -11,14 +10,14 @@ const UserCtxKey = "user"
|
|||||||
// Middleware adds authentication handling
|
// Middleware adds authentication handling
|
||||||
func Middleware(userLookup func(int) any) sushi.Middleware {
|
func Middleware(userLookup func(int) any) sushi.Middleware {
|
||||||
return func(ctx sushi.Ctx, params []any, next func()) {
|
return func(ctx sushi.Ctx, params []any, next func()) {
|
||||||
sess := session.GetCurrentSession(ctx)
|
sess := sushi.GetCurrentSession(ctx)
|
||||||
if sess != nil && sess.UserID > 0 && userLookup != nil {
|
if sess != nil && sess.UserID > 0 && userLookup != nil {
|
||||||
user := userLookup(sess.UserID)
|
user := userLookup(sess.UserID)
|
||||||
if user != nil {
|
if user != nil {
|
||||||
ctx.SetUserValue(UserCtxKey, user)
|
ctx.SetUserValue(UserCtxKey, user)
|
||||||
} else {
|
} else {
|
||||||
sess.SetUserID(0)
|
sess.SetUserID(0)
|
||||||
session.StoreSession(sess)
|
sushi.StoreSession(sess)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
next()
|
next()
|
||||||
@ -33,7 +32,7 @@ func RequireAuth(redirectPath ...string) sushi.Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(ctx sushi.Ctx, params []any, next func()) {
|
return func(ctx sushi.Ctx, params []any, next func()) {
|
||||||
if !IsAuthenticated(ctx) {
|
if !ctx.IsAuthenticated() {
|
||||||
ctx.Redirect(redirect, fasthttp.StatusFound)
|
ctx.Redirect(redirect, fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -49,51 +48,10 @@ func RequireGuest(redirectPath ...string) sushi.Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(ctx sushi.Ctx, params []any, next func()) {
|
return func(ctx sushi.Ctx, params []any, next func()) {
|
||||||
if IsAuthenticated(ctx) {
|
if ctx.IsAuthenticated() {
|
||||||
ctx.Redirect(redirect, fasthttp.StatusFound)
|
ctx.Redirect(redirect, fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next()
|
next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAuthenticated checks if the current request is from an authenticated user
|
|
||||||
func IsAuthenticated(ctx sushi.Ctx) bool {
|
|
||||||
user := ctx.UserValue(UserCtxKey)
|
|
||||||
return user != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCurrentUser returns the current authenticated user
|
|
||||||
func GetCurrentUser(ctx sushi.Ctx) any {
|
|
||||||
return ctx.UserValue(UserCtxKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login authenticates a user session
|
|
||||||
func Login(ctx sushi.Ctx, userID int, user any) {
|
|
||||||
sess := session.GetCurrentSession(ctx)
|
|
||||||
if sess != nil {
|
|
||||||
sess.SetUserID(userID)
|
|
||||||
sess.RegenerateID()
|
|
||||||
session.StoreSession(sess)
|
|
||||||
|
|
||||||
ctx.SetUserValue(session.SessionCtxKey, sess)
|
|
||||||
ctx.SetUserValue(UserCtxKey, user)
|
|
||||||
|
|
||||||
session.SetSessionCookie(ctx, sess.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logout clears the user session
|
|
||||||
func Logout(ctx sushi.Ctx) {
|
|
||||||
sess := session.GetCurrentSession(ctx)
|
|
||||||
if sess != nil {
|
|
||||||
sess.SetUserID(0)
|
|
||||||
sess.RegenerateID()
|
|
||||||
session.StoreSession(sess)
|
|
||||||
|
|
||||||
ctx.SetUserValue(session.SessionCtxKey, sess)
|
|
||||||
session.SetSessionCookie(ctx, sess.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.SetUserValue(UserCtxKey, nil)
|
|
||||||
}
|
|
43
csrf/csrf.go
43
csrf/csrf.go
@ -7,7 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
sushi "git.sharkk.net/Sharkk/Sushi"
|
sushi "git.sharkk.net/Sharkk/Sushi"
|
||||||
"git.sharkk.net/Sharkk/Sushi/session"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -17,16 +16,8 @@ const (
|
|||||||
SessionCtxKey = "session"
|
SessionCtxKey = "session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetCurrentSession retrieves the session from context
|
// GenerateToken creates a new CSRF token and stores it in the session
|
||||||
func GetCurrentSession(ctx sushi.Ctx) *session.Session {
|
func GenerateToken(ctx sushi.Ctx) string {
|
||||||
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
|
|
||||||
return sess
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateCSRFToken creates a new CSRF token and stores it in the session
|
|
||||||
func GenerateCSRFToken(ctx sushi.Ctx) string {
|
|
||||||
tokenBytes := make([]byte, CSRFTokenLength)
|
tokenBytes := make([]byte, CSRFTokenLength)
|
||||||
if _, err := rand.Read(tokenBytes); err != nil {
|
if _, err := rand.Read(tokenBytes); err != nil {
|
||||||
return ""
|
return ""
|
||||||
@ -34,17 +25,17 @@ func GenerateCSRFToken(ctx sushi.Ctx) string {
|
|||||||
|
|
||||||
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||||
|
|
||||||
if sess := GetCurrentSession(ctx); sess != nil {
|
if sess := ctx.GetCurrentSession(); sess != nil {
|
||||||
sess.Set(CSRFSessionKey, token)
|
sess.Set(CSRFSessionKey, token)
|
||||||
session.StoreSession(sess)
|
sushi.StoreSession(sess)
|
||||||
}
|
}
|
||||||
|
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCSRFToken retrieves the current CSRF token from session, generating one if needed
|
// GetToken retrieves the current CSRF token from session, generating one if needed
|
||||||
func GetCSRFToken(ctx sushi.Ctx) string {
|
func GetToken(ctx sushi.Ctx) string {
|
||||||
sess := GetCurrentSession(ctx)
|
sess := ctx.GetCurrentSession()
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -55,16 +46,16 @@ func GetCSRFToken(ctx sushi.Ctx) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return GenerateCSRFToken(ctx)
|
return GenerateToken(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateCSRFToken verifies a CSRF token against the stored session token
|
// ValidateToken verifies a CSRF token against the stored session token
|
||||||
func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool {
|
func ValidateToken(ctx sushi.Ctx, submittedToken string) bool {
|
||||||
if submittedToken == "" {
|
if submittedToken == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
sess := GetCurrentSession(ctx)
|
sess := ctx.GetCurrentSession()
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -82,9 +73,9 @@ func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool {
|
|||||||
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
|
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRFHiddenField generates an HTML hidden input field with the CSRF token
|
// HiddenField generates an HTML hidden input field with the CSRF token
|
||||||
func CSRFHiddenField(ctx sushi.Ctx) string {
|
func HiddenField(ctx sushi.Ctx) string {
|
||||||
token := GetCSRFToken(ctx)
|
token := GetToken(ctx)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -95,7 +86,7 @@ func CSRFHiddenField(ctx sushi.Ctx) string {
|
|||||||
|
|
||||||
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
|
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
|
||||||
func CSRFTokenMeta(ctx sushi.Ctx) string {
|
func CSRFTokenMeta(ctx sushi.Ctx) string {
|
||||||
token := GetCSRFToken(ctx)
|
token := GetToken(ctx)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -114,7 +105,7 @@ func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return ValidateCSRFToken(ctx, string(tokenBytes))
|
return ValidateToken(ctx, string(tokenBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Middleware returns middleware that automatically validates CSRF tokens
|
// Middleware returns middleware that automatically validates CSRF tokens
|
||||||
@ -124,7 +115,7 @@ func Middleware() sushi.Middleware {
|
|||||||
|
|
||||||
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
||||||
if !ValidateFormCSRFToken(ctx) {
|
if !ValidateFormCSRFToken(ctx) {
|
||||||
GenerateCSRFToken(ctx)
|
GenerateToken(ctx)
|
||||||
currentPath := string(ctx.Path())
|
currentPath := string(ctx.Path())
|
||||||
ctx.Redirect(currentPath, 302)
|
ctx.Redirect(currentPath, 302)
|
||||||
return
|
return
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package session
|
package sushi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
@ -7,8 +7,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
sushi "git.sharkk.net/Sharkk/Sushi"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -132,7 +130,7 @@ func (s *Session) SetUserID(userID int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetCurrentSession retrieves the session from context
|
// GetCurrentSession retrieves the session from context
|
||||||
func GetCurrentSession(ctx sushi.Ctx) *Session {
|
func GetCurrentSession(ctx Ctx) *Session {
|
||||||
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
|
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
@ -267,14 +265,22 @@ func SaveSessions() error {
|
|||||||
return sessionManager.Save()
|
return sessionManager.Save()
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetSessionCookie(ctx sushi.Ctx, sessionID string) {
|
func SetSessionCookie(ctx Ctx, sessionID string) {
|
||||||
sushi.SetSecureCookie(ctx, sushi.CookieOptions{
|
SetSecureCookie(ctx, CookieOptions{
|
||||||
Name: SessionCookieName,
|
Name: SessionCookieName,
|
||||||
Value: sessionID,
|
Value: sessionID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Expires: time.Now().Add(24 * time.Hour),
|
Expires: time.Now().Add(24 * time.Hour),
|
||||||
HTTPOnly: true,
|
HTTPOnly: true,
|
||||||
Secure: sushi.IsHTTPS(ctx),
|
Secure: IsHTTPS(ctx),
|
||||||
SameSite: "lax",
|
SameSite: "lax",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCurrentSession retrieves the session from context
|
||||||
|
func (ctx Ctx) GetCurrentSession() *Session {
|
||||||
|
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
|
||||||
|
return sess
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -5,24 +5,24 @@ import sushi "git.sharkk.net/Sharkk/Sushi"
|
|||||||
// Middleware provides session handling
|
// Middleware provides session handling
|
||||||
func Middleware() sushi.Middleware {
|
func Middleware() sushi.Middleware {
|
||||||
return func(ctx sushi.Ctx, params []any, next func()) {
|
return func(ctx sushi.Ctx, params []any, next func()) {
|
||||||
sessionID := sushi.GetCookie(ctx, SessionCookieName)
|
sessionID := sushi.GetCookie(ctx, sushi.SessionCookieName)
|
||||||
var sess *Session
|
var sess *sushi.Session
|
||||||
|
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
if existingSess, exists := GetSession(sessionID); exists {
|
if existingSess, exists := sushi.GetSession(sessionID); exists {
|
||||||
sess = existingSess
|
sess = existingSess
|
||||||
sess.Touch()
|
sess.Touch()
|
||||||
StoreSession(sess)
|
sushi.StoreSession(sess)
|
||||||
SetSessionCookie(ctx, sessionID)
|
sushi.SetSessionCookie(ctx, sessionID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
sess = CreateSession(0) // Guest session
|
sess = sushi.CreateSession(0) // Guest session
|
||||||
SetSessionCookie(ctx, sess.ID)
|
sushi.SetSessionCookie(ctx, sess.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.SetUserValue(SessionCtxKey, sess)
|
ctx.SetUserValue(sushi.SessionCtxKey, sess)
|
||||||
next()
|
next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
2
types.go
2
types.go
@ -57,7 +57,7 @@ func (ctx Ctx) Redirect(url string, statusCode ...int) {
|
|||||||
if len(statusCode) > 0 {
|
if len(statusCode) > 0 {
|
||||||
code = statusCode[0]
|
code = statusCode[0]
|
||||||
}
|
}
|
||||||
ctx.Redirect(url, code)
|
ctx.RequestCtx.Redirect(url, code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendFile serves a file
|
// SendFile serves a file
|
||||||
|
Loading…
x
Reference in New Issue
Block a user