Compare commits

...

3 Commits

Author SHA1 Message Date
fd75638fc6 world server skeleton 2025-07-30 10:01:32 -05:00
4bae02bec0 add achievements package 2025-07-30 09:38:58 -05:00
fc82f97cb6 create thin database wrapper 2025-07-30 09:29:20 -05:00
10 changed files with 2532 additions and 165 deletions

View File

@ -1,18 +1,17 @@
package main
import (
"eq2emu/internal/database"
"fmt"
"log"
"time"
"golang.org/x/crypto/bcrypt"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
// Database handles all database operations for the login server
type Database struct {
conn *sqlite.Conn
db *database.DB
}
// DatabaseConfig holds database connection settings
@ -24,32 +23,26 @@ type DatabaseConfig struct {
// NewDatabase creates a new database connection
func NewDatabase(config DatabaseConfig) (*Database, error) {
// Open SQLite database
db, err := sqlite.OpenConn(config.FilePath, sqlite.OpenReadWrite|sqlite.OpenCreate)
db, err := database.Open(config.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Enable foreign keys
if err := sqlitex.ExecuteTransient(db, "PRAGMA foreign_keys = ON", nil); err != nil {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
// Set busy timeout
// Set busy timeout if specified
if config.BusyTimeout > 0 {
query := fmt.Sprintf("PRAGMA busy_timeout = %d", config.BusyTimeout)
if err := sqlitex.ExecuteTransient(db, query, nil); err != nil {
if err := db.Exec(query); err != nil {
return nil, fmt.Errorf("failed to set busy timeout: %w", err)
}
}
log.Println("SQLite database connection established")
return &Database{conn: db}, nil
return &Database{db: db}, nil
}
// Close closes the database connection
func (d *Database) Close() error {
return d.conn.Close()
return d.db.Close()
}
// AuthenticateAccount verifies user credentials and returns account info
@ -60,49 +53,37 @@ func (d *Database) AuthenticateAccount(username, password string) (*Account, err
FROM accounts
WHERE username = ? AND active = 1`
stmt, err := d.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("prepare statement failed: %w", err)
}
defer stmt.Finalize()
stmt.BindText(1, username)
hasRow, err := stmt.Step()
row, err := d.db.QueryRow(query, username)
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
if !hasRow {
if row == nil {
return nil, nil // Account not found
}
defer row.Close()
var account Account
var passwordHash string
var lastLogin string
var clientVersion int64
account.ID = int32(stmt.ColumnInt64(0))
account.Username = stmt.ColumnText(1)
passwordHash = stmt.ColumnText(2)
account.LSAdmin = stmt.ColumnInt(3) != 0
account.WorldAdmin = stmt.ColumnInt(4) != 0
account.ID = int32(row.Int64(0))
account.Username = row.Text(1)
passwordHash := row.Text(2)
account.LSAdmin = row.Bool(3)
account.WorldAdmin = row.Bool(4)
// Skip created_date at index 5 - not needed for authentication
lastLogin = stmt.ColumnText(6)
clientVersion = stmt.ColumnInt64(7)
lastLogin := row.Text(6)
account.ClientVersion = uint16(row.Int64(7))
// Verify password
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
return nil, nil // Invalid password
}
// Parse timestamps
// Parse timestamp
if lastLogin != "" {
if t, err := time.Parse("2006-01-02 15:04:05", lastLogin); err == nil {
account.LastLogin = t
}
}
account.ClientVersion = uint16(clientVersion)
return &account, nil
}
@ -113,14 +94,12 @@ func (d *Database) UpdateAccountLogin(account *Account) error {
SET last_login = ?, last_ip = ?, client_version = ?
WHERE id = ?`
return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
Args: []any{
account.LastLogin.Format("2006-01-02 15:04:05"),
account.IPAddress,
account.ClientVersion,
account.ID,
},
})
return d.db.Exec(query,
account.LastLogin.Format("2006-01-02 15:04:05"),
account.IPAddress,
account.ClientVersion,
account.ID,
)
}
// LoadCharacters loads all characters for an account
@ -132,45 +111,29 @@ func (d *Database) LoadCharacters(accountID int32, version uint16) ([]*Character
WHERE account_id = ?
ORDER BY created_date ASC`
stmt, err := d.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("prepare failed: %w", err)
}
defer stmt.Finalize()
stmt.BindInt64(1, int64(accountID))
var characters []*Character
for {
hasRow, err := stmt.Step()
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
if !hasRow {
break
}
err := d.db.Query(query, func(row *database.Row) error {
char := &Character{AccountID: accountID}
char.ID = int32(row.Int64(0))
char.ServerID = int32(row.Int64(1))
char.Name = row.Text(2)
char.Level = int8(row.Int(3))
char.Race = int8(row.Int(4))
char.Gender = int8(row.Int(5))
char.Class = int8(row.Int(6))
char.ID = int32(stmt.ColumnInt64(0))
char.ServerID = int32(stmt.ColumnInt64(1))
char.Name = stmt.ColumnText(2)
char.Level = int8(stmt.ColumnInt(3))
char.Race = int8(stmt.ColumnInt(4))
char.Gender = int8(stmt.ColumnInt(5))
char.Class = int8(stmt.ColumnInt(6))
if dateStr := stmt.ColumnText(7); dateStr != "" {
if dateStr := row.Text(7); dateStr != "" {
if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil {
char.CreatedDate = t
}
}
char.Deleted = stmt.ColumnInt(8) != 0
char.Deleted = row.Bool(8)
characters = append(characters, char)
}
return nil
}, accountID)
return characters, nil
return characters, err
}
// CharacterNameExists checks if a character name is already taken
@ -180,24 +143,16 @@ func (d *Database) CharacterNameExists(name string, serverID int32) (bool, error
FROM characters
WHERE name = ? AND server_id = ? AND deleted = 0`
stmt, err := d.conn.Prepare(query)
row, err := d.db.QueryRow(query, name, serverID)
if err != nil {
return false, err
}
defer stmt.Finalize()
stmt.BindText(1, name)
stmt.BindInt64(2, int64(serverID))
hasRow, err := stmt.Step()
if err != nil {
return false, err
}
if hasRow {
return stmt.ColumnInt(0) > 0, nil
if row == nil {
return false, nil
}
defer row.Close()
return false, nil
return row.Int(0) > 0, nil
}
// CreateCharacter creates a new character in the database
@ -207,18 +162,16 @@ func (d *Database) CreateCharacter(char *Character) (int32, error) {
gender, class, created_date)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`
err := sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
Args: []any{
char.AccountID, char.ServerID, char.Name, char.Level,
char.Race, char.Gender, char.Class,
char.CreatedDate.Format("2006-01-02 15:04:05"),
},
})
err := d.db.Exec(query,
char.AccountID, char.ServerID, char.Name, char.Level,
char.Race, char.Gender, char.Class,
char.CreatedDate.Format("2006-01-02 15:04:05"),
)
if err != nil {
return 0, fmt.Errorf("failed to create character: %w", err)
}
id := int32(d.conn.LastInsertRowID())
id := int32(d.db.LastInsertID())
log.Printf("Created character %s (ID: %d) for account %d",
char.Name, id, char.AccountID)
@ -232,18 +185,13 @@ func (d *Database) DeleteCharacter(charID, accountID int32) error {
SET deleted = 1, deleted_date = CURRENT_TIMESTAMP
WHERE id = ? AND account_id = ?`
err := sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
Args: []any{charID, accountID},
})
err := d.db.Exec(query, charID, accountID)
if err != nil {
return fmt.Errorf("failed to delete character: %w", err)
}
// Check if any rows were affected
stmt, _ := d.conn.Prepare("SELECT changes()")
defer stmt.Finalize()
hasRow, _ := stmt.Step()
if hasRow && stmt.ColumnInt(0) == 0 {
if d.db.Changes() == 0 {
return fmt.Errorf("character not found or not owned by account")
}
@ -260,35 +208,20 @@ func (d *Database) GetWorldServers() ([]*WorldServer, error) {
WHERE active = 1
ORDER BY sort_order, name`
stmt, err := d.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("prepare failed: %w", err)
}
defer stmt.Finalize()
var servers []*WorldServer
for {
hasRow, err := stmt.Step()
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
if !hasRow {
break
}
err := d.db.Query(query, func(row *database.Row) error {
server := &WorldServer{}
server.ID = int32(row.Int64(0))
server.Name = row.Text(1)
server.Description = row.Text(2)
server.IPAddress = row.Text(3)
server.Port = row.Int(4)
server.Status = row.Text(5)
server.Population = int32(row.Int64(6))
server.Locked = row.Bool(7)
server.Hidden = row.Bool(8)
server.ID = int32(stmt.ColumnInt64(0))
server.Name = stmt.ColumnText(1)
server.Description = stmt.ColumnText(2)
server.IPAddress = stmt.ColumnText(3)
server.Port = stmt.ColumnInt(4)
server.Status = stmt.ColumnText(5)
server.Population = int32(stmt.ColumnInt64(6))
server.Locked = stmt.ColumnInt(7) != 0
server.Hidden = stmt.ColumnInt(8) != 0
if dateStr := stmt.ColumnText(9); dateStr != "" {
if dateStr := row.Text(9); dateStr != "" {
if t, err := time.Parse("2006-01-02 15:04:05", dateStr); err == nil {
server.CreatedDate = t
}
@ -296,11 +229,11 @@ func (d *Database) GetWorldServers() ([]*WorldServer, error) {
server.Online = server.Status == "online"
server.PopulationLevel = calculatePopulationLevel(server.Population)
servers = append(servers, server)
}
return nil
})
return servers, nil
return servers, err
}
// UpdateWorldServerStats updates world server statistics
@ -310,12 +243,10 @@ func (d *Database) UpdateWorldServerStats(serverID int32, stats *WorldServerStat
(server_id, timestamp, population, zones_active, players_online, uptime_seconds)
VALUES (?, CURRENT_TIMESTAMP, ?, ?, ?, ?)`
return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
Args: []any{
serverID, stats.Population,
stats.ZonesActive, stats.PlayersOnline, stats.UptimeSeconds,
},
})
return d.db.Exec(query,
serverID, stats.Population,
stats.ZonesActive, stats.PlayersOnline, stats.UptimeSeconds,
)
}
// CleanupOldEntries removes old log entries and statistics
@ -327,7 +258,7 @@ func (d *Database) CleanupOldEntries() error {
}
for _, query := range queries {
if err := sqlitex.ExecuteTransient(d.conn, query, nil); err != nil {
if err := d.db.Exec(query); err != nil {
log.Printf("Cleanup query failed: %v", err)
}
}
@ -341,9 +272,7 @@ func (d *Database) LogLoginAttempt(username, ipAddress string, success bool) err
INSERT INTO login_attempts (username, ip_address, success, timestamp)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)`
return sqlitex.Execute(d.conn, query, &sqlitex.ExecOptions{
Args: []any{username, ipAddress, success},
})
return d.db.Exec(query, username, ipAddress, success)
}
// GetMaxCharsSetting returns the maximum characters per account
@ -351,21 +280,15 @@ func (d *Database) GetMaxCharsSetting() int32 {
var maxChars int32 = 7 // Default
query := "SELECT value FROM server_settings WHERE name = 'max_characters_per_account'"
stmt, err := d.conn.Prepare(query)
if err != nil {
row, err := d.db.QueryRow(query)
if err != nil || row == nil {
return maxChars
}
defer stmt.Finalize()
defer row.Close()
hasRow, err := stmt.Step()
if err != nil {
return maxChars
}
if hasRow {
if val := stmt.ColumnText(0); val != "" {
if parsed := stmt.ColumnInt64(0); parsed > 0 {
maxChars = int32(parsed)
}
if !row.IsNull(0) {
if val := row.Int64(0); val > 0 {
maxChars = int32(val)
}
}
@ -377,21 +300,13 @@ func (d *Database) GetAccountBonus(accountID int32) uint8 {
var bonus uint8 = 0
query := "SELECT veteran_bonus FROM accounts WHERE id = ?"
stmt, err := d.conn.Prepare(query)
if err != nil {
row, err := d.db.QueryRow(query, accountID)
if err != nil || row == nil {
return bonus
}
defer stmt.Finalize()
stmt.BindInt64(1, int64(accountID))
hasRow, err := stmt.Step()
if err != nil {
return bonus
}
if hasRow {
bonus = uint8(stmt.ColumnInt(0))
}
defer row.Close()
bonus = uint8(row.Int(0))
return bonus
}

137
cmd/world_server/main.go Normal file
View File

@ -0,0 +1,137 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"os/signal"
"syscall"
)
const (
ConfigFile = "world_config.json"
Version = "0.1.0"
)
// printHeader displays the EQ2Emu banner and copyright info
func printHeader() {
fmt.Println("EQ2Emulator World Server")
fmt.Printf("Version: %s\n", Version)
fmt.Println()
fmt.Println("Copyright (C) 2007-2026 EQ2Emulator Development Team")
fmt.Println("https://www.eq2emu.com")
fmt.Println()
fmt.Println("EQ2Emulator is free software licensed under the GNU GPL v3")
fmt.Println("See LICENSE file for details")
fmt.Println()
}
// loadConfig loads configuration from JSON file with command line overrides
func loadConfig() (*WorldConfig, error) {
// Default configuration
config := &WorldConfig{
ListenAddr: "0.0.0.0",
ListenPort: 9000,
MaxClients: 1000,
BufferSize: 8192,
WebAddr: "0.0.0.0",
WebPort: 8080,
DatabasePath: "world.db",
XPRate: 1.0,
TSXPRate: 1.0,
VitalityRate: 1.0,
LogLevel: "info",
ThreadedLoad: true,
}
// Load from config file if it exists
if data, err := os.ReadFile(ConfigFile); err == nil {
if err := json.Unmarshal(data, config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
log.Printf("Loaded configuration from %s", ConfigFile)
} else {
log.Printf("Config file %s not found, using defaults", ConfigFile)
}
// Command line overrides
flag.StringVar(&config.ListenAddr, "listen-addr", config.ListenAddr, "UDP listen address")
flag.IntVar(&config.ListenPort, "listen-port", config.ListenPort, "UDP listen port")
flag.IntVar(&config.MaxClients, "max-clients", config.MaxClients, "Maximum client connections")
flag.StringVar(&config.WebAddr, "web-addr", config.WebAddr, "Web server address")
flag.IntVar(&config.WebPort, "web-port", config.WebPort, "Web server port")
flag.StringVar(&config.DatabasePath, "db-path", config.DatabasePath, "Database file path")
flag.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (debug, info, warn, error)")
flag.BoolVar(&config.ThreadedLoad, "threaded-load", config.ThreadedLoad, "Use threaded loading")
flag.Parse()
return config, nil
}
// saveConfig saves the current configuration to file
func saveConfig(config *WorldConfig) error {
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
if err := os.WriteFile(ConfigFile, data, 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
return nil
}
// setupSignalHandlers sets up graceful shutdown on SIGINT/SIGTERM
func setupSignalHandlers(world *World) <-chan os.Signal {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Printf("Received signal %v, initiating graceful shutdown...", sig)
world.Shutdown()
}()
return sigChan
}
func main() {
printHeader()
// Load configuration
config, err := loadConfig()
if err != nil {
log.Fatalf("Configuration error: %v", err)
}
// Save config file with any command line overrides
if err := saveConfig(config); err != nil {
log.Printf("Warning: failed to save config: %v", err)
}
// Create world server instance
world, err := NewWorld(config)
if err != nil {
log.Fatalf("Failed to create world server: %v", err)
}
// Initialize all components
log.Println("Initializing EQ2Emulator World Server...")
if err := world.Initialize(); err != nil {
log.Fatalf("Failed to initialize world server: %v", err)
}
// Setup signal handlers for graceful shutdown
setupSignalHandlers(world)
// Run the server
log.Println("Starting World Server...")
if err := world.Run(); err != nil {
log.Fatalf("World server error: %v", err)
}
log.Println("World Server stopped gracefully")
}

298
cmd/world_server/web.go Normal file
View File

@ -0,0 +1,298 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)
// setupWebServer initializes the HTTP server for admin interface
func (w *World) setupWebServer() error {
if w.config.WebPort == 0 {
return nil // Web server disabled
}
mux := http.NewServeMux()
// API endpoints
mux.HandleFunc("/api/status", w.handleStatus)
mux.HandleFunc("/api/clients", w.handleClients)
mux.HandleFunc("/api/zones", w.handleZones)
mux.HandleFunc("/api/stats", w.handleStats)
mux.HandleFunc("/api/time", w.handleWorldTime)
mux.HandleFunc("/api/shutdown", w.handleShutdown)
// Administrative endpoints
mux.HandleFunc("/api/admin/reload", w.handleReload)
mux.HandleFunc("/api/admin/broadcast", w.handleBroadcast)
mux.HandleFunc("/api/admin/kick", w.handleKickClient)
// Peer management endpoints
mux.HandleFunc("/api/peers", w.handlePeers)
mux.HandleFunc("/api/peers/sync", w.handlePeerSync)
// Console command interface
mux.HandleFunc("/api/console", w.handleConsoleCommand)
// Static health check
mux.HandleFunc("/health", w.handleHealth)
// @TODO: Add authentication middleware
// @TODO: Add rate limiting middleware
// @TODO: Add CORS middleware for browser access
// @TODO: Add TLS support with cert/key files
addr := fmt.Sprintf("%s:%d", w.config.WebAddr, w.config.WebPort)
w.webServer = &http.Server{
Addr: addr,
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
return nil
}
// Core API handlers
// handleHealth provides a simple health check endpoint
func (w *World) handleHealth(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(map[string]string{"status": "ok"})
}
// handleStatus returns comprehensive server status information
func (w *World) handleStatus(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
status := map[string]any{
"status": "running",
"uptime": time.Since(w.stats.StartTime).Seconds(),
"version": Version,
"locked": w.config.WorldLocked,
"primary": w.config.IsPrimary,
"threaded": w.config.ThreadedLoad,
"data_loaded": w.isDataLoaded(),
"world_time": w.getWorldTime(),
}
json.NewEncoder(rw).Encode(status)
}
// handleClients returns list of connected clients
func (w *World) handleClients(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
w.clientsMux.RLock()
clients := make([]*ClientInfo, 0, len(w.clients))
for _, client := range w.clients {
clients = append(clients, client)
}
w.clientsMux.RUnlock()
json.NewEncoder(rw).Encode(map[string]any{
"count": len(clients),
"clients": clients,
})
}
// handleZones returns list of zone servers
func (w *World) handleZones(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
w.zonesMux.RLock()
zones := make([]*ZoneInfo, 0, len(w.zones))
for _, zone := range w.zones {
zones = append(zones, zone)
}
w.zonesMux.RUnlock()
json.NewEncoder(rw).Encode(map[string]any{
"count": len(zones),
"zones": zones,
})
}
// handleStats returns detailed server statistics
func (w *World) handleStats(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
w.statsMux.RLock()
stats := w.stats
w.statsMux.RUnlock()
// Add UDP server stats if available
if w.udpServer != nil {
serverStats := w.udpServer.GetStats()
stats.TotalConnections = int64(serverStats.ConnectionCount)
}
json.NewEncoder(rw).Encode(stats)
}
// handleWorldTime returns current game world time
func (w *World) handleWorldTime(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(w.getWorldTime())
}
// Administrative handlers
// handleShutdown initiates graceful server shutdown
func (w *World) handleShutdown(rw http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// @TODO: Add authentication check
// @TODO: Add confirmation parameter
// @TODO: Add delay parameter
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(map[string]string{"status": "shutdown initiated"})
go func() {
time.Sleep(time.Second) // Allow response to be sent
w.Shutdown()
}()
}
// handleReload reloads game data
func (w *World) handleReload(rw http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// @TODO: Add authentication check
// @TODO: Implement selective reloading (items, spells, quests, etc.)
// @TODO: Add progress reporting
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(map[string]string{"status": "reload not implemented"})
}
// handleBroadcast sends server-wide message
func (w *World) handleBroadcast(rw http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// @TODO: Add authentication check
// @TODO: Parse message from request body
// @TODO: Validate message content
// @TODO: Send to all connected clients
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(map[string]string{"status": "broadcast not implemented"})
}
// handleKickClient disconnects a specific client
func (w *World) handleKickClient(rw http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// @TODO: Add authentication check
// @TODO: Parse client ID from request
// @TODO: Find and disconnect client
// @TODO: Log kick action
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(map[string]string{"status": "kick not implemented"})
}
// Peer management handlers
// handlePeers returns list of peer servers
func (w *World) handlePeers(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
peers := make([]map[string]any, 0)
for _, peer := range w.config.PeerServers {
peerInfo := map[string]any{
"address": peer.Address,
"port": peer.Port,
"status": "unknown", // @TODO: Implement peer status checking
}
peers = append(peers, peerInfo)
}
json.NewEncoder(rw).Encode(map[string]any{
"count": len(peers),
"peers": peers,
})
}
// handlePeerSync synchronizes data with peer servers
func (w *World) handlePeerSync(rw http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// @TODO: Add authentication check
// @TODO: Implement peer synchronization
// @TODO: Return sync status and results
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(map[string]string{"status": "peer sync not implemented"})
}
// Console command handler
// handleConsoleCommand executes administrative commands
func (w *World) handleConsoleCommand(rw http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// @TODO: Add authentication check
// @TODO: Parse command from request body
// @TODO: Validate command permissions
// @TODO: Execute command and return results
// @TODO: Log command execution
rw.Header().Set("Content-Type", "application/json")
json.NewEncoder(rw).Encode(map[string]string{"status": "console commands not implemented"})
}
// Helper methods for web handlers
// getWorldTime returns thread-safe copy of world time
func (w *World) getWorldTime() WorldTime {
w.worldTimeMux.RLock()
defer w.worldTimeMux.RUnlock()
return w.worldTime
}
// startWebServer starts the web server in a goroutine
func (w *World) startWebServer() {
if w.webServer == nil {
return
}
go func() {
if err := w.webServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Printf("Web server error: %v\n", err)
}
}()
}
// stopWebServer gracefully stops the web server
func (w *World) stopWebServer() error {
if w.webServer == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return w.webServer.Shutdown(ctx)
}

830
cmd/world_server/world.go Normal file
View File

@ -0,0 +1,830 @@
package main
import (
"context"
"fmt"
"log"
"net/http"
"sync"
"time"
"eq2emu/internal/database"
"eq2emu/internal/udp"
)
// WorldTime represents the in-game time
type WorldTime struct {
Year int32 `json:"year"`
Month int32 `json:"month"`
Day int32 `json:"day"`
Hour int32 `json:"hour"`
Minute int32 `json:"minute"`
}
// WorldConfig holds all world server configuration
type WorldConfig struct {
// Network settings
ListenAddr string `json:"listen_addr"`
ListenPort int `json:"listen_port"`
MaxClients int `json:"max_clients"`
BufferSize int `json:"buffer_size"`
// Web server settings
WebAddr string `json:"web_addr"`
WebPort int `json:"web_port"`
CertFile string `json:"cert_file"`
KeyFile string `json:"key_file"`
KeyPassword string `json:"key_password"`
WebUser string `json:"web_user"`
WebPassword string `json:"web_password"`
// Database settings
DatabasePath string `json:"database_path"`
// Game settings
XPRate float64 `json:"xp_rate"`
TSXPRate float64 `json:"ts_xp_rate"`
VitalityRate float64 `json:"vitality_rate"`
// Server settings
LogLevel string `json:"log_level"`
ThreadedLoad bool `json:"threaded_load"`
WorldLocked bool `json:"world_locked"`
IsPrimary bool `json:"is_primary"`
// Login server settings
LoginServers []LoginServerInfo `json:"login_servers"`
// Peer server settings
PeerServers []PeerServerInfo `json:"peer_servers"`
PeerPriority int `json:"peer_priority"`
}
// LoginServerInfo represents login server connection details
type LoginServerInfo struct {
Address string `json:"address"`
Port int `json:"port"`
Account string `json:"account"`
Password string `json:"password"`
}
// PeerServerInfo represents peer server connection details
type PeerServerInfo struct {
Address string `json:"address"`
Port int `json:"port"`
}
// ClientInfo represents a connected client
type ClientInfo struct {
ID int32 `json:"id"`
AccountID int32 `json:"account_id"`
CharacterID int32 `json:"character_id"`
Name string `json:"name"`
ZoneID int32 `json:"zone_id"`
ConnectedAt time.Time `json:"connected_at"`
LastActive time.Time `json:"last_active"`
IPAddress string `json:"ip_address"`
}
// ZoneInfo represents zone server information
type ZoneInfo struct {
ID int32 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
PlayerCount int32 `json:"player_count"`
MaxPlayers int32 `json:"max_players"`
IsShutdown bool `json:"is_shutdown"`
Address string `json:"address"`
Port int `json:"port"`
}
// ServerStats holds server statistics
type ServerStats struct {
StartTime time.Time `json:"start_time"`
ClientCount int32 `json:"client_count"`
ZoneCount int32 `json:"zone_count"`
TotalConnections int64 `json:"total_connections"`
PacketsProcessed int64 `json:"packets_processed"`
DataLoaded bool `json:"data_loaded"`
ItemsLoaded bool `json:"items_loaded"`
SpellsLoaded bool `json:"spells_loaded"`
QuestsLoaded bool `json:"quests_loaded"`
}
// World represents the main world server
type World struct {
config *WorldConfig
db *database.DB
// Network components
udpServer *udp.Server
webServer *http.Server
// Game state
worldTime WorldTime
worldTimeMux sync.RWMutex
// Client management
clients map[int32]*ClientInfo
clientsMux sync.RWMutex
// Zone management
zones map[int32]*ZoneInfo
zonesMux sync.RWMutex
// Statistics
stats ServerStats
statsMux sync.RWMutex
// Control
ctx context.Context
cancel context.CancelFunc
shutdownWg *sync.WaitGroup
// Timers
timeTickTimer *time.Ticker
saveTimer *time.Ticker
vitalityTimer *time.Ticker
statsTimer *time.Ticker
watchdogTimer *time.Ticker
loginCheckTimer *time.Ticker
// Loading state
loadingMux sync.RWMutex
itemsLoaded bool
spellsLoaded bool
questsLoaded bool
traitsLoaded bool
dataLoaded bool
}
// NewWorld creates a new world server instance
func NewWorld(config *WorldConfig) (*World, error) {
ctx, cancel := context.WithCancel(context.Background())
db, err := database.Open(config.DatabasePath)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to open database: %w", err)
}
w := &World{
config: config,
db: db,
ctx: ctx,
cancel: cancel,
shutdownWg: &sync.WaitGroup{},
clients: make(map[int32]*ClientInfo),
zones: make(map[int32]*ZoneInfo),
stats: ServerStats{
StartTime: time.Now(),
},
}
// Initialize world time from database
if err := w.loadWorldTime(); err != nil {
log.Printf("Warning: failed to load world time: %v", err)
w.setDefaultWorldTime()
}
return w, nil
}
// Initialize sets up all world server components
func (w *World) Initialize() error {
log.Println("Loading System Data...")
// Initialize database schema
if err := w.initializeDatabase(); err != nil {
return fmt.Errorf("database initialization failed: %w", err)
}
// Load game data (threaded or sequential)
if w.config.ThreadedLoad {
log.Println("Using threaded loading of static data...")
if err := w.loadGameDataThreaded(); err != nil {
return fmt.Errorf("threaded game data loading failed: %w", err)
}
} else {
if err := w.loadGameData(); err != nil {
return fmt.Errorf("game data loading failed: %w", err)
}
}
// Setup UDP server for game connections
if err := w.setupUDPServer(); err != nil {
return fmt.Errorf("UDP server setup failed: %w", err)
}
// Setup web server for admin/API
if err := w.setupWebServer(); err != nil {
return fmt.Errorf("web server setup failed: %w", err)
}
// Initialize timers
w.initializeTimers()
log.Println("World Server initialization complete")
return nil
}
// Run starts the world server main loop
func (w *World) Run() error {
// Start background processes
w.shutdownWg.Add(6)
go w.processTimeUpdates()
go w.processSaveOperations()
go w.processVitalityUpdates()
go w.processStatsUpdates()
go w.processWatchdog()
go w.processLoginCheck()
// Start network servers
if w.udpServer != nil {
go func() {
if err := w.udpServer.Start(); err != nil {
log.Printf("UDP server error: %v", err)
}
}()
}
// Start web server
w.startWebServer()
log.Printf("World Server running on UDP %s:%d, Web %s:%d",
w.config.ListenAddr, w.config.ListenPort,
w.config.WebAddr, w.config.WebPort)
// Wait for shutdown signal
<-w.ctx.Done()
return w.shutdown()
}
// Shutdown gracefully stops the world server
func (w *World) Shutdown() {
log.Println("Initiating World Server shutdown...")
w.cancel()
}
// setupUDPServer initializes the UDP server for game client connections
func (w *World) setupUDPServer() error {
handler := func(conn *udp.Connection, packet *udp.ApplicationPacket) {
w.handleGamePacket(conn, packet)
}
config := udp.DefaultConfig()
config.MaxConnections = w.config.MaxClients
config.BufferSize = w.config.BufferSize
config.EnableCompression = true
config.EnableEncryption = true
addr := fmt.Sprintf("%s:%d", w.config.ListenAddr, w.config.ListenPort)
server, err := udp.NewServer(addr, handler, config)
if err != nil {
return err
}
w.udpServer = server
return nil
}
// initializeTimers sets up all periodic timers
func (w *World) initializeTimers() {
w.timeTickTimer = time.NewTicker(5 * time.Second) // Game time updates
w.saveTimer = time.NewTicker(5 * time.Minute) // Save operations
w.vitalityTimer = time.NewTicker(1 * time.Hour) // Vitality updates
w.statsTimer = time.NewTicker(1 * time.Minute) // Statistics updates
w.watchdogTimer = time.NewTicker(30 * time.Second) // Watchdog checks
w.loginCheckTimer = time.NewTicker(30 * time.Second) // Login server check
}
// Background processes
// processTimeUpdates handles game world time progression
func (w *World) processTimeUpdates() {
defer w.shutdownWg.Done()
for {
select {
case <-w.ctx.Done():
return
case <-w.timeTickTimer.C:
w.updateWorldTime()
}
}
}
// processSaveOperations handles periodic save operations
func (w *World) processSaveOperations() {
defer w.shutdownWg.Done()
for {
select {
case <-w.ctx.Done():
return
case <-w.saveTimer.C:
w.saveWorldState()
}
}
}
// processVitalityUpdates handles vitality system updates
func (w *World) processVitalityUpdates() {
defer w.shutdownWg.Done()
for {
select {
case <-w.ctx.Done():
return
case <-w.vitalityTimer.C:
w.updateVitality()
}
}
}
// processStatsUpdates handles statistics collection
func (w *World) processStatsUpdates() {
defer w.shutdownWg.Done()
for {
select {
case <-w.ctx.Done():
return
case <-w.statsTimer.C:
w.updateStatistics()
}
}
}
// processWatchdog handles connection timeouts and cleanup
func (w *World) processWatchdog() {
defer w.shutdownWg.Done()
for {
select {
case <-w.ctx.Done():
return
case <-w.watchdogTimer.C:
w.cleanupInactiveClients()
w.cleanupTimeoutConnections()
}
}
}
// processLoginCheck handles login server connectivity
func (w *World) processLoginCheck() {
defer w.shutdownWg.Done()
for {
select {
case <-w.ctx.Done():
return
case <-w.loginCheckTimer.C:
w.checkLoginServers()
}
}
}
// Game packet handling
func (w *World) handleGamePacket(conn *udp.Connection, packet *udp.ApplicationPacket) {
// Update connection activity
w.updateConnectionActivity(conn)
// Route packet based on opcode
switch packet.Opcode {
case 0x2000: // Login request
w.handleLoginRequest(conn, packet)
case 0x0020: // Zone change request
w.handleZoneChange(conn, packet)
case 0x0080: // Client command
w.handleClientCommand(conn, packet)
case 0x01F0: // Chat message
w.handleChatMessage(conn, packet)
default:
// @TODO: Implement comprehensive packet routing
log.Printf("Unhandled packet opcode: 0x%04X, size: %d", packet.Opcode, len(packet.Data))
}
// Update packet statistics
w.statsMux.Lock()
w.stats.PacketsProcessed++
w.statsMux.Unlock()
}
// Game packet handlers
func (w *World) handleLoginRequest(conn *udp.Connection, packet *udp.ApplicationPacket) {
// @TODO: Parse login request packet
// @TODO: Validate credentials with login server
// @TODO: Create client session
// @TODO: Send login response
log.Printf("Login request from connection %d", conn.GetSessionID())
}
func (w *World) handleZoneChange(conn *udp.Connection, packet *udp.ApplicationPacket) {
// @TODO: Parse zone change request
// @TODO: Validate zone transfer
// @TODO: Coordinate with zone servers
// @TODO: Send zone change response
log.Printf("Zone change request from connection %d", conn.GetSessionID())
}
func (w *World) handleClientCommand(conn *udp.Connection, packet *udp.ApplicationPacket) {
// @TODO: Parse client command packet
// @TODO: Process administrative commands
// @TODO: Route to appropriate handlers
log.Printf("Client command from connection %d", conn.GetSessionID())
}
func (w *World) handleChatMessage(conn *udp.Connection, packet *udp.ApplicationPacket) {
// @TODO: Parse chat message packet
// @TODO: Handle channel routing
// @TODO: Apply filters and permissions
// @TODO: Broadcast to appropriate recipients
log.Printf("Chat message from connection %d", conn.GetSessionID())
}
// Game state management
func (w *World) updateWorldTime() {
w.worldTimeMux.Lock()
defer w.worldTimeMux.Unlock()
w.worldTime.Minute++
if w.worldTime.Minute >= 60 {
w.worldTime.Minute = 0
w.worldTime.Hour++
if w.worldTime.Hour >= 24 {
w.worldTime.Hour = 0
w.worldTime.Day++
if w.worldTime.Day >= 30 {
w.worldTime.Day = 0
w.worldTime.Month++
if w.worldTime.Month >= 12 {
w.worldTime.Month = 0
w.worldTime.Year++
}
}
}
}
// @TODO: Broadcast time update to all zones/clients
// @TODO: Save time to database periodically
}
func (w *World) saveWorldState() {
// @TODO: Save world time to database
// @TODO: Save player data
// @TODO: Save guild data
// @TODO: Save zone states
// @TODO: Save server statistics
log.Println("Saving world state...")
}
func (w *World) updateVitality() {
// @TODO: Update player vitality for offline/resting players
// @TODO: Broadcast vitality updates to zones
// @TODO: Apply vitality bonuses
log.Println("Updating vitality...")
}
func (w *World) updateStatistics() {
w.statsMux.Lock()
defer w.statsMux.Unlock()
// Update client count
w.clientsMux.RLock()
w.stats.ClientCount = int32(len(w.clients))
w.clientsMux.RUnlock()
// Update zone count
w.zonesMux.RLock()
w.stats.ZoneCount = int32(len(w.zones))
w.zonesMux.RUnlock()
// Update loading status
w.loadingMux.RLock()
w.stats.DataLoaded = w.dataLoaded
w.stats.ItemsLoaded = w.itemsLoaded
w.stats.SpellsLoaded = w.spellsLoaded
w.stats.QuestsLoaded = w.questsLoaded
w.loadingMux.RUnlock()
}
func (w *World) cleanupInactiveClients() {
w.clientsMux.Lock()
defer w.clientsMux.Unlock()
timeout := time.Now().Add(-5 * time.Minute)
for id, client := range w.clients {
if client.LastActive.Before(timeout) {
log.Printf("Removing inactive client %d (%s)", id, client.Name)
delete(w.clients, id)
}
}
}
func (w *World) cleanupTimeoutConnections() {
// @TODO: Clean up timed out UDP connections
// @TODO: Update connection statistics
}
func (w *World) checkLoginServers() {
// @TODO: Check connectivity to login servers
// @TODO: Attempt reconnection if disconnected
// @TODO: Update server status
}
func (w *World) updateConnectionActivity(conn *udp.Connection) {
sessionID := conn.GetSessionID()
w.clientsMux.Lock()
if client, exists := w.clients[int32(sessionID)]; exists {
client.LastActive = time.Now()
}
w.clientsMux.Unlock()
}
// Database operations
func (w *World) initializeDatabase() error {
// @TODO: Create/update database schema tables
// @TODO: Initialize character tables
// @TODO: Initialize guild tables
// @TODO: Initialize item tables
// @TODO: Initialize zone tables
log.Println("Database schema initialized")
return nil
}
func (w *World) loadGameData() error {
log.Println("Loading game data sequentially...")
// Load items
log.Println("Loading items...")
if err := w.loadItems(); err != nil {
return fmt.Errorf("failed to load items: %w", err)
}
// Load spells
log.Println("Loading spells...")
if err := w.loadSpells(); err != nil {
return fmt.Errorf("failed to load spells: %w", err)
}
// Load quests
log.Println("Loading quests...")
if err := w.loadQuests(); err != nil {
return fmt.Errorf("failed to load quests: %w", err)
}
// Load additional data
if err := w.loadTraits(); err != nil {
return fmt.Errorf("failed to load traits: %w", err)
}
if err := w.loadNPCs(); err != nil {
return fmt.Errorf("failed to load NPCs: %w", err)
}
if err := w.loadZones(); err != nil {
return fmt.Errorf("failed to load zones: %w", err)
}
w.loadingMux.Lock()
w.dataLoaded = true
w.loadingMux.Unlock()
log.Println("Game data loading complete")
return nil
}
func (w *World) loadGameDataThreaded() error {
log.Println("Loading game data with threads...")
var wg sync.WaitGroup
errChan := make(chan error, 10)
// Load items in thread
wg.Add(1)
go func() {
defer wg.Done()
log.Println("Loading items...")
if err := w.loadItems(); err != nil {
errChan <- fmt.Errorf("failed to load items: %w", err)
return
}
w.loadingMux.Lock()
w.itemsLoaded = true
w.loadingMux.Unlock()
log.Println("Items loaded")
}()
// Load spells in thread
wg.Add(1)
go func() {
defer wg.Done()
log.Println("Loading spells...")
if err := w.loadSpells(); err != nil {
errChan <- fmt.Errorf("failed to load spells: %w", err)
return
}
w.loadingMux.Lock()
w.spellsLoaded = true
w.loadingMux.Unlock()
log.Println("Spells loaded")
}()
// Load quests in thread
wg.Add(1)
go func() {
defer wg.Done()
log.Println("Loading quests...")
if err := w.loadQuests(); err != nil {
errChan <- fmt.Errorf("failed to load quests: %w", err)
return
}
w.loadingMux.Lock()
w.questsLoaded = true
w.loadingMux.Unlock()
log.Println("Quests loaded")
}()
// Wait for completion
go func() {
wg.Wait()
close(errChan)
}()
// Check for errors
for err := range errChan {
if err != nil {
return err
}
}
// Load additional data sequentially
if err := w.loadTraits(); err != nil {
return fmt.Errorf("failed to load traits: %w", err)
}
if err := w.loadNPCs(); err != nil {
return fmt.Errorf("failed to load NPCs: %w", err)
}
if err := w.loadZones(); err != nil {
return fmt.Errorf("failed to load zones: %w", err)
}
// Wait for threaded loads to complete
for !w.isDataLoaded() {
time.Sleep(100 * time.Millisecond)
}
w.loadingMux.Lock()
w.dataLoaded = true
w.loadingMux.Unlock()
log.Println("Threaded game data loading complete")
return nil
}
// Data loading functions
func (w *World) loadItems() error {
// @TODO: Load items from database
// @TODO: Build item lookup tables
// @TODO: Load item templates
// @TODO: Initialize item factories
return nil
}
func (w *World) loadSpells() error {
// @TODO: Load spells from database
// @TODO: Build spell lookup tables
// @TODO: Load spell effects
// @TODO: Initialize spell system
return nil
}
func (w *World) loadQuests() error {
// @TODO: Load quests from database
// @TODO: Build quest lookup tables
// @TODO: Load quest rewards
// @TODO: Initialize quest system
return nil
}
func (w *World) loadTraits() error {
// @TODO: Load traits from database
// @TODO: Build trait trees
// @TODO: Initialize trait system
return nil
}
func (w *World) loadNPCs() error {
// @TODO: Load NPCs from database
// @TODO: Load NPC templates
// @TODO: Load NPC spawn data
return nil
}
func (w *World) loadZones() error {
// @TODO: Load zone definitions
// @TODO: Load zone spawn points
// @TODO: Initialize zone management
return nil
}
func (w *World) loadWorldTime() error {
// @TODO: Load world time from database
w.worldTime = WorldTime{
Year: 3800,
Month: 0,
Day: 0,
Hour: 8,
Minute: 30,
}
return nil
}
func (w *World) setDefaultWorldTime() {
w.worldTimeMux.Lock()
defer w.worldTimeMux.Unlock()
w.worldTime = WorldTime{
Year: 3800,
Month: 0,
Day: 0,
Hour: 8,
Minute: 30,
}
}
func (w *World) isDataLoaded() bool {
w.loadingMux.RLock()
defer w.loadingMux.RUnlock()
if w.config.ThreadedLoad {
return w.itemsLoaded && w.spellsLoaded && w.questsLoaded && w.traitsLoaded
}
return w.dataLoaded
}
// Cleanup and shutdown
func (w *World) shutdown() error {
log.Println("Shutting down World Server...")
// Stop timers
if w.timeTickTimer != nil {
w.timeTickTimer.Stop()
}
if w.saveTimer != nil {
w.saveTimer.Stop()
}
if w.vitalityTimer != nil {
w.vitalityTimer.Stop()
}
if w.statsTimer != nil {
w.statsTimer.Stop()
}
if w.watchdogTimer != nil {
w.watchdogTimer.Stop()
}
if w.loginCheckTimer != nil {
w.loginCheckTimer.Stop()
}
// Stop network servers
if err := w.stopWebServer(); err != nil {
log.Printf("Error stopping web server: %v", err)
}
if w.udpServer != nil {
w.udpServer.Stop()
}
// Wait for background processes
w.shutdownWg.Wait()
// Save final state
w.saveWorldState()
// Close database
if w.db != nil {
w.db.Close()
}
log.Println("World Server shutdown complete")
return nil
}

View File

@ -0,0 +1,301 @@
package achievements
import (
"eq2emu/internal/database"
"fmt"
"time"
)
// LoadAllAchievements loads all achievements from database into master list
func LoadAllAchievements(db *database.DB, masterList *MasterList) error {
query := `SELECT achievement_id, title, uncompleted_text, completed_text,
category, expansion, icon, point_value, qty_req, hide_achievement,
unknown3a, unknown3b FROM achievements`
err := db.Query(query, func(row *database.Row) error {
achievement := NewAchievement()
achievement.ID = uint32(row.Int(0))
achievement.Title = row.Text(1)
achievement.UncompletedText = row.Text(2)
achievement.CompletedText = row.Text(3)
achievement.Category = row.Text(4)
achievement.Expansion = row.Text(5)
achievement.Icon = uint16(row.Int(6))
achievement.PointValue = uint32(row.Int(7))
achievement.QtyRequired = uint32(row.Int(8))
achievement.Hide = row.Bool(9)
achievement.Unknown3A = uint32(row.Int(10))
achievement.Unknown3B = uint32(row.Int(11))
// Load requirements and rewards
if err := loadAchievementRequirements(db, achievement); err != nil {
return fmt.Errorf("failed to load requirements for achievement %d: %w", achievement.ID, err)
}
if err := loadAchievementRewards(db, achievement); err != nil {
return fmt.Errorf("failed to load rewards for achievement %d: %w", achievement.ID, err)
}
if !masterList.AddAchievement(achievement) {
return fmt.Errorf("duplicate achievement ID: %d", achievement.ID)
}
return nil
})
return err
}
// loadAchievementRequirements loads requirements for a specific achievement
func loadAchievementRequirements(db *database.DB, achievement *Achievement) error {
query := `SELECT achievement_id, name, qty_req
FROM achievements_requirements
WHERE achievement_id = ?`
return db.Query(query, func(row *database.Row) error {
req := Requirement{
AchievementID: uint32(row.Int(0)),
Name: row.Text(1),
QtyRequired: uint32(row.Int(2)),
}
achievement.AddRequirement(req)
return nil
}, achievement.ID)
}
// loadAchievementRewards loads rewards for a specific achievement
func loadAchievementRewards(db *database.DB, achievement *Achievement) error {
query := `SELECT achievement_id, reward
FROM achievements_rewards
WHERE achievement_id = ?`
return db.Query(query, func(row *database.Row) error {
reward := Reward{
AchievementID: uint32(row.Int(0)),
Reward: row.Text(1),
}
achievement.AddReward(reward)
return nil
}, achievement.ID)
}
// LoadPlayerAchievements loads player achievements from database
func LoadPlayerAchievements(db *database.DB, playerID uint32, playerList *PlayerList) error {
query := `SELECT achievement_id, title, uncompleted_text, completed_text,
category, expansion, icon, point_value, qty_req, hide_achievement,
unknown3a, unknown3b FROM achievements`
err := db.Query(query, func(row *database.Row) error {
achievement := NewAchievement()
achievement.ID = uint32(row.Int(0))
achievement.Title = row.Text(1)
achievement.UncompletedText = row.Text(2)
achievement.CompletedText = row.Text(3)
achievement.Category = row.Text(4)
achievement.Expansion = row.Text(5)
achievement.Icon = uint16(row.Int(6))
achievement.PointValue = uint32(row.Int(7))
achievement.QtyRequired = uint32(row.Int(8))
achievement.Hide = row.Bool(9)
achievement.Unknown3A = uint32(row.Int(10))
achievement.Unknown3B = uint32(row.Int(11))
// Load requirements and rewards
if err := loadAchievementRequirements(db, achievement); err != nil {
return fmt.Errorf("failed to load requirements: %w", err)
}
if err := loadAchievementRewards(db, achievement); err != nil {
return fmt.Errorf("failed to load rewards: %w", err)
}
if !playerList.AddAchievement(achievement) {
return fmt.Errorf("duplicate achievement ID: %d", achievement.ID)
}
return nil
})
return err
}
// LoadPlayerAchievementUpdates loads player achievement progress from database
func LoadPlayerAchievementUpdates(db *database.DB, playerID uint32, updateList *PlayerUpdateList) error {
query := `SELECT char_id, achievement_id, completed_date
FROM character_achievements
WHERE char_id = ?`
return db.Query(query, func(row *database.Row) error {
update := NewUpdate()
update.ID = uint32(row.Int(1))
// Convert completed_date from Unix timestamp
if !row.IsNull(2) {
timestamp := row.Int64(2)
update.CompletedDate = time.Unix(timestamp, 0)
}
// Load update items
if err := loadPlayerAchievementUpdateItems(db, playerID, update); err != nil {
return fmt.Errorf("failed to load update items: %w", err)
}
if !updateList.AddUpdate(update) {
return fmt.Errorf("duplicate achievement update ID: %d", update.ID)
}
return nil
}, playerID)
}
// loadPlayerAchievementUpdateItems loads progress items for an achievement update
func loadPlayerAchievementUpdateItems(db *database.DB, playerID uint32, update *Update) error {
query := `SELECT achievement_id, items
FROM character_achievements_items
WHERE char_id = ? AND achievement_id = ?`
return db.Query(query, func(row *database.Row) error {
item := UpdateItem{
AchievementID: uint32(row.Int(0)),
ItemUpdate: uint32(row.Int(1)),
}
update.AddUpdateItem(item)
return nil
}, playerID, update.ID)
}
// SavePlayerAchievementUpdate saves or updates player achievement progress
func SavePlayerAchievementUpdate(db *database.DB, playerID uint32, update *Update) error {
return db.Transaction(func(tx *database.DB) error {
// Save or update main achievement record
query := `INSERT OR REPLACE INTO character_achievements
(char_id, achievement_id, completed_date) VALUES (?, ?, ?)`
var completedDate *int64
if !update.CompletedDate.IsZero() {
timestamp := update.CompletedDate.Unix()
completedDate = &timestamp
}
if err := tx.Exec(query, playerID, update.ID, completedDate); err != nil {
return fmt.Errorf("failed to save achievement update: %w", err)
}
// Delete existing update items
deleteQuery := `DELETE FROM character_achievements_items
WHERE char_id = ? AND achievement_id = ?`
if err := tx.Exec(deleteQuery, playerID, update.ID); err != nil {
return fmt.Errorf("failed to delete old update items: %w", err)
}
// Insert new update items
itemQuery := `INSERT INTO character_achievements_items
(char_id, achievement_id, items) VALUES (?, ?, ?)`
for _, item := range update.UpdateItems {
if err := tx.Exec(itemQuery, playerID, item.AchievementID, item.ItemUpdate); err != nil {
return fmt.Errorf("failed to save update item: %w", err)
}
}
return nil
})
}
// DeletePlayerAchievementUpdate removes player achievement progress from database
func DeletePlayerAchievementUpdate(db *database.DB, playerID uint32, achievementID uint32) error {
return db.Transaction(func(tx *database.DB) error {
// Delete main achievement record
query := `DELETE FROM character_achievements
WHERE char_id = ? AND achievement_id = ?`
if err := tx.Exec(query, playerID, achievementID); err != nil {
return fmt.Errorf("failed to delete achievement update: %w", err)
}
// Delete update items
itemQuery := `DELETE FROM character_achievements_items
WHERE char_id = ? AND achievement_id = ?`
if err := tx.Exec(itemQuery, playerID, achievementID); err != nil {
return fmt.Errorf("failed to delete update items: %w", err)
}
return nil
})
}
// SaveAchievement saves or updates an achievement in the database
func SaveAchievement(db *database.DB, achievement *Achievement) error {
return db.Transaction(func(tx *database.DB) error {
// Save main achievement record
query := `INSERT OR REPLACE INTO achievements
(achievement_id, title, uncompleted_text, completed_text,
category, expansion, icon, point_value, qty_req,
hide_achievement, unknown3a, unknown3b)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
if err := tx.Exec(query, achievement.ID, achievement.Title,
achievement.UncompletedText, achievement.CompletedText,
achievement.Category, achievement.Expansion, achievement.Icon,
achievement.PointValue, achievement.QtyRequired, achievement.Hide,
achievement.Unknown3A, achievement.Unknown3B); err != nil {
return fmt.Errorf("failed to save achievement: %w", err)
}
// Delete existing requirements and rewards
if err := tx.Exec("DELETE FROM achievements_requirements WHERE achievement_id = ?", achievement.ID); err != nil {
return fmt.Errorf("failed to delete old requirements: %w", err)
}
if err := tx.Exec("DELETE FROM achievements_rewards WHERE achievement_id = ?", achievement.ID); err != nil {
return fmt.Errorf("failed to delete old rewards: %w", err)
}
// Insert requirements
reqQuery := `INSERT INTO achievements_requirements
(achievement_id, name, qty_req) VALUES (?, ?, ?)`
for _, req := range achievement.Requirements {
if err := tx.Exec(reqQuery, req.AchievementID, req.Name, req.QtyRequired); err != nil {
return fmt.Errorf("failed to save requirement: %w", err)
}
}
// Insert rewards
rewardQuery := `INSERT INTO achievements_rewards
(achievement_id, reward) VALUES (?, ?)`
for _, reward := range achievement.Rewards {
if err := tx.Exec(rewardQuery, reward.AchievementID, reward.Reward); err != nil {
return fmt.Errorf("failed to save reward: %w", err)
}
}
return nil
})
}
// DeleteAchievement removes an achievement and all related records from database
func DeleteAchievement(db *database.DB, achievementID uint32) error {
return db.Transaction(func(tx *database.DB) error {
// Delete main achievement
if err := tx.Exec("DELETE FROM achievements WHERE achievement_id = ?", achievementID); err != nil {
return fmt.Errorf("failed to delete achievement: %w", err)
}
// Delete requirements
if err := tx.Exec("DELETE FROM achievements_requirements WHERE achievement_id = ?", achievementID); err != nil {
return fmt.Errorf("failed to delete requirements: %w", err)
}
// Delete rewards
if err := tx.Exec("DELETE FROM achievements_rewards WHERE achievement_id = ?", achievementID); err != nil {
return fmt.Errorf("failed to delete rewards: %w", err)
}
// Delete player progress (optional - might want to preserve history)
if err := tx.Exec("DELETE FROM character_achievements WHERE achievement_id = ?", achievementID); err != nil {
return fmt.Errorf("failed to delete player achievements: %w", err)
}
if err := tx.Exec("DELETE FROM character_achievements_items WHERE achievement_id = ?", achievementID); err != nil {
return fmt.Errorf("failed to delete player achievement items: %w", err)
}
return nil
})
}

