// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

//go:build !js
// +build !js

package webrtc

import (
	
	
	
	
	
	
	
	
	
	
	

	
	
	
	
	
	
	
	
	
)

// DTLSTransport allows an application access to information about the DTLS
// transport over which RTP and RTCP packets are sent and received by
// RTPSender and RTPReceiver, as well other data such as SCTP packets sent
// and received by data channels.
type DTLSTransport struct {
	lock sync.RWMutex

	iceTransport          *ICETransport
	certificates          []Certificate
	remoteParameters      DTLSParameters
	remoteCertificate     []byte
	state                 DTLSTransportState
	srtpProtectionProfile srtp.ProtectionProfile

	onStateChangeHandler   func(DTLSTransportState)
	internalOnCloseHandler func()

	conn *dtls.Conn

	srtpSession, srtcpSession   atomic.Value
	srtpEndpoint, srtcpEndpoint *mux.Endpoint
	simulcastStreams            []simulcastStreamPair
	srtpReady                   chan struct{}

	dtlsMatcher mux.MatchFunc

	api *API
	log logging.LeveledLogger
}

type simulcastStreamPair struct {
	srtp  *srtp.ReadStreamSRTP
	srtcp *srtp.ReadStreamSRTCP
}

// NewDTLSTransport creates a new DTLSTransport.
// This constructor is part of the ORTC API. It is not
// meant to be used together with the basic WebRTC API.
func ( *API) ( *ICETransport,  []Certificate) (*DTLSTransport, error) {
	 := &DTLSTransport{
		iceTransport: ,
		api:          ,
		state:        DTLSTransportStateNew,
		dtlsMatcher:  mux.MatchDTLS,
		srtpReady:    make(chan struct{}),
		log:          .settingEngine.LoggerFactory.NewLogger("DTLSTransport"),
	}

	if len() > 0 {
		 := time.Now()
		for ,  := range  {
			if !.Expires().IsZero() && .After(.Expires()) {
				return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired}
			}
			.certificates = append(.certificates, )
		}
	} else {
		,  := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
		if  != nil {
			return nil, &rtcerr.UnknownError{Err: }
		}
		,  := GenerateCertificate()
		if  != nil {
			return nil, 
		}
		.certificates = []Certificate{*}
	}

	return , nil
}

// ICETransport returns the currently-configured *ICETransport or nil
// if one has not been configured.
func ( *DTLSTransport) () *ICETransport {
	.lock.RLock()
	defer .lock.RUnlock()

	return .iceTransport
}

// onStateChange requires the caller holds the lock.
func ( *DTLSTransport) ( DTLSTransportState) {
	.state = 
	 := .onStateChangeHandler
	if  != nil {
		()
	}
}

// OnStateChange sets a handler that is fired when the DTLS
// connection state changes.
func ( *DTLSTransport) ( func(DTLSTransportState)) {
	.lock.Lock()
	defer .lock.Unlock()
	.onStateChangeHandler = 
}

// State returns the current dtls transport state.
func ( *DTLSTransport) () DTLSTransportState {
	.lock.RLock()
	defer .lock.RUnlock()

	return .state
}

// WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
// packet is discarded.
func ( *DTLSTransport) ( []rtcp.Packet) (int, error) {
	,  := rtcp.Marshal()
	if  != nil {
		return 0, 
	}

	,  := .getSRTCPSession()
	if  != nil {
		return 0, 
	}

	,  := .OpenWriteStream()
	if  != nil {
		// nolint
		return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, )
	}

	return .Write()
}

// GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction.
func ( *DTLSTransport) () (DTLSParameters, error) {
	 := []DTLSFingerprint{}

	for ,  := range .certificates {
		,  := .GetFingerprints()
		if  != nil {
			return DTLSParameters{}, 
		}

		 = append(, ...)
	}

	return DTLSParameters{
		Role:         DTLSRoleAuto, // always returns the default role
		Fingerprints: ,
	}, nil
}

// GetRemoteCertificate returns the certificate chain in use by the remote side
// returns an empty list prior to selection of the remote certificate.
func ( *DTLSTransport) () []byte {
	.lock.RLock()
	defer .lock.RUnlock()

	return .remoteCertificate
}

func ( *DTLSTransport) () error {
	 := &srtp.Config{
		Profile:       .srtpProtectionProfile,
		BufferFactory: .api.settingEngine.BufferFactory,
		LoggerFactory: .api.settingEngine.LoggerFactory,
	}
	if .api.settingEngine.replayProtection.SRTP != nil {
		.RemoteOptions = append(
			.RemoteOptions,
			srtp.SRTPReplayProtection(*.api.settingEngine.replayProtection.SRTP),
		)
	}

	if .api.settingEngine.disableSRTPReplayProtection {
		.RemoteOptions = append(
			.RemoteOptions,
			srtp.SRTPNoReplayProtection(),
		)
	}

	if .api.settingEngine.replayProtection.SRTCP != nil {
		.RemoteOptions = append(
			.RemoteOptions,
			srtp.SRTCPReplayProtection(*.api.settingEngine.replayProtection.SRTCP),
		)
	}

	if .api.settingEngine.disableSRTCPReplayProtection {
		.RemoteOptions = append(
			.RemoteOptions,
			srtp.SRTCPNoReplayProtection(),
		)
	}

	,  := .conn.ConnectionState()
	if ! {
		// nolint
		return fmt.Errorf("%w: Failed to get DTLS ConnectionState", errDtlsKeyExtractionFailed)
	}

	 := .ExtractSessionKeysFromDTLS(&, .role() == DTLSRoleClient)
	if  != nil {
		// nolint
		return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, )
	}

	,  := srtp.NewSessionSRTP(.srtpEndpoint, )
	if  != nil {
		// nolint
		return fmt.Errorf("%w: %v", errFailedToStartSRTP, )
	}

	,  := srtp.NewSessionSRTCP(.srtcpEndpoint, )
	if  != nil {
		// nolint
		return fmt.Errorf("%w: %v", errFailedToStartSRTCP, )
	}

	.srtpSession.Store()
	.srtcpSession.Store()
	close(.srtpReady)

	return nil
}

func ( *DTLSTransport) () (*srtp.SessionSRTP, error) {
	if ,  := .srtpSession.Load().(*srtp.SessionSRTP);  {
		return , nil
	}

	return nil, errDtlsTransportNotStarted
}

func ( *DTLSTransport) () (*srtp.SessionSRTCP, error) {
	if ,  := .srtcpSession.Load().(*srtp.SessionSRTCP);  {
		return , nil
	}

	return nil, errDtlsTransportNotStarted
}

func ( *DTLSTransport) () DTLSRole {
	// If remote has an explicit role use the inverse
	switch .remoteParameters.Role {
	case DTLSRoleClient:
		return DTLSRoleServer
	case DTLSRoleServer:
		return DTLSRoleClient
	default:
	}

	// If SettingEngine has an explicit role
	switch .api.settingEngine.answeringDTLSRole {
	case DTLSRoleServer:
		return DTLSRoleServer
	case DTLSRoleClient:
		return DTLSRoleClient
	default:
	}

	// Remote was auto and no explicit role was configured via SettingEngine
	if .iceTransport.Role() == ICERoleControlling {
		return DTLSRoleServer
	}

	return defaultDtlsRoleAnswer
}

// Start DTLS transport negotiation with the parameters of the remote DTLS transport.
func ( *DTLSTransport) ( DTLSParameters) error { //nolint:gocognit,cyclop
	// Take lock and prepare connection, we must not hold the lock
	// when connecting
	 := func() (DTLSRole, *dtls.Config, error) {
		.lock.Lock()
		defer .lock.Unlock()

		if  := .ensureICEConn();  != nil {
			return DTLSRole(0), nil, 
		}

		if .state != DTLSTransportStateNew {
			return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, .state)}
		}

		.srtpEndpoint = .iceTransport.newEndpoint(mux.MatchSRTP)
		.srtcpEndpoint = .iceTransport.newEndpoint(mux.MatchSRTCP)
		.remoteParameters = 

		 := .certificates[0]
		.onStateChange(DTLSTransportStateConnecting)

		return .role(), &dtls.Config{
			Certificates: []tls.Certificate{
				{
					Certificate: [][]byte{.x509Cert.Raw},
					PrivateKey:  .privateKey,
				},
			},
			SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile {
				if len(.api.settingEngine.srtpProtectionProfiles) > 0 {
					return .api.settingEngine.srtpProtectionProfiles
				}

				return defaultSrtpProtectionProfiles()
			}(),
			ClientAuth:         dtls.RequireAnyClientCert,
			LoggerFactory:      .api.settingEngine.LoggerFactory,
			InsecureSkipVerify: !.api.settingEngine.dtls.disableInsecureSkipVerify,
			CustomCipherSuites: .api.settingEngine.dtls.customCipherSuites,
		}, nil
	}

	var  *dtls.Conn
	 := .iceTransport.newEndpoint(mux.MatchDTLS)
	.SetOnClose(.internalOnCloseHandler)
	, ,  := ()
	if  != nil {
		return 
	}

	if .api.settingEngine.replayProtection.DTLS != nil {
		.ReplayProtectionWindow = int(*.api.settingEngine.replayProtection.DTLS) //nolint:gosec // G115
	}

	if .api.settingEngine.dtls.clientAuth != nil {
		.ClientAuth = *.api.settingEngine.dtls.clientAuth
	}

	.FlightInterval = .api.settingEngine.dtls.retransmissionInterval
	.InsecureSkipVerifyHello = .api.settingEngine.dtls.insecureSkipHelloVerify
	.EllipticCurves = .api.settingEngine.dtls.ellipticCurves
	.ExtendedMasterSecret = .api.settingEngine.dtls.extendedMasterSecret
	.ClientCAs = .api.settingEngine.dtls.clientCAs
	.RootCAs = .api.settingEngine.dtls.rootCAs
	.KeyLogWriter = .api.settingEngine.dtls.keyLogWriter
	.ClientHelloMessageHook = .api.settingEngine.dtls.clientHelloMessageHook
	.ServerHelloMessageHook = .api.settingEngine.dtls.serverHelloMessageHook
	.CertificateRequestMessageHook = .api.settingEngine.dtls.certificateRequestMessageHook

	// Connect as DTLS Client/Server, function is blocking and we
	// must not hold the DTLSTransport lock
	if  == DTLSRoleClient {
		,  = dtls.Client(, .RemoteAddr(), )
	} else {
		,  = dtls.Server(, .RemoteAddr(), )
	}

	if  == nil {
		if .api.settingEngine.dtls.connectContextMaker != nil {
			,  := .api.settingEngine.dtls.connectContextMaker()
			 = .HandshakeContext()
		} else {
			 = .Handshake()
		}
	}

	// Re-take the lock, nothing beyond here is blocking
	.lock.Lock()
	defer .lock.Unlock()

	if  != nil {
		.onStateChange(DTLSTransportStateFailed)

		return 
	}

	,  := .SelectedSRTPProtectionProfile()
	if ! {
		.onStateChange(DTLSTransportStateFailed)

		return ErrNoSRTPProtectionProfile
	}

	switch  {
	case dtls.SRTP_AEAD_AES_128_GCM:
		.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
	case dtls.SRTP_AEAD_AES_256_GCM:
		.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm
	case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
		.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
	case dtls.SRTP_NULL_HMAC_SHA1_80:
		.srtpProtectionProfile = srtp.ProtectionProfileNullHmacSha1_80
	default:
		.onStateChange(DTLSTransportStateFailed)

		return ErrNoSRTPProtectionProfile
	}

	// Check the fingerprint if a certificate was exchanged
	,  := .ConnectionState()
	if ! {
		.onStateChange(DTLSTransportStateFailed)

		return errNoRemoteCertificate
	}

	if len(.PeerCertificates) == 0 {
		.onStateChange(DTLSTransportStateFailed)

		return errNoRemoteCertificate
	}
	.remoteCertificate = .PeerCertificates[0]

	if !.api.settingEngine.disableCertificateFingerprintVerification { //nolint:nestif
		,  := x509.ParseCertificate(.remoteCertificate)
		if  != nil {
			if  := .Close();  != nil {
				.log.Error(.Error())
			}

			.onStateChange(DTLSTransportStateFailed)

			return 
		}

		if  = .validateFingerPrint();  != nil {
			if  := .Close();  != nil {
				.log.Error(.Error())
			}

			.onStateChange(DTLSTransportStateFailed)

			return 
		}
	}

	.conn = 
	.onStateChange(DTLSTransportStateConnected)

	return .startSRTP()
}

