package database import ( "context" "fmt" "runtime" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) const DefaultPath = "dk.db" // DB wraps a SQLite connection pool with simplified methods type DB struct { pool *sqlitex.Pool } // Open creates a new database connection pool func Open(path string) (*DB, error) { if path == "" { path = DefaultPath } poolSize := max(runtime.GOMAXPROCS(0), 2) pool, err := sqlitex.NewPool(path, sqlitex.PoolOptions{ PoolSize: poolSize, Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL, }) if err != nil { return nil, fmt.Errorf("failed to open database pool: %w", err) } conn, err := pool.Take(context.Background()) if err != nil { pool.Close() return nil, fmt.Errorf("failed to get connection from pool: %w", err) } defer pool.Put(conn) if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil { pool.Close() return nil, fmt.Errorf("failed to set WAL mode: %w", err) } if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil { pool.Close() return nil, fmt.Errorf("failed to set synchronous mode: %w", err) } return &DB{pool: pool}, nil } // Close closes the database connection pool func (db *DB) Close() error { return db.pool.Close() } // Exec executes a SQL statement without returning results func (db *DB) Exec(query string, args ...any) error { conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection from pool: %w", err) } defer db.pool.Put(conn) if len(args) == 0 { return sqlitex.ExecuteTransient(conn, query, nil) } return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Query executes a SQL query and calls fn for each row func (db *DB) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection from pool: %w", err) } defer db.pool.Put(conn) if len(args) == 0 { return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ ResultFunc: fn, }) } return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ Args: args, ResultFunc: fn, }) } // Begin starts a new transaction func (db *DB) Begin() (*Tx, error) { conn, err := db.pool.Take(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get connection from pool: %w", err) } if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil { db.pool.Put(conn) return nil, fmt.Errorf("failed to begin transaction: %w", err) } return &Tx{conn: conn, pool: db.pool}, nil } // Transaction runs a function within a transaction func (db *DB) Transaction(fn func(*Tx) error) error { tx, err := db.Begin() if err != nil { return err } if err := fn(tx); err != nil { tx.Rollback() return err } return tx.Commit() } // Tx represents a database transaction type Tx struct { conn *sqlite.Conn pool *sqlitex.Pool } // Exec executes a SQL statement within the transaction func (tx *Tx) Exec(query string, args ...any) error { if len(args) == 0 { return sqlitex.ExecuteTransient(tx.conn, query, nil) } return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Query executes a SQL query within the transaction func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { if len(args) == 0 { return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ ResultFunc: fn, }) } return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ Args: args, ResultFunc: fn, }) } // Commit commits the transaction func (tx *Tx) Commit() error { defer tx.pool.Put(tx.conn) return sqlitex.ExecuteTransient(tx.conn, "COMMIT", nil) } // Rollback rolls back the transaction func (tx *Tx) Rollback() error { defer tx.pool.Put(tx.conn) return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil) }