Compare commits

..

No commits in common. "4a5f0debf68d615806ef1d6e146e785b4086632c" and "82ef4b31d486780f0cada81b5d83c0eb0814a9b0" have entirely different histories.

12 changed files with 683 additions and 321 deletions

167
internal/csrf/csrf_test.go Normal file
View File

@ -0,0 +1,167 @@
package csrf
import (
"testing"
"time"
"dk/internal/session"
"github.com/valyala/fasthttp"
)
func TestGenerateToken(t *testing.T) {
sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
LastSeen: time.Now(),
Data: make(map[string]any),
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, sess)
token := GenerateToken(ctx)
if token == "" {
t.Error("Expected non-empty token")
}
storedToken := GetStoredToken(sess)
if storedToken != token {
t.Errorf("Expected stored token %s, got %s", token, storedToken)
}
}
func TestValidateToken(t *testing.T) {
sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, sess)
if !ValidateToken(ctx, "test-token") {
t.Error("Expected valid token to pass validation")
}
if ValidateToken(ctx, "wrong-token") {
t.Error("Expected invalid token to fail validation")
}
if ValidateToken(ctx, "") {
t.Error("Expected empty token to fail validation")
}
}
func TestValidateTokenNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
if ValidateToken(ctx, "any-token") {
t.Error("Expected validation to fail with no session")
}
}
func TestHiddenField(t *testing.T) {
sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, sess)
field := HiddenField(ctx)
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
if field != expected {
t.Errorf("Expected %s, got %s", expected, field)
}
}
func TestHiddenFieldNoSession(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
field := HiddenField(ctx)
if field == "" {
t.Error("Expected non-empty field for guest user with cookie-based token")
}
}
func TestTokenMeta(t *testing.T) {
sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, sess)
meta := TokenMeta(ctx)
expected := `<meta name="csrf-token" content="test-token">`
if meta != expected {
t.Errorf("Expected %s, got %s", expected, meta)
}
}
func TestStoreAndGetToken(t *testing.T) {
sess := &session.Session{
Data: make(map[string]any),
}
token := "test-token"
StoreToken(sess, token)
retrieved := GetStoredToken(sess)
if retrieved != token {
t.Errorf("Expected %s, got %s", token, retrieved)
}
}
func TestGetStoredTokenNoData(t *testing.T) {
sess := &session.Session{}
token := GetStoredToken(sess)
if token != "" {
t.Errorf("Expected empty token, got %s", token)
}
}
func TestValidateFormToken(t *testing.T) {
sess := &session.Session{
ID: "test-session",
UserID: 1,
Username: "testuser",
Email: "test@example.com",
Data: map[string]any{SessionKey: "test-token"},
}
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue(SessionCtxKey, sess)
ctx.PostArgs().Set(TokenFieldName, "test-token")
if !ValidateFormToken(ctx) {
t.Error("Expected form token validation to pass")
}
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
if ValidateFormToken(ctx) {
t.Error("Expected form token validation to fail with wrong token")
}
}

View File

@ -5,43 +5,29 @@ import (
"dk/internal/models/users" "dk/internal/models/users"
"dk/internal/router" "dk/internal/router"
"dk/internal/session" "dk/internal/session"
"fmt"
"time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
const SessionCookieName = "dk_session"
func Auth() router.Middleware { func Auth() router.Middleware {
return func(next router.Handler) router.Handler { return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
sessionID := cookies.GetCookie(ctx, SessionCookieName) sessionID := cookies.GetCookie(ctx, session.SessionCookieName)
var sess *session.Session
if sessionID != "" { if sessionID != "" {
if existingSess, exists := session.Get(sessionID); exists { if sess, exists := session.Get(sessionID); exists {
sess = existingSess session.Update(sessionID)
sess.Touch()
session.Store(sess)
if sess.UserID > 0 { // User session user, err := users.Find(sess.UserID)
user, err := users.Find(sess.UserID) if err == nil && user != nil {
if err == nil && user != nil { ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user) ctx.SetUserValue("user", user)
setSessionCookie(ctx, sessionID)
} session.SetSessionCookie(ctx, sessionID)
} }
} }
} }
// Create guest session if none exists
if sess == nil {
sess = session.Create(0) // Guest session
setSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue("session", sess)
next(ctx, params) next(ctx, params)
} }
} }
@ -78,7 +64,6 @@ func RequireGuest(paths ...string) router.Middleware {
return func(next router.Handler) router.Handler { return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) { return func(ctx router.Ctx, params []string) {
if IsAuthenticated(ctx) { if IsAuthenticated(ctx) {
fmt.Println("RequireGuest: user is authenticated")
ctx.Redirect(redirect, fasthttp.StatusFound) ctx.Redirect(redirect, fasthttp.StatusFound)
return return
} }
@ -107,38 +92,21 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
} }
func Login(ctx router.Ctx, user *users.User) { func Login(ctx router.Ctx, user *users.User) {
sess := session.Create(user.ID) sess := session.Create(user.ID, user.Username, user.Email)
setSessionCookie(ctx, sess.ID) session.SetSessionCookie(ctx, sess.ID)
ctx.SetUserValue("session", sess) ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user) ctx.SetUserValue("user", user)
} }
func Logout(ctx router.Ctx) { func Logout(ctx router.Ctx) {
sessionID := cookies.GetCookie(ctx, SessionCookieName) sessionID := cookies.GetCookie(ctx, session.SessionCookieName)
if sessionID != "" { if sessionID != "" {
session.Delete(sessionID) session.Delete(sessionID)
} }
deleteSessionCookie(ctx) session.DeleteSessionCookie(ctx)
ctx.SetUserValue("session", nil) ctx.SetUserValue("session", nil)
ctx.SetUserValue("user", nil) ctx.SetUserValue("user", nil)
} }
// Helper functions for session cookies
func setSessionCookie(ctx router.Ctx, sessionID string) {
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
Name: SessionCookieName,
Value: sessionID,
Path: "/",
Expires: time.Now().Add(24 * time.Hour),
HTTPOnly: true,
Secure: cookies.IsHTTPS(ctx),
SameSite: "lax",
})
}
func deleteSessionCookie(ctx router.Ctx) {
cookies.DeleteCookie(ctx, SessionCookieName)
}

