write database wrapper

This commit is contained in:
Sky Johnson 2025-08-08 22:42:41 -05:00
parent d04acc06eb
commit db2a95cd02
4 changed files with 313 additions and 69 deletions

View File

@ -0,0 +1,119 @@
package database
import (
"fmt"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
const DefaultPath = "dk.db"
// DB wraps a SQLite connection with simplified methods
type DB struct {
conn *sqlite.Conn
}
// Open creates a new database connection
func Open(path string) (*DB, error) {
if path == "" {
path = DefaultPath
}
conn, err := sqlite.OpenConn(path, sqlite.OpenCreate|sqlite.OpenReadWrite|sqlite.OpenWAL)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Set pragmas for performance
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to set WAL mode: %w", err)
}
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
}
return &DB{conn: conn}, nil
}
// Close closes the database connection
func (db *DB) Close() error {
return db.conn.Close()
}
// Exec executes a SQL statement without returning results
func (db *DB) Exec(query string, args ...any) error {
if len(args) == 0 {
return sqlitex.ExecuteTransient(db.conn, query, nil)
}
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
Args: args,
})
}
// Query executes a SQL query and calls fn for each row
func (db *DB) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
if len(args) == 0 {
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
ResultFunc: fn,
})
}
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
Args: args,
ResultFunc: fn,
})
}
// Begin starts a new transaction
func (db *DB) Begin() (*Tx, error) {
if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
return &Tx{db: db}, nil
}
// Transaction runs a function within a transaction
func (db *DB) Transaction(fn func(*Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
// Tx represents a database transaction
type Tx struct {
db *DB
}
// Exec executes a SQL statement within the transaction
func (tx *Tx) Exec(query string, args ...any) error {
return tx.db.Exec(query, args...)
}
// Query executes a SQL query within the transaction
func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
return tx.db.Query(query, fn, args...)
}
// Commit commits the transaction
func (tx *Tx) Commit() error {
return sqlitex.ExecuteTransient(tx.db.conn, "COMMIT", nil)
}
// Rollback rolls back the transaction
func (tx *Tx) Rollback() error {
return sqlitex.ExecuteTransient(tx.db.conn, "ROLLBACK", nil)
}

View File

@ -0,0 +1,72 @@
package database
import (
"os"
"testing"
"zombiezen.com/go/sqlite"
)
func TestDatabaseOperations(t *testing.T) {
// Use a temporary database file
testDB := "test.db"
defer os.Remove(testDB)
// Test opening database
db, err := Open(testDB)
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Test creating a simple table
err = db.Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)")
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Test inserting data
err = db.Exec("INSERT INTO test_users (name) VALUES (?)", "Alice")
if err != nil {
t.Fatalf("Failed to insert data: %v", err)
}
// Test querying data
var foundName string
err = db.Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error {
foundName = stmt.ColumnText(0)
return nil
}, "Alice")
if err != nil {
t.Fatalf("Failed to query data: %v", err)
}
if foundName != "Alice" {
t.Errorf("Expected 'Alice', got '%s'", foundName)
}
// Test transaction
err = db.Transaction(func(tx *Tx) error {
return tx.Exec("INSERT INTO test_users (name) VALUES (?)", "Bob")
})
if err != nil {
t.Fatalf("Transaction failed: %v", err)
}
// Verify transaction worked
var count int
err = db.Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0)
return nil
})
if err != nil {
t.Fatalf("Failed to count users: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 users, got %d", count)
}
}

View File

@ -4,10 +4,8 @@ import (
"fmt" "fmt"
"time" "time"
"dk/internal/database"
"dk/internal/password" "dk/internal/password"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
) )
const dbPath = "dk.db" const dbPath = "dk.db"
@ -19,24 +17,24 @@ func Run() error {
start := time.Now() start := time.Now()
// Open database connection // Open database connection
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenCreate|sqlite.OpenReadWrite|sqlite.OpenWAL) db, err := database.Open(dbPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open database: %w", err) return err
} }
defer conn.Close() defer db.Close()
// Create tables // Create tables
if err := createTables(conn); err != nil { if err := createTables(db); err != nil {
return fmt.Errorf("failed to create tables: %w", err) return fmt.Errorf("failed to create tables: %w", err)
} }
// Populate initial data // Populate initial data
if err := populateData(conn); err != nil { if err := populateData(db); err != nil {
return fmt.Errorf("failed to populate data: %w", err) return fmt.Errorf("failed to populate data: %w", err)
} }
// Create demo user // Create demo user
if err := createDemoUser(conn); err != nil { if err := createDemoUser(db); err != nil {
return fmt.Errorf("failed to create demo user: %w", err) return fmt.Errorf("failed to create demo user: %w", err)
} }
@ -51,7 +49,7 @@ func Run() error {
return nil return nil
} }
func createTables(conn *sqlite.Conn) error { func createTables(db *database.DB) error {
tables := []struct { tables := []struct {
name string name string
sql string sql string
@ -186,7 +184,7 @@ func createTables(conn *sqlite.Conn) error {
} }
for _, table := range tables { for _, table := range tables {
if err := sqlitex.ExecuteTransient(conn, table.sql, nil); err != nil { if err := db.Exec(table.sql); err != nil {
return fmt.Errorf("failed to create %s table: %w", table.name, err) return fmt.Errorf("failed to create %s table: %w", table.name, err)
} }
fmt.Printf("✓ %s table created\n", table.name) fmt.Printf("✓ %s table created\n", table.name)
@ -195,27 +193,25 @@ func createTables(conn *sqlite.Conn) error {
return nil return nil
} }
func populateData(conn *sqlite.Conn) error { func populateData(db *database.DB) error {
if err := sqlitex.ExecuteTransient(conn, if err := db.Exec("INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')"); err != nil {
"INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')",
nil); err != nil {
return fmt.Errorf("failed to populate control table: %w", err) return fmt.Errorf("failed to populate control table: %w", err)
} }
fmt.Println("✓ Control table populated") fmt.Println("✓ Control table populated")
dropsSQL := `INSERT INTO drops VALUES dropsSQL := `INSERT INTO drops VALUES
(1, 'Life Pebble', 1, 1, 'maxhp,10', 'X'), (1, 'Life Pebble', 1, 1, 'maxhp,10', ''),
(2, 'Life Stone', 10, 1, 'maxhp,25', 'X'), (2, 'Life Stone', 10, 1, 'maxhp,25', ''),
(3, 'Life Rock', 25, 1, 'maxhp,50', 'X'), (3, 'Life Rock', 25, 1, 'maxhp,50', ''),
(4, 'Magic Pebble', 1, 1, 'maxmp,10', 'X'), (4, 'Magic Pebble', 1, 1, 'maxmp,10', ''),
(5, 'Magic Stone', 10, 1, 'maxmp,25', 'X'), (5, 'Magic Stone', 10, 1, 'maxmp,25', ''),
(6, 'Magic Rock', 25, 1, 'maxmp,50', 'X'), (6, 'Magic Rock', 25, 1, 'maxmp,50', ''),
(7, 'Dragon''s Scale', 10, 1, 'defensepower,25', 'X'), (7, 'Dragon''s Scale', 10, 1, 'defensepower,25', ''),
(8, 'Dragon''s Plate', 30, 1, 'defensepower,50', 'X'), (8, 'Dragon''s Plate', 30, 1, 'defensepower,50', ''),
(9, 'Dragon''s Claw', 10, 1, 'attackpower,25', 'X'), (9, 'Dragon''s Claw', 10, 1, 'attackpower,25', ''),
(10, 'Dragon''s Tooth', 30, 1, 'attackpower,50', 'X'), (10, 'Dragon''s Tooth', 30, 1, 'attackpower,50', ''),
(11, 'Dragon''s Tear', 35, 1, 'strength,50', 'X'), (11, 'Dragon''s Tear', 35, 1, 'strength,50', ''),
(12, 'Dragon''s Wing', 35, 1, 'dexterity,50', 'X'), (12, 'Dragon''s Wing', 35, 1, 'dexterity,50', ''),
(13, 'Demon''s Sin', 35, 1, 'maxhp,-50', 'strength,50'), (13, 'Demon''s Sin', 35, 1, 'maxhp,-50', 'strength,50'),
(14, 'Demon''s Fall', 35, 1, 'maxmp,-50', 'strength,50'), (14, 'Demon''s Fall', 35, 1, 'maxmp,-50', 'strength,50'),
(15, 'Demon''s Lie', 45, 1, 'maxhp,-100', 'strength,100'), (15, 'Demon''s Lie', 45, 1, 'maxhp,-100', 'strength,100'),
@ -228,54 +224,54 @@ func populateData(conn *sqlite.Conn) error {
(22, 'Seraph''s Rise', 30, 1, 'maxmp,50', 'dexterity,50'), (22, 'Seraph''s Rise', 30, 1, 'maxmp,50', 'dexterity,50'),
(23, 'Seraph''s Truth', 35, 1, 'maxmp,75', 'dexterity,75'), (23, 'Seraph''s Truth', 35, 1, 'maxmp,75', 'dexterity,75'),
(24, 'Seraph''s Love', 40, 1, 'maxmp,100', 'dexterity,100'), (24, 'Seraph''s Love', 40, 1, 'maxmp,100', 'dexterity,100'),
(25, 'Ruby', 50, 1, 'maxhp,150', 'X'), (25, 'Ruby', 50, 1, 'maxhp,150', ''),
(26, 'Pearl', 50, 1, 'maxmp,150', 'X'), (26, 'Pearl', 50, 1, 'maxmp,150', ''),
(27, 'Emerald', 50, 1, 'strength,150', 'X'), (27, 'Emerald', 50, 1, 'strength,150', ''),
(28, 'Topaz', 50, 1, 'dexterity,150', 'X'), (28, 'Topaz', 50, 1, 'dexterity,150', ''),
(29, 'Obsidian', 50, 1, 'attackpower,150', 'X'), (29, 'Obsidian', 50, 1, 'attackpower,150', ''),
(30, 'Diamond', 50, 1, 'defensepower,150', 'X'), (30, 'Diamond', 50, 1, 'defensepower,150', ''),
(31, 'Memory Drop', 5, 1, 'expbonus,10', 'X'), (31, 'Memory Drop', 5, 1, 'expbonus,10', ''),
(32, 'Fortune Drop', 5, 1, 'goldbonus,10', 'X')` (32, 'Fortune Drop', 5, 1, 'goldbonus,10', '')`
if err := sqlitex.ExecuteTransient(conn, dropsSQL, nil); err != nil { if err := db.Exec(dropsSQL); err != nil {
return fmt.Errorf("failed to populate drops table: %w", err) return fmt.Errorf("failed to populate drops table: %w", err)
} }
fmt.Println("✓ Drops table populated") fmt.Println("✓ Drops table populated")
itemsSQL := `INSERT INTO items VALUES itemsSQL := `INSERT INTO items VALUES
(1, 1, 'Stick', 10, 2, 'X'), (1, 1, 'Stick', 10, 2, ''),
(2, 1, 'Branch', 30, 4, 'X'), (2, 1, 'Branch', 30, 4, ''),
(3, 1, 'Club', 40, 5, 'X'), (3, 1, 'Club', 40, 5, ''),
(4, 1, 'Dagger', 90, 8, 'X'), (4, 1, 'Dagger', 90, 8, ''),
(5, 1, 'Hatchet', 150, 12, 'X'), (5, 1, 'Hatchet', 150, 12, ''),
(6, 1, 'Axe', 200, 16, 'X'), (6, 1, 'Axe', 200, 16, ''),
(7, 1, 'Brand', 300, 25, 'X'), (7, 1, 'Brand', 300, 25, ''),
(8, 1, 'Poleaxe', 500, 35, 'X'), (8, 1, 'Poleaxe', 500, 35, ''),
(9, 1, 'Broadsword', 800, 45, 'X'), (9, 1, 'Broadsword', 800, 45, ''),
(10, 1, 'Battle Axe', 1200, 50, 'X'), (10, 1, 'Battle Axe', 1200, 50, ''),
(11, 1, 'Claymore', 2000, 60, 'X'), (11, 1, 'Claymore', 2000, 60, ''),
(12, 1, 'Dark Axe', 3000, 100, 'expbonus,-5'), (12, 1, 'Dark Axe', 3000, 100, 'expbonus,-5'),
(13, 1, 'Dark Sword', 4500, 125, 'expbonus,-10'), (13, 1, 'Dark Sword', 4500, 125, 'expbonus,-10'),
(14, 1, 'Bright Sword', 6000, 100, 'expbonus,10'), (14, 1, 'Bright Sword', 6000, 100, 'expbonus,10'),
(15, 1, 'Magic Sword', 10000, 150, 'maxmp,50'), (15, 1, 'Magic Sword', 10000, 150, 'maxmp,50'),
(16, 1, 'Destiny Blade', 50000, 250, 'strength,50'), (16, 1, 'Destiny Blade', 50000, 250, 'strength,50'),
(17, 2, 'Skivvies', 25, 2, 'goldbonus,10'), (17, 2, 'Skivvies', 25, 2, 'goldbonus,10'),
(18, 2, 'Clothes', 50, 5, 'X'), (18, 2, 'Clothes', 50, 5, ''),
(19, 2, 'Leather Armor', 75, 10, 'X'), (19, 2, 'Leather Armor', 75, 10, ''),
(20, 2, 'Hard Leather Armor', 150, 25, 'X'), (20, 2, 'Hard Leather Armor', 150, 25, ''),
(21, 2, 'Chain Mail', 300, 30, 'X'), (21, 2, 'Chain Mail', 300, 30, ''),
(22, 2, 'Bronze Plate', 900, 50, 'X'), (22, 2, 'Bronze Plate', 900, 50, ''),
(23, 2, 'Iron Plate', 2000, 100, 'X'), (23, 2, 'Iron Plate', 2000, 100, ''),
(24, 2, 'Magic Armor', 4000, 125, 'maxmp,50'), (24, 2, 'Magic Armor', 4000, 125, 'maxmp,50'),
(25, 2, 'Dark Armor', 5000, 150, 'expbonus,-10'), (25, 2, 'Dark Armor', 5000, 150, 'expbonus,-10'),
(26, 2, 'Bright Armor', 10000, 175, 'expbonus,10'), (26, 2, 'Bright Armor', 10000, 175, 'expbonus,10'),
(27, 2, 'Destiny Raiment', 50000, 200, 'dexterity,50'), (27, 2, 'Destiny Raiment', 50000, 200, 'dexterity,50'),
(28, 3, 'Reed Shield', 50, 2, 'X'), (28, 3, 'Reed Shield', 50, 2, ''),
(29, 3, 'Buckler', 100, 4, 'X'), (29, 3, 'Buckler', 100, 4, ''),
(30, 3, 'Small Shield', 500, 10, 'X'), (30, 3, 'Small Shield', 500, 10, ''),
(31, 3, 'Large Shield', 2500, 30, 'X'), (31, 3, 'Large Shield', 2500, 30, ''),
(32, 3, 'Silver Shield', 10000, 60, 'X'), (32, 3, 'Silver Shield', 10000, 60, ''),
(33, 3, 'Destiny Aegis', 25000, 100, 'maxhp,50')` (33, 3, 'Destiny Aegis', 25000, 100, 'maxhp,50')`
if err := sqlitex.ExecuteTransient(conn, itemsSQL, nil); err != nil { if err := db.Exec(itemsSQL); err != nil {
return fmt.Errorf("failed to populate items table: %w", err) return fmt.Errorf("failed to populate items table: %w", err)
} }
fmt.Println("✓ Items table populated") fmt.Println("✓ Items table populated")
@ -432,15 +428,13 @@ func populateData(conn *sqlite.Conn) error {
(149, 'Titan', 360, 340, 270, 50, 2400, 800, 0), (149, 'Titan', 360, 340, 270, 50, 2400, 800, 0),
(150, 'Black Daemon', 400, 400, 280, 50, 3000, 1000, 1), (150, 'Black Daemon', 400, 400, 280, 50, 3000, 1000, 1),
(151, 'Lucifuge', 600, 600, 400, 50, 10000, 10000, 2)` (151, 'Lucifuge', 600, 600, 400, 50, 10000, 10000, 2)`
if err := sqlitex.ExecuteTransient(conn, monstersSQL, nil); err != nil { if err := db.Exec(monstersSQL); err != nil {
return fmt.Errorf("failed to populate monsters table: %w", err) return fmt.Errorf("failed to populate monsters table: %w", err)
} }
fmt.Println("✓ Monsters table populated (sample data)") fmt.Println("✓ Monsters table populated (sample data)")
// News table // News table
if err := sqlitex.ExecuteTransient(conn, if err := db.Exec("INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')"); err != nil {
"INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')",
nil); err != nil {
return fmt.Errorf("failed to populate news table: %w", err) return fmt.Errorf("failed to populate news table: %w", err)
} }
fmt.Println("✓ News table populated") fmt.Println("✓ News table populated")
@ -466,7 +460,7 @@ func populateData(conn *sqlite.Conn) error {
(17, 'Ward', 10, 10, 5), (17, 'Ward', 10, 10, 5),
(18, 'Fend', 20, 25, 5), (18, 'Fend', 20, 25, 5),
(19, 'Barrier', 30, 50, 5)` (19, 'Barrier', 30, 50, 5)`
if err := sqlitex.ExecuteTransient(conn, spellsSQL, nil); err != nil { if err := db.Exec(spellsSQL); err != nil {
return fmt.Errorf("failed to populate spells table: %w", err) return fmt.Errorf("failed to populate spells table: %w", err)
} }
fmt.Println("✓ Spells table populated") fmt.Println("✓ Spells table populated")
@ -481,7 +475,7 @@ func populateData(conn *sqlite.Conn) error {
(6, 'Hambry', 170, 170, 90, 1000, 80, '10,11,12,13,14,23,24,30,31'), (6, 'Hambry', 170, 170, 90, 1000, 80, '10,11,12,13,14,23,24,30,31'),
(7, 'Gilead', 200, -200, 100, 3000, 110, '12,13,14,15,24,25,26,32'), (7, 'Gilead', 200, -200, 100, 3000, 110, '12,13,14,15,24,25,26,32'),
(8, 'Endworld', -250, -250, 125, 9000, 160, '16,27,33')` (8, 'Endworld', -250, -250, 125, 9000, 160, '16,27,33')`
if err := sqlitex.ExecuteTransient(conn, townsSQL, nil); err != nil { if err := db.Exec(townsSQL); err != nil {
return fmt.Errorf("failed to populate towns table: %w", err) return fmt.Errorf("failed to populate towns table: %w", err)
} }
fmt.Println("✓ Towns table populated") fmt.Println("✓ Towns table populated")
@ -489,7 +483,7 @@ func populateData(conn *sqlite.Conn) error {
return nil return nil
} }
func createDemoUser(conn *sqlite.Conn) error { func createDemoUser(db *database.DB) error {
// Hash the password using argon2id // Hash the password using argon2id
hashedPassword, err := password.Hash("Demo123!") hashedPassword, err := password.Hash("Demo123!")
if err != nil { if err != nil {
@ -499,9 +493,7 @@ func createDemoUser(conn *sqlite.Conn) error {
stmt := `INSERT INTO users (username, password, email, verified, class_id, auth) stmt := `INSERT INTO users (username, password, email, verified, class_id, auth)
VALUES (?, ?, ?, 1, 1, 1)` VALUES (?, ?, ?, 1, 1, 1)`
if err := sqlitex.ExecuteTransient(conn, stmt, &sqlitex.ExecOptions{ if err := db.Exec(stmt, "demo", hashedPassword, "demo@demo.com"); err != nil {
Args: []any{"demo", hashedPassword, "demo@demo.com"},
}); err != nil {
return fmt.Errorf("failed to create demo user: %w", err) return fmt.Errorf("failed to create demo user: %w", err)
} }

View File

@ -0,0 +1,61 @@
package password
import (
"testing"
)
func TestHashAndVerify(t *testing.T) {
password := "Demo123!"
// Test hashing
hash, err := Hash(password)
if err != nil {
t.Fatalf("Failed to hash password: %v", err)
}
// Test that hash is not empty
if hash == "" {
t.Fatal("Hash should not be empty")
}
// Test that hash contains expected format
if hash[:9] != "$argon2id" {
t.Fatal("Hash should start with $argon2id")
}
// Test verification with correct password
valid, err := Verify(password, hash)
if err != nil {
t.Fatalf("Failed to verify password: %v", err)
}
if !valid {
t.Fatal("Password should be valid")
}
// Test verification with incorrect password
valid, err = Verify("wrongpassword", hash)
if err != nil {
t.Fatalf("Failed to verify wrong password: %v", err)
}
if valid {
t.Fatal("Wrong password should not be valid")
}
// Test that hashing same password twice produces different hashes (due to salt)
hash2, err := Hash(password)
if err != nil {
t.Fatalf("Failed to hash password second time: %v", err)
}
if hash == hash2 {
t.Fatal("Hashing same password twice should produce different hashes due to salt")
}
// But both hashes should verify correctly
valid, err = Verify(password, hash2)
if err != nil {
t.Fatalf("Failed to verify second hash: %v", err)
}
if !valid {
t.Fatal("Second hash should also be valid")
}
}