423 lines
10 KiB
Go
423 lines
10 KiB
Go
package udp
|
|
|
|
import (
|
|
"eq2emu/internal/opcodes"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// TestHandler implements PacketHandler for testing
|
|
type TestHandler struct {
|
|
receivedPackets []*ApplicationPacket
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
// HandlePacket processes received packets and stores them for verification
|
|
func (h *TestHandler) HandlePacket(conn *Connection, packet *ApplicationPacket) {
|
|
h.mutex.Lock()
|
|
defer h.mutex.Unlock()
|
|
|
|
h.receivedPackets = append(h.receivedPackets, packet)
|
|
fmt.Printf("Received packet - Opcode: 0x%04X, Data length: %d\n",
|
|
packet.Opcode, len(packet.Data))
|
|
|
|
// Echo back a response for interactive testing
|
|
response := &ApplicationPacket{
|
|
Opcode: opcodes.OpLoginReplyMsg,
|
|
Data: []byte("Hello from server"),
|
|
}
|
|
conn.SendPacket(response)
|
|
}
|
|
|
|
// GetReceivedPackets returns a copy of all received packets
|
|
func (h *TestHandler) GetReceivedPackets() []*ApplicationPacket {
|
|
h.mutex.Lock()
|
|
defer h.mutex.Unlock()
|
|
|
|
packets := make([]*ApplicationPacket, len(h.receivedPackets))
|
|
copy(packets, h.receivedPackets)
|
|
return packets
|
|
}
|
|
|
|
// Clear removes all received packets
|
|
func (h *TestHandler) Clear() {
|
|
h.mutex.Lock()
|
|
defer h.mutex.Unlock()
|
|
h.receivedPackets = nil
|
|
}
|
|
|
|
// TestServer tests basic server creation and startup
|
|
func TestServer(t *testing.T) {
|
|
handler := &TestHandler{}
|
|
server, err := NewServer(":9999", handler)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create server: %v", err)
|
|
}
|
|
|
|
// Start server in goroutine
|
|
go func() {
|
|
if err := server.Start(); err != nil {
|
|
t.Errorf("Server error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Let it run briefly
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Verify server is running
|
|
if server.GetConnectionCount() != 0 {
|
|
t.Errorf("Expected 0 connections, got %d", server.GetConnectionCount())
|
|
}
|
|
|
|
server.Stop()
|
|
}
|
|
|
|
// TestServerConfig tests server configuration options
|
|
func TestServerConfig(t *testing.T) {
|
|
handler := &TestHandler{}
|
|
config := ServerConfig{
|
|
MaxConnections: 10,
|
|
Timeout: 30 * time.Second,
|
|
BufferSize: 4096,
|
|
}
|
|
|
|
server, err := NewServerWithConfig(":9998", handler, config)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create server with config: %v", err)
|
|
}
|
|
|
|
stats := server.GetStats()
|
|
if stats.MaxConnections != 10 {
|
|
t.Errorf("Expected max connections 10, got %d", stats.MaxConnections)
|
|
}
|
|
if stats.Timeout != 30*time.Second {
|
|
t.Errorf("Expected timeout 30s, got %v", stats.Timeout)
|
|
}
|
|
|
|
server.Stop()
|
|
}
|
|
|
|
// TestPacketParsing tests protocol packet parsing
|
|
func TestPacketParsing(t *testing.T) {
|
|
// Test 1-byte opcode with CRC
|
|
payload1 := []byte{0x01, 0x48, 0x65, 0x6C, 0x6C, 0x6F}
|
|
data1 := AppendCRC(payload1)
|
|
packet1, err := ParseProtocolPacket(data1)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse 1-byte opcode: %v", err)
|
|
}
|
|
if packet1.Opcode != 0x01 {
|
|
t.Errorf("Expected opcode 0x01, got 0x%02X", packet1.Opcode)
|
|
}
|
|
|
|
// Test 2-byte opcode with CRC
|
|
payload2 := []byte{0x00, 0xFF, 0x48, 0x65, 0x6C, 0x6C, 0x6F}
|
|
data2 := AppendCRC(payload2)
|
|
packet2, err := ParseProtocolPacket(data2)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse 2-byte opcode: %v", err)
|
|
}
|
|
if packet2.Opcode != 0xFF {
|
|
t.Errorf("Expected opcode 0xFF, got 0x%02X", packet2.Opcode)
|
|
}
|
|
|
|
// Test session packet (no CRC required)
|
|
sessionData := []byte{opcodes.OpSessionRequest, 0x48, 0x65, 0x6C, 0x6C, 0x6F}
|
|
sessionPacket, err := ParseProtocolPacket(sessionData)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse session packet: %v", err)
|
|
}
|
|
if sessionPacket.Opcode != opcodes.OpSessionRequest {
|
|
t.Errorf("Expected session request opcode, got 0x%02X", sessionPacket.Opcode)
|
|
}
|
|
}
|
|
|
|
// TestCRC tests CRC calculation and validation
|
|
func TestCRC(t *testing.T) {
|
|
data := []byte("Hello, World!")
|
|
|
|
// Append CRC and validate
|
|
withCRC := AppendCRC(data)
|
|
if !ValidateCRC(withCRC) {
|
|
t.Error("CRC validation failed for correct data")
|
|
}
|
|
|
|
// Test with corrupted data
|
|
corrupted := make([]byte, len(withCRC))
|
|
copy(corrupted, withCRC)
|
|
corrupted[0] ^= 0xFF // Flip bits
|
|
|
|
if ValidateCRC(corrupted) {
|
|
t.Error("CRC validation passed for corrupted data")
|
|
}
|
|
|
|
// Test ValidateAndStrip
|
|
stripped, valid := ValidateAndStrip(withCRC)
|
|
if !valid {
|
|
t.Error("ValidateAndStrip failed for valid data")
|
|
}
|
|
if string(stripped) != string(data) {
|
|
t.Error("Stripped data doesn't match original")
|
|
}
|
|
}
|
|
|
|
// TestCompression tests data compression and decompression
|
|
func TestCompression(t *testing.T) {
|
|
testData := []byte("This is a test string that should compress well because it has repetitive patterns.")
|
|
|
|
compressed, err := Compress(testData)
|
|
if err != nil {
|
|
t.Fatalf("Compression failed: %v", err)
|
|
}
|
|
|
|
decompressed, err := Decompress(compressed)
|
|
if err != nil {
|
|
t.Fatalf("Decompression failed: %v", err)
|
|
}
|
|
|
|
if string(decompressed) != string(testData) {
|
|
t.Error("Decompressed data doesn't match original")
|
|
}
|
|
|
|
// Test empty data
|
|
empty, err := Compress([]byte{})
|
|
if err != nil {
|
|
t.Fatalf("Empty compression failed: %v", err)
|
|
}
|
|
if len(empty) != 1 || empty[0] != UncompressedMarker {
|
|
t.Error("Empty data compression incorrect")
|
|
}
|
|
}
|
|
|
|
// TestCrypto tests RC4 encryption and decryption
|
|
func TestCrypto(t *testing.T) {
|
|
crypto := NewCrypto()
|
|
key := []byte{0x01, 0x02, 0x03, 0x04}
|
|
|
|
err := crypto.SetKey(key)
|
|
if err != nil {
|
|
t.Fatalf("SetKey failed: %v", err)
|
|
}
|
|
|
|
if !crypto.IsEncrypted() {
|
|
t.Error("Crypto should be encrypted after SetKey")
|
|
}
|
|
|
|
testData := []byte("Hello, World!")
|
|
encrypted := crypto.Encrypt(testData)
|
|
decrypted := crypto.Decrypt(encrypted)
|
|
|
|
if string(decrypted) != string(testData) {
|
|
t.Error("Decrypted data doesn't match original")
|
|
}
|
|
}
|
|
|
|
// TestRetransmitQueue tests packet retransmission logic
|
|
func TestRetransmitQueue(t *testing.T) {
|
|
rq := NewRetransmitQueue()
|
|
|
|
packet := &ProtocolPacket{
|
|
Opcode: opcodes.OpPacket,
|
|
Data: []byte("test"),
|
|
}
|
|
|
|
// Add packet
|
|
rq.Add(packet, 1)
|
|
if rq.Size() != 1 {
|
|
t.Errorf("Expected size 1, got %d", rq.Size())
|
|
}
|
|
|
|
// Acknowledge packet
|
|
acked := rq.Acknowledge(1)
|
|
if !acked {
|
|
t.Error("Acknowledge should return true for existing packet")
|
|
}
|
|
if rq.Size() != 0 {
|
|
t.Errorf("Expected size 0 after ack, got %d", rq.Size())
|
|
}
|
|
|
|
// Test non-existent acknowledgment
|
|
acked = rq.Acknowledge(999)
|
|
if acked {
|
|
t.Error("Acknowledge should return false for non-existent packet")
|
|
}
|
|
}
|
|
|
|
// TestFragmentation tests packet fragmentation and reassembly
|
|
func TestFragmentation(t *testing.T) {
|
|
fm := NewFragmentManager(100) // Small max length to force fragmentation
|
|
|
|
// Create large test data
|
|
largeData := make([]byte, 300)
|
|
for i := range largeData {
|
|
largeData[i] = byte(i % 256)
|
|
}
|
|
|
|
// Fragment the data
|
|
fragments := fm.FragmentPacket(largeData, 1)
|
|
if fragments == nil {
|
|
t.Fatal("Expected fragmentation for large data")
|
|
}
|
|
if len(fragments) < 2 {
|
|
t.Error("Expected multiple fragments")
|
|
}
|
|
|
|
// Reassemble fragments
|
|
var reassembled []byte
|
|
complete := false
|
|
for _, frag := range fragments {
|
|
data, isComplete, err := fm.ProcessFragment(frag)
|
|
if err != nil {
|
|
t.Fatalf("Fragment processing failed: %v", err)
|
|
}
|
|
if isComplete {
|
|
reassembled = data
|
|
complete = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !complete {
|
|
t.Error("Fragmentation did not complete")
|
|
}
|
|
if len(reassembled) != len(largeData) {
|
|
t.Errorf("Reassembled length %d != original %d", len(reassembled), len(largeData))
|
|
}
|
|
for i, b := range reassembled {
|
|
if b != largeData[i] {
|
|
t.Errorf("Reassembled data differs at position %d", i)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestPacketCombining tests packet combination functionality
|
|
func TestPacketCombining(t *testing.T) {
|
|
combiner := NewPacketCombiner()
|
|
|
|
// Add small packets - use session opcodes that don't require CRC
|
|
packet1 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test1")}
|
|
packet2 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test2")}
|
|
|
|
combiner.AddPacket(packet1)
|
|
combiner.AddPacket(packet2)
|
|
|
|
if combiner.GetPendingCount() != 2 {
|
|
t.Errorf("Expected 2 pending packets, got %d", combiner.GetPendingCount())
|
|
}
|
|
|
|
// Flush combined
|
|
combined := combiner.FlushCombined()
|
|
if len(combined) != 1 {
|
|
t.Errorf("Expected 1 combined packet, got %d", len(combined))
|
|
}
|
|
if combined[0].Opcode != opcodes.OpCombined {
|
|
t.Error("Combined packet should have OpCombined opcode")
|
|
}
|
|
|
|
// Parse combined packet
|
|
parsed, err := ParseCombinedPacket(combined[0].Data)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse combined packet: %v", err)
|
|
}
|
|
if len(parsed) != 2 {
|
|
t.Errorf("Expected 2 parsed packets, got %d", len(parsed))
|
|
}
|
|
}
|
|
|
|
// TestConnection tests basic connection functionality
|
|
func TestConnection(t *testing.T) {
|
|
handler := &TestHandler{}
|
|
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
|
conn, _ := net.ListenUDP("udp", addr)
|
|
defer conn.Close()
|
|
|
|
connection := NewConnection(addr, conn, handler)
|
|
|
|
if connection.GetState() != StateClosed {
|
|
t.Error("New connection should be in closed state")
|
|
}
|
|
|
|
// Test session ID
|
|
if connection.GetSessionID() != 0 {
|
|
t.Error("New connection should have session ID 0")
|
|
}
|
|
|
|
// Test timeout
|
|
if !connection.IsTimedOut(time.Nanosecond) {
|
|
t.Error("New connection should be timed out with very short timeout")
|
|
}
|
|
}
|
|
|
|
// BenchmarkCRC benchmarks CRC calculation performance
|
|
func BenchmarkCRC(b *testing.B) {
|
|
data := make([]byte, 1024)
|
|
for i := range data {
|
|
data[i] = byte(i)
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for range b.N {
|
|
CalculateCRC32(data)
|
|
}
|
|
}
|
|
|
|
// BenchmarkCompression benchmarks compression performance
|
|
func BenchmarkCompression(b *testing.B) {
|
|
data := make([]byte, 1024)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for range b.N {
|
|
compressed, _ := Compress(data)
|
|
Decompress(compressed)
|
|
}
|
|
}
|
|
|
|
// BenchmarkEncryption benchmarks encryption performance
|
|
func BenchmarkEncryption(b *testing.B) {
|
|
crypto := NewCrypto()
|
|
crypto.SetKey([]byte{1, 2, 3, 4})
|
|
|
|
data := make([]byte, 1024)
|
|
for i := range data {
|
|
data[i] = byte(i)
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for range b.N {
|
|
encrypted := crypto.Encrypt(data)
|
|
crypto.Decrypt(encrypted)
|
|
}
|
|
}
|
|
|
|
// TestIntegration performs a basic integration test
|
|
func TestIntegration(t *testing.T) {
|
|
handler := &TestHandler{}
|
|
server, err := NewServer(":0", handler) // Use any available port
|
|
if err != nil {
|
|
t.Fatalf("Failed to create server: %v", err)
|
|
}
|
|
|
|
// Start server
|
|
go server.Start()
|
|
defer server.Stop()
|
|
|
|
// Wait for server to start
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Basic integration test - server should be running with 0 connections
|
|
stats := server.GetStats()
|
|
if !stats.Running {
|
|
t.Error("Server should be running")
|
|
}
|
|
if stats.ConnectionCount != 0 {
|
|
t.Errorf("Expected 0 connections, got %d", stats.ConnectionCount)
|
|
}
|
|
}
|