Compare commits
No commits in common. "4a5f0debf68d615806ef1d6e146e785b4086632c" and "82ef4b31d486780f0cada81b5d83c0eb0814a9b0" have entirely different histories.
4a5f0debf6
...
82ef4b31d4
167
internal/csrf/csrf_test.go
Normal file
167
internal/csrf/csrf_test.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
package csrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dk/internal/session"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateToken(t *testing.T) {
|
||||||
|
sess := &session.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
Data: make(map[string]any),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
|
token := GenerateToken(ctx)
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Error("Expected non-empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
storedToken := GetStoredToken(sess)
|
||||||
|
if storedToken != token {
|
||||||
|
t.Errorf("Expected stored token %s, got %s", token, storedToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateToken(t *testing.T) {
|
||||||
|
sess := &session.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
|
if !ValidateToken(ctx, "test-token") {
|
||||||
|
t.Error("Expected valid token to pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ValidateToken(ctx, "wrong-token") {
|
||||||
|
t.Error("Expected invalid token to fail validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ValidateToken(ctx, "") {
|
||||||
|
t.Error("Expected empty token to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenNoSession(t *testing.T) {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
|
if ValidateToken(ctx, "any-token") {
|
||||||
|
t.Error("Expected validation to fail with no session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHiddenField(t *testing.T) {
|
||||||
|
sess := &session.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
|
field := HiddenField(ctx)
|
||||||
|
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
|
||||||
|
|
||||||
|
if field != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHiddenFieldNoSession(t *testing.T) {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
|
field := HiddenField(ctx)
|
||||||
|
if field == "" {
|
||||||
|
t.Error("Expected non-empty field for guest user with cookie-based token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenMeta(t *testing.T) {
|
||||||
|
sess := &session.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
|
meta := TokenMeta(ctx)
|
||||||
|
expected := `<meta name="csrf-token" content="test-token">`
|
||||||
|
|
||||||
|
if meta != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreAndGetToken(t *testing.T) {
|
||||||
|
sess := &session.Session{
|
||||||
|
Data: make(map[string]any),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := "test-token"
|
||||||
|
StoreToken(sess, token)
|
||||||
|
|
||||||
|
retrieved := GetStoredToken(sess)
|
||||||
|
if retrieved != token {
|
||||||
|
t.Errorf("Expected %s, got %s", token, retrieved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStoredTokenNoData(t *testing.T) {
|
||||||
|
sess := &session.Session{}
|
||||||
|
|
||||||
|
token := GetStoredToken(sess)
|
||||||
|
if token != "" {
|
||||||
|
t.Errorf("Expected empty token, got %s", token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateFormToken(t *testing.T) {
|
||||||
|
sess := &session.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
|
ctx.PostArgs().Set(TokenFieldName, "test-token")
|
||||||
|
|
||||||
|
if !ValidateFormToken(ctx) {
|
||||||
|
t.Error("Expected form token validation to pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
|
||||||
|
|
||||||
|
if ValidateFormToken(ctx) {
|
||||||
|
t.Error("Expected form token validation to fail with wrong token")
|
||||||
|
}
|
||||||
|
}
|
@ -5,43 +5,29 @@ import (
|
|||||||
"dk/internal/models/users"
|
"dk/internal/models/users"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
"dk/internal/session"
|
"dk/internal/session"
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const SessionCookieName = "dk_session"
|
|
||||||
|
|
||||||
func Auth() router.Middleware {
|
func Auth() router.Middleware {
|
||||||
return func(next router.Handler) router.Handler {
|
return func(next router.Handler) router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
sessionID := cookies.GetCookie(ctx, SessionCookieName)
|
sessionID := cookies.GetCookie(ctx, session.SessionCookieName)
|
||||||
var sess *session.Session
|
|
||||||
|
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
if existingSess, exists := session.Get(sessionID); exists {
|
if sess, exists := session.Get(sessionID); exists {
|
||||||
sess = existingSess
|
session.Update(sessionID)
|
||||||
sess.Touch()
|
|
||||||
session.Store(sess)
|
|
||||||
|
|
||||||
if sess.UserID > 0 { // User session
|
user, err := users.Find(sess.UserID)
|
||||||
user, err := users.Find(sess.UserID)
|
if err == nil && user != nil {
|
||||||
if err == nil && user != nil {
|
ctx.SetUserValue("session", sess)
|
||||||
ctx.SetUserValue("user", user)
|
ctx.SetUserValue("user", user)
|
||||||
setSessionCookie(ctx, sessionID)
|
|
||||||
}
|
session.SetSessionCookie(ctx, sessionID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create guest session if none exists
|
|
||||||
if sess == nil {
|
|
||||||
sess = session.Create(0) // Guest session
|
|
||||||
setSessionCookie(ctx, sess.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.SetUserValue("session", sess)
|
|
||||||
next(ctx, params)
|
next(ctx, params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,7 +64,6 @@ func RequireGuest(paths ...string) router.Middleware {
|
|||||||
return func(next router.Handler) router.Handler {
|
return func(next router.Handler) router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
if IsAuthenticated(ctx) {
|
if IsAuthenticated(ctx) {
|
||||||
fmt.Println("RequireGuest: user is authenticated")
|
|
||||||
ctx.Redirect(redirect, fasthttp.StatusFound)
|
ctx.Redirect(redirect, fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -107,38 +92,21 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Login(ctx router.Ctx, user *users.User) {
|
func Login(ctx router.Ctx, user *users.User) {
|
||||||
sess := session.Create(user.ID)
|
sess := session.Create(user.ID, user.Username, user.Email)
|
||||||
setSessionCookie(ctx, sess.ID)
|
session.SetSessionCookie(ctx, sess.ID)
|
||||||
|
|
||||||
ctx.SetUserValue("session", sess)
|
ctx.SetUserValue("session", sess)
|
||||||
ctx.SetUserValue("user", user)
|
ctx.SetUserValue("user", user)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Logout(ctx router.Ctx) {
|
func Logout(ctx router.Ctx) {
|
||||||
sessionID := cookies.GetCookie(ctx, SessionCookieName)
|
sessionID := cookies.GetCookie(ctx, session.SessionCookieName)
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
session.Delete(sessionID)
|
session.Delete(sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
deleteSessionCookie(ctx)
|
session.DeleteSessionCookie(ctx)
|
||||||
|
|
||||||
ctx.SetUserValue("session", nil)
|
ctx.SetUserValue("session", nil)
|
||||||
ctx.SetUserValue("user", nil)
|
ctx.SetUserValue("user", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions for session cookies
|
|
||||||
func setSessionCookie(ctx router.Ctx, sessionID string) {
|
|
||||||
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
|
|
||||||
Name: SessionCookieName,
|
|
||||||
Value: sessionID,
|
|
||||||
Path: "/",
|
|
||||||
Expires: time.Now().Add(24 * time.Hour),
|
|
||||||
HTTPOnly: true,
|
|
||||||
Secure: cookies.IsHTTPS(ctx),
|
|
||||||
SameSite: "lax",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteSessionCookie(ctx router.Ctx) {
|
|
||||||
cookies.DeleteCookie(ctx, SessionCookieName)
|
|
||||||
}
|
|
||||||
|
@ -87,19 +87,19 @@ func New() *User {
|
|||||||
ClassID: 1,
|
ClassID: 1,
|
||||||
Currently: "In Town",
|
Currently: "In Town",
|
||||||
Fighting: 0,
|
Fighting: 0,
|
||||||
HP: 10,
|
HP: 15,
|
||||||
MP: 10,
|
MP: 0,
|
||||||
TP: 10,
|
TP: 10,
|
||||||
MaxHP: 10,
|
MaxHP: 15,
|
||||||
MaxMP: 10,
|
MaxMP: 0,
|
||||||
MaxTP: 10,
|
MaxTP: 10,
|
||||||
Level: 1,
|
Level: 1,
|
||||||
Gold: 100,
|
Gold: 100,
|
||||||
Exp: 0,
|
Exp: 0,
|
||||||
Strength: 0,
|
Strength: 5,
|
||||||
Dexterity: 0,
|
Dexterity: 5,
|
||||||
Attack: 0,
|
Attack: 5,
|
||||||
Defense: 0,
|
Defense: 5,
|
||||||
Spells: "",
|
Spells: "",
|
||||||
Towns: "",
|
Towns: "",
|
||||||
}
|
}
|
||||||
|
@ -18,39 +18,36 @@ import (
|
|||||||
|
|
||||||
// RegisterAuthRoutes sets up authentication routes
|
// RegisterAuthRoutes sets up authentication routes
|
||||||
func RegisterAuthRoutes(r *router.Router) {
|
func RegisterAuthRoutes(r *router.Router) {
|
||||||
guests := r.Group("")
|
// Guest routes
|
||||||
guests.Use(middleware.RequireGuest())
|
guestGroup := r.Group("")
|
||||||
|
guestGroup.Use(middleware.RequireGuest())
|
||||||
|
|
||||||
guests.Get("/login", showLogin)
|
guestGroup.Get("/login", showLogin)
|
||||||
guests.Post("/login", processLogin)
|
guestGroup.Post("/login", processLogin)
|
||||||
guests.Get("/register", showRegister)
|
guestGroup.Get("/register", showRegister)
|
||||||
guests.Post("/register", processRegister)
|
guestGroup.Post("/register", processRegister)
|
||||||
|
|
||||||
authed := r.Group("")
|
// Authenticated routes
|
||||||
authed.Use(middleware.RequireAuth())
|
authGroup := r.Group("")
|
||||||
|
authGroup.Use(middleware.RequireAuth())
|
||||||
|
|
||||||
authed.Post("/logout", processLogout)
|
authGroup.Post("/logout", processLogout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// showLogin displays the login form
|
// showLogin displays the login form
|
||||||
func showLogin(ctx router.Ctx, _ []string) {
|
func showLogin(ctx router.Ctx, _ []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
// Get flash message if any
|
||||||
var errorHTML string
|
var errorHTML string
|
||||||
var id string
|
if flash := session.GetFlashMessage(ctx); flash != nil {
|
||||||
|
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, flash.Message)
|
||||||
if flash, exists := sess.GetFlash("error"); exists {
|
|
||||||
if msg, ok := flash.(string); ok {
|
|
||||||
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, msg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if formData, exists := sess.Get("form_data"); exists {
|
// Get form data if any (for preserving email/username on error)
|
||||||
if data, ok := formData.(map[string]string); ok {
|
formData := session.GetFormData(ctx)
|
||||||
id = data["id"]
|
id := ""
|
||||||
}
|
if formData != nil {
|
||||||
|
id = formData["id"]
|
||||||
}
|
}
|
||||||
sess.Delete("form_data")
|
|
||||||
session.Store(sess)
|
|
||||||
|
|
||||||
components.RenderPage(ctx, "Log In", "auth/login.html", map[string]any{
|
components.RenderPage(ctx, "Log In", "auth/login.html", map[string]any{
|
||||||
"error_message": errorHTML,
|
"error_message": errorHTML,
|
||||||
@ -70,30 +67,26 @@ func processLogin(ctx router.Ctx, _ []string) {
|
|||||||
userPassword := string(ctx.PostArgs().Peek("password"))
|
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||||
|
|
||||||
if email == "" || userPassword == "" {
|
if email == "" || userPassword == "" {
|
||||||
setFlashAndFormData(ctx, "Email and password are required", map[string]string{"id": email})
|
session.SetFlashMessage(ctx, "error", "Email and password are required")
|
||||||
|
session.SetFormData(ctx, map[string]string{"id": email})
|
||||||
ctx.Redirect("/login", fasthttp.StatusFound)
|
ctx.Redirect("/login", fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := auth.Authenticate(email, userPassword)
|
user, err := auth.Authenticate(email, userPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
setFlashAndFormData(ctx, "Invalid email or password", map[string]string{"id": email})
|
session.SetFlashMessage(ctx, "error", "Invalid email or password")
|
||||||
|
session.SetFormData(ctx, map[string]string{"id": email})
|
||||||
ctx.Redirect("/login", fasthttp.StatusFound)
|
ctx.Redirect("/login", fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
middleware.Login(ctx, user)
|
middleware.Login(ctx, user)
|
||||||
|
|
||||||
// Set success message
|
|
||||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
|
||||||
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
|
|
||||||
session.Store(sess)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transfer CSRF token from cookie to session for authenticated user
|
// Transfer CSRF token from cookie to session for authenticated user
|
||||||
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
if session := csrf.GetCurrentSession(ctx); session != nil {
|
||||||
csrf.StoreToken(sess, cookieToken)
|
csrf.StoreToken(session, cookieToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,24 +95,20 @@ func processLogin(ctx router.Ctx, _ []string) {
|
|||||||
|
|
||||||
// showRegister displays the registration form
|
// showRegister displays the registration form
|
||||||
func showRegister(ctx router.Ctx, _ []string) {
|
func showRegister(ctx router.Ctx, _ []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
// Get flash message if any
|
||||||
var errorHTML string
|
var errorHTML string
|
||||||
var username, email string
|
if flash := session.GetFlashMessage(ctx); flash != nil {
|
||||||
|
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, flash.Message)
|
||||||
if flash, exists := sess.GetFlash("error"); exists {
|
|
||||||
if msg, ok := flash.(string); ok {
|
|
||||||
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, msg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if formData, exists := sess.Get("form_data"); exists {
|
// Get form data if any (for preserving values on error)
|
||||||
if data, ok := formData.(map[string]string); ok {
|
formData := session.GetFormData(ctx)
|
||||||
username = data["username"]
|
username := ""
|
||||||
email = data["email"]
|
email := ""
|
||||||
}
|
if formData != nil {
|
||||||
|
username = formData["username"]
|
||||||
|
email = formData["email"]
|
||||||
}
|
}
|
||||||
sess.Delete("form_data")
|
|
||||||
session.Store(sess)
|
|
||||||
|
|
||||||
components.RenderPage(ctx, "Register", "auth/register.html", map[string]any{
|
components.RenderPage(ctx, "Register", "auth/register.html", map[string]any{
|
||||||
"error_message": errorHTML,
|
"error_message": errorHTML,
|
||||||
@ -141,25 +130,32 @@ func processRegister(ctx router.Ctx, _ []string) {
|
|||||||
userPassword := string(ctx.PostArgs().Peek("password"))
|
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||||
confirmPassword := string(ctx.PostArgs().Peek("confirm_password"))
|
confirmPassword := string(ctx.PostArgs().Peek("confirm_password"))
|
||||||
|
|
||||||
formData := map[string]string{
|
|
||||||
"username": username,
|
|
||||||
"email": email,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
|
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
|
||||||
setFlashAndFormData(ctx, err.Error(), formData)
|
session.SetFlashMessage(ctx, "error", err.Error())
|
||||||
|
session.SetFormData(ctx, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
"email": email,
|
||||||
|
})
|
||||||
ctx.Redirect("/register", fasthttp.StatusFound)
|
ctx.Redirect("/register", fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := users.ByUsername(username); err == nil {
|
if _, err := users.ByUsername(username); err == nil {
|
||||||
setFlashAndFormData(ctx, "Username already exists", formData)
|
session.SetFlashMessage(ctx, "error", "Username already exists")
|
||||||
|
session.SetFormData(ctx, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
"email": email,
|
||||||
|
})
|
||||||
ctx.Redirect("/register", fasthttp.StatusFound)
|
ctx.Redirect("/register", fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := users.ByEmail(email); err == nil {
|
if _, err := users.ByEmail(email); err == nil {
|
||||||
setFlashAndFormData(ctx, "Email already registered", formData)
|
session.SetFlashMessage(ctx, "error", "Email already registered")
|
||||||
|
session.SetFormData(ctx, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
"email": email,
|
||||||
|
})
|
||||||
ctx.Redirect("/register", fasthttp.StatusFound)
|
ctx.Redirect("/register", fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -172,7 +168,11 @@ func processRegister(ctx router.Ctx, _ []string) {
|
|||||||
user.Auth = 1
|
user.Auth = 1
|
||||||
|
|
||||||
if err := user.Insert(); err != nil {
|
if err := user.Insert(); err != nil {
|
||||||
setFlashAndFormData(ctx, "Failed to create account", formData)
|
session.SetFlashMessage(ctx, "error", "Failed to create account")
|
||||||
|
session.SetFormData(ctx, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
"email": email,
|
||||||
|
})
|
||||||
ctx.Redirect("/register", fasthttp.StatusFound)
|
ctx.Redirect("/register", fasthttp.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -180,16 +180,10 @@ func processRegister(ctx router.Ctx, _ []string) {
|
|||||||
// Auto-login after registration
|
// Auto-login after registration
|
||||||
middleware.Login(ctx, user)
|
middleware.Login(ctx, user)
|
||||||
|
|
||||||
// Set success message
|
|
||||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
|
||||||
sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username))
|
|
||||||
session.Store(sess)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transfer CSRF token from cookie to session for authenticated user
|
// Transfer CSRF token from cookie to session for authenticated user
|
||||||
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
if session := csrf.GetCurrentSession(ctx); session != nil {
|
||||||
csrf.StoreToken(sess, cookieToken)
|
csrf.StoreToken(session, cookieToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,10 +229,3 @@ func validateRegistration(username, email, password, confirmPassword string) err
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setFlashAndFormData(ctx router.Ctx, message string, formData map[string]string) {
|
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
|
||||||
sess.SetFlash("error", message)
|
|
||||||
sess.Set("form_data", formData)
|
|
||||||
session.Store(sess)
|
|
||||||
}
|
|
||||||
|
@ -49,13 +49,9 @@ func showTown(ctx router.Ctx, _ []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func showInn(ctx router.Ctx, _ []string) {
|
func showInn(ctx router.Ctx, _ []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
|
||||||
var errorHTML string
|
var errorHTML string
|
||||||
|
if flash := session.GetFlashMessage(ctx); flash != nil {
|
||||||
if flash, exists := sess.GetFlash("error"); exists {
|
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + flash.Message + "</div>"
|
||||||
if msg, ok := flash.(string); ok {
|
|
||||||
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + msg + "</div>"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
town := ctx.UserValue("town").(*towns.Town)
|
town := ctx.UserValue("town").(*towns.Town)
|
||||||
@ -68,12 +64,11 @@ func showInn(ctx router.Ctx, _ []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func rest(ctx router.Ctx, _ []string) {
|
func rest(ctx router.Ctx, _ []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
|
||||||
town := ctx.UserValue("town").(*towns.Town)
|
town := ctx.UserValue("town").(*towns.Town)
|
||||||
user := ctx.UserValue("user").(*users.User)
|
user := ctx.UserValue("user").(*users.User)
|
||||||
|
|
||||||
if user.Gold < town.InnCost {
|
if user.Gold < town.InnCost {
|
||||||
sess.SetFlash("error", "You can't afford to stay here tonight.")
|
session.SetFlashMessage(ctx, "error", "You can't afford to stay here tonight.")
|
||||||
ctx.Redirect("/town/inn", 303)
|
ctx.Redirect("/town/inn", 303)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -88,13 +83,9 @@ func rest(ctx router.Ctx, _ []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func showShop(ctx router.Ctx, _ []string) {
|
func showShop(ctx router.Ctx, _ []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
|
||||||
var errorHTML string
|
var errorHTML string
|
||||||
|
if flash := session.GetFlashMessage(ctx); flash != nil {
|
||||||
if flash, exists := sess.GetFlash("error"); exists {
|
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + flash.Message + "</div>"
|
||||||
if msg, ok := flash.(string); ok {
|
|
||||||
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + msg + "</div>"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
town := ctx.UserValue("town").(*towns.Town)
|
town := ctx.UserValue("town").(*towns.Town)
|
||||||
@ -118,32 +109,30 @@ func showShop(ctx router.Ctx, _ []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buyItem(ctx router.Ctx, params []string) {
|
func buyItem(ctx router.Ctx, params []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
|
||||||
|
|
||||||
id, err := strconv.Atoi(params[0])
|
id, err := strconv.Atoi(params[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sess.SetFlash("error", "Error purchasing item; "+err.Error())
|
session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error())
|
||||||
ctx.Redirect("/town/shop", 302)
|
ctx.Redirect("/town/shop", 302)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
town := ctx.UserValue("town").(*towns.Town)
|
town := ctx.UserValue("town").(*towns.Town)
|
||||||
if !slices.Contains(town.GetShopItems(), id) {
|
if !slices.Contains(town.GetShopItems(), id) {
|
||||||
sess.SetFlash("error", "The item doesn't exist in this shop.")
|
session.SetFlashMessage(ctx, "error", "The item doesn't exist in this shop.")
|
||||||
ctx.Redirect("/town/shop", 302)
|
ctx.Redirect("/town/shop", 302)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
item, err := items.Find(id)
|
item, err := items.Find(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sess.SetFlash("error", "Error purchasing item; "+err.Error())
|
session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error())
|
||||||
ctx.Redirect("/town/shop", 302)
|
ctx.Redirect("/town/shop", 302)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user := ctx.UserValue("user").(*users.User)
|
user := ctx.UserValue("user").(*users.User)
|
||||||
if user.Gold < item.Value {
|
if user.Gold < item.Value {
|
||||||
sess.SetFlash("error", "You don't have enough gold to buy "+item.Name)
|
session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy "+item.Name)
|
||||||
ctx.Redirect("/town/shop", 302)
|
ctx.Redirect("/town/shop", 302)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -156,13 +145,9 @@ func buyItem(ctx router.Ctx, params []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func showMaps(ctx router.Ctx, _ []string) {
|
func showMaps(ctx router.Ctx, _ []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
|
||||||
var errorHTML string
|
var errorHTML string
|
||||||
|
if flash := session.GetFlashMessage(ctx); flash != nil {
|
||||||
if flash, exists := sess.GetFlash("error"); exists {
|
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + flash.Message + "</div>"
|
||||||
if msg, ok := flash.(string); ok {
|
|
||||||
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + msg + "</div>"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
town := ctx.UserValue("town").(*towns.Town)
|
town := ctx.UserValue("town").(*towns.Town)
|
||||||
@ -201,25 +186,23 @@ func showMaps(ctx router.Ctx, _ []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buyMap(ctx router.Ctx, params []string) {
|
func buyMap(ctx router.Ctx, params []string) {
|
||||||
sess := ctx.UserValue("session").(*session.Session)
|
|
||||||
|
|
||||||
id, err := strconv.Atoi(params[0])
|
id, err := strconv.Atoi(params[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sess.SetFlash("error", "Error purchasing map; "+err.Error())
|
session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error())
|
||||||
ctx.Redirect("/town/maps", 302)
|
ctx.Redirect("/town/maps", 302)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mapped, err := towns.Find(id)
|
mapped, err := towns.Find(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sess.SetFlash("error", "Error purchasing map; "+err.Error())
|
session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error())
|
||||||
ctx.Redirect("/town/maps", 302)
|
ctx.Redirect("/town/maps", 302)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user := ctx.UserValue("user").(*users.User)
|
user := ctx.UserValue("user").(*users.User)
|
||||||
if user.Gold < mapped.MapCost {
|
if user.Gold < mapped.MapCost {
|
||||||
sess.SetFlash("error", "You don't have enough gold to buy the map to "+mapped.Name)
|
session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy the map to "+mapped.Name)
|
||||||
ctx.Redirect("/town/maps", 302)
|
ctx.Redirect("/town/maps", 302)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
56
internal/session/flash.go
Normal file
56
internal/session/flash.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
type FlashMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) SetFlash(key string, value any) {
|
||||||
|
if s.Data == nil {
|
||||||
|
s.Data = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
flashData = make(map[string]any)
|
||||||
|
}
|
||||||
|
flashData[key] = value
|
||||||
|
s.Data["_flash"] = flashData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) GetFlash(key string) (any, bool) {
|
||||||
|
if s.Data == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists := flashData[key]
|
||||||
|
if exists {
|
||||||
|
delete(flashData, key)
|
||||||
|
if len(flashData) == 0 {
|
||||||
|
delete(s.Data, "_flash")
|
||||||
|
} else {
|
||||||
|
s.Data["_flash"] = flashData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) GetAllFlash() map[string]any {
|
||||||
|
if s.Data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(s.Data, "_flash")
|
||||||
|
return flashData
|
||||||
|
}
|
@ -1,35 +1,28 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"dk/internal/cookies"
|
||||||
"os"
|
"dk/internal/router"
|
||||||
"sync"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionManager handles session storage and persistence
|
const SessionCookieName = "dk_session"
|
||||||
type SessionManager struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
sessions map[string]*Session
|
|
||||||
filePath string
|
|
||||||
}
|
|
||||||
|
|
||||||
var Manager *SessionManager
|
var Manager *SessionManager
|
||||||
|
|
||||||
// Init initializes the global session manager
|
type SessionManager struct {
|
||||||
func Init(filePath string) {
|
store *Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func Init(sessionsFilePath string) {
|
||||||
if Manager != nil {
|
if Manager != nil {
|
||||||
panic("session manager already initialized")
|
panic("session manager already initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
Manager = &SessionManager{
|
Manager = &SessionManager{
|
||||||
sessions: make(map[string]*Session),
|
store: NewStore(sessionsFilePath),
|
||||||
filePath: filePath,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Manager.load()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetManager returns the global session manager
|
|
||||||
func GetManager() *SessionManager {
|
func GetManager() *SessionManager {
|
||||||
if Manager == nil {
|
if Manager == nil {
|
||||||
panic("session manager not initialized")
|
panic("session manager not initialized")
|
||||||
@ -37,116 +30,200 @@ func GetManager() *SessionManager {
|
|||||||
return Manager
|
return Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates and stores a new session
|
func (sm *SessionManager) Create(userID int, username, email string) *Session {
|
||||||
func (sm *SessionManager) Create(userID int) *Session {
|
sess := New(userID, username, email)
|
||||||
sess := New(userID)
|
sm.store.Save(sess)
|
||||||
sm.mu.Lock()
|
|
||||||
sm.sessions[sess.ID] = sess
|
|
||||||
sm.mu.Unlock()
|
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a session by ID
|
|
||||||
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
|
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
|
||||||
sm.mu.RLock()
|
return sm.store.Get(sessionID)
|
||||||
sess, exists := sm.sessions[sessionID]
|
}
|
||||||
sm.mu.RUnlock()
|
|
||||||
|
|
||||||
if !exists || sess.IsExpired() {
|
func (sm *SessionManager) GetFromContext(ctx router.Ctx) (*Session, bool) {
|
||||||
if exists {
|
sessionID := cookies.GetCookie(ctx, SessionCookieName)
|
||||||
sm.Delete(sessionID)
|
if sessionID == "" {
|
||||||
}
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
return sm.Get(sessionID)
|
||||||
return sess, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store saves a session in memory (updates existing or creates new)
|
func (sm *SessionManager) Update(sessionID string) bool {
|
||||||
func (sm *SessionManager) Store(sess *Session) {
|
sess, exists := sm.store.Get(sessionID)
|
||||||
sm.mu.Lock()
|
if !exists {
|
||||||
sm.sessions[sess.ID] = sess
|
return false
|
||||||
sm.mu.Unlock()
|
}
|
||||||
|
|
||||||
|
sess.Touch()
|
||||||
|
sm.store.Save(sess)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a session
|
|
||||||
func (sm *SessionManager) Delete(sessionID string) {
|
func (sm *SessionManager) Delete(sessionID string) {
|
||||||
sm.mu.Lock()
|
sm.store.Delete(sessionID)
|
||||||
delete(sm.sessions, sessionID)
|
|
||||||
sm.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup removes expired sessions
|
func (sm *SessionManager) SetSessionCookie(ctx router.Ctx, sessionID string) {
|
||||||
func (sm *SessionManager) Cleanup() {
|
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
|
||||||
sm.mu.Lock()
|
Name: SessionCookieName,
|
||||||
for id, sess := range sm.sessions {
|
Value: sessionID,
|
||||||
if sess.IsExpired() {
|
Path: "/",
|
||||||
delete(sm.sessions, id)
|
Expires: time.Now().Add(DefaultExpiration),
|
||||||
}
|
HTTPOnly: true,
|
||||||
}
|
Secure: cookies.IsHTTPS(ctx),
|
||||||
sm.mu.Unlock()
|
SameSite: "lax",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stats returns session statistics
|
func (sm *SessionManager) DeleteSessionCookie(ctx router.Ctx) {
|
||||||
func (sm *SessionManager) Stats() (total, active int) {
|
cookies.DeleteCookie(ctx, SessionCookieName)
|
||||||
sm.mu.RLock()
|
|
||||||
defer sm.mu.RUnlock()
|
|
||||||
|
|
||||||
total = len(sm.sessions)
|
|
||||||
for _, sess := range sm.sessions {
|
|
||||||
if !sess.IsExpired() {
|
|
||||||
active++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// load reads sessions from the JSON file
|
func (sm *SessionManager) SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
|
||||||
func (sm *SessionManager) load() {
|
sess, exists := sm.GetFromContext(ctx)
|
||||||
if sm.filePath == "" {
|
if !exists {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := os.ReadFile(sm.filePath)
|
sess.SetFlash("message", FlashMessage{
|
||||||
if err != nil {
|
Type: msgType,
|
||||||
return // File doesn't exist or can't be read
|
Message: message,
|
||||||
}
|
})
|
||||||
|
sm.store.Save(sess)
|
||||||
var sessions map[string]*Session
|
return true
|
||||||
if err := json.Unmarshal(data, &sessions); err != nil {
|
|
||||||
return // Invalid JSON
|
|
||||||
}
|
|
||||||
|
|
||||||
sm.mu.Lock()
|
|
||||||
for id, sess := range sessions {
|
|
||||||
if sess != nil && !sess.IsExpired() {
|
|
||||||
sess.ID = id // Ensure ID consistency
|
|
||||||
sm.sessions[id] = sess
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sm.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save writes sessions to the JSON file
|
func (sm *SessionManager) GetFlashMessage(ctx router.Ctx) *FlashMessage {
|
||||||
func (sm *SessionManager) Save() error {
|
sess, exists := sm.GetFromContext(ctx)
|
||||||
if sm.filePath == "" {
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sm.Cleanup() // Remove expired sessions before saving
|
value, exists := sess.GetFlash("message")
|
||||||
|
if !exists {
|
||||||
sm.mu.RLock()
|
return nil
|
||||||
data, err := json.MarshalIndent(sm.sessions, "", "\t")
|
|
||||||
sm.mu.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.WriteFile(sm.filePath, data, 0600)
|
sm.store.Save(sess)
|
||||||
|
|
||||||
|
if msg, ok := value.(FlashMessage); ok {
|
||||||
|
return &msg
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgMap, ok := value.(map[string]interface{}); ok {
|
||||||
|
msg := &FlashMessage{}
|
||||||
|
if t, ok := msgMap["type"].(string); ok {
|
||||||
|
msg.Type = t
|
||||||
|
}
|
||||||
|
if m, ok := msgMap["message"].(string); ok {
|
||||||
|
msg.Message = m
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close saves sessions and cleans up
|
func (sm *SessionManager) SetFormData(ctx router.Ctx, data map[string]string) bool {
|
||||||
func (sm *SessionManager) Close() error {
|
sess, exists := sm.GetFromContext(ctx)
|
||||||
return sm.Save()
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.Set("form_data", data)
|
||||||
|
sm.store.Save(sess)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sm *SessionManager) GetFormData(ctx router.Ctx) map[string]string {
|
||||||
|
sess, exists := sm.GetFromContext(ctx)
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists := sess.Get("form_data")
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.Delete("form_data")
|
||||||
|
sm.store.Save(sess)
|
||||||
|
|
||||||
|
if formData, ok := value.(map[string]string); ok {
|
||||||
|
return formData
|
||||||
|
}
|
||||||
|
|
||||||
|
if formMap, ok := value.(map[string]interface{}); ok {
|
||||||
|
result := make(map[string]string)
|
||||||
|
for k, v := range formMap {
|
||||||
|
if str, ok := v.(string); ok {
|
||||||
|
result[k] = str
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *SessionManager) Stats() (total, active int) {
|
||||||
|
return sm.store.Stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *SessionManager) Close() error {
|
||||||
|
return sm.store.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Package-level convenience functions that use the global Manager
|
||||||
|
|
||||||
|
func Create(userID int, username, email string) *Session {
|
||||||
|
return Manager.Create(userID, username, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Get(sessionID string) (*Session, bool) {
|
||||||
|
return Manager.Get(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFromContext(ctx router.Ctx) (*Session, bool) {
|
||||||
|
return Manager.GetFromContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Update(sessionID string) bool {
|
||||||
|
return Manager.Update(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Delete(sessionID string) {
|
||||||
|
Manager.Delete(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetSessionCookie(ctx router.Ctx, sessionID string) {
|
||||||
|
Manager.SetSessionCookie(ctx, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSessionCookie(ctx router.Ctx) {
|
||||||
|
Manager.DeleteSessionCookie(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
|
||||||
|
return Manager.SetFlashMessage(ctx, msgType, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFlashMessage(ctx router.Ctx) *FlashMessage {
|
||||||
|
return Manager.GetFlashMessage(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetFormData(ctx router.Ctx, data map[string]string) bool {
|
||||||
|
return Manager.SetFormData(ctx, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFormData(ctx router.Ctx) map[string]string {
|
||||||
|
return Manager.GetFormData(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Stats() (total, active int) {
|
||||||
|
return Manager.Stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Close() error {
|
||||||
|
return Manager.Close()
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
// session.go
|
// Package session provides session management functionality.
|
||||||
|
// It includes session storage, flash messages, and data persistence.
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -12,97 +13,62 @@ const (
|
|||||||
IDLength = 32
|
IDLength = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
// Session represents a user session
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"-"`
|
||||||
UserID int `json:"user_id"` // 0 for guest sessions
|
UserID int `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
Data map[string]any `json:"data"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
Data map[string]any `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new session
|
func New(userID int, username, email string) *Session {
|
||||||
func New(userID int) *Session {
|
|
||||||
return &Session{
|
return &Session{
|
||||||
ID: generateID(),
|
ID: generateID(),
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
Email: email,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresAt: time.Now().Add(DefaultExpiration),
|
ExpiresAt: time.Now().Add(DefaultExpiration),
|
||||||
|
LastSeen: time.Now(),
|
||||||
Data: make(map[string]any),
|
Data: make(map[string]any),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsExpired checks if the session has expired
|
|
||||||
func (s *Session) IsExpired() bool {
|
func (s *Session) IsExpired() bool {
|
||||||
return time.Now().After(s.ExpiresAt)
|
return time.Now().After(s.ExpiresAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Touch extends the session expiration
|
|
||||||
func (s *Session) Touch() {
|
func (s *Session) Touch() {
|
||||||
|
s.LastSeen = time.Now()
|
||||||
s.ExpiresAt = time.Now().Add(DefaultExpiration)
|
s.ExpiresAt = time.Now().Add(DefaultExpiration)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set stores a value in the session
|
|
||||||
func (s *Session) Set(key string, value any) {
|
func (s *Session) Set(key string, value any) {
|
||||||
|
if s.Data == nil {
|
||||||
|
s.Data = make(map[string]any)
|
||||||
|
}
|
||||||
s.Data[key] = value
|
s.Data[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a value from the session
|
|
||||||
func (s *Session) Get(key string) (any, bool) {
|
func (s *Session) Get(key string) (any, bool) {
|
||||||
|
if s.Data == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
value, exists := s.Data[key]
|
value, exists := s.Data[key]
|
||||||
return value, exists
|
return value, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a value from the session
|
|
||||||
func (s *Session) Delete(key string) {
|
func (s *Session) Delete(key string) {
|
||||||
delete(s.Data, key)
|
if s.Data != nil {
|
||||||
}
|
delete(s.Data, key)
|
||||||
|
|
||||||
// SetFlash stores a flash message (consumed on next Get)
|
|
||||||
func (s *Session) SetFlash(key string, value any) {
|
|
||||||
s.Set("flash_"+key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFlash retrieves and removes a flash message
|
|
||||||
func (s *Session) GetFlash(key string) (any, bool) {
|
|
||||||
flashKey := "flash_" + key
|
|
||||||
value, exists := s.Get(flashKey)
|
|
||||||
if exists {
|
|
||||||
s.Delete(flashKey)
|
|
||||||
}
|
}
|
||||||
return value, exists
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateID creates a random session ID
|
|
||||||
func generateID() string {
|
func generateID() string {
|
||||||
bytes := make([]byte, IDLength)
|
bytes := make([]byte, IDLength)
|
||||||
rand.Read(bytes)
|
rand.Read(bytes)
|
||||||
return hex.EncodeToString(bytes)
|
return hex.EncodeToString(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Package-level convenience functions
|
|
||||||
func Create(userID int) *Session {
|
|
||||||
return Manager.Create(userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Get(sessionID string) (*Session, bool) {
|
|
||||||
return Manager.Get(sessionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Store(sess *Session) {
|
|
||||||
Manager.Store(sess)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Delete(sessionID string) {
|
|
||||||
Manager.Delete(sessionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Cleanup() {
|
|
||||||
Manager.Cleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
func Stats() (total, active int) {
|
|
||||||
return Manager.Stats()
|
|
||||||
}
|
|
||||||
|
|
||||||
func Close() error {
|
|
||||||
return Manager.Close()
|
|
||||||
}
|
|
||||||
|
161
internal/session/store.go
Normal file
161
internal/session/store.go
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]*Session
|
||||||
|
filePath string
|
||||||
|
saveInterval time.Duration
|
||||||
|
stopChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type persistedData struct {
|
||||||
|
Sessions map[string]*Session `json:"sessions"`
|
||||||
|
SavedAt time.Time `json:"saved_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore(filePath string) *Store {
|
||||||
|
store := &Store{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
filePath: filePath,
|
||||||
|
saveInterval: 5 * time.Minute,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
store.loadFromFile()
|
||||||
|
store.startPeriodicSave()
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Save(session *Session) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.sessions[session.ID] = session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Get(sessionID string) (*Session, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
session, exists := s.sessions[sessionID]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.IsExpired() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Delete(sessionID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Cleanup() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for id, session := range s.sessions {
|
||||||
|
if session.IsExpired() {
|
||||||
|
delete(s.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Stats() (total, active int) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
total = len(s.sessions)
|
||||||
|
for _, session := range s.sessions {
|
||||||
|
if !session.IsExpired() {
|
||||||
|
active++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) loadFromFile() {
|
||||||
|
if s.filePath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(s.filePath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var persisted persistedData
|
||||||
|
if err := json.Unmarshal(data, &persisted); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for id, session := range persisted.Sessions {
|
||||||
|
if !session.IsExpired() {
|
||||||
|
session.ID = id
|
||||||
|
s.sessions[id] = session
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) saveToFile() error {
|
||||||
|
if s.filePath == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
sessionsCopy := make(map[string]*Session, len(s.sessions))
|
||||||
|
maps.Copy(sessionsCopy, s.sessions)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
data := persistedData{
|
||||||
|
Sessions: sessionsCopy,
|
||||||
|
SavedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.MarshalIndent(data, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(s.filePath, jsonData, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) startPeriodicSave() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(s.saveInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.Cleanup()
|
||||||
|
s.saveToFile()
|
||||||
|
case <-s.stopChan:
|
||||||
|
s.saveToFile()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Close() error {
|
||||||
|
close(s.stopChan)
|
||||||
|
return s.saveToFile()
|
||||||
|
}
|
@ -3,7 +3,6 @@ package components
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"dk/internal/csrf"
|
"dk/internal/csrf"
|
||||||
@ -23,18 +22,15 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin
|
|||||||
return fmt.Errorf("failed to load layout template: %w", err)
|
return fmt.Errorf("failed to load layout template: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var m runtime.MemStats
|
|
||||||
runtime.ReadMemStats(&m)
|
|
||||||
|
|
||||||
data := map[string]any{
|
data := map[string]any{
|
||||||
"_title": PageTitle(title),
|
"_title": PageTitle(title),
|
||||||
"authenticated": middleware.IsAuthenticated(ctx),
|
"authenticated": middleware.IsAuthenticated(ctx),
|
||||||
"csrf": csrf.HiddenField(ctx),
|
"csrf": csrf.HiddenField(ctx),
|
||||||
"_totaltime": middleware.GetRequestTime(ctx),
|
"_totaltime": middleware.GetRequestTime(ctx),
|
||||||
|
"_numqueries": 0,
|
||||||
"_version": "1.0.0",
|
"_version": "1.0.0",
|
||||||
"_build": "dev",
|
"_build": "dev",
|
||||||
"user": middleware.GetCurrentUser(ctx),
|
"user": middleware.GetCurrentUser(ctx),
|
||||||
"_memalloc": m.Alloc / 1024 / 1024,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
maps.Copy(data, LeftAside(ctx))
|
maps.Copy(data, LeftAside(ctx))
|
||||||
|
9
main.go
9
main.go
@ -158,14 +158,15 @@ func start(port string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get current working directory: %w", err)
|
return fmt.Errorf("failed to get current working directory: %w", err)
|
||||||
}
|
}
|
||||||
|
// Initialize template singleton
|
||||||
template.InitializeCache(cwd)
|
template.InitializeCache(cwd)
|
||||||
|
|
||||||
|
// Load all model data into memory
|
||||||
if err := loadModels(); err != nil {
|
if err := loadModels(); err != nil {
|
||||||
return fmt.Errorf("failed to load models: %w", err)
|
return fmt.Errorf("failed to load models: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Init("sessions.json")
|
session.Init("sessions.json") // Initialize session.Manager
|
||||||
|
|
||||||
r := router.New()
|
r := router.New()
|
||||||
r.Use(middleware.Timing())
|
r.Use(middleware.Timing())
|
||||||
@ -173,8 +174,8 @@ func start(port string) error {
|
|||||||
r.Use(middleware.CSRF())
|
r.Use(middleware.CSRF())
|
||||||
|
|
||||||
r.Get("/", routes.Index)
|
r.Get("/", routes.Index)
|
||||||
r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore)
|
r.Use(middleware.RequireAuth()).Get("/explore", routes.Explore)
|
||||||
r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move)
|
r.Use(middleware.RequireAuth()).Post("/move", routes.Move)
|
||||||
routes.RegisterAuthRoutes(r)
|
routes.RegisterAuthRoutes(r)
|
||||||
routes.RegisterTownRoutes(r)
|
routes.RegisterTownRoutes(r)
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@
|
|||||||
<footer>
|
<footer>
|
||||||
<div>Powered by <a href="/">Dragon Knight</a></div>
|
<div>Powered by <a href="/">Dragon Knight</a></div>
|
||||||
<div>© 2025 Sharkk</div>
|
<div>© 2025 Sharkk</div>
|
||||||
<div>{_totaltime} Seconds, {_memalloc} MiB</div>
|
<div>{_totaltime} Seconds, {_numqueries} Queries</div>
|
||||||
<div>Version {_version} {_build}</div>
|
<div>Version {_version} {_build}</div>
|
||||||
</footer>
|
</footer>
|
||||||
</div>
|
</div>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user