move middleware/auth to its own package, more work on session management
This commit is contained in:
parent
c5218c6061
commit
bfe6c12a7a
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,5 +1,5 @@
|
|||||||
# Dragon Knight test/build files
|
# Dragon Knight test/build files
|
||||||
/dk
|
/dk
|
||||||
/sessions.json
|
_sessions.json
|
||||||
/data/users.json
|
users.json
|
||||||
/tmp
|
/tmp
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package middleware
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"dk/internal/cookies"
|
"dk/internal/cookies"
|
||||||
@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
const SessionCookieName = "dk_session"
|
const SessionCookieName = "dk_session"
|
||||||
|
|
||||||
func Auth() router.Middleware {
|
func Middleware() router.Middleware {
|
||||||
return func(next router.Handler) router.Handler {
|
return func(next router.Handler) router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
sessionID := cookies.GetCookie(ctx, SessionCookieName)
|
sessionID := cookies.GetCookie(ctx, SessionCookieName)
|
||||||
@ -108,8 +108,11 @@ func GetCurrentSession(ctx router.Ctx) *session.Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Login(ctx router.Ctx, user *users.User) {
|
func Login(ctx router.Ctx, user *users.User) {
|
||||||
sess := session.Create(user.ID)
|
sess := ctx.UserValue("session").(*session.Session)
|
||||||
setSessionCookie(ctx, sess.ID)
|
sess.RegenerateID()
|
||||||
|
sess.Set("user_id", user.ID)
|
||||||
|
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
|
||||||
|
session.Store(sess)
|
||||||
|
|
||||||
ctx.SetUserValue("session", sess)
|
ctx.SetUserValue("session", sess)
|
||||||
ctx.SetUserValue("user", user)
|
ctx.SetUserValue("user", user)
|
@ -199,3 +199,25 @@ func StoreTokenInCookie(ctx router.Ctx, token string) {
|
|||||||
func GetTokenFromCookie(ctx router.Ctx) string {
|
func GetTokenFromCookie(ctx router.Ctx) string {
|
||||||
return string(ctx.Request.Header.Cookie(CookieName))
|
return string(ctx.Request.Header.Cookie(CookieName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Middleware returns a middleware function that automatically validates CSRF tokens
|
||||||
|
// for state-changing HTTP methods (POST, PUT, PATCH, DELETE)
|
||||||
|
func Middleware() router.Middleware {
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
method := string(ctx.Method())
|
||||||
|
|
||||||
|
// Only validate CSRF for state-changing methods
|
||||||
|
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
||||||
|
if !ValidateFormToken(ctx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.WriteString("CSRF validation failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue to next handler
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,117 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"dk/internal/csrf"
|
|
||||||
"dk/internal/router"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CSRFConfig holds configuration for CSRF middleware
|
|
||||||
type CSRFConfig struct {
|
|
||||||
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
|
|
||||||
SkipMethods []string
|
|
||||||
// Custom failure handler (default: returns 403)
|
|
||||||
FailureHandler func(ctx router.Ctx)
|
|
||||||
// Skip CSRF for certain paths
|
|
||||||
SkipPaths []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// CSRF creates a CSRF protection middleware
|
|
||||||
func CSRF(config ...CSRFConfig) router.Middleware {
|
|
||||||
cfg := CSRFConfig{
|
|
||||||
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
|
|
||||||
FailureHandler: func(ctx router.Ctx) {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
||||||
ctx.SetContentType("text/plain")
|
|
||||||
ctx.WriteString("CSRF token validation failed")
|
|
||||||
},
|
|
||||||
SkipPaths: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply custom config if provided
|
|
||||||
if len(config) > 0 {
|
|
||||||
if len(config[0].SkipMethods) > 0 {
|
|
||||||
cfg.SkipMethods = config[0].SkipMethods
|
|
||||||
}
|
|
||||||
if config[0].FailureHandler != nil {
|
|
||||||
cfg.FailureHandler = config[0].FailureHandler
|
|
||||||
}
|
|
||||||
if len(config[0].SkipPaths) > 0 {
|
|
||||||
cfg.SkipPaths = config[0].SkipPaths
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next router.Handler) router.Handler {
|
|
||||||
return func(ctx router.Ctx, params []string) {
|
|
||||||
method := string(ctx.Method())
|
|
||||||
path := string(ctx.Path())
|
|
||||||
|
|
||||||
// Skip CSRF validation for certain methods
|
|
||||||
shouldSkip := slices.Contains(cfg.SkipMethods, method)
|
|
||||||
|
|
||||||
// Skip CSRF validation for certain paths
|
|
||||||
if !shouldSkip {
|
|
||||||
if slices.Contains(cfg.SkipPaths, path) {
|
|
||||||
shouldSkip = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CSRF protection now works for both authenticated and guest users
|
|
||||||
// Remove the skip for non-authenticated users
|
|
||||||
|
|
||||||
if shouldSkip {
|
|
||||||
next(ctx, params)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate CSRF token for protected methods
|
|
||||||
if !csrf.ValidateFormToken(ctx) {
|
|
||||||
cfg.FailureHandler(ctx)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next(ctx, params)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequireCSRF is a stricter CSRF middleware that always validates tokens
|
|
||||||
func RequireCSRF(failureHandler ...func(router.Ctx)) router.Middleware {
|
|
||||||
handler := func(ctx router.Ctx) {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
||||||
ctx.SetContentType("text/plain")
|
|
||||||
ctx.WriteString("CSRF token required")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(failureHandler) > 0 {
|
|
||||||
handler = failureHandler[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next router.Handler) router.Handler {
|
|
||||||
return func(ctx router.Ctx, params []string) {
|
|
||||||
if !csrf.ValidateFormToken(ctx) {
|
|
||||||
handler(ctx)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next(ctx, params)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CSRFToken returns the current CSRF token for the request
|
|
||||||
func CSRFToken(ctx router.Ctx) string {
|
|
||||||
return csrf.GetToken(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CSRFHiddenField generates a hidden input field for forms
|
|
||||||
func CSRFHiddenField(ctx router.Ctx) string {
|
|
||||||
return csrf.HiddenField(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CSRFMeta generates a meta tag for JavaScript access
|
|
||||||
func CSRFMeta(ctx router.Ctx) string {
|
|
||||||
return csrf.TokenMeta(ctx)
|
|
||||||
}
|
|
@ -1,4 +0,0 @@
|
|||||||
// Package middleware provides reusable HTTP middleware for the Dragon Knight server.
|
|
||||||
// Middleware functions wrap request handlers to add cross-cutting functionality
|
|
||||||
// like timing, logging, authentication, and request processing.
|
|
||||||
package middleware
|
|
@ -4,8 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
"dk/internal/csrf"
|
"dk/internal/csrf"
|
||||||
"dk/internal/middleware"
|
|
||||||
"dk/internal/models/users"
|
"dk/internal/models/users"
|
||||||
"dk/internal/password"
|
"dk/internal/password"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
@ -18,7 +18,7 @@ import (
|
|||||||
// RegisterAuthRoutes sets up authentication routes
|
// RegisterAuthRoutes sets up authentication routes
|
||||||
func RegisterAuthRoutes(r *router.Router) {
|
func RegisterAuthRoutes(r *router.Router) {
|
||||||
guests := r.Group("")
|
guests := r.Group("")
|
||||||
guests.Use(middleware.RequireGuest())
|
guests.Use(auth.RequireGuest())
|
||||||
|
|
||||||
guests.Get("/login", showLogin)
|
guests.Get("/login", showLogin)
|
||||||
guests.Post("/login", processLogin)
|
guests.Post("/login", processLogin)
|
||||||
@ -26,7 +26,7 @@ func RegisterAuthRoutes(r *router.Router) {
|
|||||||
guests.Post("/register", processRegister)
|
guests.Post("/register", processRegister)
|
||||||
|
|
||||||
authed := r.Group("")
|
authed := r.Group("")
|
||||||
authed.Use(middleware.RequireAuth())
|
authed.Use(auth.RequireAuth())
|
||||||
|
|
||||||
authed.Post("/logout", processLogout)
|
authed.Post("/logout", processLogout)
|
||||||
}
|
}
|
||||||
@ -59,12 +59,6 @@ func showLogin(ctx router.Ctx, _ []string) {
|
|||||||
|
|
||||||
// processLogin handles login form submission
|
// processLogin handles login form submission
|
||||||
func processLogin(ctx router.Ctx, _ []string) {
|
func processLogin(ctx router.Ctx, _ []string) {
|
||||||
if !csrf.ValidateFormToken(ctx) {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
||||||
ctx.WriteString("CSRF validation failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
email := strings.TrimSpace(string(ctx.PostArgs().Peek("id")))
|
email := strings.TrimSpace(string(ctx.PostArgs().Peek("id")))
|
||||||
userPassword := string(ctx.PostArgs().Peek("password"))
|
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||||
|
|
||||||
@ -81,13 +75,7 @@ func processLogin(ctx router.Ctx, _ []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
middleware.Login(ctx, user)
|
auth.Login(ctx, user)
|
||||||
|
|
||||||
// Set success message
|
|
||||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
|
||||||
sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username))
|
|
||||||
session.Store(sess)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transfer CSRF token from cookie to session for authenticated user
|
// Transfer CSRF token from cookie to session for authenticated user
|
||||||
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||||
@ -129,12 +117,6 @@ func showRegister(ctx router.Ctx, _ []string) {
|
|||||||
|
|
||||||
// processRegister handles registration form submission
|
// processRegister handles registration form submission
|
||||||
func processRegister(ctx router.Ctx, _ []string) {
|
func processRegister(ctx router.Ctx, _ []string) {
|
||||||
if !csrf.ValidateFormToken(ctx) {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
||||||
ctx.WriteString("CSRF validation failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
username := strings.TrimSpace(string(ctx.PostArgs().Peek("username")))
|
username := strings.TrimSpace(string(ctx.PostArgs().Peek("username")))
|
||||||
email := strings.TrimSpace(string(ctx.PostArgs().Peek("email")))
|
email := strings.TrimSpace(string(ctx.PostArgs().Peek("email")))
|
||||||
userPassword := string(ctx.PostArgs().Peek("password"))
|
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||||
@ -176,8 +158,15 @@ func processRegister(ctx router.Ctx, _ []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store old session ID before creating new one
|
||||||
|
oldSess := ctx.UserValue("session").(*session.Session)
|
||||||
|
oldSessionID := oldSess.ID
|
||||||
|
|
||||||
// Auto-login after registration
|
// Auto-login after registration
|
||||||
middleware.Login(ctx, user)
|
auth.Login(ctx, user)
|
||||||
|
|
||||||
|
// Clean up old guest session
|
||||||
|
session.Delete(oldSessionID)
|
||||||
|
|
||||||
// Set success message
|
// Set success message
|
||||||
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
if sess := ctx.UserValue("session").(*session.Session); sess != nil {
|
||||||
@ -197,14 +186,7 @@ func processRegister(ctx router.Ctx, _ []string) {
|
|||||||
|
|
||||||
// processLogout handles logout
|
// processLogout handles logout
|
||||||
func processLogout(ctx router.Ctx, params []string) {
|
func processLogout(ctx router.Ctx, params []string) {
|
||||||
// Validate CSRF token
|
auth.Logout(ctx)
|
||||||
if !csrf.ValidateFormToken(ctx) {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
||||||
ctx.WriteString("CSRF validation failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
middleware.Logout(ctx)
|
|
||||||
ctx.Redirect("/", fasthttp.StatusFound)
|
ctx.Redirect("/", fasthttp.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package routes
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"dk/internal/actions"
|
"dk/internal/actions"
|
||||||
|
"dk/internal/auth"
|
||||||
"dk/internal/helpers"
|
"dk/internal/helpers"
|
||||||
"dk/internal/middleware"
|
"dk/internal/middleware"
|
||||||
"dk/internal/models/items"
|
"dk/internal/models/items"
|
||||||
@ -27,12 +28,12 @@ type Map struct {
|
|||||||
|
|
||||||
func RegisterTownRoutes(r *router.Router) {
|
func RegisterTownRoutes(r *router.Router) {
|
||||||
group := r.Group("/town")
|
group := r.Group("/town")
|
||||||
group.Use(middleware.RequireAuth())
|
group.Use(auth.RequireAuth())
|
||||||
group.Use(middleware.RequireTown())
|
group.Use(middleware.RequireTown())
|
||||||
|
|
||||||
group.Get("/", showTown)
|
group.Get("/", showTown)
|
||||||
group.Get("/inn", showInn)
|
group.Get("/inn", showInn)
|
||||||
group.WithMiddleware(middleware.CSRF()).Post("/inn", rest)
|
group.Post("/inn", rest)
|
||||||
group.Get("/shop", showShop)
|
group.Get("/shop", showShop)
|
||||||
group.Get("/shop/buy/:id", buyItem)
|
group.Get("/shop/buy/:id", buyItem)
|
||||||
group.Get("/maps", showMaps)
|
group.Get("/maps", showMaps)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionManager handles session storage and persistence
|
// SessionManager handles session storage and persistence
|
||||||
@ -15,6 +16,13 @@ type SessionManager struct {
|
|||||||
|
|
||||||
var Manager *SessionManager
|
var Manager *SessionManager
|
||||||
|
|
||||||
|
// sessionData represents session data for JSON serialization (excludes ID)
|
||||||
|
type sessionData struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
// Init initializes the global session manager
|
// Init initializes the global session manager
|
||||||
func Init(filePath string) {
|
func Init(filePath string) {
|
||||||
if Manager != nil {
|
if Manager != nil {
|
||||||
@ -112,15 +120,21 @@ func (sm *SessionManager) load() {
|
|||||||
return // File doesn't exist or can't be read
|
return // File doesn't exist or can't be read
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessions map[string]*Session
|
var sessionsData map[string]*sessionData
|
||||||
if err := json.Unmarshal(data, &sessions); err != nil {
|
if err := json.Unmarshal(data, &sessionsData); err != nil {
|
||||||
return // Invalid JSON
|
return // Invalid JSON
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
sm.mu.Lock()
|
sm.mu.Lock()
|
||||||
for id, sess := range sessions {
|
for id, data := range sessionsData {
|
||||||
if sess != nil && !sess.IsExpired() {
|
if data != nil && data.ExpiresAt > now {
|
||||||
sess.ID = id // Ensure ID consistency
|
sess := &Session{
|
||||||
|
ID: id,
|
||||||
|
UserID: data.UserID,
|
||||||
|
ExpiresAt: data.ExpiresAt,
|
||||||
|
Data: data.Data,
|
||||||
|
}
|
||||||
sm.sessions[id] = sess
|
sm.sessions[id] = sess
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,7 +150,18 @@ func (sm *SessionManager) Save() error {
|
|||||||
sm.Cleanup() // Remove expired sessions before saving
|
sm.Cleanup() // Remove expired sessions before saving
|
||||||
|
|
||||||
sm.mu.RLock()
|
sm.mu.RLock()
|
||||||
data, err := json.MarshalIndent(sm.sessions, "", "\t")
|
|
||||||
|
// Convert sessions to sessionData (without ID field)
|
||||||
|
sessionsData := make(map[string]*sessionData, len(sm.sessions))
|
||||||
|
for id, sess := range sm.sessions {
|
||||||
|
sessionsData[id] = &sessionData{
|
||||||
|
UserID: sess.UserID,
|
||||||
|
ExpiresAt: sess.ExpiresAt,
|
||||||
|
Data: sess.Data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(sessionsData, "", "\t")
|
||||||
sm.mu.RUnlock()
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -16,7 +16,7 @@ const (
|
|||||||
type Session struct {
|
type Session struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
UserID int `json:"user_id"` // 0 for guest sessions
|
UserID int `json:"user_id"` // 0 for guest sessions
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
Data map[string]any `json:"data"`
|
Data map[string]any `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -25,19 +25,19 @@ func New(userID int) *Session {
|
|||||||
return &Session{
|
return &Session{
|
||||||
ID: generateID(),
|
ID: generateID(),
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
ExpiresAt: time.Now().Add(DefaultExpiration),
|
ExpiresAt: time.Now().Add(DefaultExpiration).Unix(),
|
||||||
Data: make(map[string]any),
|
Data: make(map[string]any),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsExpired checks if the session has expired
|
// IsExpired checks if the session has expired
|
||||||
func (s *Session) IsExpired() bool {
|
func (s *Session) IsExpired() bool {
|
||||||
return time.Now().After(s.ExpiresAt)
|
return time.Now().Unix() > s.ExpiresAt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Touch extends the session expiration
|
// Touch extends the session expiration
|
||||||
func (s *Session) Touch() {
|
func (s *Session) Touch() {
|
||||||
s.ExpiresAt = time.Now().Add(DefaultExpiration)
|
s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set stores a value in the session
|
// Set stores a value in the session
|
||||||
@ -71,6 +71,19 @@ func (s *Session) GetFlash(key string) (any, bool) {
|
|||||||
return value, exists
|
return value, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegenerateID creates a new session ID and updates storage
|
||||||
|
func (s *Session) RegenerateID() {
|
||||||
|
oldID := s.ID
|
||||||
|
s.ID = generateID()
|
||||||
|
|
||||||
|
if Manager != nil {
|
||||||
|
Manager.mu.Lock()
|
||||||
|
delete(Manager.sessions, oldID)
|
||||||
|
Manager.sessions[s.ID] = s
|
||||||
|
Manager.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// generateID creates a random session ID
|
// generateID creates a random session ID
|
||||||
func generateID() string {
|
func generateID() string {
|
||||||
bytes := make([]byte, IDLength)
|
bytes := make([]byte, IDLength)
|
||||||
@ -106,3 +119,8 @@ func Stats() (total, active int) {
|
|||||||
func Close() error {
|
func Close() error {
|
||||||
return Manager.Close()
|
return Manager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegenerateID regenerates the session ID for security (package-level convenience)
|
||||||
|
func RegenerateID(sess *Session) {
|
||||||
|
sess.RegenerateID()
|
||||||
|
}
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
package components
|
package components
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"dk/internal/auth"
|
||||||
"dk/internal/helpers"
|
"dk/internal/helpers"
|
||||||
"dk/internal/middleware"
|
|
||||||
"dk/internal/models/spells"
|
"dk/internal/models/spells"
|
||||||
"dk/internal/models/towns"
|
"dk/internal/models/towns"
|
||||||
|
"dk/internal/models/users"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,11 +14,12 @@ import (
|
|||||||
func LeftAside(ctx router.Ctx) map[string]any {
|
func LeftAside(ctx router.Ctx) map[string]any {
|
||||||
data := map[string]any{}
|
data := map[string]any{}
|
||||||
|
|
||||||
user := middleware.GetCurrentUser(ctx)
|
if !auth.IsAuthenticated(ctx) {
|
||||||
if user == nil {
|
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user := ctx.UserValue("user").(*users.User)
|
||||||
|
|
||||||
// Build owned town maps list
|
// Build owned town maps list
|
||||||
if user.Towns != "" {
|
if user.Towns != "" {
|
||||||
townMap := helpers.NewOrderedMap[int, string]()
|
townMap := helpers.NewOrderedMap[int, string]()
|
||||||
@ -37,11 +39,12 @@ func LeftAside(ctx router.Ctx) map[string]any {
|
|||||||
func RightAside(ctx router.Ctx) map[string]any {
|
func RightAside(ctx router.Ctx) map[string]any {
|
||||||
data := map[string]any{}
|
data := map[string]any{}
|
||||||
|
|
||||||
user := middleware.GetCurrentUser(ctx)
|
if !auth.IsAuthenticated(ctx) {
|
||||||
if user == nil {
|
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user := ctx.UserValue("user").(*users.User)
|
||||||
|
|
||||||
hpPct := helpers.ClampPct(float64(user.HP), float64(user.MaxHP), 0, 100)
|
hpPct := helpers.ClampPct(float64(user.HP), float64(user.MaxHP), 0, 100)
|
||||||
data["hppct"] = hpPct
|
data["hppct"] = hpPct
|
||||||
data["mppct"] = helpers.ClampPct(float64(user.MP), float64(user.MaxMP), 0, 100)
|
data["mppct"] = helpers.ClampPct(float64(user.MP), float64(user.MaxMP), 0, 100)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
"dk/internal/csrf"
|
"dk/internal/csrf"
|
||||||
"dk/internal/middleware"
|
"dk/internal/middleware"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
@ -28,12 +29,12 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin
|
|||||||
|
|
||||||
data := map[string]any{
|
data := map[string]any{
|
||||||
"_title": PageTitle(title),
|
"_title": PageTitle(title),
|
||||||
"authenticated": middleware.IsAuthenticated(ctx),
|
"authenticated": auth.IsAuthenticated(ctx),
|
||||||
"csrf": csrf.HiddenField(ctx),
|
"csrf": csrf.HiddenField(ctx),
|
||||||
"_totaltime": middleware.GetRequestTime(ctx),
|
"_totaltime": middleware.GetRequestTime(ctx),
|
||||||
"_version": "1.0.0",
|
"_version": "1.0.0",
|
||||||
"_build": "dev",
|
"_build": "dev",
|
||||||
"user": middleware.GetCurrentUser(ctx),
|
"user": auth.GetCurrentUser(ctx),
|
||||||
"_memalloc": m.Alloc / 1024 / 1024,
|
"_memalloc": m.Alloc / 1024 / 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
package components
|
package components
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"dk/internal/auth"
|
||||||
"dk/internal/csrf"
|
"dk/internal/csrf"
|
||||||
"dk/internal/middleware"
|
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateTopNav generates the top navigation HTML based on authentication status
|
// GenerateTopNav generates the top navigation HTML based on authentication status
|
||||||
func GenerateTopNav(ctx router.Ctx) string {
|
func GenerateTopNav(ctx router.Ctx) string {
|
||||||
if middleware.IsAuthenticated(ctx) {
|
if auth.IsAuthenticated(ctx) {
|
||||||
return fmt.Sprintf(`<form action="/logout" method="post" class="logout">
|
return fmt.Sprintf(`<form action="/logout" method="post" class="logout">
|
||||||
%s
|
%s
|
||||||
<button class="img-button" type="submit"><img src="/assets/images/button_logout.gif" alt="Log Out" title="Log Out"></button>
|
<button class="img-button" type="submit"><img src="/assets/images/button_logout.gif" alt="Log Out" title="Log Out"></button>
|
||||||
|
12
main.go
12
main.go
@ -9,6 +9,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/csrf"
|
||||||
"dk/internal/middleware"
|
"dk/internal/middleware"
|
||||||
"dk/internal/models/babble"
|
"dk/internal/models/babble"
|
||||||
"dk/internal/models/control"
|
"dk/internal/models/control"
|
||||||
@ -165,16 +167,16 @@ func start(port string) error {
|
|||||||
return fmt.Errorf("failed to load models: %w", err)
|
return fmt.Errorf("failed to load models: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Init("sessions.json")
|
session.Init("data/_sessions.json")
|
||||||
|
|
||||||
r := router.New()
|
r := router.New()
|
||||||
r.Use(middleware.Timing())
|
r.Use(middleware.Timing())
|
||||||
r.Use(middleware.Auth())
|
r.Use(auth.Middleware())
|
||||||
r.Use(middleware.CSRF())
|
r.Use(csrf.Middleware())
|
||||||
|
|
||||||
r.Get("/", routes.Index)
|
r.Get("/", routes.Index)
|
||||||
r.WithMiddleware(middleware.RequireAuth()).Get("/explore", routes.Explore)
|
r.WithMiddleware(auth.RequireAuth()).Get("/explore", routes.Explore)
|
||||||
r.WithMiddleware(middleware.RequireAuth()).Post("/move", routes.Move)
|
r.WithMiddleware(auth.RequireAuth()).Post("/move", routes.Move)
|
||||||
routes.RegisterAuthRoutes(r)
|
routes.RegisterAuthRoutes(r)
|
||||||
routes.RegisterTownRoutes(r)
|
routes.RegisterTownRoutes(r)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user