// Stop stops and closes the DTLSTransport object.
func ( *DTLSTransport) () error {
	.lock.Lock()
	defer .lock.Unlock()

	// Try closing everything and collect the errors
	var  []error

	if ,  := .getSRTPSession();  == nil &&  != nil {
		 = append(, .Close())
	}

	if ,  := .getSRTCPSession();  == nil &&  != nil {
		 = append(, .Close())
	}

	for  := range .simulcastStreams {
		 = append(, .simulcastStreams[].srtp.Close())
		 = append(, .simulcastStreams[].srtcp.Close())
	}

	if .conn != nil {
		// dtls connection may be closed on sctp close.
		if  := .conn.Close();  != nil && !errors.Is(, dtls.ErrConnClosed) {
			 = append(, )
		}
	}
	.onStateChange(DTLSTransportStateClosed)

	return util.FlattenErrs()
}

func ( *DTLSTransport) ( *x509.Certificate) error {
	for ,  := range .remoteParameters.Fingerprints {
		,  := fingerprint.HashFromString(.Algorithm)
		if  != nil {
			return 
		}

		,  := fingerprint.Fingerprint(, )
		if  != nil {
			return 
		}

		if strings.EqualFold(, .Value) {
			return nil
		}
	}

	return errNoMatchingCertificateFingerprint
}

func ( *DTLSTransport) () error {
	if .iceTransport == nil {
		return errICEConnectionNotStarted
	}

	return nil
}

func ( *DTLSTransport) (
	 *srtp.ReadStreamSRTP,
	 *srtp.ReadStreamSRTCP,
) {
	.lock.Lock()
	defer .lock.Unlock()

	.simulcastStreams = append(.simulcastStreams, simulcastStreamPair{, })
}

func ( *DTLSTransport) (
	 SSRC,
	 interceptor.StreamInfo,
) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
	,  := .getSRTPSession()
	if  != nil {
		return nil, nil, nil, nil, 
	}

	,  := .OpenReadStream(uint32())
	if  != nil {
		return nil, nil, nil, nil, 
	}

	 := .api.interceptor.BindRemoteStream(
		&,
		interceptor.RTPReaderFunc(
			func( []byte,  interceptor.Attributes) ( int,  interceptor.Attributes,  error) {
				,  = .Read()

				return , , 
			},
		),
	)

	,  := .getSRTCPSession()
	if  != nil {
		return nil, nil, nil, nil, 
	}

	,  := .OpenReadStream(uint32())
	if  != nil {
		return nil, nil, nil, nil, 
	}

	 := .api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(
		func( []byte,  interceptor.Attributes) ( int,  interceptor.Attributes,  error) {
			,  = .Read()

			return , , 
		}),
	)

	return , , , , nil
}