eq2go/internal/chat/database.go

270 lines
7.6 KiB
Go

package chat
import (
"context"
"fmt"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
// DatabaseChannelManager implements ChannelDatabase interface using sqlitex.Pool
type DatabaseChannelManager struct {
pool *sqlitex.Pool
}
// NewDatabaseChannelManager creates a new database channel manager using sqlitex.Pool
func NewDatabaseChannelManager(pool *sqlitex.Pool) *DatabaseChannelManager {
return &DatabaseChannelManager{
pool: pool,
}
}
// LoadWorldChannels retrieves all persistent world channels from database
func (dcm *DatabaseChannelManager) LoadWorldChannels(ctx context.Context) ([]ChatChannelData, error) {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels`"
var channels []ChatChannelData
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error {
var channel ChatChannelData
channel.Name = stmt.ColumnText(0)
if stmt.ColumnType(1) != sqlite.TypeNull {
channel.Password = stmt.ColumnText(1)
}
channel.LevelRestriction = int32(stmt.ColumnInt64(2))
channel.ClassRestriction = int32(stmt.ColumnInt64(3))
channel.RaceRestriction = int32(stmt.ColumnInt64(4))
channels = append(channels, channel)
return nil
},
})
if err != nil {
return nil, fmt.Errorf("failed to query channels: %w", err)
}
return channels, nil
}
// SaveChannel persists a channel to database (world channels only)
func (dcm *DatabaseChannelManager) SaveChannel(ctx context.Context, channel ChatChannelData) error {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
// Insert or update channel
query := `
INSERT OR REPLACE INTO channels
(name, password, level_restriction, classes, races)
VALUES (?, ?, ?, ?, ?)`
var password any
if channel.Password != "" {
password = channel.Password
}
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: []any{
channel.Name,
password,
channel.LevelRestriction,
channel.ClassRestriction,
channel.RaceRestriction,
},
})
if err != nil {
return fmt.Errorf("failed to save channel %s: %w", channel.Name, err)
}
return nil
}
// DeleteChannel removes a channel from database
func (dcm *DatabaseChannelManager) DeleteChannel(ctx context.Context, channelName string) error {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := "DELETE FROM channels WHERE name = ?"
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: []any{channelName},
})
if err != nil {
return fmt.Errorf("failed to delete channel %s: %w", channelName, err)
}
return nil
}
// EnsureChannelsTable creates the channels table if it doesn't exist
func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) error {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := `
CREATE TABLE IF NOT EXISTS channels (
name TEXT PRIMARY KEY,
password TEXT,
level_restriction INTEGER NOT NULL DEFAULT 0,
classes INTEGER NOT NULL DEFAULT 0,
races INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`
err = sqlitex.Execute(conn, query, nil)
if err != nil {
return fmt.Errorf("failed to create channels table: %w", err)
}
return nil
}
// GetChannelCount returns the total number of channels in the database
func (dcm *DatabaseChannelManager) GetChannelCount(ctx context.Context) (int, error) {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := "SELECT COUNT(*) FROM channels"
var count int
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error {
count = int(stmt.ColumnInt64(0))
return nil
},
})
if err != nil {
return 0, fmt.Errorf("failed to query channel count: %w", err)
}
return count, nil
}
// GetChannelByName retrieves a specific channel by name
func (dcm *DatabaseChannelManager) GetChannelByName(ctx context.Context, channelName string) (*ChatChannelData, error) {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels` WHERE `name` = ?"
var channel *ChatChannelData
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: []any{channelName},
ResultFunc: func(stmt *sqlite.Stmt) error {
channel = &ChatChannelData{}
channel.Name = stmt.ColumnText(0)
if stmt.ColumnType(1) != sqlite.TypeNull {
channel.Password = stmt.ColumnText(1)
}
channel.LevelRestriction = int32(stmt.ColumnInt64(2))
channel.ClassRestriction = int32(stmt.ColumnInt64(3))
channel.RaceRestriction = int32(stmt.ColumnInt64(4))
return nil
},
})
if err != nil {
return nil, fmt.Errorf("failed to query channel %s: %w", channelName, err)
}
if channel == nil {
return nil, fmt.Errorf("channel %s not found", channelName)
}
return channel, nil
}
// ListChannelNames returns a list of all channel names in the database
func (dcm *DatabaseChannelManager) ListChannelNames(ctx context.Context) ([]string, error) {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := "SELECT name FROM channels ORDER BY name"
var names []string
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error {
name := stmt.ColumnText(0)
names = append(names, name)
return nil
},
})
if err != nil {
return nil, fmt.Errorf("failed to query channel names: %w", err)
}
return names, nil
}
// UpdateChannelPassword updates just the password for a channel
func (dcm *DatabaseChannelManager) UpdateChannelPassword(ctx context.Context, channelName, password string) error {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := "UPDATE channels SET password = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?"
var passwordParam any
if password != "" {
passwordParam = password
}
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: []any{passwordParam, channelName},
})
if err != nil {
return fmt.Errorf("failed to update password for channel %s: %w", channelName, err)
}
return nil
}
// UpdateChannelRestrictions updates the level, race, and class restrictions for a channel
func (dcm *DatabaseChannelManager) UpdateChannelRestrictions(ctx context.Context, channelName string, levelRestriction, classRestriction, raceRestriction int32) error {
conn, err := dcm.pool.Take(context.Background())
if err != nil {
return fmt.Errorf("failed to get connection: %w", err)
}
defer dcm.pool.Put(conn)
query := "UPDATE channels SET level_restriction = ?, classes = ?, races = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?"
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Args: []any{levelRestriction, classRestriction, raceRestriction, channelName},
})
if err != nil {
return fmt.Errorf("failed to update restrictions for channel %s: %w", channelName, err)
}
return nil
}