577 lines
12 KiB
Go
577 lines
12 KiB
Go
package udp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Conn represents a reliable UDP connection
|
|
type Conn interface {
|
|
Read(b []byte) (n int, err error)
|
|
Write(b []byte) (n int, err error)
|
|
Close() error
|
|
LocalAddr() net.Addr
|
|
RemoteAddr() net.Addr
|
|
SetReadDeadline(t time.Time) error
|
|
SetWriteDeadline(t time.Time) error
|
|
}
|
|
|
|
// Listener listens for incoming reliable UDP connections
|
|
type Listener interface {
|
|
Accept() (Conn, error)
|
|
Close() error
|
|
Addr() net.Addr
|
|
}
|
|
|
|
// stream implements a reliable UDP stream
|
|
type stream struct {
|
|
conn *net.UDPConn
|
|
remoteAddr *net.UDPAddr
|
|
localAddr *net.UDPAddr
|
|
session uint32
|
|
config *Config
|
|
|
|
// Sequence tracking
|
|
sendSeq uint32
|
|
recvSeq uint16
|
|
lastAckSent uint16
|
|
|
|
// Channels for communication
|
|
inbound chan []byte
|
|
outbound chan []byte
|
|
control chan *packet
|
|
done chan struct{}
|
|
closeOnce sync.Once
|
|
|
|
// Reliability tracking
|
|
pending map[uint16]*pendingPacket
|
|
pendingMutex sync.RWMutex
|
|
outOfOrder map[uint16][]byte
|
|
oooMutex sync.RWMutex
|
|
|
|
// Flow control
|
|
windowSize uint16
|
|
|
|
// Read/Write deadlines
|
|
readDeadline atomic.Value
|
|
writeDeadline atomic.Value
|
|
|
|
// Last activity for keep-alive
|
|
lastActivity time.Time
|
|
activityMutex sync.RWMutex
|
|
}
|
|
|
|
// newStream creates a new reliable UDP stream
|
|
func newStream(conn *net.UDPConn, remoteAddr *net.UDPAddr, session uint32, config *Config) *stream {
|
|
s := &stream{
|
|
conn: conn,
|
|
remoteAddr: remoteAddr,
|
|
localAddr: conn.LocalAddr().(*net.UDPAddr),
|
|
session: session,
|
|
config: config,
|
|
windowSize: config.WindowSize,
|
|
|
|
inbound: make(chan []byte, 256),
|
|
outbound: make(chan []byte, 256),
|
|
control: make(chan *packet, 64),
|
|
done: make(chan struct{}),
|
|
|
|
pending: make(map[uint16]*pendingPacket),
|
|
outOfOrder: make(map[uint16][]byte),
|
|
|
|
lastActivity: time.Now(),
|
|
}
|
|
|
|
// Start background goroutines
|
|
go s.writeLoop()
|
|
go s.retransmitLoop()
|
|
go s.keepAliveLoop()
|
|
|
|
return s
|
|
}
|
|
|
|
// Read implements Conn.Read
|
|
func (s *stream) Read(b []byte) (n int, err error) {
|
|
ctx, cancel := s.getReadDeadlineContext()
|
|
defer cancel()
|
|
|
|
select {
|
|
case data := <-s.inbound:
|
|
n = copy(b, data)
|
|
if n < len(data) {
|
|
return n, fmt.Errorf("buffer too small: need %d bytes, got %d", len(data), len(b))
|
|
}
|
|
return n, nil
|
|
case <-s.done:
|
|
return 0, fmt.Errorf("connection closed")
|
|
case <-ctx.Done():
|
|
return 0, fmt.Errorf("read timeout")
|
|
}
|
|
}
|
|
|
|
// Write implements Conn.Write
|
|
func (s *stream) Write(b []byte) (n int, err error) {
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// Fragment large packets
|
|
mtu := s.config.MTU - 15 // Account for packet header
|
|
if len(b) <= mtu {
|
|
return s.writePacket(b)
|
|
}
|
|
|
|
// Fragment the data
|
|
sent := 0
|
|
for sent < len(b) {
|
|
end := sent + mtu
|
|
if end > len(b) {
|
|
end = len(b)
|
|
}
|
|
|
|
n, err := s.writePacket(b[sent:end])
|
|
sent += n
|
|
if err != nil {
|
|
return sent, err
|
|
}
|
|
}
|
|
|
|
return sent, nil
|
|
}
|
|
|
|
// writePacket writes a single packet
|
|
func (s *stream) writePacket(data []byte) (int, error) {
|
|
ctx, cancel := s.getWriteDeadlineContext()
|
|
defer cancel()
|
|
|
|
select {
|
|
case s.outbound <- data:
|
|
s.updateActivity()
|
|
return len(data), nil
|
|
case <-s.done:
|
|
return 0, fmt.Errorf("connection closed")
|
|
case <-ctx.Done():
|
|
return 0, fmt.Errorf("write timeout")
|
|
}
|
|
}
|
|
|
|
// writeLoop handles outbound packet transmission
|
|
func (s *stream) writeLoop() {
|
|
defer close(s.outbound)
|
|
|
|
for {
|
|
select {
|
|
case data := <-s.outbound:
|
|
s.sendDataPacket(data)
|
|
case ctrlPacket := <-s.control:
|
|
s.sendControlPacket(ctrlPacket)
|
|
case <-s.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// sendDataPacket sends a data packet with reliability
|
|
func (s *stream) sendDataPacket(data []byte) {
|
|
seq := uint16(atomic.AddUint32(&s.sendSeq, 1) - 1)
|
|
|
|
pkt := &packet{
|
|
Type: PacketTypeData,
|
|
Sequence: seq,
|
|
Ack: s.lastAckSent,
|
|
Session: s.session,
|
|
Data: data,
|
|
}
|
|
|
|
// Store for retransmission
|
|
s.pendingMutex.Lock()
|
|
s.pending[seq] = &pendingPacket{
|
|
packet: pkt,
|
|
timestamp: time.Now(),
|
|
attempts: 0,
|
|
}
|
|
s.pendingMutex.Unlock()
|
|
|
|
s.sendRawPacket(pkt)
|
|
}
|
|
|
|
// sendControlPacket sends control packets (ACKs, etc.)
|
|
func (s *stream) sendControlPacket(pkt *packet) {
|
|
pkt.Session = s.session
|
|
s.sendRawPacket(pkt)
|
|
}
|
|
|
|
// sendRawPacket sends a packet over UDP
|
|
func (s *stream) sendRawPacket(pkt *packet) {
|
|
data := pkt.Marshal()
|
|
s.conn.WriteToUDP(data, s.remoteAddr)
|
|
}
|
|
|
|
// handlePacket processes an incoming packet
|
|
func (s *stream) handlePacket(pkt *packet) {
|
|
s.updateActivity()
|
|
|
|
switch pkt.Type {
|
|
case PacketTypeData:
|
|
s.handleDataPacket(pkt)
|
|
case PacketTypeAck:
|
|
s.handleAckPacket(pkt)
|
|
case PacketTypeKeepAlive:
|
|
s.sendAck(pkt.Sequence)
|
|
case PacketTypeDisconnect:
|
|
s.Close()
|
|
}
|
|
}
|
|
|
|
// handleDataPacket processes incoming data packets
|
|
func (s *stream) handleDataPacket(pkt *packet) {
|
|
// Send ACK
|
|
s.sendAck(pkt.Sequence)
|
|
|
|
// Check sequence order
|
|
expectedSeq := s.recvSeq + 1
|
|
|
|
if pkt.Sequence == expectedSeq {
|
|
// In order - deliver immediately
|
|
s.deliverData(pkt.Data)
|
|
s.recvSeq = pkt.Sequence
|
|
|
|
// Check for buffered out-of-order packets
|
|
s.processOutOfOrder()
|
|
} else if pkt.Sequence > expectedSeq {
|
|
// Future packet - buffer it
|
|
s.oooMutex.Lock()
|
|
s.outOfOrder[pkt.Sequence] = pkt.Data
|
|
s.oooMutex.Unlock()
|
|
}
|
|
// Past packets are ignored (duplicate)
|
|
}
|
|
|
|
// processOutOfOrder delivers buffered in-order packets
|
|
func (s *stream) processOutOfOrder() {
|
|
s.oooMutex.Lock()
|
|
defer s.oooMutex.Unlock()
|
|
|
|
for {
|
|
nextSeq := s.recvSeq + 1
|
|
if data, exists := s.outOfOrder[nextSeq]; exists {
|
|
s.deliverData(data)
|
|
s.recvSeq = nextSeq
|
|
delete(s.outOfOrder, nextSeq)
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// deliverData delivers data to the application
|
|
func (s *stream) deliverData(data []byte) {
|
|
select {
|
|
case s.inbound <- data:
|
|
case <-s.done:
|
|
default:
|
|
// Channel full - would block
|
|
}
|
|
}
|
|
|
|
// handleAckPacket processes acknowledgment packets
|
|
func (s *stream) handleAckPacket(pkt *packet) {
|
|
s.pendingMutex.Lock()
|
|
defer s.pendingMutex.Unlock()
|
|
|
|
if pending, exists := s.pending[pkt.Sequence]; exists {
|
|
delete(s.pending, pkt.Sequence)
|
|
_ = pending // Packet acknowledged
|
|
}
|
|
}
|
|
|
|
// sendAck sends an acknowledgment
|
|
func (s *stream) sendAck(seq uint16) {
|
|
s.lastAckSent = seq
|
|
ackPkt := &packet{
|
|
Type: PacketTypeAck,
|
|
Sequence: seq,
|
|
Ack: seq,
|
|
}
|
|
|
|
select {
|
|
case s.control <- ackPkt:
|
|
case <-s.done:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// retransmitLoop handles packet retransmission
|
|
func (s *stream) retransmitLoop() {
|
|
ticker := time.NewTicker(time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
s.checkRetransmissions()
|
|
case <-s.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// checkRetransmissions checks for packets needing retransmission
|
|
func (s *stream) checkRetransmissions() {
|
|
now := time.Now()
|
|
|
|
s.pendingMutex.Lock()
|
|
defer s.pendingMutex.Unlock()
|
|
|
|
for seq, pending := range s.pending {
|
|
if now.Sub(pending.timestamp) > RetransmitTimeout {
|
|
if pending.attempts >= s.config.RetryAttempts {
|
|
// Too many attempts - close connection
|
|
delete(s.pending, seq)
|
|
go s.Close()
|
|
return
|
|
}
|
|
|
|
// Retransmit
|
|
pending.attempts++
|
|
pending.timestamp = now
|
|
s.sendRawPacket(pending.packet)
|
|
}
|
|
}
|
|
}
|
|
|
|
// keepAliveLoop sends periodic keep-alive packets
|
|
func (s *stream) keepAliveLoop() {
|
|
ticker := time.NewTicker(KeepAliveInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
s.activityMutex.RLock()
|
|
idle := time.Since(s.lastActivity)
|
|
s.activityMutex.RUnlock()
|
|
|
|
if idle > KeepAliveInterval {
|
|
keepAlive := &packet{Type: PacketTypeKeepAlive}
|
|
select {
|
|
case s.control <- keepAlive:
|
|
case <-s.done:
|
|
return
|
|
}
|
|
}
|
|
case <-s.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// updateActivity updates the last activity timestamp
|
|
func (s *stream) updateActivity() {
|
|
s.activityMutex.Lock()
|
|
s.lastActivity = time.Now()
|
|
s.activityMutex.Unlock()
|
|
}
|
|
|
|
// Close implements Conn.Close
|
|
func (s *stream) Close() error {
|
|
s.closeOnce.Do(func() {
|
|
// Send disconnect packet
|
|
disconnect := &packet{Type: PacketTypeDisconnect}
|
|
select {
|
|
case s.control <- disconnect:
|
|
default:
|
|
}
|
|
|
|
close(s.done)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Address methods
|
|
func (s *stream) LocalAddr() net.Addr { return s.localAddr }
|
|
func (s *stream) RemoteAddr() net.Addr { return s.remoteAddr }
|
|
|
|
// Deadline methods
|
|
func (s *stream) SetReadDeadline(t time.Time) error {
|
|
s.readDeadline.Store(t)
|
|
return nil
|
|
}
|
|
|
|
func (s *stream) SetWriteDeadline(t time.Time) error {
|
|
s.writeDeadline.Store(t)
|
|
return nil
|
|
}
|
|
|
|
func (s *stream) getReadDeadlineContext() (context.Context, context.CancelFunc) {
|
|
if deadline, ok := s.readDeadline.Load().(time.Time); ok && !deadline.IsZero() {
|
|
return context.WithDeadline(context.Background(), deadline)
|
|
}
|
|
return context.Background(), func() {}
|
|
}
|
|
|
|
func (s *stream) getWriteDeadlineContext() (context.Context, context.CancelFunc) {
|
|
if deadline, ok := s.writeDeadline.Load().(time.Time); ok && !deadline.IsZero() {
|
|
return context.WithDeadline(context.Background(), deadline)
|
|
}
|
|
return context.Background(), func() {}
|
|
}
|
|
|
|
// listener implements a reliable UDP listener
|
|
type listener struct {
|
|
conn *net.UDPConn
|
|
config *Config
|
|
streams map[string]*stream
|
|
mutex sync.RWMutex
|
|
incoming chan *stream
|
|
done chan struct{}
|
|
}
|
|
|
|
// Listen creates a new reliable UDP listener
|
|
func Listen(address string, config *Config) (Listener, error) {
|
|
if config == nil {
|
|
config = DefaultConfig()
|
|
}
|
|
|
|
addr, err := net.ResolveUDPAddr("udp", address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conn, err := net.ListenUDP("udp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
l := &listener{
|
|
conn: conn,
|
|
config: config,
|
|
streams: make(map[string]*stream),
|
|
incoming: make(chan *stream, 16),
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
go l.readLoop()
|
|
return l, nil
|
|
}
|
|
|
|
// readLoop handles incoming UDP packets
|
|
func (l *listener) readLoop() {
|
|
buf := make([]byte, 2048)
|
|
|
|
for {
|
|
select {
|
|
case <-l.done:
|
|
return
|
|
default:
|
|
}
|
|
|
|
n, addr, err := l.conn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
pkt := &packet{}
|
|
if err := pkt.Unmarshal(buf[:n]); err != nil {
|
|
continue
|
|
}
|
|
|
|
l.handlePacket(pkt, addr)
|
|
}
|
|
}
|
|
|
|
// handlePacket routes packets to appropriate streams
|
|
func (l *listener) handlePacket(pkt *packet, addr *net.UDPAddr) {
|
|
streamKey := addr.String()
|
|
|
|
l.mutex.RLock()
|
|
stream, exists := l.streams[streamKey]
|
|
l.mutex.RUnlock()
|
|
|
|
if !exists && pkt.Type == PacketTypeSessionRequest {
|
|
// New connection
|
|
session := pkt.Session
|
|
stream = newStream(l.conn, addr, session, l.config)
|
|
|
|
l.mutex.Lock()
|
|
l.streams[streamKey] = stream
|
|
l.mutex.Unlock()
|
|
|
|
// Send session response
|
|
response := &packet{
|
|
Type: PacketTypeSessionResponse,
|
|
Session: session,
|
|
}
|
|
stream.sendControlPacket(response)
|
|
|
|
select {
|
|
case l.incoming <- stream:
|
|
case <-l.done:
|
|
}
|
|
} else if exists {
|
|
stream.handlePacket(pkt)
|
|
}
|
|
}
|
|
|
|
// Accept implements Listener.Accept
|
|
func (l *listener) Accept() (Conn, error) {
|
|
select {
|
|
case stream := <-l.incoming:
|
|
return stream, nil
|
|
case <-l.done:
|
|
return nil, fmt.Errorf("listener closed")
|
|
}
|
|
}
|
|
|
|
// Close implements Listener.Close
|
|
func (l *listener) Close() error {
|
|
close(l.done)
|
|
|
|
l.mutex.Lock()
|
|
defer l.mutex.Unlock()
|
|
|
|
for _, stream := range l.streams {
|
|
stream.Close()
|
|
}
|
|
|
|
return l.conn.Close()
|
|
}
|
|
|
|
// Addr implements Listener.Addr
|
|
func (l *listener) Addr() net.Addr {
|
|
return l.conn.LocalAddr()
|
|
}
|
|
|
|
// Dial creates a client connection to a reliable UDP server
|
|
func Dial(address string, config *Config) (Conn, error) {
|
|
if config == nil {
|
|
config = DefaultConfig()
|
|
}
|
|
|
|
addr, err := net.ResolveUDPAddr("udp", address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conn, err := net.DialUDP("udp", nil, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
session := uint32(time.Now().Unix())
|
|
stream := newStream(conn, addr, session, config)
|
|
|
|
// Send session request
|
|
request := &packet{
|
|
Type: PacketTypeSessionRequest,
|
|
Session: session,
|
|
}
|
|
stream.sendControlPacket(request)
|
|
|
|
return stream, nil
|
|
}
|