1
0
Protocol/stream_factory.go
2025-09-01 12:31:14 -05:00

471 lines
10 KiB
Go

package eq2net
import (
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"sync"
"time"
)
// StreamFactory manages multiple UDP streams
type StreamFactory struct {
mu sync.RWMutex
// Network
conn *net.UDPConn
listenAddr *net.UDPAddr
streamType StreamType
// Streams
streams map[string]*Stream // Key is "IP:Port"
streamsByID map[uint32]*Stream // Key is session ID
// Options
maxStreams int
readTimeout time.Duration
writeTimeout time.Duration
// State
running bool
stopChan chan struct{}
// Callbacks
onNewStream func(*Stream)
onStreamClosed func(*Stream)
}
// NewStreamFactory creates a new stream factory
func NewStreamFactory(listenAddr string, streamType StreamType) (*StreamFactory, error) {
addr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, err
}
sf := &StreamFactory{
listenAddr: addr,
streamType: streamType,
streams: make(map[string]*Stream),
streamsByID: make(map[uint32]*Stream),
maxStreams: 1000,
readTimeout: 30 * time.Second,
writeTimeout: 5 * time.Second,
stopChan: make(chan struct{}),
}
return sf, nil
}
// Start starts the stream factory
func (sf *StreamFactory) Start() error {
sf.mu.Lock()
defer sf.mu.Unlock()
if sf.running {
return fmt.Errorf("stream factory already running")
}
// Create UDP connection
conn, err := net.ListenUDP("udp", sf.listenAddr)
if err != nil {
return err
}
sf.conn = conn
// Set buffer sizes
conn.SetReadBuffer(65536)
conn.SetWriteBuffer(65536)
sf.running = true
// Start worker goroutines
go sf.readLoop()
go sf.writeLoop()
go sf.maintenanceLoop()
return nil
}
// Stop stops the stream factory
func (sf *StreamFactory) Stop() {
sf.mu.Lock()
defer sf.mu.Unlock()
if !sf.running {
return
}
sf.running = false
close(sf.stopChan)
if sf.conn != nil {
sf.conn.Close()
}
// Close all streams
for _, stream := range sf.streams {
stream.SetState(Closed)
if stream.onDisconnect != nil {
stream.onDisconnect()
}
}
sf.streams = make(map[string]*Stream)
sf.streamsByID = make(map[uint32]*Stream)
}
// readLoop reads packets from the UDP connection
func (sf *StreamFactory) readLoop() {
buffer := make([]byte, 65536)
for {
select {
case <-sf.stopChan:
return
default:
// Set read deadline
sf.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
// Read packet
n, remoteAddr, err := sf.conn.ReadFromUDP(buffer)
if err != nil {
// Check if it's a timeout (expected)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
// Actual error
continue
}
// Process packet
data := make([]byte, n)
copy(data, buffer[:n])
go sf.processPacket(data, remoteAddr)
}
}
}
// writeLoop handles writing packets to streams
func (sf *StreamFactory) writeLoop() {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-sf.stopChan:
return
case <-ticker.C:
sf.processOutgoingPackets()
}
}
}
// maintenanceLoop performs periodic maintenance
func (sf *StreamFactory) maintenanceLoop() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-sf.stopChan:
return
case <-ticker.C:
sf.performMaintenance()
}
}
}
// processPacket processes an incoming packet
func (sf *StreamFactory) processPacket(data []byte, remoteAddr *net.UDPAddr) {
if len(data) < 2 {
return
}
// Get or create stream
stream := sf.getOrCreateStream(remoteAddr, data)
if stream == nil {
return
}
// Process packet in stream
err := stream.Process(data)
if err != nil {
// Handle error (log it, etc.)
}
}
// getOrCreateStream gets an existing stream or creates a new one
func (sf *StreamFactory) getOrCreateStream(remoteAddr *net.UDPAddr, data []byte) *Stream {
sf.mu.Lock()
defer sf.mu.Unlock()
key := fmt.Sprintf("%s:%d", remoteAddr.IP.String(), remoteAddr.Port)
// Check if stream exists
if stream, ok := sf.streams[key]; ok {
return stream
}
// Check if this is a session request
if len(data) >= 2 {
opcode := binary.BigEndian.Uint16(data)
if uint8(opcode&0xFF) == OPSessionRequest {
// Create new stream
stream := NewStream(remoteAddr, sf.streamType)
// Generate session ID and key
sessionID := sf.generateSessionID()
sessionKey := sf.generateKey()
stream.SetSessionInfo(sessionID, sessionKey)
// Add to maps
sf.streams[key] = stream
sf.streamsByID[sessionID] = stream
// Call new stream callback
if sf.onNewStream != nil {
sf.onNewStream(stream)
}
return stream
}
}
return nil
}
// processOutgoingPackets processes outgoing packets for all streams
func (sf *StreamFactory) processOutgoingPackets() {
sf.mu.RLock()
streams := make([]*Stream, 0, len(sf.streams))
for _, stream := range sf.streams {
streams = append(streams, stream)
}
sf.mu.RUnlock()
for _, stream := range streams {
sf.processStreamOutgoing(stream)
}
}
// processStreamOutgoing processes outgoing packets for a stream
func (sf *StreamFactory) processStreamOutgoing(stream *Stream) {
stream.mu.Lock()
defer stream.mu.Unlock()
// Process outgoing queue
for len(stream.outgoingQueue) > 0 {
packet := stream.outgoingQueue[0]
stream.outgoingQueue = stream.outgoingQueue[1:]
// Assign sequence number
packet.Sequence = stream.nextOutSeq
stream.nextOutSeq++
// Build packet data
data := sf.buildPacketData(stream, packet)
// Send packet
sf.sendPacket(stream.remoteAddr, data)
// Add to sent queue for acknowledgment tracking
packet.SentTime = int32(time.Now().Unix())
stream.sentQueue[packet.Sequence] = packet
}
// Check for retransmissions
now := time.Now()
for seq, packet := range stream.sentQueue {
sentTime := time.Unix(int64(packet.SentTime), 0)
if now.Sub(sentTime) > stream.retransmitTimeout {
if packet.AttemptCount < 3 {
// Retransmit
packet.AttemptCount++
packet.SentTime = int32(now.Unix())
data := sf.buildPacketData(stream, packet)
sf.sendPacket(stream.remoteAddr, data)
} else {
// Give up on this packet
delete(stream.sentQueue, seq)
}
}
}
}
// buildPacketData builds the complete packet data including headers and CRC
func (sf *StreamFactory) buildPacketData(stream *Stream, packet *ProtocolPacket) []byte {
// Check if we need to fragment
maxDataSize := int(stream.maxLength) - 6 // Header (4) + CRC (2)
if len(packet.Buffer) > maxDataSize {
// Need to fragment
return sf.buildFragmentData(stream, packet)
}
// Build regular packet
data := make([]byte, len(packet.Buffer)+6)
// Header
binary.BigEndian.PutUint16(data[0:2], uint16(OPPacket))
binary.BigEndian.PutUint16(data[2:4], packet.Sequence)
// Data
copy(data[4:], packet.Buffer)
// CRC
crc := CalculateCRC16(data[:len(data)-2], stream.key)
binary.BigEndian.PutUint16(data[len(data)-2:], crc)
return data
}
// buildFragmentData builds a fragment packet
func (sf *StreamFactory) buildFragmentData(stream *Stream, packet *ProtocolPacket) []byte {
// For now, just truncate (proper fragmentation would split into multiple packets)
// This is a simplified version - real implementation would handle multiple fragments
maxDataSize := int(stream.maxLength) - 10 // Header (8) + CRC (2)
dataSize := len(packet.Buffer)
if dataSize > maxDataSize {
dataSize = maxDataSize
}
data := make([]byte, dataSize+10)
// Header
binary.BigEndian.PutUint16(data[0:2], uint16(OPFragment))
binary.BigEndian.PutUint16(data[2:4], packet.Sequence)
binary.BigEndian.PutUint32(data[4:8], uint32(len(packet.Buffer)))
// Data
copy(data[8:], packet.Buffer[:dataSize])
// CRC
crc := CalculateCRC16(data[:len(data)-2], stream.key)
binary.BigEndian.PutUint16(data[len(data)-2:], crc)
return data
}
// sendPacket sends a packet
func (sf *StreamFactory) sendPacket(addr *net.UDPAddr, data []byte) error {
sf.conn.SetWriteDeadline(time.Now().Add(sf.writeTimeout))
_, err := sf.conn.WriteToUDP(data, addr)
return err
}
// performMaintenance performs periodic maintenance tasks
func (sf *StreamFactory) performMaintenance() {
sf.mu.Lock()
defer sf.mu.Unlock()
now := time.Now()
toRemove := []string{}
for key, stream := range sf.streams {
stream.mu.RLock()
lastReceive := stream.lastReceiveTime
state := stream.state
stream.mu.RUnlock()
// Check for timeout
if now.Sub(lastReceive) > sf.readTimeout && state == Established {
stream.SetState(Disconnecting)
toRemove = append(toRemove, key)
}
// Remove closed streams
if state == Closed || state == Disconnecting {
toRemove = append(toRemove, key)
}
}
// Remove dead streams
for _, key := range toRemove {
if stream, ok := sf.streams[key]; ok {
delete(sf.streams, key)
delete(sf.streamsByID, stream.GetSessionID())
// Call closed callback
if sf.onStreamClosed != nil {
sf.onStreamClosed(stream)
}
}
}
}
// generateSessionID generates a random session ID
func (sf *StreamFactory) generateSessionID() uint32 {
var id uint32
for {
binary.Read(rand.Reader, binary.BigEndian, &id)
// Make sure it's not already in use
if _, exists := sf.streamsByID[id]; !exists && id != 0 {
return id
}
}
}
// generateKey generates a random key
func (sf *StreamFactory) generateKey() uint32 {
var key uint32
binary.Read(rand.Reader, binary.BigEndian, &key)
return key
}
// GetStream gets a stream by remote address
func (sf *StreamFactory) GetStream(remoteAddr *net.UDPAddr) *Stream {
sf.mu.RLock()
defer sf.mu.RUnlock()
key := fmt.Sprintf("%s:%d", remoteAddr.IP.String(), remoteAddr.Port)
return sf.streams[key]
}
// GetStreamByID gets a stream by session ID
func (sf *StreamFactory) GetStreamByID(sessionID uint32) *Stream {
sf.mu.RLock()
defer sf.mu.RUnlock()
return sf.streamsByID[sessionID]
}
// GetStreamCount returns the number of active streams
func (sf *StreamFactory) GetStreamCount() int {
sf.mu.RLock()
defer sf.mu.RUnlock()
return len(sf.streams)
}
// SetOnNewStream sets the new stream callback
func (sf *StreamFactory) SetOnNewStream(callback func(*Stream)) {
sf.mu.Lock()
defer sf.mu.Unlock()
sf.onNewStream = callback
}
// SetOnStreamClosed sets the stream closed callback
func (sf *StreamFactory) SetOnStreamClosed(callback func(*Stream)) {
sf.mu.Lock()
defer sf.mu.Unlock()
sf.onStreamClosed = callback
}
// Broadcast sends a packet to all connected streams
func (sf *StreamFactory) Broadcast(packet *ApplicationPacket) {
sf.mu.RLock()
streams := make([]*Stream, 0, len(sf.streams))
for _, stream := range sf.streams {
if stream.GetState() == Established {
streams = append(streams, stream)
}
}
sf.mu.RUnlock()
for _, stream := range streams {
stream.Send(packet)
}
}