169 lines
3.9 KiB
Go

package database
import (
"context"
"fmt"
"runtime"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
const DefaultPath = "dk.db"
// DB wraps a SQLite connection pool with simplified methods
type DB struct {
pool *sqlitex.Pool
}
// Open creates a new database connection pool
func Open(path string) (*DB, error) {
if path == "" {
path = DefaultPath
}
poolSize := max(runtime.GOMAXPROCS(0), 2)
pool, err := sqlitex.NewPool(path, sqlitex.PoolOptions{
PoolSize: poolSize,
Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL,
})
if err != nil {
return nil, fmt.Errorf("failed to open database pool: %w", err)
}
conn, err := pool.Take(context.Background())
if err != nil {
pool.Close()
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
defer pool.Put(conn)
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
pool.Close()
return nil, fmt.Errorf("failed to set WAL mode: %w", err)
}
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
pool.Close()
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
}
return &DB{pool: pool}, nil
}
// Close closes the database connection pool
func (db *DB) Close() error {
return db.pool.Close()
}
// Exec executes a SQL statement without returning results
func (db *DB) Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection from pool: %w", err)
}
defer db.pool.Put(conn)
if len(args) == 0 {
return sqlitex.ExecuteTransient(conn, query, nil)
}
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
Args: args,
})
}
// Query executes a SQL query and calls fn for each row
func (db *DB) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection from pool: %w", err)
}
defer db.pool.Put(conn)
if len(args) == 0 {
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
ResultFunc: fn,
})
}
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
Args: args,
ResultFunc: fn,
})
}
// Begin starts a new transaction
func (db *DB) Begin() (*Tx, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil {
db.pool.Put(conn)
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
return &Tx{conn: conn, pool: db.pool}, nil
}
// Transaction runs a function within a transaction
func (db *DB) Transaction(fn func(*Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
// Tx represents a database transaction
type Tx struct {
conn *sqlite.Conn
pool *sqlitex.Pool
}
// Exec executes a SQL statement within the transaction
func (tx *Tx) Exec(query string, args ...any) error {
if len(args) == 0 {
return sqlitex.ExecuteTransient(tx.conn, query, nil)
}
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
Args: args,
})
}
// Query executes a SQL query within the transaction
func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
if len(args) == 0 {
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
ResultFunc: fn,
})
}
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
Args: args,
ResultFunc: fn,
})
}
// Commit commits the transaction
func (tx *Tx) Commit() error {
defer tx.pool.Put(tx.conn)
return sqlitex.ExecuteTransient(tx.conn, "COMMIT", nil)
}
// Rollback rolls back the transaction
func (tx *Tx) Rollback() error {
defer tx.pool.Put(tx.conn)
return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil)
}