diff --git a/internal/csrf/csrf_test.go b/internal/csrf/csrf_test.go deleted file mode 100644 index d2d8f0c..0000000 --- a/internal/csrf/csrf_test.go +++ /dev/null @@ -1,167 +0,0 @@ -package csrf - -import ( - "testing" - "time" - - "dk/internal/session" - - "github.com/valyala/fasthttp" -) - -func TestGenerateToken(t *testing.T) { - sess := &session.Session{ - ID: "test-session", - UserID: 1, - Username: "testuser", - Email: "test@example.com", - CreatedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), - LastSeen: time.Now(), - Data: make(map[string]any), - } - - ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, sess) - - token := GenerateToken(ctx) - - if token == "" { - t.Error("Expected non-empty token") - } - - storedToken := GetStoredToken(sess) - if storedToken != token { - t.Errorf("Expected stored token %s, got %s", token, storedToken) - } -} - -func TestValidateToken(t *testing.T) { - sess := &session.Session{ - ID: "test-session", - UserID: 1, - Username: "testuser", - Email: "test@example.com", - Data: map[string]any{SessionKey: "test-token"}, - } - - ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, sess) - - if !ValidateToken(ctx, "test-token") { - t.Error("Expected valid token to pass validation") - } - - if ValidateToken(ctx, "wrong-token") { - t.Error("Expected invalid token to fail validation") - } - - if ValidateToken(ctx, "") { - t.Error("Expected empty token to fail validation") - } -} - -func TestValidateTokenNoSession(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - - if ValidateToken(ctx, "any-token") { - t.Error("Expected validation to fail with no session") - } -} - -func TestHiddenField(t *testing.T) { - sess := &session.Session{ - ID: "test-session", - UserID: 1, - Username: "testuser", - Email: "test@example.com", - Data: map[string]any{SessionKey: "test-token"}, - } - - ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, sess) - - field := HiddenField(ctx) - expected := `` - - if field != expected { - t.Errorf("Expected %s, got %s", expected, field) - } -} - -func TestHiddenFieldNoSession(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - - field := HiddenField(ctx) - if field == "" { - t.Error("Expected non-empty field for guest user with cookie-based token") - } -} - -func TestTokenMeta(t *testing.T) { - sess := &session.Session{ - ID: "test-session", - UserID: 1, - Username: "testuser", - Email: "test@example.com", - Data: map[string]any{SessionKey: "test-token"}, - } - - ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, sess) - - meta := TokenMeta(ctx) - expected := `` - - if meta != expected { - t.Errorf("Expected %s, got %s", expected, meta) - } -} - -func TestStoreAndGetToken(t *testing.T) { - sess := &session.Session{ - Data: make(map[string]any), - } - - token := "test-token" - StoreToken(sess, token) - - retrieved := GetStoredToken(sess) - if retrieved != token { - t.Errorf("Expected %s, got %s", token, retrieved) - } -} - -func TestGetStoredTokenNoData(t *testing.T) { - sess := &session.Session{} - - token := GetStoredToken(sess) - if token != "" { - t.Errorf("Expected empty token, got %s", token) - } -} - -func TestValidateFormToken(t *testing.T) { - sess := &session.Session{ - ID: "test-session", - UserID: 1, - Username: "testuser", - Email: "test@example.com", - Data: map[string]any{SessionKey: "test-token"}, - } - - ctx := &fasthttp.RequestCtx{} - ctx.SetUserValue(SessionCtxKey, sess) - - ctx.PostArgs().Set(TokenFieldName, "test-token") - - if !ValidateFormToken(ctx) { - t.Error("Expected form token validation to pass") - } - - ctx.PostArgs().Set(TokenFieldName, "wrong-token") - - if ValidateFormToken(ctx) { - t.Error("Expected form token validation to fail with wrong token") - } -} \ No newline at end of file diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 13b4b2f..809b38d 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -5,29 +5,43 @@ import ( "dk/internal/models/users" "dk/internal/router" "dk/internal/session" + "fmt" + "time" "github.com/valyala/fasthttp" ) +const SessionCookieName = "dk_session" + func Auth() router.Middleware { return func(next router.Handler) router.Handler { return func(ctx router.Ctx, params []string) { - sessionID := cookies.GetCookie(ctx, session.SessionCookieName) + sessionID := cookies.GetCookie(ctx, SessionCookieName) + var sess *session.Session if sessionID != "" { - if sess, exists := session.Get(sessionID); exists { - session.Update(sessionID) + if existingSess, exists := session.Get(sessionID); exists { + sess = existingSess + sess.Touch() + session.Store(sess) - user, err := users.Find(sess.UserID) - if err == nil && user != nil { - ctx.SetUserValue("session", sess) - ctx.SetUserValue("user", user) - - session.SetSessionCookie(ctx, sessionID) + if sess.UserID > 0 { // User session + user, err := users.Find(sess.UserID) + if err == nil && user != nil { + ctx.SetUserValue("user", user) + setSessionCookie(ctx, sessionID) + } } } } + // Create guest session if none exists + if sess == nil { + sess = session.Create(0) // Guest session + setSessionCookie(ctx, sess.ID) + } + + ctx.SetUserValue("session", sess) next(ctx, params) } } @@ -64,6 +78,7 @@ func RequireGuest(paths ...string) router.Middleware { return func(next router.Handler) router.Handler { return func(ctx router.Ctx, params []string) { if IsAuthenticated(ctx) { + fmt.Println("RequireGuest: user is authenticated") ctx.Redirect(redirect, fasthttp.StatusFound) return } @@ -92,21 +107,38 @@ func GetCurrentSession(ctx router.Ctx) *session.Session { } func Login(ctx router.Ctx, user *users.User) { - sess := session.Create(user.ID, user.Username, user.Email) - session.SetSessionCookie(ctx, sess.ID) + sess := session.Create(user.ID) + setSessionCookie(ctx, sess.ID) ctx.SetUserValue("session", sess) ctx.SetUserValue("user", user) } func Logout(ctx router.Ctx) { - sessionID := cookies.GetCookie(ctx, session.SessionCookieName) + sessionID := cookies.GetCookie(ctx, SessionCookieName) if sessionID != "" { session.Delete(sessionID) } - session.DeleteSessionCookie(ctx) + deleteSessionCookie(ctx) ctx.SetUserValue("session", nil) ctx.SetUserValue("user", nil) } + +// Helper functions for session cookies +func setSessionCookie(ctx router.Ctx, sessionID string) { + cookies.SetSecureCookie(ctx, cookies.CookieOptions{ + Name: SessionCookieName, + Value: sessionID, + Path: "/", + Expires: time.Now().Add(24 * time.Hour), + HTTPOnly: true, + Secure: cookies.IsHTTPS(ctx), + SameSite: "lax", + }) +} + +func deleteSessionCookie(ctx router.Ctx) { + cookies.DeleteCookie(ctx, SessionCookieName) +} diff --git a/internal/routes/auth.go b/internal/routes/auth.go index 5442ffe..3320af3 100644 --- a/internal/routes/auth.go +++ b/internal/routes/auth.go @@ -18,36 +18,39 @@ import ( // RegisterAuthRoutes sets up authentication routes func RegisterAuthRoutes(r *router.Router) { - // Guest routes - guestGroup := r.Group("") - guestGroup.Use(middleware.RequireGuest()) + guests := r.Group("") + guests.Use(middleware.RequireGuest()) - guestGroup.Get("/login", showLogin) - guestGroup.Post("/login", processLogin) - guestGroup.Get("/register", showRegister) - guestGroup.Post("/register", processRegister) + guests.Get("/login", showLogin) + guests.Post("/login", processLogin) + guests.Get("/register", showRegister) + guests.Post("/register", processRegister) - // Authenticated routes - authGroup := r.Group("") - authGroup.Use(middleware.RequireAuth()) + authed := r.Group("") + authed.Use(middleware.RequireAuth()) - authGroup.Post("/logout", processLogout) + authed.Post("/logout", processLogout) } // showLogin displays the login form func showLogin(ctx router.Ctx, _ []string) { - // Get flash message if any + sess := ctx.UserValue("session").(*session.Session) var errorHTML string - if flash := session.GetFlashMessage(ctx); flash != nil { - errorHTML = fmt.Sprintf(`
%s
`, flash.Message) + var id string + + if flash, exists := sess.GetFlash("error"); exists { + if msg, ok := flash.(string); ok { + errorHTML = fmt.Sprintf(`
%s
`, msg) + } } - // Get form data if any (for preserving email/username on error) - formData := session.GetFormData(ctx) - id := "" - if formData != nil { - id = formData["id"] + if formData, exists := sess.Get("form_data"); exists { + if data, ok := formData.(map[string]string); ok { + id = data["id"] + } } + sess.Delete("form_data") + session.Store(sess) components.RenderPage(ctx, "Log In", "auth/login.html", map[string]any{ "error_message": errorHTML, @@ -67,26 +70,30 @@ func processLogin(ctx router.Ctx, _ []string) { userPassword := string(ctx.PostArgs().Peek("password")) if email == "" || userPassword == "" { - session.SetFlashMessage(ctx, "error", "Email and password are required") - session.SetFormData(ctx, map[string]string{"id": email}) + setFlashAndFormData(ctx, "Email and password are required", map[string]string{"id": email}) ctx.Redirect("/login", fasthttp.StatusFound) return } user, err := auth.Authenticate(email, userPassword) if err != nil { - session.SetFlashMessage(ctx, "error", "Invalid email or password") - session.SetFormData(ctx, map[string]string{"id": email}) + setFlashAndFormData(ctx, "Invalid email or password", map[string]string{"id": email}) ctx.Redirect("/login", fasthttp.StatusFound) return } middleware.Login(ctx, user) + // Set success message + if sess := ctx.UserValue("session").(*session.Session); sess != nil { + sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username)) + session.Store(sess) + } + // Transfer CSRF token from cookie to session for authenticated user if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { - if session := csrf.GetCurrentSession(ctx); session != nil { - csrf.StoreToken(session, cookieToken) + if sess := ctx.UserValue("session").(*session.Session); sess != nil { + csrf.StoreToken(sess, cookieToken) } } @@ -95,20 +102,24 @@ func processLogin(ctx router.Ctx, _ []string) { // showRegister displays the registration form func showRegister(ctx router.Ctx, _ []string) { - // Get flash message if any + sess := ctx.UserValue("session").(*session.Session) var errorHTML string - if flash := session.GetFlashMessage(ctx); flash != nil { - errorHTML = fmt.Sprintf(`
%s
`, flash.Message) + var username, email string + + if flash, exists := sess.GetFlash("error"); exists { + if msg, ok := flash.(string); ok { + errorHTML = fmt.Sprintf(`
%s
`, msg) + } } - // Get form data if any (for preserving values on error) - formData := session.GetFormData(ctx) - username := "" - email := "" - if formData != nil { - username = formData["username"] - email = formData["email"] + if formData, exists := sess.Get("form_data"); exists { + if data, ok := formData.(map[string]string); ok { + username = data["username"] + email = data["email"] + } } + sess.Delete("form_data") + session.Store(sess) components.RenderPage(ctx, "Register", "auth/register.html", map[string]any{ "error_message": errorHTML, @@ -130,32 +141,25 @@ func processRegister(ctx router.Ctx, _ []string) { userPassword := string(ctx.PostArgs().Peek("password")) confirmPassword := string(ctx.PostArgs().Peek("confirm_password")) + formData := map[string]string{ + "username": username, + "email": email, + } + if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil { - session.SetFlashMessage(ctx, "error", err.Error()) - session.SetFormData(ctx, map[string]string{ - "username": username, - "email": email, - }) + setFlashAndFormData(ctx, err.Error(), formData) ctx.Redirect("/register", fasthttp.StatusFound) return } if _, err := users.ByUsername(username); err == nil { - session.SetFlashMessage(ctx, "error", "Username already exists") - session.SetFormData(ctx, map[string]string{ - "username": username, - "email": email, - }) + setFlashAndFormData(ctx, "Username already exists", formData) ctx.Redirect("/register", fasthttp.StatusFound) return } if _, err := users.ByEmail(email); err == nil { - session.SetFlashMessage(ctx, "error", "Email already registered") - session.SetFormData(ctx, map[string]string{ - "username": username, - "email": email, - }) + setFlashAndFormData(ctx, "Email already registered", formData) ctx.Redirect("/register", fasthttp.StatusFound) return } @@ -168,11 +172,7 @@ func processRegister(ctx router.Ctx, _ []string) { user.Auth = 1 if err := user.Insert(); err != nil { - session.SetFlashMessage(ctx, "error", "Failed to create account") - session.SetFormData(ctx, map[string]string{ - "username": username, - "email": email, - }) + setFlashAndFormData(ctx, "Failed to create account", formData) ctx.Redirect("/register", fasthttp.StatusFound) return } @@ -180,10 +180,16 @@ func processRegister(ctx router.Ctx, _ []string) { // Auto-login after registration middleware.Login(ctx, user) + // Set success message + if sess := ctx.UserValue("session").(*session.Session); sess != nil { + sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username)) + session.Store(sess) + } + // Transfer CSRF token from cookie to session for authenticated user if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { - if session := csrf.GetCurrentSession(ctx); session != nil { - csrf.StoreToken(session, cookieToken) + if sess := ctx.UserValue("session").(*session.Session); sess != nil { + csrf.StoreToken(sess, cookieToken) } } @@ -229,3 +235,10 @@ func validateRegistration(username, email, password, confirmPassword string) err } return nil } + +func setFlashAndFormData(ctx router.Ctx, message string, formData map[string]string) { + sess := ctx.UserValue("session").(*session.Session) + sess.SetFlash("error", message) + sess.Set("form_data", formData) + session.Store(sess) +} diff --git a/internal/routes/town.go b/internal/routes/town.go index cbd713b..2ed670d 100644 --- a/internal/routes/town.go +++ b/internal/routes/town.go @@ -49,9 +49,13 @@ func showTown(ctx router.Ctx, _ []string) { } func showInn(ctx router.Ctx, _ []string) { + sess := ctx.UserValue("session").(*session.Session) var errorHTML string - if flash := session.GetFlashMessage(ctx); flash != nil { - errorHTML = `
` + flash.Message + "
" + + if flash, exists := sess.GetFlash("error"); exists { + if msg, ok := flash.(string); ok { + errorHTML = `
` + msg + "
" + } } town := ctx.UserValue("town").(*towns.Town) @@ -64,11 +68,12 @@ func showInn(ctx router.Ctx, _ []string) { } func rest(ctx router.Ctx, _ []string) { + sess := ctx.UserValue("session").(*session.Session) town := ctx.UserValue("town").(*towns.Town) user := ctx.UserValue("user").(*users.User) if user.Gold < town.InnCost { - session.SetFlashMessage(ctx, "error", "You can't afford to stay here tonight.") + sess.SetFlash("error", "You can't afford to stay here tonight.") ctx.Redirect("/town/inn", 303) return } @@ -83,9 +88,13 @@ func rest(ctx router.Ctx, _ []string) { } func showShop(ctx router.Ctx, _ []string) { + sess := ctx.UserValue("session").(*session.Session) var errorHTML string - if flash := session.GetFlashMessage(ctx); flash != nil { - errorHTML = `
` + flash.Message + "
" + + if flash, exists := sess.GetFlash("error"); exists { + if msg, ok := flash.(string); ok { + errorHTML = `
` + msg + "
" + } } town := ctx.UserValue("town").(*towns.Town) @@ -109,30 +118,32 @@ func showShop(ctx router.Ctx, _ []string) { } func buyItem(ctx router.Ctx, params []string) { + sess := ctx.UserValue("session").(*session.Session) + id, err := strconv.Atoi(params[0]) if err != nil { - session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error()) + sess.SetFlash("error", "Error purchasing item; "+err.Error()) ctx.Redirect("/town/shop", 302) return } town := ctx.UserValue("town").(*towns.Town) if !slices.Contains(town.GetShopItems(), id) { - session.SetFlashMessage(ctx, "error", "The item doesn't exist in this shop.") + sess.SetFlash("error", "The item doesn't exist in this shop.") ctx.Redirect("/town/shop", 302) return } item, err := items.Find(id) if err != nil { - session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error()) + sess.SetFlash("error", "Error purchasing item; "+err.Error()) ctx.Redirect("/town/shop", 302) return } user := ctx.UserValue("user").(*users.User) if user.Gold < item.Value { - session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy "+item.Name) + sess.SetFlash("error", "You don't have enough gold to buy "+item.Name) ctx.Redirect("/town/shop", 302) return } @@ -145,9 +156,13 @@ func buyItem(ctx router.Ctx, params []string) { } func showMaps(ctx router.Ctx, _ []string) { + sess := ctx.UserValue("session").(*session.Session) var errorHTML string - if flash := session.GetFlashMessage(ctx); flash != nil { - errorHTML = `
` + flash.Message + "
" + + if flash, exists := sess.GetFlash("error"); exists { + if msg, ok := flash.(string); ok { + errorHTML = `
` + msg + "
" + } } town := ctx.UserValue("town").(*towns.Town) @@ -186,23 +201,25 @@ func showMaps(ctx router.Ctx, _ []string) { } func buyMap(ctx router.Ctx, params []string) { + sess := ctx.UserValue("session").(*session.Session) + id, err := strconv.Atoi(params[0]) if err != nil { - session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error()) + sess.SetFlash("error", "Error purchasing map; "+err.Error()) ctx.Redirect("/town/maps", 302) return } mapped, err := towns.Find(id) if err != nil { - session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error()) + sess.SetFlash("error", "Error purchasing map; "+err.Error()) ctx.Redirect("/town/maps", 302) return } user := ctx.UserValue("user").(*users.User) if user.Gold < mapped.MapCost { - session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy the map to "+mapped.Name) + sess.SetFlash("error", "You don't have enough gold to buy the map to "+mapped.Name) ctx.Redirect("/town/maps", 302) return } diff --git a/internal/session/flash.go b/internal/session/flash.go deleted file mode 100644 index 9c96036..0000000 --- a/internal/session/flash.go +++ /dev/null @@ -1,56 +0,0 @@ -package session - -type FlashMessage struct { - Type string `json:"type"` - Message string `json:"message"` -} - -func (s *Session) SetFlash(key string, value any) { - if s.Data == nil { - s.Data = make(map[string]any) - } - - flashData, ok := s.Data["_flash"].(map[string]any) - if !ok { - flashData = make(map[string]any) - } - flashData[key] = value - s.Data["_flash"] = flashData -} - -func (s *Session) GetFlash(key string) (any, bool) { - if s.Data == nil { - return nil, false - } - - flashData, ok := s.Data["_flash"].(map[string]any) - if !ok { - return nil, false - } - - value, exists := flashData[key] - if exists { - delete(flashData, key) - if len(flashData) == 0 { - delete(s.Data, "_flash") - } else { - s.Data["_flash"] = flashData - } - } - - return value, exists -} - -func (s *Session) GetAllFlash() map[string]any { - if s.Data == nil { - return nil - } - - flashData, ok := s.Data["_flash"].(map[string]any) - if !ok { - return nil - } - - delete(s.Data, "_flash") - return flashData -} \ No newline at end of file diff --git a/internal/session/manager.go b/internal/session/manager.go index 6a47590..69ab098 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -1,28 +1,35 @@ package session import ( - "dk/internal/cookies" - "dk/internal/router" - "time" + "encoding/json" + "os" + "sync" ) -const SessionCookieName = "dk_session" +// SessionManager handles session storage and persistence +type SessionManager struct { + mu sync.RWMutex + sessions map[string]*Session + filePath string +} var Manager *SessionManager -type SessionManager struct { - store *Store -} - -func Init(sessionsFilePath string) { +// Init initializes the global session manager +func Init(filePath string) { if Manager != nil { panic("session manager already initialized") } + Manager = &SessionManager{ - store: NewStore(sessionsFilePath), + sessions: make(map[string]*Session), + filePath: filePath, } + + Manager.load() } +// GetManager returns the global session manager func GetManager() *SessionManager { if Manager == nil { panic("session manager not initialized") @@ -30,200 +37,116 @@ func GetManager() *SessionManager { return Manager } -func (sm *SessionManager) Create(userID int, username, email string) *Session { - sess := New(userID, username, email) - sm.store.Save(sess) +// Create creates and stores a new session +func (sm *SessionManager) Create(userID int) *Session { + sess := New(userID) + sm.mu.Lock() + sm.sessions[sess.ID] = sess + sm.mu.Unlock() return sess } +// Get retrieves a session by ID func (sm *SessionManager) Get(sessionID string) (*Session, bool) { - return sm.store.Get(sessionID) -} + sm.mu.RLock() + sess, exists := sm.sessions[sessionID] + sm.mu.RUnlock() -func (sm *SessionManager) GetFromContext(ctx router.Ctx) (*Session, bool) { - sessionID := cookies.GetCookie(ctx, SessionCookieName) - if sessionID == "" { + if !exists || sess.IsExpired() { + if exists { + sm.Delete(sessionID) + } return nil, false } - return sm.Get(sessionID) + + return sess, true } -func (sm *SessionManager) Update(sessionID string) bool { - sess, exists := sm.store.Get(sessionID) - if !exists { - return false - } - - sess.Touch() - sm.store.Save(sess) - return true +// Store saves a session in memory (updates existing or creates new) +func (sm *SessionManager) Store(sess *Session) { + sm.mu.Lock() + sm.sessions[sess.ID] = sess + sm.mu.Unlock() } +// Delete removes a session func (sm *SessionManager) Delete(sessionID string) { - sm.store.Delete(sessionID) + sm.mu.Lock() + delete(sm.sessions, sessionID) + sm.mu.Unlock() } -func (sm *SessionManager) SetSessionCookie(ctx router.Ctx, sessionID string) { - cookies.SetSecureCookie(ctx, cookies.CookieOptions{ - Name: SessionCookieName, - Value: sessionID, - Path: "/", - Expires: time.Now().Add(DefaultExpiration), - HTTPOnly: true, - Secure: cookies.IsHTTPS(ctx), - SameSite: "lax", - }) -} - -func (sm *SessionManager) DeleteSessionCookie(ctx router.Ctx) { - cookies.DeleteCookie(ctx, SessionCookieName) -} - -func (sm *SessionManager) SetFlashMessage(ctx router.Ctx, msgType, message string) bool { - sess, exists := sm.GetFromContext(ctx) - if !exists { - return false - } - - sess.SetFlash("message", FlashMessage{ - Type: msgType, - Message: message, - }) - sm.store.Save(sess) - return true -} - -func (sm *SessionManager) GetFlashMessage(ctx router.Ctx) *FlashMessage { - sess, exists := sm.GetFromContext(ctx) - if !exists { - return nil - } - - value, exists := sess.GetFlash("message") - if !exists { - return nil - } - - sm.store.Save(sess) - - if msg, ok := value.(FlashMessage); ok { - return &msg - } - - if msgMap, ok := value.(map[string]interface{}); ok { - msg := &FlashMessage{} - if t, ok := msgMap["type"].(string); ok { - msg.Type = t +// Cleanup removes expired sessions +func (sm *SessionManager) Cleanup() { + sm.mu.Lock() + for id, sess := range sm.sessions { + if sess.IsExpired() { + delete(sm.sessions, id) } - if m, ok := msgMap["message"].(string); ok { - msg.Message = m - } - return msg } - - return nil -} - -func (sm *SessionManager) SetFormData(ctx router.Ctx, data map[string]string) bool { - sess, exists := sm.GetFromContext(ctx) - if !exists { - return false - } - - sess.Set("form_data", data) - sm.store.Save(sess) - return true -} - -func (sm *SessionManager) GetFormData(ctx router.Ctx) map[string]string { - sess, exists := sm.GetFromContext(ctx) - if !exists { - return nil - } - - value, exists := sess.Get("form_data") - if !exists { - return nil - } - - sess.Delete("form_data") - sm.store.Save(sess) - - if formData, ok := value.(map[string]string); ok { - return formData - } - - if formMap, ok := value.(map[string]interface{}); ok { - result := make(map[string]string) - for k, v := range formMap { - if str, ok := v.(string); ok { - result[k] = str - } - } - return result - } - - return nil + sm.mu.Unlock() } +// Stats returns session statistics func (sm *SessionManager) Stats() (total, active int) { - return sm.store.Stats() + sm.mu.RLock() + defer sm.mu.RUnlock() + + total = len(sm.sessions) + for _, sess := range sm.sessions { + if !sess.IsExpired() { + active++ + } + } + return } +// load reads sessions from the JSON file +func (sm *SessionManager) load() { + if sm.filePath == "" { + return + } + + data, err := os.ReadFile(sm.filePath) + if err != nil { + return // File doesn't exist or can't be read + } + + var sessions map[string]*Session + if err := json.Unmarshal(data, &sessions); err != nil { + return // Invalid JSON + } + + sm.mu.Lock() + for id, sess := range sessions { + if sess != nil && !sess.IsExpired() { + sess.ID = id // Ensure ID consistency + sm.sessions[id] = sess + } + } + sm.mu.Unlock() +} + +// Save writes sessions to the JSON file +func (sm *SessionManager) Save() error { + if sm.filePath == "" { + return nil + } + + sm.Cleanup() // Remove expired sessions before saving + + sm.mu.RLock() + data, err := json.MarshalIndent(sm.sessions, "", "\t") + sm.mu.RUnlock() + + if err != nil { + return err + } + + return os.WriteFile(sm.filePath, data, 0600) +} + +// Close saves sessions and cleans up func (sm *SessionManager) Close() error { - return sm.store.Close() + return sm.Save() } - -// Package-level convenience functions that use the global Manager - -func Create(userID int, username, email string) *Session { - return Manager.Create(userID, username, email) -} - -func Get(sessionID string) (*Session, bool) { - return Manager.Get(sessionID) -} - -func GetFromContext(ctx router.Ctx) (*Session, bool) { - return Manager.GetFromContext(ctx) -} - -func Update(sessionID string) bool { - return Manager.Update(sessionID) -} - -func Delete(sessionID string) { - Manager.Delete(sessionID) -} - -func SetSessionCookie(ctx router.Ctx, sessionID string) { - Manager.SetSessionCookie(ctx, sessionID) -} - -func DeleteSessionCookie(ctx router.Ctx) { - Manager.DeleteSessionCookie(ctx) -} - -func SetFlashMessage(ctx router.Ctx, msgType, message string) bool { - return Manager.SetFlashMessage(ctx, msgType, message) -} - -func GetFlashMessage(ctx router.Ctx) *FlashMessage { - return Manager.GetFlashMessage(ctx) -} - -func SetFormData(ctx router.Ctx, data map[string]string) bool { - return Manager.SetFormData(ctx, data) -} - -func GetFormData(ctx router.Ctx) map[string]string { - return Manager.GetFormData(ctx) -} - -func Stats() (total, active int) { - return Manager.Stats() -} - -func Close() error { - return Manager.Close() -} \ No newline at end of file diff --git a/internal/session/session.go b/internal/session/session.go index 96fbc2a..65c3d99 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -1,5 +1,4 @@ -// Package session provides session management functionality. -// It includes session storage, flash messages, and data persistence. +// session.go package session import ( @@ -13,62 +12,97 @@ const ( IDLength = 32 ) +// Session represents a user session type Session struct { - ID string `json:"-"` - UserID int `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` + UserID int `json:"user_id"` // 0 for guest sessions ExpiresAt time.Time `json:"expires_at"` - LastSeen time.Time `json:"last_seen"` - Data map[string]any `json:"data,omitempty"` + Data map[string]any `json:"data"` } -func New(userID int, username, email string) *Session { +// New creates a new session +func New(userID int) *Session { return &Session{ ID: generateID(), UserID: userID, - Username: username, - Email: email, - CreatedAt: time.Now(), ExpiresAt: time.Now().Add(DefaultExpiration), - LastSeen: time.Now(), Data: make(map[string]any), } } +// IsExpired checks if the session has expired func (s *Session) IsExpired() bool { return time.Now().After(s.ExpiresAt) } +// Touch extends the session expiration func (s *Session) Touch() { - s.LastSeen = time.Now() s.ExpiresAt = time.Now().Add(DefaultExpiration) } +// Set stores a value in the session func (s *Session) Set(key string, value any) { - if s.Data == nil { - s.Data = make(map[string]any) - } s.Data[key] = value } +// Get retrieves a value from the session func (s *Session) Get(key string) (any, bool) { - if s.Data == nil { - return nil, false - } value, exists := s.Data[key] return value, exists } +// Delete removes a value from the session func (s *Session) Delete(key string) { - if s.Data != nil { - delete(s.Data, key) - } + delete(s.Data, key) } +// SetFlash stores a flash message (consumed on next Get) +func (s *Session) SetFlash(key string, value any) { + s.Set("flash_"+key, value) +} + +// GetFlash retrieves and removes a flash message +func (s *Session) GetFlash(key string) (any, bool) { + flashKey := "flash_" + key + value, exists := s.Get(flashKey) + if exists { + s.Delete(flashKey) + } + return value, exists +} + +// generateID creates a random session ID func generateID() string { bytes := make([]byte, IDLength) rand.Read(bytes) return hex.EncodeToString(bytes) } + +// Package-level convenience functions +func Create(userID int) *Session { + return Manager.Create(userID) +} + +func Get(sessionID string) (*Session, bool) { + return Manager.Get(sessionID) +} + +func Store(sess *Session) { + Manager.Store(sess) +} + +func Delete(sessionID string) { + Manager.Delete(sessionID) +} + +func Cleanup() { + Manager.Cleanup() +} + +func Stats() (total, active int) { + return Manager.Stats() +} + +func Close() error { + return Manager.Close() +} diff --git a/internal/session/store.go b/internal/session/store.go deleted file mode 100644 index 6a5a410..0000000 --- a/internal/session/store.go +++ /dev/null @@ -1,161 +0,0 @@ -package session - -import ( - "encoding/json" - "maps" - "os" - "sync" - "time" -) - -type Store struct { - mu sync.RWMutex - sessions map[string]*Session - filePath string - saveInterval time.Duration - stopChan chan struct{} -} - -type persistedData struct { - Sessions map[string]*Session `json:"sessions"` - SavedAt time.Time `json:"saved_at"` -} - -func NewStore(filePath string) *Store { - store := &Store{ - sessions: make(map[string]*Session), - filePath: filePath, - saveInterval: 5 * time.Minute, - stopChan: make(chan struct{}), - } - - store.loadFromFile() - store.startPeriodicSave() - - return store -} - -func (s *Store) Save(session *Session) { - s.mu.Lock() - defer s.mu.Unlock() - s.sessions[session.ID] = session -} - -func (s *Store) Get(sessionID string) (*Session, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - - session, exists := s.sessions[sessionID] - if !exists { - return nil, false - } - - if session.IsExpired() { - return nil, false - } - - return session, true -} - -func (s *Store) Delete(sessionID string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.sessions, sessionID) -} - -func (s *Store) Cleanup() { - s.mu.Lock() - defer s.mu.Unlock() - - for id, session := range s.sessions { - if session.IsExpired() { - delete(s.sessions, id) - } - } -} - -func (s *Store) Stats() (total, active int) { - s.mu.RLock() - defer s.mu.RUnlock() - - total = len(s.sessions) - for _, session := range s.sessions { - if !session.IsExpired() { - active++ - } - } - - return -} - -func (s *Store) loadFromFile() { - if s.filePath == "" { - return - } - - data, err := os.ReadFile(s.filePath) - if err != nil { - return - } - - var persisted persistedData - if err := json.Unmarshal(data, &persisted); err != nil { - return - } - - s.mu.Lock() - defer s.mu.Unlock() - - for id, session := range persisted.Sessions { - if !session.IsExpired() { - session.ID = id - s.sessions[id] = session - } - } -} - -func (s *Store) saveToFile() error { - if s.filePath == "" { - return nil - } - - s.mu.RLock() - sessionsCopy := make(map[string]*Session, len(s.sessions)) - maps.Copy(sessionsCopy, s.sessions) - s.mu.RUnlock() - - data := persistedData{ - Sessions: sessionsCopy, - SavedAt: time.Now(), - } - - jsonData, err := json.MarshalIndent(data, "", " ") - if err != nil { - return err - } - - return os.WriteFile(s.filePath, jsonData, 0600) -} - -func (s *Store) startPeriodicSave() { - go func() { - ticker := time.NewTicker(s.saveInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - s.Cleanup() - s.saveToFile() - case <-s.stopChan: - s.saveToFile() - return - } - } - }() -} - -func (s *Store) Close() error { - close(s.stopChan) - return s.saveToFile() -} diff --git a/internal/template/components/page.go b/internal/template/components/page.go index 3274f3b..72de21a 100644 --- a/internal/template/components/page.go +++ b/internal/template/components/page.go @@ -3,6 +3,7 @@ package components import ( "fmt" "maps" + "runtime" "strings" "dk/internal/csrf" @@ -22,15 +23,18 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin return fmt.Errorf("failed to load layout template: %w", err) } + var m runtime.MemStats + runtime.ReadMemStats(&m) + data := map[string]any{ "_title": PageTitle(title), "authenticated": middleware.IsAuthenticated(ctx), "csrf": csrf.HiddenField(ctx), "_totaltime": middleware.GetRequestTime(ctx), - "_numqueries": 0, "_version": "1.0.0", "_build": "dev", "user": middleware.GetCurrentUser(ctx), + "_memalloc": m.Alloc / 1024 / 1024, } maps.Copy(data, LeftAside(ctx)) diff --git a/main.go b/main.go index c689b86..ac44bcc 100644 --- a/main.go +++ b/main.go @@ -158,15 +158,14 @@ func start(port string) error { if err != nil { return fmt.Errorf("failed to get current working directory: %w", err) } - // Initialize template singleton + template.InitializeCache(cwd) - // Load all model data into memory if err := loadModels(); err != nil { return fmt.Errorf("failed to load models: %w", err) } - session.Init("sessions.json") // Initialize session.Manager + session.Init("sessions.json") r := router.New() r.Use(middleware.Timing()) @@ -174,8 +173,8 @@ func start(port string) error { r.Use(middleware.CSRF()) r.Get("/", routes.Index) - r.Use(middleware.RequireAuth()).Get("/explore", routes.Explore) - r.Use(middleware.RequireAuth()).Post("/move", routes.Move) + r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore) + r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move) routes.RegisterAuthRoutes(r) routes.RegisterTownRoutes(r) diff --git a/templates/layout.html b/templates/layout.html index c9b4465..5b1996f 100644 --- a/templates/layout.html +++ b/templates/layout.html @@ -44,7 +44,7 @@