package quic
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
var ErrTransportClosed = &errTransportClosed {}
type errTransportClosed struct {
err error
}
func (e *errTransportClosed ) Unwrap () []error { return []error {net .ErrClosed , e .err } }
func (e *errTransportClosed ) Error () string {
if e .err == nil {
return "quic: transport closed"
}
return fmt .Sprintf ("quic: transport closed: %s" , e .err )
}
func (e *errTransportClosed ) Is (target error ) bool {
_ , ok := target .(*errTransportClosed )
return ok
}
var errListenerAlreadySet = errors .New ("listener already set" )
type closePacket struct {
payload []byte
addr net .Addr
info packetInfo
}
type Transport struct {
Conn net .PacketConn
ConnectionIDLength int
ConnectionIDGenerator ConnectionIDGenerator
StatelessResetKey *StatelessResetKey
TokenGeneratorKey *TokenGeneratorKey
MaxTokenAge time .Duration
DisableVersionNegotiationPackets bool
VerifySourceAddress func (net .Addr ) bool
ConnContext func (context .Context , *ClientInfo ) (context .Context , error )
Tracer *logging .Tracer
connMx sync .Mutex
handlers map [protocol .ConnectionID ]packetHandler
resetTokens map [protocol .StatelessResetToken ]packetHandler
mutex sync .Mutex
initOnce sync .Once
initErr error
connIDLen int
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
server *baseServer
conn rawConn
closeQueue chan closePacket
statelessResetQueue chan receivedPacket
listening chan struct {}
closeErr error
createdConn bool
isSingleUse bool
readingNonQUICPackets atomic .Bool
nonQUICPackets chan receivedPacket
logger utils .Logger
}
func (t *Transport ) Listen (tlsConf *tls .Config , conf *Config ) (*Listener , error ) {
s , err := t .createServer (tlsConf , conf , false )
if err != nil {
return nil , err
}
return &Listener {baseServer : s }, nil
}
func (t *Transport ) ListenEarly (tlsConf *tls .Config , conf *Config ) (*EarlyListener , error ) {
s , err := t .createServer (tlsConf , conf , true )
if err != nil {
return nil , err
}
return &EarlyListener {baseServer : s }, nil
}
func (t *Transport ) createServer (tlsConf *tls .Config , conf *Config , allow0RTT bool ) (*baseServer , error ) {
if tlsConf == nil {
return nil , errors .New ("quic: tls.Config not set" )
}
if err := validateConfig (conf ); err != nil {
return nil , err
}
t .mutex .Lock ()
defer t .mutex .Unlock ()
if t .closeErr != nil {
return nil , t .closeErr
}
if t .server != nil {
return nil , errListenerAlreadySet
}
conf = populateConfig (conf )
if err := t .init (false ); err != nil {
return nil , err
}
maxTokenAge := t .MaxTokenAge
if maxTokenAge == 0 {
maxTokenAge = 24 * time .Hour
}
s := newServer (
t .conn ,
(*packetHandlerMap )(t ),
t .connIDGenerator ,
t .statelessResetter ,
t .ConnContext ,
tlsConf ,
conf ,
t .Tracer ,
t .closeServer ,
*t .TokenGeneratorKey ,
maxTokenAge ,
t .VerifySourceAddress ,
t .DisableVersionNegotiationPackets ,
allow0RTT ,
)
t .server = s
return s , nil
}
func (t *Transport ) Dial (ctx context .Context , addr net .Addr , tlsConf *tls .Config , conf *Config ) (Connection , error ) {
return t .dial (ctx , addr , "" , tlsConf , conf , false )
}
func (t *Transport ) DialEarly (ctx context .Context , addr net .Addr , tlsConf *tls .Config , conf *Config ) (EarlyConnection , error ) {
return t .dial (ctx , addr , "" , tlsConf , conf , true )
}
func (t *Transport ) dial (ctx context .Context , addr net .Addr , host string , tlsConf *tls .Config , conf *Config , use0RTT bool ) (EarlyConnection , error ) {
if err := t .init (t .isSingleUse ); err != nil {
return nil , err
}
if err := validateConfig (conf ); err != nil {
return nil , err
}
conf = populateConfig (conf )
tlsConf = tlsConf .Clone ()
setTLSConfigServerName (tlsConf , addr , host )
return t .doDial (ctx ,
newSendConn (t .conn , addr , packetInfo {}, utils .DefaultLogger ),
tlsConf ,
conf ,
0 ,
false ,
use0RTT ,
conf .Versions [0 ],
)
}
func (t *Transport ) doDial (
ctx context .Context ,
sendConn sendConn ,
tlsConf *tls .Config ,
config *Config ,
initialPacketNumber protocol .PacketNumber ,
hasNegotiatedVersion bool ,
use0RTT bool ,
version protocol .Version ,
) (quicConn , error ) {
srcConnID , err := t .connIDGenerator .GenerateConnectionID ()
if err != nil {
return nil , err
}
destConnID , err := generateConnectionIDForInitial ()
if err != nil {
return nil , err
}
tracingID := nextConnTracingID ()
ctx = context .WithValue (ctx , ConnectionTracingKey , tracingID )
t .mutex .Lock ()
if t .closeErr != nil {
t .mutex .Unlock ()
return nil , t .closeErr
}
var tracer *logging .ConnectionTracer
if config .Tracer != nil {
tracer = config .Tracer (ctx , protocol .PerspectiveClient , destConnID )
}
if tracer != nil && tracer .StartedConnection != nil {
tracer .StartedConnection (sendConn .LocalAddr (), sendConn .RemoteAddr (), srcConnID , destConnID )
}
logger := utils .DefaultLogger .WithPrefix ("client" )
logger .Infof ("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s" , tlsConf .ServerName , sendConn .LocalAddr (), sendConn .RemoteAddr (), srcConnID , destConnID , version )
conn := newClientConnection (
context .WithoutCancel (ctx ),
sendConn ,
(*packetHandlerMap )(t ),
destConnID ,
srcConnID ,
t .connIDGenerator ,
t .statelessResetter ,
config ,
tlsConf ,
initialPacketNumber ,
use0RTT ,
hasNegotiatedVersion ,
tracer ,
logger ,
version ,
)
t .connMx .Lock ()
t .handlers [srcConnID ] = conn
t .connMx .Unlock ()
t .mutex .Unlock ()
errChan := make (chan error , 1 )
recreateChan := make (chan errCloseForRecreating )
go func () {
err := conn .run ()
var recreateErr *errCloseForRecreating
if errors .As (err , &recreateErr ) {
recreateChan <- *recreateErr
return
}
if t .isSingleUse {
t .Close ()
}
errChan <- err
}()
var earlyConnChan <-chan struct {}
if use0RTT {
earlyConnChan = conn .earlyConnReady ()
}
select {
case <- ctx .Done ():
conn .destroy (nil )
select {
case <- errChan :
case <- recreateChan :
}
return nil , context .Cause (ctx )
case params := <- recreateChan :
return t .doDial (ctx ,
sendConn ,
tlsConf ,
config ,
params .nextPacketNumber ,
true ,
use0RTT ,
params .nextVersion ,
)
case err := <- errChan :
return nil , err
case <- earlyConnChan :
return conn , nil
case <- conn .HandshakeComplete ():
return conn , nil
}
}
func (t *Transport ) init (allowZeroLengthConnIDs bool ) error {
t .initOnce .Do (func () {
var conn rawConn
if c , ok := t .Conn .(rawConn ); ok {
conn = c
} else {
var err error
conn , err = wrapConn (t .Conn )
if err != nil {
t .initErr = err
return
}
}
t .logger = utils .DefaultLogger
t .conn = conn
t .handlers = make (map [protocol .ConnectionID ]packetHandler )
t .resetTokens = make (map [protocol .StatelessResetToken ]packetHandler )
t .listening = make (chan struct {})
t .closeQueue = make (chan closePacket , 4 )
t .statelessResetQueue = make (chan receivedPacket , 4 )
if t .TokenGeneratorKey == nil {
var key TokenGeneratorKey
if _ , err := rand .Read (key [:]); err != nil {
t .initErr = err
return
}
t .TokenGeneratorKey = &key
}
if t .ConnectionIDGenerator != nil {
t .connIDGenerator = t .ConnectionIDGenerator
t .connIDLen = t .ConnectionIDGenerator .ConnectionIDLen ()
} else {
connIDLen := t .ConnectionIDLength
if t .ConnectionIDLength == 0 && !allowZeroLengthConnIDs {
connIDLen = protocol .DefaultConnectionIDLength
}
t .connIDLen = connIDLen
t .connIDGenerator = &protocol .DefaultConnectionIDGenerator {ConnLen : t .connIDLen }
}
t .statelessResetter = newStatelessResetter (t .StatelessResetKey )
go func () {
defer close (t .listening )
t .listen (conn )
if t .createdConn {
conn .Close ()
}
}()
go t .runSendQueue ()
})
return t .initErr
}
func (t *Transport ) WriteTo (b []byte , addr net .Addr ) (int , error ) {
if err := t .init (false ); err != nil {
return 0 , err
}
return t .conn .WritePacket (b , addr , nil , 0 , protocol .ECNUnsupported )
}
func (t *Transport ) runSendQueue () {
for {
select {
case <- t .listening :
return
case p := <- t .closeQueue :
t .conn .WritePacket (p .payload , p .addr , p .info .OOB (), 0 , protocol .ECNUnsupported )
case p := <- t .statelessResetQueue :
t .sendStatelessReset (p )
}
}
}
func (t *Transport ) Close () error {
t .init (false )
t .close (nil )
if t .createdConn {
if err := t .Conn .Close (); err != nil {
return err
}
} else if t .conn != nil {
t .conn .SetReadDeadline (time .Now ())
defer func () { t .conn .SetReadDeadline (time .Time {}) }()
}
if t .listening != nil {
<-t .listening
}
return nil
}
func (t *Transport ) closeServer () {
t .mutex .Lock ()
defer t .mutex .Unlock ()
t .server = nil
if t .isSingleUse {
t .closeErr = ErrServerClosed
}
t .connMx .Lock ()
defer t .connMx .Unlock ()
if len (t .handlers ) == 0 {
t .maybeStopListening ()
}
}
func (t *Transport ) close (e error ) {
t .mutex .Lock ()
defer t .mutex .Unlock ()
if t .closeErr != nil {
return
}
e = &errTransportClosed {err : e }
var wg sync .WaitGroup
t .connMx .Lock ()
for _ , handler := range t .handlers {
wg .Add (1 )
go func (handler packetHandler ) {
handler .destroy (e )
wg .Done ()
}(handler )
}
t .connMx .Unlock ()
wg .Wait ()
if t .server != nil {
t .server .close (e , false )
}
if t .Tracer != nil && t .Tracer .Close != nil {
t .Tracer .Close ()
}
t .closeErr = e
}
var setBufferWarningOnce sync .Once
func (t *Transport ) listen (conn rawConn ) {
for {
p , err := conn .ReadPacket ()
if nerr , ok := err .(net .Error ); ok && nerr .Temporary () {
t .mutex .Lock ()
closed := t .closeErr != nil
t .mutex .Unlock ()
if closed {
return
}
t .logger .Debugf ("Temporary error reading from conn: %w" , err )
continue
}
if err != nil {
if isRecvMsgSizeErr (err ) {
continue
}
t .close (err )
return
}
t .handlePacket (p )
}
}
func (t *Transport ) maybeStopListening () {
if t .isSingleUse && t .closeErr != nil {
t .conn .SetReadDeadline (time .Now ())
}
}
func (t *Transport ) handlePacket (p receivedPacket ) {
if len (p .data ) == 0 {
return
}
if !wire .IsPotentialQUICPacket (p .data [0 ]) && !wire .IsLongHeaderPacket (p .data [0 ]) {
t .handleNonQUICPacket (p )
return
}
connID , err := wire .ParseConnectionID (p .data , t .connIDLen )
if err != nil {
t .logger .Debugf ("error parsing connection ID on packet from %s: %s" , p .remoteAddr , err )
if t .Tracer != nil && t .Tracer .DroppedPacket != nil {
t .Tracer .DroppedPacket (p .remoteAddr , logging .PacketTypeNotDetermined , p .Size (), logging .PacketDropHeaderParseError )
}
p .buffer .MaybeRelease ()
return
}
if handler , ok := (*packetHandlerMap )(t ).Get (connID ); ok {
handler .handlePacket (p )
return
}
if isStatelessReset := t .maybeHandleStatelessReset (p .data ); isStatelessReset {
return
}
if !wire .IsLongHeaderPacket (p .data [0 ]) {
if statelessResetQueued := t .maybeSendStatelessReset (p ); !statelessResetQueued {
if t .Tracer != nil && t .Tracer .DroppedPacket != nil {
t .Tracer .DroppedPacket (p .remoteAddr , logging .PacketTypeNotDetermined , p .Size (), logging .PacketDropUnknownConnectionID )
}
p .buffer .Release ()
}
return
}
t .mutex .Lock ()
defer t .mutex .Unlock ()
if t .server == nil {
t .logger .Debugf ("received a packet with an unexpected connection ID %s" , connID )
if t .Tracer != nil && t .Tracer .DroppedPacket != nil {
t .Tracer .DroppedPacket (p .remoteAddr , logging .PacketTypeNotDetermined , p .Size (), logging .PacketDropUnknownConnectionID )
}
p .buffer .MaybeRelease ()
return
}
t .server .handlePacket (p )
}
func (t *Transport ) maybeSendStatelessReset (p receivedPacket ) (statelessResetQueued bool ) {
if t .StatelessResetKey == nil {
return false
}
if len (p .data ) <= protocol .MinStatelessResetSize {
return false
}
select {
case t .statelessResetQueue <- p :
return true
default :
return false
}
}
func (t *Transport ) sendStatelessReset (p receivedPacket ) {
defer p .buffer .Release ()
connID , err := wire .ParseConnectionID (p .data , t .connIDLen )
if err != nil {
t .logger .Errorf ("error parsing connection ID on packet from %s: %s" , p .remoteAddr , err )
return
}
token := t .statelessResetter .GetStatelessResetToken (connID )
t .logger .Debugf ("Sending stateless reset to %s (connection ID: %s). Token: %#x" , p .remoteAddr , connID , token )
data := make ([]byte , protocol .MinStatelessResetSize -16 , protocol .MinStatelessResetSize )
rand .Read (data )
data [0 ] = (data [0 ] & 0x7f ) | 0x40
data = append (data , token [:]...)
if _ , err := t .conn .WritePacket (data , p .remoteAddr , p .info .OOB (), 0 , protocol .ECNUnsupported ); err != nil {
t .logger .Debugf ("Error sending Stateless Reset to %s: %s" , p .remoteAddr , err )
}
}
func (t *Transport ) maybeHandleStatelessReset (data []byte ) bool {
if wire .IsLongHeaderPacket (data [0 ]) {
return false
}
if len (data ) < 17 {
return false
}
token := protocol .StatelessResetToken (data [len (data )-16 :])
t .connMx .Lock ()
conn , ok := t .resetTokens [token ]
t .connMx .Unlock ()
if ok {
t .logger .Debugf ("Received a stateless reset with token %#x. Closing connection." , token )
go conn .destroy (&StatelessResetError {})
return true
}
return false
}
func (t *Transport ) handleNonQUICPacket (p receivedPacket ) {
if !t .readingNonQUICPackets .Load () {
return
}
select {
case t .nonQUICPackets <- p :
default :
if t .Tracer != nil && t .Tracer .DroppedPacket != nil {
t .Tracer .DroppedPacket (p .remoteAddr , logging .PacketTypeNotDetermined , p .Size (), logging .PacketDropDOSPrevention )
}
}
}
const maxQueuedNonQUICPackets = 32
func (t *Transport ) ReadNonQUICPacket (ctx context .Context , b []byte ) (int , net .Addr , error ) {
if err := t .init (false ); err != nil {
return 0 , nil , err
}
if !t .readingNonQUICPackets .Load () {
t .nonQUICPackets = make (chan receivedPacket , maxQueuedNonQUICPackets )
t .readingNonQUICPackets .Store (true )
}
select {
case <- ctx .Done ():
return 0 , nil , ctx .Err ()
case p := <- t .nonQUICPackets :
n := copy (b , p .data )
return n , p .remoteAddr , nil
case <- t .listening :
return 0 , nil , errors .New ("closed" )
}
}
func setTLSConfigServerName(tlsConf *tls .Config , addr net .Addr , host string ) {
if tlsConf .ServerName != "" {
return
}
if host == "" {
if udpAddr , ok := addr .(*net .UDPAddr ); ok {
tlsConf .ServerName = udpAddr .IP .String ()
return
}
}
h , _ , err := net .SplitHostPort (host )
if err != nil {
tlsConf .ServerName = host
return
}
tlsConf .ServerName = h
}
type packetHandlerMap Transport
var _ connRunner = &packetHandlerMap {}
func (h *packetHandlerMap ) Add (id protocol .ConnectionID , handler packetHandler ) bool {
h .connMx .Lock ()
defer h .connMx .Unlock ()
if _ , ok := h .handlers [id ]; ok {
h .logger .Debugf ("Not adding connection ID %s, as it already exists." , id )
return false
}
h .handlers [id ] = handler
h .logger .Debugf ("Adding connection ID %s." , id )
return true
}
func (h *packetHandlerMap ) Get (connID protocol .ConnectionID ) (packetHandler , bool ) {
h .connMx .Lock ()
defer h .connMx .Unlock ()
handler , ok := h .handlers [connID ]
return handler , ok
}
func (h *packetHandlerMap ) AddResetToken (token protocol .StatelessResetToken , handler packetHandler ) {
h .connMx .Lock ()
h .resetTokens [token ] = handler
h .connMx .Unlock ()
}
func (h *packetHandlerMap ) RemoveResetToken (token protocol .StatelessResetToken ) {
h .connMx .Lock ()
delete (h .resetTokens , token )
h .connMx .Unlock ()
}
func (h *packetHandlerMap ) AddWithConnID (clientDestConnID , newConnID protocol .ConnectionID , handler packetHandler ) bool {
h .connMx .Lock ()
defer h .connMx .Unlock ()
if _ , ok := h .handlers [clientDestConnID ]; ok {
h .logger .Debugf ("Not adding connection ID %s for a new connection, as it already exists." , clientDestConnID )
return false
}
h .handlers [clientDestConnID ] = handler
h .handlers [newConnID ] = handler
h .logger .Debugf ("Adding connection IDs %s and %s for a new connection." , clientDestConnID , newConnID )
return true
}
func (h *packetHandlerMap ) Remove (id protocol .ConnectionID ) {
h .connMx .Lock ()
delete (h .handlers , id )
h .connMx .Unlock ()
h .logger .Debugf ("Removing connection ID %s." , id )
}
func (h *packetHandlerMap ) ReplaceWithClosed (ids []protocol .ConnectionID , connClosePacket []byte , expiry time .Duration ) {
var handler packetHandler
if connClosePacket != nil {
handler = newClosedLocalConn (
func (addr net .Addr , info packetInfo ) {
select {
case h .closeQueue <- closePacket {payload : connClosePacket , addr : addr , info : info }:
default :
}
},
h .logger ,
)
} else {
handler = newClosedRemoteConn ()
}
h .connMx .Lock ()
for _ , id := range ids {
h .handlers [id ] = handler
}
h .connMx .Unlock ()
h .logger .Debugf ("Replacing connection for connection IDs %s with a closed connection." , ids )
time .AfterFunc (expiry , func () {
h .connMx .Lock ()
for _ , id := range ids {
delete (h .handlers , id )
}
if len (h .handlers ) == 0 {
t := (*Transport )(h )
t .mutex .Lock ()
t .maybeStopListening ()
t .mutex .Unlock ()
}
h .connMx .Unlock ()
h .logger .Debugf ("Removing connection IDs %s for a closed connection after it has been retired." , ids )
})
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .