package database import ( "fmt" "reflect" "strings" "zombiezen.com/go/sqlite" ) // Model interface for trackable database models type Model interface { GetTableName() string GetID() int SetID(id int) GetDirtyFields() map[string]any SetDirty(field string, value any) ClearDirty() IsDirty() bool } // BaseModel provides common model functionality type BaseModel struct { FieldTracker } // Set uses reflection to set a field and track changes func Set(model Model, field string, value any) error { v := reflect.ValueOf(model).Elem() fieldVal := v.FieldByName(field) if !fieldVal.IsValid() { return fmt.Errorf("field %s does not exist", field) } if !fieldVal.CanSet() { return fmt.Errorf("field %s cannot be set", field) } // Get current value for comparison currentVal := fieldVal.Interface() // Only set if value has changed if !reflect.DeepEqual(currentVal, value) { // Convert value to correct type newVal := reflect.ValueOf(value) if newVal.Type().ConvertibleTo(fieldVal.Type()) { fieldVal.Set(newVal.Convert(fieldVal.Type())) // Convert field name to snake_case for database dbField := toSnakeCase(field) model.SetDirty(dbField, value) } else { return fmt.Errorf("cannot convert %T to %s", value, fieldVal.Type()) } } return nil } // toSnakeCase converts CamelCase to snake_case func toSnakeCase(s string) string { var result strings.Builder for i, r := range s { if i > 0 && r >= 'A' && r <= 'Z' { result.WriteByte('_') } if r >= 'A' && r <= 'Z' { result.WriteRune(r - 'A' + 'a') } else { result.WriteRune(r) } } return result.String() } // Save updates only dirty fields func Save(model Model) error { if model.GetID() == 0 { return fmt.Errorf("cannot save model without ID") } return UpdateDirty(model) } // Insert creates a new record and sets the ID func Insert(model Model, columns string, values ...any) error { if model.GetID() != 0 { return fmt.Errorf("model already has ID %d, use Save() to update", model.GetID()) } return Transaction(func(tx *Tx) error { placeholders := strings.Repeat("?,", len(values)) placeholders = placeholders[:len(placeholders)-1] // Remove trailing comma query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.GetTableName(), columns, placeholders) if err := tx.Exec(query, values...); err != nil { return fmt.Errorf("failed to insert: %w", err) } var id int err := tx.Query("SELECT last_insert_rowid()", func(stmt *sqlite.Stmt) error { id = stmt.ColumnInt(0) return nil }) if err != nil { return fmt.Errorf("failed to get insert ID: %w", err) } model.SetID(id) return nil }) } // Delete removes the record func Delete(model Model) error { if model.GetID() == 0 { return fmt.Errorf("cannot delete model without ID") } return Exec("DELETE FROM ? WHERE id = ?", model.GetTableName(), model.GetID()) }