package middleware import ( "encoding/binary" "net" "sync" "time" ) // CombinerConfig holds configuration for packet combination type CombinerConfig struct { MaxCombinedSize int FlushInterval time.Duration MaxQueuedPackets int } // DefaultCombinerConfig returns default combiner configuration func DefaultCombinerConfig() *CombinerConfig { return &CombinerConfig{ MaxCombinedSize: 1200, FlushInterval: 250 * time.Millisecond, MaxQueuedPackets: 16, } } type queuedPacket struct { data []byte timestamp time.Time callback func([]byte) (int, error) result chan combineResult } type combineResult struct { n int err error } // Combiner implements packet combination middleware type Combiner struct { config *CombinerConfig queue []*queuedPacket queueMux sync.Mutex flushChan chan struct{} done chan struct{} closeOnce sync.Once } // NewCombiner creates a new packet combining middleware func NewCombiner(config *CombinerConfig) *Combiner { if config == nil { config = DefaultCombinerConfig() } c := &Combiner{ config: config, flushChan: make(chan struct{}, 1), done: make(chan struct{}), } go c.flushLoop() return c } // ProcessOutbound implements Middleware.ProcessOutbound func (c *Combiner) ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error) { if len(data) == 0 { return 0, nil } // Large packets bypass combination if len(data) > c.config.MaxCombinedSize/2 { return next(data) } c.queueMux.Lock() defer c.queueMux.Unlock() result := make(chan combineResult, 1) c.queue = append(c.queue, &queuedPacket{ data: append([]byte(nil), data...), timestamp: time.Now(), callback: next, result: result, }) shouldFlush := len(c.queue) >= c.config.MaxQueuedPackets if !shouldFlush { totalSize := c.calculateCombinedSize() shouldFlush = totalSize > c.config.MaxCombinedSize } if shouldFlush { c.flushQueueLocked() } else { select { case c.flushChan <- struct{}{}: default: } } select { case res := <-result: return res.n, res.err case <-c.done: return 0, net.ErrClosed } } // ProcessInbound implements Middleware.ProcessInbound func (c *Combiner) ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error) { if len(data) < 2 { return next(data) } packetCount := binary.BigEndian.Uint16(data[0:2]) // Single packet or invalid format if packetCount == 1 { if len(data) < 4 { return next(data) } firstLen := binary.BigEndian.Uint16(data[2:4]) if int(firstLen)+4 == len(data) { return next(data[4 : 4+firstLen]) } return next(data) } // Multiple packets - return first one if packetCount > 1 && len(data) >= 4 { firstLen := binary.BigEndian.Uint16(data[2:4]) if len(data) >= 4+int(firstLen) { return next(data[4 : 4+firstLen]) } } return next(data) } func (c *Combiner) calculateCombinedSize() int { size := 2 // count field for _, pkt := range c.queue { size += 2 + len(pkt.data) } return size } func (c *Combiner) flushLoop() { ticker := time.NewTicker(c.config.FlushInterval) defer ticker.Stop() for { select { case <-ticker.C: c.flushIfNeeded() case <-c.flushChan: c.flushIfNeeded() case <-c.done: c.queueMux.Lock() c.flushQueueLocked() c.queueMux.Unlock() return } } } func (c *Combiner) flushIfNeeded() { c.queueMux.Lock() defer c.queueMux.Unlock() if len(c.queue) == 0 { return } now := time.Now() oldestAge := now.Sub(c.queue[0].timestamp) if oldestAge >= c.config.FlushInterval || len(c.queue) >= c.config.MaxQueuedPackets { c.flushQueueLocked() } } func (c *Combiner) flushQueueLocked() { if len(c.queue) == 0 { return } queue := c.queue c.queue = nil go c.processQueue(queue) } func (c *Combiner) processQueue(queue []*queuedPacket) { if len(queue) == 1 { pkt := queue[0] n, err := pkt.callback(pkt.data) pkt.result <- combineResult{n, err} return } // Combine multiple packets combined := c.combinePackets(queue) _, err := queue[0].callback(combined) // Distribute result to all packets for _, pkt := range queue { pkt.result <- combineResult{len(pkt.data), err} } } func (c *Combiner) combinePackets(packets []*queuedPacket) []byte { totalSize := 2 // count for _, pkt := range packets { totalSize += 2 + len(pkt.data) } combined := make([]byte, totalSize) offset := 0 // Write packet count binary.BigEndian.PutUint16(combined[offset:], uint16(len(packets))) offset += 2 // Write each packet for _, pkt := range packets { binary.BigEndian.PutUint16(combined[offset:], uint16(len(pkt.data))) offset += 2 copy(combined[offset:], pkt.data) offset += len(pkt.data) } return combined } // Close implements Middleware.Close func (c *Combiner) Close() error { c.closeOnce.Do(func() { close(c.done) }) return nil }