From 9b3f113716dd073c653fe16dd8a6bc9303a852be Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Mon, 21 Jul 2025 13:50:28 -0500 Subject: [PATCH] combiner middleware --- internal/udp/middleware/combiner.go | 244 ++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 internal/udp/middleware/combiner.go diff --git a/internal/udp/middleware/combiner.go b/internal/udp/middleware/combiner.go new file mode 100644 index 0000000..d0cb395 --- /dev/null +++ b/internal/udp/middleware/combiner.go @@ -0,0 +1,244 @@ +package middleware + +import ( + "encoding/binary" + "net" + "sync" + "time" +) + +// CombinerConfig holds configuration for packet combination +type CombinerConfig struct { + MaxCombinedSize int + FlushInterval time.Duration + MaxQueuedPackets int +} + +// DefaultCombinerConfig returns default combiner configuration +func DefaultCombinerConfig() *CombinerConfig { + return &CombinerConfig{ + MaxCombinedSize: 1200, + FlushInterval: 250 * time.Millisecond, + MaxQueuedPackets: 16, + } +} + +type queuedPacket struct { + data []byte + timestamp time.Time + callback func([]byte) (int, error) + result chan combineResult +} + +type combineResult struct { + n int + err error +} + +// Combiner implements packet combination middleware +type Combiner struct { + config *CombinerConfig + queue []*queuedPacket + queueMux sync.Mutex + flushChan chan struct{} + done chan struct{} + closeOnce sync.Once +} + +// NewCombiner creates a new packet combining middleware +func NewCombiner(config *CombinerConfig) *Combiner { + if config == nil { + config = DefaultCombinerConfig() + } + + c := &Combiner{ + config: config, + flushChan: make(chan struct{}, 1), + done: make(chan struct{}), + } + + go c.flushLoop() + return c +} + +// ProcessOutbound implements Middleware.ProcessOutbound +func (c *Combiner) ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error) { + if len(data) == 0 { + return 0, nil + } + + // Large packets bypass combination + if len(data) > c.config.MaxCombinedSize/2 { + return next(data) + } + + c.queueMux.Lock() + defer c.queueMux.Unlock() + + result := make(chan combineResult, 1) + + c.queue = append(c.queue, &queuedPacket{ + data: append([]byte(nil), data...), + timestamp: time.Now(), + callback: next, + result: result, + }) + + shouldFlush := len(c.queue) >= c.config.MaxQueuedPackets + if !shouldFlush { + totalSize := c.calculateCombinedSize() + shouldFlush = totalSize > c.config.MaxCombinedSize + } + + if shouldFlush { + c.flushQueueLocked() + } else { + select { + case c.flushChan <- struct{}{}: + default: + } + } + + select { + case res := <-result: + return res.n, res.err + case <-c.done: + return 0, net.ErrClosed + } +} + +// ProcessInbound implements Middleware.ProcessInbound +func (c *Combiner) ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error) { + if len(data) < 2 { + return next(data) + } + + packetCount := binary.BigEndian.Uint16(data[0:2]) + + // Single packet or invalid format + if packetCount == 1 { + if len(data) < 4 { + return next(data) + } + + firstLen := binary.BigEndian.Uint16(data[2:4]) + if int(firstLen)+4 == len(data) { + return next(data[4 : 4+firstLen]) + } + return next(data) + } + + // Multiple packets - return first one + if packetCount > 1 && len(data) >= 4 { + firstLen := binary.BigEndian.Uint16(data[2:4]) + if len(data) >= 4+int(firstLen) { + return next(data[4 : 4+firstLen]) + } + } + + return next(data) +} + +func (c *Combiner) calculateCombinedSize() int { + size := 2 // count field + for _, pkt := range c.queue { + size += 2 + len(pkt.data) + } + return size +} + +func (c *Combiner) flushLoop() { + ticker := time.NewTicker(c.config.FlushInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.flushIfNeeded() + case <-c.flushChan: + c.flushIfNeeded() + case <-c.done: + c.queueMux.Lock() + c.flushQueueLocked() + c.queueMux.Unlock() + return + } + } +} + +func (c *Combiner) flushIfNeeded() { + c.queueMux.Lock() + defer c.queueMux.Unlock() + + if len(c.queue) == 0 { + return + } + + now := time.Now() + oldestAge := now.Sub(c.queue[0].timestamp) + + if oldestAge >= c.config.FlushInterval || len(c.queue) >= c.config.MaxQueuedPackets { + c.flushQueueLocked() + } +} + +func (c *Combiner) flushQueueLocked() { + if len(c.queue) == 0 { + return + } + + queue := c.queue + c.queue = nil + + go c.processQueue(queue) +} + +func (c *Combiner) processQueue(queue []*queuedPacket) { + if len(queue) == 1 { + pkt := queue[0] + n, err := pkt.callback(pkt.data) + pkt.result <- combineResult{n, err} + return + } + + // Combine multiple packets + combined := c.combinePackets(queue) + _, err := queue[0].callback(combined) + + // Distribute result to all packets + for _, pkt := range queue { + pkt.result <- combineResult{len(pkt.data), err} + } +} + +func (c *Combiner) combinePackets(packets []*queuedPacket) []byte { + totalSize := 2 // count + for _, pkt := range packets { + totalSize += 2 + len(pkt.data) + } + + combined := make([]byte, totalSize) + offset := 0 + + // Write packet count + binary.BigEndian.PutUint16(combined[offset:], uint16(len(packets))) + offset += 2 + + // Write each packet + for _, pkt := range packets { + binary.BigEndian.PutUint16(combined[offset:], uint16(len(pkt.data))) + offset += 2 + copy(combined[offset:], pkt.data) + offset += len(pkt.data) + } + + return combined +} + +// Close implements Middleware.Close +func (c *Combiner) Close() error { + c.closeOnce.Do(func() { + close(c.done) + }) + return nil +}