fix chat package with raw sqlite usage
This commit is contained in:
parent
1288bc086f
commit
674b14f278
@ -876,7 +876,7 @@ func TestMockLanguageProcessor(t *testing.T) {
|
|||||||
func BenchmarkChannelJoin(b *testing.B) {
|
func BenchmarkChannelJoin(b *testing.B) {
|
||||||
channel := NewChannel("benchmark")
|
channel := NewChannel("benchmark")
|
||||||
|
|
||||||
for i := 0; b.Loop(); i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
channel.JoinChannel(int32(i))
|
channel.JoinChannel(int32(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -888,7 +888,7 @@ func BenchmarkChannelIsInChannel(b *testing.B) {
|
|||||||
channel.JoinChannel(int32(i))
|
channel.JoinChannel(int32(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; b.Loop(); i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
channel.IsInChannel(int32(i % 1000))
|
channel.IsInChannel(int32(i % 1000))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -900,7 +900,7 @@ func BenchmarkChannelGetMembers(b *testing.B) {
|
|||||||
channel.JoinChannel(int32(i))
|
channel.JoinChannel(int32(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
for b.Loop() {
|
for i := 0; i < b.N; i++ {
|
||||||
channel.GetMembers()
|
channel.GetMembers()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,40 +4,49 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"eq2emu/internal/database"
|
"zombiezen.com/go/sqlite"
|
||||||
|
"zombiezen.com/go/sqlite/sqlitex"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DatabaseChannelManager implements ChannelDatabase interface using the correct database wrapper
|
// DatabaseChannelManager implements ChannelDatabase interface using sqlitex.Pool
|
||||||
type DatabaseChannelManager struct {
|
type DatabaseChannelManager struct {
|
||||||
db *database.DB
|
pool *sqlitex.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabaseChannelManager creates a new database channel manager using the correct wrapper
|
// NewDatabaseChannelManager creates a new database channel manager using sqlitex.Pool
|
||||||
func NewDatabaseChannelManager(db *database.DB) *DatabaseChannelManager {
|
func NewDatabaseChannelManager(pool *sqlitex.Pool) *DatabaseChannelManager {
|
||||||
return &DatabaseChannelManager{
|
return &DatabaseChannelManager{
|
||||||
db: db,
|
pool: pool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadWorldChannels retrieves all persistent world channels from database
|
// LoadWorldChannels retrieves all persistent world channels from database
|
||||||
func (dcm *DatabaseChannelManager) LoadWorldChannels(ctx context.Context) ([]ChatChannelData, error) {
|
func (dcm *DatabaseChannelManager) LoadWorldChannels(ctx context.Context) ([]ChatChannelData, error) {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels`"
|
query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels`"
|
||||||
|
|
||||||
var channels []ChatChannelData
|
var channels []ChatChannelData
|
||||||
|
|
||||||
err := dcm.db.Query(query, func(row *database.Row) error {
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
var channel ChatChannelData
|
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||||
|
var channel ChatChannelData
|
||||||
|
|
||||||
channel.Name = row.Text(0)
|
channel.Name = stmt.ColumnText(0)
|
||||||
if !row.IsNull(1) {
|
if stmt.ColumnType(1) != sqlite.TypeNull {
|
||||||
channel.Password = row.Text(1)
|
channel.Password = stmt.ColumnText(1)
|
||||||
}
|
}
|
||||||
channel.LevelRestriction = int32(row.Int64(2))
|
channel.LevelRestriction = int32(stmt.ColumnInt64(2))
|
||||||
channel.ClassRestriction = int32(row.Int64(3))
|
channel.ClassRestriction = int32(stmt.ColumnInt64(3))
|
||||||
channel.RaceRestriction = int32(row.Int64(4))
|
channel.RaceRestriction = int32(stmt.ColumnInt64(4))
|
||||||
|
|
||||||
channels = append(channels, channel)
|
channels = append(channels, channel)
|
||||||
return nil
|
return nil
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -49,24 +58,32 @@ func (dcm *DatabaseChannelManager) LoadWorldChannels(ctx context.Context) ([]Cha
|
|||||||
|
|
||||||
// SaveChannel persists a channel to database (world channels only)
|
// SaveChannel persists a channel to database (world channels only)
|
||||||
func (dcm *DatabaseChannelManager) SaveChannel(ctx context.Context, channel ChatChannelData) error {
|
func (dcm *DatabaseChannelManager) SaveChannel(ctx context.Context, channel ChatChannelData) error {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
// Insert or update channel
|
// Insert or update channel
|
||||||
query := `
|
query := `
|
||||||
INSERT OR REPLACE INTO channels
|
INSERT OR REPLACE INTO channels
|
||||||
(name, password, level_restriction, classes, races)
|
(name, password, level_restriction, classes, races)
|
||||||
VALUES (?, ?, ?, ?, ?)`
|
VALUES (?, ?, ?, ?, ?)`
|
||||||
|
|
||||||
var password *string
|
var password any
|
||||||
if channel.Password != "" {
|
if channel.Password != "" {
|
||||||
password = &channel.Password
|
password = channel.Password
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dcm.db.Exec(query,
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
channel.Name,
|
Args: []any{
|
||||||
password,
|
channel.Name,
|
||||||
channel.LevelRestriction,
|
password,
|
||||||
channel.ClassRestriction,
|
channel.LevelRestriction,
|
||||||
channel.RaceRestriction,
|
channel.ClassRestriction,
|
||||||
)
|
channel.RaceRestriction,
|
||||||
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to save channel %s: %w", channel.Name, err)
|
return fmt.Errorf("failed to save channel %s: %w", channel.Name, err)
|
||||||
}
|
}
|
||||||
@ -76,9 +93,17 @@ func (dcm *DatabaseChannelManager) SaveChannel(ctx context.Context, channel Chat
|
|||||||
|
|
||||||
// DeleteChannel removes a channel from database
|
// DeleteChannel removes a channel from database
|
||||||
func (dcm *DatabaseChannelManager) DeleteChannel(ctx context.Context, channelName string) error {
|
func (dcm *DatabaseChannelManager) DeleteChannel(ctx context.Context, channelName string) error {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := "DELETE FROM channels WHERE name = ?"
|
query := "DELETE FROM channels WHERE name = ?"
|
||||||
|
|
||||||
err := dcm.db.Exec(query, channelName)
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
|
Args: []any{channelName},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete channel %s: %w", channelName, err)
|
return fmt.Errorf("failed to delete channel %s: %w", channelName, err)
|
||||||
}
|
}
|
||||||
@ -88,6 +113,12 @@ func (dcm *DatabaseChannelManager) DeleteChannel(ctx context.Context, channelNam
|
|||||||
|
|
||||||
// EnsureChannelsTable creates the channels table if it doesn't exist
|
// EnsureChannelsTable creates the channels table if it doesn't exist
|
||||||
func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) error {
|
func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) error {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
CREATE TABLE IF NOT EXISTS channels (
|
CREATE TABLE IF NOT EXISTS channels (
|
||||||
name TEXT PRIMARY KEY,
|
name TEXT PRIMARY KEY,
|
||||||
@ -99,7 +130,7 @@ func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) erro
|
|||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
)`
|
)`
|
||||||
|
|
||||||
err := dcm.db.Exec(query)
|
err = sqlitex.Execute(conn, query, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create channels table: %w", err)
|
return fmt.Errorf("failed to create channels table: %w", err)
|
||||||
}
|
}
|
||||||
@ -109,58 +140,81 @@ func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) erro
|
|||||||
|
|
||||||
// GetChannelCount returns the total number of channels in the database
|
// GetChannelCount returns the total number of channels in the database
|
||||||
func (dcm *DatabaseChannelManager) GetChannelCount(ctx context.Context) (int, error) {
|
func (dcm *DatabaseChannelManager) GetChannelCount(ctx context.Context) (int, error) {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := "SELECT COUNT(*) FROM channels"
|
query := "SELECT COUNT(*) FROM channels"
|
||||||
|
|
||||||
row, err := dcm.db.QueryRow(query)
|
var count int
|
||||||
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
|
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||||
|
count = int(stmt.ColumnInt64(0))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to query channel count: %w", err)
|
return 0, fmt.Errorf("failed to query channel count: %w", err)
|
||||||
}
|
}
|
||||||
if row == nil {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
defer row.Close()
|
|
||||||
|
|
||||||
count := row.Int(0)
|
|
||||||
|
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChannelByName retrieves a specific channel by name
|
// GetChannelByName retrieves a specific channel by name
|
||||||
func (dcm *DatabaseChannelManager) GetChannelByName(ctx context.Context, channelName string) (*ChatChannelData, error) {
|
func (dcm *DatabaseChannelManager) GetChannelByName(ctx context.Context, channelName string) (*ChatChannelData, error) {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels` WHERE `name` = ?"
|
query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels` WHERE `name` = ?"
|
||||||
|
|
||||||
row, err := dcm.db.QueryRow(query, channelName)
|
var channel *ChatChannelData
|
||||||
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
|
Args: []any{channelName},
|
||||||
|
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||||
|
channel = &ChatChannelData{}
|
||||||
|
channel.Name = stmt.ColumnText(0)
|
||||||
|
if stmt.ColumnType(1) != sqlite.TypeNull {
|
||||||
|
channel.Password = stmt.ColumnText(1)
|
||||||
|
}
|
||||||
|
channel.LevelRestriction = int32(stmt.ColumnInt64(2))
|
||||||
|
channel.ClassRestriction = int32(stmt.ColumnInt64(3))
|
||||||
|
channel.RaceRestriction = int32(stmt.ColumnInt64(4))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query channel %s: %w", channelName, err)
|
return nil, fmt.Errorf("failed to query channel %s: %w", channelName, err)
|
||||||
}
|
}
|
||||||
if row == nil {
|
if channel == nil {
|
||||||
return nil, fmt.Errorf("channel %s not found", channelName)
|
return nil, fmt.Errorf("channel %s not found", channelName)
|
||||||
}
|
}
|
||||||
defer row.Close()
|
|
||||||
|
|
||||||
var channel ChatChannelData
|
return channel, nil
|
||||||
|
|
||||||
channel.Name = row.Text(0)
|
|
||||||
if !row.IsNull(1) {
|
|
||||||
channel.Password = row.Text(1)
|
|
||||||
}
|
|
||||||
channel.LevelRestriction = int32(row.Int64(2))
|
|
||||||
channel.ClassRestriction = int32(row.Int64(3))
|
|
||||||
channel.RaceRestriction = int32(row.Int64(4))
|
|
||||||
|
|
||||||
return &channel, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListChannelNames returns a list of all channel names in the database
|
// ListChannelNames returns a list of all channel names in the database
|
||||||
func (dcm *DatabaseChannelManager) ListChannelNames(ctx context.Context) ([]string, error) {
|
func (dcm *DatabaseChannelManager) ListChannelNames(ctx context.Context) ([]string, error) {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := "SELECT name FROM channels ORDER BY name"
|
query := "SELECT name FROM channels ORDER BY name"
|
||||||
|
|
||||||
var names []string
|
var names []string
|
||||||
|
|
||||||
err := dcm.db.Query(query, func(row *database.Row) error {
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
name := row.Text(0)
|
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||||
names = append(names, name)
|
name := stmt.ColumnText(0)
|
||||||
return nil
|
names = append(names, name)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -172,14 +226,22 @@ func (dcm *DatabaseChannelManager) ListChannelNames(ctx context.Context) ([]stri
|
|||||||
|
|
||||||
// UpdateChannelPassword updates just the password for a channel
|
// UpdateChannelPassword updates just the password for a channel
|
||||||
func (dcm *DatabaseChannelManager) UpdateChannelPassword(ctx context.Context, channelName, password string) error {
|
func (dcm *DatabaseChannelManager) UpdateChannelPassword(ctx context.Context, channelName, password string) error {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := "UPDATE channels SET password = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?"
|
query := "UPDATE channels SET password = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?"
|
||||||
|
|
||||||
var passwordParam *string
|
var passwordParam any
|
||||||
if password != "" {
|
if password != "" {
|
||||||
passwordParam = &password
|
passwordParam = password
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dcm.db.Exec(query, passwordParam, channelName)
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
|
Args: []any{passwordParam, channelName},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update password for channel %s: %w", channelName, err)
|
return fmt.Errorf("failed to update password for channel %s: %w", channelName, err)
|
||||||
}
|
}
|
||||||
@ -189,9 +251,17 @@ func (dcm *DatabaseChannelManager) UpdateChannelPassword(ctx context.Context, ch
|
|||||||
|
|
||||||
// UpdateChannelRestrictions updates the level, race, and class restrictions for a channel
|
// UpdateChannelRestrictions updates the level, race, and class restrictions for a channel
|
||||||
func (dcm *DatabaseChannelManager) UpdateChannelRestrictions(ctx context.Context, channelName string, levelRestriction, classRestriction, raceRestriction int32) error {
|
func (dcm *DatabaseChannelManager) UpdateChannelRestrictions(ctx context.Context, channelName string, levelRestriction, classRestriction, raceRestriction int32) error {
|
||||||
|
conn, err := dcm.pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get connection: %w", err)
|
||||||
|
}
|
||||||
|
defer dcm.pool.Put(conn)
|
||||||
|
|
||||||
query := "UPDATE channels SET level_restriction = ?, classes = ?, races = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?"
|
query := "UPDATE channels SET level_restriction = ?, classes = ?, races = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?"
|
||||||
|
|
||||||
err := dcm.db.Exec(query, levelRestriction, classRestriction, raceRestriction, channelName)
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
||||||
|
Args: []any{levelRestriction, classRestriction, raceRestriction, channelName},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update restrictions for channel %s: %w", channelName, err)
|
return fmt.Errorf("failed to update restrictions for channel %s: %w", channelName, err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user