fix rules

This commit is contained in:
Sky Johnson 2025-08-05 21:45:03 -05:00
parent 37574a7db2
commit c67a7bf6c6
2 changed files with 1036 additions and 209 deletions

View File

@ -5,17 +5,18 @@ import (
"log" "log"
"strconv" "strconv"
"eq2emu/internal/database" "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
) )
// DatabaseService handles rule database operations // DatabaseService handles rule database operations
// Converted from C++ WorldDatabase rule functions // Converted from C++ WorldDatabase rule functions
type DatabaseService struct { type DatabaseService struct {
db *database.DB db *sqlite.Conn
} }
// NewDatabaseService creates a new database service instance // NewDatabaseService creates a new database service instance
func NewDatabaseService(db *database.DB) *DatabaseService { func NewDatabaseService(db *sqlite.Conn) *DatabaseService {
return &DatabaseService{ return &DatabaseService{
db: db, db: db,
} }
@ -32,18 +33,20 @@ func (ds *DatabaseService) LoadGlobalRuleSet(ruleManager *RuleManager) error {
// Get the default ruleset ID from variables table // Get the default ruleset ID from variables table
query := "SELECT variable_value FROM variables WHERE variable_name = ?" query := "SELECT variable_value FROM variables WHERE variable_name = ?"
row, err := ds.db.QueryRow(query, DefaultRuleSetIDVar) stmt := ds.db.Prep(query)
stmt.BindText(1, DefaultRuleSetIDVar)
hasRow, err := stmt.Step()
if err != nil { if err != nil {
return fmt.Errorf("error querying default ruleset ID: %v", err) return fmt.Errorf("error querying default ruleset ID: %v", err)
} }
if row == nil { if !hasRow {
log.Printf("[Rules] Variables table is missing %s variable name, using code-default rules", DefaultRuleSetIDVar) log.Printf("[Rules] Variables table is missing %s variable name, using code-default rules", DefaultRuleSetIDVar)
return nil return nil
} }
defer row.Close()
variableValue := row.Text(0) variableValue := stmt.ColumnText(0)
if id, err := strconv.ParseInt(variableValue, 10, 32); err == nil { if id, err := strconv.ParseInt(variableValue, 10, 32); err == nil {
ruleSetID = int32(id) ruleSetID = int32(id)
log.Printf("[Rules] Loading Global Ruleset id %d", ruleSetID) log.Printf("[Rules] Loading Global Ruleset id %d", ruleSetID)
@ -79,9 +82,20 @@ func (ds *DatabaseService) LoadRuleSets(ruleManager *RuleManager, reload bool) e
query := "SELECT ruleset_id, ruleset_name FROM rulesets WHERE ruleset_active > 0" query := "SELECT ruleset_id, ruleset_name FROM rulesets WHERE ruleset_active > 0"
loadedCount := 0 loadedCount := 0
err := ds.db.Query(query, func(row *database.Row) error { stmt := ds.db.Prep(query)
ruleSetID := int32(row.Int64(0)) defer stmt.Finalize()
ruleSetName := row.Text(1)
for {
hasRow, err := stmt.Step()
if err != nil {
return fmt.Errorf("error querying rule sets: %v", err)
}
if !hasRow {
break
}
ruleSetID := int32(stmt.ColumnInt64(0))
ruleSetName := stmt.ColumnText(1)
ruleSet := NewRuleSet() ruleSet := NewRuleSet()
ruleSet.SetID(ruleSetID) ruleSet.SetID(ruleSetID)
@ -93,23 +107,18 @@ func (ds *DatabaseService) LoadRuleSets(ruleManager *RuleManager, reload bool) e
err := ds.LoadRuleSetDetails(ruleManager, ruleSet) err := ds.LoadRuleSetDetails(ruleManager, ruleSet)
if err != nil { if err != nil {
log.Printf("[Rules] Error loading rule set details for '%s': %v", ruleSetName, err) log.Printf("[Rules] Error loading rule set details for '%s': %v", ruleSetName, err)
return nil // Continue with other rule sets continue // Continue with other rule sets
} }
loadedCount++ loadedCount++
} else { } else {
log.Printf("[Rules] Unable to add rule set '%s' - ID %d already exists", ruleSetName, ruleSetID) log.Printf("[Rules] Unable to add rule set '%s' - ID %d already exists", ruleSetName, ruleSetID)
} }
return nil
})
if err != nil {
return fmt.Errorf("error querying rule sets: %v", err)
} }
log.Printf("[Rules] Loaded %d Rule Sets", loadedCount) log.Printf("[Rules] Loaded %d Rule Sets", loadedCount)
// Load global rule set // Load global rule set
err = ds.LoadGlobalRuleSet(ruleManager) err := ds.LoadGlobalRuleSet(ruleManager)
if err != nil { if err != nil {
return fmt.Errorf("error loading global rule set: %v", err) return fmt.Errorf("error loading global rule set: %v", err)
} }
@ -136,26 +145,33 @@ func (ds *DatabaseService) LoadRuleSetDetails(ruleManager *RuleManager, ruleSet
query := "SELECT rule_category, rule_type, rule_value FROM ruleset_details WHERE ruleset_id = ?" query := "SELECT rule_category, rule_type, rule_value FROM ruleset_details WHERE ruleset_id = ?"
loadedRules := 0 loadedRules := 0
err := ds.db.Query(query, func(row *database.Row) error { stmt := ds.db.Prep(query)
categoryName := row.Text(0) stmt.BindInt64(1, int64(ruleSet.GetID()))
typeName := row.Text(1) defer stmt.Finalize()
ruleValue := row.Text(2)
for {
hasRow, err := stmt.Step()
if err != nil {
return fmt.Errorf("error querying rule set details: %v", err)
}
if !hasRow {
break
}
categoryName := stmt.ColumnText(0)
typeName := stmt.ColumnText(1)
ruleValue := stmt.ColumnText(2)
// Find the rule by name // Find the rule by name
rule := ruleSet.GetRuleByName(categoryName, typeName) rule := ruleSet.GetRuleByName(categoryName, typeName)
if rule == nil { if rule == nil {
log.Printf("[Rules] Unknown rule with category '%s' and type '%s'", categoryName, typeName) log.Printf("[Rules] Unknown rule with category '%s' and type '%s'", categoryName, typeName)
return nil // Continue with other rules continue // Continue with other rules
} }
log.Printf("[Rules] Setting rule category '%s', type '%s' to value: %s", categoryName, typeName, ruleValue) log.Printf("[Rules] Setting rule category '%s', type '%s' to value: %s", categoryName, typeName, ruleValue)
rule.SetValue(ruleValue) rule.SetValue(ruleValue)
loadedRules++ loadedRules++
return nil
}, ruleSet.GetID())
if err != nil {
return fmt.Errorf("error querying rule set details: %v", err)
} }
log.Printf("[Rules] Loaded %d rule overrides for rule set '%s'", loadedRules, ruleSet.GetName()) log.Printf("[Rules] Loaded %d rule overrides for rule set '%s'", loadedRules, ruleSet.GetName())
@ -175,45 +191,62 @@ func (ds *DatabaseService) SaveRuleSet(ruleSet *RuleSet) error {
} }
// Use transaction for atomicity // Use transaction for atomicity
return ds.db.Transaction(func(tx *database.DB) error { var err error
// Insert or update rule set defer sqlitex.Save(ds.db)(&err)
query := `INSERT INTO rulesets (ruleset_id, ruleset_name, ruleset_active)
VALUES (?, ?, 1)
ON CONFLICT(ruleset_id) DO UPDATE SET
ruleset_name = excluded.ruleset_name,
ruleset_active = excluded.ruleset_active`
err := tx.Exec(query, ruleSet.GetID(), ruleSet.GetName()) // Insert or update rule set
if err != nil { query := `INSERT INTO rulesets (ruleset_id, ruleset_name, ruleset_active)
return fmt.Errorf("error saving rule set: %v", err) VALUES (?, ?, 1)
} ON CONFLICT(ruleset_id) DO UPDATE SET
ruleset_name = excluded.ruleset_name,
ruleset_active = excluded.ruleset_active`
// Delete existing rule details stmt := ds.db.Prep(query)
err = tx.Exec("DELETE FROM ruleset_details WHERE ruleset_id = ?", ruleSet.GetID()) stmt.BindInt64(1, int64(ruleSet.GetID()))
if err != nil { stmt.BindText(2, ruleSet.GetName())
return fmt.Errorf("error deleting existing rule details: %v", err)
} _, err = stmt.Step()
if err != nil {
return fmt.Errorf("error saving rule set: %v", err)
}
stmt.Finalize()
// Insert rule details // Delete existing rule details
rules := ruleSet.GetRules() deleteQuery := "DELETE FROM ruleset_details WHERE ruleset_id = ?"
for _, categoryMap := range rules { deleteStmt := ds.db.Prep(deleteQuery)
for _, rule := range categoryMap { deleteStmt.BindInt64(1, int64(ruleSet.GetID()))
if rule.IsValid() { _, err = deleteStmt.Step()
combined := rule.GetCombined() if err != nil {
parts := splitCombined(combined) return fmt.Errorf("error deleting existing rule details: %v", err)
if len(parts) == 2 { }
query := "INSERT INTO ruleset_details (ruleset_id, rule_category, rule_type, rule_value) VALUES (?, ?, ?, ?)" deleteStmt.Finalize()
err = tx.Exec(query, ruleSet.GetID(), parts[0], parts[1], rule.GetValue())
if err != nil { // Insert rule details
return fmt.Errorf("error saving rule detail: %v", err) insertQuery := "INSERT INTO ruleset_details (ruleset_id, rule_category, rule_type, rule_value) VALUES (?, ?, ?, ?)"
} rules := ruleSet.GetRules()
for _, categoryMap := range rules {
for _, rule := range categoryMap {
if rule.IsValid() {
combined := rule.GetCombined()
parts := splitCombined(combined)
if len(parts) == 2 {
insertStmt := ds.db.Prep(insertQuery)
insertStmt.BindInt64(1, int64(ruleSet.GetID()))
insertStmt.BindText(2, parts[0])
insertStmt.BindText(3, parts[1])
insertStmt.BindText(4, rule.GetValue())
_, err = insertStmt.Step()
insertStmt.Finalize()
if err != nil {
return fmt.Errorf("error saving rule detail: %v", err)
} }
} }
} }
} }
}
return nil return nil
})
} }
// DeleteRuleSet deletes a rule set from the database // DeleteRuleSet deletes a rule set from the database
@ -223,21 +256,28 @@ func (ds *DatabaseService) DeleteRuleSet(ruleSetID int32) error {
} }
// Use transaction for atomicity // Use transaction for atomicity
return ds.db.Transaction(func(tx *database.DB) error { var err error
// Delete rule details first (foreign key constraint) defer sqlitex.Save(ds.db)(&err)
err := tx.Exec("DELETE FROM ruleset_details WHERE ruleset_id = ?", ruleSetID)
if err != nil {
return fmt.Errorf("error deleting rule details: %v", err)
}
// Delete rule set // Delete rule details first (foreign key constraint)
err = tx.Exec("DELETE FROM rulesets WHERE ruleset_id = ?", ruleSetID) detailsStmt := ds.db.Prep("DELETE FROM ruleset_details WHERE ruleset_id = ?")
if err != nil { detailsStmt.BindInt64(1, int64(ruleSetID))
return fmt.Errorf("error deleting rule set: %v", err) _, err = detailsStmt.Step()
} detailsStmt.Finalize()
if err != nil {
return fmt.Errorf("error deleting rule details: %v", err)
}
return nil // Delete rule set
}) rulesetStmt := ds.db.Prep("DELETE FROM rulesets WHERE ruleset_id = ?")
rulesetStmt.BindInt64(1, int64(ruleSetID))
_, err = rulesetStmt.Step()
rulesetStmt.Finalize()
if err != nil {
return fmt.Errorf("error deleting rule set: %v", err)
}
return nil
} }
// SetDefaultRuleSet sets the default rule set ID in the variables table // SetDefaultRuleSet sets the default rule set ID in the variables table
@ -251,7 +291,12 @@ func (ds *DatabaseService) SetDefaultRuleSet(ruleSetID int32) error {
ON CONFLICT(variable_name) DO UPDATE SET ON CONFLICT(variable_name) DO UPDATE SET
variable_value = excluded.variable_value` variable_value = excluded.variable_value`
err := ds.db.Exec(query, DefaultRuleSetIDVar, strconv.Itoa(int(ruleSetID))) stmt := ds.db.Prep(query)
stmt.BindText(1, DefaultRuleSetIDVar)
stmt.BindText(2, strconv.Itoa(int(ruleSetID)))
_, err := stmt.Step()
stmt.Finalize()
if err != nil { if err != nil {
return fmt.Errorf("error setting default rule set: %v", err) return fmt.Errorf("error setting default rule set: %v", err)
} }
@ -266,17 +311,19 @@ func (ds *DatabaseService) GetDefaultRuleSetID() (int32, error) {
} }
query := "SELECT variable_value FROM variables WHERE variable_name = ?" query := "SELECT variable_value FROM variables WHERE variable_name = ?"
row, err := ds.db.QueryRow(query, DefaultRuleSetIDVar) stmt := ds.db.Prep(query)
stmt.BindText(1, DefaultRuleSetIDVar)
hasRow, err := stmt.Step()
if err != nil { if err != nil {
return 0, fmt.Errorf("error querying default ruleset ID: %v", err) return 0, fmt.Errorf("error querying default ruleset ID: %v", err)
} }
if row == nil { if !hasRow {
return 0, fmt.Errorf("default ruleset ID not found in variables table") return 0, fmt.Errorf("default ruleset ID not found in variables table")
} }
defer row.Close()
variableValue := row.Text(0) variableValue := stmt.ColumnText(0)
if id, err := strconv.ParseInt(variableValue, 10, 32); err == nil { if id, err := strconv.ParseInt(variableValue, 10, 32); err == nil {
return int32(id), nil return int32(id), nil
} }
@ -293,18 +340,24 @@ func (ds *DatabaseService) GetRuleSetList() ([]RuleSetInfo, error) {
query := "SELECT ruleset_id, ruleset_name, ruleset_active FROM rulesets ORDER BY ruleset_id" query := "SELECT ruleset_id, ruleset_name, ruleset_active FROM rulesets ORDER BY ruleset_id"
var ruleSets []RuleSetInfo var ruleSets []RuleSetInfo
err := ds.db.Query(query, func(row *database.Row) error { stmt := ds.db.Prep(query)
defer stmt.Finalize()
for {
hasRow, err := stmt.Step()
if err != nil {
return nil, fmt.Errorf("error querying rule sets: %v", err)
}
if !hasRow {
break
}
info := RuleSetInfo{ info := RuleSetInfo{
ID: int32(row.Int64(0)), ID: int32(stmt.ColumnInt64(0)),
Name: row.Text(1), Name: stmt.ColumnText(1),
Active: row.Bool(2), Active: stmt.ColumnInt64(2) > 0, // Convert int to bool
} }
ruleSets = append(ruleSets, info) ruleSets = append(ruleSets, info)
return nil
})
if err != nil {
return nil, fmt.Errorf("error querying rule sets: %v", err)
} }
return ruleSets, nil return ruleSets, nil
@ -316,38 +369,26 @@ func (ds *DatabaseService) ValidateDatabase() error {
return fmt.Errorf("database not initialized") return fmt.Errorf("database not initialized")
} }
// Check if rulesets table exists tables := []string{"rulesets", "ruleset_details", "variables"}
query := "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='rulesets'" query := "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?"
row, err := ds.db.QueryRow(query)
if err != nil {
return fmt.Errorf("error checking rulesets table: %v", err)
}
if row == nil || row.Int(0) == 0 {
return fmt.Errorf("rulesets table does not exist")
}
row.Close()
// Check if ruleset_details table exists for _, table := range tables {
query = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='ruleset_details'" stmt := ds.db.Prep(query)
row, err = ds.db.QueryRow(query) stmt.BindText(1, table)
if err != nil {
return fmt.Errorf("error checking ruleset_details table: %v", err) hasRow, err := stmt.Step()
if err != nil {
stmt.Finalize()
return fmt.Errorf("error checking %s table: %v", table, err)
}
count := stmt.ColumnInt64(0)
stmt.Finalize()
if !hasRow || count == 0 {
return fmt.Errorf("%s table does not exist", table)
}
} }
if row == nil || row.Int(0) == 0 {
return fmt.Errorf("ruleset_details table does not exist")
}
row.Close()
// Check if variables table exists
query = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='variables'"
row, err = ds.db.QueryRow(query)
if err != nil {
return fmt.Errorf("error checking variables table: %v", err)
}
if row == nil || row.Int(0) == 0 {
return fmt.Errorf("variables table does not exist")
}
row.Close()
return nil return nil
} }
@ -383,7 +424,9 @@ func (ds *DatabaseService) CreateRulesTables() error {
ruleset_active INTEGER NOT NULL DEFAULT 0 ruleset_active INTEGER NOT NULL DEFAULT 0
)` )`
err := ds.db.Exec(createRuleSets) stmt := ds.db.Prep(createRuleSets)
_, err := stmt.Step()
stmt.Finalize()
if err != nil { if err != nil {
return fmt.Errorf("error creating rulesets table: %v", err) return fmt.Errorf("error creating rulesets table: %v", err)
} }
@ -400,7 +443,9 @@ func (ds *DatabaseService) CreateRulesTables() error {
FOREIGN KEY (ruleset_id) REFERENCES rulesets(ruleset_id) ON DELETE CASCADE FOREIGN KEY (ruleset_id) REFERENCES rulesets(ruleset_id) ON DELETE CASCADE
)` )`
err = ds.db.Exec(createRuleSetDetails) stmt = ds.db.Prep(createRuleSetDetails)
_, err = stmt.Step()
stmt.Finalize()
if err != nil { if err != nil {
return fmt.Errorf("error creating ruleset_details table: %v", err) return fmt.Errorf("error creating ruleset_details table: %v", err)
} }
@ -413,7 +458,9 @@ func (ds *DatabaseService) CreateRulesTables() error {
comment TEXT comment TEXT
)` )`
err = ds.db.Exec(createVariables) stmt = ds.db.Prep(createVariables)
_, err = stmt.Step()
stmt.Finalize()
if err != nil { if err != nil {
return fmt.Errorf("error creating variables table: %v", err) return fmt.Errorf("error creating variables table: %v", err)
} }
@ -426,7 +473,9 @@ func (ds *DatabaseService) CreateRulesTables() error {
} }
for _, indexSQL := range indexes { for _, indexSQL := range indexes {
err = ds.db.Exec(indexSQL) stmt = ds.db.Prep(indexSQL)
_, err = stmt.Step()
stmt.Finalize()
if err != nil { if err != nil {
return fmt.Errorf("error creating index: %v", err) return fmt.Errorf("error creating index: %v", err)
} }

File diff suppressed because it is too large Load Diff