245 lines
4.7 KiB
Go
245 lines
4.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// CombinerConfig holds configuration for packet combination
|
|
type CombinerConfig struct {
|
|
MaxCombinedSize int
|
|
FlushInterval time.Duration
|
|
MaxQueuedPackets int
|
|
}
|
|
|
|
// DefaultCombinerConfig returns default combiner configuration
|
|
func DefaultCombinerConfig() *CombinerConfig {
|
|
return &CombinerConfig{
|
|
MaxCombinedSize: 1200,
|
|
FlushInterval: 250 * time.Millisecond,
|
|
MaxQueuedPackets: 16,
|
|
}
|
|
}
|
|
|
|
type queuedPacket struct {
|
|
data []byte
|
|
timestamp time.Time
|
|
callback func([]byte) (int, error)
|
|
result chan combineResult
|
|
}
|
|
|
|
type combineResult struct {
|
|
n int
|
|
err error
|
|
}
|
|
|
|
// Combiner implements packet combination middleware
|
|
type Combiner struct {
|
|
config *CombinerConfig
|
|
queue []*queuedPacket
|
|
queueMux sync.Mutex
|
|
flushChan chan struct{}
|
|
done chan struct{}
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
// NewCombiner creates a new packet combining middleware
|
|
func NewCombiner(config *CombinerConfig) *Combiner {
|
|
if config == nil {
|
|
config = DefaultCombinerConfig()
|
|
}
|
|
|
|
c := &Combiner{
|
|
config: config,
|
|
flushChan: make(chan struct{}, 1),
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
go c.flushLoop()
|
|
return c
|
|
}
|
|
|
|
// ProcessOutbound implements Middleware.ProcessOutbound
|
|
func (c *Combiner) ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error) {
|
|
if len(data) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// Large packets bypass combination
|
|
if len(data) > c.config.MaxCombinedSize/2 {
|
|
return next(data)
|
|
}
|
|
|
|
c.queueMux.Lock()
|
|
defer c.queueMux.Unlock()
|
|
|
|
result := make(chan combineResult, 1)
|
|
|
|
c.queue = append(c.queue, &queuedPacket{
|
|
data: append([]byte(nil), data...),
|
|
timestamp: time.Now(),
|
|
callback: next,
|
|
result: result,
|
|
})
|
|
|
|
shouldFlush := len(c.queue) >= c.config.MaxQueuedPackets
|
|
if !shouldFlush {
|
|
totalSize := c.calculateCombinedSize()
|
|
shouldFlush = totalSize > c.config.MaxCombinedSize
|
|
}
|
|
|
|
if shouldFlush {
|
|
c.flushQueueLocked()
|
|
} else {
|
|
select {
|
|
case c.flushChan <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
select {
|
|
case res := <-result:
|
|
return res.n, res.err
|
|
case <-c.done:
|
|
return 0, net.ErrClosed
|
|
}
|
|
}
|
|
|
|
// ProcessInbound implements Middleware.ProcessInbound
|
|
func (c *Combiner) ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error) {
|
|
if len(data) < 2 {
|
|
return next(data)
|
|
}
|
|
|
|
packetCount := binary.BigEndian.Uint16(data[0:2])
|
|
|
|
// Single packet or invalid format
|
|
if packetCount == 1 {
|
|
if len(data) < 4 {
|
|
return next(data)
|
|
}
|
|
|
|
firstLen := binary.BigEndian.Uint16(data[2:4])
|
|
if int(firstLen)+4 == len(data) {
|
|
return next(data[4 : 4+firstLen])
|
|
}
|
|
return next(data)
|
|
}
|
|
|
|
// Multiple packets - return first one
|
|
if packetCount > 1 && len(data) >= 4 {
|
|
firstLen := binary.BigEndian.Uint16(data[2:4])
|
|
if len(data) >= 4+int(firstLen) {
|
|
return next(data[4 : 4+firstLen])
|
|
}
|
|
}
|
|
|
|
return next(data)
|
|
}
|
|
|
|
func (c *Combiner) calculateCombinedSize() int {
|
|
size := 2 // count field
|
|
for _, pkt := range c.queue {
|
|
size += 2 + len(pkt.data)
|
|
}
|
|
return size
|
|
}
|
|
|
|
func (c *Combiner) flushLoop() {
|
|
ticker := time.NewTicker(c.config.FlushInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
c.flushIfNeeded()
|
|
case <-c.flushChan:
|
|
c.flushIfNeeded()
|
|
case <-c.done:
|
|
c.queueMux.Lock()
|
|
c.flushQueueLocked()
|
|
c.queueMux.Unlock()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Combiner) flushIfNeeded() {
|
|
c.queueMux.Lock()
|
|
defer c.queueMux.Unlock()
|
|
|
|
if len(c.queue) == 0 {
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
oldestAge := now.Sub(c.queue[0].timestamp)
|
|
|
|
if oldestAge >= c.config.FlushInterval || len(c.queue) >= c.config.MaxQueuedPackets {
|
|
c.flushQueueLocked()
|
|
}
|
|
}
|
|
|
|
func (c *Combiner) flushQueueLocked() {
|
|
if len(c.queue) == 0 {
|
|
return
|
|
}
|
|
|
|
queue := c.queue
|
|
c.queue = nil
|
|
|
|
go c.processQueue(queue)
|
|
}
|
|
|
|
func (c *Combiner) processQueue(queue []*queuedPacket) {
|
|
if len(queue) == 1 {
|
|
pkt := queue[0]
|
|
n, err := pkt.callback(pkt.data)
|
|
pkt.result <- combineResult{n, err}
|
|
return
|
|
}
|
|
|
|
// Combine multiple packets
|
|
combined := c.combinePackets(queue)
|
|
_, err := queue[0].callback(combined)
|
|
|
|
// Distribute result to all packets
|
|
for _, pkt := range queue {
|
|
pkt.result <- combineResult{len(pkt.data), err}
|
|
}
|
|
}
|
|
|
|
func (c *Combiner) combinePackets(packets []*queuedPacket) []byte {
|
|
totalSize := 2 // count
|
|
for _, pkt := range packets {
|
|
totalSize += 2 + len(pkt.data)
|
|
}
|
|
|
|
combined := make([]byte, totalSize)
|
|
offset := 0
|
|
|
|
// Write packet count
|
|
binary.BigEndian.PutUint16(combined[offset:], uint16(len(packets)))
|
|
offset += 2
|
|
|
|
// Write each packet
|
|
for _, pkt := range packets {
|
|
binary.BigEndian.PutUint16(combined[offset:], uint16(len(pkt.data)))
|
|
offset += 2
|
|
copy(combined[offset:], pkt.data)
|
|
offset += len(pkt.data)
|
|
}
|
|
|
|
return combined
|
|
}
|
|
|
|
// Close implements Middleware.Close
|
|
func (c *Combiner) Close() error {
|
|
c.closeOnce.Do(func() {
|
|
close(c.done)
|
|
})
|
|
return nil
|
|
}
|