View File

@ -0,0 +1,32 @@
// Package achievements provides a complete achievement system for EQ2Emulator servers.
//
// The package includes:
// - Achievement definitions with requirements and rewards
// - Master achievement list for server-wide management
// - Player-specific achievement tracking and progress
// - Database operations for persistence
//
// Basic usage:
//
// // Create master list and load from database
// masterList := achievements.NewMasterList()
// db, _ := database.Open("world.db")
// achievements.LoadAllAchievements(db, masterList)
//
// // Create player manager
// playerMgr := achievements.NewPlayerManager()
// achievements.LoadPlayerAchievements(db, playerID, playerMgr.Achievements)
// achievements.LoadPlayerAchievementUpdates(db, playerID, playerMgr.Updates)
//
// // Update player progress
// playerMgr.Updates.UpdateProgress(achievementID, newProgress)
//
// // Check completion
// if playerMgr.Updates.IsCompleted(achievementID) {
// // Handle completed achievement
// }
//
// // Save progress
// update := playerMgr.Updates.GetUpdate(achievementID)
// achievements.SavePlayerAchievementUpdate(db, playerID, update)
package achievements

View File

@ -0,0 +1,197 @@
package achievements
import (
"fmt"
"sync"
)
// MasterList manages the global list of all achievements
type MasterList struct {
achievements map[uint32]*Achievement
mutex sync.RWMutex
}
// NewMasterList creates a new master achievement list
func NewMasterList() *MasterList {
return &MasterList{
achievements: make(map[uint32]*Achievement),
}
}
// AddAchievement adds an achievement to the master list
// Returns false if achievement with same ID already exists
func (m *MasterList) AddAchievement(achievement *Achievement) bool {
if achievement == nil {
return false
}
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.achievements[achievement.ID]; exists {
return false
}
m.achievements[achievement.ID] = achievement
return true
}
// GetAchievement retrieves an achievement by ID
// Returns nil if not found
func (m *MasterList) GetAchievement(id uint32) *Achievement {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.achievements[id]
}
// GetAchievementClone retrieves a cloned copy of an achievement by ID
// Returns nil if not found. Safe for modification without affecting master list
func (m *MasterList) GetAchievementClone(id uint32) *Achievement {
m.mutex.RLock()
achievement := m.achievements[id]
m.mutex.RUnlock()
if achievement == nil {
return nil
}
return achievement.Clone()
}
// GetAllAchievements returns a map of all achievements (read-only access)
// The returned map should not be modified
func (m *MasterList) GetAllAchievements() map[uint32]*Achievement {
m.mutex.RLock()
defer m.mutex.RUnlock()
// Return copy of map to prevent external modification
result := make(map[uint32]*Achievement, len(m.achievements))
for id, achievement := range m.achievements {
result[id] = achievement
}
return result
}
// GetAchievementsByCategory returns achievements filtered by category
func (m *MasterList) GetAchievementsByCategory(category string) []*Achievement {
m.mutex.RLock()
defer m.mutex.RUnlock()
var result []*Achievement
for _, achievement := range m.achievements {
if achievement.Category == category {
result = append(result, achievement)
}
}
return result
}
// GetAchievementsByExpansion returns achievements filtered by expansion
func (m *MasterList) GetAchievementsByExpansion(expansion string) []*Achievement {
m.mutex.RLock()
defer m.mutex.RUnlock()
var result []*Achievement
for _, achievement := range m.achievements {
if achievement.Expansion == expansion {
result = append(result, achievement)
}
}
return result
}
// RemoveAchievement removes an achievement from the master list
// Returns true if achievement was found and removed
func (m *MasterList) RemoveAchievement(id uint32) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.achievements[id]; !exists {
return false
}
delete(m.achievements, id)
return true
}
// UpdateAchievement updates an existing achievement
// Returns error if achievement doesn't exist
func (m *MasterList) UpdateAchievement(achievement *Achievement) error {
if achievement == nil {
return fmt.Errorf("achievement cannot be nil")
}
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.achievements[achievement.ID]; !exists {
return fmt.Errorf("achievement with ID %d does not exist", achievement.ID)
}
m.achievements[achievement.ID] = achievement
return nil
}
// Clear removes all achievements from the master list
func (m *MasterList) Clear() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.achievements = make(map[uint32]*Achievement)
}
// Size returns the number of achievements in the master list
func (m *MasterList) Size() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.achievements)
}
// Exists checks if an achievement with given ID exists
func (m *MasterList) Exists(id uint32) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
_, exists := m.achievements[id]
return exists
}
// GetCategories returns all unique categories
func (m *MasterList) GetCategories() []string {
m.mutex.RLock()
defer m.mutex.RUnlock()
categories := make(map[string]bool)
for _, achievement := range m.achievements {
if achievement.Category != "" {
categories[achievement.Category] = true
}
}
result := make([]string, 0, len(categories))
for category := range categories {
result = append(result, category)
}
return result
}
// GetExpansions returns all unique expansions
func (m *MasterList) GetExpansions() []string {
m.mutex.RLock()
defer m.mutex.RUnlock()
expansions := make(map[string]bool)
for _, achievement := range m.achievements {
if achievement.Expansion != "" {
expansions[achievement.Expansion] = true
}
}
result := make([]string, 0, len(expansions))
for expansion := range expansions {
result = append(result, expansion)
}
return result
}

