123 lines
2.8 KiB
Go
123 lines
2.8 KiB
Go
package database
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"zombiezen.com/go/sqlite"
|
|
)
|
|
|
|
// Model interface for trackable database models
|
|
type Model interface {
|
|
GetTableName() string
|
|
GetID() int
|
|
SetID(id int)
|
|
GetDirtyFields() map[string]any
|
|
SetDirty(field string, value any)
|
|
ClearDirty()
|
|
IsDirty() bool
|
|
}
|
|
|
|
// BaseModel provides common model functionality
|
|
type BaseModel struct {
|
|
FieldTracker
|
|
}
|
|
|
|
// Set uses reflection to set a field and track changes
|
|
func Set(model Model, field string, value any) error {
|
|
v := reflect.ValueOf(model).Elem()
|
|
fieldVal := v.FieldByName(field)
|
|
|
|
if !fieldVal.IsValid() {
|
|
return fmt.Errorf("field %s does not exist", field)
|
|
}
|
|
|
|
if !fieldVal.CanSet() {
|
|
return fmt.Errorf("field %s cannot be set", field)
|
|
}
|
|
|
|
// Get current value for comparison
|
|
currentVal := fieldVal.Interface()
|
|
|
|
// Only set if value has changed
|
|
if !reflect.DeepEqual(currentVal, value) {
|
|
// Convert value to correct type
|
|
newVal := reflect.ValueOf(value)
|
|
if newVal.Type().ConvertibleTo(fieldVal.Type()) {
|
|
fieldVal.Set(newVal.Convert(fieldVal.Type()))
|
|
|
|
// Convert field name to snake_case for database
|
|
dbField := toSnakeCase(field)
|
|
model.SetDirty(dbField, value)
|
|
} else {
|
|
return fmt.Errorf("cannot convert %T to %s", value, fieldVal.Type())
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// toSnakeCase converts CamelCase to snake_case
|
|
func toSnakeCase(s string) string {
|
|
var result strings.Builder
|
|
for i, r := range s {
|
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
|
result.WriteByte('_')
|
|
}
|
|
if r >= 'A' && r <= 'Z' {
|
|
result.WriteRune(r - 'A' + 'a')
|
|
} else {
|
|
result.WriteRune(r)
|
|
}
|
|
}
|
|
return result.String()
|
|
}
|
|
|
|
// Save updates only dirty fields
|
|
func Save(model Model) error {
|
|
if model.GetID() == 0 {
|
|
return fmt.Errorf("cannot save model without ID")
|
|
}
|
|
return UpdateDirty(model)
|
|
}
|
|
|
|
// Insert creates a new record and sets the ID
|
|
func Insert(model Model, columns string, values ...any) error {
|
|
if model.GetID() != 0 {
|
|
return fmt.Errorf("model already has ID %d, use Save() to update", model.GetID())
|
|
}
|
|
|
|
return Transaction(func(tx *Tx) error {
|
|
placeholders := strings.Repeat("?,", len(values))
|
|
placeholders = placeholders[:len(placeholders)-1] // Remove trailing comma
|
|
|
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
|
model.GetTableName(), columns, placeholders)
|
|
|
|
if err := tx.Exec(query, values...); err != nil {
|
|
return fmt.Errorf("failed to insert: %w", err)
|
|
}
|
|
|
|
var id int
|
|
err := tx.Query("SELECT last_insert_rowid()", func(stmt *sqlite.Stmt) error {
|
|
id = stmt.ColumnInt(0)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get insert ID: %w", err)
|
|
}
|
|
|
|
model.SetID(id)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// Delete removes the record
|
|
func Delete(model Model) error {
|
|
if model.GetID() == 0 {
|
|
return fmt.Errorf("cannot delete model without ID")
|
|
}
|
|
return Exec("DELETE FROM ? WHERE id = ?", model.GetTableName(), model.GetID())
|
|
}
|