// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

// Package flexfec implements FlexFEC-03 to recover missing RTP packets due to packet loss. // https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03
package flexfec import ( ) // Static errors for the flexfec package. var ( errPacketTruncated = errors.New("packet truncated") errRetransmissionBitSet = errors.New("packet with retransmission bit set not supported") errInflexibleGeneratorMatrix = errors.New("packet with inflexible generator matrix not supported") errMultipleSSRCProtection = errors.New("multiple ssrc protection not supported") errLastOptionalMaskKBitSetToFalse = errors.New("k-bit of last optional mask is set to false") ) // fecDecoder is a WIP implementation decoder used for testing purposes. type fecDecoder struct { logger logging.LeveledLogger ssrc uint32 protectedStreamSSRC uint32 maxMediaPackets int maxFECPackets int recoveredPackets []rtp.Packet receivedFECPackets []fecPacketState } func newFECDecoder( uint32, uint32) *fecDecoder { return &fecDecoder{ logger: logging.NewDefaultLoggerFactory().NewLogger("fec_decoder"), ssrc: , protectedStreamSSRC: , maxMediaPackets: 100, maxFECPackets: 100, recoveredPackets: make([]rtp.Packet, 0), receivedFECPackets: make([]fecPacketState, 0), } } func ( *fecDecoder) ( rtp.Packet) []rtp.Packet { if len(.recoveredPackets) == .maxMediaPackets { := .recoveredPackets[len(.recoveredPackets)-1] if .SSRC == .SSRC { := seqDiff(.SequenceNumber, .SequenceNumber) if > uint16(.maxMediaPackets) { //nolint:gosec .logger.Info("big gap in media sequence numbers - resetting buffers") .recoveredPackets = nil .receivedFECPackets = nil } } } .insertPacket() return .attemptRecovery() } func ( *fecDecoder) ( rtp.Packet) { // Discard old FEC packets such that the sequence numbers in // `received_fec_packets_` span at most 1/2 of the sequence number space. // This is important for keeping `received_fec_packets_` sorted, and may // also reduce the possibility of incorrect decoding due to sequence number // wrap-around. if len(.receivedFECPackets) > 0 && .SSRC == .ssrc { := 0 for , := range .receivedFECPackets { if abs(int(.SequenceNumber)-int(.packet.SequenceNumber)) > 0x3fff { ++ } else { // No need to keep iterating, since |received_fec_packets_| is sorted. break } } } switch .SSRC { case .ssrc: .insertFECPacket() case .protectedStreamSSRC: .insertMediaPacket() } .discardOldRecoveredPackets() } func ( *fecDecoder) ( rtp.Packet) { for , := range .recoveredPackets { if .SequenceNumber == .SequenceNumber { return } } .recoveredPackets = append(.recoveredPackets, ) sort.Slice(.recoveredPackets, func(, int) bool { return isNewerSeq(.recoveredPackets[].SequenceNumber, .recoveredPackets[].SequenceNumber) }) .updateCoveringFecPackets() } func ( *fecDecoder) ( rtp.Packet) { for , := range .receivedFECPackets { for , := range .protectedPackets { if .seq == .SequenceNumber { .packet = & } } } } func ( *fecDecoder) ( rtp.Packet) { //nolint:cyclop for , := range .receivedFECPackets { if .packet.SequenceNumber == .SequenceNumber { return } } , := parseFlexFEC03Header(.Payload) if != nil { .logger.Errorf("failed to parse flexfec03 header: %v", ) return } if .protectedSSRC != .protectedStreamSSRC { .logger.Errorf("fec is protecting unknown ssrc, expected %d, got %d", .protectedSSRC, .protectedStreamSSRC) return } := decodeMask(uint64(.mask0), 15, .seqNumBase) if .mask1 != 0 { = append(, decodeMask(uint64(.mask1), 31, .seqNumBase+15)...) } if .mask2 != 0 { = append(, decodeMask(.mask2, 63, .seqNumBase+46)...) } if len() == 0 { .logger.Warn("empty fec packet mask") return } := make([]*protectedPacket, 0, len()) := 0 := 0 for < len() && < len(.recoveredPackets) { switch { case isNewerSeq([], .recoveredPackets[].SequenceNumber): = append(, &protectedPacket{ seq: [], packet: nil, }) ++ case isNewerSeq(.recoveredPackets[].SequenceNumber, []): ++ default: = append(, &protectedPacket{ seq: [], packet: &.recoveredPackets[], }) ++ ++ } } for < len() { = append(, &protectedPacket{ seq: [], packet: nil, }) ++ } .receivedFECPackets = append(.receivedFECPackets, fecPacketState{ packet: , flexFec: , protectedPackets: , }) sort.Slice(.receivedFECPackets, func(, int) bool { return isNewerSeq(.receivedFECPackets[].packet.SequenceNumber, .receivedFECPackets[].packet.SequenceNumber) }) if len(.receivedFECPackets) > .maxFECPackets { .receivedFECPackets = .receivedFECPackets[1:] } } func ( *fecDecoder) () []rtp.Packet { := make([]rtp.Packet, 0) for { := 0 for , := range .receivedFECPackets { := 0 for , := range .protectedPackets { if .packet == nil { ++ if > 1 { break } } } if != 1 { continue } , := .recoverPacket(&) //nolint:gosec if != nil { .logger.Errorf("failed to recover packet: %v", ) } = append(, ) .recoveredPackets = append(.recoveredPackets, ) sort.Slice(.recoveredPackets, func(, int) bool { return isNewerSeq(.recoveredPackets[].SequenceNumber, .recoveredPackets[].SequenceNumber) }) .updateCoveringFecPackets() .discardOldRecoveredPackets() ++ } if == 0 { break } } return } func ( *fecDecoder) ( *fecPacketState) (rtp.Packet, error) { // https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03#section-6.3.2 // 2. For the repair packet in T, extract the FEC bit string as the // first 80 bits of the FEC header. := make([]byte, 12) copy(, .packet.Payload[:10]) var uint16 for , := range .protectedPackets { if .packet != nil { // 1. For each of the source packets that are successfully received in // T, compute the 80-bit string by concatenating the first 64 bits // of their RTP header and the unsigned network-ordered 16-bit // representation of their length in bytes minus 12. , := .packet.Header.Marshal() if != nil { return rtp.Packet{}, fmt.Errorf("marshal received header: %w", ) } binary.BigEndian.PutUint16([2:4], uint16(.packet.MarshalSize()-12)) //nolint:gosec for := 0; < 8; ++ { [] ^= [] } } else { = .seq } } // set version to 2 [0] |= 0x80 [0] &= 0xbf := binary.BigEndian.Uint16([2:4]) binary.BigEndian.PutUint16([2:4], ) binary.BigEndian.PutUint32([8:12], .protectedStreamSSRC) := make([]byte, ) copy(, .flexFec.payload) for , := range .protectedPackets { if .packet != nil { , := .packet.Marshal() if != nil { return rtp.Packet{}, fmt.Errorf("marshal protected packet: %w", ) } for := 0; < minInt(int(), len()-12); ++ { [] ^= [12+] } } } = append(, ...) //nolint:makezero var rtp.Packet := .Unmarshal() if != nil { return rtp.Packet{}, fmt.Errorf("unmarshal recovered: %w", ) } return , nil } func ( *fecDecoder) () { const = 192 if len(.recoveredPackets) > { .recoveredPackets = .recoveredPackets[len(.recoveredPackets)-192:] } } func decodeMask( uint64, uint16, uint16) []uint16 { := make([]uint16, 0) for := uint16(0); < ; ++ { if (>>(-1-))&1 == 1 { = append(, +) } } return } type fecPacketState struct { packet rtp.Packet flexFec flexFec protectedPackets []*protectedPacket } type flexFec struct { protectedSSRC uint32 seqNumBase uint16 mask0 uint16 mask1 uint32 mask2 uint64 payload []byte } type protectedPacket struct { seq uint16 packet *rtp.Packet } func parseFlexFEC03Header( []byte) (flexFec, error) { if len() < 20 { return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len()) } := ([0] & 0x80) != 0 if { return flexFec{}, errRetransmissionBitSet } := ([0] & 0x40) != 0 if { return flexFec{}, errInflexibleGeneratorMatrix } := [8] if != 1 { return flexFec{}, fmt.Errorf("%w: count %d", errMultipleSSRCProtection, ) } := binary.BigEndian.Uint32([12:]) := binary.BigEndian.Uint16([16:]) := [18:] var []byte := ([0] & 0x80) != 0 := binary.BigEndian.Uint16([0:2]) & 0x7FFF var uint32 var uint64 if { //nolint:nestif = [2:] } else { if len() < 24 { return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len()) } := ([2] & 0x80) != 0 = binary.BigEndian.Uint32([2:]) & 0x7FFFFFFF if { = [6:] } else { if len() < 32 { return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len()) } := ([6] & 0x80) != 0 = binary.BigEndian.Uint64([6:]) & 0x7FFFFFFFFFFFFFFF if { = [14:] } else { return flexFec{}, errLastOptionalMaskKBitSetToFalse } } } return flexFec{ protectedSSRC: , seqNumBase: , mask0: , mask1: , mask2: , payload: , }, nil } func seqDiff(, uint16) uint16 { return minUInt16(-, -) } func minInt(, int) int { if < { return } return } func minUInt16(, uint16) uint16 { if < { return } return } func abs( int) int { if >= 0 { return } return - } func isNewerSeq(, uint16) bool { // half-way mark := uint16(0x8000) if - == { return > } return != && (-) < }