Moonshark/modules/sql/sqlite.go

385 lines
8.5 KiB
Go

package sql
import (
"context"
"fmt"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
// SQLiteDriver implements the Driver interface for SQLite
type SQLiteDriver struct{}
func (d *SQLiteDriver) Name() string {
return "sqlite"
}
func (d *SQLiteDriver) Open(dsn string) (Connection, error) {
conn, err := sqlite.OpenConn(dsn, sqlite.OpenReadWrite|sqlite.OpenCreate)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to open database: %w", err)
}
return &SQLiteConnection{conn: conn}, nil
}
// SQLiteConnection implements the Connection interface
type SQLiteConnection struct {
conn *sqlite.Conn
}
func (c *SQLiteConnection) Close() error {
return c.conn.Close()
}
func (c *SQLiteConnection) Ping(ctx context.Context) error {
return sqlitex.ExecuteTransient(c.conn, "SELECT 1", nil)
}
func (c *SQLiteConnection) Begin(ctx context.Context) (Transaction, error) {
if err := sqlitex.ExecuteTransient(c.conn, "BEGIN", nil); err != nil {
return nil, fmt.Errorf("sqlite: failed to begin transaction: %w", err)
}
return &SQLiteTransaction{conn: c.conn}, nil
}
func (c *SQLiteConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
stmt, err := c.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to prepare query: %w", err)
}
if err := c.bindArgs(stmt, args...); err != nil {
stmt.Finalize()
return nil, err
}
return &SQLiteRows{stmt: stmt, hasNext: true}, nil
}
func (c *SQLiteConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
rows, err := c.Query(ctx, query, args...)
if err != nil {
return &SQLiteRow{err: err}
}
return &SQLiteRow{rows: rows.(*SQLiteRows)}
}
func (c *SQLiteConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
stmt, err := c.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
}
defer stmt.Finalize()
if err := c.bindArgs(stmt, args...); err != nil {
return nil, err
}
hasRow, err := stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
}
// Consume all rows if any
for hasRow {
hasRow, err = stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
}
}
return &SQLiteResult{
lastInsertID: c.conn.LastInsertRowID(),
rowsAffected: c.conn.Changes(),
}, nil
}
func (c *SQLiteConnection) Prepare(ctx context.Context, query string) (Statement, error) {
stmt, err := c.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
}
return &SQLiteStatement{stmt: stmt, conn: c.conn}, nil
}
func (c *SQLiteConnection) bindArgs(stmt *sqlite.Stmt, args ...any) error {
for i, arg := range args {
paramIndex := i + 1
if arg == nil {
stmt.BindNull(paramIndex)
continue
}
switch v := arg.(type) {
case int:
stmt.BindInt64(paramIndex, int64(v))
case int64:
stmt.BindInt64(paramIndex, v)
case float64:
stmt.BindFloat(paramIndex, v)
case string:
stmt.BindText(paramIndex, v)
case bool:
if v {
stmt.BindInt64(paramIndex, 1)
} else {
stmt.BindInt64(paramIndex, 0)
}
case []byte:
stmt.BindBytes(paramIndex, v)
default:
return fmt.Errorf("sqlite: unsupported parameter type: %T", arg)
}
}
return nil
}
// SQLiteTransaction implements the Transaction interface
type SQLiteTransaction struct {
conn *sqlite.Conn
}
func (t *SQLiteTransaction) Commit() error {
return sqlitex.ExecuteTransient(t.conn, "COMMIT", nil)
}
func (t *SQLiteTransaction) Rollback() error {
return sqlitex.ExecuteTransient(t.conn, "ROLLBACK", nil)
}
func (t *SQLiteTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
conn := &SQLiteConnection{conn: t.conn}
return conn.Query(ctx, query, args...)
}
func (t *SQLiteTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
conn := &SQLiteConnection{conn: t.conn}
return conn.QueryRow(ctx, query, args...)
}
func (t *SQLiteTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
conn := &SQLiteConnection{conn: t.conn}
return conn.Exec(ctx, query, args...)
}
func (t *SQLiteTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
conn := &SQLiteConnection{conn: t.conn}
return conn.Prepare(ctx, query)
}
// SQLiteRows implements the Rows interface
type SQLiteRows struct {
stmt *sqlite.Stmt
hasNext bool
err error
}
func (r *SQLiteRows) Next() bool {
if r.err != nil {
return false
}
if !r.hasNext {
return false
}
var err error
r.hasNext, err = r.stmt.Step()
if err != nil {
r.err = err
return false
}
return r.hasNext
}
func (r *SQLiteRows) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
for i, d := range dest {
if i >= r.stmt.ColumnCount() {
break
}
switch ptr := d.(type) {
case *any:
*ptr = r.getValue(i)
case *string:
*ptr = r.stmt.ColumnText(i)
case *int:
*ptr = int(r.stmt.ColumnInt64(i))
case *int64:
*ptr = r.stmt.ColumnInt64(i)
case *float64:
*ptr = r.stmt.ColumnFloat(i)
case *bool:
*ptr = r.stmt.ColumnInt64(i) != 0
case *[]byte:
if r.stmt.ColumnType(i) == sqlite.TypeBlob {
// Get blob size first
size := r.stmt.ColumnBytes(i, nil)
if size == 0 {
*ptr = []byte{}
} else {
buf := make([]byte, size)
r.stmt.ColumnBytes(i, buf)
*ptr = buf
}
} else {
// Convert text to bytes
*ptr = []byte(r.stmt.ColumnText(i))
}
default:
return fmt.Errorf("sqlite: unsupported scan destination type: %T", d)
}
}
return nil
}
func (r *SQLiteRows) getValue(index int) any {
switch r.stmt.ColumnType(index) {
case sqlite.TypeInteger:
return r.stmt.ColumnInt64(index)
case sqlite.TypeFloat:
return r.stmt.ColumnFloat(index)
case sqlite.TypeText:
return r.stmt.ColumnText(index)
case sqlite.TypeBlob:
// For blob columns, we need to handle this differently
// First, get the size by calling with nil buffer
size := r.stmt.ColumnBytes(index, nil)
if size == 0 {
return []byte{}
}
// Now allocate buffer and get the actual data
buf := make([]byte, size)
r.stmt.ColumnBytes(index, buf)
return buf
case sqlite.TypeNull:
return nil
default:
return r.stmt.ColumnText(index)
}
}
func (r *SQLiteRows) Columns() ([]string, error) {
if r.err != nil {
return nil, r.err
}
columns := make([]string, r.stmt.ColumnCount())
for i := range columns {
columns[i] = r.stmt.ColumnName(i)
}
return columns, nil
}
func (r *SQLiteRows) Close() error {
if r.stmt != nil {
return r.stmt.Finalize()
}
return nil
}
func (r *SQLiteRows) Err() error {
return r.err
}
// SQLiteRow implements the Row interface
type SQLiteRow struct {
rows *SQLiteRows
err error
}
func (r *SQLiteRow) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
if r.rows == nil {
return fmt.Errorf("sqlite: no rows available")
}
if !r.rows.Next() {
if r.rows.Err() != nil {
return r.rows.Err()
}
return fmt.Errorf("sqlite: no rows in result set")
}
return r.rows.Scan(dest...)
}
// SQLiteResult implements the Result interface
type SQLiteResult struct {
lastInsertID int64
rowsAffected int
}
func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.lastInsertID, nil
}
func (r *SQLiteResult) RowsAffected() (int64, error) {
return int64(r.rowsAffected), nil
}
// SQLiteStatement implements the Statement interface
type SQLiteStatement struct {
stmt *sqlite.Stmt
conn *sqlite.Conn
}
func (s *SQLiteStatement) Close() error {
return s.stmt.Finalize()
}
func (s *SQLiteStatement) Query(ctx context.Context, args ...any) (Rows, error) {
conn := &SQLiteConnection{conn: s.conn}
if err := conn.bindArgs(s.stmt, args...); err != nil {
return nil, err
}
return &SQLiteRows{stmt: s.stmt, hasNext: true}, nil
}
func (s *SQLiteStatement) QueryRow(ctx context.Context, args ...any) Row {
rows, err := s.Query(ctx, args...)
if err != nil {
return &SQLiteRow{err: err}
}
return &SQLiteRow{rows: rows.(*SQLiteRows)}
}
func (s *SQLiteStatement) Exec(ctx context.Context, args ...any) (Result, error) {
conn := &SQLiteConnection{conn: s.conn}
if err := conn.bindArgs(s.stmt, args...); err != nil {
return nil, err
}
hasRow, err := s.stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
}
// Consume all rows if any
for hasRow {
hasRow, err = s.stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
}
}
return &SQLiteResult{
lastInsertID: s.conn.LastInsertRowID(),
rowsAffected: s.conn.Changes(),
}, nil
}