package database import ( "context" "fmt" "runtime" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) const DefaultPath = "dk.db" // database wraps a SQLite connection pool with simplified methods type database struct { pool *sqlitex.Pool } // DB is a backward-compatible type alias type DB = database // instance is the global singleton instance var instance *database // Open creates a new database connection pool func Open(path string) (*database, error) { if path == "" { path = DefaultPath } poolSize := max(runtime.GOMAXPROCS(0), 2) pool, err := sqlitex.NewPool(path, sqlitex.PoolOptions{ PoolSize: poolSize, Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL, }) if err != nil { return nil, fmt.Errorf("failed to open database pool: %w", err) } conn, err := pool.Take(context.Background()) if err != nil { pool.Close() return nil, fmt.Errorf("failed to get connection from pool: %w", err) } defer pool.Put(conn) if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil { pool.Close() return nil, fmt.Errorf("failed to set WAL mode: %w", err) } if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil { pool.Close() return nil, fmt.Errorf("failed to set synchronous mode: %w", err) } return &database{pool: pool}, nil } // Close closes the database connection pool func (db *database) Close() error { return db.pool.Close() } // GetConn gets a connection from the pool - caller must call Put when done func (db *database) GetConn(ctx context.Context) (*sqlite.Conn, error) { return db.pool.Take(ctx) } // PutConn returns a connection to the pool func (db *database) PutConn(conn *sqlite.Conn) { db.pool.Put(conn) } // Exec executes a SQL statement without returning results func (db *database) Exec(query string, args ...any) error { conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection from pool: %w", err) } defer db.pool.Put(conn) if len(args) == 0 { return sqlitex.ExecuteTransient(conn, query, nil) } return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Query executes a SQL query and calls fn for each row func (db *database) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { conn, err := db.pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection from pool: %w", err) } defer db.pool.Put(conn) if len(args) == 0 { return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ ResultFunc: fn, }) } return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ Args: args, ResultFunc: fn, }) } // Begin starts a new transaction func (db *database) Begin() (*Tx, error) { conn, err := db.pool.Take(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get connection from pool: %w", err) } if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil { db.pool.Put(conn) return nil, fmt.Errorf("failed to begin transaction: %w", err) } return &Tx{conn: conn, pool: db.pool}, nil } // Transaction runs a function within a transaction func (db *database) Transaction(fn func(*Tx) error) error { tx, err := db.Begin() if err != nil { return err } if err := fn(tx); err != nil { tx.Rollback() return err } return tx.Commit() } // Tx represents a database transaction type Tx struct { conn *sqlite.Conn pool *sqlitex.Pool } // Exec executes a SQL statement within the transaction func (tx *Tx) Exec(query string, args ...any) error { if len(args) == 0 { return sqlitex.ExecuteTransient(tx.conn, query, nil) } return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Query executes a SQL query within the transaction func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { if len(args) == 0 { return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ ResultFunc: fn, }) } return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ Args: args, ResultFunc: fn, }) } // Commit commits the transaction func (tx *Tx) Commit() error { defer tx.pool.Put(tx.conn) return sqlitex.ExecuteTransient(tx.conn, "COMMIT", nil) } // Rollback rolls back the transaction func (tx *Tx) Rollback() error { defer tx.pool.Put(tx.conn) return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil) } // InitializeDB initializes the global DB singleton func InitializeDB(path string) error { db, err := Open(path) if err != nil { return err } instance = db return nil } // GetDB returns the global database instance func GetDB() *DB { return instance } // Global convenience functions that use the singleton // Exec executes a SQL statement without returning results using the global DB func Exec(query string, args ...any) error { if instance == nil { return fmt.Errorf("database not initialized") } return instance.Exec(query, args...) } // Query executes a SQL query and calls fn for each row using the global DB func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { if instance == nil { return fmt.Errorf("database not initialized") } return instance.Query(query, fn, args...) } // Begin starts a new transaction using the global DB func Begin() (*Tx, error) { if instance == nil { return nil, fmt.Errorf("database not initialized") } return instance.Begin() } // Transaction runs a function within a transaction using the global DB func Transaction(fn func(*Tx) error) error { if instance == nil { return fmt.Errorf("database not initialized") } return instance.Transaction(fn) } // GetConn gets a connection from the pool using the global DB func GetConn(ctx context.Context) (*sqlite.Conn, error) { if instance == nil { return nil, fmt.Errorf("database not initialized") } return instance.GetConn(ctx) } // PutConn returns a connection to the pool using the global DB func PutConn(conn *sqlite.Conn) { if instance != nil { instance.PutConn(conn) } } // Close closes the global database connection pool func Close() error { if instance == nil { return nil } return instance.Close() }