create thin database wrapper
This commit is contained in:
parent
f6334b97c9
commit
fc82f97cb6
@ -1,18 +1,17 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"eq2emu/internal/database"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"zombiezen.com/go/sqlite"
|
|
||||||
"zombiezen.com/go/sqlite/sqlitex"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database handles all database operations for the login server
|
// Database handles all database operations for the login server
|
||||||
type Database struct {
|
type Database struct {
|
||||||
conn *sqlite.Conn
|
db *database.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseConfig holds database connection settings
|
// DatabaseConfig holds database connection settings
|
||||||
@ -24,32 +23,26 @@ type DatabaseConfig struct {
|
|||||||
|
|
||||||
// NewDatabase creates a new database connection
|
// NewDatabase creates a new database connection
|
||||||
func NewDatabase(config DatabaseConfig) (*Database, error) {
|
func NewDatabase(config DatabaseConfig) (*Database, error) {
|
||||||
// Open SQLite database
|
db, err := database.Open(config.FilePath)
|
||||||
db, err := sqlite.OpenConn(config.FilePath, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enable foreign keys
|
// Set busy timeout if specified
|
||||||
if err := sqlitex.ExecuteTransient(db, "PRAGMA foreign_keys = ON", nil); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set busy timeout
|
|
||||||
if config.BusyTimeout > 0 {
|
if config.BusyTimeout > 0 {
|
||||||
query := fmt.Sprintf("PRAGMA busy_timeout = %d", config.BusyTimeout)
|
query := fmt.Sprintf("PRAGMA busy_timeout = %d", config.BusyTimeout)
|
||||||
if err := sqlitex.ExecuteTransient(db, query, nil); err != nil {
|
if err := db.Exec(query); err != nil {
|
||||||
return nil, fmt.Errorf("failed to set busy timeout: %w", err)
|
return nil, fmt.Errorf("failed to set busy timeout: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("SQLite database connection established")
|
log.Println("SQLite database connection established")
|
||||||
return &Database{conn: db}, nil
|
return &Database{db: db}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the database connection
|
// Close closes the database connection
|
||||||
func (d *Database) Close() error {
|
func (d *Database) Close() error {
|
||||||
return d.conn.Close()
|
return d.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticateAccount verifies user credentials and returns account info
|
// AuthenticateAccount verifies user credentials and returns account info
|
||||||
@ -60,49 +53,37 @@ func (d *Database) AuthenticateAccount(username, password string) (*Account, err
|
|||||||
FROM accounts
|
FROM accounts
|
||||||
WHERE username = ? AND active = 1`
|
WHERE username = ? AND active = 1`
|
||||||
|
|
||||||
stmt, err := d.conn.Prepare(query)
|
row, err := d.db.QueryRow(query, username)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("prepare statement failed: %w", err)
|
|
||||||
}
|
|
||||||
defer stmt.Finalize()
|
|
||||||
|
|
||||||
stmt.BindText(1, username)
|
|
||||||
|
|
||||||
hasRow, err := stmt.Step()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query failed: %w", err)
|
return nil, fmt.Errorf("query failed: %w", err)
|
||||||
}
|
}
|
||||||
if !hasRow {
|
if row == nil {
|
||||||
return nil, nil // Account not found
|
return nil, nil // Account not found
|
||||||
}
|
}
|
||||||
|
defer row.Close()
|
||||||
|
|
||||||
var account Account
|
var account Account
|
||||||
var passwordHash string
|
account.ID = int32(row.Int64(0))
|
||||||
var lastLogin string
|
account.Username = row.Text(1)
|
||||||
var clientVersion int64
|
passwordHash := row.Text(2)
|
||||||
|
account.LSAdmin = row.Bool(3)
|
||||||
account.ID = int32(stmt.ColumnInt64(0))
|
account.WorldAdmin = row.Bool(4)
|
||||||
account.Username = stmt.ColumnText(1)
|
|
||||||
passwordHash = stmt.ColumnText(2)
|
|
||||||
account.LSAdmin = stmt.ColumnInt(3) != 0
|
|
||||||
account.WorldAdmin = stmt.ColumnInt(4) != 0
|
|
||||||
// Skip created_date at index 5 - not needed for authentication
|
// Skip created_date at index 5 - not needed for authentication
|
||||||
lastLogin = stmt.ColumnText(6)
|
lastLogin := row.Text(6)
|
||||||
clientVersion = stmt.ColumnInt64(7)
|
account.ClientVersion = uint16(row.Int64(7))
|
||||||
|
|
||||||
// Verify password
|
// Verify password
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
|
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
|
||||||
return nil, nil // Invalid password
|
return nil, nil // Invalid password
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse timestamps
|
// Parse timestamp
|
||||||
if lastLogin != "" {
|
if lastLogin != "" {
|
||||||
if t, err := time.Parse("2006-01-02 15:04:05", lastLogin); err == nil {
|
if t, err := time.Parse("2006-01-02 15:04:05", lastLogin); err == nil {
|
||||||
account.LastLogin = t
|
account.LastLogin = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account.ClientVersion = uint16(clientVersion)
|
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,14 +94,12 @@ func (d *Database) UpdateAccountLogin(account *Account) error {
|
|||||||
SET last_login = ?, last_ip = ?, client_version = ?
|
SET last_login = ?, last_ip = ?, client_version = ?
|
||||||
WHERE id = ?`
|
WHERE id = ?`
|
||||||
|
|
||||||
return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
|
return d.db.Exec(query,
|
||||||
Args: []any{
|
account.LastLogin.Format("2006-01-02 15:04:05"),
|
||||||
account.LastLogin.Format("2006-01-02 15:04:05"),
|
account.IPAddress,
|
||||||
account.IPAddress,
|
account.ClientVersion,
|
||||||
account.ClientVersion,
|
account.ID,
|
||||||
account.ID,
|
)
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadCharacters loads all characters for an account
|
// LoadCharacters loads all characters for an account
|
||||||
@ -132,45 +111,29 @@ func (d *Database) LoadCharacters(accountID int32, version uint16) ([]*Character
|
|||||||
WHERE account_id = ?
|
WHERE account_id = ?
|
||||||
ORDER BY created_date ASC`
|
ORDER BY created_date ASC`
|
||||||
|
|
||||||
stmt, err := d.conn.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("prepare failed: %w", err)
|
|
||||||
}
|
|
||||||
defer stmt.Finalize()
|
|
||||||
|
|
||||||
stmt.BindInt64(1, int64(accountID))
|
|
||||||
|
|
||||||
var characters []*Character
|
var characters []*Character
|
||||||
for {
|
err := d.db.Query(query, func(row *database.Row) error {
|
||||||
hasRow, err := stmt.Step()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("query failed: %w", err)
|
|
||||||
}
|
|
||||||
if !hasRow {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
char := &Character{AccountID: accountID}
|
char := &Character{AccountID: accountID}
|
||||||
|
char.ID = int32(row.Int64(0))
|
||||||
|
char.ServerID = int32(row.Int64(1))
|
||||||
|
char.Name = row.Text(2)
|
||||||
|
char.Level = int8(row.Int(3))
|
||||||
|
char.Race = int8(row.Int(4))
|
||||||
|
char.Gender = int8(row.Int(5))
|
||||||
|
char.Class = int8(row.Int(6))
|
||||||
|
|
||||||
char.ID = int32(stmt.ColumnInt64(0))
|
if dateStr := row.Text(7); dateStr != "" {
|
||||||
char.ServerID = int32(stmt.ColumnInt64(1))
|
|
||||||
char.Name = stmt.ColumnText(2)
|
|
||||||
char.Level = int8(stmt.ColumnInt(3))
|
|
||||||
char.Race = int8(stmt.ColumnInt(4))
|
|
||||||
char.Gender = int8(stmt.ColumnInt(5))
|
|
||||||
char.Class = int8(stmt.ColumnInt(6))
|
|
||||||
|
|
||||||
if dateStr := stmt.ColumnText(7); dateStr != "" {
|
|
||||||
if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil {
|
if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil {
|
||||||
char.CreatedDate = t
|
char.CreatedDate = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
char.Deleted = stmt.ColumnInt(8) != 0
|
char.Deleted = row.Bool(8)
|
||||||
characters = append(characters, char)
|
characters = append(characters, char)
|
||||||
}
|
return nil
|
||||||
|
}, accountID)
|
||||||
|
|
||||||
return characters, nil
|
return characters, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CharacterNameExists checks if a character name is already taken
|
// CharacterNameExists checks if a character name is already taken
|
||||||
@ -180,24 +143,16 @@ func (d *Database) CharacterNameExists(name string, serverID int32) (bool, error
|
|||||||
FROM characters
|
FROM characters
|
||||||
WHERE name = ? AND server_id = ? AND deleted = 0`
|
WHERE name = ? AND server_id = ? AND deleted = 0`
|
||||||
|
|
||||||
stmt, err := d.conn.Prepare(query)
|
row, err := d.db.QueryRow(query, name, serverID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer stmt.Finalize()
|
if row == nil {
|
||||||
|
return false, nil
|
||||||
stmt.BindText(1, name)
|
|
||||||
stmt.BindInt64(2, int64(serverID))
|
|
||||||
|
|
||||||
hasRow, err := stmt.Step()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if hasRow {
|
|
||||||
return stmt.ColumnInt(0) > 0, nil
|
|
||||||
}
|
}
|
||||||
|
defer row.Close()
|
||||||
|
|
||||||
return false, nil
|
return row.Int(0) > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateCharacter creates a new character in the database
|
// CreateCharacter creates a new character in the database
|
||||||
@ -207,18 +162,16 @@ func (d *Database) CreateCharacter(char *Character) (int32, error) {
|
|||||||
gender, class, created_date)
|
gender, class, created_date)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`
|
||||||
|
|
||||||
err := sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
|
err := d.db.Exec(query,
|
||||||
Args: []any{
|
char.AccountID, char.ServerID, char.Name, char.Level,
|
||||||
char.AccountID, char.ServerID, char.Name, char.Level,
|
char.Race, char.Gender, char.Class,
|
||||||
char.Race, char.Gender, char.Class,
|
char.CreatedDate.Format("2006-01-02 15:04:05"),
|
||||||
char.CreatedDate.Format("2006-01-02 15:04:05"),
|
)
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to create character: %w", err)
|
return 0, fmt.Errorf("failed to create character: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
id := int32(d.conn.LastInsertRowID())
|
id := int32(d.db.LastInsertID())
|
||||||
log.Printf("Created character %s (ID: %d) for account %d",
|
log.Printf("Created character %s (ID: %d) for account %d",
|
||||||
char.Name, id, char.AccountID)
|
char.Name, id, char.AccountID)
|
||||||
|
|
||||||
@ -232,18 +185,13 @@ func (d *Database) DeleteCharacter(charID, accountID int32) error {
|
|||||||
SET deleted = 1, deleted_date = CURRENT_TIMESTAMP
|
SET deleted = 1, deleted_date = CURRENT_TIMESTAMP
|
||||||
WHERE id = ? AND account_id = ?`
|
WHERE id = ? AND account_id = ?`
|
||||||
|
|
||||||
err := sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
|
err := d.db.Exec(query, charID, accountID)
|
||||||
Args: []any{charID, accountID},
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete character: %w", err)
|
return fmt.Errorf("failed to delete character: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if any rows were affected
|
// Check if any rows were affected
|
||||||
stmt, _ := d.conn.Prepare("SELECT changes()")
|
if d.db.Changes() == 0 {
|
||||||
defer stmt.Finalize()
|
|
||||||
hasRow, _ := stmt.Step()
|
|
||||||
if hasRow && stmt.ColumnInt(0) == 0 {
|
|
||||||
return fmt.Errorf("character not found or not owned by account")
|
return fmt.Errorf("character not found or not owned by account")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,35 +208,20 @@ func (d *Database) GetWorldServers() ([]*WorldServer, error) {
|
|||||||
WHERE active = 1
|
WHERE active = 1
|
||||||
ORDER BY sort_order, name`
|
ORDER BY sort_order, name`
|
||||||
|
|
||||||
stmt, err := d.conn.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("prepare failed: %w", err)
|
|
||||||
}
|
|
||||||
defer stmt.Finalize()
|
|
||||||
|
|
||||||
var servers []*WorldServer
|
var servers []*WorldServer
|
||||||
for {
|
err := d.db.Query(query, func(row *database.Row) error {
|
||||||
hasRow, err := stmt.Step()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("query failed: %w", err)
|
|
||||||
}
|
|
||||||
if !hasRow {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &WorldServer{}
|
server := &WorldServer{}
|
||||||
|
server.ID = int32(row.Int64(0))
|
||||||
|
server.Name = row.Text(1)
|
||||||
|
server.Description = row.Text(2)
|
||||||
|
server.IPAddress = row.Text(3)
|
||||||
|
server.Port = row.Int(4)
|
||||||
|
server.Status = row.Text(5)
|
||||||
|
server.Population = int32(row.Int64(6))
|
||||||
|
server.Locked = row.Bool(7)
|
||||||
|
server.Hidden = row.Bool(8)
|
||||||
|
|
||||||
server.ID = int32(stmt.ColumnInt64(0))
|
if dateStr := row.Text(9); dateStr != "" {
|
||||||
server.Name = stmt.ColumnText(1)
|
|
||||||
server.Description = stmt.ColumnText(2)
|
|
||||||
server.IPAddress = stmt.ColumnText(3)
|
|
||||||
server.Port = stmt.ColumnInt(4)
|
|
||||||
server.Status = stmt.ColumnText(5)
|
|
||||||
server.Population = int32(stmt.ColumnInt64(6))
|
|
||||||
server.Locked = stmt.ColumnInt(7) != 0
|
|
||||||
server.Hidden = stmt.ColumnInt(8) != 0
|
|
||||||
|
|
||||||
if dateStr := stmt.ColumnText(9); dateStr != "" {
|
|
||||||
if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil {
|
if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil {
|
||||||
server.CreatedDate = t
|
server.CreatedDate = t
|
||||||
}
|
}
|
||||||
@ -296,11 +229,11 @@ func (d *Database) GetWorldServers() ([]*WorldServer, error) {
|
|||||||
|
|
||||||
server.Online = server.Status == "online"
|
server.Online = server.Status == "online"
|
||||||
server.PopulationLevel = calculatePopulationLevel(server.Population)
|
server.PopulationLevel = calculatePopulationLevel(server.Population)
|
||||||
|
|
||||||
servers = append(servers, server)
|
servers = append(servers, server)
|
||||||
}
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
return servers, nil
|
return servers, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateWorldServerStats updates world server statistics
|
// UpdateWorldServerStats updates world server statistics
|
||||||
@ -310,12 +243,10 @@ func (d *Database) UpdateWorldServerStats(serverID int32, stats *WorldServerStat
|
|||||||
(server_id, timestamp, population, zones_active, players_online, uptime_seconds)
|
(server_id, timestamp, population, zones_active, players_online, uptime_seconds)
|
||||||
VALUES (?, CURRENT_TIMESTAMP, ?, ?, ?, ?)`
|
VALUES (?, CURRENT_TIMESTAMP, ?, ?, ?, ?)`
|
||||||
|
|
||||||
return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
|
return d.db.Exec(query,
|
||||||
Args: []any{
|
serverID, stats.Population,
|
||||||
serverID, stats.Population,
|
stats.ZonesActive, stats.PlayersOnline, stats.UptimeSeconds,
|
||||||
stats.ZonesActive, stats.PlayersOnline, stats.UptimeSeconds,
|
)
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupOldEntries removes old log entries and statistics
|
// CleanupOldEntries removes old log entries and statistics
|
||||||
@ -327,7 +258,7 @@ func (d *Database) CleanupOldEntries() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, query := range queries {
|
for _, query := range queries {
|
||||||
if err := sqlitex.ExecuteTransient(d.conn, query, nil); err != nil {
|
if err := d.db.Exec(query); err != nil {
|
||||||
log.Printf("Cleanup query failed: %v", err)
|
log.Printf("Cleanup query failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -341,9 +272,7 @@ func (d *Database) LogLoginAttempt(username, ipAddress string, success bool) err
|
|||||||
INSERT INTO login_attempts (username, ip_address, success, timestamp)
|
INSERT INTO login_attempts (username, ip_address, success, timestamp)
|
||||||
VALUES (?, ?, ?, CURRENT_TIMESTAMP)`
|
VALUES (?, ?, ?, CURRENT_TIMESTAMP)`
|
||||||
|
|
||||||
return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
|
return d.db.Exec(query, username, ipAddress, success)
|
||||||
Args: []any{username, ipAddress, success},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMaxCharsSetting returns the maximum characters per account
|
// GetMaxCharsSetting returns the maximum characters per account
|
||||||
@ -351,21 +280,15 @@ func (d *Database) GetMaxCharsSetting() int32 {
|
|||||||
var maxChars int32 = 7 // Default
|
var maxChars int32 = 7 // Default
|
||||||
|
|
||||||
query := "SELECT value FROM server_settings WHERE name = 'max_characters_per_account'"
|
query := "SELECT value FROM server_settings WHERE name = 'max_characters_per_account'"
|
||||||
stmt, err := d.conn.Prepare(query)
|
row, err := d.db.QueryRow(query)
|
||||||
if err != nil {
|
if err != nil || row == nil {
|
||||||
return maxChars
|
return maxChars
|
||||||
}
|
}
|
||||||
defer stmt.Finalize()
|
defer row.Close()
|
||||||
|
|
||||||
hasRow, err := stmt.Step()
|
if !row.IsNull(0) {
|
||||||
if err != nil {
|
if val := row.Int64(0); val > 0 {
|
||||||
return maxChars
|
maxChars = int32(val)
|
||||||
}
|
|
||||||
if hasRow {
|
|
||||||
if val := stmt.ColumnText(0); val != "" {
|
|
||||||
if parsed := stmt.ColumnInt64(0); parsed > 0 {
|
|
||||||
maxChars = int32(parsed)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,21 +300,13 @@ func (d *Database) GetAccountBonus(accountID int32) uint8 {
|
|||||||
var bonus uint8 = 0
|
var bonus uint8 = 0
|
||||||
|
|
||||||
query := "SELECT veteran_bonus FROM accounts WHERE id = ?"
|
query := "SELECT veteran_bonus FROM accounts WHERE id = ?"
|
||||||
stmt, err := d.conn.Prepare(query)
|
row, err := d.db.QueryRow(query, accountID)
|
||||||
if err != nil {
|
if err != nil || row == nil {
|
||||||
return bonus
|
return bonus
|
||||||
}
|
}
|
||||||
defer stmt.Finalize()
|
defer row.Close()
|
||||||
|
|
||||||
stmt.BindInt64(1, int64(accountID))
|
|
||||||
hasRow, err := stmt.Step()
|
|
||||||
if err != nil {
|
|
||||||
return bonus
|
|
||||||
}
|
|
||||||
if hasRow {
|
|
||||||
bonus = uint8(stmt.ColumnInt(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
bonus = uint8(row.Int(0))
|
||||||
return bonus
|
return bonus
|
||||||
}
|
}
|
||||||
|
|
||||||
|
262
internal/database/wrapper.go
Normal file
262
internal/database/wrapper.go
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"zombiezen.com/go/sqlite"
|
||||||
|
"zombiezen.com/go/sqlite/sqlitex"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DB wraps sqlite.Conn with simplified query methods
|
||||||
|
type DB struct {
|
||||||
|
conn *sqlite.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row represents a single database row with easy column access
|
||||||
|
type Row struct {
|
||||||
|
stmt *sqlite.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryFunc processes each row in a result set
|
||||||
|
type QueryFunc func(*Row) error
|
||||||
|
|
||||||
|
// Open creates a new database connection with common settings
|
||||||
|
func Open(path string) (*DB, error) {
|
||||||
|
conn, err := sqlite.OpenConn(path, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable foreign keys and WAL mode for better performance
|
||||||
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA foreign_keys = ON", nil); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DB{conn: conn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the database connection
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
return db.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes a statement with parameters
|
||||||
|
func (db *DB) Exec(query string, args ...any) error {
|
||||||
|
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{
|
||||||
|
Args: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRow executes a query expecting a single row result
|
||||||
|
func (db *DB) QueryRow(query string, args ...any) (*Row, error) {
|
||||||
|
stmt, err := db.conn.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("prepare failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind parameters
|
||||||
|
for i, arg := range args {
|
||||||
|
if err := bindParam(stmt, i+1, arg); err != nil {
|
||||||
|
stmt.Finalize()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasRow, err := stmt.Step()
|
||||||
|
if err != nil {
|
||||||
|
stmt.Finalize()
|
||||||
|
return nil, fmt.Errorf("query failed: %w", err)
|
||||||
|
}
|
||||||
|
if !hasRow {
|
||||||
|
stmt.Finalize()
|
||||||
|
return nil, nil // No row found
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Row{stmt: stmt}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query executes a query and calls fn for each row
|
||||||
|
func (db *DB) Query(query string, fn QueryFunc, args ...any) error {
|
||||||
|
stmt, err := db.conn.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("prepare failed: %w", err)
|
||||||
|
}
|
||||||
|
defer stmt.Finalize()
|
||||||
|
|
||||||
|
// Bind parameters
|
||||||
|
for i, arg := range args {
|
||||||
|
if err := bindParam(stmt, i+1, arg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
row := &Row{stmt: stmt}
|
||||||
|
for {
|
||||||
|
hasRow, err := stmt.Step()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("query failed: %w", err)
|
||||||
|
}
|
||||||
|
if !hasRow {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fn(row); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuerySlice executes a query and returns all rows in a slice
|
||||||
|
func (db *DB) QuerySlice(query string, args ...any) ([]*Row, error) {
|
||||||
|
var rows []*Row
|
||||||
|
|
||||||
|
stmt, err := db.conn.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("prepare failed: %w", err)
|
||||||
|
}
|
||||||
|
defer stmt.Finalize()
|
||||||
|
|
||||||
|
// Bind parameters
|
||||||
|
for i, arg := range args {
|
||||||
|
if err := bindParam(stmt, i+1, arg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
hasRow, err := stmt.Step()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query failed: %w", err)
|
||||||
|
}
|
||||||
|
if !hasRow {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a snapshot of the current row
|
||||||
|
rowData := &Row{stmt: stmt}
|
||||||
|
rows = append(rows, rowData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastInsertID returns the last inserted row ID
|
||||||
|
func (db *DB) LastInsertID() int64 {
|
||||||
|
return db.conn.LastInsertRowID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Changes returns the number of rows affected by the last statement
|
||||||
|
func (db *DB) Changes() int {
|
||||||
|
return db.conn.Changes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transaction executes fn within a database transaction
|
||||||
|
func (db *DB) Transaction(fn func(*DB) error) error {
|
||||||
|
if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil {
|
||||||
|
return fmt.Errorf("begin transaction failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fn(db); err != nil {
|
||||||
|
sqlitex.ExecuteTransient(db.conn, "ROLLBACK", nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sqlitex.ExecuteTransient(db.conn, "COMMIT", nil); err != nil {
|
||||||
|
return fmt.Errorf("commit transaction failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row column access methods
|
||||||
|
|
||||||
|
// Close releases the row's statement
|
||||||
|
func (r *Row) Close() {
|
||||||
|
if r.stmt != nil {
|
||||||
|
r.stmt.Finalize()
|
||||||
|
r.stmt = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int returns column as int
|
||||||
|
func (r *Row) Int(col int) int {
|
||||||
|
return r.stmt.ColumnInt(col)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int64 returns column as int64
|
||||||
|
func (r *Row) Int64(col int) int64 {
|
||||||
|
return r.stmt.ColumnInt64(col)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text returns column as string
|
||||||
|
func (r *Row) Text(col int) string {
|
||||||
|
return r.stmt.ColumnText(col)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bool returns column as bool (0 = false, non-zero = true)
|
||||||
|
func (r *Row) Bool(col int) bool {
|
||||||
|
return r.stmt.ColumnInt(col) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float returns column as float64
|
||||||
|
func (r *Row) Float(col int) float64 {
|
||||||
|
return r.stmt.ColumnFloat(col)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNull checks if column is NULL
|
||||||
|
func (r *Row) IsNull(col int) bool {
|
||||||
|
return r.stmt.ColumnType(col) == sqlite.TypeNull
|
||||||
|
}
|
||||||
|
|
||||||
|
// bindParam binds a parameter to a statement at the given index
|
||||||
|
func bindParam(stmt *sqlite.Stmt, index int, value any) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case nil:
|
||||||
|
stmt.BindNull(index)
|
||||||
|
case int:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case int8:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case int16:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case int32:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case int64:
|
||||||
|
stmt.BindInt64(index, v)
|
||||||
|
case uint:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case uint8:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case uint16:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case uint32:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case uint64:
|
||||||
|
stmt.BindInt64(index, int64(v))
|
||||||
|
case float32:
|
||||||
|
stmt.BindFloat(index, float64(v))
|
||||||
|
case float64:
|
||||||
|
stmt.BindFloat(index, v)
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
stmt.BindInt64(index, 1)
|
||||||
|
} else {
|
||||||
|
stmt.BindInt64(index, 0)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
stmt.BindText(index, v)
|
||||||
|
case []byte:
|
||||||
|
stmt.BindBytes(index, v)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported parameter type: %T", value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user