package eq2net import ( "crypto/rand" "encoding/binary" "fmt" "net" "sync" "time" ) // StreamFactory manages multiple UDP streams type StreamFactory struct { mu sync.RWMutex // Network conn *net.UDPConn listenAddr *net.UDPAddr streamType StreamType // Streams streams map[string]*Stream // Key is "IP:Port" streamsByID map[uint32]*Stream // Key is session ID // Options maxStreams int readTimeout time.Duration writeTimeout time.Duration // State running bool stopChan chan struct{} // Callbacks onNewStream func(*Stream) onStreamClosed func(*Stream) } // NewStreamFactory creates a new stream factory func NewStreamFactory(listenAddr string, streamType StreamType) (*StreamFactory, error) { addr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { return nil, err } sf := &StreamFactory{ listenAddr: addr, streamType: streamType, streams: make(map[string]*Stream), streamsByID: make(map[uint32]*Stream), maxStreams: 1000, readTimeout: 30 * time.Second, writeTimeout: 5 * time.Second, stopChan: make(chan struct{}), } return sf, nil } // Start starts the stream factory func (sf *StreamFactory) Start() error { sf.mu.Lock() defer sf.mu.Unlock() if sf.running { return fmt.Errorf("stream factory already running") } // Create UDP connection conn, err := net.ListenUDP("udp", sf.listenAddr) if err != nil { return err } sf.conn = conn // Set buffer sizes conn.SetReadBuffer(65536) conn.SetWriteBuffer(65536) sf.running = true // Start worker goroutines go sf.readLoop() go sf.writeLoop() go sf.maintenanceLoop() return nil } // Stop stops the stream factory func (sf *StreamFactory) Stop() { sf.mu.Lock() defer sf.mu.Unlock() if !sf.running { return } sf.running = false close(sf.stopChan) if sf.conn != nil { sf.conn.Close() } // Close all streams for _, stream := range sf.streams { stream.SetState(Closed) if stream.onDisconnect != nil { stream.onDisconnect() } } sf.streams = make(map[string]*Stream) sf.streamsByID = make(map[uint32]*Stream) } // readLoop reads packets from the UDP connection func (sf *StreamFactory) readLoop() { buffer := make([]byte, 65536) for { select { case <-sf.stopChan: return default: // Set read deadline sf.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) // Read packet n, remoteAddr, err := sf.conn.ReadFromUDP(buffer) if err != nil { // Check if it's a timeout (expected) if netErr, ok := err.(net.Error); ok && netErr.Timeout() { continue } // Actual error continue } // Process packet data := make([]byte, n) copy(data, buffer[:n]) go sf.processPacket(data, remoteAddr) } } } // writeLoop handles writing packets to streams func (sf *StreamFactory) writeLoop() { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { case <-sf.stopChan: return case <-ticker.C: sf.processOutgoingPackets() } } } // maintenanceLoop performs periodic maintenance func (sf *StreamFactory) maintenanceLoop() { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-sf.stopChan: return case <-ticker.C: sf.performMaintenance() } } } // processPacket processes an incoming packet func (sf *StreamFactory) processPacket(data []byte, remoteAddr *net.UDPAddr) { if len(data) < 2 { return } // Get or create stream stream := sf.getOrCreateStream(remoteAddr, data) if stream == nil { return } // Process packet in stream err := stream.Process(data) if err != nil { // Handle error (log it, etc.) } } // getOrCreateStream gets an existing stream or creates a new one func (sf *StreamFactory) getOrCreateStream(remoteAddr *net.UDPAddr, data []byte) *Stream { sf.mu.Lock() defer sf.mu.Unlock() key := fmt.Sprintf("%s:%d", remoteAddr.IP.String(), remoteAddr.Port) // Check if stream exists if stream, ok := sf.streams[key]; ok { return stream } // Check if this is a session request if len(data) >= 2 { opcode := binary.BigEndian.Uint16(data) if uint8(opcode&0xFF) == OPSessionRequest { // Create new stream stream := NewStream(remoteAddr, sf.streamType) // Generate session ID and key sessionID := sf.generateSessionID() sessionKey := sf.generateKey() stream.SetSessionInfo(sessionID, sessionKey) // Add to maps sf.streams[key] = stream sf.streamsByID[sessionID] = stream // Call new stream callback if sf.onNewStream != nil { sf.onNewStream(stream) } return stream } } return nil } // processOutgoingPackets processes outgoing packets for all streams func (sf *StreamFactory) processOutgoingPackets() { sf.mu.RLock() streams := make([]*Stream, 0, len(sf.streams)) for _, stream := range sf.streams { streams = append(streams, stream) } sf.mu.RUnlock() for _, stream := range streams { sf.processStreamOutgoing(stream) } } // processStreamOutgoing processes outgoing packets for a stream func (sf *StreamFactory) processStreamOutgoing(stream *Stream) { stream.mu.Lock() defer stream.mu.Unlock() // Process outgoing queue for len(stream.outgoingQueue) > 0 { packet := stream.outgoingQueue[0] stream.outgoingQueue = stream.outgoingQueue[1:] // Assign sequence number packet.Sequence = stream.nextOutSeq stream.nextOutSeq++ // Build packet data data := sf.buildPacketData(stream, packet) // Send packet sf.sendPacket(stream.remoteAddr, data) // Add to sent queue for acknowledgment tracking packet.SentTime = int32(time.Now().Unix()) stream.sentQueue[packet.Sequence] = packet } // Check for retransmissions now := time.Now() for seq, packet := range stream.sentQueue { sentTime := time.Unix(int64(packet.SentTime), 0) if now.Sub(sentTime) > stream.retransmitTimeout { if packet.AttemptCount < 3 { // Retransmit packet.AttemptCount++ packet.SentTime = int32(now.Unix()) data := sf.buildPacketData(stream, packet) sf.sendPacket(stream.remoteAddr, data) } else { // Give up on this packet delete(stream.sentQueue, seq) } } } } // buildPacketData builds the complete packet data including headers and CRC func (sf *StreamFactory) buildPacketData(stream *Stream, packet *ProtocolPacket) []byte { // Check if we need to fragment maxDataSize := int(stream.maxLength) - 6 // Header (4) + CRC (2) if len(packet.Buffer) > maxDataSize { // Need to fragment return sf.buildFragmentData(stream, packet) } // Build regular packet data := make([]byte, len(packet.Buffer)+6) // Header binary.BigEndian.PutUint16(data[0:2], uint16(OPPacket)) binary.BigEndian.PutUint16(data[2:4], packet.Sequence) // Data copy(data[4:], packet.Buffer) // CRC crc := CalculateCRC16(data[:len(data)-2], stream.key) binary.BigEndian.PutUint16(data[len(data)-2:], crc) return data } // buildFragmentData builds a fragment packet func (sf *StreamFactory) buildFragmentData(stream *Stream, packet *ProtocolPacket) []byte { // For now, just truncate (proper fragmentation would split into multiple packets) // This is a simplified version - real implementation would handle multiple fragments maxDataSize := int(stream.maxLength) - 10 // Header (8) + CRC (2) dataSize := len(packet.Buffer) if dataSize > maxDataSize { dataSize = maxDataSize } data := make([]byte, dataSize+10) // Header binary.BigEndian.PutUint16(data[0:2], uint16(OPFragment)) binary.BigEndian.PutUint16(data[2:4], packet.Sequence) binary.BigEndian.PutUint32(data[4:8], uint32(len(packet.Buffer))) // Data copy(data[8:], packet.Buffer[:dataSize]) // CRC crc := CalculateCRC16(data[:len(data)-2], stream.key) binary.BigEndian.PutUint16(data[len(data)-2:], crc) return data } // sendPacket sends a packet func (sf *StreamFactory) sendPacket(addr *net.UDPAddr, data []byte) error { sf.conn.SetWriteDeadline(time.Now().Add(sf.writeTimeout)) _, err := sf.conn.WriteToUDP(data, addr) return err } // performMaintenance performs periodic maintenance tasks func (sf *StreamFactory) performMaintenance() { sf.mu.Lock() defer sf.mu.Unlock() now := time.Now() toRemove := []string{} for key, stream := range sf.streams { stream.mu.RLock() lastReceive := stream.lastReceiveTime state := stream.state stream.mu.RUnlock() // Check for timeout if now.Sub(lastReceive) > sf.readTimeout && state == Established { stream.SetState(Disconnecting) toRemove = append(toRemove, key) } // Remove closed streams if state == Closed || state == Disconnecting { toRemove = append(toRemove, key) } } // Remove dead streams for _, key := range toRemove { if stream, ok := sf.streams[key]; ok { delete(sf.streams, key) delete(sf.streamsByID, stream.GetSessionID()) // Call closed callback if sf.onStreamClosed != nil { sf.onStreamClosed(stream) } } } } // generateSessionID generates a random session ID func (sf *StreamFactory) generateSessionID() uint32 { var id uint32 for { binary.Read(rand.Reader, binary.BigEndian, &id) // Make sure it's not already in use if _, exists := sf.streamsByID[id]; !exists && id != 0 { return id } } } // generateKey generates a random key func (sf *StreamFactory) generateKey() uint32 { var key uint32 binary.Read(rand.Reader, binary.BigEndian, &key) return key } // GetStream gets a stream by remote address func (sf *StreamFactory) GetStream(remoteAddr *net.UDPAddr) *Stream { sf.mu.RLock() defer sf.mu.RUnlock() key := fmt.Sprintf("%s:%d", remoteAddr.IP.String(), remoteAddr.Port) return sf.streams[key] } // GetStreamByID gets a stream by session ID func (sf *StreamFactory) GetStreamByID(sessionID uint32) *Stream { sf.mu.RLock() defer sf.mu.RUnlock() return sf.streamsByID[sessionID] } // GetStreamCount returns the number of active streams func (sf *StreamFactory) GetStreamCount() int { sf.mu.RLock() defer sf.mu.RUnlock() return len(sf.streams) } // SetOnNewStream sets the new stream callback func (sf *StreamFactory) SetOnNewStream(callback func(*Stream)) { sf.mu.Lock() defer sf.mu.Unlock() sf.onNewStream = callback } // SetOnStreamClosed sets the stream closed callback func (sf *StreamFactory) SetOnStreamClosed(callback func(*Stream)) { sf.mu.Lock() defer sf.mu.Unlock() sf.onStreamClosed = callback } // Broadcast sends a packet to all connected streams func (sf *StreamFactory) Broadcast(packet *ApplicationPacket) { sf.mu.RLock() streams := make([]*Stream, 0, len(sf.streams)) for _, stream := range sf.streams { if stream.GetState() == Established { streams = append(streams, stream) } } sf.mu.RUnlock() for _, stream := range streams { stream.Send(packet) } }