263 lines
5.7 KiB
Go
263 lines
5.7 KiB
Go
package database
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
)
|
|
|
|
// DB wraps sqlite.Conn with simplified query methods
|
|
type DB struct {
|
|
conn *sqlite.Conn
|
|
}
|
|
|
|
// Row represents a single database row with easy column access
|
|
type Row struct {
|
|
stmt *sqlite.Stmt
|
|
}
|
|
|
|
// QueryFunc processes each row in a result set
|
|
type QueryFunc func(*Row) error
|
|
|
|
// Open creates a new database connection with common settings
|
|
func Open(path string) (*DB, error) {
|
|
conn, err := sqlite.OpenConn(path, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
// Enable foreign keys and WAL mode for better performance
|
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA foreign_keys = ON", nil); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
|
}
|
|
|
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
|
|
}
|
|
|
|
return &DB{conn: conn}, nil
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func (db *DB) Close() error {
|
|
return db.conn.Close()
|
|
}
|
|
|
|
// Exec executes a statement with parameters
|
|
func (db *DB) Exec(query string, args ...any) error {
|
|
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
})
|
|
}
|
|
|
|
// QueryRow executes a query expecting a single row result
|
|
func (db *DB) QueryRow(query string, args ...any) (*Row, error) {
|
|
stmt, err := db.conn.Prepare(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("prepare failed: %w", err)
|
|
}
|
|
|
|
// Bind parameters
|
|
for i, arg := range args {
|
|
if err := bindParam(stmt, i+1, arg); err != nil {
|
|
stmt.Finalize()
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
hasRow, err := stmt.Step()
|
|
if err != nil {
|
|
stmt.Finalize()
|
|
return nil, fmt.Errorf("query failed: %w", err)
|
|
}
|
|
if !hasRow {
|
|
stmt.Finalize()
|
|
return nil, nil // No row found
|
|
}
|
|
|
|
return &Row{stmt: stmt}, nil
|
|
}
|
|
|
|
// Query executes a query and calls fn for each row
|
|
func (db *DB) Query(query string, fn QueryFunc, args ...any) error {
|
|
stmt, err := db.conn.Prepare(query)
|
|
if err != nil {
|
|
return fmt.Errorf("prepare failed: %w", err)
|
|
}
|
|
defer stmt.Finalize()
|
|
|
|
// Bind parameters
|
|
for i, arg := range args {
|
|
if err := bindParam(stmt, i+1, arg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
row := &Row{stmt: stmt}
|
|
for {
|
|
hasRow, err := stmt.Step()
|
|
if err != nil {
|
|
return fmt.Errorf("query failed: %w", err)
|
|
}
|
|
if !hasRow {
|
|
break
|
|
}
|
|
|
|
if err := fn(row); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// QuerySlice executes a query and returns all rows in a slice
|
|
func (db *DB) QuerySlice(query string, args ...any) ([]*Row, error) {
|
|
var rows []*Row
|
|
|
|
stmt, err := db.conn.Prepare(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("prepare failed: %w", err)
|
|
}
|
|
defer stmt.Finalize()
|
|
|
|
// Bind parameters
|
|
for i, arg := range args {
|
|
if err := bindParam(stmt, i+1, arg); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
for {
|
|
hasRow, err := stmt.Step()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query failed: %w", err)
|
|
}
|
|
if !hasRow {
|
|
break
|
|
}
|
|
|
|
// Create a snapshot of the current row
|
|
rowData := &Row{stmt: stmt}
|
|
rows = append(rows, rowData)
|
|
}
|
|
|
|
return rows, nil
|
|
}
|
|
|
|
// LastInsertID returns the last inserted row ID
|
|
func (db *DB) LastInsertID() int64 {
|
|
return db.conn.LastInsertRowID()
|
|
}
|
|
|
|
// Changes returns the number of rows affected by the last statement
|
|
func (db *DB) Changes() int {
|
|
return db.conn.Changes()
|
|
}
|
|
|
|
// Transaction executes fn within a database transaction
|
|
func (db *DB) Transaction(fn func(*DB) error) error {
|
|
if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil {
|
|
return fmt.Errorf("begin transaction failed: %w", err)
|
|
}
|
|
|
|
if err := fn(db); err != nil {
|
|
sqlitex.ExecuteTransient(db.conn, "ROLLBACK", nil)
|
|
return err
|
|
}
|
|
|
|
if err := sqlitex.ExecuteTransient(db.conn, "COMMIT", nil); err != nil {
|
|
return fmt.Errorf("commit transaction failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Row column access methods
|
|
|
|
// Close releases the row's statement
|
|
func (r *Row) Close() {
|
|
if r.stmt != nil {
|
|
r.stmt.Finalize()
|
|
r.stmt = nil
|
|
}
|
|
}
|
|
|
|
// Int returns column as int
|
|
func (r *Row) Int(col int) int {
|
|
return r.stmt.ColumnInt(col)
|
|
}
|
|
|
|
// Int64 returns column as int64
|
|
func (r *Row) Int64(col int) int64 {
|
|
return r.stmt.ColumnInt64(col)
|
|
}
|
|
|
|
// Text returns column as string
|
|
func (r *Row) Text(col int) string {
|
|
return r.stmt.ColumnText(col)
|
|
}
|
|
|
|
// Bool returns column as bool (0 = false, non-zero = true)
|
|
func (r *Row) Bool(col int) bool {
|
|
return r.stmt.ColumnInt(col) != 0
|
|
}
|
|
|
|
// Float returns column as float64
|
|
func (r *Row) Float(col int) float64 {
|
|
return r.stmt.ColumnFloat(col)
|
|
}
|
|
|
|
// IsNull checks if column is NULL
|
|
func (r *Row) IsNull(col int) bool {
|
|
return r.stmt.ColumnType(col) == sqlite.TypeNull
|
|
}
|
|
|
|
// bindParam binds a parameter to a statement at the given index
|
|
func bindParam(stmt *sqlite.Stmt, index int, value any) error {
|
|
switch v := value.(type) {
|
|
case nil:
|
|
stmt.BindNull(index)
|
|
case int:
|
|
stmt.BindInt64(index, int64(v))
|
|
case int8:
|
|
stmt.BindInt64(index, int64(v))
|
|
case int16:
|
|
stmt.BindInt64(index, int64(v))
|
|
case int32:
|
|
stmt.BindInt64(index, int64(v))
|
|
case int64:
|
|
stmt.BindInt64(index, v)
|
|
case uint:
|
|
stmt.BindInt64(index, int64(v))
|
|
case uint8:
|
|
stmt.BindInt64(index, int64(v))
|
|
case uint16:
|
|
stmt.BindInt64(index, int64(v))
|
|
case uint32:
|
|
stmt.BindInt64(index, int64(v))
|
|
case uint64:
|
|
stmt.BindInt64(index, int64(v))
|
|
case float32:
|
|
stmt.BindFloat(index, float64(v))
|
|
case float64:
|
|
stmt.BindFloat(index, v)
|
|
case bool:
|
|
if v {
|
|
stmt.BindInt64(index, 1)
|
|
} else {
|
|
stmt.BindInt64(index, 0)
|
|
}
|
|
case string:
|
|
stmt.BindText(index, v)
|
|
case []byte:
|
|
stmt.BindBytes(index, v)
|
|
default:
|
|
return fmt.Errorf("unsupported parameter type: %T", value)
|
|
}
|
|
return nil
|
|
}
|