Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
66860a7126 | |||
8d62c3b589 | |||
44c823a1ae | |||
08a4a2c99f |
116
db.go
116
db.go
@ -82,7 +82,7 @@ func (db *DB) Pool() *sqlitex.Pool {
|
||||
}
|
||||
|
||||
// Scan scans a SQLite statement result into a struct using field names
|
||||
func (db *DB) Scan(stmt *pooledStmt, dest any) error {
|
||||
func (db *DB) Scan(stmt *PooledStmt, dest any) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to struct")
|
||||
@ -129,7 +129,7 @@ func (db *DB) Scan(stmt *pooledStmt, dest any) error {
|
||||
}
|
||||
|
||||
// Query executes a query with fmt-style placeholders and automatically binds parameters
|
||||
func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
|
||||
func (db *DB) Query(query string, args ...any) (*PooledStmt, error) {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||
@ -177,18 +177,18 @@ func (db *DB) Query(query string, args ...any) (*pooledStmt, error) {
|
||||
}
|
||||
|
||||
// Create a wrapped statement that releases the connection when finalized
|
||||
return &pooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
|
||||
return &PooledStmt{Stmt: stmt, pool: db.pool, conn: conn}, nil
|
||||
}
|
||||
|
||||
// pooledStmt wraps a statement to automatically release pool connections
|
||||
type pooledStmt struct {
|
||||
// PooledStmt wraps a statement to automatically release pool connections
|
||||
type PooledStmt struct {
|
||||
*sqlite.Stmt
|
||||
pool *sqlitex.Pool
|
||||
conn *sqlite.Conn
|
||||
finalized bool
|
||||
}
|
||||
|
||||
func (ps *pooledStmt) Finalize() error {
|
||||
func (ps *PooledStmt) Finalize() error {
|
||||
if !ps.finalized {
|
||||
err := ps.Stmt.Finalize()
|
||||
ps.pool.Put(ps.conn)
|
||||
@ -214,7 +214,7 @@ func (db *DB) Get(dest any, query string, args ...any) error {
|
||||
return fmt.Errorf("no rows found")
|
||||
}
|
||||
|
||||
return db.Scan(stmt, dest)
|
||||
return db.scanValue(stmt, dest)
|
||||
}
|
||||
|
||||
// Select executes a query and scans all rows into a slice
|
||||
@ -324,7 +324,7 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
|
||||
})
|
||||
}
|
||||
|
||||
// Insert inserts a struct into the database
|
||||
// Insert inserts a struct or map into the database
|
||||
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
@ -332,12 +332,6 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
||||
}
|
||||
defer db.pool.Put(conn)
|
||||
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
}
|
||||
t := v.Type()
|
||||
|
||||
exclude := make(map[string]bool)
|
||||
for _, field := range excludeFields {
|
||||
exclude[toSnakeCase(field)] = true
|
||||
@ -347,16 +341,41 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
||||
var placeholders []string
|
||||
var args []any
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Map:
|
||||
// Handle map[string]any
|
||||
m := obj.(map[string]any)
|
||||
for key, value := range m {
|
||||
columnName := toSnakeCase(key)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, value)
|
||||
}
|
||||
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, v.Field(i).Interface())
|
||||
case reflect.Struct:
|
||||
// Handle struct
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, v.Field(i).Interface())
|
||||
}
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
@ -416,6 +435,59 @@ func (db *DB) Transaction(fn func() error) error {
|
||||
return sqlitex.Execute(conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
// scanValue scans a statement result into either a struct or primitive type
|
||||
func (db *DB) scanValue(stmt *PooledStmt, dest any) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
elem := v.Elem()
|
||||
|
||||
// Handle primitive types
|
||||
if isPrimitiveType(elem.Kind()) {
|
||||
if stmt.ColumnCount() == 0 {
|
||||
return fmt.Errorf("no columns in result")
|
||||
}
|
||||
|
||||
return scanPrimitive(stmt, elem, 0)
|
||||
}
|
||||
|
||||
// Handle struct types
|
||||
if elem.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to struct or primitive type")
|
||||
}
|
||||
|
||||
return db.Scan(stmt, dest)
|
||||
}
|
||||
|
||||
// isPrimitiveType checks if a reflect.Kind represents a primitive type
|
||||
func isPrimitiveType(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.String, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// scanPrimitive scans a column value into a primitive type
|
||||
func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error {
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
|
||||
case reflect.String:
|
||||
fieldValue.SetString(stmt.ColumnText(colIndex))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
|
||||
case reflect.Bool:
|
||||
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %v", fieldValue.Kind())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertPlaceholders(query string) (string, []string) {
|
||||
var paramTypes []string
|
||||
|
||||
|
26
go.sum
26
go.sum
@ -1,5 +1,7 @@
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
@ -10,16 +12,40 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
|
||||
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
|
||||
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/libc v1.65.7 h1:Ia9Z4yzZtWNtUIuiPuQ7Qf7kxYrxP1/jeHZzG8bFu00=
|
||||
modernc.org/libc v1.65.7/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
|
||||
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo=
|
||||
zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc=
|
||||
|
@ -172,7 +172,7 @@ func (m *Migrator) Run() error {
|
||||
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
|
||||
|
||||
// Execute the migration SQL
|
||||
if err := sqlitex.Execute(conn, migration.Content, nil); err != nil {
|
||||
if err := sqlitex.ExecuteScript(conn, migration.Content, nil); err != nil {
|
||||
return fmt.Errorf("failed to execute migration %d (%s): %w",
|
||||
migration.Number, migration.Name, err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user