eq2go/internal/udp/udp_test.go
2025-07-21 23:18:39 -05:00

423 lines
10 KiB
Go

package udp
import (
"eq2emu/internal/opcodes"
"fmt"
"net"
"sync"
"testing"
"time"
)
// TestHandler implements PacketHandler for testing
type TestHandler struct {
receivedPackets []*ApplicationPacket
mutex sync.Mutex
}
// HandlePacket processes received packets and stores them for verification
func (h *TestHandler) HandlePacket(conn *Connection, packet *ApplicationPacket) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.receivedPackets = append(h.receivedPackets, packet)
fmt.Printf("Received packet - Opcode: 0x%04X, Data length: %d\n",
packet.Opcode, len(packet.Data))
// Echo back a response for interactive testing
response := &ApplicationPacket{
Opcode: opcodes.OpLoginReplyMsg,
Data: []byte("Hello from server"),
}
conn.SendPacket(response)
}
// GetReceivedPackets returns a copy of all received packets
func (h *TestHandler) GetReceivedPackets() []*ApplicationPacket {
h.mutex.Lock()
defer h.mutex.Unlock()
packets := make([]*ApplicationPacket, len(h.receivedPackets))
copy(packets, h.receivedPackets)
return packets
}
// Clear removes all received packets
func (h *TestHandler) Clear() {
h.mutex.Lock()
defer h.mutex.Unlock()
h.receivedPackets = nil
}
// TestServer tests basic server creation and startup
func TestServer(t *testing.T) {
handler := &TestHandler{}
server, err := NewServer(":9999", handler)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Start server in goroutine
go func() {
if err := server.Start(); err != nil {
t.Errorf("Server error: %v", err)
}
}()
// Let it run briefly
time.Sleep(100 * time.Millisecond)
// Verify server is running
if server.GetConnectionCount() != 0 {
t.Errorf("Expected 0 connections, got %d", server.GetConnectionCount())
}
server.Stop()
}
// TestServerConfig tests server configuration options
func TestServerConfig(t *testing.T) {
handler := &TestHandler{}
config := ServerConfig{
MaxConnections: 10,
Timeout: 30 * time.Second,
BufferSize: 4096,
}
server, err := NewServerWithConfig(":9998", handler, config)
if err != nil {
t.Fatalf("Failed to create server with config: %v", err)
}
stats := server.GetStats()
if stats.MaxConnections != 10 {
t.Errorf("Expected max connections 10, got %d", stats.MaxConnections)
}
if stats.Timeout != 30*time.Second {
t.Errorf("Expected timeout 30s, got %v", stats.Timeout)
}
server.Stop()
}
// TestPacketParsing tests protocol packet parsing
func TestPacketParsing(t *testing.T) {
// Test 1-byte opcode with CRC
payload1 := []byte{0x01, 0x48, 0x65, 0x6C, 0x6C, 0x6F}
data1 := AppendCRC(payload1)
packet1, err := ParseProtocolPacket(data1)
if err != nil {
t.Fatalf("Failed to parse 1-byte opcode: %v", err)
}
if packet1.Opcode != 0x01 {
t.Errorf("Expected opcode 0x01, got 0x%02X", packet1.Opcode)
}
// Test 2-byte opcode with CRC
payload2 := []byte{0x00, 0xFF, 0x48, 0x65, 0x6C, 0x6C, 0x6F}
data2 := AppendCRC(payload2)
packet2, err := ParseProtocolPacket(data2)
if err != nil {
t.Fatalf("Failed to parse 2-byte opcode: %v", err)
}
if packet2.Opcode != 0xFF {
t.Errorf("Expected opcode 0xFF, got 0x%02X", packet2.Opcode)
}
// Test session packet (no CRC required)
sessionData := []byte{opcodes.OpSessionRequest, 0x48, 0x65, 0x6C, 0x6C, 0x6F}
sessionPacket, err := ParseProtocolPacket(sessionData)
if err != nil {
t.Fatalf("Failed to parse session packet: %v", err)
}
if sessionPacket.Opcode != opcodes.OpSessionRequest {
t.Errorf("Expected session request opcode, got 0x%02X", sessionPacket.Opcode)
}
}
// TestCRC tests CRC calculation and validation
func TestCRC(t *testing.T) {
data := []byte("Hello, World!")
// Append CRC and validate
withCRC := AppendCRC(data)
if !ValidateCRC(withCRC) {
t.Error("CRC validation failed for correct data")
}
// Test with corrupted data
corrupted := make([]byte, len(withCRC))
copy(corrupted, withCRC)
corrupted[0] ^= 0xFF // Flip bits
if ValidateCRC(corrupted) {
t.Error("CRC validation passed for corrupted data")
}
// Test ValidateAndStrip
stripped, valid := ValidateAndStrip(withCRC)
if !valid {
t.Error("ValidateAndStrip failed for valid data")
}
if string(stripped) != string(data) {
t.Error("Stripped data doesn't match original")
}
}
// TestCompression tests data compression and decompression
func TestCompression(t *testing.T) {
testData := []byte("This is a test string that should compress well because it has repetitive patterns.")
compressed, err := Compress(testData)
if err != nil {
t.Fatalf("Compression failed: %v", err)
}
decompressed, err := Decompress(compressed)
if err != nil {
t.Fatalf("Decompression failed: %v", err)
}
if string(decompressed) != string(testData) {
t.Error("Decompressed data doesn't match original")
}
// Test empty data
empty, err := Compress([]byte{})
if err != nil {
t.Fatalf("Empty compression failed: %v", err)
}
if len(empty) != 1 || empty[0] != UncompressedMarker {
t.Error("Empty data compression incorrect")
}
}
// TestCrypto tests RC4 encryption and decryption
func TestCrypto(t *testing.T) {
crypto := NewCrypto()
key := []byte{0x01, 0x02, 0x03, 0x04}
err := crypto.SetKey(key)
if err != nil {
t.Fatalf("SetKey failed: %v", err)
}
if !crypto.IsEncrypted() {
t.Error("Crypto should be encrypted after SetKey")
}
testData := []byte("Hello, World!")
encrypted := crypto.Encrypt(testData)
decrypted := crypto.Decrypt(encrypted)
if string(decrypted) != string(testData) {
t.Error("Decrypted data doesn't match original")
}
}
// TestRetransmitQueue tests packet retransmission logic
func TestRetransmitQueue(t *testing.T) {
rq := NewRetransmitQueue()
packet := &ProtocolPacket{
Opcode: opcodes.OpPacket,
Data: []byte("test"),
}
// Add packet
rq.Add(packet, 1)
if rq.Size() != 1 {
t.Errorf("Expected size 1, got %d", rq.Size())
}
// Acknowledge packet
acked := rq.Acknowledge(1)
if !acked {
t.Error("Acknowledge should return true for existing packet")
}
if rq.Size() != 0 {
t.Errorf("Expected size 0 after ack, got %d", rq.Size())
}
// Test non-existent acknowledgment
acked = rq.Acknowledge(999)
if acked {
t.Error("Acknowledge should return false for non-existent packet")
}
}
// TestFragmentation tests packet fragmentation and reassembly
func TestFragmentation(t *testing.T) {
fm := NewFragmentManager(100) // Small max length to force fragmentation
// Create large test data
largeData := make([]byte, 300)
for i := range largeData {
largeData[i] = byte(i % 256)
}
// Fragment the data
fragments := fm.FragmentPacket(largeData, 1)
if fragments == nil {
t.Fatal("Expected fragmentation for large data")
}
if len(fragments) < 2 {
t.Error("Expected multiple fragments")
}
// Reassemble fragments
var reassembled []byte
complete := false
for _, frag := range fragments {
data, isComplete, err := fm.ProcessFragment(frag)
if err != nil {
t.Fatalf("Fragment processing failed: %v", err)
}
if isComplete {
reassembled = data
complete = true
break
}
}
if !complete {
t.Error("Fragmentation did not complete")
}
if len(reassembled) != len(largeData) {
t.Errorf("Reassembled length %d != original %d", len(reassembled), len(largeData))
}
for i, b := range reassembled {
if b != largeData[i] {
t.Errorf("Reassembled data differs at position %d", i)
break
}
}
}
// TestPacketCombining tests packet combination functionality
func TestPacketCombining(t *testing.T) {
combiner := NewPacketCombiner()
// Add small packets - use session opcodes that don't require CRC
packet1 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test1")}
packet2 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test2")}
combiner.AddPacket(packet1)
combiner.AddPacket(packet2)
if combiner.GetPendingCount() != 2 {
t.Errorf("Expected 2 pending packets, got %d", combiner.GetPendingCount())
}
// Flush combined
combined := combiner.FlushCombined()
if len(combined) != 1 {
t.Errorf("Expected 1 combined packet, got %d", len(combined))
}
if combined[0].Opcode != opcodes.OpCombined {
t.Error("Combined packet should have OpCombined opcode")
}
// Parse combined packet
parsed, err := ParseCombinedPacket(combined[0].Data)
if err != nil {
t.Fatalf("Failed to parse combined packet: %v", err)
}
if len(parsed) != 2 {
t.Errorf("Expected 2 parsed packets, got %d", len(parsed))
}
}
// TestConnection tests basic connection functionality
func TestConnection(t *testing.T) {
handler := &TestHandler{}
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", addr)
defer conn.Close()
connection := NewConnection(addr, conn, handler)
if connection.GetState() != StateClosed {
t.Error("New connection should be in closed state")
}
// Test session ID
if connection.GetSessionID() != 0 {
t.Error("New connection should have session ID 0")
}
// Test timeout
if !connection.IsTimedOut(time.Nanosecond) {
t.Error("New connection should be timed out with very short timeout")
}
}
// BenchmarkCRC benchmarks CRC calculation performance
func BenchmarkCRC(b *testing.B) {
data := make([]byte, 1024)
for i := range data {
data[i] = byte(i)
}
b.ResetTimer()
for range b.N {
CalculateCRC32(data)
}
}
// BenchmarkCompression benchmarks compression performance
func BenchmarkCompression(b *testing.B) {
data := make([]byte, 1024)
for i := range data {
data[i] = byte(i % 256)
}
b.ResetTimer()
for range b.N {
compressed, _ := Compress(data)
Decompress(compressed)
}
}
// BenchmarkEncryption benchmarks encryption performance
func BenchmarkEncryption(b *testing.B) {
crypto := NewCrypto()
crypto.SetKey([]byte{1, 2, 3, 4})
data := make([]byte, 1024)
for i := range data {
data[i] = byte(i)
}
b.ResetTimer()
for range b.N {
encrypted := crypto.Encrypt(data)
crypto.Decrypt(encrypted)
}
}
// TestIntegration performs a basic integration test
func TestIntegration(t *testing.T) {
handler := &TestHandler{}
server, err := NewServer(":0", handler) // Use any available port
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Start server
go server.Start()
defer server.Stop()
// Wait for server to start
time.Sleep(50 * time.Millisecond)
// Basic integration test - server should be running with 0 connections
stats := server.GetStats()
if !stats.Running {
t.Error("Server should be running")
}
if stats.ConnectionCount != 0 {
t.Errorf("Expected 0 connections, got %d", stats.ConnectionCount)
}
}