package basichost

import (
	
	
	
	
	
	
	

	
	inat 

	ma 
	manet 
)

// NATManager is a simple interface to manage NAT devices.
// It listens Listen and ListenClose notifications from the network.Network,
// and tries to obtain port mappings for those.
type NATManager interface {
	GetMapping(ma.Multiaddr) ma.Multiaddr
	HasDiscoveredNAT() bool
	io.Closer
}

// NewNATManager creates a NAT manager.
func ( network.Network) NATManager {
	return newNATManager()
}

type entry struct {
	protocol string
	port     int
}

type nat interface {
	AddMapping(ctx context.Context, protocol string, port int) error
	RemoveMapping(ctx context.Context, protocol string, port int) error
	GetMapping(protocol string, port int) (netip.AddrPort, bool)
	io.Closer
}

// so we can mock it in tests
var discoverNAT = func( context.Context) (nat, error) { return inat.DiscoverNAT() }

// natManager takes care of adding + removing port mappings to the nat.
// Initialized with the host if it has a NATPortMap option enabled.
// natManager receives signals from the network, and check on nat mappings:
//   - natManager listens to the network and adds or closes port mappings
//     as the network signals Listen() or ListenClose().
//   - closing the natManager closes the nat and its mappings.
type natManager struct {
	net   network.Network
	natMx sync.RWMutex
	nat   nat

	syncFlag chan struct{} // cap: 1

	tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function

	refCount  sync.WaitGroup
	ctx       context.Context
	ctxCancel context.CancelFunc
}

func newNATManager( network.Network) *natManager {
	,  := context.WithCancel(context.Background())
	 := &natManager{
		net:       ,
		syncFlag:  make(chan struct{}, 1),
		ctx:       ,
		ctxCancel: ,
		tracked:   make(map[entry]bool),
	}
	.refCount.Add(1)
	go .background()
	return 
}

// Close closes the natManager, closing the underlying nat
// and unregistering from network events.
func ( *natManager) () error {
	.ctxCancel()
	.refCount.Wait()
	return nil
}

func ( *natManager) () bool {
	.natMx.RLock()
	defer .natMx.RUnlock()
	return .nat != nil
}

func ( *natManager) ( context.Context) {
	defer .refCount.Done()

	defer func() {
		.natMx.Lock()
		defer .natMx.Unlock()

		if .nat != nil {
			.nat.Close()
		}
	}()

	,  := context.WithTimeout(, 10*time.Second)
	defer ()
	,  := discoverNAT()
	if  != nil {
		log.Info("DiscoverNAT error:", )
		return
	}

	.natMx.Lock()
	.nat = 
	.natMx.Unlock()

	// sign natManager up for network notifications
	// we need to sign up here to avoid missing some notifs
	// before the NAT has been found.
	.net.Notify((*nmgrNetNotifiee)())
	defer .net.StopNotify((*nmgrNetNotifiee)())

	.doSync() // sync one first.
	for {
		select {
		case <-.syncFlag:
			.doSync() // sync when our listen addresses change.
		case <-.Done():
			return
		}
	}
}

func ( *natManager) () {
	select {
	case .syncFlag <- struct{}{}:
	default:
	}
}

// doSync syncs the current NAT mappings, removing any outdated mappings and adding any
// new mappings.
func ( *natManager) () {
	for  := range .tracked {
		.tracked[] = false
	}
	var  []entry
	for ,  := range .net.ListenAddresses() {
		// Strip the IP
		,  := ma.SplitFirst()
		if  == nil || len() == 0 {
			continue
		}

		switch .Protocol().Code {
		case ma.P_IP6, ma.P_IP4:
		default:
			continue
		}

		// Only bother if we're listening on an unicast / unspecified IP.
		 := net.IP(.RawValue())
		if !.IsGlobalUnicast() && !.IsUnspecified() {
			continue
		}

		// Extract the port/protocol
		,  := ma.SplitFirst()
		if  == nil {
			continue
		}

		var  string
		switch .Protocol().Code {
		case ma.P_TCP:
			 = "tcp"
		case ma.P_UDP:
			 = "udp"
		default:
			continue
		}
		,  := strconv.ParseUint(.Value(), 10, 16)
		if  != nil {
			// bug in multiaddr
			panic()
		}
		 := entry{protocol: , port: int()}
		if ,  := .tracked[];  {
			.tracked[] = true
		} else {
			 = append(, )
		}
	}

	var  sync.WaitGroup
	defer .Wait()

	// Close old mappings
	for ,  := range .tracked {
		if ! {
			.nat.RemoveMapping(.ctx, .protocol, .port)
			delete(.tracked, )
		}
	}

	// Create new mappings.
	for ,  := range  {
		if  := .nat.AddMapping(.ctx, .protocol, .port);  != nil {
			log.Errorf("failed to port-map %s port %d: %s", .protocol, .port, )
		}
		.tracked[] = false
	}
}

func ( *natManager) ( ma.Multiaddr) ma.Multiaddr {
	.natMx.Lock()
	defer .natMx.Unlock()

	if .nat == nil { // NAT not yet initialized
		return nil
	}

	var  bool
	var  int // ma.P_TCP or ma.P_UDP
	,  := ma.SplitFunc(, func( ma.Component) bool {
		if  {
			return true
		}
		 = .Protocol().Code
		 =  == ma.P_TCP ||  == ma.P_UDP
		return false
	})
	if !manet.IsThinWaist() {
		return nil
	}

	,  := manet.ToNetAddr()
	if  != nil {
		log.Error("error parsing net multiaddr %q: %s", , )
		return nil
	}

	var (
		       net.IP
		     int
		 string
	)
	switch naddr := .(type) {
	case *net.TCPAddr:
		 = .IP
		 = .Port
		 = "tcp"
	case *net.UDPAddr:
		 = .IP
		 = .Port
		 = "udp"
	default:
		return nil
	}

	if !.IsGlobalUnicast() && !.IsUnspecified() {
		// We only map global unicast & unspecified addresses ports, not broadcast, multicast, etc.
		return nil
	}

	,  := .nat.GetMapping(, )
	if ! {
		return nil
	}

	var  net.Addr
	switch .(type) {
	case *net.TCPAddr:
		 = net.TCPAddrFromAddrPort()
	case *net.UDPAddr:
		 = net.UDPAddrFromAddrPort()
	}
	,  := manet.FromNetAddr()
	if  != nil {
		log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", , )
		return nil
	}
	 := 
	if  != nil {
		 = ma.Join(, )
	}
	return 
}

type nmgrNetNotifiee natManager

func ( *nmgrNetNotifiee) () *natManager                       { return (*natManager)() }
func ( *nmgrNetNotifiee) (network.Network, ma.Multiaddr)          { .natManager().sync() }
func ( *nmgrNetNotifiee) ( network.Network,  ma.Multiaddr) { .natManager().sync() }
func ( *nmgrNetNotifiee) (network.Network, network.Conn)       {}
func ( *nmgrNetNotifiee) (network.Network, network.Conn)    {}