Compare commits

...

2 Commits

Author SHA1 Message Date
820bc87418 fully migrate db ops to singleton pattern 2025-08-09 19:04:08 -05:00
b1f436585e Simplify auth package 2025-08-09 18:12:23 -05:00
25 changed files with 407 additions and 546 deletions

View File

@ -1,7 +1,6 @@
package auth package auth
import ( import (
"dk/internal/database"
"dk/internal/password" "dk/internal/password"
"dk/internal/users" "dk/internal/users"
) )
@ -9,29 +8,28 @@ import (
// Manager is the global singleton instance // Manager is the global singleton instance
var Manager *AuthManager var Manager *AuthManager
// User is a simplified User struct for auth purposes
type User struct { type User struct {
ID int ID int
Username string Username string
Email string Email string
} }
// AuthManager is a wrapper for the session store to add
// authentication tools over the store itself
type AuthManager struct { type AuthManager struct {
sessionStore *SessionStore store *SessionStore
db *database.DB
} }
func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager { // Init initializes the global auth manager (auth.Manager)
return &AuthManager{ func Init(sessionsFilePath string) {
sessionStore: NewSessionStore(sessionsFilePath), Manager = &AuthManager{
db: db, store: NewSessionStore(sessionsFilePath),
} }
} }
// InitializeManager initializes the global Manager singleton // Authenticate checks for the usernaname or email, then verifies the plain password
func InitializeManager(db *database.DB, sessionsFilePath string) { // against the stored hash.
Manager = NewAuthManager(db, sessionsFilePath)
}
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) { func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
var user *users.User var user *users.User
var err error var err error
@ -39,14 +37,12 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use
// Try to find user by username first // Try to find user by username first
user, err = users.GetByUsername(usernameOrEmail) user, err = users.GetByUsername(usernameOrEmail)
if err != nil { if err != nil {
// Try by email if username lookup failed
user, err = users.GetByEmail(usernameOrEmail) user, err = users.GetByEmail(usernameOrEmail)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// Verify password
isValid, err := password.Verify(plainPassword, user.Password) isValid, err := password.Verify(plainPassword, user.Password)
if err != nil { if err != nil {
return nil, err return nil, err
@ -63,31 +59,27 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use
} }
func (am *AuthManager) CreateSession(user *User) *Session { func (am *AuthManager) CreateSession(user *User) *Session {
return am.sessionStore.Create(user.ID, user.Username, user.Email) return am.store.Create(user.ID, user.Username, user.Email)
} }
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) { func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
return am.sessionStore.Get(sessionID) return am.store.Get(sessionID)
} }
func (am *AuthManager) UpdateSession(sessionID string) bool { func (am *AuthManager) UpdateSession(sessionID string) bool {
return am.sessionStore.Update(sessionID) return am.store.Update(sessionID)
} }
func (am *AuthManager) DeleteSession(sessionID string) { func (am *AuthManager) DeleteSession(sessionID string) {
am.sessionStore.Delete(sessionID) am.store.Delete(sessionID)
} }
func (am *AuthManager) SessionStats() (total, active int) { func (am *AuthManager) SessionStats() (total, active int) {
return am.sessionStore.Stats() return am.store.Stats()
}
func (am *AuthManager) DB() *database.DB {
return am.db
} }
func (am *AuthManager) Close() error { func (am *AuthManager) Close() error {
return am.sessionStore.Close() return am.store.Close()
} }
var ( var (

View File

@ -1,103 +1,29 @@
package auth package auth
import ( import (
"dk/internal/cookies"
"dk/internal/utils"
"time" "time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) { func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
SetSecureCookie(ctx, CookieOptions{ cookies.SetSecureCookie(ctx, cookies.CookieOptions{
Name: SessionCookieName, Name: SessionCookieName,
Value: sessionID, Value: sessionID,
Path: "/", Path: "/",
Expires: time.Now().Add(DefaultExpiration), Expires: time.Now().Add(DefaultExpiration),
HTTPOnly: true, HTTPOnly: true,
Secure: isHTTPS(ctx), Secure: utils.IsHTTPS(ctx),
SameSite: "lax", SameSite: "lax",
}) })
} }
func GetSessionCookie(ctx *fasthttp.RequestCtx) string { func GetSessionCookie(ctx *fasthttp.RequestCtx) string {
return GetCookie(ctx, SessionCookieName) return cookies.GetCookie(ctx, SessionCookieName)
} }
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) { func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
DeleteCookie(ctx, SessionCookieName) cookies.DeleteCookie(ctx, SessionCookieName)
} }
func isHTTPS(ctx *fasthttp.RequestCtx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}

View File

@ -16,16 +16,14 @@ type Babble struct {
Posted int64 `json:"posted"` Posted int64 `json:"posted"`
Author string `json:"author"` Author string `json:"author"`
Babble string `json:"babble"` Babble string `json:"babble"`
db *database.DB
} }
// Find retrieves a babble message by ID // Find retrieves a babble message by ID
func Find(db *database.DB, id int) (*Babble, error) { func Find(id int) (*Babble, error) {
babble := &Babble{db: db} babble := &Babble{}
query := "SELECT id, posted, author, babble FROM babble WHERE id = ?" query := "SELECT id, posted, author, babble FROM babble WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble.ID = stmt.ColumnInt(0) babble.ID = stmt.ColumnInt(0)
babble.Posted = stmt.ColumnInt64(1) babble.Posted = stmt.ColumnInt64(1)
babble.Author = stmt.ColumnText(2) babble.Author = stmt.ColumnText(2)
@ -45,17 +43,16 @@ func Find(db *database.DB, id int) (*Babble, error) {
} }
// All retrieves all babble messages ordered by posted time (newest first) // All retrieves all babble messages ordered by posted time (newest first)
func All(db *database.DB) ([]*Babble, error) { func All() ([]*Babble, error) {
var babbles []*Babble var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble ORDER BY posted DESC, id DESC" query := "SELECT id, posted, author, babble FROM babble ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{ babble := &Babble{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2), Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3), Babble: stmt.ColumnText(3),
db: db,
} }
babbles = append(babbles, babble) babbles = append(babbles, babble)
return nil return nil
@ -69,17 +66,16 @@ func All(db *database.DB) ([]*Babble, error) {
} }
// ByAuthor retrieves babble messages by a specific author // ByAuthor retrieves babble messages by a specific author
func ByAuthor(db *database.DB, author string) ([]*Babble, error) { func ByAuthor(author string) ([]*Babble, error) {
var babbles []*Babble var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC" query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{ babble := &Babble{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2), Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3), Babble: stmt.ColumnText(3),
db: db,
} }
babbles = append(babbles, babble) babbles = append(babbles, babble)
return nil return nil
@ -93,17 +89,16 @@ func ByAuthor(db *database.DB, author string) ([]*Babble, error) {
} }
// Recent retrieves the most recent babble messages (limited by count) // Recent retrieves the most recent babble messages (limited by count)
func Recent(db *database.DB, limit int) ([]*Babble, error) { func Recent(limit int) ([]*Babble, error) {
var babbles []*Babble var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble ORDER BY posted DESC, id DESC LIMIT ?" query := "SELECT id, posted, author, babble FROM babble ORDER BY posted DESC, id DESC LIMIT ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{ babble := &Babble{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2), Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3), Babble: stmt.ColumnText(3),
db: db,
} }
babbles = append(babbles, babble) babbles = append(babbles, babble)
return nil return nil
@ -117,17 +112,16 @@ func Recent(db *database.DB, limit int) ([]*Babble, error) {
} }
// Since retrieves babble messages since a specific timestamp // Since retrieves babble messages since a specific timestamp
func Since(db *database.DB, since int64) ([]*Babble, error) { func Since(since int64) ([]*Babble, error) {
var babbles []*Babble var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE posted >= ? ORDER BY posted DESC, id DESC" query := "SELECT id, posted, author, babble FROM babble WHERE posted >= ? ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{ babble := &Babble{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2), Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3), Babble: stmt.ColumnText(3),
db: db,
} }
babbles = append(babbles, babble) babbles = append(babbles, babble)
return nil return nil
@ -141,17 +135,16 @@ func Since(db *database.DB, since int64) ([]*Babble, error) {
} }
// Between retrieves babble messages between two timestamps (inclusive) // Between retrieves babble messages between two timestamps (inclusive)
func Between(db *database.DB, start, end int64) ([]*Babble, error) { func Between(start, end int64) ([]*Babble, error) {
var babbles []*Babble var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC" query := "SELECT id, posted, author, babble FROM babble WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{ babble := &Babble{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2), Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3), Babble: stmt.ColumnText(3),
db: db,
} }
babbles = append(babbles, babble) babbles = append(babbles, babble)
return nil return nil
@ -165,19 +158,18 @@ func Between(db *database.DB, start, end int64) ([]*Babble, error) {
} }
// Search retrieves babble messages containing the search term (case-insensitive) // Search retrieves babble messages containing the search term (case-insensitive)
func Search(db *database.DB, term string) ([]*Babble, error) { func Search(term string) ([]*Babble, error) {
var babbles []*Babble var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(babble) LIKE LOWER(?) ORDER BY posted DESC, id DESC" query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(babble) LIKE LOWER(?) ORDER BY posted DESC, id DESC"
searchTerm := "%" + term + "%" searchTerm := "%" + term + "%"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{ babble := &Babble{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2), Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3), Babble: stmt.ColumnText(3),
db: db,
} }
babbles = append(babbles, babble) babbles = append(babbles, babble)
return nil return nil
@ -191,17 +183,16 @@ func Search(db *database.DB, term string) ([]*Babble, error) {
} }
// RecentByAuthor retrieves recent messages from a specific author // RecentByAuthor retrieves recent messages from a specific author
func RecentByAuthor(db *database.DB, author string, limit int) ([]*Babble, error) { func RecentByAuthor(author string, limit int) ([]*Babble, error) {
var babbles []*Babble var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC LIMIT ?" query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC LIMIT ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{ babble := &Babble{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2), Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3), Babble: stmt.ColumnText(3),
db: db,
} }
babbles = append(babbles, babble) babbles = append(babbles, babble)
return nil return nil
@ -221,7 +212,7 @@ func (b *Babble) Save() error {
} }
query := `UPDATE babble SET posted = ?, author = ?, babble = ? WHERE id = ?` query := `UPDATE babble SET posted = ?, author = ?, babble = ? WHERE id = ?`
return b.db.Exec(query, b.Posted, b.Author, b.Babble, b.ID) return database.Exec(query, b.Posted, b.Author, b.Babble, b.ID)
} }
// Delete removes the babble message from the database // Delete removes the babble message from the database
@ -230,8 +221,7 @@ func (b *Babble) Delete() error {
return fmt.Errorf("cannot delete babble without ID") return fmt.Errorf("cannot delete babble without ID")
} }
query := "DELETE FROM babble WHERE id = ?" return database.Exec("DELETE FROM babble WHERE id = ?", b.ID)
return b.db.Exec(query, b.ID)
} }
// PostedTime returns the posted timestamp as a time.Time // PostedTime returns the posted timestamp as a time.Time
@ -264,11 +254,11 @@ func (b *Babble) Preview(maxLength int) string {
if len(b.Babble) <= maxLength { if len(b.Babble) <= maxLength {
return b.Babble return b.Babble
} }
if maxLength < 3 { if maxLength < 3 {
return b.Babble[:maxLength] return b.Babble[:maxLength]
} }
return b.Babble[:maxLength-3] + "..." return b.Babble[:maxLength-3] + "..."
} }
@ -277,11 +267,11 @@ func (b *Babble) WordCount() int {
if b.Babble == "" { if b.Babble == "" {
return 0 return 0
} }
// Simple word count by splitting on whitespace // Simple word count by splitting on whitespace
words := 0 words := 0
inWord := false inWord := false
for _, char := range b.Babble { for _, char := range b.Babble {
if char == ' ' || char == '\t' || char == '\n' || char == '\r' { if char == ' ' || char == '\t' || char == '\n' || char == '\r' {
if inWord { if inWord {
@ -292,11 +282,11 @@ func (b *Babble) WordCount() int {
inWord = true inWord = true
} }
} }
if inWord { if inWord {
words++ words++
} }
return words return words
} }
@ -324,7 +314,7 @@ func (b *Babble) IsLongMessage(threshold int) bool {
func (b *Babble) GetMentions() []string { func (b *Babble) GetMentions() []string {
words := strings.Fields(b.Babble) words := strings.Fields(b.Babble)
var mentions []string var mentions []string
for _, word := range words { for _, word := range words {
if strings.HasPrefix(word, "@") && len(word) > 1 { if strings.HasPrefix(word, "@") && len(word) > 1 {
// Clean up punctuation from the end // Clean up punctuation from the end
@ -334,7 +324,7 @@ func (b *Babble) GetMentions() []string {
} }
} }
} }
return mentions return mentions
} }
@ -347,4 +337,4 @@ func (b *Babble) HasMention(username string) bool {
} }
} }
return false return false
} }

