// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package srtp

import (
	
	

	
)

const (
	labelSRTPEncryption        = 0x00
	labelSRTPAuthenticationTag = 0x01
	labelSRTPSalt              = 0x02

	labelSRTCPEncryption        = 0x03
	labelSRTCPAuthenticationTag = 0x04
	labelSRTCPSalt              = 0x05

	maxSequenceNumber = 65535
	maxROC            = (1 << 32) - 1

	seqNumMedian = 1 << 15
	seqNumMax    = 1 << 16
)

// Encrypt/Decrypt state for a single SRTP SSRC.
type srtpSSRCState struct {
	ssrc                 uint32
	rolloverHasProcessed bool
	index                uint64
	replayDetector       replaydetector.ReplayDetector
}

// Encrypt/Decrypt state for a single SRTCP SSRC.
type srtcpSSRCState struct {
	srtcpIndex     uint32
	ssrc           uint32
	replayDetector replaydetector.ReplayDetector
}

// RCCMode is the mode of Roll-over Counter Carrying Transform from RFC 4771.
type RCCMode int

const (
	// RCCModeNone is the default mode.
	RCCModeNone RCCMode = iota
	// RCCMode1 is RCCm1 mode from RFC 4771. In this mode ROC and truncated auth tag is sent every R-th packet,
	// and no auth tag in other ones. This mode is not supported by pion/srtp.
	RCCMode1
	// RCCMode2 is RCCm2 mode from RFC 4771. In this mode ROC and truncated auth tag is sent every R-th packet,
	// and full auth tag in other ones. This mode is supported for AES-CM and NULL profiles only.
	RCCMode2
	// RCCMode3 is RCCm3 mode from RFC 4771. In this mode ROC is sent every R-th packet (without truncated auth tag),
	// and no auth tag in other ones. This mode is supported for AES-GCM profiles only.
	RCCMode3
)

// Context represents a SRTP cryptographic context.
// Context can only be used for one-way operations.
// it must either used ONLY for encryption or ONLY for decryption.
// Note that Context does not provide any concurrency protection:
// access to a Context from multiple goroutines requires external
// synchronization.
type Context struct {
	cipher srtpCipher

	srtpSSRCStates  map[uint32]*srtpSSRCState
	srtcpSSRCStates map[uint32]*srtcpSSRCState

	newSRTCPReplayDetector func() replaydetector.ReplayDetector
	newSRTPReplayDetector  func() replaydetector.ReplayDetector

	profile ProtectionProfile

	// Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled.
	sendMKI []byte
	// Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled.
	mkis map[string]srtpCipher

	encryptSRTP  bool
	encryptSRTCP bool

	rccMode         RCCMode
	rocTransmitRate uint16

	authTagRTPLen *int
}

// CreateContext creates a new SRTP Context.
//
// CreateContext receives variable number of ContextOption-s.
// Passing multiple options which set the same parameter let the last one valid.
// Following example create SRTP Context with replay protection with window size of 256.
//
//	decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
func (
	,  []byte,
	 ProtectionProfile,
	 ...ContextOption,
) ( *Context,  error) {
	 = &Context{
		srtpSSRCStates:  map[uint32]*srtpSSRCState{},
		srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
		profile:         ,
		mkis:            map[string]srtpCipher{},
	}

	for ,  := range append(
		[]ContextOption{ // Default options
			SRTPNoReplayProtection(),
			SRTCPNoReplayProtection(),
			SRTPEncryption(),
			SRTCPEncryption(),
		},
		..., // User specified options
	) {
		if  := ();  != nil {
			return nil, 
		}
	}

	if  = .checkRCCMode();  != nil {
		return nil, 
	}

	if .authTagRTPLen != nil {
		var  int
		,  = .profile.AuthKeyLen()
		if  != nil {
			return nil, 
		}
		if *.authTagRTPLen >  {
			return nil, errTooLongSRTPAuthTag
		}
	}

	.cipher,  = .createCipher(.sendMKI, , , .encryptSRTP, .encryptSRTCP)
	if  != nil {
		return nil, 
	}
	if len(.sendMKI) != 0 {
		.mkis[string(.sendMKI)] = .cipher
	}

	return , nil
}

// AddCipherForMKI adds new MKI with associated masker key and salt.
// Context must be created with MasterKeyIndicator option
// to enable MKI support. MKI must be unique and have the same length as the one used for creating Context.
// Operation is not thread-safe, you need to provide synchronization with decrypting packets.
func ( *Context) (, ,  []byte) error {
	if len(.mkis) == 0 {
		return errMKIIsNotEnabled
	}
	if len() == 0 || len() != len(.sendMKI) {
		return errInvalidMKILength
	}
	if ,  := .mkis[string()];  {
		return errMKIAlreadyInUse
	}

	,  := .createCipher(, , , .encryptSRTP, .encryptSRTCP)
	if  != nil {
		return 
	}
	.mkis[string()] = 

	return nil
}

func ( *Context) (, ,  []byte, ,  bool) (srtpCipher, error) {
	,  := .profile.KeyLen()
	if  != nil {
		return nil, 
	}

	,  := .profile.SaltLen()
	if  != nil {
		return nil, 
	}

	if  := len();  !=  {
		return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, , )
	} else if  := len();  !=  {
		return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, , )
	}

	 := protectionProfileWithArgs{
		ProtectionProfile: .profile,
		authTagRTPLen:     .authTagRTPLen,
	}

	switch .profile {
	case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
		return newSrtpCipherAeadAesGcm(, , , , , )
	case ProtectionProfileAes128CmHmacSha1_32,
		ProtectionProfileAes128CmHmacSha1_80,
		ProtectionProfileAes256CmHmacSha1_32,
		ProtectionProfileAes256CmHmacSha1_80:
		return newSrtpCipherAesCmHmacSha1(, , , , , )
	case ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80:
		return newSrtpCipherAesCmHmacSha1(, , , , false, false)
	default:
		return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, .profile)
	}
}

// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with decrypting packets.
func ( *Context) ( []byte) error {
	if ,  := .mkis[string()]; ! {
		return ErrMKINotFound
	}
	if bytes.Equal(, .sendMKI) {
		return errMKIAlreadyInUse
	}
	delete(.mkis, string())

	return nil
}

// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with encrypting packets.
func ( *Context) ( []byte) error {
	,  := .mkis[string()]
	if ! {
		return ErrMKINotFound
	}
	.sendMKI = 
	.cipher = 

	return nil
}

// https://tools.ietf.org/html/rfc3550#appendix-A.1
func ( *srtpSSRCState) ( uint16) ( uint32,  int64,  bool) {
	 := int32()
	 := uint32(.index >> 16)            //nolint:gosec // G115
	 := int32(.index & (seqNumMax - 1)) //nolint:gosec // G115

	 := 
	var  int32

	if .rolloverHasProcessed { //nolint:nestif
		// When localROC is equal to 0, and entering seq-localSeq > seqNumMedian
		// judgment, it will cause guessRoc calculation error
		if .index > seqNumMedian {
			if  < seqNumMedian {
				if - > seqNumMedian {
					 =  - 1
					 =  -  - seqNumMax
				} else {
					 = 
					 =  - 
				}
			} else {
				if -seqNumMedian >  {
					 =  + 1
					 =  -  + seqNumMax
				} else {
					 = 
					 =  - 
				}
			}
		} else {
			// localRoc is equal to 0
			 =  - 
		}
	}

	return , int64(), ( == 0 &&  == maxROC)
}

func ( *srtpSSRCState) ( uint16,  int64,  bool,
	 uint32,
) {
	switch {
	case :
		.index = (uint64() << 16) | uint64()
		.rolloverHasProcessed = true
	case !.rolloverHasProcessed:
		.index |= uint64()
		.rolloverHasProcessed = true
	case  > 0:
		.index += uint64()
	}
}

func ( *Context) ( uint32) *srtpSSRCState {
	,  := .srtpSSRCStates[]
	if  {
		return 
	}

	 = &srtpSSRCState{
		ssrc:           ,
		replayDetector: .newSRTPReplayDetector(),
	}
	.srtpSSRCStates[] = 

	return 
}

func ( *Context) ( uint32) *srtcpSSRCState {
	,  := .srtcpSSRCStates[]
	if  {
		return 
	}

	 = &srtcpSSRCState{
		ssrc:           ,
		replayDetector: .newSRTCPReplayDetector(),
	}
	.srtcpSSRCStates[] = 

	return 
}

// ROC returns SRTP rollover counter value of specified SSRC.
func ( *Context) ( uint32) (uint32, bool) {
	,  := .srtpSSRCStates[]
	if ! {
		return 0, false
	}

	return uint32(.index >> 16), true //nolint:gosec // G115
}

// SetROC sets SRTP rollover counter value of specified SSRC.
func ( *Context) ( uint32,  uint32) {
	 := .getSRTPSSRCState()
	.index = uint64() << 16
	.rolloverHasProcessed = false
}

// Index returns SRTCP index value of specified SSRC.
func ( *Context) ( uint32) (uint32, bool) {
	,  := .srtcpSSRCStates[]
	if ! {
		return 0, false
	}

	return .srtcpIndex, true
}

// SetIndex sets SRTCP index value of specified SSRC.
func ( *Context) ( uint32,  uint32) {
	 := .getSRTCPSSRCState()
	.srtcpIndex =  % (maxSRTCPIndex + 1)
}

//nolint:cyclop
func ( *Context) () error {
	if .rccMode == RCCModeNone {
		return nil
	}

	if .rocTransmitRate == 0 {
		return errZeroRocTransmitRate
	}

	switch .profile {
	case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
		// AEAD profiles support RCCMode3 only
		if .rccMode != RCCMode3 {
			return errUnsupportedRccMode
		}

	case ProtectionProfileAes128CmHmacSha1_32,
		ProtectionProfileAes256CmHmacSha1_32,
		ProtectionProfileNullHmacSha1_32:
		if .authTagRTPLen == nil {
			// ROC completely replaces auth tag for _32 profiles. If you really want to use 4-byte
			// SRTP auth tag with RCC, use SRTPAuthenticationTagLength(4) option.
			return errTooShortSRTPAuthTag
		}

		fallthrough // Checks below are common for _32 and _80 profiles.

	case ProtectionProfileAes128CmHmacSha1_80,
		ProtectionProfileAes256CmHmacSha1_80,
		ProtectionProfileNullHmacSha1_80:
		// AES-CM and NULL profiles support RCCMode2 only
		if .rccMode != RCCMode2 {
			return errUnsupportedRccMode
		}
		if .authTagRTPLen != nil && *.authTagRTPLen < 4 {
			return errTooShortSRTPAuthTag
		}

	default:
		return errUnsupportedRccMode
	}

	return nil
}