169 lines
3.9 KiB
Go
169 lines
3.9 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"runtime"
|
|
|
|
"zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
)
|
|
|
|
const DefaultPath = "dk.db"
|
|
|
|
// DB wraps a SQLite connection pool with simplified methods
|
|
type DB struct {
|
|
pool *sqlitex.Pool
|
|
}
|
|
|
|
// Open creates a new database connection pool
|
|
func Open(path string) (*DB, error) {
|
|
if path == "" {
|
|
path = DefaultPath
|
|
}
|
|
|
|
poolSize := max(runtime.GOMAXPROCS(0), 2)
|
|
|
|
pool, err := sqlitex.NewPool(path, sqlitex.PoolOptions{
|
|
PoolSize: poolSize,
|
|
Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database pool: %w", err)
|
|
}
|
|
|
|
conn, err := pool.Take(context.Background())
|
|
if err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
|
}
|
|
defer pool.Put(conn)
|
|
|
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("failed to set WAL mode: %w", err)
|
|
}
|
|
|
|
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("failed to set synchronous mode: %w", err)
|
|
}
|
|
|
|
return &DB{pool: pool}, nil
|
|
}
|
|
|
|
// Close closes the database connection pool
|
|
func (db *DB) Close() error {
|
|
return db.pool.Close()
|
|
}
|
|
|
|
// Exec executes a SQL statement without returning results
|
|
func (db *DB) Exec(query string, args ...any) error {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection from pool: %w", err)
|
|
}
|
|
defer db.pool.Put(conn)
|
|
|
|
if len(args) == 0 {
|
|
return sqlitex.ExecuteTransient(conn, query, nil)
|
|
}
|
|
|
|
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
})
|
|
}
|
|
|
|
// Query executes a SQL query and calls fn for each row
|
|
func (db *DB) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection from pool: %w", err)
|
|
}
|
|
defer db.pool.Put(conn)
|
|
|
|
if len(args) == 0 {
|
|
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
|
|
ResultFunc: fn,
|
|
})
|
|
}
|
|
|
|
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
ResultFunc: fn,
|
|
})
|
|
}
|
|
|
|
// Begin starts a new transaction
|
|
func (db *DB) Begin() (*Tx, error) {
|
|
conn, err := db.pool.Take(context.Background())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
|
}
|
|
|
|
if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil {
|
|
db.pool.Put(conn)
|
|
return nil, fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
return &Tx{conn: conn, pool: db.pool}, nil
|
|
}
|
|
|
|
// Transaction runs a function within a transaction
|
|
func (db *DB) Transaction(fn func(*Tx) error) error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := fn(tx); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Tx represents a database transaction
|
|
type Tx struct {
|
|
conn *sqlite.Conn
|
|
pool *sqlitex.Pool
|
|
}
|
|
|
|
// Exec executes a SQL statement within the transaction
|
|
func (tx *Tx) Exec(query string, args ...any) error {
|
|
if len(args) == 0 {
|
|
return sqlitex.ExecuteTransient(tx.conn, query, nil)
|
|
}
|
|
|
|
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
})
|
|
}
|
|
|
|
// Query executes a SQL query within the transaction
|
|
func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
|
|
if len(args) == 0 {
|
|
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
|
|
ResultFunc: fn,
|
|
})
|
|
}
|
|
|
|
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
|
|
Args: args,
|
|
ResultFunc: fn,
|
|
})
|
|
}
|
|
|
|
// Commit commits the transaction
|
|
func (tx *Tx) Commit() error {
|
|
defer tx.pool.Put(tx.conn)
|
|
return sqlitex.ExecuteTransient(tx.conn, "COMMIT", nil)
|
|
}
|
|
|
|
// Rollback rolls back the transaction
|
|
func (tx *Tx) Rollback() error {
|
|
defer tx.pool.Put(tx.conn)
|
|
return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil)
|
|
}
|