eq2go/internal/udp/server.go

577 lines
12 KiB
Go

package udp
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// Conn represents a reliable UDP connection
type Conn interface {
Read(b []byte) (n int, err error)
Write(b []byte) (n int, err error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}
// Listener listens for incoming reliable UDP connections
type Listener interface {
Accept() (Conn, error)
Close() error
Addr() net.Addr
}
// stream implements a reliable UDP stream
type stream struct {
conn *net.UDPConn
remoteAddr *net.UDPAddr
localAddr *net.UDPAddr
session uint32
config *Config
// Sequence tracking
sendSeq uint32
recvSeq uint16
lastAckSent uint16
// Channels for communication
inbound chan []byte
outbound chan []byte
control chan *packet
done chan struct{}
closeOnce sync.Once
// Reliability tracking
pending map[uint16]*pendingPacket
pendingMutex sync.RWMutex
outOfOrder map[uint16][]byte
oooMutex sync.RWMutex
// Flow control
windowSize uint16
// Read/Write deadlines
readDeadline atomic.Value
writeDeadline atomic.Value
// Last activity for keep-alive
lastActivity time.Time
activityMutex sync.RWMutex
}
// newStream creates a new reliable UDP stream
func newStream(conn *net.UDPConn, remoteAddr *net.UDPAddr, session uint32, config *Config) *stream {
s := &stream{
conn: conn,
remoteAddr: remoteAddr,
localAddr: conn.LocalAddr().(*net.UDPAddr),
session: session,
config: config,
windowSize: config.WindowSize,
inbound: make(chan []byte, 256),
outbound: make(chan []byte, 256),
control: make(chan *packet, 64),
done: make(chan struct{}),
pending: make(map[uint16]*pendingPacket),
outOfOrder: make(map[uint16][]byte),
lastActivity: time.Now(),
}
// Start background goroutines
go s.writeLoop()
go s.retransmitLoop()
go s.keepAliveLoop()
return s
}
// Read implements Conn.Read
func (s *stream) Read(b []byte) (n int, err error) {
ctx, cancel := s.getReadDeadlineContext()
defer cancel()
select {
case data := <-s.inbound:
n = copy(b, data)
if n < len(data) {
return n, fmt.Errorf("buffer too small: need %d bytes, got %d", len(data), len(b))
}
return n, nil
case <-s.done:
return 0, fmt.Errorf("connection closed")
case <-ctx.Done():
return 0, fmt.Errorf("read timeout")
}
}
// Write implements Conn.Write
func (s *stream) Write(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
// Fragment large packets
mtu := s.config.MTU - 15 // Account for packet header
if len(b) <= mtu {
return s.writePacket(b)
}
// Fragment the data
sent := 0
for sent < len(b) {
end := sent + mtu
if end > len(b) {
end = len(b)
}
n, err := s.writePacket(b[sent:end])
sent += n
if err != nil {
return sent, err
}
}
return sent, nil
}
// writePacket writes a single packet
func (s *stream) writePacket(data []byte) (int, error) {
ctx, cancel := s.getWriteDeadlineContext()
defer cancel()
select {
case s.outbound <- data:
s.updateActivity()
return len(data), nil
case <-s.done:
return 0, fmt.Errorf("connection closed")
case <-ctx.Done():
return 0, fmt.Errorf("write timeout")
}
}
// writeLoop handles outbound packet transmission
func (s *stream) writeLoop() {
defer close(s.outbound)
for {
select {
case data := <-s.outbound:
s.sendDataPacket(data)
case ctrlPacket := <-s.control:
s.sendControlPacket(ctrlPacket)
case <-s.done:
return
}
}
}
// sendDataPacket sends a data packet with reliability
func (s *stream) sendDataPacket(data []byte) {
seq := uint16(atomic.AddUint32(&s.sendSeq, 1) - 1)
pkt := &packet{
Type: PacketTypeData,
Sequence: seq,
Ack: s.lastAckSent,
Session: s.session,
Data: data,
}
// Store for retransmission
s.pendingMutex.Lock()
s.pending[seq] = &pendingPacket{
packet: pkt,
timestamp: time.Now(),
attempts: 0,
}
s.pendingMutex.Unlock()
s.sendRawPacket(pkt)
}
// sendControlPacket sends control packets (ACKs, etc.)
func (s *stream) sendControlPacket(pkt *packet) {
pkt.Session = s.session
s.sendRawPacket(pkt)
}
// sendRawPacket sends a packet over UDP
func (s *stream) sendRawPacket(pkt *packet) {
data := pkt.Marshal()
s.conn.WriteToUDP(data, s.remoteAddr)
}
// handlePacket processes an incoming packet
func (s *stream) handlePacket(pkt *packet) {
s.updateActivity()
switch pkt.Type {
case PacketTypeData:
s.handleDataPacket(pkt)
case PacketTypeAck:
s.handleAckPacket(pkt)
case PacketTypeKeepAlive:
s.sendAck(pkt.Sequence)
case PacketTypeDisconnect:
s.Close()
}
}
// handleDataPacket processes incoming data packets
func (s *stream) handleDataPacket(pkt *packet) {
// Send ACK
s.sendAck(pkt.Sequence)
// Check sequence order
expectedSeq := s.recvSeq + 1
if pkt.Sequence == expectedSeq {
// In order - deliver immediately
s.deliverData(pkt.Data)
s.recvSeq = pkt.Sequence
// Check for buffered out-of-order packets
s.processOutOfOrder()
} else if pkt.Sequence > expectedSeq {
// Future packet - buffer it
s.oooMutex.Lock()
s.outOfOrder[pkt.Sequence] = pkt.Data
s.oooMutex.Unlock()
}
// Past packets are ignored (duplicate)
}
// processOutOfOrder delivers buffered in-order packets
func (s *stream) processOutOfOrder() {
s.oooMutex.Lock()
defer s.oooMutex.Unlock()
for {
nextSeq := s.recvSeq + 1
if data, exists := s.outOfOrder[nextSeq]; exists {
s.deliverData(data)
s.recvSeq = nextSeq
delete(s.outOfOrder, nextSeq)
} else {
break
}
}
}
// deliverData delivers data to the application
func (s *stream) deliverData(data []byte) {
select {
case s.inbound <- data:
case <-s.done:
default:
// Channel full - would block
}
}
// handleAckPacket processes acknowledgment packets
func (s *stream) handleAckPacket(pkt *packet) {
s.pendingMutex.Lock()
defer s.pendingMutex.Unlock()
if pending, exists := s.pending[pkt.Sequence]; exists {
delete(s.pending, pkt.Sequence)
_ = pending // Packet acknowledged
}
}
// sendAck sends an acknowledgment
func (s *stream) sendAck(seq uint16) {
s.lastAckSent = seq
ackPkt := &packet{
Type: PacketTypeAck,
Sequence: seq,
Ack: seq,
}
select {
case s.control <- ackPkt:
case <-s.done:
default:
}
}
// retransmitLoop handles packet retransmission
func (s *stream) retransmitLoop() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.checkRetransmissions()
case <-s.done:
return
}
}
}
// checkRetransmissions checks for packets needing retransmission
func (s *stream) checkRetransmissions() {
now := time.Now()
s.pendingMutex.Lock()
defer s.pendingMutex.Unlock()
for seq, pending := range s.pending {
if now.Sub(pending.timestamp) > RetransmitTimeout {
if pending.attempts >= s.config.RetryAttempts {
// Too many attempts - close connection
delete(s.pending, seq)
go s.Close()
return
}
// Retransmit
pending.attempts++
pending.timestamp = now
s.sendRawPacket(pending.packet)
}
}
}
// keepAliveLoop sends periodic keep-alive packets
func (s *stream) keepAliveLoop() {
ticker := time.NewTicker(KeepAliveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.activityMutex.RLock()
idle := time.Since(s.lastActivity)
s.activityMutex.RUnlock()
if idle > KeepAliveInterval {
keepAlive := &packet{Type: PacketTypeKeepAlive}
select {
case s.control <- keepAlive:
case <-s.done:
return
}
}
case <-s.done:
return
}
}
}
// updateActivity updates the last activity timestamp
func (s *stream) updateActivity() {
s.activityMutex.Lock()
s.lastActivity = time.Now()
s.activityMutex.Unlock()
}
// Close implements Conn.Close
func (s *stream) Close() error {
s.closeOnce.Do(func() {
// Send disconnect packet
disconnect := &packet{Type: PacketTypeDisconnect}
select {
case s.control <- disconnect:
default:
}
close(s.done)
})
return nil
}
// Address methods
func (s *stream) LocalAddr() net.Addr { return s.localAddr }
func (s *stream) RemoteAddr() net.Addr { return s.remoteAddr }
// Deadline methods
func (s *stream) SetReadDeadline(t time.Time) error {
s.readDeadline.Store(t)
return nil
}
func (s *stream) SetWriteDeadline(t time.Time) error {
s.writeDeadline.Store(t)
return nil
}
func (s *stream) getReadDeadlineContext() (context.Context, context.CancelFunc) {
if deadline, ok := s.readDeadline.Load().(time.Time); ok && !deadline.IsZero() {
return context.WithDeadline(context.Background(), deadline)
}
return context.Background(), func() {}
}
func (s *stream) getWriteDeadlineContext() (context.Context, context.CancelFunc) {
if deadline, ok := s.writeDeadline.Load().(time.Time); ok && !deadline.IsZero() {
return context.WithDeadline(context.Background(), deadline)
}
return context.Background(), func() {}
}
// listener implements a reliable UDP listener
type listener struct {
conn *net.UDPConn
config *Config
streams map[string]*stream
mutex sync.RWMutex
incoming chan *stream
done chan struct{}
}
// Listen creates a new reliable UDP listener
func Listen(address string, config *Config) (Listener, error) {
if config == nil {
config = DefaultConfig()
}
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
l := &listener{
conn: conn,
config: config,
streams: make(map[string]*stream),
incoming: make(chan *stream, 16),
done: make(chan struct{}),
}
go l.readLoop()
return l, nil
}
// readLoop handles incoming UDP packets
func (l *listener) readLoop() {
buf := make([]byte, 2048)
for {
select {
case <-l.done:
return
default:
}
n, addr, err := l.conn.ReadFromUDP(buf)
if err != nil {
continue
}
pkt := &packet{}
if err := pkt.Unmarshal(buf[:n]); err != nil {
continue
}
l.handlePacket(pkt, addr)
}
}
// handlePacket routes packets to appropriate streams
func (l *listener) handlePacket(pkt *packet, addr *net.UDPAddr) {
streamKey := addr.String()
l.mutex.RLock()
stream, exists := l.streams[streamKey]
l.mutex.RUnlock()
if !exists && pkt.Type == PacketTypeSessionRequest {
// New connection
session := pkt.Session
stream = newStream(l.conn, addr, session, l.config)
l.mutex.Lock()
l.streams[streamKey] = stream
l.mutex.Unlock()
// Send session response
response := &packet{
Type: PacketTypeSessionResponse,
Session: session,
}
stream.sendControlPacket(response)
select {
case l.incoming <- stream:
case <-l.done:
}
} else if exists {
stream.handlePacket(pkt)
}
}
// Accept implements Listener.Accept
func (l *listener) Accept() (Conn, error) {
select {
case stream := <-l.incoming:
return stream, nil
case <-l.done:
return nil, fmt.Errorf("listener closed")
}
}
// Close implements Listener.Close
func (l *listener) Close() error {
close(l.done)
l.mutex.Lock()
defer l.mutex.Unlock()
for _, stream := range l.streams {
stream.Close()
}
return l.conn.Close()
}
// Addr implements Listener.Addr
func (l *listener) Addr() net.Addr {
return l.conn.LocalAddr()
}
// Dial creates a client connection to a reliable UDP server
func Dial(address string, config *Config) (Conn, error) {
if config == nil {
config = DefaultConfig()
}
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
conn, err := net.DialUDP("udp", nil, addr)
if err != nil {
return nil, err
}
session := uint32(time.Now().Unix())
stream := newStream(conn, addr, session, config)
// Send session request
request := &packet{
Type: PacketTypeSessionRequest,
Session: session,
}
stream.sendControlPacket(request)
return stream, nil
}