// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package dtls

import (
	
	
	
	
	
)

// ClientHelloInfo contains information from a ClientHello message in order to
// guide application logic in the GetCertificate.
type ClientHelloInfo struct {
	// ServerName indicates the name of the server requested by the client
	// in order to support virtual hosting. ServerName is only set if the
	// client is using SNI (see RFC 4366, Section 3.1).
	ServerName string

	// CipherSuites lists the CipherSuites supported by the client (e.g.
	// TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
	CipherSuites []CipherSuiteID
}

// CertificateRequestInfo contains information from a server's
// CertificateRequest message, which is used to demand a certificate and proof
// of control from a client.
type CertificateRequestInfo struct {
	// AcceptableCAs contains zero or more, DER-encoded, X.501
	// Distinguished Names. These are the names of root or intermediate CAs
	// that the server wishes the returned certificate to be signed by. An
	// empty slice indicates that the server has no preference.
	AcceptableCAs [][]byte
}

// SupportsCertificate returns nil if the provided certificate is supported by
// the server that sent the CertificateRequest. Otherwise, it returns an error
// describing the reason for the incompatibility.
// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273
func ( *CertificateRequestInfo) ( *tls.Certificate) error {
	if len(.AcceptableCAs) == 0 {
		return nil
	}

	for ,  := range .Certificate {
		 := .Leaf
		// Parse the certificate if this isn't the leaf node, or if
		// chain.Leaf was nil.
		if  != 0 ||  == nil {
			var  error
			if ,  = x509.ParseCertificate();  != nil {
				return fmt.Errorf("failed to parse certificate #%d in the chain: %w", , )
			}
		}

		for ,  := range .AcceptableCAs {
			if bytes.Equal(.RawIssuer, ) {
				return nil
			}
		}
	}
	return errNotAcceptableCertificateChain
}

func ( *handshakeConfig) () {
	 := make(map[string]*tls.Certificate)
	for  := range .localCertificates {
		 := &.localCertificates[]
		 := .Leaf
		if  == nil {
			var  error
			,  = x509.ParseCertificate(.Certificate[0])
			if  != nil {
				continue
			}
		}
		if len(.Subject.CommonName) > 0 {
			[strings.ToLower(.Subject.CommonName)] = 
		}
		for ,  := range .DNSNames {
			[strings.ToLower()] = 
		}
	}
	.nameToCertificate = 
}

func ( *handshakeConfig) ( *ClientHelloInfo) (*tls.Certificate, error) {
	.mu.Lock()
	defer .mu.Unlock()

	if .localGetCertificate != nil &&
		(len(.localCertificates) == 0 || len(.ServerName) > 0) {
		,  := .localGetCertificate()
		if  != nil ||  != nil {
			return , 
		}
	}

	if .nameToCertificate == nil {
		.setNameToCertificateLocked()
	}

	if len(.localCertificates) == 0 {
		return nil, errNoCertificates
	}

	if len(.localCertificates) == 1 {
		// There's only one choice, so no point doing any work.
		return &.localCertificates[0], nil
	}

	if len(.ServerName) == 0 {
		return &.localCertificates[0], nil
	}

	 := strings.TrimRight(strings.ToLower(.ServerName), ".")

	if ,  := .nameToCertificate[];  {
		return , nil
	}

	// try replacing labels in the name with wildcards until we get a
	// match.
	 := strings.Split(, ".")
	for  := range  {
		[] = "*"
		 := strings.Join(, ".")
		if ,  := .nameToCertificate[];  {
			return , nil
		}
	}

	// If nothing matches, return the first certificate.
	return &.localCertificates[0], nil
}

// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974
func ( *handshakeConfig) ( *CertificateRequestInfo) (*tls.Certificate, error) {
	.mu.Lock()
	defer .mu.Unlock()
	if .localGetClientCertificate != nil {
		return .localGetClientCertificate()
	}

	for  := range .localCertificates {
		 := .localCertificates[]
		if  := .SupportsCertificate(&);  != nil {
			continue
		}
		return &, nil
	}

	// No acceptable certificate found. Don't send a certificate.
	return new(tls.Certificate), nil
}