package quicreuse
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/google/gopacket/routing"
"github.com/quic-go/quic-go"
)
type RefCountedQUICTransport interface {
LocalAddr () net .Addr
WriteTo ([]byte , net .Addr ) (int , error )
Close () error
DecreaseCount ()
IncreaseCount ()
Dial (ctx context .Context , addr net .Addr , tlsConf *tls .Config , conf *quic .Config ) (quic .Connection , error )
Listen (tlsConf *tls .Config , conf *quic .Config ) (QUICListener , error )
}
type singleOwnerTransport struct {
Transport QUICTransport
packetConn net .PacketConn
}
var _ QUICTransport = &singleOwnerTransport {}
var _ RefCountedQUICTransport = (*singleOwnerTransport )(nil )
func (c *singleOwnerTransport ) IncreaseCount () {}
func (c *singleOwnerTransport ) DecreaseCount () { c .Transport .Close () }
func (c *singleOwnerTransport ) LocalAddr () net .Addr {
return c .packetConn .LocalAddr ()
}
func (c *singleOwnerTransport ) Dial (ctx context .Context , addr net .Addr , tlsConf *tls .Config , conf *quic .Config ) (quic .Connection , error ) {
return c .Transport .Dial (ctx , addr , tlsConf , conf )
}
func (c *singleOwnerTransport ) ReadNonQUICPacket (ctx context .Context , b []byte ) (int , net .Addr , error ) {
return c .Transport .ReadNonQUICPacket (ctx , b )
}
func (c *singleOwnerTransport ) Close () error {
c .Transport .Close ()
return c .packetConn .Close ()
}
func (c *singleOwnerTransport ) WriteTo (b []byte , addr net .Addr ) (int , error ) {
return c .Transport .WriteTo (b , addr )
}
func (c *singleOwnerTransport ) Listen (tlsConf *tls .Config , conf *quic .Config ) (QUICListener , error ) {
return c .Transport .Listen (tlsConf , conf )
}
var (
garbageCollectInterval = 30 * time .Second
maxUnusedDuration = 10 * time .Second
)
type refcountedTransport struct {
QUICTransport
packetConn net .PacketConn
mutex sync .Mutex
refCount int
unusedSince time .Time
borrowDoneSignal chan struct {}
assocations map [any ]struct {}
}
type connContextFunc = func (context .Context , *quic .ClientInfo ) (context .Context , error )
func (c *refcountedTransport ) associate (a any ) {
if a == nil {
return
}
c .mutex .Lock ()
defer c .mutex .Unlock ()
if c .assocations == nil {
c .assocations = make (map [any ]struct {})
}
c .assocations [a ] = struct {}{}
}
func (c *refcountedTransport ) hasAssociation (a any ) bool {
if a == nil {
return true
}
c .mutex .Lock ()
defer c .mutex .Unlock ()
_ , ok := c .assocations [a ]
return ok
}
func (c *refcountedTransport ) IncreaseCount () {
c .mutex .Lock ()
c .refCount ++
c .unusedSince = time .Time {}
c .mutex .Unlock ()
}
func (c *refcountedTransport ) Close () error {
if c .borrowDoneSignal != nil {
close (c .borrowDoneSignal )
return nil
}
return errors .Join (c .QUICTransport .Close (), c .packetConn .Close ())
}
func (c *refcountedTransport ) WriteTo (b []byte , addr net .Addr ) (int , error ) {
return c .QUICTransport .WriteTo (b , addr )
}
func (c *refcountedTransport ) LocalAddr () net .Addr {
return c .packetConn .LocalAddr ()
}
func (c *refcountedTransport ) Listen (tlsConf *tls .Config , conf *quic .Config ) (QUICListener , error ) {
return c .QUICTransport .Listen (tlsConf , conf )
}
func (c *refcountedTransport ) DecreaseCount () {
c .mutex .Lock ()
c .refCount --
if c .refCount == 0 {
c .unusedSince = time .Now ()
}
c .mutex .Unlock ()
}
func (c *refcountedTransport ) ShouldGarbageCollect (now time .Time ) bool {
c .mutex .Lock ()
defer c .mutex .Unlock ()
return !c .unusedSince .IsZero () && c .unusedSince .Add (maxUnusedDuration ).Before (now )
}
type reuse struct {
mutex sync .Mutex
closeChan chan struct {}
gcStopChan chan struct {}
listenUDP listenUDP
sourceIPSelectorFn func () (SourceIPSelector , error )
routes SourceIPSelector
unicast map [string ] map [int ] *refcountedTransport
globalListeners map [int ]*refcountedTransport
globalDialers map [int ]*refcountedTransport
statelessResetKey *quic .StatelessResetKey
tokenGeneratorKey *quic .TokenGeneratorKey
connContext connContextFunc
verifySourceAddress func (addr net .Addr ) bool
}
func newReuse(srk *quic .StatelessResetKey , tokenKey *quic .TokenGeneratorKey , listenUDP listenUDP , sourceIPSelectorFn func () (SourceIPSelector , error ),
connContext connContextFunc , verifySourceAddress func (addr net .Addr ) bool ) *reuse {
r := &reuse {
unicast : make (map [string ]map [int ]*refcountedTransport ),
globalListeners : make (map [int ]*refcountedTransport ),
globalDialers : make (map [int ]*refcountedTransport ),
closeChan : make (chan struct {}),
gcStopChan : make (chan struct {}),
listenUDP : listenUDP ,
sourceIPSelectorFn : sourceIPSelectorFn ,
statelessResetKey : srk ,
tokenGeneratorKey : tokenKey ,
connContext : connContext ,
verifySourceAddress : verifySourceAddress ,
}
go r .gc ()
return r
}
func (r *reuse ) gc () {
defer func () {
r .mutex .Lock ()
for _ , tr := range r .globalListeners {
tr .Close ()
}
for _ , tr := range r .globalDialers {
tr .Close ()
}
for _ , trs := range r .unicast {
for _ , tr := range trs {
tr .Close ()
}
}
r .mutex .Unlock ()
close (r .gcStopChan )
}()
ticker := time .NewTicker (garbageCollectInterval )
defer ticker .Stop ()
for {
select {
case <- r .closeChan :
return
case <- ticker .C :
now := time .Now ()
r .mutex .Lock ()
for key , tr := range r .globalListeners {
if tr .ShouldGarbageCollect (now ) {
tr .Close ()
delete (r .globalListeners , key )
}
}
for key , tr := range r .globalDialers {
if tr .ShouldGarbageCollect (now ) {
tr .Close ()
delete (r .globalDialers , key )
}
}
for ukey , trs := range r .unicast {
for key , tr := range trs {
if tr .ShouldGarbageCollect (now ) {
tr .Close ()
delete (trs , key )
}
}
if len (trs ) == 0 {
delete (r .unicast , ukey )
if len (r .unicast ) == 0 {
r .routes = nil
} else {
r .routes , _ = r .sourceIPSelectorFn ()
}
}
}
r .mutex .Unlock ()
}
}
}
func (r *reuse ) TransportWithAssociationForDial (association any , network string , raddr *net .UDPAddr ) (*refcountedTransport , error ) {
var ip *net .IP
r .mutex .Lock ()
router := r .routes
r .mutex .Unlock ()
if router != nil {
src , err := router .PreferredSourceIPForDestination (raddr )
if err == nil && !src .IsUnspecified () {
ip = &src
}
}
r .mutex .Lock ()
defer r .mutex .Unlock ()
tr , err := r .transportForDialLocked (association , network , ip )
if err != nil {
return nil , err
}
tr .IncreaseCount ()
return tr , nil
}
func (r *reuse ) transportForDialLocked (association any , network string , source *net .IP ) (*refcountedTransport , error ) {
if source != nil {
if trs , ok := r .unicast [source .String ()]; ok {
for _ , tr := range trs {
if tr .hasAssociation (association ) {
return tr , nil
}
}
for _ , tr := range trs {
return tr , nil
}
}
}
for _ , tr := range r .globalListeners {
if tr .hasAssociation (association ) {
return tr , nil
}
}
for _ , tr := range r .globalListeners {
return tr , nil
}
for _ , tr := range r .globalDialers {
return tr , nil
}
var addr *net .UDPAddr
switch network {
case "udp4" :
addr = &net .UDPAddr {IP : net .IPv4zero , Port : 0 }
case "udp6" :
addr = &net .UDPAddr {IP : net .IPv6zero , Port : 0 }
}
conn , err := r .listenUDP (network , addr )
if err != nil {
return nil , err
}
tr := r .newTransport (conn )
r .globalDialers [conn .LocalAddr ().(*net .UDPAddr ).Port ] = tr
return tr , nil
}
func (r *reuse ) AddTransport (tr *refcountedTransport , laddr *net .UDPAddr ) error {
r .mutex .Lock ()
defer r .mutex .Unlock ()
if !laddr .IP .IsUnspecified () {
return errors .New ("adding transport for specific IP not supported" )
}
if _ , ok := r .globalDialers [laddr .Port ]; ok {
return fmt .Errorf ("already have global dialer for port %d" , laddr .Port )
}
r .globalDialers [laddr .Port ] = tr
return nil
}
func (r *reuse ) AssertTransportExists (tr RefCountedQUICTransport ) error {
t , ok := tr .(*refcountedTransport )
if !ok {
return fmt .Errorf ("invalid transport type: expected: *refcountedTransport, got: %T" , tr )
}
laddr := t .LocalAddr ().(*net .UDPAddr )
if laddr .IP .IsUnspecified () {
if lt , ok := r .globalListeners [laddr .Port ]; ok {
if lt == t {
return nil
}
return errors .New ("two global listeners on the same port" )
}
return errors .New ("transport not found" )
}
if m , ok := r .unicast [laddr .IP .String ()]; ok {
if lt , ok := m [laddr .Port ]; ok {
if lt == t {
return nil
}
return errors .New ("two unicast listeners on same ip:port" )
}
return errors .New ("transport not found" )
}
return errors .New ("transport not found" )
}
func (r *reuse ) TransportForListen (network string , laddr *net .UDPAddr ) (*refcountedTransport , error ) {
r .mutex .Lock ()
defer r .mutex .Unlock ()
if laddr .IP .IsUnspecified () {
var rTr *refcountedTransport
var localAddr *net .UDPAddr
if laddr .Port == 0 {
for _ , tr := range r .globalDialers {
rTr = tr
localAddr = rTr .LocalAddr ().(*net .UDPAddr )
delete (r .globalDialers , localAddr .Port )
break
}
} else if _ , ok := r .globalDialers [laddr .Port ]; ok {
rTr = r .globalDialers [laddr .Port ]
localAddr = rTr .LocalAddr ().(*net .UDPAddr )
delete (r .globalDialers , localAddr .Port )
}
if rTr != nil {
rTr .IncreaseCount ()
r .globalListeners [localAddr .Port ] = rTr
return rTr , nil
}
}
conn , err := r .listenUDP (network , laddr )
if err != nil {
return nil , err
}
tr := r .newTransport (conn )
tr .IncreaseCount ()
localAddr := conn .LocalAddr ().(*net .UDPAddr )
if localAddr .IP .IsUnspecified () {
r .globalListeners [localAddr .Port ] = tr
return tr , nil
}
if _ , ok := r .unicast [localAddr .IP .String ()]; !ok {
r .unicast [localAddr .IP .String ()] = make (map [int ]*refcountedTransport )
r .routes , _ = r .sourceIPSelectorFn ()
}
r .unicast [localAddr .IP .String ()][localAddr .Port ] = tr
return tr , nil
}
func (r *reuse ) newTransport (conn net .PacketConn ) *refcountedTransport {
return &refcountedTransport {
QUICTransport : &wrappedQUICTransport {
Transport : newQUICTransport (
conn ,
r .tokenGeneratorKey ,
r .statelessResetKey ,
r .connContext ,
r .verifySourceAddress ,
),
},
packetConn : conn ,
}
}
func (r *reuse ) Close () error {
close (r .closeChan )
<-r .gcStopChan
return nil
}
type SourceIPSelector interface {
PreferredSourceIPForDestination (dst *net .UDPAddr ) (net .IP , error )
}
type netrouteSourceIPSelector struct {
routes routing .Router
}
func (s *netrouteSourceIPSelector ) PreferredSourceIPForDestination (dst *net .UDPAddr ) (net .IP , error ) {
_ , _ , src , err := s .routes .Route (dst .IP )
return src , err
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .