write database wrapper

This commit is contained in:
Sky Johnson 2025-08-08 22:42:41 -05:00
parent d04acc06eb
commit db2a95cd02
4 changed files with 313 additions and 69 deletions

View File

@ -0,0 +1,119 @@
package database
import (
"fmt"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
const DefaultPath = "dk.db"
// DB wraps a SQLite connection with simplified methods
type DB struct {
conn *sqlite.Conn
}
// Open creates a new database connection
func Open(path string) (*DB, error) {
if path == "" {
path = DefaultPath
}
conn, err := sqlite.OpenConn(path, sqlite.OpenCreate|sqlite.OpenReadWrite|sqlite.OpenWAL)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Set pragmas for performance
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to set WAL mode: %w", err)
}
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
}
return &DB{conn: conn}, nil
}
// Close closes the database connection
func (db *DB) Close() error {
return db.conn.Close()
}
// Exec executes a SQL statement without returning results
func (db *DB) Exec(query string, args ...any) error {
if len(args) == 0 {
return sqlitex.ExecuteTransient(db.conn, query, nil)
}
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
Args: args,
})
}
// Query executes a SQL query and calls fn for each row
func (db *DB) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
if len(args) == 0 {
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
ResultFunc: fn,
})
}
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
Args: args,
ResultFunc: fn,
})
}
// Begin starts a new transaction
func (db *DB) Begin() (*Tx, error) {
if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
return &Tx{db: db}, nil
}
// Transaction runs a function within a transaction
func (db *DB) Transaction(fn func(*Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
// Tx represents a database transaction
type Tx struct {
db *DB
}
// Exec executes a SQL statement within the transaction
func (tx *Tx) Exec(query string, args ...any) error {
return tx.db.Exec(query, args...)
}
// Query executes a SQL query within the transaction
func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
return tx.db.Query(query, fn, args...)
}
// Commit commits the transaction
func (tx *Tx) Commit() error {
return sqlitex.ExecuteTransient(tx.db.conn, "COMMIT", nil)
}
// Rollback rolls back the transaction
func (tx *Tx) Rollback() error {
return sqlitex.ExecuteTransient(tx.db.conn, "ROLLBACK", nil)
}

View File

@ -0,0 +1,72 @@
package database
import (
"os"
"testing"
"zombiezen.com/go/sqlite"
)
func TestDatabaseOperations(t *testing.T) {
// Use a temporary database file
testDB := "test.db"
defer os.Remove(testDB)
// Test opening database
db, err := Open(testDB)
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Test creating a simple table
err = db.Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)")
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Test inserting data
err = db.Exec("INSERT INTO test_users (name) VALUES (?)", "Alice")
if err != nil {
t.Fatalf("Failed to insert data: %v", err)
}
// Test querying data
var foundName string
err = db.Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error {
foundName = stmt.ColumnText(0)
return nil
}, "Alice")
if err != nil {
t.Fatalf("Failed to query data: %v", err)
}
if foundName != "Alice" {
t.Errorf("Expected 'Alice', got '%s'", foundName)
}
// Test transaction
err = db.Transaction(func(tx *Tx) error {
return tx.Exec("INSERT INTO test_users (name) VALUES (?)", "Bob")
})
if err != nil {
t.Fatalf("Transaction failed: %v", err)
}
// Verify transaction worked
var count int
err = db.Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error {
count = stmt.ColumnInt(0)
return nil
})
if err != nil {
t.Fatalf("Failed to count users: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 users, got %d", count)
}
}

View File

@ -4,10 +4,8 @@ import (
"fmt"
"time"
"dk/internal/database"
"dk/internal/password"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
const dbPath = "dk.db"
@ -19,24 +17,24 @@ func Run() error {
start := time.Now()
// Open database connection
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenCreate|sqlite.OpenReadWrite|sqlite.OpenWAL)
db, err := database.Open(dbPath)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
return err
}
defer conn.Close()
defer db.Close()
// Create tables
if err := createTables(conn); err != nil {
if err := createTables(db); err != nil {
return fmt.Errorf("failed to create tables: %w", err)
}
// Populate initial data
if err := populateData(conn); err != nil {
if err := populateData(db); err != nil {
return fmt.Errorf("failed to populate data: %w", err)
}
// Create demo user
if err := createDemoUser(conn); err != nil {
if err := createDemoUser(db); err != nil {
return fmt.Errorf("failed to create demo user: %w", err)
}
@ -51,7 +49,7 @@ func Run() error {
return nil
}
func createTables(conn *sqlite.Conn) error {
func createTables(db *database.DB) error {
tables := []struct {
name string
sql string
@ -186,7 +184,7 @@ func createTables(conn *sqlite.Conn) error {
}
for _, table := range tables {
if err := sqlitex.ExecuteTransient(conn, table.sql, nil); err != nil {
if err := db.Exec(table.sql); err != nil {
return fmt.Errorf("failed to create %s table: %w", table.name, err)
}
fmt.Printf("✓ %s table created\n", table.name)
@ -195,27 +193,25 @@ func createTables(conn *sqlite.Conn) error {
return nil
}
func populateData(conn *sqlite.Conn) error {
if err := sqlitex.ExecuteTransient(conn,
"INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')",
nil); err != nil {
func populateData(db *database.DB) error {
if err := db.Exec("INSERT INTO control VALUES (1, 250, 1, '', 'Mage', 'Warrior', 'Paladin')"); err != nil {
return fmt.Errorf("failed to populate control table: %w", err)
}
fmt.Println("✓ Control table populated")
dropsSQL := `INSERT INTO drops VALUES
(1, 'Life Pebble', 1, 1, 'maxhp,10', 'X'),
(2, 'Life Stone', 10, 1, 'maxhp,25', 'X'),
(3, 'Life Rock', 25, 1, 'maxhp,50', 'X'),
(4, 'Magic Pebble', 1, 1, 'maxmp,10', 'X'),
(5, 'Magic Stone', 10, 1, 'maxmp,25', 'X'),
(6, 'Magic Rock', 25, 1, 'maxmp,50', 'X'),
(7, 'Dragon''s Scale', 10, 1, 'defensepower,25', 'X'),
(8, 'Dragon''s Plate', 30, 1, 'defensepower,50', 'X'),
(9, 'Dragon''s Claw', 10, 1, 'attackpower,25', 'X'),
(10, 'Dragon''s Tooth', 30, 1, 'attackpower,50', 'X'),
(11, 'Dragon''s Tear', 35, 1, 'strength,50', 'X'),
(12, 'Dragon''s Wing', 35, 1, 'dexterity,50', 'X'),
(1, 'Life Pebble', 1, 1, 'maxhp,10', ''),
(2, 'Life Stone', 10, 1, 'maxhp,25', ''),
(3, 'Life Rock', 25, 1, 'maxhp,50', ''),
(4, 'Magic Pebble', 1, 1, 'maxmp,10', ''),
(5, 'Magic Stone', 10, 1, 'maxmp,25', ''),
(6, 'Magic Rock', 25, 1, 'maxmp,50', ''),
(7, 'Dragon''s Scale', 10, 1, 'defensepower,25', ''),
(8, 'Dragon''s Plate', 30, 1, 'defensepower,50', ''),
(9, 'Dragon''s Claw', 10, 1, 'attackpower,25', ''),
(10, 'Dragon''s Tooth', 30, 1, 'attackpower,50', ''),
(11, 'Dragon''s Tear', 35, 1, 'strength,50', ''),
(12, 'Dragon''s Wing', 35, 1, 'dexterity,50', ''),
(13, 'Demon''s Sin', 35, 1, 'maxhp,-50', 'strength,50'),
(14, 'Demon''s Fall', 35, 1, 'maxmp,-50', 'strength,50'),
(15, 'Demon''s Lie', 45, 1, 'maxhp,-100', 'strength,100'),
@ -228,54 +224,54 @@ func populateData(conn *sqlite.Conn) error {
(22, 'Seraph''s Rise', 30, 1, 'maxmp,50', 'dexterity,50'),
(23, 'Seraph''s Truth', 35, 1, 'maxmp,75', 'dexterity,75'),
(24, 'Seraph''s Love', 40, 1, 'maxmp,100', 'dexterity,100'),
(25, 'Ruby', 50, 1, 'maxhp,150', 'X'),
(26, 'Pearl', 50, 1, 'maxmp,150', 'X'),
(27, 'Emerald', 50, 1, 'strength,150', 'X'),
(28, 'Topaz', 50, 1, 'dexterity,150', 'X'),
(29, 'Obsidian', 50, 1, 'attackpower,150', 'X'),
(30, 'Diamond', 50, 1, 'defensepower,150', 'X'),
(31, 'Memory Drop', 5, 1, 'expbonus,10', 'X'),
(32, 'Fortune Drop', 5, 1, 'goldbonus,10', 'X')`
if err := sqlitex.ExecuteTransient(conn, dropsSQL, nil); err != nil {
(25, 'Ruby', 50, 1, 'maxhp,150', ''),
(26, 'Pearl', 50, 1, 'maxmp,150', ''),
(27, 'Emerald', 50, 1, 'strength,150', ''),
(28, 'Topaz', 50, 1, 'dexterity,150', ''),
(29, 'Obsidian', 50, 1, 'attackpower,150', ''),
(30, 'Diamond', 50, 1, 'defensepower,150', ''),
(31, 'Memory Drop', 5, 1, 'expbonus,10', ''),
(32, 'Fortune Drop', 5, 1, 'goldbonus,10', '')`
if err := db.Exec(dropsSQL); err != nil {
return fmt.Errorf("failed to populate drops table: %w", err)
}
fmt.Println("✓ Drops table populated")
itemsSQL := `INSERT INTO items VALUES
(1, 1, 'Stick', 10, 2, 'X'),
(2, 1, 'Branch', 30, 4, 'X'),
(3, 1, 'Club', 40, 5, 'X'),
(4, 1, 'Dagger', 90, 8, 'X'),
(5, 1, 'Hatchet', 150, 12, 'X'),
(6, 1, 'Axe', 200, 16, 'X'),
(7, 1, 'Brand', 300, 25, 'X'),
(8, 1, 'Poleaxe', 500, 35, 'X'),
(9, 1, 'Broadsword', 800, 45, 'X'),
(10, 1, 'Battle Axe', 1200, 50, 'X'),
(11, 1, 'Claymore', 2000, 60, 'X'),
(1, 1, 'Stick', 10, 2, ''),
(2, 1, 'Branch', 30, 4, ''),
(3, 1, 'Club', 40, 5, ''),
(4, 1, 'Dagger', 90, 8, ''),
(5, 1, 'Hatchet', 150, 12, ''),
(6, 1, 'Axe', 200, 16, ''),
(7, 1, 'Brand', 300, 25, ''),
(8, 1, 'Poleaxe', 500, 35, ''),
(9, 1, 'Broadsword', 800, 45, ''),
(10, 1, 'Battle Axe', 1200, 50, ''),
(11, 1, 'Claymore', 2000, 60, ''),
(12, 1, 'Dark Axe', 3000, 100, 'expbonus,-5'),
(13, 1, 'Dark Sword', 4500, 125, 'expbonus,-10'),
(14, 1, 'Bright Sword', 6000, 100, 'expbonus,10'),
(15, 1, 'Magic Sword', 10000, 150, 'maxmp,50'),
(16, 1, 'Destiny Blade', 50000, 250, 'strength,50'),
(17, 2, 'Skivvies', 25, 2, 'goldbonus,10'),
(18, 2, 'Clothes', 50, 5, 'X'),
(19, 2, 'Leather Armor', 75, 10, 'X'),
(20, 2, 'Hard Leather Armor', 150, 25, 'X'),
(21, 2, 'Chain Mail', 300, 30, 'X'),
(22, 2, 'Bronze Plate', 900, 50, 'X'),
(23, 2, 'Iron Plate', 2000, 100, 'X'),
(18, 2, 'Clothes', 50, 5, ''),
(19, 2, 'Leather Armor', 75, 10, ''),
(20, 2, 'Hard Leather Armor', 150, 25, ''),
(21, 2, 'Chain Mail', 300, 30, ''),
(22, 2, 'Bronze Plate', 900, 50, ''),
(23, 2, 'Iron Plate', 2000, 100, ''),
(24, 2, 'Magic Armor', 4000, 125, 'maxmp,50'),
(25, 2, 'Dark Armor', 5000, 150, 'expbonus,-10'),
(26, 2, 'Bright Armor', 10000, 175, 'expbonus,10'),
(27, 2, 'Destiny Raiment', 50000, 200, 'dexterity,50'),
(28, 3, 'Reed Shield', 50, 2, 'X'),
(29, 3, 'Buckler', 100, 4, 'X'),
(30, 3, 'Small Shield', 500, 10, 'X'),
(31, 3, 'Large Shield', 2500, 30, 'X'),
(32, 3, 'Silver Shield', 10000, 60, 'X'),
(28, 3, 'Reed Shield', 50, 2, ''),
(29, 3, 'Buckler', 100, 4, ''),
(30, 3, 'Small Shield', 500, 10, ''),
(31, 3, 'Large Shield', 2500, 30, ''),
(32, 3, 'Silver Shield', 10000, 60, ''),
(33, 3, 'Destiny Aegis', 25000, 100, 'maxhp,50')`
if err := sqlitex.ExecuteTransient(conn, itemsSQL, nil); err != nil {
if err := db.Exec(itemsSQL); err != nil {
return fmt.Errorf("failed to populate items table: %w", err)
}
fmt.Println("✓ Items table populated")
@ -432,15 +428,13 @@ func populateData(conn *sqlite.Conn) error {
(149, 'Titan', 360, 340, 270, 50, 2400, 800, 0),
(150, 'Black Daemon', 400, 400, 280, 50, 3000, 1000, 1),
(151, 'Lucifuge', 600, 600, 400, 50, 10000, 10000, 2)`
if err := sqlitex.ExecuteTransient(conn, monstersSQL, nil); err != nil {
if err := db.Exec(monstersSQL); err != nil {
return fmt.Errorf("failed to populate monsters table: %w", err)
}
fmt.Println("✓ Monsters table populated (sample data)")
// News table
if err := sqlitex.ExecuteTransient(conn,
"INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')",
nil); err != nil {
if err := db.Exec("INSERT INTO news (author, content) VALUES (1, 'Welcome to Dragon Knight! This is your first news post.')"); err != nil {
return fmt.Errorf("failed to populate news table: %w", err)
}
fmt.Println("✓ News table populated")
@ -466,7 +460,7 @@ func populateData(conn *sqlite.Conn) error {
(17, 'Ward', 10, 10, 5),
(18, 'Fend', 20, 25, 5),
(19, 'Barrier', 30, 50, 5)`
if err := sqlitex.ExecuteTransient(conn, spellsSQL, nil); err != nil {
if err := db.Exec(spellsSQL); err != nil {
return fmt.Errorf("failed to populate spells table: %w", err)
}
fmt.Println("✓ Spells table populated")
@ -481,7 +475,7 @@ func populateData(conn *sqlite.Conn) error {
(6, 'Hambry', 170, 170, 90, 1000, 80, '10,11,12,13,14,23,24,30,31'),
(7, 'Gilead', 200, -200, 100, 3000, 110, '12,13,14,15,24,25,26,32'),
(8, 'Endworld', -250, -250, 125, 9000, 160, '16,27,33')`
if err := sqlitex.ExecuteTransient(conn, townsSQL, nil); err != nil {
if err := db.Exec(townsSQL); err != nil {
return fmt.Errorf("failed to populate towns table: %w", err)
}
fmt.Println("✓ Towns table populated")
@ -489,7 +483,7 @@ func populateData(conn *sqlite.Conn) error {
return nil
}
func createDemoUser(conn *sqlite.Conn) error {
func createDemoUser(db *database.DB) error {
// Hash the password using argon2id
hashedPassword, err := password.Hash("Demo123!")
if err != nil {
@ -499,9 +493,7 @@ func createDemoUser(conn *sqlite.Conn) error {
stmt := `INSERT INTO users (username, password, email, verified, class_id, auth)
VALUES (?, ?, ?, 1, 1, 1)`
if err := sqlitex.ExecuteTransient(conn, stmt, &sqlitex.ExecOptions{
Args: []any{"demo", hashedPassword, "demo@demo.com"},
}); err != nil {
if err := db.Exec(stmt, "demo", hashedPassword, "demo@demo.com"); err != nil {
return fmt.Errorf("failed to create demo user: %w", err)
}

View File

@ -0,0 +1,61 @@
package password
import (
"testing"
)
func TestHashAndVerify(t *testing.T) {
password := "Demo123!"
// Test hashing
hash, err := Hash(password)
if err != nil {
t.Fatalf("Failed to hash password: %v", err)
}
// Test that hash is not empty
if hash == "" {
t.Fatal("Hash should not be empty")
}
// Test that hash contains expected format
if hash[:9] != "$argon2id" {
t.Fatal("Hash should start with $argon2id")
}
// Test verification with correct password
valid, err := Verify(password, hash)
if err != nil {
t.Fatalf("Failed to verify password: %v", err)
}
if !valid {
t.Fatal("Password should be valid")
}
// Test verification with incorrect password
valid, err = Verify("wrongpassword", hash)
if err != nil {
t.Fatalf("Failed to verify wrong password: %v", err)
}
if valid {
t.Fatal("Wrong password should not be valid")
}
// Test that hashing same password twice produces different hashes (due to salt)
hash2, err := Hash(password)
if err != nil {
t.Fatalf("Failed to hash password second time: %v", err)
}
if hash == hash2 {
t.Fatal("Hashing same password twice should produce different hashes due to salt")
}
// But both hashes should verify correctly
valid, err = Verify(password, hash2)
if err != nil {
t.Fatalf("Failed to verify second hash: %v", err)
}
if !valid {
t.Fatal("Second hash should also be valid")
}
}