add invisible connection pooling

This commit is contained in:
Sky Johnson 2025-08-22 15:48:27 -05:00
parent 344da424a0
commit 29546a2066
2 changed files with 102 additions and 28 deletions

113
db.go
View File

@ -1,6 +1,7 @@
package sashimi package sashimi
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -18,21 +19,25 @@ type Stmt = sqlite.Stmt
var placeholderRegex = regexp.MustCompile(`%[sd]`) var placeholderRegex = regexp.MustCompile(`%[sd]`)
type DB struct { type DB struct {
conn *sqlite.Conn pool *sqlitex.Pool
} }
// New creates a new database wrapper instance // New creates a new database wrapper instance with connection pooling
func New(dbPath string) (*DB, error) { func New(dbPath string) (*DB, error) {
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate) // Create connection pool with 10 connections max
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
Flags: sqlite.OpenReadWrite | sqlite.OpenCreate,
PoolSize: 10,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to create connection pool: %w", err)
} }
db := &DB{conn: conn} db := &DB{pool: pool}
// Configure database // Configure database using one connection from pool
if err := db.configure(); err != nil { if err := db.configure(); err != nil {
conn.Close() pool.Close()
return nil, err return nil, err
} }
@ -41,14 +46,21 @@ func New(dbPath string) (*DB, error) {
// configure sets up database pragmas // configure sets up database pragmas
func (db *DB) configure() error { func (db *DB) configure() error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
configs := []string{ configs := []string{
"PRAGMA journal_mode=WAL", "PRAGMA journal_mode=WAL",
"PRAGMA cache_size=-65536", // 64MB cache "PRAGMA cache_size=-65536", // 64MB cache
"PRAGMA foreign_keys=ON", "PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000", // 5 second timeout
} }
for _, config := range configs { for _, config := range configs {
if err := sqlitex.Execute(db.conn, config, nil); err != nil { if err := sqlitex.Execute(conn, config, nil); err != nil {
return fmt.Errorf("failed to configure database: %w", err) return fmt.Errorf("failed to configure database: %w", err)
} }
} }
@ -56,21 +68,21 @@ func (db *DB) configure() error {
return nil return nil
} }
// Close closes the database connection // Close closes the database connection pool
func (db *DB) Close() error { func (db *DB) Close() error {
if db.conn != nil { if db.pool != nil {
return db.conn.Close() return db.pool.Close()
} }
return nil return nil
} }
// Conn returns the underlying sqlite connection // Pool returns the underlying connection pool
func (db *DB) Conn() *sqlite.Conn { func (db *DB) Pool() *sqlitex.Pool {
return db.conn return db.pool
} }
// Scan scans a SQLite statement result into a struct using field names // Scan scans a SQLite statement result into a struct using field names
func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error { func (db *DB) Scan(stmt *pooledStmt, dest any) error {
v := reflect.ValueOf(dest) v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct { if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct") return fmt.Errorf("dest must be a pointer to struct")
@ -117,11 +129,17 @@ func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
} }
// Query executes a query with fmt-style placeholders and automatically binds parameters // Query executes a query with fmt-style placeholders and automatically binds parameters
func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) { func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
convertedQuery, paramTypes := convertPlaceholders(query) convertedQuery, paramTypes := convertPlaceholders(query)
stmt, err := db.conn.Prepare(convertedQuery) stmt, err := conn.Prepare(convertedQuery)
if err != nil { if err != nil {
db.pool.Put(conn)
return nil, err return nil, err
} }
@ -158,7 +176,26 @@ func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
} }
} }
return stmt, nil // Create a wrapped statement that releases the connection when finalized
return &pooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
}
// pooledStmt wraps a statement to automatically release pool connections
type pooledStmt struct {
*sqlite.Stmt
pool *sqlitex.Pool
conn *sqlite.Conn
finalized bool
}
func (ps *pooledStmt) Finalize() error {
if !ps.finalized {
err := ps.Stmt.Finalize()
ps.pool.Put(ps.conn)
ps.finalized = true
return err
}
return nil
} }
// Get executes a query and returns the first row // Get executes a query and returns the first row
@ -224,6 +261,12 @@ func (db *DB) Select(dest any, query string, args ...any) error {
// Exec executes a statement with fmt-style placeholders // Exec executes a statement with fmt-style placeholders
func (db *DB) Exec(query string, args ...any) error { func (db *DB) Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
convertedQuery, paramTypes := convertPlaceholders(query) convertedQuery, paramTypes := convertPlaceholders(query)
sqlArgs := make([]any, len(args)) sqlArgs := make([]any, len(args))
@ -245,7 +288,7 @@ func (db *DB) Exec(query string, args ...any) error {
} }
} }
return sqlitex.Execute(db.conn, convertedQuery, &sqlitex.ExecOptions{ return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{
Args: sqlArgs, Args: sqlArgs,
}) })
} }
@ -256,6 +299,12 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
return nil // No changes return nil // No changes
} }
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Build UPDATE query // Build UPDATE query
setParts := make([]string, 0, len(fields)) setParts := make([]string, 0, len(fields))
args := make([]any, 0, len(fields)+1) args := make([]any, 0, len(fields)+1)
@ -270,13 +319,19 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
tableName, strings.Join(setParts, ", "), whereField) tableName, strings.Join(setParts, ", "), whereField)
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{ return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: args, Args: args,
}) })
} }
// Insert inserts a struct into the database // Insert inserts a struct into the database
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) { func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
v := reflect.ValueOf(obj) v := reflect.ValueOf(obj)
if v.Kind() == reflect.Pointer { if v.Kind() == reflect.Pointer {
v = v.Elem() v = v.Elem()
@ -307,7 +362,7 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
stmt, err := db.conn.Prepare(query) stmt, err := conn.Prepare(query)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -332,27 +387,33 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
return 0, err return 0, err
} }
return db.conn.LastInsertRowID(), nil return conn.LastInsertRowID(), nil
} }
// Transaction executes multiple operations atomically // Transaction executes multiple operations atomically
func (db *DB) Transaction(fn func() error) error { func (db *DB) Transaction(fn func() error) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Begin transaction // Begin transaction
if err := sqlitex.Execute(db.conn, "BEGIN", nil); err != nil { if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
return err return err
} }
// Execute operations // Execute operations
err := fn() err = fn()
if err != nil { if err != nil {
// Rollback on error // Rollback on error
sqlitex.Execute(db.conn, "ROLLBACK", nil) sqlitex.Execute(conn, "ROLLBACK", nil)
return err return err
} }
// Commit on success // Commit on success
return sqlitex.Execute(db.conn, "COMMIT", nil) return sqlitex.Execute(conn, "COMMIT", nil)
} }
func convertPlaceholders(query string) (string, []string) { func convertPlaceholders(query string) (string, []string) {

View File

@ -1,6 +1,7 @@
package sashimi package sashimi
import ( import (
"context"
"fmt" "fmt"
"io/fs" "io/fs"
"os" "os"
@ -39,6 +40,12 @@ func NewMigrator(db *DB, dataDir string) *Migrator {
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist // ensureMigrationsTable creates the migrations tracking table if it doesn't exist
func (m *Migrator) ensureMigrationsTable() error { func (m *Migrator) ensureMigrationsTable() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
query := ` query := `
CREATE TABLE IF NOT EXISTS migrations ( CREATE TABLE IF NOT EXISTS migrations (
number INTEGER PRIMARY KEY, number INTEGER PRIMARY KEY,
@ -47,7 +54,7 @@ func (m *Migrator) ensureMigrationsTable() error {
executed_at INTEGER NOT NULL executed_at INTEGER NOT NULL
) )
` `
return sqlitex.Execute(m.db.conn, query, nil) return sqlitex.Execute(conn, query, nil)
} }
// getExecutedMigrations returns a map of migration numbers that have been executed // getExecutedMigrations returns a map of migration numbers that have been executed
@ -155,11 +162,17 @@ func (m *Migrator) Run() error {
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations)) fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
return m.db.Transaction(func() error { return m.db.Transaction(func() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
for _, migration := range pendingMigrations { for _, migration := range pendingMigrations {
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name) fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
// Execute the migration SQL // Execute the migration SQL
if err := sqlitex.Execute(m.db.conn, migration.Content, nil); err != nil { if err := sqlitex.Execute(conn, migration.Content, nil); err != nil {
return fmt.Errorf("failed to execute migration %d (%s): %w", return fmt.Errorf("failed to execute migration %d (%s): %w",
migration.Number, migration.Name, err) migration.Number, migration.Name, err)
} }