View File

@ -0,0 +1,282 @@
package achievements
import (
"fmt"
"time"
)
// PlayerList manages achievements for a specific player
type PlayerList struct {
achievements map[uint32]*Achievement
}
// PlayerUpdateList manages achievement updates/progress for a specific player
type PlayerUpdateList struct {
updates map[uint32]*Update
}
// NewPlayerList creates a new player achievement list
func NewPlayerList() *PlayerList {
return &PlayerList{
achievements: make(map[uint32]*Achievement),
}
}
// NewPlayerUpdateList creates a new player achievement update list
func NewPlayerUpdateList() *PlayerUpdateList {
return &PlayerUpdateList{
updates: make(map[uint32]*Update),
}
}
// AddAchievement adds an achievement to the player's list
// Returns false if achievement with same ID already exists
func (p *PlayerList) AddAchievement(achievement *Achievement) bool {
if achievement == nil {
return false
}
if _, exists := p.achievements[achievement.ID]; exists {
return false
}
p.achievements[achievement.ID] = achievement
return true
}
// GetAchievement retrieves an achievement by ID
// Returns nil if not found
func (p *PlayerList) GetAchievement(id uint32) *Achievement {
return p.achievements[id]
}
// GetAllAchievements returns all player achievements
func (p *PlayerList) GetAllAchievements() map[uint32]*Achievement {
result := make(map[uint32]*Achievement, len(p.achievements))
for id, achievement := range p.achievements {
result[id] = achievement
}
return result
}
// RemoveAchievement removes an achievement from the player's list
// Returns true if achievement was found and removed
func (p *PlayerList) RemoveAchievement(id uint32) bool {
if _, exists := p.achievements[id]; !exists {
return false
}
delete(p.achievements, id)
return true
}
// HasAchievement checks if player has a specific achievement
func (p *PlayerList) HasAchievement(id uint32) bool {
_, exists := p.achievements[id]
return exists
}
// Clear removes all achievements from the player's list
func (p *PlayerList) Clear() {
p.achievements = make(map[uint32]*Achievement)
}
// Size returns the number of achievements in the player's list
func (p *PlayerList) Size() int {
return len(p.achievements)
}
// GetAchievementsByCategory returns player achievements filtered by category
func (p *PlayerList) GetAchievementsByCategory(category string) []*Achievement {
var result []*Achievement
for _, achievement := range p.achievements {
if achievement.Category == category {
result = append(result, achievement)
}
}
return result
}
// AddUpdate adds an achievement update to the player's list
// Returns false if update with same ID already exists
func (p *PlayerUpdateList) AddUpdate(update *Update) bool {
if update == nil {
return false
}
if _, exists := p.updates[update.ID]; exists {
return false
}
p.updates[update.ID] = update
return true
}
// GetUpdate retrieves an achievement update by ID
// Returns nil if not found
func (p *PlayerUpdateList) GetUpdate(id uint32) *Update {
return p.updates[id]
}
// GetAllUpdates returns all player achievement updates
func (p *PlayerUpdateList) GetAllUpdates() map[uint32]*Update {
result := make(map[uint32]*Update, len(p.updates))
for id, update := range p.updates {
result[id] = update
}
return result
}
// UpdateProgress updates or creates achievement progress
func (p *PlayerUpdateList) UpdateProgress(achievementID uint32, itemUpdate uint32) {
update := p.updates[achievementID]
if update == nil {
update = NewUpdate()
update.ID = achievementID
p.updates[achievementID] = update
}
// Add or update the progress item
found := false
for i := range update.UpdateItems {
if update.UpdateItems[i].AchievementID == achievementID {
update.UpdateItems[i].ItemUpdate = itemUpdate
found = true
break
}
}
if !found {
update.AddUpdateItem(UpdateItem{
AchievementID: achievementID,
ItemUpdate: itemUpdate,
})
}
}
// CompleteAchievement marks an achievement as completed
func (p *PlayerUpdateList) CompleteAchievement(achievementID uint32) {
update := p.updates[achievementID]
if update == nil {
update = NewUpdate()
update.ID = achievementID
p.updates[achievementID] = update
}
update.CompletedDate = time.Now()
}
// IsCompleted checks if an achievement is completed
func (p *PlayerUpdateList) IsCompleted(achievementID uint32) bool {
update := p.updates[achievementID]
return update != nil && !update.CompletedDate.IsZero()
}
// GetCompletedDate returns the completion date for an achievement
// Returns zero time if not completed
func (p *PlayerUpdateList) GetCompletedDate(achievementID uint32) time.Time {
update := p.updates[achievementID]
if update == nil {
return time.Time{}
}
return update.CompletedDate
}
// GetProgress returns the current progress for an achievement
// Returns 0 if no progress found
func (p *PlayerUpdateList) GetProgress(achievementID uint32) uint32 {
update := p.updates[achievementID]
if update == nil || len(update.UpdateItems) == 0 {
return 0
}
// Return the first matching update item's progress
for _, item := range update.UpdateItems {
if item.AchievementID == achievementID {
return item.ItemUpdate
}
}
return 0
}
// RemoveUpdate removes an achievement update from the player's list
// Returns true if update was found and removed
func (p *PlayerUpdateList) RemoveUpdate(id uint32) bool {
if _, exists := p.updates[id]; !exists {
return false
}
delete(p.updates, id)
return true
}
// Clear removes all updates from the player's list
func (p *PlayerUpdateList) Clear() {
p.updates = make(map[uint32]*Update)
}
// Size returns the number of updates in the player's list
func (p *PlayerUpdateList) Size() int {
return len(p.updates)
}
// GetCompletedAchievements returns all completed achievement IDs
func (p *PlayerUpdateList) GetCompletedAchievements() []uint32 {
var completed []uint32
for id, update := range p.updates {
if !update.CompletedDate.IsZero() {
completed = append(completed, id)
}
}
return completed
}
// GetInProgressAchievements returns all in-progress achievement IDs
func (p *PlayerUpdateList) GetInProgressAchievements() []uint32 {
var inProgress []uint32
for id, update := range p.updates {
if update.CompletedDate.IsZero() && len(update.UpdateItems) > 0 {
inProgress = append(inProgress, id)
}
}
return inProgress
}
// PlayerManager combines achievement list and update list for a player
type PlayerManager struct {
Achievements *PlayerList
Updates *PlayerUpdateList
}
// NewPlayerManager creates a new player manager
func NewPlayerManager() *PlayerManager {
return &PlayerManager{
Achievements: NewPlayerList(),
Updates: NewPlayerUpdateList(),
}
}
// CheckRequirements validates if player meets achievement requirements
// This is a basic implementation - extend as needed for specific game logic
func (pm *PlayerManager) CheckRequirements(achievement *Achievement) (bool, error) {
if achievement == nil {
return false, fmt.Errorf("achievement cannot be nil")
}
// Basic implementation - check if we have progress >= required quantity
progress := pm.Updates.GetProgress(achievement.ID)
return progress >= achievement.QtyRequired, nil
}
// GetCompletionStatus returns completion percentage for an achievement
func (pm *PlayerManager) GetCompletionStatus(achievement *Achievement) float64 {
if achievement == nil || achievement.QtyRequired == 0 {
return 0.0
}
progress := pm.Updates.GetProgress(achievement.ID)
if progress >= achievement.QtyRequired {
return 100.0
}
return (float64(progress) / float64(achievement.QtyRequired)) * 100.0
}

