// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package rtp

import (
	
	
	
	

	
)

var (
	// ErrVLATooShort is returned when payload is too short.
	ErrVLATooShort = errors.New("VLA payload too short")
	// ErrVLAInvalidStreamCount is returned when RTP stream count is invalid.
	ErrVLAInvalidStreamCount = errors.New("invalid RTP stream count in VLA")
	// ErrVLAInvalidStreamID is returned when RTP stream ID is invalid.
	ErrVLAInvalidStreamID = errors.New("invalid RTP stream ID in VLA")
	// ErrVLAInvalidSpatialID is returned when spatial ID is invalid.
	ErrVLAInvalidSpatialID = errors.New("invalid spatial ID in VLA")
	// ErrVLADuplicateSpatialID is returned when spatial ID is invalid.
	ErrVLADuplicateSpatialID = errors.New("duplicate spatial ID in VLA")
	// ErrVLAInvalidTemporalLayer is returned when temporal layer is invalid.
	ErrVLAInvalidTemporalLayer = errors.New("invalid temporal layer in VLA")
)

// SpatialLayer is a spatial layer in VLA.
type SpatialLayer struct {
	RTPStreamID    int
	SpatialID      int
	TargetBitrates []int // target bitrates per temporal layer

	// Following members are valid only when HasResolutionAndFramerate is true
	Width     int
	Height    int
	Framerate int
}

// VLA is a Video Layer Allocation (VLA) extension.
// See https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00
type VLA struct {
	RTPStreamID               int // 0-origin RTP stream ID (RID) this allocation is sent on (0..3)
	RTPStreamCount            int // Number of RTP streams (1..4)
	ActiveSpatialLayer        []SpatialLayer
	HasResolutionAndFramerate bool
}

type vlaMarshalingContext struct {
	slMBs                 [4]uint8
	sls                   [4][4]*SpatialLayer
	commonSLBM            uint8
	encodedTargetBitrates [][]byte
	requiredLen           int
}

func ( VLA) ( *vlaMarshalingContext) error {
	for  := 0;  < len(.ActiveSpatialLayer); ++ {
		 := .ActiveSpatialLayer[]
		if .RTPStreamID < 0 || .RTPStreamID >= .RTPStreamCount {
			return fmt.Errorf("invalid RTP streamID %d:%w", .RTPStreamID, ErrVLAInvalidStreamID)
		}
		if .SpatialID < 0 || .SpatialID >= 4 {
			return fmt.Errorf("invalid spatial ID %d: %w", .SpatialID, ErrVLAInvalidSpatialID)
		}
		if len(.TargetBitrates) == 0 || len(.TargetBitrates) > 4 {
			return fmt.Errorf("invalid temporal layer count %d: %w", len(.TargetBitrates), ErrVLAInvalidTemporalLayer)
		}
		.slMBs[.RTPStreamID] |= 1 << .SpatialID
		if .sls[.RTPStreamID][.SpatialID] != nil {
			return fmt.Errorf("duplicate spatial layer: %w", ErrVLADuplicateSpatialID)
		}
		.sls[.RTPStreamID][.SpatialID] = &
	}

	return nil
}

func ( VLA) ( *vlaMarshalingContext) {
	for  := 0;  < .RTPStreamCount; ++ {
		for  := 0;  < 4; ++ {
			if  := .sls[][];  != nil {
				for ,  := range .TargetBitrates {
					 := obu.WriteToLeb128(uint()) // nolint: gosec
					.encodedTargetBitrates = append(.encodedTargetBitrates, )
					.requiredLen += len()
				}
			}
		}
	}
}

func ( VLA) () (*vlaMarshalingContext, error) {
	// Validate RTPStreamCount
	if .RTPStreamCount <= 0 || .RTPStreamCount > 4 {
		return nil, ErrVLAInvalidStreamCount
	}
	// Validate RTPStreamID
	if .RTPStreamID < 0 || .RTPStreamID >= .RTPStreamCount {
		return nil, ErrVLAInvalidStreamID
	}

	 := &vlaMarshalingContext{}
	 := .preprocessForMashaling()
	if  != nil {
		return nil, 
	}

	.commonSLBM = commonSLBMValues(.slMBs[:])

	// RID, NS, sl_bm fields
	if .commonSLBM != 0 {
		.requiredLen = 1
	} else {
		.requiredLen = 3
	}

	// #tl fields
	.requiredLen += (len(.ActiveSpatialLayer)-1)/4 + 1

	.encodeTargetBitrates()

	if .HasResolutionAndFramerate {
		.requiredLen += len(.ActiveSpatialLayer) * 5
	}

	return , nil
}

// Marshal encodes VLA into a byte slice.
func ( VLA) () ([]byte, error) { // nolint: cyclop
	,  := .analyzeVLAForMarshaling()
	if  != nil {
		return nil, 
	}

	 := make([]byte, .requiredLen)
	 := 0

	// RID, NS, sl_bm fields
	[] = byte(.RTPStreamID<<6) | byte(.RTPStreamCount-1)<<4 | .commonSLBM

	if .commonSLBM == 0 {
		++
		for  := 0;  < .RTPStreamCount; ++ {
			if %2 == 0 {
				[+/2] |= .slMBs[] << 4
			} else {
				[+/2] |= .slMBs[]
			}
		}
		 += (.RTPStreamCount - 1) / 2
	}

	// #tl fields
	++
	var  int
	for  := 0;  < .RTPStreamCount; ++ {
		for  := 0;  < 4; ++ {
			if  := .sls[][];  != nil {
				if  >= 4 {
					 = 0
					++
				}
				[] |= byte(len(.TargetBitrates)-1) << (2 * (3 - ))
				++
			}
		}
	}

	// Target bitrate fields
	++
	for ,  := range .encodedTargetBitrates {
		 := len()
		copy([:], )
		 += 
	}

	// Resolution & framerate fields
	if .HasResolutionAndFramerate {
		for ,  := range .ActiveSpatialLayer {
			binary.BigEndian.PutUint16([+0:], uint16(.Width-1))  // nolint: gosec
			binary.BigEndian.PutUint16([+2:], uint16(.Height-1)) // nolint: gosec
			[+4] = byte(.Framerate)
			 += 5
		}
	}

	return , nil
}