View File

@ -87,19 +87,19 @@ func New() *User {
ClassID: 1, ClassID: 1,
Currently: "In Town", Currently: "In Town",
Fighting: 0, Fighting: 0,
HP: 10, HP: 15,
MP: 10, MP: 0,
TP: 10, TP: 10,
MaxHP: 10, MaxHP: 15,
MaxMP: 10, MaxMP: 0,
MaxTP: 10, MaxTP: 10,
Level: 1, Level: 1,
Gold: 100, Gold: 100,
Exp: 0, Exp: 0,
Strength: 0, Strength: 5,
Dexterity: 0, Dexterity: 5,
Attack: 0, Attack: 5,
Defense: 0, Defense: 5,
Spells: "", Spells: "",
Towns: "", Towns: "",
} }

View File

@ -18,39 +18,36 @@ import (
// RegisterAuthRoutes sets up authentication routes // RegisterAuthRoutes sets up authentication routes
func RegisterAuthRoutes(r *router.Router) { func RegisterAuthRoutes(r *router.Router) {
guests := r.Group("") // Guest routes
guests.Use(middleware.RequireGuest()) guestGroup := r.Group("")
guestGroup.Use(middleware.RequireGuest())
guests.Get("/login", showLogin) guestGroup.Get("/login", showLogin)
guests.Post("/login", processLogin) guestGroup.Post("/login", processLogin)
guests.Get("/register", showRegister) guestGroup.Get("/register", showRegister)
guests.Post("/register", processRegister) guestGroup.Post("/register", processRegister)
authed := r.Group("") // Authenticated routes
authed.Use(middleware.RequireAuth()) authGroup := r.Group("")
authGroup.Use(middleware.RequireAuth())
authed.Post("/logout", processLogout) authGroup.Post("/logout", processLogout)
} }
// showLogin displays the login form // showLogin displays the login form
func showLogin(ctx router.Ctx, _ []string) { func showLogin(ctx router.Ctx, _ []string) {
sess := ctx.UserValue("session").(*session.Session) // Get flash message if any
var errorHTML string var errorHTML string
var id string if flash := session.GetFlashMessage(ctx); flash != nil {
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, flash.Message)
if flash, exists := sess.GetFlash("error"); exists {
if msg, ok := flash.(string); ok {
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, msg)
}
} }
if formData, exists := sess.Get("form_data"); exists { // Get form data if any (for preserving email/username on error)
if data, ok := formData.(map[string]string); ok { formData := session.GetFormData(ctx)
id = data["id"] id := ""
} if formData != nil {
id = formData["id"]
} }
sess.Delete("form_data")
session.Store(sess)
components.RenderPage(ctx, "Log In", "auth/login.html", map[string]any{ components.RenderPage(ctx, "Log In", "auth/login.html", map[string]any{
"error_message": errorHTML, "error_message": errorHTML,
@ -70,30 +67,26 @@ func processLogin(ctx router.Ctx, _ []string) {
userPassword := string(ctx.PostArgs().Peek("password")) userPassword := string(ctx.PostArgs().Peek("password"))
if email == "" || userPassword == "" { if email == "" || userPassword == "" {
setFlashAndFormData(ctx, "Email and password are required", map[string]string{"id": email}) session.SetFlashMessage(ctx, "error", "Email and password are required")
session.SetFormData(ctx, map[string]string{"id": email})
ctx.Redirect("/login", fasthttp.StatusFound) ctx.Redirect("/login", fasthttp.StatusFound)
return return
} }
user, err := auth.Authenticate(email, userPassword) user, err := auth.Authenticate(email, userPassword)
if err != nil { if err != nil {
setFlashAndFormData(ctx, "Invalid email or password", map[string]string{"id": email}) session.SetFlashMessage(ctx, "error", "Invalid email or password")
session.SetFormData(ctx, map[string]string{"id": email})
ctx.Redirect("/login", fasthttp.StatusFound) ctx.Redirect("/login", fasthttp.StatusFound)
return return
} }
middleware.Login(ctx, user) middleware.Login(ctx, user)
// Set success message
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
session.Store(sess)
}
// Transfer CSRF token from cookie to session for authenticated user // Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if sess := ctx.UserValue("session").(*session.Session); sess != nil { if session := csrf.GetCurrentSession(ctx); session != nil {
csrf.StoreToken(sess, cookieToken) csrf.StoreToken(session, cookieToken)
} }
} }
@ -102,24 +95,20 @@ func processLogin(ctx router.Ctx, _ []string) {
// showRegister displays the registration form // showRegister displays the registration form
func showRegister(ctx router.Ctx, _ []string) { func showRegister(ctx router.Ctx, _ []string) {
sess := ctx.UserValue("session").(*session.Session) // Get flash message if any
var errorHTML string var errorHTML string
var username, email string if flash := session.GetFlashMessage(ctx); flash != nil {
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, flash.Message)
if flash, exists := sess.GetFlash("error"); exists {
if msg, ok := flash.(string); ok {
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, msg)
}
} }
if formData, exists := sess.Get("form_data"); exists { // Get form data if any (for preserving values on error)
if data, ok := formData.(map[string]string); ok { formData := session.GetFormData(ctx)
username = data["username"] username := ""
email = data["email"] email := ""
} if formData != nil {
username = formData["username"]
email = formData["email"]
} }
sess.Delete("form_data")
session.Store(sess)
components.RenderPage(ctx, "Register", "auth/register.html", map[string]any{ components.RenderPage(ctx, "Register", "auth/register.html", map[string]any{
"error_message": errorHTML, "error_message": errorHTML,
@ -141,25 +130,32 @@ func processRegister(ctx router.Ctx, _ []string) {
userPassword := string(ctx.PostArgs().Peek("password")) userPassword := string(ctx.PostArgs().Peek("password"))
confirmPassword := string(ctx.PostArgs().Peek("confirm_password")) confirmPassword := string(ctx.PostArgs().Peek("confirm_password"))
formData := map[string]string{
"username": username,
"email": email,
}
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil { if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
setFlashAndFormData(ctx, err.Error(), formData) session.SetFlashMessage(ctx, "error", err.Error())
session.SetFormData(ctx, map[string]string{
"username": username,
"email": email,
})
ctx.Redirect("/register", fasthttp.StatusFound) ctx.Redirect("/register", fasthttp.StatusFound)
return return
} }
if _, err := users.ByUsername(username); err == nil { if _, err := users.ByUsername(username); err == nil {
setFlashAndFormData(ctx, "Username already exists", formData) session.SetFlashMessage(ctx, "error", "Username already exists")
session.SetFormData(ctx, map[string]string{
"username": username,
"email": email,
})
ctx.Redirect("/register", fasthttp.StatusFound) ctx.Redirect("/register", fasthttp.StatusFound)
return return
} }
if _, err := users.ByEmail(email); err == nil { if _, err := users.ByEmail(email); err == nil {
setFlashAndFormData(ctx, "Email already registered", formData) session.SetFlashMessage(ctx, "error", "Email already registered")
session.SetFormData(ctx, map[string]string{
"username": username,
"email": email,
})
ctx.Redirect("/register", fasthttp.StatusFound) ctx.Redirect("/register", fasthttp.StatusFound)
return return
} }
@ -172,7 +168,11 @@ func processRegister(ctx router.Ctx, _ []string) {
user.Auth = 1 user.Auth = 1
if err := user.Insert(); err != nil { if err := user.Insert(); err != nil {
setFlashAndFormData(ctx, "Failed to create account", formData) session.SetFlashMessage(ctx, "error", "Failed to create account")
session.SetFormData(ctx, map[string]string{
"username": username,
"email": email,
})
ctx.Redirect("/register", fasthttp.StatusFound) ctx.Redirect("/register", fasthttp.StatusFound)
return return
} }
@ -180,16 +180,10 @@ func processRegister(ctx router.Ctx, _ []string) {
// Auto-login after registration // Auto-login after registration
middleware.Login(ctx, user) middleware.Login(ctx, user)
// Set success message
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username))
session.Store(sess)
}
// Transfer CSRF token from cookie to session for authenticated user // Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
if sess := ctx.UserValue("session").(*session.Session); sess != nil { if session := csrf.GetCurrentSession(ctx); session != nil {
csrf.StoreToken(sess, cookieToken) csrf.StoreToken(session, cookieToken)
} }
} }
@ -235,10 +229,3 @@ func validateRegistration(username, email, password, confirmPassword string) err
} }
return nil return nil
} }
func setFlashAndFormData(ctx router.Ctx, message string, formData map[string]string) {
sess := ctx.UserValue("session").(*session.Session)
sess.SetFlash("error", message)
sess.Set("form_data", formData)
session.Store(sess)
}

