package basichost
import (
"context"
"io"
"net"
"net/netip"
"strconv"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
inat "github.com/libp2p/go-libp2p/p2p/net/nat"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type NATManager interface {
GetMapping (ma .Multiaddr ) ma .Multiaddr
HasDiscoveredNAT () bool
io .Closer
}
func NewNATManager (net network .Network ) NATManager {
return newNATManager (net )
}
type entry struct {
protocol string
port int
}
type nat interface {
AddMapping(ctx context .Context , protocol string , port int ) error
RemoveMapping(ctx context .Context , protocol string , port int ) error
GetMapping(protocol string , port int ) (netip .AddrPort , bool )
io .Closer
}
var discoverNAT = func (ctx context .Context ) (nat , error ) { return inat .DiscoverNAT (ctx ) }
type natManager struct {
net network .Network
natMx sync .RWMutex
nat nat
syncFlag chan struct {}
tracked map [entry ]bool
refCount sync .WaitGroup
ctx context .Context
ctxCancel context .CancelFunc
}
func newNATManager(net network .Network ) *natManager {
ctx , cancel := context .WithCancel (context .Background ())
nmgr := &natManager {
net : net ,
syncFlag : make (chan struct {}, 1 ),
ctx : ctx ,
ctxCancel : cancel ,
tracked : make (map [entry ]bool ),
}
nmgr .refCount .Add (1 )
go nmgr .background (ctx )
return nmgr
}
func (nmgr *natManager ) Close () error {
nmgr .ctxCancel ()
nmgr .refCount .Wait ()
return nil
}
func (nmgr *natManager ) HasDiscoveredNAT () bool {
nmgr .natMx .RLock ()
defer nmgr .natMx .RUnlock ()
return nmgr .nat != nil
}
func (nmgr *natManager ) background (ctx context .Context ) {
defer nmgr .refCount .Done ()
defer func () {
nmgr .natMx .Lock ()
defer nmgr .natMx .Unlock ()
if nmgr .nat != nil {
nmgr .nat .Close ()
}
}()
discoverCtx , cancel := context .WithTimeout (ctx , 10 *time .Second )
defer cancel ()
natInstance , err := discoverNAT (discoverCtx )
if err != nil {
log .Info ("DiscoverNAT error:" , err )
return
}
nmgr .natMx .Lock ()
nmgr .nat = natInstance
nmgr .natMx .Unlock ()
nmgr .net .Notify ((*nmgrNetNotifiee )(nmgr ))
defer nmgr .net .StopNotify ((*nmgrNetNotifiee )(nmgr ))
nmgr .doSync ()
for {
select {
case <- nmgr .syncFlag :
nmgr .doSync ()
case <- ctx .Done ():
return
}
}
}
func (nmgr *natManager ) sync () {
select {
case nmgr .syncFlag <- struct {}{}:
default :
}
}
func (nmgr *natManager ) doSync () {
for e := range nmgr .tracked {
nmgr .tracked [e ] = false
}
var newAddresses []entry
for _ , maddr := range nmgr .net .ListenAddresses () {
maIP , rest := ma .SplitFirst (maddr )
if maIP == nil || len (rest ) == 0 {
continue
}
switch maIP .Protocol ().Code {
case ma .P_IP6 , ma .P_IP4 :
default :
continue
}
ip := net .IP (maIP .RawValue ())
if !ip .IsGlobalUnicast () && !ip .IsUnspecified () {
continue
}
proto , _ := ma .SplitFirst (rest )
if proto == nil {
continue
}
var protocol string
switch proto .Protocol ().Code {
case ma .P_TCP :
protocol = "tcp"
case ma .P_UDP :
protocol = "udp"
default :
continue
}
port , err := strconv .ParseUint (proto .Value (), 10 , 16 )
if err != nil {
panic (err )
}
e := entry {protocol : protocol , port : int (port )}
if _ , ok := nmgr .tracked [e ]; ok {
nmgr .tracked [e ] = true
} else {
newAddresses = append (newAddresses , e )
}
}
var wg sync .WaitGroup
defer wg .Wait ()
for e , v := range nmgr .tracked {
if !v {
nmgr .nat .RemoveMapping (nmgr .ctx , e .protocol , e .port )
delete (nmgr .tracked , e )
}
}
for _ , e := range newAddresses {
if err := nmgr .nat .AddMapping (nmgr .ctx , e .protocol , e .port ); err != nil {
log .Errorf ("failed to port-map %s port %d: %s" , e .protocol , e .port , err )
}
nmgr .tracked [e ] = false
}
}
func (nmgr *natManager ) GetMapping (addr ma .Multiaddr ) ma .Multiaddr {
nmgr .natMx .Lock ()
defer nmgr .natMx .Unlock ()
if nmgr .nat == nil {
return nil
}
var found bool
var proto int
transport , rest := ma .SplitFunc (addr , func (c ma .Component ) bool {
if found {
return true
}
proto = c .Protocol ().Code
found = proto == ma .P_TCP || proto == ma .P_UDP
return false
})
if !manet .IsThinWaist (transport ) {
return nil
}
naddr , err := manet .ToNetAddr (transport )
if err != nil {
log .Error ("error parsing net multiaddr %q: %s" , transport , err )
return nil
}
var (
ip net .IP
port int
protocol string
)
switch naddr := naddr .(type ) {
case *net .TCPAddr :
ip = naddr .IP
port = naddr .Port
protocol = "tcp"
case *net .UDPAddr :
ip = naddr .IP
port = naddr .Port
protocol = "udp"
default :
return nil
}
if !ip .IsGlobalUnicast () && !ip .IsUnspecified () {
return nil
}
extAddr , ok := nmgr .nat .GetMapping (protocol , port )
if !ok {
return nil
}
var mappedAddr net .Addr
switch naddr .(type ) {
case *net .TCPAddr :
mappedAddr = net .TCPAddrFromAddrPort (extAddr )
case *net .UDPAddr :
mappedAddr = net .UDPAddrFromAddrPort (extAddr )
}
mappedMaddr , err := manet .FromNetAddr (mappedAddr )
if err != nil {
log .Errorf ("mapped addr can't be turned into a multiaddr %q: %s" , mappedAddr , err )
return nil
}
extMaddr := mappedMaddr
if rest != nil {
extMaddr = ma .Join (extMaddr , rest )
}
return extMaddr
}
type nmgrNetNotifiee natManager
func (nn *nmgrNetNotifiee ) natManager () *natManager { return (*natManager )(nn ) }
func (nn *nmgrNetNotifiee ) Listen (network .Network , ma .Multiaddr ) { nn .natManager ().sync () }
func (nn *nmgrNetNotifiee ) ListenClose (_ network .Network , _ ma .Multiaddr ) { nn .natManager ().sync () }
func (nn *nmgrNetNotifiee ) Connected (network .Network , network .Conn ) {}
func (nn *nmgrNetNotifiee ) Disconnected (network .Network , network .Conn ) {}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .