// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

// Package udp provides a connection-oriented listener over a UDP PacketConn
package udp import ( ) const ( receiveMTU = 8192 sendMTU = 1500 defaultListenBacklog = 128 // same as Linux default ) // Typed errors var ( ErrClosedListener = errors.New("udp: listener closed") ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") ErrInvalidBatchConfig = errors.New("udp: invalid batch config") ) // listener augments a connection-oriented Listener over a UDP PacketConn type listener struct { pConn net.PacketConn readBatchSize int accepting atomic.Value // bool acceptCh chan *Conn doneCh chan struct{} doneOnce sync.Once acceptFilter func([]byte) bool connLock sync.Mutex conns map[string]*Conn connWG *sync.WaitGroup readWG sync.WaitGroup errClose atomic.Value // error readDoneCh chan struct{} errRead atomic.Value // error logger logging.LeveledLogger } // Accept waits for and returns the next connection to the listener. func ( *listener) () (net.Conn, error) { select { case := <-.acceptCh: .connWG.Add(1) return , nil case <-.readDoneCh: , := .errRead.Load().(error) return nil, case <-.doneCh: return nil, ErrClosedListener } } // Close closes the listener. // Any blocked Accept operations will be unblocked and return errors. func ( *listener) () error { var error .doneOnce.Do(func() { .accepting.Store(false) close(.doneCh) .connLock.Lock() // Close unaccepted connections : for { select { case := <-.acceptCh: close(.doneCh) delete(.conns, .rAddr.String()) default: break } } := len(.conns) .connLock.Unlock() .connWG.Done() if == 0 { // Wait if this is the final connection .readWG.Wait() if , := .errClose.Load().(error); { = } } else { = nil } }) return } // Addr returns the listener's network address. func ( *listener) () net.Addr { return .pConn.LocalAddr() } // BatchIOConfig indicates config to batch read/write packets, // it will use ReadBatch/WriteBatch to improve throughput for UDP. type BatchIOConfig struct { Enable bool // ReadBatchSize indicates the maximum number of packets to be read in one batch, a batch size less than 2 means // disable read batch. ReadBatchSize int // WriteBatchSize indicates the maximum number of packets to be written in one batch WriteBatchSize int // WriteBatchInterval indicates the maximum interval to wait before writing packets in one batch // small interval will reduce latency/jitter, but increase the io count. WriteBatchInterval time.Duration } // ListenConfig stores options for listening to an address. type ListenConfig struct { // Backlog defines the maximum length of the queue of pending // connections. It is equivalent of the backlog argument of // POSIX listen function. // If a connection request arrives when the queue is full, // the request will be silently discarded, unlike TCP. // Set zero to use default value 128 which is same as Linux default. Backlog int // AcceptFilter determines whether the new conn should be made for // the incoming packet. If not set, any packet creates new conn. AcceptFilter func([]byte) bool // ReadBufferSize sets the size of the operating system's // receive buffer associated with the listener. ReadBufferSize int // WriteBufferSize sets the size of the operating system's // send buffer associated with the connection. WriteBufferSize int Batch BatchIOConfig LoggerFactory logging.LoggerFactory } // Listen creates a new listener based on the ListenConfig. func ( *ListenConfig) ( string, *net.UDPAddr) (net.Listener, error) { if .Backlog == 0 { .Backlog = defaultListenBacklog } if .Batch.Enable && (.Batch.WriteBatchSize <= 0 || .Batch.WriteBatchInterval <= 0) { return nil, ErrInvalidBatchConfig } , := net.ListenUDP(, ) if != nil { return nil, } if .ReadBufferSize > 0 { _ = .SetReadBuffer(.ReadBufferSize) } if .WriteBufferSize > 0 { _ = .SetWriteBuffer(.WriteBufferSize) } := .LoggerFactory if == nil { = logging.NewDefaultLoggerFactory() } := .NewLogger("transport") := &listener{ pConn: , acceptCh: make(chan *Conn, .Backlog), conns: make(map[string]*Conn), doneCh: make(chan struct{}), acceptFilter: .AcceptFilter, connWG: &sync.WaitGroup{}, readDoneCh: make(chan struct{}), logger: , } if .Batch.Enable { .pConn = NewBatchConn(, .Batch.WriteBatchSize, .Batch.WriteBatchInterval) .readBatchSize = .Batch.ReadBatchSize } .accepting.Store(true) .connWG.Add(1) .readWG.Add(2) // wait readLoop and Close execution routine go .readLoop() go func() { .connWG.Wait() if := .pConn.Close(); != nil { .errClose.Store() } .readWG.Done() }() return , nil } // Listen creates a new listener using default ListenConfig. func ( string, *net.UDPAddr) (net.Listener, error) { return (&ListenConfig{}).Listen(, ) } // readLoop has to tasks: // 1. Dispatching incoming packets to the correct Conn. // It can therefore not be ended until all Conns are closed. // 2. Creating a new Conn when receiving from a new remote. func ( *listener) () { defer .readWG.Done() defer close(.readDoneCh) if , := .pConn.(BatchReader); && .readBatchSize > 1 { .readBatch() } else { .read() } } func ( *listener) ( BatchReader) { := make([]ipv4.Message, .readBatchSize) for := range { := &[] .Buffers = [][]byte{make([]byte, receiveMTU)} .OOB = make([]byte, 40) } for { , := .ReadBatch(, 0) if != nil { .errRead.Store() return } for := 0; < ; ++ { .dispatchMsg([].Addr, [].Buffers[0][:[].N]) } } } func ( *listener) () { := make([]byte, receiveMTU) for { , , := .pConn.ReadFrom() if != nil { .errRead.Store() .logger.Tracef("error reading from connection err=%v", ) return } .dispatchMsg(, [:]) } } func ( *listener) ( net.Addr, []byte) { , , := .getConn(, ) if != nil { return } if { , := .buffer.Write() if != nil { .logger.Tracef("error dispatching message addr=%v err=%v", , ) } } } func ( *listener) ( net.Addr, []byte) (*Conn, bool, error) { .connLock.Lock() defer .connLock.Unlock() , := .conns[.String()] if ! { if , := .accepting.Load().(bool); ! || ! { return nil, false, ErrClosedListener } if .acceptFilter != nil { if !.acceptFilter() { return nil, false, nil } } = .newConn() select { case .acceptCh <- : .conns[.String()] = default: return nil, false, ErrListenQueueExceeded } } return , true, nil } // Conn augments a connection-oriented connection over a UDP PacketConn type Conn struct { listener *listener rAddr net.Addr buffer *packetio.Buffer doneCh chan struct{} doneOnce sync.Once writeDeadline *deadline.Deadline } func ( *listener) ( net.Addr) *Conn { return &Conn{ listener: , rAddr: , buffer: packetio.NewBuffer(), doneCh: make(chan struct{}), writeDeadline: deadline.New(), } } // Read reads from c into p func ( *Conn) ( []byte) (int, error) { return .buffer.Read() } // Write writes len(p) bytes from p to the DTLS connection func ( *Conn) ( []byte) ( int, error) { select { case <-.writeDeadline.Done(): return 0, context.DeadlineExceeded default: } return .listener.pConn.WriteTo(, .rAddr) } // Close closes the conn and releases any Read calls func ( *Conn) () error { var error .doneOnce.Do(func() { .listener.connWG.Done() close(.doneCh) .listener.connLock.Lock() delete(.listener.conns, .rAddr.String()) := len(.listener.conns) .listener.connLock.Unlock() if , := .listener.accepting.Load().(bool); == 0 && ! && { // Wait if this is the final connection .listener.readWG.Wait() if , := .listener.errClose.Load().(error); { = } } else { = nil } if := .buffer.Close(); != nil && == nil { = } }) return } // LocalAddr implements net.Conn.LocalAddr func ( *Conn) () net.Addr { return .listener.pConn.LocalAddr() } // RemoteAddr implements net.Conn.RemoteAddr func ( *Conn) () net.Addr { return .rAddr } // SetDeadline implements net.Conn.SetDeadline func ( *Conn) ( time.Time) error { .writeDeadline.Set() return .SetReadDeadline() } // SetReadDeadline implements net.Conn.SetDeadline func ( *Conn) ( time.Time) error { return .buffer.SetReadDeadline() } // SetWriteDeadline implements net.Conn.SetDeadline func ( *Conn) ( time.Time) error { .writeDeadline.Set() // Write deadline of underlying connection should not be changed // since the connection can be shared. return nil }