// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package ciphersuite

import (
	
	
	
	
	

	
	
)

const (
	gcmTagLength   = 16
	gcmNonceLength = 12
)

// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets.
type GCM struct {
	localGCM, remoteGCM         cipher.AEAD
	localWriteIV, remoteWriteIV []byte
}

// NewGCM creates a DTLS GCM Cipher.
func (, , ,  []byte) (*GCM, error) {
	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}
	,  := cipher.NewGCM()
	if  != nil {
		return nil, 
	}

	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}
	,  := cipher.NewGCM()
	if  != nil {
		return nil, 
	}

	return &GCM{
		localGCM:      ,
		localWriteIV:  ,
		remoteGCM:     ,
		remoteWriteIV: ,
	}, nil
}

// Encrypt encrypt a DTLS RecordLayer message.
func ( *GCM) ( *recordlayer.RecordLayer,  []byte) ([]byte, error) {
	 := [.Header.Size():]
	 = [:.Header.Size()]

	 := make([]byte, gcmNonceLength)
	copy(, .localWriteIV[:4])
	if ,  := rand.Read([4:]);  != nil {
		return nil, 
	}

	var  []byte
	if .Header.ContentType == protocol.ContentTypeConnectionID {
		 = generateAEADAdditionalDataCID(&.Header, len())
	} else {
		 = generateAEADAdditionalData(&.Header, len())
	}
	 := .localGCM.Seal(nil, , , )
	 := make([]byte, len()+len([4:])+len())
	copy(, )
	copy([len():], [4:])
	copy([len()+len([4:]):], )

	// Update recordLayer size to include explicit nonce
	binary.BigEndian.PutUint16([.Header.Size()-2:], uint16(len()-.Header.Size())) //nolint:gosec //G115

	return , nil
}

// Decrypt decrypts a DTLS RecordLayer message.
func ( *GCM) ( recordlayer.Header,  []byte) ([]byte, error) {
	 := .Unmarshal()
	switch {
	case  != nil:
		return nil, 
	case .ContentType == protocol.ContentTypeChangeCipherSpec:
		// Nothing to encrypt with ChangeCipherSpec
		return , nil
	case len() <= (8 + .Size()):
		return nil, errNotEnoughRoomForNonce
	}

	 := make([]byte, 0, gcmNonceLength)
	 = append(append(, .remoteWriteIV[:4]...), [.Size():.Size()+8]...)
	 := [.Size()+8:]

	var  []byte
	if .ContentType == protocol.ContentTypeConnectionID {
		 = generateAEADAdditionalDataCID(&, len()-gcmTagLength)
	} else {
		 = generateAEADAdditionalData(&, len()-gcmTagLength)
	}
	,  = .remoteGCM.Open([:0], , , )
	if  != nil {
		return nil, fmt.Errorf("%w: %v", errDecryptPacket, ) //nolint:errorlint
	}

	return append([:.Size()], ...), nil
}