From c2eeaa2f429fa729d2d338c5af129bdbc237555d Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 13 Aug 2025 22:24:40 -0500 Subject: [PATCH] migrate all models to in-memory --- internal/babble/babble.go | 414 +++++++++++++++--------- internal/control/control.go | 250 ++++++++++----- internal/database/database.go | 201 ------------ internal/database/database_test.go | 72 ----- internal/database/model.go | 129 -------- internal/database/trackable.go | 82 ----- internal/drops/drops.go | 325 +++++++++++++------ internal/forum/forum.go | 473 +++++++++++++++++++--------- internal/items/items.go | 286 +++++++++++------ internal/monsters/monsters.go | 481 ++++++++++------------------ internal/news/news.go | 394 ++++++++++++++--------- internal/server/server.go | 11 +- internal/spells/spells.go | 396 +++++++++++++++-------- internal/store/store.go | 198 ++++++++++++ internal/towns/towns.go | 434 ++++++++++++++++--------- internal/users/users.go | 488 +++++++++++++++++++---------- main.go | 9 +- test_load_data.go | 34 -- 18 files changed, 2682 insertions(+), 1995 deletions(-) delete mode 100644 internal/database/database.go delete mode 100644 internal/database/database_test.go delete mode 100644 internal/database/model.go delete mode 100644 internal/database/trackable.go create mode 100644 internal/store/store.go delete mode 100644 test_load_data.go diff --git a/internal/babble/babble.go b/internal/babble/babble.go index efdd25b..498a33c 100644 --- a/internal/babble/babble.go +++ b/internal/babble/babble.go @@ -1,48 +1,32 @@ package babble import ( + "dk/internal/store" "fmt" + "sort" "strings" + "sync" "time" - - "dk/internal/database" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" ) -// Babble represents a global chat message in the database +// Babble represents a global chat message in the game type Babble struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Posted int64 `db:"posted" json:"posted"` - Author string `db:"author" json:"author"` - Babble string `db:"babble" json:"babble"` -} - -func (b *Babble) GetTableName() string { - return "babble" -} - -func (b *Babble) GetID() int { - return b.ID -} - -func (b *Babble) SetID(id int) { - b.ID = id -} - -func (b *Babble) Set(field string, value any) error { - return database.Set(b, field, value) + ID int `json:"id"` + Posted int64 `json:"posted"` + Author string `json:"author"` + Babble string `json:"babble"` } func (b *Babble) Save() error { - return database.Save(b) + babbleStore := GetStore() + babbleStore.UpdateBabble(b) + return nil } func (b *Babble) Delete() error { - return database.Delete(b) + babbleStore := GetStore() + babbleStore.RemoveBabble(b.ID) + return nil } // Creates a new Babble with sensible defaults @@ -54,181 +38,315 @@ func New() *Babble { } } -var babbleScanner = scanner.New[Babble]() - -// Returns the column list for babble queries -func babbleColumns() string { - return babbleScanner.Columns() +// Validate checks if babble has valid values +func (b *Babble) Validate() error { + if b.Posted <= 0 { + return fmt.Errorf("babble Posted timestamp must be positive") + } + if strings.TrimSpace(b.Author) == "" { + return fmt.Errorf("babble Author cannot be empty") + } + if strings.TrimSpace(b.Babble) == "" { + return fmt.Errorf("babble message cannot be empty") + } + return nil } -// Populates a Babble struct using the fast scanner -func scanBabble(stmt *sqlite.Stmt) *Babble { - babble := &Babble{} - babbleScanner.Scan(stmt, babble) - return babble +// BabbleStore provides in-memory storage with O(1) lookups and babble-specific indices +type BabbleStore struct { + *store.BaseStore[Babble] // Embedded generic store + byAuthor map[string][]int // Author (lowercase) -> []ID + allByPosted []int // All IDs sorted by posted DESC, id DESC + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var babbleStore *BabbleStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + babbleStore = &BabbleStore{ + BaseStore: store.NewBaseStore[Babble](), + byAuthor: make(map[string][]int), + allByPosted: make([]int, 0), + } +} + +// GetStore returns the global babble store +func GetStore() *BabbleStore { + storeOnce.Do(initStore) + return babbleStore +} + +// AddBabble adds a babble message to the in-memory store and updates all indices +func (bs *BabbleStore) AddBabble(babble *Babble) { + bs.mu.Lock() + defer bs.mu.Unlock() + + // Validate babble + if err := babble.Validate(); err != nil { + return + } + + // Add to base store + bs.Add(babble.ID, babble) + + // Rebuild indices + bs.rebuildIndicesUnsafe() +} + +// RemoveBabble removes a babble message from the store and updates indices +func (bs *BabbleStore) RemoveBabble(id int) { + bs.mu.Lock() + defer bs.mu.Unlock() + + // Remove from base store + bs.Remove(id) + + // Rebuild indices + bs.rebuildIndicesUnsafe() +} + +// UpdateBabble updates a babble message efficiently +func (bs *BabbleStore) UpdateBabble(babble *Babble) { + bs.mu.Lock() + defer bs.mu.Unlock() + + // Validate babble + if err := babble.Validate(); err != nil { + return + } + + // Update base store + bs.Add(babble.ID, babble) + + // Rebuild indices + bs.rebuildIndicesUnsafe() +} + +// LoadData loads babble data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + bs := GetStore() + + // Load from base store, which handles JSON loading + if err := bs.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + bs.rebuildIndices() + return nil +} + +// SaveData saves babble data to JSON file +func SaveData(dataPath string) error { + bs := GetStore() + return bs.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (bs *BabbleStore) rebuildIndicesUnsafe() { + // Clear indices + bs.byAuthor = make(map[string][]int) + bs.allByPosted = make([]int, 0) + + // Collect all babbles and build indices + allBabbles := bs.GetAll() + + for id, babble := range allBabbles { + // Author index (case-insensitive) + authorKey := strings.ToLower(babble.Author) + bs.byAuthor[authorKey] = append(bs.byAuthor[authorKey], id) + + // All IDs + bs.allByPosted = append(bs.allByPosted, id) + } + + // Sort allByPosted by posted DESC, then ID DESC + sort.Slice(bs.allByPosted, func(i, j int) bool { + babbleI, _ := bs.GetByID(bs.allByPosted[i]) + babbleJ, _ := bs.GetByID(bs.allByPosted[j]) + if babbleI.Posted != babbleJ.Posted { + return babbleI.Posted > babbleJ.Posted // DESC + } + return bs.allByPosted[i] > bs.allByPosted[j] // DESC + }) + + // Sort author indices by posted DESC, then ID DESC + for author := range bs.byAuthor { + sort.Slice(bs.byAuthor[author], func(i, j int) bool { + babbleI, _ := bs.GetByID(bs.byAuthor[author][i]) + babbleJ, _ := bs.GetByID(bs.byAuthor[author][j]) + if babbleI.Posted != babbleJ.Posted { + return babbleI.Posted > babbleJ.Posted // DESC + } + return bs.byAuthor[author][i] > bs.byAuthor[author][j] // DESC + }) + } +} + +// rebuildIndices rebuilds all babble-specific indices from base store data +func (bs *BabbleStore) rebuildIndices() { + bs.mu.Lock() + defer bs.mu.Unlock() + bs.rebuildIndicesUnsafe() } // Retrieves a babble message by ID func Find(id int) (*Babble, error) { - var babble *Babble - - query := `SELECT ` + babbleColumns() + ` FROM babble WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble = scanBabble(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find babble: %w", err) - } - - if babble == nil { + bs := GetStore() + babble, exists := bs.GetByID(id) + if !exists { return nil, fmt.Errorf("babble with ID %d not found", id) } - return babble, nil } // Retrieves all babble messages ordered by posted time (newest first) func All() ([]*Babble, error) { - var babbles []*Babble + bs := GetStore() + bs.mu.RLock() + defer bs.mu.RUnlock() - query := `SELECT ` + babbleColumns() + ` FROM babble ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble := scanBabble(stmt) - babbles = append(babbles, babble) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve all babble: %w", err) + result := make([]*Babble, 0, len(bs.allByPosted)) + for _, id := range bs.allByPosted { + if babble, exists := bs.GetByID(id); exists { + result = append(result, babble) + } } - - return babbles, nil + return result, nil } // Retrieves babble messages by a specific author func ByAuthor(author string) ([]*Babble, error) { - var babbles []*Babble + bs := GetStore() + bs.mu.RLock() + defer bs.mu.RUnlock() - query := `SELECT ` + babbleColumns() + ` FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble := scanBabble(stmt) - babbles = append(babbles, babble) - return nil - }, author) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve babble by author: %w", err) + ids, exists := bs.byAuthor[strings.ToLower(author)] + if !exists { + return []*Babble{}, nil } - return babbles, nil + result := make([]*Babble, 0, len(ids)) + for _, id := range ids { + if babble, exists := bs.GetByID(id); exists { + result = append(result, babble) + } + } + return result, nil } // Retrieves the most recent babble messages (limited by count) func Recent(limit int) ([]*Babble, error) { - var babbles []*Babble + bs := GetStore() + bs.mu.RLock() + defer bs.mu.RUnlock() - query := `SELECT ` + babbleColumns() + ` FROM babble ORDER BY posted DESC, id DESC LIMIT ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble := scanBabble(stmt) - babbles = append(babbles, babble) - return nil - }, limit) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve recent babble: %w", err) + if limit > len(bs.allByPosted) { + limit = len(bs.allByPosted) } - return babbles, nil + result := make([]*Babble, 0, limit) + for i := 0; i < limit; i++ { + if babble, exists := bs.GetByID(bs.allByPosted[i]); exists { + result = append(result, babble) + } + } + return result, nil } // Retrieves babble messages since a specific timestamp func Since(since int64) ([]*Babble, error) { - var babbles []*Babble + bs := GetStore() + bs.mu.RLock() + defer bs.mu.RUnlock() - query := `SELECT ` + babbleColumns() + ` FROM babble WHERE posted >= ? ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble := scanBabble(stmt) - babbles = append(babbles, babble) - return nil - }, since) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve babble since timestamp: %w", err) + var result []*Babble + for _, id := range bs.allByPosted { + if babble, exists := bs.GetByID(id); exists && babble.Posted >= since { + result = append(result, babble) + } } - - return babbles, nil + return result, nil } // Retrieves babble messages between two timestamps (inclusive) func Between(start, end int64) ([]*Babble, error) { - var babbles []*Babble + bs := GetStore() + bs.mu.RLock() + defer bs.mu.RUnlock() - query := `SELECT ` + babbleColumns() + ` FROM babble WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble := scanBabble(stmt) - babbles = append(babbles, babble) - return nil - }, start, end) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve babble between timestamps: %w", err) + var result []*Babble + for _, id := range bs.allByPosted { + if babble, exists := bs.GetByID(id); exists && babble.Posted >= start && babble.Posted <= end { + result = append(result, babble) + } } - - return babbles, nil + return result, nil } // Retrieves babble messages containing the search term (case-insensitive) func Search(term string) ([]*Babble, error) { - var babbles []*Babble + bs := GetStore() + bs.mu.RLock() + defer bs.mu.RUnlock() - query := `SELECT ` + babbleColumns() + ` FROM babble WHERE LOWER(babble) LIKE LOWER(?) ORDER BY posted DESC, id DESC` - searchTerm := "%" + term + "%" + var result []*Babble + lowerTerm := strings.ToLower(term) - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble := scanBabble(stmt) - babbles = append(babbles, babble) - return nil - }, searchTerm) - - if err != nil { - return nil, fmt.Errorf("failed to search babble: %w", err) + for _, id := range bs.allByPosted { + if babble, exists := bs.GetByID(id); exists { + if strings.Contains(strings.ToLower(babble.Babble), lowerTerm) { + result = append(result, babble) + } + } } - - return babbles, nil + return result, nil } // Retrieves recent messages from a specific author func RecentByAuthor(author string, limit int) ([]*Babble, error) { - var babbles []*Babble + bs := GetStore() + bs.mu.RLock() + defer bs.mu.RUnlock() - query := `SELECT ` + babbleColumns() + ` FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC LIMIT ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - babble := scanBabble(stmt) - babbles = append(babbles, babble) - return nil - }, author, limit) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve recent babble by author: %w", err) + ids, exists := bs.byAuthor[strings.ToLower(author)] + if !exists { + return []*Babble{}, nil } - return babbles, nil + if limit > len(ids) { + limit = len(ids) + } + + result := make([]*Babble, 0, limit) + for i := 0; i < limit; i++ { + if babble, exists := bs.GetByID(ids[i]); exists { + result = append(result, babble) + } + } + return result, nil } -// Saves a new babble to the database and sets the ID +// Saves a new babble to the in-memory store and sets the ID func (b *Babble) Insert() error { - columns := `posted, author, babble` - values := []any{b.Posted, b.Author, b.Babble} - return database.Insert(b, columns, values...) + bs := GetStore() + + // Validate before insertion + if err := b.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if b.ID == 0 { + b.ID = bs.GetNextID() + } + + // Add to store + bs.AddBabble(b) + return nil } // Returns the posted timestamp as a time.Time @@ -238,7 +356,7 @@ func (b *Babble) PostedTime() time.Time { // Sets the posted timestamp from a time.Time func (b *Babble) SetPostedTime(t time.Time) { - b.Set("Posted", t.Unix()) + b.Posted = t.Unix() } // Returns true if the babble message was posted within the last hour diff --git a/internal/control/control.go b/internal/control/control.go index dfd1272..d2cc27a 100644 --- a/internal/control/control.go +++ b/internal/control/control.go @@ -1,49 +1,32 @@ package control import ( + "dk/internal/store" "fmt" - - "dk/internal/database" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" + "sync" ) -// Control represents the game control settings in the database +// Control represents the game control settings type Control struct { - database.BaseModel - - ID int `db:"id" json:"id"` - WorldSize int `db:"world_size" json:"world_size"` - Open int `db:"open" json:"open"` - AdminEmail string `db:"admin_email" json:"admin_email"` - Class1Name string `db:"class_1_name" json:"class_1_name"` - Class2Name string `db:"class_2_name" json:"class_2_name"` - Class3Name string `db:"class_3_name" json:"class_3_name"` -} - -func (c *Control) GetTableName() string { - return "control" -} - -func (c *Control) GetID() int { - return c.ID -} - -func (c *Control) SetID(id int) { - c.ID = id -} - -func (c *Control) Set(field string, value any) error { - return database.Set(c, field, value) + ID int `json:"id"` + WorldSize int `json:"world_size"` + Open int `json:"open"` + AdminEmail string `json:"admin_email"` + Class1Name string `json:"class_1_name"` + Class2Name string `json:"class_2_name"` + Class3Name string `json:"class_3_name"` } func (c *Control) Save() error { - return database.Save(c) + controlStore := GetStore() + controlStore.UpdateControl(c) + return nil } func (c *Control) Delete() error { - return database.Delete(c) + controlStore := GetStore() + controlStore.RemoveControl(c.ID) + return nil } // Creates a new Control with sensible defaults @@ -58,39 +41,144 @@ func New() *Control { } } -var controlScanner = scanner.New[Control]() - -// Returns the column list for control queries -func controlColumns() string { - return controlScanner.Columns() +// Validate checks if control has valid values +func (c *Control) Validate() error { + if c.WorldSize <= 0 || c.WorldSize > 10000 { + return fmt.Errorf("control WorldSize must be between 1 and 10000") + } + if c.Open != 0 && c.Open != 1 { + return fmt.Errorf("control Open must be 0 or 1") + } + return nil } -// Populates a Control struct using the fast scanner -func scanControl(stmt *sqlite.Stmt) *Control { - control := &Control{} - controlScanner.Scan(stmt, control) - return control +// ControlStore provides in-memory storage for control settings +type ControlStore struct { + *store.BaseStore[Control] // Embedded generic store + allByID []int // All IDs sorted by ID + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var controlStore *ControlStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + controlStore = &ControlStore{ + BaseStore: store.NewBaseStore[Control](), + allByID: make([]int, 0), + } +} + +// GetStore returns the global control store +func GetStore() *ControlStore { + storeOnce.Do(initStore) + return controlStore +} + +// AddControl adds a control to the in-memory store and updates all indices +func (cs *ControlStore) AddControl(control *Control) { + cs.mu.Lock() + defer cs.mu.Unlock() + + // Validate control + if err := control.Validate(); err != nil { + return + } + + // Add to base store + cs.Add(control.ID, control) + + // Rebuild indices + cs.rebuildIndicesUnsafe() +} + +// RemoveControl removes a control from the store and updates indices +func (cs *ControlStore) RemoveControl(id int) { + cs.mu.Lock() + defer cs.mu.Unlock() + + // Remove from base store + cs.Remove(id) + + // Rebuild indices + cs.rebuildIndicesUnsafe() +} + +// UpdateControl updates a control efficiently +func (cs *ControlStore) UpdateControl(control *Control) { + cs.mu.Lock() + defer cs.mu.Unlock() + + // Validate control + if err := control.Validate(); err != nil { + return + } + + // Update base store + cs.Add(control.ID, control) + + // Rebuild indices + cs.rebuildIndicesUnsafe() +} + +// LoadData loads control data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + cs := GetStore() + + // Load from base store, which handles JSON loading + if err := cs.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + cs.rebuildIndices() + return nil +} + +// SaveData saves control data to JSON file +func SaveData(dataPath string) error { + cs := GetStore() + return cs.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (cs *ControlStore) rebuildIndicesUnsafe() { + // Clear indices + cs.allByID = make([]int, 0) + + // Collect all controls + allControls := cs.GetAll() + + for id := range allControls { + cs.allByID = append(cs.allByID, id) + } + + // Sort by ID (though typically only one control record exists) + for i := 0; i < len(cs.allByID); i++ { + for j := i + 1; j < len(cs.allByID); j++ { + if cs.allByID[i] > cs.allByID[j] { + cs.allByID[i], cs.allByID[j] = cs.allByID[j], cs.allByID[i] + } + } + } +} + +// rebuildIndices rebuilds all control-specific indices from base store data +func (cs *ControlStore) rebuildIndices() { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.rebuildIndicesUnsafe() } // Retrieves the control record by ID (typically only ID 1 exists) func Find(id int) (*Control, error) { - var control *Control - - query := `SELECT ` + controlColumns() + ` FROM control WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - control = scanControl(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find control: %w", err) - } - - if control == nil { + cs := GetStore() + control, exists := cs.GetByID(id) + if !exists { return nil, fmt.Errorf("control with ID %d not found", id) } - return control, nil } @@ -99,11 +187,23 @@ func Get() (*Control, error) { return Find(1) } -// Saves a new control to the database and sets the ID +// Saves a new control to the in-memory store and sets the ID func (c *Control) Insert() error { - columns := `world_size, open, admin_email, class_1_name, class_2_name, class_3_name` - values := []any{c.WorldSize, c.Open, c.AdminEmail, c.Class1Name, c.Class2Name, c.Class3Name} - return database.Insert(c, columns, values...) + cs := GetStore() + + // Validate before insertion + if err := c.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if c.ID == 0 { + c.ID = cs.GetNextID() + } + + // Add to store + cs.AddControl(c) + return nil } // Returns true if the game world is open for new players @@ -114,20 +214,20 @@ func (c *Control) IsOpen() bool { // Sets whether the game world is open for new players func (c *Control) SetOpen(open bool) { if open { - c.Set("Open", 1) + c.Open = 1 } else { - c.Set("Open", 0) + c.Open = 0 } } // Closes the game world to new players func (c *Control) Close() { - c.Set("Open", 0) + c.Open = 0 } // Opens the game world to new players func (c *Control) OpenWorld() { - c.Set("Open", 1) + c.Open = 1 } // Returns all class names as a slice @@ -148,19 +248,19 @@ func (c *Control) GetClassNames() []string { // Sets all class names from a slice func (c *Control) SetClassNames(classes []string) { // Reset all class names - c.Set("Class1Name", "") - c.Set("Class2Name", "") - c.Set("Class3Name", "") + c.Class1Name = "" + c.Class2Name = "" + c.Class3Name = "" // Set provided class names if len(classes) > 0 { - c.Set("Class1Name", classes[0]) + c.Class1Name = classes[0] } if len(classes) > 1 { - c.Set("Class2Name", classes[1]) + c.Class2Name = classes[1] } if len(classes) > 2 { - c.Set("Class3Name", classes[2]) + c.Class3Name = classes[2] } } @@ -182,13 +282,13 @@ func (c *Control) GetClassName(classNum int) string { func (c *Control) SetClassName(classNum int, name string) bool { switch classNum { case 1: - c.Set("Class1Name", name) + c.Class1Name = name return true case 2: - c.Set("Class2Name", name) + c.Class2Name = name return true case 3: - c.Set("Class3Name", name) + c.Class3Name = name return true default: return false diff --git a/internal/database/database.go b/internal/database/database.go deleted file mode 100644 index f973342..0000000 --- a/internal/database/database.go +++ /dev/null @@ -1,201 +0,0 @@ -package database - -import ( - "context" - "fmt" - "runtime" - - "zombiezen.com/go/sqlite" - "zombiezen.com/go/sqlite/sqlitex" -) - -const DefaultPath = "dk.db" - -// Global singleton instance -var pool *sqlitex.Pool - -// Init initializes the global database connection pool -func Init(path string) error { - if path == "" { - path = DefaultPath - } - - poolSize := max(runtime.GOMAXPROCS(0), 2) - - var err error - pool, err = sqlitex.NewPool(path, sqlitex.PoolOptions{ - PoolSize: poolSize, - Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL, - }) - if err != nil { - return fmt.Errorf("failed to open database pool: %w", err) - } - - conn, err := pool.Take(context.Background()) - if err != nil { - pool.Close() - return fmt.Errorf("failed to get connection from pool: %w", err) - } - defer pool.Put(conn) - - if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil { - pool.Close() - return fmt.Errorf("failed to set WAL mode: %w", err) - } - - if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil { - pool.Close() - return fmt.Errorf("failed to set synchronous mode: %w", err) - } - - return nil -} - -// Close closes the global database connection pool -func Close() error { - if pool == nil { - return nil - } - return pool.Close() -} - -// GetConn gets a connection from the pool - caller must call PutConn when done -func GetConn(ctx context.Context) (*sqlite.Conn, error) { - if pool == nil { - return nil, fmt.Errorf("database not initialized") - } - return pool.Take(ctx) -} - -// PutConn returns a connection to the pool -func PutConn(conn *sqlite.Conn) { - if pool != nil { - pool.Put(conn) - } -} - -// Exec executes a SQL statement without returning results -func Exec(query string, args ...any) error { - if pool == nil { - return fmt.Errorf("database not initialized") - } - - conn, err := pool.Take(context.Background()) - if err != nil { - return fmt.Errorf("failed to get connection from pool: %w", err) - } - defer pool.Put(conn) - - if len(args) == 0 { - return sqlitex.ExecuteTransient(conn, query, nil) - } - - return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ - Args: args, - }) -} - -// Query executes a SQL query and calls fn for each row -func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { - if pool == nil { - return fmt.Errorf("database not initialized") - } - - conn, err := pool.Take(context.Background()) - if err != nil { - return fmt.Errorf("failed to get connection from pool: %w", err) - } - defer pool.Put(conn) - - if len(args) == 0 { - return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ - ResultFunc: fn, - }) - } - - return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{ - Args: args, - ResultFunc: fn, - }) -} - -// Begin starts a new transaction -func Begin() (*Tx, error) { - if pool == nil { - return nil, fmt.Errorf("database not initialized") - } - - conn, err := pool.Take(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get connection from pool: %w", err) - } - - if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil { - pool.Put(conn) - return nil, fmt.Errorf("failed to begin transaction: %w", err) - } - - return &Tx{conn: conn, pool: pool}, nil -} - -// Transaction runs a function within a transaction -func Transaction(fn func(*Tx) error) error { - if pool == nil { - return fmt.Errorf("database not initialized") - } - - tx, err := Begin() - if err != nil { - return err - } - - if err := fn(tx); err != nil { - tx.Rollback() - return err - } - - return tx.Commit() -} - -// Tx represents a database transaction -type Tx struct { - conn *sqlite.Conn - pool *sqlitex.Pool -} - -// Exec executes a SQL statement within the transaction -func (tx *Tx) Exec(query string, args ...any) error { - if len(args) == 0 { - return sqlitex.ExecuteTransient(tx.conn, query, nil) - } - - return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ - Args: args, - }) -} - -// Query executes a SQL query within the transaction -func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error { - if len(args) == 0 { - return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ - ResultFunc: fn, - }) - } - - return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{ - Args: args, - ResultFunc: fn, - }) -} - -// Commit commits the transaction -func (tx *Tx) Commit() error { - defer tx.pool.Put(tx.conn) - return sqlitex.ExecuteTransient(tx.conn, "COMMIT", nil) -} - -// Rollback rolls back the transaction -func (tx *Tx) Rollback() error { - defer tx.pool.Put(tx.conn) - return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil) -} diff --git a/internal/database/database_test.go b/internal/database/database_test.go deleted file mode 100644 index 969d059..0000000 --- a/internal/database/database_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package database - -import ( - "os" - "testing" - - "zombiezen.com/go/sqlite" -) - -func TestDatabaseOperations(t *testing.T) { - // Use a temporary database file - testDB := "test.db" - defer os.Remove(testDB) - - // Initialize the singleton database - err := Init(testDB) - if err != nil { - t.Fatalf("Failed to initialize database: %v", err) - } - defer Close() - - // Test creating a simple table - err = Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)") - if err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - // Test inserting data - err = Exec("INSERT INTO test_users (name) VALUES (?)", "Alice") - if err != nil { - t.Fatalf("Failed to insert data: %v", err) - } - - // Test querying data - var foundName string - err = Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error { - foundName = stmt.ColumnText(0) - return nil - }, "Alice") - - if err != nil { - t.Fatalf("Failed to query data: %v", err) - } - - if foundName != "Alice" { - t.Errorf("Expected 'Alice', got '%s'", foundName) - } - - // Test transaction - err = Transaction(func(tx *Tx) error { - return tx.Exec("INSERT INTO test_users (name) VALUES (?)", "Bob") - }) - - if err != nil { - t.Fatalf("Transaction failed: %v", err) - } - - // Verify transaction worked - var count int - err = Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error { - count = stmt.ColumnInt(0) - return nil - }) - - if err != nil { - t.Fatalf("Failed to count users: %v", err) - } - - if count != 2 { - t.Errorf("Expected 2 users, got %d", count) - } -} diff --git a/internal/database/model.go b/internal/database/model.go deleted file mode 100644 index c1c483d..0000000 --- a/internal/database/model.go +++ /dev/null @@ -1,129 +0,0 @@ -package database - -import ( - "fmt" - "reflect" - "strings" - - "zombiezen.com/go/sqlite" -) - -// Model interface for trackable database models -type Model interface { - GetTableName() string - GetID() int - SetID(id int) - GetDirtyFields() map[string]any - SetDirty(field string, value any) - ClearDirty() - IsDirty() bool -} - -// BaseModel provides common model functionality -type BaseModel struct { - FieldTracker -} - -// Set uses reflection to set a field and track changes -func Set(model Model, field string, value any) error { - v := reflect.ValueOf(model).Elem() - t := v.Type() - - fieldVal := v.FieldByName(field) - if !fieldVal.IsValid() { - return fmt.Errorf("field %s does not exist", field) - } - if !fieldVal.CanSet() { - return fmt.Errorf("field %s cannot be set", field) - } - - // Get current value for comparison - currentVal := fieldVal.Interface() - - // Only set if value has changed - if !reflect.DeepEqual(currentVal, value) { - newVal := reflect.ValueOf(value) - if newVal.Type().ConvertibleTo(fieldVal.Type()) { - fieldVal.Set(newVal.Convert(fieldVal.Type())) - - // Get db column name from struct tag - structField, _ := t.FieldByName(field) - dbField := structField.Tag.Get("db") - if dbField == "" { - dbField = toSnakeCase(field) // fallback - } - - model.SetDirty(dbField, value) - } else { - return fmt.Errorf("cannot convert %T to %s", value, fieldVal.Type()) - } - } - - return nil -} - -// toSnakeCase converts CamelCase to snake_case -func toSnakeCase(s string) string { - var result strings.Builder - for i, r := range s { - if i > 0 && r >= 'A' && r <= 'Z' { - prev := rune(s[i-1]) - if prev < 'A' || prev > 'Z' { - result.WriteByte('_') - } - } - if r >= 'A' && r <= 'Z' { - result.WriteRune(r - 'A' + 'a') - } else { - result.WriteRune(r) - } - } - return result.String() -} - -// Save updates only dirty fields -func Save(model Model) error { - if model.GetID() == 0 { - return fmt.Errorf("cannot save model without ID") - } - return UpdateDirty(model) -} - -// Insert creates a new record and sets the ID -func Insert(model Model, columns string, values ...any) error { - if model.GetID() != 0 { - return fmt.Errorf("model already has ID %d, use Save() to update", model.GetID()) - } - - return Transaction(func(tx *Tx) error { - placeholders := strings.Repeat("?,", len(values)) - placeholders = placeholders[:len(placeholders)-1] // Remove trailing comma - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", - model.GetTableName(), columns, placeholders) - - if err := tx.Exec(query, values...); err != nil { - return fmt.Errorf("failed to insert: %w", err) - } - - var id int - err := tx.Query("SELECT last_insert_rowid()", func(stmt *sqlite.Stmt) error { - id = stmt.ColumnInt(0) - return nil - }) - if err != nil { - return fmt.Errorf("failed to get insert ID: %w", err) - } - - model.SetID(id) - return nil - }) -} - -// Delete removes the record -func Delete(model Model) error { - if model.GetID() == 0 { - return fmt.Errorf("cannot delete model without ID") - } - return Exec("DELETE FROM ? WHERE id = ?", model.GetTableName(), model.GetID()) -} diff --git a/internal/database/trackable.go b/internal/database/trackable.go deleted file mode 100644 index 9c2f2e9..0000000 --- a/internal/database/trackable.go +++ /dev/null @@ -1,82 +0,0 @@ -package database - -import ( - "fmt" - "strings" -) - -// Trackable interface for models that can track field changes -type Trackable interface { - GetTableName() string - GetID() int - GetDirtyFields() map[string]any - SetDirty(field string, value any) - ClearDirty() - IsDirty() bool -} - -// FieldTracker provides dirty field tracking functionality -type FieldTracker struct { - dirty map[string]any -} - -// SetDirty marks a field as dirty with its new value -func (ft *FieldTracker) SetDirty(field string, value any) { - if ft.dirty == nil { - ft.dirty = make(map[string]any) - } - ft.dirty[field] = value -} - -// GetDirtyFields returns map of dirty fields and their values -func (ft *FieldTracker) GetDirtyFields() map[string]any { - if ft.dirty == nil { - return make(map[string]any) - } - return ft.dirty -} - -// ClearDirty clears all dirty field tracking -func (ft *FieldTracker) ClearDirty() { - ft.dirty = nil -} - -// IsDirty returns true if any fields have been modified -func (ft *FieldTracker) IsDirty() bool { - return len(ft.dirty) > 0 -} - -// UpdateDirty updates only dirty fields in the database -func UpdateDirty(model Trackable) error { - if !model.IsDirty() { - return nil // No changes to save - } - - dirty := model.GetDirtyFields() - if len(dirty) == 0 { - return nil - } - - // Build dynamic UPDATE query - var setParts []string - var args []any - - for field, value := range dirty { - setParts = append(setParts, field+" = ?") - args = append(args, value) - } - - args = append(args, model.GetID()) // Add ID for WHERE clause - - query := fmt.Sprintf("UPDATE %s SET %s WHERE id = ?", - model.GetTableName(), - strings.Join(setParts, ", ")) - - err := Exec(query, args...) - if err != nil { - return fmt.Errorf("failed to update %s: %w", model.GetTableName(), err) - } - - model.ClearDirty() - return nil -} diff --git a/internal/drops/drops.go b/internal/drops/drops.go index 8dc13b3..085ade3 100644 --- a/internal/drops/drops.go +++ b/internal/drops/drops.go @@ -1,47 +1,31 @@ package drops import ( + "dk/internal/store" "fmt" - - "dk/internal/database" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" + "sort" + "sync" ) -// Drop represents a drop item in the database +// Drop represents a drop item in the game type Drop struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Name string `db:"name" json:"name"` - Level int `db:"level" json:"level"` - Type int `db:"type" json:"type"` - Att string `db:"att" json:"att"` -} - -func (d *Drop) GetTableName() string { - return "drops" -} - -func (d *Drop) GetID() int { - return d.ID -} - -func (d *Drop) SetID(id int) { - d.ID = id -} - -func (d *Drop) Set(field string, value any) error { - return database.Set(d, field, value) + ID int `json:"id"` + Name string `json:"name"` + Level int `json:"level"` + Type int `json:"type"` + Att string `json:"att"` } func (d *Drop) Save() error { - return database.Save(d) + dropStore := GetStore() + dropStore.UpdateDrop(d) + return nil } func (d *Drop) Delete() error { - return database.Delete(d) + dropStore := GetStore() + dropStore.RemoveDrop(d.ID) + return nil } // Creates a new Drop with sensible defaults @@ -54,18 +38,18 @@ func New() *Drop { } } -var dropScanner = scanner.New[Drop]() - -// Returns the column list for drop queries -func dropColumns() string { - return dropScanner.Columns() -} - -// Populates a Drop struct using the fast scanner -func scanDrop(stmt *sqlite.Stmt) *Drop { - drop := &Drop{} - dropScanner.Scan(stmt, drop) - return drop +// Validate checks if drop has valid values +func (d *Drop) Validate() error { + if d.Name == "" { + return fmt.Errorf("drop name cannot be empty") + } + if d.Level < 1 { + return fmt.Errorf("drop Level must be at least 1") + } + if d.Type < TypeConsumable { + return fmt.Errorf("invalid drop type: %d", d.Type) + } + return nil } // DropType constants for drop types @@ -73,90 +57,231 @@ const ( TypeConsumable = 1 ) +// DropStore provides in-memory storage with O(1) lookups and drop-specific indices +type DropStore struct { + *store.BaseStore[Drop] // Embedded generic store + byLevel map[int][]int // Level -> []ID + byType map[int][]int // Type -> []ID + allByID []int // All IDs sorted by ID + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var dropStore *DropStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + dropStore = &DropStore{ + BaseStore: store.NewBaseStore[Drop](), + byLevel: make(map[int][]int), + byType: make(map[int][]int), + allByID: make([]int, 0), + } +} + +// GetStore returns the global drop store +func GetStore() *DropStore { + storeOnce.Do(initStore) + return dropStore +} + +// AddDrop adds a drop to the in-memory store and updates all indices +func (ds *DropStore) AddDrop(drop *Drop) { + ds.mu.Lock() + defer ds.mu.Unlock() + + // Validate drop + if err := drop.Validate(); err != nil { + return + } + + // Add to base store + ds.Add(drop.ID, drop) + + // Rebuild indices + ds.rebuildIndicesUnsafe() +} + +// RemoveDrop removes a drop from the store and updates indices +func (ds *DropStore) RemoveDrop(id int) { + ds.mu.Lock() + defer ds.mu.Unlock() + + // Remove from base store + ds.Remove(id) + + // Rebuild indices + ds.rebuildIndicesUnsafe() +} + +// UpdateDrop updates a drop efficiently +func (ds *DropStore) UpdateDrop(drop *Drop) { + ds.mu.Lock() + defer ds.mu.Unlock() + + // Validate drop + if err := drop.Validate(); err != nil { + return + } + + // Update base store + ds.Add(drop.ID, drop) + + // Rebuild indices + ds.rebuildIndicesUnsafe() +} + +// LoadData loads drop data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + ds := GetStore() + + // Load from base store, which handles JSON loading + if err := ds.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + ds.rebuildIndices() + return nil +} + +// SaveData saves drop data to JSON file +func SaveData(dataPath string) error { + ds := GetStore() + return ds.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (ds *DropStore) rebuildIndicesUnsafe() { + // Clear indices + ds.byLevel = make(map[int][]int) + ds.byType = make(map[int][]int) + ds.allByID = make([]int, 0) + + // Collect all drops and build indices + allDrops := ds.GetAll() + + for id, drop := range allDrops { + // Level index + ds.byLevel[drop.Level] = append(ds.byLevel[drop.Level], id) + + // Type index + ds.byType[drop.Type] = append(ds.byType[drop.Type], id) + + // All IDs + ds.allByID = append(ds.allByID, id) + } + + // Sort allByID by ID + sort.Ints(ds.allByID) + + // Sort level indices by ID + for level := range ds.byLevel { + sort.Ints(ds.byLevel[level]) + } + + // Sort type indices by level, then ID + for dropType := range ds.byType { + sort.Slice(ds.byType[dropType], func(i, j int) bool { + dropI, _ := ds.GetByID(ds.byType[dropType][i]) + dropJ, _ := ds.GetByID(ds.byType[dropType][j]) + if dropI.Level != dropJ.Level { + return dropI.Level < dropJ.Level + } + return ds.byType[dropType][i] < ds.byType[dropType][j] + }) + } +} + +// rebuildIndices rebuilds all drop-specific indices from base store data +func (ds *DropStore) rebuildIndices() { + ds.mu.Lock() + defer ds.mu.Unlock() + ds.rebuildIndicesUnsafe() +} + // Retrieves a drop by ID func Find(id int) (*Drop, error) { - var drop *Drop - - query := `SELECT ` + dropColumns() + ` FROM drops WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - drop = scanDrop(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find drop: %w", err) - } - - if drop == nil { + ds := GetStore() + drop, exists := ds.GetByID(id) + if !exists { return nil, fmt.Errorf("drop with ID %d not found", id) } - return drop, nil } // Retrieves all drops func All() ([]*Drop, error) { - var drops []*Drop + ds := GetStore() + ds.mu.RLock() + defer ds.mu.RUnlock() - query := `SELECT ` + dropColumns() + ` FROM drops ORDER BY id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - drop := scanDrop(stmt) - drops = append(drops, drop) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve all drops: %w", err) + result := make([]*Drop, 0, len(ds.allByID)) + for _, id := range ds.allByID { + if drop, exists := ds.GetByID(id); exists { + result = append(result, drop) + } } - - return drops, nil + return result, nil } // Retrieves drops by minimum level requirement func ByLevel(minLevel int) ([]*Drop, error) { - var drops []*Drop + ds := GetStore() + ds.mu.RLock() + defer ds.mu.RUnlock() - query := `SELECT ` + dropColumns() + ` FROM drops WHERE level <= ? ORDER BY level, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - drop := scanDrop(stmt) - drops = append(drops, drop) - return nil - }, minLevel) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve drops by level: %w", err) + var result []*Drop + for level := 1; level <= minLevel; level++ { + if ids, exists := ds.byLevel[level]; exists { + for _, id := range ids { + if drop, exists := ds.GetByID(id); exists { + result = append(result, drop) + } + } + } } - - return drops, nil + return result, nil } // Retrieves drops by type func ByType(dropType int) ([]*Drop, error) { - var drops []*Drop + ds := GetStore() + ds.mu.RLock() + defer ds.mu.RUnlock() - query := `SELECT ` + dropColumns() + ` FROM drops WHERE type = ? ORDER BY level, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - drop := scanDrop(stmt) - drops = append(drops, drop) - return nil - }, dropType) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve drops by type: %w", err) + ids, exists := ds.byType[dropType] + if !exists { + return []*Drop{}, nil } - return drops, nil + result := make([]*Drop, 0, len(ids)) + for _, id := range ids { + if drop, exists := ds.GetByID(id); exists { + result = append(result, drop) + } + } + return result, nil } -// Saves a new drop to the database and sets the ID +// Saves a new drop to the in-memory store and sets the ID func (d *Drop) Insert() error { - columns := `name, level, type, att` - values := []any{d.Name, d.Level, d.Type, d.Att} - return database.Insert(d, columns, values...) + ds := GetStore() + + // Validate before insertion + if err := d.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if d.ID == 0 { + d.ID = ds.GetNextID() + } + + // Add to store + ds.AddDrop(d) + return nil } // Returns true if the drop is a consumable item diff --git a/internal/forum/forum.go b/internal/forum/forum.go index b5d9713..3d58262 100644 --- a/internal/forum/forum.go +++ b/internal/forum/forum.go @@ -1,52 +1,36 @@ package forum import ( + "dk/internal/store" "fmt" + "sort" "strings" + "sync" "time" - - "dk/internal/database" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" ) -// Forum represents a forum post or thread in the database +// Forum represents a forum post or thread in the game type Forum struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Posted int64 `db:"posted" json:"posted"` - LastPost int64 `db:"last_post" json:"last_post"` - Author int `db:"author" json:"author"` - Parent int `db:"parent" json:"parent"` - Replies int `db:"replies" json:"replies"` - Title string `db:"title" json:"title"` - Content string `db:"content" json:"content"` -} - -func (f *Forum) GetTableName() string { - return "forum" -} - -func (f *Forum) GetID() int { - return f.ID -} - -func (f *Forum) SetID(id int) { - f.ID = id -} - -func (f *Forum) Set(field string, value any) error { - return database.Set(f, field, value) + ID int `json:"id"` + Posted int64 `json:"posted"` + LastPost int64 `json:"last_post"` + Author int `json:"author"` + Parent int `json:"parent"` + Replies int `json:"replies"` + Title string `json:"title"` + Content string `json:"content"` } func (f *Forum) Save() error { - return database.Save(f) + forumStore := GetStore() + forumStore.UpdateForum(f) + return nil } func (f *Forum) Delete() error { - return database.Delete(f) + forumStore := GetStore() + forumStore.RemoveForum(f.ID) + return nil } // Creates a new Forum with sensible defaults @@ -63,181 +47,358 @@ func New() *Forum { } } -var forumScanner = scanner.New[Forum]() - -// Returns the column list for forum queries -func forumColumns() string { - return forumScanner.Columns() +// Validate checks if forum has valid values +func (f *Forum) Validate() error { + if strings.TrimSpace(f.Title) == "" { + return fmt.Errorf("forum title cannot be empty") + } + if strings.TrimSpace(f.Content) == "" { + return fmt.Errorf("forum content cannot be empty") + } + if f.Posted <= 0 { + return fmt.Errorf("forum Posted timestamp must be positive") + } + if f.LastPost <= 0 { + return fmt.Errorf("forum LastPost timestamp must be positive") + } + if f.Parent < 0 { + return fmt.Errorf("forum Parent cannot be negative") + } + if f.Replies < 0 { + return fmt.Errorf("forum Replies cannot be negative") + } + return nil } -// Populates a Forum struct using the fast scanner -func scanForum(stmt *sqlite.Stmt) *Forum { - forum := &Forum{} - forumScanner.Scan(stmt, forum) - return forum +// ForumStore provides in-memory storage with O(1) lookups and forum-specific indices +type ForumStore struct { + *store.BaseStore[Forum] // Embedded generic store + byParent map[int][]int // Parent -> []ID + byAuthor map[int][]int // Author -> []ID + threadsOnly []int // Parent=0 IDs sorted by last_post DESC, id DESC + allByLastPost []int // All IDs sorted by last_post DESC, id DESC + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var forumStore *ForumStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + forumStore = &ForumStore{ + BaseStore: store.NewBaseStore[Forum](), + byParent: make(map[int][]int), + byAuthor: make(map[int][]int), + threadsOnly: make([]int, 0), + allByLastPost: make([]int, 0), + } +} + +// GetStore returns the global forum store +func GetStore() *ForumStore { + storeOnce.Do(initStore) + return forumStore +} + +// AddForum adds a forum post to the in-memory store and updates all indices +func (fs *ForumStore) AddForum(forum *Forum) { + fs.mu.Lock() + defer fs.mu.Unlock() + + // Validate forum + if err := forum.Validate(); err != nil { + return + } + + // Add to base store + fs.Add(forum.ID, forum) + + // Rebuild indices + fs.rebuildIndicesUnsafe() +} + +// RemoveForum removes a forum post from the store and updates indices +func (fs *ForumStore) RemoveForum(id int) { + fs.mu.Lock() + defer fs.mu.Unlock() + + // Remove from base store + fs.Remove(id) + + // Rebuild indices + fs.rebuildIndicesUnsafe() +} + +// UpdateForum updates a forum post efficiently +func (fs *ForumStore) UpdateForum(forum *Forum) { + fs.mu.Lock() + defer fs.mu.Unlock() + + // Validate forum + if err := forum.Validate(); err != nil { + return + } + + // Update base store + fs.Add(forum.ID, forum) + + // Rebuild indices + fs.rebuildIndicesUnsafe() +} + +// LoadData loads forum data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + fs := GetStore() + + // Load from base store, which handles JSON loading + if err := fs.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + fs.rebuildIndices() + return nil +} + +// SaveData saves forum data to JSON file +func SaveData(dataPath string) error { + fs := GetStore() + return fs.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (fs *ForumStore) rebuildIndicesUnsafe() { + // Clear indices + fs.byParent = make(map[int][]int) + fs.byAuthor = make(map[int][]int) + fs.threadsOnly = make([]int, 0) + fs.allByLastPost = make([]int, 0) + + // Collect all forum posts and build indices + allForums := fs.GetAll() + + for id, forum := range allForums { + // Parent index + fs.byParent[forum.Parent] = append(fs.byParent[forum.Parent], id) + + // Author index + fs.byAuthor[forum.Author] = append(fs.byAuthor[forum.Author], id) + + // Threads only (parent = 0) + if forum.Parent == 0 { + fs.threadsOnly = append(fs.threadsOnly, id) + } + + // All posts + fs.allByLastPost = append(fs.allByLastPost, id) + } + + // Sort allByLastPost by last_post DESC, then ID DESC + sort.Slice(fs.allByLastPost, func(i, j int) bool { + forumI, _ := fs.GetByID(fs.allByLastPost[i]) + forumJ, _ := fs.GetByID(fs.allByLastPost[j]) + if forumI.LastPost != forumJ.LastPost { + return forumI.LastPost > forumJ.LastPost // DESC + } + return fs.allByLastPost[i] > fs.allByLastPost[j] // DESC + }) + + // Sort threadsOnly by last_post DESC, then ID DESC + sort.Slice(fs.threadsOnly, func(i, j int) bool { + forumI, _ := fs.GetByID(fs.threadsOnly[i]) + forumJ, _ := fs.GetByID(fs.threadsOnly[j]) + if forumI.LastPost != forumJ.LastPost { + return forumI.LastPost > forumJ.LastPost // DESC + } + return fs.threadsOnly[i] > fs.threadsOnly[j] // DESC + }) + + // Sort byParent replies by posted ASC, then ID ASC + for parent := range fs.byParent { + if parent > 0 { // Only sort replies, not threads + sort.Slice(fs.byParent[parent], func(i, j int) bool { + forumI, _ := fs.GetByID(fs.byParent[parent][i]) + forumJ, _ := fs.GetByID(fs.byParent[parent][j]) + if forumI.Posted != forumJ.Posted { + return forumI.Posted < forumJ.Posted // ASC + } + return fs.byParent[parent][i] < fs.byParent[parent][j] // ASC + }) + } + } + + // Sort byAuthor by posted DESC, then ID DESC + for author := range fs.byAuthor { + sort.Slice(fs.byAuthor[author], func(i, j int) bool { + forumI, _ := fs.GetByID(fs.byAuthor[author][i]) + forumJ, _ := fs.GetByID(fs.byAuthor[author][j]) + if forumI.Posted != forumJ.Posted { + return forumI.Posted > forumJ.Posted // DESC + } + return fs.byAuthor[author][i] > fs.byAuthor[author][j] // DESC + }) + } +} + +// rebuildIndices rebuilds all forum-specific indices from base store data +func (fs *ForumStore) rebuildIndices() { + fs.mu.Lock() + defer fs.mu.Unlock() + fs.rebuildIndicesUnsafe() } // Retrieves a forum post by ID func Find(id int) (*Forum, error) { - var forum *Forum - - query := `SELECT ` + forumColumns() + ` FROM forum WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum = scanForum(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find forum post: %w", err) - } - - if forum == nil { + fs := GetStore() + forum, exists := fs.GetByID(id) + if !exists { return nil, fmt.Errorf("forum post with ID %d not found", id) } - return forum, nil } // Retrieves all forum posts ordered by last post time (most recent first) func All() ([]*Forum, error) { - var forums []*Forum + fs := GetStore() + fs.mu.RLock() + defer fs.mu.RUnlock() - query := `SELECT ` + forumColumns() + ` FROM forum ORDER BY last_post DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum := scanForum(stmt) - forums = append(forums, forum) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve all forum posts: %w", err) + result := make([]*Forum, 0, len(fs.allByLastPost)) + for _, id := range fs.allByLastPost { + if forum, exists := fs.GetByID(id); exists { + result = append(result, forum) + } } - - return forums, nil + return result, nil } // Retrieves all top-level forum threads (parent = 0) func Threads() ([]*Forum, error) { - var forums []*Forum + fs := GetStore() + fs.mu.RLock() + defer fs.mu.RUnlock() - query := `SELECT ` + forumColumns() + ` FROM forum WHERE parent = 0 ORDER BY last_post DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum := scanForum(stmt) - forums = append(forums, forum) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve forum threads: %w", err) + result := make([]*Forum, 0, len(fs.threadsOnly)) + for _, id := range fs.threadsOnly { + if forum, exists := fs.GetByID(id); exists { + result = append(result, forum) + } } - - return forums, nil + return result, nil } // Retrieves all replies to a specific thread/post func ByParent(parentID int) ([]*Forum, error) { - var forums []*Forum + fs := GetStore() + fs.mu.RLock() + defer fs.mu.RUnlock() - query := `SELECT ` + forumColumns() + ` FROM forum WHERE parent = ? ORDER BY posted ASC, id ASC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum := scanForum(stmt) - forums = append(forums, forum) - return nil - }, parentID) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve forum replies: %w", err) + ids, exists := fs.byParent[parentID] + if !exists { + return []*Forum{}, nil } - return forums, nil + result := make([]*Forum, 0, len(ids)) + for _, id := range ids { + if forum, exists := fs.GetByID(id); exists { + result = append(result, forum) + } + } + return result, nil } // Retrieves forum posts by a specific author func ByAuthor(authorID int) ([]*Forum, error) { - var forums []*Forum + fs := GetStore() + fs.mu.RLock() + defer fs.mu.RUnlock() - query := `SELECT ` + forumColumns() + ` FROM forum WHERE author = ? ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum := scanForum(stmt) - forums = append(forums, forum) - return nil - }, authorID) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve forum posts by author: %w", err) + ids, exists := fs.byAuthor[authorID] + if !exists { + return []*Forum{}, nil } - return forums, nil + result := make([]*Forum, 0, len(ids)) + for _, id := range ids { + if forum, exists := fs.GetByID(id); exists { + result = append(result, forum) + } + } + return result, nil } // Retrieves the most recent forum activity (limited by count) func Recent(limit int) ([]*Forum, error) { - var forums []*Forum + fs := GetStore() + fs.mu.RLock() + defer fs.mu.RUnlock() - query := `SELECT ` + forumColumns() + ` FROM forum ORDER BY last_post DESC, id DESC LIMIT ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum := scanForum(stmt) - forums = append(forums, forum) - return nil - }, limit) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve recent forum posts: %w", err) + if limit > len(fs.allByLastPost) { + limit = len(fs.allByLastPost) } - return forums, nil + result := make([]*Forum, 0, limit) + for i := 0; i < limit; i++ { + if forum, exists := fs.GetByID(fs.allByLastPost[i]); exists { + result = append(result, forum) + } + } + return result, nil } // Retrieves forum posts containing the search term in title or content func Search(term string) ([]*Forum, error) { - var forums []*Forum + fs := GetStore() + fs.mu.RLock() + defer fs.mu.RUnlock() - query := `SELECT ` + forumColumns() + ` FROM forum WHERE LOWER(title) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?) ORDER BY last_post DESC, id DESC` - searchTerm := "%" + term + "%" + var result []*Forum + lowerTerm := strings.ToLower(term) - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum := scanForum(stmt) - forums = append(forums, forum) - return nil - }, searchTerm, searchTerm) - - if err != nil { - return nil, fmt.Errorf("failed to search forum posts: %w", err) + for _, id := range fs.allByLastPost { + if forum, exists := fs.GetByID(id); exists { + if strings.Contains(strings.ToLower(forum.Title), lowerTerm) || + strings.Contains(strings.ToLower(forum.Content), lowerTerm) { + result = append(result, forum) + } + } } - - return forums, nil + return result, nil } // Retrieves forum posts with activity since a specific timestamp func Since(since int64) ([]*Forum, error) { - var forums []*Forum + fs := GetStore() + fs.mu.RLock() + defer fs.mu.RUnlock() - query := `SELECT ` + forumColumns() + ` FROM forum WHERE last_post >= ? ORDER BY last_post DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - forum := scanForum(stmt) - forums = append(forums, forum) - return nil - }, since) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve forum posts since timestamp: %w", err) + var result []*Forum + for _, id := range fs.allByLastPost { + if forum, exists := fs.GetByID(id); exists && forum.LastPost >= since { + result = append(result, forum) + } } - - return forums, nil + return result, nil } -// Saves a new forum post to the database and sets the ID +// Saves a new forum post to the in-memory store and sets the ID func (f *Forum) Insert() error { - columns := `posted, last_post, author, parent, replies, title, content` - values := []any{f.Posted, f.LastPost, f.Author, f.Parent, f.Replies, f.Title, f.Content} - return database.Insert(f, columns, values...) + fs := GetStore() + + // Validate before insertion + if err := f.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if f.ID == 0 { + f.ID = fs.GetNextID() + } + + // Add to store + fs.AddForum(f) + return nil } // Returns the posted timestamp as a time.Time @@ -252,12 +413,12 @@ func (f *Forum) LastPostTime() time.Time { // Sets the posted timestamp from a time.Time func (f *Forum) SetPostedTime(t time.Time) { - f.Set("Posted", t.Unix()) + f.Posted = t.Unix() } // Sets the last post timestamp from a time.Time func (f *Forum) SetLastPostTime(t time.Time) { - f.Set("LastPost", t.Unix()) + f.LastPost = t.Unix() } // Returns true if this is a top-level thread (parent = 0) @@ -350,18 +511,18 @@ func (f *Forum) Contains(term string) bool { // Updates the last_post timestamp to current time func (f *Forum) UpdateLastPost() { - f.Set("LastPost", time.Now().Unix()) + f.LastPost = time.Now().Unix() } // Increments the reply count func (f *Forum) IncrementReplies() { - f.Set("Replies", f.Replies+1) + f.Replies++ } // Decrements the reply count (minimum 0) func (f *Forum) DecrementReplies() { if f.Replies > 0 { - f.Set("Replies", f.Replies-1) + f.Replies-- } } diff --git a/internal/items/items.go b/internal/items/items.go index bd34fda..3291d3b 100644 --- a/internal/items/items.go +++ b/internal/items/items.go @@ -1,48 +1,32 @@ package items import ( + "dk/internal/store" "fmt" - - "dk/internal/database" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" + "sort" + "sync" ) -// Item represents an item in the database +// Item represents an item in the game type Item struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Type int `db:"type" json:"type"` - Name string `db:"name" json:"name"` - Value int `db:"value" json:"value"` - Att int `db:"att" json:"att"` - Special string `db:"special" json:"special"` -} - -func (i *Item) GetTableName() string { - return "items" -} - -func (i *Item) GetID() int { - return i.ID -} - -func (i *Item) SetID(id int) { - i.ID = id -} - -func (i *Item) Set(field string, value any) error { - return database.Set(i, field, value) + ID int `json:"id"` + Type int `json:"type"` + Name string `json:"name"` + Value int `json:"value"` + Att int `json:"att"` + Special string `json:"special"` } func (i *Item) Save() error { - return database.Save(i) + itemStore := GetStore() + itemStore.UpdateItem(i) + return nil } func (i *Item) Delete() error { - return database.Delete(i) + itemStore := GetStore() + itemStore.RemoveItem(i.ID) + return nil } // Creates a new Item with sensible defaults @@ -56,18 +40,21 @@ func New() *Item { } } -var itemScanner = scanner.New[Item]() - -// Returns the column list for item queries -func itemColumns() string { - return itemScanner.Columns() -} - -// Populates an Item struct using the fast scanner -func scanItem(stmt *sqlite.Stmt) *Item { - item := &Item{} - itemScanner.Scan(stmt, item) - return item +// Validate checks if item has valid values +func (i *Item) Validate() error { + if i.Name == "" { + return fmt.Errorf("item name cannot be empty") + } + if i.Type < TypeWeapon || i.Type > TypeShield { + return fmt.Errorf("invalid item type: %d", i.Type) + } + if i.Value < 0 { + return fmt.Errorf("item Value cannot be negative") + } + if i.Att < 0 { + return fmt.Errorf("item Att cannot be negative") + } + return nil } // ItemType constants for item types @@ -77,71 +64,194 @@ const ( TypeShield = 3 ) +// ItemStore provides in-memory storage with O(1) lookups and item-specific indices +type ItemStore struct { + *store.BaseStore[Item] // Embedded generic store + byType map[int][]int // Type -> []ID + allByID []int // All IDs sorted by ID + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var itemStore *ItemStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + itemStore = &ItemStore{ + BaseStore: store.NewBaseStore[Item](), + byType: make(map[int][]int), + allByID: make([]int, 0), + } +} + +// GetStore returns the global item store +func GetStore() *ItemStore { + storeOnce.Do(initStore) + return itemStore +} + +// AddItem adds an item to the in-memory store and updates all indices +func (is *ItemStore) AddItem(item *Item) { + is.mu.Lock() + defer is.mu.Unlock() + + // Validate item + if err := item.Validate(); err != nil { + return + } + + // Add to base store + is.Add(item.ID, item) + + // Rebuild indices + is.rebuildIndicesUnsafe() +} + +// RemoveItem removes an item from the store and updates indices +func (is *ItemStore) RemoveItem(id int) { + is.mu.Lock() + defer is.mu.Unlock() + + // Remove from base store + is.Remove(id) + + // Rebuild indices + is.rebuildIndicesUnsafe() +} + +// UpdateItem updates an item efficiently +func (is *ItemStore) UpdateItem(item *Item) { + is.mu.Lock() + defer is.mu.Unlock() + + // Validate item + if err := item.Validate(); err != nil { + return + } + + // Update base store + is.Add(item.ID, item) + + // Rebuild indices + is.rebuildIndicesUnsafe() +} + +// LoadData loads item data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + is := GetStore() + + // Load from base store, which handles JSON loading + if err := is.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + is.rebuildIndices() + return nil +} + +// SaveData saves item data to JSON file +func SaveData(dataPath string) error { + is := GetStore() + return is.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (is *ItemStore) rebuildIndicesUnsafe() { + // Clear indices + is.byType = make(map[int][]int) + is.allByID = make([]int, 0) + + // Collect all items and build indices + allItems := is.GetAll() + + for id, item := range allItems { + // Type index + is.byType[item.Type] = append(is.byType[item.Type], id) + + // All IDs + is.allByID = append(is.allByID, id) + } + + // Sort allByID by ID + sort.Ints(is.allByID) + + // Sort type indices by ID + for itemType := range is.byType { + sort.Ints(is.byType[itemType]) + } +} + +// rebuildIndices rebuilds all item-specific indices from base store data +func (is *ItemStore) rebuildIndices() { + is.mu.Lock() + defer is.mu.Unlock() + is.rebuildIndicesUnsafe() +} + // Retrieves an item by ID func Find(id int) (*Item, error) { - var item *Item - - query := `SELECT ` + itemColumns() + ` FROM items WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - item = scanItem(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find item: %w", err) - } - - if item == nil { + is := GetStore() + item, exists := is.GetByID(id) + if !exists { return nil, fmt.Errorf("item with ID %d not found", id) } - return item, nil } // Retrieves all items func All() ([]*Item, error) { - var items []*Item + is := GetStore() + is.mu.RLock() + defer is.mu.RUnlock() - query := `SELECT ` + itemColumns() + ` FROM items ORDER BY id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - item := scanItem(stmt) - items = append(items, item) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve all items: %w", err) + result := make([]*Item, 0, len(is.allByID)) + for _, id := range is.allByID { + if item, exists := is.GetByID(id); exists { + result = append(result, item) + } } - - return items, nil + return result, nil } // Retrieves items by type func ByType(itemType int) ([]*Item, error) { - var items []*Item + is := GetStore() + is.mu.RLock() + defer is.mu.RUnlock() - query := `SELECT ` + itemColumns() + ` FROM items WHERE type = ? ORDER BY id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - item := scanItem(stmt) - items = append(items, item) - return nil - }, itemType) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve items by type: %w", err) + ids, exists := is.byType[itemType] + if !exists { + return []*Item{}, nil } - return items, nil + result := make([]*Item, 0, len(ids)) + for _, id := range ids { + if item, exists := is.GetByID(id); exists { + result = append(result, item) + } + } + return result, nil } -// Saves a new item to the database and sets the ID +// Saves a new item to the in-memory store and sets the ID func (i *Item) Insert() error { - columns := `type, name, value, att, special` - values := []any{i.Type, i.Name, i.Value, i.Att, i.Special} - return database.Insert(i, columns, values...) + is := GetStore() + + // Validate before insertion + if err := i.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if i.ID == 0 { + i.ID = is.GetNextID() + } + + // Add to store + is.AddItem(i) + return nil } // Returns true if the item is a weapon diff --git a/internal/monsters/monsters.go b/internal/monsters/monsters.go index ec72301..c0cd7e8 100644 --- a/internal/monsters/monsters.go +++ b/internal/monsters/monsters.go @@ -1,10 +1,8 @@ package monsters import ( - "encoding/json" + "dk/internal/store" "fmt" - "os" - "path/filepath" "sort" "sync" ) @@ -23,14 +21,14 @@ type Monster struct { } func (m *Monster) Save() error { - store := GetStore() - store.UpdateMonster(m) + monsterStore := GetStore() + monsterStore.UpdateMonster(m) return nil } func (m *Monster) Delete() error { - store := GetStore() - store.RemoveMonster(m.ID) + monsterStore := GetStore() + monsterStore.RemoveMonster(m.ID) return nil } @@ -48,6 +46,22 @@ func New() *Monster { } } +// Validate checks if monster has valid values +func (m *Monster) Validate() error { + if m.Name == "" { + return fmt.Errorf("monster name cannot be empty") + } + if m.MaxHP < 1 { + return fmt.Errorf("monster MaxHP must be at least 1") + } + if m.Level < 1 { + return fmt.Errorf("monster Level must be at least 1") + } + if m.Immune < ImmuneNone || m.Immune > ImmuneSleep { + return fmt.Errorf("invalid immunity type: %d", m.Immune) + } + return nil +} // Immunity constants for monster immunity types const ( @@ -56,35 +70,33 @@ const ( ImmuneSleep = 2 // Immune to Sleep spells ) -// MonsterStore provides in-memory storage with O(1) lookups +// MonsterStore provides in-memory storage with O(1) lookups and monster-specific indices type MonsterStore struct { - monsters map[int]*Monster // ID -> Monster (O(1)) - byLevel map[int][]*Monster // Level -> []*Monster (O(1) to get slice) - byImmunity map[int][]*Monster // Immunity -> []*Monster (O(1) to get slice) - allByLevel []*Monster // Pre-sorted by level, id - maxID int - mu sync.RWMutex + *store.BaseStore[Monster] // Embedded generic store + byLevel map[int][]int // Level -> []ID + byImmunity map[int][]int // Immunity -> []ID + allByLevel []int // All IDs sorted by level, then ID + mu sync.RWMutex // Protects indices } // Global in-memory store -var store *MonsterStore +var monsterStore *MonsterStore var storeOnce sync.Once // Initialize the in-memory store func initStore() { - store = &MonsterStore{ - monsters: make(map[int]*Monster), - byLevel: make(map[int][]*Monster), - byImmunity: make(map[int][]*Monster), - allByLevel: make([]*Monster, 0), - maxID: 0, + monsterStore = &MonsterStore{ + BaseStore: store.NewBaseStore[Monster](), + byLevel: make(map[int][]int), + byImmunity: make(map[int][]int), + allByLevel: make([]int, 0), } } // GetStore returns the global monster store func GetStore() *MonsterStore { storeOnce.Do(initStore) - return store + return monsterStore } // AddMonster adds a monster to the in-memory store and updates all indices @@ -92,41 +104,16 @@ func (ms *MonsterStore) AddMonster(monster *Monster) { ms.mu.Lock() defer ms.mu.Unlock() - // Add to primary store - ms.monsters[monster.ID] = monster - - // Update max ID - if monster.ID > ms.maxID { - ms.maxID = monster.ID + // Validate monster + if err := monster.Validate(); err != nil { + return } - // Add to level index - ms.byLevel[monster.Level] = append(ms.byLevel[monster.Level], monster) - - // Add to immunity index - ms.byImmunity[monster.Immune] = append(ms.byImmunity[monster.Immune], monster) - - // Add to sorted list and re-sort - ms.allByLevel = append(ms.allByLevel, monster) - sort.Slice(ms.allByLevel, func(i, j int) bool { - if ms.allByLevel[i].Level == ms.allByLevel[j].Level { - return ms.allByLevel[i].ID < ms.allByLevel[j].ID - } - return ms.allByLevel[i].Level < ms.allByLevel[j].Level - }) + // Add to base store + ms.Add(monster.ID, monster) - // Sort level index - sort.Slice(ms.byLevel[monster.Level], func(i, j int) bool { - return ms.byLevel[monster.Level][i].ID < ms.byLevel[monster.Level][j].ID - }) - - // Sort immunity index - sort.Slice(ms.byImmunity[monster.Immune], func(i, j int) bool { - if ms.byImmunity[monster.Immune][i].Level == ms.byImmunity[monster.Immune][j].Level { - return ms.byImmunity[monster.Immune][i].ID < ms.byImmunity[monster.Immune][j].ID - } - return ms.byImmunity[monster.Immune][i].Level < ms.byImmunity[monster.Immune][j].Level - }) + // Rebuild indices + ms.rebuildIndicesUnsafe() } // RemoveMonster removes a monster from the store and updates indices @@ -134,344 +121,202 @@ func (ms *MonsterStore) RemoveMonster(id int) { ms.mu.Lock() defer ms.mu.Unlock() - monster, exists := ms.monsters[id] - if !exists { - return - } + // Remove from base store + ms.Remove(id) - // Remove from primary store - delete(ms.monsters, id) - - // Remove from level index - levelMonsters := ms.byLevel[monster.Level] - for i, m := range levelMonsters { - if m.ID == id { - ms.byLevel[monster.Level] = append(levelMonsters[:i], levelMonsters[i+1:]...) - break - } - } - - // Remove from immunity index - immunityMonsters := ms.byImmunity[monster.Immune] - for i, m := range immunityMonsters { - if m.ID == id { - ms.byImmunity[monster.Immune] = append(immunityMonsters[:i], immunityMonsters[i+1:]...) - break - } - } - - // Remove from sorted list - for i, m := range ms.allByLevel { - if m.ID == id { - ms.allByLevel = append(ms.allByLevel[:i], ms.allByLevel[i+1:]...) - break - } - } + // Rebuild indices + ms.rebuildIndicesUnsafe() } -// UpdateMonster updates a monster and rebuilds indices +// UpdateMonster updates a monster efficiently func (ms *MonsterStore) UpdateMonster(monster *Monster) { - ms.RemoveMonster(monster.ID) - ms.AddMonster(monster) -} - -// GetNextID returns the next available ID -func (ms *MonsterStore) GetNextID() int { - ms.mu.RLock() - defer ms.mu.RUnlock() - return ms.maxID + 1 -} - -// LoadFromJSON loads monster data from a JSON file -func (ms *MonsterStore) LoadFromJSON(filename string) error { ms.mu.Lock() defer ms.mu.Unlock() - data, err := os.ReadFile(filename) - if err != nil { - if os.IsNotExist(err) { - return nil // File doesn't exist, start with empty store - } - return fmt.Errorf("failed to read monsters JSON: %w", err) + // Validate monster + if err := monster.Validate(); err != nil { + return } - // Handle empty file - if len(data) == 0 { - return nil // Empty file, start with empty store - } + // Update base store + ms.Add(monster.ID, monster) - var monsters []*Monster - if err := json.Unmarshal(data, &monsters); err != nil { - return fmt.Errorf("failed to unmarshal monsters JSON: %w", err) - } - - // Clear existing data - ms.monsters = make(map[int]*Monster) - ms.byLevel = make(map[int][]*Monster) - ms.byImmunity = make(map[int][]*Monster) - ms.allByLevel = make([]*Monster, 0) - ms.maxID = 0 - - // Add all monsters - for _, monster := range monsters { - ms.monsters[monster.ID] = monster - if monster.ID > ms.maxID { - ms.maxID = monster.ID - } - ms.byLevel[monster.Level] = append(ms.byLevel[monster.Level], monster) - ms.byImmunity[monster.Immune] = append(ms.byImmunity[monster.Immune], monster) - ms.allByLevel = append(ms.allByLevel, monster) - } - - // Sort all indices - sort.Slice(ms.allByLevel, func(i, j int) bool { - if ms.allByLevel[i].Level == ms.allByLevel[j].Level { - return ms.allByLevel[i].ID < ms.allByLevel[j].ID - } - return ms.allByLevel[i].Level < ms.allByLevel[j].Level - }) - - for level := range ms.byLevel { - sort.Slice(ms.byLevel[level], func(i, j int) bool { - return ms.byLevel[level][i].ID < ms.byLevel[level][j].ID - }) - } - - for immunity := range ms.byImmunity { - sort.Slice(ms.byImmunity[immunity], func(i, j int) bool { - if ms.byImmunity[immunity][i].Level == ms.byImmunity[immunity][j].Level { - return ms.byImmunity[immunity][i].ID < ms.byImmunity[immunity][j].ID - } - return ms.byImmunity[immunity][i].Level < ms.byImmunity[immunity][j].Level - }) - } - - return nil + // Rebuild indices + ms.rebuildIndicesUnsafe() } -// SaveToJSON saves monster data to a JSON file -func (ms *MonsterStore) SaveToJSON(filename string) error { - ms.mu.RLock() - defer ms.mu.RUnlock() +// LoadData loads monster data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + ms := GetStore() - monsters := make([]*Monster, 0, len(ms.monsters)) - for _, monster := range ms.monsters { - monsters = append(monsters, monster) - } - - // Sort by ID for consistent output - sort.Slice(monsters, func(i, j int) bool { - return monsters[i].ID < monsters[j].ID - }) - - data, err := json.MarshalIndent(monsters, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal monsters to JSON: %w", err) - } - - if err := os.WriteFile(filename, data, 0644); err != nil { - return fmt.Errorf("failed to write monsters JSON: %w", err) - } - - return nil -} - -// findMonstersDataPath finds the monsters.json file relative to the current working directory -func findMonstersDataPath() (string, error) { - // Try current directory first (cwd/data/monsters.json) - if _, err := os.Stat("data/monsters.json"); err == nil { - return "data/monsters.json", nil - } - - // Walk up directories to find the data folder - dir, err := os.Getwd() - if err != nil { - return "", err - } - - for { - dataPath := filepath.Join(dir, "data", "monsters.json") - if _, err := os.Stat(dataPath); err == nil { - return dataPath, nil - } - - parent := filepath.Dir(dir) - if parent == dir { - break // reached root - } - dir = parent - } - - // Default to current directory if not found - return "data/monsters.json", nil -} - -// LoadData loads monster data from JSON file, or initializes with default data -func LoadData() error { - store := GetStore() - - dataPath, err := findMonstersDataPath() - if err != nil { - return fmt.Errorf("failed to find monsters data path: %w", err) - } - - if err := store.LoadFromJSON(dataPath); err != nil { - // If JSON doesn't exist, initialize with default monsters - if os.IsNotExist(err) { - fmt.Println("No existing monster data found, initializing with defaults...") - if err := initializeDefaultMonsters(); err != nil { - return fmt.Errorf("failed to initialize default monsters: %w", err) - } - // Save the default data - if err := SaveData(); err != nil { - return fmt.Errorf("failed to save default monster data: %w", err) - } - fmt.Printf("Initialized %d default monsters\n", len(store.monsters)) - } else { - return fmt.Errorf("failed to load from JSON: %w", err) - } - } else { - fmt.Printf("Loaded %d monsters from JSON\n", len(store.monsters)) - } - - return nil -} - -// initializeDefaultMonsters creates the default monster set -func initializeDefaultMonsters() error { - store := GetStore() - - // Default monsters from the original SQL data - defaultMonsters := []*Monster{ - {ID: 1, Name: "Blue Slime", MaxHP: 4, MaxDmg: 3, Armor: 1, Level: 1, MaxExp: 1, MaxGold: 1, Immune: ImmuneNone}, - {ID: 2, Name: "Red Slime", MaxHP: 6, MaxDmg: 5, Armor: 1, Level: 1, MaxExp: 2, MaxGold: 1, Immune: ImmuneNone}, - {ID: 3, Name: "Critter", MaxHP: 6, MaxDmg: 5, Armor: 2, Level: 1, MaxExp: 4, MaxGold: 2, Immune: ImmuneNone}, - {ID: 4, Name: "Creature", MaxHP: 10, MaxDmg: 8, Armor: 2, Level: 2, MaxExp: 4, MaxGold: 2, Immune: ImmuneNone}, - {ID: 5, Name: "Shadow", MaxHP: 10, MaxDmg: 9, Armor: 3, Level: 2, MaxExp: 6, MaxGold: 2, Immune: ImmuneHurt}, - {ID: 6, Name: "Drake", MaxHP: 11, MaxDmg: 10, Armor: 3, Level: 2, MaxExp: 8, MaxGold: 3, Immune: ImmuneNone}, - {ID: 7, Name: "Shade", MaxHP: 12, MaxDmg: 10, Armor: 3, Level: 3, MaxExp: 10, MaxGold: 3, Immune: ImmuneHurt}, - {ID: 8, Name: "Drakelor", MaxHP: 14, MaxDmg: 12, Armor: 4, Level: 3, MaxExp: 10, MaxGold: 3, Immune: ImmuneNone}, - {ID: 9, Name: "Silver Slime", MaxHP: 15, MaxDmg: 100, Armor: 200, Level: 30, MaxExp: 15, MaxGold: 1000, Immune: ImmuneSleep}, - {ID: 10, Name: "Scamp", MaxHP: 16, MaxDmg: 13, Armor: 5, Level: 4, MaxExp: 15, MaxGold: 5, Immune: ImmuneNone}, - } - - for _, monster := range defaultMonsters { - store.AddMonster(monster) + // Load from base store, which handles JSON loading + if err := ms.BaseStore.LoadData(dataPath); err != nil { + return err } + // Rebuild indices from loaded data + ms.rebuildIndices() return nil } // SaveData saves monster data to JSON file -func SaveData() error { - store := GetStore() - - dataPath, err := findMonstersDataPath() - if err != nil { - return fmt.Errorf("failed to find monsters data path: %w", err) +func SaveData(dataPath string) error { + ms := GetStore() + return ms.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (ms *MonsterStore) rebuildIndicesUnsafe() { + // Clear indices + ms.byLevel = make(map[int][]int) + ms.byImmunity = make(map[int][]int) + ms.allByLevel = make([]int, 0) + + // Collect all monsters and build indices + allMonsters := ms.GetAll() + + // Build level and immunity indices + for id, monster := range allMonsters { + ms.byLevel[monster.Level] = append(ms.byLevel[monster.Level], id) + ms.byImmunity[monster.Immune] = append(ms.byImmunity[monster.Immune], id) + ms.allByLevel = append(ms.allByLevel, id) } - - // Ensure data directory exists - dataDir := filepath.Dir(dataPath) - if err := os.MkdirAll(dataDir, 0755); err != nil { - return fmt.Errorf("failed to create data directory: %w", err) + + // Sort allByLevel by level first, then by ID + sort.Slice(ms.allByLevel, func(i, j int) bool { + monsterI, _ := ms.GetByID(ms.allByLevel[i]) + monsterJ, _ := ms.GetByID(ms.allByLevel[j]) + if monsterI.Level == monsterJ.Level { + return ms.allByLevel[i] < ms.allByLevel[j] + } + return monsterI.Level < monsterJ.Level + }) + + // Sort level indices by ID + for level := range ms.byLevel { + sort.Ints(ms.byLevel[level]) } - - if err := store.SaveToJSON(dataPath); err != nil { - return fmt.Errorf("failed to save monsters to JSON: %w", err) + + // Sort immunity indices by level, then ID + for immunity := range ms.byImmunity { + sort.Slice(ms.byImmunity[immunity], func(i, j int) bool { + monsterI, _ := ms.GetByID(ms.byImmunity[immunity][i]) + monsterJ, _ := ms.GetByID(ms.byImmunity[immunity][j]) + if monsterI.Level == monsterJ.Level { + return ms.byImmunity[immunity][i] < ms.byImmunity[immunity][j] + } + return monsterI.Level < monsterJ.Level + }) } - - fmt.Printf("Saved %d monsters to JSON\n", len(store.monsters)) - return nil +} + +// rebuildIndices rebuilds all monster-specific indices from base store data +func (ms *MonsterStore) rebuildIndices() { + ms.mu.Lock() + defer ms.mu.Unlock() + ms.rebuildIndicesUnsafe() } // Retrieves a monster by ID - O(1) lookup func Find(id int) (*Monster, error) { - store := GetStore() - store.mu.RLock() - defer store.mu.RUnlock() - - monster, exists := store.monsters[id] + ms := GetStore() + monster, exists := ms.GetByID(id) if !exists { return nil, fmt.Errorf("monster with ID %d not found", id) } - return monster, nil } // Retrieves all monsters - O(1) lookup (returns pre-sorted slice) func All() ([]*Monster, error) { - store := GetStore() - store.mu.RLock() - defer store.mu.RUnlock() + ms := GetStore() + ms.mu.RLock() + defer ms.mu.RUnlock() - // Return a copy of the slice to prevent external modifications - result := make([]*Monster, len(store.allByLevel)) - copy(result, store.allByLevel) + result := make([]*Monster, 0, len(ms.allByLevel)) + for _, id := range ms.allByLevel { + if monster, exists := ms.GetByID(id); exists { + result = append(result, monster) + } + } return result, nil } // Retrieves monsters by level - O(1) lookup func ByLevel(level int) ([]*Monster, error) { - store := GetStore() - store.mu.RLock() - defer store.mu.RUnlock() + ms := GetStore() + ms.mu.RLock() + defer ms.mu.RUnlock() - monsters, exists := store.byLevel[level] + ids, exists := ms.byLevel[level] if !exists { return []*Monster{}, nil } - // Return a copy of the slice to prevent external modifications - result := make([]*Monster, len(monsters)) - copy(result, monsters) + result := make([]*Monster, 0, len(ids)) + for _, id := range ids { + if monster, exists := ms.GetByID(id); exists { + result = append(result, monster) + } + } return result, nil } // Retrieves monsters within a level range (inclusive) - O(k) where k is result size func ByLevelRange(minLevel, maxLevel int) ([]*Monster, error) { - store := GetStore() - store.mu.RLock() - defer store.mu.RUnlock() + ms := GetStore() + ms.mu.RLock() + defer ms.mu.RUnlock() var result []*Monster for level := minLevel; level <= maxLevel; level++ { - if monsters, exists := store.byLevel[level]; exists { - result = append(result, monsters...) + if ids, exists := ms.byLevel[level]; exists { + for _, id := range ids { + if monster, exists := ms.GetByID(id); exists { + result = append(result, monster) + } + } } } - return result, nil } // Retrieves monsters by immunity type - O(1) lookup func ByImmunity(immunityType int) ([]*Monster, error) { - store := GetStore() - store.mu.RLock() - defer store.mu.RUnlock() + ms := GetStore() + ms.mu.RLock() + defer ms.mu.RUnlock() - monsters, exists := store.byImmunity[immunityType] + ids, exists := ms.byImmunity[immunityType] if !exists { return []*Monster{}, nil } - // Return a copy of the slice to prevent external modifications - result := make([]*Monster, len(monsters)) - copy(result, monsters) + result := make([]*Monster, 0, len(ids)) + for _, id := range ids { + if monster, exists := ms.GetByID(id); exists { + result = append(result, monster) + } + } return result, nil } // Saves a new monster to the in-memory store and sets the ID func (m *Monster) Insert() error { - store := GetStore() - + ms := GetStore() + + // Validate before insertion + if err := m.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + // Assign new ID if not set if m.ID == 0 { - m.ID = store.GetNextID() + m.ID = ms.GetNextID() } - + // Add to store - store.AddMonster(m) + ms.AddMonster(m) return nil } diff --git a/internal/news/news.go b/internal/news/news.go index 70cc8c5..4a998b5 100644 --- a/internal/news/news.go +++ b/internal/news/news.go @@ -1,47 +1,32 @@ package news import ( - "dk/internal/database" - "dk/internal/helpers/scanner" + "dk/internal/store" "fmt" + "sort" "strings" + "sync" "time" - - "zombiezen.com/go/sqlite" ) -// News represents a news post in the database +// News represents a news post in the game type News struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Author int `db:"author" json:"author"` - Posted int64 `db:"posted" json:"posted"` - Content string `db:"content" json:"content"` -} - -func (n *News) GetTableName() string { - return "news" -} - -func (n *News) GetID() int { - return n.ID -} - -func (n *News) SetID(id int) { - n.ID = id -} - -func (n *News) Set(field string, value any) error { - return database.Set(n, field, value) + ID int `json:"id"` + Author int `json:"author"` + Posted int64 `json:"posted"` + Content string `json:"content"` } func (n *News) Save() error { - return database.Save(n) + newsStore := GetStore() + newsStore.UpdateNews(n) + return nil } func (n *News) Delete() error { - return database.Delete(n) + newsStore := GetStore() + newsStore.RemoveNews(n.ID) + return nil } // Creates a new News with sensible defaults @@ -53,142 +38,287 @@ func New() *News { } } -var newsScanner = scanner.New[News]() - -// Returns the column list for news queries -func newsColumns() string { - return newsScanner.Columns() +// Validate checks if news has valid values +func (n *News) Validate() error { + if n.Posted < 0 { + return fmt.Errorf("news Posted timestamp cannot be negative") + } + if strings.TrimSpace(n.Content) == "" { + return fmt.Errorf("news Content cannot be empty") + } + return nil } -// Populates a News struct using the fast scanner -func scanNews(stmt *sqlite.Stmt) *News { - news := &News{} - newsScanner.Scan(stmt, news) - return news +// NewsStore provides in-memory storage with O(1) lookups and news-specific indices +type NewsStore struct { + *store.BaseStore[News] // Embedded generic store + byAuthor map[int][]int // Author -> []ID + allByPosted []int // All IDs sorted by posted DESC, id DESC + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var newsStore *NewsStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + newsStore = &NewsStore{ + BaseStore: store.NewBaseStore[News](), + byAuthor: make(map[int][]int), + allByPosted: make([]int, 0), + } +} + +// GetStore returns the global news store +func GetStore() *NewsStore { + storeOnce.Do(initStore) + return newsStore +} + +// AddNews adds a news post to the in-memory store and updates all indices +func (ns *NewsStore) AddNews(news *News) { + ns.mu.Lock() + defer ns.mu.Unlock() + + // Validate news + if err := news.Validate(); err != nil { + return + } + + // Add to base store + ns.Add(news.ID, news) + + // Rebuild indices + ns.rebuildIndicesUnsafe() +} + +// RemoveNews removes a news post from the store and updates indices +func (ns *NewsStore) RemoveNews(id int) { + ns.mu.Lock() + defer ns.mu.Unlock() + + // Remove from base store + ns.Remove(id) + + // Rebuild indices + ns.rebuildIndicesUnsafe() +} + +// UpdateNews updates a news post efficiently +func (ns *NewsStore) UpdateNews(news *News) { + ns.mu.Lock() + defer ns.mu.Unlock() + + // Validate news + if err := news.Validate(); err != nil { + return + } + + // Update base store + ns.Add(news.ID, news) + + // Rebuild indices + ns.rebuildIndicesUnsafe() +} + +// LoadData loads news data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + ns := GetStore() + + // Load from base store, which handles JSON loading + if err := ns.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + ns.rebuildIndices() + return nil +} + +// SaveData saves news data to JSON file +func SaveData(dataPath string) error { + ns := GetStore() + return ns.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (ns *NewsStore) rebuildIndicesUnsafe() { + // Clear indices + ns.byAuthor = make(map[int][]int) + ns.allByPosted = make([]int, 0) + + // Collect all news and build indices + allNews := ns.GetAll() + + for id, news := range allNews { + // Author index + ns.byAuthor[news.Author] = append(ns.byAuthor[news.Author], id) + + // All IDs + ns.allByPosted = append(ns.allByPosted, id) + } + + // Sort allByPosted by posted DESC, then ID DESC + sort.Slice(ns.allByPosted, func(i, j int) bool { + newsI, _ := ns.GetByID(ns.allByPosted[i]) + newsJ, _ := ns.GetByID(ns.allByPosted[j]) + if newsI.Posted != newsJ.Posted { + return newsI.Posted > newsJ.Posted // DESC + } + return ns.allByPosted[i] > ns.allByPosted[j] // DESC + }) + + // Sort author indices by posted DESC, then ID DESC + for author := range ns.byAuthor { + sort.Slice(ns.byAuthor[author], func(i, j int) bool { + newsI, _ := ns.GetByID(ns.byAuthor[author][i]) + newsJ, _ := ns.GetByID(ns.byAuthor[author][j]) + if newsI.Posted != newsJ.Posted { + return newsI.Posted > newsJ.Posted // DESC + } + return ns.byAuthor[author][i] > ns.byAuthor[author][j] // DESC + }) + } +} + +// rebuildIndices rebuilds all news-specific indices from base store data +func (ns *NewsStore) rebuildIndices() { + ns.mu.Lock() + defer ns.mu.Unlock() + ns.rebuildIndicesUnsafe() } // Retrieves a news post by ID func Find(id int) (*News, error) { - var news *News - - query := `SELECT ` + newsColumns() + ` FROM news WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - news = scanNews(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find news: %w", err) - } - - if news == nil { + ns := GetStore() + news, exists := ns.GetByID(id) + if !exists { return nil, fmt.Errorf("news with ID %d not found", id) } - return news, nil } // Retrieves all news posts ordered by posted date (newest first) func All() ([]*News, error) { - var newsPosts []*News + ns := GetStore() + ns.mu.RLock() + defer ns.mu.RUnlock() - query := `SELECT ` + newsColumns() + ` FROM news ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - news := scanNews(stmt) - newsPosts = append(newsPosts, news) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve all news: %w", err) + result := make([]*News, 0, len(ns.allByPosted)) + for _, id := range ns.allByPosted { + if news, exists := ns.GetByID(id); exists { + result = append(result, news) + } } - - return newsPosts, nil + return result, nil } // Retrieves news posts by a specific author func ByAuthor(authorID int) ([]*News, error) { - var newsPosts []*News + ns := GetStore() + ns.mu.RLock() + defer ns.mu.RUnlock() - query := `SELECT ` + newsColumns() + ` FROM news WHERE author = ? ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - news := scanNews(stmt) - newsPosts = append(newsPosts, news) - return nil - }, authorID) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve news by author: %w", err) + ids, exists := ns.byAuthor[authorID] + if !exists { + return []*News{}, nil } - return newsPosts, nil + result := make([]*News, 0, len(ids)) + for _, id := range ids { + if news, exists := ns.GetByID(id); exists { + result = append(result, news) + } + } + return result, nil } // Retrieves the most recent news posts (limited by count) func Recent(limit int) ([]*News, error) { - var newsPosts []*News + ns := GetStore() + ns.mu.RLock() + defer ns.mu.RUnlock() - query := `SELECT ` + newsColumns() + ` FROM news ORDER BY posted DESC, id DESC LIMIT ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - news := scanNews(stmt) - newsPosts = append(newsPosts, news) - return nil - }, limit) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve recent news: %w", err) + if limit > len(ns.allByPosted) { + limit = len(ns.allByPosted) } - return newsPosts, nil + result := make([]*News, 0, limit) + for i := 0; i < limit; i++ { + if news, exists := ns.GetByID(ns.allByPosted[i]); exists { + result = append(result, news) + } + } + return result, nil } // Retrieves news posts since a specific timestamp func Since(since int64) ([]*News, error) { - var newsPosts []*News + ns := GetStore() + ns.mu.RLock() + defer ns.mu.RUnlock() - query := `SELECT ` + newsColumns() + ` FROM news WHERE posted >= ? ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - news := scanNews(stmt) - newsPosts = append(newsPosts, news) - return nil - }, since) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve news since timestamp: %w", err) + var result []*News + for _, id := range ns.allByPosted { + if news, exists := ns.GetByID(id); exists && news.Posted >= since { + result = append(result, news) + } } - - return newsPosts, nil + return result, nil } // Retrieves news posts between two timestamps (inclusive) func Between(start, end int64) ([]*News, error) { - var newsPosts []*News + ns := GetStore() + ns.mu.RLock() + defer ns.mu.RUnlock() - query := `SELECT ` + newsColumns() + ` FROM news WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - news := scanNews(stmt) - newsPosts = append(newsPosts, news) - return nil - }, start, end) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve news between timestamps: %w", err) + var result []*News + for _, id := range ns.allByPosted { + if news, exists := ns.GetByID(id); exists && news.Posted >= start && news.Posted <= end { + result = append(result, news) + } } - - return newsPosts, nil + return result, nil } -// Saves a new news post to the database and sets the ID +// Retrieves news posts containing the search term in content +func Search(term string) ([]*News, error) { + ns := GetStore() + ns.mu.RLock() + defer ns.mu.RUnlock() + + var result []*News + lowerTerm := strings.ToLower(term) + + for _, id := range ns.allByPosted { + if news, exists := ns.GetByID(id); exists { + if strings.Contains(strings.ToLower(news.Content), lowerTerm) { + result = append(result, news) + } + } + } + return result, nil +} + +// Saves a new news post to the in-memory store and sets the ID func (n *News) Insert() error { - columns := `author, posted, content` - values := []any{n.Author, n.Posted, n.Content} - return database.Insert(n, columns, values...) + ns := GetStore() + + // Validate before insertion + if err := n.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if n.ID == 0 { + n.ID = ns.GetNextID() + } + + // Add to store + ns.AddNews(n) + return nil } // Returns the posted timestamp as a time.Time @@ -198,7 +328,7 @@ func (n *News) PostedTime() time.Time { // Sets the posted timestamp from a time.Time func (n *News) SetPostedTime(t time.Time) { - n.Set("Posted", t.Unix()) + n.Posted = t.Unix() } // Returns true if the news post was made within the last 24 hours @@ -276,23 +406,3 @@ func (n *News) Contains(term string) bool { func (n *News) IsEmpty() bool { return strings.TrimSpace(n.Content) == "" } - -// Retrieves news posts containing the search term in content -func Search(term string) ([]*News, error) { - var newsPosts []*News - - query := `SELECT ` + newsColumns() + ` FROM news WHERE LOWER(content) LIKE LOWER(?) ORDER BY posted DESC, id DESC` - searchTerm := "%" + term + "%" - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - news := scanNews(stmt) - newsPosts = append(newsPosts, news) - return nil - }, searchTerm) - - if err != nil { - return nil, fmt.Errorf("failed to search news: %w", err) - } - - return newsPosts, nil -} diff --git a/internal/server/server.go b/internal/server/server.go index fd516e0..dda6940 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,7 +9,6 @@ import ( "syscall" "dk/internal/auth" - "dk/internal/database" "dk/internal/middleware" "dk/internal/monsters" "dk/internal/router" @@ -27,14 +26,8 @@ func Start(port string) error { // Initialize template singleton template.InitializeCache(cwd) - // Initialize database singleton - if err := database.Init("dk.db"); err != nil { - return fmt.Errorf("failed to initialize database: %w", err) - } - defer database.Close() - // Load monster data into memory - if err := monsters.LoadData(); err != nil { + if err := monsters.LoadData("data/monsters.json"); err != nil { return fmt.Errorf("failed to load monster data: %w", err) } @@ -104,7 +97,7 @@ func Start(port string) error { // Save monster data before shutdown log.Println("Saving monster data...") - if err := monsters.SaveData(); err != nil { + if err := monsters.SaveData("data/monsters.json"); err != nil { log.Printf("Error saving monster data: %v", err) } diff --git a/internal/spells/spells.go b/internal/spells/spells.go index 28f0b12..40e5a32 100644 --- a/internal/spells/spells.go +++ b/internal/spells/spells.go @@ -1,47 +1,32 @@ package spells import ( + "dk/internal/store" "fmt" - - "dk/internal/database" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" + "sort" + "strings" + "sync" ) -// Spell represents a spell in the database +// Spell represents a spell in the game type Spell struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Name string `db:"name" json:"name"` - MP int `db:"mp" json:"mp"` - Attribute int `db:"attribute" json:"attribute"` - Type int `db:"type" json:"type"` -} - -func (s *Spell) GetTableName() string { - return "spells" -} - -func (s *Spell) GetID() int { - return s.ID -} - -func (s *Spell) SetID(id int) { - s.ID = id -} - -func (s *Spell) Set(field string, value any) error { - return database.Set(s, field, value) + ID int `json:"id"` + Name string `json:"name"` + MP int `json:"mp"` + Attribute int `json:"attribute"` + Type int `json:"type"` } func (s *Spell) Save() error { - return database.Save(s) + spellStore := GetStore() + spellStore.UpdateSpell(s) + return nil } func (s *Spell) Delete() error { - return database.Delete(s) + spellStore := GetStore() + spellStore.RemoveSpell(s.ID) + return nil } // Creates a new Spell with sensible defaults @@ -54,18 +39,21 @@ func New() *Spell { } } -var spellScanner = scanner.New[Spell]() - -// Returns the column list for spell queries -func spellColumns() string { - return spellScanner.Columns() -} - -// Populates a Spell struct using the fast scanner -func scanSpell(stmt *sqlite.Stmt) *Spell { - spell := &Spell{} - spellScanner.Scan(stmt, spell) - return spell +// Validate checks if spell has valid values +func (s *Spell) Validate() error { + if s.Name == "" { + return fmt.Errorf("spell name cannot be empty") + } + if s.MP < 0 { + return fmt.Errorf("spell MP cannot be negative") + } + if s.Attribute < 0 { + return fmt.Errorf("spell Attribute cannot be negative") + } + if s.Type < TypeHealing || s.Type > TypeDefenseBoost { + return fmt.Errorf("invalid spell type: %d", s.Type) + } + return nil } // SpellType constants for spell types @@ -77,131 +65,293 @@ const ( TypeDefenseBoost = 5 ) +// SpellStore provides in-memory storage with O(1) lookups and spell-specific indices +type SpellStore struct { + *store.BaseStore[Spell] // Embedded generic store + byType map[int][]int // Type -> []ID + byName map[string]int // Name (lowercase) -> ID + byMP map[int][]int // MP -> []ID + allByTypeMP []int // All IDs sorted by type, MP, ID + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var spellStore *SpellStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + spellStore = &SpellStore{ + BaseStore: store.NewBaseStore[Spell](), + byType: make(map[int][]int), + byName: make(map[string]int), + byMP: make(map[int][]int), + allByTypeMP: make([]int, 0), + } +} + +// GetStore returns the global spell store +func GetStore() *SpellStore { + storeOnce.Do(initStore) + return spellStore +} + +// AddSpell adds a spell to the in-memory store and updates all indices +func (ss *SpellStore) AddSpell(spell *Spell) { + ss.mu.Lock() + defer ss.mu.Unlock() + + // Validate spell + if err := spell.Validate(); err != nil { + return + } + + // Add to base store + ss.Add(spell.ID, spell) + + // Rebuild indices + ss.rebuildIndicesUnsafe() +} + +// RemoveSpell removes a spell from the store and updates indices +func (ss *SpellStore) RemoveSpell(id int) { + ss.mu.Lock() + defer ss.mu.Unlock() + + // Remove from base store + ss.Remove(id) + + // Rebuild indices + ss.rebuildIndicesUnsafe() +} + +// UpdateSpell updates a spell efficiently +func (ss *SpellStore) UpdateSpell(spell *Spell) { + ss.mu.Lock() + defer ss.mu.Unlock() + + // Validate spell + if err := spell.Validate(); err != nil { + return + } + + // Update base store + ss.Add(spell.ID, spell) + + // Rebuild indices + ss.rebuildIndicesUnsafe() +} + +// LoadData loads spell data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + ss := GetStore() + + // Load from base store, which handles JSON loading + if err := ss.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + ss.rebuildIndices() + return nil +} + +// SaveData saves spell data to JSON file +func SaveData(dataPath string) error { + ss := GetStore() + return ss.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (ss *SpellStore) rebuildIndicesUnsafe() { + // Clear indices + ss.byType = make(map[int][]int) + ss.byName = make(map[string]int) + ss.byMP = make(map[int][]int) + ss.allByTypeMP = make([]int, 0) + + // Collect all spells and build indices + allSpells := ss.GetAll() + + for id, spell := range allSpells { + // Type index + ss.byType[spell.Type] = append(ss.byType[spell.Type], id) + + // Name index (case-insensitive) + ss.byName[strings.ToLower(spell.Name)] = id + + // MP index + ss.byMP[spell.MP] = append(ss.byMP[spell.MP], id) + + // All IDs + ss.allByTypeMP = append(ss.allByTypeMP, id) + } + + // Sort allByTypeMP by type, then MP, then ID + sort.Slice(ss.allByTypeMP, func(i, j int) bool { + spellI, _ := ss.GetByID(ss.allByTypeMP[i]) + spellJ, _ := ss.GetByID(ss.allByTypeMP[j]) + if spellI.Type != spellJ.Type { + return spellI.Type < spellJ.Type + } + if spellI.MP != spellJ.MP { + return spellI.MP < spellJ.MP + } + return ss.allByTypeMP[i] < ss.allByTypeMP[j] + }) + + // Sort type indices by MP, then ID + for spellType := range ss.byType { + sort.Slice(ss.byType[spellType], func(i, j int) bool { + spellI, _ := ss.GetByID(ss.byType[spellType][i]) + spellJ, _ := ss.GetByID(ss.byType[spellType][j]) + if spellI.MP != spellJ.MP { + return spellI.MP < spellJ.MP + } + return ss.byType[spellType][i] < ss.byType[spellType][j] + }) + } + + // Sort MP indices by type, then ID + for mp := range ss.byMP { + sort.Slice(ss.byMP[mp], func(i, j int) bool { + spellI, _ := ss.GetByID(ss.byMP[mp][i]) + spellJ, _ := ss.GetByID(ss.byMP[mp][j]) + if spellI.Type != spellJ.Type { + return spellI.Type < spellJ.Type + } + return ss.byMP[mp][i] < ss.byMP[mp][j] + }) + } +} + +// rebuildIndices rebuilds all spell-specific indices from base store data +func (ss *SpellStore) rebuildIndices() { + ss.mu.Lock() + defer ss.mu.Unlock() + ss.rebuildIndicesUnsafe() +} + // Retrieves a spell by ID func Find(id int) (*Spell, error) { - var spell *Spell - - query := `SELECT ` + spellColumns() + ` FROM spells WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - spell = scanSpell(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find spell: %w", err) - } - - if spell == nil { + ss := GetStore() + spell, exists := ss.GetByID(id) + if !exists { return nil, fmt.Errorf("spell with ID %d not found", id) } - return spell, nil } // Retrieves all spells func All() ([]*Spell, error) { - var spells []*Spell + ss := GetStore() + ss.mu.RLock() + defer ss.mu.RUnlock() - query := `SELECT ` + spellColumns() + ` FROM spells ORDER BY type, mp, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - spell := scanSpell(stmt) - spells = append(spells, spell) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve all spells: %w", err) + result := make([]*Spell, 0, len(ss.allByTypeMP)) + for _, id := range ss.allByTypeMP { + if spell, exists := ss.GetByID(id); exists { + result = append(result, spell) + } } - - return spells, nil + return result, nil } // Retrieves spells by type func ByType(spellType int) ([]*Spell, error) { - var spells []*Spell + ss := GetStore() + ss.mu.RLock() + defer ss.mu.RUnlock() - query := `SELECT ` + spellColumns() + ` FROM spells WHERE type = ? ORDER BY mp, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - spell := scanSpell(stmt) - spells = append(spells, spell) - return nil - }, spellType) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve spells by type: %w", err) + ids, exists := ss.byType[spellType] + if !exists { + return []*Spell{}, nil } - return spells, nil + result := make([]*Spell, 0, len(ids)) + for _, id := range ids { + if spell, exists := ss.GetByID(id); exists { + result = append(result, spell) + } + } + return result, nil } // Retrieves spells that cost at most the specified MP func ByMaxMP(maxMP int) ([]*Spell, error) { - var spells []*Spell + ss := GetStore() + ss.mu.RLock() + defer ss.mu.RUnlock() - query := `SELECT ` + spellColumns() + ` FROM spells WHERE mp <= ? ORDER BY type, mp, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - spell := scanSpell(stmt) - spells = append(spells, spell) - return nil - }, maxMP) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve spells by max MP: %w", err) + var result []*Spell + for mp := 0; mp <= maxMP; mp++ { + if ids, exists := ss.byMP[mp]; exists { + for _, id := range ids { + if spell, exists := ss.GetByID(id); exists { + result = append(result, spell) + } + } + } } - - return spells, nil + return result, nil } // Retrieves spells of a specific type that cost at most the specified MP func ByTypeAndMaxMP(spellType, maxMP int) ([]*Spell, error) { - var spells []*Spell + ss := GetStore() + ss.mu.RLock() + defer ss.mu.RUnlock() - query := `SELECT ` + spellColumns() + ` FROM spells WHERE type = ? AND mp <= ? ORDER BY mp, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - spell := scanSpell(stmt) - spells = append(spells, spell) - return nil - }, spellType, maxMP) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve spells by type and max MP: %w", err) + ids, exists := ss.byType[spellType] + if !exists { + return []*Spell{}, nil } - return spells, nil + var result []*Spell + for _, id := range ids { + if spell, exists := ss.GetByID(id); exists && spell.MP <= maxMP { + result = append(result, spell) + } + } + return result, nil } // Retrieves a spell by name (case-insensitive) func ByName(name string) (*Spell, error) { - var spell *Spell + ss := GetStore() + ss.mu.RLock() + defer ss.mu.RUnlock() - query := `SELECT ` + spellColumns() + ` FROM spells WHERE LOWER(name) = LOWER(?) LIMIT 1` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - spell = scanSpell(stmt) - return nil - }, name) - - if err != nil { - return nil, fmt.Errorf("failed to find spell by name: %w", err) + id, exists := ss.byName[strings.ToLower(name)] + if !exists { + return nil, fmt.Errorf("spell with name '%s' not found", name) } - if spell == nil { + spell, exists := ss.GetByID(id) + if !exists { return nil, fmt.Errorf("spell with name '%s' not found", name) } return spell, nil } -// Saves a new spell to the database and sets the ID +// Saves a new spell to the in-memory store and sets the ID func (s *Spell) Insert() error { - columns := `name, mp, attribute, type` - values := []any{s.Name, s.MP, s.Attribute, s.Type} - return database.Insert(s, columns, values...) + ss := GetStore() + + // Validate before insertion + if err := s.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if s.ID == 0 { + s.ID = ss.GetNextID() + } + + // Add to store + ss.AddSpell(s) + return nil } // Returns true if the spell is a healing spell diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..1120a0a --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,198 @@ +package store + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "reflect" + "sync" +) + +// Store provides generic storage operations +type Store[T any] interface { + LoadFromJSON(filename string) error + SaveToJSON(filename string) error + LoadData(dataPath string) error + SaveData(dataPath string) error +} + +// BaseStore provides generic JSON persistence +type BaseStore[T any] struct { + items map[int]*T + maxID int + mu sync.RWMutex + itemType reflect.Type +} + +// NewBaseStore creates a new base store for type T +func NewBaseStore[T any]() *BaseStore[T] { + var zero T + return &BaseStore[T]{ + items: make(map[int]*T), + maxID: 0, + itemType: reflect.TypeOf(zero), + } +} + +// GetNextID returns the next available ID atomically +func (bs *BaseStore[T]) GetNextID() int { + bs.mu.Lock() + defer bs.mu.Unlock() + bs.maxID++ + return bs.maxID +} + +// GetByID retrieves an item by ID +func (bs *BaseStore[T]) GetByID(id int) (*T, bool) { + bs.mu.RLock() + defer bs.mu.RUnlock() + item, exists := bs.items[id] + return item, exists +} + +// Add adds an item to the store +func (bs *BaseStore[T]) Add(id int, item *T) { + bs.mu.Lock() + defer bs.mu.Unlock() + bs.items[id] = item + if id > bs.maxID { + bs.maxID = id + } +} + +// Remove removes an item from the store +func (bs *BaseStore[T]) Remove(id int) { + bs.mu.Lock() + defer bs.mu.Unlock() + delete(bs.items, id) +} + +// GetAll returns all items +func (bs *BaseStore[T]) GetAll() map[int]*T { + bs.mu.RLock() + defer bs.mu.RUnlock() + result := make(map[int]*T, len(bs.items)) + for k, v := range bs.items { + result[k] = v + } + return result +} + +// Clear removes all items +func (bs *BaseStore[T]) Clear() { + bs.mu.Lock() + defer bs.mu.Unlock() + bs.items = make(map[int]*T) + bs.maxID = 0 +} + +// LoadFromJSON loads items from JSON using reflection +func (bs *BaseStore[T]) LoadFromJSON(filename string) error { + bs.mu.Lock() + defer bs.mu.Unlock() + + data, err := os.ReadFile(filename) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read JSON: %w", err) + } + + if len(data) == 0 { + return nil + } + + // Create slice of pointers to T + sliceType := reflect.SliceOf(reflect.PointerTo(bs.itemType)) + slicePtr := reflect.New(sliceType) + + if err := json.Unmarshal(data, slicePtr.Interface()); err != nil { + return fmt.Errorf("failed to unmarshal JSON: %w", err) + } + + // Clear existing data + bs.items = make(map[int]*T) + bs.maxID = 0 + + // Extract items using reflection + slice := slicePtr.Elem() + for i := 0; i < slice.Len(); i++ { + item := slice.Index(i).Interface().(*T) + + // Get ID using reflection + itemValue := reflect.ValueOf(item).Elem() + idField := itemValue.FieldByName("ID") + if !idField.IsValid() { + return fmt.Errorf("item type must have an ID field") + } + + id := int(idField.Int()) + bs.items[id] = item + if id > bs.maxID { + bs.maxID = id + } + } + + return nil +} + +// SaveToJSON saves items to JSON atomically +func (bs *BaseStore[T]) SaveToJSON(filename string) error { + bs.mu.RLock() + defer bs.mu.RUnlock() + + items := make([]*T, 0, len(bs.items)) + for _, item := range bs.items { + items = append(items, item) + } + + data, err := json.MarshalIndent(items, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal to JSON: %w", err) + } + + // Atomic write + tempFile := filename + ".tmp" + if err := os.WriteFile(tempFile, data, 0644); err != nil { + return fmt.Errorf("failed to write temp JSON: %w", err) + } + + if err := os.Rename(tempFile, filename); err != nil { + os.Remove(tempFile) + return fmt.Errorf("failed to rename temp JSON: %w", err) + } + + return nil +} + +// LoadData loads from JSON file or starts empty +func (bs *BaseStore[T]) LoadData(dataPath string) error { + if err := bs.LoadFromJSON(dataPath); err != nil { + if os.IsNotExist(err) { + fmt.Println("No existing data found, starting with empty store") + return nil + } + return fmt.Errorf("failed to load from JSON: %w", err) + } + + fmt.Printf("Loaded %d items from JSON\n", len(bs.items)) + return nil +} + +// SaveData saves to JSON file +func (bs *BaseStore[T]) SaveData(dataPath string) error { + // Ensure directory exists + dataDir := filepath.Dir(dataPath) + if err := os.MkdirAll(dataDir, 0755); err != nil { + return fmt.Errorf("failed to create data directory: %w", err) + } + + if err := bs.SaveToJSON(dataPath); err != nil { + return fmt.Errorf("failed to save to JSON: %w", err) + } + + fmt.Printf("Saved %d items to JSON\n", len(bs.items)) + return nil +} diff --git a/internal/towns/towns.go b/internal/towns/towns.go index 97d76cb..3996b2c 100644 --- a/internal/towns/towns.go +++ b/internal/towns/towns.go @@ -1,53 +1,40 @@ package towns import ( + "dk/internal/store" "fmt" "math" "slices" + "sort" + "strconv" + "strings" + "sync" - "dk/internal/database" "dk/internal/helpers" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" ) -// Town represents a town in the database +// Town represents a town in the game type Town struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Name string `db:"name" json:"name"` - X int `db:"x" json:"x"` - Y int `db:"y" json:"y"` - InnCost int `db:"inn_cost" json:"inn_cost"` - MapCost int `db:"map_cost" json:"map_cost"` - TPCost int `db:"tp_cost" json:"tp_cost"` - ShopList string `db:"shop_list" json:"shop_list"` -} - -func (t *Town) GetTableName() string { - return "towns" -} - -func (t *Town) GetID() int { - return t.ID -} - -func (t *Town) SetID(id int) { - t.ID = id -} - -func (t *Town) Set(field string, value any) error { - return database.Set(t, field, value) + ID int `json:"id"` + Name string `json:"name"` + X int `json:"x"` + Y int `json:"y"` + InnCost int `json:"inn_cost"` + MapCost int `json:"map_cost"` + TPCost int `json:"tp_cost"` + ShopList string `json:"shop_list"` } func (t *Town) Save() error { - return database.Save(t) + townStore := GetStore() + townStore.UpdateTown(t) + return nil } func (t *Town) Delete() error { - return database.Delete(t) + townStore := GetStore() + townStore.RemoveTown(t.ID) + return nil } // Creates a new Town with sensible defaults @@ -63,77 +50,212 @@ func New() *Town { } } -var townScanner = scanner.New[Town]() - -// Returns the column list for town queries -func townColumns() string { - return townScanner.Columns() +// Validate checks if town has valid values +func (t *Town) Validate() error { + if t.Name == "" { + return fmt.Errorf("town name cannot be empty") + } + if t.InnCost < 0 { + return fmt.Errorf("town InnCost cannot be negative") + } + if t.MapCost < 0 { + return fmt.Errorf("town MapCost cannot be negative") + } + if t.TPCost < 0 { + return fmt.Errorf("town TPCost cannot be negative") + } + return nil } -// Populates a Town struct using the fast scanner -func scanTown(stmt *sqlite.Stmt) *Town { - town := &Town{} - townScanner.Scan(stmt, town) - return town +// TownStore provides in-memory storage with O(1) lookups and town-specific indices +type TownStore struct { + *store.BaseStore[Town] // Embedded generic store + byName map[string]int // Name (lowercase) -> ID + byCoords map[string]int // "x,y" -> ID + byInnCost map[int][]int // InnCost -> []ID + byTPCost map[int][]int // TPCost -> []ID + allByID []int // All IDs sorted by ID + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var townStore *TownStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + townStore = &TownStore{ + BaseStore: store.NewBaseStore[Town](), + byName: make(map[string]int), + byCoords: make(map[string]int), + byInnCost: make(map[int][]int), + byTPCost: make(map[int][]int), + allByID: make([]int, 0), + } +} + +// GetStore returns the global town store +func GetStore() *TownStore { + storeOnce.Do(initStore) + return townStore +} + +// AddTown adds a town to the in-memory store and updates all indices +func (ts *TownStore) AddTown(town *Town) { + ts.mu.Lock() + defer ts.mu.Unlock() + + // Validate town + if err := town.Validate(); err != nil { + return + } + + // Add to base store + ts.Add(town.ID, town) + + // Rebuild indices + ts.rebuildIndicesUnsafe() +} + +// RemoveTown removes a town from the store and updates indices +func (ts *TownStore) RemoveTown(id int) { + ts.mu.Lock() + defer ts.mu.Unlock() + + // Remove from base store + ts.Remove(id) + + // Rebuild indices + ts.rebuildIndicesUnsafe() +} + +// UpdateTown updates a town efficiently +func (ts *TownStore) UpdateTown(town *Town) { + ts.mu.Lock() + defer ts.mu.Unlock() + + // Validate town + if err := town.Validate(); err != nil { + return + } + + // Update base store + ts.Add(town.ID, town) + + // Rebuild indices + ts.rebuildIndicesUnsafe() +} + +// LoadData loads town data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + ts := GetStore() + + // Load from base store, which handles JSON loading + if err := ts.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + ts.rebuildIndices() + return nil +} + +// SaveData saves town data to JSON file +func SaveData(dataPath string) error { + ts := GetStore() + return ts.BaseStore.SaveData(dataPath) +} + +// coordsKey creates a key for coordinate-based lookup +func coordsKey(x, y int) string { + return strconv.Itoa(x) + "," + strconv.Itoa(y) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (ts *TownStore) rebuildIndicesUnsafe() { + // Clear indices + ts.byName = make(map[string]int) + ts.byCoords = make(map[string]int) + ts.byInnCost = make(map[int][]int) + ts.byTPCost = make(map[int][]int) + ts.allByID = make([]int, 0) + + // Collect all towns and build indices + allTowns := ts.GetAll() + + for id, town := range allTowns { + // Name index (case-insensitive) + ts.byName[strings.ToLower(town.Name)] = id + + // Coordinates index + ts.byCoords[coordsKey(town.X, town.Y)] = id + + // Cost indices + ts.byInnCost[town.InnCost] = append(ts.byInnCost[town.InnCost], id) + ts.byTPCost[town.TPCost] = append(ts.byTPCost[town.TPCost], id) + + // All IDs + ts.allByID = append(ts.allByID, id) + } + + // Sort all by ID + sort.Ints(ts.allByID) + + // Sort cost indices by ID + for innCost := range ts.byInnCost { + sort.Ints(ts.byInnCost[innCost]) + } + + for tpCost := range ts.byTPCost { + sort.Ints(ts.byTPCost[tpCost]) + } +} + +// rebuildIndices rebuilds all town-specific indices from base store data +func (ts *TownStore) rebuildIndices() { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.rebuildIndicesUnsafe() } // Retrieves a town by ID func Find(id int) (*Town, error) { - var town *Town - - query := `SELECT ` + townColumns() + ` FROM towns WHERE id = ?` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - town = scanTown(stmt) - return nil - }, id) - - if err != nil { - return nil, fmt.Errorf("failed to find town: %w", err) - } - - if town == nil { + ts := GetStore() + town, exists := ts.GetByID(id) + if !exists { return nil, fmt.Errorf("town with ID %d not found", id) } - return town, nil } // Retrieves all towns func All() ([]*Town, error) { - var towns []*Town + ts := GetStore() + ts.mu.RLock() + defer ts.mu.RUnlock() - query := `SELECT ` + townColumns() + ` FROM towns ORDER BY id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - town := scanTown(stmt) - towns = append(towns, town) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve all towns: %w", err) + result := make([]*Town, 0, len(ts.allByID)) + for _, id := range ts.allByID { + if town, exists := ts.GetByID(id); exists { + result = append(result, town) + } } - - return towns, nil + return result, nil } // Retrieves a town by name (case-insensitive) func ByName(name string) (*Town, error) { - var town *Town + ts := GetStore() + ts.mu.RLock() + defer ts.mu.RUnlock() - query := `SELECT ` + townColumns() + ` FROM towns WHERE LOWER(name) = LOWER(?) LIMIT 1` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - town = scanTown(stmt) - return nil - }, name) - - if err != nil { - return nil, fmt.Errorf("failed to find town by name: %w", err) + id, exists := ts.byName[strings.ToLower(name)] + if !exists { + return nil, fmt.Errorf("town with name '%s' not found", name) } - if town == nil { + town, exists := ts.GetByID(id) + if !exists { return nil, fmt.Errorf("town with name '%s' not found", name) } @@ -142,55 +264,56 @@ func ByName(name string) (*Town, error) { // Retrieves towns with inn cost at most the specified amount func ByMaxInnCost(maxCost int) ([]*Town, error) { - var towns []*Town + ts := GetStore() + ts.mu.RLock() + defer ts.mu.RUnlock() - query := `SELECT ` + townColumns() + ` FROM towns WHERE inn_cost <= ? ORDER BY inn_cost, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - town := scanTown(stmt) - towns = append(towns, town) - return nil - }, maxCost) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve towns by max inn cost: %w", err) + var result []*Town + for cost := 0; cost <= maxCost; cost++ { + if ids, exists := ts.byInnCost[cost]; exists { + for _, id := range ids { + if town, exists := ts.GetByID(id); exists { + result = append(result, town) + } + } + } } - - return towns, nil + return result, nil } // Retrieves towns with teleport cost at most the specified amount func ByMaxTPCost(maxCost int) ([]*Town, error) { - var towns []*Town + ts := GetStore() + ts.mu.RLock() + defer ts.mu.RUnlock() - query := `SELECT ` + townColumns() + ` FROM towns WHERE tp_cost <= ? ORDER BY tp_cost, id` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - town := scanTown(stmt) - towns = append(towns, town) - return nil - }, maxCost) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve towns by max TP cost: %w", err) + var result []*Town + for cost := 0; cost <= maxCost; cost++ { + if ids, exists := ts.byTPCost[cost]; exists { + for _, id := range ids { + if town, exists := ts.GetByID(id); exists { + result = append(result, town) + } + } + } } - - return towns, nil + return result, nil } // Retrieves a town by its x, y coordinates func ByCoords(x, y int) (*Town, error) { - var town *Town + ts := GetStore() + ts.mu.RLock() + defer ts.mu.RUnlock() - query := `SELECT ` + townColumns() + ` FROM towns WHERE x = ? AND y = ? LIMIT 1` + id, exists := ts.byCoords[coordsKey(x, y)] + if !exists { + return nil, nil // Return nil if not found (like original) + } - err := database.Query(query, func(stmt *sqlite.Stmt) error { - town = scanTown(stmt) - return nil - }, x, y) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve town by coordinates: %w", err) + town, exists := ts.GetByID(id) + if !exists { + return nil, nil } return town, nil @@ -198,46 +321,61 @@ func ByCoords(x, y int) (*Town, error) { // ExistsAt checks for a town at the given coordinates, returning true/false func ExistsAt(x, y int) bool { - var exists bool + ts := GetStore() + ts.mu.RLock() + defer ts.mu.RUnlock() - query := `SELECT COUNT(*) > 0 FROM towns WHERE x = ? AND y = ? LIMIT 1` - - err := database.Query(query, func(stmt *sqlite.Stmt) error { - exists = stmt.ColumnInt(0) > 0 - return nil - }, x, y) - - return err == nil && exists + _, exists := ts.byCoords[coordsKey(x, y)] + return exists } // Retrieves towns within a certain distance from a point func ByDistance(fromX, fromY, maxDistance int) ([]*Town, error) { - var towns []*Town + ts := GetStore() + ts.mu.RLock() + defer ts.mu.RUnlock() - query := `SELECT ` + townColumns() + ` - FROM towns - WHERE ((x - ?) * (x - ?) + (y - ?) * (y - ?)) <= ? - ORDER BY ((x - ?) * (x - ?) + (y - ?) * (y - ?)), id` + var result []*Town + maxDistance2 := float64(maxDistance * maxDistance) - maxDistance2 := maxDistance * maxDistance - err := database.Query(query, func(stmt *sqlite.Stmt) error { - town := scanTown(stmt) - towns = append(towns, town) - return nil - }, fromX, fromX, fromY, fromY, maxDistance2, fromX, fromX, fromY, fromY) - - if err != nil { - return nil, fmt.Errorf("failed to retrieve towns by distance: %w", err) + for _, id := range ts.allByID { + if town, exists := ts.GetByID(id); exists { + if town.DistanceFromSquared(fromX, fromY) <= maxDistance2 { + result = append(result, town) + } + } } - return towns, nil + // Sort by distance, then by ID + sort.Slice(result, func(i, j int) bool { + distI := result[i].DistanceFromSquared(fromX, fromY) + distJ := result[j].DistanceFromSquared(fromX, fromY) + if distI == distJ { + return result[i].ID < result[j].ID + } + return distI < distJ + }) + + return result, nil } -// Saves a new town to the database and sets the ID +// Saves a new town to the in-memory store and sets the ID func (t *Town) Insert() error { - columns := `name, x, y, inn_cost, map_cost, tp_cost, shop_list` - values := []any{t.Name, t.X, t.Y, t.InnCost, t.MapCost, t.TPCost, t.ShopList} - return database.Insert(t, columns, values...) + ts := GetStore() + + // Validate before insertion + if err := t.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Assign new ID if not set + if t.ID == 0 { + t.ID = ts.GetNextID() + } + + // Add to store + ts.AddTown(t) + return nil } // Returns the shop items as a slice of item IDs @@ -247,7 +385,7 @@ func (t *Town) GetShopItems() []int { // Sets the shop items from a slice of item IDs func (t *Town) SetShopItems(items []int) { - t.Set("ShopList", helpers.IntsToString(items)) + t.ShopList = helpers.IntsToString(items) } // Checks if the town's shop sells a specific item ID @@ -299,6 +437,6 @@ func (t *Town) GetPosition() (int, int) { // Sets the town's coordinates func (t *Town) SetPosition(x, y int) { - t.Set("X", x) - t.Set("Y", y) + t.X = x + t.Y = y } diff --git a/internal/users/users.go b/internal/users/users.go index 14d64fb..fdb6b04 100644 --- a/internal/users/users.go +++ b/internal/users/users.go @@ -1,95 +1,81 @@ package users import ( + "dk/internal/store" "fmt" "slices" + "sort" + "strings" + "sync" "time" - "dk/internal/database" "dk/internal/helpers" - "dk/internal/helpers/scanner" - - "zombiezen.com/go/sqlite" ) -// User represents a user in the database +// User represents a user in the game type User struct { - database.BaseModel - - ID int `db:"id" json:"id"` - Username string `db:"username" json:"username"` - Password string `db:"password" json:"password"` - Email string `db:"email" json:"email"` - Verified int `db:"verified" json:"verified"` - Token string `db:"token" json:"token"` - Registered int64 `db:"registered" json:"registered"` - LastOnline int64 `db:"last_online" json:"last_online"` - Auth int `db:"auth" json:"auth"` - X int `db:"x" json:"x"` - Y int `db:"y" json:"y"` - ClassID int `db:"class_id" json:"class_id"` - Currently string `db:"currently" json:"currently"` - Fighting int `db:"fighting" json:"fighting"` - MonsterID int `db:"monster_id" json:"monster_id"` - MonsterHP int `db:"monster_hp" json:"monster_hp"` - MonsterSleep int `db:"monster_sleep" json:"monster_sleep"` - MonsterImmune int `db:"monster_immune" json:"monster_immune"` - UberDamage int `db:"uber_damage" json:"uber_damage"` - UberDefense int `db:"uber_defense" json:"uber_defense"` - HP int `db:"hp" json:"hp"` - MP int `db:"mp" json:"mp"` - TP int `db:"tp" json:"tp"` - MaxHP int `db:"max_hp" json:"max_hp"` - MaxMP int `db:"max_mp" json:"max_mp"` - MaxTP int `db:"max_tp" json:"max_tp"` - Level int `db:"level" json:"level"` - Gold int `db:"gold" json:"gold"` - Exp int `db:"exp" json:"exp"` - GoldBonus int `db:"gold_bonus" json:"gold_bonus"` - ExpBonus int `db:"exp_bonus" json:"exp_bonus"` - Strength int `db:"strength" json:"strength"` - Dexterity int `db:"dexterity" json:"dexterity"` - Attack int `db:"attack" json:"attack"` - Defense int `db:"defense" json:"defense"` - WeaponID int `db:"weapon_id" json:"weapon_id"` - ArmorID int `db:"armor_id" json:"armor_id"` - ShieldID int `db:"shield_id" json:"shield_id"` - Slot1ID int `db:"slot_1_id" json:"slot_1_id"` - Slot2ID int `db:"slot_2_id" json:"slot_2_id"` - Slot3ID int `db:"slot_3_id" json:"slot_3_id"` - WeaponName string `db:"weapon_name" json:"weapon_name"` - ArmorName string `db:"armor_name" json:"armor_name"` - ShieldName string `db:"shield_name" json:"shield_name"` - Slot1Name string `db:"slot_1_name" json:"slot_1_name"` - Slot2Name string `db:"slot_2_name" json:"slot_2_name"` - Slot3Name string `db:"slot_3_name" json:"slot_3_name"` - DropCode int `db:"drop_code" json:"drop_code"` - Spells string `db:"spells" json:"spells"` - Towns string `db:"towns" json:"towns"` -} - -func (u *User) GetTableName() string { - return "users" -} - -func (u *User) GetID() int { - return u.ID -} - -func (u *User) SetID(id int) { - u.ID = id -} - -func (u *User) Set(field string, value any) error { - return database.Set(u, field, value) + ID int `json:"id"` + Username string `json:"username"` + Password string `json:"password"` + Email string `json:"email"` + Verified int `json:"verified"` + Token string `json:"token"` + Registered int64 `json:"registered"` + LastOnline int64 `json:"last_online"` + Auth int `json:"auth"` + X int `json:"x"` + Y int `json:"y"` + ClassID int `json:"class_id"` + Currently string `json:"currently"` + Fighting int `json:"fighting"` + MonsterID int `json:"monster_id"` + MonsterHP int `json:"monster_hp"` + MonsterSleep int `json:"monster_sleep"` + MonsterImmune int `json:"monster_immune"` + UberDamage int `json:"uber_damage"` + UberDefense int `json:"uber_defense"` + HP int `json:"hp"` + MP int `json:"mp"` + TP int `json:"tp"` + MaxHP int `json:"max_hp"` + MaxMP int `json:"max_mp"` + MaxTP int `json:"max_tp"` + Level int `json:"level"` + Gold int `json:"gold"` + Exp int `json:"exp"` + GoldBonus int `json:"gold_bonus"` + ExpBonus int `json:"exp_bonus"` + Strength int `json:"strength"` + Dexterity int `json:"dexterity"` + Attack int `json:"attack"` + Defense int `json:"defense"` + WeaponID int `json:"weapon_id"` + ArmorID int `json:"armor_id"` + ShieldID int `json:"shield_id"` + Slot1ID int `json:"slot_1_id"` + Slot2ID int `json:"slot_2_id"` + Slot3ID int `json:"slot_3_id"` + WeaponName string `json:"weapon_name"` + ArmorName string `json:"armor_name"` + ShieldName string `json:"shield_name"` + Slot1Name string `json:"slot_1_name"` + Slot2Name string `json:"slot_2_name"` + Slot3Name string `json:"slot_3_name"` + DropCode int `json:"drop_code"` + Spells string `json:"spells"` + Towns string `json:"towns"` } func (u *User) Save() error { - return database.Save(u) + userStore := GetStore() + userStore.UpdateUser(u) + return nil } func (u *User) Delete() error { - return database.Delete(u) + userStore := GetStore() + userStore.RemoveUser(u.ID) + return nil } func New() *User { @@ -123,127 +109,303 @@ func New() *User { } } -var userScanner = scanner.New[User]() - -func userColumns() string { - return userScanner.Columns() +// Validate checks if user has valid values +func (u *User) Validate() error { + if strings.TrimSpace(u.Username) == "" { + return fmt.Errorf("user username cannot be empty") + } + if strings.TrimSpace(u.Email) == "" { + return fmt.Errorf("user email cannot be empty") + } + if u.Registered <= 0 { + return fmt.Errorf("user Registered timestamp must be positive") + } + if u.LastOnline <= 0 { + return fmt.Errorf("user LastOnline timestamp must be positive") + } + if u.Level < 1 { + return fmt.Errorf("user Level must be at least 1") + } + if u.HP < 0 { + return fmt.Errorf("user HP cannot be negative") + } + if u.MaxHP < 1 { + return fmt.Errorf("user MaxHP must be at least 1") + } + return nil } -func scanUser(stmt *sqlite.Stmt) *User { - user := &User{} - userScanner.Scan(stmt, user) - return user +// UserStore provides in-memory storage with O(1) lookups and user-specific indices +type UserStore struct { + *store.BaseStore[User] // Embedded generic store + byUsername map[string]int // Username (lowercase) -> ID + byEmail map[string]int // Email -> ID + byLevel map[int][]int // Level -> []ID + allByRegistered []int // All IDs sorted by registered DESC, id DESC + mu sync.RWMutex // Protects indices +} + +// Global in-memory store +var userStore *UserStore +var storeOnce sync.Once + +// Initialize the in-memory store +func initStore() { + userStore = &UserStore{ + BaseStore: store.NewBaseStore[User](), + byUsername: make(map[string]int), + byEmail: make(map[string]int), + byLevel: make(map[int][]int), + allByRegistered: make([]int, 0), + } +} + +// GetStore returns the global user store +func GetStore() *UserStore { + storeOnce.Do(initStore) + return userStore +} + +// AddUser adds a user to the in-memory store and updates all indices +func (us *UserStore) AddUser(user *User) { + us.mu.Lock() + defer us.mu.Unlock() + + // Validate user + if err := user.Validate(); err != nil { + return + } + + // Add to base store + us.Add(user.ID, user) + + // Rebuild indices + us.rebuildIndicesUnsafe() +} + +// RemoveUser removes a user from the store and updates indices +func (us *UserStore) RemoveUser(id int) { + us.mu.Lock() + defer us.mu.Unlock() + + // Remove from base store + us.Remove(id) + + // Rebuild indices + us.rebuildIndicesUnsafe() +} + +// UpdateUser updates a user efficiently +func (us *UserStore) UpdateUser(user *User) { + us.mu.Lock() + defer us.mu.Unlock() + + // Validate user + if err := user.Validate(); err != nil { + return + } + + // Update base store + us.Add(user.ID, user) + + // Rebuild indices + us.rebuildIndicesUnsafe() +} + +// LoadData loads user data from JSON file, or starts with empty store +func LoadData(dataPath string) error { + us := GetStore() + + // Load from base store, which handles JSON loading + if err := us.BaseStore.LoadData(dataPath); err != nil { + return err + } + + // Rebuild indices from loaded data + us.rebuildIndices() + return nil +} + +// SaveData saves user data to JSON file +func SaveData(dataPath string) error { + us := GetStore() + return us.BaseStore.SaveData(dataPath) +} + +// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock) +func (us *UserStore) rebuildIndicesUnsafe() { + // Clear indices + us.byUsername = make(map[string]int) + us.byEmail = make(map[string]int) + us.byLevel = make(map[int][]int) + us.allByRegistered = make([]int, 0) + + // Collect all users and build indices + allUsers := us.GetAll() + + for id, user := range allUsers { + // Username index (case-insensitive) + us.byUsername[strings.ToLower(user.Username)] = id + + // Email index + us.byEmail[user.Email] = id + + // Level index + us.byLevel[user.Level] = append(us.byLevel[user.Level], id) + + // All IDs + us.allByRegistered = append(us.allByRegistered, id) + } + + // Sort allByRegistered by registered DESC, then ID DESC + sort.Slice(us.allByRegistered, func(i, j int) bool { + userI, _ := us.GetByID(us.allByRegistered[i]) + userJ, _ := us.GetByID(us.allByRegistered[j]) + if userI.Registered != userJ.Registered { + return userI.Registered > userJ.Registered // DESC + } + return us.allByRegistered[i] > us.allByRegistered[j] // DESC + }) + + // Sort level indices by exp DESC, then ID ASC + for level := range us.byLevel { + sort.Slice(us.byLevel[level], func(i, j int) bool { + userI, _ := us.GetByID(us.byLevel[level][i]) + userJ, _ := us.GetByID(us.byLevel[level][j]) + if userI.Exp != userJ.Exp { + return userI.Exp > userJ.Exp // DESC + } + return us.byLevel[level][i] < us.byLevel[level][j] // ASC + }) + } +} + +// rebuildIndices rebuilds all user-specific indices from base store data +func (us *UserStore) rebuildIndices() { + us.mu.Lock() + defer us.mu.Unlock() + us.rebuildIndicesUnsafe() } func Find(id int) (*User, error) { - var user *User - query := `SELECT ` + userColumns() + ` FROM users WHERE id = ?` - err := database.Query(query, func(stmt *sqlite.Stmt) error { - user = scanUser(stmt) - return nil - }, id) - if err != nil { - return nil, fmt.Errorf("failed to find user: %w", err) - } - if user == nil { + us := GetStore() + user, exists := us.GetByID(id) + if !exists { return nil, fmt.Errorf("user with ID %d not found", id) } return user, nil } func All() ([]*User, error) { - var users []*User - query := `SELECT ` + userColumns() + ` FROM users ORDER BY registered DESC, id DESC` - err := database.Query(query, func(stmt *sqlite.Stmt) error { - user := scanUser(stmt) - users = append(users, user) - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to retrieve all users: %w", err) + us := GetStore() + us.mu.RLock() + defer us.mu.RUnlock() + + result := make([]*User, 0, len(us.allByRegistered)) + for _, id := range us.allByRegistered { + if user, exists := us.GetByID(id); exists { + result = append(result, user) + } } - return users, nil + return result, nil } func ByUsername(username string) (*User, error) { - var user *User - query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(username) = LOWER(?) LIMIT 1` - err := database.Query(query, func(stmt *sqlite.Stmt) error { - user = scanUser(stmt) - return nil - }, username) - if err != nil { - return nil, fmt.Errorf("failed to find user by username: %w", err) - } - if user == nil { + us := GetStore() + us.mu.RLock() + defer us.mu.RUnlock() + + id, exists := us.byUsername[strings.ToLower(username)] + if !exists { return nil, fmt.Errorf("user with username '%s' not found", username) } + + user, exists := us.GetByID(id) + if !exists { + return nil, fmt.Errorf("user with username '%s' not found", username) + } + return user, nil } func ByEmail(email string) (*User, error) { - var user *User - query := `SELECT ` + userColumns() + ` FROM users WHERE email = ? LIMIT 1` - err := database.Query(query, func(stmt *sqlite.Stmt) error { - user = scanUser(stmt) - return nil - }, email) - if err != nil { - return nil, fmt.Errorf("failed to find user by email: %w", err) - } - if user == nil { + us := GetStore() + us.mu.RLock() + defer us.mu.RUnlock() + + id, exists := us.byEmail[email] + if !exists { return nil, fmt.Errorf("user with email '%s' not found", email) } + + user, exists := us.GetByID(id) + if !exists { + return nil, fmt.Errorf("user with email '%s' not found", email) + } + return user, nil } func ByLevel(level int) ([]*User, error) { - var users []*User - query := `SELECT ` + userColumns() + ` FROM users WHERE level = ? ORDER BY exp DESC, id ASC` - err := database.Query(query, func(stmt *sqlite.Stmt) error { - user := scanUser(stmt) - users = append(users, user) - return nil - }, level) - if err != nil { - return nil, fmt.Errorf("failed to retrieve users by level: %w", err) + us := GetStore() + us.mu.RLock() + defer us.mu.RUnlock() + + ids, exists := us.byLevel[level] + if !exists { + return []*User{}, nil } - return users, nil + + result := make([]*User, 0, len(ids)) + for _, id := range ids { + if user, exists := us.GetByID(id); exists { + result = append(result, user) + } + } + return result, nil } func Online(within time.Duration) ([]*User, error) { - var users []*User + us := GetStore() + us.mu.RLock() + defer us.mu.RUnlock() + cutoff := time.Now().Add(-within).Unix() - query := `SELECT ` + userColumns() + ` FROM users WHERE last_online >= ? ORDER BY last_online DESC, id ASC` - err := database.Query(query, func(stmt *sqlite.Stmt) error { - user := scanUser(stmt) - users = append(users, user) - return nil - }, cutoff) - if err != nil { - return nil, fmt.Errorf("failed to retrieve online users: %w", err) + var result []*User + + for _, id := range us.allByRegistered { + if user, exists := us.GetByID(id); exists && user.LastOnline >= cutoff { + result = append(result, user) + } } - return users, nil + + // Sort by last_online DESC, then ID ASC + sort.Slice(result, func(i, j int) bool { + if result[i].LastOnline != result[j].LastOnline { + return result[i].LastOnline > result[j].LastOnline // DESC + } + return result[i].ID < result[j].ID // ASC + }) + + return result, nil } func (u *User) Insert() error { - columns := `username, password, email, verified, token, registered, last_online, auth, - x, y, class_id, currently, fighting, monster_id, monster_hp, monster_sleep, monster_immune, - uber_damage, uber_defense, hp, mp, tp, max_hp, max_mp, max_tp, level, gold, exp, - gold_bonus, exp_bonus, strength, dexterity, attack, defense, weapon_id, armor_id, shield_id, - slot_1_id, slot_2_id, slot_3_id, weapon_name, armor_name, shield_name, - slot_1_name, slot_2_name, slot_3_name, drop_code, spells, towns` + us := GetStore() - values := []any{u.Username, u.Password, u.Email, u.Verified, u.Token, - u.Registered, u.LastOnline, u.Auth, u.X, u.Y, u.ClassID, u.Currently, - u.Fighting, u.MonsterID, u.MonsterHP, u.MonsterSleep, u.MonsterImmune, - u.UberDamage, u.UberDefense, u.HP, u.MP, u.TP, u.MaxHP, u.MaxMP, u.MaxTP, - u.Level, u.Gold, u.Exp, u.GoldBonus, u.ExpBonus, u.Strength, u.Dexterity, - u.Attack, u.Defense, u.WeaponID, u.ArmorID, u.ShieldID, u.Slot1ID, - u.Slot2ID, u.Slot3ID, u.WeaponName, u.ArmorName, u.ShieldName, - u.Slot1Name, u.Slot2Name, u.Slot3Name, u.DropCode, u.Spells, u.Towns} + // Validate before insertion + if err := u.Validate(); err != nil { + return fmt.Errorf("validation failed: %w", err) + } - return database.Insert(u, columns, values...) + // Assign new ID if not set + if u.ID == 0 { + u.ID = us.GetNextID() + } + + // Add to store + us.AddUser(u) + return nil } func (u *User) RegisteredTime() time.Time { @@ -255,7 +417,7 @@ func (u *User) LastOnlineTime() time.Time { } func (u *User) UpdateLastOnline() { - u.Set("LastOnline", time.Now().Unix()) + u.LastOnline = time.Now().Unix() } func (u *User) IsVerified() bool { @@ -283,7 +445,7 @@ func (u *User) GetSpellIDs() []int { } func (u *User) SetSpellIDs(spells []int) { - u.Set("Spells", helpers.IntsToString(spells)) + u.Spells = helpers.IntsToString(spells) } func (u *User) HasSpell(spellID int) bool { @@ -295,7 +457,7 @@ func (u *User) GetTownIDs() []int { } func (u *User) SetTownIDs(towns []int) { - u.Set("Towns", helpers.IntsToString(towns)) + u.Towns = helpers.IntsToString(towns) } func (u *User) HasTownMap(townID int) bool { @@ -336,6 +498,6 @@ func (u *User) GetPosition() (int, int) { } func (u *User) SetPosition(x, y int) { - u.Set("X", x) - u.Set("Y", y) + u.X = x + u.Y = y } diff --git a/main.go b/main.go index d6def0f..91ac617 100644 --- a/main.go +++ b/main.go @@ -6,24 +6,19 @@ import ( "log" "os" - "dk/internal/install" "dk/internal/server" ) func main() { var port string flag.StringVar(&port, "p", "3000", "Port to run server on") - + if len(os.Args) < 2 { startServer(port) return } switch os.Args[1] { - case "install": - if err := install.Run(); err != nil { - log.Fatalf("Installation failed: %v", err) - } case "serve": flag.CommandLine.Parse(os.Args[2:]) startServer(port) @@ -42,4 +37,4 @@ func startServer(port string) { if err := server.Start(port); err != nil { log.Fatalf("Server failed: %v", err) } -} \ No newline at end of file +} diff --git a/test_load_data.go b/test_load_data.go deleted file mode 100644 index ddc1355..0000000 --- a/test_load_data.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "fmt" - "dk/internal/monsters" -) - -func main() { - fmt.Println("Testing LoadData() function...") - - err := monsters.LoadData() - if err != nil { - fmt.Printf("LoadData() failed: %v\n", err) - return - } - - // Test that we can find a monster - monster, err := monsters.Find(1) - if err != nil { - fmt.Printf("Find(1) failed: %v\n", err) - return - } - - fmt.Printf("Successfully loaded data! Found monster: %s (Level %d)\n", monster.Name, monster.Level) - - // Test getting all monsters - all, err := monsters.All() - if err != nil { - fmt.Printf("All() failed: %v\n", err) - return - } - - fmt.Printf("Total monsters loaded: %d\n", len(all)) -} \ No newline at end of file