eq2go/internal/udp/server.go
2025-07-21 23:18:39 -05:00

256 lines
6.3 KiB
Go

package udp
import (
"fmt"
"net"
"sync"
"time"
)
// Server manages multiple UDP connections and handles packet routing
type Server struct {
conn *net.UDPConn // Main UDP socket
connections map[string]*Connection // Active connections by address
mutex sync.RWMutex // Protects connections map
handler PacketHandler // Application packet handler
running bool // Server running state
// Configuration
maxConnections int // Maximum concurrent connections
timeout time.Duration // Connection timeout duration
}
// PacketHandler processes application-level packets for connections
type PacketHandler interface {
HandlePacket(conn *Connection, packet *ApplicationPacket)
}
// ServerConfig holds server configuration options
type ServerConfig struct {
MaxConnections int // Maximum concurrent connections (default: 1000)
Timeout time.Duration // Connection timeout (default: 45s)
BufferSize int // UDP receive buffer size (default: 8192)
}
// DefaultServerConfig returns sensible default configuration
func DefaultServerConfig() ServerConfig {
return ServerConfig{
MaxConnections: 1000,
Timeout: 45 * time.Second,
BufferSize: 8192,
}
}
// NewServer creates a new UDP server instance
func NewServer(addr string, handler PacketHandler) (*Server, error) {
return NewServerWithConfig(addr, handler, DefaultServerConfig())
}
// NewServerWithConfig creates a server with custom configuration
func NewServerWithConfig(addr string, handler PacketHandler, config ServerConfig) (*Server, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, fmt.Errorf("invalid UDP address %s: %w", addr, err)
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
}
// Set socket buffer size for better performance
if config.BufferSize > 0 {
conn.SetReadBuffer(config.BufferSize)
conn.SetWriteBuffer(config.BufferSize)
}
return &Server{
conn: conn,
connections: make(map[string]*Connection),
handler: handler,
maxConnections: config.MaxConnections,
timeout: config.Timeout,
}, nil
}
// Start begins accepting and processing UDP packets
func (s *Server) Start() error {
s.running = true
// Start background management tasks
go s.connectionManager()
// Main packet receive loop
buffer := make([]byte, 8192)
for s.running {
n, addr, err := s.conn.ReadFromUDP(buffer)
if err != nil {
if s.running {
fmt.Printf("UDP read error: %v\n", err)
}
continue
}
// Handle packet in separate goroutine to avoid blocking
go s.handleIncomingPacket(buffer[:n], addr)
}
return nil
}
// Stop gracefully shuts down the server
func (s *Server) Stop() {
s.running = false
// Close all connections
s.mutex.Lock()
for _, conn := range s.connections {
conn.Close()
}
s.connections = make(map[string]*Connection)
s.mutex.Unlock()
// Close UDP socket
s.conn.Close()
}
// handleIncomingPacket processes a single UDP packet
func (s *Server) handleIncomingPacket(data []byte, addr *net.UDPAddr) {
if len(data) < 1 {
return
}
connKey := addr.String()
// Find or create connection
s.mutex.Lock()
conn, exists := s.connections[connKey]
if !exists {
// Check connection limit
if len(s.connections) >= s.maxConnections {
s.mutex.Unlock()
return // Drop packet if at capacity
}
conn = NewConnection(addr, s.conn, s.handler)
conn.StartRetransmitLoop()
s.connections[connKey] = conn
}
s.mutex.Unlock()
// Process packet
conn.ProcessPacket(data)
}
// connectionManager handles connection cleanup and maintenance
func (s *Server) connectionManager() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for range ticker.C {
if !s.running {
return
}
// Clean up timed out connections
s.cleanupTimedOutConnections()
}
}
// cleanupTimedOutConnections removes connections that have timed out
func (s *Server) cleanupTimedOutConnections() {
s.mutex.Lock()
defer s.mutex.Unlock()
for key, conn := range s.connections {
if conn.IsTimedOut(s.timeout) {
conn.Close()
delete(s.connections, key)
}
}
}
// GetConnectionCount returns the number of active connections
func (s *Server) GetConnectionCount() int {
s.mutex.RLock()
defer s.mutex.RUnlock()
return len(s.connections)
}
// GetConnection returns a connection by address string
func (s *Server) GetConnection(addr string) *Connection {
s.mutex.RLock()
defer s.mutex.RUnlock()
return s.connections[addr]
}
// GetAllConnections returns a snapshot of all active connections
func (s *Server) GetAllConnections() []*Connection {
s.mutex.RLock()
defer s.mutex.RUnlock()
connections := make([]*Connection, 0, len(s.connections))
for _, conn := range s.connections {
connections = append(connections, conn)
}
return connections
}
// BroadcastPacket sends a packet to all connected clients
func (s *Server) BroadcastPacket(packet *ApplicationPacket) {
connections := s.GetAllConnections()
for _, conn := range connections {
if conn.GetState() == StateEstablished {
conn.SendPacket(packet)
}
}
}
// DisconnectClient forcibly disconnects a client by address
func (s *Server) DisconnectClient(addr string) bool {
s.mutex.Lock()
defer s.mutex.Unlock()
if conn, exists := s.connections[addr]; exists {
conn.Close()
delete(s.connections, addr)
return true
}
return false
}
// GetStats returns server statistics
func (s *Server) GetStats() ServerStats {
s.mutex.RLock()
defer s.mutex.RUnlock()
return ServerStats{
ConnectionCount: len(s.connections),
MaxConnections: s.maxConnections,
Running: s.running,
Timeout: s.timeout,
}
}
// ServerStats contains server runtime statistics
type ServerStats struct {
ConnectionCount int // Current number of connections
MaxConnections int // Maximum allowed connections
Running bool // Whether server is running
Timeout time.Duration // Connection timeout setting
}
// SetConnectionLimit updates the maximum connection limit
func (s *Server) SetConnectionLimit(limit int) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.maxConnections = limit
}
// SetTimeout updates the connection timeout duration
func (s *Server) SetTimeout(timeout time.Duration) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.timeout = timeout
}