Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
66860a7126 | |||
8d62c3b589 | |||
44c823a1ae |
104
db.go
104
db.go
@ -214,7 +214,7 @@ func (db *DB) Get(dest any, query string, args ...any) error {
|
||||
return fmt.Errorf("no rows found")
|
||||
}
|
||||
|
||||
return db.Scan(stmt, dest)
|
||||
return db.scanValue(stmt, dest)
|
||||
}
|
||||
|
||||
// Select executes a query and scans all rows into a slice
|
||||
@ -324,7 +324,7 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
|
||||
})
|
||||
}
|
||||
|
||||
// Insert inserts a struct into the database
|
||||
// Insert inserts a struct or map into the database
|
||||
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
||||
conn, err := db.pool.Take(context.Background())
|
||||
if err != nil {
|
||||
@ -332,12 +332,6 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
||||
}
|
||||
defer db.pool.Put(conn)
|
||||
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
}
|
||||
t := v.Type()
|
||||
|
||||
exclude := make(map[string]bool)
|
||||
for _, field := range excludeFields {
|
||||
exclude[toSnakeCase(field)] = true
|
||||
@ -347,16 +341,41 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
|
||||
var placeholders []string
|
||||
var args []any
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Map:
|
||||
// Handle map[string]any
|
||||
m := obj.(map[string]any)
|
||||
for key, value := range m {
|
||||
columnName := toSnakeCase(key)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, value)
|
||||
}
|
||||
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, v.Field(i).Interface())
|
||||
case reflect.Struct:
|
||||
// Handle struct
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, v.Field(i).Interface())
|
||||
}
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
@ -416,6 +435,59 @@ func (db *DB) Transaction(fn func() error) error {
|
||||
return sqlitex.Execute(conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
// scanValue scans a statement result into either a struct or primitive type
|
||||
func (db *DB) scanValue(stmt *PooledStmt, dest any) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
elem := v.Elem()
|
||||
|
||||
// Handle primitive types
|
||||
if isPrimitiveType(elem.Kind()) {
|
||||
if stmt.ColumnCount() == 0 {
|
||||
return fmt.Errorf("no columns in result")
|
||||
}
|
||||
|
||||
return scanPrimitive(stmt, elem, 0)
|
||||
}
|
||||
|
||||
// Handle struct types
|
||||
if elem.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to struct or primitive type")
|
||||
}
|
||||
|
||||
return db.Scan(stmt, dest)
|
||||
}
|
||||
|
||||
// isPrimitiveType checks if a reflect.Kind represents a primitive type
|
||||
func isPrimitiveType(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.String, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// scanPrimitive scans a column value into a primitive type
|
||||
func scanPrimitive(stmt *PooledStmt, fieldValue reflect.Value, colIndex int) error {
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
|
||||
case reflect.String:
|
||||
fieldValue.SetString(stmt.ColumnText(colIndex))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
|
||||
case reflect.Bool:
|
||||
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %v", fieldValue.Kind())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertPlaceholders(query string) (string, []string) {
|
||||
var paramTypes []string
|
||||
|
||||
|
@ -172,7 +172,7 @@ func (m *Migrator) Run() error {
|
||||
fmt.Printf("Running migration %d: %s\n", migration.Number, migration.Name)
|
||||
|
||||
// Execute the migration SQL
|
||||
if err := sqlitex.Execute(conn, migration.Content, nil); err != nil {
|
||||
if err := sqlitex.ExecuteScript(conn, migration.Content, nil); err != nil {
|
||||
return fmt.Errorf("failed to execute migration %d (%s): %w",
|
||||
migration.Number, migration.Name, err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user