package udp import ( "eq2emu/internal/common/opcodes" "fmt" "net" "sync" "testing" "time" ) // TestHandler implements PacketHandler for testing type TestHandler struct { receivedPackets []*ApplicationPacket mutex sync.Mutex } // HandlePacket processes received packets and stores them for verification func (h *TestHandler) HandlePacket(conn *Connection, packet *ApplicationPacket) { h.mutex.Lock() defer h.mutex.Unlock() h.receivedPackets = append(h.receivedPackets, packet) fmt.Printf("Received packet - Opcode: 0x%04X, Data length: %d\n", packet.Opcode, len(packet.Data)) // Echo back a response for interactive testing response := &ApplicationPacket{ Opcode: opcodes.OpLoginReplyMsg, Data: []byte("Hello from server"), } conn.SendPacket(response) } // GetReceivedPackets returns a copy of all received packets func (h *TestHandler) GetReceivedPackets() []*ApplicationPacket { h.mutex.Lock() defer h.mutex.Unlock() packets := make([]*ApplicationPacket, len(h.receivedPackets)) copy(packets, h.receivedPackets) return packets } // Clear removes all received packets func (h *TestHandler) Clear() { h.mutex.Lock() defer h.mutex.Unlock() h.receivedPackets = nil } // Simple handler function for testing func testHandler(conn *Connection, packet *ApplicationPacket) { fmt.Printf("Test handler received packet opcode: 0x%04X\n", packet.Opcode) } // TestServer tests basic server creation and startup func TestServer(t *testing.T) { server, err := NewServer(":9999", testHandler) if err != nil { t.Fatalf("Failed to create server: %v", err) } // Start server in goroutine go func() { if err := server.Start(); err != nil { t.Errorf("Server error: %v", err) } }() // Let it run briefly time.Sleep(100 * time.Millisecond) // Verify server is running if server.GetConnectionCount() != 0 { t.Errorf("Expected 0 connections, got %d", server.GetConnectionCount()) } server.Stop() } // TestServerConfig tests server configuration options func TestServerConfig(t *testing.T) { config := DefaultConfig() config.MaxConnections = 10 config.Timeout = 30 * time.Second config.BufferSize = 4096 server, err := NewServer(":9998", testHandler, config) if err != nil { t.Fatalf("Failed to create server with config: %v", err) } stats := server.GetStats() if stats.MaxConnections != 10 { t.Errorf("Expected max connections 10, got %d", stats.MaxConnections) } if stats.Timeout != 30*time.Second { t.Errorf("Expected timeout 30s, got %v", stats.Timeout) } server.Stop() } // TestPacketParsing tests protocol packet parsing func TestPacketParsing(t *testing.T) { // Test 1-byte opcode with CRC payload1 := []byte{0x01, 0x48, 0x65, 0x6C, 0x6C, 0x6F} data1 := AppendCRC(payload1) packet1, err := ParseProtocolPacket(data1) if err != nil { t.Fatalf("Failed to parse 1-byte opcode: %v", err) } if packet1.Opcode != 0x01 { t.Errorf("Expected opcode 0x01, got 0x%02X", packet1.Opcode) } // Test 2-byte opcode with CRC payload2 := []byte{0x00, 0xFF, 0x48, 0x65, 0x6C, 0x6C, 0x6F} data2 := AppendCRC(payload2) packet2, err := ParseProtocolPacket(data2) if err != nil { t.Fatalf("Failed to parse 2-byte opcode: %v", err) } if packet2.Opcode != 0xFF { t.Errorf("Expected opcode 0xFF, got 0x%02X", packet2.Opcode) } // Test session packet (no CRC required) sessionData := []byte{opcodes.OpSessionRequest, 0x48, 0x65, 0x6C, 0x6C, 0x6F} sessionPacket, err := ParseProtocolPacket(sessionData) if err != nil { t.Fatalf("Failed to parse session packet: %v", err) } if sessionPacket.Opcode != opcodes.OpSessionRequest { t.Errorf("Expected session request opcode, got 0x%02X", sessionPacket.Opcode) } } // TestCRC tests CRC calculation and validation func TestCRC(t *testing.T) { data := []byte("Hello, World!") // Append CRC and validate withCRC := AppendCRC(data) if !ValidateCRC(withCRC) { t.Error("CRC validation failed for correct data") } // Test with corrupted data corrupted := make([]byte, len(withCRC)) copy(corrupted, withCRC) corrupted[0] ^= 0xFF // Flip bits if ValidateCRC(corrupted) { t.Error("CRC validation passed for corrupted data") } // Test ValidateAndStrip stripped, valid := ValidateAndStrip(withCRC) if !valid { t.Error("ValidateAndStrip failed for valid data") } if string(stripped) != string(data) { t.Error("Stripped data doesn't match original") } } // TestCompression tests data compression and decompression func TestCompression(t *testing.T) { testData := []byte("This is a test string that should compress well because it has repetitive patterns.") compressed, err := Compress(testData) if err != nil { t.Fatalf("Compression failed: %v", err) } decompressed, err := Decompress(compressed) if err != nil { t.Fatalf("Decompression failed: %v", err) } if string(decompressed) != string(testData) { t.Error("Decompressed data doesn't match original") } // Test empty data empty, err := Compress([]byte{}) if err != nil { t.Fatalf("Empty compression failed: %v", err) } if len(empty) != 1 || empty[0] != UncompressedMarker { t.Error("Empty data compression incorrect") } } // TestCrypto tests RC4 encryption and decryption func TestCrypto(t *testing.T) { crypto := NewCrypto() key := []byte{0x01, 0x02, 0x03, 0x04} err := crypto.SetKey(key) if err != nil { t.Fatalf("SetKey failed: %v", err) } if !crypto.IsEncrypted() { t.Error("Crypto should be encrypted after SetKey") } testData := []byte("Hello, World!") encrypted := crypto.Encrypt(testData) decrypted := crypto.Decrypt(encrypted) if string(decrypted) != string(testData) { t.Error("Decrypted data doesn't match original") } } // TestRetransmitQueue tests packet retransmission logic func TestRetransmitQueue(t *testing.T) { config := DefaultConfig() rq := NewRetransmitQueue(config.RetransmitBase, config.RetransmitMax, config.RetransmitAttempts) packet := &ProtocolPacket{ Opcode: opcodes.OpPacket, Data: []byte("test"), } // Add packet rq.Add(packet, 1) if rq.Size() != 1 { t.Errorf("Expected size 1, got %d", rq.Size()) } // Acknowledge packet acked := rq.Acknowledge(1) if !acked { t.Error("Acknowledge should return true for existing packet") } if rq.Size() != 0 { t.Errorf("Expected size 0 after ack, got %d", rq.Size()) } // Test non-existent acknowledgment acked = rq.Acknowledge(999) if acked { t.Error("Acknowledge should return false for non-existent packet") } } // TestFragmentation tests packet fragmentation and reassembly func TestFragmentation(t *testing.T) { fm := NewFragmentManager(100) // Small max length to force fragmentation // Create large test data largeData := make([]byte, 300) for i := range largeData { largeData[i] = byte(i % 256) } // Fragment the data fragments := fm.FragmentPacket(largeData, 1) if fragments == nil { t.Fatal("Expected fragmentation for large data") } if len(fragments) < 2 { t.Error("Expected multiple fragments") } // Reassemble fragments var reassembled []byte complete := false for _, frag := range fragments { data, isComplete, err := fm.ProcessFragment(frag) if err != nil { t.Fatalf("Fragment processing failed: %v", err) } if isComplete { reassembled = data complete = true break } } if !complete { t.Error("Fragmentation did not complete") } if len(reassembled) != len(largeData) { t.Errorf("Reassembled length %d != original %d", len(reassembled), len(largeData)) } for i, b := range reassembled { if b != largeData[i] { t.Errorf("Reassembled data differs at position %d", i) break } } } // TestPacketCombining tests packet combination functionality func TestPacketCombining(t *testing.T) { combiner := NewPacketCombiner(256) // Add small packets - use session opcodes that don't require CRC packet1 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test1")} packet2 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test2")} combiner.Add(packet1) combiner.Add(packet2) if len(combiner.PendingPackets) != 2 { t.Errorf("Expected 2 pending packets, got %d", len(combiner.PendingPackets)) } // Flush combined combined := combiner.Flush() if len(combined) != 1 { t.Errorf("Expected 1 combined packet, got %d", len(combined)) } if combined[0].Opcode != opcodes.OpCombined { t.Error("Combined packet should have OpCombined opcode") } // Parse combined packet parsed, err := ParseCombinedPacket(combined[0].Data) if err != nil { t.Fatalf("Failed to parse combined packet: %v", err) } if len(parsed) != 2 { t.Errorf("Expected 2 parsed packets, got %d", len(parsed)) } } // TestConnection tests basic connection functionality func TestConnection(t *testing.T) { config := DefaultConfig() addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") conn, _ := net.ListenUDP("udp", addr) defer conn.Close() connection := NewConnection(addr, conn, testHandler, config) if connection.GetState() != StateClosed { t.Error("New connection should be in closed state") } // Test session ID if connection.GetSessionID() != 0 { t.Error("New connection should have session ID 0") } // Test timeout with very short timeout config shortConfig := DefaultConfig() shortConfig.Timeout = time.Nanosecond shortConnection := NewConnection(addr, conn, testHandler, shortConfig) // Wait a bit to ensure timeout time.Sleep(time.Millisecond) if !shortConnection.IsTimedOut() { t.Error("Connection should be timed out with very short timeout") } } // TestDefaultConfig tests the default configuration func TestDefaultConfig(t *testing.T) { config := DefaultConfig() if config.MaxConnections != 1000 { t.Errorf("Expected MaxConnections 1000, got %d", config.MaxConnections) } if config.Timeout != 45*time.Second { t.Errorf("Expected Timeout 45s, got %v", config.Timeout) } if config.MaxPacketSize != 512 { t.Errorf("Expected MaxPacketSize 512, got %d", config.MaxPacketSize) } if !config.EnableCompression { t.Error("Expected compression to be enabled by default") } if !config.EnableEncryption { t.Error("Expected encryption to be enabled by default") } } // BenchmarkCRC benchmarks CRC calculation performance func BenchmarkCRC(b *testing.B) { data := make([]byte, 1024) for i := range data { data[i] = byte(i) } for b.Loop() { CalculateCRC32(data) } } // BenchmarkCompression benchmarks compression performance func BenchmarkCompression(b *testing.B) { data := make([]byte, 1024) for i := range data { data[i] = byte(i % 256) } for b.Loop() { compressed, _ := Compress(data) Decompress(compressed) } } // BenchmarkEncryption benchmarks encryption performance func BenchmarkEncryption(b *testing.B) { crypto := NewCrypto() crypto.SetKey([]byte{1, 2, 3, 4}) data := make([]byte, 1024) for i := range data { data[i] = byte(i) } for b.Loop() { encrypted := crypto.Encrypt(data) crypto.Decrypt(encrypted) } } // TestIntegration performs a basic integration test func TestIntegration(t *testing.T) { server, err := NewServer(":0", testHandler) // Use any available port if err != nil { t.Fatalf("Failed to create server: %v", err) } // Start server go server.Start() defer server.Stop() // Wait for server to start time.Sleep(50 * time.Millisecond) // Basic integration test - server should be running with 0 connections stats := server.GetStats() if !stats.Running { t.Error("Server should be running") } if stats.ConnectionCount != 0 { t.Errorf("Expected 0 connections, got %d", stats.ConnectionCount) } } // TestDirectFieldAccess tests that we can access fields directly func TestDirectFieldAccess(t *testing.T) { // Test PacketCombiner direct access combiner := NewPacketCombiner(256) combiner.MaxSize = 512 // Direct field modification if combiner.MaxSize != 512 { t.Errorf("Expected MaxSize 512, got %d", combiner.MaxSize) } // Test adding packets and accessing them directly packet := &ProtocolPacket{Opcode: opcodes.OpKeepAlive, Data: []byte("test")} combiner.Add(packet) if len(combiner.PendingPackets) != 1 { t.Errorf("Expected 1 pending packet, got %d", len(combiner.PendingPackets)) } // Direct access to pending packets firstPacket := combiner.PendingPackets[0] if firstPacket.Opcode != opcodes.OpKeepAlive { t.Errorf("Expected OpKeepAlive, got 0x%02X", firstPacket.Opcode) } }