diff --git a/.gitignore b/.gitignore index 4ae2e3c..f2560a6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ _sessions.json users.json /tmp +wal.log \ No newline at end of file diff --git a/data/control.json b/data/control.json index 853319b..c71d452 100644 --- a/data/control.json +++ b/data/control.json @@ -1,8 +1,11 @@ -{ - "world_size": 200, - "open": 1, - "admin_email": "", - "class_1_name": "Mage", - "class_2_name": "Warrior", - "class_3_name": "Paladin" -} \ No newline at end of file +[ + { + "id": 1, + "world_size": 200, + "open": 1, + "admin_email": "", + "class_1_name": "Mage", + "class_2_name": "Warrior", + "class_3_name": "Paladin" + } +] \ No newline at end of file diff --git a/data/fights.json b/data/fights.json index b322ad9..f077021 100644 --- a/data/fights.json +++ b/data/fights.json @@ -824,165 +824,29 @@ "won": false, "reward_gold": 0, "reward_exp": 0, - "actions": [ - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - }, - { - "t": 1, - "d": 1 - }, - { - "t": 8, - "d": 1, - "n": "Drakelor" - } - ], - "created": 1755274841, - "updated": 1755275436 + "actions": [], + "created": 1755222893, + "updated": 1755222893 }, { - "id": 14, + "id": 5, "user_id": 1, - "monster_id": 4, - "monster_hp": 0, + "monster_id": 5, + "monster_hp": 10, "monster_max_hp": 10, "monster_sleep": 0, - "monster_immune": 0, + "monster_immune": 1, "uber_damage": 0, "uber_defense": 0, "first_strike": true, - "turn": 5, + "turn": 1, "ran_away": false, - "victory": true, - "won": true, - "reward_gold": 1, - "reward_exp": 3, - "actions": [ - { - "t": 1, - "d": 2 - }, - { - "t": 8, - "d": 1, - "n": "Creature" - }, - { - "t": 1, - "d": 2 - }, - { - "t": 8, - "d": 1, - "n": "Creature" - }, - { - "t": 1, - "d": 2 - }, - { - "t": 8, - "d": 1, - "n": "Creature" - }, - { - "t": 1, - "d": 2 - }, - { - "t": 8, - "d": 1, - "n": "Creature" - }, - { - "t": 1, - "d": 2 - }, - { - "t": 11, - "n": "Creature" - } - ], - "created": 1755275442, - "updated": 1755275447 + "victory": false, + "won": false, + "reward_gold": 0, + "reward_exp": 0, + "actions": [], + "created": 1755608716, + "updated": 1755608716 } ] \ No newline at end of file diff --git a/go.mod b/go.mod index 0d37e7c..fb69a86 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,17 @@ module dk -go 1.24.6 +go 1.25.0 require ( - github.com/valyala/fasthttp v1.64.0 - golang.org/x/crypto v0.41.0 + git.sharkk.net/Sharkk/Nigiri v1.0.0 + git.sharkk.net/Sharkk/Sushi v1.1.0 + github.com/valyala/fasthttp v1.65.0 ) require ( github.com/andybalholm/brotli v1.2.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/crypto v0.41.0 // indirect golang.org/x/sys v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index eec0dd9..b49a838 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,15 @@ +git.sharkk.net/Sharkk/Nigiri v1.0.0 h1:N0MvWOoX54iXjR8D1LqGIFrtMAPdaoj/32n13Ou/p90= +git.sharkk.net/Sharkk/Nigiri v1.0.0/go.mod h1:HWpMtXaodPXE7dZXQ6tbZNL0DRV9PT65D0DOV0NAwsM= +git.sharkk.net/Sharkk/Sushi v1.1.0 h1:mOcQlcLEl941ozjbOzHOnBAmsOcZ7Q5BkFowILwxNow= +git.sharkk.net/Sharkk/Sushi v1.1.0/go.mod h1:S84ACGkuZ+BKzBO4lb5WQnm5aw9+l7VSO2T1bjzxL3o= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og= -github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= diff --git a/internal/auth/auth.go b/internal/auth/auth.go deleted file mode 100644 index 84d6630..0000000 --- a/internal/auth/auth.go +++ /dev/null @@ -1,171 +0,0 @@ -package auth - -import ( - "dk/internal/cookies" - "dk/internal/helpers" - "dk/internal/models/users" - "dk/internal/router" - "dk/internal/session" - "fmt" - "time" - - "github.com/valyala/fasthttp" -) - -const SessionCookieName = "dk_session" - -func Middleware() router.Middleware { - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - sessionID := cookies.GetCookie(ctx, SessionCookieName) - var sess *session.Session - - if sessionID != "" { - if existingSess, exists := session.Get(sessionID); exists { - sess = existingSess - sess.Touch() - - if sess.UserID > 0 { // User session - user, err := users.Find(sess.UserID) - if err == nil && user != nil { - ctx.SetUserValue("user", user) - } else { - // User not found, reset to guest session - sess.SetUserID(0) - } - } - session.Store(sess) - setSessionCookie(ctx, sessionID) - } - } - - // Create guest session if none exists - if sess == nil { - sess = session.Create(0) // Guest session - setSessionCookie(ctx, sess.ID) - } - - ctx.SetUserValue("session", sess) - next(ctx, params) - } - } -} - -func RequireAuth(paths ...string) router.Middleware { - redirect := "/login" - if len(paths) > 0 && paths[0] != "" { - redirect = paths[0] - } - - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - if !IsAuthenticated(ctx) { - ctx.Redirect(redirect, fasthttp.StatusFound) - return - } - - user := ctx.UserValue("user").(*users.User) - user.UpdateLastOnline() - user.Save() - - next(ctx, params) - } - } -} - -func RequireGuest(paths ...string) router.Middleware { - redirect := "/" - if len(paths) > 0 && paths[0] != "" { - redirect = paths[0] - } - - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - if IsAuthenticated(ctx) { - fmt.Println("RequireGuest: user is authenticated") - ctx.Redirect(redirect, fasthttp.StatusFound) - return - } - next(ctx, params) - } - } -} - -func IsAuthenticated(ctx router.Ctx) bool { - if user, ok := ctx.UserValue("user").(*users.User); ok && user != nil { - return true - } - return false -} - -func GetCurrentUser(ctx router.Ctx) *users.User { - if user, ok := ctx.UserValue("user").(*users.User); ok { - return user - } - return nil -} - -func GetCurrentSession(ctx router.Ctx) *session.Session { - if sess, ok := ctx.UserValue("session").(*session.Session); ok { - return sess - } - return nil -} - -func Login(ctx router.Ctx, user *users.User) { - sess := ctx.UserValue("session").(*session.Session) - - // Update the session to be authenticated - sess.SetUserID(user.ID) // This updates the struct field - sess.RegenerateID() // Generate new ID for security - sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username)) - - // Remove any old user_id from session data if it exists - sess.Delete("user_id") - - session.Store(sess) - - // Update context values - ctx.SetUserValue("session", sess) - ctx.SetUserValue("user", user) - - // Update cookie with new session ID - setSessionCookie(ctx, sess.ID) -} - -func Logout(ctx router.Ctx) { - sess := ctx.UserValue("session").(*session.Session) - if sess != nil { - // Convert back to guest session - sess.SetUserID(0) // Reset to guest - sess.RegenerateID() // Generate new ID for security - - // Clean up any user-related session data - sess.Delete("user_id") - - session.Store(sess) - ctx.SetUserValue("session", sess) - - // Update cookie with new session ID - setSessionCookie(ctx, sess.ID) - } - - ctx.SetUserValue("user", nil) -} - -// Helper functions for session cookies -func setSessionCookie(ctx router.Ctx, sessionID string) { - cookies.SetSecureCookie(ctx, cookies.CookieOptions{ - Name: SessionCookieName, - Value: sessionID, - Path: "/", - Expires: time.Now().Add(24 * time.Hour), - HTTPOnly: true, - Secure: helpers.IsHTTPS(ctx), - SameSite: "lax", - }) -} - -func deleteSessionCookie(ctx router.Ctx) { - cookies.DeleteCookie(ctx, SessionCookieName) -} diff --git a/internal/components/asides.go b/internal/components/asides.go index fb2878a..e41905f 100644 --- a/internal/components/asides.go +++ b/internal/components/asides.go @@ -1,21 +1,21 @@ package components import ( - "dk/internal/auth" "dk/internal/helpers" "dk/internal/models/spells" "dk/internal/models/towns" "dk/internal/models/users" - "dk/internal/router" "fmt" + + sushi "git.sharkk.net/Sharkk/Sushi" ) // LeftAside generates the data map for the left sidebar. // Returns an empty map when not auth'd. -func LeftAside(ctx router.Ctx) map[string]any { +func LeftAside(ctx sushi.Ctx) map[string]any { data := map[string]any{} - if !auth.IsAuthenticated(ctx) { + if !ctx.IsAuthenticated() { return data } @@ -37,10 +37,10 @@ func LeftAside(ctx router.Ctx) map[string]any { // RightAside generates the data map for the right sidebar. // Returns an empty map when not auth'd. -func RightAside(ctx router.Ctx) map[string]any { +func RightAside(ctx sushi.Ctx) map[string]any { data := map[string]any{} - if !auth.IsAuthenticated(ctx) { + if !ctx.IsAuthenticated() { return data } diff --git a/internal/components/page.go b/internal/components/page.go index 6620c66..dd58458 100644 --- a/internal/components/page.go +++ b/internal/components/page.go @@ -6,16 +6,14 @@ import ( "runtime" "strings" - "dk/internal/auth" - "dk/internal/csrf" - "dk/internal/middleware" - "dk/internal/router" - "dk/internal/session" "dk/internal/template" + + sushi "git.sharkk.net/Sharkk/Sushi" + "git.sharkk.net/Sharkk/Sushi/csrf" ) // RenderPage renders a page using the layout template with common data and additional custom data -func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[string]any) error { +func RenderPage(ctx sushi.Ctx, title, tmplPath string, additionalData map[string]any) error { if template.Cache == nil { return fmt.Errorf("template.Cache not initialized") } @@ -25,19 +23,19 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin return fmt.Errorf("failed to load layout template: %w", err) } - sess := ctx.UserValue("session").(*session.Session) - var m runtime.MemStats runtime.ReadMemStats(&m) + sess := ctx.GetCurrentSession() + data := map[string]any{ "_title": PageTitle(title), - "authenticated": auth.IsAuthenticated(ctx), + "authenticated": ctx.IsAuthenticated(), "csrf": csrf.HiddenField(ctx), - "_totaltime": middleware.GetRequestTime(ctx), + "_totaltime": ctx.UserValue("request_time"), "_version": "1.0.0", "_build": "dev", - "user": auth.GetCurrentUser(ctx), + "user": ctx.GetCurrentUser(), "_memalloc": m.Alloc / 1024 / 1024, "_errormsg": sess.GetFlashMessage("error"), "_successmsg": sess.GetFlashMessage("success"), @@ -47,8 +45,7 @@ func RenderPage(ctx router.Ctx, title, tmplPath string, additionalData map[strin maps.Copy(data, RightAside(ctx)) maps.Copy(data, additionalData) - tmpl.WriteTo(ctx, data) - return nil + return tmpl.WriteTo(ctx, data) } // PageTitle returns a proper title for a rendered page. If an empty string diff --git a/internal/cookies/cookies.go b/internal/cookies/cookies.go deleted file mode 100644 index d3866b4..0000000 --- a/internal/cookies/cookies.go +++ /dev/null @@ -1,77 +0,0 @@ -package cookies - -import ( - "time" - - "github.com/valyala/fasthttp" -) - -type CookieOptions struct { - Name string - Value string - Path string - Domain string - Expires time.Time - MaxAge int - Secure bool - HTTPOnly bool - SameSite string -} - -func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) { - cookie := &fasthttp.Cookie{} - - cookie.SetKey(opts.Name) - cookie.SetValue(opts.Value) - - if opts.Path != "" { - cookie.SetPath(opts.Path) - } else { - cookie.SetPath("/") - } - - if opts.Domain != "" { - cookie.SetDomain(opts.Domain) - } - - if !opts.Expires.IsZero() { - cookie.SetExpire(opts.Expires) - } - - if opts.MaxAge > 0 { - cookie.SetMaxAge(opts.MaxAge) - } - - cookie.SetSecure(opts.Secure) - cookie.SetHTTPOnly(opts.HTTPOnly) - - switch opts.SameSite { - case "strict": - cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) - case "lax": - cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) - case "none": - cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) - default: - cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) - } - - ctx.Response.Header.SetCookie(cookie) -} - -func GetCookie(ctx *fasthttp.RequestCtx, name string) string { - return string(ctx.Request.Header.Cookie(name)) -} - -func DeleteCookie(ctx *fasthttp.RequestCtx, name string) { - SetSecureCookie(ctx, CookieOptions{ - Name: name, - Value: "", - Path: "/", - Expires: time.Unix(0, 0), - MaxAge: -1, - HTTPOnly: true, - Secure: true, - SameSite: "lax", - }) -} diff --git a/internal/csrf/csrf.go b/internal/csrf/csrf.go deleted file mode 100644 index df21267..0000000 --- a/internal/csrf/csrf.go +++ /dev/null @@ -1,195 +0,0 @@ -// Package csrf provides Cross-Site Request Forgery (CSRF) protection -// with session-based token storage and form helpers. -// -// # Basic Usage -// -// // Generate token and store in session -// token := csrf.GenerateToken(ctx) -// -// // In templates - generate hidden input field -// hiddenField := csrf.HiddenField(ctx) -// -// // Verify form submission -// if !csrf.ValidateToken(ctx, formToken) { -// // Handle CSRF validation failure -// } -// -// # Middleware Integration -// -// // Add CSRF middleware to protected routes -// r.Use(csrf.Middleware()) -package csrf - -import ( - "crypto/rand" - "crypto/subtle" - "encoding/base64" - "fmt" - - "dk/internal/router" - "dk/internal/session" -) - -const ( - TokenLength = 32 - TokenFieldName = "_csrf_token" - SessionKey = "csrf_token" - SessionCtxKey = "session" -) - -// GetCurrentSession retrieves the session from context -func GetCurrentSession(ctx router.Ctx) *session.Session { - if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok { - return sess - } - return nil -} - -// GenerateToken creates a new CSRF token and stores it in the session -func GenerateToken(ctx router.Ctx) string { - // Generate cryptographically secure random bytes - tokenBytes := make([]byte, TokenLength) - if _, err := rand.Read(tokenBytes); err != nil { - // Fallback - this should never happen in practice - return "" - } - - token := base64.URLEncoding.EncodeToString(tokenBytes) - - // Store token in session (both guests and authenticated users have sessions) - if sess := GetCurrentSession(ctx); sess != nil { - StoreToken(sess, token) - session.Store(sess) - } - - return token -} - -// GetToken retrieves the current CSRF token from session, generating one if needed -func GetToken(ctx router.Ctx) string { - sess := GetCurrentSession(ctx) - if sess == nil { - return "" // No session available - } - - // Check for existing token - if existingToken := GetStoredToken(sess); existingToken != "" { - return existingToken - } - - // Generate new token if none exists - return GenerateToken(ctx) -} - -// ValidateToken verifies a CSRF token against the stored session token -func ValidateToken(ctx router.Ctx, submittedToken string) bool { - if submittedToken == "" { - return false - } - - sess := GetCurrentSession(ctx) - if sess == nil { - return false // No session - } - - storedToken := GetStoredToken(sess) - if storedToken == "" { - return false // No stored token - } - - // Use constant-time comparison to prevent timing attacks - return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1 -} - -// StoreToken saves a CSRF token in the session -func StoreToken(sess *session.Session, token string) { - sess.Set(SessionKey, token) -} - -// GetStoredToken retrieves the CSRF token from session -func GetStoredToken(sess *session.Session) string { - if token, ok := sess.Get(SessionKey); ok { - if tokenStr, ok := token.(string); ok { - return tokenStr - } - } - return "" -} - -// RotateToken generates a new token and replaces the old one in the session -func RotateToken(ctx router.Ctx) string { - sess := GetCurrentSession(ctx) - if sess == nil { - return "" - } - - // Generate new token (this will automatically store it) - newToken := GenerateToken(ctx) - return newToken -} - -// HiddenField generates an HTML hidden input field with the CSRF token -func HiddenField(ctx router.Ctx) string { - token := GetToken(ctx) - if token == "" { - return "" // No token available - } - - return fmt.Sprintf(``, - TokenFieldName, token) -} - -// TokenMeta generates HTML meta tag for JavaScript access to CSRF token -func TokenMeta(ctx router.Ctx) string { - token := GetToken(ctx) - if token == "" { - return "" - } - - return fmt.Sprintf(``, token) -} - -// ValidateFormToken is a convenience function to validate CSRF token from form data -func ValidateFormToken(ctx router.Ctx) bool { - // Try to get token from form data - tokenBytes := ctx.PostArgs().Peek(TokenFieldName) - if len(tokenBytes) == 0 { - // Try from query parameters as fallback - tokenBytes = ctx.QueryArgs().Peek(TokenFieldName) - } - - if len(tokenBytes) == 0 { - return false - } - - return ValidateToken(ctx, string(tokenBytes)) -} - -// GetTokenFromCookie retrieves a CSRF token from cookie (legacy support) -func GetTokenFromCookie(ctx router.Ctx) string { - return string(ctx.Request.Header.Cookie("_csrf")) -} - -// Middleware returns a middleware function that automatically validates CSRF tokens -// for state-changing HTTP methods (POST, PUT, PATCH, DELETE) -func Middleware() router.Middleware { - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - method := string(ctx.Method()) - - // Only validate CSRF for state-changing methods - if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { - if !ValidateFormToken(ctx) { - fmt.Println("Failed CSRF validation.") - RotateToken(ctx) - currentPath := string(ctx.Path()) - ctx.Redirect(currentPath, 302) - return - } - } - - // Continue to next handler - next(ctx, params) - } - } -} diff --git a/internal/middleware/fights.go b/internal/middleware/fights.go deleted file mode 100644 index be4db95..0000000 --- a/internal/middleware/fights.go +++ /dev/null @@ -1,59 +0,0 @@ -package middleware - -import ( - "dk/internal/models/users" - "dk/internal/router" - "strings" - - "github.com/valyala/fasthttp" -) - -// RequireFighting ensures the user is in a fight when accessing fight routes -func RequireFighting() router.Middleware { - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - user, ok := ctx.UserValue("user").(*users.User) - if !ok || user == nil { - ctx.SetStatusCode(fasthttp.StatusUnauthorized) - ctx.SetBodyString("Not authenticated") - return - } - - if !user.IsFighting() { - ctx.Redirect("/", 303) - return - } - - next(ctx, params) - } - } -} - -// HandleFightRedirect redirects users to appropriate locations based on fight status -func HandleFightRedirect() router.Middleware { - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - user, ok := ctx.UserValue("user").(*users.User) - if !ok || user == nil { - next(ctx, params) - return - } - - currentPath := string(ctx.URI().Path()) - - // If user is fighting and not on fight page, redirect to fight - if user.IsFighting() && !strings.HasPrefix(currentPath, "/fight") { - ctx.Redirect("/fight", 303) - return - } - - // If user is not fighting and on fight page, redirect to home - if !user.IsFighting() && strings.HasPrefix(currentPath, "/fight") { - ctx.Redirect("/", 303) - return - } - - next(ctx, params) - } - } -} diff --git a/internal/middleware/timing.go b/internal/middleware/timing.go deleted file mode 100644 index 92ea972..0000000 --- a/internal/middleware/timing.go +++ /dev/null @@ -1,49 +0,0 @@ -package middleware - -import ( - "fmt" - "time" - - "dk/internal/router" -) - -const RequestTimerKey = "request_start_time" - -// Timing adds request timing functionality -func Timing() router.Middleware { - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - startTime := time.Now() - ctx.SetUserValue(RequestTimerKey, startTime) - - next(ctx, params) - } - } -} - -// GetRequestTime returns the total request processing time in seconds (formatted) -func GetRequestTime(ctx router.Ctx) string { - startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time) - if !ok { - return "0" - } - - duration := time.Since(startTime) - seconds := duration.Seconds() - - if seconds < 0.001 { - return "0" - } - - return fmt.Sprintf("%.3f", seconds) -} - -// GetRequestDuration returns the raw duration -func GetRequestDuration(ctx router.Ctx) time.Duration { - startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time) - if !ok { - return 0 - } - - return time.Since(startTime) -} \ No newline at end of file diff --git a/internal/middleware/town.go b/internal/middleware/town.go deleted file mode 100644 index 8a3bb26..0000000 --- a/internal/middleware/town.go +++ /dev/null @@ -1,39 +0,0 @@ -package middleware - -import ( - "dk/internal/models/towns" - "dk/internal/models/users" - "dk/internal/router" - - "github.com/valyala/fasthttp" -) - -// RequireTown ensures the user is in town at valid coordinates -func RequireTown() router.Middleware { - return func(next router.Handler) router.Handler { - return func(ctx router.Ctx, params []string) { - user, ok := ctx.UserValue("user").(*users.User) - if !ok || user == nil { - ctx.SetStatusCode(fasthttp.StatusUnauthorized) - ctx.SetBodyString("Not authenticated") - return - } - - if user.Currently != "In Town" { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.SetBodyString("You must be in town") - return - } - - town, err := towns.ByCoords(user.X, user.Y) - if err != nil || town == nil || town.ID == 0 { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.SetBodyString("Invalid town location") - return - } - - ctx.SetUserValue("town", town) - next(ctx, params) - } - } -} diff --git a/internal/models/babble/babble.go b/internal/models/babble/babble.go index e6be6f4..c29e418 100644 --- a/internal/models/babble/babble.go +++ b/internal/models/babble/babble.go @@ -1,39 +1,22 @@ package babble import ( - "dk/internal/store" "fmt" "sort" "strings" "time" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Babble represents a global chat message in the game type Babble struct { ID int `json:"id"` Posted int64 `json:"posted"` - Author string `json:"author"` + Author string `json:"author" db:"index"` Babble string `json:"babble"` } -func (b *Babble) Save() error { - return GetStore().UpdateWithRebuild(b.ID, b) -} - -func (b *Babble) Delete() error { - GetStore().RemoveWithRebuild(b.ID) - return nil -} - -// Creates a new Babble with sensible defaults -func New() *Babble { - return &Babble{ - Posted: time.Now().Unix(), - Author: "", - Babble: "", - } -} - // Validate checks if babble has valid values func (b *Babble) Validate() error { if b.Posted <= 0 { @@ -48,58 +31,78 @@ func (b *Babble) Validate() error { return nil } -// BabbleStore with enhanced BaseStore -type BabbleStore struct { - *store.BaseStore[Babble] -} - // Global store with singleton pattern -var GetStore = store.NewSingleton(func() *BabbleStore { - bs := &BabbleStore{BaseStore: store.NewBaseStore[Babble]()} +var store *nigiri.BaseStore[Babble] +var db *nigiri.Collection - // Register indices - bs.RegisterIndex("byAuthor", store.BuildStringGroupIndex(func(b *Babble) string { +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Babble]() + + // Register custom indices + store.RegisterIndex("byAuthor", nigiri.BuildStringGroupIndex(func(b *Babble) string { return strings.ToLower(b.Author) })) - bs.RegisterIndex("allByPosted", store.BuildSortedListIndex(func(a, b *Babble) bool { + store.RegisterIndex("allByPosted", nigiri.BuildSortedListIndex(func(a, b *Babble) bool { if a.Posted != b.Posted { return a.Posted > b.Posted // DESC } return a.ID > b.ID // DESC })) - return bs -}) - -// Enhanced CRUD operations -func (bs *BabbleStore) AddBabble(babble *Babble) error { - return bs.AddWithRebuild(babble.ID, babble) + store.RebuildIndices() } -func (bs *BabbleStore) RemoveBabble(id int) { - bs.RemoveWithRebuild(id) +// GetStore returns the babble store +func GetStore() *nigiri.BaseStore[Babble] { + if store == nil { + panic("babble store not initialized - call Initialize first") + } + return store } -func (bs *BabbleStore) UpdateBabble(babble *Babble) error { - return bs.UpdateWithRebuild(babble.ID, babble) +// Creates a new Babble with sensible defaults +func New() *Babble { + return &Babble{ + Posted: time.Now().Unix(), + Author: "", + Babble: "", + } } -// Data persistence -func LoadData(dataPath string) error { - bs := GetStore() - return bs.BaseStore.LoadData(dataPath) +// CRUD operations +func (b *Babble) Save() error { + if b.ID == 0 { + id, err := store.Create(b) + if err != nil { + return err + } + b.ID = id + return nil + } + return store.Update(b.ID, b) } -func SaveData(dataPath string) error { - bs := GetStore() - return bs.BaseStore.SaveData(dataPath) +func (b *Babble) Delete() error { + store.Remove(b.ID) + return nil } -// Query functions using enhanced store +// Insert with ID assignment +func (b *Babble) Insert() error { + id, err := store.Create(b) + if err != nil { + return err + } + b.ID = id + return nil +} + +// Query functions func Find(id int) (*Babble, error) { - bs := GetStore() - babble, exists := bs.Find(id) + babble, exists := store.Find(id) if !exists { return nil, fmt.Errorf("babble with ID %d not found", id) } @@ -107,13 +110,11 @@ func Find(id int) (*Babble, error) { } func All() ([]*Babble, error) { - bs := GetStore() - return bs.AllSorted("allByPosted"), nil + return store.AllSorted("allByPosted"), nil } func ByAuthor(author string) ([]*Babble, error) { - bs := GetStore() - messages := bs.GroupByIndex("byAuthor", strings.ToLower(author)) + messages := store.GroupByIndex("byAuthor", strings.ToLower(author)) // Sort by posted DESC, then ID DESC sort.Slice(messages, func(i, j int) bool { @@ -127,8 +128,7 @@ func ByAuthor(author string) ([]*Babble, error) { } func Recent(limit int) ([]*Babble, error) { - bs := GetStore() - all := bs.AllSorted("allByPosted") + all := store.AllSorted("allByPosted") if limit > len(all) { limit = len(all) } @@ -136,23 +136,20 @@ func Recent(limit int) ([]*Babble, error) { } func Since(since int64) ([]*Babble, error) { - bs := GetStore() - return bs.FilterByIndex("allByPosted", func(b *Babble) bool { + return store.FilterByIndex("allByPosted", func(b *Babble) bool { return b.Posted >= since }), nil } func Between(start, end int64) ([]*Babble, error) { - bs := GetStore() - return bs.FilterByIndex("allByPosted", func(b *Babble) bool { + return store.FilterByIndex("allByPosted", func(b *Babble) bool { return b.Posted >= start && b.Posted <= end }), nil } func Search(term string) ([]*Babble, error) { - bs := GetStore() lowerTerm := strings.ToLower(term) - return bs.FilterByIndex("allByPosted", func(b *Babble) bool { + return store.FilterByIndex("allByPosted", func(b *Babble) bool { return strings.Contains(strings.ToLower(b.Babble), lowerTerm) }), nil } @@ -168,15 +165,6 @@ func RecentByAuthor(author string, limit int) ([]*Babble, error) { return messages[:limit], nil } -// Insert with ID assignment -func (b *Babble) Insert() error { - bs := GetStore() - if b.ID == 0 { - b.ID = bs.GetNextID() - } - return bs.AddBabble(b) -} - // Helper methods func (b *Babble) PostedTime() time.Time { return time.Unix(b.Posted, 0) @@ -279,3 +267,14 @@ func (b *Babble) HasMention(username string) bool { } return false } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/control/control.go b/internal/models/control/control.go index b8abea1..4dd1afb 100644 --- a/internal/models/control/control.go +++ b/internal/models/control/control.go @@ -1,24 +1,22 @@ package control import ( - "encoding/json" "fmt" - "os" "sync" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) var ( - global *Control - configPath string - mu sync.RWMutex + store *nigiri.BaseStore[Control] + db *nigiri.Collection + global *Control + mu sync.RWMutex ) -func init() { - global = New() -} - // Control represents the game control settings type Control struct { + ID int `json:"id"` WorldSize int `json:"world_size"` Open int `json:"open"` AdminEmail string `json:"admin_email"` @@ -27,9 +25,46 @@ type Control struct { Class3Name string `json:"class_3_name"` } +// Init sets up the Nigiri store for control settings +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Control]() + + // Load or create the singleton control instance + all := store.GetAll() + if len(all) == 0 { + // Create default control settings + global = New() + global.ID = 1 + store.Add(1, global) + } else { + // Use the first (and only) control entry + for _, ctrl := range all { + global = ctrl + break + } + // Apply defaults for any missing fields + defaults := New() + if global.WorldSize == 0 { + global.WorldSize = defaults.WorldSize + } + if global.Class1Name == "" { + global.Class1Name = defaults.Class1Name + } + if global.Class2Name == "" { + global.Class2Name = defaults.Class2Name + } + if global.Class3Name == "" { + global.Class3Name = defaults.Class3Name + } + store.Update(global.ID, global) + } +} + // New creates a new Control with sensible defaults func New() *Control { return &Control{ + ID: 1, // Singleton WorldSize: 200, Open: 1, AdminEmail: "", @@ -39,84 +74,56 @@ func New() *Control { } } -// Load loads control settings from JSON file into global instance -func Load(filename string) error { - mu.Lock() - configPath = filename - mu.Unlock() - - data, err := os.ReadFile(filename) - if err != nil { - if os.IsNotExist(err) { - return nil // Keep defaults - } - return fmt.Errorf("failed to read config file: %w", err) - } - - if len(data) == 0 { - return nil - } - - control := &Control{} - if err := json.Unmarshal(data, control); err != nil { - return fmt.Errorf("failed to parse config: %w", err) - } - - // Apply defaults for any missing fields - defaults := New() - if control.WorldSize == 0 { - control.WorldSize = defaults.WorldSize - } - if control.Class1Name == "" { - control.Class1Name = defaults.Class1Name - } - if control.Class2Name == "" { - control.Class2Name = defaults.Class2Name - } - if control.Class3Name == "" { - control.Class3Name = defaults.Class3Name - } - - mu.Lock() - global = control - mu.Unlock() - return nil -} - -// Save saves global control settings to the loaded path -func Save() error { - mu.RLock() - defer mu.RUnlock() - - if configPath == "" { - return fmt.Errorf("no config path set, call Load() first") - } - - data, err := json.MarshalIndent(global, "", "\t") - if err != nil { - return fmt.Errorf("failed to marshal config: %w", err) - } - - if err := os.WriteFile(configPath, data, 0644); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - - return nil -} - // Get returns the global control instance (thread-safe) func Get() *Control { mu.RLock() defer mu.RUnlock() + if global == nil { + panic("control not initialized - call Initialize first") + } return global } // Set updates the global control instance (thread-safe) -func Set(control *Control) { +func Set(control *Control) error { mu.Lock() defer mu.Unlock() + + control.ID = 1 // Ensure it's always ID 1 (singleton) + if err := control.Validate(); err != nil { + return err + } + + if err := store.Update(1, control); err != nil { + return err + } + global = control + return nil } + +// Update updates specific fields of the control settings +func Update(updater func(*Control)) error { + mu.Lock() + defer mu.Unlock() + + // Create a copy to work with + updated := *global + updater(&updated) + + if err := updated.Validate(); err != nil { + return err + } + + if err := store.Update(1, &updated); err != nil { + return err + } + + global = &updated + return nil +} + +// Validate checks if control settings have valid values func (c *Control) Validate() error { if c.WorldSize <= 0 || c.WorldSize > 10000 { return fmt.Errorf("WorldSize must be between 1 and 10000") @@ -124,6 +131,15 @@ func (c *Control) Validate() error { if c.Open != 0 && c.Open != 1 { return fmt.Errorf("Open must be 0 or 1") } + if c.Class1Name == "" { + return fmt.Errorf("Class1Name cannot be empty") + } + if c.Class2Name == "" { + return fmt.Errorf("Class2Name cannot be empty") + } + if c.Class3Name == "" { + return fmt.Errorf("Class3Name cannot be empty") + } return nil } @@ -132,6 +148,17 @@ func (c *Control) IsOpen() bool { return c.Open == 1 } +// SetOpen sets whether the game world is open for new players +func SetOpen(open bool) error { + return Update(func(c *Control) { + if open { + c.Open = 1 + } else { + c.Open = 0 + } + }) +} + // GetClassNames returns all class names as a slice func (c *Control) GetClassNames() []string { classes := make([]string, 0, 3) @@ -209,3 +236,14 @@ func (c *Control) GetWorldBounds() (minX, minY, maxX, maxY int) { radius := c.GetWorldRadius() return -radius, -radius, radius, radius } + +// Legacy compatibility functions (will be removed later) +func Load(filename string) error { + // No longer needed - Nigiri handles this + return nil +} + +func Save() error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/drops/drops.go b/internal/models/drops/drops.go index 4897815..16d6ec7 100644 --- a/internal/models/drops/drops.go +++ b/internal/models/drops/drops.go @@ -1,26 +1,56 @@ package drops import ( - "dk/internal/store" "fmt" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Drop represents a drop item in the game type Drop struct { ID int `json:"id"` - Name string `json:"name"` - Level int `json:"level"` - Type int `json:"type"` + Name string `json:"name" db:"required"` + Level int `json:"level" db:"index"` + Type int `json:"type" db:"index"` Att string `json:"att"` } -func (d *Drop) Save() error { - return GetStore().UpdateWithRebuild(d.ID, d) +// DropType constants for drop types +const ( + TypeConsumable = 1 +) + +// Global store +var store *nigiri.BaseStore[Drop] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Drop]() + + // Register custom indices + store.RegisterIndex("byLevel", nigiri.BuildIntGroupIndex(func(d *Drop) int { + return d.Level + })) + + store.RegisterIndex("byType", nigiri.BuildIntGroupIndex(func(d *Drop) int { + return d.Type + })) + + store.RegisterIndex("allByID", nigiri.BuildSortedListIndex(func(a, b *Drop) bool { + return a.ID < b.ID + })) + + store.RebuildIndices() } -func (d *Drop) Delete() error { - GetStore().RemoveWithRebuild(d.ID) - return nil +// GetStore returns the drops store +func GetStore() *nigiri.BaseStore[Drop] { + if store == nil { + panic("drops store not initialized - call Initialize first") + } + return store } // Creates a new Drop with sensible defaults @@ -47,64 +77,37 @@ func (d *Drop) Validate() error { return nil } -// DropType constants for drop types -const ( - TypeConsumable = 1 -) - -// DropStore with enhanced BaseStore -type DropStore struct { - *store.BaseStore[Drop] +// CRUD operations +func (d *Drop) Save() error { + if d.ID == 0 { + id, err := store.Create(d) + if err != nil { + return err + } + d.ID = id + return nil + } + return store.Update(d.ID, d) } -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *DropStore { - ds := &DropStore{BaseStore: store.NewBaseStore[Drop]()} - - // Register indices - ds.RegisterIndex("byLevel", store.BuildIntGroupIndex(func(d *Drop) int { - return d.Level - })) - - ds.RegisterIndex("byType", store.BuildIntGroupIndex(func(d *Drop) int { - return d.Type - })) - - ds.RegisterIndex("allByID", store.BuildSortedListIndex(func(a, b *Drop) bool { - return a.ID < b.ID - })) - - return ds -}) - -// Enhanced CRUD operations -func (ds *DropStore) AddDrop(drop *Drop) error { - return ds.AddWithRebuild(drop.ID, drop) +func (d *Drop) Delete() error { + store.Remove(d.ID) + return nil } -func (ds *DropStore) RemoveDrop(id int) { - ds.RemoveWithRebuild(id) +// Insert with ID assignment +func (d *Drop) Insert() error { + id, err := store.Create(d) + if err != nil { + return err + } + d.ID = id + return nil } -func (ds *DropStore) UpdateDrop(drop *Drop) error { - return ds.UpdateWithRebuild(drop.ID, drop) -} - -// Data persistence -func LoadData(dataPath string) error { - ds := GetStore() - return ds.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - ds := GetStore() - return ds.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*Drop, error) { - ds := GetStore() - drop, exists := ds.Find(id) + drop, exists := store.Find(id) if !exists { return nil, fmt.Errorf("drop with ID %d not found", id) } @@ -112,29 +115,17 @@ func Find(id int) (*Drop, error) { } func All() ([]*Drop, error) { - ds := GetStore() - return ds.AllSorted("allByID"), nil + return store.AllSorted("allByID"), nil } func ByLevel(minLevel int) ([]*Drop, error) { - ds := GetStore() - return ds.FilterByIndex("allByID", func(d *Drop) bool { + return store.FilterByIndex("allByID", func(d *Drop) bool { return d.Level <= minLevel }), nil } func ByType(dropType int) ([]*Drop, error) { - ds := GetStore() - return ds.GroupByIndex("byType", dropType), nil -} - -// Insert with ID assignment -func (d *Drop) Insert() error { - ds := GetStore() - if d.ID == 0 { - d.ID = ds.GetNextID() - } - return ds.AddDrop(d) + return store.GroupByIndex("byType", dropType), nil } // Helper methods @@ -150,3 +141,14 @@ func (d *Drop) TypeName() string { return "Unknown" } } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/fights/fights.go b/internal/models/fights/fights.go index 106ee5a..013b950 100644 --- a/internal/models/fights/fights.go +++ b/internal/models/fights/fights.go @@ -1,16 +1,17 @@ package fights import ( - "dk/internal/store" "fmt" "time" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Fight represents a fight, past or present type Fight struct { ID int `json:"id"` - UserID int `json:"user_id"` - MonsterID int `json:"monster_id"` + UserID int `json:"user_id" db:"index"` + MonsterID int `json:"monster_id" db:"index"` MonsterHP int `json:"monster_hp"` MonsterMaxHP int `json:"monster_max_hp"` MonsterSleep int `json:"monster_sleep"` @@ -29,16 +30,59 @@ type Fight struct { Updated int64 `json:"updated"` } -func (f *Fight) Save() error { - f.Updated = time.Now().Unix() - return GetStore().UpdateWithRebuild(f.ID, f) +// Global store +var store *nigiri.BaseStore[Fight] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Fight]() + + // Register custom indices + store.RegisterIndex("byUserID", nigiri.BuildIntGroupIndex(func(f *Fight) int { + return f.UserID + })) + + store.RegisterIndex("byMonsterID", nigiri.BuildIntGroupIndex(func(f *Fight) int { + return f.MonsterID + })) + + store.RegisterIndex("activeFights", nigiri.BuildFilteredIntGroupIndex( + func(f *Fight) bool { + return !f.RanAway && !f.Victory + }, + func(f *Fight) int { + return f.UserID + }, + )) + + store.RegisterIndex("allByCreated", nigiri.BuildSortedListIndex(func(a, b *Fight) bool { + if a.Created != b.Created { + return a.Created > b.Created // DESC + } + return a.ID > b.ID // DESC + })) + + store.RegisterIndex("allByUpdated", nigiri.BuildSortedListIndex(func(a, b *Fight) bool { + if a.Updated != b.Updated { + return a.Updated > b.Updated // DESC + } + return a.ID > b.ID // DESC + })) + + store.RebuildIndices() } -func (f *Fight) Delete() error { - GetStore().RemoveWithRebuild(f.ID) - return nil +// GetStore returns the fights store +func GetStore() *nigiri.BaseStore[Fight] { + if store == nil { + panic("fights store not initialized - call Initialize first") + } + return store } +// New creates a new Fight with sensible defaults func New(userID, monsterID int) *Fight { now := time.Now().Unix() return &Fight{ @@ -86,78 +130,39 @@ func (f *Fight) Validate() error { return nil } -// FightStore with enhanced BaseStore -type FightStore struct { - *store.BaseStore[Fight] -} - -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *FightStore { - fs := &FightStore{BaseStore: store.NewBaseStore[Fight]()} - - // Register indices - fs.RegisterIndex("byUserID", store.BuildIntGroupIndex(func(f *Fight) int { - return f.UserID - })) - - fs.RegisterIndex("byMonsterID", store.BuildIntGroupIndex(func(f *Fight) int { - return f.MonsterID - })) - - fs.RegisterIndex("activeFights", store.BuildFilteredIntGroupIndex( - func(f *Fight) bool { - return !f.RanAway && !f.Victory - }, - func(f *Fight) int { - return f.UserID - }, - )) - - fs.RegisterIndex("allByCreated", store.BuildSortedListIndex(func(a, b *Fight) bool { - if a.Created != b.Created { - return a.Created > b.Created // DESC +// CRUD operations +func (f *Fight) Save() error { + f.Updated = time.Now().Unix() + if f.ID == 0 { + id, err := store.Create(f) + if err != nil { + return err } - return a.ID > b.ID // DESC - })) - - fs.RegisterIndex("allByUpdated", store.BuildSortedListIndex(func(a, b *Fight) bool { - if a.Updated != b.Updated { - return a.Updated > b.Updated // DESC - } - return a.ID > b.ID // DESC - })) - - return fs -}) - -// Enhanced CRUD operations -func (fs *FightStore) AddFight(fight *Fight) error { - return fs.AddWithRebuild(fight.ID, fight) + f.ID = id + return nil + } + return store.Update(f.ID, f) } -func (fs *FightStore) RemoveFight(id int) { - fs.RemoveWithRebuild(id) +func (f *Fight) Delete() error { + store.Remove(f.ID) + return nil } -func (fs *FightStore) UpdateFight(fight *Fight) error { - return fs.UpdateWithRebuild(fight.ID, fight) +// Insert with ID assignment +func (f *Fight) Insert() error { + f.Updated = time.Now().Unix() + id, err := store.Create(f) + if err != nil { + return err + } + f.ID = id + return nil } -// Data persistence -func LoadData(dataPath string) error { - fs := GetStore() - return fs.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - fs := GetStore() - return fs.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*Fight, error) { - fs := GetStore() - fight, exists := fs.Find(id) + fight, exists := store.Find(id) if !exists { return nil, fmt.Errorf("fight with ID %d not found", id) } @@ -165,54 +170,38 @@ func Find(id int) (*Fight, error) { } func All() ([]*Fight, error) { - fs := GetStore() - return fs.AllSorted("allByCreated"), nil + return store.AllSorted("allByCreated"), nil } func ByUserID(userID int) ([]*Fight, error) { - fs := GetStore() - return fs.GroupByIndex("byUserID", userID), nil + return store.GroupByIndex("byUserID", userID), nil } func ByMonsterID(monsterID int) ([]*Fight, error) { - fs := GetStore() - return fs.GroupByIndex("byMonsterID", monsterID), nil + return store.GroupByIndex("byMonsterID", monsterID), nil } func ActiveByUserID(userID int) ([]*Fight, error) { - fs := GetStore() - return fs.GroupByIndex("activeFights", userID), nil + return store.GroupByIndex("activeFights", userID), nil } func Active() ([]*Fight, error) { - fs := GetStore() - result := fs.FilterByIndex("allByCreated", func(f *Fight) bool { + result := store.FilterByIndex("allByCreated", func(f *Fight) bool { return !f.RanAway && !f.Victory }) return result, nil } func Recent(within time.Duration) ([]*Fight, error) { - fs := GetStore() cutoff := time.Now().Add(-within).Unix() - result := fs.FilterByIndex("allByCreated", func(f *Fight) bool { + result := store.FilterByIndex("allByCreated", func(f *Fight) bool { return f.Created >= cutoff }) return result, nil } -// Insert with ID assignment -func (f *Fight) Insert() error { - fs := GetStore() - if f.ID == 0 { - f.ID = fs.GetNextID() - } - f.Updated = time.Now().Unix() - return fs.AddFight(f) -} - // Helper methods func (f *Fight) CreatedTime() time.Time { return time.Unix(f.Created, 0) diff --git a/internal/models/forum/forum.go b/internal/models/forum/forum.go index 39aaffc..32ede75 100644 --- a/internal/models/forum/forum.go +++ b/internal/models/forum/forum.go @@ -1,11 +1,12 @@ package forum import ( - "dk/internal/store" "fmt" "sort" "strings" "time" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Forum represents a forum post or thread in the game @@ -13,20 +14,47 @@ type Forum struct { ID int `json:"id"` Posted int64 `json:"posted"` LastPost int64 `json:"last_post"` - Author int `json:"author"` - Parent int `json:"parent"` + Author int `json:"author" db:"index"` + Parent int `json:"parent" db:"index"` Replies int `json:"replies"` - Title string `json:"title"` - Content string `json:"content"` + Title string `json:"title" db:"required"` + Content string `json:"content" db:"required"` } -func (f *Forum) Save() error { - return GetStore().UpdateWithRebuild(f.ID, f) +// Global store +var store *nigiri.BaseStore[Forum] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Forum]() + + // Register custom indices + store.RegisterIndex("byParent", nigiri.BuildIntGroupIndex(func(f *Forum) int { + return f.Parent + })) + + store.RegisterIndex("byAuthor", nigiri.BuildIntGroupIndex(func(f *Forum) int { + return f.Author + })) + + store.RegisterIndex("allByLastPost", nigiri.BuildSortedListIndex(func(a, b *Forum) bool { + if a.LastPost != b.LastPost { + return a.LastPost > b.LastPost // DESC + } + return a.ID > b.ID // DESC + })) + + store.RebuildIndices() } -func (f *Forum) Delete() error { - GetStore().RemoveWithRebuild(f.ID) - return nil +// GetStore returns the forum store +func GetStore() *nigiri.BaseStore[Forum] { + if store == nil { + panic("forum store not initialized - call Initialize first") + } + return store } // Creates a new Forum with sensible defaults @@ -66,62 +94,37 @@ func (f *Forum) Validate() error { return nil } -// ForumStore with enhanced BaseStore -type ForumStore struct { - *store.BaseStore[Forum] -} - -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *ForumStore { - fs := &ForumStore{BaseStore: store.NewBaseStore[Forum]()} - - // Register indices - fs.RegisterIndex("byParent", store.BuildIntGroupIndex(func(f *Forum) int { - return f.Parent - })) - - fs.RegisterIndex("byAuthor", store.BuildIntGroupIndex(func(f *Forum) int { - return f.Author - })) - - fs.RegisterIndex("allByLastPost", store.BuildSortedListIndex(func(a, b *Forum) bool { - if a.LastPost != b.LastPost { - return a.LastPost > b.LastPost // DESC +// CRUD operations +func (f *Forum) Save() error { + if f.ID == 0 { + id, err := store.Create(f) + if err != nil { + return err } - return a.ID > b.ID // DESC - })) - - return fs -}) - -// Enhanced CRUD operations -func (fs *ForumStore) AddForum(forum *Forum) error { - return fs.AddWithRebuild(forum.ID, forum) + f.ID = id + return nil + } + return store.Update(f.ID, f) } -func (fs *ForumStore) RemoveForum(id int) { - fs.RemoveWithRebuild(id) +func (f *Forum) Delete() error { + store.Remove(f.ID) + return nil } -func (fs *ForumStore) UpdateForum(forum *Forum) error { - return fs.UpdateWithRebuild(forum.ID, forum) +// Insert with ID assignment +func (f *Forum) Insert() error { + id, err := store.Create(f) + if err != nil { + return err + } + f.ID = id + return nil } -// Data persistence -func LoadData(dataPath string) error { - fs := GetStore() - return fs.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - fs := GetStore() - return fs.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*Forum, error) { - fs := GetStore() - forum, exists := fs.Find(id) + forum, exists := store.Find(id) if !exists { return nil, fmt.Errorf("forum post with ID %d not found", id) } @@ -129,20 +132,17 @@ func Find(id int) (*Forum, error) { } func All() ([]*Forum, error) { - fs := GetStore() - return fs.AllSorted("allByLastPost"), nil + return store.AllSorted("allByLastPost"), nil } func Threads() ([]*Forum, error) { - fs := GetStore() - return fs.FilterByIndex("allByLastPost", func(f *Forum) bool { + return store.FilterByIndex("allByLastPost", func(f *Forum) bool { return f.Parent == 0 }), nil } func ByParent(parentID int) ([]*Forum, error) { - fs := GetStore() - replies := fs.GroupByIndex("byParent", parentID) + replies := store.GroupByIndex("byParent", parentID) // Sort replies chronologically (posted ASC, then ID ASC) if parentID > 0 && len(replies) > 1 { @@ -158,8 +158,7 @@ func ByParent(parentID int) ([]*Forum, error) { } func ByAuthor(authorID int) ([]*Forum, error) { - fs := GetStore() - posts := fs.GroupByIndex("byAuthor", authorID) + posts := store.GroupByIndex("byAuthor", authorID) // Sort by posted DESC, then ID DESC sort.Slice(posts, func(i, j int) bool { @@ -173,8 +172,7 @@ func ByAuthor(authorID int) ([]*Forum, error) { } func Recent(limit int) ([]*Forum, error) { - fs := GetStore() - all := fs.AllSorted("allByLastPost") + all := store.AllSorted("allByLastPost") if limit > len(all) { limit = len(all) } @@ -182,30 +180,19 @@ func Recent(limit int) ([]*Forum, error) { } func Search(term string) ([]*Forum, error) { - fs := GetStore() lowerTerm := strings.ToLower(term) - return fs.FilterByIndex("allByLastPost", func(f *Forum) bool { + return store.FilterByIndex("allByLastPost", func(f *Forum) bool { return strings.Contains(strings.ToLower(f.Title), lowerTerm) || strings.Contains(strings.ToLower(f.Content), lowerTerm) }), nil } func Since(since int64) ([]*Forum, error) { - fs := GetStore() - return fs.FilterByIndex("allByLastPost", func(f *Forum) bool { + return store.FilterByIndex("allByLastPost", func(f *Forum) bool { return f.LastPost >= since }), nil } -// Insert with ID assignment -func (f *Forum) Insert() error { - fs := GetStore() - if f.ID == 0 { - f.ID = fs.GetNextID() - } - return fs.AddForum(f) -} - // Helper methods func (f *Forum) PostedTime() time.Time { return time.Unix(f.Posted, 0) @@ -324,3 +311,14 @@ func (f *Forum) GetThread() (*Forum, error) { } return Find(f.Parent) } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/items/items.go b/internal/models/items/items.go index 9486bed..77b11d3 100644 --- a/internal/models/items/items.go +++ b/internal/models/items/items.go @@ -1,27 +1,55 @@ package items import ( - "dk/internal/store" "fmt" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Item represents an item in the game type Item struct { ID int `json:"id"` - Type int `json:"type"` - Name string `json:"name"` + Type int `json:"type" db:"index"` + Name string `json:"name" db:"required"` Value int `json:"value"` Att int `json:"att"` Special string `json:"special"` } -func (i *Item) Save() error { - return GetStore().UpdateWithRebuild(i.ID, i) +// ItemType constants for item types +const ( + TypeWeapon = 1 + TypeArmor = 2 + TypeShield = 3 +) + +// Global store +var store *nigiri.BaseStore[Item] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Item]() + + // Register custom indices + store.RegisterIndex("byType", nigiri.BuildIntGroupIndex(func(i *Item) int { + return i.Type + })) + + store.RegisterIndex("allByID", nigiri.BuildSortedListIndex(func(a, b *Item) bool { + return a.ID < b.ID + })) + + store.RebuildIndices() } -func (i *Item) Delete() error { - GetStore().RemoveWithRebuild(i.ID) - return nil +// GetStore returns the items store +func GetStore() *nigiri.BaseStore[Item] { + if store == nil { + panic("items store not initialized - call Initialize first") + } + return store } // Creates a new Item with sensible defaults @@ -52,62 +80,37 @@ func (i *Item) Validate() error { return nil } -// ItemType constants for item types -const ( - TypeWeapon = 1 - TypeArmor = 2 - TypeShield = 3 -) - -// ItemStore with enhanced BaseStore -type ItemStore struct { - *store.BaseStore[Item] +// CRUD operations +func (i *Item) Save() error { + if i.ID == 0 { + id, err := store.Create(i) + if err != nil { + return err + } + i.ID = id + return nil + } + return store.Update(i.ID, i) } -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *ItemStore { - is := &ItemStore{BaseStore: store.NewBaseStore[Item]()} - - // Register indices - is.RegisterIndex("byType", store.BuildIntGroupIndex(func(i *Item) int { - return i.Type - })) - - is.RegisterIndex("allByID", store.BuildSortedListIndex(func(a, b *Item) bool { - return a.ID < b.ID - })) - - return is -}) - -// Enhanced CRUD operations -func (is *ItemStore) AddItem(item *Item) error { - return is.AddWithRebuild(item.ID, item) +func (i *Item) Delete() error { + store.Remove(i.ID) + return nil } -func (is *ItemStore) RemoveItem(id int) { - is.RemoveWithRebuild(id) +// Insert with ID assignment +func (i *Item) Insert() error { + id, err := store.Create(i) + if err != nil { + return err + } + i.ID = id + return nil } -func (is *ItemStore) UpdateItem(item *Item) error { - return is.UpdateWithRebuild(item.ID, item) -} - -// Data persistence -func LoadData(dataPath string) error { - is := GetStore() - return is.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - is := GetStore() - return is.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*Item, error) { - is := GetStore() - item, exists := is.Find(id) + item, exists := store.Find(id) if !exists { return nil, fmt.Errorf("item with ID %d not found", id) } @@ -115,22 +118,11 @@ func Find(id int) (*Item, error) { } func All() ([]*Item, error) { - is := GetStore() - return is.AllSorted("allByID"), nil + return store.AllSorted("allByID"), nil } func ByType(itemType int) ([]*Item, error) { - is := GetStore() - return is.GroupByIndex("byType", itemType), nil -} - -// Insert with ID assignment -func (i *Item) Insert() error { - is := GetStore() - if i.ID == 0 { - i.ID = is.GetNextID() - } - return is.AddItem(i) + return store.GroupByIndex("byType", itemType), nil } // Helper methods @@ -166,3 +158,14 @@ func (i *Item) HasSpecial() bool { func (i *Item) IsEquippable() bool { return i.Type == TypeWeapon || i.Type == TypeArmor || i.Type == TypeShield } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/monsters/monsters.go b/internal/models/monsters/monsters.go index eba2b5c..655a340 100644 --- a/internal/models/monsters/monsters.go +++ b/internal/models/monsters/monsters.go @@ -1,30 +1,65 @@ package monsters import ( - "dk/internal/store" "fmt" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Monster represents a monster in the game type Monster struct { ID int `json:"id"` - Name string `json:"name"` + Name string `json:"name" db:"required"` MaxHP int `json:"max_hp"` MaxDmg int `json:"max_dmg"` Armor int `json:"armor"` - Level int `json:"level"` + Level int `json:"level" db:"index"` MaxExp int `json:"max_exp"` MaxGold int `json:"max_gold"` - Immune int `json:"immune"` + Immune int `json:"immune" db:"index"` } -func (m *Monster) Save() error { - return GetStore().UpdateWithRebuild(m.ID, m) +// Immunity constants +const ( + ImmuneNone = 0 + ImmuneHurt = 1 + ImmuneSleep = 2 +) + +// Global store +var store *nigiri.BaseStore[Monster] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Monster]() + + // Register custom indices + store.RegisterIndex("byLevel", nigiri.BuildIntGroupIndex(func(m *Monster) int { + return m.Level + })) + + store.RegisterIndex("byImmunity", nigiri.BuildIntGroupIndex(func(m *Monster) int { + return m.Immune + })) + + store.RegisterIndex("allByLevel", nigiri.BuildSortedListIndex(func(a, b *Monster) bool { + if a.Level == b.Level { + return a.ID < b.ID + } + return a.Level < b.Level + })) + + store.RebuildIndices() } -func (m *Monster) Delete() error { - GetStore().RemoveWithRebuild(m.ID) - return nil +// GetStore returns the monsters store +func GetStore() *nigiri.BaseStore[Monster] { + if store == nil { + panic("monsters store not initialized - call Initialize first") + } + return store } // Creates a new Monster with sensible defaults @@ -58,69 +93,37 @@ func (m *Monster) Validate() error { return nil } -// Immunity constants -const ( - ImmuneNone = 0 - ImmuneHurt = 1 - ImmuneSleep = 2 -) - -// MonsterStore with enhanced BaseStore -type MonsterStore struct { - *store.BaseStore[Monster] -} - -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *MonsterStore { - ms := &MonsterStore{BaseStore: store.NewBaseStore[Monster]()} - - // Register indices - ms.RegisterIndex("byLevel", store.BuildIntGroupIndex(func(m *Monster) int { - return m.Level - })) - - ms.RegisterIndex("byImmunity", store.BuildIntGroupIndex(func(m *Monster) int { - return m.Immune - })) - - ms.RegisterIndex("allByLevel", store.BuildSortedListIndex(func(a, b *Monster) bool { - if a.Level == b.Level { - return a.ID < b.ID +// CRUD operations +func (m *Monster) Save() error { + if m.ID == 0 { + id, err := store.Create(m) + if err != nil { + return err } - return a.Level < b.Level - })) - - return ms -}) - -// Enhanced CRUD operations -func (ms *MonsterStore) AddMonster(monster *Monster) error { - return ms.AddWithRebuild(monster.ID, monster) + m.ID = id + return nil + } + return store.Update(m.ID, m) } -func (ms *MonsterStore) RemoveMonster(id int) { - ms.RemoveWithRebuild(id) +func (m *Monster) Delete() error { + store.Remove(m.ID) + return nil } -func (ms *MonsterStore) UpdateMonster(monster *Monster) error { - return ms.UpdateWithRebuild(monster.ID, monster) +// Insert with ID assignment +func (m *Monster) Insert() error { + id, err := store.Create(m) + if err != nil { + return err + } + m.ID = id + return nil } -// Data persistence -func LoadData(dataPath string) error { - ms := GetStore() - return ms.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - ms := GetStore() - return ms.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*Monster, error) { - ms := GetStore() - monster, exists := ms.Find(id) + monster, exists := store.Find(id) if !exists { return nil, fmt.Errorf("monster with ID %d not found", id) } @@ -128,37 +131,24 @@ func Find(id int) (*Monster, error) { } func All() ([]*Monster, error) { - ms := GetStore() - return ms.AllSorted("allByLevel"), nil + return store.AllSorted("allByLevel"), nil } func ByLevel(level int) ([]*Monster, error) { - ms := GetStore() - return ms.GroupByIndex("byLevel", level), nil + return store.GroupByIndex("byLevel", level), nil } func ByLevelRange(minLevel, maxLevel int) ([]*Monster, error) { - ms := GetStore() var result []*Monster for level := minLevel; level <= maxLevel; level++ { - monsters := ms.GroupByIndex("byLevel", level) + monsters := store.GroupByIndex("byLevel", level) result = append(result, monsters...) } return result, nil } func ByImmunity(immunityType int) ([]*Monster, error) { - ms := GetStore() - return ms.GroupByIndex("byImmunity", immunityType), nil -} - -// Insert with ID assignment -func (m *Monster) Insert() error { - ms := GetStore() - if m.ID == 0 { - m.ID = ms.GetNextID() - } - return ms.AddMonster(m) + return store.GroupByIndex("byImmunity", immunityType), nil } // Helper methods @@ -207,3 +197,14 @@ func (m *Monster) GoldPerHP() float64 { } return float64(m.MaxGold) / float64(m.MaxHP) } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/news/news.go b/internal/models/news/news.go index 1ba3987..d259512 100644 --- a/internal/models/news/news.go +++ b/internal/models/news/news.go @@ -1,27 +1,51 @@ package news import ( - "dk/internal/store" "fmt" "strings" "time" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // News represents a news post in the game type News struct { ID int `json:"id"` - Author int `json:"author"` + Author int `json:"author" db:"index"` Posted int64 `json:"posted"` - Content string `json:"content"` + Content string `json:"content" db:"required"` } -func (n *News) Save() error { - return GetStore().UpdateWithRebuild(n.ID, n) +// Global store +var store *nigiri.BaseStore[News] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[News]() + + // Register custom indices + store.RegisterIndex("byAuthor", nigiri.BuildIntGroupIndex(func(n *News) int { + return n.Author + })) + + store.RegisterIndex("allByPosted", nigiri.BuildSortedListIndex(func(a, b *News) bool { + if a.Posted != b.Posted { + return a.Posted > b.Posted // DESC + } + return a.ID > b.ID // DESC + })) + + store.RebuildIndices() } -func (n *News) Delete() error { - GetStore().RemoveWithRebuild(n.ID) - return nil +// GetStore returns the news store +func GetStore() *nigiri.BaseStore[News] { + if store == nil { + panic("news store not initialized - call Init first") + } + return store } // Creates a new News with sensible defaults @@ -44,58 +68,37 @@ func (n *News) Validate() error { return nil } -// NewsStore with enhanced BaseStore -type NewsStore struct { - *store.BaseStore[News] -} - -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *NewsStore { - ns := &NewsStore{BaseStore: store.NewBaseStore[News]()} - - // Register indices - ns.RegisterIndex("byAuthor", store.BuildIntGroupIndex(func(n *News) int { - return n.Author - })) - - ns.RegisterIndex("allByPosted", store.BuildSortedListIndex(func(a, b *News) bool { - if a.Posted != b.Posted { - return a.Posted > b.Posted // DESC +// CRUD operations +func (n *News) Save() error { + if n.ID == 0 { + id, err := store.Create(n) + if err != nil { + return err } - return a.ID > b.ID // DESC - })) - - return ns -}) - -// Enhanced CRUD operations -func (ns *NewsStore) AddNews(news *News) error { - return ns.AddWithRebuild(news.ID, news) + n.ID = id + return nil + } + return store.Update(n.ID, n) } -func (ns *NewsStore) RemoveNews(id int) { - ns.RemoveWithRebuild(id) +func (n *News) Delete() error { + store.Remove(n.ID) + return nil } -func (ns *NewsStore) UpdateNews(news *News) error { - return ns.UpdateWithRebuild(news.ID, news) +// Insert with ID assignment +func (n *News) Insert() error { + id, err := store.Create(n) + if err != nil { + return err + } + n.ID = id + return nil } -// Data persistence -func LoadData(dataPath string) error { - ns := GetStore() - return ns.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - ns := GetStore() - return ns.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*News, error) { - ns := GetStore() - news, exists := ns.Find(id) + news, exists := store.Find(id) if !exists { return nil, fmt.Errorf("news with ID %d not found", id) } @@ -103,18 +106,15 @@ func Find(id int) (*News, error) { } func All() ([]*News, error) { - ns := GetStore() - return ns.AllSorted("allByPosted"), nil + return store.AllSorted("allByPosted"), nil } func ByAuthor(authorID int) ([]*News, error) { - ns := GetStore() - return ns.GroupByIndex("byAuthor", authorID), nil + return store.GroupByIndex("byAuthor", authorID), nil } func Recent(limit int) ([]*News, error) { - ns := GetStore() - all := ns.AllSorted("allByPosted") + all := store.AllSorted("allByPosted") if limit > len(all) { limit = len(all) } @@ -122,36 +122,24 @@ func Recent(limit int) ([]*News, error) { } func Since(since int64) ([]*News, error) { - ns := GetStore() - return ns.FilterByIndex("allByPosted", func(n *News) bool { + return store.FilterByIndex("allByPosted", func(n *News) bool { return n.Posted >= since }), nil } func Between(start, end int64) ([]*News, error) { - ns := GetStore() - return ns.FilterByIndex("allByPosted", func(n *News) bool { + return store.FilterByIndex("allByPosted", func(n *News) bool { return n.Posted >= start && n.Posted <= end }), nil } func Search(term string) ([]*News, error) { - ns := GetStore() lowerTerm := strings.ToLower(term) - return ns.FilterByIndex("allByPosted", func(n *News) bool { + return store.FilterByIndex("allByPosted", func(n *News) bool { return strings.Contains(strings.ToLower(n.Content), lowerTerm) }), nil } -// Insert with ID assignment -func (n *News) Insert() error { - ns := GetStore() - if n.ID == 0 { - n.ID = ns.GetNextID() - } - return ns.AddNews(n) -} - // Helper methods func (n *News) PostedTime() time.Time { return time.Unix(n.Posted, 0) @@ -227,3 +215,14 @@ func (n *News) Contains(term string) bool { func (n *News) IsEmpty() bool { return strings.TrimSpace(n.Content) == "" } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/spells/spells.go b/internal/models/spells/spells.go index f72c6b8..77eb282 100644 --- a/internal/models/spells/spells.go +++ b/internal/models/spells/spells.go @@ -1,27 +1,71 @@ package spells import ( - "dk/internal/store" "fmt" "strings" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Spell represents a spell in the game type Spell struct { ID int `json:"id"` - Name string `json:"name"` - MP int `json:"mp"` + Name string `json:"name" db:"required,unique"` + MP int `json:"mp" db:"index"` Attribute int `json:"attribute"` - Type int `json:"type"` + Type int `json:"type" db:"index"` } -func (s *Spell) Save() error { - return GetStore().UpdateWithRebuild(s.ID, s) +// SpellType constants for spell types +const ( + TypeHealing = 1 + TypeHurt = 2 + TypeSleep = 3 + TypeAttackBoost = 4 + TypeDefenseBoost = 5 +) + +// Global store +var store *nigiri.BaseStore[Spell] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Spell]() + + // Register custom indices + store.RegisterIndex("byType", nigiri.BuildIntGroupIndex(func(s *Spell) int { + return s.Type + })) + + store.RegisterIndex("byName", nigiri.BuildCaseInsensitiveLookupIndex(func(s *Spell) string { + return s.Name + })) + + store.RegisterIndex("byMP", nigiri.BuildIntGroupIndex(func(s *Spell) int { + return s.MP + })) + + store.RegisterIndex("allByTypeMP", nigiri.BuildSortedListIndex(func(a, b *Spell) bool { + if a.Type != b.Type { + return a.Type < b.Type + } + if a.MP != b.MP { + return a.MP < b.MP + } + return a.ID < b.ID + })) + + store.RebuildIndices() } -func (s *Spell) Delete() error { - GetStore().RemoveWithRebuild(s.ID) - return nil +// GetStore returns the spells store +func GetStore() *nigiri.BaseStore[Spell] { + if store == nil { + panic("spells store not initialized - call Initialize first") + } + return store } // Creates a new Spell with sensible defaults @@ -51,78 +95,37 @@ func (s *Spell) Validate() error { return nil } -// SpellType constants for spell types -const ( - TypeHealing = 1 - TypeHurt = 2 - TypeSleep = 3 - TypeAttackBoost = 4 - TypeDefenseBoost = 5 -) - -// SpellStore with enhanced BaseStore -type SpellStore struct { - *store.BaseStore[Spell] -} - -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *SpellStore { - ss := &SpellStore{BaseStore: store.NewBaseStore[Spell]()} - - // Register indices - ss.RegisterIndex("byType", store.BuildIntGroupIndex(func(s *Spell) int { - return s.Type - })) - - ss.RegisterIndex("byName", store.BuildCaseInsensitiveLookupIndex(func(s *Spell) string { - return s.Name - })) - - ss.RegisterIndex("byMP", store.BuildIntGroupIndex(func(s *Spell) int { - return s.MP - })) - - ss.RegisterIndex("allByTypeMP", store.BuildSortedListIndex(func(a, b *Spell) bool { - if a.Type != b.Type { - return a.Type < b.Type +// CRUD operations +func (s *Spell) Save() error { + if s.ID == 0 { + id, err := store.Create(s) + if err != nil { + return err } - if a.MP != b.MP { - return a.MP < b.MP - } - return a.ID < b.ID - })) - - return ss -}) - -// Enhanced CRUD operations -func (ss *SpellStore) AddSpell(spell *Spell) error { - return ss.AddWithRebuild(spell.ID, spell) + s.ID = id + return nil + } + return store.Update(s.ID, s) } -func (ss *SpellStore) RemoveSpell(id int) { - ss.RemoveWithRebuild(id) +func (s *Spell) Delete() error { + store.Remove(s.ID) + return nil } -func (ss *SpellStore) UpdateSpell(spell *Spell) error { - return ss.UpdateWithRebuild(spell.ID, spell) +// Insert with ID assignment +func (s *Spell) Insert() error { + id, err := store.Create(s) + if err != nil { + return err + } + s.ID = id + return nil } -// Data persistence -func LoadData(dataPath string) error { - ss := GetStore() - return ss.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - ss := GetStore() - return ss.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*Spell, error) { - ss := GetStore() - spell, exists := ss.Find(id) + spell, exists := store.Find(id) if !exists { return nil, fmt.Errorf("spell with ID %d not found", id) } @@ -130,47 +133,33 @@ func Find(id int) (*Spell, error) { } func All() ([]*Spell, error) { - ss := GetStore() - return ss.AllSorted("allByTypeMP"), nil + return store.AllSorted("allByTypeMP"), nil } func ByType(spellType int) ([]*Spell, error) { - ss := GetStore() - return ss.GroupByIndex("byType", spellType), nil + return store.GroupByIndex("byType", spellType), nil } func ByMaxMP(maxMP int) ([]*Spell, error) { - ss := GetStore() - return ss.FilterByIndex("allByTypeMP", func(s *Spell) bool { + return store.FilterByIndex("allByTypeMP", func(s *Spell) bool { return s.MP <= maxMP }), nil } func ByTypeAndMaxMP(spellType, maxMP int) ([]*Spell, error) { - ss := GetStore() - return ss.FilterByIndex("allByTypeMP", func(s *Spell) bool { + return store.FilterByIndex("allByTypeMP", func(s *Spell) bool { return s.Type == spellType && s.MP <= maxMP }), nil } func ByName(name string) (*Spell, error) { - ss := GetStore() - spell, exists := ss.LookupByIndex("byName", strings.ToLower(name)) + spell, exists := store.LookupByIndex("byName", strings.ToLower(name)) if !exists { return nil, fmt.Errorf("spell with name '%s' not found", name) } return spell, nil } -// Insert with ID assignment -func (s *Spell) Insert() error { - ss := GetStore() - if s.ID == 0 { - s.ID = ss.GetNextID() - } - return ss.AddSpell(s) -} - // Helper methods func (s *Spell) IsHealing() bool { return s.Type == TypeHealing @@ -227,3 +216,14 @@ func (s *Spell) IsOffensive() bool { func (s *Spell) IsSupport() bool { return s.Type == TypeHealing || s.Type == TypeAttackBoost || s.Type == TypeDefenseBoost } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/towns/towns.go b/internal/models/towns/towns.go index e15ba14..b4886d4 100644 --- a/internal/models/towns/towns.go +++ b/internal/models/towns/towns.go @@ -1,7 +1,6 @@ package towns import ( - "dk/internal/store" "fmt" "math" "slices" @@ -10,12 +9,14 @@ import ( "strings" "dk/internal/helpers" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // Town represents a town in the game type Town struct { ID int `json:"id"` - Name string `json:"name"` + Name string `json:"name" db:"required,unique"` X int `json:"x"` Y int `json:"y"` InnCost int `json:"inn_cost"` @@ -24,13 +25,50 @@ type Town struct { ShopList string `json:"shop_list"` } -func (t *Town) Save() error { - return GetStore().UpdateWithRebuild(t.ID, t) +// Global store +var store *nigiri.BaseStore[Town] +var db *nigiri.Collection + +// coordsKey creates a key for coordinate-based lookup +func coordsKey(x, y int) string { + return strconv.Itoa(x) + "," + strconv.Itoa(y) } -func (t *Town) Delete() error { - GetStore().RemoveWithRebuild(t.ID) - return nil +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[Town]() + + // Register custom indices + store.RegisterIndex("byName", nigiri.BuildCaseInsensitiveLookupIndex(func(t *Town) string { + return t.Name + })) + + store.RegisterIndex("byCoords", nigiri.BuildStringLookupIndex(func(t *Town) string { + return coordsKey(t.X, t.Y) + })) + + store.RegisterIndex("byInnCost", nigiri.BuildIntGroupIndex(func(t *Town) int { + return t.InnCost + })) + + store.RegisterIndex("byTPCost", nigiri.BuildIntGroupIndex(func(t *Town) int { + return t.TPCost + })) + + store.RegisterIndex("allByID", nigiri.BuildSortedListIndex(func(a, b *Town) bool { + return a.ID < b.ID + })) + + store.RebuildIndices() +} + +// GetStore returns the towns store +func GetStore() *nigiri.BaseStore[Town] { + if store == nil { + panic("towns store not initialized - call Initialize first") + } + return store } // Creates a new Town with sensible defaults @@ -63,72 +101,37 @@ func (t *Town) Validate() error { return nil } -// coordsKey creates a key for coordinate-based lookup -func coordsKey(x, y int) string { - return strconv.Itoa(x) + "," + strconv.Itoa(y) +// CRUD operations +func (t *Town) Save() error { + if t.ID == 0 { + id, err := store.Create(t) + if err != nil { + return err + } + t.ID = id + return nil + } + return store.Update(t.ID, t) } -// TownStore with enhanced BaseStore -type TownStore struct { - *store.BaseStore[Town] +func (t *Town) Delete() error { + store.Remove(t.ID) + return nil } -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *TownStore { - ts := &TownStore{BaseStore: store.NewBaseStore[Town]()} - - // Register indices - ts.RegisterIndex("byName", store.BuildCaseInsensitiveLookupIndex(func(t *Town) string { - return t.Name - })) - - ts.RegisterIndex("byCoords", store.BuildStringLookupIndex(func(t *Town) string { - return coordsKey(t.X, t.Y) - })) - - ts.RegisterIndex("byInnCost", store.BuildIntGroupIndex(func(t *Town) int { - return t.InnCost - })) - - ts.RegisterIndex("byTPCost", store.BuildIntGroupIndex(func(t *Town) int { - return t.TPCost - })) - - ts.RegisterIndex("allByID", store.BuildSortedListIndex(func(a, b *Town) bool { - return a.ID < b.ID - })) - - return ts -}) - -// Enhanced CRUD operations -func (ts *TownStore) AddTown(town *Town) error { - return ts.AddWithRebuild(town.ID, town) +// Insert with ID assignment +func (t *Town) Insert() error { + id, err := store.Create(t) + if err != nil { + return err + } + t.ID = id + return nil } -func (ts *TownStore) RemoveTown(id int) { - ts.RemoveWithRebuild(id) -} - -func (ts *TownStore) UpdateTown(town *Town) error { - return ts.UpdateWithRebuild(town.ID, town) -} - -// Data persistence -func LoadData(dataPath string) error { - ts := GetStore() - return ts.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - ts := GetStore() - return ts.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*Town, error) { - ts := GetStore() - town, exists := ts.Find(id) + town, exists := store.Find(id) if !exists { return nil, fmt.Errorf("town with ID %d not found", id) } @@ -136,13 +139,11 @@ func Find(id int) (*Town, error) { } func All() ([]*Town, error) { - ts := GetStore() - return ts.AllSorted("allByID"), nil + return store.AllSorted("allByID"), nil } func ByName(name string) (*Town, error) { - ts := GetStore() - town, exists := ts.LookupByIndex("byName", strings.ToLower(name)) + town, exists := store.LookupByIndex("byName", strings.ToLower(name)) if !exists { return nil, fmt.Errorf("town with name '%s' not found", name) } @@ -150,22 +151,19 @@ func ByName(name string) (*Town, error) { } func ByMaxInnCost(maxCost int) ([]*Town, error) { - ts := GetStore() - return ts.FilterByIndex("allByID", func(t *Town) bool { + return store.FilterByIndex("allByID", func(t *Town) bool { return t.InnCost <= maxCost }), nil } func ByMaxTPCost(maxCost int) ([]*Town, error) { - ts := GetStore() - return ts.FilterByIndex("allByID", func(t *Town) bool { + return store.FilterByIndex("allByID", func(t *Town) bool { return t.TPCost <= maxCost }), nil } func ByCoords(x, y int) (*Town, error) { - ts := GetStore() - town, exists := ts.LookupByIndex("byCoords", coordsKey(x, y)) + town, exists := store.LookupByIndex("byCoords", coordsKey(x, y)) if !exists { return nil, nil // Return nil if not found (like original) } @@ -173,16 +171,14 @@ func ByCoords(x, y int) (*Town, error) { } func ExistsAt(x, y int) bool { - ts := GetStore() - _, exists := ts.LookupByIndex("byCoords", coordsKey(x, y)) + _, exists := store.LookupByIndex("byCoords", coordsKey(x, y)) return exists } func ByDistance(fromX, fromY, maxDistance int) ([]*Town, error) { - ts := GetStore() maxDistance2 := float64(maxDistance * maxDistance) - result := ts.FilterByIndex("allByID", func(t *Town) bool { + result := store.FilterByIndex("allByID", func(t *Town) bool { return t.DistanceFromSquared(fromX, fromY) <= maxDistance2 }) @@ -199,15 +195,6 @@ func ByDistance(fromX, fromY, maxDistance int) ([]*Town, error) { return result, nil } -// Insert with ID assignment -func (t *Town) Insert() error { - ts := GetStore() - if t.ID == 0 { - t.ID = ts.GetNextID() - } - return ts.AddTown(t) -} - // Helper methods func (t *Town) GetShopItems() []int { return helpers.StringToInts(t.ShopList) @@ -259,3 +246,14 @@ func (t *Town) SetPosition(x, y int) { t.X = x t.Y = y } + +// Legacy compatibility functions (will be removed later) +func LoadData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} + +func SaveData(dataPath string) error { + // No longer needed - Nigiri handles this + return nil +} diff --git a/internal/models/users/users.go b/internal/models/users/users.go index 95df143..fccad54 100644 --- a/internal/models/users/users.go +++ b/internal/models/users/users.go @@ -1,8 +1,6 @@ package users import ( - "dk/internal/helpers/exp" - "dk/internal/store" "fmt" "slices" "sort" @@ -10,14 +8,16 @@ import ( "time" "dk/internal/helpers" + + nigiri "git.sharkk.net/Sharkk/Nigiri" ) // User represents a user in the game type User struct { ID int `json:"id"` - Username string `json:"username"` - Password string `json:"password"` - Email string `json:"email"` + Username string `json:"username" db:"required,unique"` + Password string `json:"password" db:"required"` + Email string `json:"email" db:"required,unique"` Verified int `json:"verified"` Token string `json:"token"` Registered int64 `json:"registered"` @@ -34,7 +34,7 @@ type User struct { MaxHP int `json:"max_hp"` MaxMP int `json:"max_mp"` MaxTP int `json:"max_tp"` - Level int `json:"level"` + Level int `json:"level" db:"index"` Gold int `json:"gold"` Exp int `json:"exp"` GoldBonus int `json:"gold_bonus"` @@ -59,15 +59,57 @@ type User struct { Towns string `json:"towns"` } -func (u *User) Save() error { - return GetStore().UpdateWithRebuild(u.ID, u) +// Global store +var store *nigiri.BaseStore[User] +var db *nigiri.Collection + +// Init sets up the Nigiri store and indices +func Init(collection *nigiri.Collection) { + db = collection + store = nigiri.NewBaseStore[User]() + + // Register custom indices + store.RegisterIndex("byUsername", nigiri.BuildCaseInsensitiveLookupIndex(func(u *User) string { + return u.Username + })) + + store.RegisterIndex("byEmail", nigiri.BuildStringLookupIndex(func(u *User) string { + return u.Email + })) + + store.RegisterIndex("byLevel", nigiri.BuildIntGroupIndex(func(u *User) int { + return u.Level + })) + + store.RegisterIndex("allByRegistered", nigiri.BuildSortedListIndex(func(a, b *User) bool { + if a.Registered != b.Registered { + return a.Registered > b.Registered // DESC + } + return a.ID > b.ID // DESC + })) + + store.RegisterIndex("allByLevelExp", nigiri.BuildSortedListIndex(func(a, b *User) bool { + if a.Level != b.Level { + return a.Level > b.Level // Level DESC + } + if a.Exp != b.Exp { + return a.Exp > b.Exp // Exp DESC + } + return a.ID < b.ID // ID ASC + })) + + store.RebuildIndices() } -func (u *User) Delete() error { - GetStore().RemoveWithRebuild(u.ID) - return nil +// GetStore returns the users store +func GetStore() *nigiri.BaseStore[User] { + if store == nil { + panic("users store not initialized - call Initialize first") + } + return store } +// New creates a new User with sensible defaults func New() *User { now := time.Now().Unix() return &User{ @@ -125,90 +167,57 @@ func (u *User) Validate() error { return nil } -// UserStore with enhanced BaseStore -type UserStore struct { - *store.BaseStore[User] -} - -// Global store with singleton pattern -var GetStore = store.NewSingleton(func() *UserStore { - us := &UserStore{BaseStore: store.NewBaseStore[User]()} - - // Register indices - us.RegisterIndex("byUsername", store.BuildCaseInsensitiveLookupIndex(func(u *User) string { - return u.Username - })) - - us.RegisterIndex("byEmail", store.BuildStringLookupIndex(func(u *User) string { - return u.Email - })) - - us.RegisterIndex("byLevel", store.BuildIntGroupIndex(func(u *User) int { - return u.Level - })) - - us.RegisterIndex("allByRegistered", store.BuildSortedListIndex(func(a, b *User) bool { - if a.Registered != b.Registered { - return a.Registered > b.Registered // DESC +// CRUD operations +func (u *User) Save() error { + if u.ID == 0 { + id, err := store.Create(u) + if err != nil { + return err } - return a.ID > b.ID // DESC - })) - - us.RegisterIndex("allByLevelExp", store.BuildSortedListIndex(func(a, b *User) bool { - if a.Level != b.Level { - return a.Level > b.Level // Level DESC - } - if a.Exp != b.Exp { - return a.Exp > b.Exp // Exp DESC - } - return a.ID < b.ID // ID ASC - })) - - return us -}) - -// Enhanced CRUD operations -func (us *UserStore) AddUser(user *User) error { - return us.AddWithRebuild(user.ID, user) + u.ID = id + return nil + } + return store.Update(u.ID, u) } -func (us *UserStore) RemoveUser(id int) { - us.RemoveWithRebuild(id) +func (u *User) Delete() error { + store.Remove(u.ID) + return nil } -func (us *UserStore) UpdateUser(user *User) error { - return us.UpdateWithRebuild(user.ID, user) +// Insert with ID assignment +func (u *User) Insert() error { + id, err := store.Create(u) + if err != nil { + return err + } + u.ID = id + return nil } -// Data persistence -func LoadData(dataPath string) error { - us := GetStore() - return us.BaseStore.LoadData(dataPath) -} - -func SaveData(dataPath string) error { - us := GetStore() - return us.BaseStore.SaveData(dataPath) -} - -// Query functions using enhanced store +// Query functions func Find(id int) (*User, error) { - us := GetStore() - user, exists := us.Find(id) + user, exists := store.Find(id) if !exists { return nil, fmt.Errorf("user with ID %d not found", id) } return user, nil } +func GetByID(id int) *User { + user, exists := store.Find(id) + if !exists { + return nil + } + return user +} + func All() ([]*User, error) { - us := GetStore() - return us.AllSorted("allByRegistered"), nil + return store.AllSorted("allByRegistered"), nil } func ByUsername(username string) (*User, error) { - us := GetStore() - user, exists := us.LookupByIndex("byUsername", strings.ToLower(username)) + user, exists := store.LookupByIndex("byUsername", strings.ToLower(username)) if !exists { return nil, fmt.Errorf("user with username '%s' not found", username) } @@ -216,8 +225,7 @@ func ByUsername(username string) (*User, error) { } func ByEmail(email string) (*User, error) { - us := GetStore() - user, exists := us.LookupByIndex("byEmail", email) + user, exists := store.LookupByIndex("Email_idx", email) if !exists { return nil, fmt.Errorf("user with email '%s' not found", email) } @@ -225,15 +233,13 @@ func ByEmail(email string) (*User, error) { } func ByLevel(level int) ([]*User, error) { - us := GetStore() - return us.GroupByIndex("byLevel", level), nil + return store.GroupByIndex("level_idx", level), nil } func Online(within time.Duration) ([]*User, error) { - us := GetStore() cutoff := time.Now().Add(-within).Unix() - result := us.FilterByIndex("allByRegistered", func(u *User) bool { + result := store.FilterByIndex("allByRegistered", func(u *User) bool { return u.LastOnline >= cutoff }) @@ -248,15 +254,6 @@ func Online(within time.Duration) ([]*User, error) { return result, nil } -// Insert with ID assignment -func (u *User) Insert() error { - us := GetStore() - if u.ID == 0 { - u.ID = us.GetNextID() - } - return us.AddUser(u) -} - // Helper methods func (u *User) RegisteredTime() time.Time { return time.Unix(u.Registered, 0) @@ -351,7 +348,7 @@ func (u *User) SetPosition(x, y int) { } func (u *User) ExpNeededForNextLevel() int { - return exp.Calc(u.Level + 1) + return u.Level * u.Level * u.Level } func (u *User) GrantExp(expAmount int) { @@ -384,7 +381,7 @@ func (u *User) ExpProgress() float64 { return float64(u.Exp) / float64(u.ExpNeededForNextLevel()) * 100 } - currentLevelExp := exp.Calc(u.Level) + currentLevelExp := u.Level * u.Level * u.Level nextLevelExp := u.ExpNeededForNextLevel() progressExp := u.Exp diff --git a/internal/password/password.go b/internal/password/password.go deleted file mode 100644 index 7e2ebb9..0000000 --- a/internal/password/password.go +++ /dev/null @@ -1,80 +0,0 @@ -package password - -import ( - "crypto/rand" - "crypto/subtle" - "encoding/base64" - "fmt" - "strings" - - "golang.org/x/crypto/argon2" -) - -const ( - time = 1 - memory = 64 * 1024 - threads = 4 - keyLen = 32 -) - -// Hash creates an argon2id hash of the password -func Hash(password string) string { - salt := make([]byte, 16) - rand.Read(salt) - - hash := argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen) - - b64Salt := base64.RawStdEncoding.EncodeToString(salt) - b64Hash := base64.RawStdEncoding.EncodeToString(hash) - - encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", - argon2.Version, memory, time, threads, b64Salt, b64Hash) - - return encoded -} - -// Verify checks if a password matches the hash -func Verify(password, encodedHash string) (bool, error) { - parts := strings.Split(encodedHash, "$") - if len(parts) != 6 { - return false, fmt.Errorf("invalid hash format") - } - - if parts[1] != "argon2id" { - return false, fmt.Errorf("invalid hash variant") - } - - var version int - _, err := fmt.Sscanf(parts[2], "v=%d", &version) - if err != nil { - return false, err - } - if version != argon2.Version { - return false, fmt.Errorf("incompatible argon2 version") - } - - var m, t, p uint32 - _, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p) - if err != nil { - return false, err - } - - salt, err := base64.RawStdEncoding.DecodeString(parts[4]) - if err != nil { - return false, err - } - - expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) - if err != nil { - return false, err - } - - hash := argon2.IDKey([]byte(password), salt, t, m, uint8(p), uint32(len(expectedHash))) - - // Use constant-time comparison to prevent timing attacks - if subtle.ConstantTimeCompare(hash, expectedHash) == 1 { - return true, nil - } - - return false, nil -} diff --git a/internal/router/router.go b/internal/router/router.go deleted file mode 100644 index 105d96d..0000000 --- a/internal/router/router.go +++ /dev/null @@ -1,435 +0,0 @@ -package router - -import ( - "fmt" - - "github.com/valyala/fasthttp" -) - -type Ctx = *fasthttp.RequestCtx - -// Handler is a request handler with parameters. -type Handler func(ctx Ctx, params []string) - -func (h Handler) Serve(ctx Ctx, params []string) { - h(ctx, params) -} - -type Middleware func(Handler) Handler - -type node struct { - segment string - handler Handler - children []*node - isDynamic bool - isWildcard bool - maxParams uint8 -} - -type Router struct { - get *node - post *node - put *node - patch *node - delete *node - middleware []Middleware - paramsBuffer []string // Pre-allocated buffer for parameters -} - -type Group struct { - router *Router - prefix string - middleware []Middleware -} - -// Creates a new Router instance. -func New() *Router { - return &Router{ - get: &node{}, - post: &node{}, - put: &node{}, - patch: &node{}, - delete: &node{}, - middleware: []Middleware{}, - paramsBuffer: make([]string, 64), - } -} - -// Implements the Handler interface for fasthttp -func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) { - path := string(ctx.Path()) - method := string(ctx.Method()) - - h, params, found := r.Lookup(method, path) - if !found { - ctx.SetStatusCode(fasthttp.StatusNotFound) - return - } - - h(ctx, params) -} - -// Returns a fasthttp request handler -func (r *Router) Handler() fasthttp.RequestHandler { - return r.ServeHTTP -} - -// Adds middleware to the router. -func (r *Router) Use(mw ...Middleware) *Router { - r.middleware = append(r.middleware, mw...) - return r -} - -// Creates a new route group. -func (r *Router) Group(prefix string) *Group { - return &Group{router: r, prefix: prefix, middleware: []Middleware{}} -} - -// Adds middleware to the group. -func (g *Group) Use(mw ...Middleware) *Group { - g.middleware = append(g.middleware, mw...) - return g -} - -// Creates a nested group. -func (g *Group) Group(prefix string) *Group { - return &Group{ - router: g.router, - prefix: g.prefix + prefix, - middleware: append([]Middleware{}, g.middleware...), - } -} - -// Applies middleware in reverse order. -func applyMiddleware(h Handler, mw []Middleware) Handler { - for i := len(mw) - 1; i >= 0; i-- { - h = mw[i](h) - } - return h -} - -// Registers a handler for the given method and path. -func (r *Router) Handle(method, path string, h Handler) error { - root := r.methodNode(method) - if root == nil { - return fmt.Errorf("unsupported method: %s", method) - } - return r.addRoute(root, path, h, r.middleware) -} - -func (r *Router) methodNode(method string) *node { - switch method { - case "GET": - return r.get - case "POST": - return r.post - case "PUT": - return r.put - case "PATCH": - return r.patch - case "DELETE": - return r.delete - default: - return nil - } -} - -// Registers a GET handler. -func (r *Router) Get(path string, h Handler) error { - return r.Handle("GET", path, h) -} - -// Registers a POST handler. -func (r *Router) Post(path string, h Handler) error { - return r.Handle("POST", path, h) -} - -// Registers a PUT handler. -func (r *Router) Put(path string, h Handler) error { - return r.Handle("PUT", path, h) -} - -// Registers a PATCH handler. -func (r *Router) Patch(path string, h Handler) error { - return r.Handle("PATCH", path, h) -} - -// Registers a DELETE handler. -func (r *Router) Delete(path string, h Handler) error { - return r.Handle("DELETE", path, h) -} - -func (g *Group) buildGroupMiddleware() []Middleware { - mw := append([]Middleware{}, g.router.middleware...) - return append(mw, g.middleware...) -} - -// Registers a handler in the group. -func (g *Group) Handle(method, path string, h Handler) error { - root := g.router.methodNode(method) - if root == nil { - return fmt.Errorf("unsupported method: %s", method) - } - return g.router.addRoute(root, g.prefix+path, h, g.buildGroupMiddleware()) -} - -// Registers a GET handler in the group. -func (g *Group) Get(path string, h Handler) error { - return g.Handle("GET", path, h) -} - -// Registers a POST handler in the group. -func (g *Group) Post(path string, h Handler) error { - return g.Handle("POST", path, h) -} - -// Registers a PUT handler in the group. -func (g *Group) Put(path string, h Handler) error { - return g.Handle("PUT", path, h) -} - -// Registers a PATCH handler in the group. -func (g *Group) Patch(path string, h Handler) error { - return g.Handle("PATCH", path, h) -} - -// Registers a DELETE handler in the group. -func (g *Group) Delete(path string, h Handler) error { - return g.Handle("DELETE", path, h) -} - -// Applies specific middleware for next registration. -func (r *Router) WithMiddleware(mw ...Middleware) *MiddlewareRouter { - return &MiddlewareRouter{router: r, middleware: mw} -} - -// Applies specific middleware for next group route. -func (g *Group) WithMiddleware(mw ...Middleware) *MiddlewareGroup { - return &MiddlewareGroup{group: g, middleware: mw} -} - -type MiddlewareRouter struct { - router *Router - middleware []Middleware -} - -type MiddlewareGroup struct { - group *Group - middleware []Middleware -} - -func (mr *MiddlewareRouter) buildMiddleware() []Middleware { - mw := append([]Middleware{}, mr.router.middleware...) - return append(mw, mr.middleware...) -} - -// Registers a handler with middleware router. -func (mr *MiddlewareRouter) Handle(method, path string, h Handler) error { - root := mr.router.methodNode(method) - if root == nil { - return fmt.Errorf("unsupported method: %s", method) - } - return mr.router.addRoute(root, path, h, mr.buildMiddleware()) -} - -// Registers a GET handler with middleware router. -func (mr *MiddlewareRouter) Get(path string, h Handler) error { - return mr.Handle("GET", path, h) -} - -// Registers a POST handler with middleware router. -func (mr *MiddlewareRouter) Post(path string, h Handler) error { - return mr.Handle("POST", path, h) -} - -// Registers a PUT handler with middleware router. -func (mr *MiddlewareRouter) Put(path string, h Handler) error { - return mr.Handle("PUT", path, h) -} - -// Registers a PATCH handler with middleware router. -func (mr *MiddlewareRouter) Patch(path string, h Handler) error { - return mr.Handle("PATCH", path, h) -} - -// Registers a DELETE handler with middleware router. -func (mr *MiddlewareRouter) Delete(path string, h Handler) error { - return mr.Handle("DELETE", path, h) -} - -func (mg *MiddlewareGroup) buildMiddleware() []Middleware { - mw := append([]Middleware{}, mg.group.router.middleware...) - mw = append(mw, mg.group.middleware...) - return append(mw, mg.middleware...) -} - -// Registers a handler with middleware group. -func (mg *MiddlewareGroup) Handle(method, path string, h Handler) error { - root := mg.group.router.methodNode(method) - if root == nil { - return fmt.Errorf("unsupported method: %s", method) - } - return mg.group.router.addRoute(root, mg.group.prefix+path, h, mg.buildMiddleware()) -} - -// Registers a GET handler with middleware group. -func (mg *MiddlewareGroup) Get(path string, h Handler) error { - return mg.Handle("GET", path, h) -} - -// Registers a POST handler with middleware group. -func (mg *MiddlewareGroup) Post(path string, h Handler) error { - return mg.Handle("POST", path, h) -} - -// Registers a PUT handler with middleware group. -func (mg *MiddlewareGroup) Put(path string, h Handler) error { - return mg.Handle("PUT", path, h) -} - -// Registers a PATCH handler with middleware group. -func (mg *MiddlewareGroup) Patch(path string, h Handler) error { - return mg.Handle("PATCH", path, h) -} - -// Registers a DELETE handler with middleware group. -func (mg *MiddlewareGroup) Delete(path string, h Handler) error { - return mg.Handle("DELETE", path, h) -} - -// Adapts a standard fasthttp.RequestHandler to the router's Handler -func StandardHandler(handler fasthttp.RequestHandler) Handler { - return func(ctx Ctx, _ []string) { - handler(ctx) - } -} - -// Extracts the next path segment. -func readSegment(path string, start int) (segment string, end int, hasMore bool) { - if start >= len(path) { - return "", start, false - } - if path[start] == '/' { - start++ - } - if start >= len(path) { - return "", start, false - } - end = start - for end < len(path) && path[end] != '/' { - end++ - } - return path[start:end], end, end < len(path) -} - -// Adds a new route to the trie. -func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error { - h = applyMiddleware(h, mw) - if path == "/" { - root.handler = h - return nil - } - current := root - pos := 0 - lastWC := false - count := uint8(0) - for { - seg, newPos, more := readSegment(path, pos) - if seg == "" { - break - } - isDyn := len(seg) > 1 && seg[0] == ':' - isWC := len(seg) > 0 && seg[0] == '*' - if isWC { - if lastWC || more { - return fmt.Errorf("wildcard must be the last segment in the path") - } - lastWC = true - } - if isDyn || isWC { - count++ - } - var child *node - for _, c := range current.children { - if c.segment == seg { - child = c - break - } - } - if child == nil { - child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC} - current.children = append(current.children, child) - } - if child.maxParams < count { - child.maxParams = count - } - current = child - pos = newPos - } - current.handler = h - return nil -} - -// Finds a handler matching method and path. -func (r *Router) Lookup(method, path string) (Handler, []string, bool) { - root := r.methodNode(method) - if root == nil { - return nil, nil, false - } - if path == "/" { - return root.handler, nil, root.handler != nil - } - - buffer := r.paramsBuffer - if cap(buffer) < int(root.maxParams) { - buffer = make([]string, root.maxParams) - r.paramsBuffer = buffer - } - buffer = buffer[:0] - - h, paramCount, found := match(root, path, 0, &buffer) - if !found { - return nil, nil, false - } - - return h, buffer[:paramCount], true -} - -// Traverses the trie to find a handler. -func match(current *node, path string, start int, params *[]string) (Handler, int, bool) { - paramCount := 0 - - for _, c := range current.children { - if c.isWildcard { - rem := path[start:] - if len(rem) > 0 && rem[0] == '/' { - rem = rem[1:] - } - *params = append(*params, rem) - return c.handler, 1, c.handler != nil - } - } - - seg, pos, more := readSegment(path, start) - if seg == "" { - return current.handler, 0, current.handler != nil - } - - for _, c := range current.children { - if c.segment == seg || c.isDynamic { - if c.isDynamic { - *params = append(*params, seg) - paramCount++ - } - if !more { - return c.handler, paramCount, c.handler != nil - } - h, nestedCount, ok := match(c, path, pos, params) - if ok { - return h, paramCount + nestedCount, true - } - } - } - - return nil, 0, false -} diff --git a/internal/routes/auth.go b/internal/routes/auth.go index e217164..2555b8b 100644 --- a/internal/routes/auth.go +++ b/internal/routes/auth.go @@ -4,35 +4,31 @@ import ( "fmt" "strings" - "dk/internal/auth" "dk/internal/components" "dk/internal/models/users" - "dk/internal/password" - "dk/internal/router" - "dk/internal/session" - "github.com/valyala/fasthttp" + sushi "git.sharkk.net/Sharkk/Sushi" + "git.sharkk.net/Sharkk/Sushi/auth" + "git.sharkk.net/Sharkk/Sushi/password" ) // RegisterAuthRoutes sets up authentication routes -func RegisterAuthRoutes(r *router.Router) { - guests := r.Group("") - guests.Use(auth.RequireGuest()) - - guests.Get("/login", showLogin) - guests.Post("/login", processLogin) - guests.Get("/register", showRegister) - guests.Post("/register", processRegister) - - authed := r.Group("") - authed.Use(auth.RequireAuth()) +func RegisterAuthRoutes(app *sushi.App) { + // Public routes (no auth required) + app.Get("/login", showLogin) + app.Post("/login", processLogin) + app.Get("/register", showRegister) + app.Post("/register", processRegister) + // Protected routes + authed := app.Group("") + authed.Use(auth.RequireAuth("/login")) authed.Post("/logout", processLogout) } // showLogin displays the login form -func showLogin(ctx router.Ctx, _ []string) { - sess := ctx.UserValue("session").(*session.Session) +func showLogin(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() var id string if formData, exists := sess.Get("form_data"); exists { @@ -41,7 +37,6 @@ func showLogin(ctx router.Ctx, _ []string) { } } sess.Delete("form_data") - session.Store(sess) components.RenderPage(ctx, "Log In", "auth/login.html", map[string]any{ "id": id, @@ -49,33 +44,35 @@ func showLogin(ctx router.Ctx, _ []string) { } // processLogin handles login form submission -func processLogin(ctx router.Ctx, _ []string) { +func processLogin(ctx sushi.Ctx) { email := strings.TrimSpace(string(ctx.PostArgs().Peek("id"))) userPassword := string(ctx.PostArgs().Peek("password")) if email == "" || userPassword == "" { setFlashAndFormData(ctx, "Email and password are required", map[string]string{"id": email}) - ctx.Redirect("/login", fasthttp.StatusFound) + ctx.Redirect("/login") return } user, err := authenticate(email, userPassword) if err != nil { setFlashAndFormData(ctx, "Invalid email or password", map[string]string{"id": email}) - ctx.Redirect("/login", fasthttp.StatusFound) + ctx.Redirect("/login") return } - auth.Login(ctx, user) + ctx.Login(user.ID, user) - // CSRF token is already in session, no need to transfer from cookie + // Set success message + sess := ctx.GetCurrentSession() + sess.SetFlash("success", fmt.Sprintf("Welcome back, %s!", user.Username)) - ctx.Redirect("/", fasthttp.StatusFound) + ctx.Redirect("/") } // showRegister displays the registration form -func showRegister(ctx router.Ctx, _ []string) { - sess := ctx.UserValue("session").(*session.Session) +func showRegister(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() var username, email string if formData, exists := sess.Get("form_data"); exists { @@ -85,16 +82,16 @@ func showRegister(ctx router.Ctx, _ []string) { } } sess.Delete("form_data") - session.Store(sess) components.RenderPage(ctx, "Register", "auth/register.html", map[string]any{ "username": username, "email": email, + "error_message": sess.GetFlashMessage("error"), }) } // processRegister handles registration form submission -func processRegister(ctx router.Ctx, _ []string) { +func processRegister(ctx sushi.Ctx) { username := strings.TrimSpace(string(ctx.PostArgs().Peek("username"))) email := strings.TrimSpace(string(ctx.PostArgs().Peek("email"))) userPassword := string(ctx.PostArgs().Peek("password")) @@ -107,53 +104,49 @@ func processRegister(ctx router.Ctx, _ []string) { if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil { setFlashAndFormData(ctx, err.Error(), formData) - ctx.Redirect("/register", fasthttp.StatusFound) + ctx.Redirect("/register") return } if _, err := users.ByUsername(username); err == nil { setFlashAndFormData(ctx, "Username already exists", formData) - ctx.Redirect("/register", fasthttp.StatusFound) + ctx.Redirect("/register") return } if _, err := users.ByEmail(email); err == nil { setFlashAndFormData(ctx, "Email already registered", formData) - ctx.Redirect("/register", fasthttp.StatusFound) + ctx.Redirect("/register") return } user := users.New() user.Username = username user.Email = email - user.Password = password.Hash(userPassword) + user.Password = password.HashPassword(userPassword) user.ClassID = 1 user.Auth = 1 if err := user.Insert(); err != nil { setFlashAndFormData(ctx, "Failed to create account", formData) - ctx.Redirect("/register", fasthttp.StatusFound) + ctx.Redirect("/register") return } - // Auto-login after registration (this will update the current session) - auth.Login(ctx, user) + // Auto-login after registration + ctx.Login(user.ID, user) - // Update success message (Login already sets a message, so override it) - if sess := ctx.UserValue("session").(*session.Session); sess != nil { - sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username)) - session.Store(sess) - } + // Set success message + sess := ctx.GetCurrentSession() + sess.SetFlash("success", fmt.Sprintf("Greetings, %s!", user.Username)) - // CSRF token is already in session, no need to transfer from cookie - - ctx.Redirect("/", fasthttp.StatusFound) + ctx.Redirect("/") } // processLogout handles logout -func processLogout(ctx router.Ctx, params []string) { - auth.Logout(ctx) - ctx.Redirect("/", fasthttp.StatusFound) +func processLogout(ctx sushi.Ctx) { + ctx.Logout() + ctx.Redirect("/") } // Helper functions @@ -183,11 +176,10 @@ func validateRegistration(username, email, password, confirmPassword string) err return nil } -func setFlashAndFormData(ctx router.Ctx, message string, formData map[string]string) { - sess := ctx.UserValue("session").(*session.Session) +func setFlashAndFormData(ctx sushi.Ctx, message string, formData map[string]string) { + sess := ctx.GetCurrentSession() sess.SetFlash("error", message) sess.Set("form_data", formData) - session.Store(sess) } func authenticate(usernameOrEmail, plainPassword string) (*users.User, error) { @@ -196,13 +188,15 @@ func authenticate(usernameOrEmail, plainPassword string) (*users.User, error) { user, err = users.ByUsername(usernameOrEmail) if err != nil { + fmt.Println(err.Error()) user, err = users.ByEmail(usernameOrEmail) if err != nil { + fmt.Println(err.Error()) return nil, err } } - isValid, err := password.Verify(plainPassword, user.Password) + isValid, err := password.VerifyPassword(plainPassword, user.Password) if err != nil { return nil, err } diff --git a/internal/routes/fight.go b/internal/routes/fight.go index 0d63163..861db89 100644 --- a/internal/routes/fight.go +++ b/internal/routes/fight.go @@ -2,45 +2,61 @@ package routes import ( "dk/internal/actions" - "dk/internal/auth" "dk/internal/components" "dk/internal/helpers" - "dk/internal/middleware" "dk/internal/models/fights" "dk/internal/models/monsters" "dk/internal/models/spells" "dk/internal/models/users" - "dk/internal/router" - "dk/internal/session" "fmt" "math/rand" "strconv" + + sushi "git.sharkk.net/Sharkk/Sushi" + "git.sharkk.net/Sharkk/Sushi/auth" ) -func RegisterFightRoutes(r *router.Router) { - group := r.Group("/fight") - group.Use(auth.RequireAuth()) - group.Use(middleware.RequireFighting()) +func RegisterFightRoutes(app *sushi.App) { + group := app.Group("/fight") + group.Use(auth.RequireAuth("/login")) + group.Use(requireFighting()) group.Get("/", showFight) group.Post("/", handleFightAction) } -func showFight(ctx router.Ctx, _ []string) { - sess := ctx.UserValue("session").(*session.Session) - user := ctx.UserValue("user").(*users.User) +// requireFighting middleware ensures the user is in a fight +func requireFighting() sushi.Middleware { + return func(ctx sushi.Ctx, next func()) { + user := ctx.GetCurrentUser() + if user == nil { + ctx.SendError(401, "Not authenticated") + return + } + + userModel := user.(*users.User) + if !userModel.IsFighting() { + ctx.Redirect("/") + return + } + + next() + } +} + +func showFight(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() + user := ctx.GetCurrentUser().(*users.User) fight, err := fights.Find(user.FightID) if err != nil { - ctx.SetContentType("text/plain") - ctx.SetBodyString("Fight not found") + ctx.SendError(404, "Fight not found") return } monster, err := monsters.Find(fight.MonsterID) if err != nil { - ctx.SetContentType("text/plain") - ctx.SetBodyString("Monster not found for fight") + ctx.SendError(404, "Monster not found for fight") return } @@ -82,9 +98,9 @@ func showFight(ctx router.Ctx, _ []string) { }) } -func handleFightAction(ctx router.Ctx, _ []string) { - user := ctx.UserValue("user").(*users.User) - sess := ctx.UserValue("session").(*session.Session) +func handleFightAction(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() + user := ctx.GetCurrentUser().(*users.User) fight, err := fights.Find(user.FightID) if err != nil { diff --git a/internal/routes/index.go b/internal/routes/index.go index 6c64a0a..c187bc7 100644 --- a/internal/routes/index.go +++ b/internal/routes/index.go @@ -5,52 +5,51 @@ import ( "dk/internal/components" "dk/internal/models/towns" "dk/internal/models/users" - "dk/internal/router" - "dk/internal/session" "slices" "strconv" + + sushi "git.sharkk.net/Sharkk/Sushi" ) -func Index(ctx router.Ctx, _ []string) { - user, ok := ctx.UserValue("user").(*users.User) - if !ok || user == nil { +func Index(ctx sushi.Ctx) { + if !ctx.IsAuthenticated() { components.RenderPage(ctx, "", "intro.html", nil) return } + user := ctx.GetCurrentUser().(*users.User) + switch user.Currently { case "In Town": - ctx.Redirect("/town", 303) + ctx.Redirect("/town") case "Exploring": - ctx.Redirect("/explore", 303) + ctx.Redirect("/explore") case "Fighting": - ctx.Redirect("/fight", 303) + ctx.Redirect("/fight") default: - ctx.Redirect("/explore", 303) + ctx.Redirect("/explore") } } -func Move(ctx router.Ctx, _ []string) { - sess := ctx.UserValue("session").(*session.Session) - user := ctx.UserValue("user").(*users.User) +func Move(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() + user := ctx.GetCurrentUser().(*users.User) if user.Currently == "Fighting" { sess.SetFlash("error", "You can't just run from a fight!") - ctx.Redirect("/fight", 303) + ctx.Redirect("/fight") return } dir, err := strconv.Atoi(string(ctx.PostArgs().Peek("direction"))) if err != nil { - ctx.SetContentType("text/plain") - ctx.SetBodyString("move form parsing error") + ctx.SendError(400, "move form parsing error") return } currently, newX, newY, err := actions.Move(user, actions.Direction(dir)) if err != nil { - ctx.SetContentType("text/plain") - ctx.SetBodyString("move error: " + err.Error()) + ctx.SendError(400, "move error: "+err.Error()) return } @@ -60,50 +59,45 @@ func Move(ctx router.Ctx, _ []string) { switch currently { case "In Town": - ctx.Redirect("/town", 303) + ctx.Redirect("/town") case "Fighting": - ctx.Redirect("/fight", 303) + ctx.Redirect("/fight") default: - ctx.Redirect("/explore", 303) + ctx.Redirect("/explore") } } -func Explore(ctx router.Ctx, _ []string) { - user := ctx.UserValue("user").(*users.User) +func Explore(ctx sushi.Ctx) { + user := ctx.GetCurrentUser().(*users.User) if user.Currently != "Exploring" { - ctx.Redirect("/", 303) + ctx.Redirect("/") return } components.RenderPage(ctx, "", "explore.html", nil) } -func Teleport(ctx router.Ctx, params []string) { - sess := ctx.UserValue("session").(*session.Session) +func Teleport(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() - id, err := strconv.Atoi(params[0]) - if err != nil { - sess.SetFlash("error", "Error teleporting; "+err.Error()) - ctx.Redirect("/", 302) - return - } + id := ctx.Param("id").Int() town, err := towns.Find(id) if err != nil { sess.SetFlash("error", "Failed to teleport. Unknown town.") - ctx.Redirect("/", 302) + ctx.Redirect("/") return } - user := ctx.UserValue("user").(*users.User) + user := ctx.GetCurrentUser().(*users.User) if !slices.Contains(user.GetTownIDs(), id) { sess.SetFlash("error", "You don't have a map to "+town.Name+".") - ctx.Redirect("/", 302) + ctx.Redirect("/") return } if user.TP < town.TPCost { sess.SetFlash("error", "You don't have enough TP to teleport to "+town.Name+".") - ctx.Redirect("/", 302) + ctx.Redirect("/") return } @@ -113,5 +107,5 @@ func Teleport(ctx router.Ctx, params []string) { user.Save() sess.SetFlash("success", "You teleported to "+town.Name+" successfully!") - ctx.Redirect("/town", 302) + ctx.Redirect("/town") } diff --git a/internal/routes/town.go b/internal/routes/town.go index f87478d..405cac9 100644 --- a/internal/routes/town.go +++ b/internal/routes/town.go @@ -2,17 +2,16 @@ package routes import ( "dk/internal/actions" - "dk/internal/auth" "dk/internal/components" "dk/internal/helpers" - "dk/internal/middleware" "dk/internal/models/items" "dk/internal/models/towns" "dk/internal/models/users" - "dk/internal/router" - "dk/internal/session" + "fmt" "slices" - "strconv" + + sushi "git.sharkk.net/Sharkk/Sushi" + "git.sharkk.net/Sharkk/Sushi/auth" ) // Map acts as a representation of owned/unowned maps in the town stores. @@ -26,10 +25,10 @@ type Map struct { TP int } -func RegisterTownRoutes(r *router.Router) { - group := r.Group("/town") - group.Use(auth.RequireAuth()) - group.Use(middleware.RequireTown()) +func RegisterTownRoutes(app *sushi.App) { + group := app.Group("/town") + group.Use(auth.RequireAuth("/login")) + group.Use(requireTown()) group.Get("/", showTown) group.Get("/inn", showInn) @@ -40,7 +39,33 @@ func RegisterTownRoutes(r *router.Router) { group.Get("/maps/buy/:id", buyMap) } -func showTown(ctx router.Ctx, _ []string) { +// requireTown middleware ensures the user is in a town +func requireTown() sushi.Middleware { + return func(ctx sushi.Ctx, next func()) { + user := ctx.GetCurrentUser() + if user == nil { + ctx.SendError(401, "Not authenticated") + return + } + + userModel := user.(*users.User) + if userModel.Currently != "In Town" { + ctx.SendError(403, "You must be in town") + return + } + + town, err := towns.ByCoords(userModel.X, userModel.Y) + if err != nil || town == nil || town.ID == 0 { + ctx.SendError(403, fmt.Sprintf("Invalid town location (%d, %d)", userModel.X, userModel.Y)) + return + } + + ctx.SetUserValue("town", town) + next() + } +} + +func showTown(ctx sushi.Ctx) { town := ctx.UserValue("town").(*towns.Town) components.RenderPage(ctx, town.Name, "town/town.html", map[string]any{ "town": town, @@ -49,7 +74,7 @@ func showTown(ctx router.Ctx, _ []string) { }) } -func showInn(ctx router.Ctx, _ []string) { +func showInn(ctx sushi.Ctx) { town := ctx.UserValue("town").(*towns.Town) components.RenderPage(ctx, town.Name+" Inn", "town/inn.html", map[string]any{ "town": town, @@ -57,19 +82,20 @@ func showInn(ctx router.Ctx, _ []string) { }) } -func rest(ctx router.Ctx, _ []string) { - sess := ctx.UserValue("session").(*session.Session) +func rest(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() town := ctx.UserValue("town").(*towns.Town) - user := ctx.UserValue("user").(*users.User) + user := ctx.GetCurrentUser().(*users.User) if user.Gold < town.InnCost { sess.SetFlash("error", "You can't afford to stay here tonight.") - ctx.Redirect("/town/inn", 303) + ctx.Redirect("/town/inn") return } user.Gold -= town.InnCost user.HP, user.MP, user.TP = user.MaxHP, user.MaxMP, user.MaxTP + user.Save() components.RenderPage(ctx, town.Name+" Inn", "town/inn.html", map[string]any{ "town": town, @@ -77,14 +103,13 @@ func rest(ctx router.Ctx, _ []string) { }) } -func showShop(ctx router.Ctx, _ []string) { - sess := ctx.UserValue("session").(*session.Session) +func showShop(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() var errorHTML string - if flash, exists := sess.GetFlash("error"); exists { - if msg, ok := flash.(string); ok { - errorHTML = `
` + msg + "
" - } + errorMsg := sess.GetFlashMessage("error") + if errorMsg != "" { + errorHTML = `
` + errorMsg + "
" } town := ctx.UserValue("town").(*towns.Town) @@ -107,34 +132,29 @@ func showShop(ctx router.Ctx, _ []string) { }) } -func buyItem(ctx router.Ctx, params []string) { - sess := ctx.UserValue("session").(*session.Session) +func buyItem(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() - id, err := strconv.Atoi(params[0]) - if err != nil { - sess.SetFlash("error", "Error purchasing item; "+err.Error()) - ctx.Redirect("/town/shop", 302) - return - } + id := ctx.Param("id").Int() town := ctx.UserValue("town").(*towns.Town) if !slices.Contains(town.GetShopItems(), id) { sess.SetFlash("error", "The item doesn't exist in this shop.") - ctx.Redirect("/town/shop", 302) + ctx.Redirect("/town/shop") return } item, err := items.Find(id) if err != nil { sess.SetFlash("error", "Error purchasing item; "+err.Error()) - ctx.Redirect("/town/shop", 302) + ctx.Redirect("/town/shop") return } - user := ctx.UserValue("user").(*users.User) + user := ctx.GetCurrentUser().(*users.User) if user.Gold < item.Value { sess.SetFlash("error", "You don't have enough gold to buy "+item.Name) - ctx.Redirect("/town/shop", 302) + ctx.Redirect("/town/shop") return } @@ -142,21 +162,20 @@ func buyItem(ctx router.Ctx, params []string) { actions.UserEquipItem(user, item) user.Save() - ctx.Redirect("/town/shop", 302) + ctx.Redirect("/town/shop") } -func showMaps(ctx router.Ctx, _ []string) { - sess := ctx.UserValue("session").(*session.Session) +func showMaps(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() var errorHTML string - if flash, exists := sess.GetFlash("error"); exists { - if msg, ok := flash.(string); ok { - errorHTML = `
` + msg + "
" - } + errorMsg := sess.GetFlashMessage("error") + if errorMsg != "" { + errorHTML = `
` + errorMsg + "
" } town := ctx.UserValue("town").(*towns.Town) - user := ctx.UserValue("user").(*users.User) + user := ctx.GetCurrentUser().(*users.User) maplist := helpers.NewOrderedMap[int, Map]() towns, _ := towns.All() @@ -190,37 +209,30 @@ func showMaps(ctx router.Ctx, _ []string) { }) } -func buyMap(ctx router.Ctx, params []string) { - sess := ctx.UserValue("session").(*session.Session) +func buyMap(ctx sushi.Ctx) { + sess := ctx.GetCurrentSession() - id, err := strconv.Atoi(params[0]) - if err != nil { - sess.SetFlash("error", "Error purchasing map; "+err.Error()) - ctx.Redirect("/town/maps", 302) - return - } + id := ctx.Param("id").Int() mapped, err := towns.Find(id) if err != nil { sess.SetFlash("error", "Error purchasing map; "+err.Error()) - ctx.Redirect("/town/maps", 302) + ctx.Redirect("/town/maps") return } - user := ctx.UserValue("user").(*users.User) + user := ctx.GetCurrentUser().(*users.User) if user.Gold < mapped.MapCost { sess.SetFlash("error", "You don't have enough gold to buy the map to "+mapped.Name) - ctx.Redirect("/town/maps", 302) + ctx.Redirect("/town/maps") return } user.Gold -= mapped.MapCost - if user.Towns == "" { - user.Towns = params[0] - } else { - user.Towns += "," + params[0] - } + townIDs := user.GetTownIDs() + townIDs = append(townIDs, id) + user.SetTownIDs(townIDs) user.Save() - ctx.Redirect("/town/maps", 302) + ctx.Redirect("/town/maps") } diff --git a/internal/session/manager.go b/internal/session/manager.go deleted file mode 100644 index 0735a02..0000000 --- a/internal/session/manager.go +++ /dev/null @@ -1,180 +0,0 @@ -package session - -import ( - "encoding/json" - "os" - "sync" - "time" -) - -// SessionManager handles session storage and persistence -type SessionManager struct { - mu sync.RWMutex - sessions map[string]*Session - filePath string -} - -var Manager *SessionManager - -// sessionData represents session data for JSON serialization (excludes ID) -type sessionData struct { - UserID int `json:"user_id"` - ExpiresAt int64 `json:"expires_at"` - Data map[string]any `json:"data"` -} - -// Init initializes the global session manager -func Init(filePath string) { - if Manager != nil { - panic("session manager already initialized") - } - - Manager = &SessionManager{ - sessions: make(map[string]*Session), - filePath: filePath, - } - - Manager.load() -} - -// GetManager returns the global session manager -func GetManager() *SessionManager { - if Manager == nil { - panic("session manager not initialized") - } - return Manager -} - -// Create creates and stores a new session -func (sm *SessionManager) Create(userID int) *Session { - sess := New(userID) - sm.mu.Lock() - sm.sessions[sess.ID] = sess - sm.mu.Unlock() - return sess -} - -// Get retrieves a session by ID -func (sm *SessionManager) Get(sessionID string) (*Session, bool) { - sm.mu.RLock() - sess, exists := sm.sessions[sessionID] - sm.mu.RUnlock() - - if !exists || sess.IsExpired() { - if exists { - sm.Delete(sessionID) - } - return nil, false - } - - return sess, true -} - -// Store saves a session in memory (updates existing or creates new) -func (sm *SessionManager) Store(sess *Session) { - sm.mu.Lock() - sm.sessions[sess.ID] = sess - sm.mu.Unlock() -} - -// Delete removes a session -func (sm *SessionManager) Delete(sessionID string) { - sm.mu.Lock() - delete(sm.sessions, sessionID) - sm.mu.Unlock() -} - -// Cleanup removes expired sessions -func (sm *SessionManager) Cleanup() { - sm.mu.Lock() - for id, sess := range sm.sessions { - if sess.IsExpired() { - delete(sm.sessions, id) - } - } - sm.mu.Unlock() -} - -// Stats returns session statistics -func (sm *SessionManager) Stats() (total, active int) { - sm.mu.RLock() - defer sm.mu.RUnlock() - - total = len(sm.sessions) - for _, sess := range sm.sessions { - if !sess.IsExpired() { - active++ - } - } - return -} - -// load reads sessions from the JSON file -func (sm *SessionManager) load() { - if sm.filePath == "" { - return - } - - data, err := os.ReadFile(sm.filePath) - if err != nil { - return // File doesn't exist or can't be read - } - - var sessionsData map[string]*sessionData - if err := json.Unmarshal(data, &sessionsData); err != nil { - return // Invalid JSON - } - - now := time.Now().Unix() - sm.mu.Lock() - for id, data := range sessionsData { - if data != nil && data.ExpiresAt > now { - sess := &Session{ - ID: id, - UserID: data.UserID, // Make sure we restore the UserID properly - ExpiresAt: data.ExpiresAt, - Data: data.Data, - } - if sess.Data == nil { - sess.Data = make(map[string]any) - } - sm.sessions[id] = sess - } - } - sm.mu.Unlock() -} - -// Save writes sessions to the JSON file -func (sm *SessionManager) Save() error { - if sm.filePath == "" { - return nil - } - - sm.Cleanup() // Remove expired sessions before saving - - sm.mu.RLock() - - // Convert sessions to sessionData (without ID field) - sessionsData := make(map[string]*sessionData, len(sm.sessions)) - for id, sess := range sm.sessions { - sessionsData[id] = &sessionData{ - UserID: sess.UserID, // Save the actual UserID from the struct - ExpiresAt: sess.ExpiresAt, - Data: sess.Data, - } - } - - data, err := json.MarshalIndent(sessionsData, "", "\t") - sm.mu.RUnlock() - - if err != nil { - return err - } - - return os.WriteFile(sm.filePath, data, 0600) -} - -// Close saves sessions and cleans up -func (sm *SessionManager) Close() error { - return sm.Save() -} diff --git a/internal/session/session.go b/internal/session/session.go deleted file mode 100644 index 9915d2e..0000000 --- a/internal/session/session.go +++ /dev/null @@ -1,146 +0,0 @@ -// session.go -package session - -import ( - "crypto/rand" - "encoding/hex" - "time" -) - -const ( - DefaultExpiration = 24 * time.Hour - IDLength = 32 -) - -// Session represents a user session -type Session struct { - ID string `json:"id"` - UserID int `json:"user_id"` // 0 for guest sessions - ExpiresAt int64 `json:"expires_at"` - Data map[string]any `json:"data"` -} - -// New creates a new session -func New(userID int) *Session { - return &Session{ - ID: generateID(), - UserID: userID, - ExpiresAt: time.Now().Add(DefaultExpiration).Unix(), - Data: make(map[string]any), - } -} - -// IsExpired checks if the session has expired -func (s *Session) IsExpired() bool { - return time.Now().Unix() > s.ExpiresAt -} - -// Touch extends the session expiration -func (s *Session) Touch() { - s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix() -} - -// Set stores a value in the session -func (s *Session) Set(key string, value any) { - s.Data[key] = value -} - -// Get retrieves a value from the session -func (s *Session) Get(key string) (any, bool) { - value, exists := s.Data[key] - return value, exists -} - -// Delete removes a value from the session -func (s *Session) Delete(key string) { - delete(s.Data, key) -} - -// SetFlash stores a flash message (consumed on next Get) -func (s *Session) SetFlash(key string, value any) { - s.Set("flash_"+key, value) -} - -// GetFlash retrieves and removes a flash message -func (s *Session) GetFlash(key string) (any, bool) { - flashKey := "flash_" + key - value, exists := s.Get(flashKey) - if exists { - s.Delete(flashKey) - } - return value, exists -} - -// GetFlashMessage retrieves and removes a flash message as string or empty string -func (s *Session) GetFlashMessage(key string) string { - if flash, exists := s.GetFlash(key); exists { - if msg, ok := flash.(string); ok { - return msg - } - } - return "" -} - -// DeleteFlash removes a flash from the session. -func (s *Session) DeleteFlash(key string) { - s.GetFlash(key) -} - -// RegenerateID creates a new session ID and updates storage -func (s *Session) RegenerateID() { - oldID := s.ID - s.ID = generateID() - - if Manager != nil { - Manager.mu.Lock() - delete(Manager.sessions, oldID) - Manager.sessions[s.ID] = s - Manager.mu.Unlock() - } -} - -// SetUserID updates the session's user ID (for login/logout) -func (s *Session) SetUserID(userID int) { - s.UserID = userID -} - -// generateID creates a random session ID -func generateID() string { - bytes := make([]byte, IDLength) - rand.Read(bytes) - return hex.EncodeToString(bytes) -} - -// Package-level convenience functions -func Create(userID int) *Session { - return Manager.Create(userID) -} - -func Get(sessionID string) (*Session, bool) { - return Manager.Get(sessionID) -} - -func Store(sess *Session) { - Manager.Store(sess) -} - -func Delete(sessionID string) { - Manager.Delete(sessionID) -} - -func Cleanup() { - Manager.Cleanup() -} - -func Stats() (total, active int) { - return Manager.Stats() -} - -func Close() error { - return Manager.Close() -} - -// RegenerateID regenerates the session ID for security (package-level convenience) -func RegenerateID(sess *Session) { - sess.RegenerateID() -} diff --git a/internal/store/store.go b/internal/store/store.go deleted file mode 100644 index a3c4394..0000000 --- a/internal/store/store.go +++ /dev/null @@ -1,526 +0,0 @@ -package store - -import ( - "encoding/json" - "fmt" - "maps" - "os" - "path/filepath" - "reflect" - "sort" - "strings" - "sync" -) - -// Validatable interface for entities that can validate themselves -type Validatable interface { - Validate() error -} - -// IndexBuilder function type for building custom indices -type IndexBuilder[T any] func(allItems map[int]*T) any - -// BaseStore provides generic storage with index management -type BaseStore[T any] struct { - items map[int]*T - maxID int - mu sync.RWMutex - itemType reflect.Type - indices map[string]any - indexBuilders map[string]IndexBuilder[T] -} - -// NewBaseStore creates a new base store for type T -func NewBaseStore[T any]() *BaseStore[T] { - var zero T - return &BaseStore[T]{ - items: make(map[int]*T), - maxID: 0, - itemType: reflect.TypeOf(zero), - indices: make(map[string]any), - indexBuilders: make(map[string]IndexBuilder[T]), - } -} - -// RegisterIndex registers an index builder function -func (bs *BaseStore[T]) RegisterIndex(name string, builder IndexBuilder[T]) { - bs.mu.Lock() - defer bs.mu.Unlock() - bs.indexBuilders[name] = builder -} - -// GetIndex retrieves a named index -func (bs *BaseStore[T]) GetIndex(name string) (any, bool) { - bs.mu.RLock() - defer bs.mu.RUnlock() - index, exists := bs.indices[name] - return index, exists -} - -// RebuildIndices rebuilds all registered indices -func (bs *BaseStore[T]) RebuildIndices() { - bs.mu.Lock() - defer bs.mu.Unlock() - bs.rebuildIndicesUnsafe() -} - -func (bs *BaseStore[T]) rebuildIndicesUnsafe() { - allItems := make(map[int]*T, len(bs.items)) - maps.Copy(allItems, bs.items) - - for name, builder := range bs.indexBuilders { - bs.indices[name] = builder(allItems) - } -} - -// AddWithRebuild adds item with validation and index rebuild -func (bs *BaseStore[T]) AddWithRebuild(id int, item *T) error { - bs.mu.Lock() - defer bs.mu.Unlock() - - if validatable, ok := any(item).(Validatable); ok { - if err := validatable.Validate(); err != nil { - return err - } - } - - bs.items[id] = item - if id > bs.maxID { - bs.maxID = id - } - - bs.rebuildIndicesUnsafe() - return nil -} - -// RemoveWithRebuild removes item and rebuilds indices -func (bs *BaseStore[T]) RemoveWithRebuild(id int) { - bs.mu.Lock() - defer bs.mu.Unlock() - delete(bs.items, id) - bs.rebuildIndicesUnsafe() -} - -// UpdateWithRebuild updates item with validation and index rebuild -func (bs *BaseStore[T]) UpdateWithRebuild(id int, item *T) error { - return bs.AddWithRebuild(id, item) -} - -// Common Query Methods - -// Find retrieves an item by ID -func (bs *BaseStore[T]) Find(id int) (*T, bool) { - bs.mu.RLock() - defer bs.mu.RUnlock() - item, exists := bs.items[id] - return item, exists -} - -// AllSorted returns all items using named sorted index -func (bs *BaseStore[T]) AllSorted(indexName string) []*T { - bs.mu.RLock() - defer bs.mu.RUnlock() - - if index, exists := bs.indices[indexName]; exists { - if sortedIDs, ok := index.([]int); ok { - result := make([]*T, 0, len(sortedIDs)) - for _, id := range sortedIDs { - if item, exists := bs.items[id]; exists { - result = append(result, item) - } - } - return result - } - } - - // Fallback: return all items by ID order - ids := make([]int, 0, len(bs.items)) - for id := range bs.items { - ids = append(ids, id) - } - sort.Ints(ids) - - result := make([]*T, 0, len(ids)) - for _, id := range ids { - result = append(result, bs.items[id]) - } - return result -} - -// LookupByIndex finds single item using string lookup index -func (bs *BaseStore[T]) LookupByIndex(indexName, key string) (*T, bool) { - bs.mu.RLock() - defer bs.mu.RUnlock() - - if index, exists := bs.indices[indexName]; exists { - if lookupMap, ok := index.(map[string]int); ok { - if id, found := lookupMap[key]; found { - if item, exists := bs.items[id]; exists { - return item, true - } - } - } - } - return nil, false -} - -// GroupByIndex returns items grouped by key -func (bs *BaseStore[T]) GroupByIndex(indexName string, key any) []*T { - bs.mu.RLock() - defer bs.mu.RUnlock() - - if index, exists := bs.indices[indexName]; exists { - switch groupMap := index.(type) { - case map[int][]int: - if intKey, ok := key.(int); ok { - if ids, found := groupMap[intKey]; found { - result := make([]*T, 0, len(ids)) - for _, id := range ids { - if item, exists := bs.items[id]; exists { - result = append(result, item) - } - } - return result - } - } - case map[string][]int: - if strKey, ok := key.(string); ok { - if ids, found := groupMap[strKey]; found { - result := make([]*T, 0, len(ids)) - for _, id := range ids { - if item, exists := bs.items[id]; exists { - result = append(result, item) - } - } - return result - } - } - } - } - return []*T{} -} - -// FilterByIndex returns items matching filter criteria -func (bs *BaseStore[T]) FilterByIndex(indexName string, filterFunc func(*T) bool) []*T { - bs.mu.RLock() - defer bs.mu.RUnlock() - - var sourceIDs []int - - if index, exists := bs.indices[indexName]; exists { - if sortedIDs, ok := index.([]int); ok { - sourceIDs = sortedIDs - } - } - - if sourceIDs == nil { - for id := range bs.items { - sourceIDs = append(sourceIDs, id) - } - sort.Ints(sourceIDs) - } - - var result []*T - for _, id := range sourceIDs { - if item, exists := bs.items[id]; exists && filterFunc(item) { - result = append(result, item) - } - } - return result -} - -// BuildStringLookupIndex creates string-to-ID mapping -func BuildStringLookupIndex[T any](keyFunc func(*T) string) IndexBuilder[T] { - return func(allItems map[int]*T) any { - index := make(map[string]int) - for id, item := range allItems { - key := keyFunc(item) - index[key] = id - } - return index - } -} - -// BuildCaseInsensitiveLookupIndex creates lowercase string-to-ID mapping -func BuildCaseInsensitiveLookupIndex[T any](keyFunc func(*T) string) IndexBuilder[T] { - return func(allItems map[int]*T) any { - index := make(map[string]int) - for id, item := range allItems { - key := strings.ToLower(keyFunc(item)) - index[key] = id - } - return index - } -} - -// BuildIntGroupIndex creates int-to-[]ID mapping -func BuildIntGroupIndex[T any](keyFunc func(*T) int) IndexBuilder[T] { - return func(allItems map[int]*T) any { - index := make(map[int][]int) - for id, item := range allItems { - key := keyFunc(item) - index[key] = append(index[key], id) - } - - // Sort each group by ID - for key := range index { - sort.Ints(index[key]) - } - - return index - } -} - -// BuildStringGroupIndex creates string-to-[]ID mapping -func BuildStringGroupIndex[T any](keyFunc func(*T) string) IndexBuilder[T] { - return func(allItems map[int]*T) any { - index := make(map[string][]int) - for id, item := range allItems { - key := keyFunc(item) - index[key] = append(index[key], id) - } - - // Sort each group by ID - for key := range index { - sort.Ints(index[key]) - } - - return index - } -} - -// BuildSortedListIndex creates sorted []ID list -func BuildSortedListIndex[T any](sortFunc func(*T, *T) bool) IndexBuilder[T] { - return func(allItems map[int]*T) any { - ids := make([]int, 0, len(allItems)) - for id := range allItems { - ids = append(ids, id) - } - - sort.Slice(ids, func(i, j int) bool { - return sortFunc(allItems[ids[i]], allItems[ids[j]]) - }) - - return ids - } -} - -// NewSingleton creates singleton store pattern with sync.Once -func NewSingleton[S any](initFunc func() *S) func() *S { - var store *S - var once sync.Once - - return func() *S { - once.Do(func() { - store = initFunc() - }) - return store - } -} - -// GetNextID returns the next available ID atomically -func (bs *BaseStore[T]) GetNextID() int { - bs.mu.Lock() - defer bs.mu.Unlock() - bs.maxID++ - return bs.maxID -} - -// GetByID retrieves an item by ID -func (bs *BaseStore[T]) GetByID(id int) (*T, bool) { - return bs.Find(id) -} - -// Add adds an item to the store -func (bs *BaseStore[T]) Add(id int, item *T) { - bs.mu.Lock() - defer bs.mu.Unlock() - bs.items[id] = item - if id > bs.maxID { - bs.maxID = id - } -} - -// Remove removes an item from the store -func (bs *BaseStore[T]) Remove(id int) { - bs.mu.Lock() - defer bs.mu.Unlock() - delete(bs.items, id) -} - -// GetAll returns all items -func (bs *BaseStore[T]) GetAll() map[int]*T { - bs.mu.RLock() - defer bs.mu.RUnlock() - result := make(map[int]*T, len(bs.items)) - maps.Copy(result, bs.items) - return result -} - -// Clear removes all items -func (bs *BaseStore[T]) Clear() { - bs.mu.Lock() - defer bs.mu.Unlock() - bs.items = make(map[int]*T) - bs.maxID = 0 - bs.rebuildIndicesUnsafe() -} - -// LoadFromJSON loads items from JSON using reflection -func (bs *BaseStore[T]) LoadFromJSON(filename string) error { - bs.mu.Lock() - defer bs.mu.Unlock() - - data, err := os.ReadFile(filename) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return fmt.Errorf("failed to read JSON: %w", err) - } - - if len(data) == 0 { - return nil - } - - // Create slice of pointers to T - sliceType := reflect.SliceOf(reflect.PointerTo(bs.itemType)) - slicePtr := reflect.New(sliceType) - - if err := json.Unmarshal(data, slicePtr.Interface()); err != nil { - return fmt.Errorf("failed to unmarshal JSON: %w", err) - } - - // Clear existing data - bs.items = make(map[int]*T) - bs.maxID = 0 - - // Extract items using reflection - slice := slicePtr.Elem() - for i := 0; i < slice.Len(); i++ { - item := slice.Index(i).Interface().(*T) - - // Get ID using reflection - itemValue := reflect.ValueOf(item).Elem() - idField := itemValue.FieldByName("ID") - if !idField.IsValid() { - return fmt.Errorf("item type must have an ID field") - } - - id := int(idField.Int()) - bs.items[id] = item - if id > bs.maxID { - bs.maxID = id - } - } - - return nil -} - -// SaveToJSON saves items to JSON atomically with consistent ID ordering -func (bs *BaseStore[T]) SaveToJSON(filename string) error { - bs.mu.RLock() - defer bs.mu.RUnlock() - - // Get sorted IDs for consistent ordering - ids := make([]int, 0, len(bs.items)) - for id := range bs.items { - ids = append(ids, id) - } - sort.Ints(ids) - - // Build items slice in ID order - items := make([]*T, 0, len(bs.items)) - for _, id := range ids { - items = append(items, bs.items[id]) - } - - data, err := json.MarshalIndent(items, "", "\t") - if err != nil { - return fmt.Errorf("failed to marshal to JSON: %w", err) - } - - // Atomic write - tempFile := filename + ".tmp" - if err := os.WriteFile(tempFile, data, 0644); err != nil { - return fmt.Errorf("failed to write temp JSON: %w", err) - } - - if err := os.Rename(tempFile, filename); err != nil { - os.Remove(tempFile) - return fmt.Errorf("failed to rename temp JSON: %w", err) - } - - return nil -} - -// LoadData loads from JSON file or starts empty -func (bs *BaseStore[T]) LoadData(dataPath string) error { - if err := bs.LoadFromJSON(dataPath); err != nil { - if os.IsNotExist(err) { - fmt.Println("No existing data found, starting with empty store") - return nil - } - return fmt.Errorf("failed to load from JSON: %w", err) - } - - fmt.Printf("Loaded %d items from %s\n", len(bs.items), dataPath) - bs.RebuildIndices() // Rebuild indices after loading - return nil -} - -// SaveData saves to JSON file -func (bs *BaseStore[T]) SaveData(dataPath string) error { - // Ensure directory exists - dataDir := filepath.Dir(dataPath) - if err := os.MkdirAll(dataDir, 0755); err != nil { - return fmt.Errorf("failed to create data directory: %w", err) - } - - if err := bs.SaveToJSON(dataPath); err != nil { - return fmt.Errorf("failed to save to JSON: %w", err) - } - - fmt.Printf("Saved %d items to %s\n", len(bs.items), dataPath) - return nil -} - -// BuildFilteredIntGroupIndex creates int-to-[]ID mapping for items passing filter -func BuildFilteredIntGroupIndex[T any](filterFunc func(*T) bool, keyFunc func(*T) int) IndexBuilder[T] { - return func(allItems map[int]*T) any { - index := make(map[int][]int) - for id, item := range allItems { - if filterFunc(item) { - key := keyFunc(item) - index[key] = append(index[key], id) - } - } - - // Sort each group by ID - for key := range index { - sort.Ints(index[key]) - } - - return index - } -} - -// BuildFilteredStringGroupIndex creates string-to-[]ID mapping for items passing filter -func BuildFilteredStringGroupIndex[T any](filterFunc func(*T) bool, keyFunc func(*T) string) IndexBuilder[T] { - return func(allItems map[int]*T) any { - index := make(map[string][]int) - for id, item := range allItems { - if filterFunc(item) { - key := keyFunc(item) - index[key] = append(index[key], id) - } - } - - // Sort each group by ID - for key := range index { - sort.Ints(index[key]) - } - - return index - } -} diff --git a/internal/template/template.go b/internal/template/template.go index 47a0120..384f386 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/valyala/fasthttp" + sushi "git.sharkk.net/Sharkk/Sushi" ) type Template struct { @@ -43,14 +43,12 @@ func (t *Template) RenderNamed(data map[string]any) string { return result } -func (t *Template) WriteTo(ctx *fasthttp.RequestCtx, data any) { - var result string - +func (t *Template) Render(data any) string { switch v := data.(type) { case map[string]any: - result = t.RenderNamed(v) + return t.RenderNamed(v) case []any: - result = t.RenderPositional(v...) + return t.RenderPositional(v...) default: rv := reflect.ValueOf(data) if rv.Kind() == reflect.Slice { @@ -58,14 +56,17 @@ func (t *Template) WriteTo(ctx *fasthttp.RequestCtx, data any) { for i := 0; i < rv.Len(); i++ { args[i] = rv.Index(i).Interface() } - result = t.RenderPositional(args...) + return t.RenderPositional(args...) } else { - result = t.RenderPositional(data) + return t.RenderPositional(data) } } +} - ctx.SetContentType("text/html; charset=utf-8") - ctx.WriteString(result) +func (t *Template) WriteTo(ctx sushi.Ctx, data any) error { + result := t.Render(data) + ctx.SendHTML(result) + return nil } func (t *Template) processBlocks(content string, blocks map[string]string) string { diff --git a/main.go b/main.go index 5b06c64..dec7d5c 100644 --- a/main.go +++ b/main.go @@ -9,9 +9,6 @@ import ( "path/filepath" "syscall" - "dk/internal/auth" - "dk/internal/csrf" - "dk/internal/middleware" "dk/internal/models/babble" "dk/internal/models/control" "dk/internal/models/drops" @@ -23,12 +20,15 @@ import ( "dk/internal/models/spells" "dk/internal/models/towns" "dk/internal/models/users" - "dk/internal/router" "dk/internal/routes" - "dk/internal/session" "dk/internal/template" - "github.com/valyala/fasthttp" + nigiri "git.sharkk.net/Sharkk/Nigiri" + sushi "git.sharkk.net/Sharkk/Sushi" + "git.sharkk.net/Sharkk/Sushi/auth" + "git.sharkk.net/Sharkk/Sushi/csrf" + "git.sharkk.net/Sharkk/Sushi/session" + "git.sharkk.net/Sharkk/Sushi/timing" ) func main() { @@ -53,106 +53,6 @@ func main() { } } -func loadModels(dataDir string) error { - if err := os.MkdirAll(dataDir, 0755); err != nil { - return fmt.Errorf("failed to create data directory: %w", err) - } - - if err := users.LoadData(filepath.Join(dataDir, "users.json")); err != nil { - return fmt.Errorf("failed to load users data: %w", err) - } - - if err := towns.LoadData(filepath.Join(dataDir, "towns.json")); err != nil { - return fmt.Errorf("failed to load towns data: %w", err) - } - - if err := spells.LoadData(filepath.Join(dataDir, "spells.json")); err != nil { - return fmt.Errorf("failed to load spells data: %w", err) - } - - if err := news.LoadData(filepath.Join(dataDir, "news.json")); err != nil { - return fmt.Errorf("failed to load news data: %w", err) - } - - if err := monsters.LoadData(filepath.Join(dataDir, "monsters.json")); err != nil { - return fmt.Errorf("failed to load monsters data: %w", err) - } - - if err := items.LoadData(filepath.Join(dataDir, "items.json")); err != nil { - return fmt.Errorf("failed to load items data: %w", err) - } - - if err := forum.LoadData(filepath.Join(dataDir, "forum.json")); err != nil { - return fmt.Errorf("failed to load forum data: %w", err) - } - - if err := drops.LoadData(filepath.Join(dataDir, "drops.json")); err != nil { - return fmt.Errorf("failed to load drops data: %w", err) - } - - if err := babble.LoadData(filepath.Join(dataDir, "babble.json")); err != nil { - return fmt.Errorf("failed to load babble data: %w", err) - } - - if err := control.Load(filepath.Join(dataDir, "control.json")); err != nil { - return fmt.Errorf("failed to load control data: %w", err) - } - - if err := fights.LoadData(filepath.Join(dataDir, "fights.json")); err != nil { - return fmt.Errorf("failed to load fights data: %w", err) - } - - return nil -} - -func saveModels(dataDir string) error { - if err := users.SaveData(filepath.Join(dataDir, "users.json")); err != nil { - return fmt.Errorf("failed to save users data: %w", err) - } - - if err := towns.SaveData(filepath.Join(dataDir, "towns.json")); err != nil { - return fmt.Errorf("failed to save towns data: %w", err) - } - - if err := spells.SaveData(filepath.Join(dataDir, "spells.json")); err != nil { - return fmt.Errorf("failed to save spells data: %w", err) - } - - if err := news.SaveData(filepath.Join(dataDir, "news.json")); err != nil { - return fmt.Errorf("failed to save news data: %w", err) - } - - if err := monsters.SaveData(filepath.Join(dataDir, "monsters.json")); err != nil { - return fmt.Errorf("failed to save monsters data: %w", err) - } - - if err := items.SaveData(filepath.Join(dataDir, "items.json")); err != nil { - return fmt.Errorf("failed to save items data: %w", err) - } - - if err := forum.SaveData(filepath.Join(dataDir, "forum.json")); err != nil { - return fmt.Errorf("failed to save forum data: %w", err) - } - - if err := drops.SaveData(filepath.Join(dataDir, "drops.json")); err != nil { - return fmt.Errorf("failed to save drops data: %w", err) - } - - if err := babble.SaveData(filepath.Join(dataDir, "babble.json")); err != nil { - return fmt.Errorf("failed to save babble data: %w", err) - } - - if err := control.Save(); err != nil { - return fmt.Errorf("failed to save control data: %w", err) - } - - if err := fights.SaveData(filepath.Join(dataDir, "fights.json")); err != nil { - return fmt.Errorf("failed to save fights data: %w", err) - } - - return nil -} - func startServer(port string) { fmt.Println("Starting Dragon Knight server...") if err := start(port); err != nil { @@ -168,94 +68,85 @@ func start(port string) error { template.InitializeCache(cwd) - if err := loadModels(filepath.Join(cwd, "data")); err != nil { - return fmt.Errorf("failed to load models: %w", err) + db := nigiri.NewCollection(filepath.Join(cwd, "data")) + if err := setupStores(db); err != nil { + return fmt.Errorf("failed to setup Nigiri stores: %w", err) } - session.Init(filepath.Join(cwd, "data/_sessions.json")) + app := sushi.New() + sushi.InitSessions(filepath.Join(cwd, "data/_sessions.json")) - r := router.New() - r.Use(middleware.Timing()) - r.Use(auth.Middleware()) - r.Use(csrf.Middleware()) + app.Use(session.Middleware()) + app.Use(auth.Middleware(getUserByID)) + app.Use(csrf.Middleware()) + app.Use(timing.Middleware()) - r.Get("/", routes.Index) + app.Get("/", routes.Index) - actions := r.Group("") - actions.Use(auth.RequireAuth()) - actions.Get("/explore", routes.Explore) - actions.Post("/move", routes.Move) - actions.Get("/teleport/:to", routes.Teleport) + protected := app.Group("") + protected.Use(auth.RequireAuth("/login")) + protected.Get("/explore", routes.Explore) + protected.Post("/move", routes.Move) + protected.Get("/teleport/:to", routes.Teleport) - routes.RegisterAuthRoutes(r) - routes.RegisterTownRoutes(r) - routes.RegisterFightRoutes(r) + routes.RegisterAuthRoutes(app) + routes.RegisterTownRoutes(app) + routes.RegisterFightRoutes(app) - // Use current working directory for static files - assetsDir := filepath.Join(cwd, "assets") - - // Static file server for /assets - fs := &fasthttp.FS{ - Root: assetsDir, - Compress: false, - } - assetsHandler := fs.NewRequestHandler() - - // Combined handler - requestHandler := func(ctx *fasthttp.RequestCtx) { - path := string(ctx.Path()) - - // Handle static assets - strip /assets prefix - if len(path) >= 7 && path[:7] == "/assets" { - // Strip the /assets prefix for the file system handler - originalPath := ctx.Path() - ctx.Request.URI().SetPath(path[7:]) // Remove "/assets" prefix - assetsHandler(ctx) - ctx.Request.URI().SetPathBytes(originalPath) // Restore original path - return - } - - // Handle routes - r.ServeHTTP(ctx) - } + app.Get("/assets/*path", sushi.Static(cwd)) addr := ":" + port log.Printf("Server starting on %s", addr) - // Setup graceful shutdown - server := &fasthttp.Server{ - Handler: requestHandler, - } - - // Channel to listen for interrupt signal c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) - // Start server in a goroutine go func() { - if err := server.ListenAndServe(addr); err != nil { - log.Printf("Server error: %v", err) - } + app.Listen(addr) }() - // Wait for interrupt signal <-c - log.Println("Received shutdown signal, shutting down gracefully...") + log.Println("\nReceived shutdown signal, shutting down gracefully...") - // Save all model data before shutdown - log.Println("Saving model data...") - if err := saveModels(filepath.Join(cwd, "data")); err != nil { - log.Printf("Error saving model data: %v", err) + log.Println("Saving database...") + if err := db.Save(); err != nil { + log.Printf("Error saving database: %v", err) } - // Save sessions before shutdown log.Println("Saving sessions...") - if err := session.Close(); err != nil { - log.Printf("Error saving sessions: %v", err) - } + sushi.SaveSessions() - // FastHTTP doesn't have a graceful Shutdown method like net/http - // We just let the server stop naturally when the main function exits log.Println("Server stopped") return nil } + +func setupStores(db *nigiri.Collection) error { + users.Init(db) + towns.Init(db) + spells.Init(db) + news.Init(db) + monsters.Init(db) + items.Init(db) + forum.Init(db) + drops.Init(db) + babble.Init(db) + fights.Init(db) + control.Init(db) + + db.Add("users", users.GetStore()) + db.Add("towns", towns.GetStore()) + db.Add("spells", spells.GetStore()) + db.Add("news", news.GetStore()) + db.Add("monsters", monsters.GetStore()) + db.Add("items", items.GetStore()) + db.Add("forum", forum.GetStore()) + db.Add("drops", drops.GetStore()) + db.Add("babble", babble.GetStore()) + db.Add("fights", fights.GetStore()) + + return nil +} + +func getUserByID(userID int) any { + return users.GetByID(userID) +}