// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package ciphersuite

import (
	
	
	
	
	

	
	
)

const (
	gcmTagLength   = 16
	gcmNonceLength = 12
)

// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
type GCM struct {
	localGCM, remoteGCM         cipher.AEAD
	localWriteIV, remoteWriteIV []byte
}

// NewGCM creates a DTLS GCM Cipher
func (, , ,  []byte) (*GCM, error) {
	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}
	,  := cipher.NewGCM()
	if  != nil {
		return nil, 
	}

	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}
	,  := cipher.NewGCM()
	if  != nil {
		return nil, 
	}

	return &GCM{
		localGCM:      ,
		localWriteIV:  ,
		remoteGCM:     ,
		remoteWriteIV: ,
	}, nil
}

// Encrypt encrypt a DTLS RecordLayer message
func ( *GCM) ( *recordlayer.RecordLayer,  []byte) ([]byte, error) {
	 := [recordlayer.HeaderSize:]
	 = [:recordlayer.HeaderSize]

	 := make([]byte, gcmNonceLength)
	copy(, .localWriteIV[:4])
	if ,  := rand.Read([4:]);  != nil {
		return nil, 
	}

	 := generateAEADAdditionalData(&.Header, len())
	 := .localGCM.Seal(nil, , , )
	 := make([]byte, len()+len([4:])+len())
	copy(, )
	copy([len():], [4:])
	copy([len()+len([4:]):], )

	// Update recordLayer size to include explicit nonce
	binary.BigEndian.PutUint16([recordlayer.HeaderSize-2:], uint16(len()-recordlayer.HeaderSize))
	return , nil
}

// Decrypt decrypts a DTLS RecordLayer message
func ( *GCM) ( []byte) ([]byte, error) {
	var  recordlayer.Header
	 := .Unmarshal()
	switch {
	case  != nil:
		return nil, 
	case .ContentType == protocol.ContentTypeChangeCipherSpec:
		// Nothing to encrypt with ChangeCipherSpec
		return , nil
	case len() <= (8 + recordlayer.HeaderSize):
		return nil, errNotEnoughRoomForNonce
	}

	 := make([]byte, 0, gcmNonceLength)
	 = append(append(, .remoteWriteIV[:4]...), [recordlayer.HeaderSize:recordlayer.HeaderSize+8]...)
	 := [recordlayer.HeaderSize+8:]

	 := generateAEADAdditionalData(&, len()-gcmTagLength)
	,  = .remoteGCM.Open([:0], , , )
	if  != nil {
		return nil, fmt.Errorf("%w: %v", errDecryptPacket, ) //nolint:errorlint
	}
	return append([:recordlayer.HeaderSize], ...), nil
}