add scanning into primitives
This commit is contained in:
parent
8d62c3b589
commit
66860a7126
55
db.go
55
db.go
@ -214,7 +214,7 @@ func (db *DB) Get(dest any, query string, args ...any) error {
|
|||||||
return fmt.Errorf("no rows found")
|
return fmt.Errorf("no rows found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.Scan(stmt, dest)
|
return db.scanValue(stmt, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select executes a query and scans all rows into a slice
|
// Select executes a query and scans all rows into a slice
|
||||||
@ -435,6 +435,59 @@ func (db *DB) Transaction(fn func() error) error {
|
|||||||
return sqlitex.Execute(conn, "COMMIT", nil)
|
return sqlitex.Execute(conn, "COMMIT", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// scanValue scans a statement result into either a struct or primitive type
|
||||||
|
func (db *DB) scanValue(stmt *PooledStmt, dest any) error {
|
||||||
|
v := reflect.ValueOf(dest)
|
||||||
|
if v.Kind() != reflect.Pointer {
|
||||||
|
return fmt.Errorf("dest must be a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
elem := v.Elem()
|
||||||
|
|
||||||
|
// Handle primitive types
|
||||||
|
if isPrimitiveType(elem.Kind()) {
|
||||||
|
if stmt.ColumnCount() == 0 {
|
||||||
|
return fmt.Errorf("no columns in result")
|
||||||
|
}
|
||||||
|
|
||||||
|
return scanPrimitive(stmt, elem, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle struct types
|
||||||
|
if elem.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("dest must be a pointer to struct or primitive type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Scan(stmt, dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrimitiveType checks if a reflect.Kind represents a primitive type
|
||||||
|
func isPrimitiveType(k reflect.Kind) bool {
|
||||||
|
switch k {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
|
reflect.String, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanPrimitive scans a column value into a primitive type
|
||||||
|
func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error {
|
||||||
|
switch fieldValue.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
|
||||||
|
case reflect.String:
|
||||||
|
fieldValue.SetString(stmt.ColumnText(colIndex))
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
|
||||||
|
case reflect.Bool:
|
||||||
|
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported type: %v", fieldValue.Kind())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func convertPlaceholders(query string) (string, []string) {
|
func convertPlaceholders(query string) (string, []string) {
|
||||||
var paramTypes []string
|
var paramTypes []string
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user