eq2go/internal/udp/server.go
2025-07-21 23:51:35 -05:00

231 lines
5.4 KiB
Go

package udp
import (
"fmt"
"net"
"sync"
"time"
)
// Server manages multiple UDP connections and handles packet routing
type Server struct {
conn *net.UDPConn // Main UDP socket
connections map[string]*Connection // Active connections by address
mutex sync.RWMutex // Protects connections map
handler PacketHandler // Application packet handler
running bool // Server running state
config Config // Server configuration
}
// NewServer creates a UDP server with simplified configuration
func NewServer(addr string, handler PacketHandler, config ...Config) (*Server, error) {
cfg := DefaultConfig()
if len(config) > 0 {
cfg = config[0]
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, fmt.Errorf("invalid UDP address %s: %w", addr, err)
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
}
// Set socket buffer size for better performance
if cfg.BufferSize > 0 {
conn.SetReadBuffer(cfg.BufferSize)
conn.SetWriteBuffer(cfg.BufferSize)
}
return &Server{
conn: conn,
connections: make(map[string]*Connection),
handler: handler,
config: cfg,
}, nil
}
// Start begins accepting and processing UDP packets
func (s *Server) Start() error {
s.running = true
// Start background management tasks
go s.connectionManager()
// Main packet receive loop
buffer := make([]byte, 8192)
for s.running {
n, addr, err := s.conn.ReadFromUDP(buffer)
if err != nil {
if s.running {
fmt.Printf("UDP read error: %v\n", err)
}
continue
}
// Handle packet in separate goroutine to avoid blocking
go s.handleIncomingPacket(buffer[:n], addr)
}
return nil
}
// Stop gracefully shuts down the server
func (s *Server) Stop() {
s.running = false
// Close all connections
s.mutex.Lock()
for _, conn := range s.connections {
conn.Close()
}
s.connections = make(map[string]*Connection)
s.mutex.Unlock()
// Close UDP socket
s.conn.Close()
}
// handleIncomingPacket processes a single UDP packet
func (s *Server) handleIncomingPacket(data []byte, addr *net.UDPAddr) {
if len(data) < 1 {
return
}
connKey := addr.String()
// Find or create connection
s.mutex.Lock()
conn, exists := s.connections[connKey]
if !exists {
// Check connection limit
if len(s.connections) >= s.config.MaxConnections {
s.mutex.Unlock()
return // Drop packet if at capacity
}
conn = NewConnection(addr, s.conn, s.handler, s.config)
conn.StartRetransmitLoop()
s.connections[connKey] = conn
}
s.mutex.Unlock()
// Process packet
conn.ProcessPacket(data)
}
// connectionManager handles connection cleanup and maintenance
func (s *Server) connectionManager() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for range ticker.C {
if !s.running {
return
}
// Clean up timed out connections
s.cleanupTimedOutConnections()
}
}
// cleanupTimedOutConnections removes connections that have timed out
func (s *Server) cleanupTimedOutConnections() {
s.mutex.Lock()
defer s.mutex.Unlock()
for key, conn := range s.connections {
if conn.IsTimedOut() {
conn.Close()
delete(s.connections, key)
}
}
}
// GetConnectionCount returns the number of active connections
func (s *Server) GetConnectionCount() int {
s.mutex.RLock()
defer s.mutex.RUnlock()
return len(s.connections)
}
// GetConnection returns a connection by address string
func (s *Server) GetConnection(addr string) *Connection {
s.mutex.RLock()
defer s.mutex.RUnlock()
return s.connections[addr]
}
// GetAllConnections returns a snapshot of all active connections
func (s *Server) GetAllConnections() []*Connection {
s.mutex.RLock()
defer s.mutex.RUnlock()
connections := make([]*Connection, 0, len(s.connections))
for _, conn := range s.connections {
connections = append(connections, conn)
}
return connections
}
// BroadcastPacket sends a packet to all connected clients
func (s *Server) BroadcastPacket(packet *ApplicationPacket) {
connections := s.GetAllConnections()
for _, conn := range connections {
if conn.GetState() == StateEstablished {
conn.SendPacket(packet)
}
}
}
// DisconnectClient forcibly disconnects a client by address
func (s *Server) DisconnectClient(addr string) bool {
s.mutex.Lock()
defer s.mutex.Unlock()
if conn, exists := s.connections[addr]; exists {
conn.Close()
delete(s.connections, addr)
return true
}
return false
}
// ServerStats contains server runtime statistics
type ServerStats struct {
ConnectionCount int // Current number of connections
MaxConnections int // Maximum allowed connections
Running bool // Whether server is running
Timeout time.Duration // Connection timeout setting
}
// GetStats returns server statistics
func (s *Server) GetStats() ServerStats {
s.mutex.RLock()
defer s.mutex.RUnlock()
return ServerStats{
ConnectionCount: len(s.connections),
MaxConnections: s.config.MaxConnections,
Running: s.running,
Timeout: s.config.Timeout,
}
}
// SetConnectionLimit updates the maximum connection limit
func (s *Server) SetConnectionLimit(limit int) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.config.MaxConnections = limit
}
// SetTimeout updates the connection timeout duration
func (s *Server) SetTimeout(timeout time.Duration) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.config.Timeout = timeout
}