130 lines
3.0 KiB
Go

package database
import (
"fmt"
"reflect"
"strings"
"zombiezen.com/go/sqlite"
)
// Model interface for trackable database models
type Model interface {
GetTableName() string
GetID() int
SetID(id int)
GetDirtyFields() map[string]any
SetDirty(field string, value any)
ClearDirty()
IsDirty() bool
}
// BaseModel provides common model functionality
type BaseModel struct {
FieldTracker
}
// Set uses reflection to set a field and track changes
func Set(model Model, field string, value any) error {
v := reflect.ValueOf(model).Elem()
t := v.Type()
fieldVal := v.FieldByName(field)
if !fieldVal.IsValid() {
return fmt.Errorf("field %s does not exist", field)
}
if !fieldVal.CanSet() {
return fmt.Errorf("field %s cannot be set", field)
}
// Get current value for comparison
currentVal := fieldVal.Interface()
// Only set if value has changed
if !reflect.DeepEqual(currentVal, value) {
newVal := reflect.ValueOf(value)
if newVal.Type().ConvertibleTo(fieldVal.Type()) {
fieldVal.Set(newVal.Convert(fieldVal.Type()))
// Get db column name from struct tag
structField, _ := t.FieldByName(field)
dbField := structField.Tag.Get("db")
if dbField == "" {
dbField = toSnakeCase(field) // fallback
}
model.SetDirty(dbField, value)
} else {
return fmt.Errorf("cannot convert %T to %s", value, fieldVal.Type())
}
}
return nil
}
// toSnakeCase converts CamelCase to snake_case
func toSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
prev := rune(s[i-1])
if prev < 'A' || prev > 'Z' {
result.WriteByte('_')
}
}
if r >= 'A' && r <= 'Z' {
result.WriteRune(r - 'A' + 'a')
} else {
result.WriteRune(r)
}
}
return result.String()
}
// Save updates only dirty fields
func Save(model Model) error {
if model.GetID() == 0 {
return fmt.Errorf("cannot save model without ID")
}
return UpdateDirty(model)
}
// Insert creates a new record and sets the ID
func Insert(model Model, columns string, values ...any) error {
if model.GetID() != 0 {
return fmt.Errorf("model already has ID %d, use Save() to update", model.GetID())
}
return Transaction(func(tx *Tx) error {
placeholders := strings.Repeat("?,", len(values))
placeholders = placeholders[:len(placeholders)-1] // Remove trailing comma
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
model.GetTableName(), columns, placeholders)
if err := tx.Exec(query, values...); err != nil {
return fmt.Errorf("failed to insert: %w", err)
}
var id int
err := tx.Query("SELECT last_insert_rowid()", func(stmt *sqlite.Stmt) error {
id = stmt.ColumnInt(0)
return nil
})
if err != nil {
return fmt.Errorf("failed to get insert ID: %w", err)
}
model.SetID(id)
return nil
})
}
// Delete removes the record
func Delete(model Model) error {
if model.GetID() == 0 {
return fmt.Errorf("cannot delete model without ID")
}
return Exec("DELETE FROM ? WHERE id = ?", model.GetTableName(), model.GetID())
}