diff --git a/internal/chat/chat_test.go b/internal/chat/chat_test.go index 02f7f17..b2d6972 100644 --- a/internal/chat/chat_test.go +++ b/internal/chat/chat_test.go @@ -876,7 +876,7 @@ func TestMockLanguageProcessor(t *testing.T) { func BenchmarkChannelJoin(b *testing.B) { channel := NewChannel("benchmark") - for i := 0; b.Loop(); i++ { + for i := 0; i < b.N; i++ { channel.JoinChannel(int32(i)) } } @@ -888,7 +888,7 @@ func BenchmarkChannelIsInChannel(b *testing.B) { channel.JoinChannel(int32(i)) } - for i := 0; b.Loop(); i++ { + for i := 0; i < b.N; i++ { channel.IsInChannel(int32(i % 1000)) } } @@ -900,7 +900,7 @@ func BenchmarkChannelGetMembers(b *testing.B) { channel.JoinChannel(int32(i)) } - for b.Loop() { + for i := 0; i < b.N; i++ { channel.GetMembers() } } diff --git a/internal/chat/database.go b/internal/chat/database.go index 2b09c98..f9815e5 100644 --- a/internal/chat/database.go +++ b/internal/chat/database.go @@ -4,40 +4,49 @@ import ( "context" "fmt" - "eq2emu/internal/database" + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" ) -// DatabaseChannelManager implements ChannelDatabase interface using the correct database wrapper +// DatabaseChannelManager implements ChannelDatabase interface using sqlitex.Pool type DatabaseChannelManager struct { - db *database.DB + pool *sqlitex.Pool } -// NewDatabaseChannelManager creates a new database channel manager using the correct wrapper -func NewDatabaseChannelManager(db *database.DB) *DatabaseChannelManager { +// NewDatabaseChannelManager creates a new database channel manager using sqlitex.Pool +func NewDatabaseChannelManager(pool *sqlitex.Pool) *DatabaseChannelManager { return &DatabaseChannelManager{ - db: db, + pool: pool, } } // LoadWorldChannels retrieves all persistent world channels from database func (dcm *DatabaseChannelManager) LoadWorldChannels(ctx context.Context) ([]ChatChannelData, error) { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels`" var channels []ChatChannelData - err := dcm.db.Query(query, func(row *database.Row) error { - var channel ChatChannelData + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + ResultFunc: func(stmt *sqlite.Stmt) error { + var channel ChatChannelData - channel.Name = row.Text(0) - if !row.IsNull(1) { - channel.Password = row.Text(1) - } - channel.LevelRestriction = int32(row.Int64(2)) - channel.ClassRestriction = int32(row.Int64(3)) - channel.RaceRestriction = int32(row.Int64(4)) + channel.Name = stmt.ColumnText(0) + if stmt.ColumnType(1) != sqlite.TypeNull { + channel.Password = stmt.ColumnText(1) + } + channel.LevelRestriction = int32(stmt.ColumnInt64(2)) + channel.ClassRestriction = int32(stmt.ColumnInt64(3)) + channel.RaceRestriction = int32(stmt.ColumnInt64(4)) - channels = append(channels, channel) - return nil + channels = append(channels, channel) + return nil + }, }) if err != nil { @@ -49,24 +58,32 @@ func (dcm *DatabaseChannelManager) LoadWorldChannels(ctx context.Context) ([]Cha // SaveChannel persists a channel to database (world channels only) func (dcm *DatabaseChannelManager) SaveChannel(ctx context.Context, channel ChatChannelData) error { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + // Insert or update channel query := ` INSERT OR REPLACE INTO channels (name, password, level_restriction, classes, races) VALUES (?, ?, ?, ?, ?)` - var password *string + var password any if channel.Password != "" { - password = &channel.Password + password = channel.Password } - err := dcm.db.Exec(query, - channel.Name, - password, - channel.LevelRestriction, - channel.ClassRestriction, - channel.RaceRestriction, - ) + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + Args: []any{ + channel.Name, + password, + channel.LevelRestriction, + channel.ClassRestriction, + channel.RaceRestriction, + }, + }) if err != nil { return fmt.Errorf("failed to save channel %s: %w", channel.Name, err) } @@ -76,9 +93,17 @@ func (dcm *DatabaseChannelManager) SaveChannel(ctx context.Context, channel Chat // DeleteChannel removes a channel from database func (dcm *DatabaseChannelManager) DeleteChannel(ctx context.Context, channelName string) error { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := "DELETE FROM channels WHERE name = ?" - err := dcm.db.Exec(query, channelName) + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + Args: []any{channelName}, + }) if err != nil { return fmt.Errorf("failed to delete channel %s: %w", channelName, err) } @@ -88,6 +113,12 @@ func (dcm *DatabaseChannelManager) DeleteChannel(ctx context.Context, channelNam // EnsureChannelsTable creates the channels table if it doesn't exist func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) error { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := ` CREATE TABLE IF NOT EXISTS channels ( name TEXT PRIMARY KEY, @@ -99,7 +130,7 @@ func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) erro updated_at DATETIME DEFAULT CURRENT_TIMESTAMP )` - err := dcm.db.Exec(query) + err = sqlitex.Execute(conn, query, nil) if err != nil { return fmt.Errorf("failed to create channels table: %w", err) } @@ -109,58 +140,81 @@ func (dcm *DatabaseChannelManager) EnsureChannelsTable(ctx context.Context) erro // GetChannelCount returns the total number of channels in the database func (dcm *DatabaseChannelManager) GetChannelCount(ctx context.Context) (int, error) { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := "SELECT COUNT(*) FROM channels" - row, err := dcm.db.QueryRow(query) + var count int + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + ResultFunc: func(stmt *sqlite.Stmt) error { + count = int(stmt.ColumnInt64(0)) + return nil + }, + }) if err != nil { return 0, fmt.Errorf("failed to query channel count: %w", err) } - if row == nil { - return 0, nil - } - defer row.Close() - - count := row.Int(0) return count, nil } // GetChannelByName retrieves a specific channel by name func (dcm *DatabaseChannelManager) GetChannelByName(ctx context.Context, channelName string) (*ChatChannelData, error) { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := "SELECT `name`, `password`, `level_restriction`, `classes`, `races` FROM `channels` WHERE `name` = ?" - row, err := dcm.db.QueryRow(query, channelName) + var channel *ChatChannelData + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + Args: []any{channelName}, + ResultFunc: func(stmt *sqlite.Stmt) error { + channel = &ChatChannelData{} + channel.Name = stmt.ColumnText(0) + if stmt.ColumnType(1) != sqlite.TypeNull { + channel.Password = stmt.ColumnText(1) + } + channel.LevelRestriction = int32(stmt.ColumnInt64(2)) + channel.ClassRestriction = int32(stmt.ColumnInt64(3)) + channel.RaceRestriction = int32(stmt.ColumnInt64(4)) + return nil + }, + }) if err != nil { return nil, fmt.Errorf("failed to query channel %s: %w", channelName, err) } - if row == nil { + if channel == nil { return nil, fmt.Errorf("channel %s not found", channelName) } - defer row.Close() - var channel ChatChannelData - - channel.Name = row.Text(0) - if !row.IsNull(1) { - channel.Password = row.Text(1) - } - channel.LevelRestriction = int32(row.Int64(2)) - channel.ClassRestriction = int32(row.Int64(3)) - channel.RaceRestriction = int32(row.Int64(4)) - - return &channel, nil + return channel, nil } // ListChannelNames returns a list of all channel names in the database func (dcm *DatabaseChannelManager) ListChannelNames(ctx context.Context) ([]string, error) { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := "SELECT name FROM channels ORDER BY name" var names []string - err := dcm.db.Query(query, func(row *database.Row) error { - name := row.Text(0) - names = append(names, name) - return nil + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + ResultFunc: func(stmt *sqlite.Stmt) error { + name := stmt.ColumnText(0) + names = append(names, name) + return nil + }, }) if err != nil { @@ -172,14 +226,22 @@ func (dcm *DatabaseChannelManager) ListChannelNames(ctx context.Context) ([]stri // UpdateChannelPassword updates just the password for a channel func (dcm *DatabaseChannelManager) UpdateChannelPassword(ctx context.Context, channelName, password string) error { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := "UPDATE channels SET password = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?" - var passwordParam *string + var passwordParam any if password != "" { - passwordParam = &password + passwordParam = password } - err := dcm.db.Exec(query, passwordParam, channelName) + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + Args: []any{passwordParam, channelName}, + }) if err != nil { return fmt.Errorf("failed to update password for channel %s: %w", channelName, err) } @@ -189,9 +251,17 @@ func (dcm *DatabaseChannelManager) UpdateChannelPassword(ctx context.Context, ch // UpdateChannelRestrictions updates the level, race, and class restrictions for a channel func (dcm *DatabaseChannelManager) UpdateChannelRestrictions(ctx context.Context, channelName string, levelRestriction, classRestriction, raceRestriction int32) error { + conn, err := dcm.pool.Take(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + defer dcm.pool.Put(conn) + query := "UPDATE channels SET level_restriction = ?, classes = ?, races = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?" - err := dcm.db.Exec(query, levelRestriction, classRestriction, raceRestriction, channelName) + err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ + Args: []any{levelRestriction, classRestriction, raceRestriction, channelName}, + }) if err != nil { return fmt.Errorf("failed to update restrictions for channel %s: %w", channelName, err) }