// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package rtp

import (
	
	
	
)

// Extension RTP Header extension.
type Extension struct {
	id      uint8
	payload []byte
}

// Header represents an RTP packet header.
type Header struct {
	Version          uint8
	Padding          bool
	Extension        bool
	Marker           bool
	PayloadType      uint8
	SequenceNumber   uint16
	Timestamp        uint32
	SSRC             uint32
	CSRC             []uint32
	ExtensionProfile uint16
	Extensions       []Extension

	// PaddingLength is the length of the padding in bytes. It is not part of the RTP header
	// (it is sent in the last byte of RTP packet padding), but logically it belongs here.
	PaddingSize byte

	// Deprecated: will be removed in a future version.
	PayloadOffset int
}

// Packet represents an RTP Packet.
type Packet struct {
	Header
	Payload []byte

	PaddingSize byte // Deprecated: will be removed in a future version. Use Header.PaddingSize instead.

	// Deprecated: will be removed in a future version.
	Raw []byte

	// Please do not add any new field directly to Packet struct unless you know that it is safe.
	// pion internally passes Header and Payload separately, what causes bugs like
	// https://github.com/pion/webrtc/issues/2403 .
}

const (
	// ExtensionProfileOneByte is the RTP One Byte Header Extension Profile, defined in RFC 8285.
	ExtensionProfileOneByte = 0xBEDE
	// ExtensionProfileTwoByte is the RTP Two Byte Header Extension Profile, defined in RFC 8285.
	ExtensionProfileTwoByte = 0x1000
	// CryptexProfileOneByte is the Cryptex One Byte Header Extension Profile, defined in RFC 9335.
	CryptexProfileOneByte = 0xC0DE
	// CryptexProfileTwoByte is the Cryptex Two Byte Header Extension Profile, defined in RFC 9335.
	CryptexProfileTwoByte = 0xC2DE
)

const (
	headerLength        = 4
	versionShift        = 6
	versionMask         = 0x3
	paddingShift        = 5
	paddingMask         = 0x1
	extensionShift      = 4
	extensionMask       = 0x1
	extensionIDReserved = 0xF
	ccMask              = 0xF
	markerShift         = 7
	markerMask          = 0x1
	ptMask              = 0x7F
	seqNumOffset        = 2
	seqNumLength        = 2
	timestampOffset     = 4
	timestampLength     = 4
	ssrcOffset          = 8
	ssrcLength          = 4
	csrcOffset          = 12
	csrcLength          = 4
)

// String helps with debugging by printing packet information in a readable way.
func ( Packet) () string {
	 := "RTP PACKET:\n"

	 += fmt.Sprintf("\tVersion: %v\n", .Version)
	 += fmt.Sprintf("\tMarker: %v\n", .Marker)
	 += fmt.Sprintf("\tPayload Type: %d\n", .PayloadType)
	 += fmt.Sprintf("\tSequence Number: %d\n", .SequenceNumber)
	 += fmt.Sprintf("\tTimestamp: %d\n", .Timestamp)
	 += fmt.Sprintf("\tSSRC: %d (%x)\n", .SSRC, .SSRC)
	 += fmt.Sprintf("\tPayload Length: %d\n", len(.Payload))

	return 
}

// Unmarshal parses the passed byte slice and stores the result in the Header.
// It returns the number of bytes read n and any error.
func ( *Header) ( []byte) ( int,  error) { //nolint:gocognit,cyclop
	if len() < headerLength {
		return 0, fmt.Errorf("%w: %d < %d", errHeaderSizeInsufficient, len(), headerLength)
	}

	/*
	 *  0                   1                   2                   3
	 *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 * |V=2|P|X|  CC   |M|     PT      |       sequence number         |
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 * |                           timestamp                           |
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 * |           synchronization source (SSRC) identifier            |
	 * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
	 * |            contributing source (CSRC) identifiers             |
	 * |                             ....                              |
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 */

	.Version = [0] >> versionShift & versionMask
	.Padding = ([0] >> paddingShift & paddingMask) > 0
	.Extension = ([0] >> extensionShift & extensionMask) > 0
	 := int([0] & ccMask)
	if cap(.CSRC) <  || .CSRC == nil {
		.CSRC = make([]uint32, )
	} else {
		.CSRC = .CSRC[:]
	}

	 = csrcOffset + ( * csrcLength)
	if len() <  {
		return , fmt.Errorf("size %d < %d: %w", len(), ,
			errHeaderSizeInsufficient)
	}

	.Marker = ([1] >> markerShift & markerMask) > 0
	.PayloadType = [1] & ptMask

	.SequenceNumber = binary.BigEndian.Uint16([seqNumOffset : seqNumOffset+seqNumLength])
	.Timestamp = binary.BigEndian.Uint32([timestampOffset : timestampOffset+timestampLength])
	.SSRC = binary.BigEndian.Uint32([ssrcOffset : ssrcOffset+ssrcLength])

	for  := range .CSRC {
		 := csrcOffset + ( * csrcLength)
		.CSRC[] = binary.BigEndian.Uint32([:])
	}

	if .Extensions != nil {
		.Extensions = .Extensions[:0]
	}

	if .Extension { // nolint: nestif
		if  :=  + 4; len() <  {
			return , fmt.Errorf("size %d < %d: %w",
				len(), ,
				errHeaderSizeInsufficientForExtension,
			)
		}

		.ExtensionProfile = binary.BigEndian.Uint16([:])
		 += 2
		 := int(binary.BigEndian.Uint16([:])) * 4
		 += 2
		 :=  + 

		if len() <  {
			return , fmt.Errorf("size %d < %d: %w", len(), , errHeaderSizeInsufficientForExtension)
		}

		if .ExtensionProfile == ExtensionProfileOneByte || .ExtensionProfile == ExtensionProfileTwoByte {
			var (
				      uint8
				 int
			)

			for  <  {
				if [] == 0x00 { // padding
					++

					continue
				}

				if .ExtensionProfile == ExtensionProfileOneByte {
					 = [] >> 4
					 = int([]&^0xF0 + 1)
					++

					if  == extensionIDReserved {
						break
					}
				} else {
					 = []
					++

					if len() <=  {
						return , fmt.Errorf("size %d < %d: %w", len(), , errHeaderSizeInsufficientForExtension)
					}

					 = int([])
					++
				}

				if  :=  + ; len() <=  {
					return , fmt.Errorf("size %d < %d: %w", len(), , errHeaderSizeInsufficientForExtension)
				}

				 := Extension{id: , payload: [ : +]}
				.Extensions = append(.Extensions, )
				 += 
			}
		} else {
			// RFC3550 Extension
			 := Extension{id: 0, payload: [:]}
			.Extensions = append(.Extensions, )
			 += len(.Extensions[0].payload)
		}
	}

	return , nil
}

// Unmarshal parses the passed byte slice and stores the result in the Packet.
func ( *Packet) ( []byte) error {
	,  := .Header.Unmarshal()
	if  != nil {
		return 
	}

	 := len()
	if .Header.Padding {
		if  <=  {
			return errTooSmall
		}
		.Header.PaddingSize = [-1]
		 -= int(.Header.PaddingSize)
	} else {
		.Header.PaddingSize = 0
	}
	.PaddingSize = .Header.PaddingSize
	if  <  {
		return errTooSmall
	}

	.Payload = [:]

	return nil
}

// Marshal serializes the header into bytes.
func ( Header) () ( []byte,  error) {
	 = make([]byte, .MarshalSize())

	,  := .MarshalTo()
	if  != nil {
		return nil, 
	}

	return [:], nil
}

// MarshalTo serializes the header and writes to the buffer.
func ( Header) ( []byte) ( int,  error) { //nolint:cyclop
	/*
	 *  0                   1                   2                   3
	 *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 * |V=2|P|X|  CC   |M|     PT      |       sequence number         |
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 * |                           timestamp                           |
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 * |           synchronization source (SSRC) identifier            |
	 * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
	 * |            contributing source (CSRC) identifiers             |
	 * |                             ....                              |
	 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	 */

	 := .MarshalSize()
	if  > len() {
		return 0, io.ErrShortBuffer
	}

	// The first byte contains the version, padding bit, extension bit,
	// and csrc size.
	[0] = (.Version << versionShift) | uint8(len(.CSRC)) // nolint: gosec // G115
	if .Padding {
		[0] |= 1 << paddingShift
	}

	if .Extension {
		[0] |= 1 << extensionShift
	}

	// The second byte contains the marker bit and payload type.
	[1] = .PayloadType
	if .Marker {
		[1] |= 1 << markerShift
	}

	binary.BigEndian.PutUint16([2:4], .SequenceNumber)
	binary.BigEndian.PutUint32([4:8], .Timestamp)
	binary.BigEndian.PutUint32([8:12], .SSRC)

	 = 12
	for ,  := range .CSRC {
		binary.BigEndian.PutUint32([:+4], )
		 += 4
	}

	if .Extension {
		 := 
		binary.BigEndian.PutUint16([+0:+2], .ExtensionProfile)
		 += 4
		 := 

		switch .ExtensionProfile {
		// RFC 8285 RTP One Byte Header Extension
		case ExtensionProfileOneByte:
			for ,  := range .Extensions {
				[] = .id<<4 | (uint8(len(.payload)) - 1) // nolint: gosec // G115
				++
				 += copy([:], .payload)
			}
		// RFC 8285 RTP Two Byte Header Extension
		case ExtensionProfileTwoByte:
			for ,  := range .Extensions {
				[] = .id
				++
				[] = uint8(len(.payload)) // nolint: gosec // G115
				++
				 += copy([:], .payload)
			}
		default: // RFC3550 Extension
			 := len(.Extensions[0].payload)
			if %4 != 0 {
				// the payload must be in 32-bit words.
				return 0, io.ErrShortBuffer
			}
			 += copy([:], .Extensions[0].payload)
		}

		// calculate extensions size and round to 4 bytes boundaries
		 :=  - 
		 := (( + 3) / 4) * 4

		// nolint: gosec // G115 false positive
		binary.BigEndian.PutUint16([+2:+4], uint16(/4))

		// add padding to reach 4 bytes boundaries
		for  := 0;  < -; ++ {
			[] = 0
			++
		}
	}

	return , nil
}

// MarshalSize returns the size of the header once marshaled.
func ( Header) () int {
	// NOTE: Be careful to match the MarshalTo() method.
	 := 12 + (len(.CSRC) * csrcLength)

	if .Extension {
		 := 4

		switch .ExtensionProfile {
		// RFC 8285 RTP One Byte Header Extension
		case ExtensionProfileOneByte:
			for ,  := range .Extensions {
				 += 1 + len(.payload)
			}
		// RFC 8285 RTP Two Byte Header Extension
		case ExtensionProfileTwoByte:
			for ,  := range .Extensions {
				 += 2 + len(.payload)
			}
		default:
			 += len(.Extensions[0].payload)
		}

		// extensions size must have 4 bytes boundaries
		 += (( + 3) / 4) * 4
	}

	return 
}

// SetExtension sets an RTP header extension.
func ( *Header) ( uint8,  []byte) error { //nolint:gocognit, cyclop
	if .Extension { // nolint: nestif
		switch .ExtensionProfile {
		// RFC 8285 RTP One Byte Header Extension
		case ExtensionProfileOneByte:
			if  < 1 ||  > 14 {
				return fmt.Errorf("%w actual(%d)", errRFC8285OneByteHeaderIDRange, )
			}
			if len() > 16 {
				return fmt.Errorf("%w actual(%d)", errRFC8285OneByteHeaderSize, len())
			}
		// RFC 8285 RTP Two Byte Header Extension
		case ExtensionProfileTwoByte:
			if  < 1 {
				return fmt.Errorf("%w actual(%d)", errRFC8285TwoByteHeaderIDRange, )
			}
			if len() > 255 {
				return fmt.Errorf("%w actual(%d)", errRFC8285TwoByteHeaderSize, len())
			}
		default: // RFC3550 Extension
			if  != 0 {
				return fmt.Errorf("%w actual(%d)", errRFC3550HeaderIDRange, )
			}
		}

		// Update existing if it exists else add new extension
		for ,  := range .Extensions {
			if .id ==  {
				.Extensions[].payload = 

				return nil
			}
		}

		.Extensions = append(.Extensions, Extension{id: , payload: })

		return nil
	}

	// No existing header extensions
	.Extension = true

	switch  := len(); {
	case  <= 16:
		.ExtensionProfile = ExtensionProfileOneByte
	case  > 16 &&  < 256:
		.ExtensionProfile = ExtensionProfileTwoByte
	}

	.Extensions = append(.Extensions, Extension{id: , payload: })

	return nil
}

// GetExtensionIDs returns an extension id array.
func ( *Header) () []uint8 {
	if !.Extension {
		return nil
	}

	if len(.Extensions) == 0 {
		return nil
	}

	 := make([]uint8, 0, len(.Extensions))
	for ,  := range .Extensions {
		 = append(, .id)
	}

	return 
}

// GetExtension returns an RTP header extension.
func ( *Header) ( uint8) []byte {
	if !.Extension {
		return nil
	}
	for ,  := range .Extensions {
		if .id ==  {
			return .payload
		}
	}

	return nil
}

// DelExtension Removes an RTP Header extension.
func ( *Header) ( uint8) error {
	if !.Extension {
		return errHeaderExtensionsNotEnabled
	}
	for ,  := range .Extensions {
		if .id ==  {
			.Extensions = append(.Extensions[:], .Extensions[+1:]...)

			return nil
		}
	}

	return errHeaderExtensionNotFound
}

// Marshal serializes the packet into bytes.
func ( Packet) () ( []byte,  error) {
	 = make([]byte, .MarshalSize())

	,  := .MarshalTo()
	if  != nil {
		return nil, 
	}

	return [:], nil
}

// MarshalTo serializes the packet and writes to the buffer.
func ( *Packet) ( []byte) ( int,  error) {
	if .Header.Padding && .paddingSize() == 0 {
		return 0, errInvalidRTPPadding
	}

	,  = .Header.MarshalTo()
	if  != nil {
		return 0, 
	}

	return marshalPayloadAndPaddingTo(, , &.Header, .Payload, .paddingSize())
}

func marshalPayloadAndPaddingTo( []byte,  int,  *Header,  []byte,  byte,
) ( int,  error) {
	// Make sure the buffer is large enough to hold the packet.
	if +len()+int() > len() {
		return 0, io.ErrShortBuffer
	}

	 := copy([:], )

	if .Padding {
		[++int(-1)] = 
	}

	return  +  + int(), nil
}

// MarshalSize returns the size of the packet once marshaled.
func ( Packet) () int {
	return .Header.MarshalSize() + len(.Payload) + int(.paddingSize())
}

// Clone returns a deep copy of p.
func ( Packet) () *Packet {
	 := &Packet{}
	.Header = .Header.Clone()
	if .Payload != nil {
		.Payload = make([]byte, len(.Payload))
		copy(.Payload, .Payload)
	}
	.PaddingSize = .PaddingSize

	return 
}

// Clone returns a deep copy h.
func ( Header) () Header {
	 := 
	if .CSRC != nil {
		.CSRC = make([]uint32, len(.CSRC))
		copy(.CSRC, .CSRC)
	}
	if .Extensions != nil {
		 := make([]Extension, len(.Extensions))
		for ,  := range .Extensions {
			[] = 
			if .payload != nil {
				[].payload = make([]byte, len(.payload))
				copy([].payload, .payload)
			}
		}
		.Extensions = 
	}

	return 
}

func ( *Packet) () byte {
	if .Header.PaddingSize > 0 {
		return .Header.PaddingSize
	}

	return .PaddingSize
}

// MarshalPacketTo serializes the header and payload into bytes.
// Parts of pion code passes RTP header and payload separately, so this function
// is provided to help with that.
//
// Deprecated: this function is a temporary workaround and will be removed in pion/webrtc v5.
func ( []byte,  *Header,  []byte) (int, error) {
	,  := .MarshalTo()
	if  != nil {
		return 0, 
	}

	return marshalPayloadAndPaddingTo(, , , , .PaddingSize)
}

// PacketMarshalSize returns the size of the header and payload once marshaled.
// Parts of pion code passes RTP header and payload separately, so this function
// is provided to help with that.
//
// Deprecated: this function is a temporary workaround and will be removed in pion/webrtc v5.
func ( *Header,  []byte) int {
	return .MarshalSize() + len() + int(.PaddingSize)
}

// HeaderAndPacketMarshalSize returns the size of the header and full packet once marshaled.
// Parts of pion code passes RTP header and payload separately, so this function
// is provided to help with that.
//
// Deprecated: this function is a temporary workaround and will be removed in pion/webrtc v5.
func ( *Header,  []byte) ( int,  int) {
	 = .MarshalSize()

	return ,  + len() + int(.PaddingSize)
}