View File

@ -12,17 +12,14 @@ import (
// Builder provides a fluent interface for creating babble messages // Builder provides a fluent interface for creating babble messages
type Builder struct { type Builder struct {
babble *Babble babble *Babble
db *database.DB
} }
// NewBuilder creates a new babble builder // NewBuilder creates a new babble builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
babble: &Babble{ babble: &Babble{
db: db,
Posted: time.Now().Unix(), // Default to current time Posted: time.Now().Unix(), // Default to current time
}, },
db: db,
} }
} }
@ -59,10 +56,10 @@ func (b *Builder) WithPostedTime(t time.Time) *Builder {
func (b *Builder) Create() (*Babble, error) { func (b *Builder) Create() (*Babble, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var babble *Babble var babble *Babble
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO babble (posted, author, babble) query := `INSERT INTO babble (posted, author, babble)
VALUES (?, ?, ?)` VALUES (?, ?, ?)`
if err := tx.Exec(query, b.babble.Posted, b.babble.Author, b.babble.Babble); err != nil { if err := tx.Exec(query, b.babble.Posted, b.babble.Author, b.babble.Babble); err != nil {
return fmt.Errorf("failed to insert babble: %w", err) return fmt.Errorf("failed to insert babble: %w", err)
} }
@ -81,10 +78,10 @@ func (b *Builder) Create() (*Babble, error) {
babble = b.babble babble = b.babble
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return babble, nil return babble, nil
} }

View File

@ -9,25 +9,22 @@ import (
) )
// Control represents the game control settings in the database // Control represents the game control settings in the database
// There is only ever one control record with ID 1
type Control struct { type Control struct {
ID int `json:"id"` ID int `json:"id"`
WorldSize int `json:"world_size"` WorldSize int `json:"world_size"`
Open int `json:"open"` Open int `json:"open"`
AdminEmail string `json:"admin_email"` AdminEmail string `json:"admin_email"`
Class1Name string `json:"class_1_name"` Class1Name string `json:"class_1_name"`
Class2Name string `json:"class_2_name"` Class2Name string `json:"class_2_name"`
Class3Name string `json:"class_3_name"` Class3Name string `json:"class_3_name"`
db *database.DB
} }
// Find retrieves the control record by ID (typically only ID 1 exists) // Find retrieves the control record by ID (typically only ID 1 exists)
func Find(db *database.DB, id int) (*Control, error) { func Find(id int) (*Control, error) {
control := &Control{db: db} control := &Control{}
query := "SELECT id, world_size, open, admin_email, class_1_name, class_2_name, class_3_name FROM control WHERE id = ?" query := "SELECT id, world_size, open, admin_email, class_1_name, class_2_name, class_3_name FROM control WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
control.ID = stmt.ColumnInt(0) control.ID = stmt.ColumnInt(0)
control.WorldSize = stmt.ColumnInt(1) control.WorldSize = stmt.ColumnInt(1)
control.Open = stmt.ColumnInt(2) control.Open = stmt.ColumnInt(2)
@ -50,8 +47,8 @@ func Find(db *database.DB, id int) (*Control, error) {
} }
// Get retrieves the main control record (ID 1) // Get retrieves the main control record (ID 1)
func Get(db *database.DB) (*Control, error) { func Get() (*Control, error) {
return Find(db, 1) return Find(1)
} }
// Save updates the control record in the database // Save updates the control record in the database
@ -61,7 +58,7 @@ func (c *Control) Save() error {
} }
query := `UPDATE control SET world_size = ?, open = ?, admin_email = ?, class_1_name = ?, class_2_name = ?, class_3_name = ? WHERE id = ?` query := `UPDATE control SET world_size = ?, open = ?, admin_email = ?, class_1_name = ?, class_2_name = ?, class_3_name = ? WHERE id = ?`
return c.db.Exec(query, c.WorldSize, c.Open, c.AdminEmail, c.Class1Name, c.Class2Name, c.Class3Name, c.ID) return database.Exec(query, c.WorldSize, c.Open, c.AdminEmail, c.Class1Name, c.Class2Name, c.Class3Name, c.ID)
} }
// IsOpen returns true if the game world is open for new players // IsOpen returns true if the game world is open for new players
@ -109,7 +106,7 @@ func (c *Control) SetClassNames(classes []string) {
c.Class1Name = "" c.Class1Name = ""
c.Class2Name = "" c.Class2Name = ""
c.Class3Name = "" c.Class3Name = ""
// Set provided class names // Set provided class names
if len(classes) > 0 { if len(classes) > 0 {
c.Class1Name = classes[0] c.Class1Name = classes[0]
@ -200,4 +197,4 @@ func (c *Control) IsWithinWorldBounds(x, y int) bool {
func (c *Control) GetWorldBounds() (minX, minY, maxX, maxY int) { func (c *Control) GetWorldBounds() (minX, minY, maxX, maxY int) {
radius := c.GetWorldRadius() radius := c.GetWorldRadius()
return -radius, -radius, radius, radius return -radius, -radius, radius, radius
} }

View File

@ -0,0 +1,77 @@
package cookies
import (
"time"
"github.com/valyala/fasthttp"
)
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}

View File

@ -11,75 +11,80 @@ import (
const DefaultPath = "dk.db" const DefaultPath = "dk.db"
// database wraps a SQLite connection pool with simplified methods // Global singleton instance
type database struct { var pool *sqlitex.Pool
pool *sqlitex.Pool
}
// DB is a backward-compatible type alias // Init initializes the global database connection pool
type DB = database func Init(path string) error {
// instance is the global singleton instance
var instance *database
// Open creates a new database connection pool
func Open(path string) (*database, error) {
if path == "" { if path == "" {
path = DefaultPath path = DefaultPath
} }
poolSize := max(runtime.GOMAXPROCS(0), 2) poolSize := max(runtime.GOMAXPROCS(0), 2)
pool, err := sqlitex.NewPool(path, sqlitex.PoolOptions{ var err error
pool, err = sqlitex.NewPool(path, sqlitex.PoolOptions{
PoolSize: poolSize, PoolSize: poolSize,
Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL, Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database pool: %w", err) return fmt.Errorf("failed to open database pool: %w", err)
} }
conn, err := pool.Take(context.Background()) conn, err := pool.Take(context.Background())
if err != nil { if err != nil {
pool.Close() pool.Close()
return nil, fmt.Errorf("failed to get connection from pool: %w", err) return fmt.Errorf("failed to get connection from pool: %w", err)
} }
defer pool.Put(conn) defer pool.Put(conn)
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil { if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
pool.Close() pool.Close()
return nil, fmt.Errorf("failed to set WAL mode: %w", err) return fmt.Errorf("failed to set WAL mode: %w", err)
} }
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil { if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
pool.Close() pool.Close()
return nil, fmt.Errorf("failed to set synchronous mode: %w", err) return fmt.Errorf("failed to set synchronous mode: %w", err)
} }
return &database{pool: pool}, nil return nil
} }
// Close closes the database connection pool // Close closes the global database connection pool
func (db *database) Close() error { func Close() error {
return db.pool.Close() if pool == nil {
return nil
}
return pool.Close()
} }
// GetConn gets a connection from the pool - caller must call Put when done // GetConn gets a connection from the pool - caller must call PutConn when done
func (db *database) GetConn(ctx context.Context) (*sqlite.Conn, error) { func GetConn(ctx context.Context) (*sqlite.Conn, error) {
return db.pool.Take(ctx) if pool == nil {
return nil, fmt.Errorf("database not initialized")
}
return pool.Take(ctx)
} }
// PutConn returns a connection to the pool // PutConn returns a connection to the pool
func (db *database) PutConn(conn *sqlite.Conn) { func PutConn(conn *sqlite.Conn) {
db.pool.Put(conn) if pool != nil {
pool.Put(conn)
}
} }
// Exec executes a SQL statement without returning results // Exec executes a SQL statement without returning results
func (db *database) Exec(query string, args ...any) error { func Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background()) if pool == nil {
return fmt.Errorf("database not initialized")
}
conn, err := pool.Take(context.Background())
if err != nil { if err != nil {
return fmt.Errorf("failed to get connection from pool: %w", err) return fmt.Errorf("failed to get connection from pool: %w", err)
} }
defer db.pool.Put(conn) defer pool.Put(conn)
if len(args) == 0 { if len(args) == 0 {
return sqlitex.ExecuteTransient(conn, query, nil) return sqlitex.ExecuteTransient(conn, query, nil)
@ -91,12 +96,16 @@ func (db *database) Exec(query string, args ...any) error {
} }
// Query executes a SQL query and calls fn for each row // Query executes a SQL query and calls fn for each row
func (db *database) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
conn, err := db.pool.Take(context.Background()) if pool == nil {
return fmt.Errorf("database not initialized")
}
conn, err := pool.Take(context.Background())
if err != nil { if err != nil {
return fmt.Errorf("failed to get connection from pool: %w", err) return fmt.Errorf("failed to get connection from pool: %w", err)
} }
defer db.pool.Put(conn) defer pool.Put(conn)
if len(args) == 0 { if len(args) == 0 {
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
@ -111,23 +120,31 @@ func (db *database) Query(query string, fn func(*sqlite.Stmt) error, args ...any
} }
// Begin starts a new transaction // Begin starts a new transaction
func (db *database) Begin() (*Tx, error) { func Begin() (*Tx, error) {
conn, err := db.pool.Take(context.Background()) if pool == nil {
return nil, fmt.Errorf("database not initialized")
}
conn, err := pool.Take(context.Background())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get connection from pool: %w", err) return nil, fmt.Errorf("failed to get connection from pool: %w", err)
} }
if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil { if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil {
db.pool.Put(conn) pool.Put(conn)
return nil, fmt.Errorf("failed to begin transaction: %w", err) return nil, fmt.Errorf("failed to begin transaction: %w", err)
} }
return &Tx{conn: conn, pool: db.pool}, nil return &Tx{conn: conn, pool: pool}, nil
} }
// Transaction runs a function within a transaction // Transaction runs a function within a transaction
func (db *database) Transaction(fn func(*Tx) error) error { func Transaction(fn func(*Tx) error) error {
tx, err := db.Begin() if pool == nil {
return fmt.Errorf("database not initialized")
}
tx, err := Begin()
if err != nil { if err != nil {
return err return err
} }
@ -182,75 +199,3 @@ func (tx *Tx) Rollback() error {
defer tx.pool.Put(tx.conn) defer tx.pool.Put(tx.conn)
return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil) return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil)
} }
// InitializeDB initializes the global DB singleton
func InitializeDB(path string) error {
db, err := Open(path)
if err != nil {
return err
}
instance = db
return nil
}
// GetDB returns the global database instance
func GetDB() *DB {
return instance
}
// Global convenience functions that use the singleton
// Exec executes a SQL statement without returning results using the global DB
func Exec(query string, args ...any) error {
if instance == nil {
return fmt.Errorf("database not initialized")
}
return instance.Exec(query, args...)
}
// Query executes a SQL query and calls fn for each row using the global DB
func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
if instance == nil {
return fmt.Errorf("database not initialized")
}
return instance.Query(query, fn, args...)
}
// Begin starts a new transaction using the global DB
func Begin() (*Tx, error) {
if instance == nil {
return nil, fmt.Errorf("database not initialized")
}
return instance.Begin()
}
// Transaction runs a function within a transaction using the global DB
func Transaction(fn func(*Tx) error) error {
if instance == nil {
return fmt.Errorf("database not initialized")
}
return instance.Transaction(fn)
}
// GetConn gets a connection from the pool using the global DB
func GetConn(ctx context.Context) (*sqlite.Conn, error) {
if instance == nil {
return nil, fmt.Errorf("database not initialized")
}
return instance.GetConn(ctx)
}
// PutConn returns a connection to the pool using the global DB
func PutConn(conn *sqlite.Conn) {
if instance != nil {
instance.PutConn(conn)
}
}
// Close closes the global database connection pool
func Close() error {
if instance == nil {
return nil
}
return instance.Close()
}

View File

@ -11,62 +11,62 @@ func TestDatabaseOperations(t *testing.T) {
// Use a temporary database file // Use a temporary database file
testDB := "test.db" testDB := "test.db"
defer os.Remove(testDB) defer os.Remove(testDB)
// Test opening database // Initialize the singleton database
db, err := Open(testDB) err := Init(testDB)
if err != nil { if err != nil {
t.Fatalf("Failed to open database: %v", err) t.Fatalf("Failed to initialize database: %v", err)
} }
defer db.Close() defer Close()
// Test creating a simple table // Test creating a simple table
err = db.Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)") err = Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)")
if err != nil { if err != nil {
t.Fatalf("Failed to create table: %v", err) t.Fatalf("Failed to create table: %v", err)
} }
// Test inserting data // Test inserting data
err = db.Exec("INSERT INTO test_users (name) VALUES (?)", "Alice") err = Exec("INSERT INTO test_users (name) VALUES (?)", "Alice")
if err != nil { if err != nil {
t.Fatalf("Failed to insert data: %v", err) t.Fatalf("Failed to insert data: %v", err)
} }
// Test querying data // Test querying data
var foundName string var foundName string
err = db.Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error { err = Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error {
foundName = stmt.ColumnText(0) foundName = stmt.ColumnText(0)
return nil return nil
}, "Alice") }, "Alice")
if err != nil { if err != nil {
t.Fatalf("Failed to query data: %v", err) t.Fatalf("Failed to query data: %v", err)
} }
if foundName != "Alice" { if foundName != "Alice" {
t.Errorf("Expected 'Alice', got '%s'", foundName) t.Errorf("Expected 'Alice', got '%s'", foundName)
} }
// Test transaction // Test transaction
err = db.Transaction(func(tx *Tx) error { err = Transaction(func(tx *Tx) error {
return tx.Exec("INSERT INTO test_users (name) VALUES (?)", "Bob") return tx.Exec("INSERT INTO test_users (name) VALUES (?)", "Bob")
}) })
if err != nil { if err != nil {
t.Fatalf("Transaction failed: %v", err) t.Fatalf("Transaction failed: %v", err)
} }
// Verify transaction worked // Verify transaction worked
var count int var count int
err = db.Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error { err = Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0) count = stmt.ColumnInt(0)
return nil return nil
}) })
if err != nil { if err != nil {
t.Fatalf("Failed to count users: %v", err) t.Fatalf("Failed to count users: %v", err)
} }
if count != 2 { if count != 2 {
t.Errorf("Expected 2 users, got %d", count) t.Errorf("Expected 2 users, got %d", count)
} }
} }

View File

@ -10,14 +10,12 @@ import (
// Builder provides a fluent interface for creating drops // Builder provides a fluent interface for creating drops
type Builder struct { type Builder struct {
drop *Drop drop *Drop
db *database.DB
} }
// NewBuilder creates a new drop builder // NewBuilder creates a new drop builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
drop: &Drop{db: db}, drop: &Drop{},
db: db,
} }
} }
@ -49,8 +47,8 @@ func (b *Builder) WithAtt(att string) *Builder {
func (b *Builder) Create() (*Drop, error) { func (b *Builder) Create() (*Drop, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var drop *Drop var drop *Drop
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO drops (name, level, type, att) query := `INSERT INTO drops (name, level, type, att)
VALUES (?, ?, ?, ?)` VALUES (?, ?, ?, ?)`
if err := tx.Exec(query, b.drop.Name, b.drop.Level, b.drop.Type, b.drop.Att); err != nil { if err := tx.Exec(query, b.drop.Name, b.drop.Level, b.drop.Type, b.drop.Att); err != nil {
@ -74,7 +72,6 @@ func (b *Builder) Create() (*Drop, error) {
Level: b.drop.Level, Level: b.drop.Level,
Type: b.drop.Type, Type: b.drop.Type,
Att: b.drop.Att, Att: b.drop.Att,
db: b.db,
} }
return nil return nil

View File

@ -15,8 +15,6 @@ type Drop struct {
Level int `json:"level"` Level int `json:"level"`
Type int `json:"type"` Type int `json:"type"`
Att string `json:"att"` Att string `json:"att"`
db *database.DB
} }
// DropType constants for drop types // DropType constants for drop types
@ -25,11 +23,11 @@ const (
) )
// Find retrieves a drop by ID // Find retrieves a drop by ID
func Find(db *database.DB, id int) (*Drop, error) { func Find(id int) (*Drop, error) {
drop := &Drop{db: db} drop := &Drop{}
query := "SELECT id, name, level, type, att FROM drops WHERE id = ?" query := "SELECT id, name, level, type, att FROM drops WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
drop.ID = stmt.ColumnInt(0) drop.ID = stmt.ColumnInt(0)
drop.Name = stmt.ColumnText(1) drop.Name = stmt.ColumnText(1)
drop.Level = stmt.ColumnInt(2) drop.Level = stmt.ColumnInt(2)
@ -50,18 +48,17 @@ func Find(db *database.DB, id int) (*Drop, error) {
} }
// All retrieves all drops // All retrieves all drops
func All(db *database.DB) ([]*Drop, error) { func All() ([]*Drop, error) {
var drops []*Drop var drops []*Drop
query := "SELECT id, name, level, type, att FROM drops ORDER BY id" query := "SELECT id, name, level, type, att FROM drops ORDER BY id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
drop := &Drop{ drop := &Drop{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
Level: stmt.ColumnInt(2), Level: stmt.ColumnInt(2),
Type: stmt.ColumnInt(3), Type: stmt.ColumnInt(3),
Att: stmt.ColumnText(4), Att: stmt.ColumnText(4),
db: db,
} }
drops = append(drops, drop) drops = append(drops, drop)
return nil return nil
@ -75,18 +72,17 @@ func All(db *database.DB) ([]*Drop, error) {
} }
// ByLevel retrieves drops by minimum level requirement // ByLevel retrieves drops by minimum level requirement
func ByLevel(db *database.DB, minLevel int) ([]*Drop, error) { func ByLevel(minLevel int) ([]*Drop, error) {
var drops []*Drop var drops []*Drop
query := "SELECT id, name, level, type, att FROM drops WHERE level <= ? ORDER BY level, id" query := "SELECT id, name, level, type, att FROM drops WHERE level <= ? ORDER BY level, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
drop := &Drop{ drop := &Drop{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
Level: stmt.ColumnInt(2), Level: stmt.ColumnInt(2),
Type: stmt.ColumnInt(3), Type: stmt.ColumnInt(3),
Att: stmt.ColumnText(4), Att: stmt.ColumnText(4),
db: db,
} }
drops = append(drops, drop) drops = append(drops, drop)
return nil return nil
@ -100,18 +96,17 @@ func ByLevel(db *database.DB, minLevel int) ([]*Drop, error) {
} }
// ByType retrieves drops by type // ByType retrieves drops by type
func ByType(db *database.DB, dropType int) ([]*Drop, error) { func ByType(dropType int) ([]*Drop, error) {
var drops []*Drop var drops []*Drop
query := "SELECT id, name, level, type, att FROM drops WHERE type = ? ORDER BY level, id" query := "SELECT id, name, level, type, att FROM drops WHERE type = ? ORDER BY level, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
drop := &Drop{ drop := &Drop{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
Level: stmt.ColumnInt(2), Level: stmt.ColumnInt(2),
Type: stmt.ColumnInt(3), Type: stmt.ColumnInt(3),
Att: stmt.ColumnText(4), Att: stmt.ColumnText(4),
db: db,
} }
drops = append(drops, drop) drops = append(drops, drop)
return nil return nil
@ -131,7 +126,7 @@ func (d *Drop) Save() error {
} }
query := `UPDATE drops SET name = ?, level = ?, type = ?, att = ? WHERE id = ?` query := `UPDATE drops SET name = ?, level = ?, type = ?, att = ? WHERE id = ?`
return d.db.Exec(query, d.Name, d.Level, d.Type, d.Att, d.ID) return database.Exec(query, d.Name, d.Level, d.Type, d.Att, d.ID)
} }
// Delete removes the drop from the database // Delete removes the drop from the database
@ -140,8 +135,7 @@ func (d *Drop) Delete() error {
return fmt.Errorf("cannot delete drop without ID") return fmt.Errorf("cannot delete drop without ID")
} }
query := "DELETE FROM drops WHERE id = ?" return database.Exec("DELETE FROM drops WHERE id = ?", d.ID)
return d.db.Exec(query, d.ID)
} }
// IsConsumable returns true if the drop is a consumable item // IsConsumable returns true if the drop is a consumable item

View File

@ -12,21 +12,18 @@ import (
// Builder provides a fluent interface for creating forum posts // Builder provides a fluent interface for creating forum posts
type Builder struct { type Builder struct {
forum *Forum forum *Forum
db *database.DB
} }
// NewBuilder creates a new forum post builder // NewBuilder creates a new forum post builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
now := time.Now().Unix() now := time.Now().Unix()
return &Builder{ return &Builder{
forum: &Forum{ forum: &Forum{
db: db,
Posted: now, Posted: now,
LastPost: now, // Default to same as posted time LastPost: now, // Default to same as posted time
Parent: 0, // Default to thread (no parent) Parent: 0, // Default to thread (no parent)
Replies: 0, // Default to no replies Replies: 0, // Default to no replies
}, },
db: db,
} }
} }
@ -99,8 +96,8 @@ func (b *Builder) WithReplies(replies int) *Builder {
func (b *Builder) Create() (*Forum, error) { func (b *Builder) Create() (*Forum, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var forum *Forum var forum *Forum
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO forum (posted, last_post, author, parent, replies, title, content) query := `INSERT INTO forum (posted, last_post, author, parent, replies, title, content)
VALUES (?, ?, ?, ?, ?, ?, ?)` VALUES (?, ?, ?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.forum.Posted, b.forum.LastPost, b.forum.Author, if err := tx.Exec(query, b.forum.Posted, b.forum.LastPost, b.forum.Author,
@ -128,4 +125,4 @@ func (b *Builder) Create() (*Forum, error) {
} }
return forum, nil return forum, nil
} }

View File

@ -20,16 +20,14 @@ type Forum struct {
Replies int `json:"replies"` Replies int `json:"replies"`
Title string `json:"title"` Title string `json:"title"`
Content string `json:"content"` Content string `json:"content"`
db *database.DB
} }
// Find retrieves a forum post by ID // Find retrieves a forum post by ID
func Find(db *database.DB, id int) (*Forum, error) { func Find(id int) (*Forum, error) {
forum := &Forum{db: db} forum := &Forum{}
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE id = ?" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum.ID = stmt.ColumnInt(0) forum.ID = stmt.ColumnInt(0)
forum.Posted = stmt.ColumnInt64(1) forum.Posted = stmt.ColumnInt64(1)
forum.LastPost = stmt.ColumnInt64(2) forum.LastPost = stmt.ColumnInt64(2)
@ -53,11 +51,11 @@ func Find(db *database.DB, id int) (*Forum, error) {
} }
// All retrieves all forum posts ordered by last post time (most recent first) // All retrieves all forum posts ordered by last post time (most recent first)
func All(db *database.DB) ([]*Forum, error) { func All() ([]*Forum, error) {
var forums []*Forum var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum ORDER BY last_post DESC, id DESC" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum ORDER BY last_post DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{ forum := &Forum{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
@ -67,7 +65,6 @@ func All(db *database.DB) ([]*Forum, error) {
Replies: stmt.ColumnInt(5), Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6), Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7), Content: stmt.ColumnText(7),
db: db,
} }
forums = append(forums, forum) forums = append(forums, forum)
return nil return nil
@ -81,11 +78,11 @@ func All(db *database.DB) ([]*Forum, error) {
} }
// Threads retrieves all top-level forum threads (parent = 0) // Threads retrieves all top-level forum threads (parent = 0)
func Threads(db *database.DB) ([]*Forum, error) { func Threads() ([]*Forum, error) {
var forums []*Forum var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE parent = 0 ORDER BY last_post DESC, id DESC" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE parent = 0 ORDER BY last_post DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{ forum := &Forum{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
@ -95,7 +92,6 @@ func Threads(db *database.DB) ([]*Forum, error) {
Replies: stmt.ColumnInt(5), Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6), Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7), Content: stmt.ColumnText(7),
db: db,
} }
forums = append(forums, forum) forums = append(forums, forum)
return nil return nil
@ -109,11 +105,11 @@ func Threads(db *database.DB) ([]*Forum, error) {
} }
// ByParent retrieves all replies to a specific thread/post // ByParent retrieves all replies to a specific thread/post
func ByParent(db *database.DB, parentID int) ([]*Forum, error) { func ByParent(parentID int) ([]*Forum, error) {
var forums []*Forum var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE parent = ? ORDER BY posted ASC, id ASC" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE parent = ? ORDER BY posted ASC, id ASC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{ forum := &Forum{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
@ -123,7 +119,6 @@ func ByParent(db *database.DB, parentID int) ([]*Forum, error) {
Replies: stmt.ColumnInt(5), Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6), Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7), Content: stmt.ColumnText(7),
db: db,
} }
forums = append(forums, forum) forums = append(forums, forum)
return nil return nil
@ -137,11 +132,11 @@ func ByParent(db *database.DB, parentID int) ([]*Forum, error) {
} }
// ByAuthor retrieves forum posts by a specific author // ByAuthor retrieves forum posts by a specific author
func ByAuthor(db *database.DB, authorID int) ([]*Forum, error) { func ByAuthor(authorID int) ([]*Forum, error) {
var forums []*Forum var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE author = ? ORDER BY posted DESC, id DESC" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE author = ? ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{ forum := &Forum{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
@ -151,7 +146,6 @@ func ByAuthor(db *database.DB, authorID int) ([]*Forum, error) {
Replies: stmt.ColumnInt(5), Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6), Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7), Content: stmt.ColumnText(7),
db: db,
} }
forums = append(forums, forum) forums = append(forums, forum)
return nil return nil
@ -165,11 +159,11 @@ func ByAuthor(db *database.DB, authorID int) ([]*Forum, error) {
} }
// Recent retrieves the most recent forum activity (limited by count) // Recent retrieves the most recent forum activity (limited by count)
func Recent(db *database.DB, limit int) ([]*Forum, error) { func Recent(limit int) ([]*Forum, error) {
var forums []*Forum var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum ORDER BY last_post DESC, id DESC LIMIT ?" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum ORDER BY last_post DESC, id DESC LIMIT ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{ forum := &Forum{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
@ -179,7 +173,6 @@ func Recent(db *database.DB, limit int) ([]*Forum, error) {
Replies: stmt.ColumnInt(5), Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6), Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7), Content: stmt.ColumnText(7),
db: db,
} }
forums = append(forums, forum) forums = append(forums, forum)
return nil return nil
@ -193,13 +186,13 @@ func Recent(db *database.DB, limit int) ([]*Forum, error) {
} }
// Search retrieves forum posts containing the search term in title or content // Search retrieves forum posts containing the search term in title or content
func Search(db *database.DB, term string) ([]*Forum, error) { func Search(term string) ([]*Forum, error) {
var forums []*Forum var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE LOWER(title) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?) ORDER BY last_post DESC, id DESC" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE LOWER(title) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?) ORDER BY last_post DESC, id DESC"
searchTerm := "%" + term + "%" searchTerm := "%" + term + "%"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{ forum := &Forum{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
@ -209,7 +202,6 @@ func Search(db *database.DB, term string) ([]*Forum, error) {
Replies: stmt.ColumnInt(5), Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6), Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7), Content: stmt.ColumnText(7),
db: db,
} }
forums = append(forums, forum) forums = append(forums, forum)
return nil return nil
@ -223,11 +215,11 @@ func Search(db *database.DB, term string) ([]*Forum, error) {
} }
// Since retrieves forum posts with activity since a specific timestamp // Since retrieves forum posts with activity since a specific timestamp
func Since(db *database.DB, since int64) ([]*Forum, error) { func Since(since int64) ([]*Forum, error) {
var forums []*Forum var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE last_post >= ? ORDER BY last_post DESC, id DESC" query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE last_post >= ? ORDER BY last_post DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{ forum := &Forum{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1), Posted: stmt.ColumnInt64(1),
@ -237,7 +229,6 @@ func Since(db *database.DB, since int64) ([]*Forum, error) {
Replies: stmt.ColumnInt(5), Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6), Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7), Content: stmt.ColumnText(7),
db: db,
} }
forums = append(forums, forum) forums = append(forums, forum)
return nil return nil
@ -257,7 +248,7 @@ func (f *Forum) Save() error {
} }
query := `UPDATE forum SET posted = ?, last_post = ?, author = ?, parent = ?, replies = ?, title = ?, content = ? WHERE id = ?` query := `UPDATE forum SET posted = ?, last_post = ?, author = ?, parent = ?, replies = ?, title = ?, content = ? WHERE id = ?`
return f.db.Exec(query, f.Posted, f.LastPost, f.Author, f.Parent, f.Replies, f.Title, f.Content, f.ID) return database.Exec(query, f.Posted, f.LastPost, f.Author, f.Parent, f.Replies, f.Title, f.Content, f.ID)
} }
// Delete removes the forum post from the database // Delete removes the forum post from the database
@ -266,8 +257,7 @@ func (f *Forum) Delete() error {
return fmt.Errorf("cannot delete forum post without ID") return fmt.Errorf("cannot delete forum post without ID")
} }
query := "DELETE FROM forum WHERE id = ?" return database.Exec("DELETE FROM forum WHERE id = ?", f.ID)
return f.db.Exec(query, f.ID)
} }
// PostedTime returns the posted timestamp as a time.Time // PostedTime returns the posted timestamp as a time.Time
@ -397,7 +387,7 @@ func (f *Forum) DecrementReplies() {
// GetReplies retrieves all direct replies to this post // GetReplies retrieves all direct replies to this post
func (f *Forum) GetReplies() ([]*Forum, error) { func (f *Forum) GetReplies() ([]*Forum, error) {
return ByParent(f.db, f.ID) return ByParent(f.ID)
} }
// GetThread retrieves the parent thread (if this is a reply) or returns self (if this is a thread) // GetThread retrieves the parent thread (if this is a reply) or returns self (if this is a thread)
@ -405,5 +395,5 @@ func (f *Forum) GetThread() (*Forum, error) {
if f.IsThread() { if f.IsThread() {
return f, nil return f, nil
} }
return Find(f.db, f.Parent) return Find(f.Parent)
} }

