From 9c04d9a67e09327af2abcdbd17254ba4f6729fe1 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 7 Aug 2025 20:08:24 -0500 Subject: [PATCH] fix database - zombiezen only for sqlite --- go.mod | 10 +- go.sum | 2 - internal/database/database.go | 284 ++++++++++++++++++++--------- internal/database/database_test.go | 23 ++- 4 files changed, 208 insertions(+), 111 deletions(-) diff --git a/go.mod b/go.mod index 6a8d16d..cd1a840 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,12 @@ go 1.24.5 require zombiezen.com/go/sqlite v1.4.2 -require ( - filippo.io/edwards25519 v1.1.0 // indirect - golang.org/x/text v0.27.0 // indirect -) +require golang.org/x/text v0.27.0 // indirect require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/go-sql-driver/mysql v1.9.3 + github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/google/uuid v1.6.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect @@ -21,5 +19,5 @@ require ( modernc.org/libc v1.65.7 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect - modernc.org/sqlite v1.37.1 + modernc.org/sqlite v1.37.1 // indirect ) diff --git a/go.sum b/go.sum index 601e680..9e93d69 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1 github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= -github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= diff --git a/internal/database/database.go b/internal/database/database.go index 3e7e611..6d82535 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -1,12 +1,13 @@ package database import ( + "context" "database/sql" "fmt" "sync" _ "github.com/go-sql-driver/mysql" - _ "modernc.org/sqlite" + "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) @@ -25,55 +26,54 @@ type Config struct { PoolSize int // Connection pool size } -// Database wraps the SQL database connection +// Database wraps database connections for both SQLite (zombiezen) and MySQL type Database struct { - db *sql.DB - pool *sqlitex.Pool // For achievements system compatibility (SQLite only) + db *sql.DB // For MySQL + pool *sqlitex.Pool // For SQLite (zombiezen) config Config mutex sync.RWMutex } // New creates a new database connection with the provided configuration func New(config Config) (*Database, error) { - var driverName string - var pool *sqlitex.Pool - // Set default pool size if config.PoolSize == 0 { config.PoolSize = 25 } + var db *sql.DB + var pool *sqlitex.Pool + switch config.Type { case SQLite: - driverName = "sqlite" - // Create sqlitex pool for achievements system compatibility + // Use zombiezen sqlite pool var err error pool, err = sqlitex.NewPool(config.DSN, sqlitex.PoolOptions{ - PoolSize: 5, + PoolSize: config.PoolSize, }) if err != nil { return nil, fmt.Errorf("failed to create sqlite pool: %w", err) } case MySQL: - driverName = "mysql" + // Use standard database/sql for MySQL + var err error + db, err = sql.Open("mysql", config.DSN) + if err != nil { + return nil, fmt.Errorf("failed to open mysql database: %w", err) + } + + // Test connection + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("failed to ping mysql database: %w", err) + } + + // Set connection pool settings + db.SetMaxOpenConns(config.PoolSize) + db.SetMaxIdleConns(config.PoolSize / 5) default: return nil, fmt.Errorf("unsupported database type: %d", config.Type) } - db, err := sql.Open(driverName, config.DSN) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - // Test connection - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("failed to ping database: %w", err) - } - - // Set connection pool settings - db.SetMaxOpenConns(config.PoolSize) - db.SetMaxIdleConns(config.PoolSize / 5) - d := &Database{ db: db, pool: pool, @@ -88,7 +88,10 @@ func (d *Database) Close() error { if d.pool != nil { d.pool.Close() } - return d.db.Close() + if d.db != nil { + return d.db.Close() + } + return nil } // GetType returns the database type @@ -96,63 +99,133 @@ func (d *Database) GetType() DatabaseType { return d.config.Type } -// GetPool returns the sqlitex pool for achievements system compatibility +// GetPool returns the sqlitex pool func (d *Database) GetPool() *sqlitex.Pool { return d.pool } -// Query executes a query that returns rows +// Query executes a query that returns rows (database/sql compatibility) func (d *Database) Query(query string, args ...any) (*sql.Rows, error) { - return d.db.Query(query, args...) + if d.config.Type == MySQL { + return d.db.Query(query, args...) + } + return nil, fmt.Errorf("Query method only supported for MySQL; use ExecTransient for SQLite") } -// QueryRow executes a query that returns a single row +// QueryRow executes a query that returns a single row (database/sql compatibility) func (d *Database) QueryRow(query string, args ...any) *sql.Row { - return d.db.QueryRow(query, args...) + if d.config.Type == MySQL { + return d.db.QueryRow(query, args...) + } + return nil // This will result in an error when scanned } -// Exec executes a query that doesn't return rows +// Exec executes a query that doesn't return rows (database/sql compatibility) func (d *Database) Exec(query string, args ...any) (sql.Result, error) { - return d.db.Exec(query, args...) + if d.config.Type == MySQL { + return d.db.Exec(query, args...) + } + return nil, fmt.Errorf("Exec method only supported for MySQL; use Execute for SQLite") } -// Begin starts a transaction +// Begin starts a transaction (database/sql compatibility) func (d *Database) Begin() (*sql.Tx, error) { - return d.db.Begin() + if d.config.Type == MySQL { + return d.db.Begin() + } + return nil, fmt.Errorf("Begin method only supported for MySQL; use zombiezen transaction helpers for SQLite") +} + +// Execute executes a query using the zombiezen sqlite approach (SQLite only) +func (d *Database) Execute(query string, opts *sqlitex.ExecOptions) error { + if d.config.Type != SQLite { + return fmt.Errorf("Execute method only supported for SQLite") + } + + conn, err := d.pool.Take(context.Background()) + if err != nil { + return err + } + defer d.pool.Put(conn) + + return sqlitex.Execute(conn, query, opts) +} + +// ExecTransient executes a transient query and calls resultFn for each row (SQLite only) +func (d *Database) ExecTransient(query string, resultFn func(stmt *sqlite.Stmt) error, args ...any) error { + if d.config.Type != SQLite { + return fmt.Errorf("ExecTransient method only supported for SQLite") + } + + conn, err := d.pool.Take(context.Background()) + if err != nil { + return err + } + defer d.pool.Put(conn) + + return sqlitex.ExecTransient(conn, query, resultFn, args...) } // LoadRules loads all rules from the database func (d *Database) LoadRules() (map[string]map[string]string, error) { - rows, err := d.Query("SELECT category, name, value FROM rules") - if err != nil { - return nil, err - } - defer rows.Close() - rules := make(map[string]map[string]string) - - for rows.Next() { - var category, name, value string - if err := rows.Scan(&category, &name, &value); err != nil { + + if d.config.Type == SQLite { + err := d.ExecTransient("SELECT category, name, value FROM rules", func(stmt *sqlite.Stmt) error { + category := stmt.ColumnText(0) + name := stmt.ColumnText(1) + value := stmt.ColumnText(2) + + if rules[category] == nil { + rules[category] = make(map[string]string) + } + rules[category][name] = value + + return nil + }) + return rules, err + } else { + // MySQL using database/sql + rows, err := d.Query("SELECT category, name, value FROM rules") + if err != nil { return nil, err } + defer rows.Close() - if rules[category] == nil { - rules[category] = make(map[string]string) + for rows.Next() { + var category, name, value string + if err := rows.Scan(&category, &name, &value); err != nil { + return nil, err + } + + if rules[category] == nil { + rules[category] = make(map[string]string) + } + rules[category][name] = value } - rules[category][name] = value - } - return rules, rows.Err() + return rules, rows.Err() + } } // SaveRule saves a rule to the database func (d *Database) SaveRule(category, name, value, description string) error { - _, err := d.Exec(` - INSERT OR REPLACE INTO rules (category, name, value, description) - VALUES (?, ?, ?, ?) - `, category, name, value, description) - return err + if d.config.Type == SQLite { + return d.Execute(` + INSERT OR REPLACE INTO rules (category, name, value, description) + VALUES (?, ?, ?, ?) + `, &sqlitex.ExecOptions{ + Args: []any{category, name, value, description}, + }) + } else { + // MySQL using database/sql + _, err := d.Exec(` + INSERT INTO rules (category, name, value, description) + VALUES (?, ?, ?, ?) + ON DUPLICATE KEY UPDATE value = VALUES(value), description = VALUES(description) + `, category, name, value, description) + return err + } } // NewSQLite creates a new SQLite database connection @@ -174,53 +247,82 @@ func NewMySQL(dsn string) (*Database, error) { // GetZones retrieves all zones from the database func (d *Database) GetZones() ([]map[string]any, error) { - rows, err := d.Query(` + var zones []map[string]any + + query := ` SELECT id, name, file, description, motd, min_level, max_level, min_version, xp_modifier, city_zone, weather_allowed, safe_x, safe_y, safe_z, safe_heading FROM zones ORDER BY name - `) - if err != nil { - return nil, err - } - defer rows.Close() + ` - var zones []map[string]any + if d.config.Type == SQLite { + err := d.ExecTransient(query, func(stmt *sqlite.Stmt) error { + zone := make(map[string]any) + + zone["id"] = stmt.ColumnInt(0) + zone["name"] = stmt.ColumnText(1) + zone["file"] = stmt.ColumnText(2) + zone["description"] = stmt.ColumnText(3) + zone["motd"] = stmt.ColumnText(4) + zone["min_level"] = stmt.ColumnInt(5) + zone["max_level"] = stmt.ColumnInt(6) + zone["min_version"] = stmt.ColumnInt(7) + zone["xp_modifier"] = stmt.ColumnFloat(8) + zone["city_zone"] = stmt.ColumnBool(9) + zone["weather_allowed"] = stmt.ColumnBool(10) + zone["safe_x"] = stmt.ColumnFloat(11) + zone["safe_y"] = stmt.ColumnFloat(12) + zone["safe_z"] = stmt.ColumnFloat(13) + zone["safe_heading"] = stmt.ColumnFloat(14) - for rows.Next() { - zone := make(map[string]any) - var id, minLevel, maxLevel, minVersion int - var name, file, description, motd string - var xpModifier, safeX, safeY, safeZ, safeHeading float64 - var cityZone, weatherAllowed bool - - err := rows.Scan(&id, &name, &file, &description, &motd, - &minLevel, &maxLevel, &minVersion, &xpModifier, - &cityZone, &weatherAllowed, - &safeX, &safeY, &safeZ, &safeHeading) + zones = append(zones, zone) + return nil + }) + return zones, err + } else { + // MySQL using database/sql + rows, err := d.Query(query) if err != nil { return nil, err } + defer rows.Close() - zone["id"] = id - zone["name"] = name - zone["file"] = file - zone["description"] = description - zone["motd"] = motd - zone["min_level"] = minLevel - zone["max_level"] = maxLevel - zone["min_version"] = minVersion - zone["xp_modifier"] = xpModifier - zone["city_zone"] = cityZone - zone["weather_allowed"] = weatherAllowed - zone["safe_x"] = safeX - zone["safe_y"] = safeY - zone["safe_z"] = safeZ - zone["safe_heading"] = safeHeading + for rows.Next() { + zone := make(map[string]any) + var id, minLevel, maxLevel, minVersion int + var name, file, description, motd string + var xpModifier, safeX, safeY, safeZ, safeHeading float64 + var cityZone, weatherAllowed bool - zones = append(zones, zone) + err := rows.Scan(&id, &name, &file, &description, &motd, + &minLevel, &maxLevel, &minVersion, &xpModifier, + &cityZone, &weatherAllowed, + &safeX, &safeY, &safeZ, &safeHeading) + if err != nil { + return nil, err + } + + zone["id"] = id + zone["name"] = name + zone["file"] = file + zone["description"] = description + zone["motd"] = motd + zone["min_level"] = minLevel + zone["max_level"] = maxLevel + zone["min_version"] = minVersion + zone["xp_modifier"] = xpModifier + zone["city_zone"] = cityZone + zone["weather_allowed"] = weatherAllowed + zone["safe_x"] = safeX + zone["safe_y"] = safeY + zone["safe_z"] = safeZ + zone["safe_heading"] = safeHeading + + zones = append(zones, zone) + } + + return zones, rows.Err() } - - return zones, rows.Err() } diff --git a/internal/database/database_test.go b/internal/database/database_test.go index cbfe129..950afbf 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test.go @@ -2,6 +2,9 @@ package database import ( "testing" + + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" ) func TestNewSQLite(t *testing.T) { @@ -18,29 +21,25 @@ func TestNewSQLite(t *testing.T) { } // Test basic query - result, err := db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + err = db.Execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)", nil) if err != nil { t.Fatalf("Failed to create test table: %v", err) } - affected, err := result.RowsAffected() - if err != nil { - t.Fatalf("Failed to get rows affected: %v", err) - } - - if affected != 0 { - t.Errorf("Expected 0 rows affected for CREATE TABLE, got %d", affected) - } - // Test insert - _, err = db.Exec("INSERT INTO test (name) VALUES (?)", "test_value") + err = db.Execute("INSERT INTO test (name) VALUES (?)", &sqlitex.ExecOptions{ + Args: []any{"test_value"}, + }) if err != nil { t.Fatalf("Failed to insert test data: %v", err) } // Test query var name string - err = db.QueryRow("SELECT name FROM test WHERE id = 1").Scan(&name) + err = db.ExecTransient("SELECT name FROM test WHERE id = 1", func(stmt *sqlite.Stmt) error { + name = stmt.ColumnText(0) + return nil + }) if err != nil { t.Fatalf("Failed to query test data: %v", err) }