View File

@ -49,13 +49,9 @@ func showTown(ctx router.Ctx, _ []string) {
} }
func showInn(ctx router.Ctx, _ []string) { func showInn(ctx router.Ctx, _ []string) {
sess := ctx.UserValue("session").(*session.Session)
var errorHTML string var errorHTML string
if flash := session.GetFlashMessage(ctx); flash != nil {
if flash, exists := sess.GetFlash("error"); exists { errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + flash.Message + "</div>"
if msg, ok := flash.(string); ok {
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + msg + "</div>"
}
} }
town := ctx.UserValue("town").(*towns.Town) town := ctx.UserValue("town").(*towns.Town)
@ -68,12 +64,11 @@ func showInn(ctx router.Ctx, _ []string) {
} }
func rest(ctx router.Ctx, _ []string) { func rest(ctx router.Ctx, _ []string) {
sess := ctx.UserValue("session").(*session.Session)
town := ctx.UserValue("town").(*towns.Town) town := ctx.UserValue("town").(*towns.Town)
user := ctx.UserValue("user").(*users.User) user := ctx.UserValue("user").(*users.User)
if user.Gold < town.InnCost { if user.Gold < town.InnCost {
sess.SetFlash("error", "You can't afford to stay here tonight.") session.SetFlashMessage(ctx, "error", "You can't afford to stay here tonight.")
ctx.Redirect("/town/inn", 303) ctx.Redirect("/town/inn", 303)
return return
} }
@ -88,13 +83,9 @@ func rest(ctx router.Ctx, _ []string) {
} }
func showShop(ctx router.Ctx, _ []string) { func showShop(ctx router.Ctx, _ []string) {
sess := ctx.UserValue("session").(*session.Session)
var errorHTML string var errorHTML string
if flash := session.GetFlashMessage(ctx); flash != nil {
if flash, exists := sess.GetFlash("error"); exists { errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + flash.Message + "</div>"
if msg, ok := flash.(string); ok {
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + msg + "</div>"
}
} }
town := ctx.UserValue("town").(*towns.Town) town := ctx.UserValue("town").(*towns.Town)
@ -118,32 +109,30 @@ func showShop(ctx router.Ctx, _ []string) {
} }
func buyItem(ctx router.Ctx, params []string) { func buyItem(ctx router.Ctx, params []string) {
sess := ctx.UserValue("session").(*session.Session)
id, err := strconv.Atoi(params[0]) id, err := strconv.Atoi(params[0])
if err != nil { if err != nil {
sess.SetFlash("error", "Error purchasing item; "+err.Error()) session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error())
ctx.Redirect("/town/shop", 302) ctx.Redirect("/town/shop", 302)
return return
} }
town := ctx.UserValue("town").(*towns.Town) town := ctx.UserValue("town").(*towns.Town)
if !slices.Contains(town.GetShopItems(), id) { if !slices.Contains(town.GetShopItems(), id) {
sess.SetFlash("error", "The item doesn't exist in this shop.") session.SetFlashMessage(ctx, "error", "The item doesn't exist in this shop.")
ctx.Redirect("/town/shop", 302) ctx.Redirect("/town/shop", 302)
return return
} }
item, err := items.Find(id) item, err := items.Find(id)
if err != nil { if err != nil {
sess.SetFlash("error", "Error purchasing item; "+err.Error()) session.SetFlashMessage(ctx, "error", "Error purchasing item; "+err.Error())
ctx.Redirect("/town/shop", 302) ctx.Redirect("/town/shop", 302)
return return
} }
user := ctx.UserValue("user").(*users.User) user := ctx.UserValue("user").(*users.User)
if user.Gold < item.Value { if user.Gold < item.Value {
sess.SetFlash("error", "You don't have enough gold to buy "+item.Name) session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy "+item.Name)
ctx.Redirect("/town/shop", 302) ctx.Redirect("/town/shop", 302)
return return
} }
@ -156,13 +145,9 @@ func buyItem(ctx router.Ctx, params []string) {
} }
func showMaps(ctx router.Ctx, _ []string) { func showMaps(ctx router.Ctx, _ []string) {
sess := ctx.UserValue("session").(*session.Session)
var errorHTML string var errorHTML string
if flash := session.GetFlashMessage(ctx); flash != nil {
if flash, exists := sess.GetFlash("error"); exists { errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + flash.Message + "</div>"
if msg, ok := flash.(string); ok {
errorHTML = `<div style="color: red; margin-bottom: 1rem;">` + msg + "</div>"
}
} }
town := ctx.UserValue("town").(*towns.Town) town := ctx.UserValue("town").(*towns.Town)
@ -201,25 +186,23 @@ func showMaps(ctx router.Ctx, _ []string) {
} }
func buyMap(ctx router.Ctx, params []string) { func buyMap(ctx router.Ctx, params []string) {
sess := ctx.UserValue("session").(*session.Session)
id, err := strconv.Atoi(params[0]) id, err := strconv.Atoi(params[0])
if err != nil { if err != nil {
sess.SetFlash("error", "Error purchasing map; "+err.Error()) session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error())
ctx.Redirect("/town/maps", 302) ctx.Redirect("/town/maps", 302)
return return
} }
mapped, err := towns.Find(id) mapped, err := towns.Find(id)
if err != nil { if err != nil {
sess.SetFlash("error", "Error purchasing map; "+err.Error()) session.SetFlashMessage(ctx, "error", "Error purchasing map; "+err.Error())
ctx.Redirect("/town/maps", 302) ctx.Redirect("/town/maps", 302)
return return
} }
user := ctx.UserValue("user").(*users.User) user := ctx.UserValue("user").(*users.User)
if user.Gold < mapped.MapCost { if user.Gold < mapped.MapCost {
sess.SetFlash("error", "You don't have enough gold to buy the map to "+mapped.Name) session.SetFlashMessage(ctx, "error", "You don't have enough gold to buy the map to "+mapped.Name)
ctx.Redirect("/town/maps", 302) ctx.Redirect("/town/maps", 302)
return return
} }

56
internal/session/flash.go Normal file
View File

@ -0,0 +1,56 @@
package session
type FlashMessage struct {
Type string `json:"type"`
Message string `json:"message"`
}
func (s *Session) SetFlash(key string, value any) {
if s.Data == nil {
s.Data = make(map[string]any)
}
flashData, ok := s.Data["_flash"].(map[string]any)
if !ok {
flashData = make(map[string]any)
}
flashData[key] = value
s.Data["_flash"] = flashData
}
func (s *Session) GetFlash(key string) (any, bool) {
if s.Data == nil {
return nil, false
}
flashData, ok := s.Data["_flash"].(map[string]any)
if !ok {
return nil, false
}
value, exists := flashData[key]
if exists {
delete(flashData, key)
if len(flashData) == 0 {
delete(s.Data, "_flash")
} else {
s.Data["_flash"] = flashData
}
}
return value, exists
}
func (s *Session) GetAllFlash() map[string]any {
if s.Data == nil {
return nil
}
flashData, ok := s.Data["_flash"].(map[string]any)
if !ok {
return nil
}
delete(s.Data, "_flash")
return flashData
}

View File

@ -1,35 +1,28 @@
package session package session
import ( import (
"encoding/json" "dk/internal/cookies"
"os" "dk/internal/router"
"sync" "time"
) )
// SessionManager handles session storage and persistence const SessionCookieName = "dk_session"
type SessionManager struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
}
var Manager *SessionManager var Manager *SessionManager
// Init initializes the global session manager type SessionManager struct {
func Init(filePath string) { store *Store
}
func Init(sessionsFilePath string) {
if Manager != nil { if Manager != nil {
panic("session manager already initialized") panic("session manager already initialized")
} }
Manager = &SessionManager{ Manager = &SessionManager{
sessions: make(map[string]*Session), store: NewStore(sessionsFilePath),
filePath: filePath,
} }
Manager.load()
} }
// GetManager returns the global session manager
func GetManager() *SessionManager { func GetManager() *SessionManager {
if Manager == nil { if Manager == nil {
panic("session manager not initialized") panic("session manager not initialized")
@ -37,116 +30,200 @@ func GetManager() *SessionManager {
return Manager return Manager
} }
// Create creates and stores a new session func (sm *SessionManager) Create(userID int, username, email string) *Session {
func (sm *SessionManager) Create(userID int) *Session { sess := New(userID, username, email)
sess := New(userID) sm.store.Save(sess)
sm.mu.Lock()
sm.sessions[sess.ID] = sess
sm.mu.Unlock()
return sess return sess
} }
// Get retrieves a session by ID
func (sm *SessionManager) Get(sessionID string) (*Session, bool) { func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
sm.mu.RLock() return sm.store.Get(sessionID)
sess, exists := sm.sessions[sessionID] }
sm.mu.RUnlock()
if !exists || sess.IsExpired() { func (sm *SessionManager) GetFromContext(ctx router.Ctx) (*Session, bool) {
if exists { sessionID := cookies.GetCookie(ctx, SessionCookieName)
sm.Delete(sessionID) if sessionID == "" {
}
return nil, false return nil, false
} }
return sm.Get(sessionID)
return sess, true
} }
// Store saves a session in memory (updates existing or creates new) func (sm *SessionManager) Update(sessionID string) bool {
func (sm *SessionManager) Store(sess *Session) { sess, exists := sm.store.Get(sessionID)
sm.mu.Lock() if !exists {
sm.sessions[sess.ID] = sess return false
sm.mu.Unlock() }
sess.Touch()
sm.store.Save(sess)
return true
} }
// Delete removes a session
func (sm *SessionManager) Delete(sessionID string) { func (sm *SessionManager) Delete(sessionID string) {
sm.mu.Lock() sm.store.Delete(sessionID)
delete(sm.sessions, sessionID)
sm.mu.Unlock()
} }
// Cleanup removes expired sessions func (sm *SessionManager) SetSessionCookie(ctx router.Ctx, sessionID string) {
func (sm *SessionManager) Cleanup() { cookies.SetSecureCookie(ctx, cookies.CookieOptions{
sm.mu.Lock() Name: SessionCookieName,
for id, sess := range sm.sessions { Value: sessionID,
if sess.IsExpired() { Path: "/",
delete(sm.sessions, id) Expires: time.Now().Add(DefaultExpiration),
} HTTPOnly: true,
} Secure: cookies.IsHTTPS(ctx),
sm.mu.Unlock() SameSite: "lax",
})
} }
// Stats returns session statistics func (sm *SessionManager) DeleteSessionCookie(ctx router.Ctx) {
func (sm *SessionManager) Stats() (total, active int) { cookies.DeleteCookie(ctx, SessionCookieName)
sm.mu.RLock()
defer sm.mu.RUnlock()
total = len(sm.sessions)
for _, sess := range sm.sessions {
if !sess.IsExpired() {
active++
}
}
return
} }
// load reads sessions from the JSON file func (sm *SessionManager) SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
func (sm *SessionManager) load() { sess, exists := sm.GetFromContext(ctx)
if sm.filePath == "" { if !exists {
return return false
} }
data, err := os.ReadFile(sm.filePath) sess.SetFlash("message", FlashMessage{
if err != nil { Type: msgType,
return // File doesn't exist or can't be read Message: message,
} })
sm.store.Save(sess)
var sessions map[string]*Session return true
if err := json.Unmarshal(data, &sessions); err != nil {
return // Invalid JSON
}
sm.mu.Lock()
for id, sess := range sessions {
if sess != nil && !sess.IsExpired() {
sess.ID = id // Ensure ID consistency
sm.sessions[id] = sess
}
}
sm.mu.Unlock()
} }
// Save writes sessions to the JSON file func (sm *SessionManager) GetFlashMessage(ctx router.Ctx) *FlashMessage {
func (sm *SessionManager) Save() error { sess, exists := sm.GetFromContext(ctx)
if sm.filePath == "" { if !exists {
return nil return nil
} }
sm.Cleanup() // Remove expired sessions before saving value, exists := sess.GetFlash("message")
if !exists {
sm.mu.RLock() return nil
data, err := json.MarshalIndent(sm.sessions, "", "\t")
sm.mu.RUnlock()
if err != nil {
return err
} }
return os.WriteFile(sm.filePath, data, 0600) sm.store.Save(sess)
if msg, ok := value.(FlashMessage); ok {
return &msg
}
if msgMap, ok := value.(map[string]interface{}); ok {
msg := &FlashMessage{}
if t, ok := msgMap["type"].(string); ok {
msg.Type = t
}
if m, ok := msgMap["message"].(string); ok {
msg.Message = m
}
return msg
}
return nil
} }
// Close saves sessions and cleans up func (sm *SessionManager) SetFormData(ctx router.Ctx, data map[string]string) bool {
func (sm *SessionManager) Close() error { sess, exists := sm.GetFromContext(ctx)
return sm.Save() if !exists {
return false
}
sess.Set("form_data", data)
sm.store.Save(sess)
return true
} }
func (sm *SessionManager) GetFormData(ctx router.Ctx) map[string]string {
sess, exists := sm.GetFromContext(ctx)
if !exists {
return nil
}
value, exists := sess.Get("form_data")
if !exists {
return nil
}
sess.Delete("form_data")
sm.store.Save(sess)
if formData, ok := value.(map[string]string); ok {
return formData
}
if formMap, ok := value.(map[string]interface{}); ok {
result := make(map[string]string)
for k, v := range formMap {
if str, ok := v.(string); ok {
result[k] = str
}
}
return result
}
return nil
}
func (sm *SessionManager) Stats() (total, active int) {
return sm.store.Stats()
}
func (sm *SessionManager) Close() error {
return sm.store.Close()
}
// Package-level convenience functions that use the global Manager
func Create(userID int, username, email string) *Session {
return Manager.Create(userID, username, email)
}
func Get(sessionID string) (*Session, bool) {
return Manager.Get(sessionID)
}
func GetFromContext(ctx router.Ctx) (*Session, bool) {
return Manager.GetFromContext(ctx)
}
func Update(sessionID string) bool {
return Manager.Update(sessionID)
}
func Delete(sessionID string) {
Manager.Delete(sessionID)
}
func SetSessionCookie(ctx router.Ctx, sessionID string) {
Manager.SetSessionCookie(ctx, sessionID)
}
func DeleteSessionCookie(ctx router.Ctx) {
Manager.DeleteSessionCookie(ctx)
}
func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
return Manager.SetFlashMessage(ctx, msgType, message)
}
func GetFlashMessage(ctx router.Ctx) *FlashMessage {
return Manager.GetFlashMessage(ctx)
}
func SetFormData(ctx router.Ctx, data map[string]string) bool {
return Manager.SetFormData(ctx, data)
}
func GetFormData(ctx router.Ctx) map[string]string {
return Manager.GetFormData(ctx)
}
func Stats() (total, active int) {
return Manager.Stats()
}
func Close() error {
return Manager.Close()
}

