add invisible connection pooling
This commit is contained in:
parent
344da424a0
commit
29546a2066
113
db.go
113
db.go
@ -1,6 +1,7 @@
|
||||
package sashimi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
@ -18,21 +19,25 @@ type Stmt = sqlite.Stmt
|
||||
var placeholderRegex = regexp.MustCompile(`%[sd]`)
|
||||
|
||||
type DB struct {
|
||||
conn *sqlite.Conn
|
||||
pool *sqlitex.Pool
|
||||
}
|
||||
|
||||
// New creates a new database wrapper instance
|
||||
// New creates a new database wrapper instance with connection pooling
|
||||
func New(dbPath string) (*DB, error) {
|
||||
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
||||
// Create connection pool with 10 connections max
|
||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
||||
Flags: sqlite.OpenReadWrite | sqlite.OpenCreate,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
db := &DB{conn: conn}
|
||||
db := &DB{pool: pool}
|
||||
|
||||
// Configure database
|
||||
// Configure database using one connection from pool
|
||||
if err := db.configure(); err != nil {
|
||||
conn.Close()
|
||||
pool.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -41,14 +46,21 @@ func New(dbPath string) (*DB, error) {
|
||||
|
||||
// configure sets up database pragmas
|
||||
func (db *DB) configure() error {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
defer db.pool.Put(conn)
|
||||
|
||||
configs := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA cache_size=-65536", // 64MB cache
|
||||
"PRAGMA foreign_keys=ON",
|
||||
"PRAGMA busy_timeout=5000", // 5 second timeout
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err := sqlitex.Execute(db.conn, config, nil); err != nil {
|
||||
if err := sqlitex.Execute(conn, config, nil); err != nil {
|
||||
return fmt.Errorf("failed to configure database: %w", err)
|
||||
}
|
||||
}
|
||||
@ -56,21 +68,21 @@ func (db *DB) configure() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
// Close closes the database connection pool
|
||||
func (db *DB) Close() error {
|
||||
if db.conn != nil {
|
||||
return db.conn.Close()
|
||||
if db.pool != nil {
|
||||
return db.pool.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conn returns the underlying sqlite connection
|
||||
func (db *DB) Conn() *sqlite.Conn {
|
||||
return db.conn
|
||||
// Pool returns the underlying connection pool
|
||||
func (db *DB) Pool() *sqlitex.Pool {
|
||||
return db.pool
|
||||
}
|
||||
|
||||
// Scan scans a SQLite statement result into a struct using field names
|
||||
func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
|
||||
func (db *DB) Scan(stmt *pooledStmt, dest any) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to struct")
|
||||
@ -117,11 +129,17 @@ func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
|
||||
}
|
||||
|
||||
// Query executes a query with fmt-style placeholders and automatically binds parameters
|
||||
func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
|
||||
func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
|
||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||
|
||||
stmt, err := db.conn.Prepare(convertedQuery)
|
||||
stmt, err := conn.Prepare(convertedQuery)
|
||||
if err != nil {
|
||||
db.pool.Put(conn)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -158,7 +176,26 @@ func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
// Create a wrapped statement that releases the connection when finalized
|
||||
return &pooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
|
||||
}
|
||||
|
||||
// pooledStmt wraps a statement to automatically release pool connections
|
||||
type pooledStmt struct {
|
||||
*sqlite.Stmt
|
||||
pool *sqlitex.Pool
|
||||
conn *sqlite.Conn
|
||||
finalized bool
|
||||
}
|
||||
|
||||
func (ps *pooledStmt) Finalize() error {
|
||||
if !ps.finalized {
|
||||
err := ps.Stmt.Finalize()
|
||||
ps.pool.Put(ps.conn)
|
||||
ps.finalized = true
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get executes a query and returns the first row
|
||||
@ -224,6 +261,12 @@ func (db *DB) Select(dest any, query string, args ...any) error {
|
||||
|
||||
// Exec executes a statement with fmt-style placeholders
|
||||
func (db *DB) Exec(query string, args ...any) error {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
defer db.pool.Put(conn)
|
||||
|
||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||
|
||||
sqlArgs := make([]any, len(args))
|
||||
@ -245,7 +288,7 @@ func (db *DB) Exec(query string, args ...any) error {
|
||||
}
|
||||
}
|
||||
|
||||
return sqlitex.Execute(db.conn, convertedQuery, &sqlitex.ExecOptions{
|
||||
return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{
|
||||
Args: sqlArgs,
|
||||
})
|
||||
}
|
||||
@ -256,6 +299,12 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
|
||||
return nil // No changes
|
||||
}
|
||||
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
defer db.pool.Put(conn)
|
||||
|
||||
// Build UPDATE query
|
||||
setParts := make([]string, 0, len(fields))
|
||||
args := make([]any, 0, len(fields)+1)
|
||||
@ -270,13 +319,19 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
|
||||
tableName, strings.Join(setParts, ", "), whereField)
|
||||
|
||||
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{
|
||||
return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Insert inserts a struct into the database
|
||||
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
defer db.pool.Put(conn)
|
||||
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
@ -307,7 +362,7 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||
|
||||
stmt, err := db.conn.Prepare(query)
|
||||
stmt, err := conn.Prepare(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -332,27 +387,33 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return db.conn.LastInsertRowID(), nil
|
||||
return conn.LastInsertRowID(), nil
|
||||
}
|
||||
|
||||
// Transaction executes multiple operations atomically
|
||||
func (db *DB) Transaction(fn func() error) error {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
defer db.pool.Put(conn)
|
||||
|
||||
// Begin transaction
|
||||
if err := sqlitex.Execute(db.conn, "BEGIN", nil); err != nil {
|
||||
if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute operations
|
||||
err := fn()
|
||||
err = fn()
|
||||
|
||||
if err != nil {
|
||||
// Rollback on error
|
||||
sqlitex.Execute(db.conn, "ROLLBACK", nil)
|
||||
sqlitex.Execute(conn, "ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit on success
|
||||
return sqlitex.Execute(db.conn, "COMMIT", nil)
|
||||
return sqlitex.Execute(conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
func convertPlaceholders(query string) (string, []string) {
|
||||
|
17
migrate.go
17
migrate.go
@ -1,6 +1,7 @@
|
||||
package sashimi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
@ -39,6 +40,12 @@ func NewMigrator(db *DB, dataDir string) *Migrator {
|
||||
|
||||
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist
|
||||
func (m *Migrator) ensureMigrationsTable() error {
|
||||
conn, err := m.db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
defer m.db.pool.Put(conn)
|
||||
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
number INTEGER PRIMARY KEY,
|
||||
@ -47,7 +54,7 @@ func (m *Migrator) ensureMigrationsTable() error {
|
||||
executed_at INTEGER NOT NULL
|
||||
)
|
||||
`
|
||||
return sqlitex.Execute(m.db.conn, query, nil)
|
||||
return sqlitex.Execute(conn, query, nil)
|
||||
}
|
||||
|
||||
// getExecutedMigrations returns a map of migration numbers that have been executed
|
||||
@ -155,11 +162,17 @@ func (m *Migrator) Run() error {
|
||||
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
|
||||
|
||||
return m.db.Transaction(func() error {
|
||||
conn, err := m.db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
defer m.db.pool.Put(conn)
|
||||
|
||||
for _, migration := range pendingMigrations {
|
||||
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
|
||||
|
||||
// Execute the migration SQL
|
||||
if err := sqlitex.Execute(m.db.conn, migration.Content, nil); err != nil {
|
||||
if err := sqlitex.Execute(conn, migration.Content, nil); err != nil {
|
||||
return fmt.Errorf("failed to execute migration %d (%s): %w",
|
||||
migration.Number, migration.Name, err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user