diff --git a/db.go b/db.go index f273adf..d11d7f2 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,7 @@ package sashimi import ( + "context" "fmt" "reflect" "regexp" @@ -18,21 +19,25 @@ type Stmt = sqlite.Stmt var placeholderRegex = regexp.MustCompile(`%[sd]`) type DB struct { - conn *sqlite.Conn + pool *sqlitex.Pool } -// New creates a new database wrapper instance +// New creates a new database wrapper instance with connection pooling func New(dbPath string) (*DB, error) { - conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate) + // Create connection pool with 10 connections max + pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ + Flags: sqlite.OpenReadWrite | sqlite.OpenCreate, + PoolSize: 10, + }) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return nil, fmt.Errorf("failed to create connection pool: %w", err) } - db := &DB{conn: conn} + db := &DB{pool: pool} - // Configure database + // Configure database using one connection from pool if err := db.configure(); err != nil { - conn.Close() + pool.Close() return nil, err } @@ -41,14 +46,21 @@ func New(dbPath string) (*DB, error) { // configure sets up database pragmas func (db *DB) configure() error { + conn, err := db.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer db.pool.Put(conn) + configs := []string{ "PRAGMA journal_mode=WAL", "PRAGMA cache_size=-65536", // 64MB cache "PRAGMA foreign_keys=ON", + "PRAGMA busy_timeout=5000", // 5 second timeout } for _, config := range configs { - if err := sqlitex.Execute(db.conn, config, nil); err != nil { + if err := sqlitex.Execute(conn, config, nil); err != nil { return fmt.Errorf("failed to configure database: %w", err) } } @@ -56,21 +68,21 @@ func (db *DB) configure() error { return nil } -// Close closes the database connection +// Close closes the database connection pool func (db *DB) Close() error { - if db.conn != nil { - return db.conn.Close() + if db.pool != nil { + return db.pool.Close() } return nil } -// Conn returns the underlying sqlite connection -func (db *DB) Conn() *sqlite.Conn { - return db.conn +// Pool returns the underlying connection pool +func (db *DB) Pool() *sqlitex.Pool { + return db.pool } // Scan scans a SQLite statement result into a struct using field names -func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error { +func (db *DB) Scan(stmt *pooledStmt, dest any) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct { return fmt.Errorf("dest must be a pointer to struct") @@ -117,11 +129,17 @@ func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error { } // Query executes a query with fmt-style placeholders and automatically binds parameters -func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) { +func (db *DB) Query(query string, args ...any) (*pooledStmt, error) { + conn, err := db.pool.Take(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + convertedQuery, paramTypes := convertPlaceholders(query) - stmt, err := db.conn.Prepare(convertedQuery) + stmt, err := conn.Prepare(convertedQuery) if err != nil { + db.pool.Put(conn) return nil, err } @@ -158,7 +176,26 @@ func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) { } } - return stmt, nil + // Create a wrapped statement that releases the connection when finalized + return &pooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil +} + +// pooledStmt wraps a statement to automatically release pool connections +type pooledStmt struct { + *sqlite.Stmt + pool *sqlitex.Pool + conn *sqlite.Conn + finalized bool +} + +func (ps *pooledStmt) Finalize() error { + if !ps.finalized { + err := ps.Stmt.Finalize() + ps.pool.Put(ps.conn) + ps.finalized = true + return err + } + return nil } // Get executes a query and returns the first row @@ -224,6 +261,12 @@ func (db *DB) Select(dest any, query string, args ...any) error { // Exec executes a statement with fmt-style placeholders func (db *DB) Exec(query string, args ...any) error { + conn, err := db.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer db.pool.Put(conn) + convertedQuery, paramTypes := convertPlaceholders(query) sqlArgs := make([]any, len(args)) @@ -245,7 +288,7 @@ func (db *DB) Exec(query string, args ...any) error { } } - return sqlitex.Execute(db.conn, convertedQuery, &sqlitex.ExecOptions{ + return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{ Args: sqlArgs, }) } @@ -256,6 +299,12 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string, return nil // No changes } + conn, err := db.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer db.pool.Put(conn) + // Build UPDATE query setParts := make([]string, 0, len(fields)) args := make([]any, 0, len(fields)+1) @@ -270,13 +319,19 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string, query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", tableName, strings.Join(setParts, ", "), whereField) - return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{ + return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ Args: args, }) } // Insert inserts a struct into the database func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) { + conn, err := db.pool.Take(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get connection: %w", err) + } + defer db.pool.Put(conn) + v := reflect.ValueOf(obj) if v.Kind() == reflect.Pointer { v = v.Elem() @@ -307,7 +362,7 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) - stmt, err := db.conn.Prepare(query) + stmt, err := conn.Prepare(query) if err != nil { return 0, err } @@ -332,27 +387,33 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, return 0, err } - return db.conn.LastInsertRowID(), nil + return conn.LastInsertRowID(), nil } // Transaction executes multiple operations atomically func (db *DB) Transaction(fn func() error) error { + conn, err := db.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer db.pool.Put(conn) + // Begin transaction - if err := sqlitex.Execute(db.conn, "BEGIN", nil); err != nil { + if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil { return err } // Execute operations - err := fn() + err = fn() if err != nil { // Rollback on error - sqlitex.Execute(db.conn, "ROLLBACK", nil) + sqlitex.Execute(conn, "ROLLBACK", nil) return err } // Commit on success - return sqlitex.Execute(db.conn, "COMMIT", nil) + return sqlitex.Execute(conn, "COMMIT", nil) } func convertPlaceholders(query string) (string, []string) { diff --git a/migrate.go b/migrate.go index 606e6f0..e810700 100644 --- a/migrate.go +++ b/migrate.go @@ -1,6 +1,7 @@ package sashimi import ( + "context" "fmt" "io/fs" "os" @@ -39,6 +40,12 @@ func NewMigrator(db *DB, dataDir string) *Migrator { // ensureMigrationsTable creates the migrations tracking table if it doesn't exist func (m *Migrator) ensureMigrationsTable() error { + conn, err := m.db.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer m.db.pool.Put(conn) + query := ` CREATE TABLE IF NOT EXISTS migrations ( number INTEGER PRIMARY KEY, @@ -47,7 +54,7 @@ func (m *Migrator) ensureMigrationsTable() error { executed_at INTEGER NOT NULL ) ` - return sqlitex.Execute(m.db.conn, query, nil) + return sqlitex.Execute(conn, query, nil) } // getExecutedMigrations returns a map of migration numbers that have been executed @@ -155,11 +162,17 @@ func (m *Migrator) Run() error { fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations)) return m.db.Transaction(func() error { + conn, err := m.db.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer m.db.pool.Put(conn) + for _, migration := range pendingMigrations { fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name) // Execute the migration SQL - if err := sqlitex.Execute(m.db.conn, migration.Content, nil); err != nil { + if err := sqlitex.Execute(conn, migration.Content, nil); err != nil { return fmt.Errorf("failed to execute migration %d (%s): %w", migration.Number, migration.Name, err) }