// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

// Package udp implements DTLS specific UDP networking primitives. // NOTE: this package is an adaption of pion/transport/udp that allows for // routing datagrams based on identifiers other than the remote address. The // primary use case for this functionality is routing based on DTLS connection // IDs. In order to allow for consumers of this package to treat connections as // generic net.PackageConn, routing and identitier establishment is based on // custom introspecion of datagrams, rather than direct intervention by // consumers. If possible, the updates made in this repository will be reflected // back upstream. If not, it is likely that this will be moved to a public // package in this repository. // // This package was migrated from pion/transport/udp at // https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb
package udp import ( idtlsnet dtlsnet ) const ( receiveMTU = 8192 defaultListenBacklog = 128 // same as Linux default ) // Typed errors. var ( ErrClosedListener = errors.New("udp: listener closed") ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") ) // listener augments a connection-oriented Listener over a UDP PacketConn. type listener struct { pConn *net.UDPConn accepting atomic.Value // bool acceptCh chan *PacketConn doneCh chan struct{} doneOnce sync.Once acceptFilter func([]byte) bool datagramRouter func([]byte) (string, bool) connIdentifier func([]byte) (string, bool) connLock sync.Mutex conns map[string]*PacketConn connWG sync.WaitGroup readWG sync.WaitGroup errClose atomic.Value // error readDoneCh chan struct{} errRead atomic.Value // error } // Accept waits for and returns the next connection to the listener. func ( *listener) () (net.PacketConn, net.Addr, error) { select { case := <-.acceptCh: .connWG.Add(1) return , .raddr, nil case <-.readDoneCh: , := .errRead.Load().(error) return nil, nil, case <-.doneCh: return nil, nil, ErrClosedListener } } // Close closes the listener. // Any blocked Accept operations will be unblocked and return errors. func ( *listener) () error { var error .doneOnce.Do(func() { .accepting.Store(false) close(.doneCh) .connLock.Lock() // Close unaccepted connections : for { select { case := <-.acceptCh: close(.doneCh) // If we have an alternate identifier, remove it from the connection // map. if := .id.Load(); != nil { delete(.conns, .(string)) //nolint:forcetypeassert } // If we haven't already removed the remote address, remove it // from the connection map. if .rmraddr.Load() == nil { delete(.conns, .raddr.String()) .rmraddr.Store(true) } default: break } } := len(.conns) .connLock.Unlock() .connWG.Done() if == 0 { // Wait if this is the final connection. .readWG.Wait() if , := .errClose.Load().(error); { = } } else { = nil } }) return } // Addr returns the listener's network address. func ( *listener) () net.Addr { return .pConn.LocalAddr() } // ListenConfig stores options for listening to an address. type ListenConfig struct { // Backlog defines the maximum length of the queue of pending // connections. It is equivalent of the backlog argument of // POSIX listen function. // If a connection request arrives when the queue is full, // the request will be silently discarded, unlike TCP. // Set zero to use default value 128 which is same as Linux default. Backlog int // AcceptFilter determines whether the new conn should be made for // the incoming packet. If not set, any packet creates new conn. AcceptFilter func([]byte) bool // DatagramRouter routes an incoming datagram to a connection by extracting // an identifier from the its paylod DatagramRouter func([]byte) (string, bool) // ConnectionIdentifier extracts an identifier from an outgoing packet. If // the identifier is not already associated with the connection, it will be // added. ConnectionIdentifier func([]byte) (string, bool) } // Listen creates a new listener based on the ListenConfig. func ( *ListenConfig) ( string, *net.UDPAddr) (dtlsnet.PacketListener, error) { if .Backlog == 0 { .Backlog = defaultListenBacklog } , := net.ListenUDP(, ) if != nil { return nil, } := &listener{ pConn: , acceptCh: make(chan *PacketConn, .Backlog), conns: make(map[string]*PacketConn), doneCh: make(chan struct{}), acceptFilter: .AcceptFilter, datagramRouter: .DatagramRouter, connIdentifier: .ConnectionIdentifier, readDoneCh: make(chan struct{}), } .accepting.Store(true) .connWG.Add(1) .readWG.Add(2) // wait readLoop and Close execution routine go .readLoop() go func() { .connWG.Wait() if := .pConn.Close(); != nil { .errClose.Store() } .readWG.Done() }() return , nil } // Listen creates a new listener using default ListenConfig. func ( string, *net.UDPAddr) (dtlsnet.PacketListener, error) { return (&ListenConfig{}).Listen(, ) } // readLoop dispatches packets to the proper connection, creating a new one if // necessary, until all connections are closed. func ( *listener) () { defer .readWG.Done() defer close(.readDoneCh) := make([]byte, receiveMTU) for { , , := .pConn.ReadFrom() if != nil { .errRead.Store() return } , , := .getConn(, [:]) if != nil { continue } if { _, _ = .buffer.WriteTo([:], ) } } } // getConn gets an existing connection or creates a new one. func ( *listener) ( net.Addr, []byte) (*PacketConn, bool, error) { //nolint:cyclop .connLock.Lock() defer .connLock.Unlock() // If we have a custom resolver, use it. if .datagramRouter != nil { if , := .datagramRouter(); { if , := .conns[]; { return , true, nil } } } // If we don't have a custom resolver, or we were unable to find an // associated connection, fall back to remote address. , := .conns[.String()] if ! { if , := .accepting.Load().(bool); ! || ! { return nil, false, ErrClosedListener } if .acceptFilter != nil { if !.acceptFilter() { return nil, false, nil } } = .newPacketConn() select { case .acceptCh <- : .conns[.String()] = default: return nil, false, ErrListenQueueExceeded } } return , true, nil } // PacketConn is a net.PacketConn implementation that is able to dictate its // routing ID via an alternate identifier from its remote address. Internal // buffering is performed for reads, and writes are passed through to the // underlying net.PacketConn. type PacketConn struct { listener *listener raddr net.Addr rmraddr atomic.Value // bool id atomic.Value // string buffer *idtlsnet.PacketBuffer doneCh chan struct{} doneOnce sync.Once writeDeadline *deadline.Deadline } // newPacketConn constructs a new PacketConn. func ( *listener) ( net.Addr) *PacketConn { return &PacketConn{ listener: , raddr: , buffer: idtlsnet.NewPacketBuffer(), doneCh: make(chan struct{}), writeDeadline: deadline.New(), } } // ReadFrom reads a single packet payload and its associated remote address from // the underlying buffer. func ( *PacketConn) ( []byte) (int, net.Addr, error) { return .buffer.ReadFrom() } // WriteTo writes len(payload) bytes from payload to the specified address. func ( *PacketConn) ( []byte, net.Addr) ( int, error) { // If we have a connection identifier, check to see if the outgoing packet // sets it. if .listener.connIdentifier != nil { := .id.Load() // Only update establish identifier if we haven't already done so. if == nil { , := .listener.connIdentifier() // If we have an identifier, add entry to connection map. if { .listener.connLock.Lock() .listener.conns[] = .listener.connLock.Unlock() .id.Store() } } // If we are writing to a remote address that differs from the initial, // we have an alternate identifier established, and we haven't already // freed the remote address, free the remote address to be used by // another connection. // Note: this strategy results in holding onto a remote address after it // is potentially no longer in use by the client. However, releasing // earlier means that we could miss some packets that should have been // routed to this connection. Ideally, we would drop the connection // entry for the remote address as soon as the client starts sending // using an alternate identifier, but in practice this proves // challenging because any client could spoof a connection identifier, // resulting in the remote address entry being dropped prior to the // "real" client transitioning to sending using the alternate // identifier. if != nil && .rmraddr.Load() == nil && .String() != .raddr.String() { .listener.connLock.Lock() delete(.listener.conns, .raddr.String()) .rmraddr.Store(true) .listener.connLock.Unlock() } } select { case <-.writeDeadline.Done(): return 0, context.DeadlineExceeded default: } return .listener.pConn.WriteTo(, ) } // Close closes the conn and releases any Read calls. func ( *PacketConn) () error { var error .doneOnce.Do(func() { .listener.connWG.Done() close(.doneCh) .listener.connLock.Lock() // If we have an alternate identifier, remove it from the connection // map. if := .id.Load(); != nil { delete(.listener.conns, .(string)) //nolint:forcetypeassert } // If we haven't already removed the remote address, remove it from the // connection map. if .rmraddr.Load() == nil { delete(.listener.conns, .raddr.String()) .rmraddr.Store(true) } := len(.listener.conns) .listener.connLock.Unlock() if , := .listener.accepting.Load().(bool); == 0 && ! && { // Wait if this is the final connection .listener.readWG.Wait() if , := .listener.errClose.Load().(error); { = } } else { = nil } if := .buffer.Close(); != nil && == nil { = } }) return } // LocalAddr implements net.PacketConn.LocalAddr. func ( *PacketConn) () net.Addr { return .listener.pConn.LocalAddr() } // SetDeadline implements net.PacketConn.SetDeadline. func ( *PacketConn) ( time.Time) error { .writeDeadline.Set() return .SetReadDeadline() } // SetReadDeadline implements net.PacketConn.SetReadDeadline. func ( *PacketConn) ( time.Time) error { return .buffer.SetReadDeadline() } // SetWriteDeadline implements net.PacketConn.SetWriteDeadline. func ( *PacketConn) ( time.Time) error { .writeDeadline.Set() // Write deadline of underlying connection should not be changed // since the connection can be shared. return nil }