// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package codecs

import (
	
)

const (
	av1ZMask     = byte(0b10000000)
	av1ZBitshift = 7

	av1YMask     = byte(0b01000000)
	av1YBitshift = 6

	av1WMask     = byte(0b00110000)
	av1WBitshift = 4

	av1NMask     = byte(0b00001000)
	av1NBitshift = 3
)

// AV1Payloader payloads AV1 packets.
type AV1Payloader struct{}

// Payload implements AV1 RTP payloader.
// Reads from a open_bitstream_unit (OBU) framing stream as defined in
// 5.3. https://aomediacodec.github.io/av1-spec/av1-spec.pdf#page=39
// Returns AV1 RTP packets https://aomediacodec.github.io/av1-rtp-spec/
// The payload is fragmented into multiple packets, each packet is a valid AV1 RTP payload.
// nolint:cyclop
func ( *AV1Payloader) ( uint16,  []byte) ( [][]byte) {
	// 2 is the minimum MTU for AV1 (aggregate header + 1 byte)
	if  <= 1 || len() == 0 {
		return 
	}

	// We maximize the use of the W field in the AV1 aggregation header
	// to minimize the need for explicit length fields for each OBU.
	// To achieve this, we temporarily hold the OBU payload before adding it to a packet.
	// Since we can't determine in advance whether the next OBU should be included in the same packet
	// or start a new one, we also can't know ahead of time if an OBU is the last in the current packet.
	var  []byte
	var  *obu.ExtensionHeader
	 := 0
	 := false
	 := false

	for  := 0;  < len(); {
		,  := obu.ParseOBUHeader([:])
		if  != nil {
			break
		}

		 += .Size()
		//  if ( obu_has_size_field ) {
		//    obu_size leb128()
		//  } else {
		//    obu_size = sz - 1 - obu_extension_flag
		//  }
		var  int
		if .HasSizeField {
			, ,  := obu.ReadLeb128([:])
			if  != nil {
				break
			}

			 += int()            //nolint:gosec // G115, leb128 size is a signle digit
			 = int() //nolint:gosec // G115, Leb128 is capped at 4 bytes
		} else {
			 = len() - 
		}

		// Each RTP packet MUST NOT contain OBUs that belong to different temporal units.
		// If a sequence header OBU is present in an RTP packet, then it SHOULD be the first OBU in the packet.
		// https://aomediacodec.github.io/av1-rtp-spec/#5-packetization-rules
		 := .Type == obu.OBUTemporalDelimiter || .Type == obu.OBUSequenceHeader
		// If more than one OBU contained in an RTP packet has an OBU extension header,
		// then the values of the temporal_id and spatial_id MUST be the same in all such OBUs in the RTP packet.
		if ! && .ExtensionHeader != nil &&  != nil {
			 = .ExtensionHeader.SpatialID != .SpatialID ||
				.ExtensionHeader.TemporalID != .TemporalID
		}

		if .ExtensionHeader != nil {
			 = .ExtensionHeader
		}

		if  > len()- {
			break
		}

		if len() > 0 {
			,  = .appendOBUPayload(
				,
				,
				,
				,
				,
				int(),
				,
			)
			 = nil
			 = 

			if  {
				 = false
				 = nil
			}
		}

		// The temporal delimiter OBU, if present, SHOULD be removed when transmitting,
		// and MUST be ignored by receivers. Tile list OBUs are not supported.
		// They SHOULD be removed when transmitted, and MUST be ignored by receivers.
		// https://aomediacodec.github.io/av1-rtp-spec/#5-packetization-rules
		if .Type == obu.OBUTileList || .Type == obu.OBUTemporalDelimiter {
			 += 

			continue
		}

		 = make([]byte, +.Size())
		// The AV1 specification allows OBUs to have an optional size field called obu_size
		// (also leb128 encoded), signaled by the obu_has_size_field flag in the OBU header.
		// To minimize overhead, the obu_has_size_field flag SHOULD be set to zero in all OBUs.
		// https://aomediacodec.github.io/av1-rtp-spec/#45-payload-structure
		.HasSizeField = false
		copy(, .Marshal())
		//nolint:gosec // G115 we validate the size of the payload
		copy([.Size():], [:+])
		 += 
		 = .Type == obu.OBUSequenceHeader
	}

	if len() > 0 {
		, _ = .appendOBUPayload(
			,
			,
			,
			true,
			,
			int(),
			,
		)
	}

	return 
}

