combiner middleware
This commit is contained in:
parent
28fd282f20
commit
9b3f113716
244
internal/udp/middleware/combiner.go
Normal file
244
internal/udp/middleware/combiner.go
Normal file
@ -0,0 +1,244 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CombinerConfig holds configuration for packet combination
|
||||
type CombinerConfig struct {
|
||||
MaxCombinedSize int
|
||||
FlushInterval time.Duration
|
||||
MaxQueuedPackets int
|
||||
}
|
||||
|
||||
// DefaultCombinerConfig returns default combiner configuration
|
||||
func DefaultCombinerConfig() *CombinerConfig {
|
||||
return &CombinerConfig{
|
||||
MaxCombinedSize: 1200,
|
||||
FlushInterval: 250 * time.Millisecond,
|
||||
MaxQueuedPackets: 16,
|
||||
}
|
||||
}
|
||||
|
||||
type queuedPacket struct {
|
||||
data []byte
|
||||
timestamp time.Time
|
||||
callback func([]byte) (int, error)
|
||||
result chan combineResult
|
||||
}
|
||||
|
||||
type combineResult struct {
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
// Combiner implements packet combination middleware
|
||||
type Combiner struct {
|
||||
config *CombinerConfig
|
||||
queue []*queuedPacket
|
||||
queueMux sync.Mutex
|
||||
flushChan chan struct{}
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewCombiner creates a new packet combining middleware
|
||||
func NewCombiner(config *CombinerConfig) *Combiner {
|
||||
if config == nil {
|
||||
config = DefaultCombinerConfig()
|
||||
}
|
||||
|
||||
c := &Combiner{
|
||||
config: config,
|
||||
flushChan: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go c.flushLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
// ProcessOutbound implements Middleware.ProcessOutbound
|
||||
func (c *Combiner) ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error) {
|
||||
if len(data) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Large packets bypass combination
|
||||
if len(data) > c.config.MaxCombinedSize/2 {
|
||||
return next(data)
|
||||
}
|
||||
|
||||
c.queueMux.Lock()
|
||||
defer c.queueMux.Unlock()
|
||||
|
||||
result := make(chan combineResult, 1)
|
||||
|
||||
c.queue = append(c.queue, &queuedPacket{
|
||||
data: append([]byte(nil), data...),
|
||||
timestamp: time.Now(),
|
||||
callback: next,
|
||||
result: result,
|
||||
})
|
||||
|
||||
shouldFlush := len(c.queue) >= c.config.MaxQueuedPackets
|
||||
if !shouldFlush {
|
||||
totalSize := c.calculateCombinedSize()
|
||||
shouldFlush = totalSize > c.config.MaxCombinedSize
|
||||
}
|
||||
|
||||
if shouldFlush {
|
||||
c.flushQueueLocked()
|
||||
} else {
|
||||
select {
|
||||
case c.flushChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case res := <-result:
|
||||
return res.n, res.err
|
||||
case <-c.done:
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessInbound implements Middleware.ProcessInbound
|
||||
func (c *Combiner) ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error) {
|
||||
if len(data) < 2 {
|
||||
return next(data)
|
||||
}
|
||||
|
||||
packetCount := binary.BigEndian.Uint16(data[0:2])
|
||||
|
||||
// Single packet or invalid format
|
||||
if packetCount == 1 {
|
||||
if len(data) < 4 {
|
||||
return next(data)
|
||||
}
|
||||
|
||||
firstLen := binary.BigEndian.Uint16(data[2:4])
|
||||
if int(firstLen)+4 == len(data) {
|
||||
return next(data[4 : 4+firstLen])
|
||||
}
|
||||
return next(data)
|
||||
}
|
||||
|
||||
// Multiple packets - return first one
|
||||
if packetCount > 1 && len(data) >= 4 {
|
||||
firstLen := binary.BigEndian.Uint16(data[2:4])
|
||||
if len(data) >= 4+int(firstLen) {
|
||||
return next(data[4 : 4+firstLen])
|
||||
}
|
||||
}
|
||||
|
||||
return next(data)
|
||||
}
|
||||
|
||||
func (c *Combiner) calculateCombinedSize() int {
|
||||
size := 2 // count field
|
||||
for _, pkt := range c.queue {
|
||||
size += 2 + len(pkt.data)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (c *Combiner) flushLoop() {
|
||||
ticker := time.NewTicker(c.config.FlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.flushIfNeeded()
|
||||
case <-c.flushChan:
|
||||
c.flushIfNeeded()
|
||||
case <-c.done:
|
||||
c.queueMux.Lock()
|
||||
c.flushQueueLocked()
|
||||
c.queueMux.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Combiner) flushIfNeeded() {
|
||||
c.queueMux.Lock()
|
||||
defer c.queueMux.Unlock()
|
||||
|
||||
if len(c.queue) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
oldestAge := now.Sub(c.queue[0].timestamp)
|
||||
|
||||
if oldestAge >= c.config.FlushInterval || len(c.queue) >= c.config.MaxQueuedPackets {
|
||||
c.flushQueueLocked()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Combiner) flushQueueLocked() {
|
||||
if len(c.queue) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
queue := c.queue
|
||||
c.queue = nil
|
||||
|
||||
go c.processQueue(queue)
|
||||
}
|
||||
|
||||
func (c *Combiner) processQueue(queue []*queuedPacket) {
|
||||
if len(queue) == 1 {
|
||||
pkt := queue[0]
|
||||
n, err := pkt.callback(pkt.data)
|
||||
pkt.result <- combineResult{n, err}
|
||||
return
|
||||
}
|
||||
|
||||
// Combine multiple packets
|
||||
combined := c.combinePackets(queue)
|
||||
_, err := queue[0].callback(combined)
|
||||
|
||||
// Distribute result to all packets
|
||||
for _, pkt := range queue {
|
||||
pkt.result <- combineResult{len(pkt.data), err}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Combiner) combinePackets(packets []*queuedPacket) []byte {
|
||||
totalSize := 2 // count
|
||||
for _, pkt := range packets {
|
||||
totalSize += 2 + len(pkt.data)
|
||||
}
|
||||
|
||||
combined := make([]byte, totalSize)
|
||||
offset := 0
|
||||
|
||||
// Write packet count
|
||||
binary.BigEndian.PutUint16(combined[offset:], uint16(len(packets)))
|
||||
offset += 2
|
||||
|
||||
// Write each packet
|
||||
for _, pkt := range packets {
|
||||
binary.BigEndian.PutUint16(combined[offset:], uint16(len(pkt.data)))
|
||||
offset += 2
|
||||
copy(combined[offset:], pkt.data)
|
||||
offset += len(pkt.data)
|
||||
}
|
||||
|
||||
return combined
|
||||
}
|
||||
|
||||
// Close implements Middleware.Close
|
||||
func (c *Combiner) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user