Compare commits

...

4 Commits

3 changed files with 121 additions and 23 deletions

116
db.go
View File

@ -82,7 +82,7 @@ func (db *DB) Pool() *sqlitex.Pool {
}
// Scan scans a SQLite statement result into a struct using field names
func (db *DB) Scan(stmt *pooledStmt, dest any) error {
func (db *DB) Scan(stmt *PooledStmt, dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct")
@ -129,7 +129,7 @@ func (db *DB) Scan(stmt *pooledStmt, dest any) error {
}
// Query executes a query with fmt-style placeholders and automatically binds parameters
func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
func (db *DB) Query(query string, args ...any) (*PooledStmt, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
@ -177,18 +177,18 @@ func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
}
// Create a wrapped statement that releases the connection when finalized
return &pooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
return &PooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
}
// pooledStmt wraps a statement to automatically release pool connections
type pooledStmt struct {
// PooledStmt wraps a statement to automatically release pool connections
type PooledStmt struct {
*sqlite.Stmt
pool *sqlitex.Pool
conn *sqlite.Conn
finalized bool
}
func (ps *pooledStmt) Finalize() error {
func (ps *PooledStmt) Finalize() error {
if !ps.finalized {
err := ps.Stmt.Finalize()
ps.pool.Put(ps.conn)
@ -214,7 +214,7 @@ func (db *DB) Get(dest any, query string, args ...any) error {
return fmt.Errorf("no rows found")
}
return db.Scan(stmt, dest)
return db.scanValue(stmt, dest)
}
// Select executes a query and scans all rows into a slice
@ -324,7 +324,7 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
})
}
// Insert inserts a struct into the database
// Insert inserts a struct or map into the database
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
conn, err := db.pool.Take(context.Background())
if err != nil {
@ -332,12 +332,6 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
}
defer db.pool.Put(conn)
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
t := v.Type()
exclude := make(map[string]bool)
for _, field := range excludeFields {
exclude[toSnakeCase(field)] = true
@ -347,16 +341,41 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
var placeholders []string
var args []any
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
columnName := toSnakeCase(field.Name)
if exclude[columnName] {
continue
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
switch v.Kind() {
case reflect.Map:
// Handle map[string]any
m := obj.(map[string]any)
for key, value := range m {
columnName := toSnakeCase(key)
if exclude[columnName] {
continue
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, value)
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, v.Field(i).Interface())
case reflect.Struct:
// Handle struct
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
columnName := toSnakeCase(field.Name)
if exclude[columnName] {
continue
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, v.Field(i).Interface())
}
default:
return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any")
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
@ -416,6 +435,59 @@ func (db *DB) Transaction(fn func() error) error {
return sqlitex.Execute(conn, "COMMIT", nil)
}
// scanValue scans a statement result into either a struct or primitive type
func (db *DB) scanValue(stmt *PooledStmt, dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Pointer {
return fmt.Errorf("dest must be a pointer")
}
elem := v.Elem()
// Handle primitive types
if isPrimitiveType(elem.Kind()) {
if stmt.ColumnCount() == 0 {
return fmt.Errorf("no columns in result")
}
return scanPrimitive(stmt, elem, 0)
}
// Handle struct types
if elem.Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to struct or primitive type")
}
return db.Scan(stmt, dest)
}
// isPrimitiveType checks if a reflect.Kind represents a primitive type
func isPrimitiveType(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.String, reflect.Float32, reflect.Float64, reflect.Bool:
return true
}
return false
}
// scanPrimitive scans a column value into a primitive type
func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error {
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
case reflect.String:
fieldValue.SetString(stmt.ColumnText(colIndex))
case reflect.Float32, reflect.Float64:
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
case reflect.Bool:
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
default:
return fmt.Errorf("unsupported type: %v", fieldValue.Kind())
}
return nil
}
func convertPlaceholders(query string) (string, []string) {
var paramTypes []string

26
go.sum
View File

@ -1,5 +1,7 @@
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@ -10,16 +12,40 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/libc v1.65.7 h1:Ia9Z4yzZtWNtUIuiPuQ7Qf7kxYrxP1/jeHZzG8bFu00=
modernc.org/libc v1.65.7/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo=
zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc=

View File

@ -172,7 +172,7 @@ func (m *Migrator) Run() error {
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
// Execute the migration SQL
if err := sqlitex.Execute(conn, migration.Content, nil); err != nil {
if err := sqlitex.ExecuteScript(conn, migration.Content, nil); err != nil {
return fmt.Errorf("failed to execute migration %d (%s): %w",
migration.Number, migration.Name, err)
}