package database import ( "context" "fmt" "runtime" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) const DefaultPath = "dk.db" // Global singleton instance var pool *sqlitex.Pool // Init initializes the global database connection pool func Init(path string) error { if path == "" { path = DefaultPath } poolSize := max(runtime.GOMAXPROCS(0), 2) var err error pool, err = sqlitex.NewPool(path, sqlitex.PoolOptions{ PoolSize: poolSize, Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL, }) if err != nil { return fmt.Errorf("failed to open database pool: %w", err) } conn, err := pool.Take(context.Background()) if err != nil { pool.Close() return fmt.Errorf("failed to get connection from pool: %w", err) } defer pool.Put(conn) if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil { pool.Close() return fmt.Errorf("failed to set WAL mode: %w", err) } if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil { pool.Close() return fmt.Errorf("failed to set synchronous mode: %w", err) } return nil } // Close closes the global database connection pool func Close() error { if pool == nil { return nil } return pool.Close() } // GetConn gets a connection from the pool - caller must call PutConn when done func GetConn(ctx context.Context) (*sqlite.Conn, error) { if pool == nil { return nil, fmt.Errorf("database not initialized") } return pool.Take(ctx) } // PutConn returns a connection to the pool func PutConn(conn *sqlite.Conn) { if pool != nil { pool.Put(conn) } } // Exec executes a SQL statement without returning results func Exec(query string, args ...any) error { if pool == nil { return fmt.Errorf("database not initialized") } conn, err := pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection from pool: %w", err) } defer pool.Put(conn) if len(args) == 0 { return sqlitex.ExecuteTransient(conn, query, nil) } return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Query executes a SQL query and calls fn for each row func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { if pool == nil { return fmt.Errorf("database not initialized") } conn, err := pool.Take(context.Background()) if err != nil { return fmt.Errorf("failed to get connection from pool: %w", err) } defer pool.Put(conn) if len(args) == 0 { return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ ResultFunc: fn, }) } return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ Args: args, ResultFunc: fn, }) } // Begin starts a new transaction func Begin() (*Tx, error) { if pool == nil { return nil, fmt.Errorf("database not initialized") } conn, err := pool.Take(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get connection from pool: %w", err) } if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil { pool.Put(conn) return nil, fmt.Errorf("failed to begin transaction: %w", err) } return &Tx{conn: conn, pool: pool}, nil } // Transaction runs a function within a transaction func Transaction(fn func(*Tx) error) error { if pool == nil { return fmt.Errorf("database not initialized") } tx, err := Begin() if err != nil { return err } if err := fn(tx); err != nil { tx.Rollback() return err } return tx.Commit() } // Tx represents a database transaction type Tx struct { conn *sqlite.Conn pool *sqlitex.Pool } // Exec executes a SQL statement within the transaction func (tx *Tx) Exec(query string, args ...any) error { if len(args) == 0 { return sqlitex.ExecuteTransient(tx.conn, query, nil) } return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Query executes a SQL query within the transaction func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { if len(args) == 0 { return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ ResultFunc: fn, }) } return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ Args: args, ResultFunc: fn, }) } // Commit commits the transaction func (tx *Tx) Commit() error { defer tx.pool.Put(tx.conn) return sqlitex.ExecuteTransient(tx.conn, "COMMIT", nil) } // Rollback rolls back the transaction func (tx *Tx) Rollback() error { defer tx.pool.Put(tx.conn) return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil) }