package udp import ( "net" "sync" "time" ) // Middleware interface for processing packets type Middleware interface { ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error) ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error) Close() error } // Builder for fluent middleware configuration type Builder struct { address string config *Config middlewares []Middleware } // NewBuilder creates a new connection builder func NewBuilder() *Builder { return &Builder{ config: DefaultConfig(), } } // Address sets the connection address func (b *Builder) Address(addr string) *Builder { b.address = addr return b } // Config sets the UDP configuration func (b *Builder) Config(config *Config) *Builder { b.config = config return b } // Use adds middleware to the stack func (b *Builder) Use(middleware Middleware) *Builder { b.middlewares = append(b.middlewares, middleware) return b } // Listen creates a listener with middleware func (b *Builder) Listen() (Listener, error) { listener, err := Listen(b.address, b.config) if err != nil { return nil, err } return &middlewareListener{listener, b.middlewares}, nil } // Dial creates a client connection with middleware func (b *Builder) Dial() (Conn, error) { conn, err := Dial(b.address, b.config) if err != nil { return nil, err } return newMiddlewareConn(conn, b.middlewares), nil } // middlewareConn wraps a connection with middleware stack type middlewareConn struct { conn Conn middlewares []Middleware closeOnce sync.Once } func newMiddlewareConn(conn Conn, middlewares []Middleware) *middlewareConn { return &middlewareConn{ conn: conn, middlewares: middlewares, } } func (m *middlewareConn) Write(data []byte) (int, error) { return m.processOutbound(0, data) } func (m *middlewareConn) Read(data []byte) (int, error) { n, err := m.conn.Read(data) if err != nil { return n, err } return m.processInbound(len(m.middlewares)-1, data[:n]) } func (m *middlewareConn) processOutbound(index int, data []byte) (int, error) { if index >= len(m.middlewares) { return m.conn.Write(data) } return m.middlewares[index].ProcessOutbound(data, func(processed []byte) (int, error) { return m.processOutbound(index+1, processed) }) } func (m *middlewareConn) processInbound(index int, data []byte) (int, error) { if index < 0 { return len(data), nil } return m.middlewares[index].ProcessInbound(data, func(processed []byte) (int, error) { return m.processInbound(index-1, processed) }) } func (m *middlewareConn) Close() error { m.closeOnce.Do(func() { for _, middleware := range m.middlewares { middleware.Close() } }) return m.conn.Close() } func (m *middlewareConn) LocalAddr() net.Addr { return m.conn.LocalAddr() } func (m *middlewareConn) RemoteAddr() net.Addr { return m.conn.RemoteAddr() } func (m *middlewareConn) SetReadDeadline(t time.Time) error { return m.conn.SetReadDeadline(t) } func (m *middlewareConn) SetWriteDeadline(t time.Time) error { return m.conn.SetWriteDeadline(t) } type middlewareListener struct { listener Listener middlewares []Middleware } func (l *middlewareListener) Accept() (Conn, error) { conn, err := l.listener.Accept() if err != nil { return nil, err } return newMiddlewareConn(conn, l.middlewares), nil } func (l *middlewareListener) Close() error { return l.listener.Close() } func (l *middlewareListener) Addr() net.Addr { return l.listener.Addr() }