Compare commits

..

No commits in common. "820bc874186bfefd61fec9059b011fc8bac5e184" and "80700149f80ccf685cee4e19a628a164c0a195ff" have entirely different histories.

25 changed files with 546 additions and 407 deletions

View File

@ -1,6 +1,7 @@
package auth
import (
"dk/internal/database"
"dk/internal/password"
"dk/internal/users"
)
@ -8,28 +9,29 @@ import (
// Manager is the global singleton instance
var Manager *AuthManager
// User is a simplified User struct for auth purposes
type User struct {
ID int
Username string
Email string
}
// AuthManager is a wrapper for the session store to add
// authentication tools over the store itself
type AuthManager struct {
store *SessionStore
sessionStore *SessionStore
db *database.DB
}
// Init initializes the global auth manager (auth.Manager)
func Init(sessionsFilePath string) {
Manager = &AuthManager{
store: NewSessionStore(sessionsFilePath),
func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager {
return &AuthManager{
sessionStore: NewSessionStore(sessionsFilePath),
db: db,
}
}
// Authenticate checks for the usernaname or email, then verifies the plain password
// against the stored hash.
// InitializeManager initializes the global Manager singleton
func InitializeManager(db *database.DB, sessionsFilePath string) {
Manager = NewAuthManager(db, sessionsFilePath)
}
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
var user *users.User
var err error
@ -37,12 +39,14 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use
// Try to find user by username first
user, err = users.GetByUsername(usernameOrEmail)
if err != nil {
// Try by email if username lookup failed
user, err = users.GetByEmail(usernameOrEmail)
if err != nil {
return nil, err
}
}
// Verify password
isValid, err := password.Verify(plainPassword, user.Password)
if err != nil {
return nil, err
@ -59,27 +63,31 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use
}
func (am *AuthManager) CreateSession(user *User) *Session {
return am.store.Create(user.ID, user.Username, user.Email)
return am.sessionStore.Create(user.ID, user.Username, user.Email)
}
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
return am.store.Get(sessionID)
return am.sessionStore.Get(sessionID)
}
func (am *AuthManager) UpdateSession(sessionID string) bool {
return am.store.Update(sessionID)
return am.sessionStore.Update(sessionID)
}
func (am *AuthManager) DeleteSession(sessionID string) {
am.store.Delete(sessionID)
am.sessionStore.Delete(sessionID)
}
func (am *AuthManager) SessionStats() (total, active int) {
return am.store.Stats()
return am.sessionStore.Stats()
}
func (am *AuthManager) DB() *database.DB {
return am.db
}
func (am *AuthManager) Close() error {
return am.store.Close()
return am.sessionStore.Close()
}
var (

View File

@ -1,29 +1,103 @@
package auth
import (
"dk/internal/cookies"
"dk/internal/utils"
"time"
"github.com/valyala/fasthttp"
)
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
SetSecureCookie(ctx, CookieOptions{
Name: SessionCookieName,
Value: sessionID,
Path: "/",
Expires: time.Now().Add(DefaultExpiration),
HTTPOnly: true,
Secure: utils.IsHTTPS(ctx),
Secure: isHTTPS(ctx),
SameSite: "lax",
})
}
func GetSessionCookie(ctx *fasthttp.RequestCtx) string {
return cookies.GetCookie(ctx, SessionCookieName)
return GetCookie(ctx, SessionCookieName)
}
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
cookies.DeleteCookie(ctx, SessionCookieName)
DeleteCookie(ctx, SessionCookieName)
}
func isHTTPS(ctx *fasthttp.RequestCtx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}

View File

@ -16,14 +16,16 @@ type Babble struct {
Posted int64 `json:"posted"`
Author string `json:"author"`
Babble string `json:"babble"`
db *database.DB
}
// Find retrieves a babble message by ID
func Find(id int) (*Babble, error) {
babble := &Babble{}
func Find(db *database.DB, id int) (*Babble, error) {
babble := &Babble{db: db}
query := "SELECT id, posted, author, babble FROM babble WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble.ID = stmt.ColumnInt(0)
babble.Posted = stmt.ColumnInt64(1)
babble.Author = stmt.ColumnText(2)
@ -43,16 +45,17 @@ func Find(id int) (*Babble, error) {
}
// All retrieves all babble messages ordered by posted time (newest first)
func All() ([]*Babble, error) {
func All(db *database.DB) ([]*Babble, error) {
var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3),
db: db,
}
babbles = append(babbles, babble)
return nil
@ -66,16 +69,17 @@ func All() ([]*Babble, error) {
}
// ByAuthor retrieves babble messages by a specific author
func ByAuthor(author string) ([]*Babble, error) {
func ByAuthor(db *database.DB, author string) ([]*Babble, error) {
var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3),
db: db,
}
babbles = append(babbles, babble)
return nil
@ -89,16 +93,17 @@ func ByAuthor(author string) ([]*Babble, error) {
}
// Recent retrieves the most recent babble messages (limited by count)
func Recent(limit int) ([]*Babble, error) {
func Recent(db *database.DB, limit int) ([]*Babble, error) {
var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble ORDER BY posted DESC, id DESC LIMIT ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3),
db: db,
}
babbles = append(babbles, babble)
return nil
@ -112,16 +117,17 @@ func Recent(limit int) ([]*Babble, error) {
}
// Since retrieves babble messages since a specific timestamp
func Since(since int64) ([]*Babble, error) {
func Since(db *database.DB, since int64) ([]*Babble, error) {
var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE posted >= ? ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3),
db: db,
}
babbles = append(babbles, babble)
return nil
@ -135,16 +141,17 @@ func Since(since int64) ([]*Babble, error) {
}
// Between retrieves babble messages between two timestamps (inclusive)
func Between(start, end int64) ([]*Babble, error) {
func Between(db *database.DB, start, end int64) ([]*Babble, error) {
var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3),
db: db,
}
babbles = append(babbles, babble)
return nil
@ -158,18 +165,19 @@ func Between(start, end int64) ([]*Babble, error) {
}
// Search retrieves babble messages containing the search term (case-insensitive)
func Search(term string) ([]*Babble, error) {
func Search(db *database.DB, term string) ([]*Babble, error) {
var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(babble) LIKE LOWER(?) ORDER BY posted DESC, id DESC"
searchTerm := "%" + term + "%"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3),
db: db,
}
babbles = append(babbles, babble)
return nil
@ -183,16 +191,17 @@ func Search(term string) ([]*Babble, error) {
}
// RecentByAuthor retrieves recent messages from a specific author
func RecentByAuthor(author string, limit int) ([]*Babble, error) {
func RecentByAuthor(db *database.DB, author string, limit int) ([]*Babble, error) {
var babbles []*Babble
query := "SELECT id, posted, author, babble FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC LIMIT ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
babble := &Babble{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
Author: stmt.ColumnText(2),
Babble: stmt.ColumnText(3),
db: db,
}
babbles = append(babbles, babble)
return nil
@ -212,7 +221,7 @@ func (b *Babble) Save() error {
}
query := `UPDATE babble SET posted = ?, author = ?, babble = ? WHERE id = ?`
return database.Exec(query, b.Posted, b.Author, b.Babble, b.ID)
return b.db.Exec(query, b.Posted, b.Author, b.Babble, b.ID)
}
// Delete removes the babble message from the database
@ -221,7 +230,8 @@ func (b *Babble) Delete() error {
return fmt.Errorf("cannot delete babble without ID")
}
return database.Exec("DELETE FROM babble WHERE id = ?", b.ID)
query := "DELETE FROM babble WHERE id = ?"
return b.db.Exec(query, b.ID)
}
// PostedTime returns the posted timestamp as a time.Time
@ -254,11 +264,11 @@ func (b *Babble) Preview(maxLength int) string {
if len(b.Babble) <= maxLength {
return b.Babble
}
if maxLength < 3 {
return b.Babble[:maxLength]
}
return b.Babble[:maxLength-3] + "..."
}
@ -267,11 +277,11 @@ func (b *Babble) WordCount() int {
if b.Babble == "" {
return 0
}
// Simple word count by splitting on whitespace
words := 0
inWord := false
for _, char := range b.Babble {
if char == ' ' || char == '\t' || char == '\n' || char == '\r' {
if inWord {
@ -282,11 +292,11 @@ func (b *Babble) WordCount() int {
inWord = true
}
}
if inWord {
words++
}
return words
}
@ -314,7 +324,7 @@ func (b *Babble) IsLongMessage(threshold int) bool {
func (b *Babble) GetMentions() []string {
words := strings.Fields(b.Babble)
var mentions []string
for _, word := range words {
if strings.HasPrefix(word, "@") && len(word) > 1 {
// Clean up punctuation from the end
@ -324,7 +334,7 @@ func (b *Babble) GetMentions() []string {
}
}
}
return mentions
}
@ -337,4 +347,4 @@ func (b *Babble) HasMention(username string) bool {
}
}
return false
}
}

View File

@ -12,14 +12,17 @@ import (
// Builder provides a fluent interface for creating babble messages
type Builder struct {
babble *Babble
db *database.DB
}
// NewBuilder creates a new babble builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
return &Builder{
babble: &Babble{
db: db,
Posted: time.Now().Unix(), // Default to current time
},
db: db,
}
}
@ -56,10 +59,10 @@ func (b *Builder) WithPostedTime(t time.Time) *Builder {
func (b *Builder) Create() (*Babble, error) {
// Use a transaction to ensure we can get the ID
var babble *Babble
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO babble (posted, author, babble)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO babble (posted, author, babble)
VALUES (?, ?, ?)`
if err := tx.Exec(query, b.babble.Posted, b.babble.Author, b.babble.Babble); err != nil {
return fmt.Errorf("failed to insert babble: %w", err)
}
@ -78,10 +81,10 @@ func (b *Builder) Create() (*Babble, error) {
babble = b.babble
return nil
})
if err != nil {
return nil, err
}
return babble, nil
}
}

View File

@ -9,22 +9,25 @@ import (
)
// Control represents the game control settings in the database
// There is only ever one control record with ID 1
type Control struct {
ID int `json:"id"`
WorldSize int `json:"world_size"`
Open int `json:"open"`
AdminEmail string `json:"admin_email"`
Class1Name string `json:"class_1_name"`
Class2Name string `json:"class_2_name"`
Class3Name string `json:"class_3_name"`
ID int `json:"id"`
WorldSize int `json:"world_size"`
Open int `json:"open"`
AdminEmail string `json:"admin_email"`
Class1Name string `json:"class_1_name"`
Class2Name string `json:"class_2_name"`
Class3Name string `json:"class_3_name"`
db *database.DB
}
// Find retrieves the control record by ID (typically only ID 1 exists)
func Find(id int) (*Control, error) {
control := &Control{}
func Find(db *database.DB, id int) (*Control, error) {
control := &Control{db: db}
query := "SELECT id, world_size, open, admin_email, class_1_name, class_2_name, class_3_name FROM control WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
control.ID = stmt.ColumnInt(0)
control.WorldSize = stmt.ColumnInt(1)
control.Open = stmt.ColumnInt(2)
@ -47,8 +50,8 @@ func Find(id int) (*Control, error) {
}
// Get retrieves the main control record (ID 1)
func Get() (*Control, error) {
return Find(1)
func Get(db *database.DB) (*Control, error) {
return Find(db, 1)
}
// Save updates the control record in the database
@ -58,7 +61,7 @@ func (c *Control) Save() error {
}
query := `UPDATE control SET world_size = ?, open = ?, admin_email = ?, class_1_name = ?, class_2_name = ?, class_3_name = ? WHERE id = ?`
return database.Exec(query, c.WorldSize, c.Open, c.AdminEmail, c.Class1Name, c.Class2Name, c.Class3Name, c.ID)
return c.db.Exec(query, c.WorldSize, c.Open, c.AdminEmail, c.Class1Name, c.Class2Name, c.Class3Name, c.ID)
}
// IsOpen returns true if the game world is open for new players
@ -106,7 +109,7 @@ func (c *Control) SetClassNames(classes []string) {
c.Class1Name = ""
c.Class2Name = ""
c.Class3Name = ""
// Set provided class names
if len(classes) > 0 {
c.Class1Name = classes[0]
@ -197,4 +200,4 @@ func (c *Control) IsWithinWorldBounds(x, y int) bool {
func (c *Control) GetWorldBounds() (minX, minY, maxX, maxY int) {
radius := c.GetWorldRadius()
return -radius, -radius, radius, radius
}
}

View File

@ -1,77 +0,0 @@
package cookies
import (
"time"
"github.com/valyala/fasthttp"
)
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}

View File

@ -11,80 +11,75 @@ import (
const DefaultPath = "dk.db"
// Global singleton instance
var pool *sqlitex.Pool
// database wraps a SQLite connection pool with simplified methods
type database struct {
pool *sqlitex.Pool
}
// Init initializes the global database connection pool
func Init(path string) error {
// DB is a backward-compatible type alias
type DB = database
// instance is the global singleton instance
var instance *database
// Open creates a new database connection pool
func Open(path string) (*database, error) {
if path == "" {
path = DefaultPath
}
poolSize := max(runtime.GOMAXPROCS(0), 2)
var err error
pool, err = sqlitex.NewPool(path, sqlitex.PoolOptions{
pool, err := sqlitex.NewPool(path, sqlitex.PoolOptions{
PoolSize: poolSize,
Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL,
})
if err != nil {
return fmt.Errorf("failed to open database pool: %w", err)
return nil, fmt.Errorf("failed to open database pool: %w", err)
}
conn, err := pool.Take(context.Background())
if err != nil {
pool.Close()
return fmt.Errorf("failed to get connection from pool: %w", err)
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
defer pool.Put(conn)
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
pool.Close()
return fmt.Errorf("failed to set WAL mode: %w", err)
return nil, fmt.Errorf("failed to set WAL mode: %w", err)
}
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
pool.Close()
return fmt.Errorf("failed to set synchronous mode: %w", err)
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
}
return nil
return &database{pool: pool}, nil
}
// Close closes the global database connection pool
func Close() error {
if pool == nil {
return nil
}
return pool.Close()
// Close closes the database connection pool
func (db *database) Close() error {
return db.pool.Close()
}
// GetConn gets a connection from the pool - caller must call PutConn when done
func GetConn(ctx context.Context) (*sqlite.Conn, error) {
if pool == nil {
return nil, fmt.Errorf("database not initialized")
}
return pool.Take(ctx)
// GetConn gets a connection from the pool - caller must call Put when done
func (db *database) GetConn(ctx context.Context) (*sqlite.Conn, error) {
return db.pool.Take(ctx)
}
// PutConn returns a connection to the pool
func PutConn(conn *sqlite.Conn) {
if pool != nil {
pool.Put(conn)
}
func (db *database) PutConn(conn *sqlite.Conn) {
db.pool.Put(conn)
}
// Exec executes a SQL statement without returning results
func Exec(query string, args ...any) error {
if pool == nil {
return fmt.Errorf("database not initialized")
}
conn, err := pool.Take(context.Background())
func (db *database) Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection from pool: %w", err)
}
defer pool.Put(conn)
defer db.pool.Put(conn)
if len(args) == 0 {
return sqlitex.ExecuteTransient(conn, query, nil)
@ -96,16 +91,12 @@ func Exec(query string, args ...any) error {
}
// Query executes a SQL query and calls fn for each row
func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
if pool == nil {
return fmt.Errorf("database not initialized")
}
conn, err := pool.Take(context.Background())
func (db *database) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection from pool: %w", err)
}
defer pool.Put(conn)
defer db.pool.Put(conn)
if len(args) == 0 {
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
@ -120,31 +111,23 @@ func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
}
// Begin starts a new transaction
func Begin() (*Tx, error) {
if pool == nil {
return nil, fmt.Errorf("database not initialized")
}
conn, err := pool.Take(context.Background())
func (db *database) Begin() (*Tx, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil {
pool.Put(conn)
db.pool.Put(conn)
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
return &Tx{conn: conn, pool: pool}, nil
return &Tx{conn: conn, pool: db.pool}, nil
}
// Transaction runs a function within a transaction
func Transaction(fn func(*Tx) error) error {
if pool == nil {
return fmt.Errorf("database not initialized")
}
tx, err := Begin()
func (db *database) Transaction(fn func(*Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
@ -199,3 +182,75 @@ func (tx *Tx) Rollback() error {
defer tx.pool.Put(tx.conn)
return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil)
}
// InitializeDB initializes the global DB singleton
func InitializeDB(path string) error {
db, err := Open(path)
if err != nil {
return err
}
instance = db
return nil
}
// GetDB returns the global database instance
func GetDB() *DB {
return instance
}
// Global convenience functions that use the singleton
// Exec executes a SQL statement without returning results using the global DB
func Exec(query string, args ...any) error {
if instance == nil {
return fmt.Errorf("database not initialized")
}
return instance.Exec(query, args...)
}
// Query executes a SQL query and calls fn for each row using the global DB
func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
if instance == nil {
return fmt.Errorf("database not initialized")
}
return instance.Query(query, fn, args...)
}
// Begin starts a new transaction using the global DB
func Begin() (*Tx, error) {
if instance == nil {
return nil, fmt.Errorf("database not initialized")
}
return instance.Begin()
}
// Transaction runs a function within a transaction using the global DB
func Transaction(fn func(*Tx) error) error {
if instance == nil {
return fmt.Errorf("database not initialized")
}
return instance.Transaction(fn)
}
// GetConn gets a connection from the pool using the global DB
func GetConn(ctx context.Context) (*sqlite.Conn, error) {
if instance == nil {
return nil, fmt.Errorf("database not initialized")
}
return instance.GetConn(ctx)
}
// PutConn returns a connection to the pool using the global DB
func PutConn(conn *sqlite.Conn) {
if instance != nil {
instance.PutConn(conn)
}
}
// Close closes the global database connection pool
func Close() error {
if instance == nil {
return nil
}
return instance.Close()
}

View File

@ -11,62 +11,62 @@ func TestDatabaseOperations(t *testing.T) {
// Use a temporary database file
testDB := "test.db"
defer os.Remove(testDB)
// Initialize the singleton database
err := Init(testDB)
// Test opening database
db, err := Open(testDB)
if err != nil {
t.Fatalf("Failed to initialize database: %v", err)
t.Fatalf("Failed to open database: %v", err)
}
defer Close()
defer db.Close()
// Test creating a simple table
err = Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)")
err = db.Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)")
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Test inserting data
err = Exec("INSERT INTO test_users (name) VALUES (?)", "Alice")
err = db.Exec("INSERT INTO test_users (name) VALUES (?)", "Alice")
if err != nil {
t.Fatalf("Failed to insert data: %v", err)
}
// Test querying data
var foundName string
err = Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error {
err = db.Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error {
foundName = stmt.ColumnText(0)
return nil
}, "Alice")
if err != nil {
t.Fatalf("Failed to query data: %v", err)
}
if foundName != "Alice" {
t.Errorf("Expected 'Alice', got '%s'", foundName)
}
// Test transaction
err = Transaction(func(tx *Tx) error {
err = db.Transaction(func(tx *Tx) error {
return tx.Exec("INSERT INTO test_users (name) VALUES (?)", "Bob")
})
if err != nil {
t.Fatalf("Transaction failed: %v", err)
}
// Verify transaction worked
var count int
err = Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error {
err = db.Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0)
return nil
})
if err != nil {
t.Fatalf("Failed to count users: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 users, got %d", count)
}
}
}

View File

@ -10,12 +10,14 @@ import (
// Builder provides a fluent interface for creating drops
type Builder struct {
drop *Drop
db *database.DB
}
// NewBuilder creates a new drop builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
return &Builder{
drop: &Drop{},
drop: &Drop{db: db},
db: db,
}
}
@ -47,8 +49,8 @@ func (b *Builder) WithAtt(att string) *Builder {
func (b *Builder) Create() (*Drop, error) {
// Use a transaction to ensure we can get the ID
var drop *Drop
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO drops (name, level, type, att)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO drops (name, level, type, att)
VALUES (?, ?, ?, ?)`
if err := tx.Exec(query, b.drop.Name, b.drop.Level, b.drop.Type, b.drop.Att); err != nil {
@ -72,6 +74,7 @@ func (b *Builder) Create() (*Drop, error) {
Level: b.drop.Level,
Type: b.drop.Type,
Att: b.drop.Att,
db: b.db,
}
return nil

View File

@ -15,6 +15,8 @@ type Drop struct {
Level int `json:"level"`
Type int `json:"type"`
Att string `json:"att"`
db *database.DB
}
// DropType constants for drop types
@ -23,11 +25,11 @@ const (
)
// Find retrieves a drop by ID
func Find(id int) (*Drop, error) {
drop := &Drop{}
func Find(db *database.DB, id int) (*Drop, error) {
drop := &Drop{db: db}
query := "SELECT id, name, level, type, att FROM drops WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
drop.ID = stmt.ColumnInt(0)
drop.Name = stmt.ColumnText(1)
drop.Level = stmt.ColumnInt(2)
@ -48,17 +50,18 @@ func Find(id int) (*Drop, error) {
}
// All retrieves all drops
func All() ([]*Drop, error) {
func All(db *database.DB) ([]*Drop, error) {
var drops []*Drop
query := "SELECT id, name, level, type, att FROM drops ORDER BY id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
drop := &Drop{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
Level: stmt.ColumnInt(2),
Type: stmt.ColumnInt(3),
Att: stmt.ColumnText(4),
db: db,
}
drops = append(drops, drop)
return nil
@ -72,17 +75,18 @@ func All() ([]*Drop, error) {
}
// ByLevel retrieves drops by minimum level requirement
func ByLevel(minLevel int) ([]*Drop, error) {
func ByLevel(db *database.DB, minLevel int) ([]*Drop, error) {
var drops []*Drop
query := "SELECT id, name, level, type, att FROM drops WHERE level <= ? ORDER BY level, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
drop := &Drop{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
Level: stmt.ColumnInt(2),
Type: stmt.ColumnInt(3),
Att: stmt.ColumnText(4),
db: db,
}
drops = append(drops, drop)
return nil
@ -96,17 +100,18 @@ func ByLevel(minLevel int) ([]*Drop, error) {
}
// ByType retrieves drops by type
func ByType(dropType int) ([]*Drop, error) {
func ByType(db *database.DB, dropType int) ([]*Drop, error) {
var drops []*Drop
query := "SELECT id, name, level, type, att FROM drops WHERE type = ? ORDER BY level, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
drop := &Drop{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
Level: stmt.ColumnInt(2),
Type: stmt.ColumnInt(3),
Att: stmt.ColumnText(4),
db: db,
}
drops = append(drops, drop)
return nil
@ -126,7 +131,7 @@ func (d *Drop) Save() error {
}
query := `UPDATE drops SET name = ?, level = ?, type = ?, att = ? WHERE id = ?`
return database.Exec(query, d.Name, d.Level, d.Type, d.Att, d.ID)
return d.db.Exec(query, d.Name, d.Level, d.Type, d.Att, d.ID)
}
// Delete removes the drop from the database
@ -135,7 +140,8 @@ func (d *Drop) Delete() error {
return fmt.Errorf("cannot delete drop without ID")
}
return database.Exec("DELETE FROM drops WHERE id = ?", d.ID)
query := "DELETE FROM drops WHERE id = ?"
return d.db.Exec(query, d.ID)
}
// IsConsumable returns true if the drop is a consumable item

View File

@ -12,18 +12,21 @@ import (
// Builder provides a fluent interface for creating forum posts
type Builder struct {
forum *Forum
db *database.DB
}
// NewBuilder creates a new forum post builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
now := time.Now().Unix()
return &Builder{
forum: &Forum{
db: db,
Posted: now,
LastPost: now, // Default to same as posted time
Parent: 0, // Default to thread (no parent)
Replies: 0, // Default to no replies
},
db: db,
}
}
@ -96,8 +99,8 @@ func (b *Builder) WithReplies(replies int) *Builder {
func (b *Builder) Create() (*Forum, error) {
// Use a transaction to ensure we can get the ID
var forum *Forum
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO forum (posted, last_post, author, parent, replies, title, content)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO forum (posted, last_post, author, parent, replies, title, content)
VALUES (?, ?, ?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.forum.Posted, b.forum.LastPost, b.forum.Author,
@ -125,4 +128,4 @@ func (b *Builder) Create() (*Forum, error) {
}
return forum, nil
}
}

View File

@ -20,14 +20,16 @@ type Forum struct {
Replies int `json:"replies"`
Title string `json:"title"`
Content string `json:"content"`
db *database.DB
}
// Find retrieves a forum post by ID
func Find(id int) (*Forum, error) {
forum := &Forum{}
func Find(db *database.DB, id int) (*Forum, error) {
forum := &Forum{db: db}
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum.ID = stmt.ColumnInt(0)
forum.Posted = stmt.ColumnInt64(1)
forum.LastPost = stmt.ColumnInt64(2)
@ -51,11 +53,11 @@ func Find(id int) (*Forum, error) {
}
// All retrieves all forum posts ordered by last post time (most recent first)
func All() ([]*Forum, error) {
func All(db *database.DB) ([]*Forum, error) {
var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum ORDER BY last_post DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
@ -65,6 +67,7 @@ func All() ([]*Forum, error) {
Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7),
db: db,
}
forums = append(forums, forum)
return nil
@ -78,11 +81,11 @@ func All() ([]*Forum, error) {
}
// Threads retrieves all top-level forum threads (parent = 0)
func Threads() ([]*Forum, error) {
func Threads(db *database.DB) ([]*Forum, error) {
var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE parent = 0 ORDER BY last_post DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
@ -92,6 +95,7 @@ func Threads() ([]*Forum, error) {
Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7),
db: db,
}
forums = append(forums, forum)
return nil
@ -105,11 +109,11 @@ func Threads() ([]*Forum, error) {
}
// ByParent retrieves all replies to a specific thread/post
func ByParent(parentID int) ([]*Forum, error) {
func ByParent(db *database.DB, parentID int) ([]*Forum, error) {
var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE parent = ? ORDER BY posted ASC, id ASC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
@ -119,6 +123,7 @@ func ByParent(parentID int) ([]*Forum, error) {
Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7),
db: db,
}
forums = append(forums, forum)
return nil
@ -132,11 +137,11 @@ func ByParent(parentID int) ([]*Forum, error) {
}
// ByAuthor retrieves forum posts by a specific author
func ByAuthor(authorID int) ([]*Forum, error) {
func ByAuthor(db *database.DB, authorID int) ([]*Forum, error) {
var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE author = ? ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
@ -146,6 +151,7 @@ func ByAuthor(authorID int) ([]*Forum, error) {
Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7),
db: db,
}
forums = append(forums, forum)
return nil
@ -159,11 +165,11 @@ func ByAuthor(authorID int) ([]*Forum, error) {
}
// Recent retrieves the most recent forum activity (limited by count)
func Recent(limit int) ([]*Forum, error) {
func Recent(db *database.DB, limit int) ([]*Forum, error) {
var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum ORDER BY last_post DESC, id DESC LIMIT ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
@ -173,6 +179,7 @@ func Recent(limit int) ([]*Forum, error) {
Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7),
db: db,
}
forums = append(forums, forum)
return nil
@ -186,13 +193,13 @@ func Recent(limit int) ([]*Forum, error) {
}
// Search retrieves forum posts containing the search term in title or content
func Search(term string) ([]*Forum, error) {
func Search(db *database.DB, term string) ([]*Forum, error) {
var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE LOWER(title) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?) ORDER BY last_post DESC, id DESC"
searchTerm := "%" + term + "%"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
@ -202,6 +209,7 @@ func Search(term string) ([]*Forum, error) {
Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7),
db: db,
}
forums = append(forums, forum)
return nil
@ -215,11 +223,11 @@ func Search(term string) ([]*Forum, error) {
}
// Since retrieves forum posts with activity since a specific timestamp
func Since(since int64) ([]*Forum, error) {
func Since(db *database.DB, since int64) ([]*Forum, error) {
var forums []*Forum
query := "SELECT id, posted, last_post, author, parent, replies, title, content FROM forum WHERE last_post >= ? ORDER BY last_post DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
forum := &Forum{
ID: stmt.ColumnInt(0),
Posted: stmt.ColumnInt64(1),
@ -229,6 +237,7 @@ func Since(since int64) ([]*Forum, error) {
Replies: stmt.ColumnInt(5),
Title: stmt.ColumnText(6),
Content: stmt.ColumnText(7),
db: db,
}
forums = append(forums, forum)
return nil
@ -248,7 +257,7 @@ func (f *Forum) Save() error {
}
query := `UPDATE forum SET posted = ?, last_post = ?, author = ?, parent = ?, replies = ?, title = ?, content = ? WHERE id = ?`
return database.Exec(query, f.Posted, f.LastPost, f.Author, f.Parent, f.Replies, f.Title, f.Content, f.ID)
return f.db.Exec(query, f.Posted, f.LastPost, f.Author, f.Parent, f.Replies, f.Title, f.Content, f.ID)
}
// Delete removes the forum post from the database
@ -257,7 +266,8 @@ func (f *Forum) Delete() error {
return fmt.Errorf("cannot delete forum post without ID")
}
return database.Exec("DELETE FROM forum WHERE id = ?", f.ID)
query := "DELETE FROM forum WHERE id = ?"
return f.db.Exec(query, f.ID)
}
// PostedTime returns the posted timestamp as a time.Time
@ -387,7 +397,7 @@ func (f *Forum) DecrementReplies() {
// GetReplies retrieves all direct replies to this post
func (f *Forum) GetReplies() ([]*Forum, error) {
return ByParent(f.ID)
return ByParent(f.db, f.ID)
}
// GetThread retrieves the parent thread (if this is a reply) or returns self (if this is a thread)
@ -395,5 +405,5 @@ func (f *Forum) GetThread() (*Forum, error) {
if f.IsThread() {
return f, nil
}
return Find(f.Parent)
}
return Find(f.db, f.Parent)
}

View File

@ -24,20 +24,21 @@ func Run() error {
start := time.Now()
if err := database.Init("dk.db"); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
db, err := database.Open(dbPath)
if err != nil {
return err
}
defer database.Close()
defer db.Close()
if err := createTables(); err != nil {
if err := createTables(db); err != nil {
return fmt.Errorf("failed to create tables: %w", err)
}
if err := populateData(); err != nil {
if err := populateData(db); err != nil {
return fmt.Errorf("failed to populate data: %w", err)
}
if err := createDemoUser(); err != nil {
if err := createDemoUser(db); err != nil {
return fmt.Errorf("failed to create demo user: %w", err)
}
@ -52,7 +53,7 @@ func Run() error {
return nil
}
func createTables() error {
func createTables(db *database.DB) error {
tables := []struct {
name string
sql string
@ -186,7 +187,7 @@ func createTables() error {
}
for _, table := range tables {
if err := database.Exec(table.sql); err != nil {
if err := db.Exec(table.sql); err != nil {
return fmt.Errorf("failed to create %s table: %w", table.name, err)
}
fmt.Printf("✓ %s table created\n", table.name)
@ -195,8 +196,8 @@ func createTables() error {
return nil
}
func populateData() error {
if err := database.Exec("INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')"); err != nil {
func populateData(db *database.DB) error {
if err := db.Exec("INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')"); err != nil {
return fmt.Errorf("failed to populate control table: %w", err)
}
fmt.Println("✓ control table populated")
@ -234,7 +235,7 @@ func populateData() error {
(30, 'Diamond', 50, 1, 'defensepower,150'),
(31, 'Memory Drop', 5, 1, 'expbonus,10'),
(32, 'Fortune Drop', 5, 1, 'goldbonus,10')`
if err := database.Exec(dropsSQL); err != nil {
if err := db.Exec(dropsSQL); err != nil {
return fmt.Errorf("failed to populate drops table: %w", err)
}
fmt.Println("✓ drops table populated")
@ -273,7 +274,7 @@ func populateData() error {
(31, 3, 'Large Shield', 2500, 30, ''),
(32, 3, 'Silver Shield', 10000, 60, ''),
(33, 3, 'Destiny Aegis', 25000, 100, 'maxhp,50')`
if err := database.Exec(itemsSQL); err != nil {
if err := db.Exec(itemsSQL); err != nil {
return fmt.Errorf("failed to populate items table: %w", err)
}
fmt.Println("✓ items table populated")
@ -430,12 +431,12 @@ func populateData() error {
(149, 'Titan', 360, 340, 270, 50, 2400, 800, 0),
(150, 'Black Daemon', 400, 400, 280, 50, 3000, 1000, 1),
(151, 'Lucifuge', 600, 600, 400, 50, 10000, 10000, 2)`
if err := database.Exec(monstersSQL); err != nil {
if err := db.Exec(monstersSQL); err != nil {
return fmt.Errorf("failed to populate monsters table: %w", err)
}
fmt.Println("✓ monsters table populated")
if err := database.Exec("INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')"); err != nil {
if err := db.Exec("INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')"); err != nil {
return fmt.Errorf("failed to populate news table: %w", err)
}
fmt.Println("✓ news table populated")
@ -460,7 +461,7 @@ func populateData() error {
(17, 'Ward', 10, 10, 5),
(18, 'Fend', 20, 25, 5),
(19, 'Barrier', 30, 50, 5)`
if err := database.Exec(spellsSQL); err != nil {
if err := db.Exec(spellsSQL); err != nil {
return fmt.Errorf("failed to populate spells table: %w", err)
}
fmt.Println("✓ spells table populated")
@ -474,7 +475,7 @@ func populateData() error {
(6, 'Hambry', 170, 170, 90, 1000, 80, '10,11,12,13,14,23,24,30,31'),
(7, 'Gilead', 200, -200, 100, 3000, 110, '12,13,14,15,24,25,26,32'),
(8, 'Endworld', -250, -250, 125, 9000, 160, '16,27,33')`
if err := database.Exec(townsSQL); err != nil {
if err := db.Exec(townsSQL); err != nil {
return fmt.Errorf("failed to populate towns table: %w", err)
}
fmt.Println("✓ towns table populated")
@ -482,16 +483,16 @@ func populateData() error {
return nil
}
func createDemoUser() error {
func createDemoUser(db *database.DB) error {
hashedPassword, err := password.Hash("Demo123!")
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
stmt := `INSERT INTO users (username, password, email, verified, class_id, auth)
stmt := `INSERT INTO users (username, password, email, verified, class_id, auth)
VALUES (?, ?, ?, 1, 1, 4)`
if err := database.Exec(stmt, "demo", hashedPassword, "demo@demo.com"); err != nil {
if err := db.Exec(stmt, "demo", hashedPassword, "demo@demo.com"); err != nil {
return fmt.Errorf("failed to create demo user: %w", err)
}

View File

@ -10,12 +10,14 @@ import (
// Builder provides a fluent interface for creating items
type Builder struct {
item *Item
db *database.DB
}
// NewBuilder creates a new item builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
return &Builder{
item: &Item{},
item: &Item{db: db},
db: db,
}
}
@ -53,8 +55,8 @@ func (b *Builder) WithSpecial(special string) *Builder {
func (b *Builder) Create() (*Item, error) {
// Use a transaction to ensure we can get the ID
var item *Item
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO items (type, name, value, att, special)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO items (type, name, value, att, special)
VALUES (?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.item.Type, b.item.Name, b.item.Value, b.item.Att, b.item.Special); err != nil {
@ -79,6 +81,7 @@ func (b *Builder) Create() (*Item, error) {
Value: b.item.Value,
Att: b.item.Att,
Special: b.item.Special,
db: b.db,
}
return nil

View File

@ -16,6 +16,8 @@ type Item struct {
Value int `json:"value"`
Att int `json:"att"`
Special string `json:"special"`
db *database.DB
}
// ItemType constants for item types
@ -26,11 +28,11 @@ const (
)
// Find retrieves an item by ID
func Find(id int) (*Item, error) {
item := &Item{}
func Find(db *database.DB, id int) (*Item, error) {
item := &Item{db: db}
query := "SELECT id, type, name, value, att, special FROM items WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
item.ID = stmt.ColumnInt(0)
item.Type = stmt.ColumnInt(1)
item.Name = stmt.ColumnText(2)
@ -52,11 +54,11 @@ func Find(id int) (*Item, error) {
}
// All retrieves all items
func All() ([]*Item, error) {
func All(db *database.DB) ([]*Item, error) {
var items []*Item
query := "SELECT id, type, name, value, att, special FROM items ORDER BY id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
item := &Item{
ID: stmt.ColumnInt(0),
Type: stmt.ColumnInt(1),
@ -64,6 +66,7 @@ func All() ([]*Item, error) {
Value: stmt.ColumnInt(3),
Att: stmt.ColumnInt(4),
Special: stmt.ColumnText(5),
db: db,
}
items = append(items, item)
return nil
@ -77,11 +80,11 @@ func All() ([]*Item, error) {
}
// ByType retrieves items by type
func ByType(itemType int) ([]*Item, error) {
func ByType(db *database.DB, itemType int) ([]*Item, error) {
var items []*Item
query := "SELECT id, type, name, value, att, special FROM items WHERE type = ? ORDER BY id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
item := &Item{
ID: stmt.ColumnInt(0),
Type: stmt.ColumnInt(1),
@ -89,6 +92,7 @@ func ByType(itemType int) ([]*Item, error) {
Value: stmt.ColumnInt(3),
Att: stmt.ColumnInt(4),
Special: stmt.ColumnText(5),
db: db,
}
items = append(items, item)
return nil
@ -108,7 +112,7 @@ func (i *Item) Save() error {
}
query := `UPDATE items SET type = ?, name = ?, value = ?, att = ?, special = ? WHERE id = ?`
return database.Exec(query, i.Type, i.Name, i.Value, i.Att, i.Special, i.ID)
return i.db.Exec(query, i.Type, i.Name, i.Value, i.Att, i.Special, i.ID)
}
// Delete removes the item from the database
@ -118,7 +122,7 @@ func (i *Item) Delete() error {
}
query := "DELETE FROM items WHERE id = ?"
return database.Exec(query, i.ID)
return i.db.Exec(query, i.ID)
}
// IsWeapon returns true if the item is a weapon

View File

@ -10,12 +10,14 @@ import (
// Builder provides a fluent interface for creating monsters
type Builder struct {
monster *Monster
db *database.DB
}
// NewBuilder creates a new monster builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
return &Builder{
monster: &Monster{},
monster: &Monster{db: db},
db: db,
}
}
@ -71,8 +73,8 @@ func (b *Builder) WithImmunity(immunity int) *Builder {
func (b *Builder) Create() (*Monster, error) {
// Use a transaction to ensure we can get the ID
var monster *Monster
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO monsters (name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO monsters (name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.monster.Name, b.monster.MaxHP, b.monster.MaxDmg, b.monster.Armor,
@ -101,6 +103,7 @@ func (b *Builder) Create() (*Monster, error) {
MaxExp: b.monster.MaxExp,
MaxGold: b.monster.MaxGold,
Immune: b.monster.Immune,
db: b.db,
}
return nil

View File

@ -19,6 +19,8 @@ type Monster struct {
MaxExp int `json:"max_exp"`
MaxGold int `json:"max_gold"`
Immune int `json:"immune"`
db *database.DB
}
// Immunity constants for monster immunity types
@ -29,11 +31,11 @@ const (
)
// Find retrieves a monster by ID
func Find(id int) (*Monster, error) {
monster := &Monster{}
func Find(db *database.DB, id int) (*Monster, error) {
monster := &Monster{db: db}
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
monster.ID = stmt.ColumnInt(0)
monster.Name = stmt.ColumnText(1)
monster.MaxHP = stmt.ColumnInt(2)
@ -58,11 +60,11 @@ func Find(id int) (*Monster, error) {
}
// All retrieves all monsters
func All() ([]*Monster, error) {
func All(db *database.DB) ([]*Monster, error) {
var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters ORDER BY level, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -73,6 +75,7 @@ func All() ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8),
db: db,
}
monsters = append(monsters, monster)
return nil
@ -86,11 +89,11 @@ func All() ([]*Monster, error) {
}
// ByLevel retrieves monsters by level
func ByLevel(level int) ([]*Monster, error) {
func ByLevel(db *database.DB, level int) ([]*Monster, error) {
var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE level = ? ORDER BY id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -101,6 +104,7 @@ func ByLevel(level int) ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8),
db: db,
}
monsters = append(monsters, monster)
return nil
@ -114,11 +118,11 @@ func ByLevel(level int) ([]*Monster, error) {
}
// ByLevelRange retrieves monsters within a level range (inclusive)
func ByLevelRange(minLevel, maxLevel int) ([]*Monster, error) {
func ByLevelRange(db *database.DB, minLevel, maxLevel int) ([]*Monster, error) {
var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE level BETWEEN ? AND ? ORDER BY level, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -129,6 +133,7 @@ func ByLevelRange(minLevel, maxLevel int) ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8),
db: db,
}
monsters = append(monsters, monster)
return nil
@ -142,11 +147,11 @@ func ByLevelRange(minLevel, maxLevel int) ([]*Monster, error) {
}
// ByImmunity retrieves monsters by immunity type
func ByImmunity(immunityType int) ([]*Monster, error) {
func ByImmunity(db *database.DB, immunityType int) ([]*Monster, error) {
var monsters []*Monster
query := "SELECT id, name, max_hp, max_dmg, armor, level, max_exp, max_gold, immune FROM monsters WHERE immune = ? ORDER BY level, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
monster := &Monster{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -157,6 +162,7 @@ func ByImmunity(immunityType int) ([]*Monster, error) {
MaxExp: stmt.ColumnInt(6),
MaxGold: stmt.ColumnInt(7),
Immune: stmt.ColumnInt(8),
db: db,
}
monsters = append(monsters, monster)
return nil
@ -176,7 +182,7 @@ func (m *Monster) Save() error {
}
query := `UPDATE monsters SET name = ?, max_hp = ?, max_dmg = ?, armor = ?, level = ?, max_exp = ?, max_gold = ?, immune = ? WHERE id = ?`
return database.Exec(query, m.Name, m.MaxHP, m.MaxDmg, m.Armor, m.Level, m.MaxExp, m.MaxGold, m.Immune, m.ID)
return m.db.Exec(query, m.Name, m.MaxHP, m.MaxDmg, m.Armor, m.Level, m.MaxExp, m.MaxGold, m.Immune, m.ID)
}
// Delete removes the monster from the database
@ -186,7 +192,7 @@ func (m *Monster) Delete() error {
}
query := "DELETE FROM monsters WHERE id = ?"
return database.Exec(query, m.ID)
return m.db.Exec(query, m.ID)
}
// IsHurtImmune returns true if the monster is immune to Hurt spells

View File

@ -12,14 +12,17 @@ import (
// Builder provides a fluent interface for creating news posts
type Builder struct {
news *News
db *database.DB
}
// NewBuilder creates a new news builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
return &Builder{
news: &News{
db: db,
Posted: time.Now().Unix(), // Default to current time
},
db: db,
}
}
@ -51,10 +54,10 @@ func (b *Builder) WithPostedTime(t time.Time) *Builder {
func (b *Builder) Create() (*News, error) {
// Use a transaction to ensure we can get the ID
var news *News
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO news (author, posted, content)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO news (author, posted, content)
VALUES (?, ?, ?)`
if err := tx.Exec(query, b.news.Author, b.news.Posted, b.news.Content); err != nil {
return fmt.Errorf("failed to insert news: %w", err)
}
@ -73,10 +76,10 @@ func (b *Builder) Create() (*News, error) {
news = b.news
return nil
})
if err != nil {
return nil, err
}
return news, nil
}
}

View File

@ -1,10 +1,11 @@
package news
import (
"dk/internal/database"
"fmt"
"time"
"dk/internal/database"
"zombiezen.com/go/sqlite"
)
@ -14,14 +15,16 @@ type News struct {
Author int `json:"author"`
Posted int64 `json:"posted"`
Content string `json:"content"`
db *database.DB
}
// Find retrieves a news post by ID
func Find(id int) (*News, error) {
news := &News{}
func Find(db *database.DB, id int) (*News, error) {
news := &News{db: db}
query := "SELECT id, author, posted, content FROM news WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
news.ID = stmt.ColumnInt(0)
news.Author = stmt.ColumnInt(1)
news.Posted = stmt.ColumnInt64(2)
@ -41,16 +44,17 @@ func Find(id int) (*News, error) {
}
// All retrieves all news posts ordered by posted date (newest first)
func All() ([]*News, error) {
func All(db *database.DB) ([]*News, error) {
var newsPosts []*News
query := "SELECT id, author, posted, content FROM news ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{
ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3),
db: db,
}
newsPosts = append(newsPosts, news)
return nil
@ -64,16 +68,17 @@ func All() ([]*News, error) {
}
// ByAuthor retrieves news posts by a specific author
func ByAuthor(authorID int) ([]*News, error) {
func ByAuthor(db *database.DB, authorID int) ([]*News, error) {
var newsPosts []*News
query := "SELECT id, author, posted, content FROM news WHERE author = ? ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{
ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3),
db: db,
}
newsPosts = append(newsPosts, news)
return nil
@ -87,16 +92,17 @@ func ByAuthor(authorID int) ([]*News, error) {
}
// Recent retrieves the most recent news posts (limited by count)
func Recent(limit int) ([]*News, error) {
func Recent(db *database.DB, limit int) ([]*News, error) {
var newsPosts []*News
query := "SELECT id, author, posted, content FROM news ORDER BY posted DESC, id DESC LIMIT ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{
ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3),
db: db,
}
newsPosts = append(newsPosts, news)
return nil
@ -110,16 +116,17 @@ func Recent(limit int) ([]*News, error) {
}
// Since retrieves news posts since a specific timestamp
func Since(since int64) ([]*News, error) {
func Since(db *database.DB, since int64) ([]*News, error) {
var newsPosts []*News
query := "SELECT id, author, posted, content FROM news WHERE posted >= ? ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{
ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3),
db: db,
}
newsPosts = append(newsPosts, news)
return nil
@ -133,16 +140,17 @@ func Since(since int64) ([]*News, error) {
}
// Between retrieves news posts between two timestamps (inclusive)
func Between(start, end int64) ([]*News, error) {
func Between(db *database.DB, start, end int64) ([]*News, error) {
var newsPosts []*News
query := "SELECT id, author, posted, content FROM news WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
news := &News{
ID: stmt.ColumnInt(0),
Author: stmt.ColumnInt(1),
Posted: stmt.ColumnInt64(2),
Content: stmt.ColumnText(3),
db: db,
}
newsPosts = append(newsPosts, news)
return nil
@ -162,7 +170,7 @@ func (n *News) Save() error {
}
query := `UPDATE news SET author = ?, posted = ?, content = ? WHERE id = ?`
return database.Exec(query, n.Author, n.Posted, n.Content, n.ID)
return n.db.Exec(query, n.Author, n.Posted, n.Content, n.ID)
}
// Delete removes the news post from the database
@ -172,7 +180,7 @@ func (n *News) Delete() error {
}
query := "DELETE FROM news WHERE id = ?"
return database.Exec(query, n.ID)
return n.db.Exec(query, n.ID)
}
// PostedTime returns the posted timestamp as a time.Time
@ -205,11 +213,11 @@ func (n *News) Preview(maxLength int) string {
if len(n.Content) <= maxLength {
return n.Content
}
if maxLength < 3 {
return n.Content[:maxLength]
}
return n.Content[:maxLength-3] + "..."
}
@ -218,11 +226,11 @@ func (n *News) WordCount() int {
if n.Content == "" {
return 0
}
// Simple word count by splitting on spaces
words := 0
inWord := false
for _, char := range n.Content {
if char == ' ' || char == '\t' || char == '\n' || char == '\r' {
if inWord {
@ -233,10 +241,10 @@ func (n *News) WordCount() int {
inWord = true
}
}
if inWord {
words++
}
return words
}
}

View File

@ -28,14 +28,18 @@ func Start(port string) error {
template.InitializeCache(cwd)
// Initialize database singleton
if err := database.Init("dk.db"); err != nil {
if err := database.InitializeDB("dk.db"); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
defer database.Close()
auth.Init("sessions.json") // Initialize auth.Manager
// Initialize auth singleton
auth.InitializeManager(database.GetDB(), "sessions.json")
// Initialize router
r := router.New()
// Add middleware
r.Use(middleware.Timing())
r.Use(middleware.Auth(auth.Manager))
r.Use(middleware.CSRF(auth.Manager))

View File

@ -10,12 +10,14 @@ import (
// Builder provides a fluent interface for creating spells
type Builder struct {
spell *Spell
db *database.DB
}
// NewBuilder creates a new spell builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
return &Builder{
spell: &Spell{},
spell: &Spell{db: db},
db: db,
}
}
@ -47,8 +49,8 @@ func (b *Builder) WithType(spellType int) *Builder {
func (b *Builder) Create() (*Spell, error) {
// Use a transaction to ensure we can get the ID
var spell *Spell
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO spells (name, mp, attribute, type)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO spells (name, mp, attribute, type)
VALUES (?, ?, ?, ?)`
if err := tx.Exec(query, b.spell.Name, b.spell.MP, b.spell.Attribute, b.spell.Type); err != nil {
@ -72,6 +74,7 @@ func (b *Builder) Create() (*Spell, error) {
MP: b.spell.MP,
Attribute: b.spell.Attribute,
Type: b.spell.Type,
db: b.db,
}
return nil

View File

@ -15,6 +15,8 @@ type Spell struct {
MP int `json:"mp"`
Attribute int `json:"attribute"`
Type int `json:"type"`
db *database.DB
}
// SpellType constants for spell types
@ -27,11 +29,11 @@ const (
)
// Find retrieves a spell by ID
func Find(id int) (*Spell, error) {
spell := &Spell{}
func Find(db *database.DB, id int) (*Spell, error) {
spell := &Spell{db: db}
query := "SELECT id, name, mp, attribute, type FROM spells WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
spell.ID = stmt.ColumnInt(0)
spell.Name = stmt.ColumnText(1)
spell.MP = stmt.ColumnInt(2)
@ -52,17 +54,18 @@ func Find(id int) (*Spell, error) {
}
// All retrieves all spells
func All() ([]*Spell, error) {
func All(db *database.DB) ([]*Spell, error) {
var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells ORDER BY type, mp, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4),
db: db,
}
spells = append(spells, spell)
return nil
@ -76,17 +79,18 @@ func All() ([]*Spell, error) {
}
// ByType retrieves spells by type
func ByType(spellType int) ([]*Spell, error) {
func ByType(db *database.DB, spellType int) ([]*Spell, error) {
var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells WHERE type = ? ORDER BY mp, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4),
db: db,
}
spells = append(spells, spell)
return nil
@ -100,17 +104,18 @@ func ByType(spellType int) ([]*Spell, error) {
}
// ByMaxMP retrieves spells that cost at most the specified MP
func ByMaxMP(maxMP int) ([]*Spell, error) {
func ByMaxMP(db *database.DB, maxMP int) ([]*Spell, error) {
var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells WHERE mp <= ? ORDER BY type, mp, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4),
db: db,
}
spells = append(spells, spell)
return nil
@ -124,17 +129,18 @@ func ByMaxMP(maxMP int) ([]*Spell, error) {
}
// ByTypeAndMaxMP retrieves spells of a specific type that cost at most the specified MP
func ByTypeAndMaxMP(spellType, maxMP int) ([]*Spell, error) {
func ByTypeAndMaxMP(db *database.DB, spellType, maxMP int) ([]*Spell, error) {
var spells []*Spell
query := "SELECT id, name, mp, attribute, type FROM spells WHERE type = ? AND mp <= ? ORDER BY mp, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
spell := &Spell{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
MP: stmt.ColumnInt(2),
Attribute: stmt.ColumnInt(3),
Type: stmt.ColumnInt(4),
db: db,
}
spells = append(spells, spell)
return nil
@ -148,11 +154,11 @@ func ByTypeAndMaxMP(spellType, maxMP int) ([]*Spell, error) {
}
// ByName retrieves a spell by name (case-insensitive)
func ByName(name string) (*Spell, error) {
spell := &Spell{}
func ByName(db *database.DB, name string) (*Spell, error) {
spell := &Spell{db: db}
query := "SELECT id, name, mp, attribute, type FROM spells WHERE LOWER(name) = LOWER(?) LIMIT 1"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
spell.ID = stmt.ColumnInt(0)
spell.Name = stmt.ColumnText(1)
spell.MP = stmt.ColumnInt(2)
@ -179,7 +185,7 @@ func (s *Spell) Save() error {
}
query := `UPDATE spells SET name = ?, mp = ?, attribute = ?, type = ? WHERE id = ?`
return database.Exec(query, s.Name, s.MP, s.Attribute, s.Type, s.ID)
return s.db.Exec(query, s.Name, s.MP, s.Attribute, s.Type, s.ID)
}
// Delete removes the spell from the database
@ -189,7 +195,7 @@ func (s *Spell) Delete() error {
}
query := "DELETE FROM spells WHERE id = ?"
return database.Exec(query, s.ID)
return s.db.Exec(query, s.ID)
}
// IsHealing returns true if the spell is a healing spell

View File

@ -11,12 +11,16 @@ import (
// Builder provides a fluent interface for creating towns
type Builder struct {
town *Town
db *database.DB
}
// NewBuilder creates a new town builder
func NewBuilder() *Builder {
func NewBuilder(db *database.DB) *Builder {
return &Builder{
town: &Town{},
town: &Town{
db: db,
},
db: db,
}
}
@ -79,11 +83,11 @@ func (b *Builder) WithShopItems(items []string) *Builder {
func (b *Builder) Create() (*Town, error) {
// Use a transaction to ensure we can get the ID
var town *Town
err := database.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO towns (name, x, y, inn_cost, map_cost, tp_cost, shop_list)
err := b.db.Transaction(func(tx *database.Tx) error {
query := `INSERT INTO towns (name, x, y, inn_cost, map_cost, tp_cost, shop_list)
VALUES (?, ?, ?, ?, ?, ?, ?)`
if err := tx.Exec(query, b.town.Name, b.town.X, b.town.Y,
if err := tx.Exec(query, b.town.Name, b.town.X, b.town.Y,
b.town.InnCost, b.town.MapCost, b.town.TPCost, b.town.ShopList); err != nil {
return fmt.Errorf("failed to insert town: %w", err)
}
@ -102,10 +106,10 @@ func (b *Builder) Create() (*Town, error) {
town = b.town
return nil
})
if err != nil {
return nil, err
}
return town, nil
}
}

View File

@ -19,14 +19,16 @@ type Town struct {
MapCost int `json:"map_cost"`
TPCost int `json:"tp_cost"`
ShopList string `json:"shop_list"`
db *database.DB
}
// Find retrieves a town by ID
func Find(id int) (*Town, error) {
town := &Town{}
func Find(db *database.DB, id int) (*Town, error) {
town := &Town{db: db}
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE id = ?"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
town.ID = stmt.ColumnInt(0)
town.Name = stmt.ColumnText(1)
town.X = stmt.ColumnInt(2)
@ -50,11 +52,11 @@ func Find(id int) (*Town, error) {
}
// All retrieves all towns
func All() ([]*Town, error) {
func All(db *database.DB) ([]*Town, error) {
var towns []*Town
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns ORDER BY id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -64,6 +66,7 @@ func All() ([]*Town, error) {
MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7),
db: db,
}
towns = append(towns, town)
return nil
@ -77,11 +80,11 @@ func All() ([]*Town, error) {
}
// ByName retrieves a town by name (case-insensitive)
func ByName(name string) (*Town, error) {
town := &Town{}
func ByName(db *database.DB, name string) (*Town, error) {
town := &Town{db: db}
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE LOWER(name) = LOWER(?) LIMIT 1"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
town.ID = stmt.ColumnInt(0)
town.Name = stmt.ColumnText(1)
town.X = stmt.ColumnInt(2)
@ -105,11 +108,11 @@ func ByName(name string) (*Town, error) {
}
// ByMaxInnCost retrieves towns with inn cost at most the specified amount
func ByMaxInnCost(maxCost int) ([]*Town, error) {
func ByMaxInnCost(db *database.DB, maxCost int) ([]*Town, error) {
var towns []*Town
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE inn_cost <= ? ORDER BY inn_cost, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -119,6 +122,7 @@ func ByMaxInnCost(maxCost int) ([]*Town, error) {
MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7),
db: db,
}
towns = append(towns, town)
return nil
@ -132,11 +136,11 @@ func ByMaxInnCost(maxCost int) ([]*Town, error) {
}
// ByMaxTPCost retrieves towns with teleport cost at most the specified amount
func ByMaxTPCost(maxCost int) ([]*Town, error) {
func ByMaxTPCost(db *database.DB, maxCost int) ([]*Town, error) {
var towns []*Town
query := "SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list FROM towns WHERE tp_cost <= ? ORDER BY tp_cost, id"
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -146,6 +150,7 @@ func ByMaxTPCost(maxCost int) ([]*Town, error) {
MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7),
db: db,
}
towns = append(towns, town)
return nil
@ -159,16 +164,16 @@ func ByMaxTPCost(maxCost int) ([]*Town, error) {
}
// ByDistance retrieves towns within a certain distance from a point
func ByDistance(fromX, fromY, maxDistance int) ([]*Town, error) {
func ByDistance(db *database.DB, fromX, fromY, maxDistance int) ([]*Town, error) {
var towns []*Town
query := `SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list
FROM towns
query := `SELECT id, name, x, y, inn_cost, map_cost, tp_cost, shop_list
FROM towns
WHERE ((x - ?) * (x - ?) + (y - ?) * (y - ?)) <= ?
ORDER BY ((x - ?) * (x - ?) + (y - ?) * (y - ?)), id`
maxDistance2 := maxDistance * maxDistance
err := database.Query(query, func(stmt *sqlite.Stmt) error {
err := db.Query(query, func(stmt *sqlite.Stmt) error {
town := &Town{
ID: stmt.ColumnInt(0),
Name: stmt.ColumnText(1),
@ -178,6 +183,7 @@ func ByDistance(fromX, fromY, maxDistance int) ([]*Town, error) {
MapCost: stmt.ColumnInt(5),
TPCost: stmt.ColumnInt(6),
ShopList: stmt.ColumnText(7),
db: db,
}
towns = append(towns, town)
return nil
@ -197,7 +203,7 @@ func (t *Town) Save() error {
}
query := `UPDATE towns SET name = ?, x = ?, y = ?, inn_cost = ?, map_cost = ?, tp_cost = ?, shop_list = ? WHERE id = ?`
return database.Exec(query, t.Name, t.X, t.Y, t.InnCost, t.MapCost, t.TPCost, t.ShopList, t.ID)
return t.db.Exec(query, t.Name, t.X, t.Y, t.InnCost, t.MapCost, t.TPCost, t.ShopList, t.ID)
}
// Delete removes the town from the database
@ -207,7 +213,7 @@ func (t *Town) Delete() error {
}
query := "DELETE FROM towns WHERE id = ?"
return database.Exec(query, t.ID)
return t.db.Exec(query, t.ID)
}
// GetShopItems returns the shop items as a slice of item IDs
@ -259,4 +265,4 @@ func (t *Town) CanAffordMap(gold int) bool {
// CanAffordTeleport returns true if the player can afford to teleport here
func (t *Town) CanAffordTeleport(gold int) bool {
return gold >= t.TPCost
}
}

View File

@ -1,10 +0,0 @@
package utils
import "github.com/valyala/fasthttp"
// IsHTTPS tries to determine if the current request context is over HTTPS
func IsHTTPS(ctx *fasthttp.RequestCtx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}