//nolint:cyclop
func ( *AV1Payloader) (
	 [][]byte,
	 []byte,
	, ,  bool,
	,  int,
) ([][]byte, int) {
	 := len() - 1
	 := 0
	if  >= 0 {
		 =  - len([])
	}

	if  < 0 ||  <= 0 ||  {
		 := make([]byte, 1, )
		if  {
			[0] |= 1 << av1NBitshift
		}

		 = append(, )
		 = len() - 1
		// MTU - aggregation header
		 =  - 1
		 = 0
	}

	 := len()
	// How much to write to the current packet.
	 := 
	if  >=  {
		 = 
	}

	// W: two bit field that describes the number of OBU elements in the packet.
	// This field MUST be set equal to 0 or equal to the number of OBU elements contained in the packet.
	// If set to 0, each OBU element MUST be preceded by a length field. If not set to 0 (i.e., W = 1, 2 or 3)
	// the last OBU element MUST NOT be preceded by a length field.
	// https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header
	 := ( ||  >= ) &&  < 3
	switch {
	case :
		[][0] |= byte((+1)<<av1WBitshift) & av1WMask
		[] = append([], [:]...)
		 = 0
	case  >= 2:
		// 2 bytes is the minimum size for OBUs with length field.
		// [1 byte for the length field] [1 byte for the OBU]
		//nolint:gosec // G115 false positive
		 = .computeWriteSize(, )
		 := obu.WriteToLeb128(uint()) //nolint:gosec // G115 false positive
		[] = append([], ...)
		[] = append([], [:]...)
		++
	default:
		// If we can't fit any more OBUs in the current packet (only 1 byte left and W=0)
		 = 0
	}

	 = [:]
	 -= 

	// Handle fragments.
	for  > 0 {
		// New packet with empty aggregation header.
		 := make([]byte, 1, )
		 = append(, )
		++

		// Append the Y bit to the previous packet. And Z bit to the current packet.
		// If we wrote some bytes to the previous packet.
		// Handles an edge case where the previous packet has only one byte remaining,
		// while the W field is not used. This results in insufficient space
		// for a one-byte length field and a one-byte OBU.
		// So we don't write anything to the initial packet.
		if  != 0 {
			[-1][0] |= av1YMask
			[][0] |= av1ZMask
		}

		 = 
		if  >= -1 { // MTU - aggregation header
			 =  - 1
		}

		// Last OBU in the current packet, Or this whole packet is a fragment.
		if  ||  >= -1 {
			[][0] |= 1 << av1WBitshift
		} else {
			 = .computeWriteSize(, -1)
			 := obu.WriteToLeb128(uint()) //nolint:gosec // G115 false positive
			[] = append([], ...)
		}

		[] = append([], [:]...)
		 = [:]
		 -= 
		 = 1
	}

	return , 
}

// Measure the maximum write size for a payload with leb128 encoding added.
func ( *AV1Payloader) (,  int) int {
	,  := .leb128Size()
	if  >= + {
		return 
	}

	// Handle edge case where subtracting one from the leb128 size
	// results in a smaller leb128 size that can fit in the remaining space.
	if  &&  >= +-1 {
		return  - 1
	}

	return  - 
}

func ( *AV1Payloader) ( int) ( int,  bool) {
	switch {
	case  >= 268435456: // 2^28
		return 5,  == 268435456
	case  >= 2097152: // 2^21
		return 4,  == 2097152
	case  >= 16384: // 2^14
		return 3,  == 16384
	case  >= 128: // 2^7
		return 2,  == 128
	default:
		return 1, false
	}
}

// AV1Packet represents a depacketized AV1 RTP Packet
/*
*  0 1 2 3 4 5 6 7
* +-+-+-+-+-+-+-+-+
* |Z|Y| W |N|-|-|-|
* +-+-+-+-+-+-+-+-+
**/
// https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header
// Deprecated: Use AV1Depacketizer instead.
type AV1Packet struct {
	// Z: MUST be set to 1 if the first OBU element is an
	//    OBU fragment that is a continuation of an OBU fragment
	//    from the previous packet, and MUST be set to 0 otherwise.
	Z bool

	// Y: MUST be set to 1 if the last OBU element is an OBU fragment
	//    that will continue in the next packet, and MUST be set to 0 otherwise.
	Y bool

	// W: two bit field that describes the number of OBU elements in the packet.
	//    This field MUST be set equal to 0 or equal to the number of OBU elements
	//    contained in the packet. If set to 0, each OBU element MUST be preceded by
	//    a length field. If not set to 0 (i.e., W = 1, 2 or 3) the last OBU element
	//    MUST NOT be preceded by a length field. Instead, the length of the last OBU
	//    element contained in the packet can be calculated as follows:
	// Length of the last OBU element =
	//    length of the RTP payload
	//  - length of aggregation header
	//  - length of previous OBU elements including length fields
	W byte

	// N: MUST be set to 1 if the packet is the first packet of a coded video sequence, and MUST be set to 0 otherwise.
	N bool

	// Each AV1 RTP Packet is a collection of OBU Elements. Each OBU Element may be a full OBU, or just a fragment of one.
	// AV1Frame provides the tools to construct a collection of OBUs from a collection of OBU Elements
	OBUElements [][]byte

	// zeroAllocation prevents populating the OBUElements field
	zeroAllocation bool
}

// Unmarshal parses the passed byte slice and stores the result in the AV1Packet this method is called upon.
func ( *AV1Packet) ( []byte) ([]byte, error) {
	if  == nil {
		return nil, errNilPacket
	} else if len() < 2 {
		return nil, errShortPacket
	}

	.Z = (([0] & av1ZMask) >> av1ZBitshift) != 0
	.Y = (([0] & av1YMask) >> av1YBitshift) != 0
	.N = (([0] & av1NMask) >> av1NBitshift) != 0
	.W = ([0] & av1WMask) >> av1WBitshift

	if .Z && .N {
		return nil, errIsKeyframeAndFragment
	}

	if !.zeroAllocation {
		,  := .parseBody([1:])
		if  != nil {
			return nil, 
		}
		.OBUElements = 
	}

	return [1:], nil
}

func ( *AV1Packet) ( []byte) ([][]byte, error) {
	if .OBUElements != nil {
		return .OBUElements, nil
	}

	 := [][]byte{}

	var ,  uint
	 := uint(0)
	for  := 1; ; ++ {
		if  == uint(len()) {
			break
		}

		// If W bit is set the last OBU Element will have no length header
		if byte() == .W {
			 = 0
			 = uint(len()) - 
		} else {
			var  error
			, ,  = obu.ReadLeb128([:])
			if  != nil {
				return nil, 
			}
		}

		 += 
		if uint(len()) < + {
			return nil, errShortPacket
		}
		 = append(, [:+])
		 += 
	}

	return , nil
}