diff --git a/internal/csrf/csrf_test.go b/internal/csrf/csrf_test.go
deleted file mode 100644
index d2d8f0c..0000000
--- a/internal/csrf/csrf_test.go
+++ /dev/null
@@ -1,167 +0,0 @@
-package csrf
-
-import (
- "testing"
- "time"
-
- "dk/internal/session"
-
- "github.com/valyala/fasthttp"
-)
-
-func TestGenerateToken(t *testing.T) {
- sess := &session.Session{
- ID: "test-session",
- UserID: 1,
- Username: "testuser",
- Email: "test@example.com",
- CreatedAt: time.Now(),
- ExpiresAt: time.Now().Add(time.Hour),
- LastSeen: time.Now(),
- Data: make(map[string]any),
- }
-
- ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, sess)
-
- token := GenerateToken(ctx)
-
- if token == "" {
- t.Error("Expected non-empty token")
- }
-
- storedToken := GetStoredToken(sess)
- if storedToken != token {
- t.Errorf("Expected stored token %s, got %s", token, storedToken)
- }
-}
-
-func TestValidateToken(t *testing.T) {
- sess := &session.Session{
- ID: "test-session",
- UserID: 1,
- Username: "testuser",
- Email: "test@example.com",
- Data: map[string]any{SessionKey: "test-token"},
- }
-
- ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, sess)
-
- if !ValidateToken(ctx, "test-token") {
- t.Error("Expected valid token to pass validation")
- }
-
- if ValidateToken(ctx, "wrong-token") {
- t.Error("Expected invalid token to fail validation")
- }
-
- if ValidateToken(ctx, "") {
- t.Error("Expected empty token to fail validation")
- }
-}
-
-func TestValidateTokenNoSession(t *testing.T) {
- ctx := &fasthttp.RequestCtx{}
-
- if ValidateToken(ctx, "any-token") {
- t.Error("Expected validation to fail with no session")
- }
-}
-
-func TestHiddenField(t *testing.T) {
- sess := &session.Session{
- ID: "test-session",
- UserID: 1,
- Username: "testuser",
- Email: "test@example.com",
- Data: map[string]any{SessionKey: "test-token"},
- }
-
- ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, sess)
-
- field := HiddenField(ctx)
- expected := ` `
-
- if field != expected {
- t.Errorf("Expected %s, got %s", expected, field)
- }
-}
-
-func TestHiddenFieldNoSession(t *testing.T) {
- ctx := &fasthttp.RequestCtx{}
-
- field := HiddenField(ctx)
- if field == "" {
- t.Error("Expected non-empty field for guest user with cookie-based token")
- }
-}
-
-func TestTokenMeta(t *testing.T) {
- sess := &session.Session{
- ID: "test-session",
- UserID: 1,
- Username: "testuser",
- Email: "test@example.com",
- Data: map[string]any{SessionKey: "test-token"},
- }
-
- ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, sess)
-
- meta := TokenMeta(ctx)
- expected := ` `
-
- if meta != expected {
- t.Errorf("Expected %s, got %s", expected, meta)
- }
-}
-
-func TestStoreAndGetToken(t *testing.T) {
- sess := &session.Session{
- Data: make(map[string]any),
- }
-
- token := "test-token"
- StoreToken(sess, token)
-
- retrieved := GetStoredToken(sess)
- if retrieved != token {
- t.Errorf("Expected %s, got %s", token, retrieved)
- }
-}
-
-func TestGetStoredTokenNoData(t *testing.T) {
- sess := &session.Session{}
-
- token := GetStoredToken(sess)
- if token != "" {
- t.Errorf("Expected empty token, got %s", token)
- }
-}
-
-func TestValidateFormToken(t *testing.T) {
- sess := &session.Session{
- ID: "test-session",
- UserID: 1,
- Username: "testuser",
- Email: "test@example.com",
- Data: map[string]any{SessionKey: "test-token"},
- }
-
- ctx := &fasthttp.RequestCtx{}
- ctx.SetUserValue(SessionCtxKey, sess)
-
- ctx.PostArgs().Set(TokenFieldName, "test-token")
-
- if !ValidateFormToken(ctx) {
- t.Error("Expected form token validation to pass")
- }
-
- ctx.PostArgs().Set(TokenFieldName, "wrong-token")
-
- if ValidateFormToken(ctx) {
- t.Error("Expected form token validation to fail with wrong token")
- }
-}
\ No newline at end of file
diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go
index 13b4b2f..809b38d 100644
--- a/internal/middleware/auth.go
+++ b/internal/middleware/auth.go
@@ -5,29 +5,43 @@ import (
"dk/internal/models/users"
"dk/internal/router"
"dk/internal/session"
+ "fmt"
+ "time"
"github.com/valyala/fasthttp"
)
+const SessionCookieName = "dk_session"
+
func Auth() router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
- sessionID := cookies.GetCookie(ctx, session.SessionCookieName)
+ sessionID := cookies.GetCookie(ctx, SessionCookieName)
+ var sess *session.Session
if sessionID != "" {
- if sess, exists := session.Get(sessionID); exists {
- session.Update(sessionID)
+ if existingSess, exists := session.Get(sessionID); exists {
+ sess = existingSess
+ sess.Touch()
+ session.Store(sess)
- user, err := users.Find(sess.UserID)
- if err == nil && user != nil {
- ctx.SetUserValue("session", sess)
- ctx.SetUserValue("user", user)
-
- session.SetSessionCookie(ctx, sessionID)
+ if sess.UserID > 0 { // User session
+ user, err := users.Find(sess.UserID)
+ if err == nil && user != nil {
+ ctx.SetUserValue("user", user)
+ setSessionCookie(ctx, sessionID)
+ }
}
}
}
+ // Create guest session if none exists
+ if sess == nil {
+ sess = session.Create(0) // Guest session
+ setSessionCookie(ctx, sess.ID)
+ }
+
+ ctx.SetUserValue("session", sess)
next(ctx, params)
}
}
@@ -64,6 +78,7 @@ func RequireGuest(paths ...string) router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
if IsAuthenticated(ctx) {
+ fmt.Println("RequireGuest: user is authenticated")
ctx.Redirect(redirect, fasthttp.StatusFound)
return
}
@@ -92,21 +107,38 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
}
func Login(ctx router.Ctx, user *users.User) {
- sess := session.Create(user.ID, user.Username, user.Email)
- session.SetSessionCookie(ctx, sess.ID)
+ sess := session.Create(user.ID)
+ setSessionCookie(ctx, sess.ID)
ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user)
}
func Logout(ctx router.Ctx) {
- sessionID := cookies.GetCookie(ctx, session.SessionCookieName)
+ sessionID := cookies.GetCookie(ctx, SessionCookieName)
if sessionID != "" {
session.Delete(sessionID)
}
- session.DeleteSessionCookie(ctx)
+ deleteSessionCookie(ctx)
ctx.SetUserValue("session", nil)
ctx.SetUserValue("user", nil)
}
+
+// Helper functions for session cookies
+func setSessionCookie(ctx router.Ctx, sessionID string) {
+ cookies.SetSecureCookie(ctx, cookies.CookieOptions{
+ Name: SessionCookieName,
+ Value: sessionID,
+ Path: "/",
+ Expires: time.Now().Add(24 * time.Hour),
+ HTTPOnly: true,
+ Secure: cookies.IsHTTPS(ctx),
+ SameSite: "lax",
+ })
+}
+
+func deleteSessionCookie(ctx router.Ctx) {
+ cookies.DeleteCookie(ctx, SessionCookieName)
+}
diff --git a/internal/routes/auth.go b/internal/routes/auth.go
index 5442ffe..3320af3 100644
--- a/internal/routes/auth.go
+++ b/internal/routes/auth.go
@@ -18,36 +18,39 @@ import (
// RegisterAuthRoutes sets up authentication routes
func RegisterAuthRoutes(r *router.Router) {
- // Guest routes
- guestGroup := r.Group("")
- guestGroup.Use(middleware.RequireGuest())
+ guests := r.Group("")
+ guests.Use(middleware.RequireGuest())
- guestGroup.Get("/login", showLogin)
- guestGroup.Post("/login", processLogin)
- guestGroup.Get("/register", showRegister)
- guestGroup.Post("/register", processRegister)
+ guests.Get("/login", showLogin)
+ guests.Post("/login", processLogin)
+ guests.Get("/register", showRegister)
+ guests.Post("/register", processRegister)
- // Authenticated routes
- authGroup := r.Group("")
- authGroup.Use(middleware.RequireAuth())
+ authed := r.Group("")
+ authed.Use(middleware.RequireAuth())
- authGroup.Post("/logout", processLogout)
+ authed.Post("/logout", processLogout)
}
// showLogin displays the login form
func showLogin(ctx router.Ctx, _ []string) {
- // Get flash message if any
+ sess := ctx.UserValue("session").(*session.Session)
var errorHTML string
- if flash := session.GetFlashMessage(ctx); flash != nil {
- errorHTML = fmt.Sprintf(`
%s
`, flash.Message)
+ var id string
+
+ if flash, exists := sess.GetFlash("error"); exists {
+ if msg, ok := flash.(string); ok {
+ errorHTML = fmt.Sprintf(`%s
`, msg)
+ }
}
- // Get form data if any (for preserving email/username on error)
- formData := session.GetFormData(ctx)
- id := ""
- if formData != nil {
- id = formData["id"]
+ if formData, exists := sess.Get("form_data"); exists {
+ if data, ok := formData.(map[string]string); ok {
+ id = data["id"]
+ }
}
+ sess.Delete("form_data")
+ session.Store(sess)
components.RenderPage(ctx, "Log In", "auth/login.html", map[string]any{
"error_message": errorHTML,
@@ -67,26 +70,30 @@ func processLogin(ctx router.Ctx, _ []string) {
userPassword := string(ctx.PostArgs().Peek("password"))
if email == "" || userPassword == "" {
- session.SetFlashMessage(ctx, "error", "Email and password are required")
- session.SetFormData(ctx, map[string]string{"id": email})
+ setFlashAndFormData(ctx, "Email and password are required", map[string]string{"id": email})
ctx.Redirect("/login", fasthttp.StatusFound)
return
}
user, err := auth.Authenticate(email, userPassword)
if err != nil {
- session.SetFlashMessage(ctx, "error", "Invalid email or password")
- session.SetFormData(ctx, map[string]string{"id": email})
+ setFlashAndFormData(ctx, "Invalid email or password", map[string]string{"id": email})
ctx.Redirect("/login", fasthttp.StatusFound)
return
}
middleware.Login(ctx, user)
+ // Set success message
+ if sess := ctx.UserValue("session").(*session.Session); sess != nil {
+ sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
+ session.Store(sess)
+ }
+
// Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
- if session := csrf.GetCurrentSession(ctx); session != nil {
- csrf.StoreToken(session, cookieToken)
+ if sess := ctx.UserValue("session").(*session.Session); sess != nil {
+ csrf.StoreToken(sess, cookieToken)
}
}
@@ -95,20 +102,24 @@ func processLogin(ctx router.Ctx, _ []string) {
// showRegister displays the registration form
func showRegister(ctx router.Ctx, _ []string) {
- // Get flash message if any
+ sess := ctx.UserValue("session").(*session.Session)
var errorHTML string
- if flash := session.GetFlashMessage(ctx); flash != nil {
- errorHTML = fmt.Sprintf(`%s
`, flash.Message)
+ var username, email string
+
+ if flash, exists := sess.GetFlash("error"); exists {
+ if msg, ok := flash.(string); ok {
+ errorHTML = fmt.Sprintf(`%s
`, msg)
+ }
}
- // Get form data if any (for preserving values on error)
- formData := session.GetFormData(ctx)
- username := ""
- email := ""
- if formData != nil {
- username = formData["username"]
- email = formData["email"]
+ if formData, exists := sess.Get("form_data"); exists {
+ if data, ok := formData.(map[string]string); ok {
+ username = data["username"]
+ email = data["email"]
+ }
}
+ sess.Delete("form_data")
+ session.Store(sess)
components.RenderPage(ctx, "Register", "auth/register.html", map[string]any{
"error_message": errorHTML,
@@ -130,32 +141,25 @@ func processRegister(ctx router.Ctx, _ []string) {
userPassword := string(ctx.PostArgs().Peek("password"))
confirmPassword := string(ctx.PostArgs().Peek("confirm_password"))
+ formData := map[string]string{
+ "username": username,
+ "email": email,
+ }
+
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
- session.SetFlashMessage(ctx, "error", err.Error())
- session.SetFormData(ctx, map[string]string{
- "username": username,
- "email": email,
- })
+ setFlashAndFormData(ctx, err.Error(), formData)
ctx.Redirect("/register", fasthttp.StatusFound)
return
}
if _, err := users.ByUsername(username); err == nil {
- session.SetFlashMessage(ctx, "error", "Username already exists")
- session.SetFormData(ctx, map[string]string{
- "username": username,
- "email": email,
- })
+ setFlashAndFormData(ctx, "Username already exists", formData)
ctx.Redirect("/register", fasthttp.StatusFound)
return
}
if _, err := users.ByEmail(email); err == nil {
- session.SetFlashMessage(ctx, "error", "Email already registered")
- session.SetFormData(ctx, map[string]string{
- "username": username,
- "email": email,
- })
+ setFlashAndFormData(ctx, "Email already registered", formData)
ctx.Redirect("/register", fasthttp.StatusFound)
return
}
@@ -168,11 +172,7 @@ func processRegister(ctx router.Ctx, _ []string) {
user.Auth = 1
if err := user.Insert(); err != nil {
- session.SetFlashMessage(ctx, "error", "Failed to create account")
- session.SetFormData(ctx, map[string]string{
- "username": username,
- "email": email,
- })
+ setFlashAndFormData(ctx, "Failed to create account", formData)
ctx.Redirect("/register", fasthttp.StatusFound)
return
}
@@ -180,10 +180,16 @@ func processRegister(ctx router.Ctx, _ []string) {
// Auto-login after registration
middleware.Login(ctx, user)
+ // Set success message
+ if sess := ctx.UserValue("session").(*session.Session); sess != nil {
+ sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username))
+ session.Store(sess)
+ }
+
// Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
- if session := csrf.GetCurrentSession(ctx); session != nil {
- csrf.StoreToken(session, cookieToken)
+ if sess := ctx.UserValue("session").(*session.Session); sess != nil {
+ csrf.StoreToken(sess, cookieToken)
}
}
@@ -229,3 +235,10 @@ func validateRegistration(username, email, password, confirmPassword string) err
}
return nil
}
+
+func setFlashAndFormData(ctx router.Ctx, message string, formData map[string]string) {
+ sess := ctx.UserValue("session").(*session.Session)
+ sess.SetFlash("error", message)
+ sess.Set("form_data", formData)
+ session.Store(sess)
+}
diff --git a/internal/routes/town.go b/internal/routes/town.go
index cbd713b..2ed670d 100644
--- a/internal/routes/town.go
+++ b/internal/routes/town.go
@@ -49,9 +49,13 @@ func showTown(ctx router.Ctx, _ []string) {
}
func showInn(ctx router.Ctx, _ []string) {
+ sess := ctx.UserValue("session").(*session.Session)
var errorHTML string
- if flash := session.GetFlashMessage(ctx); flash != nil {
- errorHTML = `` + flash.Message + "
"
+
+ if flash, exists := sess.GetFlash("error"); exists {
+ if msg, ok := flash.(string); ok {
+ errorHTML = `` + msg + "
"
+ }
}
town := ctx.UserValue("town").(*towns.Town)
@@ -64,11 +68,12 @@ func showInn(ctx router.Ctx, _ []string) {
}
func rest(ctx router.Ctx, _ []string) {
+ sess := ctx.UserValue("session").(*session.Session)
town := ctx.UserValue("town").(*towns.Town)
user := ctx.UserValue("user").(*users.User)
if user.Gold < town.InnCost {
- session.SetFlashMessage(ctx, "error", "You can't afford to stay here tonight.")
+ sess.SetFlash("error", "You can't afford to stay here tonight.")
ctx.Redirect("/town/inn", 303)
return
}
@@ -83,9 +88,13 @@ func rest(ctx router.Ctx, _ []string) {
}
func showShop(ctx router.Ctx, _ []string) {
+ sess := ctx.UserValue("session").(*session.Session)
var errorHTML string
- if flash := session.GetFlashMessage(ctx); flash != nil {
- errorHTML = `` + flash.Message + "
"
+
+ if flash, exists := sess.GetFlash("error"); exists {
+ if msg, ok := flash.(string); ok {
+ errorHTML = `` + msg + "
"
+ }
}
town := ctx.UserValue("town").(*towns.Town)
@@ -109,30 +118,32 @@ func showShop(ctx router.Ctx, _ []string) {
}
func buyItem(ctx router.Ctx, params []string) {
+ sess := ctx.UserValue("session").(*session.Session)
+
id, err := strconv.Atoi(params[0])
if err != nil {
- session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error())
+ sess.SetFlash("error", "Error purchasing item; "+err.Error())
ctx.Redirect("/town/shop", 302)
return
}
town := ctx.UserValue("town").(*towns.Town)
if !slices.Contains(town.GetShopItems(), id) {
- session.SetFlashMessage(ctx, "error", "The item doesn't exist in this shop.")
+ sess.SetFlash("error", "The item doesn't exist in this shop.")
ctx.Redirect("/town/shop", 302)
return
}
item, err := items.Find(id)
if err != nil {
- session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error())
+ sess.SetFlash("error", "Error purchasing item; "+err.Error())
ctx.Redirect("/town/shop", 302)
return
}
user := ctx.UserValue("user").(*users.User)
if user.Gold < item.Value {
- session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy "+item.Name)
+ sess.SetFlash("error", "You don't have enough gold to buy "+item.Name)
ctx.Redirect("/town/shop", 302)
return
}
@@ -145,9 +156,13 @@ func buyItem(ctx router.Ctx, params []string) {
}
func showMaps(ctx router.Ctx, _ []string) {
+ sess := ctx.UserValue("session").(*session.Session)
var errorHTML string
- if flash := session.GetFlashMessage(ctx); flash != nil {
- errorHTML = `` + flash.Message + "
"
+
+ if flash, exists := sess.GetFlash("error"); exists {
+ if msg, ok := flash.(string); ok {
+ errorHTML = `` + msg + "
"
+ }
}
town := ctx.UserValue("town").(*towns.Town)
@@ -186,23 +201,25 @@ func showMaps(ctx router.Ctx, _ []string) {
}
func buyMap(ctx router.Ctx, params []string) {
+ sess := ctx.UserValue("session").(*session.Session)
+
id, err := strconv.Atoi(params[0])
if err != nil {
- session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error())
+ sess.SetFlash("error", "Error purchasing map; "+err.Error())
ctx.Redirect("/town/maps", 302)
return
}
mapped, err := towns.Find(id)
if err != nil {
- session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error())
+ sess.SetFlash("error", "Error purchasing map; "+err.Error())
ctx.Redirect("/town/maps", 302)
return
}
user := ctx.UserValue("user").(*users.User)
if user.Gold < mapped.MapCost {
- session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy the map to "+mapped.Name)
+ sess.SetFlash("error", "You don't have enough gold to buy the map to "+mapped.Name)
ctx.Redirect("/town/maps", 302)
return
}
diff --git a/internal/session/flash.go b/internal/session/flash.go
deleted file mode 100644
index 9c96036..0000000
--- a/internal/session/flash.go
+++ /dev/null
@@ -1,56 +0,0 @@
-package session
-
-type FlashMessage struct {
- Type string `json:"type"`
- Message string `json:"message"`
-}
-
-func (s *Session) SetFlash(key string, value any) {
- if s.Data == nil {
- s.Data = make(map[string]any)
- }
-
- flashData, ok := s.Data["_flash"].(map[string]any)
- if !ok {
- flashData = make(map[string]any)
- }
- flashData[key] = value
- s.Data["_flash"] = flashData
-}
-
-func (s *Session) GetFlash(key string) (any, bool) {
- if s.Data == nil {
- return nil, false
- }
-
- flashData, ok := s.Data["_flash"].(map[string]any)
- if !ok {
- return nil, false
- }
-
- value, exists := flashData[key]
- if exists {
- delete(flashData, key)
- if len(flashData) == 0 {
- delete(s.Data, "_flash")
- } else {
- s.Data["_flash"] = flashData
- }
- }
-
- return value, exists
-}
-
-func (s *Session) GetAllFlash() map[string]any {
- if s.Data == nil {
- return nil
- }
-
- flashData, ok := s.Data["_flash"].(map[string]any)
- if !ok {
- return nil
- }
-
- delete(s.Data, "_flash")
- return flashData
-}
\ No newline at end of file
diff --git a/internal/session/manager.go b/internal/session/manager.go
index 6a47590..69ab098 100644
--- a/internal/session/manager.go
+++ b/internal/session/manager.go
@@ -1,28 +1,35 @@
package session
import (
- "dk/internal/cookies"
- "dk/internal/router"
- "time"
+ "encoding/json"
+ "os"
+ "sync"
)
-const SessionCookieName = "dk_session"
+// SessionManager handles session storage and persistence
+type SessionManager struct {
+ mu sync.RWMutex
+ sessions map[string]*Session
+ filePath string
+}
var Manager *SessionManager
-type SessionManager struct {
- store *Store
-}
-
-func Init(sessionsFilePath string) {
+// Init initializes the global session manager
+func Init(filePath string) {
if Manager != nil {
panic("session manager already initialized")
}
+
Manager = &SessionManager{
- store: NewStore(sessionsFilePath),
+ sessions: make(map[string]*Session),
+ filePath: filePath,
}
+
+ Manager.load()
}
+// GetManager returns the global session manager
func GetManager() *SessionManager {
if Manager == nil {
panic("session manager not initialized")
@@ -30,200 +37,116 @@ func GetManager() *SessionManager {
return Manager
}
-func (sm *SessionManager) Create(userID int, username, email string) *Session {
- sess := New(userID, username, email)
- sm.store.Save(sess)
+// Create creates and stores a new session
+func (sm *SessionManager) Create(userID int) *Session {
+ sess := New(userID)
+ sm.mu.Lock()
+ sm.sessions[sess.ID] = sess
+ sm.mu.Unlock()
return sess
}
+// Get retrieves a session by ID
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
- return sm.store.Get(sessionID)
-}
+ sm.mu.RLock()
+ sess, exists := sm.sessions[sessionID]
+ sm.mu.RUnlock()
-func (sm *SessionManager) GetFromContext(ctx router.Ctx) (*Session, bool) {
- sessionID := cookies.GetCookie(ctx, SessionCookieName)
- if sessionID == "" {
+ if !exists || sess.IsExpired() {
+ if exists {
+ sm.Delete(sessionID)
+ }
return nil, false
}
- return sm.Get(sessionID)
+
+ return sess, true
}
-func (sm *SessionManager) Update(sessionID string) bool {
- sess, exists := sm.store.Get(sessionID)
- if !exists {
- return false
- }
-
- sess.Touch()
- sm.store.Save(sess)
- return true
+// Store saves a session in memory (updates existing or creates new)
+func (sm *SessionManager) Store(sess *Session) {
+ sm.mu.Lock()
+ sm.sessions[sess.ID] = sess
+ sm.mu.Unlock()
}
+// Delete removes a session
func (sm *SessionManager) Delete(sessionID string) {
- sm.store.Delete(sessionID)
+ sm.mu.Lock()
+ delete(sm.sessions, sessionID)
+ sm.mu.Unlock()
}
-func (sm *SessionManager) SetSessionCookie(ctx router.Ctx, sessionID string) {
- cookies.SetSecureCookie(ctx, cookies.CookieOptions{
- Name: SessionCookieName,
- Value: sessionID,
- Path: "/",
- Expires: time.Now().Add(DefaultExpiration),
- HTTPOnly: true,
- Secure: cookies.IsHTTPS(ctx),
- SameSite: "lax",
- })
-}
-
-func (sm *SessionManager) DeleteSessionCookie(ctx router.Ctx) {
- cookies.DeleteCookie(ctx, SessionCookieName)
-}
-
-func (sm *SessionManager) SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
- sess, exists := sm.GetFromContext(ctx)
- if !exists {
- return false
- }
-
- sess.SetFlash("message", FlashMessage{
- Type: msgType,
- Message: message,
- })
- sm.store.Save(sess)
- return true
-}
-
-func (sm *SessionManager) GetFlashMessage(ctx router.Ctx) *FlashMessage {
- sess, exists := sm.GetFromContext(ctx)
- if !exists {
- return nil
- }
-
- value, exists := sess.GetFlash("message")
- if !exists {
- return nil
- }
-
- sm.store.Save(sess)
-
- if msg, ok := value.(FlashMessage); ok {
- return &msg
- }
-
- if msgMap, ok := value.(map[string]interface{}); ok {
- msg := &FlashMessage{}
- if t, ok := msgMap["type"].(string); ok {
- msg.Type = t
+// Cleanup removes expired sessions
+func (sm *SessionManager) Cleanup() {
+ sm.mu.Lock()
+ for id, sess := range sm.sessions {
+ if sess.IsExpired() {
+ delete(sm.sessions, id)
}
- if m, ok := msgMap["message"].(string); ok {
- msg.Message = m
- }
- return msg
}
-
- return nil
-}
-
-func (sm *SessionManager) SetFormData(ctx router.Ctx, data map[string]string) bool {
- sess, exists := sm.GetFromContext(ctx)
- if !exists {
- return false
- }
-
- sess.Set("form_data", data)
- sm.store.Save(sess)
- return true
-}
-
-func (sm *SessionManager) GetFormData(ctx router.Ctx) map[string]string {
- sess, exists := sm.GetFromContext(ctx)
- if !exists {
- return nil
- }
-
- value, exists := sess.Get("form_data")
- if !exists {
- return nil
- }
-
- sess.Delete("form_data")
- sm.store.Save(sess)
-
- if formData, ok := value.(map[string]string); ok {
- return formData
- }
-
- if formMap, ok := value.(map[string]interface{}); ok {
- result := make(map[string]string)
- for k, v := range formMap {
- if str, ok := v.(string); ok {
- result[k] = str
- }
- }
- return result
- }
-
- return nil
+ sm.mu.Unlock()
}
+// Stats returns session statistics
func (sm *SessionManager) Stats() (total, active int) {
- return sm.store.Stats()
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+
+ total = len(sm.sessions)
+ for _, sess := range sm.sessions {
+ if !sess.IsExpired() {
+ active++
+ }
+ }
+ return
}
+// load reads sessions from the JSON file
+func (sm *SessionManager) load() {
+ if sm.filePath == "" {
+ return
+ }
+
+ data, err := os.ReadFile(sm.filePath)
+ if err != nil {
+ return // File doesn't exist or can't be read
+ }
+
+ var sessions map[string]*Session
+ if err := json.Unmarshal(data, &sessions); err != nil {
+ return // Invalid JSON
+ }
+
+ sm.mu.Lock()
+ for id, sess := range sessions {
+ if sess != nil && !sess.IsExpired() {
+ sess.ID = id // Ensure ID consistency
+ sm.sessions[id] = sess
+ }
+ }
+ sm.mu.Unlock()
+}
+
+// Save writes sessions to the JSON file
+func (sm *SessionManager) Save() error {
+ if sm.filePath == "" {
+ return nil
+ }
+
+ sm.Cleanup() // Remove expired sessions before saving
+
+ sm.mu.RLock()
+ data, err := json.MarshalIndent(sm.sessions, "", "\t")
+ sm.mu.RUnlock()
+
+ if err != nil {
+ return err
+ }
+
+ return os.WriteFile(sm.filePath, data, 0600)
+}
+
+// Close saves sessions and cleans up
func (sm *SessionManager) Close() error {
- return sm.store.Close()
+ return sm.Save()
}
-
-// Package-level convenience functions that use the global Manager
-
-func Create(userID int, username, email string) *Session {
- return Manager.Create(userID, username, email)
-}
-
-func Get(sessionID string) (*Session, bool) {
- return Manager.Get(sessionID)
-}
-
-func GetFromContext(ctx router.Ctx) (*Session, bool) {
- return Manager.GetFromContext(ctx)
-}
-
-func Update(sessionID string) bool {
- return Manager.Update(sessionID)
-}
-
-func Delete(sessionID string) {
- Manager.Delete(sessionID)
-}
-
-func SetSessionCookie(ctx router.Ctx, sessionID string) {
- Manager.SetSessionCookie(ctx, sessionID)
-}
-
-func DeleteSessionCookie(ctx router.Ctx) {
- Manager.DeleteSessionCookie(ctx)
-}
-
-func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
- return Manager.SetFlashMessage(ctx, msgType, message)
-}
-
-func GetFlashMessage(ctx router.Ctx) *FlashMessage {
- return Manager.GetFlashMessage(ctx)
-}
-
-func SetFormData(ctx router.Ctx, data map[string]string) bool {
- return Manager.SetFormData(ctx, data)
-}
-
-func GetFormData(ctx router.Ctx) map[string]string {
- return Manager.GetFormData(ctx)
-}
-
-func Stats() (total, active int) {
- return Manager.Stats()
-}
-
-func Close() error {
- return Manager.Close()
-}
\ No newline at end of file
diff --git a/internal/session/session.go b/internal/session/session.go
index 96fbc2a..65c3d99 100644
--- a/internal/session/session.go
+++ b/internal/session/session.go
@@ -1,5 +1,4 @@
-// Package session provides session management functionality.
-// It includes session storage, flash messages, and data persistence.
+// session.go
package session
import (
@@ -13,62 +12,97 @@ const (
IDLength = 32
)
+// Session represents a user session
type Session struct {
- ID string `json:"-"`
- UserID int `json:"user_id"`
- Username string `json:"username"`
- Email string `json:"email"`
- CreatedAt time.Time `json:"created_at"`
+ ID string `json:"id"`
+ UserID int `json:"user_id"` // 0 for guest sessions
ExpiresAt time.Time `json:"expires_at"`
- LastSeen time.Time `json:"last_seen"`
- Data map[string]any `json:"data,omitempty"`
+ Data map[string]any `json:"data"`
}
-func New(userID int, username, email string) *Session {
+// New creates a new session
+func New(userID int) *Session {
return &Session{
ID: generateID(),
UserID: userID,
- Username: username,
- Email: email,
- CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(DefaultExpiration),
- LastSeen: time.Now(),
Data: make(map[string]any),
}
}
+// IsExpired checks if the session has expired
func (s *Session) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
+// Touch extends the session expiration
func (s *Session) Touch() {
- s.LastSeen = time.Now()
s.ExpiresAt = time.Now().Add(DefaultExpiration)
}
+// Set stores a value in the session
func (s *Session) Set(key string, value any) {
- if s.Data == nil {
- s.Data = make(map[string]any)
- }
s.Data[key] = value
}
+// Get retrieves a value from the session
func (s *Session) Get(key string) (any, bool) {
- if s.Data == nil {
- return nil, false
- }
value, exists := s.Data[key]
return value, exists
}
+// Delete removes a value from the session
func (s *Session) Delete(key string) {
- if s.Data != nil {
- delete(s.Data, key)
- }
+ delete(s.Data, key)
}
+// SetFlash stores a flash message (consumed on next Get)
+func (s *Session) SetFlash(key string, value any) {
+ s.Set("flash_"+key, value)
+}
+
+// GetFlash retrieves and removes a flash message
+func (s *Session) GetFlash(key string) (any, bool) {
+ flashKey := "flash_" + key
+ value, exists := s.Get(flashKey)
+ if exists {
+ s.Delete(flashKey)
+ }
+ return value, exists
+}
+
+// generateID creates a random session ID
func generateID() string {
bytes := make([]byte, IDLength)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
+
+// Package-level convenience functions
+func Create(userID int) *Session {
+ return Manager.Create(userID)
+}
+
+func Get(sessionID string) (*Session, bool) {
+ return Manager.Get(sessionID)
+}
+
+func Store(sess *Session) {
+ Manager.Store(sess)
+}
+
+func Delete(sessionID string) {
+ Manager.Delete(sessionID)
+}
+
+func Cleanup() {
+ Manager.Cleanup()
+}
+
+func Stats() (total, active int) {
+ return Manager.Stats()
+}
+
+func Close() error {
+ return Manager.Close()
+}
diff --git a/internal/session/store.go b/internal/session/store.go
deleted file mode 100644
index 6a5a410..0000000
--- a/internal/session/store.go
+++ /dev/null
@@ -1,161 +0,0 @@
-package session
-
-import (
- "encoding/json"
- "maps"
- "os"
- "sync"
- "time"
-)
-
-type Store struct {
- mu sync.RWMutex
- sessions map[string]*Session
- filePath string
- saveInterval time.Duration
- stopChan chan struct{}
-}
-
-type persistedData struct {
- Sessions map[string]*Session `json:"sessions"`
- SavedAt time.Time `json:"saved_at"`
-}
-
-func NewStore(filePath string) *Store {
- store := &Store{
- sessions: make(map[string]*Session),
- filePath: filePath,
- saveInterval: 5 * time.Minute,
- stopChan: make(chan struct{}),
- }
-
- store.loadFromFile()
- store.startPeriodicSave()
-
- return store
-}
-
-func (s *Store) Save(session *Session) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.sessions[session.ID] = session
-}
-
-func (s *Store) Get(sessionID string) (*Session, bool) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- session, exists := s.sessions[sessionID]
- if !exists {
- return nil, false
- }
-
- if session.IsExpired() {
- return nil, false
- }
-
- return session, true
-}
-
-func (s *Store) Delete(sessionID string) {
- s.mu.Lock()
- defer s.mu.Unlock()
- delete(s.sessions, sessionID)
-}
-
-func (s *Store) Cleanup() {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- for id, session := range s.sessions {
- if session.IsExpired() {
- delete(s.sessions, id)
- }
- }
-}
-
-func (s *Store) Stats() (total, active int) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- total = len(s.sessions)
- for _, session := range s.sessions {
- if !session.IsExpired() {
- active++
- }
- }
-
- return
-}
-
-func (s *Store) loadFromFile() {
- if s.filePath == "" {
- return
- }
-
- data, err := os.ReadFile(s.filePath)
- if err != nil {
- return
- }
-
- var persisted persistedData
- if err := json.Unmarshal(data, &persisted); err != nil {
- return
- }
-
- s.mu.Lock()
- defer s.mu.Unlock()
-
- for id, session := range persisted.Sessions {
- if !session.IsExpired() {
- session.ID = id
- s.sessions[id] = session
- }
- }
-}
-
-func (s *Store) saveToFile() error {
- if s.filePath == "" {
- return nil
- }
-
- s.mu.RLock()
- sessionsCopy := make(map[string]*Session, len(s.sessions))
- maps.Copy(sessionsCopy, s.sessions)
- s.mu.RUnlock()
-
- data := persistedData{
- Sessions: sessionsCopy,
- SavedAt: time.Now(),
- }
-
- jsonData, err := json.MarshalIndent(data, "", " ")
- if err != nil {
- return err
- }
-
- return os.WriteFile(s.filePath, jsonData, 0600)
-}
-
-func (s *Store) startPeriodicSave() {
- go func() {
- ticker := time.NewTicker(s.saveInterval)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- s.Cleanup()
- s.saveToFile()
- case <-s.stopChan:
- s.saveToFile()
- return
- }
- }
- }()
-}
-
-func (s *Store) Close() error {
- close(s.stopChan)
- return s.saveToFile()
-}
diff --git a/internal/template/components/page.go b/internal/template/components/page.go
index 3274f3b..72de21a 100644
--- a/internal/template/components/page.go
+++ b/internal/template/components/page.go
@@ -3,6 +3,7 @@ package components
import (
"fmt"
"maps"
+ "runtime"
"strings"
"dk/internal/csrf"
@@ -22,15 +23,18 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin
return fmt.Errorf("failed to load layout template: %w", err)
}
+ var m runtime.MemStats
+ runtime.ReadMemStats(&m)
+
data := map[string]any{
"_title": PageTitle(title),
"authenticated": middleware.IsAuthenticated(ctx),
"csrf": csrf.HiddenField(ctx),
"_totaltime": middleware.GetRequestTime(ctx),
- "_numqueries": 0,
"_version": "1.0.0",
"_build": "dev",
"user": middleware.GetCurrentUser(ctx),
+ "_memalloc": m.Alloc / 1024 / 1024,
}
maps.Copy(data, LeftAside(ctx))
diff --git a/main.go b/main.go
index c689b86..ac44bcc 100644
--- a/main.go
+++ b/main.go
@@ -158,15 +158,14 @@ func start(port string) error {
if err != nil {
return fmt.Errorf("failed to get current working directory: %w", err)
}
- // Initialize template singleton
+
template.InitializeCache(cwd)
- // Load all model data into memory
if err := loadModels(); err != nil {
return fmt.Errorf("failed to load models: %w", err)
}
- session.Init("sessions.json") // Initialize session.Manager
+ session.Init("sessions.json")
r := router.New()
r.Use(middleware.Timing())
@@ -174,8 +173,8 @@ func start(port string) error {
r.Use(middleware.CSRF())
r.Get("/", routes.Index)
- r.Use(middleware.RequireAuth()).Get("/explore", routes.Explore)
- r.Use(middleware.RequireAuth()).Post("/move", routes.Move)
+ r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore)
+ r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move)
routes.RegisterAuthRoutes(r)
routes.RegisterTownRoutes(r)
diff --git a/templates/layout.html b/templates/layout.html
index c9b4465..5b1996f 100644
--- a/templates/layout.html
+++ b/templates/layout.html
@@ -44,7 +44,7 @@
© 2025 Sharkk
- {_totaltime} Seconds, {_numqueries} Queries
+ {_totaltime} Seconds, {_memalloc} MiB
Version {_version} {_build}