Compare commits

..

5 Commits

3 changed files with 214 additions and 42 deletions

199
db.go
View File

@ -1,6 +1,7 @@
package sashimi package sashimi
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -18,21 +19,25 @@ type Stmt = sqlite.Stmt
var placeholderRegex = regexp.MustCompile(`%[sd]`) var placeholderRegex = regexp.MustCompile(`%[sd]`)
type DB struct { type DB struct {
conn *sqlite.Conn pool *sqlitex.Pool
} }
// New creates a new database wrapper instance // New creates a new database wrapper instance with connection pooling
func New(dbPath string) (*DB, error) { func New(dbPath string) (*DB, error) {
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate) // Create connection pool with 10 connections max
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
Flags: sqlite.OpenReadWrite | sqlite.OpenCreate,
PoolSize: 10,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to create connection pool: %w", err)
} }
db := &DB{conn: conn} db := &DB{pool: pool}
// Configure database // Configure database using one connection from pool
if err := db.configure(); err != nil { if err := db.configure(); err != nil {
conn.Close() pool.Close()
return nil, err return nil, err
} }
@ -41,14 +46,21 @@ func New(dbPath string) (*DB, error) {
// configure sets up database pragmas // configure sets up database pragmas
func (db *DB) configure() error { func (db *DB) configure() error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
configs := []string{ configs := []string{
"PRAGMA journal_mode=WAL", "PRAGMA journal_mode=WAL",
"PRAGMA cache_size=-65536", // 64MB cache "PRAGMA cache_size=-65536", // 64MB cache
"PRAGMA foreign_keys=ON", "PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000", // 5 second timeout
} }
for _, config := range configs { for _, config := range configs {
if err := sqlitex.Execute(db.conn, config, nil); err != nil { if err := sqlitex.Execute(conn, config, nil); err != nil {
return fmt.Errorf("failed to configure database: %w", err) return fmt.Errorf("failed to configure database: %w", err)
} }
} }
@ -56,21 +68,21 @@ func (db *DB) configure() error {
return nil return nil
} }
// Close closes the database connection // Close closes the database connection pool
func (db *DB) Close() error { func (db *DB) Close() error {
if db.conn != nil { if db.pool != nil {
return db.conn.Close() return db.pool.Close()
} }
return nil return nil
} }
// Conn returns the underlying sqlite connection // Pool returns the underlying connection pool
func (db *DB) Conn() *sqlite.Conn { func (db *DB) Pool() *sqlitex.Pool {
return db.conn return db.pool
} }
// Scan scans a SQLite statement result into a struct using field names // Scan scans a SQLite statement result into a struct using field names
func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error { func (db *DB) Scan(stmt *PooledStmt, dest any) error {
v := reflect.ValueOf(dest) v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct { if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct") return fmt.Errorf("dest must be a pointer to struct")
@ -117,11 +129,17 @@ func (db *DB) Scan(stmt *sqlite.Stmt, dest any) error {
} }
// Query executes a query with fmt-style placeholders and automatically binds parameters // Query executes a query with fmt-style placeholders and automatically binds parameters
func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) { func (db *DB) Query(query string, args ...any) (*PooledStmt, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
convertedQuery, paramTypes := convertPlaceholders(query) convertedQuery, paramTypes := convertPlaceholders(query)
stmt, err := db.conn.Prepare(convertedQuery) stmt, err := conn.Prepare(convertedQuery)
if err != nil { if err != nil {
db.pool.Put(conn)
return nil, err return nil, err
} }
@ -158,7 +176,26 @@ func (db *DB) Query(query string, args ...any) (*sqlite.Stmt, error) {
} }
} }
return stmt, nil // Create a wrapped statement that releases the connection when finalized
return &PooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
}
// PooledStmt wraps a statement to automatically release pool connections
type PooledStmt struct {
*sqlite.Stmt
pool *sqlitex.Pool
conn *sqlite.Conn
finalized bool
}
func (ps *PooledStmt) Finalize() error {
if !ps.finalized {
err := ps.Stmt.Finalize()
ps.pool.Put(ps.conn)
ps.finalized = true
return err
}
return nil
} }
// Get executes a query and returns the first row // Get executes a query and returns the first row
@ -177,7 +214,7 @@ func (db *DB) Get(dest any, query string, args ...any) error {
return fmt.Errorf("no rows found") return fmt.Errorf("no rows found")
} }
return db.Scan(stmt, dest) return db.scanValue(stmt, dest)
} }
// Select executes a query and scans all rows into a slice // Select executes a query and scans all rows into a slice
@ -224,6 +261,12 @@ func (db *DB) Select(dest any, query string, args ...any) error {
// Exec executes a statement with fmt-style placeholders // Exec executes a statement with fmt-style placeholders
func (db *DB) Exec(query string, args ...any) error { func (db *DB) Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
convertedQuery, paramTypes := convertPlaceholders(query) convertedQuery, paramTypes := convertPlaceholders(query)
sqlArgs := make([]any, len(args)) sqlArgs := make([]any, len(args))
@ -245,7 +288,7 @@ func (db *DB) Exec(query string, args ...any) error {
} }
} }
return sqlitex.Execute(db.conn, convertedQuery, &sqlitex.ExecOptions{ return sqlitex.Execute(conn, convertedQuery, &sqlitex.ExecOptions{
Args: sqlArgs, Args: sqlArgs,
}) })
} }
@ -256,6 +299,12 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
return nil // No changes return nil // No changes
} }
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Build UPDATE query // Build UPDATE query
setParts := make([]string, 0, len(fields)) setParts := make([]string, 0, len(fields))
args := make([]any, 0, len(fields)+1) args := make([]any, 0, len(fields)+1)
@ -270,18 +319,18 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
tableName, strings.Join(setParts, ", "), whereField) tableName, strings.Join(setParts, ", "), whereField)
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{ return sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: args, Args: args,
}) })
} }
// Insert inserts a struct into the database // Insert inserts a struct or map into the database
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) { func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
v := reflect.ValueOf(obj) conn, err := db.pool.Take(context.Background())
if v.Kind() == reflect.Pointer { if err != nil {
v = v.Elem() return 0, fmt.Errorf("failed to get connection: %w", err)
} }
t := v.Type() defer db.pool.Put(conn)
exclude := make(map[string]bool) exclude := make(map[string]bool)
for _, field := range excludeFields { for _, field := range excludeFields {
@ -292,22 +341,47 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
var placeholders []string var placeholders []string
var args []any var args []any
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
switch v.Kind() {
case reflect.Map:
// Handle map[string]any
m := obj.(map[string]any)
for key, value := range m {
columnName := toSnakeCase(key)
if exclude[columnName] {
continue
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, value)
}
case reflect.Struct:
// Handle struct
t := v.Type()
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
field := t.Field(i) field := t.Field(i)
columnName := toSnakeCase(field.Name) columnName := toSnakeCase(field.Name)
if exclude[columnName] { if exclude[columnName] {
continue continue
} }
columns = append(columns, columnName) columns = append(columns, columnName)
placeholders = append(placeholders, "?") placeholders = append(placeholders, "?")
args = append(args, v.Field(i).Interface()) args = append(args, v.Field(i).Interface())
} }
default:
return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any")
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
stmt, err := db.conn.Prepare(query) stmt, err := conn.Prepare(query)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -332,27 +406,86 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
return 0, err return 0, err
} }
return db.conn.LastInsertRowID(), nil return conn.LastInsertRowID(), nil
} }
// Transaction executes multiple operations atomically // Transaction executes multiple operations atomically
func (db *DB) Transaction(fn func() error) error { func (db *DB) Transaction(fn func() error) error {
conn, err := db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer db.pool.Put(conn)
// Begin transaction // Begin transaction
if err := sqlitex.Execute(db.conn, "BEGIN", nil); err != nil { if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
return err return err
} }
// Execute operations // Execute operations
err := fn() err = fn()
if err != nil { if err != nil {
// Rollback on error // Rollback on error
sqlitex.Execute(db.conn, "ROLLBACK", nil) sqlitex.Execute(conn, "ROLLBACK", nil)
return err return err
} }
// Commit on success // Commit on success
return sqlitex.Execute(db.conn, "COMMIT", nil) return sqlitex.Execute(conn, "COMMIT", nil)
}
// scanValue scans a statement result into either a struct or primitive type
func (db *DB) scanValue(stmt *PooledStmt, dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer {
return fmt.Errorf("dest must be a pointer")
}
elem := v.Elem()
// Handle primitive types
if isPrimitiveType(elem.Kind()) {
if stmt.ColumnCount() == 0 {
return fmt.Errorf("no columns in result")
}
return scanPrimitive(stmt, elem, 0)
}
// Handle struct types
if elem.Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct or primitive type")
}
return db.Scan(stmt, dest)
}
// isPrimitiveType checks if a reflect.Kind represents a primitive type
func isPrimitiveType(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.String, reflect.Float32, reflect.Float64, reflect.Bool:
return true
}
return false
}
// scanPrimitive scans a column value into a primitive type
func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error {
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
case reflect.String:
fieldValue.SetString(stmt.ColumnText(colIndex))
case reflect.Float32, reflect.Float64:
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
case reflect.Bool:
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
default:
return fmt.Errorf("unsupported type: %v", fieldValue.Kind())
}
return nil
} }
func convertPlaceholders(query string) (string, []string) { func convertPlaceholders(query string) (string, []string) {

26
go.sum
View File

@ -1,5 +1,7 @@
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@ -10,16 +12,40 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/libc v1.65.7 h1:Ia9Z4yzZtWNtUIuiPuQ7Qf7kxYrxP1/jeHZzG8bFu00= modernc.org/libc v1.65.7 h1:Ia9Z4yzZtWNtUIuiPuQ7Qf7kxYrxP1/jeHZzG8bFu00=
modernc.org/libc v1.65.7/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU= modernc.org/libc v1.65.7/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs= modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g= modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo= zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo=
zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc= zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc=

View File

@ -1,6 +1,7 @@
package sashimi package sashimi
import ( import (
"context"
"fmt" "fmt"
"io/fs" "io/fs"
"os" "os"
@ -39,6 +40,12 @@ func NewMigrator(db *DB, dataDir string) *Migrator {
// ensureMigrationsTable creates the migrations tracking table if it doesn't exist // ensureMigrationsTable creates the migrations tracking table if it doesn't exist
func (m *Migrator) ensureMigrationsTable() error { func (m *Migrator) ensureMigrationsTable() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
query := ` query := `
CREATE TABLE IF NOT EXISTS migrations ( CREATE TABLE IF NOT EXISTS migrations (
number INTEGER PRIMARY KEY, number INTEGER PRIMARY KEY,
@ -47,7 +54,7 @@ func (m *Migrator) ensureMigrationsTable() error {
executed_at INTEGER NOT NULL executed_at INTEGER NOT NULL
) )
` `
return sqlitex.Execute(m.db.conn, query, nil) return sqlitex.Execute(conn, query, nil)
} }
// getExecutedMigrations returns a map of migration numbers that have been executed // getExecutedMigrations returns a map of migration numbers that have been executed
@ -155,11 +162,17 @@ func (m *Migrator) Run() error {
fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations)) fmt.Printf("Running %d pending migrations...\n", len(pendingMigrations))
return m.db.Transaction(func() error { return m.db.Transaction(func() error {
conn, err := m.db.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer m.db.pool.Put(conn)
for _, migration := range pendingMigrations { for _, migration := range pendingMigrations {
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name) fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
// Execute the migration SQL // Execute the migration SQL
if err := sqlitex.Execute(m.db.conn, migration.Content, nil); err != nil { if err := sqlitex.ExecuteScript(conn, migration.Content, nil); err != nil {
return fmt.Errorf("failed to execute migration %d (%s): %w", return fmt.Errorf("failed to execute migration %d (%s): %w",
migration.Number, migration.Name, err) migration.Number, migration.Name, err)
} }