View File

@ -1,4 +1,5 @@
// session.go // Package session provides session management functionality.
// It includes session storage, flash messages, and data persistence.
package session package session
import ( import (
@ -12,97 +13,62 @@ const (
IDLength = 32 IDLength = 32
) )
// Session represents a user session
type Session struct { type Session struct {
ID string `json:"id"` ID string `json:"-"`
UserID int `json:"user_id"` // 0 for guest sessions UserID int `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
Data map[string]any `json:"data"` LastSeen time.Time `json:"last_seen"`
Data map[string]any `json:"data,omitempty"`
} }
// New creates a new session func New(userID int, username, email string) *Session {
func New(userID int) *Session {
return &Session{ return &Session{
ID: generateID(), ID: generateID(),
UserID: userID, UserID: userID,
Username: username,
Email: email,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(DefaultExpiration), ExpiresAt: time.Now().Add(DefaultExpiration),
LastSeen: time.Now(),
Data: make(map[string]any), Data: make(map[string]any),
} }
} }
// IsExpired checks if the session has expired
func (s *Session) IsExpired() bool { func (s *Session) IsExpired() bool {
return time.Now().After(s.ExpiresAt) return time.Now().After(s.ExpiresAt)
} }
// Touch extends the session expiration
func (s *Session) Touch() { func (s *Session) Touch() {
s.LastSeen = time.Now()
s.ExpiresAt = time.Now().Add(DefaultExpiration) s.ExpiresAt = time.Now().Add(DefaultExpiration)
} }
// Set stores a value in the session
func (s *Session) Set(key string, value any) { func (s *Session) Set(key string, value any) {
if s.Data == nil {
s.Data = make(map[string]any)
}
s.Data[key] = value s.Data[key] = value
} }
// Get retrieves a value from the session
func (s *Session) Get(key string) (any, bool) { func (s *Session) Get(key string) (any, bool) {
if s.Data == nil {
return nil, false
}
value, exists := s.Data[key] value, exists := s.Data[key]
return value, exists return value, exists
} }
// Delete removes a value from the session
func (s *Session) Delete(key string) { func (s *Session) Delete(key string) {
delete(s.Data, key) if s.Data != nil {
} delete(s.Data, key)
// SetFlash stores a flash message (consumed on next Get)
func (s *Session) SetFlash(key string, value any) {
s.Set("flash_"+key, value)
}
// GetFlash retrieves and removes a flash message
func (s *Session) GetFlash(key string) (any, bool) {
flashKey := "flash_" + key
value, exists := s.Get(flashKey)
if exists {
s.Delete(flashKey)
} }
return value, exists
} }
// generateID creates a random session ID
func generateID() string { func generateID() string {
bytes := make([]byte, IDLength) bytes := make([]byte, IDLength)
rand.Read(bytes) rand.Read(bytes)
return hex.EncodeToString(bytes) return hex.EncodeToString(bytes)
} }
// Package-level convenience functions
func Create(userID int) *Session {
return Manager.Create(userID)
}
func Get(sessionID string) (*Session, bool) {
return Manager.Get(sessionID)
}
func Store(sess *Session) {
Manager.Store(sess)
}
func Delete(sessionID string) {
Manager.Delete(sessionID)
}
func Cleanup() {
Manager.Cleanup()
}
func Stats() (total, active int) {
return Manager.Stats()
}
func Close() error {
return Manager.Close()
}

