move middleware/auth to its own package, more work on session management
This commit is contained in:
parent
c5218c6061
commit
bfe6c12a7a
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,5 +1,5 @@
|
||||
# Dragon Knight test/build files
|
||||
/dk
|
||||
/sessions.json
|
||||
/data/users.json
|
||||
_sessions.json
|
||||
users.json
|
||||
/tmp
|
||||
|
@ -1,4 +1,4 @@
|
||||
package middleware
|
||||
package auth
|
||||
|
||||
import (
|
||||
"dk/internal/cookies"
|
||||
@ -14,7 +14,7 @@ import (
|
||||
|
||||
const SessionCookieName = "dk_session"
|
||||
|
||||
func Auth() router.Middleware {
|
||||
func Middleware() router.Middleware {
|
||||
return func(next router.Handler) router.Handler {
|
||||
return func(ctx router.Ctx, params []string) {
|
||||
sessionID := cookies.GetCookie(ctx, SessionCookieName)
|
||||
@ -108,8 +108,11 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||
}
|
||||
|
||||
func Login(ctx router.Ctx, user *users.User) {
|
||||
sess := session.Create(user.ID)
|
||||
setSessionCookie(ctx, sess.ID)
|
||||
sess := ctx.UserValue("session").(*session.Session)
|
||||
sess.RegenerateID()
|
||||
sess.Set("user_id", user.ID)
|
||||
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
|
||||
session.Store(sess)
|
||||
|
||||
ctx.SetUserValue("session", sess)
|
||||
ctx.SetUserValue("user", user)
|
@ -199,3 +199,25 @@ func StoreTokenInCookie(ctx router.Ctx, token string) {
|
||||
func GetTokenFromCookie(ctx router.Ctx) string {
|
||||
return string(ctx.Request.Header.Cookie(CookieName))
|
||||
}
|
||||
|
||||
// Middleware returns a middleware function that automatically validates CSRF tokens
|
||||
// for state-changing HTTP methods (POST, PUT, PATCH, DELETE)
|
||||
func Middleware() router.Middleware {
|
||||
return func(next router.Handler) router.Handler {
|
||||
return func(ctx router.Ctx, params []string) {
|
||||
method := string(ctx.Method())
|
||||
|
||||
// Only validate CSRF for state-changing methods
|
||||
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
||||
if !ValidateFormToken(ctx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.WriteString("CSRF validation failed")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Continue to next handler
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,117 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"dk/internal/csrf"
|
||||
"dk/internal/router"
|
||||
"slices"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// CSRFConfig holds configuration for CSRF middleware
|
||||
type CSRFConfig struct {
|
||||
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
|
||||
SkipMethods []string
|
||||
// Custom failure handler (default: returns 403)
|
||||
FailureHandler func(ctx router.Ctx)
|
||||
// Skip CSRF for certain paths
|
||||
SkipPaths []string
|
||||
}
|
||||
|
||||
// CSRF creates a CSRF protection middleware
|
||||
func CSRF(config ...CSRFConfig) router.Middleware {
|
||||
cfg := CSRFConfig{
|
||||
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
|
||||
FailureHandler: func(ctx router.Ctx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetContentType("text/plain")
|
||||
ctx.WriteString("CSRF token validation failed")
|
||||
},
|
||||
SkipPaths: []string{},
|
||||
}
|
||||
|
||||
// Apply custom config if provided
|
||||
if len(config) > 0 {
|
||||
if len(config[0].SkipMethods) > 0 {
|
||||
cfg.SkipMethods = config[0].SkipMethods
|
||||
}
|
||||
if config[0].FailureHandler != nil {
|
||||
cfg.FailureHandler = config[0].FailureHandler
|
||||
}
|
||||
if len(config[0].SkipPaths) > 0 {
|
||||
cfg.SkipPaths = config[0].SkipPaths
|
||||
}
|
||||
}
|
||||
|
||||
return func(next router.Handler) router.Handler {
|
||||
return func(ctx router.Ctx, params []string) {
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Skip CSRF validation for certain methods
|
||||
shouldSkip := slices.Contains(cfg.SkipMethods, method)
|
||||
|
||||
// Skip CSRF validation for certain paths
|
||||
if !shouldSkip {
|
||||
if slices.Contains(cfg.SkipPaths, path) {
|
||||
shouldSkip = true
|
||||
}
|
||||
}
|
||||
|
||||
// CSRF protection now works for both authenticated and guest users
|
||||
// Remove the skip for non-authenticated users
|
||||
|
||||
if shouldSkip {
|
||||
next(ctx, params)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF token for protected methods
|
||||
if !csrf.ValidateFormToken(ctx) {
|
||||
cfg.FailureHandler(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireCSRF is a stricter CSRF middleware that always validates tokens
|
||||
func RequireCSRF(failureHandler ...func(router.Ctx)) router.Middleware {
|
||||
handler := func(ctx router.Ctx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetContentType("text/plain")
|
||||
ctx.WriteString("CSRF token required")
|
||||
}
|
||||
|
||||
if len(failureHandler) > 0 {
|
||||
handler = failureHandler[0]
|
||||
}
|
||||
|
||||
return func(next router.Handler) router.Handler {
|
||||
return func(ctx router.Ctx, params []string) {
|
||||
if !csrf.ValidateFormToken(ctx) {
|
||||
handler(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFToken returns the current CSRF token for the request
|
||||
func CSRFToken(ctx router.Ctx) string {
|
||||
return csrf.GetToken(ctx)
|
||||
}
|
||||
|
||||
// CSRFHiddenField generates a hidden input field for forms
|
||||
func CSRFHiddenField(ctx router.Ctx) string {
|
||||
return csrf.HiddenField(ctx)
|
||||
}
|
||||
|
||||
// CSRFMeta generates a meta tag for JavaScript access
|
||||
func CSRFMeta(ctx router.Ctx) string {
|
||||
return csrf.TokenMeta(ctx)
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
// Package middleware provides reusable HTTP middleware for the Dragon Knight server.
|
||||
// Middleware functions wrap request handlers to add cross-cutting functionality
|
||||
// like timing, logging, authentication, and request processing.
|
||||
package middleware
|
@ -4,8 +4,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"dk/internal/auth"
|
||||
"dk/internal/csrf"
|
||||
"dk/internal/middleware"
|
||||
"dk/internal/models/users"
|
||||
"dk/internal/password"
|
||||
"dk/internal/router"
|
||||
@ -18,7 +18,7 @@ import (
|
||||
// RegisterAuthRoutes sets up authentication routes
|
||||
func RegisterAuthRoutes(r *router.Router) {
|
||||
guests := r.Group("")
|
||||
guests.Use(middleware.RequireGuest())
|
||||
guests.Use(auth.RequireGuest())
|
||||
|
||||
guests.Get("/login", showLogin)
|
||||
guests.Post("/login", processLogin)
|
||||
@ -26,7 +26,7 @@ func RegisterAuthRoutes(r *router.Router) {
|
||||
guests.Post("/register", processRegister)
|
||||
|
||||
authed := r.Group("")
|
||||
authed.Use(middleware.RequireAuth())
|
||||
authed.Use(auth.RequireAuth())
|
||||
|
||||
authed.Post("/logout", processLogout)
|
||||
}
|
||||
@ -59,12 +59,6 @@ func showLogin(ctx router.Ctx, _ []string) {
|
||||
|
||||
// processLogin handles login form submission
|
||||
func processLogin(ctx router.Ctx, _ []string) {
|
||||
if !csrf.ValidateFormToken(ctx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.WriteString("CSRF validation failed")
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(string(ctx.PostArgs().Peek("id")))
|
||||
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||
|
||||
@ -81,13 +75,7 @@ func processLogin(ctx router.Ctx, _ []string) {
|
||||
return
|
||||
}
|
||||
|
||||
middleware.Login(ctx, user)
|
||||
|
||||
// Set success message
|
||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
||||
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
|
||||
session.Store(sess)
|
||||
}
|
||||
auth.Login(ctx, user)
|
||||
|
||||
// Transfer CSRF token from cookie to session for authenticated user
|
||||
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||
@ -129,12 +117,6 @@ func showRegister(ctx router.Ctx, _ []string) {
|
||||
|
||||
// processRegister handles registration form submission
|
||||
func processRegister(ctx router.Ctx, _ []string) {
|
||||
if !csrf.ValidateFormToken(ctx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.WriteString("CSRF validation failed")
|
||||
return
|
||||
}
|
||||
|
||||
username := strings.TrimSpace(string(ctx.PostArgs().Peek("username")))
|
||||
email := strings.TrimSpace(string(ctx.PostArgs().Peek("email")))
|
||||
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||
@ -176,8 +158,15 @@ func processRegister(ctx router.Ctx, _ []string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Store old session ID before creating new one
|
||||
oldSess := ctx.UserValue("session").(*session.Session)
|
||||
oldSessionID := oldSess.ID
|
||||
|
||||
// Auto-login after registration
|
||||
middleware.Login(ctx, user)
|
||||
auth.Login(ctx, user)
|
||||
|
||||
// Clean up old guest session
|
||||
session.Delete(oldSessionID)
|
||||
|
||||
// Set success message
|
||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
||||
@ -197,14 +186,7 @@ func processRegister(ctx router.Ctx, _ []string) {
|
||||
|
||||
// processLogout handles logout
|
||||
func processLogout(ctx router.Ctx, params []string) {
|
||||
// Validate CSRF token
|
||||
if !csrf.ValidateFormToken(ctx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.WriteString("CSRF validation failed")
|
||||
return
|
||||
}
|
||||
|
||||
middleware.Logout(ctx)
|
||||
auth.Logout(ctx)
|
||||
ctx.Redirect("/", fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package routes
|
||||
|
||||
import (
|
||||
"dk/internal/actions"
|
||||
"dk/internal/auth"
|
||||
"dk/internal/helpers"
|
||||
"dk/internal/middleware"
|
||||
"dk/internal/models/items"
|
||||
@ -27,12 +28,12 @@ type Map struct {
|
||||
|
||||
func RegisterTownRoutes(r *router.Router) {
|
||||
group := r.Group("/town")
|
||||
group.Use(middleware.RequireAuth())
|
||||
group.Use(auth.RequireAuth())
|
||||
group.Use(middleware.RequireTown())
|
||||
|
||||
group.Get("/", showTown)
|
||||
group.Get("/inn", showInn)
|
||||
group.WithMiddleware(middleware.CSRF()).Post("/inn", rest)
|
||||
group.Post("/inn", rest)
|
||||
group.Get("/shop", showShop)
|
||||
group.Get("/shop/buy/:id", buyItem)
|
||||
group.Get("/maps", showMaps)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionManager handles session storage and persistence
|
||||
@ -15,6 +16,13 @@ type SessionManager struct {
|
||||
|
||||
var Manager *SessionManager
|
||||
|
||||
// sessionData represents session data for JSON serialization (excludes ID)
|
||||
type sessionData struct {
|
||||
UserID int `json:"user_id"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
// Init initializes the global session manager
|
||||
func Init(filePath string) {
|
||||
if Manager != nil {
|
||||
@ -112,15 +120,21 @@ func (sm *SessionManager) load() {
|
||||
return // File doesn't exist or can't be read
|
||||
}
|
||||
|
||||
var sessions map[string]*Session
|
||||
if err := json.Unmarshal(data, &sessions); err != nil {
|
||||
var sessionsData map[string]*sessionData
|
||||
if err := json.Unmarshal(data, &sessionsData); err != nil {
|
||||
return // Invalid JSON
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
sm.mu.Lock()
|
||||
for id, sess := range sessions {
|
||||
if sess != nil && !sess.IsExpired() {
|
||||
sess.ID = id // Ensure ID consistency
|
||||
for id, data := range sessionsData {
|
||||
if data != nil && data.ExpiresAt > now {
|
||||
sess := &Session{
|
||||
ID: id,
|
||||
UserID: data.UserID,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
Data: data.Data,
|
||||
}
|
||||
sm.sessions[id] = sess
|
||||
}
|
||||
}
|
||||
@ -136,7 +150,18 @@ func (sm *SessionManager) Save() error {
|
||||
sm.Cleanup() // Remove expired sessions before saving
|
||||
|
||||
sm.mu.RLock()
|
||||
data, err := json.MarshalIndent(sm.sessions, "", "\t")
|
||||
|
||||
// Convert sessions to sessionData (without ID field)
|
||||
sessionsData := make(map[string]*sessionData, len(sm.sessions))
|
||||
for id, sess := range sm.sessions {
|
||||
sessionsData[id] = &sessionData{
|
||||
UserID: sess.UserID,
|
||||
ExpiresAt: sess.ExpiresAt,
|
||||
Data: sess.Data,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(sessionsData, "", "\t")
|
||||
sm.mu.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
|
@ -16,7 +16,7 @@ const (
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID int `json:"user_id"` // 0 for guest sessions
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
@ -25,19 +25,19 @@ func New(userID int) *Session {
|
||||
return &Session{
|
||||
ID: generateID(),
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(DefaultExpiration),
|
||||
ExpiresAt: time.Now().Add(DefaultExpiration).Unix(),
|
||||
Data: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// IsExpired checks if the session has expired
|
||||
func (s *Session) IsExpired() bool {
|
||||
return time.Now().After(s.ExpiresAt)
|
||||
return time.Now().Unix() > s.ExpiresAt
|
||||
}
|
||||
|
||||
// Touch extends the session expiration
|
||||
func (s *Session) Touch() {
|
||||
s.ExpiresAt = time.Now().Add(DefaultExpiration)
|
||||
s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix()
|
||||
}
|
||||
|
||||
// Set stores a value in the session
|
||||
@ -71,6 +71,19 @@ func (s *Session) GetFlash(key string) (any, bool) {
|
||||
return value, exists
|
||||
}
|
||||
|
||||
// RegenerateID creates a new session ID and updates storage
|
||||
func (s *Session) RegenerateID() {
|
||||
oldID := s.ID
|
||||
s.ID = generateID()
|
||||
|
||||
if Manager != nil {
|
||||
Manager.mu.Lock()
|
||||
delete(Manager.sessions, oldID)
|
||||
Manager.sessions[s.ID] = s
|
||||
Manager.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// generateID creates a random session ID
|
||||
func generateID() string {
|
||||
bytes := make([]byte, IDLength)
|
||||
@ -106,3 +119,8 @@ func Stats() (total, active int) {
|
||||
func Close() error {
|
||||
return Manager.Close()
|
||||
}
|
||||
|
||||
// RegenerateID regenerates the session ID for security (package-level convenience)
|
||||
func RegenerateID(sess *Session) {
|
||||
sess.RegenerateID()
|
||||
}
|
||||
|
@ -1,10 +1,11 @@
|
||||
package components
|
||||
|
||||
import (
|
||||
"dk/internal/auth"
|
||||
"dk/internal/helpers"
|
||||
"dk/internal/middleware"
|
||||
"dk/internal/models/spells"
|
||||
"dk/internal/models/towns"
|
||||
"dk/internal/models/users"
|
||||
"dk/internal/router"
|
||||
)
|
||||
|
||||
@ -13,11 +14,12 @@ import (
|
||||
func LeftAside(ctx router.Ctx) map[string]any {
|
||||
data := map[string]any{}
|
||||
|
||||
user := middleware.GetCurrentUser(ctx)
|
||||
if user == nil {
|
||||
if !auth.IsAuthenticated(ctx) {
|
||||
return data
|
||||
}
|
||||
|
||||
user := ctx.UserValue("user").(*users.User)
|
||||
|
||||
// Build owned town maps list
|
||||
if user.Towns != "" {
|
||||
townMap := helpers.NewOrderedMap[int, string]()
|
||||
@ -37,11 +39,12 @@ func LeftAside(ctx router.Ctx) map[string]any {
|
||||
func RightAside(ctx router.Ctx) map[string]any {
|
||||
data := map[string]any{}
|
||||
|
||||
user := middleware.GetCurrentUser(ctx)
|
||||
if user == nil {
|
||||
if !auth.IsAuthenticated(ctx) {
|
||||
return data
|
||||
}
|
||||
|
||||
user := ctx.UserValue("user").(*users.User)
|
||||
|
||||
hpPct := helpers.ClampPct(float64(user.HP), float64(user.MaxHP), 0, 100)
|
||||
data["hppct"] = hpPct
|
||||
data["mppct"] = helpers.ClampPct(float64(user.MP), float64(user.MaxMP), 0, 100)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"dk/internal/auth"
|
||||
"dk/internal/csrf"
|
||||
"dk/internal/middleware"
|
||||
"dk/internal/router"
|
||||
@ -28,12 +29,12 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin
|
||||
|
||||
data := map[string]any{
|
||||
"_title": PageTitle(title),
|
||||
"authenticated": middleware.IsAuthenticated(ctx),
|
||||
"authenticated": auth.IsAuthenticated(ctx),
|
||||
"csrf": csrf.HiddenField(ctx),
|
||||
"_totaltime": middleware.GetRequestTime(ctx),
|
||||
"_version": "1.0.0",
|
||||
"_build": "dev",
|
||||
"user": middleware.GetCurrentUser(ctx),
|
||||
"user": auth.GetCurrentUser(ctx),
|
||||
"_memalloc": m.Alloc / 1024 / 1024,
|
||||
}
|
||||
|
||||
|
@ -1,15 +1,15 @@
|
||||
package components
|
||||
|
||||
import (
|
||||
"dk/internal/auth"
|
||||
"dk/internal/csrf"
|
||||
"dk/internal/middleware"
|
||||
"dk/internal/router"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GenerateTopNav generates the top navigation HTML based on authentication status
|
||||
func GenerateTopNav(ctx router.Ctx) string {
|
||||
if middleware.IsAuthenticated(ctx) {
|
||||
if auth.IsAuthenticated(ctx) {
|
||||
return fmt.Sprintf(`<form action="/logout" method="post" class="logout">
|
||||
%s
|
||||
<button class="img-button" type="submit"><img src="/assets/images/button_logout.gif" alt="Log Out" title="Log Out"></button>
|
||||
|
12
main.go
12
main.go
@ -9,6 +9,8 @@ import (
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"dk/internal/auth"
|
||||
"dk/internal/csrf"
|
||||
"dk/internal/middleware"
|
||||
"dk/internal/models/babble"
|
||||
"dk/internal/models/control"
|
||||
@ -165,16 +167,16 @@ func start(port string) error {
|
||||
return fmt.Errorf("failed to load models: %w", err)
|
||||
}
|
||||
|
||||
session.Init("sessions.json")
|
||||
session.Init("data/_sessions.json")
|
||||
|
||||
r := router.New()
|
||||
r.Use(middleware.Timing())
|
||||
r.Use(middleware.Auth())
|
||||
r.Use(middleware.CSRF())
|
||||
r.Use(auth.Middleware())
|
||||
r.Use(csrf.Middleware())
|
||||
|
||||
r.Get("/", routes.Index)
|
||||
r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore)
|
||||
r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move)
|
||||
r.WithMiddleware(auth.RequireAuth()).Get("/explore", routes.Explore)
|
||||
r.WithMiddleware(auth.RequireAuth()).Post("/move", routes.Move)
|
||||
routes.RegisterAuthRoutes(r)
|
||||
routes.RegisterTownRoutes(r)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user