eq2go/internal/udp/reliability.go
2025-07-21 23:51:35 -05:00

308 lines
8.5 KiB
Go

package udp
import (
"encoding/binary"
"eq2emu/internal/opcodes"
"errors"
"fmt"
"sort"
"sync"
"time"
)
// Common errors for reliability layer
var (
ErrFragmentTimeout = errors.New("fragment timeout")
ErrOrphanedFragment = errors.New("orphaned fragment")
)
// RetransmitEntry tracks a packet awaiting acknowledgment
type RetransmitEntry struct {
Packet *ProtocolPacket // The packet to retransmit
Sequence uint16 // Packet sequence number
Timestamp time.Time // When packet was last sent
Attempts int // Number of transmission attempts
}
// RetransmitQueue manages reliable packet delivery with exponential backoff
type RetransmitQueue struct {
entries map[uint16]*RetransmitEntry // Pending packets by sequence
mutex sync.RWMutex // Thread-safe access
baseTimeout time.Duration // Base retransmission timeout
maxAttempts int // Maximum retry attempts
maxTimeout time.Duration // Maximum timeout cap
}
// NewRetransmitQueue creates a queue with specified settings
func NewRetransmitQueue(baseTimeout, maxTimeout time.Duration, maxAttempts int) *RetransmitQueue {
return &RetransmitQueue{
entries: make(map[uint16]*RetransmitEntry),
baseTimeout: baseTimeout,
maxAttempts: maxAttempts,
maxTimeout: maxTimeout,
}
}
// Add queues a packet for potential retransmission
func (rq *RetransmitQueue) Add(packet *ProtocolPacket, sequence uint16) {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.entries[sequence] = &RetransmitEntry{
Packet: packet,
Sequence: sequence,
Timestamp: time.Now(),
Attempts: 1,
}
}
// Acknowledge removes a packet from the retransmit queue
func (rq *RetransmitQueue) Acknowledge(sequence uint16) bool {
rq.mutex.Lock()
defer rq.mutex.Unlock()
_, existed := rq.entries[sequence]
delete(rq.entries, sequence)
return existed
}
// GetExpired returns packets that need retransmission
func (rq *RetransmitQueue) GetExpired() []*RetransmitEntry {
rq.mutex.Lock()
defer rq.mutex.Unlock()
now := time.Now()
var expired []*RetransmitEntry
for seq, entry := range rq.entries {
timeout := rq.calculateTimeout(entry.Attempts)
if now.Sub(entry.Timestamp) > timeout {
if entry.Attempts >= rq.maxAttempts {
// Give up after max attempts
delete(rq.entries, seq)
} else {
// Schedule for retransmission
entry.Attempts++
entry.Timestamp = now
expired = append(expired, entry)
}
}
}
return expired
}
// calculateTimeout computes timeout with exponential backoff
func (rq *RetransmitQueue) calculateTimeout(attempts int) time.Duration {
timeout := rq.baseTimeout * time.Duration(attempts*attempts) // Quadratic backoff
if timeout > rq.maxTimeout {
timeout = rq.maxTimeout
}
return timeout
}
// Clear removes all pending packets
func (rq *RetransmitQueue) Clear() {
rq.mutex.Lock()
defer rq.mutex.Unlock()
rq.entries = make(map[uint16]*RetransmitEntry)
}
// Size returns the number of pending packets
func (rq *RetransmitQueue) Size() int {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
return len(rq.entries)
}
// IsEmpty returns true if no packets are pending
func (rq *RetransmitQueue) IsEmpty() bool {
rq.mutex.RLock()
defer rq.mutex.RUnlock()
return len(rq.entries) == 0
}
// FragmentGroup tracks fragments belonging to the same original packet
type FragmentGroup struct {
BaseSequence uint16 // Base sequence number
TotalLength uint32 // Expected total reassembled length
Fragments []FragmentPiece // Individual fragment pieces
FirstSeen bool // Whether we've seen the first fragment
}
// FragmentPiece represents a single fragment
type FragmentPiece struct {
Sequence uint16 // Fragment sequence number
Data []byte // Fragment payload data
IsFirst bool // Whether this is the first fragment
}
// FragmentManager handles packet fragmentation and reassembly
type FragmentManager struct {
fragments map[uint16]*FragmentGroup // Active fragment groups by base sequence
maxLength uint32 // Maximum packet size before fragmentation
}
// NewFragmentManager creates a manager with specified maximum packet length
func NewFragmentManager(maxLength uint32) *FragmentManager {
return &FragmentManager{
fragments: make(map[uint16]*FragmentGroup),
maxLength: maxLength,
}
}
// FragmentPacket splits large packets into fragments
// Returns nil if packet doesn't need fragmentation
func (fm *FragmentManager) FragmentPacket(data []byte, startSeq uint16) []*ProtocolPacket {
if uint32(len(data)) <= fm.maxLength {
return nil // No fragmentation needed
}
totalLength := uint32(len(data))
chunkSize := int(fm.maxLength - 6) // Reserve 6 bytes for headers
if chunkSize <= 0 {
chunkSize = 1
}
var packets []*ProtocolPacket
seq := startSeq
for offset := 0; offset < len(data); offset += chunkSize {
end := offset + chunkSize
if end > len(data) {
end = len(data)
}
var fragmentData []byte
if offset == 0 {
// First fragment includes total length
fragmentData = make([]byte, 6+end-offset)
binary.BigEndian.PutUint16(fragmentData[0:2], seq)
binary.LittleEndian.PutUint32(fragmentData[2:6], totalLength)
copy(fragmentData[6:], data[offset:end])
} else {
// Subsequent fragments
fragmentData = make([]byte, 2+end-offset)
binary.BigEndian.PutUint16(fragmentData[0:2], seq)
copy(fragmentData[2:], data[offset:end])
}
packets = append(packets, &ProtocolPacket{
Opcode: opcodes.OpFragment,
Data: fragmentData,
})
seq++
}
return packets
}
// ProcessFragment handles incoming fragments and returns complete packet when ready
func (fm *FragmentManager) ProcessFragment(packet *ProtocolPacket) ([]byte, bool, error) {
if len(packet.Data) < 2 {
return nil, false, ErrPacketTooSmall
}
seq := binary.BigEndian.Uint16(packet.Data[0:2])
// Parse fragment data
fragment := FragmentPiece{Sequence: seq}
if len(packet.Data) >= 6 {
// Check if this is the first fragment (has total length)
possibleLength := binary.LittleEndian.Uint32(packet.Data[2:6])
if possibleLength > 0 && possibleLength < 10*1024*1024 { // Reasonable limit
fragment.IsFirst = true
fragment.Data = packet.Data[6:]
// Create new fragment group
group := &FragmentGroup{
BaseSequence: seq,
TotalLength: possibleLength,
Fragments: []FragmentPiece{fragment},
FirstSeen: true,
}
fm.fragments[seq] = group
return fm.tryAssemble(seq)
}
}
// Not first fragment - find matching group
fragment.Data = packet.Data[2:]
group := fm.findFragmentGroup(seq)
if group == nil {
return nil, false, ErrOrphanedFragment
}
group.Fragments = append(group.Fragments, fragment)
return fm.tryAssemble(group.BaseSequence)
}
// findFragmentGroup locates the fragment group for a sequence number
func (fm *FragmentManager) findFragmentGroup(seq uint16) *FragmentGroup {
// Look for groups where this sequence fits
for baseSeq, group := range fm.fragments {
if seq >= baseSeq && seq < baseSeq+100 { // Reasonable window
return group
}
}
return nil
}
// tryAssemble attempts to reassemble fragments into complete packet
func (fm *FragmentManager) tryAssemble(baseSeq uint16) ([]byte, bool, error) {
group, exists := fm.fragments[baseSeq]
if !exists || !group.FirstSeen {
return nil, false, nil
}
// Calculate expected fragment count
chunkSize := int(fm.maxLength - 6)
expectedCount := int(group.TotalLength) / chunkSize
if int(group.TotalLength)%chunkSize != 0 {
expectedCount++
}
if len(group.Fragments) < expectedCount {
return nil, false, nil // Still waiting for fragments
}
// Sort fragments by sequence number
sort.Slice(group.Fragments, func(i, j int) bool {
return group.Fragments[i].Sequence < group.Fragments[j].Sequence
})
// Reassemble packet
result := make([]byte, 0, group.TotalLength)
for _, frag := range group.Fragments[:expectedCount] {
result = append(result, frag.Data...)
}
// Validate length
if uint32(len(result)) != group.TotalLength {
delete(fm.fragments, baseSeq)
return nil, false, fmt.Errorf("assembled length %d != expected %d", len(result), group.TotalLength)
}
// Clean up
delete(fm.fragments, baseSeq)
return result, true, nil
}
// CleanupStale removes old incomplete fragment groups
func (fm *FragmentManager) CleanupStale(maxAge uint16) {
// Simple cleanup - remove groups with very old base sequences
for baseSeq := range fm.fragments {
if baseSeq < maxAge {
delete(fm.fragments, baseSeq)
}
}
}
// Clear removes all pending fragments
func (fm *FragmentManager) Clear() {
fm.fragments = make(map[uint16]*FragmentGroup)
}