package database import ( "fmt" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) // DB wraps sqlite.Conn with simplified query methods type DB struct { conn *sqlite.Conn } // Row represents a single database row with easy column access type Row struct { stmt *sqlite.Stmt } // QueryFunc processes each row in a result set type QueryFunc func(*Row) error // Open creates a new database connection with common settings func Open(path string) (*DB, error) { conn, err := sqlite.OpenConn(path, sqlite.OpenReadWrite|sqlite.OpenCreate) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } // Enable foreign keys and WAL mode for better performance if err := sqlitex.ExecuteTransient(conn, "PRAGMA foreign_keys = ON", nil); err != nil { conn.Close() return nil, fmt.Errorf("failed to enable foreign keys: %w", err) } if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil { conn.Close() return nil, fmt.Errorf("failed to enable WAL mode: %w", err) } return &DB{conn: conn}, nil } // Close closes the database connection func (db *DB) Close() error { return db.conn.Close() } // Exec executes a statement with parameters func (db *DB) Exec(query string, args ...any) error { return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{ Args: args, }) } // QueryRow executes a query expecting a single row result func (db *DB) QueryRow(query string, args ...any) (*Row, error) { stmt, err := db.conn.Prepare(query) if err != nil { return nil, fmt.Errorf("prepare failed: %w", err) } // Bind parameters for i, arg := range args { if err := bindParam(stmt, i+1, arg); err != nil { stmt.Finalize() return nil, err } } hasRow, err := stmt.Step() if err != nil { stmt.Finalize() return nil, fmt.Errorf("query failed: %w", err) } if !hasRow { stmt.Finalize() return nil, nil // No row found } return &Row{stmt: stmt}, nil } // Query executes a query and calls fn for each row func (db *DB) Query(query string, fn QueryFunc, args ...any) error { stmt, err := db.conn.Prepare(query) if err != nil { return fmt.Errorf("prepare failed: %w", err) } defer stmt.Finalize() // Bind parameters for i, arg := range args { if err := bindParam(stmt, i+1, arg); err != nil { return err } } row := &Row{stmt: stmt} for { hasRow, err := stmt.Step() if err != nil { return fmt.Errorf("query failed: %w", err) } if !hasRow { break } if err := fn(row); err != nil { return err } } return nil } // QuerySlice executes a query and returns all rows in a slice func (db *DB) QuerySlice(query string, args ...any) ([]*Row, error) { var rows []*Row stmt, err := db.conn.Prepare(query) if err != nil { return nil, fmt.Errorf("prepare failed: %w", err) } defer stmt.Finalize() // Bind parameters for i, arg := range args { if err := bindParam(stmt, i+1, arg); err != nil { return nil, err } } for { hasRow, err := stmt.Step() if err != nil { return nil, fmt.Errorf("query failed: %w", err) } if !hasRow { break } // Create a snapshot of the current row rowData := &Row{stmt: stmt} rows = append(rows, rowData) } return rows, nil } // LastInsertID returns the last inserted row ID func (db *DB) LastInsertID() int64 { return db.conn.LastInsertRowID() } // Changes returns the number of rows affected by the last statement func (db *DB) Changes() int { return db.conn.Changes() } // Transaction executes fn within a database transaction func (db *DB) Transaction(fn func(*DB) error) error { if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil { return fmt.Errorf("begin transaction failed: %w", err) } if err := fn(db); err != nil { sqlitex.ExecuteTransient(db.conn, "ROLLBACK", nil) return err } if err := sqlitex.ExecuteTransient(db.conn, "COMMIT", nil); err != nil { return fmt.Errorf("commit transaction failed: %w", err) } return nil } // Row column access methods // Close releases the row's statement func (r *Row) Close() { if r.stmt != nil { r.stmt.Finalize() r.stmt = nil } } // Int returns column as int func (r *Row) Int(col int) int { return r.stmt.ColumnInt(col) } // Int64 returns column as int64 func (r *Row) Int64(col int) int64 { return r.stmt.ColumnInt64(col) } // Text returns column as string func (r *Row) Text(col int) string { return r.stmt.ColumnText(col) } // Bool returns column as bool (0 = false, non-zero = true) func (r *Row) Bool(col int) bool { return r.stmt.ColumnInt(col) != 0 } // Float returns column as float64 func (r *Row) Float(col int) float64 { return r.stmt.ColumnFloat(col) } // IsNull checks if column is NULL func (r *Row) IsNull(col int) bool { return r.stmt.ColumnType(col) == sqlite.TypeNull } // bindParam binds a parameter to a statement at the given index func bindParam(stmt *sqlite.Stmt, index int, value any) error { switch v := value.(type) { case nil: stmt.BindNull(index) case int: stmt.BindInt64(index, int64(v)) case int8: stmt.BindInt64(index, int64(v)) case int16: stmt.BindInt64(index, int64(v)) case int32: stmt.BindInt64(index, int64(v)) case int64: stmt.BindInt64(index, v) case uint: stmt.BindInt64(index, int64(v)) case uint8: stmt.BindInt64(index, int64(v)) case uint16: stmt.BindInt64(index, int64(v)) case uint32: stmt.BindInt64(index, int64(v)) case uint64: stmt.BindInt64(index, int64(v)) case float32: stmt.BindFloat(index, float64(v)) case float64: stmt.BindFloat(index, v) case bool: if v { stmt.BindInt64(index, 1) } else { stmt.BindInt64(index, 0) } case string: stmt.BindText(index, v) case []byte: stmt.BindBytes(index, v) default: return fmt.Errorf("unsupported parameter type: %T", value) } return nil }