package dtls
import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"strings"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
)
type ClientHelloInfo struct {
ServerName string
CipherSuites []CipherSuiteID
RandomBytes [handshake .RandomBytesLength ]byte
}
type CertificateRequestInfo struct {
AcceptableCAs [][]byte
}
func (cri *CertificateRequestInfo ) SupportsCertificate (c *tls .Certificate ) error {
if len (cri .AcceptableCAs ) == 0 {
return nil
}
for j , cert := range c .Certificate {
x509Cert := c .Leaf
if j != 0 || x509Cert == nil {
var err error
if x509Cert , err = x509 .ParseCertificate (cert ); err != nil {
return fmt .Errorf ("failed to parse certificate #%d in the chain: %w" , j , err )
}
}
for _ , ca := range cri .AcceptableCAs {
if bytes .Equal (x509Cert .RawIssuer , ca ) {
return nil
}
}
}
return errNotAcceptableCertificateChain
}
func (c *handshakeConfig ) setNameToCertificateLocked () {
nameToCertificate := make (map [string ]*tls .Certificate )
for i := range c .localCertificates {
cert := &c .localCertificates [i ]
x509Cert := cert .Leaf
if x509Cert == nil {
var parseErr error
x509Cert , parseErr = x509 .ParseCertificate (cert .Certificate [0 ])
if parseErr != nil {
continue
}
}
if len (x509Cert .Subject .CommonName ) > 0 {
nameToCertificate [strings .ToLower (x509Cert .Subject .CommonName )] = cert
}
for _ , san := range x509Cert .DNSNames {
nameToCertificate [strings .ToLower (san )] = cert
}
}
c .nameToCertificate = nameToCertificate
}
func (c *handshakeConfig ) getCertificate (clientHelloInfo *ClientHelloInfo ) (*tls .Certificate , error ) {
c .mu .Lock ()
defer c .mu .Unlock ()
if c .localGetCertificate != nil &&
(len (c .localCertificates ) == 0 || len (clientHelloInfo .ServerName ) > 0 ) {
cert , err := c .localGetCertificate (clientHelloInfo )
if cert != nil || err != nil {
return cert , err
}
}
if c .nameToCertificate == nil {
c .setNameToCertificateLocked ()
}
if len (c .localCertificates ) == 0 {
return nil , errNoCertificates
}
if len (c .localCertificates ) == 1 {
return &c .localCertificates [0 ], nil
}
if len (clientHelloInfo .ServerName ) == 0 {
return &c .localCertificates [0 ], nil
}
name := strings .TrimRight (strings .ToLower (clientHelloInfo .ServerName ), "." )
if cert , ok := c .nameToCertificate [name ]; ok {
return cert , nil
}
labels := strings .Split (name , "." )
for i := range labels {
labels [i ] = "*"
candidate := strings .Join (labels , "." )
if cert , ok := c .nameToCertificate [candidate ]; ok {
return cert , nil
}
}
return &c .localCertificates [0 ], nil
}
func (c *handshakeConfig ) getClientCertificate (cri *CertificateRequestInfo ) (*tls .Certificate , error ) {
c .mu .Lock ()
defer c .mu .Unlock ()
if c .localGetClientCertificate != nil {
return c .localGetClientCertificate (cri )
}
for i := range c .localCertificates {
chain := c .localCertificates [i ]
if err := cri .SupportsCertificate (&chain ); err != nil {
continue
}
return &chain , nil
}
return new (tls .Certificate ), nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .