2025-07-21 13:50:28 -05:00

245 lines
4.7 KiB
Go

package middleware
import (
"encoding/binary"
"net"
"sync"
"time"
)
// CombinerConfig holds configuration for packet combination
type CombinerConfig struct {
MaxCombinedSize int
FlushInterval time.Duration
MaxQueuedPackets int
}
// DefaultCombinerConfig returns default combiner configuration
func DefaultCombinerConfig() *CombinerConfig {
return &CombinerConfig{
MaxCombinedSize: 1200,
FlushInterval: 250 * time.Millisecond,
MaxQueuedPackets: 16,
}
}
type queuedPacket struct {
data []byte
timestamp time.Time
callback func([]byte) (int, error)
result chan combineResult
}
type combineResult struct {
n int
err error
}
// Combiner implements packet combination middleware
type Combiner struct {
config *CombinerConfig
queue []*queuedPacket
queueMux sync.Mutex
flushChan chan struct{}
done chan struct{}
closeOnce sync.Once
}
// NewCombiner creates a new packet combining middleware
func NewCombiner(config *CombinerConfig) *Combiner {
if config == nil {
config = DefaultCombinerConfig()
}
c := &Combiner{
config: config,
flushChan: make(chan struct{}, 1),
done: make(chan struct{}),
}
go c.flushLoop()
return c
}
// ProcessOutbound implements Middleware.ProcessOutbound
func (c *Combiner) ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error) {
if len(data) == 0 {
return 0, nil
}
// Large packets bypass combination
if len(data) > c.config.MaxCombinedSize/2 {
return next(data)
}
c.queueMux.Lock()
defer c.queueMux.Unlock()
result := make(chan combineResult, 1)
c.queue = append(c.queue, &queuedPacket{
data: append([]byte(nil), data...),
timestamp: time.Now(),
callback: next,
result: result,
})
shouldFlush := len(c.queue) >= c.config.MaxQueuedPackets
if !shouldFlush {
totalSize := c.calculateCombinedSize()
shouldFlush = totalSize > c.config.MaxCombinedSize
}
if shouldFlush {
c.flushQueueLocked()
} else {
select {
case c.flushChan <- struct{}{}:
default:
}
}
select {
case res := <-result:
return res.n, res.err
case <-c.done:
return 0, net.ErrClosed
}
}
// ProcessInbound implements Middleware.ProcessInbound
func (c *Combiner) ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error) {
if len(data) < 2 {
return next(data)
}
packetCount := binary.BigEndian.Uint16(data[0:2])
// Single packet or invalid format
if packetCount == 1 {
if len(data) < 4 {
return next(data)
}
firstLen := binary.BigEndian.Uint16(data[2:4])
if int(firstLen)+4 == len(data) {
return next(data[4 : 4+firstLen])
}
return next(data)
}
// Multiple packets - return first one
if packetCount > 1 && len(data) >= 4 {
firstLen := binary.BigEndian.Uint16(data[2:4])
if len(data) >= 4+int(firstLen) {
return next(data[4 : 4+firstLen])
}
}
return next(data)
}
func (c *Combiner) calculateCombinedSize() int {
size := 2 // count field
for _, pkt := range c.queue {
size += 2 + len(pkt.data)
}
return size
}
func (c *Combiner) flushLoop() {
ticker := time.NewTicker(c.config.FlushInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.flushIfNeeded()
case <-c.flushChan:
c.flushIfNeeded()
case <-c.done:
c.queueMux.Lock()
c.flushQueueLocked()
c.queueMux.Unlock()
return
}
}
}
func (c *Combiner) flushIfNeeded() {
c.queueMux.Lock()
defer c.queueMux.Unlock()
if len(c.queue) == 0 {
return
}
now := time.Now()
oldestAge := now.Sub(c.queue[0].timestamp)
if oldestAge >= c.config.FlushInterval || len(c.queue) >= c.config.MaxQueuedPackets {
c.flushQueueLocked()
}
}
func (c *Combiner) flushQueueLocked() {
if len(c.queue) == 0 {
return
}
queue := c.queue
c.queue = nil
go c.processQueue(queue)
}
func (c *Combiner) processQueue(queue []*queuedPacket) {
if len(queue) == 1 {
pkt := queue[0]
n, err := pkt.callback(pkt.data)
pkt.result <- combineResult{n, err}
return
}
// Combine multiple packets
combined := c.combinePackets(queue)
_, err := queue[0].callback(combined)
// Distribute result to all packets
for _, pkt := range queue {
pkt.result <- combineResult{len(pkt.data), err}
}
}
func (c *Combiner) combinePackets(packets []*queuedPacket) []byte {
totalSize := 2 // count
for _, pkt := range packets {
totalSize += 2 + len(pkt.data)
}
combined := make([]byte, totalSize)
offset := 0
// Write packet count
binary.BigEndian.PutUint16(combined[offset:], uint16(len(packets)))
offset += 2
// Write each packet
for _, pkt := range packets {
binary.BigEndian.PutUint16(combined[offset:], uint16(len(pkt.data)))
offset += 2
copy(combined[offset:], pkt.data)
offset += len(pkt.data)
}
return combined
}
// Close implements Middleware.Close
func (c *Combiner) Close() error {
c.closeOnce.Do(func() {
close(c.done)
})
return nil
}