migrate all models to in-memory
This commit is contained in:
parent
958a7098a2
commit
c2eeaa2f42
@ -1,48 +1,32 @@
|
||||
package babble
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Babble represents a global chat message in the database
|
||||
// Babble represents a global chat message in the game
|
||||
type Babble struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Posted int64 `db:"posted" json:"posted"`
|
||||
Author string `db:"author" json:"author"`
|
||||
Babble string `db:"babble" json:"babble"`
|
||||
}
|
||||
|
||||
func (b *Babble) GetTableName() string {
|
||||
return "babble"
|
||||
}
|
||||
|
||||
func (b *Babble) GetID() int {
|
||||
return b.ID
|
||||
}
|
||||
|
||||
func (b *Babble) SetID(id int) {
|
||||
b.ID = id
|
||||
}
|
||||
|
||||
func (b *Babble) Set(field string, value any) error {
|
||||
return database.Set(b, field, value)
|
||||
ID int `json:"id"`
|
||||
Posted int64 `json:"posted"`
|
||||
Author string `json:"author"`
|
||||
Babble string `json:"babble"`
|
||||
}
|
||||
|
||||
func (b *Babble) Save() error {
|
||||
return database.Save(b)
|
||||
babbleStore := GetStore()
|
||||
babbleStore.UpdateBabble(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Babble) Delete() error {
|
||||
return database.Delete(b)
|
||||
babbleStore := GetStore()
|
||||
babbleStore.RemoveBabble(b.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new Babble with sensible defaults
|
||||
@ -54,181 +38,315 @@ func New() *Babble {
|
||||
}
|
||||
}
|
||||
|
||||
var babbleScanner = scanner.New[Babble]()
|
||||
|
||||
// Returns the column list for babble queries
|
||||
func babbleColumns() string {
|
||||
return babbleScanner.Columns()
|
||||
// Validate checks if babble has valid values
|
||||
func (b *Babble) Validate() error {
|
||||
if b.Posted <= 0 {
|
||||
return fmt.Errorf("babble Posted timestamp must be positive")
|
||||
}
|
||||
if strings.TrimSpace(b.Author) == "" {
|
||||
return fmt.Errorf("babble Author cannot be empty")
|
||||
}
|
||||
if strings.TrimSpace(b.Babble) == "" {
|
||||
return fmt.Errorf("babble message cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Populates a Babble struct using the fast scanner
|
||||
func scanBabble(stmt *sqlite.Stmt) *Babble {
|
||||
babble := &Babble{}
|
||||
babbleScanner.Scan(stmt, babble)
|
||||
return babble
|
||||
// BabbleStore provides in-memory storage with O(1) lookups and babble-specific indices
|
||||
type BabbleStore struct {
|
||||
*store.BaseStore[Babble] // Embedded generic store
|
||||
byAuthor map[string][]int // Author (lowercase) -> []ID
|
||||
allByPosted []int // All IDs sorted by posted DESC, id DESC
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var babbleStore *BabbleStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
babbleStore = &BabbleStore{
|
||||
BaseStore: store.NewBaseStore[Babble](),
|
||||
byAuthor: make(map[string][]int),
|
||||
allByPosted: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global babble store
|
||||
func GetStore() *BabbleStore {
|
||||
storeOnce.Do(initStore)
|
||||
return babbleStore
|
||||
}
|
||||
|
||||
// AddBabble adds a babble message to the in-memory store and updates all indices
|
||||
func (bs *BabbleStore) AddBabble(babble *Babble) {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
|
||||
// Validate babble
|
||||
if err := babble.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
bs.Add(babble.ID, babble)
|
||||
|
||||
// Rebuild indices
|
||||
bs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveBabble removes a babble message from the store and updates indices
|
||||
func (bs *BabbleStore) RemoveBabble(id int) {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
bs.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
bs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateBabble updates a babble message efficiently
|
||||
func (bs *BabbleStore) UpdateBabble(babble *Babble) {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
|
||||
// Validate babble
|
||||
if err := babble.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
bs.Add(babble.ID, babble)
|
||||
|
||||
// Rebuild indices
|
||||
bs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads babble data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
bs := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := bs.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
bs.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves babble data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
bs := GetStore()
|
||||
return bs.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (bs *BabbleStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
bs.byAuthor = make(map[string][]int)
|
||||
bs.allByPosted = make([]int, 0)
|
||||
|
||||
// Collect all babbles and build indices
|
||||
allBabbles := bs.GetAll()
|
||||
|
||||
for id, babble := range allBabbles {
|
||||
// Author index (case-insensitive)
|
||||
authorKey := strings.ToLower(babble.Author)
|
||||
bs.byAuthor[authorKey] = append(bs.byAuthor[authorKey], id)
|
||||
|
||||
// All IDs
|
||||
bs.allByPosted = append(bs.allByPosted, id)
|
||||
}
|
||||
|
||||
// Sort allByPosted by posted DESC, then ID DESC
|
||||
sort.Slice(bs.allByPosted, func(i, j int) bool {
|
||||
babbleI, _ := bs.GetByID(bs.allByPosted[i])
|
||||
babbleJ, _ := bs.GetByID(bs.allByPosted[j])
|
||||
if babbleI.Posted != babbleJ.Posted {
|
||||
return babbleI.Posted > babbleJ.Posted // DESC
|
||||
}
|
||||
return bs.allByPosted[i] > bs.allByPosted[j] // DESC
|
||||
})
|
||||
|
||||
// Sort author indices by posted DESC, then ID DESC
|
||||
for author := range bs.byAuthor {
|
||||
sort.Slice(bs.byAuthor[author], func(i, j int) bool {
|
||||
babbleI, _ := bs.GetByID(bs.byAuthor[author][i])
|
||||
babbleJ, _ := bs.GetByID(bs.byAuthor[author][j])
|
||||
if babbleI.Posted != babbleJ.Posted {
|
||||
return babbleI.Posted > babbleJ.Posted // DESC
|
||||
}
|
||||
return bs.byAuthor[author][i] > bs.byAuthor[author][j] // DESC
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all babble-specific indices from base store data
|
||||
func (bs *BabbleStore) rebuildIndices() {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
bs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves a babble message by ID
|
||||
func Find(id int) (*Babble, error) {
|
||||
var babble *Babble
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble = scanBabble(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find babble: %w", err)
|
||||
}
|
||||
|
||||
if babble == nil {
|
||||
bs := GetStore()
|
||||
babble, exists := bs.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("babble with ID %d not found", id)
|
||||
}
|
||||
|
||||
return babble, nil
|
||||
}
|
||||
|
||||
// Retrieves all babble messages ordered by posted time (newest first)
|
||||
func All() ([]*Babble, error) {
|
||||
var babbles []*Babble
|
||||
bs := GetStore()
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble := scanBabble(stmt)
|
||||
babbles = append(babbles, babble)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all babble: %w", err)
|
||||
result := make([]*Babble, 0, len(bs.allByPosted))
|
||||
for _, id := range bs.allByPosted {
|
||||
if babble, exists := bs.GetByID(id); exists {
|
||||
result = append(result, babble)
|
||||
}
|
||||
}
|
||||
|
||||
return babbles, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves babble messages by a specific author
|
||||
func ByAuthor(author string) ([]*Babble, error) {
|
||||
var babbles []*Babble
|
||||
bs := GetStore()
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble := scanBabble(stmt)
|
||||
babbles = append(babbles, babble)
|
||||
return nil
|
||||
}, author)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve babble by author: %w", err)
|
||||
ids, exists := bs.byAuthor[strings.ToLower(author)]
|
||||
if !exists {
|
||||
return []*Babble{}, nil
|
||||
}
|
||||
|
||||
return babbles, nil
|
||||
result := make([]*Babble, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if babble, exists := bs.GetByID(id); exists {
|
||||
result = append(result, babble)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves the most recent babble messages (limited by count)
|
||||
func Recent(limit int) ([]*Babble, error) {
|
||||
var babbles []*Babble
|
||||
bs := GetStore()
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble ORDER BY posted DESC, id DESC LIMIT ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble := scanBabble(stmt)
|
||||
babbles = append(babbles, babble)
|
||||
return nil
|
||||
}, limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve recent babble: %w", err)
|
||||
if limit > len(bs.allByPosted) {
|
||||
limit = len(bs.allByPosted)
|
||||
}
|
||||
|
||||
return babbles, nil
|
||||
result := make([]*Babble, 0, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
if babble, exists := bs.GetByID(bs.allByPosted[i]); exists {
|
||||
result = append(result, babble)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves babble messages since a specific timestamp
|
||||
func Since(since int64) ([]*Babble, error) {
|
||||
var babbles []*Babble
|
||||
bs := GetStore()
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble WHERE posted >= ? ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble := scanBabble(stmt)
|
||||
babbles = append(babbles, babble)
|
||||
return nil
|
||||
}, since)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve babble since timestamp: %w", err)
|
||||
var result []*Babble
|
||||
for _, id := range bs.allByPosted {
|
||||
if babble, exists := bs.GetByID(id); exists && babble.Posted >= since {
|
||||
result = append(result, babble)
|
||||
}
|
||||
}
|
||||
|
||||
return babbles, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves babble messages between two timestamps (inclusive)
|
||||
func Between(start, end int64) ([]*Babble, error) {
|
||||
var babbles []*Babble
|
||||
bs := GetStore()
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble := scanBabble(stmt)
|
||||
babbles = append(babbles, babble)
|
||||
return nil
|
||||
}, start, end)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve babble between timestamps: %w", err)
|
||||
var result []*Babble
|
||||
for _, id := range bs.allByPosted {
|
||||
if babble, exists := bs.GetByID(id); exists && babble.Posted >= start && babble.Posted <= end {
|
||||
result = append(result, babble)
|
||||
}
|
||||
}
|
||||
|
||||
return babbles, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves babble messages containing the search term (case-insensitive)
|
||||
func Search(term string) ([]*Babble, error) {
|
||||
var babbles []*Babble
|
||||
bs := GetStore()
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble WHERE LOWER(babble) LIKE LOWER(?) ORDER BY posted DESC, id DESC`
|
||||
searchTerm := "%" + term + "%"
|
||||
var result []*Babble
|
||||
lowerTerm := strings.ToLower(term)
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble := scanBabble(stmt)
|
||||
babbles = append(babbles, babble)
|
||||
return nil
|
||||
}, searchTerm)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search babble: %w", err)
|
||||
for _, id := range bs.allByPosted {
|
||||
if babble, exists := bs.GetByID(id); exists {
|
||||
if strings.Contains(strings.ToLower(babble.Babble), lowerTerm) {
|
||||
result = append(result, babble)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return babbles, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves recent messages from a specific author
|
||||
func RecentByAuthor(author string, limit int) ([]*Babble, error) {
|
||||
var babbles []*Babble
|
||||
bs := GetStore()
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + babbleColumns() + ` FROM babble WHERE LOWER(author) = LOWER(?) ORDER BY posted DESC, id DESC LIMIT ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
babble := scanBabble(stmt)
|
||||
babbles = append(babbles, babble)
|
||||
return nil
|
||||
}, author, limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve recent babble by author: %w", err)
|
||||
ids, exists := bs.byAuthor[strings.ToLower(author)]
|
||||
if !exists {
|
||||
return []*Babble{}, nil
|
||||
}
|
||||
|
||||
return babbles, nil
|
||||
if limit > len(ids) {
|
||||
limit = len(ids)
|
||||
}
|
||||
|
||||
result := make([]*Babble, 0, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
if babble, exists := bs.GetByID(ids[i]); exists {
|
||||
result = append(result, babble)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new babble to the database and sets the ID
|
||||
// Saves a new babble to the in-memory store and sets the ID
|
||||
func (b *Babble) Insert() error {
|
||||
columns := `posted, author, babble`
|
||||
values := []any{b.Posted, b.Author, b.Babble}
|
||||
return database.Insert(b, columns, values...)
|
||||
bs := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := b.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if b.ID == 0 {
|
||||
b.ID = bs.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
bs.AddBabble(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the posted timestamp as a time.Time
|
||||
@ -238,7 +356,7 @@ func (b *Babble) PostedTime() time.Time {
|
||||
|
||||
// Sets the posted timestamp from a time.Time
|
||||
func (b *Babble) SetPostedTime(t time.Time) {
|
||||
b.Set("Posted", t.Unix())
|
||||
b.Posted = t.Unix()
|
||||
}
|
||||
|
||||
// Returns true if the babble message was posted within the last hour
|
||||
|
@ -1,49 +1,32 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Control represents the game control settings in the database
|
||||
// Control represents the game control settings
|
||||
type Control struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
WorldSize int `db:"world_size" json:"world_size"`
|
||||
Open int `db:"open" json:"open"`
|
||||
AdminEmail string `db:"admin_email" json:"admin_email"`
|
||||
Class1Name string `db:"class_1_name" json:"class_1_name"`
|
||||
Class2Name string `db:"class_2_name" json:"class_2_name"`
|
||||
Class3Name string `db:"class_3_name" json:"class_3_name"`
|
||||
}
|
||||
|
||||
func (c *Control) GetTableName() string {
|
||||
return "control"
|
||||
}
|
||||
|
||||
func (c *Control) GetID() int {
|
||||
return c.ID
|
||||
}
|
||||
|
||||
func (c *Control) SetID(id int) {
|
||||
c.ID = id
|
||||
}
|
||||
|
||||
func (c *Control) Set(field string, value any) error {
|
||||
return database.Set(c, field, value)
|
||||
ID int `json:"id"`
|
||||
WorldSize int `json:"world_size"`
|
||||
Open int `json:"open"`
|
||||
AdminEmail string `json:"admin_email"`
|
||||
Class1Name string `json:"class_1_name"`
|
||||
Class2Name string `json:"class_2_name"`
|
||||
Class3Name string `json:"class_3_name"`
|
||||
}
|
||||
|
||||
func (c *Control) Save() error {
|
||||
return database.Save(c)
|
||||
controlStore := GetStore()
|
||||
controlStore.UpdateControl(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) Delete() error {
|
||||
return database.Delete(c)
|
||||
controlStore := GetStore()
|
||||
controlStore.RemoveControl(c.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new Control with sensible defaults
|
||||
@ -58,39 +41,144 @@ func New() *Control {
|
||||
}
|
||||
}
|
||||
|
||||
var controlScanner = scanner.New[Control]()
|
||||
|
||||
// Returns the column list for control queries
|
||||
func controlColumns() string {
|
||||
return controlScanner.Columns()
|
||||
// Validate checks if control has valid values
|
||||
func (c *Control) Validate() error {
|
||||
if c.WorldSize <= 0 || c.WorldSize > 10000 {
|
||||
return fmt.Errorf("control WorldSize must be between 1 and 10000")
|
||||
}
|
||||
if c.Open != 0 && c.Open != 1 {
|
||||
return fmt.Errorf("control Open must be 0 or 1")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Populates a Control struct using the fast scanner
|
||||
func scanControl(stmt *sqlite.Stmt) *Control {
|
||||
control := &Control{}
|
||||
controlScanner.Scan(stmt, control)
|
||||
return control
|
||||
// ControlStore provides in-memory storage for control settings
|
||||
type ControlStore struct {
|
||||
*store.BaseStore[Control] // Embedded generic store
|
||||
allByID []int // All IDs sorted by ID
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var controlStore *ControlStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
controlStore = &ControlStore{
|
||||
BaseStore: store.NewBaseStore[Control](),
|
||||
allByID: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global control store
|
||||
func GetStore() *ControlStore {
|
||||
storeOnce.Do(initStore)
|
||||
return controlStore
|
||||
}
|
||||
|
||||
// AddControl adds a control to the in-memory store and updates all indices
|
||||
func (cs *ControlStore) AddControl(control *Control) {
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
|
||||
// Validate control
|
||||
if err := control.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
cs.Add(control.ID, control)
|
||||
|
||||
// Rebuild indices
|
||||
cs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveControl removes a control from the store and updates indices
|
||||
func (cs *ControlStore) RemoveControl(id int) {
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
cs.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
cs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateControl updates a control efficiently
|
||||
func (cs *ControlStore) UpdateControl(control *Control) {
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
|
||||
// Validate control
|
||||
if err := control.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
cs.Add(control.ID, control)
|
||||
|
||||
// Rebuild indices
|
||||
cs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads control data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
cs := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := cs.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
cs.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves control data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
cs := GetStore()
|
||||
return cs.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (cs *ControlStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
cs.allByID = make([]int, 0)
|
||||
|
||||
// Collect all controls
|
||||
allControls := cs.GetAll()
|
||||
|
||||
for id := range allControls {
|
||||
cs.allByID = append(cs.allByID, id)
|
||||
}
|
||||
|
||||
// Sort by ID (though typically only one control record exists)
|
||||
for i := 0; i < len(cs.allByID); i++ {
|
||||
for j := i + 1; j < len(cs.allByID); j++ {
|
||||
if cs.allByID[i] > cs.allByID[j] {
|
||||
cs.allByID[i], cs.allByID[j] = cs.allByID[j], cs.allByID[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all control-specific indices from base store data
|
||||
func (cs *ControlStore) rebuildIndices() {
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
cs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves the control record by ID (typically only ID 1 exists)
|
||||
func Find(id int) (*Control, error) {
|
||||
var control *Control
|
||||
|
||||
query := `SELECT ` + controlColumns() + ` FROM control WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
control = scanControl(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find control: %w", err)
|
||||
}
|
||||
|
||||
if control == nil {
|
||||
cs := GetStore()
|
||||
control, exists := cs.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("control with ID %d not found", id)
|
||||
}
|
||||
|
||||
return control, nil
|
||||
}
|
||||
|
||||
@ -99,11 +187,23 @@ func Get() (*Control, error) {
|
||||
return Find(1)
|
||||
}
|
||||
|
||||
// Saves a new control to the database and sets the ID
|
||||
// Saves a new control to the in-memory store and sets the ID
|
||||
func (c *Control) Insert() error {
|
||||
columns := `world_size, open, admin_email, class_1_name, class_2_name, class_3_name`
|
||||
values := []any{c.WorldSize, c.Open, c.AdminEmail, c.Class1Name, c.Class2Name, c.Class3Name}
|
||||
return database.Insert(c, columns, values...)
|
||||
cs := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := c.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if c.ID == 0 {
|
||||
c.ID = cs.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
cs.AddControl(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns true if the game world is open for new players
|
||||
@ -114,20 +214,20 @@ func (c *Control) IsOpen() bool {
|
||||
// Sets whether the game world is open for new players
|
||||
func (c *Control) SetOpen(open bool) {
|
||||
if open {
|
||||
c.Set("Open", 1)
|
||||
c.Open = 1
|
||||
} else {
|
||||
c.Set("Open", 0)
|
||||
c.Open = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Closes the game world to new players
|
||||
func (c *Control) Close() {
|
||||
c.Set("Open", 0)
|
||||
c.Open = 0
|
||||
}
|
||||
|
||||
// Opens the game world to new players
|
||||
func (c *Control) OpenWorld() {
|
||||
c.Set("Open", 1)
|
||||
c.Open = 1
|
||||
}
|
||||
|
||||
// Returns all class names as a slice
|
||||
@ -148,19 +248,19 @@ func (c *Control) GetClassNames() []string {
|
||||
// Sets all class names from a slice
|
||||
func (c *Control) SetClassNames(classes []string) {
|
||||
// Reset all class names
|
||||
c.Set("Class1Name", "")
|
||||
c.Set("Class2Name", "")
|
||||
c.Set("Class3Name", "")
|
||||
c.Class1Name = ""
|
||||
c.Class2Name = ""
|
||||
c.Class3Name = ""
|
||||
|
||||
// Set provided class names
|
||||
if len(classes) > 0 {
|
||||
c.Set("Class1Name", classes[0])
|
||||
c.Class1Name = classes[0]
|
||||
}
|
||||
if len(classes) > 1 {
|
||||
c.Set("Class2Name", classes[1])
|
||||
c.Class2Name = classes[1]
|
||||
}
|
||||
if len(classes) > 2 {
|
||||
c.Set("Class3Name", classes[2])
|
||||
c.Class3Name = classes[2]
|
||||
}
|
||||
}
|
||||
|
||||
@ -182,13 +282,13 @@ func (c *Control) GetClassName(classNum int) string {
|
||||
func (c *Control) SetClassName(classNum int, name string) bool {
|
||||
switch classNum {
|
||||
case 1:
|
||||
c.Set("Class1Name", name)
|
||||
c.Class1Name = name
|
||||
return true
|
||||
case 2:
|
||||
c.Set("Class2Name", name)
|
||||
c.Class2Name = name
|
||||
return true
|
||||
case 3:
|
||||
c.Set("Class3Name", name)
|
||||
c.Class3Name = name
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
@ -1,201 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
const DefaultPath = "dk.db"
|
||||
|
||||
// Global singleton instance
|
||||
var pool *sqlitex.Pool
|
||||
|
||||
// Init initializes the global database connection pool
|
||||
func Init(path string) error {
|
||||
if path == "" {
|
||||
path = DefaultPath
|
||||
}
|
||||
|
||||
poolSize := max(runtime.GOMAXPROCS(0), 2)
|
||||
|
||||
var err error
|
||||
pool, err = sqlitex.NewPool(path, sqlitex.PoolOptions{
|
||||
PoolSize: poolSize,
|
||||
Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenWAL,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open database pool: %w", err)
|
||||
}
|
||||
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
pool.Close()
|
||||
return fmt.Errorf("failed to get connection from pool: %w", err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
|
||||
pool.Close()
|
||||
return fmt.Errorf("failed to set WAL mode: %w", err)
|
||||
}
|
||||
|
||||
if err := sqlitex.ExecuteTransient(conn, "PRAGMA synchronous = NORMAL", nil); err != nil {
|
||||
pool.Close()
|
||||
return fmt.Errorf("failed to set synchronous mode: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the global database connection pool
|
||||
func Close() error {
|
||||
if pool == nil {
|
||||
return nil
|
||||
}
|
||||
return pool.Close()
|
||||
}
|
||||
|
||||
// GetConn gets a connection from the pool - caller must call PutConn when done
|
||||
func GetConn(ctx context.Context) (*sqlite.Conn, error) {
|
||||
if pool == nil {
|
||||
return nil, fmt.Errorf("database not initialized")
|
||||
}
|
||||
return pool.Take(ctx)
|
||||
}
|
||||
|
||||
// PutConn returns a connection to the pool
|
||||
func PutConn(conn *sqlite.Conn) {
|
||||
if pool != nil {
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Exec executes a SQL statement without returning results
|
||||
func Exec(query string, args ...any) error {
|
||||
if pool == nil {
|
||||
return fmt.Errorf("database not initialized")
|
||||
}
|
||||
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection from pool: %w", err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
if len(args) == 0 {
|
||||
return sqlitex.ExecuteTransient(conn, query, nil)
|
||||
}
|
||||
|
||||
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Query executes a SQL query and calls fn for each row
|
||||
func Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
|
||||
if pool == nil {
|
||||
return fmt.Errorf("database not initialized")
|
||||
}
|
||||
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection from pool: %w", err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
if len(args) == 0 {
|
||||
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
|
||||
ResultFunc: fn,
|
||||
})
|
||||
}
|
||||
|
||||
return sqlitex.ExecuteTransient(conn, query, &sqlitex.ExecOptions{
|
||||
Args: args,
|
||||
ResultFunc: fn,
|
||||
})
|
||||
}
|
||||
|
||||
// Begin starts a new transaction
|
||||
func Begin() (*Tx, error) {
|
||||
if pool == nil {
|
||||
return nil, fmt.Errorf("database not initialized")
|
||||
}
|
||||
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
||||
}
|
||||
|
||||
if err := sqlitex.ExecuteTransient(conn, "BEGIN", nil); err != nil {
|
||||
pool.Put(conn)
|
||||
return nil, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
return &Tx{conn: conn, pool: pool}, nil
|
||||
}
|
||||
|
||||
// Transaction runs a function within a transaction
|
||||
func Transaction(fn func(*Tx) error) error {
|
||||
if pool == nil {
|
||||
return fmt.Errorf("database not initialized")
|
||||
}
|
||||
|
||||
tx, err := Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// Tx represents a database transaction
|
||||
type Tx struct {
|
||||
conn *sqlite.Conn
|
||||
pool *sqlitex.Pool
|
||||
}
|
||||
|
||||
// Exec executes a SQL statement within the transaction
|
||||
func (tx *Tx) Exec(query string, args ...any) error {
|
||||
if len(args) == 0 {
|
||||
return sqlitex.ExecuteTransient(tx.conn, query, nil)
|
||||
}
|
||||
|
||||
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Query executes a SQL query within the transaction
|
||||
func (tx *Tx) Query(query string, fn func(*sqlite.Stmt) error, args ...any) error {
|
||||
if len(args) == 0 {
|
||||
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
|
||||
ResultFunc: fn,
|
||||
})
|
||||
}
|
||||
|
||||
return sqlitex.ExecuteTransient(tx.conn, query, &sqlitex.ExecOptions{
|
||||
Args: args,
|
||||
ResultFunc: fn,
|
||||
})
|
||||
}
|
||||
|
||||
// Commit commits the transaction
|
||||
func (tx *Tx) Commit() error {
|
||||
defer tx.pool.Put(tx.conn)
|
||||
return sqlitex.ExecuteTransient(tx.conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
// Rollback rolls back the transaction
|
||||
func (tx *Tx) Rollback() error {
|
||||
defer tx.pool.Put(tx.conn)
|
||||
return sqlitex.ExecuteTransient(tx.conn, "ROLLBACK", nil)
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
func TestDatabaseOperations(t *testing.T) {
|
||||
// Use a temporary database file
|
||||
testDB := "test.db"
|
||||
defer os.Remove(testDB)
|
||||
|
||||
// Initialize the singleton database
|
||||
err := Init(testDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
defer Close()
|
||||
|
||||
// Test creating a simple table
|
||||
err = Exec("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
// Test inserting data
|
||||
err = Exec("INSERT INTO test_users (name) VALUES (?)", "Alice")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert data: %v", err)
|
||||
}
|
||||
|
||||
// Test querying data
|
||||
var foundName string
|
||||
err = Query("SELECT name FROM test_users WHERE name = ?", func(stmt *sqlite.Stmt) error {
|
||||
foundName = stmt.ColumnText(0)
|
||||
return nil
|
||||
}, "Alice")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query data: %v", err)
|
||||
}
|
||||
|
||||
if foundName != "Alice" {
|
||||
t.Errorf("Expected 'Alice', got '%s'", foundName)
|
||||
}
|
||||
|
||||
// Test transaction
|
||||
err = Transaction(func(tx *Tx) error {
|
||||
return tx.Exec("INSERT INTO test_users (name) VALUES (?)", "Bob")
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Transaction failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify transaction worked
|
||||
var count int
|
||||
err = Query("SELECT COUNT(*) FROM test_users", func(stmt *sqlite.Stmt) error {
|
||||
count = stmt.ColumnInt(0)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to count users: %v", err)
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 users, got %d", count)
|
||||
}
|
||||
}
|
@ -1,129 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Model interface for trackable database models
|
||||
type Model interface {
|
||||
GetTableName() string
|
||||
GetID() int
|
||||
SetID(id int)
|
||||
GetDirtyFields() map[string]any
|
||||
SetDirty(field string, value any)
|
||||
ClearDirty()
|
||||
IsDirty() bool
|
||||
}
|
||||
|
||||
// BaseModel provides common model functionality
|
||||
type BaseModel struct {
|
||||
FieldTracker
|
||||
}
|
||||
|
||||
// Set uses reflection to set a field and track changes
|
||||
func Set(model Model, field string, value any) error {
|
||||
v := reflect.ValueOf(model).Elem()
|
||||
t := v.Type()
|
||||
|
||||
fieldVal := v.FieldByName(field)
|
||||
if !fieldVal.IsValid() {
|
||||
return fmt.Errorf("field %s does not exist", field)
|
||||
}
|
||||
if !fieldVal.CanSet() {
|
||||
return fmt.Errorf("field %s cannot be set", field)
|
||||
}
|
||||
|
||||
// Get current value for comparison
|
||||
currentVal := fieldVal.Interface()
|
||||
|
||||
// Only set if value has changed
|
||||
if !reflect.DeepEqual(currentVal, value) {
|
||||
newVal := reflect.ValueOf(value)
|
||||
if newVal.Type().ConvertibleTo(fieldVal.Type()) {
|
||||
fieldVal.Set(newVal.Convert(fieldVal.Type()))
|
||||
|
||||
// Get db column name from struct tag
|
||||
structField, _ := t.FieldByName(field)
|
||||
dbField := structField.Tag.Get("db")
|
||||
if dbField == "" {
|
||||
dbField = toSnakeCase(field) // fallback
|
||||
}
|
||||
|
||||
model.SetDirty(dbField, value)
|
||||
} else {
|
||||
return fmt.Errorf("cannot convert %T to %s", value, fieldVal.Type())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// toSnakeCase converts CamelCase to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
prev := rune(s[i-1])
|
||||
if prev < 'A' || prev > 'Z' {
|
||||
result.WriteByte('_')
|
||||
}
|
||||
}
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
result.WriteRune(r - 'A' + 'a')
|
||||
} else {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// Save updates only dirty fields
|
||||
func Save(model Model) error {
|
||||
if model.GetID() == 0 {
|
||||
return fmt.Errorf("cannot save model without ID")
|
||||
}
|
||||
return UpdateDirty(model)
|
||||
}
|
||||
|
||||
// Insert creates a new record and sets the ID
|
||||
func Insert(model Model, columns string, values ...any) error {
|
||||
if model.GetID() != 0 {
|
||||
return fmt.Errorf("model already has ID %d, use Save() to update", model.GetID())
|
||||
}
|
||||
|
||||
return Transaction(func(tx *Tx) error {
|
||||
placeholders := strings.Repeat("?,", len(values))
|
||||
placeholders = placeholders[:len(placeholders)-1] // Remove trailing comma
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
model.GetTableName(), columns, placeholders)
|
||||
|
||||
if err := tx.Exec(query, values...); err != nil {
|
||||
return fmt.Errorf("failed to insert: %w", err)
|
||||
}
|
||||
|
||||
var id int
|
||||
err := tx.Query("SELECT last_insert_rowid()", func(stmt *sqlite.Stmt) error {
|
||||
id = stmt.ColumnInt(0)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get insert ID: %w", err)
|
||||
}
|
||||
|
||||
model.SetID(id)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Delete removes the record
|
||||
func Delete(model Model) error {
|
||||
if model.GetID() == 0 {
|
||||
return fmt.Errorf("cannot delete model without ID")
|
||||
}
|
||||
return Exec("DELETE FROM ? WHERE id = ?", model.GetTableName(), model.GetID())
|
||||
}
|
@ -1,82 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Trackable interface for models that can track field changes
|
||||
type Trackable interface {
|
||||
GetTableName() string
|
||||
GetID() int
|
||||
GetDirtyFields() map[string]any
|
||||
SetDirty(field string, value any)
|
||||
ClearDirty()
|
||||
IsDirty() bool
|
||||
}
|
||||
|
||||
// FieldTracker provides dirty field tracking functionality
|
||||
type FieldTracker struct {
|
||||
dirty map[string]any
|
||||
}
|
||||
|
||||
// SetDirty marks a field as dirty with its new value
|
||||
func (ft *FieldTracker) SetDirty(field string, value any) {
|
||||
if ft.dirty == nil {
|
||||
ft.dirty = make(map[string]any)
|
||||
}
|
||||
ft.dirty[field] = value
|
||||
}
|
||||
|
||||
// GetDirtyFields returns map of dirty fields and their values
|
||||
func (ft *FieldTracker) GetDirtyFields() map[string]any {
|
||||
if ft.dirty == nil {
|
||||
return make(map[string]any)
|
||||
}
|
||||
return ft.dirty
|
||||
}
|
||||
|
||||
// ClearDirty clears all dirty field tracking
|
||||
func (ft *FieldTracker) ClearDirty() {
|
||||
ft.dirty = nil
|
||||
}
|
||||
|
||||
// IsDirty returns true if any fields have been modified
|
||||
func (ft *FieldTracker) IsDirty() bool {
|
||||
return len(ft.dirty) > 0
|
||||
}
|
||||
|
||||
// UpdateDirty updates only dirty fields in the database
|
||||
func UpdateDirty(model Trackable) error {
|
||||
if !model.IsDirty() {
|
||||
return nil // No changes to save
|
||||
}
|
||||
|
||||
dirty := model.GetDirtyFields()
|
||||
if len(dirty) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build dynamic UPDATE query
|
||||
var setParts []string
|
||||
var args []any
|
||||
|
||||
for field, value := range dirty {
|
||||
setParts = append(setParts, field+" = ?")
|
||||
args = append(args, value)
|
||||
}
|
||||
|
||||
args = append(args, model.GetID()) // Add ID for WHERE clause
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = ?",
|
||||
model.GetTableName(),
|
||||
strings.Join(setParts, ", "))
|
||||
|
||||
err := Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update %s: %w", model.GetTableName(), err)
|
||||
}
|
||||
|
||||
model.ClearDirty()
|
||||
return nil
|
||||
}
|
@ -1,47 +1,31 @@
|
||||
package drops
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Drop represents a drop item in the database
|
||||
// Drop represents a drop item in the game
|
||||
type Drop struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Level int `db:"level" json:"level"`
|
||||
Type int `db:"type" json:"type"`
|
||||
Att string `db:"att" json:"att"`
|
||||
}
|
||||
|
||||
func (d *Drop) GetTableName() string {
|
||||
return "drops"
|
||||
}
|
||||
|
||||
func (d *Drop) GetID() int {
|
||||
return d.ID
|
||||
}
|
||||
|
||||
func (d *Drop) SetID(id int) {
|
||||
d.ID = id
|
||||
}
|
||||
|
||||
func (d *Drop) Set(field string, value any) error {
|
||||
return database.Set(d, field, value)
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Level int `json:"level"`
|
||||
Type int `json:"type"`
|
||||
Att string `json:"att"`
|
||||
}
|
||||
|
||||
func (d *Drop) Save() error {
|
||||
return database.Save(d)
|
||||
dropStore := GetStore()
|
||||
dropStore.UpdateDrop(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Drop) Delete() error {
|
||||
return database.Delete(d)
|
||||
dropStore := GetStore()
|
||||
dropStore.RemoveDrop(d.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new Drop with sensible defaults
|
||||
@ -54,18 +38,18 @@ func New() *Drop {
|
||||
}
|
||||
}
|
||||
|
||||
var dropScanner = scanner.New[Drop]()
|
||||
|
||||
// Returns the column list for drop queries
|
||||
func dropColumns() string {
|
||||
return dropScanner.Columns()
|
||||
}
|
||||
|
||||
// Populates a Drop struct using the fast scanner
|
||||
func scanDrop(stmt *sqlite.Stmt) *Drop {
|
||||
drop := &Drop{}
|
||||
dropScanner.Scan(stmt, drop)
|
||||
return drop
|
||||
// Validate checks if drop has valid values
|
||||
func (d *Drop) Validate() error {
|
||||
if d.Name == "" {
|
||||
return fmt.Errorf("drop name cannot be empty")
|
||||
}
|
||||
if d.Level < 1 {
|
||||
return fmt.Errorf("drop Level must be at least 1")
|
||||
}
|
||||
if d.Type < TypeConsumable {
|
||||
return fmt.Errorf("invalid drop type: %d", d.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropType constants for drop types
|
||||
@ -73,90 +57,231 @@ const (
|
||||
TypeConsumable = 1
|
||||
)
|
||||
|
||||
// DropStore provides in-memory storage with O(1) lookups and drop-specific indices
|
||||
type DropStore struct {
|
||||
*store.BaseStore[Drop] // Embedded generic store
|
||||
byLevel map[int][]int // Level -> []ID
|
||||
byType map[int][]int // Type -> []ID
|
||||
allByID []int // All IDs sorted by ID
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var dropStore *DropStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
dropStore = &DropStore{
|
||||
BaseStore: store.NewBaseStore[Drop](),
|
||||
byLevel: make(map[int][]int),
|
||||
byType: make(map[int][]int),
|
||||
allByID: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global drop store
|
||||
func GetStore() *DropStore {
|
||||
storeOnce.Do(initStore)
|
||||
return dropStore
|
||||
}
|
||||
|
||||
// AddDrop adds a drop to the in-memory store and updates all indices
|
||||
func (ds *DropStore) AddDrop(drop *Drop) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
// Validate drop
|
||||
if err := drop.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
ds.Add(drop.ID, drop)
|
||||
|
||||
// Rebuild indices
|
||||
ds.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveDrop removes a drop from the store and updates indices
|
||||
func (ds *DropStore) RemoveDrop(id int) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
ds.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
ds.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateDrop updates a drop efficiently
|
||||
func (ds *DropStore) UpdateDrop(drop *Drop) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
// Validate drop
|
||||
if err := drop.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
ds.Add(drop.ID, drop)
|
||||
|
||||
// Rebuild indices
|
||||
ds.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads drop data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
ds := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := ds.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
ds.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves drop data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
ds := GetStore()
|
||||
return ds.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (ds *DropStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
ds.byLevel = make(map[int][]int)
|
||||
ds.byType = make(map[int][]int)
|
||||
ds.allByID = make([]int, 0)
|
||||
|
||||
// Collect all drops and build indices
|
||||
allDrops := ds.GetAll()
|
||||
|
||||
for id, drop := range allDrops {
|
||||
// Level index
|
||||
ds.byLevel[drop.Level] = append(ds.byLevel[drop.Level], id)
|
||||
|
||||
// Type index
|
||||
ds.byType[drop.Type] = append(ds.byType[drop.Type], id)
|
||||
|
||||
// All IDs
|
||||
ds.allByID = append(ds.allByID, id)
|
||||
}
|
||||
|
||||
// Sort allByID by ID
|
||||
sort.Ints(ds.allByID)
|
||||
|
||||
// Sort level indices by ID
|
||||
for level := range ds.byLevel {
|
||||
sort.Ints(ds.byLevel[level])
|
||||
}
|
||||
|
||||
// Sort type indices by level, then ID
|
||||
for dropType := range ds.byType {
|
||||
sort.Slice(ds.byType[dropType], func(i, j int) bool {
|
||||
dropI, _ := ds.GetByID(ds.byType[dropType][i])
|
||||
dropJ, _ := ds.GetByID(ds.byType[dropType][j])
|
||||
if dropI.Level != dropJ.Level {
|
||||
return dropI.Level < dropJ.Level
|
||||
}
|
||||
return ds.byType[dropType][i] < ds.byType[dropType][j]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all drop-specific indices from base store data
|
||||
func (ds *DropStore) rebuildIndices() {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
ds.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves a drop by ID
|
||||
func Find(id int) (*Drop, error) {
|
||||
var drop *Drop
|
||||
|
||||
query := `SELECT ` + dropColumns() + ` FROM drops WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
drop = scanDrop(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find drop: %w", err)
|
||||
}
|
||||
|
||||
if drop == nil {
|
||||
ds := GetStore()
|
||||
drop, exists := ds.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("drop with ID %d not found", id)
|
||||
}
|
||||
|
||||
return drop, nil
|
||||
}
|
||||
|
||||
// Retrieves all drops
|
||||
func All() ([]*Drop, error) {
|
||||
var drops []*Drop
|
||||
ds := GetStore()
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + dropColumns() + ` FROM drops ORDER BY id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
drop := scanDrop(stmt)
|
||||
drops = append(drops, drop)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all drops: %w", err)
|
||||
result := make([]*Drop, 0, len(ds.allByID))
|
||||
for _, id := range ds.allByID {
|
||||
if drop, exists := ds.GetByID(id); exists {
|
||||
result = append(result, drop)
|
||||
}
|
||||
}
|
||||
|
||||
return drops, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves drops by minimum level requirement
|
||||
func ByLevel(minLevel int) ([]*Drop, error) {
|
||||
var drops []*Drop
|
||||
ds := GetStore()
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + dropColumns() + ` FROM drops WHERE level <= ? ORDER BY level, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
drop := scanDrop(stmt)
|
||||
drops = append(drops, drop)
|
||||
return nil
|
||||
}, minLevel)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve drops by level: %w", err)
|
||||
var result []*Drop
|
||||
for level := 1; level <= minLevel; level++ {
|
||||
if ids, exists := ds.byLevel[level]; exists {
|
||||
for _, id := range ids {
|
||||
if drop, exists := ds.GetByID(id); exists {
|
||||
result = append(result, drop)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return drops, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves drops by type
|
||||
func ByType(dropType int) ([]*Drop, error) {
|
||||
var drops []*Drop
|
||||
ds := GetStore()
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + dropColumns() + ` FROM drops WHERE type = ? ORDER BY level, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
drop := scanDrop(stmt)
|
||||
drops = append(drops, drop)
|
||||
return nil
|
||||
}, dropType)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve drops by type: %w", err)
|
||||
ids, exists := ds.byType[dropType]
|
||||
if !exists {
|
||||
return []*Drop{}, nil
|
||||
}
|
||||
|
||||
return drops, nil
|
||||
result := make([]*Drop, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if drop, exists := ds.GetByID(id); exists {
|
||||
result = append(result, drop)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new drop to the database and sets the ID
|
||||
// Saves a new drop to the in-memory store and sets the ID
|
||||
func (d *Drop) Insert() error {
|
||||
columns := `name, level, type, att`
|
||||
values := []any{d.Name, d.Level, d.Type, d.Att}
|
||||
return database.Insert(d, columns, values...)
|
||||
ds := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := d.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if d.ID == 0 {
|
||||
d.ID = ds.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
ds.AddDrop(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns true if the drop is a consumable item
|
||||
|
@ -1,52 +1,36 @@
|
||||
package forum
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Forum represents a forum post or thread in the database
|
||||
// Forum represents a forum post or thread in the game
|
||||
type Forum struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Posted int64 `db:"posted" json:"posted"`
|
||||
LastPost int64 `db:"last_post" json:"last_post"`
|
||||
Author int `db:"author" json:"author"`
|
||||
Parent int `db:"parent" json:"parent"`
|
||||
Replies int `db:"replies" json:"replies"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Content string `db:"content" json:"content"`
|
||||
}
|
||||
|
||||
func (f *Forum) GetTableName() string {
|
||||
return "forum"
|
||||
}
|
||||
|
||||
func (f *Forum) GetID() int {
|
||||
return f.ID
|
||||
}
|
||||
|
||||
func (f *Forum) SetID(id int) {
|
||||
f.ID = id
|
||||
}
|
||||
|
||||
func (f *Forum) Set(field string, value any) error {
|
||||
return database.Set(f, field, value)
|
||||
ID int `json:"id"`
|
||||
Posted int64 `json:"posted"`
|
||||
LastPost int64 `json:"last_post"`
|
||||
Author int `json:"author"`
|
||||
Parent int `json:"parent"`
|
||||
Replies int `json:"replies"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (f *Forum) Save() error {
|
||||
return database.Save(f)
|
||||
forumStore := GetStore()
|
||||
forumStore.UpdateForum(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Forum) Delete() error {
|
||||
return database.Delete(f)
|
||||
forumStore := GetStore()
|
||||
forumStore.RemoveForum(f.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new Forum with sensible defaults
|
||||
@ -63,181 +47,358 @@ func New() *Forum {
|
||||
}
|
||||
}
|
||||
|
||||
var forumScanner = scanner.New[Forum]()
|
||||
|
||||
// Returns the column list for forum queries
|
||||
func forumColumns() string {
|
||||
return forumScanner.Columns()
|
||||
// Validate checks if forum has valid values
|
||||
func (f *Forum) Validate() error {
|
||||
if strings.TrimSpace(f.Title) == "" {
|
||||
return fmt.Errorf("forum title cannot be empty")
|
||||
}
|
||||
if strings.TrimSpace(f.Content) == "" {
|
||||
return fmt.Errorf("forum content cannot be empty")
|
||||
}
|
||||
if f.Posted <= 0 {
|
||||
return fmt.Errorf("forum Posted timestamp must be positive")
|
||||
}
|
||||
if f.LastPost <= 0 {
|
||||
return fmt.Errorf("forum LastPost timestamp must be positive")
|
||||
}
|
||||
if f.Parent < 0 {
|
||||
return fmt.Errorf("forum Parent cannot be negative")
|
||||
}
|
||||
if f.Replies < 0 {
|
||||
return fmt.Errorf("forum Replies cannot be negative")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Populates a Forum struct using the fast scanner
|
||||
func scanForum(stmt *sqlite.Stmt) *Forum {
|
||||
forum := &Forum{}
|
||||
forumScanner.Scan(stmt, forum)
|
||||
return forum
|
||||
// ForumStore provides in-memory storage with O(1) lookups and forum-specific indices
|
||||
type ForumStore struct {
|
||||
*store.BaseStore[Forum] // Embedded generic store
|
||||
byParent map[int][]int // Parent -> []ID
|
||||
byAuthor map[int][]int // Author -> []ID
|
||||
threadsOnly []int // Parent=0 IDs sorted by last_post DESC, id DESC
|
||||
allByLastPost []int // All IDs sorted by last_post DESC, id DESC
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var forumStore *ForumStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
forumStore = &ForumStore{
|
||||
BaseStore: store.NewBaseStore[Forum](),
|
||||
byParent: make(map[int][]int),
|
||||
byAuthor: make(map[int][]int),
|
||||
threadsOnly: make([]int, 0),
|
||||
allByLastPost: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global forum store
|
||||
func GetStore() *ForumStore {
|
||||
storeOnce.Do(initStore)
|
||||
return forumStore
|
||||
}
|
||||
|
||||
// AddForum adds a forum post to the in-memory store and updates all indices
|
||||
func (fs *ForumStore) AddForum(forum *Forum) {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
|
||||
// Validate forum
|
||||
if err := forum.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
fs.Add(forum.ID, forum)
|
||||
|
||||
// Rebuild indices
|
||||
fs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveForum removes a forum post from the store and updates indices
|
||||
func (fs *ForumStore) RemoveForum(id int) {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
fs.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
fs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateForum updates a forum post efficiently
|
||||
func (fs *ForumStore) UpdateForum(forum *Forum) {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
|
||||
// Validate forum
|
||||
if err := forum.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
fs.Add(forum.ID, forum)
|
||||
|
||||
// Rebuild indices
|
||||
fs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads forum data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
fs := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := fs.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
fs.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves forum data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
fs := GetStore()
|
||||
return fs.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (fs *ForumStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
fs.byParent = make(map[int][]int)
|
||||
fs.byAuthor = make(map[int][]int)
|
||||
fs.threadsOnly = make([]int, 0)
|
||||
fs.allByLastPost = make([]int, 0)
|
||||
|
||||
// Collect all forum posts and build indices
|
||||
allForums := fs.GetAll()
|
||||
|
||||
for id, forum := range allForums {
|
||||
// Parent index
|
||||
fs.byParent[forum.Parent] = append(fs.byParent[forum.Parent], id)
|
||||
|
||||
// Author index
|
||||
fs.byAuthor[forum.Author] = append(fs.byAuthor[forum.Author], id)
|
||||
|
||||
// Threads only (parent = 0)
|
||||
if forum.Parent == 0 {
|
||||
fs.threadsOnly = append(fs.threadsOnly, id)
|
||||
}
|
||||
|
||||
// All posts
|
||||
fs.allByLastPost = append(fs.allByLastPost, id)
|
||||
}
|
||||
|
||||
// Sort allByLastPost by last_post DESC, then ID DESC
|
||||
sort.Slice(fs.allByLastPost, func(i, j int) bool {
|
||||
forumI, _ := fs.GetByID(fs.allByLastPost[i])
|
||||
forumJ, _ := fs.GetByID(fs.allByLastPost[j])
|
||||
if forumI.LastPost != forumJ.LastPost {
|
||||
return forumI.LastPost > forumJ.LastPost // DESC
|
||||
}
|
||||
return fs.allByLastPost[i] > fs.allByLastPost[j] // DESC
|
||||
})
|
||||
|
||||
// Sort threadsOnly by last_post DESC, then ID DESC
|
||||
sort.Slice(fs.threadsOnly, func(i, j int) bool {
|
||||
forumI, _ := fs.GetByID(fs.threadsOnly[i])
|
||||
forumJ, _ := fs.GetByID(fs.threadsOnly[j])
|
||||
if forumI.LastPost != forumJ.LastPost {
|
||||
return forumI.LastPost > forumJ.LastPost // DESC
|
||||
}
|
||||
return fs.threadsOnly[i] > fs.threadsOnly[j] // DESC
|
||||
})
|
||||
|
||||
// Sort byParent replies by posted ASC, then ID ASC
|
||||
for parent := range fs.byParent {
|
||||
if parent > 0 { // Only sort replies, not threads
|
||||
sort.Slice(fs.byParent[parent], func(i, j int) bool {
|
||||
forumI, _ := fs.GetByID(fs.byParent[parent][i])
|
||||
forumJ, _ := fs.GetByID(fs.byParent[parent][j])
|
||||
if forumI.Posted != forumJ.Posted {
|
||||
return forumI.Posted < forumJ.Posted // ASC
|
||||
}
|
||||
return fs.byParent[parent][i] < fs.byParent[parent][j] // ASC
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort byAuthor by posted DESC, then ID DESC
|
||||
for author := range fs.byAuthor {
|
||||
sort.Slice(fs.byAuthor[author], func(i, j int) bool {
|
||||
forumI, _ := fs.GetByID(fs.byAuthor[author][i])
|
||||
forumJ, _ := fs.GetByID(fs.byAuthor[author][j])
|
||||
if forumI.Posted != forumJ.Posted {
|
||||
return forumI.Posted > forumJ.Posted // DESC
|
||||
}
|
||||
return fs.byAuthor[author][i] > fs.byAuthor[author][j] // DESC
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all forum-specific indices from base store data
|
||||
func (fs *ForumStore) rebuildIndices() {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
fs.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves a forum post by ID
|
||||
func Find(id int) (*Forum, error) {
|
||||
var forum *Forum
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum = scanForum(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find forum post: %w", err)
|
||||
}
|
||||
|
||||
if forum == nil {
|
||||
fs := GetStore()
|
||||
forum, exists := fs.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("forum post with ID %d not found", id)
|
||||
}
|
||||
|
||||
return forum, nil
|
||||
}
|
||||
|
||||
// Retrieves all forum posts ordered by last post time (most recent first)
|
||||
func All() ([]*Forum, error) {
|
||||
var forums []*Forum
|
||||
fs := GetStore()
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum ORDER BY last_post DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum := scanForum(stmt)
|
||||
forums = append(forums, forum)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all forum posts: %w", err)
|
||||
result := make([]*Forum, 0, len(fs.allByLastPost))
|
||||
for _, id := range fs.allByLastPost {
|
||||
if forum, exists := fs.GetByID(id); exists {
|
||||
result = append(result, forum)
|
||||
}
|
||||
}
|
||||
|
||||
return forums, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves all top-level forum threads (parent = 0)
|
||||
func Threads() ([]*Forum, error) {
|
||||
var forums []*Forum
|
||||
fs := GetStore()
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum WHERE parent = 0 ORDER BY last_post DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum := scanForum(stmt)
|
||||
forums = append(forums, forum)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve forum threads: %w", err)
|
||||
result := make([]*Forum, 0, len(fs.threadsOnly))
|
||||
for _, id := range fs.threadsOnly {
|
||||
if forum, exists := fs.GetByID(id); exists {
|
||||
result = append(result, forum)
|
||||
}
|
||||
}
|
||||
|
||||
return forums, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves all replies to a specific thread/post
|
||||
func ByParent(parentID int) ([]*Forum, error) {
|
||||
var forums []*Forum
|
||||
fs := GetStore()
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum WHERE parent = ? ORDER BY posted ASC, id ASC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum := scanForum(stmt)
|
||||
forums = append(forums, forum)
|
||||
return nil
|
||||
}, parentID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve forum replies: %w", err)
|
||||
ids, exists := fs.byParent[parentID]
|
||||
if !exists {
|
||||
return []*Forum{}, nil
|
||||
}
|
||||
|
||||
return forums, nil
|
||||
result := make([]*Forum, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if forum, exists := fs.GetByID(id); exists {
|
||||
result = append(result, forum)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves forum posts by a specific author
|
||||
func ByAuthor(authorID int) ([]*Forum, error) {
|
||||
var forums []*Forum
|
||||
fs := GetStore()
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum WHERE author = ? ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum := scanForum(stmt)
|
||||
forums = append(forums, forum)
|
||||
return nil
|
||||
}, authorID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve forum posts by author: %w", err)
|
||||
ids, exists := fs.byAuthor[authorID]
|
||||
if !exists {
|
||||
return []*Forum{}, nil
|
||||
}
|
||||
|
||||
return forums, nil
|
||||
result := make([]*Forum, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if forum, exists := fs.GetByID(id); exists {
|
||||
result = append(result, forum)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves the most recent forum activity (limited by count)
|
||||
func Recent(limit int) ([]*Forum, error) {
|
||||
var forums []*Forum
|
||||
fs := GetStore()
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum ORDER BY last_post DESC, id DESC LIMIT ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum := scanForum(stmt)
|
||||
forums = append(forums, forum)
|
||||
return nil
|
||||
}, limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve recent forum posts: %w", err)
|
||||
if limit > len(fs.allByLastPost) {
|
||||
limit = len(fs.allByLastPost)
|
||||
}
|
||||
|
||||
return forums, nil
|
||||
result := make([]*Forum, 0, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
if forum, exists := fs.GetByID(fs.allByLastPost[i]); exists {
|
||||
result = append(result, forum)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves forum posts containing the search term in title or content
|
||||
func Search(term string) ([]*Forum, error) {
|
||||
var forums []*Forum
|
||||
fs := GetStore()
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum WHERE LOWER(title) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?) ORDER BY last_post DESC, id DESC`
|
||||
searchTerm := "%" + term + "%"
|
||||
var result []*Forum
|
||||
lowerTerm := strings.ToLower(term)
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum := scanForum(stmt)
|
||||
forums = append(forums, forum)
|
||||
return nil
|
||||
}, searchTerm, searchTerm)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search forum posts: %w", err)
|
||||
for _, id := range fs.allByLastPost {
|
||||
if forum, exists := fs.GetByID(id); exists {
|
||||
if strings.Contains(strings.ToLower(forum.Title), lowerTerm) ||
|
||||
strings.Contains(strings.ToLower(forum.Content), lowerTerm) {
|
||||
result = append(result, forum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return forums, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves forum posts with activity since a specific timestamp
|
||||
func Since(since int64) ([]*Forum, error) {
|
||||
var forums []*Forum
|
||||
fs := GetStore()
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + forumColumns() + ` FROM forum WHERE last_post >= ? ORDER BY last_post DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
forum := scanForum(stmt)
|
||||
forums = append(forums, forum)
|
||||
return nil
|
||||
}, since)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve forum posts since timestamp: %w", err)
|
||||
var result []*Forum
|
||||
for _, id := range fs.allByLastPost {
|
||||
if forum, exists := fs.GetByID(id); exists && forum.LastPost >= since {
|
||||
result = append(result, forum)
|
||||
}
|
||||
}
|
||||
|
||||
return forums, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new forum post to the database and sets the ID
|
||||
// Saves a new forum post to the in-memory store and sets the ID
|
||||
func (f *Forum) Insert() error {
|
||||
columns := `posted, last_post, author, parent, replies, title, content`
|
||||
values := []any{f.Posted, f.LastPost, f.Author, f.Parent, f.Replies, f.Title, f.Content}
|
||||
return database.Insert(f, columns, values...)
|
||||
fs := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := f.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if f.ID == 0 {
|
||||
f.ID = fs.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
fs.AddForum(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the posted timestamp as a time.Time
|
||||
@ -252,12 +413,12 @@ func (f *Forum) LastPostTime() time.Time {
|
||||
|
||||
// Sets the posted timestamp from a time.Time
|
||||
func (f *Forum) SetPostedTime(t time.Time) {
|
||||
f.Set("Posted", t.Unix())
|
||||
f.Posted = t.Unix()
|
||||
}
|
||||
|
||||
// Sets the last post timestamp from a time.Time
|
||||
func (f *Forum) SetLastPostTime(t time.Time) {
|
||||
f.Set("LastPost", t.Unix())
|
||||
f.LastPost = t.Unix()
|
||||
}
|
||||
|
||||
// Returns true if this is a top-level thread (parent = 0)
|
||||
@ -350,18 +511,18 @@ func (f *Forum) Contains(term string) bool {
|
||||
|
||||
// Updates the last_post timestamp to current time
|
||||
func (f *Forum) UpdateLastPost() {
|
||||
f.Set("LastPost", time.Now().Unix())
|
||||
f.LastPost = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Increments the reply count
|
||||
func (f *Forum) IncrementReplies() {
|
||||
f.Set("Replies", f.Replies+1)
|
||||
f.Replies++
|
||||
}
|
||||
|
||||
// Decrements the reply count (minimum 0)
|
||||
func (f *Forum) DecrementReplies() {
|
||||
if f.Replies > 0 {
|
||||
f.Set("Replies", f.Replies-1)
|
||||
f.Replies--
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,48 +1,32 @@
|
||||
package items
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Item represents an item in the database
|
||||
// Item represents an item in the game
|
||||
type Item struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Type int `db:"type" json:"type"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Value int `db:"value" json:"value"`
|
||||
Att int `db:"att" json:"att"`
|
||||
Special string `db:"special" json:"special"`
|
||||
}
|
||||
|
||||
func (i *Item) GetTableName() string {
|
||||
return "items"
|
||||
}
|
||||
|
||||
func (i *Item) GetID() int {
|
||||
return i.ID
|
||||
}
|
||||
|
||||
func (i *Item) SetID(id int) {
|
||||
i.ID = id
|
||||
}
|
||||
|
||||
func (i *Item) Set(field string, value any) error {
|
||||
return database.Set(i, field, value)
|
||||
ID int `json:"id"`
|
||||
Type int `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
Att int `json:"att"`
|
||||
Special string `json:"special"`
|
||||
}
|
||||
|
||||
func (i *Item) Save() error {
|
||||
return database.Save(i)
|
||||
itemStore := GetStore()
|
||||
itemStore.UpdateItem(i)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Item) Delete() error {
|
||||
return database.Delete(i)
|
||||
itemStore := GetStore()
|
||||
itemStore.RemoveItem(i.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new Item with sensible defaults
|
||||
@ -56,18 +40,21 @@ func New() *Item {
|
||||
}
|
||||
}
|
||||
|
||||
var itemScanner = scanner.New[Item]()
|
||||
|
||||
// Returns the column list for item queries
|
||||
func itemColumns() string {
|
||||
return itemScanner.Columns()
|
||||
}
|
||||
|
||||
// Populates an Item struct using the fast scanner
|
||||
func scanItem(stmt *sqlite.Stmt) *Item {
|
||||
item := &Item{}
|
||||
itemScanner.Scan(stmt, item)
|
||||
return item
|
||||
// Validate checks if item has valid values
|
||||
func (i *Item) Validate() error {
|
||||
if i.Name == "" {
|
||||
return fmt.Errorf("item name cannot be empty")
|
||||
}
|
||||
if i.Type < TypeWeapon || i.Type > TypeShield {
|
||||
return fmt.Errorf("invalid item type: %d", i.Type)
|
||||
}
|
||||
if i.Value < 0 {
|
||||
return fmt.Errorf("item Value cannot be negative")
|
||||
}
|
||||
if i.Att < 0 {
|
||||
return fmt.Errorf("item Att cannot be negative")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ItemType constants for item types
|
||||
@ -77,71 +64,194 @@ const (
|
||||
TypeShield = 3
|
||||
)
|
||||
|
||||
// ItemStore provides in-memory storage with O(1) lookups and item-specific indices
|
||||
type ItemStore struct {
|
||||
*store.BaseStore[Item] // Embedded generic store
|
||||
byType map[int][]int // Type -> []ID
|
||||
allByID []int // All IDs sorted by ID
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var itemStore *ItemStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
itemStore = &ItemStore{
|
||||
BaseStore: store.NewBaseStore[Item](),
|
||||
byType: make(map[int][]int),
|
||||
allByID: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global item store
|
||||
func GetStore() *ItemStore {
|
||||
storeOnce.Do(initStore)
|
||||
return itemStore
|
||||
}
|
||||
|
||||
// AddItem adds an item to the in-memory store and updates all indices
|
||||
func (is *ItemStore) AddItem(item *Item) {
|
||||
is.mu.Lock()
|
||||
defer is.mu.Unlock()
|
||||
|
||||
// Validate item
|
||||
if err := item.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
is.Add(item.ID, item)
|
||||
|
||||
// Rebuild indices
|
||||
is.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveItem removes an item from the store and updates indices
|
||||
func (is *ItemStore) RemoveItem(id int) {
|
||||
is.mu.Lock()
|
||||
defer is.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
is.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
is.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateItem updates an item efficiently
|
||||
func (is *ItemStore) UpdateItem(item *Item) {
|
||||
is.mu.Lock()
|
||||
defer is.mu.Unlock()
|
||||
|
||||
// Validate item
|
||||
if err := item.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
is.Add(item.ID, item)
|
||||
|
||||
// Rebuild indices
|
||||
is.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads item data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
is := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := is.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
is.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves item data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
is := GetStore()
|
||||
return is.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (is *ItemStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
is.byType = make(map[int][]int)
|
||||
is.allByID = make([]int, 0)
|
||||
|
||||
// Collect all items and build indices
|
||||
allItems := is.GetAll()
|
||||
|
||||
for id, item := range allItems {
|
||||
// Type index
|
||||
is.byType[item.Type] = append(is.byType[item.Type], id)
|
||||
|
||||
// All IDs
|
||||
is.allByID = append(is.allByID, id)
|
||||
}
|
||||
|
||||
// Sort allByID by ID
|
||||
sort.Ints(is.allByID)
|
||||
|
||||
// Sort type indices by ID
|
||||
for itemType := range is.byType {
|
||||
sort.Ints(is.byType[itemType])
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all item-specific indices from base store data
|
||||
func (is *ItemStore) rebuildIndices() {
|
||||
is.mu.Lock()
|
||||
defer is.mu.Unlock()
|
||||
is.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves an item by ID
|
||||
func Find(id int) (*Item, error) {
|
||||
var item *Item
|
||||
|
||||
query := `SELECT ` + itemColumns() + ` FROM items WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
item = scanItem(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find item: %w", err)
|
||||
}
|
||||
|
||||
if item == nil {
|
||||
is := GetStore()
|
||||
item, exists := is.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("item with ID %d not found", id)
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Retrieves all items
|
||||
func All() ([]*Item, error) {
|
||||
var items []*Item
|
||||
is := GetStore()
|
||||
is.mu.RLock()
|
||||
defer is.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + itemColumns() + ` FROM items ORDER BY id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
item := scanItem(stmt)
|
||||
items = append(items, item)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all items: %w", err)
|
||||
result := make([]*Item, 0, len(is.allByID))
|
||||
for _, id := range is.allByID {
|
||||
if item, exists := is.GetByID(id); exists {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
|
||||
return items, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves items by type
|
||||
func ByType(itemType int) ([]*Item, error) {
|
||||
var items []*Item
|
||||
is := GetStore()
|
||||
is.mu.RLock()
|
||||
defer is.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + itemColumns() + ` FROM items WHERE type = ? ORDER BY id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
item := scanItem(stmt)
|
||||
items = append(items, item)
|
||||
return nil
|
||||
}, itemType)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve items by type: %w", err)
|
||||
ids, exists := is.byType[itemType]
|
||||
if !exists {
|
||||
return []*Item{}, nil
|
||||
}
|
||||
|
||||
return items, nil
|
||||
result := make([]*Item, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if item, exists := is.GetByID(id); exists {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new item to the database and sets the ID
|
||||
// Saves a new item to the in-memory store and sets the ID
|
||||
func (i *Item) Insert() error {
|
||||
columns := `type, name, value, att, special`
|
||||
values := []any{i.Type, i.Name, i.Value, i.Att, i.Special}
|
||||
return database.Insert(i, columns, values...)
|
||||
is := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := i.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if i.ID == 0 {
|
||||
i.ID = is.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
is.AddItem(i)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns true if the item is a weapon
|
||||
|
@ -1,10 +1,8 @@
|
||||
package monsters
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
@ -23,14 +21,14 @@ type Monster struct {
|
||||
}
|
||||
|
||||
func (m *Monster) Save() error {
|
||||
store := GetStore()
|
||||
store.UpdateMonster(m)
|
||||
monsterStore := GetStore()
|
||||
monsterStore.UpdateMonster(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Monster) Delete() error {
|
||||
store := GetStore()
|
||||
store.RemoveMonster(m.ID)
|
||||
monsterStore := GetStore()
|
||||
monsterStore.RemoveMonster(m.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -48,6 +46,22 @@ func New() *Monster {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if monster has valid values
|
||||
func (m *Monster) Validate() error {
|
||||
if m.Name == "" {
|
||||
return fmt.Errorf("monster name cannot be empty")
|
||||
}
|
||||
if m.MaxHP < 1 {
|
||||
return fmt.Errorf("monster MaxHP must be at least 1")
|
||||
}
|
||||
if m.Level < 1 {
|
||||
return fmt.Errorf("monster Level must be at least 1")
|
||||
}
|
||||
if m.Immune < ImmuneNone || m.Immune > ImmuneSleep {
|
||||
return fmt.Errorf("invalid immunity type: %d", m.Immune)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Immunity constants for monster immunity types
|
||||
const (
|
||||
@ -56,35 +70,33 @@ const (
|
||||
ImmuneSleep = 2 // Immune to Sleep spells
|
||||
)
|
||||
|
||||
// MonsterStore provides in-memory storage with O(1) lookups
|
||||
// MonsterStore provides in-memory storage with O(1) lookups and monster-specific indices
|
||||
type MonsterStore struct {
|
||||
monsters map[int]*Monster // ID -> Monster (O(1))
|
||||
byLevel map[int][]*Monster // Level -> []*Monster (O(1) to get slice)
|
||||
byImmunity map[int][]*Monster // Immunity -> []*Monster (O(1) to get slice)
|
||||
allByLevel []*Monster // Pre-sorted by level, id
|
||||
maxID int
|
||||
mu sync.RWMutex
|
||||
*store.BaseStore[Monster] // Embedded generic store
|
||||
byLevel map[int][]int // Level -> []ID
|
||||
byImmunity map[int][]int // Immunity -> []ID
|
||||
allByLevel []int // All IDs sorted by level, then ID
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var store *MonsterStore
|
||||
var monsterStore *MonsterStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
store = &MonsterStore{
|
||||
monsters: make(map[int]*Monster),
|
||||
byLevel: make(map[int][]*Monster),
|
||||
byImmunity: make(map[int][]*Monster),
|
||||
allByLevel: make([]*Monster, 0),
|
||||
maxID: 0,
|
||||
monsterStore = &MonsterStore{
|
||||
BaseStore: store.NewBaseStore[Monster](),
|
||||
byLevel: make(map[int][]int),
|
||||
byImmunity: make(map[int][]int),
|
||||
allByLevel: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global monster store
|
||||
func GetStore() *MonsterStore {
|
||||
storeOnce.Do(initStore)
|
||||
return store
|
||||
return monsterStore
|
||||
}
|
||||
|
||||
// AddMonster adds a monster to the in-memory store and updates all indices
|
||||
@ -92,41 +104,16 @@ func (ms *MonsterStore) AddMonster(monster *Monster) {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
// Add to primary store
|
||||
ms.monsters[monster.ID] = monster
|
||||
|
||||
// Update max ID
|
||||
if monster.ID > ms.maxID {
|
||||
ms.maxID = monster.ID
|
||||
// Validate monster
|
||||
if err := monster.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to level index
|
||||
ms.byLevel[monster.Level] = append(ms.byLevel[monster.Level], monster)
|
||||
|
||||
// Add to immunity index
|
||||
ms.byImmunity[monster.Immune] = append(ms.byImmunity[monster.Immune], monster)
|
||||
|
||||
// Add to sorted list and re-sort
|
||||
ms.allByLevel = append(ms.allByLevel, monster)
|
||||
sort.Slice(ms.allByLevel, func(i, j int) bool {
|
||||
if ms.allByLevel[i].Level == ms.allByLevel[j].Level {
|
||||
return ms.allByLevel[i].ID < ms.allByLevel[j].ID
|
||||
}
|
||||
return ms.allByLevel[i].Level < ms.allByLevel[j].Level
|
||||
})
|
||||
// Add to base store
|
||||
ms.Add(monster.ID, monster)
|
||||
|
||||
// Sort level index
|
||||
sort.Slice(ms.byLevel[monster.Level], func(i, j int) bool {
|
||||
return ms.byLevel[monster.Level][i].ID < ms.byLevel[monster.Level][j].ID
|
||||
})
|
||||
|
||||
// Sort immunity index
|
||||
sort.Slice(ms.byImmunity[monster.Immune], func(i, j int) bool {
|
||||
if ms.byImmunity[monster.Immune][i].Level == ms.byImmunity[monster.Immune][j].Level {
|
||||
return ms.byImmunity[monster.Immune][i].ID < ms.byImmunity[monster.Immune][j].ID
|
||||
}
|
||||
return ms.byImmunity[monster.Immune][i].Level < ms.byImmunity[monster.Immune][j].Level
|
||||
})
|
||||
// Rebuild indices
|
||||
ms.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveMonster removes a monster from the store and updates indices
|
||||
@ -134,344 +121,202 @@ func (ms *MonsterStore) RemoveMonster(id int) {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
monster, exists := ms.monsters[id]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
// Remove from base store
|
||||
ms.Remove(id)
|
||||
|
||||
// Remove from primary store
|
||||
delete(ms.monsters, id)
|
||||
|
||||
// Remove from level index
|
||||
levelMonsters := ms.byLevel[monster.Level]
|
||||
for i, m := range levelMonsters {
|
||||
if m.ID == id {
|
||||
ms.byLevel[monster.Level] = append(levelMonsters[:i], levelMonsters[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from immunity index
|
||||
immunityMonsters := ms.byImmunity[monster.Immune]
|
||||
for i, m := range immunityMonsters {
|
||||
if m.ID == id {
|
||||
ms.byImmunity[monster.Immune] = append(immunityMonsters[:i], immunityMonsters[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from sorted list
|
||||
for i, m := range ms.allByLevel {
|
||||
if m.ID == id {
|
||||
ms.allByLevel = append(ms.allByLevel[:i], ms.allByLevel[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Rebuild indices
|
||||
ms.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateMonster updates a monster and rebuilds indices
|
||||
// UpdateMonster updates a monster efficiently
|
||||
func (ms *MonsterStore) UpdateMonster(monster *Monster) {
|
||||
ms.RemoveMonster(monster.ID)
|
||||
ms.AddMonster(monster)
|
||||
}
|
||||
|
||||
// GetNextID returns the next available ID
|
||||
func (ms *MonsterStore) GetNextID() int {
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
return ms.maxID + 1
|
||||
}
|
||||
|
||||
// LoadFromJSON loads monster data from a JSON file
|
||||
func (ms *MonsterStore) LoadFromJSON(filename string) error {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, start with empty store
|
||||
}
|
||||
return fmt.Errorf("failed to read monsters JSON: %w", err)
|
||||
// Validate monster
|
||||
if err := monster.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle empty file
|
||||
if len(data) == 0 {
|
||||
return nil // Empty file, start with empty store
|
||||
}
|
||||
// Update base store
|
||||
ms.Add(monster.ID, monster)
|
||||
|
||||
var monsters []*Monster
|
||||
if err := json.Unmarshal(data, &monsters); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal monsters JSON: %w", err)
|
||||
}
|
||||
|
||||
// Clear existing data
|
||||
ms.monsters = make(map[int]*Monster)
|
||||
ms.byLevel = make(map[int][]*Monster)
|
||||
ms.byImmunity = make(map[int][]*Monster)
|
||||
ms.allByLevel = make([]*Monster, 0)
|
||||
ms.maxID = 0
|
||||
|
||||
// Add all monsters
|
||||
for _, monster := range monsters {
|
||||
ms.monsters[monster.ID] = monster
|
||||
if monster.ID > ms.maxID {
|
||||
ms.maxID = monster.ID
|
||||
}
|
||||
ms.byLevel[monster.Level] = append(ms.byLevel[monster.Level], monster)
|
||||
ms.byImmunity[monster.Immune] = append(ms.byImmunity[monster.Immune], monster)
|
||||
ms.allByLevel = append(ms.allByLevel, monster)
|
||||
}
|
||||
|
||||
// Sort all indices
|
||||
sort.Slice(ms.allByLevel, func(i, j int) bool {
|
||||
if ms.allByLevel[i].Level == ms.allByLevel[j].Level {
|
||||
return ms.allByLevel[i].ID < ms.allByLevel[j].ID
|
||||
}
|
||||
return ms.allByLevel[i].Level < ms.allByLevel[j].Level
|
||||
})
|
||||
|
||||
for level := range ms.byLevel {
|
||||
sort.Slice(ms.byLevel[level], func(i, j int) bool {
|
||||
return ms.byLevel[level][i].ID < ms.byLevel[level][j].ID
|
||||
})
|
||||
}
|
||||
|
||||
for immunity := range ms.byImmunity {
|
||||
sort.Slice(ms.byImmunity[immunity], func(i, j int) bool {
|
||||
if ms.byImmunity[immunity][i].Level == ms.byImmunity[immunity][j].Level {
|
||||
return ms.byImmunity[immunity][i].ID < ms.byImmunity[immunity][j].ID
|
||||
}
|
||||
return ms.byImmunity[immunity][i].Level < ms.byImmunity[immunity][j].Level
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
// Rebuild indices
|
||||
ms.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// SaveToJSON saves monster data to a JSON file
|
||||
func (ms *MonsterStore) SaveToJSON(filename string) error {
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
// LoadData loads monster data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
ms := GetStore()
|
||||
|
||||
monsters := make([]*Monster, 0, len(ms.monsters))
|
||||
for _, monster := range ms.monsters {
|
||||
monsters = append(monsters, monster)
|
||||
}
|
||||
|
||||
// Sort by ID for consistent output
|
||||
sort.Slice(monsters, func(i, j int) bool {
|
||||
return monsters[i].ID < monsters[j].ID
|
||||
})
|
||||
|
||||
data, err := json.MarshalIndent(monsters, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal monsters to JSON: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filename, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write monsters JSON: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findMonstersDataPath finds the monsters.json file relative to the current working directory
|
||||
func findMonstersDataPath() (string, error) {
|
||||
// Try current directory first (cwd/data/monsters.json)
|
||||
if _, err := os.Stat("data/monsters.json"); err == nil {
|
||||
return "data/monsters.json", nil
|
||||
}
|
||||
|
||||
// Walk up directories to find the data folder
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
dataPath := filepath.Join(dir, "data", "monsters.json")
|
||||
if _, err := os.Stat(dataPath); err == nil {
|
||||
return dataPath, nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
break // reached root
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
|
||||
// Default to current directory if not found
|
||||
return "data/monsters.json", nil
|
||||
}
|
||||
|
||||
// LoadData loads monster data from JSON file, or initializes with default data
|
||||
func LoadData() error {
|
||||
store := GetStore()
|
||||
|
||||
dataPath, err := findMonstersDataPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find monsters data path: %w", err)
|
||||
}
|
||||
|
||||
if err := store.LoadFromJSON(dataPath); err != nil {
|
||||
// If JSON doesn't exist, initialize with default monsters
|
||||
if os.IsNotExist(err) {
|
||||
fmt.Println("No existing monster data found, initializing with defaults...")
|
||||
if err := initializeDefaultMonsters(); err != nil {
|
||||
return fmt.Errorf("failed to initialize default monsters: %w", err)
|
||||
}
|
||||
// Save the default data
|
||||
if err := SaveData(); err != nil {
|
||||
return fmt.Errorf("failed to save default monster data: %w", err)
|
||||
}
|
||||
fmt.Printf("Initialized %d default monsters\n", len(store.monsters))
|
||||
} else {
|
||||
return fmt.Errorf("failed to load from JSON: %w", err)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Loaded %d monsters from JSON\n", len(store.monsters))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializeDefaultMonsters creates the default monster set
|
||||
func initializeDefaultMonsters() error {
|
||||
store := GetStore()
|
||||
|
||||
// Default monsters from the original SQL data
|
||||
defaultMonsters := []*Monster{
|
||||
{ID: 1, Name: "Blue Slime", MaxHP: 4, MaxDmg: 3, Armor: 1, Level: 1, MaxExp: 1, MaxGold: 1, Immune: ImmuneNone},
|
||||
{ID: 2, Name: "Red Slime", MaxHP: 6, MaxDmg: 5, Armor: 1, Level: 1, MaxExp: 2, MaxGold: 1, Immune: ImmuneNone},
|
||||
{ID: 3, Name: "Critter", MaxHP: 6, MaxDmg: 5, Armor: 2, Level: 1, MaxExp: 4, MaxGold: 2, Immune: ImmuneNone},
|
||||
{ID: 4, Name: "Creature", MaxHP: 10, MaxDmg: 8, Armor: 2, Level: 2, MaxExp: 4, MaxGold: 2, Immune: ImmuneNone},
|
||||
{ID: 5, Name: "Shadow", MaxHP: 10, MaxDmg: 9, Armor: 3, Level: 2, MaxExp: 6, MaxGold: 2, Immune: ImmuneHurt},
|
||||
{ID: 6, Name: "Drake", MaxHP: 11, MaxDmg: 10, Armor: 3, Level: 2, MaxExp: 8, MaxGold: 3, Immune: ImmuneNone},
|
||||
{ID: 7, Name: "Shade", MaxHP: 12, MaxDmg: 10, Armor: 3, Level: 3, MaxExp: 10, MaxGold: 3, Immune: ImmuneHurt},
|
||||
{ID: 8, Name: "Drakelor", MaxHP: 14, MaxDmg: 12, Armor: 4, Level: 3, MaxExp: 10, MaxGold: 3, Immune: ImmuneNone},
|
||||
{ID: 9, Name: "Silver Slime", MaxHP: 15, MaxDmg: 100, Armor: 200, Level: 30, MaxExp: 15, MaxGold: 1000, Immune: ImmuneSleep},
|
||||
{ID: 10, Name: "Scamp", MaxHP: 16, MaxDmg: 13, Armor: 5, Level: 4, MaxExp: 15, MaxGold: 5, Immune: ImmuneNone},
|
||||
}
|
||||
|
||||
for _, monster := range defaultMonsters {
|
||||
store.AddMonster(monster)
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := ms.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
ms.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves monster data to JSON file
|
||||
func SaveData() error {
|
||||
store := GetStore()
|
||||
|
||||
dataPath, err := findMonstersDataPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find monsters data path: %w", err)
|
||||
func SaveData(dataPath string) error {
|
||||
ms := GetStore()
|
||||
return ms.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (ms *MonsterStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
ms.byLevel = make(map[int][]int)
|
||||
ms.byImmunity = make(map[int][]int)
|
||||
ms.allByLevel = make([]int, 0)
|
||||
|
||||
// Collect all monsters and build indices
|
||||
allMonsters := ms.GetAll()
|
||||
|
||||
// Build level and immunity indices
|
||||
for id, monster := range allMonsters {
|
||||
ms.byLevel[monster.Level] = append(ms.byLevel[monster.Level], id)
|
||||
ms.byImmunity[monster.Immune] = append(ms.byImmunity[monster.Immune], id)
|
||||
ms.allByLevel = append(ms.allByLevel, id)
|
||||
}
|
||||
|
||||
// Ensure data directory exists
|
||||
dataDir := filepath.Dir(dataPath)
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create data directory: %w", err)
|
||||
|
||||
// Sort allByLevel by level first, then by ID
|
||||
sort.Slice(ms.allByLevel, func(i, j int) bool {
|
||||
monsterI, _ := ms.GetByID(ms.allByLevel[i])
|
||||
monsterJ, _ := ms.GetByID(ms.allByLevel[j])
|
||||
if monsterI.Level == monsterJ.Level {
|
||||
return ms.allByLevel[i] < ms.allByLevel[j]
|
||||
}
|
||||
return monsterI.Level < monsterJ.Level
|
||||
})
|
||||
|
||||
// Sort level indices by ID
|
||||
for level := range ms.byLevel {
|
||||
sort.Ints(ms.byLevel[level])
|
||||
}
|
||||
|
||||
if err := store.SaveToJSON(dataPath); err != nil {
|
||||
return fmt.Errorf("failed to save monsters to JSON: %w", err)
|
||||
|
||||
// Sort immunity indices by level, then ID
|
||||
for immunity := range ms.byImmunity {
|
||||
sort.Slice(ms.byImmunity[immunity], func(i, j int) bool {
|
||||
monsterI, _ := ms.GetByID(ms.byImmunity[immunity][i])
|
||||
monsterJ, _ := ms.GetByID(ms.byImmunity[immunity][j])
|
||||
if monsterI.Level == monsterJ.Level {
|
||||
return ms.byImmunity[immunity][i] < ms.byImmunity[immunity][j]
|
||||
}
|
||||
return monsterI.Level < monsterJ.Level
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Printf("Saved %d monsters to JSON\n", len(store.monsters))
|
||||
return nil
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all monster-specific indices from base store data
|
||||
func (ms *MonsterStore) rebuildIndices() {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
ms.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves a monster by ID - O(1) lookup
|
||||
func Find(id int) (*Monster, error) {
|
||||
store := GetStore()
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
|
||||
monster, exists := store.monsters[id]
|
||||
ms := GetStore()
|
||||
monster, exists := ms.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("monster with ID %d not found", id)
|
||||
}
|
||||
|
||||
return monster, nil
|
||||
}
|
||||
|
||||
// Retrieves all monsters - O(1) lookup (returns pre-sorted slice)
|
||||
func All() ([]*Monster, error) {
|
||||
store := GetStore()
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
ms := GetStore()
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
|
||||
// Return a copy of the slice to prevent external modifications
|
||||
result := make([]*Monster, len(store.allByLevel))
|
||||
copy(result, store.allByLevel)
|
||||
result := make([]*Monster, 0, len(ms.allByLevel))
|
||||
for _, id := range ms.allByLevel {
|
||||
if monster, exists := ms.GetByID(id); exists {
|
||||
result = append(result, monster)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves monsters by level - O(1) lookup
|
||||
func ByLevel(level int) ([]*Monster, error) {
|
||||
store := GetStore()
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
ms := GetStore()
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
|
||||
monsters, exists := store.byLevel[level]
|
||||
ids, exists := ms.byLevel[level]
|
||||
if !exists {
|
||||
return []*Monster{}, nil
|
||||
}
|
||||
|
||||
// Return a copy of the slice to prevent external modifications
|
||||
result := make([]*Monster, len(monsters))
|
||||
copy(result, monsters)
|
||||
result := make([]*Monster, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if monster, exists := ms.GetByID(id); exists {
|
||||
result = append(result, monster)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves monsters within a level range (inclusive) - O(k) where k is result size
|
||||
func ByLevelRange(minLevel, maxLevel int) ([]*Monster, error) {
|
||||
store := GetStore()
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
ms := GetStore()
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
|
||||
var result []*Monster
|
||||
for level := minLevel; level <= maxLevel; level++ {
|
||||
if monsters, exists := store.byLevel[level]; exists {
|
||||
result = append(result, monsters...)
|
||||
if ids, exists := ms.byLevel[level]; exists {
|
||||
for _, id := range ids {
|
||||
if monster, exists := ms.GetByID(id); exists {
|
||||
result = append(result, monster)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves monsters by immunity type - O(1) lookup
|
||||
func ByImmunity(immunityType int) ([]*Monster, error) {
|
||||
store := GetStore()
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
ms := GetStore()
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
|
||||
monsters, exists := store.byImmunity[immunityType]
|
||||
ids, exists := ms.byImmunity[immunityType]
|
||||
if !exists {
|
||||
return []*Monster{}, nil
|
||||
}
|
||||
|
||||
// Return a copy of the slice to prevent external modifications
|
||||
result := make([]*Monster, len(monsters))
|
||||
copy(result, monsters)
|
||||
result := make([]*Monster, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if monster, exists := ms.GetByID(id); exists {
|
||||
result = append(result, monster)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new monster to the in-memory store and sets the ID
|
||||
func (m *Monster) Insert() error {
|
||||
store := GetStore()
|
||||
|
||||
ms := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := m.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if m.ID == 0 {
|
||||
m.ID = store.GetNextID()
|
||||
m.ID = ms.GetNextID()
|
||||
}
|
||||
|
||||
|
||||
// Add to store
|
||||
store.AddMonster(m)
|
||||
ms.AddMonster(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,47 +1,32 @@
|
||||
package news
|
||||
|
||||
import (
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers/scanner"
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// News represents a news post in the database
|
||||
// News represents a news post in the game
|
||||
type News struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Author int `db:"author" json:"author"`
|
||||
Posted int64 `db:"posted" json:"posted"`
|
||||
Content string `db:"content" json:"content"`
|
||||
}
|
||||
|
||||
func (n *News) GetTableName() string {
|
||||
return "news"
|
||||
}
|
||||
|
||||
func (n *News) GetID() int {
|
||||
return n.ID
|
||||
}
|
||||
|
||||
func (n *News) SetID(id int) {
|
||||
n.ID = id
|
||||
}
|
||||
|
||||
func (n *News) Set(field string, value any) error {
|
||||
return database.Set(n, field, value)
|
||||
ID int `json:"id"`
|
||||
Author int `json:"author"`
|
||||
Posted int64 `json:"posted"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (n *News) Save() error {
|
||||
return database.Save(n)
|
||||
newsStore := GetStore()
|
||||
newsStore.UpdateNews(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *News) Delete() error {
|
||||
return database.Delete(n)
|
||||
newsStore := GetStore()
|
||||
newsStore.RemoveNews(n.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new News with sensible defaults
|
||||
@ -53,142 +38,287 @@ func New() *News {
|
||||
}
|
||||
}
|
||||
|
||||
var newsScanner = scanner.New[News]()
|
||||
|
||||
// Returns the column list for news queries
|
||||
func newsColumns() string {
|
||||
return newsScanner.Columns()
|
||||
// Validate checks if news has valid values
|
||||
func (n *News) Validate() error {
|
||||
if n.Posted < 0 {
|
||||
return fmt.Errorf("news Posted timestamp cannot be negative")
|
||||
}
|
||||
if strings.TrimSpace(n.Content) == "" {
|
||||
return fmt.Errorf("news Content cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Populates a News struct using the fast scanner
|
||||
func scanNews(stmt *sqlite.Stmt) *News {
|
||||
news := &News{}
|
||||
newsScanner.Scan(stmt, news)
|
||||
return news
|
||||
// NewsStore provides in-memory storage with O(1) lookups and news-specific indices
|
||||
type NewsStore struct {
|
||||
*store.BaseStore[News] // Embedded generic store
|
||||
byAuthor map[int][]int // Author -> []ID
|
||||
allByPosted []int // All IDs sorted by posted DESC, id DESC
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var newsStore *NewsStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
newsStore = &NewsStore{
|
||||
BaseStore: store.NewBaseStore[News](),
|
||||
byAuthor: make(map[int][]int),
|
||||
allByPosted: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global news store
|
||||
func GetStore() *NewsStore {
|
||||
storeOnce.Do(initStore)
|
||||
return newsStore
|
||||
}
|
||||
|
||||
// AddNews adds a news post to the in-memory store and updates all indices
|
||||
func (ns *NewsStore) AddNews(news *News) {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
|
||||
// Validate news
|
||||
if err := news.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
ns.Add(news.ID, news)
|
||||
|
||||
// Rebuild indices
|
||||
ns.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveNews removes a news post from the store and updates indices
|
||||
func (ns *NewsStore) RemoveNews(id int) {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
ns.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
ns.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateNews updates a news post efficiently
|
||||
func (ns *NewsStore) UpdateNews(news *News) {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
|
||||
// Validate news
|
||||
if err := news.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
ns.Add(news.ID, news)
|
||||
|
||||
// Rebuild indices
|
||||
ns.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads news data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
ns := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := ns.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
ns.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves news data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
ns := GetStore()
|
||||
return ns.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (ns *NewsStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
ns.byAuthor = make(map[int][]int)
|
||||
ns.allByPosted = make([]int, 0)
|
||||
|
||||
// Collect all news and build indices
|
||||
allNews := ns.GetAll()
|
||||
|
||||
for id, news := range allNews {
|
||||
// Author index
|
||||
ns.byAuthor[news.Author] = append(ns.byAuthor[news.Author], id)
|
||||
|
||||
// All IDs
|
||||
ns.allByPosted = append(ns.allByPosted, id)
|
||||
}
|
||||
|
||||
// Sort allByPosted by posted DESC, then ID DESC
|
||||
sort.Slice(ns.allByPosted, func(i, j int) bool {
|
||||
newsI, _ := ns.GetByID(ns.allByPosted[i])
|
||||
newsJ, _ := ns.GetByID(ns.allByPosted[j])
|
||||
if newsI.Posted != newsJ.Posted {
|
||||
return newsI.Posted > newsJ.Posted // DESC
|
||||
}
|
||||
return ns.allByPosted[i] > ns.allByPosted[j] // DESC
|
||||
})
|
||||
|
||||
// Sort author indices by posted DESC, then ID DESC
|
||||
for author := range ns.byAuthor {
|
||||
sort.Slice(ns.byAuthor[author], func(i, j int) bool {
|
||||
newsI, _ := ns.GetByID(ns.byAuthor[author][i])
|
||||
newsJ, _ := ns.GetByID(ns.byAuthor[author][j])
|
||||
if newsI.Posted != newsJ.Posted {
|
||||
return newsI.Posted > newsJ.Posted // DESC
|
||||
}
|
||||
return ns.byAuthor[author][i] > ns.byAuthor[author][j] // DESC
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all news-specific indices from base store data
|
||||
func (ns *NewsStore) rebuildIndices() {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
ns.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves a news post by ID
|
||||
func Find(id int) (*News, error) {
|
||||
var news *News
|
||||
|
||||
query := `SELECT ` + newsColumns() + ` FROM news WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
news = scanNews(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find news: %w", err)
|
||||
}
|
||||
|
||||
if news == nil {
|
||||
ns := GetStore()
|
||||
news, exists := ns.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("news with ID %d not found", id)
|
||||
}
|
||||
|
||||
return news, nil
|
||||
}
|
||||
|
||||
// Retrieves all news posts ordered by posted date (newest first)
|
||||
func All() ([]*News, error) {
|
||||
var newsPosts []*News
|
||||
ns := GetStore()
|
||||
ns.mu.RLock()
|
||||
defer ns.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + newsColumns() + ` FROM news ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
news := scanNews(stmt)
|
||||
newsPosts = append(newsPosts, news)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all news: %w", err)
|
||||
result := make([]*News, 0, len(ns.allByPosted))
|
||||
for _, id := range ns.allByPosted {
|
||||
if news, exists := ns.GetByID(id); exists {
|
||||
result = append(result, news)
|
||||
}
|
||||
}
|
||||
|
||||
return newsPosts, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves news posts by a specific author
|
||||
func ByAuthor(authorID int) ([]*News, error) {
|
||||
var newsPosts []*News
|
||||
ns := GetStore()
|
||||
ns.mu.RLock()
|
||||
defer ns.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + newsColumns() + ` FROM news WHERE author = ? ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
news := scanNews(stmt)
|
||||
newsPosts = append(newsPosts, news)
|
||||
return nil
|
||||
}, authorID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve news by author: %w", err)
|
||||
ids, exists := ns.byAuthor[authorID]
|
||||
if !exists {
|
||||
return []*News{}, nil
|
||||
}
|
||||
|
||||
return newsPosts, nil
|
||||
result := make([]*News, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if news, exists := ns.GetByID(id); exists {
|
||||
result = append(result, news)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves the most recent news posts (limited by count)
|
||||
func Recent(limit int) ([]*News, error) {
|
||||
var newsPosts []*News
|
||||
ns := GetStore()
|
||||
ns.mu.RLock()
|
||||
defer ns.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + newsColumns() + ` FROM news ORDER BY posted DESC, id DESC LIMIT ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
news := scanNews(stmt)
|
||||
newsPosts = append(newsPosts, news)
|
||||
return nil
|
||||
}, limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve recent news: %w", err)
|
||||
if limit > len(ns.allByPosted) {
|
||||
limit = len(ns.allByPosted)
|
||||
}
|
||||
|
||||
return newsPosts, nil
|
||||
result := make([]*News, 0, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
if news, exists := ns.GetByID(ns.allByPosted[i]); exists {
|
||||
result = append(result, news)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves news posts since a specific timestamp
|
||||
func Since(since int64) ([]*News, error) {
|
||||
var newsPosts []*News
|
||||
ns := GetStore()
|
||||
ns.mu.RLock()
|
||||
defer ns.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + newsColumns() + ` FROM news WHERE posted >= ? ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
news := scanNews(stmt)
|
||||
newsPosts = append(newsPosts, news)
|
||||
return nil
|
||||
}, since)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve news since timestamp: %w", err)
|
||||
var result []*News
|
||||
for _, id := range ns.allByPosted {
|
||||
if news, exists := ns.GetByID(id); exists && news.Posted >= since {
|
||||
result = append(result, news)
|
||||
}
|
||||
}
|
||||
|
||||
return newsPosts, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves news posts between two timestamps (inclusive)
|
||||
func Between(start, end int64) ([]*News, error) {
|
||||
var newsPosts []*News
|
||||
ns := GetStore()
|
||||
ns.mu.RLock()
|
||||
defer ns.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + newsColumns() + ` FROM news WHERE posted >= ? AND posted <= ? ORDER BY posted DESC, id DESC`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
news := scanNews(stmt)
|
||||
newsPosts = append(newsPosts, news)
|
||||
return nil
|
||||
}, start, end)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve news between timestamps: %w", err)
|
||||
var result []*News
|
||||
for _, id := range ns.allByPosted {
|
||||
if news, exists := ns.GetByID(id); exists && news.Posted >= start && news.Posted <= end {
|
||||
result = append(result, news)
|
||||
}
|
||||
}
|
||||
|
||||
return newsPosts, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new news post to the database and sets the ID
|
||||
// Retrieves news posts containing the search term in content
|
||||
func Search(term string) ([]*News, error) {
|
||||
ns := GetStore()
|
||||
ns.mu.RLock()
|
||||
defer ns.mu.RUnlock()
|
||||
|
||||
var result []*News
|
||||
lowerTerm := strings.ToLower(term)
|
||||
|
||||
for _, id := range ns.allByPosted {
|
||||
if news, exists := ns.GetByID(id); exists {
|
||||
if strings.Contains(strings.ToLower(news.Content), lowerTerm) {
|
||||
result = append(result, news)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new news post to the in-memory store and sets the ID
|
||||
func (n *News) Insert() error {
|
||||
columns := `author, posted, content`
|
||||
values := []any{n.Author, n.Posted, n.Content}
|
||||
return database.Insert(n, columns, values...)
|
||||
ns := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := n.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if n.ID == 0 {
|
||||
n.ID = ns.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
ns.AddNews(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the posted timestamp as a time.Time
|
||||
@ -198,7 +328,7 @@ func (n *News) PostedTime() time.Time {
|
||||
|
||||
// Sets the posted timestamp from a time.Time
|
||||
func (n *News) SetPostedTime(t time.Time) {
|
||||
n.Set("Posted", t.Unix())
|
||||
n.Posted = t.Unix()
|
||||
}
|
||||
|
||||
// Returns true if the news post was made within the last 24 hours
|
||||
@ -276,23 +406,3 @@ func (n *News) Contains(term string) bool {
|
||||
func (n *News) IsEmpty() bool {
|
||||
return strings.TrimSpace(n.Content) == ""
|
||||
}
|
||||
|
||||
// Retrieves news posts containing the search term in content
|
||||
func Search(term string) ([]*News, error) {
|
||||
var newsPosts []*News
|
||||
|
||||
query := `SELECT ` + newsColumns() + ` FROM news WHERE LOWER(content) LIKE LOWER(?) ORDER BY posted DESC, id DESC`
|
||||
searchTerm := "%" + term + "%"
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
news := scanNews(stmt)
|
||||
newsPosts = append(newsPosts, news)
|
||||
return nil
|
||||
}, searchTerm)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search news: %w", err)
|
||||
}
|
||||
|
||||
return newsPosts, nil
|
||||
}
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"syscall"
|
||||
|
||||
"dk/internal/auth"
|
||||
"dk/internal/database"
|
||||
"dk/internal/middleware"
|
||||
"dk/internal/monsters"
|
||||
"dk/internal/router"
|
||||
@ -27,14 +26,8 @@ func Start(port string) error {
|
||||
// Initialize template singleton
|
||||
template.InitializeCache(cwd)
|
||||
|
||||
// Initialize database singleton
|
||||
if err := database.Init("dk.db"); err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
// Load monster data into memory
|
||||
if err := monsters.LoadData(); err != nil {
|
||||
if err := monsters.LoadData("data/monsters.json"); err != nil {
|
||||
return fmt.Errorf("failed to load monster data: %w", err)
|
||||
}
|
||||
|
||||
@ -104,7 +97,7 @@ func Start(port string) error {
|
||||
|
||||
// Save monster data before shutdown
|
||||
log.Println("Saving monster data...")
|
||||
if err := monsters.SaveData(); err != nil {
|
||||
if err := monsters.SaveData("data/monsters.json"); err != nil {
|
||||
log.Printf("Error saving monster data: %v", err)
|
||||
}
|
||||
|
||||
|
@ -1,47 +1,32 @@
|
||||
package spells
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Spell represents a spell in the database
|
||||
// Spell represents a spell in the game
|
||||
type Spell struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
MP int `db:"mp" json:"mp"`
|
||||
Attribute int `db:"attribute" json:"attribute"`
|
||||
Type int `db:"type" json:"type"`
|
||||
}
|
||||
|
||||
func (s *Spell) GetTableName() string {
|
||||
return "spells"
|
||||
}
|
||||
|
||||
func (s *Spell) GetID() int {
|
||||
return s.ID
|
||||
}
|
||||
|
||||
func (s *Spell) SetID(id int) {
|
||||
s.ID = id
|
||||
}
|
||||
|
||||
func (s *Spell) Set(field string, value any) error {
|
||||
return database.Set(s, field, value)
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MP int `json:"mp"`
|
||||
Attribute int `json:"attribute"`
|
||||
Type int `json:"type"`
|
||||
}
|
||||
|
||||
func (s *Spell) Save() error {
|
||||
return database.Save(s)
|
||||
spellStore := GetStore()
|
||||
spellStore.UpdateSpell(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Spell) Delete() error {
|
||||
return database.Delete(s)
|
||||
spellStore := GetStore()
|
||||
spellStore.RemoveSpell(s.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new Spell with sensible defaults
|
||||
@ -54,18 +39,21 @@ func New() *Spell {
|
||||
}
|
||||
}
|
||||
|
||||
var spellScanner = scanner.New[Spell]()
|
||||
|
||||
// Returns the column list for spell queries
|
||||
func spellColumns() string {
|
||||
return spellScanner.Columns()
|
||||
}
|
||||
|
||||
// Populates a Spell struct using the fast scanner
|
||||
func scanSpell(stmt *sqlite.Stmt) *Spell {
|
||||
spell := &Spell{}
|
||||
spellScanner.Scan(stmt, spell)
|
||||
return spell
|
||||
// Validate checks if spell has valid values
|
||||
func (s *Spell) Validate() error {
|
||||
if s.Name == "" {
|
||||
return fmt.Errorf("spell name cannot be empty")
|
||||
}
|
||||
if s.MP < 0 {
|
||||
return fmt.Errorf("spell MP cannot be negative")
|
||||
}
|
||||
if s.Attribute < 0 {
|
||||
return fmt.Errorf("spell Attribute cannot be negative")
|
||||
}
|
||||
if s.Type < TypeHealing || s.Type > TypeDefenseBoost {
|
||||
return fmt.Errorf("invalid spell type: %d", s.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SpellType constants for spell types
|
||||
@ -77,131 +65,293 @@ const (
|
||||
TypeDefenseBoost = 5
|
||||
)
|
||||
|
||||
// SpellStore provides in-memory storage with O(1) lookups and spell-specific indices
|
||||
type SpellStore struct {
|
||||
*store.BaseStore[Spell] // Embedded generic store
|
||||
byType map[int][]int // Type -> []ID
|
||||
byName map[string]int // Name (lowercase) -> ID
|
||||
byMP map[int][]int // MP -> []ID
|
||||
allByTypeMP []int // All IDs sorted by type, MP, ID
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var spellStore *SpellStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
spellStore = &SpellStore{
|
||||
BaseStore: store.NewBaseStore[Spell](),
|
||||
byType: make(map[int][]int),
|
||||
byName: make(map[string]int),
|
||||
byMP: make(map[int][]int),
|
||||
allByTypeMP: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global spell store
|
||||
func GetStore() *SpellStore {
|
||||
storeOnce.Do(initStore)
|
||||
return spellStore
|
||||
}
|
||||
|
||||
// AddSpell adds a spell to the in-memory store and updates all indices
|
||||
func (ss *SpellStore) AddSpell(spell *Spell) {
|
||||
ss.mu.Lock()
|
||||
defer ss.mu.Unlock()
|
||||
|
||||
// Validate spell
|
||||
if err := spell.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
ss.Add(spell.ID, spell)
|
||||
|
||||
// Rebuild indices
|
||||
ss.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveSpell removes a spell from the store and updates indices
|
||||
func (ss *SpellStore) RemoveSpell(id int) {
|
||||
ss.mu.Lock()
|
||||
defer ss.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
ss.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
ss.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateSpell updates a spell efficiently
|
||||
func (ss *SpellStore) UpdateSpell(spell *Spell) {
|
||||
ss.mu.Lock()
|
||||
defer ss.mu.Unlock()
|
||||
|
||||
// Validate spell
|
||||
if err := spell.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
ss.Add(spell.ID, spell)
|
||||
|
||||
// Rebuild indices
|
||||
ss.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads spell data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
ss := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := ss.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
ss.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves spell data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
ss := GetStore()
|
||||
return ss.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (ss *SpellStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
ss.byType = make(map[int][]int)
|
||||
ss.byName = make(map[string]int)
|
||||
ss.byMP = make(map[int][]int)
|
||||
ss.allByTypeMP = make([]int, 0)
|
||||
|
||||
// Collect all spells and build indices
|
||||
allSpells := ss.GetAll()
|
||||
|
||||
for id, spell := range allSpells {
|
||||
// Type index
|
||||
ss.byType[spell.Type] = append(ss.byType[spell.Type], id)
|
||||
|
||||
// Name index (case-insensitive)
|
||||
ss.byName[strings.ToLower(spell.Name)] = id
|
||||
|
||||
// MP index
|
||||
ss.byMP[spell.MP] = append(ss.byMP[spell.MP], id)
|
||||
|
||||
// All IDs
|
||||
ss.allByTypeMP = append(ss.allByTypeMP, id)
|
||||
}
|
||||
|
||||
// Sort allByTypeMP by type, then MP, then ID
|
||||
sort.Slice(ss.allByTypeMP, func(i, j int) bool {
|
||||
spellI, _ := ss.GetByID(ss.allByTypeMP[i])
|
||||
spellJ, _ := ss.GetByID(ss.allByTypeMP[j])
|
||||
if spellI.Type != spellJ.Type {
|
||||
return spellI.Type < spellJ.Type
|
||||
}
|
||||
if spellI.MP != spellJ.MP {
|
||||
return spellI.MP < spellJ.MP
|
||||
}
|
||||
return ss.allByTypeMP[i] < ss.allByTypeMP[j]
|
||||
})
|
||||
|
||||
// Sort type indices by MP, then ID
|
||||
for spellType := range ss.byType {
|
||||
sort.Slice(ss.byType[spellType], func(i, j int) bool {
|
||||
spellI, _ := ss.GetByID(ss.byType[spellType][i])
|
||||
spellJ, _ := ss.GetByID(ss.byType[spellType][j])
|
||||
if spellI.MP != spellJ.MP {
|
||||
return spellI.MP < spellJ.MP
|
||||
}
|
||||
return ss.byType[spellType][i] < ss.byType[spellType][j]
|
||||
})
|
||||
}
|
||||
|
||||
// Sort MP indices by type, then ID
|
||||
for mp := range ss.byMP {
|
||||
sort.Slice(ss.byMP[mp], func(i, j int) bool {
|
||||
spellI, _ := ss.GetByID(ss.byMP[mp][i])
|
||||
spellJ, _ := ss.GetByID(ss.byMP[mp][j])
|
||||
if spellI.Type != spellJ.Type {
|
||||
return spellI.Type < spellJ.Type
|
||||
}
|
||||
return ss.byMP[mp][i] < ss.byMP[mp][j]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all spell-specific indices from base store data
|
||||
func (ss *SpellStore) rebuildIndices() {
|
||||
ss.mu.Lock()
|
||||
defer ss.mu.Unlock()
|
||||
ss.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves a spell by ID
|
||||
func Find(id int) (*Spell, error) {
|
||||
var spell *Spell
|
||||
|
||||
query := `SELECT ` + spellColumns() + ` FROM spells WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
spell = scanSpell(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find spell: %w", err)
|
||||
}
|
||||
|
||||
if spell == nil {
|
||||
ss := GetStore()
|
||||
spell, exists := ss.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("spell with ID %d not found", id)
|
||||
}
|
||||
|
||||
return spell, nil
|
||||
}
|
||||
|
||||
// Retrieves all spells
|
||||
func All() ([]*Spell, error) {
|
||||
var spells []*Spell
|
||||
ss := GetStore()
|
||||
ss.mu.RLock()
|
||||
defer ss.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + spellColumns() + ` FROM spells ORDER BY type, mp, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
spell := scanSpell(stmt)
|
||||
spells = append(spells, spell)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all spells: %w", err)
|
||||
result := make([]*Spell, 0, len(ss.allByTypeMP))
|
||||
for _, id := range ss.allByTypeMP {
|
||||
if spell, exists := ss.GetByID(id); exists {
|
||||
result = append(result, spell)
|
||||
}
|
||||
}
|
||||
|
||||
return spells, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves spells by type
|
||||
func ByType(spellType int) ([]*Spell, error) {
|
||||
var spells []*Spell
|
||||
ss := GetStore()
|
||||
ss.mu.RLock()
|
||||
defer ss.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + spellColumns() + ` FROM spells WHERE type = ? ORDER BY mp, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
spell := scanSpell(stmt)
|
||||
spells = append(spells, spell)
|
||||
return nil
|
||||
}, spellType)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve spells by type: %w", err)
|
||||
ids, exists := ss.byType[spellType]
|
||||
if !exists {
|
||||
return []*Spell{}, nil
|
||||
}
|
||||
|
||||
return spells, nil
|
||||
result := make([]*Spell, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if spell, exists := ss.GetByID(id); exists {
|
||||
result = append(result, spell)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves spells that cost at most the specified MP
|
||||
func ByMaxMP(maxMP int) ([]*Spell, error) {
|
||||
var spells []*Spell
|
||||
ss := GetStore()
|
||||
ss.mu.RLock()
|
||||
defer ss.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + spellColumns() + ` FROM spells WHERE mp <= ? ORDER BY type, mp, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
spell := scanSpell(stmt)
|
||||
spells = append(spells, spell)
|
||||
return nil
|
||||
}, maxMP)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve spells by max MP: %w", err)
|
||||
var result []*Spell
|
||||
for mp := 0; mp <= maxMP; mp++ {
|
||||
if ids, exists := ss.byMP[mp]; exists {
|
||||
for _, id := range ids {
|
||||
if spell, exists := ss.GetByID(id); exists {
|
||||
result = append(result, spell)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return spells, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves spells of a specific type that cost at most the specified MP
|
||||
func ByTypeAndMaxMP(spellType, maxMP int) ([]*Spell, error) {
|
||||
var spells []*Spell
|
||||
ss := GetStore()
|
||||
ss.mu.RLock()
|
||||
defer ss.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + spellColumns() + ` FROM spells WHERE type = ? AND mp <= ? ORDER BY mp, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
spell := scanSpell(stmt)
|
||||
spells = append(spells, spell)
|
||||
return nil
|
||||
}, spellType, maxMP)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve spells by type and max MP: %w", err)
|
||||
ids, exists := ss.byType[spellType]
|
||||
if !exists {
|
||||
return []*Spell{}, nil
|
||||
}
|
||||
|
||||
return spells, nil
|
||||
var result []*Spell
|
||||
for _, id := range ids {
|
||||
if spell, exists := ss.GetByID(id); exists && spell.MP <= maxMP {
|
||||
result = append(result, spell)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves a spell by name (case-insensitive)
|
||||
func ByName(name string) (*Spell, error) {
|
||||
var spell *Spell
|
||||
ss := GetStore()
|
||||
ss.mu.RLock()
|
||||
defer ss.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + spellColumns() + ` FROM spells WHERE LOWER(name) = LOWER(?) LIMIT 1`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
spell = scanSpell(stmt)
|
||||
return nil
|
||||
}, name)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find spell by name: %w", err)
|
||||
id, exists := ss.byName[strings.ToLower(name)]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("spell with name '%s' not found", name)
|
||||
}
|
||||
|
||||
if spell == nil {
|
||||
spell, exists := ss.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("spell with name '%s' not found", name)
|
||||
}
|
||||
|
||||
return spell, nil
|
||||
}
|
||||
|
||||
// Saves a new spell to the database and sets the ID
|
||||
// Saves a new spell to the in-memory store and sets the ID
|
||||
func (s *Spell) Insert() error {
|
||||
columns := `name, mp, attribute, type`
|
||||
values := []any{s.Name, s.MP, s.Attribute, s.Type}
|
||||
return database.Insert(s, columns, values...)
|
||||
ss := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := s.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if s.ID == 0 {
|
||||
s.ID = ss.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
ss.AddSpell(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns true if the spell is a healing spell
|
||||
|
198
internal/store/store.go
Normal file
198
internal/store/store.go
Normal file
@ -0,0 +1,198 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Store provides generic storage operations
|
||||
type Store[T any] interface {
|
||||
LoadFromJSON(filename string) error
|
||||
SaveToJSON(filename string) error
|
||||
LoadData(dataPath string) error
|
||||
SaveData(dataPath string) error
|
||||
}
|
||||
|
||||
// BaseStore provides generic JSON persistence
|
||||
type BaseStore[T any] struct {
|
||||
items map[int]*T
|
||||
maxID int
|
||||
mu sync.RWMutex
|
||||
itemType reflect.Type
|
||||
}
|
||||
|
||||
// NewBaseStore creates a new base store for type T
|
||||
func NewBaseStore[T any]() *BaseStore[T] {
|
||||
var zero T
|
||||
return &BaseStore[T]{
|
||||
items: make(map[int]*T),
|
||||
maxID: 0,
|
||||
itemType: reflect.TypeOf(zero),
|
||||
}
|
||||
}
|
||||
|
||||
// GetNextID returns the next available ID atomically
|
||||
func (bs *BaseStore[T]) GetNextID() int {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
bs.maxID++
|
||||
return bs.maxID
|
||||
}
|
||||
|
||||
// GetByID retrieves an item by ID
|
||||
func (bs *BaseStore[T]) GetByID(id int) (*T, bool) {
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
item, exists := bs.items[id]
|
||||
return item, exists
|
||||
}
|
||||
|
||||
// Add adds an item to the store
|
||||
func (bs *BaseStore[T]) Add(id int, item *T) {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
bs.items[id] = item
|
||||
if id > bs.maxID {
|
||||
bs.maxID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Remove removes an item from the store
|
||||
func (bs *BaseStore[T]) Remove(id int) {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
delete(bs.items, id)
|
||||
}
|
||||
|
||||
// GetAll returns all items
|
||||
func (bs *BaseStore[T]) GetAll() map[int]*T {
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
result := make(map[int]*T, len(bs.items))
|
||||
for k, v := range bs.items {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Clear removes all items
|
||||
func (bs *BaseStore[T]) Clear() {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
bs.items = make(map[int]*T)
|
||||
bs.maxID = 0
|
||||
}
|
||||
|
||||
// LoadFromJSON loads items from JSON using reflection
|
||||
func (bs *BaseStore[T]) LoadFromJSON(filename string) error {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to read JSON: %w", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create slice of pointers to T
|
||||
sliceType := reflect.SliceOf(reflect.PointerTo(bs.itemType))
|
||||
slicePtr := reflect.New(sliceType)
|
||||
|
||||
if err := json.Unmarshal(data, slicePtr.Interface()); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal JSON: %w", err)
|
||||
}
|
||||
|
||||
// Clear existing data
|
||||
bs.items = make(map[int]*T)
|
||||
bs.maxID = 0
|
||||
|
||||
// Extract items using reflection
|
||||
slice := slicePtr.Elem()
|
||||
for i := 0; i < slice.Len(); i++ {
|
||||
item := slice.Index(i).Interface().(*T)
|
||||
|
||||
// Get ID using reflection
|
||||
itemValue := reflect.ValueOf(item).Elem()
|
||||
idField := itemValue.FieldByName("ID")
|
||||
if !idField.IsValid() {
|
||||
return fmt.Errorf("item type must have an ID field")
|
||||
}
|
||||
|
||||
id := int(idField.Int())
|
||||
bs.items[id] = item
|
||||
if id > bs.maxID {
|
||||
bs.maxID = id
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveToJSON saves items to JSON atomically
|
||||
func (bs *BaseStore[T]) SaveToJSON(filename string) error {
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
items := make([]*T, 0, len(bs.items))
|
||||
for _, item := range bs.items {
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(items, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Atomic write
|
||||
tempFile := filename + ".tmp"
|
||||
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write temp JSON: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempFile, filename); err != nil {
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("failed to rename temp JSON: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadData loads from JSON file or starts empty
|
||||
func (bs *BaseStore[T]) LoadData(dataPath string) error {
|
||||
if err := bs.LoadFromJSON(dataPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
fmt.Println("No existing data found, starting with empty store")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to load from JSON: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Loaded %d items from JSON\n", len(bs.items))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves to JSON file
|
||||
func (bs *BaseStore[T]) SaveData(dataPath string) error {
|
||||
// Ensure directory exists
|
||||
dataDir := filepath.Dir(dataPath)
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create data directory: %w", err)
|
||||
}
|
||||
|
||||
if err := bs.SaveToJSON(dataPath); err != nil {
|
||||
return fmt.Errorf("failed to save to JSON: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Saved %d items to JSON\n", len(bs.items))
|
||||
return nil
|
||||
}
|
@ -1,53 +1,40 @@
|
||||
package towns
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Town represents a town in the database
|
||||
// Town represents a town in the game
|
||||
type Town struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
X int `db:"x" json:"x"`
|
||||
Y int `db:"y" json:"y"`
|
||||
InnCost int `db:"inn_cost" json:"inn_cost"`
|
||||
MapCost int `db:"map_cost" json:"map_cost"`
|
||||
TPCost int `db:"tp_cost" json:"tp_cost"`
|
||||
ShopList string `db:"shop_list" json:"shop_list"`
|
||||
}
|
||||
|
||||
func (t *Town) GetTableName() string {
|
||||
return "towns"
|
||||
}
|
||||
|
||||
func (t *Town) GetID() int {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func (t *Town) SetID(id int) {
|
||||
t.ID = id
|
||||
}
|
||||
|
||||
func (t *Town) Set(field string, value any) error {
|
||||
return database.Set(t, field, value)
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
InnCost int `json:"inn_cost"`
|
||||
MapCost int `json:"map_cost"`
|
||||
TPCost int `json:"tp_cost"`
|
||||
ShopList string `json:"shop_list"`
|
||||
}
|
||||
|
||||
func (t *Town) Save() error {
|
||||
return database.Save(t)
|
||||
townStore := GetStore()
|
||||
townStore.UpdateTown(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Town) Delete() error {
|
||||
return database.Delete(t)
|
||||
townStore := GetStore()
|
||||
townStore.RemoveTown(t.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new Town with sensible defaults
|
||||
@ -63,77 +50,212 @@ func New() *Town {
|
||||
}
|
||||
}
|
||||
|
||||
var townScanner = scanner.New[Town]()
|
||||
|
||||
// Returns the column list for town queries
|
||||
func townColumns() string {
|
||||
return townScanner.Columns()
|
||||
// Validate checks if town has valid values
|
||||
func (t *Town) Validate() error {
|
||||
if t.Name == "" {
|
||||
return fmt.Errorf("town name cannot be empty")
|
||||
}
|
||||
if t.InnCost < 0 {
|
||||
return fmt.Errorf("town InnCost cannot be negative")
|
||||
}
|
||||
if t.MapCost < 0 {
|
||||
return fmt.Errorf("town MapCost cannot be negative")
|
||||
}
|
||||
if t.TPCost < 0 {
|
||||
return fmt.Errorf("town TPCost cannot be negative")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Populates a Town struct using the fast scanner
|
||||
func scanTown(stmt *sqlite.Stmt) *Town {
|
||||
town := &Town{}
|
||||
townScanner.Scan(stmt, town)
|
||||
return town
|
||||
// TownStore provides in-memory storage with O(1) lookups and town-specific indices
|
||||
type TownStore struct {
|
||||
*store.BaseStore[Town] // Embedded generic store
|
||||
byName map[string]int // Name (lowercase) -> ID
|
||||
byCoords map[string]int // "x,y" -> ID
|
||||
byInnCost map[int][]int // InnCost -> []ID
|
||||
byTPCost map[int][]int // TPCost -> []ID
|
||||
allByID []int // All IDs sorted by ID
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var townStore *TownStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
townStore = &TownStore{
|
||||
BaseStore: store.NewBaseStore[Town](),
|
||||
byName: make(map[string]int),
|
||||
byCoords: make(map[string]int),
|
||||
byInnCost: make(map[int][]int),
|
||||
byTPCost: make(map[int][]int),
|
||||
allByID: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global town store
|
||||
func GetStore() *TownStore {
|
||||
storeOnce.Do(initStore)
|
||||
return townStore
|
||||
}
|
||||
|
||||
// AddTown adds a town to the in-memory store and updates all indices
|
||||
func (ts *TownStore) AddTown(town *Town) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
|
||||
// Validate town
|
||||
if err := town.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
ts.Add(town.ID, town)
|
||||
|
||||
// Rebuild indices
|
||||
ts.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveTown removes a town from the store and updates indices
|
||||
func (ts *TownStore) RemoveTown(id int) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
ts.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
ts.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateTown updates a town efficiently
|
||||
func (ts *TownStore) UpdateTown(town *Town) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
|
||||
// Validate town
|
||||
if err := town.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
ts.Add(town.ID, town)
|
||||
|
||||
// Rebuild indices
|
||||
ts.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads town data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
ts := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := ts.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
ts.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves town data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
ts := GetStore()
|
||||
return ts.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// coordsKey creates a key for coordinate-based lookup
|
||||
func coordsKey(x, y int) string {
|
||||
return strconv.Itoa(x) + "," + strconv.Itoa(y)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (ts *TownStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
ts.byName = make(map[string]int)
|
||||
ts.byCoords = make(map[string]int)
|
||||
ts.byInnCost = make(map[int][]int)
|
||||
ts.byTPCost = make(map[int][]int)
|
||||
ts.allByID = make([]int, 0)
|
||||
|
||||
// Collect all towns and build indices
|
||||
allTowns := ts.GetAll()
|
||||
|
||||
for id, town := range allTowns {
|
||||
// Name index (case-insensitive)
|
||||
ts.byName[strings.ToLower(town.Name)] = id
|
||||
|
||||
// Coordinates index
|
||||
ts.byCoords[coordsKey(town.X, town.Y)] = id
|
||||
|
||||
// Cost indices
|
||||
ts.byInnCost[town.InnCost] = append(ts.byInnCost[town.InnCost], id)
|
||||
ts.byTPCost[town.TPCost] = append(ts.byTPCost[town.TPCost], id)
|
||||
|
||||
// All IDs
|
||||
ts.allByID = append(ts.allByID, id)
|
||||
}
|
||||
|
||||
// Sort all by ID
|
||||
sort.Ints(ts.allByID)
|
||||
|
||||
// Sort cost indices by ID
|
||||
for innCost := range ts.byInnCost {
|
||||
sort.Ints(ts.byInnCost[innCost])
|
||||
}
|
||||
|
||||
for tpCost := range ts.byTPCost {
|
||||
sort.Ints(ts.byTPCost[tpCost])
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all town-specific indices from base store data
|
||||
func (ts *TownStore) rebuildIndices() {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// Retrieves a town by ID
|
||||
func Find(id int) (*Town, error) {
|
||||
var town *Town
|
||||
|
||||
query := `SELECT ` + townColumns() + ` FROM towns WHERE id = ?`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
town = scanTown(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find town: %w", err)
|
||||
}
|
||||
|
||||
if town == nil {
|
||||
ts := GetStore()
|
||||
town, exists := ts.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("town with ID %d not found", id)
|
||||
}
|
||||
|
||||
return town, nil
|
||||
}
|
||||
|
||||
// Retrieves all towns
|
||||
func All() ([]*Town, error) {
|
||||
var towns []*Town
|
||||
ts := GetStore()
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + townColumns() + ` FROM towns ORDER BY id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
town := scanTown(stmt)
|
||||
towns = append(towns, town)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all towns: %w", err)
|
||||
result := make([]*Town, 0, len(ts.allByID))
|
||||
for _, id := range ts.allByID {
|
||||
if town, exists := ts.GetByID(id); exists {
|
||||
result = append(result, town)
|
||||
}
|
||||
}
|
||||
|
||||
return towns, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves a town by name (case-insensitive)
|
||||
func ByName(name string) (*Town, error) {
|
||||
var town *Town
|
||||
ts := GetStore()
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + townColumns() + ` FROM towns WHERE LOWER(name) = LOWER(?) LIMIT 1`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
town = scanTown(stmt)
|
||||
return nil
|
||||
}, name)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find town by name: %w", err)
|
||||
id, exists := ts.byName[strings.ToLower(name)]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("town with name '%s' not found", name)
|
||||
}
|
||||
|
||||
if town == nil {
|
||||
town, exists := ts.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("town with name '%s' not found", name)
|
||||
}
|
||||
|
||||
@ -142,55 +264,56 @@ func ByName(name string) (*Town, error) {
|
||||
|
||||
// Retrieves towns with inn cost at most the specified amount
|
||||
func ByMaxInnCost(maxCost int) ([]*Town, error) {
|
||||
var towns []*Town
|
||||
ts := GetStore()
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + townColumns() + ` FROM towns WHERE inn_cost <= ? ORDER BY inn_cost, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
town := scanTown(stmt)
|
||||
towns = append(towns, town)
|
||||
return nil
|
||||
}, maxCost)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve towns by max inn cost: %w", err)
|
||||
var result []*Town
|
||||
for cost := 0; cost <= maxCost; cost++ {
|
||||
if ids, exists := ts.byInnCost[cost]; exists {
|
||||
for _, id := range ids {
|
||||
if town, exists := ts.GetByID(id); exists {
|
||||
result = append(result, town)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return towns, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves towns with teleport cost at most the specified amount
|
||||
func ByMaxTPCost(maxCost int) ([]*Town, error) {
|
||||
var towns []*Town
|
||||
ts := GetStore()
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + townColumns() + ` FROM towns WHERE tp_cost <= ? ORDER BY tp_cost, id`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
town := scanTown(stmt)
|
||||
towns = append(towns, town)
|
||||
return nil
|
||||
}, maxCost)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve towns by max TP cost: %w", err)
|
||||
var result []*Town
|
||||
for cost := 0; cost <= maxCost; cost++ {
|
||||
if ids, exists := ts.byTPCost[cost]; exists {
|
||||
for _, id := range ids {
|
||||
if town, exists := ts.GetByID(id); exists {
|
||||
result = append(result, town)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return towns, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Retrieves a town by its x, y coordinates
|
||||
func ByCoords(x, y int) (*Town, error) {
|
||||
var town *Town
|
||||
ts := GetStore()
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + townColumns() + ` FROM towns WHERE x = ? AND y = ? LIMIT 1`
|
||||
id, exists := ts.byCoords[coordsKey(x, y)]
|
||||
if !exists {
|
||||
return nil, nil // Return nil if not found (like original)
|
||||
}
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
town = scanTown(stmt)
|
||||
return nil
|
||||
}, x, y)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve town by coordinates: %w", err)
|
||||
town, exists := ts.GetByID(id)
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return town, nil
|
||||
@ -198,46 +321,61 @@ func ByCoords(x, y int) (*Town, error) {
|
||||
|
||||
// ExistsAt checks for a town at the given coordinates, returning true/false
|
||||
func ExistsAt(x, y int) bool {
|
||||
var exists bool
|
||||
ts := GetStore()
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
query := `SELECT COUNT(*) > 0 FROM towns WHERE x = ? AND y = ? LIMIT 1`
|
||||
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
exists = stmt.ColumnInt(0) > 0
|
||||
return nil
|
||||
}, x, y)
|
||||
|
||||
return err == nil && exists
|
||||
_, exists := ts.byCoords[coordsKey(x, y)]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Retrieves towns within a certain distance from a point
|
||||
func ByDistance(fromX, fromY, maxDistance int) ([]*Town, error) {
|
||||
var towns []*Town
|
||||
ts := GetStore()
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
query := `SELECT ` + townColumns() + `
|
||||
FROM towns
|
||||
WHERE ((x - ?) * (x - ?) + (y - ?) * (y - ?)) <= ?
|
||||
ORDER BY ((x - ?) * (x - ?) + (y - ?) * (y - ?)), id`
|
||||
var result []*Town
|
||||
maxDistance2 := float64(maxDistance * maxDistance)
|
||||
|
||||
maxDistance2 := maxDistance * maxDistance
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
town := scanTown(stmt)
|
||||
towns = append(towns, town)
|
||||
return nil
|
||||
}, fromX, fromX, fromY, fromY, maxDistance2, fromX, fromX, fromY, fromY)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve towns by distance: %w", err)
|
||||
for _, id := range ts.allByID {
|
||||
if town, exists := ts.GetByID(id); exists {
|
||||
if town.DistanceFromSquared(fromX, fromY) <= maxDistance2 {
|
||||
result = append(result, town)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return towns, nil
|
||||
// Sort by distance, then by ID
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
distI := result[i].DistanceFromSquared(fromX, fromY)
|
||||
distJ := result[j].DistanceFromSquared(fromX, fromY)
|
||||
if distI == distJ {
|
||||
return result[i].ID < result[j].ID
|
||||
}
|
||||
return distI < distJ
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Saves a new town to the database and sets the ID
|
||||
// Saves a new town to the in-memory store and sets the ID
|
||||
func (t *Town) Insert() error {
|
||||
columns := `name, x, y, inn_cost, map_cost, tp_cost, shop_list`
|
||||
values := []any{t.Name, t.X, t.Y, t.InnCost, t.MapCost, t.TPCost, t.ShopList}
|
||||
return database.Insert(t, columns, values...)
|
||||
ts := GetStore()
|
||||
|
||||
// Validate before insertion
|
||||
if err := t.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Assign new ID if not set
|
||||
if t.ID == 0 {
|
||||
t.ID = ts.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
ts.AddTown(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the shop items as a slice of item IDs
|
||||
@ -247,7 +385,7 @@ func (t *Town) GetShopItems() []int {
|
||||
|
||||
// Sets the shop items from a slice of item IDs
|
||||
func (t *Town) SetShopItems(items []int) {
|
||||
t.Set("ShopList", helpers.IntsToString(items))
|
||||
t.ShopList = helpers.IntsToString(items)
|
||||
}
|
||||
|
||||
// Checks if the town's shop sells a specific item ID
|
||||
@ -299,6 +437,6 @@ func (t *Town) GetPosition() (int, int) {
|
||||
|
||||
// Sets the town's coordinates
|
||||
func (t *Town) SetPosition(x, y int) {
|
||||
t.Set("X", x)
|
||||
t.Set("Y", y)
|
||||
t.X = x
|
||||
t.Y = y
|
||||
}
|
||||
|
@ -1,95 +1,81 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"dk/internal/store"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dk/internal/database"
|
||||
"dk/internal/helpers"
|
||||
"dk/internal/helpers/scanner"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// User represents a user in the database
|
||||
// User represents a user in the game
|
||||
type User struct {
|
||||
database.BaseModel
|
||||
|
||||
ID int `db:"id" json:"id"`
|
||||
Username string `db:"username" json:"username"`
|
||||
Password string `db:"password" json:"password"`
|
||||
Email string `db:"email" json:"email"`
|
||||
Verified int `db:"verified" json:"verified"`
|
||||
Token string `db:"token" json:"token"`
|
||||
Registered int64 `db:"registered" json:"registered"`
|
||||
LastOnline int64 `db:"last_online" json:"last_online"`
|
||||
Auth int `db:"auth" json:"auth"`
|
||||
X int `db:"x" json:"x"`
|
||||
Y int `db:"y" json:"y"`
|
||||
ClassID int `db:"class_id" json:"class_id"`
|
||||
Currently string `db:"currently" json:"currently"`
|
||||
Fighting int `db:"fighting" json:"fighting"`
|
||||
MonsterID int `db:"monster_id" json:"monster_id"`
|
||||
MonsterHP int `db:"monster_hp" json:"monster_hp"`
|
||||
MonsterSleep int `db:"monster_sleep" json:"monster_sleep"`
|
||||
MonsterImmune int `db:"monster_immune" json:"monster_immune"`
|
||||
UberDamage int `db:"uber_damage" json:"uber_damage"`
|
||||
UberDefense int `db:"uber_defense" json:"uber_defense"`
|
||||
HP int `db:"hp" json:"hp"`
|
||||
MP int `db:"mp" json:"mp"`
|
||||
TP int `db:"tp" json:"tp"`
|
||||
MaxHP int `db:"max_hp" json:"max_hp"`
|
||||
MaxMP int `db:"max_mp" json:"max_mp"`
|
||||
MaxTP int `db:"max_tp" json:"max_tp"`
|
||||
Level int `db:"level" json:"level"`
|
||||
Gold int `db:"gold" json:"gold"`
|
||||
Exp int `db:"exp" json:"exp"`
|
||||
GoldBonus int `db:"gold_bonus" json:"gold_bonus"`
|
||||
ExpBonus int `db:"exp_bonus" json:"exp_bonus"`
|
||||
Strength int `db:"strength" json:"strength"`
|
||||
Dexterity int `db:"dexterity" json:"dexterity"`
|
||||
Attack int `db:"attack" json:"attack"`
|
||||
Defense int `db:"defense" json:"defense"`
|
||||
WeaponID int `db:"weapon_id" json:"weapon_id"`
|
||||
ArmorID int `db:"armor_id" json:"armor_id"`
|
||||
ShieldID int `db:"shield_id" json:"shield_id"`
|
||||
Slot1ID int `db:"slot_1_id" json:"slot_1_id"`
|
||||
Slot2ID int `db:"slot_2_id" json:"slot_2_id"`
|
||||
Slot3ID int `db:"slot_3_id" json:"slot_3_id"`
|
||||
WeaponName string `db:"weapon_name" json:"weapon_name"`
|
||||
ArmorName string `db:"armor_name" json:"armor_name"`
|
||||
ShieldName string `db:"shield_name" json:"shield_name"`
|
||||
Slot1Name string `db:"slot_1_name" json:"slot_1_name"`
|
||||
Slot2Name string `db:"slot_2_name" json:"slot_2_name"`
|
||||
Slot3Name string `db:"slot_3_name" json:"slot_3_name"`
|
||||
DropCode int `db:"drop_code" json:"drop_code"`
|
||||
Spells string `db:"spells" json:"spells"`
|
||||
Towns string `db:"towns" json:"towns"`
|
||||
}
|
||||
|
||||
func (u *User) GetTableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
func (u *User) GetID() int {
|
||||
return u.ID
|
||||
}
|
||||
|
||||
func (u *User) SetID(id int) {
|
||||
u.ID = id
|
||||
}
|
||||
|
||||
func (u *User) Set(field string, value any) error {
|
||||
return database.Set(u, field, value)
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
Verified int `json:"verified"`
|
||||
Token string `json:"token"`
|
||||
Registered int64 `json:"registered"`
|
||||
LastOnline int64 `json:"last_online"`
|
||||
Auth int `json:"auth"`
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
ClassID int `json:"class_id"`
|
||||
Currently string `json:"currently"`
|
||||
Fighting int `json:"fighting"`
|
||||
MonsterID int `json:"monster_id"`
|
||||
MonsterHP int `json:"monster_hp"`
|
||||
MonsterSleep int `json:"monster_sleep"`
|
||||
MonsterImmune int `json:"monster_immune"`
|
||||
UberDamage int `json:"uber_damage"`
|
||||
UberDefense int `json:"uber_defense"`
|
||||
HP int `json:"hp"`
|
||||
MP int `json:"mp"`
|
||||
TP int `json:"tp"`
|
||||
MaxHP int `json:"max_hp"`
|
||||
MaxMP int `json:"max_mp"`
|
||||
MaxTP int `json:"max_tp"`
|
||||
Level int `json:"level"`
|
||||
Gold int `json:"gold"`
|
||||
Exp int `json:"exp"`
|
||||
GoldBonus int `json:"gold_bonus"`
|
||||
ExpBonus int `json:"exp_bonus"`
|
||||
Strength int `json:"strength"`
|
||||
Dexterity int `json:"dexterity"`
|
||||
Attack int `json:"attack"`
|
||||
Defense int `json:"defense"`
|
||||
WeaponID int `json:"weapon_id"`
|
||||
ArmorID int `json:"armor_id"`
|
||||
ShieldID int `json:"shield_id"`
|
||||
Slot1ID int `json:"slot_1_id"`
|
||||
Slot2ID int `json:"slot_2_id"`
|
||||
Slot3ID int `json:"slot_3_id"`
|
||||
WeaponName string `json:"weapon_name"`
|
||||
ArmorName string `json:"armor_name"`
|
||||
ShieldName string `json:"shield_name"`
|
||||
Slot1Name string `json:"slot_1_name"`
|
||||
Slot2Name string `json:"slot_2_name"`
|
||||
Slot3Name string `json:"slot_3_name"`
|
||||
DropCode int `json:"drop_code"`
|
||||
Spells string `json:"spells"`
|
||||
Towns string `json:"towns"`
|
||||
}
|
||||
|
||||
func (u *User) Save() error {
|
||||
return database.Save(u)
|
||||
userStore := GetStore()
|
||||
userStore.UpdateUser(u)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) Delete() error {
|
||||
return database.Delete(u)
|
||||
userStore := GetStore()
|
||||
userStore.RemoveUser(u.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func New() *User {
|
||||
@ -123,127 +109,303 @@ func New() *User {
|
||||
}
|
||||
}
|
||||
|
||||
var userScanner = scanner.New[User]()
|
||||
|
||||
func userColumns() string {
|
||||
return userScanner.Columns()
|
||||
// Validate checks if user has valid values
|
||||
func (u *User) Validate() error {
|
||||
if strings.TrimSpace(u.Username) == "" {
|
||||
return fmt.Errorf("user username cannot be empty")
|
||||
}
|
||||
if strings.TrimSpace(u.Email) == "" {
|
||||
return fmt.Errorf("user email cannot be empty")
|
||||
}
|
||||
if u.Registered <= 0 {
|
||||
return fmt.Errorf("user Registered timestamp must be positive")
|
||||
}
|
||||
if u.LastOnline <= 0 {
|
||||
return fmt.Errorf("user LastOnline timestamp must be positive")
|
||||
}
|
||||
if u.Level < 1 {
|
||||
return fmt.Errorf("user Level must be at least 1")
|
||||
}
|
||||
if u.HP < 0 {
|
||||
return fmt.Errorf("user HP cannot be negative")
|
||||
}
|
||||
if u.MaxHP < 1 {
|
||||
return fmt.Errorf("user MaxHP must be at least 1")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanUser(stmt *sqlite.Stmt) *User {
|
||||
user := &User{}
|
||||
userScanner.Scan(stmt, user)
|
||||
return user
|
||||
// UserStore provides in-memory storage with O(1) lookups and user-specific indices
|
||||
type UserStore struct {
|
||||
*store.BaseStore[User] // Embedded generic store
|
||||
byUsername map[string]int // Username (lowercase) -> ID
|
||||
byEmail map[string]int // Email -> ID
|
||||
byLevel map[int][]int // Level -> []ID
|
||||
allByRegistered []int // All IDs sorted by registered DESC, id DESC
|
||||
mu sync.RWMutex // Protects indices
|
||||
}
|
||||
|
||||
// Global in-memory store
|
||||
var userStore *UserStore
|
||||
var storeOnce sync.Once
|
||||
|
||||
// Initialize the in-memory store
|
||||
func initStore() {
|
||||
userStore = &UserStore{
|
||||
BaseStore: store.NewBaseStore[User](),
|
||||
byUsername: make(map[string]int),
|
||||
byEmail: make(map[string]int),
|
||||
byLevel: make(map[int][]int),
|
||||
allByRegistered: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the global user store
|
||||
func GetStore() *UserStore {
|
||||
storeOnce.Do(initStore)
|
||||
return userStore
|
||||
}
|
||||
|
||||
// AddUser adds a user to the in-memory store and updates all indices
|
||||
func (us *UserStore) AddUser(user *User) {
|
||||
us.mu.Lock()
|
||||
defer us.mu.Unlock()
|
||||
|
||||
// Validate user
|
||||
if err := user.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add to base store
|
||||
us.Add(user.ID, user)
|
||||
|
||||
// Rebuild indices
|
||||
us.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// RemoveUser removes a user from the store and updates indices
|
||||
func (us *UserStore) RemoveUser(id int) {
|
||||
us.mu.Lock()
|
||||
defer us.mu.Unlock()
|
||||
|
||||
// Remove from base store
|
||||
us.Remove(id)
|
||||
|
||||
// Rebuild indices
|
||||
us.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// UpdateUser updates a user efficiently
|
||||
func (us *UserStore) UpdateUser(user *User) {
|
||||
us.mu.Lock()
|
||||
defer us.mu.Unlock()
|
||||
|
||||
// Validate user
|
||||
if err := user.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update base store
|
||||
us.Add(user.ID, user)
|
||||
|
||||
// Rebuild indices
|
||||
us.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
// LoadData loads user data from JSON file, or starts with empty store
|
||||
func LoadData(dataPath string) error {
|
||||
us := GetStore()
|
||||
|
||||
// Load from base store, which handles JSON loading
|
||||
if err := us.BaseStore.LoadData(dataPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rebuild indices from loaded data
|
||||
us.rebuildIndices()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveData saves user data to JSON file
|
||||
func SaveData(dataPath string) error {
|
||||
us := GetStore()
|
||||
return us.BaseStore.SaveData(dataPath)
|
||||
}
|
||||
|
||||
// rebuildIndicesUnsafe rebuilds all indices from base store data (caller must hold lock)
|
||||
func (us *UserStore) rebuildIndicesUnsafe() {
|
||||
// Clear indices
|
||||
us.byUsername = make(map[string]int)
|
||||
us.byEmail = make(map[string]int)
|
||||
us.byLevel = make(map[int][]int)
|
||||
us.allByRegistered = make([]int, 0)
|
||||
|
||||
// Collect all users and build indices
|
||||
allUsers := us.GetAll()
|
||||
|
||||
for id, user := range allUsers {
|
||||
// Username index (case-insensitive)
|
||||
us.byUsername[strings.ToLower(user.Username)] = id
|
||||
|
||||
// Email index
|
||||
us.byEmail[user.Email] = id
|
||||
|
||||
// Level index
|
||||
us.byLevel[user.Level] = append(us.byLevel[user.Level], id)
|
||||
|
||||
// All IDs
|
||||
us.allByRegistered = append(us.allByRegistered, id)
|
||||
}
|
||||
|
||||
// Sort allByRegistered by registered DESC, then ID DESC
|
||||
sort.Slice(us.allByRegistered, func(i, j int) bool {
|
||||
userI, _ := us.GetByID(us.allByRegistered[i])
|
||||
userJ, _ := us.GetByID(us.allByRegistered[j])
|
||||
if userI.Registered != userJ.Registered {
|
||||
return userI.Registered > userJ.Registered // DESC
|
||||
}
|
||||
return us.allByRegistered[i] > us.allByRegistered[j] // DESC
|
||||
})
|
||||
|
||||
// Sort level indices by exp DESC, then ID ASC
|
||||
for level := range us.byLevel {
|
||||
sort.Slice(us.byLevel[level], func(i, j int) bool {
|
||||
userI, _ := us.GetByID(us.byLevel[level][i])
|
||||
userJ, _ := us.GetByID(us.byLevel[level][j])
|
||||
if userI.Exp != userJ.Exp {
|
||||
return userI.Exp > userJ.Exp // DESC
|
||||
}
|
||||
return us.byLevel[level][i] < us.byLevel[level][j] // ASC
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildIndices rebuilds all user-specific indices from base store data
|
||||
func (us *UserStore) rebuildIndices() {
|
||||
us.mu.Lock()
|
||||
defer us.mu.Unlock()
|
||||
us.rebuildIndicesUnsafe()
|
||||
}
|
||||
|
||||
func Find(id int) (*User, error) {
|
||||
var user *User
|
||||
query := `SELECT ` + userColumns() + ` FROM users WHERE id = ?`
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
user = scanUser(stmt)
|
||||
return nil
|
||||
}, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
us := GetStore()
|
||||
user, exists := us.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user with ID %d not found", id)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func All() ([]*User, error) {
|
||||
var users []*User
|
||||
query := `SELECT ` + userColumns() + ` FROM users ORDER BY registered DESC, id DESC`
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
user := scanUser(stmt)
|
||||
users = append(users, user)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve all users: %w", err)
|
||||
us := GetStore()
|
||||
us.mu.RLock()
|
||||
defer us.mu.RUnlock()
|
||||
|
||||
result := make([]*User, 0, len(us.allByRegistered))
|
||||
for _, id := range us.allByRegistered {
|
||||
if user, exists := us.GetByID(id); exists {
|
||||
result = append(result, user)
|
||||
}
|
||||
}
|
||||
return users, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ByUsername(username string) (*User, error) {
|
||||
var user *User
|
||||
query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(username) = LOWER(?) LIMIT 1`
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
user = scanUser(stmt)
|
||||
return nil
|
||||
}, username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find user by username: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
us := GetStore()
|
||||
us.mu.RLock()
|
||||
defer us.mu.RUnlock()
|
||||
|
||||
id, exists := us.byUsername[strings.ToLower(username)]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user with username '%s' not found", username)
|
||||
}
|
||||
|
||||
user, exists := us.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user with username '%s' not found", username)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func ByEmail(email string) (*User, error) {
|
||||
var user *User
|
||||
query := `SELECT ` + userColumns() + ` FROM users WHERE email = ? LIMIT 1`
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
user = scanUser(stmt)
|
||||
return nil
|
||||
}, email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find user by email: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
us := GetStore()
|
||||
us.mu.RLock()
|
||||
defer us.mu.RUnlock()
|
||||
|
||||
id, exists := us.byEmail[email]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user with email '%s' not found", email)
|
||||
}
|
||||
|
||||
user, exists := us.GetByID(id)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user with email '%s' not found", email)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func ByLevel(level int) ([]*User, error) {
|
||||
var users []*User
|
||||
query := `SELECT ` + userColumns() + ` FROM users WHERE level = ? ORDER BY exp DESC, id ASC`
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
user := scanUser(stmt)
|
||||
users = append(users, user)
|
||||
return nil
|
||||
}, level)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve users by level: %w", err)
|
||||
us := GetStore()
|
||||
us.mu.RLock()
|
||||
defer us.mu.RUnlock()
|
||||
|
||||
ids, exists := us.byLevel[level]
|
||||
if !exists {
|
||||
return []*User{}, nil
|
||||
}
|
||||
return users, nil
|
||||
|
||||
result := make([]*User, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if user, exists := us.GetByID(id); exists {
|
||||
result = append(result, user)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func Online(within time.Duration) ([]*User, error) {
|
||||
var users []*User
|
||||
us := GetStore()
|
||||
us.mu.RLock()
|
||||
defer us.mu.RUnlock()
|
||||
|
||||
cutoff := time.Now().Add(-within).Unix()
|
||||
query := `SELECT ` + userColumns() + ` FROM users WHERE last_online >= ? ORDER BY last_online DESC, id ASC`
|
||||
err := database.Query(query, func(stmt *sqlite.Stmt) error {
|
||||
user := scanUser(stmt)
|
||||
users = append(users, user)
|
||||
return nil
|
||||
}, cutoff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve online users: %w", err)
|
||||
var result []*User
|
||||
|
||||
for _, id := range us.allByRegistered {
|
||||
if user, exists := us.GetByID(id); exists && user.LastOnline >= cutoff {
|
||||
result = append(result, user)
|
||||
}
|
||||
}
|
||||
return users, nil
|
||||
|
||||
// Sort by last_online DESC, then ID ASC
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
if result[i].LastOnline != result[j].LastOnline {
|
||||
return result[i].LastOnline > result[j].LastOnline // DESC
|
||||
}
|
||||
return result[i].ID < result[j].ID // ASC
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (u *User) Insert() error {
|
||||
columns := `username, password, email, verified, token, registered, last_online, auth,
|
||||
x, y, class_id, currently, fighting, monster_id, monster_hp, monster_sleep, monster_immune,
|
||||
uber_damage, uber_defense, hp, mp, tp, max_hp, max_mp, max_tp, level, gold, exp,
|
||||
gold_bonus, exp_bonus, strength, dexterity, attack, defense, weapon_id, armor_id, shield_id,
|
||||
slot_1_id, slot_2_id, slot_3_id, weapon_name, armor_name, shield_name,
|
||||
slot_1_name, slot_2_name, slot_3_name, drop_code, spells, towns`
|
||||
us := GetStore()
|
||||
|
||||
values := []any{u.Username, u.Password, u.Email, u.Verified, u.Token,
|
||||
u.Registered, u.LastOnline, u.Auth, u.X, u.Y, u.ClassID, u.Currently,
|
||||
u.Fighting, u.MonsterID, u.MonsterHP, u.MonsterSleep, u.MonsterImmune,
|
||||
u.UberDamage, u.UberDefense, u.HP, u.MP, u.TP, u.MaxHP, u.MaxMP, u.MaxTP,
|
||||
u.Level, u.Gold, u.Exp, u.GoldBonus, u.ExpBonus, u.Strength, u.Dexterity,
|
||||
u.Attack, u.Defense, u.WeaponID, u.ArmorID, u.ShieldID, u.Slot1ID,
|
||||
u.Slot2ID, u.Slot3ID, u.WeaponName, u.ArmorName, u.ShieldName,
|
||||
u.Slot1Name, u.Slot2Name, u.Slot3Name, u.DropCode, u.Spells, u.Towns}
|
||||
// Validate before insertion
|
||||
if err := u.Validate(); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
return database.Insert(u, columns, values...)
|
||||
// Assign new ID if not set
|
||||
if u.ID == 0 {
|
||||
u.ID = us.GetNextID()
|
||||
}
|
||||
|
||||
// Add to store
|
||||
us.AddUser(u)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) RegisteredTime() time.Time {
|
||||
@ -255,7 +417,7 @@ func (u *User) LastOnlineTime() time.Time {
|
||||
}
|
||||
|
||||
func (u *User) UpdateLastOnline() {
|
||||
u.Set("LastOnline", time.Now().Unix())
|
||||
u.LastOnline = time.Now().Unix()
|
||||
}
|
||||
|
||||
func (u *User) IsVerified() bool {
|
||||
@ -283,7 +445,7 @@ func (u *User) GetSpellIDs() []int {
|
||||
}
|
||||
|
||||
func (u *User) SetSpellIDs(spells []int) {
|
||||
u.Set("Spells", helpers.IntsToString(spells))
|
||||
u.Spells = helpers.IntsToString(spells)
|
||||
}
|
||||
|
||||
func (u *User) HasSpell(spellID int) bool {
|
||||
@ -295,7 +457,7 @@ func (u *User) GetTownIDs() []int {
|
||||
}
|
||||
|
||||
func (u *User) SetTownIDs(towns []int) {
|
||||
u.Set("Towns", helpers.IntsToString(towns))
|
||||
u.Towns = helpers.IntsToString(towns)
|
||||
}
|
||||
|
||||
func (u *User) HasTownMap(townID int) bool {
|
||||
@ -336,6 +498,6 @@ func (u *User) GetPosition() (int, int) {
|
||||
}
|
||||
|
||||
func (u *User) SetPosition(x, y int) {
|
||||
u.Set("X", x)
|
||||
u.Set("Y", y)
|
||||
u.X = x
|
||||
u.Y = y
|
||||
}
|
||||
|
9
main.go
9
main.go
@ -6,24 +6,19 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"dk/internal/install"
|
||||
"dk/internal/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var port string
|
||||
flag.StringVar(&port, "p", "3000", "Port to run server on")
|
||||
|
||||
|
||||
if len(os.Args) < 2 {
|
||||
startServer(port)
|
||||
return
|
||||
}
|
||||
|
||||
switch os.Args[1] {
|
||||
case "install":
|
||||
if err := install.Run(); err != nil {
|
||||
log.Fatalf("Installation failed: %v", err)
|
||||
}
|
||||
case "serve":
|
||||
flag.CommandLine.Parse(os.Args[2:])
|
||||
startServer(port)
|
||||
@ -42,4 +37,4 @@ func startServer(port string) {
|
||||
if err := server.Start(port); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,34 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"dk/internal/monsters"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("Testing LoadData() function...")
|
||||
|
||||
err := monsters.LoadData()
|
||||
if err != nil {
|
||||
fmt.Printf("LoadData() failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Test that we can find a monster
|
||||
monster, err := monsters.Find(1)
|
||||
if err != nil {
|
||||
fmt.Printf("Find(1) failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully loaded data! Found monster: %s (Level %d)\n", monster.Name, monster.Level)
|
||||
|
||||
// Test getting all monsters
|
||||
all, err := monsters.All()
|
||||
if err != nil {
|
||||
fmt.Printf("All() failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Total monsters loaded: %d\n", len(all))
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user