View File

@ -0,0 +1,113 @@
package achievements
import "time"
// Requirement represents a single achievement requirement
type Requirement struct {
AchievementID uint32 `json:"achievement_id"`
Name string `json:"name"`
QtyRequired uint32 `json:"qty_required"`
}
// Reward represents a single achievement reward
type Reward struct {
AchievementID uint32 `json:"achievement_id"`
Reward string `json:"reward"`
}
// Achievement represents a complete achievement definition
type Achievement struct {
ID uint32 `json:"id"`
Title string `json:"title"`
UncompletedText string `json:"uncompleted_text"`
CompletedText string `json:"completed_text"`
Category string `json:"category"`
Expansion string `json:"expansion"`
Icon uint16 `json:"icon"`
PointValue uint32 `json:"point_value"`
QtyRequired uint32 `json:"qty_required"`
Hide bool `json:"hide"`
Unknown3A uint32 `json:"unknown3a"`
Unknown3B uint32 `json:"unknown3b"`
Requirements []Requirement `json:"requirements"`
Rewards []Reward `json:"rewards"`
}
// UpdateItem represents a single achievement progress update
type UpdateItem struct {
AchievementID uint32 `json:"achievement_id"`
ItemUpdate uint32 `json:"item_update"`
}
// Update represents achievement completion/progress data
type Update struct {
ID uint32 `json:"id"`
CompletedDate time.Time `json:"completed_date"`
UpdateItems []UpdateItem `json:"update_items"`
}
// NewAchievement creates a new achievement with empty slices
func NewAchievement() *Achievement {
return &Achievement{
Requirements: make([]Requirement, 0),
Rewards: make([]Reward, 0),
}
}
// NewUpdate creates a new achievement update with empty slices
func NewUpdate() *Update {
return &Update{
UpdateItems: make([]UpdateItem, 0),
}
}
// AddRequirement adds a requirement to the achievement
func (a *Achievement) AddRequirement(req Requirement) {
a.Requirements = append(a.Requirements, req)
}
// AddReward adds a reward to the achievement
func (a *Achievement) AddReward(reward Reward) {
a.Rewards = append(a.Rewards, reward)
}
// AddUpdateItem adds an update item to the achievement update
func (u *Update) AddUpdateItem(item UpdateItem) {
u.UpdateItems = append(u.UpdateItems, item)
}
// Clone creates a deep copy of the achievement
func (a *Achievement) Clone() *Achievement {
clone := &Achievement{
ID: a.ID,
Title: a.Title,
UncompletedText: a.UncompletedText,
CompletedText: a.CompletedText,
Category: a.Category,
Expansion: a.Expansion,
Icon: a.Icon,
PointValue: a.PointValue,
QtyRequired: a.QtyRequired,
Hide: a.Hide,
Unknown3A: a.Unknown3A,
Unknown3B: a.Unknown3B,
Requirements: make([]Requirement, len(a.Requirements)),
Rewards: make([]Reward, len(a.Rewards)),
}
copy(clone.Requirements, a.Requirements)
copy(clone.Rewards, a.Rewards)
return clone
}
// Clone creates a deep copy of the achievement update
func (u *Update) Clone() *Update {
clone := &Update{
ID: u.ID,
CompletedDate: u.CompletedDate,
UpdateItems: make([]UpdateItem, len(u.UpdateItems)),
}
copy(clone.UpdateItems, u.UpdateItems)
return clone
}

View File

@ -0,0 +1,262 @@
package database
import (
"fmt"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
// DB wraps sqlite.Conn with simplified query methods
type DB struct {
conn *sqlite.Conn
}
// Row represents a single database row with easy column access
type Row struct {
stmt *sqlite.Stmt
}
// QueryFunc processes each row in a result set
type QueryFunc func(*Row) error
// Open creates a new database connection with common settings
func Open(path string) (*DB, error) {
conn, err := sqlite.OpenConn(path, sqlite.OpenReadWrite|sqlite.OpenCreate)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Enable foreign keys and WAL mode for better performance
if err := sqlitex.ExecuteTransient(conn, "PRAGMA foreign_keys = ON", nil); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
if err := sqlitex.ExecuteTransient(conn, "PRAGMA journal_mode = WAL", nil); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to enable WAL mode: %w", err)
}
return &DB{conn: conn}, nil
}
// Close closes the database connection
func (db *DB) Close() error {
return db.conn.Close()
}
// Exec executes a statement with parameters
func (db *DB) Exec(query string, args ...any) error {
return sqlitex.Execute(db.conn, query, &sqlitex.ExecOptions{
Args: args,
})
}
// QueryRow executes a query expecting a single row result
func (db *DB) QueryRow(query string, args ...any) (*Row, error) {
stmt, err := db.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("prepare failed: %w", err)
}
// Bind parameters
for i, arg := range args {
if err := bindParam(stmt, i+1, arg); err != nil {
stmt.Finalize()
return nil, err
}
}
hasRow, err := stmt.Step()
if err != nil {
stmt.Finalize()
return nil, fmt.Errorf("query failed: %w", err)
}
if !hasRow {
stmt.Finalize()
return nil, nil // No row found
}
return &Row{stmt: stmt}, nil
}
// Query executes a query and calls fn for each row
func (db *DB) Query(query string, fn QueryFunc, args ...any) error {
stmt, err := db.conn.Prepare(query)
if err != nil {
return fmt.Errorf("prepare failed: %w", err)
}
defer stmt.Finalize()
// Bind parameters
for i, arg := range args {
if err := bindParam(stmt, i+1, arg); err != nil {
return err
}
}
row := &Row{stmt: stmt}
for {
hasRow, err := stmt.Step()
if err != nil {
return fmt.Errorf("query failed: %w", err)
}
if !hasRow {
break
}
if err := fn(row); err != nil {
return err
}
}
return nil
}
// QuerySlice executes a query and returns all rows in a slice
func (db *DB) QuerySlice(query string, args ...any) ([]*Row, error) {
var rows []*Row
stmt, err := db.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("prepare failed: %w", err)
}
defer stmt.Finalize()
// Bind parameters
for i, arg := range args {
if err := bindParam(stmt, i+1, arg); err != nil {
return nil, err
}
}
for {
hasRow, err := stmt.Step()
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
if !hasRow {
break
}
// Create a snapshot of the current row
rowData := &Row{stmt: stmt}
rows = append(rows, rowData)
}
return rows, nil
}
// LastInsertID returns the last inserted row ID
func (db *DB) LastInsertID() int64 {
return db.conn.LastInsertRowID()
}
// Changes returns the number of rows affected by the last statement
func (db *DB) Changes() int {
return db.conn.Changes()
}
// Transaction executes fn within a database transaction
func (db *DB) Transaction(fn func(*DB) error) error {
if err := sqlitex.ExecuteTransient(db.conn, "BEGIN", nil); err != nil {
return fmt.Errorf("begin transaction failed: %w", err)
}
if err := fn(db); err != nil {
sqlitex.ExecuteTransient(db.conn, "ROLLBACK", nil)
return err
}
if err := sqlitex.ExecuteTransient(db.conn, "COMMIT", nil); err != nil {
return fmt.Errorf("commit transaction failed: %w", err)
}
return nil
}
// Row column access methods
// Close releases the row's statement
func (r *Row) Close() {
if r.stmt != nil {
r.stmt.Finalize()
r.stmt = nil
}
}
// Int returns column as int
func (r *Row) Int(col int) int {
return r.stmt.ColumnInt(col)
}
// Int64 returns column as int64
func (r *Row) Int64(col int) int64 {
return r.stmt.ColumnInt64(col)
}
// Text returns column as string
func (r *Row) Text(col int) string {
return r.stmt.ColumnText(col)
}
// Bool returns column as bool (0 = false, non-zero = true)
func (r *Row) Bool(col int) bool {
return r.stmt.ColumnInt(col) != 0
}
// Float returns column as float64
func (r *Row) Float(col int) float64 {
return r.stmt.ColumnFloat(col)
}
// IsNull checks if column is NULL
func (r *Row) IsNull(col int) bool {
return r.stmt.ColumnType(col) == sqlite.TypeNull
}
// bindParam binds a parameter to a statement at the given index
func bindParam(stmt *sqlite.Stmt, index int, value any) error {
switch v := value.(type) {
case nil:
stmt.BindNull(index)
case int:
stmt.BindInt64(index, int64(v))
case int8:
stmt.BindInt64(index, int64(v))
case int16:
stmt.BindInt64(index, int64(v))
case int32:
stmt.BindInt64(index, int64(v))
case int64:
stmt.BindInt64(index, v)
case uint:
stmt.BindInt64(index, int64(v))
case uint8:
stmt.BindInt64(index, int64(v))
case uint16:
stmt.BindInt64(index, int64(v))
case uint32:
stmt.BindInt64(index, int64(v))
case uint64:
stmt.BindInt64(index, int64(v))
case float32:
stmt.BindFloat(index, float64(v))
case float64:
stmt.BindFloat(index, v)
case bool:
if v {
stmt.BindInt64(index, 1)
} else {
stmt.BindInt64(index, 0)
}
case string:
stmt.BindText(index, v)
case []byte:
stmt.BindBytes(index, v)
default:
return fmt.Errorf("unsupported parameter type: %T", value)
}
return nil
}