From fc82f97cb6bdb315920e3511bfe12e7c978e7f3f Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 30 Jul 2025 09:29:20 -0500 Subject: [PATCH] create thin database wrapper --- cmd/login_server/database.go | 245 +++++++++++--------------------- internal/database/wrapper.go | 262 +++++++++++++++++++++++++++++++++++ 2 files changed, 342 insertions(+), 165 deletions(-) create mode 100644 internal/database/wrapper.go diff --git a/cmd/login_server/database.go b/cmd/login_server/database.go index d9f21a3..0fabbcf 100644 --- a/cmd/login_server/database.go +++ b/cmd/login_server/database.go @@ -1,18 +1,17 @@ package main import ( + "eq2emu/internal/database" "fmt" "log" "time" "golang.org/x/crypto/bcrypt" - "zombiezen.com/go/sqlite" - "zombiezen.com/go/sqlite/sqlitex" ) // Database handles all database operations for the login server type Database struct { - conn *sqlite.Conn + db *database.DB } // DatabaseConfig holds database connection settings @@ -24,32 +23,26 @@ type DatabaseConfig struct { // NewDatabase creates a new database connection func NewDatabase(config DatabaseConfig) (*Database, error) { - // Open SQLite database - db, err := sqlite.OpenConn(config.FilePath, sqlite.OpenReadWrite|sqlite.OpenCreate) + db, err := database.Open(config.FilePath) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } - // Enable foreign keys - if err := sqlitex.ExecuteTransient(db, "PRAGMA foreign_keys = ON", nil); err != nil { - return nil, fmt.Errorf("failed to enable foreign keys: %w", err) - } - - // Set busy timeout + // Set busy timeout if specified if config.BusyTimeout > 0 { query := fmt.Sprintf("PRAGMA busy_timeout = %d", config.BusyTimeout) - if err := sqlitex.ExecuteTransient(db, query, nil); err != nil { + if err := db.Exec(query); err != nil { return nil, fmt.Errorf("failed to set busy timeout: %w", err) } } log.Println("SQLite database connection established") - return &Database{conn: db}, nil + return &Database{db: db}, nil } // Close closes the database connection func (d *Database) Close() error { - return d.conn.Close() + return d.db.Close() } // AuthenticateAccount verifies user credentials and returns account info @@ -60,49 +53,37 @@ func (d *Database) AuthenticateAccount(username, password string) (*Account, err FROM accounts WHERE username = ? AND active = 1` - stmt, err := d.conn.Prepare(query) - if err != nil { - return nil, fmt.Errorf("prepare statement failed: %w", err) - } - defer stmt.Finalize() - - stmt.BindText(1, username) - - hasRow, err := stmt.Step() + row, err := d.db.QueryRow(query, username) if err != nil { return nil, fmt.Errorf("query failed: %w", err) } - if !hasRow { + if row == nil { return nil, nil // Account not found } + defer row.Close() var account Account - var passwordHash string - var lastLogin string - var clientVersion int64 - - account.ID = int32(stmt.ColumnInt64(0)) - account.Username = stmt.ColumnText(1) - passwordHash = stmt.ColumnText(2) - account.LSAdmin = stmt.ColumnInt(3) != 0 - account.WorldAdmin = stmt.ColumnInt(4) != 0 + account.ID = int32(row.Int64(0)) + account.Username = row.Text(1) + passwordHash := row.Text(2) + account.LSAdmin = row.Bool(3) + account.WorldAdmin = row.Bool(4) // Skip created_date at index 5 - not needed for authentication - lastLogin = stmt.ColumnText(6) - clientVersion = stmt.ColumnInt64(7) + lastLogin := row.Text(6) + account.ClientVersion = uint16(row.Int64(7)) // Verify password if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil { return nil, nil // Invalid password } - // Parse timestamps + // Parse timestamp if lastLogin != "" { if t, err := time.Parse("2006-01-02 15:04:05", lastLogin); err == nil { account.LastLogin = t } } - account.ClientVersion = uint16(clientVersion) return &account, nil } @@ -113,14 +94,12 @@ func (d *Database) UpdateAccountLogin(account *Account) error { SET last_login = ?, last_ip = ?, client_version = ? WHERE id = ?` - return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{ - Args: []any{ - account.LastLogin.Format("2006-01-02 15:04:05"), - account.IPAddress, - account.ClientVersion, - account.ID, - }, - }) + return d.db.Exec(query, + account.LastLogin.Format("2006-01-02 15:04:05"), + account.IPAddress, + account.ClientVersion, + account.ID, + ) } // LoadCharacters loads all characters for an account @@ -132,45 +111,29 @@ func (d *Database) LoadCharacters(accountID int32, version uint16) ([]*Character WHERE account_id = ? ORDER BY created_date ASC` - stmt, err := d.conn.Prepare(query) - if err != nil { - return nil, fmt.Errorf("prepare failed: %w", err) - } - defer stmt.Finalize() - - stmt.BindInt64(1, int64(accountID)) - var characters []*Character - for { - hasRow, err := stmt.Step() - if err != nil { - return nil, fmt.Errorf("query failed: %w", err) - } - if !hasRow { - break - } - + err := d.db.Query(query, func(row *database.Row) error { char := &Character{AccountID: accountID} + char.ID = int32(row.Int64(0)) + char.ServerID = int32(row.Int64(1)) + char.Name = row.Text(2) + char.Level = int8(row.Int(3)) + char.Race = int8(row.Int(4)) + char.Gender = int8(row.Int(5)) + char.Class = int8(row.Int(6)) - char.ID = int32(stmt.ColumnInt64(0)) - char.ServerID = int32(stmt.ColumnInt64(1)) - char.Name = stmt.ColumnText(2) - char.Level = int8(stmt.ColumnInt(3)) - char.Race = int8(stmt.ColumnInt(4)) - char.Gender = int8(stmt.ColumnInt(5)) - char.Class = int8(stmt.ColumnInt(6)) - - if dateStr := stmt.ColumnText(7); dateStr != "" { + if dateStr := row.Text(7); dateStr != "" { if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil { char.CreatedDate = t } } - char.Deleted = stmt.ColumnInt(8) != 0 + char.Deleted = row.Bool(8) characters = append(characters, char) - } + return nil + }, accountID) - return characters, nil + return characters, err } // CharacterNameExists checks if a character name is already taken @@ -180,24 +143,16 @@ func (d *Database) CharacterNameExists(name string, serverID int32) (bool, error FROM characters WHERE name = ? AND server_id = ? AND deleted = 0` - stmt, err := d.conn.Prepare(query) + row, err := d.db.QueryRow(query, name, serverID) if err != nil { return false, err } - defer stmt.Finalize() - - stmt.BindText(1, name) - stmt.BindInt64(2, int64(serverID)) - - hasRow, err := stmt.Step() - if err != nil { - return false, err - } - if hasRow { - return stmt.ColumnInt(0) > 0, nil + if row == nil { + return false, nil } + defer row.Close() - return false, nil + return row.Int(0) > 0, nil } // CreateCharacter creates a new character in the database @@ -207,18 +162,16 @@ func (d *Database) CreateCharacter(char *Character) (int32, error) { gender, class, created_date) VALUES (?, ?, ?, ?, ?, ?, ?, ?)` - err := sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{ - Args: []any{ - char.AccountID, char.ServerID, char.Name, char.Level, - char.Race, char.Gender, char.Class, - char.CreatedDate.Format("2006-01-02 15:04:05"), - }, - }) + err := d.db.Exec(query, + char.AccountID, char.ServerID, char.Name, char.Level, + char.Race, char.Gender, char.Class, + char.CreatedDate.Format("2006-01-02 15:04:05"), + ) if err != nil { return 0, fmt.Errorf("failed to create character: %w", err) } - id := int32(d.conn.LastInsertRowID()) + id := int32(d.db.LastInsertID()) log.Printf("Created character %s (ID: %d) for account %d", char.Name, id, char.AccountID) @@ -232,18 +185,13 @@ func (d *Database) DeleteCharacter(charID, accountID int32) error { SET deleted = 1, deleted_date = CURRENT_TIMESTAMP WHERE id = ? AND account_id = ?` - err := sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{ - Args: []any{charID, accountID}, - }) + err := d.db.Exec(query, charID, accountID) if err != nil { return fmt.Errorf("failed to delete character: %w", err) } // Check if any rows were affected - stmt, _ := d.conn.Prepare("SELECT changes()") - defer stmt.Finalize() - hasRow, _ := stmt.Step() - if hasRow && stmt.ColumnInt(0) == 0 { + if d.db.Changes() == 0 { return fmt.Errorf("character not found or not owned by account") } @@ -260,35 +208,20 @@ func (d *Database) GetWorldServers() ([]*WorldServer, error) { WHERE active = 1 ORDER BY sort_order, name` - stmt, err := d.conn.Prepare(query) - if err != nil { - return nil, fmt.Errorf("prepare failed: %w", err) - } - defer stmt.Finalize() - var servers []*WorldServer - for { - hasRow, err := stmt.Step() - if err != nil { - return nil, fmt.Errorf("query failed: %w", err) - } - if !hasRow { - break - } - + err := d.db.Query(query, func(row *database.Row) error { server := &WorldServer{} + server.ID = int32(row.Int64(0)) + server.Name = row.Text(1) + server.Description = row.Text(2) + server.IPAddress = row.Text(3) + server.Port = row.Int(4) + server.Status = row.Text(5) + server.Population = int32(row.Int64(6)) + server.Locked = row.Bool(7) + server.Hidden = row.Bool(8) - server.ID = int32(stmt.ColumnInt64(0)) - server.Name = stmt.ColumnText(1) - server.Description = stmt.ColumnText(2) - server.IPAddress = stmt.ColumnText(3) - server.Port = stmt.ColumnInt(4) - server.Status = stmt.ColumnText(5) - server.Population = int32(stmt.ColumnInt64(6)) - server.Locked = stmt.ColumnInt(7) != 0 - server.Hidden = stmt.ColumnInt(8) != 0 - - if dateStr := stmt.ColumnText(9); dateStr != "" { + if dateStr := row.Text(9); dateStr != "" { if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil { server.CreatedDate = t } @@ -296,11 +229,11 @@ func (d *Database) GetWorldServers() ([]*WorldServer, error) { server.Online = server.Status == "online" server.PopulationLevel = calculatePopulationLevel(server.Population) - servers = append(servers, server) - } + return nil + }) - return servers, nil + return servers, err } // UpdateWorldServerStats updates world server statistics @@ -310,12 +243,10 @@ func (d *Database) UpdateWorldServerStats(serverID int32, stats *WorldServerStat (server_id, timestamp, population, zones_active, players_online, uptime_seconds) VALUES (?, CURRENT_TIMESTAMP, ?, ?, ?, ?)` - return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{ - Args: []any{ - serverID, stats.Population, - stats.ZonesActive, stats.PlayersOnline, stats.UptimeSeconds, - }, - }) + return d.db.Exec(query, + serverID, stats.Population, + stats.ZonesActive, stats.PlayersOnline, stats.UptimeSeconds, + ) } // CleanupOldEntries removes old log entries and statistics @@ -327,7 +258,7 @@ func (d *Database) CleanupOldEntries() error { } for _, query := range queries { - if err := sqlitex.ExecuteTransient(d.conn, query, nil); err != nil { + if err := d.db.Exec(query); err != nil { log.Printf("Cleanup query failed: %v", err) } } @@ -341,9 +272,7 @@ func (d *Database) LogLoginAttempt(username, ipAddress string, success bool) err INSERT INTO login_attempts (username, ip_address, success, timestamp) VALUES (?, ?, ?, CURRENT_TIMESTAMP)` - return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{ - Args: []any{username, ipAddress, success}, - }) + return d.db.Exec(query, username, ipAddress, success) } // GetMaxCharsSetting returns the maximum characters per account @@ -351,21 +280,15 @@ func (d *Database) GetMaxCharsSetting() int32 { var maxChars int32 = 7 // Default query := "SELECT value FROM server_settings WHERE name = 'max_characters_per_account'" - stmt, err := d.conn.Prepare(query) - if err != nil { + row, err := d.db.QueryRow(query) + if err != nil || row == nil { return maxChars } - defer stmt.Finalize() + defer row.Close() - hasRow, err := stmt.Step() - if err != nil { - return maxChars - } - if hasRow { - if val := stmt.ColumnText(0); val != "" { - if parsed := stmt.ColumnInt64(0); parsed > 0 { - maxChars = int32(parsed) - } + if !row.IsNull(0) { + if val := row.Int64(0); val > 0 { + maxChars = int32(val) } } @@ -377,21 +300,13 @@ func (d *Database) GetAccountBonus(accountID int32) uint8 { var bonus uint8 = 0 query := "SELECT veteran_bonus FROM accounts WHERE id = ?" - stmt, err := d.conn.Prepare(query) - if err != nil { + row, err := d.db.QueryRow(query, accountID) + if err != nil || row == nil { return bonus } - defer stmt.Finalize() - - stmt.BindInt64(1, int64(accountID)) - hasRow, err := stmt.Step() - if err != nil { - return bonus - } - if hasRow { - bonus = uint8(stmt.ColumnInt(0)) - } + defer row.Close() + bonus = uint8(row.Int(0)) return bonus } diff --git a/internal/database/wrapper.go b/internal/database/wrapper.go new file mode 100644 index 0000000..caa5e98 --- /dev/null +++ b/internal/database/wrapper.go @@ -0,0 +1,262 @@ +package database + +import ( + "fmt" + + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" +) + +// DB wraps sqlite.Conn with simplified query methods +type DB struct { + conn *sqlite.Conn +} + +// Row represents a single database row with easy column access +type Row struct { + stmt *sqlite.Stmt +} + +// QueryFunc processes each row in a result set +type QueryFunc func(*Row) error + +// Open creates a new database connection with common settings +func Open(path string) (*DB, error) { + conn, err := sqlite.OpenConn(path, sqlite.OpenReadWrite|sqlite.OpenCreate) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + // Enable foreign keys and WAL mode for better performance + if err := sqlitex.ExecuteTransient(conn, "PRAGMA foreign_keys = ON", nil); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to enable foreign keys: %w", err) + } + + if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to enable WAL mode: %w", err) + } + + return &DB{conn: conn}, nil +} + +// Close closes the database connection +func (db *DB) Close() error { + return db.conn.Close() +} + +// Exec executes a statement with parameters +func (db *DB) Exec(query string, args ...any) error { + return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{ + Args: args, + }) +} + +// QueryRow executes a query expecting a single row result +func (db *DB) QueryRow(query string, args ...any) (*Row, error) { + stmt, err := db.conn.Prepare(query) + if err != nil { + return nil, fmt.Errorf("prepare failed: %w", err) + } + + // Bind parameters + for i, arg := range args { + if err := bindParam(stmt, i+1, arg); err != nil { + stmt.Finalize() + return nil, err + } + } + + hasRow, err := stmt.Step() + if err != nil { + stmt.Finalize() + return nil, fmt.Errorf("query failed: %w", err) + } + if !hasRow { + stmt.Finalize() + return nil, nil // No row found + } + + return &Row{stmt: stmt}, nil +} + +// Query executes a query and calls fn for each row +func (db *DB) Query(query string, fn QueryFunc, args ...any) error { + stmt, err := db.conn.Prepare(query) + if err != nil { + return fmt.Errorf("prepare failed: %w", err) + } + defer stmt.Finalize() + + // Bind parameters + for i, arg := range args { + if err := bindParam(stmt, i+1, arg); err != nil { + return err + } + } + + row := &Row{stmt: stmt} + for { + hasRow, err := stmt.Step() + if err != nil { + return fmt.Errorf("query failed: %w", err) + } + if !hasRow { + break + } + + if err := fn(row); err != nil { + return err + } + } + + return nil +} + +// QuerySlice executes a query and returns all rows in a slice +func (db *DB) QuerySlice(query string, args ...any) ([]*Row, error) { + var rows []*Row + + stmt, err := db.conn.Prepare(query) + if err != nil { + return nil, fmt.Errorf("prepare failed: %w", err) + } + defer stmt.Finalize() + + // Bind parameters + for i, arg := range args { + if err := bindParam(stmt, i+1, arg); err != nil { + return nil, err + } + } + + for { + hasRow, err := stmt.Step() + if err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + if !hasRow { + break + } + + // Create a snapshot of the current row + rowData := &Row{stmt: stmt} + rows = append(rows, rowData) + } + + return rows, nil +} + +// LastInsertID returns the last inserted row ID +func (db *DB) LastInsertID() int64 { + return db.conn.LastInsertRowID() +} + +// Changes returns the number of rows affected by the last statement +func (db *DB) Changes() int { + return db.conn.Changes() +} + +// Transaction executes fn within a database transaction +func (db *DB) Transaction(fn func(*DB) error) error { + if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil { + return fmt.Errorf("begin transaction failed: %w", err) + } + + if err := fn(db); err != nil { + sqlitex.ExecuteTransient(db.conn, "ROLLBACK", nil) + return err + } + + if err := sqlitex.ExecuteTransient(db.conn, "COMMIT", nil); err != nil { + return fmt.Errorf("commit transaction failed: %w", err) + } + + return nil +} + +// Row column access methods + +// Close releases the row's statement +func (r *Row) Close() { + if r.stmt != nil { + r.stmt.Finalize() + r.stmt = nil + } +} + +// Int returns column as int +func (r *Row) Int(col int) int { + return r.stmt.ColumnInt(col) +} + +// Int64 returns column as int64 +func (r *Row) Int64(col int) int64 { + return r.stmt.ColumnInt64(col) +} + +// Text returns column as string +func (r *Row) Text(col int) string { + return r.stmt.ColumnText(col) +} + +// Bool returns column as bool (0 = false, non-zero = true) +func (r *Row) Bool(col int) bool { + return r.stmt.ColumnInt(col) != 0 +} + +// Float returns column as float64 +func (r *Row) Float(col int) float64 { + return r.stmt.ColumnFloat(col) +} + +// IsNull checks if column is NULL +func (r *Row) IsNull(col int) bool { + return r.stmt.ColumnType(col) == sqlite.TypeNull +} + +// bindParam binds a parameter to a statement at the given index +func bindParam(stmt *sqlite.Stmt, index int, value any) error { + switch v := value.(type) { + case nil: + stmt.BindNull(index) + case int: + stmt.BindInt64(index, int64(v)) + case int8: + stmt.BindInt64(index, int64(v)) + case int16: + stmt.BindInt64(index, int64(v)) + case int32: + stmt.BindInt64(index, int64(v)) + case int64: + stmt.BindInt64(index, v) + case uint: + stmt.BindInt64(index, int64(v)) + case uint8: + stmt.BindInt64(index, int64(v)) + case uint16: + stmt.BindInt64(index, int64(v)) + case uint32: + stmt.BindInt64(index, int64(v)) + case uint64: + stmt.BindInt64(index, int64(v)) + case float32: + stmt.BindFloat(index, float64(v)) + case float64: + stmt.BindFloat(index, v) + case bool: + if v { + stmt.BindInt64(index, 1) + } else { + stmt.BindInt64(index, 0) + } + case string: + stmt.BindText(index, v) + case []byte: + stmt.BindBytes(index, v) + default: + return fmt.Errorf("unsupported parameter type: %T", value) + } + return nil +}