diff --git a/.gitignore b/.gitignore index 0c15d5d..4ae2e3c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Dragon Knight test/build files /dk -/sessions.json -/data/users.json +_sessions.json +users.json /tmp diff --git a/internal/middleware/auth.go b/internal/auth/auth.go similarity index 92% rename from internal/middleware/auth.go rename to internal/auth/auth.go index 7c4bbc5..cfce668 100644 --- a/internal/middleware/auth.go +++ b/internal/auth/auth.go @@ -1,4 +1,4 @@ -package middleware +package auth import ( "dk/internal/cookies" @@ -14,7 +14,7 @@ import ( const SessionCookieName = "dk_session" -func Auth() router.Middleware { +func Middleware() router.Middleware { return func(next router.Handler) router.Handler { return func(ctx router.Ctx, params []string) { sessionID := cookies.GetCookie(ctx, SessionCookieName) @@ -108,8 +108,11 @@ func GetCurrentSession(ctx router.Ctx) *session.Session { } func Login(ctx router.Ctx, user *users.User) { - sess := session.Create(user.ID) - setSessionCookie(ctx, sess.ID) + sess := ctx.UserValue("session").(*session.Session) + sess.RegenerateID() + sess.Set("user_id", user.ID) + sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username)) + session.Store(sess) ctx.SetUserValue("session", sess) ctx.SetUserValue("user", user) diff --git a/internal/csrf/csrf.go b/internal/csrf/csrf.go index dfe61f0..9545137 100644 --- a/internal/csrf/csrf.go +++ b/internal/csrf/csrf.go @@ -199,3 +199,25 @@ func StoreTokenInCookie(ctx router.Ctx, token string) { func GetTokenFromCookie(ctx router.Ctx) string { return string(ctx.Request.Header.Cookie(CookieName)) } + +// Middleware returns a middleware function that automatically validates CSRF tokens +// for state-changing HTTP methods (POST, PUT, PATCH, DELETE) +func Middleware() router.Middleware { + return func(next router.Handler) router.Handler { + return func(ctx router.Ctx, params []string) { + method := string(ctx.Method()) + + // Only validate CSRF for state-changing methods + if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { + if !ValidateFormToken(ctx) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.WriteString("CSRF validation failed") + return + } + } + + // Continue to next handler + next(ctx, params) + } + } +} diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go deleted file mode 100644 index aff7f32..0000000 --- a/internal/middleware/csrf.go +++ /dev/null @@ -1,117 +0,0 @@ -package middleware - -import ( - "dk/internal/csrf" - "dk/internal/router" - "slices" - - "github.com/valyala/fasthttp" -) - -// CSRFConfig holds configuration for CSRF middleware -type CSRFConfig struct { - // Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS) - SkipMethods []string - // Custom failure handler (default: returns 403) - FailureHandler func(ctx router.Ctx) - // Skip CSRF for certain paths - SkipPaths []string -} - -// CSRF creates a CSRF protection middleware -func CSRF(config ...CSRFConfig) router.Middleware { - cfg := CSRFConfig{ - SkipMethods: []string{"GET", "HEAD", "OPTIONS"}, - FailureHandler: func(ctx router.Ctx) { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.SetContentType("text/plain") - ctx.WriteString("CSRF token validation failed") - }, - SkipPaths: []string{}, - } - - // Apply custom config if provided - if len(config) > 0 { - if len(config[0].SkipMethods) > 0 { - cfg.SkipMethods = config[0].SkipMethods - } - if config[0].FailureHandler != nil { - cfg.FailureHandler = config[0].FailureHandler - } - if len(config[0].SkipPaths) > 0 { - cfg.SkipPaths = config[0].SkipPaths - } - } - - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - method := string(ctx.Method()) - path := string(ctx.Path()) - - // Skip CSRF validation for certain methods - shouldSkip := slices.Contains(cfg.SkipMethods, method) - - // Skip CSRF validation for certain paths - if !shouldSkip { - if slices.Contains(cfg.SkipPaths, path) { - shouldSkip = true - } - } - - // CSRF protection now works for both authenticated and guest users - // Remove the skip for non-authenticated users - - if shouldSkip { - next(ctx, params) - return - } - - // Validate CSRF token for protected methods - if !csrf.ValidateFormToken(ctx) { - cfg.FailureHandler(ctx) - return - } - - next(ctx, params) - } - } -} - -// RequireCSRF is a stricter CSRF middleware that always validates tokens -func RequireCSRF(failureHandler ...func(router.Ctx)) router.Middleware { - handler := func(ctx router.Ctx) { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.SetContentType("text/plain") - ctx.WriteString("CSRF token required") - } - - if len(failureHandler) > 0 { - handler = failureHandler[0] - } - - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - if !csrf.ValidateFormToken(ctx) { - handler(ctx) - return - } - - next(ctx, params) - } - } -} - -// CSRFToken returns the current CSRF token for the request -func CSRFToken(ctx router.Ctx) string { - return csrf.GetToken(ctx) -} - -// CSRFHiddenField generates a hidden input field for forms -func CSRFHiddenField(ctx router.Ctx) string { - return csrf.HiddenField(ctx) -} - -// CSRFMeta generates a meta tag for JavaScript access -func CSRFMeta(ctx router.Ctx) string { - return csrf.TokenMeta(ctx) -} diff --git a/internal/middleware/doc.go b/internal/middleware/doc.go deleted file mode 100644 index 74047e9..0000000 --- a/internal/middleware/doc.go +++ /dev/null @@ -1,4 +0,0 @@ -// Package middleware provides reusable HTTP middleware for the Dragon Knight server. -// Middleware functions wrap request handlers to add cross-cutting functionality -// like timing, logging, authentication, and request processing. -package middleware \ No newline at end of file diff --git a/internal/routes/auth.go b/internal/routes/auth.go index 81e4ff0..3beeca2 100644 --- a/internal/routes/auth.go +++ b/internal/routes/auth.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" + "dk/internal/auth" "dk/internal/csrf" - "dk/internal/middleware" "dk/internal/models/users" "dk/internal/password" "dk/internal/router" @@ -18,7 +18,7 @@ import ( // RegisterAuthRoutes sets up authentication routes func RegisterAuthRoutes(r *router.Router) { guests := r.Group("") - guests.Use(middleware.RequireGuest()) + guests.Use(auth.RequireGuest()) guests.Get("/login", showLogin) guests.Post("/login", processLogin) @@ -26,7 +26,7 @@ func RegisterAuthRoutes(r *router.Router) { guests.Post("/register", processRegister) authed := r.Group("") - authed.Use(middleware.RequireAuth()) + authed.Use(auth.RequireAuth()) authed.Post("/logout", processLogout) } @@ -59,12 +59,6 @@ func showLogin(ctx router.Ctx, _ []string) { // processLogin handles login form submission func processLogin(ctx router.Ctx, _ []string) { - if !csrf.ValidateFormToken(ctx) { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.WriteString("CSRF validation failed") - return - } - email := strings.TrimSpace(string(ctx.PostArgs().Peek("id"))) userPassword := string(ctx.PostArgs().Peek("password")) @@ -81,13 +75,7 @@ func processLogin(ctx router.Ctx, _ []string) { return } - middleware.Login(ctx, user) - - // Set success message - if sess := ctx.UserValue("session").(*session.Session); sess != nil { - sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username)) - session.Store(sess) - } + auth.Login(ctx, user) // Transfer CSRF token from cookie to session for authenticated user if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" { @@ -129,12 +117,6 @@ func showRegister(ctx router.Ctx, _ []string) { // processRegister handles registration form submission func processRegister(ctx router.Ctx, _ []string) { - if !csrf.ValidateFormToken(ctx) { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.WriteString("CSRF validation failed") - return - } - username := strings.TrimSpace(string(ctx.PostArgs().Peek("username"))) email := strings.TrimSpace(string(ctx.PostArgs().Peek("email"))) userPassword := string(ctx.PostArgs().Peek("password")) @@ -176,8 +158,15 @@ func processRegister(ctx router.Ctx, _ []string) { return } + // Store old session ID before creating new one + oldSess := ctx.UserValue("session").(*session.Session) + oldSessionID := oldSess.ID + // Auto-login after registration - middleware.Login(ctx, user) + auth.Login(ctx, user) + + // Clean up old guest session + session.Delete(oldSessionID) // Set success message if sess := ctx.UserValue("session").(*session.Session); sess != nil { @@ -197,14 +186,7 @@ func processRegister(ctx router.Ctx, _ []string) { // processLogout handles logout func processLogout(ctx router.Ctx, params []string) { - // Validate CSRF token - if !csrf.ValidateFormToken(ctx) { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.WriteString("CSRF validation failed") - return - } - - middleware.Logout(ctx) + auth.Logout(ctx) ctx.Redirect("/", fasthttp.StatusFound) } diff --git a/internal/routes/town.go b/internal/routes/town.go index 2ed670d..57ab8b0 100644 --- a/internal/routes/town.go +++ b/internal/routes/town.go @@ -2,6 +2,7 @@ package routes import ( "dk/internal/actions" + "dk/internal/auth" "dk/internal/helpers" "dk/internal/middleware" "dk/internal/models/items" @@ -27,12 +28,12 @@ type Map struct { func RegisterTownRoutes(r *router.Router) { group := r.Group("/town") - group.Use(middleware.RequireAuth()) + group.Use(auth.RequireAuth()) group.Use(middleware.RequireTown()) group.Get("/", showTown) group.Get("/inn", showInn) - group.WithMiddleware(middleware.CSRF()).Post("/inn", rest) + group.Post("/inn", rest) group.Get("/shop", showShop) group.Get("/shop/buy/:id", buyItem) group.Get("/maps", showMaps) diff --git a/internal/session/manager.go b/internal/session/manager.go index 69ab098..872d9bc 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "sync" + "time" ) // SessionManager handles session storage and persistence @@ -15,6 +16,13 @@ type SessionManager struct { var Manager *SessionManager +// sessionData represents session data for JSON serialization (excludes ID) +type sessionData struct { + UserID int `json:"user_id"` + ExpiresAt int64 `json:"expires_at"` + Data map[string]any `json:"data"` +} + // Init initializes the global session manager func Init(filePath string) { if Manager != nil { @@ -112,15 +120,21 @@ func (sm *SessionManager) load() { return // File doesn't exist or can't be read } - var sessions map[string]*Session - if err := json.Unmarshal(data, &sessions); err != nil { + var sessionsData map[string]*sessionData + if err := json.Unmarshal(data, &sessionsData); err != nil { return // Invalid JSON } + now := time.Now().Unix() sm.mu.Lock() - for id, sess := range sessions { - if sess != nil && !sess.IsExpired() { - sess.ID = id // Ensure ID consistency + for id, data := range sessionsData { + if data != nil && data.ExpiresAt > now { + sess := &Session{ + ID: id, + UserID: data.UserID, + ExpiresAt: data.ExpiresAt, + Data: data.Data, + } sm.sessions[id] = sess } } @@ -136,7 +150,18 @@ func (sm *SessionManager) Save() error { sm.Cleanup() // Remove expired sessions before saving sm.mu.RLock() - data, err := json.MarshalIndent(sm.sessions, "", "\t") + + // Convert sessions to sessionData (without ID field) + sessionsData := make(map[string]*sessionData, len(sm.sessions)) + for id, sess := range sm.sessions { + sessionsData[id] = &sessionData{ + UserID: sess.UserID, + ExpiresAt: sess.ExpiresAt, + Data: sess.Data, + } + } + + data, err := json.MarshalIndent(sessionsData, "", "\t") sm.mu.RUnlock() if err != nil { diff --git a/internal/session/session.go b/internal/session/session.go index 65c3d99..3fbdf18 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -16,7 +16,7 @@ const ( type Session struct { ID string `json:"id"` UserID int `json:"user_id"` // 0 for guest sessions - ExpiresAt time.Time `json:"expires_at"` + ExpiresAt int64 `json:"expires_at"` Data map[string]any `json:"data"` } @@ -25,19 +25,19 @@ func New(userID int) *Session { return &Session{ ID: generateID(), UserID: userID, - ExpiresAt: time.Now().Add(DefaultExpiration), + ExpiresAt: time.Now().Add(DefaultExpiration).Unix(), Data: make(map[string]any), } } // IsExpired checks if the session has expired func (s *Session) IsExpired() bool { - return time.Now().After(s.ExpiresAt) + return time.Now().Unix() > s.ExpiresAt } // Touch extends the session expiration func (s *Session) Touch() { - s.ExpiresAt = time.Now().Add(DefaultExpiration) + s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix() } // Set stores a value in the session @@ -71,6 +71,19 @@ func (s *Session) GetFlash(key string) (any, bool) { return value, exists } +// RegenerateID creates a new session ID and updates storage +func (s *Session) RegenerateID() { + oldID := s.ID + s.ID = generateID() + + if Manager != nil { + Manager.mu.Lock() + delete(Manager.sessions, oldID) + Manager.sessions[s.ID] = s + Manager.mu.Unlock() + } +} + // generateID creates a random session ID func generateID() string { bytes := make([]byte, IDLength) @@ -106,3 +119,8 @@ func Stats() (total, active int) { func Close() error { return Manager.Close() } + +// RegenerateID regenerates the session ID for security (package-level convenience) +func RegenerateID(sess *Session) { + sess.RegenerateID() +} diff --git a/internal/template/components/asides.go b/internal/template/components/asides.go index c305cb7..4a988e0 100644 --- a/internal/template/components/asides.go +++ b/internal/template/components/asides.go @@ -1,10 +1,11 @@ package components import ( + "dk/internal/auth" "dk/internal/helpers" - "dk/internal/middleware" "dk/internal/models/spells" "dk/internal/models/towns" + "dk/internal/models/users" "dk/internal/router" ) @@ -13,11 +14,12 @@ import ( func LeftAside(ctx router.Ctx) map[string]any { data := map[string]any{} - user := middleware.GetCurrentUser(ctx) - if user == nil { + if !auth.IsAuthenticated(ctx) { return data } + user := ctx.UserValue("user").(*users.User) + // Build owned town maps list if user.Towns != "" { townMap := helpers.NewOrderedMap[int, string]() @@ -37,11 +39,12 @@ func LeftAside(ctx router.Ctx) map[string]any { func RightAside(ctx router.Ctx) map[string]any { data := map[string]any{} - user := middleware.GetCurrentUser(ctx) - if user == nil { + if !auth.IsAuthenticated(ctx) { return data } + user := ctx.UserValue("user").(*users.User) + hpPct := helpers.ClampPct(float64(user.HP), float64(user.MaxHP), 0, 100) data["hppct"] = hpPct data["mppct"] = helpers.ClampPct(float64(user.MP), float64(user.MaxMP), 0, 100) diff --git a/internal/template/components/page.go b/internal/template/components/page.go index 72de21a..87644bb 100644 --- a/internal/template/components/page.go +++ b/internal/template/components/page.go @@ -6,6 +6,7 @@ import ( "runtime" "strings" + "dk/internal/auth" "dk/internal/csrf" "dk/internal/middleware" "dk/internal/router" @@ -28,12 +29,12 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin data := map[string]any{ "_title": PageTitle(title), - "authenticated": middleware.IsAuthenticated(ctx), + "authenticated": auth.IsAuthenticated(ctx), "csrf": csrf.HiddenField(ctx), "_totaltime": middleware.GetRequestTime(ctx), "_version": "1.0.0", "_build": "dev", - "user": middleware.GetCurrentUser(ctx), + "user": auth.GetCurrentUser(ctx), "_memalloc": m.Alloc / 1024 / 1024, } diff --git a/internal/template/components/topnav.go b/internal/template/components/topnav.go index 27a3a21..37a7dcf 100644 --- a/internal/template/components/topnav.go +++ b/internal/template/components/topnav.go @@ -1,15 +1,15 @@ package components import ( + "dk/internal/auth" "dk/internal/csrf" - "dk/internal/middleware" "dk/internal/router" "fmt" ) // GenerateTopNav generates the top navigation HTML based on authentication status func GenerateTopNav(ctx router.Ctx) string { - if middleware.IsAuthenticated(ctx) { + if auth.IsAuthenticated(ctx) { return fmt.Sprintf(`
%s diff --git a/main.go b/main.go index ac44bcc..ed29c4c 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,8 @@ import ( "path/filepath" "syscall" + "dk/internal/auth" + "dk/internal/csrf" "dk/internal/middleware" "dk/internal/models/babble" "dk/internal/models/control" @@ -165,16 +167,16 @@ func start(port string) error { return fmt.Errorf("failed to load models: %w", err) } - session.Init("sessions.json") + session.Init("data/_sessions.json") r := router.New() r.Use(middleware.Timing()) - r.Use(middleware.Auth()) - r.Use(middleware.CSRF()) + r.Use(auth.Middleware()) + r.Use(csrf.Middleware()) r.Get("/", routes.Index) - r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore) - r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move) + r.WithMiddleware(auth.RequireAuth()).Get("/explore", routes.Explore) + r.WithMiddleware(auth.RequireAuth()).Post("/move", routes.Move) routes.RegisterAuthRoutes(r) routes.RegisterTownRoutes(r)