package dtls
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/dtls/v3/internal/closer"
"github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/crypto/signaturehash"
"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/alert"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
"github.com/pion/logging"
"github.com/pion/transport/v3/deadline"
"github.com/pion/transport/v3/netctx"
"github.com/pion/transport/v3/replaydetector"
)
const (
initialTickerInterval = time .Second
cookieLength = 20
sessionLength = 32
defaultNamedCurve = elliptic .X25519
inboundBufferSize = 8192
defaultReplayProtectionWindow = 64
maxAppDataPacketQueueSize = 100
)
func invalidKeyingLabels() map [string ]bool {
return map [string ]bool {
"client finished" : true ,
"server finished" : true ,
"master secret" : true ,
"key expansion" : true ,
}
}
type addrPkt struct {
rAddr net .Addr
data []byte
}
type recvHandshakeState struct {
done chan struct {}
isRetransmit bool
}
type Conn struct {
lock sync .RWMutex
nextConn netctx .PacketConn
fragmentBuffer *fragmentBuffer
handshakeCache *handshakeCache
decrypted chan interface {}
rAddr net .Addr
state State
maximumTransmissionUnit int
paddingLengthGenerator func (uint ) uint
handshakeCompletedSuccessfully atomic .Value
handshakeMutex sync .Mutex
handshakeDone chan struct {}
encryptedPackets []addrPkt
connectionClosedByUser bool
closeLock sync .Mutex
closed *closer .Closer
readDeadline *deadline .Deadline
writeDeadline *deadline .Deadline
log logging .LeveledLogger
reading chan struct {}
handshakeRecv chan recvHandshakeState
cancelHandshaker func ()
cancelHandshakeReader func ()
fsm *handshakeFSM
replayProtectionWindow uint
handshakeConfig *handshakeConfig
}
func createConn(
nextConn net .PacketConn ,
rAddr net .Addr ,
config *Config ,
isClient bool ,
resumeState *State ,
) (*Conn , error ) {
if err := validateConfig (config ); err != nil {
return nil , err
}
if nextConn == nil {
return nil , errNilNextConn
}
loggerFactory := config .LoggerFactory
if loggerFactory == nil {
loggerFactory = logging .NewDefaultLoggerFactory ()
}
logger := loggerFactory .NewLogger ("dtls" )
mtu := config .MTU
if mtu <= 0 {
mtu = defaultMTU
}
replayProtectionWindow := config .ReplayProtectionWindow
if replayProtectionWindow <= 0 {
replayProtectionWindow = defaultReplayProtectionWindow
}
paddingLengthGenerator := config .PaddingLengthGenerator
if paddingLengthGenerator == nil {
paddingLengthGenerator = func (uint ) uint { return 0 }
}
cipherSuites , err := parseCipherSuites (
config .CipherSuites ,
config .CustomCipherSuites ,
config .includeCertificateSuites (),
config .PSK != nil ,
)
if err != nil {
return nil , err
}
signatureSchemes , err := signaturehash .ParseSignatureSchemes (config .SignatureSchemes , config .InsecureHashes )
if err != nil {
return nil , err
}
workerInterval := initialTickerInterval
if config .FlightInterval != 0 {
workerInterval = config .FlightInterval
}
serverName := config .ServerName
if net .ParseIP (serverName ) != nil {
serverName = ""
}
curves := config .EllipticCurves
if len (curves ) == 0 {
curves = defaultCurves
}
handshakeConfig := &handshakeConfig {
localPSKCallback : config .PSK ,
localPSKIdentityHint : config .PSKIdentityHint ,
localCipherSuites : cipherSuites ,
localSignatureSchemes : signatureSchemes ,
extendedMasterSecret : config .ExtendedMasterSecret ,
localSRTPProtectionProfiles : config .SRTPProtectionProfiles ,
localSRTPMasterKeyIdentifier : config .SRTPMasterKeyIdentifier ,
serverName : serverName ,
supportedProtocols : config .SupportedProtocols ,
clientAuth : config .ClientAuth ,
localCertificates : config .Certificates ,
insecureSkipVerify : config .InsecureSkipVerify ,
verifyPeerCertificate : config .VerifyPeerCertificate ,
verifyConnection : config .VerifyConnection ,
rootCAs : config .RootCAs ,
clientCAs : config .ClientCAs ,
customCipherSuites : config .CustomCipherSuites ,
initialRetransmitInterval : workerInterval ,
disableRetransmitBackoff : config .DisableRetransmitBackoff ,
log : logger ,
initialEpoch : 0 ,
keyLogWriter : config .KeyLogWriter ,
sessionStore : config .SessionStore ,
ellipticCurves : curves ,
localGetCertificate : config .GetCertificate ,
localGetClientCertificate : config .GetClientCertificate ,
insecureSkipHelloVerify : config .InsecureSkipVerifyHello ,
connectionIDGenerator : config .ConnectionIDGenerator ,
helloRandomBytesGenerator : config .HelloRandomBytesGenerator ,
clientHelloMessageHook : config .ClientHelloMessageHook ,
serverHelloMessageHook : config .ServerHelloMessageHook ,
certificateRequestMessageHook : config .CertificateRequestMessageHook ,
resumeState : resumeState ,
}
conn := &Conn {
rAddr : rAddr ,
nextConn : netctx .NewPacketConn (nextConn ),
handshakeConfig : handshakeConfig ,
fragmentBuffer : newFragmentBuffer (),
handshakeCache : newHandshakeCache (),
maximumTransmissionUnit : mtu ,
paddingLengthGenerator : paddingLengthGenerator ,
decrypted : make (chan interface {}, 1 ),
log : logger ,
readDeadline : deadline .New (),
writeDeadline : deadline .New (),
reading : make (chan struct {}, 1 ),
handshakeRecv : make (chan recvHandshakeState ),
closed : closer .NewCloser (),
cancelHandshaker : func () {},
cancelHandshakeReader : func () {},
replayProtectionWindow : uint (replayProtectionWindow ),
state : State {
isClient : isClient ,
},
}
conn .setRemoteEpoch (0 )
conn .setLocalEpoch (0 )
return conn , nil
}
func (c *Conn ) Handshake () error {
return c .HandshakeContext (context .Background ())
}
func (c *Conn ) HandshakeContext (ctx context .Context ) error {
c .handshakeMutex .Lock ()
defer c .handshakeMutex .Unlock ()
if c .isHandshakeCompletedSuccessfully () {
return nil
}
handshakeDone := make (chan struct {})
defer close (handshakeDone )
c .closeLock .Lock ()
c .handshakeDone = handshakeDone
c .closeLock .Unlock ()
if !c .state .isClient {
cert , err := c .handshakeConfig .getCertificate (&ClientHelloInfo {})
if err != nil && !errors .Is (err , errNoCertificates ) {
return err
}
c .handshakeConfig .localCipherSuites = filterCipherSuitesForCertificate (cert , c .handshakeConfig .localCipherSuites )
}
var initialFlight flightVal
var initialFSMState handshakeState
if c .handshakeConfig .resumeState != nil {
if c .state .isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished
c .state = *c .handshakeConfig .resumeState
} else {
if c .state .isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
if err := c .handshake (ctx , c .handshakeConfig , initialFlight , initialFSMState ); err != nil {
return err
}
c .log .Trace ("Handshake Completed" )
return nil
}
func Dial (network string , rAddr *net .UDPAddr , config *Config ) (*Conn , error ) {
pConn , err := net .ListenUDP (network , nil )
if err != nil {
return nil , err
}
return Client (pConn , rAddr , config )
}
func Client (conn net .PacketConn , rAddr net .Addr , config *Config ) (*Conn , error ) {
switch {
case config == nil :
return nil , errNoConfigProvided
case config .PSK != nil && config .PSKIdentityHint == nil :
return nil , errPSKAndIdentityMustBeSetForClient
}
return createConn (conn , rAddr , config , true , nil )
}
func Server (conn net .PacketConn , rAddr net .Addr , config *Config ) (*Conn , error ) {
if config == nil {
return nil , errNoConfigProvided
}
if config .OnConnectionAttempt != nil {
if err := config .OnConnectionAttempt (rAddr ); err != nil {
return nil , err
}
}
return createConn (conn , rAddr , config , false , nil )
}
func (c *Conn ) Read (buff []byte ) (n int , err error ) {
if err := c .Handshake (); err != nil {
return 0 , err
}
select {
case <- c .readDeadline .Done ():
return 0 , errDeadlineExceeded
default :
}
for {
select {
case <- c .readDeadline .Done ():
return 0 , errDeadlineExceeded
case out , ok := <- c .decrypted :
if !ok {
return 0 , io .EOF
}
switch val := out .(type ) {
case ([]byte ):
if len (buff ) < len (val ) {
return 0 , errBufferTooSmall
}
copy (buff , val )
return len (val ), nil
case (error ):
return 0 , val
}
}
}
}
func (c *Conn ) Write (payload []byte ) (int , error ) {
if c .isConnectionClosed () {
return 0 , ErrConnClosed
}
select {
case <- c .writeDeadline .Done ():
return 0 , errDeadlineExceeded
default :
}
if err := c .Handshake (); err != nil {
return 0 , err
}
return len (payload ), c .writePackets (c .writeDeadline , []*packet {
{
record : &recordlayer .RecordLayer {
Header : recordlayer .Header {
Epoch : c .state .getLocalEpoch (),
Version : protocol .Version1_2 ,
},
Content : &protocol .ApplicationData {
Data : payload ,
},
},
shouldWrapCID : len (c .state .remoteConnectionID ) > 0 ,
shouldEncrypt : true ,
},
})
}
func (c *Conn ) Close () error {
err := c .close (true )
c .closeLock .Lock ()
handshakeDone := c .handshakeDone
c .closeLock .Unlock ()
if handshakeDone != nil {
<-handshakeDone
}
return err
}
func (c *Conn ) ConnectionState () (State , bool ) {
c .lock .RLock ()
defer c .lock .RUnlock ()
stateClone , err := c .state .clone ()
if err != nil {
return State {}, false
}
return *stateClone , true
}
func (c *Conn ) SelectedSRTPProtectionProfile () (SRTPProtectionProfile , bool ) {
profile := c .state .getSRTPProtectionProfile ()
if profile == 0 {
return 0 , false
}
return profile , true
}
func (c *Conn ) RemoteSRTPMasterKeyIdentifier () ([]byte , bool ) {
if profile := c .state .getSRTPProtectionProfile (); profile == 0 {
return nil , false
}
return c .state .remoteSRTPMasterKeyIdentifier , true
}
func (c *Conn ) writePackets (ctx context .Context , pkts []*packet ) error {
c .lock .Lock ()
defer c .lock .Unlock ()
var rawPackets [][]byte
for _ , pkt := range pkts {
if dtlsHandshake , ok := pkt .record .Content .(*handshake .Handshake ); ok {
handshakeRaw , err := pkt .record .Marshal ()
if err != nil {
return err
}
c .log .Tracef ("[handshake:%v] -> %s (epoch: %d, seq: %d)" ,
srvCliStr (c .state .isClient ), dtlsHandshake .Header .Type .String (),
pkt .record .Header .Epoch , dtlsHandshake .Header .MessageSequence )
c .handshakeCache .push (
handshakeRaw [recordlayer .FixedHeaderSize :],
pkt .record .Header .Epoch ,
dtlsHandshake .Header .MessageSequence ,
dtlsHandshake .Header .Type ,
c .state .isClient ,
)
rawHandshakePackets , err := c .processHandshakePacket (pkt , dtlsHandshake )
if err != nil {
return err
}
rawPackets = append (rawPackets , rawHandshakePackets ...)
} else {
rawPacket , err := c .processPacket (pkt )
if err != nil {
return err
}
rawPackets = append (rawPackets , rawPacket )
}
}
if len (rawPackets ) == 0 {
return nil
}
compactedRawPackets := c .compactRawPackets (rawPackets )
for _ , compactedRawPackets := range compactedRawPackets {
if _ , err := c .nextConn .WriteToContext (ctx , compactedRawPackets , c .rAddr ); err != nil {
return netError (err )
}
}
return nil
}
func (c *Conn ) compactRawPackets (rawPackets [][]byte ) [][]byte {
if len (rawPackets ) == 1 {
return rawPackets
}
combinedRawPackets := make ([][]byte , 0 )
currentCombinedRawPacket := make ([]byte , 0 )
for _ , rawPacket := range rawPackets {
if len (currentCombinedRawPacket ) > 0 && len (currentCombinedRawPacket )+len (rawPacket ) >= c .maximumTransmissionUnit {
combinedRawPackets = append (combinedRawPackets , currentCombinedRawPacket )
currentCombinedRawPacket = []byte {}
}
currentCombinedRawPacket = append (currentCombinedRawPacket , rawPacket ...)
}
combinedRawPackets = append (combinedRawPackets , currentCombinedRawPacket )
return combinedRawPackets
}
func (c *Conn ) processPacket (pkt *packet ) ([]byte , error ) {
epoch := pkt .record .Header .Epoch
for len (c .state .localSequenceNumber ) <= int (epoch ) {
c .state .localSequenceNumber = append (c .state .localSequenceNumber , uint64 (0 ))
}
seq := atomic .AddUint64 (&c .state .localSequenceNumber [epoch ], 1 ) - 1
if seq > recordlayer .MaxSequenceNumber {
return nil , errSequenceNumberOverflow
}
pkt .record .Header .SequenceNumber = seq
var rawPacket []byte
if pkt .shouldWrapCID {
if _ , err := pkt .record .Marshal (); err != nil {
return nil , err
}
content , err := pkt .record .Content .Marshal ()
if err != nil {
return nil , err
}
inner := &recordlayer .InnerPlaintext {
Content : content ,
RealType : pkt .record .Header .ContentType ,
}
rawInner , err := inner .Marshal ()
if err != nil {
return nil , err
}
cidHeader := &recordlayer .Header {
Version : pkt .record .Header .Version ,
ContentType : protocol .ContentTypeConnectionID ,
Epoch : pkt .record .Header .Epoch ,
ContentLen : uint16 (len (rawInner )),
ConnectionID : c .state .remoteConnectionID ,
SequenceNumber : pkt .record .Header .SequenceNumber ,
}
rawPacket , err = cidHeader .Marshal ()
if err != nil {
return nil , err
}
pkt .record .Header = *cidHeader
rawPacket = append (rawPacket , rawInner ...)
} else {
var err error
rawPacket , err = pkt .record .Marshal ()
if err != nil {
return nil , err
}
}
if pkt .shouldEncrypt {
var err error
rawPacket , err = c .state .cipherSuite .Encrypt (pkt .record , rawPacket )
if err != nil {
return nil , err
}
}
return rawPacket , nil
}
func (c *Conn ) processHandshakePacket (pkt *packet , dtlsHandshake *handshake .Handshake ) ([][]byte , error ) {
rawPackets := make ([][]byte , 0 )
handshakeFragments , err := c .fragmentHandshake (dtlsHandshake )
if err != nil {
return nil , err
}
epoch := pkt .record .Header .Epoch
for len (c .state .localSequenceNumber ) <= int (epoch ) {
c .state .localSequenceNumber = append (c .state .localSequenceNumber , uint64 (0 ))
}
for _ , handshakeFragment := range handshakeFragments {
seq := atomic .AddUint64 (&c .state .localSequenceNumber [epoch ], 1 ) - 1
if seq > recordlayer .MaxSequenceNumber {
return nil , errSequenceNumberOverflow
}
var rawPacket []byte
if pkt .shouldWrapCID {
inner := &recordlayer .InnerPlaintext {
Content : handshakeFragment ,
RealType : protocol .ContentTypeHandshake ,
Zeros : c .paddingLengthGenerator (uint (len (handshakeFragment ))),
}
rawInner , err := inner .Marshal ()
if err != nil {
return nil , err
}
cidHeader := &recordlayer .Header {
Version : pkt .record .Header .Version ,
ContentType : protocol .ContentTypeConnectionID ,
Epoch : pkt .record .Header .Epoch ,
ContentLen : uint16 (len (rawInner )),
ConnectionID : c .state .remoteConnectionID ,
SequenceNumber : pkt .record .Header .SequenceNumber ,
}
rawPacket , err = cidHeader .Marshal ()
if err != nil {
return nil , err
}
pkt .record .Header = *cidHeader
rawPacket = append (rawPacket , rawInner ...)
} else {
recordlayerHeader := &recordlayer .Header {
Version : pkt .record .Header .Version ,
ContentType : pkt .record .Header .ContentType ,
ContentLen : uint16 (len (handshakeFragment )),
Epoch : pkt .record .Header .Epoch ,
SequenceNumber : seq ,
}
rawPacket , err = recordlayerHeader .Marshal ()
if err != nil {
return nil , err
}
pkt .record .Header = *recordlayerHeader
rawPacket = append (rawPacket , handshakeFragment ...)
}
if pkt .shouldEncrypt {
var err error
rawPacket , err = c .state .cipherSuite .Encrypt (pkt .record , rawPacket )
if err != nil {
return nil , err
}
}
rawPackets = append (rawPackets , rawPacket )
}
return rawPackets , nil
}
func (c *Conn ) fragmentHandshake (dtlsHandshake *handshake .Handshake ) ([][]byte , error ) {
content , err := dtlsHandshake .Message .Marshal ()
if err != nil {
return nil , err
}
fragmentedHandshakes := make ([][]byte , 0 )
contentFragments := splitBytes (content , c .maximumTransmissionUnit )
if len (contentFragments ) == 0 {
contentFragments = [][]byte {
{},
}
}
offset := 0
for _ , contentFragment := range contentFragments {
contentFragmentLen := len (contentFragment )
headerFragment := &handshake .Header {
Type : dtlsHandshake .Header .Type ,
Length : dtlsHandshake .Header .Length ,
MessageSequence : dtlsHandshake .Header .MessageSequence ,
FragmentOffset : uint32 (offset ),
FragmentLength : uint32 (contentFragmentLen ),
}
offset += contentFragmentLen
fragmentedHandshake , err := headerFragment .Marshal ()
if err != nil {
return nil , err
}
fragmentedHandshake = append (fragmentedHandshake , contentFragment ...)
fragmentedHandshakes = append (fragmentedHandshakes , fragmentedHandshake )
}
return fragmentedHandshakes , nil
}
var poolReadBuffer = sync .Pool {
New : func () interface {} {
b := make ([]byte , inboundBufferSize )
return &b
},
}
func (c *Conn ) readAndBuffer (ctx context .Context ) error {
bufptr , ok := poolReadBuffer .Get ().(*[]byte )
if !ok {
return errFailedToAccessPoolReadBuffer
}
defer poolReadBuffer .Put (bufptr )
b := *bufptr
i , rAddr , err := c .nextConn .ReadFromContext (ctx , b )
if err != nil {
return netError (err )
}
pkts , err := recordlayer .ContentAwareUnpackDatagram (b [:i ], len (c .state .getLocalConnectionID ()))
if err != nil {
return err
}
var hasHandshake , isRetransmit bool
for _ , p := range pkts {
hs , rtx , alert , err := c .handleIncomingPacket (ctx , p , rAddr , true )
if alert != nil {
if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
var e *alertError
if errors .As (err , &e ) && e .IsFatalOrCloseNotify () {
return e
}
if err != nil {
return err
}
if hs {
hasHandshake = true
}
if rtx {
isRetransmit = true
}
}
if hasHandshake {
s := recvHandshakeState {
done : make (chan struct {}),
isRetransmit : isRetransmit ,
}
select {
case c .handshakeRecv <- s :
<-s .done
case <- c .fsm .Done ():
}
}
return nil
}
func (c *Conn ) handleQueuedPackets (ctx context .Context ) error {
pkts := c .encryptedPackets
c .encryptedPackets = nil
for _ , p := range pkts {
_ , _ , alert , err := c .handleIncomingPacket (ctx , p .data , p .rAddr , false )
if alert != nil {
if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
var e *alertError
if errors .As (err , &e ) && e .IsFatalOrCloseNotify () {
return e
}
if err != nil {
return err
}
}
return nil
}
func (c *Conn ) enqueueEncryptedPackets (packet addrPkt ) bool {
if len (c .encryptedPackets ) < maxAppDataPacketQueueSize {
c .encryptedPackets = append (c .encryptedPackets , packet )
return true
}
return false
}
func (c *Conn ) handleIncomingPacket (
ctx context .Context ,
buf []byte ,
rAddr net .Addr ,
enqueue bool ,
) (bool , bool , *alert .Alert , error ) {
header := &recordlayer .Header {}
if len (c .state .getLocalConnectionID ()) > 0 {
header .ConnectionID = make ([]byte , len (c .state .getLocalConnectionID ()))
}
if err := header .Unmarshal (buf ); err != nil {
c .log .Debugf ("discarded broken packet: %v" , err )
return false , false , nil , nil
}
remoteEpoch := c .state .getRemoteEpoch ()
if header .Epoch > remoteEpoch {
if header .Epoch > remoteEpoch +1 {
c .log .Debugf ("discarded future packet (epoch: %d, seq: %d)" ,
header .Epoch , header .SequenceNumber ,
)
return false , false , nil , nil
}
if enqueue {
if ok := c .enqueueEncryptedPackets (addrPkt {rAddr , buf }); ok {
c .log .Debug ("received packet of next epoch, queuing packet" )
}
}
return false , false , nil , nil
}
for len (c .state .replayDetector ) <= int (header .Epoch ) {
c .state .replayDetector = append (c .state .replayDetector ,
replaydetector .New (c .replayProtectionWindow , recordlayer .MaxSequenceNumber ),
)
}
markPacketAsValid , ok := c .state .replayDetector [int (header .Epoch )].Check (header .SequenceNumber )
if !ok {
c .log .Debugf ("discarded duplicated packet (epoch: %d, seq: %d)" ,
header .Epoch , header .SequenceNumber ,
)
return false , false , nil , nil
}
originalCID := false
if header .Epoch != 0 {
if c .state .cipherSuite == nil || !c .state .cipherSuite .IsInitialized () {
if enqueue {
if ok := c .enqueueEncryptedPackets (addrPkt {rAddr , buf }); ok {
c .log .Debug ("handshake not finished, queuing packet" )
}
}
return false , false , nil , nil
}
if len (c .state .getLocalConnectionID ()) > 0 && header .ContentType != protocol .ContentTypeConnectionID {
c .log .Debug ("discarded packet missing connection ID after value negotiated" )
return false , false , nil , nil
}
var err error
var hdr recordlayer .Header
if header .ContentType == protocol .ContentTypeConnectionID {
hdr .ConnectionID = make ([]byte , len (c .state .getLocalConnectionID ()))
}
buf , err = c .state .cipherSuite .Decrypt (hdr , buf )
if err != nil {
c .log .Debugf ("%s: decrypt failed: %s" , srvCliStr (c .state .isClient ), err )
return false , false , nil , nil
}
if header .ContentType == protocol .ContentTypeConnectionID {
originalCID = true
ip := &recordlayer .InnerPlaintext {}
if err := ip .Unmarshal (buf [header .Size ():]); err != nil {
c .log .Debugf ("unpacking inner plaintext failed: %s" , err )
return false , false , nil , nil
}
unpacked := &recordlayer .Header {
ContentType : ip .RealType ,
ContentLen : uint16 (len (ip .Content )),
Version : header .Version ,
Epoch : header .Epoch ,
SequenceNumber : header .SequenceNumber ,
}
buf , err = unpacked .Marshal ()
if err != nil {
c .log .Debugf ("converting CID record to inner plaintext failed: %s" , err )
return false , false , nil , nil
}
buf = append (buf , ip .Content ...)
}
if !bytes .Equal (c .state .getLocalConnectionID (), header .ConnectionID ) {
c .log .Debug ("unexpected connection ID" )
return false , false , nil , nil
}
}
isHandshake , isRetransmit , err := c .fragmentBuffer .push (append ([]byte {}, buf ...))
if err != nil {
c .log .Debugf ("defragment failed: %s" , err )
return false , false , nil , nil
} else if isHandshake {
markPacketAsValid ()
for out , epoch := c .fragmentBuffer .pop (); out != nil ; out , epoch = c .fragmentBuffer .pop () {
header := &handshake .Header {}
if err := header .Unmarshal (out ); err != nil {
c .log .Debugf ("%s: handshake parse failed: %s" , srvCliStr (c .state .isClient ), err )
continue
}
c .handshakeCache .push (out , epoch , header .MessageSequence , header .Type , !c .state .isClient )
}
return true , isRetransmit , nil , nil
}
r := &recordlayer .RecordLayer {}
if err := r .Unmarshal (buf ); err != nil {
return false , false , &alert .Alert {Level : alert .Fatal , Description : alert .DecodeError }, err
}
isLatestSeqNum := false
switch content := r .Content .(type ) {
case *alert .Alert :
c .log .Tracef ("%s: <- %s" , srvCliStr (c .state .isClient ), content .String ())
var a *alert .Alert
if content .Description == alert .CloseNotify {
a = &alert .Alert {Level : alert .Warning , Description : alert .CloseNotify }
}
_ = markPacketAsValid ()
return false , false , a , &alertError {content }
case *protocol .ChangeCipherSpec :
if c .state .cipherSuite == nil || !c .state .cipherSuite .IsInitialized () {
if enqueue {
if ok := c .enqueueEncryptedPackets (addrPkt {rAddr , buf }); ok {
c .log .Debugf ("CipherSuite not initialized, queuing packet" )
}
}
return false , false , nil , nil
}
newRemoteEpoch := header .Epoch + 1
c .log .Tracef ("%s: <- ChangeCipherSpec (epoch: %d)" , srvCliStr (c .state .isClient ), newRemoteEpoch )
if c .state .getRemoteEpoch ()+1 == newRemoteEpoch {
c .setRemoteEpoch (newRemoteEpoch )
isLatestSeqNum = markPacketAsValid ()
}
case *protocol .ApplicationData :
if header .Epoch == 0 {
return false , false , &alert .Alert {
Level : alert .Fatal , Description : alert .UnexpectedMessage ,
}, errApplicationDataEpochZero
}
isLatestSeqNum = markPacketAsValid ()
select {
case c .decrypted <- content .Data :
case <- c .closed .Done ():
case <- ctx .Done ():
}
default :
return false , false , &alert .Alert {
Level : alert .Fatal , Description : alert .UnexpectedMessage ,
}, fmt .Errorf ("%w: %d" , errUnhandledContextType , content .ContentType ())
}
if originalCID && isLatestSeqNum {
if rAddr != c .RemoteAddr () {
c .lock .Lock ()
c .rAddr = rAddr
c .lock .Unlock ()
}
}
return false , false , nil , nil
}
func (c *Conn ) recvHandshake () <-chan recvHandshakeState {
return c .handshakeRecv
}
func (c *Conn ) notify (ctx context .Context , level alert .Level , desc alert .Description ) error {
if level == alert .Fatal && len (c .state .SessionID ) > 0 {
if ss := c .fsm .cfg .sessionStore ; ss != nil {
c .log .Tracef ("clean invalid session: %s" , c .state .SessionID )
if err := ss .Del (c .sessionKey ()); err != nil {
return err
}
}
}
return c .writePackets (ctx , []*packet {
{
record : &recordlayer .RecordLayer {
Header : recordlayer .Header {
Epoch : c .state .getLocalEpoch (),
Version : protocol .Version1_2 ,
},
Content : &alert .Alert {
Level : level ,
Description : desc ,
},
},
shouldWrapCID : len (c .state .remoteConnectionID ) > 0 ,
shouldEncrypt : c .isHandshakeCompletedSuccessfully (),
},
})
}
func (c *Conn ) setHandshakeCompletedSuccessfully () {
c .handshakeCompletedSuccessfully .Store (struct { bool }{true })
}
func (c *Conn ) isHandshakeCompletedSuccessfully () bool {
boolean , _ := c .handshakeCompletedSuccessfully .Load ().(struct { bool })
return boolean .bool
}
func (c *Conn ) handshake (
ctx context .Context ,
cfg *handshakeConfig ,
initialFlight flightVal ,
initialState handshakeState ,
) error {
c .fsm = newHandshakeFSM (&c .state , c .handshakeCache , cfg , initialFlight )
done := make (chan struct {})
ctxRead , cancelRead := context .WithCancel (context .Background ())
cfg .onFlightState = func (_ flightVal , s handshakeState ) {
if s == handshakeFinished && !c .isHandshakeCompletedSuccessfully () {
c .setHandshakeCompletedSuccessfully ()
close (done )
}
}
ctxHs , cancel := context .WithCancel (context .Background ())
c .closeLock .Lock ()
c .cancelHandshaker = cancel
c .cancelHandshakeReader = cancelRead
c .closeLock .Unlock ()
firstErr := make (chan error , 1 )
var handshakeLoopsFinished sync .WaitGroup
handshakeLoopsFinished .Add (2 )
go func () {
defer handshakeLoopsFinished .Done ()
err := c .fsm .Run (ctxHs , c , initialState )
if !errors .Is (err , context .Canceled ) {
select {
case firstErr <- err :
default :
}
}
}()
go func () {
defer func () {
if c .isHandshakeCompletedSuccessfully () {
close (c .decrypted )
}
cancel ()
}()
defer handshakeLoopsFinished .Done ()
for {
if err := c .readAndBuffer (ctxRead ); err != nil {
var alertErr *alertError
if errors .As (err , &alertErr ) {
if !alertErr .IsFatalOrCloseNotify () {
if c .isHandshakeCompletedSuccessfully () {
select {
case c .decrypted <- err :
case <- c .closed .Done ():
case <- ctxRead .Done ():
}
}
continue
}
} else {
switch {
case errors .Is (err , context .DeadlineExceeded ),
errors .Is (err , context .Canceled ),
errors .Is (err , io .EOF ),
errors .Is (err , net .ErrClosed ):
case errors .Is (err , recordlayer .ErrInvalidPacketLength ):
continue
default :
if c .isHandshakeCompletedSuccessfully () {
select {
case c .decrypted <- err :
case <- c .closed .Done ():
case <- ctxRead .Done ():
}
continue
}
}
}
select {
case firstErr <- err :
default :
}
if alertErr != nil {
if alertErr .IsFatalOrCloseNotify () {
_ = c .close (false )
}
}
if !c .isConnectionClosed () && errors .Is (err , context .Canceled ) {
c .log .Trace ("handshake timeouts - closing underline connection" )
_ = c .close (false )
}
return
}
}
}()
select {
case err := <- firstErr :
cancelRead ()
cancel ()
handshakeLoopsFinished .Wait ()
return c .translateHandshakeCtxError (err )
case <- ctx .Done ():
cancelRead ()
cancel ()
handshakeLoopsFinished .Wait ()
return c .translateHandshakeCtxError (ctx .Err ())
case <- done :
return nil
}
}
func (c *Conn ) translateHandshakeCtxError (err error ) error {
if err == nil {
return nil
}
if errors .Is (err , context .Canceled ) && c .isHandshakeCompletedSuccessfully () {
return nil
}
return &HandshakeError {Err : err }
}
func (c *Conn ) close (byUser bool ) error {
c .closeLock .Lock ()
cancelHandshaker := c .cancelHandshaker
cancelHandshakeReader := c .cancelHandshakeReader
c .closeLock .Unlock ()
cancelHandshaker ()
cancelHandshakeReader ()
if c .isHandshakeCompletedSuccessfully () && byUser {
_ = c .notify (context .Background (), alert .Warning , alert .CloseNotify )
}
c .closeLock .Lock ()
closedByUser := c .connectionClosedByUser
if byUser {
c .connectionClosedByUser = true
}
isClosed := c .isConnectionClosed ()
c .closed .Close ()
c .closeLock .Unlock ()
if closedByUser {
return ErrConnClosed
}
if isClosed {
return nil
}
return c .nextConn .Close ()
}
func (c *Conn ) isConnectionClosed () bool {
select {
case <- c .closed .Done ():
return true
default :
return false
}
}
func (c *Conn ) setLocalEpoch (epoch uint16 ) {
c .state .localEpoch .Store (epoch )
}
func (c *Conn ) setRemoteEpoch (epoch uint16 ) {
c .state .remoteEpoch .Store (epoch )
}
func (c *Conn ) LocalAddr () net .Addr {
return c .nextConn .LocalAddr ()
}
func (c *Conn ) RemoteAddr () net .Addr {
c .lock .RLock ()
defer c .lock .RUnlock ()
return c .rAddr
}
func (c *Conn ) sessionKey () []byte {
if c .state .isClient {
return []byte (c .rAddr .String () + "_" + c .fsm .cfg .serverName )
}
return c .state .SessionID
}
func (c *Conn ) SetDeadline (t time .Time ) error {
c .readDeadline .Set (t )
return c .SetWriteDeadline (t )
}
func (c *Conn ) SetReadDeadline (t time .Time ) error {
c .readDeadline .Set (t )
return nil
}
func (c *Conn ) SetWriteDeadline (t time .Time ) error {
c .writeDeadline .Set (t )
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .