diff --git a/db.go b/db.go index 90f4788..623fc2a 100644 --- a/db.go +++ b/db.go @@ -214,7 +214,7 @@ func (db *DB) Get(dest any, query string, args ...any) error { return fmt.Errorf("no rows found") } - return db.Scan(stmt, dest) + return db.scanValue(stmt, dest) } // Select executes a query and scans all rows into a slice @@ -435,6 +435,59 @@ func (db *DB) Transaction(fn func() error) error { return sqlitex.Execute(conn, "COMMIT", nil) } +// scanValue scans a statement result into either a struct or primitive type +func (db *DB) scanValue(stmt *PooledStmt, dest any) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Pointer { + return fmt.Errorf("dest must be a pointer") + } + + elem := v.Elem() + + // Handle primitive types + if isPrimitiveType(elem.Kind()) { + if stmt.ColumnCount() == 0 { + return fmt.Errorf("no columns in result") + } + + return scanPrimitive(stmt, elem, 0) + } + + // Handle struct types + if elem.Kind() != reflect.Struct { + return fmt.Errorf("dest must be a pointer to struct or primitive type") + } + + return db.Scan(stmt, dest) +} + +// isPrimitiveType checks if a reflect.Kind represents a primitive type +func isPrimitiveType(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.String, reflect.Float32, reflect.Float64, reflect.Bool: + return true + } + return false +} + +// scanPrimitive scans a column value into a primitive type +func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error { + switch fieldValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fieldValue.SetInt(stmt.ColumnInt64(colIndex)) + case reflect.String: + fieldValue.SetString(stmt.ColumnText(colIndex)) + case reflect.Float32, reflect.Float64: + fieldValue.SetFloat(stmt.ColumnFloat(colIndex)) + case reflect.Bool: + fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0) + default: + return fmt.Errorf("unsupported type: %v", fieldValue.Kind()) + } + return nil +} + func convertPlaceholders(query string) (string, []string) { var paramTypes []string