// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

// Package mux multiplexes packets on a single socket (RFC7983)
package mux import ( ) const ( // The maximum amount of data that can be buffered before returning errors. maxBufferSize = 1000 * 1000 // 1MB // How many total pending packets can be cached. maxPendingPackets = 15 ) // Config collects the arguments to mux.Mux construction into // a single structure. type Config struct { Conn net.Conn BufferSize int LoggerFactory logging.LoggerFactory } // Mux allows multiplexing. type Mux struct { nextConn net.Conn bufferSize int lock sync.Mutex endpoints map[*Endpoint]MatchFunc isClosed bool pendingPackets [][]byte closedCh chan struct{} log logging.LeveledLogger } // NewMux creates a new Mux. func ( Config) *Mux { := &Mux{ nextConn: .Conn, endpoints: make(map[*Endpoint]MatchFunc), bufferSize: .BufferSize, closedCh: make(chan struct{}), log: .LoggerFactory.NewLogger("mux"), } go .readLoop() return } // NewEndpoint creates a new Endpoint. func ( *Mux) ( MatchFunc) *Endpoint { := &Endpoint{ mux: , buffer: packetio.NewBuffer(), } // Set a maximum size of the buffer in bytes. .buffer.SetLimitSize(maxBufferSize) .lock.Lock() .endpoints[] = .lock.Unlock() go .handlePendingPackets(, ) return } // RemoveEndpoint removes an endpoint from the Mux. func ( *Mux) ( *Endpoint) { .lock.Lock() defer .lock.Unlock() delete(.endpoints, ) } // Close closes the Mux and all associated Endpoints. func ( *Mux) () error { .lock.Lock() for := range .endpoints { if := .close(); != nil { .lock.Unlock() return } delete(.endpoints, ) } .isClosed = true .lock.Unlock() := .nextConn.Close() if != nil { return } // Wait for readLoop to end <-.closedCh return nil } func ( *Mux) () { defer func() { close(.closedCh) }() := make([]byte, .bufferSize) for { , := .nextConn.Read() switch { case errors.Is(, io.EOF), errors.Is(, ice.ErrClosed): return case errors.Is(, io.ErrShortBuffer), errors.Is(, packetio.ErrTimeout): .log.Errorf("mux: failed to read from packetio.Buffer %s", .Error()) continue case != nil: .log.Errorf("mux: ending readLoop packetio.Buffer error %s", .Error()) return } if = .dispatch([:]); != nil { if errors.Is(, io.ErrClosedPipe) { // if the buffer was closed, that's not an error we care to report return } .log.Errorf("mux: ending readLoop dispatch error %s", .Error()) return } } } func ( *Mux) ( []byte) error { if len() == 0 { .log.Warnf("Warning: mux: unable to dispatch zero length packet") return nil } var *Endpoint .lock.Lock() for , := range .endpoints { if () { = break } } if == nil { defer .lock.Unlock() if !.isClosed { if len(.pendingPackets) >= maxPendingPackets { .log.Warnf( "Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)", [0], //nolint:gosec // G602, false positive? len(.pendingPackets), ) } else { .log.Warnf( "Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)", [0], //nolint:gosec // G602, false positive? len(.pendingPackets), ) .pendingPackets = append(.pendingPackets, append([]byte{}, ...)) } } return nil } .lock.Unlock() , := .buffer.Write() // Expected when bytes are received faster than the endpoint can process them (#2152, #2180) if errors.Is(, packetio.ErrFull) { .log.Infof("mux: endpoint buffer is full, dropping packet") return nil } return } func ( *Mux) ( *Endpoint, MatchFunc) { .lock.Lock() defer .lock.Unlock() := make([][]byte, len(.pendingPackets)) for , := range .pendingPackets { if () { if , := .buffer.Write(); != nil { .log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", ) } } else { = append(, ) //nolint:makezero // todo fix } } .pendingPackets = }