diff --git a/internal/udp/combine.go b/internal/udp/combine.go deleted file mode 100644 index ea8c7c2..0000000 --- a/internal/udp/combine.go +++ /dev/null @@ -1,248 +0,0 @@ -package udp - -import ( - "bytes" - "eq2emu/internal/opcodes" - "errors" -) - -// PacketCombiner groups small packets together to reduce UDP overhead -type PacketCombiner struct { - pendingPackets []*ProtocolPacket // Packets awaiting combination - maxSize int // Maximum combined packet size - timeout int // Combination timeout in milliseconds -} - -// NewPacketCombiner creates a combiner with default settings -func NewPacketCombiner() *PacketCombiner { - return &PacketCombiner{ - maxSize: 256, // Default size threshold for combining - timeout: 10, // Default timeout in ms - } -} - -// NewPacketCombinerWithConfig creates a combiner with custom settings -func NewPacketCombinerWithConfig(maxSize, timeout int) *PacketCombiner { - return &PacketCombiner{ - maxSize: maxSize, - timeout: timeout, - } -} - -// AddPacket queues a packet for potential combining -func (pc *PacketCombiner) AddPacket(packet *ProtocolPacket) { - pc.pendingPackets = append(pc.pendingPackets, packet) -} - -// FlushCombined returns combined packets and clears the queue -func (pc *PacketCombiner) FlushCombined() []*ProtocolPacket { - if len(pc.pendingPackets) == 0 { - return nil - } - - if len(pc.pendingPackets) == 1 { - // Single packet - no combining needed - packet := pc.pendingPackets[0] - pc.pendingPackets = nil - return []*ProtocolPacket{packet} - } - - // Combine multiple packets - combined := pc.combineProtocolPackets(pc.pendingPackets) - pc.pendingPackets = nil - return []*ProtocolPacket{combined} -} - -// combineProtocolPackets merges multiple packets into a single combined packet -func (pc *PacketCombiner) combineProtocolPackets(packets []*ProtocolPacket) *ProtocolPacket { - var buf bytes.Buffer - - for _, packet := range packets { - serialized := packet.Serialize() - pc.writeSizeHeader(&buf, len(serialized)) - buf.Write(serialized) - } - - return &ProtocolPacket{ - Opcode: opcodes.OpCombined, - Data: buf.Bytes(), - } -} - -// writeSizeHeader writes packet size using variable-length encoding -func (pc *PacketCombiner) writeSizeHeader(buf *bytes.Buffer, size int) { - if size >= 255 { - // Large packet - use 3-byte header [0xFF][low][high] - buf.WriteByte(0xFF) - buf.WriteByte(byte(size)) - buf.WriteByte(byte(size >> 8)) - } else { - // Small packet - use 1-byte header - buf.WriteByte(byte(size)) - } -} - -// ParseCombinedPacket splits combined packet into individual packets -func ParseCombinedPacket(data []byte) ([]*ProtocolPacket, error) { - var packets []*ProtocolPacket - offset := 0 - - for offset < len(data) { - size, headerSize, err := readSizeHeader(data, offset) - if err != nil { - break - } - - offset += headerSize - - if offset+size > len(data) { - break // Incomplete packet - } - - // Parse individual packet - packetData := data[offset : offset+size] - if packet, err := ParseProtocolPacket(packetData); err == nil { - packets = append(packets, packet) - } - - offset += size - } - - return packets, nil -} - -// readSizeHeader reads variable-length size header -func readSizeHeader(data []byte, offset int) (size, headerSize int, err error) { - if offset >= len(data) { - return 0, 0, errors.New("insufficient data for size header") - } - - if data[offset] == 0xFF { - // 3-byte size header - if offset+2 >= len(data) { - return 0, 0, errors.New("insufficient data for 3-byte size header") - } - size = int(data[offset+1]) | (int(data[offset+2]) << 8) - headerSize = 3 - } else { - // 1-byte size header - size = int(data[offset]) - headerSize = 1 - } - - return size, headerSize, nil -} - -// ShouldCombine determines if packets should be combined based on total size -func (pc *PacketCombiner) ShouldCombine() bool { - if len(pc.pendingPackets) < 2 { - return false - } - - totalSize := 0 - for _, packet := range pc.pendingPackets { - serialized := packet.Serialize() - totalSize += len(serialized) - - // Add size header overhead - if len(serialized) >= 255 { - totalSize += 3 - } else { - totalSize += 1 - } - } - - return totalSize <= pc.maxSize -} - -// HasPendingPackets returns true if packets are waiting to be combined -func (pc *PacketCombiner) HasPendingPackets() bool { - return len(pc.pendingPackets) > 0 -} - -// GetPendingCount returns the number of packets waiting to be combined -func (pc *PacketCombiner) GetPendingCount() int { - return len(pc.pendingPackets) -} - -// Clear removes all pending packets without combining -func (pc *PacketCombiner) Clear() { - pc.pendingPackets = nil -} - -// SetMaxSize updates the maximum combined packet size -func (pc *PacketCombiner) SetMaxSize(maxSize int) { - pc.maxSize = maxSize -} - -// SetTimeout updates the combination timeout -func (pc *PacketCombiner) SetTimeout(timeout int) { - pc.timeout = timeout -} - -// GetStats returns packet combination statistics -func (pc *PacketCombiner) GetStats() CombinerStats { - return CombinerStats{ - PendingCount: len(pc.pendingPackets), - MaxSize: pc.maxSize, - Timeout: pc.timeout, - } -} - -// CombinerStats contains packet combiner statistics -type CombinerStats struct { - PendingCount int // Number of packets waiting to be combined - MaxSize int // Maximum combined packet size - Timeout int // Combination timeout in milliseconds -} - -// EstimateCombinedSize calculates the size if current packets were combined -func (pc *PacketCombiner) EstimateCombinedSize() int { - if len(pc.pendingPackets) == 0 { - return 0 - } - - totalSize := 0 - for _, packet := range pc.pendingPackets { - serialized := packet.Serialize() - packetSize := len(serialized) - totalSize += packetSize - - // Add size header overhead - if packetSize >= 255 { - totalSize += 3 - } else { - totalSize += 1 - } - } - - return totalSize -} - -// ValidateCombinedPacket checks if combined packet data is well-formed -func ValidateCombinedPacket(data []byte) error { - offset := 0 - count := 0 - - for offset < len(data) { - size, headerSize, err := readSizeHeader(data, offset) - if err != nil { - return err - } - - offset += headerSize - - if offset+size > len(data) { - return errors.New("packet extends beyond data boundary") - } - - offset += size - count++ - - if count > 100 { // Sanity check - return errors.New("too many packets in combined packet") - } - } - - return nil -} diff --git a/internal/udp/connection.go b/internal/udp/connection.go index 892b14f..972489b 100644 --- a/internal/udp/connection.go +++ b/internal/udp/connection.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/binary" "eq2emu/internal/opcodes" + "errors" "net" "sync" "time" @@ -19,11 +20,51 @@ const ( StateWaitClose // Waiting for close confirmation ) -const ( - DefaultWindowSize = 2048 // Default sliding window size for flow control - MaxPacketSize = 512 // Maximum packet size before fragmentation +// Common connection errors +var ( + ErrSessionClosed = errors.New("session closed") ) +// Config holds all UDP server and connection configuration +type Config struct { + // Server settings + MaxConnections int // Maximum concurrent connections + Timeout time.Duration // Connection timeout duration + BufferSize int // UDP socket buffer size + + // Protocol settings + MaxPacketSize uint32 // Maximum packet size before fragmentation + WindowSize uint16 // Sliding window size for flow control + RetransmitBase time.Duration // Base retransmission timeout + RetransmitMax time.Duration // Maximum retransmission timeout + RetransmitAttempts int // Maximum retransmission attempts + CombineThreshold int // Packet combining size threshold + + // Features + EnableCompression bool // Enable zlib compression + EnableEncryption bool // Enable RC4 encryption +} + +// DefaultConfig returns sensible defaults for EQ2EMu protocol +func DefaultConfig() Config { + return Config{ + MaxConnections: 1000, + Timeout: 45 * time.Second, + BufferSize: 8192, + MaxPacketSize: 512, + WindowSize: 2048, + RetransmitBase: 500 * time.Millisecond, + RetransmitMax: 5 * time.Second, + RetransmitAttempts: 5, + CombineThreshold: 256, + EnableCompression: true, + EnableEncryption: true, + } +} + +// PacketHandler processes application-level packets +type PacketHandler func(*Connection, *ApplicationPacket) + // Connection manages a single client connection over UDP with reliability features type Connection struct { // Network details @@ -43,7 +84,6 @@ type Connection struct { // Sequence tracking for reliable delivery nextInSeq uint16 // Next expected incoming sequence number nextOutSeq uint16 // Next outgoing sequence number - windowSize uint16 // Flow control window size // Protocol components retransmitQueue *RetransmitQueue // Handles packet retransmission @@ -53,25 +93,33 @@ type Connection struct { crypto *Crypto // Handles encryption/decryption // Connection timing - lastPacketTime time.Time // Last received packet timestamp - lastAckTime time.Time // Last acknowledgment timestamp + lastActivity time.Time // Last activity timestamp + + // Configuration (embedded from server) + config Config } -// NewConnection creates a new connection instance with default settings -func NewConnection(addr *net.UDPAddr, conn *net.UDPConn, handler PacketHandler) *Connection { +// NewConnection creates a new connection instance with server configuration +func NewConnection(addr *net.UDPAddr, conn *net.UDPConn, handler PacketHandler, config Config) *Connection { return &Connection{ - addr: addr, - conn: conn, - handler: handler, - state: StateClosed, - maxLength: MaxPacketSize, - windowSize: DefaultWindowSize, - lastPacketTime: time.Now(), - crypto: NewCrypto(), - retransmitQueue: NewRetransmitQueue(), - fragmentMgr: NewFragmentManager(MaxPacketSize), - combiner: NewPacketCombiner(), - outOfOrderMap: make(map[uint16]*ProtocolPacket), + addr: addr, + conn: conn, + handler: handler, + state: StateClosed, + maxLength: config.MaxPacketSize, + lastActivity: time.Now(), + config: config, + + // Initialize components with config values + retransmitQueue: NewRetransmitQueue( + config.RetransmitBase, + config.RetransmitMax, + config.RetransmitAttempts, + ), + fragmentMgr: NewFragmentManager(config.MaxPacketSize), + combiner: NewPacketCombiner(config.CombineThreshold), + crypto: NewCrypto(), + outOfOrderMap: make(map[uint16]*ProtocolPacket), } } @@ -80,7 +128,7 @@ func (c *Connection) ProcessPacket(data []byte) { c.mutex.Lock() defer c.mutex.Unlock() - c.lastPacketTime = time.Now() + c.lastActivity = time.Now() packet, err := ParseProtocolPacket(data) if err != nil { @@ -187,7 +235,7 @@ func (c *Connection) processInOrderPacket(seq uint16, payload []byte) { // Process application data if appPacket, err := c.processApplicationData(payload); err == nil { - c.handler.HandlePacket(c, appPacket) + c.handler(c, appPacket) } // Check for queued out-of-order packets that can now be processed @@ -211,7 +259,7 @@ func (c *Connection) processQueuedPackets() { c.sendAck(seq) if appPacket, err := c.processApplicationData(payload); err == nil { - c.handler.HandlePacket(c, appPacket) + c.handler(c, appPacket) } } } @@ -221,7 +269,7 @@ func (c *Connection) processQueuedPackets() { func (c *Connection) handleFragment(packet *ProtocolPacket) { if data, complete, err := c.fragmentMgr.ProcessFragment(packet); err == nil && complete { if appPacket, err := c.processApplicationData(data); err == nil { - c.handler.HandlePacket(c, appPacket) + c.handler(c, appPacket) } } } @@ -243,7 +291,6 @@ func (c *Connection) handleAck(packet *ProtocolPacket) { seq := binary.BigEndian.Uint16(packet.Data[0:2]) c.retransmitQueue.Acknowledge(seq) - c.lastAckTime = time.Now() } // handleOutOfOrderAck processes out-of-order acknowledgments @@ -290,14 +337,14 @@ func (c *Connection) SendPacket(packet *ApplicationPacket) { // processOutboundData applies compression and encryption to outgoing data func (c *Connection) processOutboundData(data []byte) []byte { // Compress large packets if compression is enabled - if c.compressed && len(data) > 128 { + if c.config.EnableCompression && c.compressed && len(data) > 128 { if compressed, err := Compress(data); err == nil { data = compressed } } // Encrypt data if encryption is enabled - if c.crypto.IsEncrypted() { + if c.config.EnableEncryption && c.crypto.IsEncrypted() { data = c.crypto.Encrypt(data) } @@ -307,12 +354,12 @@ func (c *Connection) processOutboundData(data []byte) []byte { // processApplicationData decrypts and decompresses incoming application data func (c *Connection) processApplicationData(data []byte) (*ApplicationPacket, error) { // Decrypt if encryption is enabled - if c.crypto.IsEncrypted() { + if c.config.EnableEncryption && c.crypto.IsEncrypted() { data = c.crypto.Decrypt(data) } // Decompress if compression is enabled - if c.compressed && len(data) > 0 { + if c.config.EnableCompression && c.compressed && len(data) > 0 { var err error data, err = Decompress(data) if err != nil { @@ -454,6 +501,36 @@ func (c *Connection) StartRetransmitLoop() { }() } +// Stats returns comprehensive connection statistics +type Stats struct { + // Connection info + State ConnectionState + SessionID uint32 + LastActivity time.Time + + // Queue stats + PendingRetransmits int + PendingFragments int + PendingCombined int + OutOfOrderCount int +} + +// GetStats returns unified statistics +func (c *Connection) GetStats() Stats { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return Stats{ + State: c.state, + SessionID: c.sessionID, + LastActivity: c.lastActivity, + PendingRetransmits: c.retransmitQueue.Size(), + PendingFragments: len(c.fragmentMgr.fragments), + PendingCombined: len(c.combiner.PendingPackets), + OutOfOrderCount: len(c.outOfOrderMap), + } +} + // GetState returns the current connection state (thread-safe) func (c *Connection) GetState() ConnectionState { c.mutex.RLock() @@ -469,8 +546,8 @@ func (c *Connection) GetSessionID() uint32 { } // IsTimedOut checks if connection has timed out -func (c *Connection) IsTimedOut(timeout time.Duration) bool { +func (c *Connection) IsTimedOut() bool { c.mutex.RLock() defer c.mutex.RUnlock() - return time.Since(c.lastPacketTime) > timeout + return time.Since(c.lastActivity) > c.config.Timeout } diff --git a/internal/udp/crc.go b/internal/udp/crc.go deleted file mode 100644 index 456554f..0000000 --- a/internal/udp/crc.go +++ /dev/null @@ -1,73 +0,0 @@ -package udp - -// EQ2EMu uses a specific CRC32 polynomial (reversed) -const crcPolynomial = 0xEDB88320 - -// Pre-computed CRC32 lookup table for fast calculation -var crcTable [256]uint32 - -// init builds the CRC lookup table at package initialization -func init() { - for i := range crcTable { - crc := uint32(i) - for range 8 { - if crc&1 == 1 { - crc = (crc >> 1) ^ crcPolynomial - } else { - crc >>= 1 - } - } - crcTable[i] = crc - } -} - -// CalculateCRC32 computes CRC32 using EQ2EMu's algorithm -// Returns 16-bit value by truncating the upper bits -func CalculateCRC32(data []byte) uint16 { - crc := uint32(0xFFFFFFFF) - - // Use lookup table for efficient calculation - for _, b := range data { - crc = crcTable[byte(crc)^b] ^ (crc >> 8) - } - - // Return inverted result truncated to 16 bits - return uint16(^crc) -} - -// ValidateCRC checks if packet has valid CRC -// Expects CRC to be the last 2 bytes of data -func ValidateCRC(data []byte) bool { - if len(data) < 2 { - return false - } - - // Split payload and CRC - payload := data[:len(data)-2] - expectedCRC := uint16(data[len(data)-2]) | (uint16(data[len(data)-1]) << 8) - - // Calculate and compare - actualCRC := CalculateCRC32(payload) - return expectedCRC == actualCRC -} - -// AppendCRC adds 16-bit CRC to the end of data -func AppendCRC(data []byte) []byte { - crc := CalculateCRC32(data) - result := make([]byte, len(data)+2) - copy(result, data) - - // Append CRC in little-endian format - result[len(data)] = byte(crc) - result[len(data)+1] = byte(crc >> 8) - - return result -} - -// ValidateAndStrip validates CRC and returns data without CRC suffix -func ValidateAndStrip(data []byte) ([]byte, bool) { - if !ValidateCRC(data) { - return nil, false - } - return data[:len(data)-2], true -} diff --git a/internal/udp/packet.go b/internal/udp/packet.go deleted file mode 100644 index 9d5b560..0000000 --- a/internal/udp/packet.go +++ /dev/null @@ -1,136 +0,0 @@ -package udp - -import ( - "encoding/binary" - "eq2emu/internal/opcodes" - "errors" - "fmt" -) - -// ProtocolPacket represents a low-level UDP protocol packet with opcode and payload -type ProtocolPacket struct { - Opcode uint8 // Protocol operation code (1-2 bytes when serialized) - Data []byte // Packet payload data - Raw []byte // Original raw packet data for debugging -} - -// ApplicationPacket represents a higher-level game application packet -type ApplicationPacket struct { - Opcode uint16 // Application-level operation code - Data []byte // Application payload data -} - -// ParseProtocolPacket parses raw UDP data into a ProtocolPacket -// Handles variable opcode sizing and CRC validation based on EQ2 protocol -func ParseProtocolPacket(data []byte) (*ProtocolPacket, error) { - if len(data) < 2 { - return nil, errors.New("packet too small for valid protocol packet") - } - - var opcode uint8 - var dataStart int - - // EQ2 protocol uses 1-byte opcodes normally, 2-byte for opcodes >= 0xFF - // When opcode >= 0xFF, it's prefixed with 0x00 - if data[0] == 0x00 && len(data) > 2 { - opcode = data[1] - dataStart = 2 - } else { - opcode = data[0] - dataStart = 1 - } - - // Extract payload, handling CRC for non-session packets - var payload []byte - if requiresCRC(opcode) { - if len(data) < dataStart+2 { - return nil, errors.New("packet too small for CRC validation") - } - - // Payload excludes the 2-byte CRC suffix - payload = data[dataStart : len(data)-2] - - // Validate CRC on the entire packet from beginning - if !ValidateCRC(data) { - return nil, fmt.Errorf("CRC validation failed for opcode 0x%02X", opcode) - } - } else { - payload = data[dataStart:] - } - - return &ProtocolPacket{ - Opcode: opcode, - Data: payload, - Raw: data, - }, nil -} - -// Serialize converts ProtocolPacket back to wire format with proper opcode encoding and CRC -func (p *ProtocolPacket) Serialize() []byte { - var result []byte - - // Handle variable opcode encoding - if p.Opcode == 0xFF { - // 2-byte opcode format: [0x00][actual_opcode][data] - result = make([]byte, 2+len(p.Data)) - result[0] = 0x00 - result[1] = p.Opcode - copy(result[2:], p.Data) - } else { - // 1-byte opcode format: [opcode][data] - result = make([]byte, 1+len(p.Data)) - result[0] = p.Opcode - copy(result[1:], p.Data) - } - - // Add CRC for packets that require it - if requiresCRC(p.Opcode) { - result = AppendCRC(result) - } - - return result -} - -// ParseApplicationPacket parses application-level packet from decrypted/decompressed data -func ParseApplicationPacket(data []byte) (*ApplicationPacket, error) { - if len(data) < 2 { - return nil, errors.New("application packet requires at least 2 bytes for opcode") - } - - // Application opcodes are always little-endian 16-bit values - opcode := binary.LittleEndian.Uint16(data[0:2]) - - return &ApplicationPacket{ - Opcode: opcode, - Data: data[2:], - }, nil -} - -// Serialize converts ApplicationPacket to byte array for transmission -func (p *ApplicationPacket) Serialize() []byte { - result := make([]byte, 2+len(p.Data)) - binary.LittleEndian.PutUint16(result[0:2], p.Opcode) - copy(result[2:], p.Data) - return result -} - -// String provides human-readable representation for debugging -func (p *ProtocolPacket) String() string { - return fmt.Sprintf("ProtocolPacket{Opcode: 0x%02X, DataLen: %d}", p.Opcode, len(p.Data)) -} - -// String provides human-readable representation for debugging -func (p *ApplicationPacket) String() string { - return fmt.Sprintf("ApplicationPacket{Opcode: 0x%04X, DataLen: %d}", p.Opcode, len(p.Data)) -} - -// requiresCRC determines if a protocol opcode requires CRC validation -// Session control packets (SessionRequest, SessionResponse, OutOfSession) don't use CRC -func requiresCRC(opcode uint8) bool { - switch opcode { - case opcodes.OpSessionRequest, opcodes.OpSessionResponse, opcodes.OpOutOfSession: - return false - default: - return true - } -} diff --git a/internal/udp/protocol.go b/internal/udp/protocol.go new file mode 100644 index 0000000..f21ada7 --- /dev/null +++ b/internal/udp/protocol.go @@ -0,0 +1,317 @@ +package udp + +import ( + "bytes" + "encoding/binary" + "eq2emu/internal/opcodes" + "errors" + "fmt" +) + +// Common protocol errors +var ( + ErrPacketTooSmall = errors.New("packet too small") + ErrInvalidCRC = errors.New("invalid CRC") + ErrInvalidOpcode = errors.New("invalid opcode") +) + +// ProtocolPacket represents a low-level UDP protocol packet with opcode and payload +type ProtocolPacket struct { + Opcode uint8 // Protocol operation code (1-2 bytes when serialized) + Data []byte // Packet payload data + Raw []byte // Original raw packet data for debugging +} + +// ApplicationPacket represents a higher-level game application packet +type ApplicationPacket struct { + Opcode uint16 // Application-level operation code + Data []byte // Application payload data +} + +// ParseProtocolPacket parses raw UDP data into a ProtocolPacket +// Handles variable opcode sizing and CRC validation based on EQ2 protocol +func ParseProtocolPacket(data []byte) (*ProtocolPacket, error) { + if len(data) < 2 { + return nil, ErrPacketTooSmall + } + + var opcode uint8 + var dataStart int + + // EQ2 protocol uses 1-byte opcodes normally, 2-byte for opcodes >= 0xFF + // When opcode >= 0xFF, it's prefixed with 0x00 + if data[0] == 0x00 && len(data) > 2 { + opcode = data[1] + dataStart = 2 + } else { + opcode = data[0] + dataStart = 1 + } + + // Extract payload, handling CRC for non-session packets + var payload []byte + if requiresCRC(opcode) { + if len(data) < dataStart+2 { + return nil, ErrPacketTooSmall + } + + // Payload excludes the 2-byte CRC suffix + payload = data[dataStart : len(data)-2] + + // Validate CRC on the entire packet from beginning + if !ValidateCRC(data) { + return nil, fmt.Errorf("%w for opcode 0x%02X", ErrInvalidCRC, opcode) + } + } else { + payload = data[dataStart:] + } + + return &ProtocolPacket{ + Opcode: opcode, + Data: payload, + Raw: data, + }, nil +} + +// Serialize converts ProtocolPacket back to wire format with proper opcode encoding and CRC +func (p *ProtocolPacket) Serialize() []byte { + var result []byte + + // Handle variable opcode encoding + if p.Opcode == 0xFF { + // 2-byte opcode format: [0x00][actual_opcode][data] + result = make([]byte, 2+len(p.Data)) + result[0] = 0x00 + result[1] = p.Opcode + copy(result[2:], p.Data) + } else { + // 1-byte opcode format: [opcode][data] + result = make([]byte, 1+len(p.Data)) + result[0] = p.Opcode + copy(result[1:], p.Data) + } + + // Add CRC for packets that require it + if requiresCRC(p.Opcode) { + result = AppendCRC(result) + } + + return result +} + +// String provides human-readable representation for debugging +func (p *ProtocolPacket) String() string { + return fmt.Sprintf("ProtocolPacket{Opcode: 0x%02X, DataLen: %d}", p.Opcode, len(p.Data)) +} + +// ParseApplicationPacket parses application-level packet from decrypted/decompressed data +func ParseApplicationPacket(data []byte) (*ApplicationPacket, error) { + if len(data) < 2 { + return nil, errors.New("application packet requires at least 2 bytes for opcode") + } + + // Application opcodes are always little-endian 16-bit values + opcode := binary.LittleEndian.Uint16(data[0:2]) + + return &ApplicationPacket{ + Opcode: opcode, + Data: data[2:], + }, nil +} + +// Serialize converts ApplicationPacket to byte array for transmission +func (p *ApplicationPacket) Serialize() []byte { + result := make([]byte, 2+len(p.Data)) + binary.LittleEndian.PutUint16(result[0:2], p.Opcode) + copy(result[2:], p.Data) + return result +} + +// String provides human-readable representation for debugging +func (p *ApplicationPacket) String() string { + return fmt.Sprintf("ApplicationPacket{Opcode: 0x%04X, DataLen: %d}", p.Opcode, len(p.Data)) +} + +// requiresCRC determines if a protocol opcode requires CRC validation +// Session control packets (SessionRequest, SessionResponse, OutOfSession) don't use CRC +func requiresCRC(opcode uint8) bool { + switch opcode { + case opcodes.OpSessionRequest, opcodes.OpSessionResponse, opcodes.OpOutOfSession: + return false + default: + return true + } +} + +// PacketCombiner groups small packets together to reduce UDP overhead +type PacketCombiner struct { + PendingPackets []*ProtocolPacket // Direct access to pending packets + MaxSize int // Direct access to max size +} + +// NewPacketCombiner creates a combiner with specified max size +func NewPacketCombiner(maxSize int) *PacketCombiner { + return &PacketCombiner{ + MaxSize: maxSize, + } +} + +// Add queues a packet for potential combining +func (pc *PacketCombiner) Add(packet *ProtocolPacket) { + pc.PendingPackets = append(pc.PendingPackets, packet) +} + +// Flush returns combined packets and clears the queue +func (pc *PacketCombiner) Flush() []*ProtocolPacket { + count := len(pc.PendingPackets) + if count == 0 { + return nil + } + + if count == 1 { + // Single packet - no combining needed + packet := pc.PendingPackets[0] + pc.Clear() + return []*ProtocolPacket{packet} + } + + // Combine multiple packets + combined := pc.combine() + pc.Clear() + return []*ProtocolPacket{combined} +} + +// combine merges all pending packets into a single combined packet +func (pc *PacketCombiner) combine() *ProtocolPacket { + var buf bytes.Buffer + + for _, packet := range pc.PendingPackets { + serialized := packet.Serialize() + pc.writeSizeHeader(&buf, len(serialized)) + buf.Write(serialized) + } + + return &ProtocolPacket{ + Opcode: opcodes.OpCombined, + Data: buf.Bytes(), + } +} + +// writeSizeHeader writes packet size using variable-length encoding +func (pc *PacketCombiner) writeSizeHeader(buf *bytes.Buffer, size int) { + if size >= 255 { + // Large packet - use 3-byte header [0xFF][low][high] + buf.WriteByte(0xFF) + buf.WriteByte(byte(size)) + buf.WriteByte(byte(size >> 8)) + } else { + // Small packet - use 1-byte header + buf.WriteByte(byte(size)) + } +} + +// ShouldCombine determines if packets should be combined based on total size +func (pc *PacketCombiner) ShouldCombine() bool { + if len(pc.PendingPackets) < 2 { + return false + } + + totalSize := 0 + for _, packet := range pc.PendingPackets { + serialized := packet.Serialize() + totalSize += len(serialized) + + // Add size header overhead + if len(serialized) >= 255 { + totalSize += 3 + } else { + totalSize += 1 + } + } + + return totalSize <= pc.MaxSize +} + +// Clear removes all pending packets +func (pc *PacketCombiner) Clear() { + pc.PendingPackets = pc.PendingPackets[:0] // Reuse slice capacity +} + +// ParseCombinedPacket splits combined packet into individual packets +func ParseCombinedPacket(data []byte) ([]*ProtocolPacket, error) { + var packets []*ProtocolPacket + offset := 0 + + for offset < len(data) { + size, headerSize, err := readSizeHeader(data, offset) + if err != nil { + break + } + + offset += headerSize + + if offset+size > len(data) { + break // Incomplete packet + } + + // Parse individual packet + packetData := data[offset : offset+size] + if packet, err := ParseProtocolPacket(packetData); err == nil { + packets = append(packets, packet) + } + + offset += size + } + + return packets, nil +} + +// readSizeHeader reads variable-length size header +func readSizeHeader(data []byte, offset int) (size, headerSize int, err error) { + if offset >= len(data) { + return 0, 0, errors.New("insufficient data") + } + + if data[offset] == 0xFF { + // 3-byte size header + if offset+2 >= len(data) { + return 0, 0, errors.New("insufficient data for 3-byte header") + } + size = int(data[offset+1]) | (int(data[offset+2]) << 8) + headerSize = 3 + } else { + // 1-byte size header + size = int(data[offset]) + headerSize = 1 + } + + return size, headerSize, nil +} + +// ValidateCombinedPacket checks if combined packet data is well-formed +func ValidateCombinedPacket(data []byte) error { + offset := 0 + count := 0 + + for offset < len(data) { + size, headerSize, err := readSizeHeader(data, offset) + if err != nil { + return err + } + + offset += headerSize + + if offset+size > len(data) { + return errors.New("packet extends beyond data boundary") + } + + offset += size + count++ + + if count > 100 { // Sanity check + return errors.New("too many packets in combined packet") + } + } + + return nil +} diff --git a/internal/udp/fragment.go b/internal/udp/reliability.go similarity index 62% rename from internal/udp/fragment.go rename to internal/udp/reliability.go index d040e10..be2c604 100644 --- a/internal/udp/fragment.go +++ b/internal/udp/reliability.go @@ -6,12 +6,121 @@ import ( "errors" "fmt" "sort" + "sync" + "time" ) -// FragmentManager handles packet fragmentation and reassembly -type FragmentManager struct { - fragments map[uint16]*FragmentGroup // Active fragment groups by base sequence - maxLength uint32 // Maximum packet size before fragmentation +// Common errors for reliability layer +var ( + ErrFragmentTimeout = errors.New("fragment timeout") + ErrOrphanedFragment = errors.New("orphaned fragment") +) + +// RetransmitEntry tracks a packet awaiting acknowledgment +type RetransmitEntry struct { + Packet *ProtocolPacket // The packet to retransmit + Sequence uint16 // Packet sequence number + Timestamp time.Time // When packet was last sent + Attempts int // Number of transmission attempts +} + +// RetransmitQueue manages reliable packet delivery with exponential backoff +type RetransmitQueue struct { + entries map[uint16]*RetransmitEntry // Pending packets by sequence + mutex sync.RWMutex // Thread-safe access + baseTimeout time.Duration // Base retransmission timeout + maxAttempts int // Maximum retry attempts + maxTimeout time.Duration // Maximum timeout cap +} + +// NewRetransmitQueue creates a queue with specified settings +func NewRetransmitQueue(baseTimeout, maxTimeout time.Duration, maxAttempts int) *RetransmitQueue { + return &RetransmitQueue{ + entries: make(map[uint16]*RetransmitEntry), + baseTimeout: baseTimeout, + maxAttempts: maxAttempts, + maxTimeout: maxTimeout, + } +} + +// Add queues a packet for potential retransmission +func (rq *RetransmitQueue) Add(packet *ProtocolPacket, sequence uint16) { + rq.mutex.Lock() + defer rq.mutex.Unlock() + + rq.entries[sequence] = &RetransmitEntry{ + Packet: packet, + Sequence: sequence, + Timestamp: time.Now(), + Attempts: 1, + } +} + +// Acknowledge removes a packet from the retransmit queue +func (rq *RetransmitQueue) Acknowledge(sequence uint16) bool { + rq.mutex.Lock() + defer rq.mutex.Unlock() + + _, existed := rq.entries[sequence] + delete(rq.entries, sequence) + return existed +} + +// GetExpired returns packets that need retransmission +func (rq *RetransmitQueue) GetExpired() []*RetransmitEntry { + rq.mutex.Lock() + defer rq.mutex.Unlock() + + now := time.Now() + var expired []*RetransmitEntry + + for seq, entry := range rq.entries { + timeout := rq.calculateTimeout(entry.Attempts) + + if now.Sub(entry.Timestamp) > timeout { + if entry.Attempts >= rq.maxAttempts { + // Give up after max attempts + delete(rq.entries, seq) + } else { + // Schedule for retransmission + entry.Attempts++ + entry.Timestamp = now + expired = append(expired, entry) + } + } + } + + return expired +} + +// calculateTimeout computes timeout with exponential backoff +func (rq *RetransmitQueue) calculateTimeout(attempts int) time.Duration { + timeout := rq.baseTimeout * time.Duration(attempts*attempts) // Quadratic backoff + if timeout > rq.maxTimeout { + timeout = rq.maxTimeout + } + return timeout +} + +// Clear removes all pending packets +func (rq *RetransmitQueue) Clear() { + rq.mutex.Lock() + defer rq.mutex.Unlock() + rq.entries = make(map[uint16]*RetransmitEntry) +} + +// Size returns the number of pending packets +func (rq *RetransmitQueue) Size() int { + rq.mutex.RLock() + defer rq.mutex.RUnlock() + return len(rq.entries) +} + +// IsEmpty returns true if no packets are pending +func (rq *RetransmitQueue) IsEmpty() bool { + rq.mutex.RLock() + defer rq.mutex.RUnlock() + return len(rq.entries) == 0 } // FragmentGroup tracks fragments belonging to the same original packet @@ -29,6 +138,12 @@ type FragmentPiece struct { IsFirst bool // Whether this is the first fragment } +// FragmentManager handles packet fragmentation and reassembly +type FragmentManager struct { + fragments map[uint16]*FragmentGroup // Active fragment groups by base sequence + maxLength uint32 // Maximum packet size before fragmentation +} + // NewFragmentManager creates a manager with specified maximum packet length func NewFragmentManager(maxLength uint32) *FragmentManager { return &FragmentManager{ @@ -86,7 +201,7 @@ func (fm *FragmentManager) FragmentPacket(data []byte, startSeq uint16) []*Proto // ProcessFragment handles incoming fragments and returns complete packet when ready func (fm *FragmentManager) ProcessFragment(packet *ProtocolPacket) ([]byte, bool, error) { if len(packet.Data) < 2 { - return nil, false, errors.New("fragment too small") + return nil, false, ErrPacketTooSmall } seq := binary.BigEndian.Uint16(packet.Data[0:2]) @@ -118,7 +233,7 @@ func (fm *FragmentManager) ProcessFragment(packet *ProtocolPacket) ([]byte, bool fragment.Data = packet.Data[2:] group := fm.findFragmentGroup(seq) if group == nil { - return nil, false, errors.New("orphaned fragment") + return nil, false, ErrOrphanedFragment } group.Fragments = append(group.Fragments, fragment) @@ -186,26 +301,7 @@ func (fm *FragmentManager) CleanupStale(maxAge uint16) { } } -// GetStats returns fragmentation statistics -func (fm *FragmentManager) GetStats() FragmentStats { - return FragmentStats{ - ActiveGroups: len(fm.fragments), - MaxLength: fm.maxLength, - } -} - -// FragmentStats contains fragmentation statistics -type FragmentStats struct { - ActiveGroups int // Number of incomplete fragment groups - MaxLength uint32 // Maximum packet length setting -} - // Clear removes all pending fragments func (fm *FragmentManager) Clear() { fm.fragments = make(map[uint16]*FragmentGroup) } - -// SetMaxLength updates the maximum packet length -func (fm *FragmentManager) SetMaxLength(maxLength uint32) { - fm.maxLength = maxLength -} diff --git a/internal/udp/retransmit.go b/internal/udp/retransmit.go deleted file mode 100644 index d125f26..0000000 --- a/internal/udp/retransmit.go +++ /dev/null @@ -1,190 +0,0 @@ -package udp - -import ( - "sync" - "time" -) - -// RetransmitEntry tracks a packet awaiting acknowledgment -type RetransmitEntry struct { - Packet *ProtocolPacket // The packet to retransmit - Sequence uint16 // Packet sequence number - Timestamp time.Time // When packet was last sent - Attempts int // Number of transmission attempts -} - -// RetransmitQueue manages reliable packet delivery with exponential backoff -type RetransmitQueue struct { - entries map[uint16]*RetransmitEntry // Pending packets by sequence - mutex sync.RWMutex // Thread-safe access - baseTimeout time.Duration // Base retransmission timeout - maxAttempts int // Maximum retry attempts - maxTimeout time.Duration // Maximum timeout cap -} - -// NewRetransmitQueue creates a queue with default settings -func NewRetransmitQueue() *RetransmitQueue { - return &RetransmitQueue{ - entries: make(map[uint16]*RetransmitEntry), - baseTimeout: 500 * time.Millisecond, - maxAttempts: 5, - maxTimeout: 5 * time.Second, - } -} - -// NewRetransmitQueueWithConfig creates a queue with custom settings -func NewRetransmitQueueWithConfig(baseTimeout, maxTimeout time.Duration, maxAttempts int) *RetransmitQueue { - return &RetransmitQueue{ - entries: make(map[uint16]*RetransmitEntry), - baseTimeout: baseTimeout, - maxAttempts: maxAttempts, - maxTimeout: maxTimeout, - } -} - -// Add queues a packet for potential retransmission -func (rq *RetransmitQueue) Add(packet *ProtocolPacket, sequence uint16) { - rq.mutex.Lock() - defer rq.mutex.Unlock() - - rq.entries[sequence] = &RetransmitEntry{ - Packet: packet, - Sequence: sequence, - Timestamp: time.Now(), - Attempts: 1, - } -} - -// Acknowledge removes a packet from the retransmit queue -func (rq *RetransmitQueue) Acknowledge(sequence uint16) bool { - rq.mutex.Lock() - defer rq.mutex.Unlock() - - _, existed := rq.entries[sequence] - delete(rq.entries, sequence) - return existed -} - -// GetExpired returns packets that need retransmission -func (rq *RetransmitQueue) GetExpired() []*RetransmitEntry { - rq.mutex.Lock() - defer rq.mutex.Unlock() - - now := time.Now() - var expired []*RetransmitEntry - - for seq, entry := range rq.entries { - timeout := rq.calculateTimeout(entry.Attempts) - - if now.Sub(entry.Timestamp) > timeout { - if entry.Attempts >= rq.maxAttempts { - // Give up after max attempts - delete(rq.entries, seq) - } else { - // Schedule for retransmission - entry.Attempts++ - entry.Timestamp = now - expired = append(expired, entry) - } - } - } - - return expired -} - -// calculateTimeout computes timeout with exponential backoff -func (rq *RetransmitQueue) calculateTimeout(attempts int) time.Duration { - timeout := rq.baseTimeout * time.Duration(attempts*attempts) // Quadratic backoff - if timeout > rq.maxTimeout { - timeout = rq.maxTimeout - } - return timeout -} - -// Clear removes all pending packets -func (rq *RetransmitQueue) Clear() { - rq.mutex.Lock() - defer rq.mutex.Unlock() - rq.entries = make(map[uint16]*RetransmitEntry) -} - -// Size returns the number of pending packets -func (rq *RetransmitQueue) Size() int { - rq.mutex.RLock() - defer rq.mutex.RUnlock() - return len(rq.entries) -} - -// GetPendingSequences returns all sequence numbers awaiting acknowledgment -func (rq *RetransmitQueue) GetPendingSequences() []uint16 { - rq.mutex.RLock() - defer rq.mutex.RUnlock() - - sequences := make([]uint16, 0, len(rq.entries)) - for seq := range rq.entries { - sequences = append(sequences, seq) - } - return sequences -} - -// IsEmpty returns true if no packets are pending -func (rq *RetransmitQueue) IsEmpty() bool { - rq.mutex.RLock() - defer rq.mutex.RUnlock() - return len(rq.entries) == 0 -} - -// SetBaseTimeout updates the base retransmission timeout -func (rq *RetransmitQueue) SetBaseTimeout(timeout time.Duration) { - rq.mutex.Lock() - defer rq.mutex.Unlock() - rq.baseTimeout = timeout -} - -// SetMaxAttempts updates the maximum retry attempts -func (rq *RetransmitQueue) SetMaxAttempts(attempts int) { - rq.mutex.Lock() - defer rq.mutex.Unlock() - rq.maxAttempts = attempts -} - -// SetMaxTimeout updates the maximum timeout cap -func (rq *RetransmitQueue) SetMaxTimeout(timeout time.Duration) { - rq.mutex.Lock() - defer rq.mutex.Unlock() - rq.maxTimeout = timeout -} - -// GetStats returns retransmission statistics -func (rq *RetransmitQueue) GetStats() RetransmitStats { - rq.mutex.RLock() - defer rq.mutex.RUnlock() - - stats := RetransmitStats{ - PendingCount: len(rq.entries), - BaseTimeout: rq.baseTimeout, - MaxAttempts: rq.maxAttempts, - MaxTimeout: rq.maxTimeout, - } - - // Calculate attempt distribution - for _, entry := range rq.entries { - if entry.Attempts == 1 { - stats.FirstAttempts++ - } else { - stats.Retransmissions++ - } - } - - return stats -} - -// RetransmitStats contains retransmission queue statistics -type RetransmitStats struct { - PendingCount int // Total pending packets - FirstAttempts int // Packets on first attempt - Retransmissions int // Packets being retransmitted - BaseTimeout time.Duration // Base timeout setting - MaxAttempts int // Maximum attempts setting - MaxTimeout time.Duration // Maximum timeout setting -} diff --git a/internal/udp/crypto.go b/internal/udp/security.go similarity index 55% rename from internal/udp/crypto.go rename to internal/udp/security.go index 1e969cf..7ad7af5 100644 --- a/internal/udp/crypto.go +++ b/internal/udp/security.go @@ -2,9 +2,81 @@ package udp import ( "crypto/rc4" - "fmt" + "errors" ) +// EQ2EMu CRC32 polynomial (reversed) +const crcPolynomial = 0xEDB88320 + +// Pre-computed CRC32 lookup table for fast calculation +var crcTable [256]uint32 + +// init builds the CRC lookup table at package initialization +func init() { + for i := range crcTable { + crc := uint32(i) + for range 8 { + if crc&1 == 1 { + crc = (crc >> 1) ^ crcPolynomial + } else { + crc >>= 1 + } + } + crcTable[i] = crc + } +} + +// CalculateCRC32 computes CRC32 using EQ2EMu's algorithm +// Returns 16-bit value by truncating the upper bits +func CalculateCRC32(data []byte) uint16 { + crc := uint32(0xFFFFFFFF) + + // Use lookup table for efficient calculation + for _, b := range data { + crc = crcTable[byte(crc)^b] ^ (crc >> 8) + } + + // Return inverted result truncated to 16 bits + return uint16(^crc) +} + +// ValidateCRC checks if packet has valid CRC +// Expects CRC to be the last 2 bytes of data +func ValidateCRC(data []byte) bool { + if len(data) < 2 { + return false + } + + // Split payload and CRC + payload := data[:len(data)-2] + expectedCRC := uint16(data[len(data)-2]) | (uint16(data[len(data)-1]) << 8) + + // Calculate and compare + actualCRC := CalculateCRC32(payload) + return expectedCRC == actualCRC +} + +// AppendCRC adds 16-bit CRC to the end of data +func AppendCRC(data []byte) []byte { + crc := CalculateCRC32(data) + result := make([]byte, len(data)+2) + copy(result, data) + + // Append CRC in little-endian format + result[len(data)] = byte(crc) + result[len(data)+1] = byte(crc >> 8) + + return result +} + +// ValidateAndStrip validates CRC and returns data without CRC suffix +func ValidateAndStrip(data []byte) ([]byte, bool) { + if !ValidateCRC(data) { + return nil, false + } + return data[:len(data)-2], true +} + // Crypto handles RC4 encryption/decryption for EQ2EMu protocol type Crypto struct { clientCipher *rc4.Cipher // Cipher for decrypting client data @@ -15,27 +87,25 @@ type Crypto struct { // NewCrypto creates a new crypto instance with encryption disabled func NewCrypto() *Crypto { - return &Crypto{ - encrypted: false, - } + return &Crypto{} } // SetKey initializes RC4 encryption with the given key // Creates separate ciphers for client and server with 20-byte priming func (c *Crypto) SetKey(key []byte) error { if len(key) == 0 { - return fmt.Errorf("encryption key cannot be empty") + return errors.New("encryption key cannot be empty") } // Create separate RC4 ciphers for bidirectional communication clientCipher, err := rc4.NewCipher(key) if err != nil { - return fmt.Errorf("failed to create client cipher: %w", err) + return err } serverCipher, err := rc4.NewCipher(key) if err != nil { - return fmt.Errorf("failed to create server cipher: %w", err) + return err } // Prime both ciphers with 20 dummy bytes per EQ2EMu protocol @@ -98,12 +168,3 @@ func (c *Crypto) Reset() { c.key = nil c.encrypted = false } - -// Clone creates a copy of the crypto instance with the same key -func (c *Crypto) Clone() (*Crypto, error) { - newCrypto := NewCrypto() - if c.encrypted && c.key != nil { - return newCrypto, newCrypto.SetKey(c.key) - } - return newCrypto, nil -} diff --git a/internal/udp/server.go b/internal/udp/server.go index 3af785c..6b9bf94 100644 --- a/internal/udp/server.go +++ b/internal/udp/server.go @@ -14,40 +14,16 @@ type Server struct { mutex sync.RWMutex // Protects connections map handler PacketHandler // Application packet handler running bool // Server running state - - // Configuration - maxConnections int // Maximum concurrent connections - timeout time.Duration // Connection timeout duration + config Config // Server configuration } -// PacketHandler processes application-level packets for connections -type PacketHandler interface { - HandlePacket(conn *Connection, packet *ApplicationPacket) -} - -// ServerConfig holds server configuration options -type ServerConfig struct { - MaxConnections int // Maximum concurrent connections (default: 1000) - Timeout time.Duration // Connection timeout (default: 45s) - BufferSize int // UDP receive buffer size (default: 8192) -} - -// DefaultServerConfig returns sensible default configuration -func DefaultServerConfig() ServerConfig { - return ServerConfig{ - MaxConnections: 1000, - Timeout: 45 * time.Second, - BufferSize: 8192, +// NewServer creates a UDP server with simplified configuration +func NewServer(addr string, handler PacketHandler, config ...Config) (*Server, error) { + cfg := DefaultConfig() + if len(config) > 0 { + cfg = config[0] } -} -// NewServer creates a new UDP server instance -func NewServer(addr string, handler PacketHandler) (*Server, error) { - return NewServerWithConfig(addr, handler, DefaultServerConfig()) -} - -// NewServerWithConfig creates a server with custom configuration -func NewServerWithConfig(addr string, handler PacketHandler, config ServerConfig) (*Server, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, fmt.Errorf("invalid UDP address %s: %w", addr, err) @@ -59,17 +35,16 @@ func NewServerWithConfig(addr string, handler PacketHandler, config ServerConfig } // Set socket buffer size for better performance - if config.BufferSize > 0 { - conn.SetReadBuffer(config.BufferSize) - conn.SetWriteBuffer(config.BufferSize) + if cfg.BufferSize > 0 { + conn.SetReadBuffer(cfg.BufferSize) + conn.SetWriteBuffer(cfg.BufferSize) } return &Server{ - conn: conn, - connections: make(map[string]*Connection), - handler: handler, - maxConnections: config.MaxConnections, - timeout: config.Timeout, + conn: conn, + connections: make(map[string]*Connection), + handler: handler, + config: cfg, }, nil } @@ -127,12 +102,12 @@ func (s *Server) handleIncomingPacket(data []byte, addr *net.UDPAddr) { conn, exists := s.connections[connKey] if !exists { // Check connection limit - if len(s.connections) >= s.maxConnections { + if len(s.connections) >= s.config.MaxConnections { s.mutex.Unlock() return // Drop packet if at capacity } - conn = NewConnection(addr, s.conn, s.handler) + conn = NewConnection(addr, s.conn, s.handler, s.config) conn.StartRetransmitLoop() s.connections[connKey] = conn } @@ -163,7 +138,7 @@ func (s *Server) cleanupTimedOutConnections() { defer s.mutex.Unlock() for key, conn := range s.connections { - if conn.IsTimedOut(s.timeout) { + if conn.IsTimedOut() { conn.Close() delete(s.connections, key) } @@ -219,19 +194,6 @@ func (s *Server) DisconnectClient(addr string) bool { return false } -// GetStats returns server statistics -func (s *Server) GetStats() ServerStats { - s.mutex.RLock() - defer s.mutex.RUnlock() - - return ServerStats{ - ConnectionCount: len(s.connections), - MaxConnections: s.maxConnections, - Running: s.running, - Timeout: s.timeout, - } -} - // ServerStats contains server runtime statistics type ServerStats struct { ConnectionCount int // Current number of connections @@ -240,16 +202,29 @@ type ServerStats struct { Timeout time.Duration // Connection timeout setting } +// GetStats returns server statistics +func (s *Server) GetStats() ServerStats { + s.mutex.RLock() + defer s.mutex.RUnlock() + + return ServerStats{ + ConnectionCount: len(s.connections), + MaxConnections: s.config.MaxConnections, + Running: s.running, + Timeout: s.config.Timeout, + } +} + // SetConnectionLimit updates the maximum connection limit func (s *Server) SetConnectionLimit(limit int) { s.mutex.Lock() defer s.mutex.Unlock() - s.maxConnections = limit + s.config.MaxConnections = limit } // SetTimeout updates the connection timeout duration func (s *Server) SetTimeout(timeout time.Duration) { s.mutex.Lock() defer s.mutex.Unlock() - s.timeout = timeout + s.config.Timeout = timeout } diff --git a/internal/udp/udp_test.go b/internal/udp/udp_test.go index 7b9b574..fa742a2 100644 --- a/internal/udp/udp_test.go +++ b/internal/udp/udp_test.go @@ -49,10 +49,14 @@ func (h *TestHandler) Clear() { h.receivedPackets = nil } +// Simple handler function for testing +func testHandler(conn *Connection, packet *ApplicationPacket) { + fmt.Printf("Test handler received packet opcode: 0x%04X\n", packet.Opcode) +} + // TestServer tests basic server creation and startup func TestServer(t *testing.T) { - handler := &TestHandler{} - server, err := NewServer(":9999", handler) + server, err := NewServer(":9999", testHandler) if err != nil { t.Fatalf("Failed to create server: %v", err) } @@ -77,14 +81,12 @@ func TestServer(t *testing.T) { // TestServerConfig tests server configuration options func TestServerConfig(t *testing.T) { - handler := &TestHandler{} - config := ServerConfig{ - MaxConnections: 10, - Timeout: 30 * time.Second, - BufferSize: 4096, - } + config := DefaultConfig() + config.MaxConnections = 10 + config.Timeout = 30 * time.Second + config.BufferSize = 4096 - server, err := NewServerWithConfig(":9998", handler, config) + server, err := NewServer(":9998", testHandler, config) if err != nil { t.Fatalf("Failed to create server with config: %v", err) } @@ -217,7 +219,8 @@ func TestCrypto(t *testing.T) { // TestRetransmitQueue tests packet retransmission logic func TestRetransmitQueue(t *testing.T) { - rq := NewRetransmitQueue() + config := DefaultConfig() + rq := NewRetransmitQueue(config.RetransmitBase, config.RetransmitMax, config.RetransmitAttempts) packet := &ProtocolPacket{ Opcode: opcodes.OpPacket, @@ -296,21 +299,21 @@ func TestFragmentation(t *testing.T) { // TestPacketCombining tests packet combination functionality func TestPacketCombining(t *testing.T) { - combiner := NewPacketCombiner() + combiner := NewPacketCombiner(256) // Add small packets - use session opcodes that don't require CRC packet1 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test1")} packet2 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test2")} - combiner.AddPacket(packet1) - combiner.AddPacket(packet2) + combiner.Add(packet1) + combiner.Add(packet2) - if combiner.GetPendingCount() != 2 { - t.Errorf("Expected 2 pending packets, got %d", combiner.GetPendingCount()) + if len(combiner.PendingPackets) != 2 { + t.Errorf("Expected 2 pending packets, got %d", len(combiner.PendingPackets)) } // Flush combined - combined := combiner.FlushCombined() + combined := combiner.Flush() if len(combined) != 1 { t.Errorf("Expected 1 combined packet, got %d", len(combined)) } @@ -330,12 +333,12 @@ func TestPacketCombining(t *testing.T) { // TestConnection tests basic connection functionality func TestConnection(t *testing.T) { - handler := &TestHandler{} + config := DefaultConfig() addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") conn, _ := net.ListenUDP("udp", addr) defer conn.Close() - connection := NewConnection(addr, conn, handler) + connection := NewConnection(addr, conn, testHandler, config) if connection.GetState() != StateClosed { t.Error("New connection should be in closed state") @@ -346,9 +349,37 @@ func TestConnection(t *testing.T) { t.Error("New connection should have session ID 0") } - // Test timeout - if !connection.IsTimedOut(time.Nanosecond) { - t.Error("New connection should be timed out with very short timeout") + // Test timeout with very short timeout config + shortConfig := DefaultConfig() + shortConfig.Timeout = time.Nanosecond + shortConnection := NewConnection(addr, conn, testHandler, shortConfig) + + // Wait a bit to ensure timeout + time.Sleep(time.Millisecond) + + if !shortConnection.IsTimedOut() { + t.Error("Connection should be timed out with very short timeout") + } +} + +// TestDefaultConfig tests the default configuration +func TestDefaultConfig(t *testing.T) { + config := DefaultConfig() + + if config.MaxConnections != 1000 { + t.Errorf("Expected MaxConnections 1000, got %d", config.MaxConnections) + } + if config.Timeout != 45*time.Second { + t.Errorf("Expected Timeout 45s, got %v", config.Timeout) + } + if config.MaxPacketSize != 512 { + t.Errorf("Expected MaxPacketSize 512, got %d", config.MaxPacketSize) + } + if !config.EnableCompression { + t.Error("Expected compression to be enabled by default") + } + if !config.EnableEncryption { + t.Error("Expected encryption to be enabled by default") } } @@ -398,8 +429,7 @@ func BenchmarkEncryption(b *testing.B) { // TestIntegration performs a basic integration test func TestIntegration(t *testing.T) { - handler := &TestHandler{} - server, err := NewServer(":0", handler) // Use any available port + server, err := NewServer(":0", testHandler) // Use any available port if err != nil { t.Fatalf("Failed to create server: %v", err) } @@ -420,3 +450,28 @@ func TestIntegration(t *testing.T) { t.Errorf("Expected 0 connections, got %d", stats.ConnectionCount) } } + +// TestDirectFieldAccess tests that we can access fields directly +func TestDirectFieldAccess(t *testing.T) { + // Test PacketCombiner direct access + combiner := NewPacketCombiner(256) + combiner.MaxSize = 512 // Direct field modification + + if combiner.MaxSize != 512 { + t.Errorf("Expected MaxSize 512, got %d", combiner.MaxSize) + } + + // Test adding packets and accessing them directly + packet := &ProtocolPacket{Opcode: opcodes.OpKeepAlive, Data: []byte("test")} + combiner.Add(packet) + + if len(combiner.PendingPackets) != 1 { + t.Errorf("Expected 1 pending packet, got %d", len(combiner.PendingPackets)) + } + + // Direct access to pending packets + firstPacket := combiner.PendingPackets[0] + if firstPacket.Opcode != opcodes.OpKeepAlive { + t.Errorf("Expected OpKeepAlive, got 0x%02X", firstPacket.Opcode) + } +}