141 lines
3.4 KiB
Go
141 lines
3.4 KiB
Go
package udp
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Middleware interface for processing packets
|
|
type Middleware interface {
|
|
ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error)
|
|
ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error)
|
|
Close() error
|
|
}
|
|
|
|
// Builder for fluent middleware configuration
|
|
type Builder struct {
|
|
address string
|
|
config *Config
|
|
middlewares []Middleware
|
|
}
|
|
|
|
// NewBuilder creates a new connection builder
|
|
func NewBuilder() *Builder {
|
|
return &Builder{
|
|
config: DefaultConfig(),
|
|
}
|
|
}
|
|
|
|
// Address sets the connection address
|
|
func (b *Builder) Address(addr string) *Builder {
|
|
b.address = addr
|
|
return b
|
|
}
|
|
|
|
// Config sets the UDP configuration
|
|
func (b *Builder) Config(config *Config) *Builder {
|
|
b.config = config
|
|
return b
|
|
}
|
|
|
|
// Use adds middleware to the stack
|
|
func (b *Builder) Use(middleware Middleware) *Builder {
|
|
b.middlewares = append(b.middlewares, middleware)
|
|
return b
|
|
}
|
|
|
|
// Listen creates a listener with middleware
|
|
func (b *Builder) Listen() (Listener, error) {
|
|
listener, err := Listen(b.address, b.config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &middlewareListener{listener, b.middlewares}, nil
|
|
}
|
|
|
|
// Dial creates a client connection with middleware
|
|
func (b *Builder) Dial() (Conn, error) {
|
|
conn, err := Dial(b.address, b.config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newMiddlewareConn(conn, b.middlewares), nil
|
|
}
|
|
|
|
// middlewareConn wraps a connection with middleware stack
|
|
type middlewareConn struct {
|
|
conn Conn
|
|
middlewares []Middleware
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
func newMiddlewareConn(conn Conn, middlewares []Middleware) *middlewareConn {
|
|
return &middlewareConn{
|
|
conn: conn,
|
|
middlewares: middlewares,
|
|
}
|
|
}
|
|
|
|
func (m *middlewareConn) Write(data []byte) (int, error) {
|
|
return m.processOutbound(0, data)
|
|
}
|
|
|
|
func (m *middlewareConn) Read(data []byte) (int, error) {
|
|
n, err := m.conn.Read(data)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
return m.processInbound(len(m.middlewares)-1, data[:n])
|
|
}
|
|
|
|
func (m *middlewareConn) processOutbound(index int, data []byte) (int, error) {
|
|
if index >= len(m.middlewares) {
|
|
return m.conn.Write(data)
|
|
}
|
|
|
|
return m.middlewares[index].ProcessOutbound(data, func(processed []byte) (int, error) {
|
|
return m.processOutbound(index+1, processed)
|
|
})
|
|
}
|
|
|
|
func (m *middlewareConn) processInbound(index int, data []byte) (int, error) {
|
|
if index < 0 {
|
|
return len(data), nil
|
|
}
|
|
|
|
return m.middlewares[index].ProcessInbound(data, func(processed []byte) (int, error) {
|
|
return m.processInbound(index-1, processed)
|
|
})
|
|
}
|
|
|
|
func (m *middlewareConn) Close() error {
|
|
m.closeOnce.Do(func() {
|
|
for _, middleware := range m.middlewares {
|
|
middleware.Close()
|
|
}
|
|
})
|
|
return m.conn.Close()
|
|
}
|
|
|
|
func (m *middlewareConn) LocalAddr() net.Addr { return m.conn.LocalAddr() }
|
|
func (m *middlewareConn) RemoteAddr() net.Addr { return m.conn.RemoteAddr() }
|
|
func (m *middlewareConn) SetReadDeadline(t time.Time) error { return m.conn.SetReadDeadline(t) }
|
|
func (m *middlewareConn) SetWriteDeadline(t time.Time) error { return m.conn.SetWriteDeadline(t) }
|
|
|
|
type middlewareListener struct {
|
|
listener Listener
|
|
middlewares []Middleware
|
|
}
|
|
|
|
func (l *middlewareListener) Accept() (Conn, error) {
|
|
conn, err := l.listener.Accept()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newMiddlewareConn(conn, l.middlewares), nil
|
|
}
|
|
|
|
func (l *middlewareListener) Close() error { return l.listener.Close() }
|
|
func (l *middlewareListener) Addr() net.Addr { return l.listener.Addr() }
|