diff --git a/db.go b/db.go index 5a97b96..90f4788 100644 --- a/db.go +++ b/db.go @@ -324,7 +324,7 @@ func (db *DB) Update(tableName string, fields map[string]any, whereField string, }) } -// Insert inserts a struct into the database +// Insert inserts a struct or map into the database func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, error) { conn, err := db.pool.Take(context.Background()) if err != nil { @@ -332,12 +332,6 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, } defer db.pool.Put(conn) - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Pointer { - v = v.Elem() - } - t := v.Type() - exclude := make(map[string]bool) for _, field := range excludeFields { exclude[toSnakeCase(field)] = true @@ -347,16 +341,41 @@ func (db *DB) Insert(tableName string, obj any, excludeFields ...string) (int64, var placeholders []string var args []any - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - columnName := toSnakeCase(field.Name) - if exclude[columnName] { - continue + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Map: + // Handle map[string]any + m := obj.(map[string]any) + for key, value := range m { + columnName := toSnakeCase(key) + if exclude[columnName] { + continue + } + columns = append(columns, columnName) + placeholders = append(placeholders, "?") + args = append(args, value) } - columns = append(columns, columnName) - placeholders = append(placeholders, "?") - args = append(args, v.Field(i).Interface()) + case reflect.Struct: + // Handle struct + t := v.Type() + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + columnName := toSnakeCase(field.Name) + if exclude[columnName] { + continue + } + columns = append(columns, columnName) + placeholders = append(placeholders, "?") + args = append(args, v.Field(i).Interface()) + } + + default: + return 0, fmt.Errorf("obj must be a struct, pointer to struct, or map[string]any") } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",