allow inserting maps as well as structs

This commit is contained in:
Sky Johnson 2025-08-25 14:27:36 -05:00
parent 44c823a1ae
commit 8d62c3b589

49
db.go
View File

@ -324,7 +324,7 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string,
}) })
} }
// Insert inserts a struct into the database // Insert inserts a struct or map into the database
func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) { func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
conn, err := db.pool.Take(context.Background()) conn, err := db.pool.Take(context.Background())
if err != nil { if err != nil {
@ -332,12 +332,6 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
} }
defer db.pool.Put(conn) defer db.pool.Put(conn)
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
t := v.Type()
exclude := make(map[string]bool) exclude := make(map[string]bool)
for _, field := range excludeFields { for _, field := range excludeFields {
exclude[toSnakeCase(field)] = true exclude[toSnakeCase(field)] = true
@ -347,16 +341,41 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64,
var placeholders []string var placeholders []string
var args []any var args []any
for i := 0; i < t.NumField(); i++ { v := reflect.ValueOf(obj)
field := t.Field(i) if v.Kind() == reflect.Pointer {
columnName := toSnakeCase(field.Name) v = v.Elem()
if exclude[columnName] { }
continue
switch v.Kind() {
case reflect.Map:
// Handle map[string]any
m := obj.(map[string]any)
for key, value := range m {
columnName := toSnakeCase(key)
if exclude[columnName] {
continue
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, value)
} }
columns = append(columns, columnName) case reflect.Struct:
placeholders = append(placeholders, "?") // Handle struct
args = append(args, v.Field(i).Interface()) t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
columnName := toSnakeCase(field.Name)
if exclude[columnName] {
continue
}
columns = append(columns, columnName)
placeholders = append(placeholders, "?")
args = append(args, v.Field(i).Interface())
}
default:
return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any")
} }
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",