add invisible connection pooling
This commit is contained in:
parent
344da424a0
commit
29546a2066
113
db.go
113
db.go
@ -1,6 +1,7 @@
|
|||||||
package sashimi
|
package sashimi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -18,21 +19,25 @@ type Stmt = sqlite.Stmt
|
|||||||
var placeholderRegex = regexp.MustCompile(`%[sd]`)
|
var placeholderRegex = regexp.MustCompile(`%[sd]`)
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
conn *sqlite.Conn
|
pool *sqlitex.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new database wrapper instance
|
// New creates a new database wrapper instance with connection pooling
|
||||||
func New(dbPath string) (*DB, error) {
|
func New(dbPath string) (*DB, error) {
|
||||||
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
// Create connection pool with 10 connections max
|
||||||
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
||||||
|
Flags: sqlite.OpenReadWrite | sqlite.OpenCreate,
|
||||||
|
PoolSize: 10,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db := &DB{conn: conn}
|
db := &DB{pool: pool}
|
||||||
|
|
||||||
// Configure database
|
// Configure database using one connection from pool
|
||||||
if err := db.configure(); err != nil {
|
if err := db.configure(); err != nil {
|
||||||
conn.Close()
|
pool.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,14 +46,21 @@ func New(dbPath string) (*DB, error) {
|
|||||||
|
|
||||||
// configure sets up database pragmas
|
// configure sets up database pragmas
|
||||||
func (db *DB) configure() error {
|
func (db *DB) configure() error {
|
||||||
|
conn, err := db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer db.pool.Put(conn)
|
||||||
|
|
||||||
configs := []string{
|
configs := []string{
|
||||||
"PRAGMA journal_mode=WAL",
|
"PRAGMA journal_mode=WAL",
|
||||||
"PRAGMA cache_size=-65536", // 64MB cache
|
"PRAGMA cache_size=-65536", // 64MB cache
|
||||||
"PRAGMA foreign_keys=ON",
|
"PRAGMA foreign_keys=ON",
|
||||||
|
"PRAGMA busy_timeout=5000", // 5 second timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, config := range configs {
|
for _, config := range configs {
|
||||||
if err := sqlitex.Execute(db.conn, config, nil); err != nil {
|
if err := sqlitex.Execute(conn, config, nil); err != nil {
|
||||||
return fmt.Errorf("failed to configure database: %w", err)
|
return fmt.Errorf("failed to configure database: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -56,21 +68,21 @@ func (db *DB) configure() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the database connection
|
// Close closes the database connection pool
|
||||||
func (db *DB) Close() error {
|
func (db *DB) Close() error {
|
||||||
if db.conn != nil {
|
if db.pool != nil {
|
||||||
return db.conn.Close()
|
return db.pool.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn returns the underlying sqlite connection
|
// Pool returns the underlying connection pool
|
||||||
func (db *DB) Conn() *sqlite.Conn {
|
func (db *DB) Pool() *sqlitex.Pool {
|
||||||
return db.conn
|
return db.pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan scans a SQLite statement result into a struct using field names
|
// Scan scans a SQLite statement result into a struct using field names
|
||||||
func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
|
func (db *DB) Scan(stmt *pooledStmt, dest any) error {
|
||||||
v := reflect.ValueOf(dest)
|
v := reflect.ValueOf(dest)
|
||||||
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
|
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
|
||||||
return fmt.Errorf("dest must be a pointer to struct")
|
return fmt.Errorf("dest must be a pointer to struct")
|
||||||
@ -117,11 +129,17 @@ func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Query executes a query with fmt-style placeholders and automatically binds parameters
|
// Query executes a query with fmt-style placeholders and automatically binds parameters
|
||||||
func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
|
func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
|
||||||
|
conn, err := db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||||
|
|
||||||
stmt, err := db.conn.Prepare(convertedQuery)
|
stmt, err := conn.Prepare(convertedQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
db.pool.Put(conn)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,7 +176,26 @@ func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return stmt, nil
|
// Create a wrapped statement that releases the connection when finalized
|
||||||
|
return &pooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pooledStmt wraps a statement to automatically release pool connections
|
||||||
|
type pooledStmt struct {
|
||||||
|
*sqlite.Stmt
|
||||||
|
pool *sqlitex.Pool
|
||||||
|
conn *sqlite.Conn
|
||||||
|
finalized bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pooledStmt) Finalize() error {
|
||||||
|
if !ps.finalized {
|
||||||
|
err := ps.Stmt.Finalize()
|
||||||
|
ps.pool.Put(ps.conn)
|
||||||
|
ps.finalized = true
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get executes a query and returns the first row
|
// Get executes a query and returns the first row
|
||||||
@ -224,6 +261,12 @@ func (db *DB) Select(dest any, query string, args ...any) error {
|
|||||||
|
|
||||||
// Exec executes a statement with fmt-style placeholders
|
// Exec executes a statement with fmt-style placeholders
|
||||||
func (db *DB) Exec(query string, args ...any) error {
|
func (db *DB) Exec(query string, args ...any) error {
|
||||||
|
conn, err := db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer db.pool.Put(conn)
|
||||||
|
|
||||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||||
|
|
||||||
sqlArgs := make([]any, len(args))
|
sqlArgs := make([]any, len(args))
|
||||||
@ -245,7 +288,7 @@ func (db *DB) Exec(query string, args ...any) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sqlitex.Execute(db.conn, convertedQuery, &sqlitex.ExecOptions{
|
return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{
|
||||||
Args: sqlArgs,
|
Args: sqlArgs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -256,6 +299,12 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
|
|||||||
return nil // No changes
|
return nil // No changes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn, err := db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer db.pool.Put(conn)
|
||||||
|
|
||||||
// Build UPDATE query
|
// Build UPDATE query
|
||||||
setParts := make([]string, 0, len(fields))
|
setParts := make([]string, 0, len(fields))
|
||||||
args := make([]any, 0, len(fields)+1)
|
args := make([]any, 0, len(fields)+1)
|
||||||
@ -270,13 +319,19 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
|
|||||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
|
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
|
||||||
tableName, strings.Join(setParts, ", "), whereField)
|
tableName, strings.Join(setParts, ", "), whereField)
|
||||||
|
|
||||||
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{
|
return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
Args: args,
|
Args: args,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert inserts a struct into the database
|
// Insert inserts a struct into the database
|
||||||
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
||||||
|
conn, err := db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer db.pool.Put(conn)
|
||||||
|
|
||||||
v := reflect.ValueOf(obj)
|
v := reflect.ValueOf(obj)
|
||||||
if v.Kind() == reflect.Pointer {
|
if v.Kind() == reflect.Pointer {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
@ -307,7 +362,7 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
|||||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||||
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||||
|
|
||||||
stmt, err := db.conn.Prepare(query)
|
stmt, err := conn.Prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -332,27 +387,33 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.conn.LastInsertRowID(), nil
|
return conn.LastInsertRowID(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transaction executes multiple operations atomically
|
// Transaction executes multiple operations atomically
|
||||||
func (db *DB) Transaction(fn func() error) error {
|
func (db *DB) Transaction(fn func() error) error {
|
||||||
|
conn, err := db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer db.pool.Put(conn)
|
||||||
|
|
||||||
// Begin transaction
|
// Begin transaction
|
||||||
if err := sqlitex.Execute(db.conn, "BEGIN", nil); err != nil {
|
if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute operations
|
// Execute operations
|
||||||
err := fn()
|
err = fn()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Rollback on error
|
// Rollback on error
|
||||||
sqlitex.Execute(db.conn, "ROLLBACK", nil)
|
sqlitex.Execute(conn, "ROLLBACK", nil)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit on success
|
// Commit on success
|
||||||
return sqlitex.Execute(db.conn, "COMMIT", nil)
|
return sqlitex.Execute(conn, "COMMIT", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertPlaceholders(query string) (string, []string) {
|
func convertPlaceholders(query string) (string, []string) {
|
||||||
|
17
migrate.go
17
migrate.go
@ -1,6 +1,7 @@
|
|||||||
package sashimi
|
package sashimi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
@ -39,6 +40,12 @@ func NewMigrator(db *DB, dataDir string) *Migrator {
|
|||||||
|
|
||||||
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist
|
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist
|
||||||
func (m *Migrator) ensureMigrationsTable() error {
|
func (m *Migrator) ensureMigrationsTable() error {
|
||||||
|
conn, err := m.db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer m.db.pool.Put(conn)
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
CREATE TABLE IF NOT EXISTS migrations (
|
CREATE TABLE IF NOT EXISTS migrations (
|
||||||
number INTEGER PRIMARY KEY,
|
number INTEGER PRIMARY KEY,
|
||||||
@ -47,7 +54,7 @@ func (m *Migrator) ensureMigrationsTable() error {
|
|||||||
executed_at INTEGER NOT NULL
|
executed_at INTEGER NOT NULL
|
||||||
)
|
)
|
||||||
`
|
`
|
||||||
return sqlitex.Execute(m.db.conn, query, nil)
|
return sqlitex.Execute(conn, query, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getExecutedMigrations returns a map of migration numbers that have been executed
|
// getExecutedMigrations returns a map of migration numbers that have been executed
|
||||||
@ -155,11 +162,17 @@ func (m *Migrator) Run() error {
|
|||||||
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
|
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
|
||||||
|
|
||||||
return m.db.Transaction(func() error {
|
return m.db.Transaction(func() error {
|
||||||
|
conn, err := m.db.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer m.db.pool.Put(conn)
|
||||||
|
|
||||||
for _, migration := range pendingMigrations {
|
for _, migration := range pendingMigrations {
|
||||||
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
|
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
|
||||||
|
|
||||||
// Execute the migration SQL
|
// Execute the migration SQL
|
||||||
if err := sqlitex.Execute(m.db.conn, migration.Content, nil); err != nil {
|
if err := sqlitex.Execute(conn, migration.Content, nil); err != nil {
|
||||||
return fmt.Errorf("failed to execute migration %d (%s): %w",
|
return fmt.Errorf("failed to execute migration %d (%s): %w",
|
||||||
migration.Number, migration.Name, err)
|
migration.Number, migration.Name, err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user