// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package twcc

import (
	
	

	
	
)

var errHeaderIsNil = errors.New("header is nil")

// HeaderExtensionInterceptorFactory is a interceptor.Factory for a HeaderExtensionInterceptor.
type HeaderExtensionInterceptorFactory struct{}

// NewInterceptor constructs a new HeaderExtensionInterceptor.
func ( *HeaderExtensionInterceptorFactory) ( string) (interceptor.Interceptor, error) {
	return &HeaderExtensionInterceptor{}, nil
}

// NewHeaderExtensionInterceptor returns a HeaderExtensionInterceptorFactory.
func () (*HeaderExtensionInterceptorFactory, error) {
	return &HeaderExtensionInterceptorFactory{}, nil
}

// HeaderExtensionInterceptor adds transport wide sequence numbers as header extension to each RTP packet.
type HeaderExtensionInterceptor struct {
	interceptor.NoOp
	nextSequenceNr uint32
}

const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"

// BindLocalStream returns a writer that adds a rtp.TransportCCExtension
// header with increasing sequence numbers to each outgoing packet.
func ( *HeaderExtensionInterceptor) (
	 *interceptor.StreamInfo,
	 interceptor.RTPWriter,
) interceptor.RTPWriter {
	var  uint8
	for ,  := range .RTPHeaderExtensions {
		if .URI == transportCCURI {
			 = uint8(.ID) //nolint:gosec // G115

			break
		}
	}
	if  == 0 { // Don't add header extension if ID is 0, because 0 is an invalid extension ID
		return 
	}

	return interceptor.RTPWriterFunc(
		func( *rtp.Header,  []byte,  interceptor.Attributes) (int, error) {
			 := atomic.AddUint32(&.nextSequenceNr, 1) - 1
			//nolint:gosec // G115
			,  := (&rtp.TransportCCExtension{TransportSequence: uint16()}).Marshal()
			if  != nil {
				return 0, 
			}
			if  == nil {
				return 0, errHeaderIsNil
			}
			 = .SetExtension(, )
			if  != nil {
				return 0, 
			}

			return .Write(, , )
		},
	)
}