package udp import ( "context" "fmt" "net" "sync" "sync/atomic" "time" ) // Conn represents a reliable UDP connection type Conn interface { Read(b []byte) (n int, err error) Write(b []byte) (n int, err error) Close() error LocalAddr() net.Addr RemoteAddr() net.Addr SetReadDeadline(t time.Time) error SetWriteDeadline(t time.Time) error } // Listener listens for incoming reliable UDP connections type Listener interface { Accept() (Conn, error) Close() error Addr() net.Addr } // stream implements a reliable UDP stream type stream struct { conn *net.UDPConn remoteAddr *net.UDPAddr localAddr *net.UDPAddr session uint32 config *Config // Sequence tracking sendSeq uint32 recvSeq uint16 lastAckSent uint16 // Channels for communication inbound chan []byte outbound chan []byte control chan *packet done chan struct{} closeOnce sync.Once // Reliability tracking pending map[uint16]*pendingPacket pendingMutex sync.RWMutex outOfOrder map[uint16][]byte oooMutex sync.RWMutex // Flow control windowSize uint16 // Read/Write deadlines readDeadline atomic.Value writeDeadline atomic.Value // Last activity for keep-alive lastActivity time.Time activityMutex sync.RWMutex } // newStream creates a new reliable UDP stream func newStream(conn *net.UDPConn, remoteAddr *net.UDPAddr, session uint32, config *Config) *stream { s := &stream{ conn: conn, remoteAddr: remoteAddr, localAddr: conn.LocalAddr().(*net.UDPAddr), session: session, config: config, windowSize: config.WindowSize, inbound: make(chan []byte, 256), outbound: make(chan []byte, 256), control: make(chan *packet, 64), done: make(chan struct{}), pending: make(map[uint16]*pendingPacket), outOfOrder: make(map[uint16][]byte), lastActivity: time.Now(), } // Start background goroutines go s.writeLoop() go s.retransmitLoop() go s.keepAliveLoop() return s } // Read implements Conn.Read func (s *stream) Read(b []byte) (n int, err error) { ctx, cancel := s.getReadDeadlineContext() defer cancel() select { case data := <-s.inbound: n = copy(b, data) if n < len(data) { return n, fmt.Errorf("buffer too small: need %d bytes, got %d", len(data), len(b)) } return n, nil case <-s.done: return 0, fmt.Errorf("connection closed") case <-ctx.Done(): return 0, fmt.Errorf("read timeout") } } // Write implements Conn.Write func (s *stream) Write(b []byte) (n int, err error) { if len(b) == 0 { return 0, nil } // Fragment large packets mtu := s.config.MTU - 15 // Account for packet header if len(b) <= mtu { return s.writePacket(b) } // Fragment the data sent := 0 for sent < len(b) { end := sent + mtu if end > len(b) { end = len(b) } n, err := s.writePacket(b[sent:end]) sent += n if err != nil { return sent, err } } return sent, nil } // writePacket writes a single packet func (s *stream) writePacket(data []byte) (int, error) { ctx, cancel := s.getWriteDeadlineContext() defer cancel() select { case s.outbound <- data: s.updateActivity() return len(data), nil case <-s.done: return 0, fmt.Errorf("connection closed") case <-ctx.Done(): return 0, fmt.Errorf("write timeout") } } // writeLoop handles outbound packet transmission func (s *stream) writeLoop() { defer close(s.outbound) for { select { case data := <-s.outbound: s.sendDataPacket(data) case ctrlPacket := <-s.control: s.sendControlPacket(ctrlPacket) case <-s.done: return } } } // sendDataPacket sends a data packet with reliability func (s *stream) sendDataPacket(data []byte) { seq := uint16(atomic.AddUint32(&s.sendSeq, 1) - 1) pkt := &packet{ Type: PacketTypeData, Sequence: seq, Ack: s.lastAckSent, Session: s.session, Data: data, } // Store for retransmission s.pendingMutex.Lock() s.pending[seq] = &pendingPacket{ packet: pkt, timestamp: time.Now(), attempts: 0, } s.pendingMutex.Unlock() s.sendRawPacket(pkt) } // sendControlPacket sends control packets (ACKs, etc.) func (s *stream) sendControlPacket(pkt *packet) { pkt.Session = s.session s.sendRawPacket(pkt) } // sendRawPacket sends a packet over UDP func (s *stream) sendRawPacket(pkt *packet) { data := pkt.Marshal() s.conn.WriteToUDP(data, s.remoteAddr) } // handlePacket processes an incoming packet func (s *stream) handlePacket(pkt *packet) { s.updateActivity() switch pkt.Type { case PacketTypeData: s.handleDataPacket(pkt) case PacketTypeAck: s.handleAckPacket(pkt) case PacketTypeKeepAlive: s.sendAck(pkt.Sequence) case PacketTypeDisconnect: s.Close() } } // handleDataPacket processes incoming data packets func (s *stream) handleDataPacket(pkt *packet) { // Send ACK s.sendAck(pkt.Sequence) // Check sequence order expectedSeq := s.recvSeq + 1 if pkt.Sequence == expectedSeq { // In order - deliver immediately s.deliverData(pkt.Data) s.recvSeq = pkt.Sequence // Check for buffered out-of-order packets s.processOutOfOrder() } else if pkt.Sequence > expectedSeq { // Future packet - buffer it s.oooMutex.Lock() s.outOfOrder[pkt.Sequence] = pkt.Data s.oooMutex.Unlock() } // Past packets are ignored (duplicate) } // processOutOfOrder delivers buffered in-order packets func (s *stream) processOutOfOrder() { s.oooMutex.Lock() defer s.oooMutex.Unlock() for { nextSeq := s.recvSeq + 1 if data, exists := s.outOfOrder[nextSeq]; exists { s.deliverData(data) s.recvSeq = nextSeq delete(s.outOfOrder, nextSeq) } else { break } } } // deliverData delivers data to the application func (s *stream) deliverData(data []byte) { select { case s.inbound <- data: case <-s.done: default: // Channel full - would block } } // handleAckPacket processes acknowledgment packets func (s *stream) handleAckPacket(pkt *packet) { s.pendingMutex.Lock() defer s.pendingMutex.Unlock() if pending, exists := s.pending[pkt.Sequence]; exists { delete(s.pending, pkt.Sequence) _ = pending // Packet acknowledged } } // sendAck sends an acknowledgment func (s *stream) sendAck(seq uint16) { s.lastAckSent = seq ackPkt := &packet{ Type: PacketTypeAck, Sequence: seq, Ack: seq, } select { case s.control <- ackPkt: case <-s.done: default: } } // retransmitLoop handles packet retransmission func (s *stream) retransmitLoop() { ticker := time.NewTicker(time.Second) defer ticker.Stop() for { select { case <-ticker.C: s.checkRetransmissions() case <-s.done: return } } } // checkRetransmissions checks for packets needing retransmission func (s *stream) checkRetransmissions() { now := time.Now() s.pendingMutex.Lock() defer s.pendingMutex.Unlock() for seq, pending := range s.pending { if now.Sub(pending.timestamp) > RetransmitTimeout { if pending.attempts >= s.config.RetryAttempts { // Too many attempts - close connection delete(s.pending, seq) go s.Close() return } // Retransmit pending.attempts++ pending.timestamp = now s.sendRawPacket(pending.packet) } } } // keepAliveLoop sends periodic keep-alive packets func (s *stream) keepAliveLoop() { ticker := time.NewTicker(KeepAliveInterval) defer ticker.Stop() for { select { case <-ticker.C: s.activityMutex.RLock() idle := time.Since(s.lastActivity) s.activityMutex.RUnlock() if idle > KeepAliveInterval { keepAlive := &packet{Type: PacketTypeKeepAlive} select { case s.control <- keepAlive: case <-s.done: return } } case <-s.done: return } } } // updateActivity updates the last activity timestamp func (s *stream) updateActivity() { s.activityMutex.Lock() s.lastActivity = time.Now() s.activityMutex.Unlock() } // Close implements Conn.Close func (s *stream) Close() error { s.closeOnce.Do(func() { // Send disconnect packet disconnect := &packet{Type: PacketTypeDisconnect} select { case s.control <- disconnect: default: } close(s.done) }) return nil } // Address methods func (s *stream) LocalAddr() net.Addr { return s.localAddr } func (s *stream) RemoteAddr() net.Addr { return s.remoteAddr } // Deadline methods func (s *stream) SetReadDeadline(t time.Time) error { s.readDeadline.Store(t) return nil } func (s *stream) SetWriteDeadline(t time.Time) error { s.writeDeadline.Store(t) return nil } func (s *stream) getReadDeadlineContext() (context.Context, context.CancelFunc) { if deadline, ok := s.readDeadline.Load().(time.Time); ok && !deadline.IsZero() { return context.WithDeadline(context.Background(), deadline) } return context.Background(), func() {} } func (s *stream) getWriteDeadlineContext() (context.Context, context.CancelFunc) { if deadline, ok := s.writeDeadline.Load().(time.Time); ok && !deadline.IsZero() { return context.WithDeadline(context.Background(), deadline) } return context.Background(), func() {} } // listener implements a reliable UDP listener type listener struct { conn *net.UDPConn config *Config streams map[string]*stream mutex sync.RWMutex incoming chan *stream done chan struct{} } // Listen creates a new reliable UDP listener func Listen(address string, config *Config) (Listener, error) { if config == nil { config = DefaultConfig() } addr, err := net.ResolveUDPAddr("udp", address) if err != nil { return nil, err } conn, err := net.ListenUDP("udp", addr) if err != nil { return nil, err } l := &listener{ conn: conn, config: config, streams: make(map[string]*stream), incoming: make(chan *stream, 16), done: make(chan struct{}), } go l.readLoop() return l, nil } // readLoop handles incoming UDP packets func (l *listener) readLoop() { buf := make([]byte, 2048) for { select { case <-l.done: return default: } n, addr, err := l.conn.ReadFromUDP(buf) if err != nil { continue } pkt := &packet{} if err := pkt.Unmarshal(buf[:n]); err != nil { continue } l.handlePacket(pkt, addr) } } // handlePacket routes packets to appropriate streams func (l *listener) handlePacket(pkt *packet, addr *net.UDPAddr) { streamKey := addr.String() l.mutex.RLock() stream, exists := l.streams[streamKey] l.mutex.RUnlock() if !exists && pkt.Type == PacketTypeSessionRequest { // New connection session := pkt.Session stream = newStream(l.conn, addr, session, l.config) l.mutex.Lock() l.streams[streamKey] = stream l.mutex.Unlock() // Send session response response := &packet{ Type: PacketTypeSessionResponse, Session: session, } stream.sendControlPacket(response) select { case l.incoming <- stream: case <-l.done: } } else if exists { stream.handlePacket(pkt) } } // Accept implements Listener.Accept func (l *listener) Accept() (Conn, error) { select { case stream := <-l.incoming: return stream, nil case <-l.done: return nil, fmt.Errorf("listener closed") } } // Close implements Listener.Close func (l *listener) Close() error { close(l.done) l.mutex.Lock() defer l.mutex.Unlock() for _, stream := range l.streams { stream.Close() } return l.conn.Close() } // Addr implements Listener.Addr func (l *listener) Addr() net.Addr { return l.conn.LocalAddr() } // Dial creates a client connection to a reliable UDP server func Dial(address string, config *Config) (Conn, error) { if config == nil { config = DefaultConfig() } addr, err := net.ResolveUDPAddr("udp", address) if err != nil { return nil, err } conn, err := net.DialUDP("udp", nil, addr) if err != nil { return nil, err } session := uint32(time.Now().Unix()) stream := newStream(conn, addr, session, config) // Send session request request := &packet{ Type: PacketTypeSessionRequest, Session: session, } stream.sendControlPacket(request) return stream, nil }