View File

@ -24,21 +24,20 @@ func Run() error {
start := time.Now() start := time.Now()
db, err := database.Open(dbPath) if err := database.Init("dk.db"); err != nil {
if err != nil { return fmt.Errorf("failed to initialize database: %w", err)
return err
} }
defer db.Close() defer database.Close()
if err := createTables(db); err != nil { if err := createTables(); err != nil {
return fmt.Errorf("failed to create tables: %w", err) return fmt.Errorf("failed to create tables: %w", err)
} }
if err := populateData(db); err != nil { if err := populateData(); err != nil {
return fmt.Errorf("failed to populate data: %w", err) return fmt.Errorf("failed to populate data: %w", err)
} }
if err := createDemoUser(db); err != nil { if err := createDemoUser(); err != nil {
return fmt.Errorf("failed to create demo user: %w", err) return fmt.Errorf("failed to create demo user: %w", err)
} }
@ -53,7 +52,7 @@ func Run() error {
return nil return nil
} }
func createTables(db *database.DB) error { func createTables() error {
tables := []struct { tables := []struct {
name string name string
sql string sql string
@ -187,7 +186,7 @@ func createTables(db *database.DB) error {
} }
for _, table := range tables { for _, table := range tables {
if err := db.Exec(table.sql); err != nil { if err := database.Exec(table.sql); err != nil {
return fmt.Errorf("failed to create %s table: %w", table.name, err) return fmt.Errorf("failed to create %s table: %w", table.name, err)
} }
fmt.Printf("✓ %s table created\n", table.name) fmt.Printf("✓ %s table created\n", table.name)
@ -196,8 +195,8 @@ func createTables(db *database.DB) error {
return nil return nil
} }
func populateData(db *database.DB) error { func populateData() error {
if err := db.Exec("INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')"); err != nil { if err := database.Exec("INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')"); err != nil {
return fmt.Errorf("failed to populate control table: %w", err) return fmt.Errorf("failed to populate control table: %w", err)
} }
fmt.Println("✓ control table populated") fmt.Println("✓ control table populated")
@ -235,7 +234,7 @@ func populateData(db *database.DB) error {
(30, 'Diamond', 50, 1, 'defensepower,150'), (30, 'Diamond', 50, 1, 'defensepower,150'),
(31, 'Memory Drop', 5, 1, 'expbonus,10'), (31, 'Memory Drop', 5, 1, 'expbonus,10'),
(32, 'Fortune Drop', 5, 1, 'goldbonus,10')` (32, 'Fortune Drop', 5, 1, 'goldbonus,10')`
if err := db.Exec(dropsSQL); err != nil { if err := database.Exec(dropsSQL); err != nil {
return fmt.Errorf("failed to populate drops table: %w", err) return fmt.Errorf("failed to populate drops table: %w", err)
} }
fmt.Println("✓ drops table populated") fmt.Println("✓ drops table populated")
@ -274,7 +273,7 @@ func populateData(db *database.DB) error {
(31, 3, 'Large Shield', 2500, 30, ''), (31, 3, 'Large Shield', 2500, 30, ''),
(32, 3, 'Silver Shield', 10000, 60, ''), (32, 3, 'Silver Shield', 10000, 60, ''),
(33, 3, 'Destiny Aegis', 25000, 100, 'maxhp,50')` (33, 3, 'Destiny Aegis', 25000, 100, 'maxhp,50')`
if err := db.Exec(itemsSQL); err != nil { if err := database.Exec(itemsSQL); err != nil {
return fmt.Errorf("failed to populate items table: %w", err) return fmt.Errorf("failed to populate items table: %w", err)
} }
fmt.Println("✓ items table populated") fmt.Println("✓ items table populated")
@ -431,12 +430,12 @@ func populateData(db *database.DB) error {
(149, 'Titan', 360, 340, 270, 50, 2400, 800, 0), (149, 'Titan', 360, 340, 270, 50, 2400, 800, 0),
(150, 'Black Daemon', 400, 400, 280, 50, 3000, 1000, 1), (150, 'Black Daemon', 400, 400, 280, 50, 3000, 1000, 1),
(151, 'Lucifuge', 600, 600, 400, 50, 10000, 10000, 2)` (151, 'Lucifuge', 600, 600, 400, 50, 10000, 10000, 2)`
if err := db.Exec(monstersSQL); err != nil { if err := database.Exec(monstersSQL); err != nil {
return fmt.Errorf("failed to populate monsters table: %w", err) return fmt.Errorf("failed to populate monsters table: %w", err)
} }
fmt.Println("✓ monsters table populated") fmt.Println("✓ monsters table populated")
if err := db.Exec("INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')"); err != nil { if err := database.Exec("INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')"); err != nil {
return fmt.Errorf("failed to populate news table: %w", err) return fmt.Errorf("failed to populate news table: %w", err)
} }
fmt.Println("✓ news table populated") fmt.Println("✓ news table populated")
@ -461,7 +460,7 @@ func populateData(db *database.DB) error {
(17, 'Ward', 10, 10, 5), (17, 'Ward', 10, 10, 5),
(18, 'Fend', 20, 25, 5), (18, 'Fend', 20, 25, 5),
(19, 'Barrier', 30, 50, 5)` (19, 'Barrier', 30, 50, 5)`
if err := db.Exec(spellsSQL); err != nil { if err := database.Exec(spellsSQL); err != nil {
return fmt.Errorf("failed to populate spells table: %w", err) return fmt.Errorf("failed to populate spells table: %w", err)
} }
fmt.Println("✓ spells table populated") fmt.Println("✓ spells table populated")
@ -475,7 +474,7 @@ func populateData(db *database.DB) error {
(6, 'Hambry', 170, 170, 90, 1000, 80, '10,11,12,13,14,23,24,30,31'), (6, 'Hambry', 170, 170, 90, 1000, 80, '10,11,12,13,14,23,24,30,31'),
(7, 'Gilead', 200, -200, 100, 3000, 110, '12,13,14,15,24,25,26,32'), (7, 'Gilead', 200, -200, 100, 3000, 110, '12,13,14,15,24,25,26,32'),
(8, 'Endworld', -250, -250, 125, 9000, 160, '16,27,33')` (8, 'Endworld', -250, -250, 125, 9000, 160, '16,27,33')`
if err := db.Exec(townsSQL); err != nil { if err := database.Exec(townsSQL); err != nil {
return fmt.Errorf("failed to populate towns table: %w", err) return fmt.Errorf("failed to populate towns table: %w", err)
} }
fmt.Println("✓ towns table populated") fmt.Println("✓ towns table populated")
@ -483,16 +482,16 @@ func populateData(db *database.DB) error {
return nil return nil
} }
func createDemoUser(db *database.DB) error { func createDemoUser() error {
hashedPassword, err := password.Hash("Demo123!") hashedPassword, err := password.Hash("Demo123!")
if err != nil { if err != nil {
return fmt.Errorf("failed to hash password: %w", err) return fmt.Errorf("failed to hash password: %w", err)
} }
stmt := `INSERT INTO users (username, password, email, verified, class_id, auth) stmt := `INSERT INTO users (username, password, email, verified, class_id, auth)
VALUES (?, ?, ?, 1, 1, 4)` VALUES (?, ?, ?, 1, 1, 4)`
if err := db.Exec(stmt, "demo", hashedPassword, "demo@demo.com"); err != nil { if err := database.Exec(stmt, "demo", hashedPassword, "demo@demo.com"); err != nil {
return fmt.Errorf("failed to create demo user: %w", err) return fmt.Errorf("failed to create demo user: %w", err)
} }

View File

@ -10,14 +10,12 @@ import (
// Builder provides a fluent interface for creating items // Builder provides a fluent interface for creating items
type Builder struct { type Builder struct {
item *Item item *Item
db *database.DB
} }
// NewBuilder creates a new item builder // NewBuilder creates a new item builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
item: &Item{db: db}, item: &Item{},
db: db,
} }
} }
@ -55,8 +53,8 @@ func (b *Builder) WithSpecial(special string) *Builder {
func (b *Builder) Create() (*Item, error) { func (b *Builder) Create() (*Item, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var item *Item var item *Item
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO items (type, name, value, att, special) query := `INSERT INTO items (type, name, value, att, special)
VALUES (?, ?, ?, ?, ?)` VALUES (?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.item.Type, b.item.Name, b.item.Value, b.item.Att, b.item.Special); err != nil { if err := tx.Exec(query, b.item.Type, b.item.Name, b.item.Value, b.item.Att, b.item.Special); err != nil {
@ -81,7 +79,6 @@ func (b *Builder) Create() (*Item, error) {
Value: b.item.Value, Value: b.item.Value,
Att: b.item.Att, Att: b.item.Att,
Special: b.item.Special, Special: b.item.Special,
db: b.db,
} }
return nil return nil

View File

@ -16,8 +16,6 @@ type Item struct {
Value int `json:"value"` Value int `json:"value"`
Att int `json:"att"` Att int `json:"att"`
Special string `json:"special"` Special string `json:"special"`
db *database.DB
} }
// ItemType constants for item types // ItemType constants for item types
@ -28,11 +26,11 @@ const (
) )
// Find retrieves an item by ID // Find retrieves an item by ID
func Find(db *database.DB, id int) (*Item, error) { func Find(id int) (*Item, error) {
item := &Item{db: db} item := &Item{}
query := "SELECT id, type, name, value, att, special FROM items WHERE id = ?" query := "SELECT id, type, name, value, att, special FROM items WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
item.ID = stmt.ColumnInt(0) item.ID = stmt.ColumnInt(0)
item.Type = stmt.ColumnInt(1) item.Type = stmt.ColumnInt(1)
item.Name = stmt.ColumnText(2) item.Name = stmt.ColumnText(2)
@ -54,11 +52,11 @@ func Find(db *database.DB, id int) (*Item, error) {
} }
// All retrieves all items // All retrieves all items
func All(db *database.DB) ([]*Item, error) { func All() ([]*Item, error) {
var items []*Item var items []*Item
query := "SELECT id, type, name, value, att, special FROM items ORDER BY id" query := "SELECT id, type, name, value, att, special FROM items ORDER BY id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
item := &Item{ item := &Item{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Type: stmt.ColumnInt(1), Type: stmt.ColumnInt(1),
@ -66,7 +64,6 @@ func All(db *database.DB) ([]*Item, error) {
Value: stmt.ColumnInt(3), Value: stmt.ColumnInt(3),
Att: stmt.ColumnInt(4), Att: stmt.ColumnInt(4),
Special: stmt.ColumnText(5), Special: stmt.ColumnText(5),
db: db,
} }
items = append(items, item) items = append(items, item)
return nil return nil
@ -80,11 +77,11 @@ func All(db *database.DB) ([]*Item, error) {
} }
// ByType retrieves items by type // ByType retrieves items by type
func ByType(db *database.DB, itemType int) ([]*Item, error) { func ByType(itemType int) ([]*Item, error) {
var items []*Item var items []*Item
query := "SELECT id, type, name, value, att, special FROM items WHERE type = ? ORDER BY id" query := "SELECT id, type, name, value, att, special FROM items WHERE type = ? ORDER BY id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
item := &Item{ item := &Item{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Type: stmt.ColumnInt(1), Type: stmt.ColumnInt(1),
@ -92,7 +89,6 @@ func ByType(db *database.DB, itemType int) ([]*Item, error) {
Value: stmt.ColumnInt(3), Value: stmt.ColumnInt(3),
Att: stmt.ColumnInt(4), Att: stmt.ColumnInt(4),
Special: stmt.ColumnText(5), Special: stmt.ColumnText(5),
db: db,
} }
items = append(items, item) items = append(items, item)
return nil return nil
@ -112,7 +108,7 @@ func (i *Item) Save() error {
} }
query := `UPDATE items SET type = ?, name = ?, value = ?, att = ?, special = ? WHERE id = ?` query := `UPDATE items SET type = ?, name = ?, value = ?, att = ?, special = ? WHERE id = ?`
return i.db.Exec(query, i.Type, i.Name, i.Value, i.Att, i.Special, i.ID) return database.Exec(query, i.Type, i.Name, i.Value, i.Att, i.Special, i.ID)
} }
// Delete removes the item from the database // Delete removes the item from the database
@ -122,7 +118,7 @@ func (i *Item) Delete() error {
} }
query := "DELETE FROM items WHERE id = ?" query := "DELETE FROM items WHERE id = ?"
return i.db.Exec(query, i.ID) return database.Exec(query, i.ID)
} }
// IsWeapon returns true if the item is a weapon // IsWeapon returns true if the item is a weapon

View File

@ -10,14 +10,12 @@ import (
// Builder provides a fluent interface for creating monsters // Builder provides a fluent interface for creating monsters
type Builder struct { type Builder struct {
monster *Monster monster *Monster
db *database.DB
} }
// NewBuilder creates a new monster builder // NewBuilder creates a new monster builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
monster: &Monster{db: db}, monster: &Monster{},
db: db,
} }
} }
@ -73,8 +71,8 @@ func (b *Builder) WithImmunity(immunity int) *Builder {
func (b *Builder) Create() (*Monster, error) { func (b *Builder) Create() (*Monster, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var monster *Monster var monster *Monster
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO monsters (name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune) query := `INSERT INTO monsters (name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)` VALUES (?, ?, ?, ?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.monster.Name, b.monster.MaxHP, b.monster.MaxDmg, b.monster.Armor, if err := tx.Exec(query, b.monster.Name, b.monster.MaxHP, b.monster.MaxDmg, b.monster.Armor,
@ -103,7 +101,6 @@ func (b *Builder) Create() (*Monster, error) {
MaxExp: b.monster.MaxExp, MaxExp: b.monster.MaxExp,
MaxGold: b.monster.MaxGold, MaxGold: b.monster.MaxGold,
Immune: b.monster.Immune, Immune: b.monster.Immune,
db: b.db,
} }
return nil return nil

View File

@ -19,8 +19,6 @@ type Monster struct {
MaxExp int `json:"max_exp"` MaxExp int `json:"max_exp"`
MaxGold int `json:"max_gold"` MaxGold int `json:"max_gold"`
Immune int `json:"immune"` Immune int `json:"immune"`
db *database.DB
} }
// Immunity constants for monster immunity types // Immunity constants for monster immunity types
@ -31,11 +29,11 @@ const (
) )
// Find retrieves a monster by ID // Find retrieves a monster by ID
func Find(db *database.DB, id int) (*Monster, error) { func Find(id int) (*Monster, error) {
monster := &Monster{db: db} monster := &Monster{}
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE id = ?" query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
monster.ID = stmt.ColumnInt(0) monster.ID = stmt.ColumnInt(0)
monster.Name = stmt.ColumnText(1) monster.Name = stmt.ColumnText(1)
monster.MaxHP = stmt.ColumnInt(2) monster.MaxHP = stmt.ColumnInt(2)
@ -60,11 +58,11 @@ func Find(db *database.DB, id int) (*Monster, error) {
} }
// All retrieves all monsters // All retrieves all monsters
func All(db *database.DB) ([]*Monster, error) { func All() ([]*Monster, error) {
var monsters []*Monster var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters ORDER BY level, id" query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters ORDER BY level, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{ monster := &Monster{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -75,7 +73,6 @@ func All(db *database.DB) ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6), MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7), MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8), Immune: stmt.ColumnInt(8),
db: db,
} }
monsters = append(monsters, monster) monsters = append(monsters, monster)
return nil return nil
@ -89,11 +86,11 @@ func All(db *database.DB) ([]*Monster, error) {
} }
// ByLevel retrieves monsters by level // ByLevel retrieves monsters by level
func ByLevel(db *database.DB, level int) ([]*Monster, error) { func ByLevel(level int) ([]*Monster, error) {
var monsters []*Monster var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE level = ? ORDER BY id" query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE level = ? ORDER BY id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{ monster := &Monster{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -104,7 +101,6 @@ func ByLevel(db *database.DB, level int) ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6), MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7), MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8), Immune: stmt.ColumnInt(8),
db: db,
} }
monsters = append(monsters, monster) monsters = append(monsters, monster)
return nil return nil
@ -118,11 +114,11 @@ func ByLevel(db *database.DB, level int) ([]*Monster, error) {
} }
// ByLevelRange retrieves monsters within a level range (inclusive) // ByLevelRange retrieves monsters within a level range (inclusive)
func ByLevelRange(db *database.DB, minLevel, maxLevel int) ([]*Monster, error) { func ByLevelRange(minLevel, maxLevel int) ([]*Monster, error) {
var monsters []*Monster var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE level BETWEEN ? AND ? ORDER BY level, id" query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE level BETWEEN ? AND ? ORDER BY level, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{ monster := &Monster{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -133,7 +129,6 @@ func ByLevelRange(db *database.DB, minLevel, maxLevel int) ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6), MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7), MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8), Immune: stmt.ColumnInt(8),
db: db,
} }
monsters = append(monsters, monster) monsters = append(monsters, monster)
return nil return nil
@ -147,11 +142,11 @@ func ByLevelRange(db *database.DB, minLevel, maxLevel int) ([]*Monster, error) {
} }
// ByImmunity retrieves monsters by immunity type // ByImmunity retrieves monsters by immunity type
func ByImmunity(db *database.DB, immunityType int) ([]*Monster, error) { func ByImmunity(immunityType int) ([]*Monster, error) {
var monsters []*Monster var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE immune = ? ORDER BY level, id" query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE immune = ? ORDER BY level, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{ monster := &Monster{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -162,7 +157,6 @@ func ByImmunity(db *database.DB, immunityType int) ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6), MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7), MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8), Immune: stmt.ColumnInt(8),
db: db,
} }
monsters = append(monsters, monster) monsters = append(monsters, monster)
return nil return nil
@ -182,7 +176,7 @@ func (m *Monster) Save() error {
} }
query := `UPDATE monsters SET name = ?, max_hp = ?, max_dmg = ?, armor = ?, level = ?, max_exp = ?, max_gold = ?, immune = ? WHERE id = ?` query := `UPDATE monsters SET name = ?, max_hp = ?, max_dmg = ?, armor = ?, level = ?, max_exp = ?, max_gold = ?, immune = ? WHERE id = ?`
return m.db.Exec(query, m.Name, m.MaxHP, m.MaxDmg, m.Armor, m.Level, m.MaxExp, m.MaxGold, m.Immune, m.ID) return database.Exec(query, m.Name, m.MaxHP, m.MaxDmg, m.Armor, m.Level, m.MaxExp, m.MaxGold, m.Immune, m.ID)
} }
// Delete removes the monster from the database // Delete removes the monster from the database
@ -192,7 +186,7 @@ func (m *Monster) Delete() error {
} }
query := "DELETE FROM monsters WHERE id = ?" query := "DELETE FROM monsters WHERE id = ?"
return m.db.Exec(query, m.ID) return database.Exec(query, m.ID)
} }
// IsHurtImmune returns true if the monster is immune to Hurt spells // IsHurtImmune returns true if the monster is immune to Hurt spells

View File

@ -12,17 +12,14 @@ import (
// Builder provides a fluent interface for creating news posts // Builder provides a fluent interface for creating news posts
type Builder struct { type Builder struct {
news *News news *News
db *database.DB
} }
// NewBuilder creates a new news builder // NewBuilder creates a new news builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
news: &News{ news: &News{
db: db,
Posted: time.Now().Unix(), // Default to current time Posted: time.Now().Unix(), // Default to current time
}, },
db: db,
} }
} }
@ -54,10 +51,10 @@ func (b *Builder) WithPostedTime(t time.Time) *Builder {
func (b *Builder) Create() (*News, error) { func (b *Builder) Create() (*News, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var news *News var news *News
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO news (author, posted, content) query := `INSERT INTO news (author, posted, content)
VALUES (?, ?, ?)` VALUES (?, ?, ?)`
if err := tx.Exec(query, b.news.Author, b.news.Posted, b.news.Content); err != nil { if err := tx.Exec(query, b.news.Author, b.news.Posted, b.news.Content); err != nil {
return fmt.Errorf("failed to insert news: %w", err) return fmt.Errorf("failed to insert news: %w", err)
} }
@ -76,10 +73,10 @@ func (b *Builder) Create() (*News, error) {
news = b.news news = b.news
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return news, nil return news, nil
} }

View File

@ -1,11 +1,10 @@
package news package news
import ( import (
"dk/internal/database"
"fmt" "fmt"
"time" "time"
"dk/internal/database"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
@ -15,16 +14,14 @@ type News struct {
Author int `json:"author"` Author int `json:"author"`
Posted int64 `json:"posted"` Posted int64 `json:"posted"`
Content string `json:"content"` Content string `json:"content"`
db *database.DB
} }
// Find retrieves a news post by ID // Find retrieves a news post by ID
func Find(db *database.DB, id int) (*News, error) { func Find(id int) (*News, error) {
news := &News{db: db} news := &News{}
query := "SELECT id, author, posted, content FROM news WHERE id = ?" query := "SELECT id, author, posted, content FROM news WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
news.ID = stmt.ColumnInt(0) news.ID = stmt.ColumnInt(0)
news.Author = stmt.ColumnInt(1) news.Author = stmt.ColumnInt(1)
news.Posted = stmt.ColumnInt64(2) news.Posted = stmt.ColumnInt64(2)
@ -44,17 +41,16 @@ func Find(db *database.DB, id int) (*News, error) {
} }
// All retrieves all news posts ordered by posted date (newest first) // All retrieves all news posts ordered by posted date (newest first)
func All(db *database.DB) ([]*News, error) { func All() ([]*News, error) {
var newsPosts []*News var newsPosts []*News
query := "SELECT id, author, posted, content FROM news ORDER BY posted DESC, id DESC" query := "SELECT id, author, posted, content FROM news ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{ news := &News{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1), Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2), Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3), Content: stmt.ColumnText(3),
db: db,
} }
newsPosts = append(newsPosts, news) newsPosts = append(newsPosts, news)
return nil return nil
@ -68,17 +64,16 @@ func All(db *database.DB) ([]*News, error) {
} }
// ByAuthor retrieves news posts by a specific author // ByAuthor retrieves news posts by a specific author
func ByAuthor(db *database.DB, authorID int) ([]*News, error) { func ByAuthor(authorID int) ([]*News, error) {
var newsPosts []*News var newsPosts []*News
query := "SELECT id, author, posted, content FROM news WHERE author = ? ORDER BY posted DESC, id DESC" query := "SELECT id, author, posted, content FROM news WHERE author = ? ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{ news := &News{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1), Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2), Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3), Content: stmt.ColumnText(3),
db: db,
} }
newsPosts = append(newsPosts, news) newsPosts = append(newsPosts, news)
return nil return nil
@ -92,17 +87,16 @@ func ByAuthor(db *database.DB, authorID int) ([]*News, error) {
} }
// Recent retrieves the most recent news posts (limited by count) // Recent retrieves the most recent news posts (limited by count)
func Recent(db *database.DB, limit int) ([]*News, error) { func Recent(limit int) ([]*News, error) {
var newsPosts []*News var newsPosts []*News
query := "SELECT id, author, posted, content FROM news ORDER BY posted DESC, id DESC LIMIT ?" query := "SELECT id, author, posted, content FROM news ORDER BY posted DESC, id DESC LIMIT ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{ news := &News{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1), Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2), Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3), Content: stmt.ColumnText(3),
db: db,
} }
newsPosts = append(newsPosts, news) newsPosts = append(newsPosts, news)
return nil return nil
@ -116,17 +110,16 @@ func Recent(db *database.DB, limit int) ([]*News, error) {
} }
// Since retrieves news posts since a specific timestamp // Since retrieves news posts since a specific timestamp
func Since(db *database.DB, since int64) ([]*News, error) { func Since(since int64) ([]*News, error) {
var newsPosts []*News var newsPosts []*News
query := "SELECT id, author, posted, content FROM news WHERE posted >= ? ORDER BY posted DESC, id DESC" query := "SELECT id, author, posted, content FROM news WHERE posted >= ? ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{ news := &News{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1), Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2), Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3), Content: stmt.ColumnText(3),
db: db,
} }
newsPosts = append(newsPosts, news) newsPosts = append(newsPosts, news)
return nil return nil
@ -140,17 +133,16 @@ func Since(db *database.DB, since int64) ([]*News, error) {
} }
// Between retrieves news posts between two timestamps (inclusive) // Between retrieves news posts between two timestamps (inclusive)
func Between(db *database.DB, start, end int64) ([]*News, error) { func Between(start, end int64) ([]*News, error) {
var newsPosts []*News var newsPosts []*News
query := "SELECT id, author, posted, content FROM news WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC" query := "SELECT id, author, posted, content FROM news WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{ news := &News{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1), Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2), Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3), Content: stmt.ColumnText(3),
db: db,
} }
newsPosts = append(newsPosts, news) newsPosts = append(newsPosts, news)
return nil return nil
@ -170,7 +162,7 @@ func (n *News) Save() error {
} }
query := `UPDATE news SET author = ?, posted = ?, content = ? WHERE id = ?` query := `UPDATE news SET author = ?, posted = ?, content = ? WHERE id = ?`
return n.db.Exec(query, n.Author, n.Posted, n.Content, n.ID) return database.Exec(query, n.Author, n.Posted, n.Content, n.ID)
} }
// Delete removes the news post from the database // Delete removes the news post from the database
@ -180,7 +172,7 @@ func (n *News) Delete() error {
} }
query := "DELETE FROM news WHERE id = ?" query := "DELETE FROM news WHERE id = ?"
return n.db.Exec(query, n.ID) return database.Exec(query, n.ID)
} }
// PostedTime returns the posted timestamp as a time.Time // PostedTime returns the posted timestamp as a time.Time
@ -213,11 +205,11 @@ func (n *News) Preview(maxLength int) string {
if len(n.Content) <= maxLength { if len(n.Content) <= maxLength {
return n.Content return n.Content
} }
if maxLength < 3 { if maxLength < 3 {
return n.Content[:maxLength] return n.Content[:maxLength]
} }
return n.Content[:maxLength-3] + "..." return n.Content[:maxLength-3] + "..."
} }
@ -226,11 +218,11 @@ func (n *News) WordCount() int {
if n.Content == "" { if n.Content == "" {
return 0 return 0
} }
// Simple word count by splitting on spaces // Simple word count by splitting on spaces
words := 0 words := 0
inWord := false inWord := false
for _, char := range n.Content { for _, char := range n.Content {
if char == ' ' || char == '\t' || char == '\n' || char == '\r' { if char == ' ' || char == '\t' || char == '\n' || char == '\r' {
if inWord { if inWord {
@ -241,10 +233,10 @@ func (n *News) WordCount() int {
inWord = true inWord = true
} }
} }
if inWord { if inWord {
words++ words++
} }
return words return words
} }

View File

@ -28,18 +28,14 @@ func Start(port string) error {
template.InitializeCache(cwd) template.InitializeCache(cwd)
// Initialize database singleton // Initialize database singleton
if err := database.InitializeDB("dk.db"); err != nil { if err := database.Init("dk.db"); err != nil {
return fmt.Errorf("failed to initialize database: %w", err) return fmt.Errorf("failed to initialize database: %w", err)
} }
defer database.Close() defer database.Close()
// Initialize auth singleton auth.Init("sessions.json") // Initialize auth.Manager
auth.InitializeManager(database.GetDB(), "sessions.json")
// Initialize router
r := router.New() r := router.New()
// Add middleware
r.Use(middleware.Timing()) r.Use(middleware.Timing())
r.Use(middleware.Auth(auth.Manager)) r.Use(middleware.Auth(auth.Manager))
r.Use(middleware.CSRF(auth.Manager)) r.Use(middleware.CSRF(auth.Manager))

View File

@ -10,14 +10,12 @@ import (
// Builder provides a fluent interface for creating spells // Builder provides a fluent interface for creating spells
type Builder struct { type Builder struct {
spell *Spell spell *Spell
db *database.DB
} }
// NewBuilder creates a new spell builder // NewBuilder creates a new spell builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
spell: &Spell{db: db}, spell: &Spell{},
db: db,
} }
} }
@ -49,8 +47,8 @@ func (b *Builder) WithType(spellType int) *Builder {
func (b *Builder) Create() (*Spell, error) { func (b *Builder) Create() (*Spell, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var spell *Spell var spell *Spell
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO spells (name, mp, attribute, type) query := `INSERT INTO spells (name, mp, attribute, type)
VALUES (?, ?, ?, ?)` VALUES (?, ?, ?, ?)`
if err := tx.Exec(query, b.spell.Name, b.spell.MP, b.spell.Attribute, b.spell.Type); err != nil { if err := tx.Exec(query, b.spell.Name, b.spell.MP, b.spell.Attribute, b.spell.Type); err != nil {
@ -74,7 +72,6 @@ func (b *Builder) Create() (*Spell, error) {
MP: b.spell.MP, MP: b.spell.MP,
Attribute: b.spell.Attribute, Attribute: b.spell.Attribute,
Type: b.spell.Type, Type: b.spell.Type,
db: b.db,
} }
return nil return nil

View File

@ -15,8 +15,6 @@ type Spell struct {
MP int `json:"mp"` MP int `json:"mp"`
Attribute int `json:"attribute"` Attribute int `json:"attribute"`
Type int `json:"type"` Type int `json:"type"`
db *database.DB
} }
// SpellType constants for spell types // SpellType constants for spell types
@ -29,11 +27,11 @@ const (
) )
// Find retrieves a spell by ID // Find retrieves a spell by ID
func Find(db *database.DB, id int) (*Spell, error) { func Find(id int) (*Spell, error) {
spell := &Spell{db: db} spell := &Spell{}
query := "SELECT id, name, mp, attribute, type FROM spells WHERE id = ?" query := "SELECT id, name, mp, attribute, type FROM spells WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
spell.ID = stmt.ColumnInt(0) spell.ID = stmt.ColumnInt(0)
spell.Name = stmt.ColumnText(1) spell.Name = stmt.ColumnText(1)
spell.MP = stmt.ColumnInt(2) spell.MP = stmt.ColumnInt(2)
@ -54,18 +52,17 @@ func Find(db *database.DB, id int) (*Spell, error) {
} }
// All retrieves all spells // All retrieves all spells
func All(db *database.DB) ([]*Spell, error) { func All() ([]*Spell, error) {
var spells []*Spell var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells ORDER BY type, mp, id" query := "SELECT id, name, mp, attribute, type FROM spells ORDER BY type, mp, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{ spell := &Spell{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2), MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3), Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4), Type: stmt.ColumnInt(4),
db: db,
} }
spells = append(spells, spell) spells = append(spells, spell)
return nil return nil
@ -79,18 +76,17 @@ func All(db *database.DB) ([]*Spell, error) {
} }
// ByType retrieves spells by type // ByType retrieves spells by type
func ByType(db *database.DB, spellType int) ([]*Spell, error) { func ByType(spellType int) ([]*Spell, error) {
var spells []*Spell var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells WHERE type = ? ORDER BY mp, id" query := "SELECT id, name, mp, attribute, type FROM spells WHERE type = ? ORDER BY mp, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{ spell := &Spell{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2), MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3), Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4), Type: stmt.ColumnInt(4),
db: db,
} }
spells = append(spells, spell) spells = append(spells, spell)
return nil return nil
@ -104,18 +100,17 @@ func ByType(db *database.DB, spellType int) ([]*Spell, error) {
} }
// ByMaxMP retrieves spells that cost at most the specified MP // ByMaxMP retrieves spells that cost at most the specified MP
func ByMaxMP(db *database.DB, maxMP int) ([]*Spell, error) { func ByMaxMP(maxMP int) ([]*Spell, error) {
var spells []*Spell var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells WHERE mp <= ? ORDER BY type, mp, id" query := "SELECT id, name, mp, attribute, type FROM spells WHERE mp <= ? ORDER BY type, mp, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{ spell := &Spell{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2), MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3), Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4), Type: stmt.ColumnInt(4),
db: db,
} }
spells = append(spells, spell) spells = append(spells, spell)
return nil return nil
@ -129,18 +124,17 @@ func ByMaxMP(db *database.DB, maxMP int) ([]*Spell, error) {
} }
// ByTypeAndMaxMP retrieves spells of a specific type that cost at most the specified MP // ByTypeAndMaxMP retrieves spells of a specific type that cost at most the specified MP
func ByTypeAndMaxMP(db *database.DB, spellType, maxMP int) ([]*Spell, error) { func ByTypeAndMaxMP(spellType, maxMP int) ([]*Spell, error) {
var spells []*Spell var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells WHERE type = ? AND mp <= ? ORDER BY mp, id" query := "SELECT id, name, mp, attribute, type FROM spells WHERE type = ? AND mp <= ? ORDER BY mp, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{ spell := &Spell{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2), MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3), Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4), Type: stmt.ColumnInt(4),
db: db,
} }
spells = append(spells, spell) spells = append(spells, spell)
return nil return nil
@ -154,11 +148,11 @@ func ByTypeAndMaxMP(db *database.DB, spellType, maxMP int) ([]*Spell, error) {
} }
// ByName retrieves a spell by name (case-insensitive) // ByName retrieves a spell by name (case-insensitive)
func ByName(db *database.DB, name string) (*Spell, error) { func ByName(name string) (*Spell, error) {
spell := &Spell{db: db} spell := &Spell{}
query := "SELECT id, name, mp, attribute, type FROM spells WHERE LOWER(name) = LOWER(?) LIMIT 1" query := "SELECT id, name, mp, attribute, type FROM spells WHERE LOWER(name) = LOWER(?) LIMIT 1"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
spell.ID = stmt.ColumnInt(0) spell.ID = stmt.ColumnInt(0)
spell.Name = stmt.ColumnText(1) spell.Name = stmt.ColumnText(1)
spell.MP = stmt.ColumnInt(2) spell.MP = stmt.ColumnInt(2)
@ -185,7 +179,7 @@ func (s *Spell) Save() error {
} }
query := `UPDATE spells SET name = ?, mp = ?, attribute = ?, type = ? WHERE id = ?` query := `UPDATE spells SET name = ?, mp = ?, attribute = ?, type = ? WHERE id = ?`
return s.db.Exec(query, s.Name, s.MP, s.Attribute, s.Type, s.ID) return database.Exec(query, s.Name, s.MP, s.Attribute, s.Type, s.ID)
} }
// Delete removes the spell from the database // Delete removes the spell from the database
@ -195,7 +189,7 @@ func (s *Spell) Delete() error {
} }
query := "DELETE FROM spells WHERE id = ?" query := "DELETE FROM spells WHERE id = ?"
return s.db.Exec(query, s.ID) return database.Exec(query, s.ID)
} }
// IsHealing returns true if the spell is a healing spell // IsHealing returns true if the spell is a healing spell

