119 lines
2.8 KiB
Go
119 lines
2.8 KiB
Go
package database
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
)
|
|
|
|
const DefaultPath = "dk.db"
|
|
|
|
// DB wraps a SQLite connection with simplified methods
|
|
type DB struct {
|
|
conn *sqlite.Conn
|
|
}
|
|
|
|
// Open creates a new database connection
|
|
func Open(path string) (*DB, error) {
|
|
if path == "" {
|
|
path = DefaultPath
|
|
}
|
|
|
|
conn, err := sqlite.OpenConn(path, sqlite.OpenCreate|sqlite.OpenReadWrite|sqlite.OpenWAL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
// Set pragmas for performance
|
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("failed to set WAL mode: %w", err)
|
|
}
|
|
|
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
|
|
}
|
|
|
|
return &DB{conn: conn}, nil
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func (db *DB) Close() error {
|
|
return db.conn.Close()
|
|
}
|
|
|
|
// Exec executes a SQL statement without returning results
|
|
func (db *DB) Exec(query string, args ...any) error {
|
|
if len(args) == 0 {
|
|
return sqlitex.ExecuteTransient(db.conn, query, nil)
|
|
}
|
|
|
|
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
})
|
|
}
|
|
|
|
// Query executes a SQL query and calls fn for each row
|
|
func (db *DB) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
|
|
if len(args) == 0 {
|
|
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
|
|
ResultFunc: fn,
|
|
})
|
|
}
|
|
|
|
return sqlitex.ExecuteTransient(db.conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
ResultFunc: fn,
|
|
})
|
|
}
|
|
|
|
// Begin starts a new transaction
|
|
func (db *DB) Begin() (*Tx, error) {
|
|
if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil {
|
|
return nil, fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
return &Tx{db: db}, nil
|
|
}
|
|
|
|
// Transaction runs a function within a transaction
|
|
func (db *DB) Transaction(fn func(*Tx) error) error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := fn(tx); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Tx represents a database transaction
|
|
type Tx struct {
|
|
db *DB
|
|
}
|
|
|
|
// Exec executes a SQL statement within the transaction
|
|
func (tx *Tx) Exec(query string, args ...any) error {
|
|
return tx.db.Exec(query, args...)
|
|
}
|
|
|
|
// Query executes a SQL query within the transaction
|
|
func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
|
|
return tx.db.Query(query, fn, args...)
|
|
}
|
|
|
|
// Commit commits the transaction
|
|
func (tx *Tx) Commit() error {
|
|
return sqlitex.ExecuteTransient(tx.db.conn, "COMMIT", nil)
|
|
}
|
|
|
|
// Rollback rolls back the transaction
|
|
func (tx *Tx) Rollback() error {
|
|
return sqlitex.ExecuteTransient(tx.db.conn, "ROLLBACK", nil)
|
|
} |