package sql import ( "context" "fmt" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" ) // SQLiteDriver implements the Driver interface for SQLite type SQLiteDriver struct{} func (d *SQLiteDriver) Name() string { return "sqlite" } func (d *SQLiteDriver) Open(dsn string) (Connection, error) { conn, err := sqlite.OpenConn(dsn, sqlite.OpenReadWrite|sqlite.OpenCreate) if err != nil { return nil, fmt.Errorf("sqlite: failed to open database: %w", err) } return &SQLiteConnection{conn: conn}, nil } // SQLiteConnection implements the Connection interface type SQLiteConnection struct { conn *sqlite.Conn } func (c *SQLiteConnection) Close() error { return c.conn.Close() } func (c *SQLiteConnection) Ping(ctx context.Context) error { return sqlitex.ExecuteTransient(c.conn, "SELECT 1", nil) } func (c *SQLiteConnection) Begin(ctx context.Context) (Transaction, error) { if err := sqlitex.ExecuteTransient(c.conn, "BEGIN", nil); err != nil { return nil, fmt.Errorf("sqlite: failed to begin transaction: %w", err) } return &SQLiteTransaction{conn: c.conn}, nil } func (c *SQLiteConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) { stmt, err := c.conn.Prepare(query) if err != nil { return nil, fmt.Errorf("sqlite: failed to prepare query: %w", err) } if err := c.bindArgs(stmt, args...); err != nil { stmt.Finalize() return nil, err } return &SQLiteRows{stmt: stmt, hasNext: true}, nil } func (c *SQLiteConnection) QueryRow(ctx context.Context, query string, args ...any) Row { rows, err := c.Query(ctx, query, args...) if err != nil { return &SQLiteRow{err: err} } return &SQLiteRow{rows: rows.(*SQLiteRows)} } func (c *SQLiteConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) { stmt, err := c.conn.Prepare(query) if err != nil { return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err) } defer stmt.Finalize() if err := c.bindArgs(stmt, args...); err != nil { return nil, err } hasRow, err := stmt.Step() if err != nil { return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err) } // Consume all rows if any for hasRow { hasRow, err = stmt.Step() if err != nil { return nil, fmt.Errorf("sqlite: error stepping through results: %w", err) } } return &SQLiteResult{ lastInsertID: c.conn.LastInsertRowID(), rowsAffected: c.conn.Changes(), }, nil } func (c *SQLiteConnection) Prepare(ctx context.Context, query string) (Statement, error) { stmt, err := c.conn.Prepare(query) if err != nil { return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err) } return &SQLiteStatement{stmt: stmt, conn: c.conn}, nil } func (c *SQLiteConnection) bindArgs(stmt *sqlite.Stmt, args ...any) error { for i, arg := range args { paramIndex := i + 1 if arg == nil { stmt.BindNull(paramIndex) continue } switch v := arg.(type) { case int: stmt.BindInt64(paramIndex, int64(v)) case int64: stmt.BindInt64(paramIndex, v) case float64: stmt.BindFloat(paramIndex, v) case string: stmt.BindText(paramIndex, v) case bool: if v { stmt.BindInt64(paramIndex, 1) } else { stmt.BindInt64(paramIndex, 0) } case []byte: stmt.BindBytes(paramIndex, v) default: return fmt.Errorf("sqlite: unsupported parameter type: %T", arg) } } return nil } // SQLiteTransaction implements the Transaction interface type SQLiteTransaction struct { conn *sqlite.Conn } func (t *SQLiteTransaction) Commit() error { return sqlitex.ExecuteTransient(t.conn, "COMMIT", nil) } func (t *SQLiteTransaction) Rollback() error { return sqlitex.ExecuteTransient(t.conn, "ROLLBACK", nil) } func (t *SQLiteTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) { conn := &SQLiteConnection{conn: t.conn} return conn.Query(ctx, query, args...) } func (t *SQLiteTransaction) QueryRow(ctx context.Context, query string, args ...any) Row { conn := &SQLiteConnection{conn: t.conn} return conn.QueryRow(ctx, query, args...) } func (t *SQLiteTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) { conn := &SQLiteConnection{conn: t.conn} return conn.Exec(ctx, query, args...) } func (t *SQLiteTransaction) Prepare(ctx context.Context, query string) (Statement, error) { conn := &SQLiteConnection{conn: t.conn} return conn.Prepare(ctx, query) } // SQLiteRows implements the Rows interface type SQLiteRows struct { stmt *sqlite.Stmt hasNext bool err error } func (r *SQLiteRows) Next() bool { if r.err != nil { return false } if !r.hasNext { return false } var err error r.hasNext, err = r.stmt.Step() if err != nil { r.err = err return false } return r.hasNext } func (r *SQLiteRows) Scan(dest ...any) error { if r.err != nil { return r.err } for i, d := range dest { if i >= r.stmt.ColumnCount() { break } switch ptr := d.(type) { case *any: *ptr = r.getValue(i) case *string: *ptr = r.stmt.ColumnText(i) case *int: *ptr = int(r.stmt.ColumnInt64(i)) case *int64: *ptr = r.stmt.ColumnInt64(i) case *float64: *ptr = r.stmt.ColumnFloat(i) case *bool: *ptr = r.stmt.ColumnInt64(i) != 0 case *[]byte: if r.stmt.ColumnType(i) == sqlite.TypeBlob { // Get blob size first size := r.stmt.ColumnBytes(i, nil) if size == 0 { *ptr = []byte{} } else { buf := make([]byte, size) r.stmt.ColumnBytes(i, buf) *ptr = buf } } else { // Convert text to bytes *ptr = []byte(r.stmt.ColumnText(i)) } default: return fmt.Errorf("sqlite: unsupported scan destination type: %T", d) } } return nil } func (r *SQLiteRows) getValue(index int) any { switch r.stmt.ColumnType(index) { case sqlite.TypeInteger: return r.stmt.ColumnInt64(index) case sqlite.TypeFloat: return r.stmt.ColumnFloat(index) case sqlite.TypeText: return r.stmt.ColumnText(index) case sqlite.TypeBlob: // For blob columns, we need to handle this differently // First, get the size by calling with nil buffer size := r.stmt.ColumnBytes(index, nil) if size == 0 { return []byte{} } // Now allocate buffer and get the actual data buf := make([]byte, size) r.stmt.ColumnBytes(index, buf) return buf case sqlite.TypeNull: return nil default: return r.stmt.ColumnText(index) } } func (r *SQLiteRows) Columns() ([]string, error) { if r.err != nil { return nil, r.err } columns := make([]string, r.stmt.ColumnCount()) for i := range columns { columns[i] = r.stmt.ColumnName(i) } return columns, nil } func (r *SQLiteRows) Close() error { if r.stmt != nil { return r.stmt.Finalize() } return nil } func (r *SQLiteRows) Err() error { return r.err } // SQLiteRow implements the Row interface type SQLiteRow struct { rows *SQLiteRows err error } func (r *SQLiteRow) Scan(dest ...any) error { if r.err != nil { return r.err } if r.rows == nil { return fmt.Errorf("sqlite: no rows available") } if !r.rows.Next() { if r.rows.Err() != nil { return r.rows.Err() } return fmt.Errorf("sqlite: no rows in result set") } return r.rows.Scan(dest...) } // SQLiteResult implements the Result interface type SQLiteResult struct { lastInsertID int64 rowsAffected int } func (r *SQLiteResult) LastInsertId() (int64, error) { return r.lastInsertID, nil } func (r *SQLiteResult) RowsAffected() (int64, error) { return int64(r.rowsAffected), nil } // SQLiteStatement implements the Statement interface type SQLiteStatement struct { stmt *sqlite.Stmt conn *sqlite.Conn } func (s *SQLiteStatement) Close() error { return s.stmt.Finalize() } func (s *SQLiteStatement) Query(ctx context.Context, args ...any) (Rows, error) { conn := &SQLiteConnection{conn: s.conn} if err := conn.bindArgs(s.stmt, args...); err != nil { return nil, err } return &SQLiteRows{stmt: s.stmt, hasNext: true}, nil } func (s *SQLiteStatement) QueryRow(ctx context.Context, args ...any) Row { rows, err := s.Query(ctx, args...) if err != nil { return &SQLiteRow{err: err} } return &SQLiteRow{rows: rows.(*SQLiteRows)} } func (s *SQLiteStatement) Exec(ctx context.Context, args ...any) (Result, error) { conn := &SQLiteConnection{conn: s.conn} if err := conn.bindArgs(s.stmt, args...); err != nil { return nil, err } hasRow, err := s.stmt.Step() if err != nil { return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err) } // Consume all rows if any for hasRow { hasRow, err = s.stmt.Step() if err != nil { return nil, fmt.Errorf("sqlite: error stepping through results: %w", err) } } return &SQLiteResult{ lastInsertID: s.conn.LastInsertRowID(), rowsAffected: s.conn.Changes(), }, nil }