simplify udp server

This commit is contained in:
Sky Johnson 2025-07-21 23:51:35 -05:00
parent b4e3e3f4e7
commit 5b6a4ff114
10 changed files with 732 additions and 798 deletions

View File

@ -1,248 +0,0 @@
package udp
import (
"bytes"
"eq2emu/internal/opcodes"
"errors"
)
// PacketCombiner groups small packets together to reduce UDP overhead
type PacketCombiner struct {
pendingPackets []*ProtocolPacket // Packets awaiting combination
maxSize int // Maximum combined packet size
timeout int // Combination timeout in milliseconds
}
// NewPacketCombiner creates a combiner with default settings
func NewPacketCombiner() *PacketCombiner {
return &PacketCombiner{
maxSize: 256, // Default size threshold for combining
timeout: 10, // Default timeout in ms
}
}
// NewPacketCombinerWithConfig creates a combiner with custom settings
func NewPacketCombinerWithConfig(maxSize, timeout int) *PacketCombiner {
return &PacketCombiner{
maxSize: maxSize,
timeout: timeout,
}
}
// AddPacket queues a packet for potential combining
func (pc *PacketCombiner) AddPacket(packet *ProtocolPacket) {
pc.pendingPackets = append(pc.pendingPackets, packet)
}
// FlushCombined returns combined packets and clears the queue
func (pc *PacketCombiner) FlushCombined() []*ProtocolPacket {
if len(pc.pendingPackets) == 0 {
return nil
}
if len(pc.pendingPackets) == 1 {
// Single packet - no combining needed
packet := pc.pendingPackets[0]
pc.pendingPackets = nil
return []*ProtocolPacket{packet}
}
// Combine multiple packets
combined := pc.combineProtocolPackets(pc.pendingPackets)
pc.pendingPackets = nil
return []*ProtocolPacket{combined}
}
// combineProtocolPackets merges multiple packets into a single combined packet
func (pc *PacketCombiner) combineProtocolPackets(packets []*ProtocolPacket) *ProtocolPacket {
var buf bytes.Buffer
for _, packet := range packets {
serialized := packet.Serialize()
pc.writeSizeHeader(&buf, len(serialized))
buf.Write(serialized)
}
return &ProtocolPacket{
Opcode: opcodes.OpCombined,
Data: buf.Bytes(),
}
}
// writeSizeHeader writes packet size using variable-length encoding
func (pc *PacketCombiner) writeSizeHeader(buf *bytes.Buffer, size int) {
if size >= 255 {
// Large packet - use 3-byte header [0xFF][low][high]
buf.WriteByte(0xFF)
buf.WriteByte(byte(size))
buf.WriteByte(byte(size >> 8))
} else {
// Small packet - use 1-byte header
buf.WriteByte(byte(size))
}
}
// ParseCombinedPacket splits combined packet into individual packets
func ParseCombinedPacket(data []byte) ([]*ProtocolPacket, error) {
var packets []*ProtocolPacket
offset := 0
for offset < len(data) {
size, headerSize, err := readSizeHeader(data, offset)
if err != nil {
break
}
offset += headerSize
if offset+size > len(data) {
break // Incomplete packet
}
// Parse individual packet
packetData := data[offset : offset+size]
if packet, err := ParseProtocolPacket(packetData); err == nil {
packets = append(packets, packet)
}
offset += size
}
return packets, nil
}
// readSizeHeader reads variable-length size header
func readSizeHeader(data []byte, offset int) (size, headerSize int, err error) {
if offset >= len(data) {
return 0, 0, errors.New("insufficient data for size header")
}
if data[offset] == 0xFF {
// 3-byte size header
if offset+2 >= len(data) {
return 0, 0, errors.New("insufficient data for 3-byte size header")
}
size = int(data[offset+1]) | (int(data[offset+2]) << 8)
headerSize = 3
} else {
// 1-byte size header
size = int(data[offset])
headerSize = 1
}
return size, headerSize, nil
}
// ShouldCombine determines if packets should be combined based on total size
func (pc *PacketCombiner) ShouldCombine() bool {
if len(pc.pendingPackets) < 2 {
return false
}
totalSize := 0
for _, packet := range pc.pendingPackets {
serialized := packet.Serialize()
totalSize += len(serialized)
// Add size header overhead
if len(serialized) >= 255 {
totalSize += 3
} else {
totalSize += 1
}
}
return totalSize <= pc.maxSize
}
// HasPendingPackets returns true if packets are waiting to be combined
func (pc *PacketCombiner) HasPendingPackets() bool {
return len(pc.pendingPackets) > 0
}
// GetPendingCount returns the number of packets waiting to be combined
func (pc *PacketCombiner) GetPendingCount() int {
return len(pc.pendingPackets)
}
// Clear removes all pending packets without combining
func (pc *PacketCombiner) Clear() {
pc.pendingPackets = nil
}
// SetMaxSize updates the maximum combined packet size
func (pc *PacketCombiner) SetMaxSize(maxSize int) {
pc.maxSize = maxSize
}
// SetTimeout updates the combination timeout
func (pc *PacketCombiner) SetTimeout(timeout int) {
pc.timeout = timeout
}
// GetStats returns packet combination statistics
func (pc *PacketCombiner) GetStats() CombinerStats {
return CombinerStats{
PendingCount: len(pc.pendingPackets),
MaxSize: pc.maxSize,
Timeout: pc.timeout,
}
}
// CombinerStats contains packet combiner statistics
type CombinerStats struct {
PendingCount int // Number of packets waiting to be combined
MaxSize int // Maximum combined packet size
Timeout int // Combination timeout in milliseconds
}
// EstimateCombinedSize calculates the size if current packets were combined
func (pc *PacketCombiner) EstimateCombinedSize() int {
if len(pc.pendingPackets) == 0 {
return 0
}
totalSize := 0
for _, packet := range pc.pendingPackets {
serialized := packet.Serialize()
packetSize := len(serialized)
totalSize += packetSize
// Add size header overhead
if packetSize >= 255 {
totalSize += 3
} else {
totalSize += 1
}
}
return totalSize
}
// ValidateCombinedPacket checks if combined packet data is well-formed
func ValidateCombinedPacket(data []byte) error {
offset := 0
count := 0
for offset < len(data) {
size, headerSize, err := readSizeHeader(data, offset)
if err != nil {
return err
}
offset += headerSize
if offset+size > len(data) {
return errors.New("packet extends beyond data boundary")
}
offset += size
count++
if count > 100 { // Sanity check
return errors.New("too many packets in combined packet")
}
}
return nil
}

View File

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"eq2emu/internal/opcodes" "eq2emu/internal/opcodes"
"errors"
"net" "net"
"sync" "sync"
"time" "time"
@ -19,11 +20,51 @@ const (
StateWaitClose // Waiting for close confirmation StateWaitClose // Waiting for close confirmation
) )
const ( // Common connection errors
DefaultWindowSize = 2048 // Default sliding window size for flow control var (
MaxPacketSize = 512 // Maximum packet size before fragmentation ErrSessionClosed = errors.New("session closed")
) )
// Config holds all UDP server and connection configuration
type Config struct {
// Server settings
MaxConnections int // Maximum concurrent connections
Timeout time.Duration // Connection timeout duration
BufferSize int // UDP socket buffer size
// Protocol settings
MaxPacketSize uint32 // Maximum packet size before fragmentation
WindowSize uint16 // Sliding window size for flow control
RetransmitBase time.Duration // Base retransmission timeout
RetransmitMax time.Duration // Maximum retransmission timeout
RetransmitAttempts int // Maximum retransmission attempts
CombineThreshold int // Packet combining size threshold
// Features
EnableCompression bool // Enable zlib compression
EnableEncryption bool // Enable RC4 encryption
}
// DefaultConfig returns sensible defaults for EQ2EMu protocol
func DefaultConfig() Config {
return Config{
MaxConnections: 1000,
Timeout: 45 * time.Second,
BufferSize: 8192,
MaxPacketSize: 512,
WindowSize: 2048,
RetransmitBase: 500 * time.Millisecond,
RetransmitMax: 5 * time.Second,
RetransmitAttempts: 5,
CombineThreshold: 256,
EnableCompression: true,
EnableEncryption: true,
}
}
// PacketHandler processes application-level packets
type PacketHandler func(*Connection, *ApplicationPacket)
// Connection manages a single client connection over UDP with reliability features // Connection manages a single client connection over UDP with reliability features
type Connection struct { type Connection struct {
// Network details // Network details
@ -43,7 +84,6 @@ type Connection struct {
// Sequence tracking for reliable delivery // Sequence tracking for reliable delivery
nextInSeq uint16 // Next expected incoming sequence number nextInSeq uint16 // Next expected incoming sequence number
nextOutSeq uint16 // Next outgoing sequence number nextOutSeq uint16 // Next outgoing sequence number
windowSize uint16 // Flow control window size
// Protocol components // Protocol components
retransmitQueue *RetransmitQueue // Handles packet retransmission retransmitQueue *RetransmitQueue // Handles packet retransmission
@ -53,25 +93,33 @@ type Connection struct {
crypto *Crypto // Handles encryption/decryption crypto *Crypto // Handles encryption/decryption
// Connection timing // Connection timing
lastPacketTime time.Time // Last received packet timestamp lastActivity time.Time // Last activity timestamp
lastAckTime time.Time // Last acknowledgment timestamp
// Configuration (embedded from server)
config Config
} }
// NewConnection creates a new connection instance with default settings // NewConnection creates a new connection instance with server configuration
func NewConnection(addr *net.UDPAddr, conn *net.UDPConn, handler PacketHandler) *Connection { func NewConnection(addr *net.UDPAddr, conn *net.UDPConn, handler PacketHandler, config Config) *Connection {
return &Connection{ return &Connection{
addr: addr, addr: addr,
conn: conn, conn: conn,
handler: handler, handler: handler,
state: StateClosed, state: StateClosed,
maxLength: MaxPacketSize, maxLength: config.MaxPacketSize,
windowSize: DefaultWindowSize, lastActivity: time.Now(),
lastPacketTime: time.Now(), config: config,
crypto: NewCrypto(),
retransmitQueue: NewRetransmitQueue(), // Initialize components with config values
fragmentMgr: NewFragmentManager(MaxPacketSize), retransmitQueue: NewRetransmitQueue(
combiner: NewPacketCombiner(), config.RetransmitBase,
outOfOrderMap: make(map[uint16]*ProtocolPacket), config.RetransmitMax,
config.RetransmitAttempts,
),
fragmentMgr: NewFragmentManager(config.MaxPacketSize),
combiner: NewPacketCombiner(config.CombineThreshold),
crypto: NewCrypto(),
outOfOrderMap: make(map[uint16]*ProtocolPacket),
} }
} }
@ -80,7 +128,7 @@ func (c *Connection) ProcessPacket(data []byte) {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
c.lastPacketTime = time.Now() c.lastActivity = time.Now()
packet, err := ParseProtocolPacket(data) packet, err := ParseProtocolPacket(data)
if err != nil { if err != nil {
@ -187,7 +235,7 @@ func (c *Connection) processInOrderPacket(seq uint16, payload []byte) {
// Process application data // Process application data
if appPacket, err := c.processApplicationData(payload); err == nil { if appPacket, err := c.processApplicationData(payload); err == nil {
c.handler.HandlePacket(c, appPacket) c.handler(c, appPacket)
} }
// Check for queued out-of-order packets that can now be processed // Check for queued out-of-order packets that can now be processed
@ -211,7 +259,7 @@ func (c *Connection) processQueuedPackets() {
c.sendAck(seq) c.sendAck(seq)
if appPacket, err := c.processApplicationData(payload); err == nil { if appPacket, err := c.processApplicationData(payload); err == nil {
c.handler.HandlePacket(c, appPacket) c.handler(c, appPacket)
} }
} }
} }
@ -221,7 +269,7 @@ func (c *Connection) processQueuedPackets() {
func (c *Connection) handleFragment(packet *ProtocolPacket) { func (c *Connection) handleFragment(packet *ProtocolPacket) {
if data, complete, err := c.fragmentMgr.ProcessFragment(packet); err == nil && complete { if data, complete, err := c.fragmentMgr.ProcessFragment(packet); err == nil && complete {
if appPacket, err := c.processApplicationData(data); err == nil { if appPacket, err := c.processApplicationData(data); err == nil {
c.handler.HandlePacket(c, appPacket) c.handler(c, appPacket)
} }
} }
} }
@ -243,7 +291,6 @@ func (c *Connection) handleAck(packet *ProtocolPacket) {
seq := binary.BigEndian.Uint16(packet.Data[0:2]) seq := binary.BigEndian.Uint16(packet.Data[0:2])
c.retransmitQueue.Acknowledge(seq) c.retransmitQueue.Acknowledge(seq)
c.lastAckTime = time.Now()
} }
// handleOutOfOrderAck processes out-of-order acknowledgments // handleOutOfOrderAck processes out-of-order acknowledgments
@ -290,14 +337,14 @@ func (c *Connection) SendPacket(packet *ApplicationPacket) {
// processOutboundData applies compression and encryption to outgoing data // processOutboundData applies compression and encryption to outgoing data
func (c *Connection) processOutboundData(data []byte) []byte { func (c *Connection) processOutboundData(data []byte) []byte {
// Compress large packets if compression is enabled // Compress large packets if compression is enabled
if c.compressed && len(data) > 128 { if c.config.EnableCompression && c.compressed && len(data) > 128 {
if compressed, err := Compress(data); err == nil { if compressed, err := Compress(data); err == nil {
data = compressed data = compressed
} }
} }
// Encrypt data if encryption is enabled // Encrypt data if encryption is enabled
if c.crypto.IsEncrypted() { if c.config.EnableEncryption && c.crypto.IsEncrypted() {
data = c.crypto.Encrypt(data) data = c.crypto.Encrypt(data)
} }
@ -307,12 +354,12 @@ func (c *Connection) processOutboundData(data []byte) []byte {
// processApplicationData decrypts and decompresses incoming application data // processApplicationData decrypts and decompresses incoming application data
func (c *Connection) processApplicationData(data []byte) (*ApplicationPacket, error) { func (c *Connection) processApplicationData(data []byte) (*ApplicationPacket, error) {
// Decrypt if encryption is enabled // Decrypt if encryption is enabled
if c.crypto.IsEncrypted() { if c.config.EnableEncryption && c.crypto.IsEncrypted() {
data = c.crypto.Decrypt(data) data = c.crypto.Decrypt(data)
} }
// Decompress if compression is enabled // Decompress if compression is enabled
if c.compressed && len(data) > 0 { if c.config.EnableCompression && c.compressed && len(data) > 0 {
var err error var err error
data, err = Decompress(data) data, err = Decompress(data)
if err != nil { if err != nil {
@ -454,6 +501,36 @@ func (c *Connection) StartRetransmitLoop() {
}() }()
} }
// Stats returns comprehensive connection statistics
type Stats struct {
// Connection info
State ConnectionState
SessionID uint32
LastActivity time.Time
// Queue stats
PendingRetransmits int
PendingFragments int
PendingCombined int
OutOfOrderCount int
}
// GetStats returns unified statistics
func (c *Connection) GetStats() Stats {
c.mutex.RLock()
defer c.mutex.RUnlock()
return Stats{
State: c.state,
SessionID: c.sessionID,
LastActivity: c.lastActivity,
PendingRetransmits: c.retransmitQueue.Size(),
PendingFragments: len(c.fragmentMgr.fragments),
PendingCombined: len(c.combiner.PendingPackets),
OutOfOrderCount: len(c.outOfOrderMap),
}
}
// GetState returns the current connection state (thread-safe) // GetState returns the current connection state (thread-safe)
func (c *Connection) GetState() ConnectionState { func (c *Connection) GetState() ConnectionState {
c.mutex.RLock() c.mutex.RLock()
@ -469,8 +546,8 @@ func (c *Connection) GetSessionID() uint32 {
} }
// IsTimedOut checks if connection has timed out // IsTimedOut checks if connection has timed out
func (c *Connection) IsTimedOut(timeout time.Duration) bool { func (c *Connection) IsTimedOut() bool {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
return time.Since(c.lastPacketTime) > timeout return time.Since(c.lastActivity) > c.config.Timeout
} }

View File

@ -1,73 +0,0 @@
package udp
// EQ2EMu uses a specific CRC32 polynomial (reversed)
const crcPolynomial = 0xEDB88320
// Pre-computed CRC32 lookup table for fast calculation
var crcTable [256]uint32
// init builds the CRC lookup table at package initialization
func init() {
for i := range crcTable {
crc := uint32(i)
for range 8 {
if crc&1 == 1 {
crc = (crc >> 1) ^ crcPolynomial
} else {
crc >>= 1
}
}
crcTable[i] = crc
}
}
// CalculateCRC32 computes CRC32 using EQ2EMu's algorithm
// Returns 16-bit value by truncating the upper bits
func CalculateCRC32(data []byte) uint16 {
crc := uint32(0xFFFFFFFF)
// Use lookup table for efficient calculation
for _, b := range data {
crc = crcTable[byte(crc)^b] ^ (crc >> 8)
}
// Return inverted result truncated to 16 bits
return uint16(^crc)
}
// ValidateCRC checks if packet has valid CRC
// Expects CRC to be the last 2 bytes of data
func ValidateCRC(data []byte) bool {
if len(data) < 2 {
return false
}
// Split payload and CRC
payload := data[:len(data)-2]
expectedCRC := uint16(data[len(data)-2]) | (uint16(data[len(data)-1]) << 8)
// Calculate and compare
actualCRC := CalculateCRC32(payload)
return expectedCRC == actualCRC
}
// AppendCRC adds 16-bit CRC to the end of data
func AppendCRC(data []byte) []byte {
crc := CalculateCRC32(data)
result := make([]byte, len(data)+2)
copy(result, data)
// Append CRC in little-endian format
result[len(data)] = byte(crc)
result[len(data)+1] = byte(crc >> 8)
return result
}
// ValidateAndStrip validates CRC and returns data without CRC suffix
func ValidateAndStrip(data []byte) ([]byte, bool) {
if !ValidateCRC(data) {
return nil, false
}
return data[:len(data)-2], true
}

View File

@ -1,136 +0,0 @@
package udp
import (
"encoding/binary"
"eq2emu/internal/opcodes"
"errors"
"fmt"
)
// ProtocolPacket represents a low-level UDP protocol packet with opcode and payload
type ProtocolPacket struct {
Opcode uint8 // Protocol operation code (1-2 bytes when serialized)
Data []byte // Packet payload data
Raw []byte // Original raw packet data for debugging
}
// ApplicationPacket represents a higher-level game application packet
type ApplicationPacket struct {
Opcode uint16 // Application-level operation code
Data []byte // Application payload data
}
// ParseProtocolPacket parses raw UDP data into a ProtocolPacket
// Handles variable opcode sizing and CRC validation based on EQ2 protocol
func ParseProtocolPacket(data []byte) (*ProtocolPacket, error) {
if len(data) < 2 {
return nil, errors.New("packet too small for valid protocol packet")
}
var opcode uint8
var dataStart int
// EQ2 protocol uses 1-byte opcodes normally, 2-byte for opcodes >= 0xFF
// When opcode >= 0xFF, it's prefixed with 0x00
if data[0] == 0x00 && len(data) > 2 {
opcode = data[1]
dataStart = 2
} else {
opcode = data[0]
dataStart = 1
}
// Extract payload, handling CRC for non-session packets
var payload []byte
if requiresCRC(opcode) {
if len(data) < dataStart+2 {
return nil, errors.New("packet too small for CRC validation")
}
// Payload excludes the 2-byte CRC suffix
payload = data[dataStart : len(data)-2]
// Validate CRC on the entire packet from beginning
if !ValidateCRC(data) {
return nil, fmt.Errorf("CRC validation failed for opcode 0x%02X", opcode)
}
} else {
payload = data[dataStart:]
}
return &ProtocolPacket{
Opcode: opcode,
Data: payload,
Raw: data,
}, nil
}
// Serialize converts ProtocolPacket back to wire format with proper opcode encoding and CRC
func (p *ProtocolPacket) Serialize() []byte {
var result []byte
// Handle variable opcode encoding
if p.Opcode == 0xFF {
// 2-byte opcode format: [0x00][actual_opcode][data]
result = make([]byte, 2+len(p.Data))
result[0] = 0x00
result[1] = p.Opcode
copy(result[2:], p.Data)
} else {
// 1-byte opcode format: [opcode][data]
result = make([]byte, 1+len(p.Data))
result[0] = p.Opcode
copy(result[1:], p.Data)
}
// Add CRC for packets that require it
if requiresCRC(p.Opcode) {
result = AppendCRC(result)
}
return result
}
// ParseApplicationPacket parses application-level packet from decrypted/decompressed data
func ParseApplicationPacket(data []byte) (*ApplicationPacket, error) {
if len(data) < 2 {
return nil, errors.New("application packet requires at least 2 bytes for opcode")
}
// Application opcodes are always little-endian 16-bit values
opcode := binary.LittleEndian.Uint16(data[0:2])
return &ApplicationPacket{
Opcode: opcode,
Data: data[2:],
}, nil
}
// Serialize converts ApplicationPacket to byte array for transmission
func (p *ApplicationPacket) Serialize() []byte {
result := make([]byte, 2+len(p.Data))
binary.LittleEndian.PutUint16(result[0:2], p.Opcode)
copy(result[2:], p.Data)
return result
}
// String provides human-readable representation for debugging
func (p *ProtocolPacket) String() string {
return fmt.Sprintf("ProtocolPacket{Opcode: 0x%02X, DataLen: %d}", p.Opcode, len(p.Data))
}
// String provides human-readable representation for debugging
func (p *ApplicationPacket) String() string {
return fmt.Sprintf("ApplicationPacket{Opcode: 0x%04X, DataLen: %d}", p.Opcode, len(p.Data))
}
// requiresCRC determines if a protocol opcode requires CRC validation
// Session control packets (SessionRequest, SessionResponse, OutOfSession) don't use CRC
func requiresCRC(opcode uint8) bool {
switch opcode {
case opcodes.OpSessionRequest, opcodes.OpSessionResponse, opcodes.OpOutOfSession:
return false
default:
return true
}
}

317
internal/udp/protocol.go Normal file
View File

@ -0,0 +1,317 @@
package udp
import (
"bytes"
"encoding/binary"
"eq2emu/internal/opcodes"
"errors"
"fmt"
)
// Common protocol errors
var (
ErrPacketTooSmall = errors.New("packet too small")
ErrInvalidCRC = errors.New("invalid CRC")
ErrInvalidOpcode = errors.New("invalid opcode")
)
// ProtocolPacket represents a low-level UDP protocol packet with opcode and payload
type ProtocolPacket struct {
Opcode uint8 // Protocol operation code (1-2 bytes when serialized)
Data []byte // Packet payload data
Raw []byte // Original raw packet data for debugging
}
// ApplicationPacket represents a higher-level game application packet
type ApplicationPacket struct {
Opcode uint16 // Application-level operation code
Data []byte // Application payload data
}
// ParseProtocolPacket parses raw UDP data into a ProtocolPacket
// Handles variable opcode sizing and CRC validation based on EQ2 protocol
func ParseProtocolPacket(data []byte) (*ProtocolPacket, error) {
if len(data) < 2 {
return nil, ErrPacketTooSmall
}
var opcode uint8
var dataStart int
// EQ2 protocol uses 1-byte opcodes normally, 2-byte for opcodes >= 0xFF
// When opcode >= 0xFF, it's prefixed with 0x00
if data[0] == 0x00 && len(data) > 2 {
opcode = data[1]
dataStart = 2
} else {
opcode = data[0]
dataStart = 1
}
// Extract payload, handling CRC for non-session packets
var payload []byte
if requiresCRC(opcode) {
if len(data) < dataStart+2 {
return nil, ErrPacketTooSmall
}
// Payload excludes the 2-byte CRC suffix
payload = data[dataStart : len(data)-2]
// Validate CRC on the entire packet from beginning
if !ValidateCRC(data) {
return nil, fmt.Errorf("%w for opcode 0x%02X", ErrInvalidCRC, opcode)
}
} else {
payload = data[dataStart:]
}
return &ProtocolPacket{
Opcode: opcode,
Data: payload,
Raw: data,
}, nil
}
// Serialize converts ProtocolPacket back to wire format with proper opcode encoding and CRC
func (p *ProtocolPacket) Serialize() []byte {
var result []byte
// Handle variable opcode encoding
if p.Opcode == 0xFF {
// 2-byte opcode format: [0x00][actual_opcode][data]
result = make([]byte, 2+len(p.Data))
result[0] = 0x00
result[1] = p.Opcode
copy(result[2:], p.Data)
} else {
// 1-byte opcode format: [opcode][data]
result = make([]byte, 1+len(p.Data))
result[0] = p.Opcode
copy(result[1:], p.Data)
}
// Add CRC for packets that require it
if requiresCRC(p.Opcode) {
result = AppendCRC(result)
}
return result
}
// String provides human-readable representation for debugging
func (p *ProtocolPacket) String() string {
return fmt.Sprintf("ProtocolPacket{Opcode: 0x%02X, DataLen: %d}", p.Opcode, len(p.Data))
}
// ParseApplicationPacket parses application-level packet from decrypted/decompressed data
func ParseApplicationPacket(data []byte) (*ApplicationPacket, error) {
if len(data) < 2 {
return nil, errors.New("application packet requires at least 2 bytes for opcode")
}
// Application opcodes are always little-endian 16-bit values
opcode := binary.LittleEndian.Uint16(data[0:2])
return &ApplicationPacket{
Opcode: opcode,
Data: data[2:],
}, nil
}
// Serialize converts ApplicationPacket to byte array for transmission
func (p *ApplicationPacket) Serialize() []byte {
result := make([]byte, 2+len(p.Data))
binary.LittleEndian.PutUint16(result[0:2], p.Opcode)
copy(result[2:], p.Data)
return result
}
// String provides human-readable representation for debugging
func (p *ApplicationPacket) String() string {
return fmt.Sprintf("ApplicationPacket{Opcode: 0x%04X, DataLen: %d}", p.Opcode, len(p.Data))
}
// requiresCRC determines if a protocol opcode requires CRC validation
// Session control packets (SessionRequest, SessionResponse, OutOfSession) don't use CRC
func requiresCRC(opcode uint8) bool {
switch opcode {
case opcodes.OpSessionRequest, opcodes.OpSessionResponse, opcodes.OpOutOfSession:
return false
default:
return true
}
}
// PacketCombiner groups small packets together to reduce UDP overhead
type PacketCombiner struct {
PendingPackets []*ProtocolPacket // Direct access to pending packets
MaxSize int // Direct access to max size
}
// NewPacketCombiner creates a combiner with specified max size
func NewPacketCombiner(maxSize int) *PacketCombiner {
return &PacketCombiner{
MaxSize: maxSize,
}
}
// Add queues a packet for potential combining
func (pc *PacketCombiner) Add(packet *ProtocolPacket) {
pc.PendingPackets = append(pc.PendingPackets, packet)
}
// Flush returns combined packets and clears the queue
func (pc *PacketCombiner) Flush() []*ProtocolPacket {
count := len(pc.PendingPackets)
if count == 0 {
return nil
}
if count == 1 {
// Single packet - no combining needed
packet := pc.PendingPackets[0]
pc.Clear()
return []*ProtocolPacket{packet}
}
// Combine multiple packets
combined := pc.combine()
pc.Clear()
return []*ProtocolPacket{combined}
}
// combine merges all pending packets into a single combined packet
func (pc *PacketCombiner) combine() *ProtocolPacket {
var buf bytes.Buffer
for _, packet := range pc.PendingPackets {
serialized := packet.Serialize()
pc.writeSizeHeader(&buf, len(serialized))
buf.Write(serialized)
}
return &ProtocolPacket{
Opcode: opcodes.OpCombined,
Data: buf.Bytes(),
}
}
// writeSizeHeader writes packet size using variable-length encoding
func (pc *PacketCombiner) writeSizeHeader(buf *bytes.Buffer, size int) {
if size >= 255 {
// Large packet - use 3-byte header [0xFF][low][high]
buf.WriteByte(0xFF)
buf.WriteByte(byte(size))
buf.WriteByte(byte(size >> 8))
} else {
// Small packet - use 1-byte header
buf.WriteByte(byte(size))
}
}
// ShouldCombine determines if packets should be combined based on total size
func (pc *PacketCombiner) ShouldCombine() bool {
if len(pc.PendingPackets) < 2 {
return false
}
totalSize := 0
for _, packet := range pc.PendingPackets {
serialized := packet.Serialize()
totalSize += len(serialized)
// Add size header overhead
if len(serialized) >= 255 {
totalSize += 3
} else {
totalSize += 1
}
}
return totalSize <= pc.MaxSize
}
// Clear removes all pending packets
func (pc *PacketCombiner) Clear() {
pc.PendingPackets = pc.PendingPackets[:0] // Reuse slice capacity
}
// ParseCombinedPacket splits combined packet into individual packets
func ParseCombinedPacket(data []byte) ([]*ProtocolPacket, error) {
var packets []*ProtocolPacket
offset := 0
for offset < len(data) {
size, headerSize, err := readSizeHeader(data, offset)
if err != nil {
break
}
offset += headerSize
if offset+size > len(data) {
break // Incomplete packet
}
// Parse individual packet
packetData := data[offset : offset+size]
if packet, err := ParseProtocolPacket(packetData); err == nil {
packets = append(packets, packet)
}
offset += size
}
return packets, nil
}
// readSizeHeader reads variable-length size header
func readSizeHeader(data []byte, offset int) (size, headerSize int, err error) {
if offset >= len(data) {
return 0, 0, errors.New("insufficient data")
}
if data[offset] == 0xFF {
// 3-byte size header
if offset+2 >= len(data) {
return 0, 0, errors.New("insufficient data for 3-byte header")
}
size = int(data[offset+1]) | (int(data[offset+2]) << 8)
headerSize = 3
} else {
// 1-byte size header
size = int(data[offset])
headerSize = 1
}
return size, headerSize, nil
}
// ValidateCombinedPacket checks if combined packet data is well-formed
func ValidateCombinedPacket(data []byte) error {
offset := 0
count := 0
for offset < len(data) {
size, headerSize, err := readSizeHeader(data, offset)
if err != nil {
return err
}
offset += headerSize
if offset+size > len(data) {
return errors.New("packet extends beyond data boundary")
}
offset += size
count++
if count > 100 { // Sanity check
return errors.New("too many packets in combined packet")
}
}
return nil
}

View File

@ -6,12 +6,121 @@ import (
"errors" "errors"
"fmt" "fmt"
"sort" "sort"
"sync"
"time"
) )
// FragmentManager handles packet fragmentation and reassembly // Common errors for reliability layer
type FragmentManager struct { var (
fragments map[uint16]*FragmentGroup // Active fragment groups by base sequence ErrFragmentTimeout = errors.New("fragment timeout")
maxLength uint32 // Maximum packet size before fragmentation ErrOrphanedFragment = errors.New("orphaned fragment")
)
// RetransmitEntry tracks a packet awaiting acknowledgment
type RetransmitEntry struct {
Packet *ProtocolPacket // The packet to retransmit
Sequence uint16 // Packet sequence number
Timestamp time.Time // When packet was last sent
Attempts int // Number of transmission attempts
}
// RetransmitQueue manages reliable packet delivery with exponential backoff
type RetransmitQueue struct {
entries map[uint16]*RetransmitEntry // Pending packets by sequence
mutex sync.RWMutex // Thread-safe access
baseTimeout time.Duration // Base retransmission timeout
maxAttempts int // Maximum retry attempts
maxTimeout time.Duration // Maximum timeout cap
}
// NewRetransmitQueue creates a queue with specified settings
func NewRetransmitQueue(baseTimeout, maxTimeout time.Duration, maxAttempts int) *RetransmitQueue {
return &RetransmitQueue{
entries: make(map[uint16]*RetransmitEntry),
baseTimeout: baseTimeout,
maxAttempts: maxAttempts,
maxTimeout: maxTimeout,
}
}
// Add queues a packet for potential retransmission
func (rq *RetransmitQueue) Add(packet *ProtocolPacket, sequence uint16) {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.entries[sequence] = &RetransmitEntry{
Packet: packet,
Sequence: sequence,
Timestamp: time.Now(),
Attempts: 1,
}
}
// Acknowledge removes a packet from the retransmit queue
func (rq *RetransmitQueue) Acknowledge(sequence uint16) bool {
rq.mutex.Lock()
defer rq.mutex.Unlock()
_, existed := rq.entries[sequence]
delete(rq.entries, sequence)
return existed
}
// GetExpired returns packets that need retransmission
func (rq *RetransmitQueue) GetExpired() []*RetransmitEntry {
rq.mutex.Lock()
defer rq.mutex.Unlock()
now := time.Now()
var expired []*RetransmitEntry
for seq, entry := range rq.entries {
timeout := rq.calculateTimeout(entry.Attempts)
if now.Sub(entry.Timestamp) > timeout {
if entry.Attempts >= rq.maxAttempts {
// Give up after max attempts
delete(rq.entries, seq)
} else {
// Schedule for retransmission
entry.Attempts++
entry.Timestamp = now
expired = append(expired, entry)
}
}
}
return expired
}
// calculateTimeout computes timeout with exponential backoff
func (rq *RetransmitQueue) calculateTimeout(attempts int) time.Duration {
timeout := rq.baseTimeout * time.Duration(attempts*attempts) // Quadratic backoff
if timeout > rq.maxTimeout {
timeout = rq.maxTimeout
}
return timeout
}
// Clear removes all pending packets
func (rq *RetransmitQueue) Clear() {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.entries = make(map[uint16]*RetransmitEntry)
}
// Size returns the number of pending packets
func (rq *RetransmitQueue) Size() int {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
return len(rq.entries)
}
// IsEmpty returns true if no packets are pending
func (rq *RetransmitQueue) IsEmpty() bool {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
return len(rq.entries) == 0
} }
// FragmentGroup tracks fragments belonging to the same original packet // FragmentGroup tracks fragments belonging to the same original packet
@ -29,6 +138,12 @@ type FragmentPiece struct {
IsFirst bool // Whether this is the first fragment IsFirst bool // Whether this is the first fragment
} }
// FragmentManager handles packet fragmentation and reassembly
type FragmentManager struct {
fragments map[uint16]*FragmentGroup // Active fragment groups by base sequence
maxLength uint32 // Maximum packet size before fragmentation
}
// NewFragmentManager creates a manager with specified maximum packet length // NewFragmentManager creates a manager with specified maximum packet length
func NewFragmentManager(maxLength uint32) *FragmentManager { func NewFragmentManager(maxLength uint32) *FragmentManager {
return &FragmentManager{ return &FragmentManager{
@ -86,7 +201,7 @@ func (fm *FragmentManager) FragmentPacket(data []byte, startSeq uint16) []*Proto
// ProcessFragment handles incoming fragments and returns complete packet when ready // ProcessFragment handles incoming fragments and returns complete packet when ready
func (fm *FragmentManager) ProcessFragment(packet *ProtocolPacket) ([]byte, bool, error) { func (fm *FragmentManager) ProcessFragment(packet *ProtocolPacket) ([]byte, bool, error) {
if len(packet.Data) < 2 { if len(packet.Data) < 2 {
return nil, false, errors.New("fragment too small") return nil, false, ErrPacketTooSmall
} }
seq := binary.BigEndian.Uint16(packet.Data[0:2]) seq := binary.BigEndian.Uint16(packet.Data[0:2])
@ -118,7 +233,7 @@ func (fm *FragmentManager) ProcessFragment(packet *ProtocolPacket) ([]byte, bool
fragment.Data = packet.Data[2:] fragment.Data = packet.Data[2:]
group := fm.findFragmentGroup(seq) group := fm.findFragmentGroup(seq)
if group == nil { if group == nil {
return nil, false, errors.New("orphaned fragment") return nil, false, ErrOrphanedFragment
} }
group.Fragments = append(group.Fragments, fragment) group.Fragments = append(group.Fragments, fragment)
@ -186,26 +301,7 @@ func (fm *FragmentManager) CleanupStale(maxAge uint16) {
} }
} }
// GetStats returns fragmentation statistics
func (fm *FragmentManager) GetStats() FragmentStats {
return FragmentStats{
ActiveGroups: len(fm.fragments),
MaxLength: fm.maxLength,
}
}
// FragmentStats contains fragmentation statistics
type FragmentStats struct {
ActiveGroups int // Number of incomplete fragment groups
MaxLength uint32 // Maximum packet length setting
}
// Clear removes all pending fragments // Clear removes all pending fragments
func (fm *FragmentManager) Clear() { func (fm *FragmentManager) Clear() {
fm.fragments = make(map[uint16]*FragmentGroup) fm.fragments = make(map[uint16]*FragmentGroup)
} }
// SetMaxLength updates the maximum packet length
func (fm *FragmentManager) SetMaxLength(maxLength uint32) {
fm.maxLength = maxLength
}

View File

@ -1,190 +0,0 @@
package udp
import (
"sync"
"time"
)
// RetransmitEntry tracks a packet awaiting acknowledgment
type RetransmitEntry struct {
Packet *ProtocolPacket // The packet to retransmit
Sequence uint16 // Packet sequence number
Timestamp time.Time // When packet was last sent
Attempts int // Number of transmission attempts
}
// RetransmitQueue manages reliable packet delivery with exponential backoff
type RetransmitQueue struct {
entries map[uint16]*RetransmitEntry // Pending packets by sequence
mutex sync.RWMutex // Thread-safe access
baseTimeout time.Duration // Base retransmission timeout
maxAttempts int // Maximum retry attempts
maxTimeout time.Duration // Maximum timeout cap
}
// NewRetransmitQueue creates a queue with default settings
func NewRetransmitQueue() *RetransmitQueue {
return &RetransmitQueue{
entries: make(map[uint16]*RetransmitEntry),
baseTimeout: 500 * time.Millisecond,
maxAttempts: 5,
maxTimeout: 5 * time.Second,
}
}
// NewRetransmitQueueWithConfig creates a queue with custom settings
func NewRetransmitQueueWithConfig(baseTimeout, maxTimeout time.Duration, maxAttempts int) *RetransmitQueue {
return &RetransmitQueue{
entries: make(map[uint16]*RetransmitEntry),
baseTimeout: baseTimeout,
maxAttempts: maxAttempts,
maxTimeout: maxTimeout,
}
}
// Add queues a packet for potential retransmission
func (rq *RetransmitQueue) Add(packet *ProtocolPacket, sequence uint16) {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.entries[sequence] = &RetransmitEntry{
Packet: packet,
Sequence: sequence,
Timestamp: time.Now(),
Attempts: 1,
}
}
// Acknowledge removes a packet from the retransmit queue
func (rq *RetransmitQueue) Acknowledge(sequence uint16) bool {
rq.mutex.Lock()
defer rq.mutex.Unlock()
_, existed := rq.entries[sequence]
delete(rq.entries, sequence)
return existed
}
// GetExpired returns packets that need retransmission
func (rq *RetransmitQueue) GetExpired() []*RetransmitEntry {
rq.mutex.Lock()
defer rq.mutex.Unlock()
now := time.Now()
var expired []*RetransmitEntry
for seq, entry := range rq.entries {
timeout := rq.calculateTimeout(entry.Attempts)
if now.Sub(entry.Timestamp) > timeout {
if entry.Attempts >= rq.maxAttempts {
// Give up after max attempts
delete(rq.entries, seq)
} else {
// Schedule for retransmission
entry.Attempts++
entry.Timestamp = now
expired = append(expired, entry)
}
}
}
return expired
}
// calculateTimeout computes timeout with exponential backoff
func (rq *RetransmitQueue) calculateTimeout(attempts int) time.Duration {
timeout := rq.baseTimeout * time.Duration(attempts*attempts) // Quadratic backoff
if timeout > rq.maxTimeout {
timeout = rq.maxTimeout
}
return timeout
}
// Clear removes all pending packets
func (rq *RetransmitQueue) Clear() {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.entries = make(map[uint16]*RetransmitEntry)
}
// Size returns the number of pending packets
func (rq *RetransmitQueue) Size() int {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
return len(rq.entries)
}
// GetPendingSequences returns all sequence numbers awaiting acknowledgment
func (rq *RetransmitQueue) GetPendingSequences() []uint16 {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
sequences := make([]uint16, 0, len(rq.entries))
for seq := range rq.entries {
sequences = append(sequences, seq)
}
return sequences
}
// IsEmpty returns true if no packets are pending
func (rq *RetransmitQueue) IsEmpty() bool {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
return len(rq.entries) == 0
}
// SetBaseTimeout updates the base retransmission timeout
func (rq *RetransmitQueue) SetBaseTimeout(timeout time.Duration) {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.baseTimeout = timeout
}
// SetMaxAttempts updates the maximum retry attempts
func (rq *RetransmitQueue) SetMaxAttempts(attempts int) {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.maxAttempts = attempts
}
// SetMaxTimeout updates the maximum timeout cap
func (rq *RetransmitQueue) SetMaxTimeout(timeout time.Duration) {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.maxTimeout = timeout
}
// GetStats returns retransmission statistics
func (rq *RetransmitQueue) GetStats() RetransmitStats {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
stats := RetransmitStats{
PendingCount: len(rq.entries),
BaseTimeout: rq.baseTimeout,
MaxAttempts: rq.maxAttempts,
MaxTimeout: rq.maxTimeout,
}
// Calculate attempt distribution
for _, entry := range rq.entries {
if entry.Attempts == 1 {
stats.FirstAttempts++
} else {
stats.Retransmissions++
}
}
return stats
}
// RetransmitStats contains retransmission queue statistics
type RetransmitStats struct {
PendingCount int // Total pending packets
FirstAttempts int // Packets on first attempt
Retransmissions int // Packets being retransmitted
BaseTimeout time.Duration // Base timeout setting
MaxAttempts int // Maximum attempts setting
MaxTimeout time.Duration // Maximum timeout setting
}

View File

@ -2,9 +2,81 @@ package udp
import ( import (
"crypto/rc4" "crypto/rc4"
"fmt" "errors"
) )
// EQ2EMu CRC32 polynomial (reversed)
const crcPolynomial = 0xEDB88320
// Pre-computed CRC32 lookup table for fast calculation
var crcTable [256]uint32
// init builds the CRC lookup table at package initialization
func init() {
for i := range crcTable {
crc := uint32(i)
for range 8 {
if crc&1 == 1 {
crc = (crc >> 1) ^ crcPolynomial
} else {
crc >>= 1
}
}
crcTable[i] = crc
}
}
// CalculateCRC32 computes CRC32 using EQ2EMu's algorithm
// Returns 16-bit value by truncating the upper bits
func CalculateCRC32(data []byte) uint16 {
crc := uint32(0xFFFFFFFF)
// Use lookup table for efficient calculation
for _, b := range data {
crc = crcTable[byte(crc)^b] ^ (crc >> 8)
}
// Return inverted result truncated to 16 bits
return uint16(^crc)
}
// ValidateCRC checks if packet has valid CRC
// Expects CRC to be the last 2 bytes of data
func ValidateCRC(data []byte) bool {
if len(data) < 2 {
return false
}
// Split payload and CRC
payload := data[:len(data)-2]
expectedCRC := uint16(data[len(data)-2]) | (uint16(data[len(data)-1]) << 8)
// Calculate and compare
actualCRC := CalculateCRC32(payload)
return expectedCRC == actualCRC
}
// AppendCRC adds 16-bit CRC to the end of data
func AppendCRC(data []byte) []byte {
crc := CalculateCRC32(data)
result := make([]byte, len(data)+2)
copy(result, data)
// Append CRC in little-endian format
result[len(data)] = byte(crc)
result[len(data)+1] = byte(crc >> 8)
return result
}
// ValidateAndStrip validates CRC and returns data without CRC suffix
func ValidateAndStrip(data []byte) ([]byte, bool) {
if !ValidateCRC(data) {
return nil, false
}
return data[:len(data)-2], true
}
// Crypto handles RC4 encryption/decryption for EQ2EMu protocol // Crypto handles RC4 encryption/decryption for EQ2EMu protocol
type Crypto struct { type Crypto struct {
clientCipher *rc4.Cipher // Cipher for decrypting client data clientCipher *rc4.Cipher // Cipher for decrypting client data
@ -15,27 +87,25 @@ type Crypto struct {
// NewCrypto creates a new crypto instance with encryption disabled // NewCrypto creates a new crypto instance with encryption disabled
func NewCrypto() *Crypto { func NewCrypto() *Crypto {
return &Crypto{ return &Crypto{}
encrypted: false,
}
} }
// SetKey initializes RC4 encryption with the given key // SetKey initializes RC4 encryption with the given key
// Creates separate ciphers for client and server with 20-byte priming // Creates separate ciphers for client and server with 20-byte priming
func (c *Crypto) SetKey(key []byte) error { func (c *Crypto) SetKey(key []byte) error {
if len(key) == 0 { if len(key) == 0 {
return fmt.Errorf("encryption key cannot be empty") return errors.New("encryption key cannot be empty")
} }
// Create separate RC4 ciphers for bidirectional communication // Create separate RC4 ciphers for bidirectional communication
clientCipher, err := rc4.NewCipher(key) clientCipher, err := rc4.NewCipher(key)
if err != nil { if err != nil {
return fmt.Errorf("failed to create client cipher: %w", err) return err
} }
serverCipher, err := rc4.NewCipher(key) serverCipher, err := rc4.NewCipher(key)
if err != nil { if err != nil {
return fmt.Errorf("failed to create server cipher: %w", err) return err
} }
// Prime both ciphers with 20 dummy bytes per EQ2EMu protocol // Prime both ciphers with 20 dummy bytes per EQ2EMu protocol
@ -98,12 +168,3 @@ func (c *Crypto) Reset() {
c.key = nil c.key = nil
c.encrypted = false c.encrypted = false
} }
// Clone creates a copy of the crypto instance with the same key
func (c *Crypto) Clone() (*Crypto, error) {
newCrypto := NewCrypto()
if c.encrypted && c.key != nil {
return newCrypto, newCrypto.SetKey(c.key)
}
return newCrypto, nil
}

View File

@ -14,40 +14,16 @@ type Server struct {
mutex sync.RWMutex // Protects connections map mutex sync.RWMutex // Protects connections map
handler PacketHandler // Application packet handler handler PacketHandler // Application packet handler
running bool // Server running state running bool // Server running state
config Config // Server configuration
// Configuration
maxConnections int // Maximum concurrent connections
timeout time.Duration // Connection timeout duration
} }
// PacketHandler processes application-level packets for connections // NewServer creates a UDP server with simplified configuration
type PacketHandler interface { func NewServer(addr string, handler PacketHandler, config ...Config) (*Server, error) {
HandlePacket(conn *Connection, packet *ApplicationPacket) cfg := DefaultConfig()
} if len(config) > 0 {
cfg = config[0]
// ServerConfig holds server configuration options
type ServerConfig struct {
MaxConnections int // Maximum concurrent connections (default: 1000)
Timeout time.Duration // Connection timeout (default: 45s)
BufferSize int // UDP receive buffer size (default: 8192)
}
// DefaultServerConfig returns sensible default configuration
func DefaultServerConfig() ServerConfig {
return ServerConfig{
MaxConnections: 1000,
Timeout: 45 * time.Second,
BufferSize: 8192,
} }
}
// NewServer creates a new UDP server instance
func NewServer(addr string, handler PacketHandler) (*Server, error) {
return NewServerWithConfig(addr, handler, DefaultServerConfig())
}
// NewServerWithConfig creates a server with custom configuration
func NewServerWithConfig(addr string, handler PacketHandler, config ServerConfig) (*Server, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid UDP address %s: %w", addr, err) return nil, fmt.Errorf("invalid UDP address %s: %w", addr, err)
@ -59,17 +35,16 @@ func NewServerWithConfig(addr string, handler PacketHandler, config ServerConfig
} }
// Set socket buffer size for better performance // Set socket buffer size for better performance
if config.BufferSize > 0 { if cfg.BufferSize > 0 {
conn.SetReadBuffer(config.BufferSize) conn.SetReadBuffer(cfg.BufferSize)
conn.SetWriteBuffer(config.BufferSize) conn.SetWriteBuffer(cfg.BufferSize)
} }
return &Server{ return &Server{
conn: conn, conn: conn,
connections: make(map[string]*Connection), connections: make(map[string]*Connection),
handler: handler, handler: handler,
maxConnections: config.MaxConnections, config: cfg,
timeout: config.Timeout,
}, nil }, nil
} }
@ -127,12 +102,12 @@ func (s *Server) handleIncomingPacket(data []byte, addr *net.UDPAddr) {
conn, exists := s.connections[connKey] conn, exists := s.connections[connKey]
if !exists { if !exists {
// Check connection limit // Check connection limit
if len(s.connections) >= s.maxConnections { if len(s.connections) >= s.config.MaxConnections {
s.mutex.Unlock() s.mutex.Unlock()
return // Drop packet if at capacity return // Drop packet if at capacity
} }
conn = NewConnection(addr, s.conn, s.handler) conn = NewConnection(addr, s.conn, s.handler, s.config)
conn.StartRetransmitLoop() conn.StartRetransmitLoop()
s.connections[connKey] = conn s.connections[connKey] = conn
} }
@ -163,7 +138,7 @@ func (s *Server) cleanupTimedOutConnections() {
defer s.mutex.Unlock() defer s.mutex.Unlock()
for key, conn := range s.connections { for key, conn := range s.connections {
if conn.IsTimedOut(s.timeout) { if conn.IsTimedOut() {
conn.Close() conn.Close()
delete(s.connections, key) delete(s.connections, key)
} }
@ -219,19 +194,6 @@ func (s *Server) DisconnectClient(addr string) bool {
return false return false
} }
// GetStats returns server statistics
func (s *Server) GetStats() ServerStats {
s.mutex.RLock()
defer s.mutex.RUnlock()
return ServerStats{
ConnectionCount: len(s.connections),
MaxConnections: s.maxConnections,
Running: s.running,
Timeout: s.timeout,
}
}
// ServerStats contains server runtime statistics // ServerStats contains server runtime statistics
type ServerStats struct { type ServerStats struct {
ConnectionCount int // Current number of connections ConnectionCount int // Current number of connections
@ -240,16 +202,29 @@ type ServerStats struct {
Timeout time.Duration // Connection timeout setting Timeout time.Duration // Connection timeout setting
} }
// GetStats returns server statistics
func (s *Server) GetStats() ServerStats {
s.mutex.RLock()
defer s.mutex.RUnlock()
return ServerStats{
ConnectionCount: len(s.connections),
MaxConnections: s.config.MaxConnections,
Running: s.running,
Timeout: s.config.Timeout,
}
}
// SetConnectionLimit updates the maximum connection limit // SetConnectionLimit updates the maximum connection limit
func (s *Server) SetConnectionLimit(limit int) { func (s *Server) SetConnectionLimit(limit int) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.maxConnections = limit s.config.MaxConnections = limit
} }
// SetTimeout updates the connection timeout duration // SetTimeout updates the connection timeout duration
func (s *Server) SetTimeout(timeout time.Duration) { func (s *Server) SetTimeout(timeout time.Duration) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.timeout = timeout s.config.Timeout = timeout
} }

View File

@ -49,10 +49,14 @@ func (h *TestHandler) Clear() {
h.receivedPackets = nil h.receivedPackets = nil
} }
// Simple handler function for testing
func testHandler(conn *Connection, packet *ApplicationPacket) {
fmt.Printf("Test handler received packet opcode: 0x%04X\n", packet.Opcode)
}
// TestServer tests basic server creation and startup // TestServer tests basic server creation and startup
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
handler := &TestHandler{} server, err := NewServer(":9999", testHandler)
server, err := NewServer(":9999", handler)
if err != nil { if err != nil {
t.Fatalf("Failed to create server: %v", err) t.Fatalf("Failed to create server: %v", err)
} }
@ -77,14 +81,12 @@ func TestServer(t *testing.T) {
// TestServerConfig tests server configuration options // TestServerConfig tests server configuration options
func TestServerConfig(t *testing.T) { func TestServerConfig(t *testing.T) {
handler := &TestHandler{} config := DefaultConfig()
config := ServerConfig{ config.MaxConnections = 10
MaxConnections: 10, config.Timeout = 30 * time.Second
Timeout: 30 * time.Second, config.BufferSize = 4096
BufferSize: 4096,
}
server, err := NewServerWithConfig(":9998", handler, config) server, err := NewServer(":9998", testHandler, config)
if err != nil { if err != nil {
t.Fatalf("Failed to create server with config: %v", err) t.Fatalf("Failed to create server with config: %v", err)
} }
@ -217,7 +219,8 @@ func TestCrypto(t *testing.T) {
// TestRetransmitQueue tests packet retransmission logic // TestRetransmitQueue tests packet retransmission logic
func TestRetransmitQueue(t *testing.T) { func TestRetransmitQueue(t *testing.T) {
rq := NewRetransmitQueue() config := DefaultConfig()
rq := NewRetransmitQueue(config.RetransmitBase, config.RetransmitMax, config.RetransmitAttempts)
packet := &ProtocolPacket{ packet := &ProtocolPacket{
Opcode: opcodes.OpPacket, Opcode: opcodes.OpPacket,
@ -296,21 +299,21 @@ func TestFragmentation(t *testing.T) {
// TestPacketCombining tests packet combination functionality // TestPacketCombining tests packet combination functionality
func TestPacketCombining(t *testing.T) { func TestPacketCombining(t *testing.T) {
combiner := NewPacketCombiner() combiner := NewPacketCombiner(256)
// Add small packets - use session opcodes that don't require CRC // Add small packets - use session opcodes that don't require CRC
packet1 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test1")} packet1 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test1")}
packet2 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test2")} packet2 := &ProtocolPacket{Opcode: opcodes.OpSessionRequest, Data: []byte("test2")}
combiner.AddPacket(packet1) combiner.Add(packet1)
combiner.AddPacket(packet2) combiner.Add(packet2)
if combiner.GetPendingCount() != 2 { if len(combiner.PendingPackets) != 2 {
t.Errorf("Expected 2 pending packets, got %d", combiner.GetPendingCount()) t.Errorf("Expected 2 pending packets, got %d", len(combiner.PendingPackets))
} }
// Flush combined // Flush combined
combined := combiner.FlushCombined() combined := combiner.Flush()
if len(combined) != 1 { if len(combined) != 1 {
t.Errorf("Expected 1 combined packet, got %d", len(combined)) t.Errorf("Expected 1 combined packet, got %d", len(combined))
} }
@ -330,12 +333,12 @@ func TestPacketCombining(t *testing.T) {
// TestConnection tests basic connection functionality // TestConnection tests basic connection functionality
func TestConnection(t *testing.T) { func TestConnection(t *testing.T) {
handler := &TestHandler{} config := DefaultConfig()
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", addr) conn, _ := net.ListenUDP("udp", addr)
defer conn.Close() defer conn.Close()
connection := NewConnection(addr, conn, handler) connection := NewConnection(addr, conn, testHandler, config)
if connection.GetState() != StateClosed { if connection.GetState() != StateClosed {
t.Error("New connection should be in closed state") t.Error("New connection should be in closed state")
@ -346,9 +349,37 @@ func TestConnection(t *testing.T) {
t.Error("New connection should have session ID 0") t.Error("New connection should have session ID 0")
} }
// Test timeout // Test timeout with very short timeout config
if !connection.IsTimedOut(time.Nanosecond) { shortConfig := DefaultConfig()
t.Error("New connection should be timed out with very short timeout") shortConfig.Timeout = time.Nanosecond
shortConnection := NewConnection(addr, conn, testHandler, shortConfig)
// Wait a bit to ensure timeout
time.Sleep(time.Millisecond)
if !shortConnection.IsTimedOut() {
t.Error("Connection should be timed out with very short timeout")
}
}
// TestDefaultConfig tests the default configuration
func TestDefaultConfig(t *testing.T) {
config := DefaultConfig()
if config.MaxConnections != 1000 {
t.Errorf("Expected MaxConnections 1000, got %d", config.MaxConnections)
}
if config.Timeout != 45*time.Second {
t.Errorf("Expected Timeout 45s, got %v", config.Timeout)
}
if config.MaxPacketSize != 512 {
t.Errorf("Expected MaxPacketSize 512, got %d", config.MaxPacketSize)
}
if !config.EnableCompression {
t.Error("Expected compression to be enabled by default")
}
if !config.EnableEncryption {
t.Error("Expected encryption to be enabled by default")
} }
} }
@ -398,8 +429,7 @@ func BenchmarkEncryption(b *testing.B) {
// TestIntegration performs a basic integration test // TestIntegration performs a basic integration test
func TestIntegration(t *testing.T) { func TestIntegration(t *testing.T) {
handler := &TestHandler{} server, err := NewServer(":0", testHandler) // Use any available port
server, err := NewServer(":0", handler) // Use any available port
if err != nil { if err != nil {
t.Fatalf("Failed to create server: %v", err) t.Fatalf("Failed to create server: %v", err)
} }
@ -420,3 +450,28 @@ func TestIntegration(t *testing.T) {
t.Errorf("Expected 0 connections, got %d", stats.ConnectionCount) t.Errorf("Expected 0 connections, got %d", stats.ConnectionCount)
} }
} }
// TestDirectFieldAccess tests that we can access fields directly
func TestDirectFieldAccess(t *testing.T) {
// Test PacketCombiner direct access
combiner := NewPacketCombiner(256)
combiner.MaxSize = 512 // Direct field modification
if combiner.MaxSize != 512 {
t.Errorf("Expected MaxSize 512, got %d", combiner.MaxSize)
}
// Test adding packets and accessing them directly
packet := &ProtocolPacket{Opcode: opcodes.OpKeepAlive, Data: []byte("test")}
combiner.Add(packet)
if len(combiner.PendingPackets) != 1 {
t.Errorf("Expected 1 pending packet, got %d", len(combiner.PendingPackets))
}
// Direct access to pending packets
firstPacket := combiner.PendingPackets[0]
if firstPacket.Opcode != opcodes.OpKeepAlive {
t.Errorf("Expected OpKeepAlive, got 0x%02X", firstPacket.Opcode)
}
}