package quicreuse

import (
	
	
	
	
	
	
	

	
	
)

type RefCountedQUICTransport interface {
	LocalAddr() net.Addr

	// Used to send packets directly around QUIC. Useful for hole punching.
	WriteTo([]byte, net.Addr) (int, error)

	Close() error

	// count transport reference
	DecreaseCount()
	IncreaseCount()

	Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
	Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error)
}

type singleOwnerTransport struct {
	Transport QUICTransport

	// Used to write packets directly around QUIC.
	packetConn net.PacketConn
}

var _ QUICTransport = &singleOwnerTransport{}
var _ RefCountedQUICTransport = (*singleOwnerTransport)(nil)

func ( *singleOwnerTransport) () {}
func ( *singleOwnerTransport) () { .Transport.Close() }
func ( *singleOwnerTransport) () net.Addr {
	return .packetConn.LocalAddr()
}

func ( *singleOwnerTransport) ( context.Context,  net.Addr,  *tls.Config,  *quic.Config) (quic.Connection, error) {
	return .Transport.Dial(, , , )
}

func ( *singleOwnerTransport) ( context.Context,  []byte) (int, net.Addr, error) {
	return .Transport.ReadNonQUICPacket(, )
}

func ( *singleOwnerTransport) () error {
	// TODO(when we drop support for go 1.19) use errors.Join
	.Transport.Close()
	return .packetConn.Close()
}

func ( *singleOwnerTransport) ( []byte,  net.Addr) (int, error) {
	return .Transport.WriteTo(, )
}

func ( *singleOwnerTransport) ( *tls.Config,  *quic.Config) (QUICListener, error) {
	return .Transport.Listen(, )
}

// Constant. Defined as variables to simplify testing.
var (
	garbageCollectInterval = 30 * time.Second
	maxUnusedDuration      = 10 * time.Second
)

type refcountedTransport struct {
	QUICTransport

	// Used to write packets directly around QUIC.
	packetConn net.PacketConn

	mutex       sync.Mutex
	refCount    int
	unusedSince time.Time

	// Only set for transports we are borrowing.
	// If set, we will _never_ close the underlying transport. We only close this
	// channel to signal to the owner that we are done with it.
	borrowDoneSignal chan struct{}

	assocations map[any]struct{}
}

type connContextFunc = func(context.Context, *quic.ClientInfo) (context.Context, error)

// associate an arbitrary value with this transport.
// This lets us "tag" the refcountedTransport when listening so we can use it
// later for dialing. Necessary for holepunching and learning about our own
// observed listening address.
func ( *refcountedTransport) ( any) {
	if  == nil {
		return
	}
	.mutex.Lock()
	defer .mutex.Unlock()
	if .assocations == nil {
		.assocations = make(map[any]struct{})
	}
	.assocations[] = struct{}{}
}

// hasAssociation returns true if the transport has the given association.
// If it is a nil association, it will always return true.
func ( *refcountedTransport) ( any) bool {
	if  == nil {
		return true
	}
	.mutex.Lock()
	defer .mutex.Unlock()
	,  := .assocations[]
	return 
}

func ( *refcountedTransport) () {
	.mutex.Lock()
	.refCount++
	.unusedSince = time.Time{}
	.mutex.Unlock()
}

func ( *refcountedTransport) () error {
	if .borrowDoneSignal != nil {
		close(.borrowDoneSignal)
		return nil
	}

	return errors.Join(.QUICTransport.Close(), .packetConn.Close())
}

func ( *refcountedTransport) ( []byte,  net.Addr) (int, error) {
	return .QUICTransport.WriteTo(, )
}

func ( *refcountedTransport) () net.Addr {
	return .packetConn.LocalAddr()
}

func ( *refcountedTransport) ( *tls.Config,  *quic.Config) (QUICListener, error) {
	return .QUICTransport.Listen(, )
}

func ( *refcountedTransport) () {
	.mutex.Lock()
	.refCount--
	if .refCount == 0 {
		.unusedSince = time.Now()
	}
	.mutex.Unlock()
}

func ( *refcountedTransport) ( time.Time) bool {
	.mutex.Lock()
	defer .mutex.Unlock()
	return !.unusedSince.IsZero() && .unusedSince.Add(maxUnusedDuration).Before()
}

type reuse struct {
	mutex sync.Mutex

	closeChan  chan struct{}
	gcStopChan chan struct{}

	listenUDP listenUDP

	sourceIPSelectorFn func() (SourceIPSelector, error)

	routes  SourceIPSelector
	unicast map[string] /* IP.String() */ map[int] /* port */ *refcountedTransport
	// globalListeners contains transports that are listening on 0.0.0.0 / ::
	globalListeners map[int]*refcountedTransport
	// globalDialers contains transports that we've dialed out from. These transports are listening on 0.0.0.0 / ::
	// On Dial, transports are reused from this map if no transport is available in the globalListeners
	// On Listen, transports are reused from this map if the requested port is 0, and then moved to globalListeners
	globalDialers map[int]*refcountedTransport

	statelessResetKey   *quic.StatelessResetKey
	tokenGeneratorKey   *quic.TokenGeneratorKey
	connContext         connContextFunc
	verifySourceAddress func(addr net.Addr) bool
}

func newReuse( *quic.StatelessResetKey,  *quic.TokenGeneratorKey,  listenUDP,  func() (SourceIPSelector, error),
	 connContextFunc,  func( net.Addr) bool) *reuse {
	 := &reuse{
		unicast:             make(map[string]map[int]*refcountedTransport),
		globalListeners:     make(map[int]*refcountedTransport),
		globalDialers:       make(map[int]*refcountedTransport),
		closeChan:           make(chan struct{}),
		gcStopChan:          make(chan struct{}),
		listenUDP:           ,
		sourceIPSelectorFn:  ,
		statelessResetKey:   ,
		tokenGeneratorKey:   ,
		connContext:         ,
		verifySourceAddress: ,
	}
	go .gc()
	return 
}

func ( *reuse) () {
	defer func() {
		.mutex.Lock()
		for ,  := range .globalListeners {
			.Close()
		}
		for ,  := range .globalDialers {
			.Close()
		}
		for ,  := range .unicast {
			for ,  := range  {
				.Close()
			}
		}
		.mutex.Unlock()
		close(.gcStopChan)
	}()
	 := time.NewTicker(garbageCollectInterval)
	defer .Stop()

	for {
		select {
		case <-.closeChan:
			return
		case <-.C:
			 := time.Now()
			.mutex.Lock()
			for ,  := range .globalListeners {
				if .ShouldGarbageCollect() {
					.Close()
					delete(.globalListeners, )
				}
			}
			for ,  := range .globalDialers {
				if .ShouldGarbageCollect() {
					.Close()
					delete(.globalDialers, )
				}
			}
			for ,  := range .unicast {
				for ,  := range  {
					if .ShouldGarbageCollect() {
						.Close()
						delete(, )
					}
				}
				if len() == 0 {
					delete(.unicast, )
					// If we've dropped all transports with a unicast binding,
					// assume our routes may have changed.
					if len(.unicast) == 0 {
						.routes = nil
					} else {
						// Ignore the error, there's nothing we can do about
						// it.
						.routes, _ = .sourceIPSelectorFn()
					}
				}
			}
			.mutex.Unlock()
		}
	}
}

func ( *reuse) ( any,  string,  *net.UDPAddr) (*refcountedTransport, error) {
	var  *net.IP

	// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
	// Otherwise, save some time.

	.mutex.Lock()
	 := .routes
	.mutex.Unlock()

	if  != nil {
		,  := .PreferredSourceIPForDestination()
		if  == nil && !.IsUnspecified() {
			 = &
		}
	}

	.mutex.Lock()
	defer .mutex.Unlock()

	,  := .transportForDialLocked(, , )
	if  != nil {
		return nil, 
	}
	.IncreaseCount()
	return , nil
}

func ( *reuse) ( any,  string,  *net.IP) (*refcountedTransport, error) {
	if  != nil {
		// We already have at least one suitable transport...
		if ,  := .unicast[.String()];  {
			// Prefer a transport that has the given association. We want to
			// reuse the transport the association used for listening.
			for ,  := range  {
				if .hasAssociation() {
					return , nil
				}
			}
			// We don't have a transport with the association, use any one
			for ,  := range  {
				return , nil
			}
		}
	}

	// Use a transport listening on 0.0.0.0 (or ::).
	// Again, prefer a transport that has the given association.
	for ,  := range .globalListeners {
		if .hasAssociation() {
			return , nil
		}
	}
	// We don't have a transport with the association, use any one
	for ,  := range .globalListeners {
		return , nil
	}

	// Use a transport we've previously dialed from
	for ,  := range .globalDialers {
		return , nil
	}

	// We don't have a transport that we can use for dialing.
	// Dial a new connection from a random port.
	var  *net.UDPAddr
	switch  {
	case "udp4":
		 = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
	case "udp6":
		 = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
	}
	,  := .listenUDP(, )
	if  != nil {
		return nil, 
	}
	 := .newTransport()
	.globalDialers[.LocalAddr().(*net.UDPAddr).Port] = 
	return , nil
}

func ( *reuse) ( *refcountedTransport,  *net.UDPAddr) error {
	.mutex.Lock()
	defer .mutex.Unlock()

	if !.IP.IsUnspecified() {
		return errors.New("adding transport for specific IP not supported")
	}
	if ,  := .globalDialers[.Port];  {
		return fmt.Errorf("already have global dialer for port %d", .Port)
	}
	.globalDialers[.Port] = 
	return nil
}

func ( *reuse) ( RefCountedQUICTransport) error {
	,  := .(*refcountedTransport)
	if ! {
		return fmt.Errorf("invalid transport type: expected: *refcountedTransport, got: %T", )
	}
	 := .LocalAddr().(*net.UDPAddr)
	if .IP.IsUnspecified() {
		if ,  := .globalListeners[.Port];  {
			if  ==  {
				return nil
			}
			return errors.New("two global listeners on the same port")
		}
		return errors.New("transport not found")
	}
	if ,  := .unicast[.IP.String()];  {
		if ,  := [.Port];  {
			if  ==  {
				return nil
			}
			return errors.New("two unicast listeners on same ip:port")
		}
		return errors.New("transport not found")
	}
	return errors.New("transport not found")
}

func ( *reuse) ( string,  *net.UDPAddr) (*refcountedTransport, error) {
	.mutex.Lock()
	defer .mutex.Unlock()

	// Check if we can reuse a transport we have already dialed out from.
	// We reuse a transport from globalDialers when the requested port is 0 or the requested
	// port is already in the globalDialers.
	// If we are reusing a transport from globalDialers, we move the globalDialers entry to
	// globalListeners
	if .IP.IsUnspecified() {
		var  *refcountedTransport
		var  *net.UDPAddr

		if .Port == 0 {
			// the requested port is 0, we can reuse any transport
			for ,  := range .globalDialers {
				 = 
				 = .LocalAddr().(*net.UDPAddr)
				delete(.globalDialers, .Port)
				break
			}
		} else if ,  := .globalDialers[.Port];  {
			 = .globalDialers[.Port]
			 = .LocalAddr().(*net.UDPAddr)
			delete(.globalDialers, .Port)
		}
		// found a match
		if  != nil {
			.IncreaseCount()
			.globalListeners[.Port] = 
			return , nil
		}
	}

	,  := .listenUDP(, )
	if  != nil {
		return nil, 
	}
	 := .newTransport()
	.IncreaseCount()

	 := .LocalAddr().(*net.UDPAddr)
	// Deal with listen on a global address
	if .IP.IsUnspecified() {
		// The kernel already checked that the laddr is not already listen
		// so we need not check here (when we create ListenUDP).
		.globalListeners[.Port] = 
		return , nil
	}

	// Deal with listen on a unicast address
	if ,  := .unicast[.IP.String()]; ! {
		.unicast[.IP.String()] = make(map[int]*refcountedTransport)
		// Assume the system's routes may have changed if we're adding a new listener.
		// Ignore the error, there's nothing we can do.
		.routes, _ = .sourceIPSelectorFn()
	}

	// The kernel already checked that the laddr is not already listen
	// so we need not check here (when we create ListenUDP).
	.unicast[.IP.String()][.Port] = 
	return , nil
}

func ( *reuse) ( net.PacketConn) *refcountedTransport {
	return &refcountedTransport{
		QUICTransport: &wrappedQUICTransport{
			Transport: newQUICTransport(
				,
				.tokenGeneratorKey,
				.statelessResetKey,
				.connContext,
				.verifySourceAddress,
			),
		},
		packetConn: ,
	}
}

func ( *reuse) () error {
	close(.closeChan)
	<-.gcStopChan
	return nil
}

type SourceIPSelector interface {
	PreferredSourceIPForDestination(dst *net.UDPAddr) (net.IP, error)
}

type netrouteSourceIPSelector struct {
	routes routing.Router
}

func ( *netrouteSourceIPSelector) ( *net.UDPAddr) (net.IP, error) {
	, , ,  := .routes.Route(.IP)
	return , 
}