471 lines
10 KiB
Go
471 lines
10 KiB
Go
package eq2net
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// StreamFactory manages multiple UDP streams
|
|
type StreamFactory struct {
|
|
mu sync.RWMutex
|
|
|
|
// Network
|
|
conn *net.UDPConn
|
|
listenAddr *net.UDPAddr
|
|
streamType StreamType
|
|
|
|
// Streams
|
|
streams map[string]*Stream // Key is "IP:Port"
|
|
streamsByID map[uint32]*Stream // Key is session ID
|
|
|
|
// Options
|
|
maxStreams int
|
|
readTimeout time.Duration
|
|
writeTimeout time.Duration
|
|
|
|
// State
|
|
running bool
|
|
stopChan chan struct{}
|
|
|
|
// Callbacks
|
|
onNewStream func(*Stream)
|
|
onStreamClosed func(*Stream)
|
|
}
|
|
|
|
// NewStreamFactory creates a new stream factory
|
|
func NewStreamFactory(listenAddr string, streamType StreamType) (*StreamFactory, error) {
|
|
addr, err := net.ResolveUDPAddr("udp", listenAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sf := &StreamFactory{
|
|
listenAddr: addr,
|
|
streamType: streamType,
|
|
streams: make(map[string]*Stream),
|
|
streamsByID: make(map[uint32]*Stream),
|
|
maxStreams: 1000,
|
|
readTimeout: 30 * time.Second,
|
|
writeTimeout: 5 * time.Second,
|
|
stopChan: make(chan struct{}),
|
|
}
|
|
|
|
return sf, nil
|
|
}
|
|
|
|
// Start starts the stream factory
|
|
func (sf *StreamFactory) Start() error {
|
|
sf.mu.Lock()
|
|
defer sf.mu.Unlock()
|
|
|
|
if sf.running {
|
|
return fmt.Errorf("stream factory already running")
|
|
}
|
|
|
|
// Create UDP connection
|
|
conn, err := net.ListenUDP("udp", sf.listenAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sf.conn = conn
|
|
|
|
// Set buffer sizes
|
|
conn.SetReadBuffer(65536)
|
|
conn.SetWriteBuffer(65536)
|
|
|
|
sf.running = true
|
|
|
|
// Start worker goroutines
|
|
go sf.readLoop()
|
|
go sf.writeLoop()
|
|
go sf.maintenanceLoop()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop stops the stream factory
|
|
func (sf *StreamFactory) Stop() {
|
|
sf.mu.Lock()
|
|
defer sf.mu.Unlock()
|
|
|
|
if !sf.running {
|
|
return
|
|
}
|
|
|
|
sf.running = false
|
|
close(sf.stopChan)
|
|
|
|
if sf.conn != nil {
|
|
sf.conn.Close()
|
|
}
|
|
|
|
// Close all streams
|
|
for _, stream := range sf.streams {
|
|
stream.SetState(Closed)
|
|
if stream.onDisconnect != nil {
|
|
stream.onDisconnect()
|
|
}
|
|
}
|
|
|
|
sf.streams = make(map[string]*Stream)
|
|
sf.streamsByID = make(map[uint32]*Stream)
|
|
}
|
|
|
|
// readLoop reads packets from the UDP connection
|
|
func (sf *StreamFactory) readLoop() {
|
|
buffer := make([]byte, 65536)
|
|
|
|
for {
|
|
select {
|
|
case <-sf.stopChan:
|
|
return
|
|
default:
|
|
// Set read deadline
|
|
sf.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
|
|
|
// Read packet
|
|
n, remoteAddr, err := sf.conn.ReadFromUDP(buffer)
|
|
if err != nil {
|
|
// Check if it's a timeout (expected)
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
continue
|
|
}
|
|
// Actual error
|
|
continue
|
|
}
|
|
|
|
// Process packet
|
|
data := make([]byte, n)
|
|
copy(data, buffer[:n])
|
|
go sf.processPacket(data, remoteAddr)
|
|
}
|
|
}
|
|
}
|
|
|
|
// writeLoop handles writing packets to streams
|
|
func (sf *StreamFactory) writeLoop() {
|
|
ticker := time.NewTicker(10 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-sf.stopChan:
|
|
return
|
|
case <-ticker.C:
|
|
sf.processOutgoingPackets()
|
|
}
|
|
}
|
|
}
|
|
|
|
// maintenanceLoop performs periodic maintenance
|
|
func (sf *StreamFactory) maintenanceLoop() {
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-sf.stopChan:
|
|
return
|
|
case <-ticker.C:
|
|
sf.performMaintenance()
|
|
}
|
|
}
|
|
}
|
|
|
|
// processPacket processes an incoming packet
|
|
func (sf *StreamFactory) processPacket(data []byte, remoteAddr *net.UDPAddr) {
|
|
if len(data) < 2 {
|
|
return
|
|
}
|
|
|
|
// Get or create stream
|
|
stream := sf.getOrCreateStream(remoteAddr, data)
|
|
if stream == nil {
|
|
return
|
|
}
|
|
|
|
// Process packet in stream
|
|
err := stream.Process(data)
|
|
if err != nil {
|
|
// Handle error (log it, etc.)
|
|
}
|
|
}
|
|
|
|
// getOrCreateStream gets an existing stream or creates a new one
|
|
func (sf *StreamFactory) getOrCreateStream(remoteAddr *net.UDPAddr, data []byte) *Stream {
|
|
sf.mu.Lock()
|
|
defer sf.mu.Unlock()
|
|
|
|
key := fmt.Sprintf("%s:%d", remoteAddr.IP.String(), remoteAddr.Port)
|
|
|
|
// Check if stream exists
|
|
if stream, ok := sf.streams[key]; ok {
|
|
return stream
|
|
}
|
|
|
|
// Check if this is a session request
|
|
if len(data) >= 2 {
|
|
opcode := binary.BigEndian.Uint16(data)
|
|
if uint8(opcode&0xFF) == OPSessionRequest {
|
|
// Create new stream
|
|
stream := NewStream(remoteAddr, sf.streamType)
|
|
|
|
// Generate session ID and key
|
|
sessionID := sf.generateSessionID()
|
|
sessionKey := sf.generateKey()
|
|
stream.SetSessionInfo(sessionID, sessionKey)
|
|
|
|
// Add to maps
|
|
sf.streams[key] = stream
|
|
sf.streamsByID[sessionID] = stream
|
|
|
|
// Call new stream callback
|
|
if sf.onNewStream != nil {
|
|
sf.onNewStream(stream)
|
|
}
|
|
|
|
return stream
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// processOutgoingPackets processes outgoing packets for all streams
|
|
func (sf *StreamFactory) processOutgoingPackets() {
|
|
sf.mu.RLock()
|
|
streams := make([]*Stream, 0, len(sf.streams))
|
|
for _, stream := range sf.streams {
|
|
streams = append(streams, stream)
|
|
}
|
|
sf.mu.RUnlock()
|
|
|
|
for _, stream := range streams {
|
|
sf.processStreamOutgoing(stream)
|
|
}
|
|
}
|
|
|
|
// processStreamOutgoing processes outgoing packets for a stream
|
|
func (sf *StreamFactory) processStreamOutgoing(stream *Stream) {
|
|
stream.mu.Lock()
|
|
defer stream.mu.Unlock()
|
|
|
|
// Process outgoing queue
|
|
for len(stream.outgoingQueue) > 0 {
|
|
packet := stream.outgoingQueue[0]
|
|
stream.outgoingQueue = stream.outgoingQueue[1:]
|
|
|
|
// Assign sequence number
|
|
packet.Sequence = stream.nextOutSeq
|
|
stream.nextOutSeq++
|
|
|
|
// Build packet data
|
|
data := sf.buildPacketData(stream, packet)
|
|
|
|
// Send packet
|
|
sf.sendPacket(stream.remoteAddr, data)
|
|
|
|
// Add to sent queue for acknowledgment tracking
|
|
packet.SentTime = int32(time.Now().Unix())
|
|
stream.sentQueue[packet.Sequence] = packet
|
|
}
|
|
|
|
// Check for retransmissions
|
|
now := time.Now()
|
|
for seq, packet := range stream.sentQueue {
|
|
sentTime := time.Unix(int64(packet.SentTime), 0)
|
|
if now.Sub(sentTime) > stream.retransmitTimeout {
|
|
if packet.AttemptCount < 3 {
|
|
// Retransmit
|
|
packet.AttemptCount++
|
|
packet.SentTime = int32(now.Unix())
|
|
|
|
data := sf.buildPacketData(stream, packet)
|
|
sf.sendPacket(stream.remoteAddr, data)
|
|
} else {
|
|
// Give up on this packet
|
|
delete(stream.sentQueue, seq)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// buildPacketData builds the complete packet data including headers and CRC
|
|
func (sf *StreamFactory) buildPacketData(stream *Stream, packet *ProtocolPacket) []byte {
|
|
// Check if we need to fragment
|
|
maxDataSize := int(stream.maxLength) - 6 // Header (4) + CRC (2)
|
|
|
|
if len(packet.Buffer) > maxDataSize {
|
|
// Need to fragment
|
|
return sf.buildFragmentData(stream, packet)
|
|
}
|
|
|
|
// Build regular packet
|
|
data := make([]byte, len(packet.Buffer)+6)
|
|
|
|
// Header
|
|
binary.BigEndian.PutUint16(data[0:2], uint16(OPPacket))
|
|
binary.BigEndian.PutUint16(data[2:4], packet.Sequence)
|
|
|
|
// Data
|
|
copy(data[4:], packet.Buffer)
|
|
|
|
// CRC
|
|
crc := CalculateCRC16(data[:len(data)-2], stream.key)
|
|
binary.BigEndian.PutUint16(data[len(data)-2:], crc)
|
|
|
|
return data
|
|
}
|
|
|
|
// buildFragmentData builds a fragment packet
|
|
func (sf *StreamFactory) buildFragmentData(stream *Stream, packet *ProtocolPacket) []byte {
|
|
// For now, just truncate (proper fragmentation would split into multiple packets)
|
|
// This is a simplified version - real implementation would handle multiple fragments
|
|
|
|
maxDataSize := int(stream.maxLength) - 10 // Header (8) + CRC (2)
|
|
dataSize := len(packet.Buffer)
|
|
if dataSize > maxDataSize {
|
|
dataSize = maxDataSize
|
|
}
|
|
|
|
data := make([]byte, dataSize+10)
|
|
|
|
// Header
|
|
binary.BigEndian.PutUint16(data[0:2], uint16(OPFragment))
|
|
binary.BigEndian.PutUint16(data[2:4], packet.Sequence)
|
|
binary.BigEndian.PutUint32(data[4:8], uint32(len(packet.Buffer)))
|
|
|
|
// Data
|
|
copy(data[8:], packet.Buffer[:dataSize])
|
|
|
|
// CRC
|
|
crc := CalculateCRC16(data[:len(data)-2], stream.key)
|
|
binary.BigEndian.PutUint16(data[len(data)-2:], crc)
|
|
|
|
return data
|
|
}
|
|
|
|
// sendPacket sends a packet
|
|
func (sf *StreamFactory) sendPacket(addr *net.UDPAddr, data []byte) error {
|
|
sf.conn.SetWriteDeadline(time.Now().Add(sf.writeTimeout))
|
|
_, err := sf.conn.WriteToUDP(data, addr)
|
|
return err
|
|
}
|
|
|
|
// performMaintenance performs periodic maintenance tasks
|
|
func (sf *StreamFactory) performMaintenance() {
|
|
sf.mu.Lock()
|
|
defer sf.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
toRemove := []string{}
|
|
|
|
for key, stream := range sf.streams {
|
|
stream.mu.RLock()
|
|
lastReceive := stream.lastReceiveTime
|
|
state := stream.state
|
|
stream.mu.RUnlock()
|
|
|
|
// Check for timeout
|
|
if now.Sub(lastReceive) > sf.readTimeout && state == Established {
|
|
stream.SetState(Disconnecting)
|
|
toRemove = append(toRemove, key)
|
|
}
|
|
|
|
// Remove closed streams
|
|
if state == Closed || state == Disconnecting {
|
|
toRemove = append(toRemove, key)
|
|
}
|
|
}
|
|
|
|
// Remove dead streams
|
|
for _, key := range toRemove {
|
|
if stream, ok := sf.streams[key]; ok {
|
|
delete(sf.streams, key)
|
|
delete(sf.streamsByID, stream.GetSessionID())
|
|
|
|
// Call closed callback
|
|
if sf.onStreamClosed != nil {
|
|
sf.onStreamClosed(stream)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// generateSessionID generates a random session ID
|
|
func (sf *StreamFactory) generateSessionID() uint32 {
|
|
var id uint32
|
|
for {
|
|
binary.Read(rand.Reader, binary.BigEndian, &id)
|
|
// Make sure it's not already in use
|
|
if _, exists := sf.streamsByID[id]; !exists && id != 0 {
|
|
return id
|
|
}
|
|
}
|
|
}
|
|
|
|
// generateKey generates a random key
|
|
func (sf *StreamFactory) generateKey() uint32 {
|
|
var key uint32
|
|
binary.Read(rand.Reader, binary.BigEndian, &key)
|
|
return key
|
|
}
|
|
|
|
// GetStream gets a stream by remote address
|
|
func (sf *StreamFactory) GetStream(remoteAddr *net.UDPAddr) *Stream {
|
|
sf.mu.RLock()
|
|
defer sf.mu.RUnlock()
|
|
|
|
key := fmt.Sprintf("%s:%d", remoteAddr.IP.String(), remoteAddr.Port)
|
|
return sf.streams[key]
|
|
}
|
|
|
|
// GetStreamByID gets a stream by session ID
|
|
func (sf *StreamFactory) GetStreamByID(sessionID uint32) *Stream {
|
|
sf.mu.RLock()
|
|
defer sf.mu.RUnlock()
|
|
|
|
return sf.streamsByID[sessionID]
|
|
}
|
|
|
|
// GetStreamCount returns the number of active streams
|
|
func (sf *StreamFactory) GetStreamCount() int {
|
|
sf.mu.RLock()
|
|
defer sf.mu.RUnlock()
|
|
|
|
return len(sf.streams)
|
|
}
|
|
|
|
// SetOnNewStream sets the new stream callback
|
|
func (sf *StreamFactory) SetOnNewStream(callback func(*Stream)) {
|
|
sf.mu.Lock()
|
|
defer sf.mu.Unlock()
|
|
sf.onNewStream = callback
|
|
}
|
|
|
|
// SetOnStreamClosed sets the stream closed callback
|
|
func (sf *StreamFactory) SetOnStreamClosed(callback func(*Stream)) {
|
|
sf.mu.Lock()
|
|
defer sf.mu.Unlock()
|
|
sf.onStreamClosed = callback
|
|
}
|
|
|
|
// Broadcast sends a packet to all connected streams
|
|
func (sf *StreamFactory) Broadcast(packet *ApplicationPacket) {
|
|
sf.mu.RLock()
|
|
streams := make([]*Stream, 0, len(sf.streams))
|
|
for _, stream := range sf.streams {
|
|
if stream.GetState() == Established {
|
|
streams = append(streams, stream)
|
|
}
|
|
}
|
|
sf.mu.RUnlock()
|
|
|
|
for _, stream := range streams {
|
|
stream.Send(packet)
|
|
}
|
|
} |