// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package srtp

import (
	
	
	
	

	
)

type srtpCipherAeadAesGcm struct {
	protectionProfileWithArgs

	srtpCipher, srtcpCipher cipher.AEAD

	srtpSessionSalt, srtcpSessionSalt []byte

	mki []byte

	srtpEncrypted, srtcpEncrypted bool
}

func newSrtpCipherAeadAesGcm(
	 protectionProfileWithArgs,
	, ,  []byte,
	,  bool,
) (*srtpCipherAeadAesGcm, error) {
	 := &srtpCipherAeadAesGcm{
		protectionProfileWithArgs: ,
		srtpEncrypted:             ,
		srtcpEncrypted:            ,
	}

	,  := aesCmKeyDerivation(labelSRTPEncryption, , , 0, len())
	if  != nil {
		return nil, 
	}

	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}

	.srtpCipher,  = cipher.NewGCM()
	if  != nil {
		return nil, 
	}

	,  := aesCmKeyDerivation(labelSRTCPEncryption, , , 0, len())
	if  != nil {
		return nil, 
	}

	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}

	.srtcpCipher,  = cipher.NewGCM()
	if  != nil {
		return nil, 
	}

	if .srtpSessionSalt,  = aesCmKeyDerivation(
		labelSRTPSalt, , , 0, len(),
	);  != nil {
		return nil, 
	} else if .srtcpSessionSalt,  = aesCmKeyDerivation(
		labelSRTCPSalt, , , 0, len(),
	);  != nil {
		return nil, 
	}

	 := len()
	if  > 0 {
		.mki = make([]byte, )
		copy(.mki, )
	}

	return , nil
}

func ( *srtpCipherAeadAesGcm) (
	 []byte,
	 *rtp.Header,
	 int,
	 []byte,
	 uint32,
	 bool,
) ( []byte,  error) {
	 := [:]
	 := len()

	// Grow the given buffer to fit the output.
	,  := .AEADAuthTagLen()
	if  != nil {
		return nil, 
	}
	 := .MarshalSize() + len() + 
	 :=  + len(.mki)
	if  {
		 += 4
	}
	 = growBufferSize(, )
	 := isSameBuffer(, )

	// Copy the header unencrypted.
	if ! {
		copy(, [:])
	}

	 := .rtpInitializationVector(, )
	if .srtpEncrypted {
		.srtpCipher.Seal([:], [:], , [:])
	} else {
		 :=  + 
		if ! {
			copy([:], )
		}
		.srtpCipher.Seal([:], [:], nil, [:])
	}

	// Add MKI after the encrypted payload
	if len(.mki) > 0 {
		copy([:], .mki)
	}

	if  {
		binary.BigEndian.PutUint32([len()-4:], )
	}

	return , nil
}

func ( *srtpCipherAeadAesGcm) (
	,  []byte,
	 *rtp.Header,
	 int,
	 uint32,
	 bool,
) ([]byte, error) {
	// Grow the given buffer to fit the output.
	,  := .AEADAuthTagLen()
	if  != nil {
		return nil, 
	}
	 := 0
	if  {
		 = 4
	}
	 := len() -  - len(.mki) - 
	if  <  {
		// Size of ciphertext is shorter than AEAD auth tag len.
		return nil, ErrFailedToVerifyAuthTag
	}
	 = growBufferSize(, )
	 := isSameBuffer(, )

	 := .rtpInitializationVector(, )

	 := len() - len(.mki) - 
	if .srtpEncrypted {
		if ,  := .srtpCipher.Open(
			[:], [:], [:], [:],
		);  != nil {
			return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, )
		}
	} else {
		 :=  - 
		if ,  := .srtpCipher.Open(
			nil, [:], [:], [:],
		);  != nil {
			return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, )
		}
		if ! {
			copy([:], [:])
		}
	}

	// Copy the header unencrypted.
	if ! {
		copy([:], [:])
	}

	return , nil
}

