add invisible connection pooling

This commit is contained in:
Sky Johnson 2025-08-22 15:48:27 -05:00
parent 344da424a0
commit 29546a2066
2 changed files with 102 additions and 28 deletions

113
db.go
View File

@ -1,6 +1,7 @@
package sashimi
import (
"context"
"fmt"
"reflect"
"regexp"
@ -18,21 +19,25 @@ type Stmt = sqlite.Stmt
var placeholderRegex = regexp.MustCompile(`%[sd]`)
type DB struct {
conn *sqlite.Conn
pool *sqlitex.Pool
}
// New creates a new database wrapper instance
// New creates a new database wrapper instance with connection pooling
func New(dbPath string) (*DB, error) {
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate)
// Create connection pool with 10 connections max
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
Flags: sqlite.OpenReadWrite | sqlite.OpenCreate,
PoolSize: 10,
})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
db := &DB{conn: conn}
db := &DB{pool: pool}
// Configure database
// Configure database using one connection from pool
if err := db.configure(); err != nil {
conn.Close()
pool.Close()
return nil, err
}
@ -41,14 +46,21 @@ func New(dbPath string) (*DB, error) {
// configure sets up database pragmas
func (db *DB) configure() error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
configs := []string{
"PRAGMA journal_mode=WAL",
"PRAGMA cache_size=-65536", // 64MB cache
"PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000", // 5 second timeout
}
for _, config := range configs {
if err := sqlitex.Execute(db.conn, config, nil); err != nil {
if err := sqlitex.Execute(conn, config, nil); err != nil {
return fmt.Errorf("failed to configure database: %w", err)
}
}
@ -56,21 +68,21 @@ func (db *DB) configure() error {
return nil
}
// Close closes the database connection
// Close closes the database connection pool
func (db *DB) Close() error {
if db.conn != nil {
return db.conn.Close()
if db.pool != nil {
return db.pool.Close()
}
return nil
}
// Conn returns the underlying sqlite connection
func (db *DB) Conn() *sqlite.Conn {
return db.conn
// Pool returns the underlying connection pool
func (db *DB) Pool() *sqlitex.Pool {
return db.pool
}
// Scan scans a SQLite statement result into a struct using field names
func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
func (db *DB) Scan(stmt *pooledStmt, dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct")
@ -117,11 +129,17 @@ func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
}
// Query executes a query with fmt-style placeholders and automatically binds parameters
func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
convertedQuery, paramTypes := convertPlaceholders(query)
stmt, err := db.conn.Prepare(convertedQuery)
stmt, err := conn.Prepare(convertedQuery)
if err != nil {
db.pool.Put(conn)
return nil, err
}
@ -158,7 +176,26 @@ func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
}
}
return stmt, nil
// Create a wrapped statement that releases the connection when finalized
return &pooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
}
// pooledStmt wraps a statement to automatically release pool connections
type pooledStmt struct {
*sqlite.Stmt
pool *sqlitex.Pool
conn *sqlite.Conn
finalized bool
}
func (ps *pooledStmt) Finalize() error {
if !ps.finalized {
err := ps.Stmt.Finalize()
ps.pool.Put(ps.conn)
ps.finalized = true
return err
}
return nil
}
// Get executes a query and returns the first row
@ -224,6 +261,12 @@ func (db *DB) Select(dest any, query string, args ...any) error {
// Exec executes a statement with fmt-style placeholders
func (db *DB) Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
convertedQuery, paramTypes := convertPlaceholders(query)
sqlArgs := make([]any, len(args))
@ -245,7 +288,7 @@ func (db *DB) Exec(query string, args ...any) error {
}
}
return sqlitex.Execute(db.conn, convertedQuery, &sqlitex.ExecOptions{
return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{
Args: sqlArgs,
})
}
@ -256,6 +299,12 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
return nil // No changes
}
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Build UPDATE query
setParts := make([]string, 0, len(fields))
args := make([]any, 0, len(fields)+1)
@ -270,13 +319,19 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
tableName, strings.Join(setParts, ", "), whereField)
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{
return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: args,
})
}
// Insert inserts a struct into the database
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Pointer {
v = v.Elem()
@ -307,7 +362,7 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
stmt, err := db.conn.Prepare(query)
stmt, err := conn.Prepare(query)
if err != nil {
return 0, err
}
@ -332,27 +387,33 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
return 0, err
}
return db.conn.LastInsertRowID(), nil
return conn.LastInsertRowID(), nil
}
// Transaction executes multiple operations atomically
func (db *DB) Transaction(fn func() error) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Begin transaction
if err := sqlitex.Execute(db.conn, "BEGIN", nil); err != nil {
if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
return err
}
// Execute operations
err := fn()
err = fn()
if err != nil {
// Rollback on error
sqlitex.Execute(db.conn, "ROLLBACK", nil)
sqlitex.Execute(conn, "ROLLBACK", nil)
return err
}
// Commit on success
return sqlitex.Execute(db.conn, "COMMIT", nil)
return sqlitex.Execute(conn, "COMMIT", nil)
}
func convertPlaceholders(query string) (string, []string) {

View File

@ -1,6 +1,7 @@
package sashimi
import (
"context"
"fmt"
"io/fs"
"os"
@ -39,6 +40,12 @@ func NewMigrator(db *DB, dataDir string) *Migrator {
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist
func (m *Migrator) ensureMigrationsTable() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
query := `
CREATE TABLE IF NOT EXISTS migrations (
number INTEGER PRIMARY KEY,
@ -47,7 +54,7 @@ func (m *Migrator) ensureMigrationsTable() error {
executed_at INTEGER NOT NULL
)
`
return sqlitex.Execute(m.db.conn, query, nil)
return sqlitex.Execute(conn, query, nil)
}
// getExecutedMigrations returns a map of migration numbers that have been executed
@ -155,11 +162,17 @@ func (m *Migrator) Run() error {
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
return m.db.Transaction(func() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
for _, migration := range pendingMigrations {
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
// Execute the migration SQL
if err := sqlitex.Execute(m.db.conn, migration.Content, nil); err != nil {
if err := sqlitex.Execute(conn, migration.Content, nil); err != nil {
return fmt.Errorf("failed to execute migration %d (%s): %w",
migration.Number, migration.Name, err)
}