View File

@ -11,16 +11,12 @@ import (
// Builder provides a fluent interface for creating towns // Builder provides a fluent interface for creating towns
type Builder struct { type Builder struct {
town *Town town *Town
db *database.DB
} }
// NewBuilder creates a new town builder // NewBuilder creates a new town builder
func NewBuilder(db *database.DB) *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
town: &Town{ town: &Town{},
db: db,
},
db: db,
} }
} }
@ -83,11 +79,11 @@ func (b *Builder) WithShopItems(items []string) *Builder {
func (b *Builder) Create() (*Town, error) { func (b *Builder) Create() (*Town, error) {
// Use a transaction to ensure we can get the ID // Use a transaction to ensure we can get the ID
var town *Town var town *Town
err := b.db.Transaction(func(tx *database.Tx) error { err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO towns (name, x, y, inn_cost, map_cost, tp_cost, shop_list) query := `INSERT INTO towns (name, x, y, inn_cost, map_cost, tp_cost, shop_list)
VALUES (?, ?, ?, ?, ?, ?, ?)` VALUES (?, ?, ?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.town.Name, b.town.X, b.town.Y, if err := tx.Exec(query, b.town.Name, b.town.X, b.town.Y,
b.town.InnCost, b.town.MapCost, b.town.TPCost, b.town.ShopList); err != nil { b.town.InnCost, b.town.MapCost, b.town.TPCost, b.town.ShopList); err != nil {
return fmt.Errorf("failed to insert town: %w", err) return fmt.Errorf("failed to insert town: %w", err)
} }
@ -106,10 +102,10 @@ func (b *Builder) Create() (*Town, error) {
town = b.town town = b.town
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return town, nil return town, nil
} }

View File

@ -19,16 +19,14 @@ type Town struct {
MapCost int `json:"map_cost"` MapCost int `json:"map_cost"`
TPCost int `json:"tp_cost"` TPCost int `json:"tp_cost"`
ShopList string `json:"shop_list"` ShopList string `json:"shop_list"`
db *database.DB
} }
// Find retrieves a town by ID // Find retrieves a town by ID
func Find(db *database.DB, id int) (*Town, error) { func Find(id int) (*Town, error) {
town := &Town{db: db} town := &Town{}
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE id = ?" query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE id = ?"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
town.ID = stmt.ColumnInt(0) town.ID = stmt.ColumnInt(0)
town.Name = stmt.ColumnText(1) town.Name = stmt.ColumnText(1)
town.X = stmt.ColumnInt(2) town.X = stmt.ColumnInt(2)
@ -52,11 +50,11 @@ func Find(db *database.DB, id int) (*Town, error) {
} }
// All retrieves all towns // All retrieves all towns
func All(db *database.DB) ([]*Town, error) { func All() ([]*Town, error) {
var towns []*Town var towns []*Town
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns ORDER BY id" query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns ORDER BY id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{ town := &Town{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -66,7 +64,6 @@ func All(db *database.DB) ([]*Town, error) {
MapCost: stmt.ColumnInt(5), MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6), TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7), ShopList: stmt.ColumnText(7),
db: db,
} }
towns = append(towns, town) towns = append(towns, town)
return nil return nil
@ -80,11 +77,11 @@ func All(db *database.DB) ([]*Town, error) {
} }
// ByName retrieves a town by name (case-insensitive) // ByName retrieves a town by name (case-insensitive)
func ByName(db *database.DB, name string) (*Town, error) { func ByName(name string) (*Town, error) {
town := &Town{db: db} town := &Town{}
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE LOWER(name) = LOWER(?) LIMIT 1" query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE LOWER(name) = LOWER(?) LIMIT 1"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
town.ID = stmt.ColumnInt(0) town.ID = stmt.ColumnInt(0)
town.Name = stmt.ColumnText(1) town.Name = stmt.ColumnText(1)
town.X = stmt.ColumnInt(2) town.X = stmt.ColumnInt(2)
@ -108,11 +105,11 @@ func ByName(db *database.DB, name string) (*Town, error) {
} }
// ByMaxInnCost retrieves towns with inn cost at most the specified amount // ByMaxInnCost retrieves towns with inn cost at most the specified amount
func ByMaxInnCost(db *database.DB, maxCost int) ([]*Town, error) { func ByMaxInnCost(maxCost int) ([]*Town, error) {
var towns []*Town var towns []*Town
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE inn_cost <= ? ORDER BY inn_cost, id" query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE inn_cost <= ? ORDER BY inn_cost, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{ town := &Town{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -122,7 +119,6 @@ func ByMaxInnCost(db *database.DB, maxCost int) ([]*Town, error) {
MapCost: stmt.ColumnInt(5), MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6), TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7), ShopList: stmt.ColumnText(7),
db: db,
} }
towns = append(towns, town) towns = append(towns, town)
return nil return nil
@ -136,11 +132,11 @@ func ByMaxInnCost(db *database.DB, maxCost int) ([]*Town, error) {
} }
// ByMaxTPCost retrieves towns with teleport cost at most the specified amount // ByMaxTPCost retrieves towns with teleport cost at most the specified amount
func ByMaxTPCost(db *database.DB, maxCost int) ([]*Town, error) { func ByMaxTPCost(maxCost int) ([]*Town, error) {
var towns []*Town var towns []*Town
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE tp_cost <= ? ORDER BY tp_cost, id" query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE tp_cost <= ? ORDER BY tp_cost, id"
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{ town := &Town{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -150,7 +146,6 @@ func ByMaxTPCost(db *database.DB, maxCost int) ([]*Town, error) {
MapCost: stmt.ColumnInt(5), MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6), TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7), ShopList: stmt.ColumnText(7),
db: db,
} }
towns = append(towns, town) towns = append(towns, town)
return nil return nil
@ -164,16 +159,16 @@ func ByMaxTPCost(db *database.DB, maxCost int) ([]*Town, error) {
} }
// ByDistance retrieves towns within a certain distance from a point // ByDistance retrieves towns within a certain distance from a point
func ByDistance(db *database.DB, fromX, fromY, maxDistance int) ([]*Town, error) { func ByDistance(fromX, fromY, maxDistance int) ([]*Town, error) {
var towns []*Town var towns []*Town
query := `SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list query := `SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list
FROM towns FROM towns
WHERE ((x - ?) * (x - ?) + (y - ?) * (y - ?)) <= ? WHERE ((x - ?) * (x - ?) + (y - ?) * (y - ?)) <= ?
ORDER BY ((x - ?) * (x - ?) + (y - ?) * (y - ?)), id` ORDER BY ((x - ?) * (x - ?) + (y - ?) * (y - ?)), id`
maxDistance2 := maxDistance * maxDistance maxDistance2 := maxDistance * maxDistance
err := db.Query(query, func(stmt *sqlite.Stmt) error { err := database.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{ town := &Town{
ID: stmt.ColumnInt(0), ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1), Name: stmt.ColumnText(1),
@ -183,7 +178,6 @@ func ByDistance(db *database.DB, fromX, fromY, maxDistance int) ([]*Town, error)
MapCost: stmt.ColumnInt(5), MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6), TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7), ShopList: stmt.ColumnText(7),
db: db,
} }
towns = append(towns, town) towns = append(towns, town)
return nil return nil
@ -203,7 +197,7 @@ func (t *Town) Save() error {
} }
query := `UPDATE towns SET name = ?, x = ?, y = ?, inn_cost = ?, map_cost = ?, tp_cost = ?, shop_list = ? WHERE id = ?` query := `UPDATE towns SET name = ?, x = ?, y = ?, inn_cost = ?, map_cost = ?, tp_cost = ?, shop_list = ? WHERE id = ?`
return t.db.Exec(query, t.Name, t.X, t.Y, t.InnCost, t.MapCost, t.TPCost, t.ShopList, t.ID) return database.Exec(query, t.Name, t.X, t.Y, t.InnCost, t.MapCost, t.TPCost, t.ShopList, t.ID)
} }
// Delete removes the town from the database // Delete removes the town from the database
@ -213,7 +207,7 @@ func (t *Town) Delete() error {
} }
query := "DELETE FROM towns WHERE id = ?" query := "DELETE FROM towns WHERE id = ?"
return t.db.Exec(query, t.ID) return database.Exec(query, t.ID)
} }
// GetShopItems returns the shop items as a slice of item IDs // GetShopItems returns the shop items as a slice of item IDs
@ -265,4 +259,4 @@ func (t *Town) CanAffordMap(gold int) bool {
// CanAffordTeleport returns true if the player can afford to teleport here // CanAffordTeleport returns true if the player can afford to teleport here
func (t *Town) CanAffordTeleport(gold int) bool { func (t *Town) CanAffordTeleport(gold int) bool {
return gold >= t.TPCost return gold >= t.TPCost
} }

10
internal/utils/http.go Normal file
View File

@ -0,0 +1,10 @@
package utils
import "github.com/valyala/fasthttp"
// IsHTTPS tries to determine if the current request context is over HTTPS
func IsHTTPS(ctx *fasthttp.RequestCtx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}