func ( *srtpCipherAeadAesGcm) (,  []byte,  uint32,  uint32) ([]byte, error) {
	,  := .AEADAuthTagLen()
	if  != nil {
		return nil, 
	}
	 := len() + 
	// Grow the given buffer to fit the output.
	 = growBufferSize(, +srtcpIndexSize+len(.mki))
	 := isSameBuffer(, )

	 := .rtcpInitializationVector(, )
	if .srtcpEncrypted {
		 := .rtcpAdditionalAuthenticatedData(, )
		if ! {
			// Copy the header unencrypted.
			copy([:srtcpHeaderSize], [:srtcpHeaderSize])
		}
		// Copy index to the proper place.
		copy([:+srtcpIndexSize], [8:12])
		.srtcpCipher.Seal([srtcpHeaderSize:srtcpHeaderSize], [:], [srtcpHeaderSize:], [:])
	} else {
		// Copy the packet unencrypted.
		if ! {
			copy(, )
		}
		// Append the SRTCP index to the end of the packet - this will form the AAD.
		binary.BigEndian.PutUint32([len():], )
		// Generate the authentication tag.
		 := make([]byte, )
		.srtcpCipher.Seal([0:0], [:], nil, [:len()+srtcpIndexSize])
		// Copy index to the proper place.
		copy([:], [len():len()+srtcpIndexSize])
		// Copy the auth tag after RTCP payload.
		copy([len():], )
	}

	copy([+srtcpIndexSize:], .mki)

	return , nil
}

func ( *srtpCipherAeadAesGcm) (,  []byte, ,  uint32) ([]byte, error) {
	 := len() - srtcpIndexSize - len(.mki)
	// Grow the given buffer to fit the output.
	,  := .AEADAuthTagLen()
	if  != nil {
		return nil, 
	}
	 :=  - 
	if  < 0 {
		// Size of ciphertext is shorter than AEAD auth tag len.
		return nil, ErrFailedToVerifyAuthTag
	}
	 = growBufferSize(, )
	 := isSameBuffer(, )

	 := []&srtcpEncryptionFlag != 0
	 := .rtcpInitializationVector(, )
	if  {
		 := .rtcpAdditionalAuthenticatedData(, )
		if ,  := .srtcpCipher.Open([srtcpHeaderSize:srtcpHeaderSize], [:], [srtcpHeaderSize:],
			[:]);  != nil {
			return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, )
		}
	} else {
		// Prepare AAD for received packet.
		 :=  - 
		 := make([]byte, +4)
		copy(, [:])
		copy([:], [:+4])
		// Verify the auth tag.
		if ,  := .srtcpCipher.Open(nil, [:], [:], );  != nil {
			return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, )
		}
		// Copy the unencrypted payload.
		if ! {
			copy([srtcpHeaderSize:], [srtcpHeaderSize:])
		}
	}

	// Copy the header unencrypted.
	if ! {
		copy([:srtcpHeaderSize], [:srtcpHeaderSize])
	}

	return , nil
}

// The 12-octet IV used by AES-GCM SRTP is formed by first concatenating
// 2 octets of zeroes, the 4-octet SSRC, the 4-octet rollover counter
// (ROC), and the 2-octet sequence number (SEQ).  The resulting 12-octet
// value is then XORed to the 12-octet salt to form the 12-octet IV.
//
// https://tools.ietf.org/html/rfc7714#section-8.1
func ( *srtpCipherAeadAesGcm) ( *rtp.Header,  uint32) [12]byte {
	var  [12]byte
	binary.BigEndian.PutUint32([2:], .SSRC)
	binary.BigEndian.PutUint32([6:], )
	binary.BigEndian.PutUint16([10:], .SequenceNumber)

	for  := range  {
		[] ^= .srtpSessionSalt[]
	}

	return 
}

// The 12-octet IV used by AES-GCM SRTCP is formed by first
// concatenating 2 octets of zeroes, the 4-octet SSRC identifier,
// 2 octets of zeroes, a single "0" bit, and the 31-bit SRTCP index.
// The resulting 12-octet value is then XORed to the 12-octet salt to
// form the 12-octet IV.
//
// https://tools.ietf.org/html/rfc7714#section-9.1
func ( *srtpCipherAeadAesGcm) ( uint32,  uint32) [12]byte {
	var  [12]byte

	binary.BigEndian.PutUint32([2:], )
	binary.BigEndian.PutUint32([8:], )

	for  := range  {
		[] ^= .srtcpSessionSalt[]
	}

	return 
}

// In an SRTCP packet, a 1-bit Encryption flag is prepended to the
// 31-bit SRTCP index to form a 32-bit value we shall call the
// "ESRTCP word"
//
// https://tools.ietf.org/html/rfc7714#section-17
func ( *srtpCipherAeadAesGcm) ( []byte,  uint32) [12]byte {
	var  [12]byte

	copy([:], [:8])
	binary.BigEndian.PutUint32([8:], )
	[8] |= srtcpEncryptionFlag

	return 
}

func ( *srtpCipherAeadAesGcm) ( []byte) uint32 {
	return binary.BigEndian.Uint32([len()-len(.mki)-srtcpIndexSize:]) &^ (srtcpEncryptionFlag << 24)
}