package quic

import (
	
	
	
	
	
	
	
	
	

	
	
	
	
)

// ErrTransportClosed is returned by the [Transport]'s Listen or Dial method after it was closed.
var ErrTransportClosed = &errTransportClosed{}

type errTransportClosed struct {
	err error
}

func ( *errTransportClosed) () []error { return []error{net.ErrClosed, .err} }

func ( *errTransportClosed) () string {
	if .err == nil {
		return "quic: transport closed"
	}
	return fmt.Sprintf("quic: transport closed: %s", .err)
}

func ( *errTransportClosed) ( error) bool {
	,  := .(*errTransportClosed)
	return 
}

var errListenerAlreadySet = errors.New("listener already set")

type closePacket struct {
	payload []byte
	addr    net.Addr
	info    packetInfo
}

// The Transport is the central point to manage incoming and outgoing QUIC connections.
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
// This means that a single UDP socket can be used for listening for incoming connections, as well as
// for dialing an arbitrary number of outgoing connections.
// A Transport handles a single net.PacketConn, and offers a range of configuration options
// compared to the simple helper functions like [Listen] and [Dial] that this package provides.
type Transport struct {
	// A single net.PacketConn can only be handled by one Transport.
	// Bad things will happen if passed to multiple Transports.
	//
	// A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface,
	// as a *net.UDPConn does.
	// 1. It enables the Don't Fragment (DF) bit on the IP header.
	//    This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
	// 2. It enables reading of the ECN bits from the IP header.
	//    This allows the remote node to speed up its loss detection and recovery.
	// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
	// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
	//
	// After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection.
	Conn net.PacketConn

	// The length of the connection ID in bytes.
	// It can be any value between 1 and 20.
	// Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes.
	// If unset, a 4 byte connection ID will be used.
	ConnectionIDLength int

	// Use for generating new connection IDs.
	// This allows the application to control of the connection IDs used,
	// which allows routing / load balancing based on connection IDs.
	// All Connection IDs returned by the ConnectionIDGenerator MUST
	// have the same length.
	ConnectionIDGenerator ConnectionIDGenerator

	// The StatelessResetKey is used to generate stateless reset tokens.
	// If no key is configured, sending of stateless resets is disabled.
	// It is highly recommended to configure a stateless reset key, as stateless resets
	// allow the peer to quickly recover from crashes and reboots of this node.
	// See section 10.3 of RFC 9000 for details.
	StatelessResetKey *StatelessResetKey

	// The TokenGeneratorKey is used to encrypt session resumption tokens.
	// If no key is configured, a random key will be generated.
	// If multiple servers are authoritative for the same domain, they should use the same key,
	// see section 8.1.3 of RFC 9000 for details.
	TokenGeneratorKey *TokenGeneratorKey

	// MaxTokenAge is the maximum age of the resumption token presented during the handshake.
	// These tokens allow skipping address resumption when resuming a QUIC connection,
	// and are especially useful when using 0-RTT.
	// If not set, it defaults to 24 hours.
	// See section 8.1.3 of RFC 9000 for details.
	MaxTokenAge time.Duration

	// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
	// This can be useful if version information is exchanged out-of-band.
	// It has no effect for clients.
	DisableVersionNegotiationPackets bool

	// VerifySourceAddress decides if a connection attempt originating from unvalidated source
	// addresses first needs to go through source address validation using QUIC's Retry mechanism,
	// as described in RFC 9000 section 8.1.2.
	// Note that the address passed to this callback is unvalidated, and might be spoofed in case
	// of an attack.
	// Validating the source address adds one additional network roundtrip to the handshake,
	// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
	// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
	// implementation of this callback (negating its return value).
	VerifySourceAddress func(net.Addr) bool

	// ConnContext is called when the server accepts a new connection. To reject a connection return
	// a non-nil error.
	// The context is closed when the connection is closed, or when the handshake fails for any reason.
	// The context returned from the callback is used to derive every other context used during the
	// lifetime of the connection:
	// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
	// * the context used in Config.Tracer
	// * the context returned from Connection.Context
	// * the context returned from SendStream.Context
	// It is not used for dialed connections.
	ConnContext func(context.Context, *ClientInfo) (context.Context, error)

	// A Tracer traces events that don't belong to a single QUIC connection.
	// Tracer.Close is called when the transport is closed.
	Tracer *logging.Tracer

	connMx      sync.Mutex
	handlers    map[protocol.ConnectionID]packetHandler
	resetTokens map[protocol.StatelessResetToken]packetHandler

	mutex    sync.Mutex
	initOnce sync.Once
	initErr  error

	// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
	connIDLen int
	// Set in init.
	// If no ConnectionIDGenerator is set, this is set to a default.
	connIDGenerator   ConnectionIDGenerator
	statelessResetter *statelessResetter

	server *baseServer

	conn rawConn

	closeQueue          chan closePacket
	statelessResetQueue chan receivedPacket

	listening   chan struct{} // is closed when listen returns
	closeErr    error
	createdConn bool
	isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial

	readingNonQUICPackets atomic.Bool
	nonQUICPackets        chan receivedPacket

	logger utils.Logger
}

// Listen starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current listener was closed.
func ( *Transport) ( *tls.Config,  *Config) (*Listener, error) {
	,  := .createServer(, , false)
	if  != nil {
		return nil, 
	}
	return &Listener{baseServer: }, nil
}

// ListenEarly starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// ListenEarly may only be called again after the current listener was closed.
func ( *Transport) ( *tls.Config,  *Config) (*EarlyListener, error) {
	,  := .createServer(, , true)
	if  != nil {
		return nil, 
	}
	return &EarlyListener{baseServer: }, nil
}

func ( *Transport) ( *tls.Config,  *Config,  bool) (*baseServer, error) {
	if  == nil {
		return nil, errors.New("quic: tls.Config not set")
	}
	if  := validateConfig();  != nil {
		return nil, 
	}

	.mutex.Lock()
	defer .mutex.Unlock()

	if .closeErr != nil {
		return nil, .closeErr
	}
	if .server != nil {
		return nil, errListenerAlreadySet
	}
	 = populateConfig()
	if  := .init(false);  != nil {
		return nil, 
	}
	 := .MaxTokenAge
	if  == 0 {
		 = 24 * time.Hour
	}
	 := newServer(
		.conn,
		(*packetHandlerMap)(),
		.connIDGenerator,
		.statelessResetter,
		.ConnContext,
		,
		,
		.Tracer,
		.closeServer,
		*.TokenGeneratorKey,
		,
		.VerifySourceAddress,
		.DisableVersionNegotiationPackets,
		,
	)
	.server = 
	return , nil
}

// Dial dials a new connection to a remote host (not using 0-RTT).
func ( *Transport) ( context.Context,  net.Addr,  *tls.Config,  *Config) (Connection, error) {
	return .dial(, , "", , , false)
}

// DialEarly dials a new connection, attempting to use 0-RTT if possible.
func ( *Transport) ( context.Context,  net.Addr,  *tls.Config,  *Config) (EarlyConnection, error) {
	return .dial(, , "", , , true)
}

func ( *Transport) ( context.Context,  net.Addr,  string,  *tls.Config,  *Config,  bool) (EarlyConnection, error) {
	if  := .init(.isSingleUse);  != nil {
		return nil, 
	}
	if  := validateConfig();  != nil {
		return nil, 
	}
	 = populateConfig()
	 = .Clone()
	setTLSConfigServerName(, , )
	return .doDial(,
		newSendConn(.conn, , packetInfo{}, utils.DefaultLogger),
		,
		,
		0,
		false,
		,
		.Versions[0],
	)
}

func ( *Transport) (
	 context.Context,
	 sendConn,
	 *tls.Config,
	 *Config,
	 protocol.PacketNumber,
	 bool,
	 bool,
	 protocol.Version,
) (quicConn, error) {
	,  := .connIDGenerator.GenerateConnectionID()
	if  != nil {
		return nil, 
	}
	,  := generateConnectionIDForInitial()
	if  != nil {
		return nil, 
	}

	 := nextConnTracingID()
	 = context.WithValue(, ConnectionTracingKey, )

	.mutex.Lock()
	if .closeErr != nil {
		.mutex.Unlock()
		return nil, .closeErr
	}

	var  *logging.ConnectionTracer
	if .Tracer != nil {
		 = .Tracer(, protocol.PerspectiveClient, )
	}
	if  != nil && .StartedConnection != nil {
		.StartedConnection(.LocalAddr(), .RemoteAddr(), , )
	}

	 := utils.DefaultLogger.WithPrefix("client")
	.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", .ServerName, .LocalAddr(), .RemoteAddr(), , , )

	 := newClientConnection(
		context.WithoutCancel(),
		,
		(*packetHandlerMap)(),
		,
		,
		.connIDGenerator,
		.statelessResetter,
		,
		,
		,
		,
		,
		,
		,
		,
	)
	.connMx.Lock()
	.handlers[] = 
	.connMx.Unlock()
	.mutex.Unlock()

	// The error channel needs to be buffered, as the run loop will continue running
	// after doDial returns (if the handshake is successful).
	 := make(chan error, 1)
	 := make(chan errCloseForRecreating)
	go func() {
		 := .run()
		var  *errCloseForRecreating
		if errors.As(, &) {
			 <- *
			return
		}
		if .isSingleUse {
			.Close()
		}
		 <- 
	}()

	// Only set when we're using 0-RTT.
	// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
	var  <-chan struct{}
	if  {
		 = .earlyConnReady()
	}

	select {
	case <-.Done():
		.destroy(nil)
		// wait until the Go routine that called Connection.run() returns
		select {
		case <-:
		case <-:
		}
		return nil, context.Cause()
	case  := <-:
		return .(,
			,
			,
			,
			.nextPacketNumber,
			true,
			,
			.nextVersion,
		)
	case  := <-:
		return nil, 
	case <-:
		// ready to send 0-RTT data
		return , nil
	case <-.HandshakeComplete():
		// handshake successfully completed
		return , nil
	}
}

func ( *Transport) ( bool) error {
	.initOnce.Do(func() {
		var  rawConn
		if ,  := .Conn.(rawConn);  {
			 = 
		} else {
			var  error
			,  = wrapConn(.Conn)
			if  != nil {
				.initErr = 
				return
			}
		}

		.logger = utils.DefaultLogger // TODO: make this configurable
		.conn = 
		.handlers = make(map[protocol.ConnectionID]packetHandler)
		.resetTokens = make(map[protocol.StatelessResetToken]packetHandler)
		.listening = make(chan struct{})

		.closeQueue = make(chan closePacket, 4)
		.statelessResetQueue = make(chan receivedPacket, 4)
		if .TokenGeneratorKey == nil {
			var  TokenGeneratorKey
			if ,  := rand.Read([:]);  != nil {
				.initErr = 
				return
			}
			.TokenGeneratorKey = &
		}

		if .ConnectionIDGenerator != nil {
			.connIDGenerator = .ConnectionIDGenerator
			.connIDLen = .ConnectionIDGenerator.ConnectionIDLen()
		} else {
			 := .ConnectionIDLength
			if .ConnectionIDLength == 0 && ! {
				 = protocol.DefaultConnectionIDLength
			}
			.connIDLen = 
			.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: .connIDLen}
		}
		.statelessResetter = newStatelessResetter(.StatelessResetKey)

		go func() {
			defer close(.listening)
			.listen()

			if .createdConn {
				.Close()
			}
		}()
		go .runSendQueue()
	})
	return .initErr
}

// WriteTo sends a packet on the underlying connection.
func ( *Transport) ( []byte,  net.Addr) (int, error) {
	if  := .init(false);  != nil {
		return 0, 
	}
	return .conn.WritePacket(, , nil, 0, protocol.ECNUnsupported)
}

func ( *Transport) () {
	for {
		select {
		case <-.listening:
			return
		case  := <-.closeQueue:
			.conn.WritePacket(.payload, .addr, .info.OOB(), 0, protocol.ECNUnsupported)
		case  := <-.statelessResetQueue:
			.sendStatelessReset()
		}
	}
}

// Close stops listening for UDP datagrams on the Transport.Conn.
// If any listener was started, it will be closed as well.
// It is invalid to start new listeners or connections after that.
func ( *Transport) () error {
	// avoid race condition if the transport is currently being initialized
	.init(false)

	.close(nil)
	if .createdConn {
		if  := .Conn.Close();  != nil {
			return 
		}
	} else if .conn != nil {
		.conn.SetReadDeadline(time.Now())
		defer func() { .conn.SetReadDeadline(time.Time{}) }()
	}
	if .listening != nil {
		<-.listening // wait until listening returns
	}
	return nil
}

func ( *Transport) () {
	.mutex.Lock()
	defer .mutex.Unlock()

	.server = nil
	if .isSingleUse {
		.closeErr = ErrServerClosed
	}

	.connMx.Lock()
	defer .connMx.Unlock()
	if len(.handlers) == 0 {
		.maybeStopListening()
	}
}

func ( *Transport) ( error) {
	.mutex.Lock()
	defer .mutex.Unlock()

	if .closeErr != nil {
		return
	}

	 = &errTransportClosed{err: }

	var  sync.WaitGroup
	.connMx.Lock()
	for ,  := range .handlers {
		.Add(1)
		go func( packetHandler) {
			.destroy()
			.Done()
		}()
	}
	.connMx.Unlock()
	.Wait()

	if .server != nil {
		.server.close(, false)
	}
	if .Tracer != nil && .Tracer.Close != nil {
		.Tracer.Close()
	}
	.closeErr = 
}

// only print warnings about the UDP receive buffer size once
var setBufferWarningOnce sync.Once

func ( *Transport) ( rawConn) {
	for {
		,  := .ReadPacket()
		//nolint:staticcheck // SA1019 ignore this!
		// TODO: This code is used to ignore wsa errors on Windows.
		// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
		// See https://github.com/quic-go/quic-go/issues/1737 for details.
		if ,  := .(net.Error);  && .Temporary() {
			.mutex.Lock()
			 := .closeErr != nil
			.mutex.Unlock()
			if  {
				return
			}
			.logger.Debugf("Temporary error reading from conn: %w", )
			continue
		}
		if  != nil {
			// Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer.
			if isRecvMsgSizeErr() {
				continue
			}
			.close()
			return
		}
		.handlePacket()
	}
}

func ( *Transport) () {
	if .isSingleUse && .closeErr != nil {
		.conn.SetReadDeadline(time.Now())
	}
}

func ( *Transport) ( receivedPacket) {
	if len(.data) == 0 {
		return
	}
	if !wire.IsPotentialQUICPacket(.data[0]) && !wire.IsLongHeaderPacket(.data[0]) {
		.handleNonQUICPacket()
		return
	}
	,  := wire.ParseConnectionID(.data, .connIDLen)
	if  != nil {
		.logger.Debugf("error parsing connection ID on packet from %s: %s", .remoteAddr, )
		if .Tracer != nil && .Tracer.DroppedPacket != nil {
			.Tracer.DroppedPacket(.remoteAddr, logging.PacketTypeNotDetermined, .Size(), logging.PacketDropHeaderParseError)
		}
		.buffer.MaybeRelease()
		return
	}

	// If there's a connection associated with the connection ID, pass the packet there.
	if ,  := (*packetHandlerMap)().Get();  {
		.handlePacket()
		return
	}
	// RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both
	// packets that cannot be associated with any connections, and for packets that can't be decrypted.
	// We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an
	// existing connection, it is dropped there if if it can't be decrypted.
	// Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are
	// exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection,
	// it is to be expected that the next stateless reset will be correctly detected.
	if  := .maybeHandleStatelessReset(.data);  {
		return
	}
	if !wire.IsLongHeaderPacket(.data[0]) {
		if  := .maybeSendStatelessReset(); ! {
			if .Tracer != nil && .Tracer.DroppedPacket != nil {
				.Tracer.DroppedPacket(.remoteAddr, logging.PacketTypeNotDetermined, .Size(), logging.PacketDropUnknownConnectionID)
			}
			.buffer.Release()
		}
		return
	}

	.mutex.Lock()
	defer .mutex.Unlock()
	if .server == nil { // no server set
		.logger.Debugf("received a packet with an unexpected connection ID %s", )
		if .Tracer != nil && .Tracer.DroppedPacket != nil {
			.Tracer.DroppedPacket(.remoteAddr, logging.PacketTypeNotDetermined, .Size(), logging.PacketDropUnknownConnectionID)
		}
		.buffer.MaybeRelease()
		return
	}
	.server.handlePacket()
}

func ( *Transport) ( receivedPacket) ( bool) {
	if .StatelessResetKey == nil {
		return false
	}

	// Don't send a stateless reset in response to very small packets.
	// This includes packets that could be stateless resets.
	if len(.data) <= protocol.MinStatelessResetSize {
		return false
	}

	select {
	case .statelessResetQueue <- :
		return true
	default:
		// it's fine to not send a stateless reset when we're busy
		return false
	}
}

func ( *Transport) ( receivedPacket) {
	defer .buffer.Release()

	,  := wire.ParseConnectionID(.data, .connIDLen)
	if  != nil {
		.logger.Errorf("error parsing connection ID on packet from %s: %s", .remoteAddr, )
		return
	}
	 := .statelessResetter.GetStatelessResetToken()
	.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", .remoteAddr, , )
	 := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
	rand.Read()
	[0] = ([0] & 0x7f) | 0x40
	 = append(, [:]...)
	if ,  := .conn.WritePacket(, .remoteAddr, .info.OOB(), 0, protocol.ECNUnsupported);  != nil {
		.logger.Debugf("Error sending Stateless Reset to %s: %s", .remoteAddr, )
	}
}

func ( *Transport) ( []byte) bool {
	// stateless resets are always short header packets
	if wire.IsLongHeaderPacket([0]) {
		return false
	}
	if len() < 17 /* type byte + 16 bytes for the reset token */ {
		return false
	}

	 := protocol.StatelessResetToken([len()-16:])
	.connMx.Lock()
	,  := .resetTokens[]
	.connMx.Unlock()

	if  {
		.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", )
		go .destroy(&StatelessResetError{})
		return true
	}
	return false
}

func ( *Transport) ( receivedPacket) {
	// Strictly speaking, this is racy,
	// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
	if !.readingNonQUICPackets.Load() {
		return
	}
	select {
	case .nonQUICPackets <- :
	default:
		if .Tracer != nil && .Tracer.DroppedPacket != nil {
			.Tracer.DroppedPacket(.remoteAddr, logging.PacketTypeNotDetermined, .Size(), logging.PacketDropDOSPrevention)
		}
	}
}

const maxQueuedNonQUICPackets = 32

// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
// Note that this is stricter than the detection logic defined in RFC 9443.
func ( *Transport) ( context.Context,  []byte) (int, net.Addr, error) {
	if  := .init(false);  != nil {
		return 0, nil, 
	}
	if !.readingNonQUICPackets.Load() {
		.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
		.readingNonQUICPackets.Store(true)
	}
	select {
	case <-.Done():
		return 0, nil, .Err()
	case  := <-.nonQUICPackets:
		 := copy(, .data)
		return , .remoteAddr, nil
	case <-.listening:
		return 0, nil, errors.New("closed")
	}
}

func setTLSConfigServerName( *tls.Config,  net.Addr,  string) {
	// If no ServerName is set, infer the ServerName from the host we're connecting to.
	if .ServerName != "" {
		return
	}
	if  == "" {
		if ,  := .(*net.UDPAddr);  {
			.ServerName = .IP.String()
			return
		}
	}
	, ,  := net.SplitHostPort()
	if  != nil { // This happens if the host doesn't contain a port number.
		.ServerName = 
		return
	}
	.ServerName = 
}

type packetHandlerMap Transport

var _ connRunner = &packetHandlerMap{}

func ( *packetHandlerMap) ( protocol.ConnectionID,  packetHandler) bool /* was added */ {
	.connMx.Lock()
	defer .connMx.Unlock()

	if ,  := .handlers[];  {
		.logger.Debugf("Not adding connection ID %s, as it already exists.", )
		return false
	}
	.handlers[] = 
	.logger.Debugf("Adding connection ID %s.", )
	return true
}

func ( *packetHandlerMap) ( protocol.ConnectionID) (packetHandler, bool) {
	.connMx.Lock()
	defer .connMx.Unlock()
	,  := .handlers[]
	return , 
}

func ( *packetHandlerMap) ( protocol.StatelessResetToken,  packetHandler) {
	.connMx.Lock()
	.resetTokens[] = 
	.connMx.Unlock()
}

func ( *packetHandlerMap) ( protocol.StatelessResetToken) {
	.connMx.Lock()
	delete(.resetTokens, )
	.connMx.Unlock()
}

func ( *packetHandlerMap) (,  protocol.ConnectionID,  packetHandler) bool {
	.connMx.Lock()
	defer .connMx.Unlock()

	if ,  := .handlers[];  {
		.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", )
		return false
	}
	.handlers[] = 
	.handlers[] = 
	.logger.Debugf("Adding connection IDs %s and %s for a new connection.", , )
	return true
}

func ( *packetHandlerMap) ( protocol.ConnectionID) {
	.connMx.Lock()
	delete(.handlers, )
	.connMx.Unlock()
	.logger.Debugf("Removing connection ID %s.", )
}

// ReplaceWithClosed is called when a connection is closed.
// Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
func ( *packetHandlerMap) ( []protocol.ConnectionID,  []byte,  time.Duration) {
	var  packetHandler
	if  != nil {
		 = newClosedLocalConn(
			func( net.Addr,  packetInfo) {
				select {
				case .closeQueue <- closePacket{payload: , addr: , info: }:
				default:
					// We're backlogged.
					// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
				}
			},
			.logger,
		)
	} else {
		 = newClosedRemoteConn()
	}

	.connMx.Lock()
	for ,  := range  {
		.handlers[] = 
	}
	.connMx.Unlock()
	.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", )

	time.AfterFunc(, func() {
		.connMx.Lock()
		for ,  := range  {
			delete(.handlers, )
		}
		if len(.handlers) == 0 {
			 := (*Transport)()
			.mutex.Lock()
			.maybeStopListening()
			.mutex.Unlock()
		}
		.connMx.Unlock()
		.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", )
	})
}