move middleware/auth to its own package, more work on session management

This commit is contained in:
Sky Johnson 2025-08-14 16:00:07 -05:00
parent c5218c6061
commit bfe6c12a7a
13 changed files with 120 additions and 184 deletions

4
.gitignore vendored
View File

@ -1,5 +1,5 @@
# Dragon Knight test/build files
/dk
/sessions.json
/data/users.json
_sessions.json
users.json
/tmp

View File

@ -1,4 +1,4 @@
package middleware
package auth
import (
"dk/internal/cookies"
@ -14,7 +14,7 @@ import (
const SessionCookieName = "dk_session"
func Auth() router.Middleware {
func Middleware() router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
sessionID := cookies.GetCookie(ctx, SessionCookieName)
@ -108,8 +108,11 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
}
func Login(ctx router.Ctx, user *users.User) {
sess := session.Create(user.ID)
setSessionCookie(ctx, sess.ID)
sess := ctx.UserValue("session").(*session.Session)
sess.RegenerateID()
sess.Set("user_id", user.ID)
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
session.Store(sess)
ctx.SetUserValue("session", sess)
ctx.SetUserValue("user", user)

View File

@ -199,3 +199,25 @@ func StoreTokenInCookie(ctx router.Ctx, token string) {
func GetTokenFromCookie(ctx router.Ctx) string {
return string(ctx.Request.Header.Cookie(CookieName))
}
// Middleware returns a middleware function that automatically validates CSRF tokens
// for state-changing HTTP methods (POST, PUT, PATCH, DELETE)
func Middleware() router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
method := string(ctx.Method())
// Only validate CSRF for state-changing methods
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateFormToken(ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.WriteString("CSRF validation failed")
return
}
}
// Continue to next handler
next(ctx, params)
}
}
}

View File

@ -1,117 +0,0 @@
package middleware
import (
"dk/internal/csrf"
"dk/internal/router"
"slices"
"github.com/valyala/fasthttp"
)
// CSRFConfig holds configuration for CSRF middleware
type CSRFConfig struct {
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
SkipMethods []string
// Custom failure handler (default: returns 403)
FailureHandler func(ctx router.Ctx)
// Skip CSRF for certain paths
SkipPaths []string
}
// CSRF creates a CSRF protection middleware
func CSRF(config ...CSRFConfig) router.Middleware {
cfg := CSRFConfig{
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
FailureHandler: func(ctx router.Ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("text/plain")
ctx.WriteString("CSRF token validation failed")
},
SkipPaths: []string{},
}
// Apply custom config if provided
if len(config) > 0 {
if len(config[0].SkipMethods) > 0 {
cfg.SkipMethods = config[0].SkipMethods
}
if config[0].FailureHandler != nil {
cfg.FailureHandler = config[0].FailureHandler
}
if len(config[0].SkipPaths) > 0 {
cfg.SkipPaths = config[0].SkipPaths
}
}
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
method := string(ctx.Method())
path := string(ctx.Path())
// Skip CSRF validation for certain methods
shouldSkip := slices.Contains(cfg.SkipMethods, method)
// Skip CSRF validation for certain paths
if !shouldSkip {
if slices.Contains(cfg.SkipPaths, path) {
shouldSkip = true
}
}
// CSRF protection now works for both authenticated and guest users
// Remove the skip for non-authenticated users
if shouldSkip {
next(ctx, params)
return
}
// Validate CSRF token for protected methods
if !csrf.ValidateFormToken(ctx) {
cfg.FailureHandler(ctx)
return
}
next(ctx, params)
}
}
}
// RequireCSRF is a stricter CSRF middleware that always validates tokens
func RequireCSRF(failureHandler ...func(router.Ctx)) router.Middleware {
handler := func(ctx router.Ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("text/plain")
ctx.WriteString("CSRF token required")
}
if len(failureHandler) > 0 {
handler = failureHandler[0]
}
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
if !csrf.ValidateFormToken(ctx) {
handler(ctx)
return
}
next(ctx, params)
}
}
}
// CSRFToken returns the current CSRF token for the request
func CSRFToken(ctx router.Ctx) string {
return csrf.GetToken(ctx)
}
// CSRFHiddenField generates a hidden input field for forms
func CSRFHiddenField(ctx router.Ctx) string {
return csrf.HiddenField(ctx)
}
// CSRFMeta generates a meta tag for JavaScript access
func CSRFMeta(ctx router.Ctx) string {
return csrf.TokenMeta(ctx)
}

View File

@ -1,4 +0,0 @@
// Package middleware provides reusable HTTP middleware for the Dragon Knight server.
// Middleware functions wrap request handlers to add cross-cutting functionality
// like timing, logging, authentication, and request processing.
package middleware

View File

@ -4,8 +4,8 @@ import (
"fmt"
"strings"
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/middleware"
"dk/internal/models/users"
"dk/internal/password"
"dk/internal/router"
@ -18,7 +18,7 @@ import (
// RegisterAuthRoutes sets up authentication routes
func RegisterAuthRoutes(r *router.Router) {
guests := r.Group("")
guests.Use(middleware.RequireGuest())
guests.Use(auth.RequireGuest())
guests.Get("/login", showLogin)
guests.Post("/login", processLogin)
@ -26,7 +26,7 @@ func RegisterAuthRoutes(r *router.Router) {
guests.Post("/register", processRegister)
authed := r.Group("")
authed.Use(middleware.RequireAuth())
authed.Use(auth.RequireAuth())
authed.Post("/logout", processLogout)
}
@ -59,12 +59,6 @@ func showLogin(ctx router.Ctx, _ []string) {
// processLogin handles login form submission
func processLogin(ctx router.Ctx, _ []string) {
if !csrf.ValidateFormToken(ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.WriteString("CSRF validation failed")
return
}
email := strings.TrimSpace(string(ctx.PostArgs().Peek("id")))
userPassword := string(ctx.PostArgs().Peek("password"))
@ -81,13 +75,7 @@ func processLogin(ctx router.Ctx, _ []string) {
return
}
middleware.Login(ctx, user)
// Set success message
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
session.Store(sess)
}
auth.Login(ctx, user)
// Transfer CSRF token from cookie to session for authenticated user
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
@ -129,12 +117,6 @@ func showRegister(ctx router.Ctx, _ []string) {
// processRegister handles registration form submission
func processRegister(ctx router.Ctx, _ []string) {
if !csrf.ValidateFormToken(ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.WriteString("CSRF validation failed")
return
}
username := strings.TrimSpace(string(ctx.PostArgs().Peek("username")))
email := strings.TrimSpace(string(ctx.PostArgs().Peek("email")))
userPassword := string(ctx.PostArgs().Peek("password"))
@ -176,8 +158,15 @@ func processRegister(ctx router.Ctx, _ []string) {
return
}
// Store old session ID before creating new one
oldSess := ctx.UserValue("session").(*session.Session)
oldSessionID := oldSess.ID
// Auto-login after registration
middleware.Login(ctx, user)
auth.Login(ctx, user)
// Clean up old guest session
session.Delete(oldSessionID)
// Set success message
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
@ -197,14 +186,7 @@ func processRegister(ctx router.Ctx, _ []string) {
// processLogout handles logout
func processLogout(ctx router.Ctx, params []string) {
// Validate CSRF token
if !csrf.ValidateFormToken(ctx) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.WriteString("CSRF validation failed")
return
}
middleware.Logout(ctx)
auth.Logout(ctx)
ctx.Redirect("/", fasthttp.StatusFound)
}

View File

@ -2,6 +2,7 @@ package routes
import (
"dk/internal/actions"
"dk/internal/auth"
"dk/internal/helpers"
"dk/internal/middleware"
"dk/internal/models/items"
@ -27,12 +28,12 @@ type Map struct {
func RegisterTownRoutes(r *router.Router) {
group := r.Group("/town")
group.Use(middleware.RequireAuth())
group.Use(auth.RequireAuth())
group.Use(middleware.RequireTown())
group.Get("/", showTown)
group.Get("/inn", showInn)
group.WithMiddleware(middleware.CSRF()).Post("/inn", rest)
group.Post("/inn", rest)
group.Get("/shop", showShop)
group.Get("/shop/buy/:id", buyItem)
group.Get("/maps", showMaps)

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"os"
"sync"
"time"
)
// SessionManager handles session storage and persistence
@ -15,6 +16,13 @@ type SessionManager struct {
var Manager *SessionManager
// sessionData represents session data for JSON serialization (excludes ID)
type sessionData struct {
UserID int `json:"user_id"`
ExpiresAt int64 `json:"expires_at"`
Data map[string]any `json:"data"`
}
// Init initializes the global session manager
func Init(filePath string) {
if Manager != nil {
@ -112,15 +120,21 @@ func (sm *SessionManager) load() {
return // File doesn't exist or can't be read
}
var sessions map[string]*Session
if err := json.Unmarshal(data, &sessions); err != nil {
var sessionsData map[string]*sessionData
if err := json.Unmarshal(data, &sessionsData); err != nil {
return // Invalid JSON
}
now := time.Now().Unix()
sm.mu.Lock()
for id, sess := range sessions {
if sess != nil && !sess.IsExpired() {
sess.ID = id // Ensure ID consistency
for id, data := range sessionsData {
if data != nil && data.ExpiresAt > now {
sess := &Session{
ID: id,
UserID: data.UserID,
ExpiresAt: data.ExpiresAt,
Data: data.Data,
}
sm.sessions[id] = sess
}
}
@ -136,7 +150,18 @@ func (sm *SessionManager) Save() error {
sm.Cleanup() // Remove expired sessions before saving
sm.mu.RLock()
data, err := json.MarshalIndent(sm.sessions, "", "\t")
// Convert sessions to sessionData (without ID field)
sessionsData := make(map[string]*sessionData, len(sm.sessions))
for id, sess := range sm.sessions {
sessionsData[id] = &sessionData{
UserID: sess.UserID,
ExpiresAt: sess.ExpiresAt,
Data: sess.Data,
}
}
data, err := json.MarshalIndent(sessionsData, "", "\t")
sm.mu.RUnlock()
if err != nil {

View File

@ -16,7 +16,7 @@ const (
type Session struct {
ID string `json:"id"`
UserID int `json:"user_id"` // 0 for guest sessions
ExpiresAt time.Time `json:"expires_at"`
ExpiresAt int64 `json:"expires_at"`
Data map[string]any `json:"data"`
}
@ -25,19 +25,19 @@ func New(userID int) *Session {
return &Session{
ID: generateID(),
UserID: userID,
ExpiresAt: time.Now().Add(DefaultExpiration),
ExpiresAt: time.Now().Add(DefaultExpiration).Unix(),
Data: make(map[string]any),
}
}
// IsExpired checks if the session has expired
func (s *Session) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
return time.Now().Unix() > s.ExpiresAt
}
// Touch extends the session expiration
func (s *Session) Touch() {
s.ExpiresAt = time.Now().Add(DefaultExpiration)
s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix()
}
// Set stores a value in the session
@ -71,6 +71,19 @@ func (s *Session) GetFlash(key string) (any, bool) {
return value, exists
}
// RegenerateID creates a new session ID and updates storage
func (s *Session) RegenerateID() {
oldID := s.ID
s.ID = generateID()
if Manager != nil {
Manager.mu.Lock()
delete(Manager.sessions, oldID)
Manager.sessions[s.ID] = s
Manager.mu.Unlock()
}
}
// generateID creates a random session ID
func generateID() string {
bytes := make([]byte, IDLength)
@ -106,3 +119,8 @@ func Stats() (total, active int) {
func Close() error {
return Manager.Close()
}
// RegenerateID regenerates the session ID for security (package-level convenience)
func RegenerateID(sess *Session) {
sess.RegenerateID()
}

View File

@ -1,10 +1,11 @@
package components
import (
"dk/internal/auth"
"dk/internal/helpers"
"dk/internal/middleware"
"dk/internal/models/spells"
"dk/internal/models/towns"
"dk/internal/models/users"
"dk/internal/router"
)
@ -13,11 +14,12 @@ import (
func LeftAside(ctx router.Ctx) map[string]any {
data := map[string]any{}
user := middleware.GetCurrentUser(ctx)
if user == nil {
if !auth.IsAuthenticated(ctx) {
return data
}
user := ctx.UserValue("user").(*users.User)
// Build owned town maps list
if user.Towns != "" {
townMap := helpers.NewOrderedMap[int, string]()
@ -37,11 +39,12 @@ func LeftAside(ctx router.Ctx) map[string]any {
func RightAside(ctx router.Ctx) map[string]any {
data := map[string]any{}
user := middleware.GetCurrentUser(ctx)
if user == nil {
if !auth.IsAuthenticated(ctx) {
return data
}
user := ctx.UserValue("user").(*users.User)
hpPct := helpers.ClampPct(float64(user.HP), float64(user.MaxHP), 0, 100)
data["hppct"] = hpPct
data["mppct"] = helpers.ClampPct(float64(user.MP), float64(user.MaxMP), 0, 100)

View File

@ -6,6 +6,7 @@ import (
"runtime"
"strings"
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/middleware"
"dk/internal/router"
@ -28,12 +29,12 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin
data := map[string]any{
"_title": PageTitle(title),
"authenticated": middleware.IsAuthenticated(ctx),
"authenticated": auth.IsAuthenticated(ctx),
"csrf": csrf.HiddenField(ctx),
"_totaltime": middleware.GetRequestTime(ctx),
"_version": "1.0.0",
"_build": "dev",
"user": middleware.GetCurrentUser(ctx),
"user": auth.GetCurrentUser(ctx),
"_memalloc": m.Alloc / 1024 / 1024,
}

View File

@ -1,15 +1,15 @@
package components
import (
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/middleware"
"dk/internal/router"
"fmt"
)
// GenerateTopNav generates the top navigation HTML based on authentication status
func GenerateTopNav(ctx router.Ctx) string {
if middleware.IsAuthenticated(ctx) {
if auth.IsAuthenticated(ctx) {
return fmt.Sprintf(`<form action="/logout" method="post" class="logout">
%s
<button class="img-button" type="submit"><img src="/assets/images/button_logout.gif" alt="Log Out" title="Log Out"></button>

12
main.go
View File

@ -9,6 +9,8 @@ import (
"path/filepath"
"syscall"
"dk/internal/auth"
"dk/internal/csrf"
"dk/internal/middleware"
"dk/internal/models/babble"
"dk/internal/models/control"
@ -165,16 +167,16 @@ func start(port string) error {
return fmt.Errorf("failed to load models: %w", err)
}
session.Init("sessions.json")
session.Init("data/_sessions.json")
r := router.New()
r.Use(middleware.Timing())
r.Use(middleware.Auth())
r.Use(middleware.CSRF())
r.Use(auth.Middleware())
r.Use(csrf.Middleware())
r.Get("/", routes.Index)
r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore)
r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move)
r.WithMiddleware(auth.RequireAuth()).Get("/explore", routes.Explore)
r.WithMiddleware(auth.RequireAuth()).Post("/move", routes.Move)
routes.RegisterAuthRoutes(r)
routes.RegisterTownRoutes(r)