package dtls
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/dtls/v2/internal/closer"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
"github.com/pion/logging"
"github.com/pion/transport/v2/connctx"
"github.com/pion/transport/v2/deadline"
"github.com/pion/transport/v2/replaydetector"
)
const (
initialTickerInterval = time .Second
cookieLength = 20
sessionLength = 32
defaultNamedCurve = elliptic .X25519
inboundBufferSize = 8192
defaultReplayProtectionWindow = 64
maxAppDataPacketQueueSize = 100
)
func invalidKeyingLabels() map [string ]bool {
return map [string ]bool {
"client finished" : true ,
"server finished" : true ,
"master secret" : true ,
"key expansion" : true ,
}
}
type Conn struct {
lock sync .RWMutex
nextConn connctx .ConnCtx
fragmentBuffer *fragmentBuffer
handshakeCache *handshakeCache
decrypted chan interface {}
state State
maximumTransmissionUnit int
handshakeCompletedSuccessfully atomic .Value
encryptedPackets [][]byte
connectionClosedByUser bool
closeLock sync .Mutex
closed *closer .Closer
handshakeLoopsFinished sync .WaitGroup
readDeadline *deadline .Deadline
writeDeadline *deadline .Deadline
log logging .LeveledLogger
reading chan struct {}
handshakeRecv chan chan struct {}
cancelHandshaker func ()
cancelHandshakeReader func ()
fsm *handshakeFSM
replayProtectionWindow uint
}
func createConn(nextConn net .Conn , config *Config , isClient bool ) (*Conn , error ) {
err := validateConfig (config )
if err != nil {
return nil , err
}
if nextConn == nil {
return nil , errNilNextConn
}
loggerFactory := config .LoggerFactory
if loggerFactory == nil {
loggerFactory = logging .NewDefaultLoggerFactory ()
}
logger := loggerFactory .NewLogger ("dtls" )
mtu := config .MTU
if mtu <= 0 {
mtu = defaultMTU
}
replayProtectionWindow := config .ReplayProtectionWindow
if replayProtectionWindow <= 0 {
replayProtectionWindow = defaultReplayProtectionWindow
}
c := &Conn {
nextConn : connctx .New (nextConn ),
fragmentBuffer : newFragmentBuffer (),
handshakeCache : newHandshakeCache (),
maximumTransmissionUnit : mtu ,
decrypted : make (chan interface {}, 1 ),
log : logger ,
readDeadline : deadline .New (),
writeDeadline : deadline .New (),
reading : make (chan struct {}, 1 ),
handshakeRecv : make (chan chan struct {}),
closed : closer .NewCloser (),
cancelHandshaker : func () {},
replayProtectionWindow : uint (replayProtectionWindow ),
state : State {
isClient : isClient ,
},
}
c .setRemoteEpoch (0 )
c .setLocalEpoch (0 )
return c , nil
}
func handshakeConn(ctx context .Context , conn *Conn , config *Config , isClient bool , initialState *State ) (*Conn , error ) {
if conn == nil {
return nil , errNilNextConn
}
cipherSuites , err := parseCipherSuites (config .CipherSuites , config .CustomCipherSuites , config .includeCertificateSuites (), config .PSK != nil )
if err != nil {
return nil , err
}
signatureSchemes , err := signaturehash .ParseSignatureSchemes (config .SignatureSchemes , config .InsecureHashes )
if err != nil {
return nil , err
}
workerInterval := initialTickerInterval
if config .FlightInterval != 0 {
workerInterval = config .FlightInterval
}
serverName := config .ServerName
if net .ParseIP (serverName ) != nil {
serverName = ""
}
curves := config .EllipticCurves
if len (curves ) == 0 {
curves = defaultCurves
}
hsCfg := &handshakeConfig {
localPSKCallback : config .PSK ,
localPSKIdentityHint : config .PSKIdentityHint ,
localCipherSuites : cipherSuites ,
localSignatureSchemes : signatureSchemes ,
extendedMasterSecret : config .ExtendedMasterSecret ,
localSRTPProtectionProfiles : config .SRTPProtectionProfiles ,
serverName : serverName ,
supportedProtocols : config .SupportedProtocols ,
clientAuth : config .ClientAuth ,
localCertificates : config .Certificates ,
insecureSkipVerify : config .InsecureSkipVerify ,
verifyPeerCertificate : config .VerifyPeerCertificate ,
verifyConnection : config .VerifyConnection ,
rootCAs : config .RootCAs ,
clientCAs : config .ClientCAs ,
customCipherSuites : config .CustomCipherSuites ,
retransmitInterval : workerInterval ,
log : conn .log ,
initialEpoch : 0 ,
keyLogWriter : config .KeyLogWriter ,
sessionStore : config .SessionStore ,
ellipticCurves : curves ,
localGetCertificate : config .GetCertificate ,
localGetClientCertificate : config .GetClientCertificate ,
insecureSkipHelloVerify : config .InsecureSkipVerifyHello ,
}
if !isClient {
cert , err := hsCfg .getCertificate (&ClientHelloInfo {})
if err != nil && !errors .Is (err , errNoCertificates ) {
return nil , err
}
hsCfg .localCipherSuites = filterCipherSuitesForCertificate (cert , cipherSuites )
}
var initialFlight flightVal
var initialFSMState handshakeState
if initialState != nil {
if conn .state .isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished
conn .state = *initialState
} else {
if conn .state .isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
if err := conn .handshake (ctx , hsCfg , initialFlight , initialFSMState ); err != nil {
return nil , err
}
conn .log .Trace ("Handshake Completed" )
return conn , nil
}
func Dial (network string , raddr *net .UDPAddr , config *Config ) (*Conn , error ) {
ctx , cancel := config .connectContextMaker ()
defer cancel ()
return DialWithContext (ctx , network , raddr , config )
}
func Client (conn net .Conn , config *Config ) (*Conn , error ) {
ctx , cancel := config .connectContextMaker ()
defer cancel ()
return ClientWithContext (ctx , conn , config )
}
func Server (conn net .Conn , config *Config ) (*Conn , error ) {
ctx , cancel := config .connectContextMaker ()
defer cancel ()
return ServerWithContext (ctx , conn , config )
}
func DialWithContext (ctx context .Context , network string , raddr *net .UDPAddr , config *Config ) (*Conn , error ) {
pConn , err := net .DialUDP (network , nil , raddr )
if err != nil {
return nil , err
}
return ClientWithContext (ctx , pConn , config )
}
func ClientWithContext (ctx context .Context , conn net .Conn , config *Config ) (*Conn , error ) {
switch {
case config == nil :
return nil , errNoConfigProvided
case config .PSK != nil && config .PSKIdentityHint == nil :
return nil , errPSKAndIdentityMustBeSetForClient
}
dconn , err := createConn (conn , config , true )
if err != nil {
return nil , err
}
return handshakeConn (ctx , dconn , config , true , nil )
}
func ServerWithContext (ctx context .Context , conn net .Conn , config *Config ) (*Conn , error ) {
if config == nil {
return nil , errNoConfigProvided
}
dconn , err := createConn (conn , config , false )
if err != nil {
return nil , err
}
return handshakeConn (ctx , dconn , config , false , nil )
}
func (c *Conn ) Read (p []byte ) (n int , err error ) {
if !c .isHandshakeCompletedSuccessfully () {
return 0 , errHandshakeInProgress
}
select {
case <- c .readDeadline .Done ():
return 0 , errDeadlineExceeded
default :
}
for {
select {
case <- c .readDeadline .Done ():
return 0 , errDeadlineExceeded
case out , ok := <- c .decrypted :
if !ok {
return 0 , io .EOF
}
switch val := out .(type ) {
case ([]byte ):
if len (p ) < len (val ) {
return 0 , errBufferTooSmall
}
copy (p , val )
return len (val ), nil
case (error ):
return 0 , val
}
}
}
}
func (c *Conn ) Write (p []byte ) (int , error ) {
if c .isConnectionClosed () {
return 0 , ErrConnClosed
}
select {
case <- c .writeDeadline .Done ():
return 0 , errDeadlineExceeded
default :
}
if !c .isHandshakeCompletedSuccessfully () {
return 0 , errHandshakeInProgress
}
return len (p ), c .writePackets (c .writeDeadline , []*packet {
{
record : &recordlayer .RecordLayer {
Header : recordlayer .Header {
Epoch : c .state .getLocalEpoch (),
Version : protocol .Version1_2 ,
},
Content : &protocol .ApplicationData {
Data : p ,
},
},
shouldEncrypt : true ,
},
})
}
func (c *Conn ) Close () error {
err := c .close (true )
c .handshakeLoopsFinished .Wait ()
return err
}
func (c *Conn ) ConnectionState () State {
c .lock .RLock ()
defer c .lock .RUnlock ()
return *c .state .clone ()
}
func (c *Conn ) SelectedSRTPProtectionProfile () (SRTPProtectionProfile , bool ) {
profile := c .state .getSRTPProtectionProfile ()
if profile == 0 {
return 0 , false
}
return profile , true
}
func (c *Conn ) writePackets (ctx context .Context , pkts []*packet ) error {
c .lock .Lock ()
defer c .lock .Unlock ()
var rawPackets [][]byte
for _ , p := range pkts {
if h , ok := p .record .Content .(*handshake .Handshake ); ok {
handshakeRaw , err := p .record .Marshal ()
if err != nil {
return err
}
c .log .Tracef ("[handshake:%v] -> %s (epoch: %d, seq: %d)" ,
srvCliStr (c .state .isClient ), h .Header .Type .String (),
p .record .Header .Epoch , h .Header .MessageSequence )
c .handshakeCache .push (handshakeRaw [recordlayer .HeaderSize :], p .record .Header .Epoch , h .Header .MessageSequence , h .Header .Type , c .state .isClient )
rawHandshakePackets , err := c .processHandshakePacket (p , h )
if err != nil {
return err
}
rawPackets = append (rawPackets , rawHandshakePackets ...)
} else {
rawPacket , err := c .processPacket (p )
if err != nil {
return err
}
rawPackets = append (rawPackets , rawPacket )
}
}
if len (rawPackets ) == 0 {
return nil
}
compactedRawPackets := c .compactRawPackets (rawPackets )
for _ , compactedRawPackets := range compactedRawPackets {
if _ , err := c .nextConn .WriteContext (ctx , compactedRawPackets ); err != nil {
return netError (err )
}
}
return nil
}
func (c *Conn ) compactRawPackets (rawPackets [][]byte ) [][]byte {
if len (rawPackets ) == 1 {
return rawPackets
}
combinedRawPackets := make ([][]byte , 0 )
currentCombinedRawPacket := make ([]byte , 0 )
for _ , rawPacket := range rawPackets {
if len (currentCombinedRawPacket ) > 0 && len (currentCombinedRawPacket )+len (rawPacket ) >= c .maximumTransmissionUnit {
combinedRawPackets = append (combinedRawPackets , currentCombinedRawPacket )
currentCombinedRawPacket = []byte {}
}
currentCombinedRawPacket = append (currentCombinedRawPacket , rawPacket ...)
}
combinedRawPackets = append (combinedRawPackets , currentCombinedRawPacket )
return combinedRawPackets
}
func (c *Conn ) processPacket (p *packet ) ([]byte , error ) {
epoch := p .record .Header .Epoch
for len (c .state .localSequenceNumber ) <= int (epoch ) {
c .state .localSequenceNumber = append (c .state .localSequenceNumber , uint64 (0 ))
}
seq := atomic .AddUint64 (&c .state .localSequenceNumber [epoch ], 1 ) - 1
if seq > recordlayer .MaxSequenceNumber {
return nil , errSequenceNumberOverflow
}
p .record .Header .SequenceNumber = seq
rawPacket , err := p .record .Marshal ()
if err != nil {
return nil , err
}
if p .shouldEncrypt {
var err error
rawPacket , err = c .state .cipherSuite .Encrypt (p .record , rawPacket )
if err != nil {
return nil , err
}
}
return rawPacket , nil
}
func (c *Conn ) processHandshakePacket (p *packet , h *handshake .Handshake ) ([][]byte , error ) {
rawPackets := make ([][]byte , 0 )
handshakeFragments , err := c .fragmentHandshake (h )
if err != nil {
return nil , err
}
epoch := p .record .Header .Epoch
for len (c .state .localSequenceNumber ) <= int (epoch ) {
c .state .localSequenceNumber = append (c .state .localSequenceNumber , uint64 (0 ))
}
for _ , handshakeFragment := range handshakeFragments {
seq := atomic .AddUint64 (&c .state .localSequenceNumber [epoch ], 1 ) - 1
if seq > recordlayer .MaxSequenceNumber {
return nil , errSequenceNumberOverflow
}
recordlayerHeader := &recordlayer .Header {
Version : p .record .Header .Version ,
ContentType : p .record .Header .ContentType ,
ContentLen : uint16 (len (handshakeFragment )),
Epoch : p .record .Header .Epoch ,
SequenceNumber : seq ,
}
rawPacket , err := recordlayerHeader .Marshal ()
if err != nil {
return nil , err
}
p .record .Header = *recordlayerHeader
rawPacket = append (rawPacket , handshakeFragment ...)
if p .shouldEncrypt {
var err error
rawPacket , err = c .state .cipherSuite .Encrypt (p .record , rawPacket )
if err != nil {
return nil , err
}
}
rawPackets = append (rawPackets , rawPacket )
}
return rawPackets , nil
}
func (c *Conn ) fragmentHandshake (h *handshake .Handshake ) ([][]byte , error ) {
content , err := h .Message .Marshal ()
if err != nil {
return nil , err
}
fragmentedHandshakes := make ([][]byte , 0 )
contentFragments := splitBytes (content , c .maximumTransmissionUnit )
if len (contentFragments ) == 0 {
contentFragments = [][]byte {
{},
}
}
offset := 0
for _ , contentFragment := range contentFragments {
contentFragmentLen := len (contentFragment )
headerFragment := &handshake .Header {
Type : h .Header .Type ,
Length : h .Header .Length ,
MessageSequence : h .Header .MessageSequence ,
FragmentOffset : uint32 (offset ),
FragmentLength : uint32 (contentFragmentLen ),
}
offset += contentFragmentLen
fragmentedHandshake , err := headerFragment .Marshal ()
if err != nil {
return nil , err
}
fragmentedHandshake = append (fragmentedHandshake , contentFragment ...)
fragmentedHandshakes = append (fragmentedHandshakes , fragmentedHandshake )
}
return fragmentedHandshakes , nil
}
var poolReadBuffer = sync .Pool {
New : func () interface {} {
b := make ([]byte , inboundBufferSize )
return &b
},
}
func (c *Conn ) readAndBuffer (ctx context .Context ) error {
bufptr , ok := poolReadBuffer .Get ().(*[]byte )
if !ok {
return errFailedToAccessPoolReadBuffer
}
defer poolReadBuffer .Put (bufptr )
b := *bufptr
i , err := c .nextConn .ReadContext (ctx , b )
if err != nil {
return netError (err )
}
pkts , err := recordlayer .UnpackDatagram (b [:i ])
if err != nil {
return err
}
var hasHandshake bool
for _ , p := range pkts {
hs , alert , err := c .handleIncomingPacket (ctx , p , true )
if alert != nil {
if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
if hs {
hasHandshake = true
}
if err != nil {
return err
}
}
if hasHandshake {
done := make (chan struct {})
select {
case c .handshakeRecv <- done :
<-done
case <- c .fsm .Done ():
}
}
return nil
}
func (c *Conn ) handleQueuedPackets (ctx context .Context ) error {
pkts := c .encryptedPackets
c .encryptedPackets = nil
for _ , p := range pkts {
_ , alert , err := c .handleIncomingPacket (ctx , p , false )
if alert != nil {
if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
var e *alertError
if errors .As (err , &e ) {
if e .IsFatalOrCloseNotify () {
return e
}
} else if err != nil {
return err
}
}
return nil
}
func (c *Conn ) enqueueEncryptedPackets (packet []byte ) bool {
if len (c .encryptedPackets ) < maxAppDataPacketQueueSize {
c .encryptedPackets = append (c .encryptedPackets , packet )
return true
}
return false
}
func (c *Conn ) handleIncomingPacket (ctx context .Context , buf []byte , enqueue bool ) (bool , *alert .Alert , error ) {
h := &recordlayer .Header {}
if err := h .Unmarshal (buf ); err != nil {
c .log .Debugf ("discarded broken packet: %v" , err )
return false , nil , nil
}
remoteEpoch := c .state .getRemoteEpoch ()
if h .Epoch > remoteEpoch {
if h .Epoch > remoteEpoch +1 {
c .log .Debugf ("discarded future packet (epoch: %d, seq: %d)" ,
h .Epoch , h .SequenceNumber ,
)
return false , nil , nil
}
if enqueue {
if ok := c .enqueueEncryptedPackets (buf ); ok {
c .log .Debug ("received packet of next epoch, queuing packet" )
}
}
return false , nil , nil
}
for len (c .state .replayDetector ) <= int (h .Epoch ) {
c .state .replayDetector = append (c .state .replayDetector ,
replaydetector .New (c .replayProtectionWindow , recordlayer .MaxSequenceNumber ),
)
}
markPacketAsValid , ok := c .state .replayDetector [int (h .Epoch )].Check (h .SequenceNumber )
if !ok {
c .log .Debugf ("discarded duplicated packet (epoch: %d, seq: %d)" ,
h .Epoch , h .SequenceNumber ,
)
return false , nil , nil
}
if h .Epoch != 0 {
if c .state .cipherSuite == nil || !c .state .cipherSuite .IsInitialized () {
if enqueue {
if ok := c .enqueueEncryptedPackets (buf ); ok {
c .log .Debug ("handshake not finished, queuing packet" )
}
}
return false , nil , nil
}
var err error
buf , err = c .state .cipherSuite .Decrypt (buf )
if err != nil {
c .log .Debugf ("%s: decrypt failed: %s" , srvCliStr (c .state .isClient ), err )
return false , nil , nil
}
}
isHandshake , err := c .fragmentBuffer .push (append ([]byte {}, buf ...))
if err != nil {
c .log .Debugf ("defragment failed: %s" , err )
return false , nil , nil
} else if isHandshake {
markPacketAsValid ()
for out , epoch := c .fragmentBuffer .pop (); out != nil ; out , epoch = c .fragmentBuffer .pop () {
header := &handshake .Header {}
if err := header .Unmarshal (out ); err != nil {
c .log .Debugf ("%s: handshake parse failed: %s" , srvCliStr (c .state .isClient ), err )
continue
}
c .handshakeCache .push (out , epoch , header .MessageSequence , header .Type , !c .state .isClient )
}
return true , nil , nil
}
r := &recordlayer .RecordLayer {}
if err := r .Unmarshal (buf ); err != nil {
return false , &alert .Alert {Level : alert .Fatal , Description : alert .DecodeError }, err
}
switch content := r .Content .(type ) {
case *alert .Alert :
c .log .Tracef ("%s: <- %s" , srvCliStr (c .state .isClient ), content .String ())
var a *alert .Alert
if content .Description == alert .CloseNotify {
a = &alert .Alert {Level : alert .Warning , Description : alert .CloseNotify }
}
markPacketAsValid ()
return false , a , &alertError {content }
case *protocol .ChangeCipherSpec :
if c .state .cipherSuite == nil || !c .state .cipherSuite .IsInitialized () {
if enqueue {
if ok := c .enqueueEncryptedPackets (buf ); ok {
c .log .Debugf ("CipherSuite not initialized, queuing packet" )
}
}
return false , nil , nil
}
newRemoteEpoch := h .Epoch + 1
c .log .Tracef ("%s: <- ChangeCipherSpec (epoch: %d)" , srvCliStr (c .state .isClient ), newRemoteEpoch )
if c .state .getRemoteEpoch ()+1 == newRemoteEpoch {
c .setRemoteEpoch (newRemoteEpoch )
markPacketAsValid ()
}
case *protocol .ApplicationData :
if h .Epoch == 0 {
return false , &alert .Alert {Level : alert .Fatal , Description : alert .UnexpectedMessage }, errApplicationDataEpochZero
}
markPacketAsValid ()
select {
case c .decrypted <- content .Data :
case <- c .closed .Done ():
case <- ctx .Done ():
}
default :
return false , &alert .Alert {Level : alert .Fatal , Description : alert .UnexpectedMessage }, fmt .Errorf ("%w: %d" , errUnhandledContextType , content .ContentType ())
}
return false , nil , nil
}
func (c *Conn ) recvHandshake () <-chan chan struct {} {
return c .handshakeRecv
}
func (c *Conn ) notify (ctx context .Context , level alert .Level , desc alert .Description ) error {
if level == alert .Fatal && len (c .state .SessionID ) > 0 {
if ss := c .fsm .cfg .sessionStore ; ss != nil {
c .log .Tracef ("clean invalid session: %s" , c .state .SessionID )
if err := ss .Del (c .sessionKey ()); err != nil {
return err
}
}
}
return c .writePackets (ctx , []*packet {
{
record : &recordlayer .RecordLayer {
Header : recordlayer .Header {
Epoch : c .state .getLocalEpoch (),
Version : protocol .Version1_2 ,
},
Content : &alert .Alert {
Level : level ,
Description : desc ,
},
},
shouldEncrypt : c .isHandshakeCompletedSuccessfully (),
},
})
}
func (c *Conn ) setHandshakeCompletedSuccessfully () {
c .handshakeCompletedSuccessfully .Store (struct { bool }{true })
}
func (c *Conn ) isHandshakeCompletedSuccessfully () bool {
boolean , _ := c .handshakeCompletedSuccessfully .Load ().(struct { bool })
return boolean .bool
}
func (c *Conn ) handshake (ctx context .Context , cfg *handshakeConfig , initialFlight flightVal , initialState handshakeState ) error {
c .fsm = newHandshakeFSM (&c .state , c .handshakeCache , cfg , initialFlight )
done := make (chan struct {})
ctxRead , cancelRead := context .WithCancel (context .Background ())
c .cancelHandshakeReader = cancelRead
cfg .onFlightState = func (f flightVal , s handshakeState ) {
if s == handshakeFinished && !c .isHandshakeCompletedSuccessfully () {
c .setHandshakeCompletedSuccessfully ()
close (done )
}
}
ctxHs , cancel := context .WithCancel (context .Background ())
c .cancelHandshaker = cancel
firstErr := make (chan error , 1 )
c .handshakeLoopsFinished .Add (2 )
go func () {
defer c .handshakeLoopsFinished .Done ()
err := c .fsm .Run (ctxHs , c , initialState )
if !errors .Is (err , context .Canceled ) {
select {
case firstErr <- err :
default :
}
}
}()
go func () {
defer func () {
close (c .decrypted )
cancel ()
}()
defer c .handshakeLoopsFinished .Done ()
for {
if err := c .readAndBuffer (ctxRead ); err != nil {
var e *alertError
if errors .As (err , &e ) {
if !e .IsFatalOrCloseNotify () {
if c .isHandshakeCompletedSuccessfully () {
select {
case c .decrypted <- err :
case <- c .closed .Done ():
case <- ctxRead .Done ():
}
}
continue
}
} else {
switch {
case errors .Is (err , context .DeadlineExceeded ), errors .Is (err , context .Canceled ), errors .Is (err , io .EOF ), errors .Is (err , net .ErrClosed ):
case errors .Is (err , recordlayer .ErrInvalidPacketLength ):
continue
default :
if c .isHandshakeCompletedSuccessfully () {
select {
case c .decrypted <- err :
case <- c .closed .Done ():
case <- ctxRead .Done ():
}
continue
}
}
}
select {
case firstErr <- err :
default :
}
if e != nil {
if e .IsFatalOrCloseNotify () {
_ = c .close (false )
}
}
if !c .isConnectionClosed () && errors .Is (err , context .Canceled ) {
c .log .Trace ("handshake timeouts - closing underline connection" )
_ = c .close (false )
}
return
}
}
}()
select {
case err := <- firstErr :
cancelRead ()
cancel ()
c .handshakeLoopsFinished .Wait ()
return c .translateHandshakeCtxError (err )
case <- ctx .Done ():
cancelRead ()
cancel ()
c .handshakeLoopsFinished .Wait ()
return c .translateHandshakeCtxError (ctx .Err ())
case <- done :
return nil
}
}
func (c *Conn ) translateHandshakeCtxError (err error ) error {
if err == nil {
return nil
}
if errors .Is (err , context .Canceled ) && c .isHandshakeCompletedSuccessfully () {
return nil
}
return &HandshakeError {Err : err }
}
func (c *Conn ) close (byUser bool ) error {
c .cancelHandshaker ()
c .cancelHandshakeReader ()
if c .isHandshakeCompletedSuccessfully () && byUser {
_ = c .notify (context .Background (), alert .Warning , alert .CloseNotify )
}
c .closeLock .Lock ()
closedByUser := c .connectionClosedByUser
if byUser {
c .connectionClosedByUser = true
}
isClosed := c .isConnectionClosed ()
c .closed .Close ()
c .closeLock .Unlock ()
if closedByUser {
return ErrConnClosed
}
if isClosed {
return nil
}
return c .nextConn .Close ()
}
func (c *Conn ) isConnectionClosed () bool {
select {
case <- c .closed .Done ():
return true
default :
return false
}
}
func (c *Conn ) setLocalEpoch (epoch uint16 ) {
c .state .localEpoch .Store (epoch )
}
func (c *Conn ) setRemoteEpoch (epoch uint16 ) {
c .state .remoteEpoch .Store (epoch )
}
func (c *Conn ) LocalAddr () net .Addr {
return c .nextConn .LocalAddr ()
}
func (c *Conn ) RemoteAddr () net .Addr {
return c .nextConn .RemoteAddr ()
}
func (c *Conn ) sessionKey () []byte {
if c .state .isClient {
return []byte (c .nextConn .RemoteAddr ().String () + "_" + c .fsm .cfg .serverName )
}
return c .state .SessionID
}
func (c *Conn ) SetDeadline (t time .Time ) error {
c .readDeadline .Set (t )
return c .SetWriteDeadline (t )
}
func (c *Conn ) SetReadDeadline (t time .Time ) error {
c .readDeadline .Set (t )
return nil
}
func (c *Conn ) SetWriteDeadline (t time .Time ) error {
c .writeDeadline .Set (t )
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .