385 lines
8.5 KiB
Go
385 lines
8.5 KiB
Go
package sql
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
)
|
|
|
|
// SQLiteDriver implements the Driver interface for SQLite
|
|
type SQLiteDriver struct{}
|
|
|
|
func (d *SQLiteDriver) Name() string {
|
|
return "sqlite"
|
|
}
|
|
|
|
func (d *SQLiteDriver) Open(dsn string) (Connection, error) {
|
|
conn, err := sqlite.OpenConn(dsn, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: failed to open database: %w", err)
|
|
}
|
|
|
|
return &SQLiteConnection{conn: conn}, nil
|
|
}
|
|
|
|
// SQLiteConnection implements the Connection interface
|
|
type SQLiteConnection struct {
|
|
conn *sqlite.Conn
|
|
}
|
|
|
|
func (c *SQLiteConnection) Close() error {
|
|
return c.conn.Close()
|
|
}
|
|
|
|
func (c *SQLiteConnection) Ping(ctx context.Context) error {
|
|
return sqlitex.ExecuteTransient(c.conn, "SELECT 1", nil)
|
|
}
|
|
|
|
func (c *SQLiteConnection) Begin(ctx context.Context) (Transaction, error) {
|
|
if err := sqlitex.ExecuteTransient(c.conn, "BEGIN", nil); err != nil {
|
|
return nil, fmt.Errorf("sqlite: failed to begin transaction: %w", err)
|
|
}
|
|
return &SQLiteTransaction{conn: c.conn}, nil
|
|
}
|
|
|
|
func (c *SQLiteConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
|
stmt, err := c.conn.Prepare(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: failed to prepare query: %w", err)
|
|
}
|
|
|
|
if err := c.bindArgs(stmt, args...); err != nil {
|
|
stmt.Finalize()
|
|
return nil, err
|
|
}
|
|
|
|
return &SQLiteRows{stmt: stmt, hasNext: true}, nil
|
|
}
|
|
|
|
func (c *SQLiteConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
|
rows, err := c.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return &SQLiteRow{err: err}
|
|
}
|
|
return &SQLiteRow{rows: rows.(*SQLiteRows)}
|
|
}
|
|
|
|
func (c *SQLiteConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
|
stmt, err := c.conn.Prepare(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
|
|
}
|
|
defer stmt.Finalize()
|
|
|
|
if err := c.bindArgs(stmt, args...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hasRow, err := stmt.Step()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
|
|
}
|
|
|
|
// Consume all rows if any
|
|
for hasRow {
|
|
hasRow, err = stmt.Step()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
|
|
}
|
|
}
|
|
|
|
return &SQLiteResult{
|
|
lastInsertID: c.conn.LastInsertRowID(),
|
|
rowsAffected: c.conn.Changes(),
|
|
}, nil
|
|
}
|
|
|
|
func (c *SQLiteConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
|
stmt, err := c.conn.Prepare(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
|
|
}
|
|
return &SQLiteStatement{stmt: stmt, conn: c.conn}, nil
|
|
}
|
|
|
|
func (c *SQLiteConnection) bindArgs(stmt *sqlite.Stmt, args ...any) error {
|
|
for i, arg := range args {
|
|
paramIndex := i + 1
|
|
|
|
if arg == nil {
|
|
stmt.BindNull(paramIndex)
|
|
continue
|
|
}
|
|
|
|
switch v := arg.(type) {
|
|
case int:
|
|
stmt.BindInt64(paramIndex, int64(v))
|
|
case int64:
|
|
stmt.BindInt64(paramIndex, v)
|
|
case float64:
|
|
stmt.BindFloat(paramIndex, v)
|
|
case string:
|
|
stmt.BindText(paramIndex, v)
|
|
case bool:
|
|
if v {
|
|
stmt.BindInt64(paramIndex, 1)
|
|
} else {
|
|
stmt.BindInt64(paramIndex, 0)
|
|
}
|
|
case []byte:
|
|
stmt.BindBytes(paramIndex, v)
|
|
default:
|
|
return fmt.Errorf("sqlite: unsupported parameter type: %T", arg)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SQLiteTransaction implements the Transaction interface
|
|
type SQLiteTransaction struct {
|
|
conn *sqlite.Conn
|
|
}
|
|
|
|
func (t *SQLiteTransaction) Commit() error {
|
|
return sqlitex.ExecuteTransient(t.conn, "COMMIT", nil)
|
|
}
|
|
|
|
func (t *SQLiteTransaction) Rollback() error {
|
|
return sqlitex.ExecuteTransient(t.conn, "ROLLBACK", nil)
|
|
}
|
|
|
|
func (t *SQLiteTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
|
conn := &SQLiteConnection{conn: t.conn}
|
|
return conn.Query(ctx, query, args...)
|
|
}
|
|
|
|
func (t *SQLiteTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
|
conn := &SQLiteConnection{conn: t.conn}
|
|
return conn.QueryRow(ctx, query, args...)
|
|
}
|
|
|
|
func (t *SQLiteTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
|
conn := &SQLiteConnection{conn: t.conn}
|
|
return conn.Exec(ctx, query, args...)
|
|
}
|
|
|
|
func (t *SQLiteTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
|
conn := &SQLiteConnection{conn: t.conn}
|
|
return conn.Prepare(ctx, query)
|
|
}
|
|
|
|
// SQLiteRows implements the Rows interface
|
|
type SQLiteRows struct {
|
|
stmt *sqlite.Stmt
|
|
hasNext bool
|
|
err error
|
|
}
|
|
|
|
func (r *SQLiteRows) Next() bool {
|
|
if r.err != nil {
|
|
return false
|
|
}
|
|
|
|
if !r.hasNext {
|
|
return false
|
|
}
|
|
|
|
var err error
|
|
r.hasNext, err = r.stmt.Step()
|
|
if err != nil {
|
|
r.err = err
|
|
return false
|
|
}
|
|
|
|
return r.hasNext
|
|
}
|
|
|
|
func (r *SQLiteRows) Scan(dest ...any) error {
|
|
if r.err != nil {
|
|
return r.err
|
|
}
|
|
|
|
for i, d := range dest {
|
|
if i >= r.stmt.ColumnCount() {
|
|
break
|
|
}
|
|
|
|
switch ptr := d.(type) {
|
|
case *any:
|
|
*ptr = r.getValue(i)
|
|
case *string:
|
|
*ptr = r.stmt.ColumnText(i)
|
|
case *int:
|
|
*ptr = int(r.stmt.ColumnInt64(i))
|
|
case *int64:
|
|
*ptr = r.stmt.ColumnInt64(i)
|
|
case *float64:
|
|
*ptr = r.stmt.ColumnFloat(i)
|
|
case *bool:
|
|
*ptr = r.stmt.ColumnInt64(i) != 0
|
|
case *[]byte:
|
|
if r.stmt.ColumnType(i) == sqlite.TypeBlob {
|
|
// Get blob size first
|
|
size := r.stmt.ColumnBytes(i, nil)
|
|
if size == 0 {
|
|
*ptr = []byte{}
|
|
} else {
|
|
buf := make([]byte, size)
|
|
r.stmt.ColumnBytes(i, buf)
|
|
*ptr = buf
|
|
}
|
|
} else {
|
|
// Convert text to bytes
|
|
*ptr = []byte(r.stmt.ColumnText(i))
|
|
}
|
|
default:
|
|
return fmt.Errorf("sqlite: unsupported scan destination type: %T", d)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *SQLiteRows) getValue(index int) any {
|
|
switch r.stmt.ColumnType(index) {
|
|
case sqlite.TypeInteger:
|
|
return r.stmt.ColumnInt64(index)
|
|
case sqlite.TypeFloat:
|
|
return r.stmt.ColumnFloat(index)
|
|
case sqlite.TypeText:
|
|
return r.stmt.ColumnText(index)
|
|
case sqlite.TypeBlob:
|
|
// For blob columns, we need to handle this differently
|
|
// First, get the size by calling with nil buffer
|
|
size := r.stmt.ColumnBytes(index, nil)
|
|
if size == 0 {
|
|
return []byte{}
|
|
}
|
|
// Now allocate buffer and get the actual data
|
|
buf := make([]byte, size)
|
|
r.stmt.ColumnBytes(index, buf)
|
|
return buf
|
|
case sqlite.TypeNull:
|
|
return nil
|
|
default:
|
|
return r.stmt.ColumnText(index)
|
|
}
|
|
}
|
|
|
|
func (r *SQLiteRows) Columns() ([]string, error) {
|
|
if r.err != nil {
|
|
return nil, r.err
|
|
}
|
|
|
|
columns := make([]string, r.stmt.ColumnCount())
|
|
for i := range columns {
|
|
columns[i] = r.stmt.ColumnName(i)
|
|
}
|
|
|
|
return columns, nil
|
|
}
|
|
|
|
func (r *SQLiteRows) Close() error {
|
|
if r.stmt != nil {
|
|
return r.stmt.Finalize()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *SQLiteRows) Err() error {
|
|
return r.err
|
|
}
|
|
|
|
// SQLiteRow implements the Row interface
|
|
type SQLiteRow struct {
|
|
rows *SQLiteRows
|
|
err error
|
|
}
|
|
|
|
func (r *SQLiteRow) Scan(dest ...any) error {
|
|
if r.err != nil {
|
|
return r.err
|
|
}
|
|
|
|
if r.rows == nil {
|
|
return fmt.Errorf("sqlite: no rows available")
|
|
}
|
|
|
|
if !r.rows.Next() {
|
|
if r.rows.Err() != nil {
|
|
return r.rows.Err()
|
|
}
|
|
return fmt.Errorf("sqlite: no rows in result set")
|
|
}
|
|
|
|
return r.rows.Scan(dest...)
|
|
}
|
|
|
|
// SQLiteResult implements the Result interface
|
|
type SQLiteResult struct {
|
|
lastInsertID int64
|
|
rowsAffected int
|
|
}
|
|
|
|
func (r *SQLiteResult) LastInsertId() (int64, error) {
|
|
return r.lastInsertID, nil
|
|
}
|
|
|
|
func (r *SQLiteResult) RowsAffected() (int64, error) {
|
|
return int64(r.rowsAffected), nil
|
|
}
|
|
|
|
// SQLiteStatement implements the Statement interface
|
|
type SQLiteStatement struct {
|
|
stmt *sqlite.Stmt
|
|
conn *sqlite.Conn
|
|
}
|
|
|
|
func (s *SQLiteStatement) Close() error {
|
|
return s.stmt.Finalize()
|
|
}
|
|
|
|
func (s *SQLiteStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
|
conn := &SQLiteConnection{conn: s.conn}
|
|
if err := conn.bindArgs(s.stmt, args...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &SQLiteRows{stmt: s.stmt, hasNext: true}, nil
|
|
}
|
|
|
|
func (s *SQLiteStatement) QueryRow(ctx context.Context, args ...any) Row {
|
|
rows, err := s.Query(ctx, args...)
|
|
if err != nil {
|
|
return &SQLiteRow{err: err}
|
|
}
|
|
return &SQLiteRow{rows: rows.(*SQLiteRows)}
|
|
}
|
|
|
|
func (s *SQLiteStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
|
conn := &SQLiteConnection{conn: s.conn}
|
|
if err := conn.bindArgs(s.stmt, args...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hasRow, err := s.stmt.Step()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
|
|
}
|
|
|
|
// Consume all rows if any
|
|
for hasRow {
|
|
hasRow, err = s.stmt.Step()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
|
|
}
|
|
}
|
|
|
|
return &SQLiteResult{
|
|
lastInsertID: s.conn.LastInsertRowID(),
|
|
rowsAffected: s.conn.Changes(),
|
|
}, nil
|
|
}
|