161
internal/session/store.go Normal file
View File

@ -0,0 +1,161 @@
package session
import (
"encoding/json"
"maps"
"os"
"sync"
"time"
)
type Store struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
saveInterval time.Duration
stopChan chan struct{}
}
type persistedData struct {
Sessions map[string]*Session `json:"sessions"`
SavedAt time.Time `json:"saved_at"`
}
func NewStore(filePath string) *Store {
store := &Store{
sessions: make(map[string]*Session),
filePath: filePath,
saveInterval: 5 * time.Minute,
stopChan: make(chan struct{}),
}
store.loadFromFile()
store.startPeriodicSave()
return store
}
func (s *Store) Save(session *Session) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[session.ID] = session
}
func (s *Store) Get(sessionID string) (*Session, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, false
}
if session.IsExpired() {
return nil, false
}
return session, true
}
func (s *Store) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *Store) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
for id, session := range s.sessions {
if session.IsExpired() {
delete(s.sessions, id)
}
}
}
func (s *Store) Stats() (total, active int) {
s.mu.RLock()
defer s.mu.RUnlock()
total = len(s.sessions)
for _, session := range s.sessions {
if !session.IsExpired() {
active++
}
}
return
}
func (s *Store) loadFromFile() {
if s.filePath == "" {
return
}
data, err := os.ReadFile(s.filePath)
if err != nil {
return
}
var persisted persistedData
if err := json.Unmarshal(data, &persisted); err != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
for id, session := range persisted.Sessions {
if !session.IsExpired() {
session.ID = id
s.sessions[id] = session
}
}
}
func (s *Store) saveToFile() error {
if s.filePath == "" {
return nil
}
s.mu.RLock()
sessionsCopy := make(map[string]*Session, len(s.sessions))
maps.Copy(sessionsCopy, s.sessions)
s.mu.RUnlock()
data := persistedData{
Sessions: sessionsCopy,
SavedAt: time.Now(),
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.filePath, jsonData, 0600)
}
func (s *Store) startPeriodicSave() {
go func() {
ticker := time.NewTicker(s.saveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.Cleanup()
s.saveToFile()
case <-s.stopChan:
s.saveToFile()
return
}
}
}()
}
func (s *Store) Close() error {
close(s.stopChan)
return s.saveToFile()
}

View File

@ -3,7 +3,6 @@ package components
import ( import (
"fmt" "fmt"
"maps" "maps"
"runtime"
"strings" "strings"
"dk/internal/csrf" "dk/internal/csrf"
@ -23,18 +22,15 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin
return fmt.Errorf("failed to load layout template: %w", err) return fmt.Errorf("failed to load layout template: %w", err)
} }
var m runtime.MemStats
runtime.ReadMemStats(&m)
data := map[string]any{ data := map[string]any{
"_title": PageTitle(title), "_title": PageTitle(title),
"authenticated": middleware.IsAuthenticated(ctx), "authenticated": middleware.IsAuthenticated(ctx),
"csrf": csrf.HiddenField(ctx), "csrf": csrf.HiddenField(ctx),
"_totaltime": middleware.GetRequestTime(ctx), "_totaltime": middleware.GetRequestTime(ctx),
"_numqueries": 0,
"_version": "1.0.0", "_version": "1.0.0",
"_build": "dev", "_build": "dev",
"user": middleware.GetCurrentUser(ctx), "user": middleware.GetCurrentUser(ctx),
"_memalloc": m.Alloc / 1024 / 1024,
} }
maps.Copy(data, LeftAside(ctx)) maps.Copy(data, LeftAside(ctx))

View File

@ -158,14 +158,15 @@ func start(port string) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to get current working directory: %w", err) return fmt.Errorf("failed to get current working directory: %w", err)
} }
// Initialize template singleton
template.InitializeCache(cwd) template.InitializeCache(cwd)
// Load all model data into memory
if err := loadModels(); err != nil { if err := loadModels(); err != nil {
return fmt.Errorf("failed to load models: %w", err) return fmt.Errorf("failed to load models: %w", err)
} }
session.Init("sessions.json") session.Init("sessions.json") // Initialize session.Manager
r := router.New() r := router.New()
r.Use(middleware.Timing()) r.Use(middleware.Timing())
@ -173,8 +174,8 @@ func start(port string) error {
r.Use(middleware.CSRF()) r.Use(middleware.CSRF())
r.Get("/", routes.Index) r.Get("/", routes.Index)
r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore) r.Use(middleware.RequireAuth()).Get("/explore", routes.Explore)
r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move) r.Use(middleware.RequireAuth()).Post("/move", routes.Move)
routes.RegisterAuthRoutes(r) routes.RegisterAuthRoutes(r)
routes.RegisterTownRoutes(r) routes.RegisterTownRoutes(r)

View File

@ -44,7 +44,7 @@
<footer> <footer>
<div>Powered by <a href="/">Dragon Knight</a></div> <div>Powered by <a href="/">Dragon Knight</a></div>
<div>&copy; 2025 Sharkk</div> <div>&copy; 2025 Sharkk</div>
<div>{_totaltime} Seconds, {_memalloc} MiB</div> <div>{_totaltime} Seconds, {_numqueries} Queries</div>
<div>Version {_version} {_build}</div> <div>Version {_version} {_build}</div>
</footer> </footer>
</div> </div>