func commonSLBMValues( []uint8) uint8 {
	var  uint8
	for  := 0;  < len(); ++ {
		if [] == 0 {
			continue
		}
		if  == 0 {
			 = []

			continue
		}
		if [] !=  {
			return 0
		}
	}

	return 
}

type vlaUnmarshalingContext struct {
	payload   []byte
	offset    int
	slBMField uint8
	slBMs     [4]uint8
}

func ( *vlaUnmarshalingContext) ( int) bool {
	return len(.payload)-.offset >= 
}

func ( *VLA) ( *vlaUnmarshalingContext) error {
	if !.checkRemainingLen(1) {
		return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", .offset, ErrVLATooShort)
	}
	.RTPStreamID = int(.payload[.offset] >> 6 & 0b11)
	.RTPStreamCount = int(.payload[.offset]>>4&0b11) + 1

	// sl_bm fields
	.slBMField = .payload[.offset] & 0b1111
	.offset++

	if .slBMField != 0 {
		for  := 0;  < .RTPStreamCount; ++ {
			.slBMs[] = .slBMField
		}
	} else {
		if !.checkRemainingLen((.RTPStreamCount-1)/2 + 1) {
			return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", .offset, ErrVLATooShort)
		}
		// slX_bm fields
		for  := 0;  < .RTPStreamCount; ++ {
			var  uint8
			if %2 == 0 {
				 = .payload[.offset+/2] >> 4 & 0b1111
			} else {
				 = .payload[.offset+/2] & 0b1111
			}
			.slBMs[] = 
		}
		.offset += 1 + (.RTPStreamCount-1)/2
	}

	return nil
}

func ( *VLA) ( *vlaUnmarshalingContext) error { // nolint: cyclop
	if !.checkRemainingLen(1) {
		return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", .offset, ErrVLATooShort)
	}

	var  int
	for  := 0;  < .RTPStreamCount; ++ {
		for  := 0;  < 4; ++ {
			if .slBMs[]&(1<<) == 0 {
				continue
			}
			if  >= 4 {
				 = 0
				.offset++
				if !.checkRemainingLen(1) {
					return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", .offset, ErrVLATooShort)
				}
			}
			 := int(.payload[.offset]>>(2*(3-))&0b11) + 1
			++
			 := SpatialLayer{
				RTPStreamID:    ,
				SpatialID:      ,
				TargetBitrates: make([]int, ),
			}
			.ActiveSpatialLayer = append(.ActiveSpatialLayer, )
		}
	}
	.offset++

	// target bitrates
	for ,  := range .ActiveSpatialLayer {
		for  := range .TargetBitrates {
			, ,  := obu.ReadLeb128(.payload[.offset:])
			if  != nil {
				return 
			}

			 := int() // nolint: gosec

			if !.checkRemainingLen() {
				return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", .offset, ErrVLATooShort)
			}
			.ActiveSpatialLayer[].TargetBitrates[] = int() // nolint: gosec
			.offset += 
		}
	}

	return nil
}

func ( *VLA) ( *vlaUnmarshalingContext) error {
	if !.checkRemainingLen(len(.ActiveSpatialLayer) * 5) {
		return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", .offset, ErrVLATooShort)
	}

	.HasResolutionAndFramerate = true

	for  := range .ActiveSpatialLayer {
		.ActiveSpatialLayer[].Width = int(binary.BigEndian.Uint16(.payload[.offset+0:])) + 1
		.ActiveSpatialLayer[].Height = int(binary.BigEndian.Uint16(.payload[.offset+2:])) + 1
		.ActiveSpatialLayer[].Framerate = int(.payload[.offset+4])
		.offset += 5
	}

	return nil
}

// Unmarshal decodes VLA from a byte slice.
func ( *VLA) ( []byte) (int, error) {
	 := &vlaUnmarshalingContext{
		payload: ,
	}

	 := .unmarshalSpatialLayers()
	if  != nil {
		return .offset, 
	}

	// #tl fields (build the list ActiveSpatialLayer at the same time)
	 = .unmarshalTemporalLayers()
	if  != nil {
		return .offset, 
	}

	if len(.payload) == .offset {
		return .offset, nil
	}

	// resolution & framerate (optional)
	 = .unmarshalResolutionAndFramerate()
	if  != nil {
		return .offset, 
	}

	return .offset, nil
}

// String makes VLA printable.
func ( VLA) () string {
	 := fmt.Sprintf("RID:%d,RTPStreamCount:%d", .RTPStreamID, .RTPStreamCount)
	var  []string
	for ,  := range .ActiveSpatialLayer {
		 := fmt.Sprintf("RTPStreamID:%d", .RTPStreamID)
		 += fmt.Sprintf(",TargetBitrates:%v", .TargetBitrates)
		if .HasResolutionAndFramerate {
			 += fmt.Sprintf(",Resolution:(%d,%d)", .Width, .Height)
			 += fmt.Sprintf(",Framerate:%d", .Framerate)
		}
		 = append(, )
	}
	 += fmt.Sprintf(",ActiveSpatialLayers:{%s}", strings.Join(, ","))

	return 
}