package rcmgr

import (
	
	
	
	
	

	
)

type ConnLimitPerSubnet struct {
	// This defines how big the subnet is. For example, a /24 subnet has a
	// PrefixLength of 24. All IPs that share the same 24 bit prefix are in the same
	// subnet. Are in the same subnet, and bound to the same limit.
	PrefixLength int
	// The maximum number of connections allowed for each subnet.
	ConnCount int
}

type NetworkPrefixLimit struct {
	// The Network prefix for which this limit applies.
	Network netip.Prefix

	// The maximum number of connections allowed for this subnet.
	ConnCount int
}

// 8 for now so that it matches the number of concurrent dials we may do
// in swarm_dial.go. With future smart dialing work we should bring this
// down
var defaultMaxConcurrentConns = 8

var defaultIP4Limit = ConnLimitPerSubnet{
	ConnCount:    defaultMaxConcurrentConns,
	PrefixLength: 32,
}
var defaultIP6Limits = []ConnLimitPerSubnet{
	{
		ConnCount:    defaultMaxConcurrentConns,
		PrefixLength: 56,
	},
	{
		ConnCount:    8 * defaultMaxConcurrentConns,
		PrefixLength: 48,
	},
}

var DefaultNetworkPrefixLimitV4 = sortNetworkPrefixes([]NetworkPrefixLimit{
	{
		// Loopback address for v4 https://datatracker.ietf.org/doc/html/rfc6890#section-2.2.2
		Network:   netip.MustParsePrefix("127.0.0.0/8"),
		ConnCount: math.MaxInt, // Unlimited
	},
})
var DefaultNetworkPrefixLimitV6 = sortNetworkPrefixes([]NetworkPrefixLimit{
	{
		// Loopback address for v6 https://datatracker.ietf.org/doc/html/rfc6890#section-2.2.3
		Network:   netip.MustParsePrefix("::1/128"),
		ConnCount: math.MaxInt, // Unlimited
	},
})

// Network prefixes limits must be sorted by most specific to least specific.  This lets us
// actually use the more specific limits, otherwise only the less specific ones
// would be matched. e.g. 1.2.3.0/24 must come before 1.2.0.0/16.
func sortNetworkPrefixes( []NetworkPrefixLimit) []NetworkPrefixLimit {
	slices.SortStableFunc(, func(,  NetworkPrefixLimit) int {
		return .Network.Bits() - .Network.Bits()
	})
	return 
}

// WithNetworkPrefixLimit sets the limits for the number of connections allowed
// for a specific Network Prefix. Use this when you want to set higher limits
// for a specific subnet than the default limit per subnet.
func ( []NetworkPrefixLimit,  []NetworkPrefixLimit) Option {
	return func( *resourceManager) error {
		if  != nil {
			.connLimiter.networkPrefixLimitV4 = sortNetworkPrefixes()
		}
		if  != nil {
			.connLimiter.networkPrefixLimitV6 = sortNetworkPrefixes()
		}
		return nil
	}
}

// WithLimitPerSubnet sets the limits for the number of connections allowed per
// subnet. This will limit the number of connections per subnet if that subnet
// is not defined in the NetworkPrefixLimit option. Think of this as a default
// limit for any given subnet.
func ( []ConnLimitPerSubnet,  []ConnLimitPerSubnet) Option {
	return func( *resourceManager) error {
		if  != nil {
			.connLimiter.connLimitPerSubnetV4 = 
		}
		if  != nil {
			.connLimiter.connLimitPerSubnetV6 = 
		}
		return nil
	}
}

type connLimiter struct {
	mu sync.Mutex

	// Specific Network Prefix limits. If these are set, they take precedence over the
	// subnet limits.
	// These must be sorted by most specific to least specific.
	networkPrefixLimitV4    []NetworkPrefixLimit
	networkPrefixLimitV6    []NetworkPrefixLimit
	connsPerNetworkPrefixV4 []int
	connsPerNetworkPrefixV6 []int

	// Subnet limits.
	connLimitPerSubnetV4 []ConnLimitPerSubnet
	connLimitPerSubnetV6 []ConnLimitPerSubnet
	ip4connsPerLimit     []map[netip.Prefix]int
	ip6connsPerLimit     []map[netip.Prefix]int
}

func newConnLimiter() *connLimiter {
	return &connLimiter{
		networkPrefixLimitV4: DefaultNetworkPrefixLimitV4,
		networkPrefixLimitV6: DefaultNetworkPrefixLimitV6,

		connLimitPerSubnetV4: []ConnLimitPerSubnet{defaultIP4Limit},
		connLimitPerSubnetV6: defaultIP6Limits,
	}
}

func ( *connLimiter) ( bool,  NetworkPrefixLimit) {
	.mu.Lock()
	defer .mu.Unlock()
	if  {
		.networkPrefixLimitV6 = append(.networkPrefixLimitV6, )
		.networkPrefixLimitV6 = sortNetworkPrefixes(.networkPrefixLimitV6)
	} else {
		.networkPrefixLimitV4 = append(.networkPrefixLimitV4, )
		.networkPrefixLimitV4 = sortNetworkPrefixes(.networkPrefixLimitV4)
	}
}

// addConn adds a connection for the given IP address. It returns true if the connection is allowed.
func ( *connLimiter) ( netip.Addr) bool {
	.mu.Lock()
	defer .mu.Unlock()
	 := .networkPrefixLimitV4
	 := .connsPerNetworkPrefixV4
	 := .connLimitPerSubnetV4
	 := .ip4connsPerLimit
	 := .Is6()
	if  {
		 = .networkPrefixLimitV6
		 = .connsPerNetworkPrefixV6
		 = .connLimitPerSubnetV6
		 = .ip6connsPerLimit
	}

	// Check Network Prefix limits first
	if len() == 0 && len() > 0 {
		// Initialize the counts
		 = make([]int, len())
		if  {
			.connsPerNetworkPrefixV6 = 
		} else {
			.connsPerNetworkPrefixV4 = 
		}
	}

	for ,  := range  {
		if .Network.Contains() {
			if []+1 > .ConnCount {
				return false
			}
			[]++
			// Done. If we find a match in the network prefix limits, we use
			// that and don't use the general subnet limits.
			return true
		}
	}

	if len() == 0 && len() > 0 {
		 = make([]map[netip.Prefix]int, len())
		if  {
			.ip6connsPerLimit = 
		} else {
			.ip4connsPerLimit = 
		}
	}

	for ,  := range  {
		,  := .Prefix(.PrefixLength)
		if  != nil {
			return false
		}
		,  := [][]
		if ! {
			if [] == nil {
				[] = make(map[netip.Prefix]int)
			}
			[][] = 0
		}
		if +1 > .ConnCount {
			return false
		}
	}

	// All limit checks passed, now we update the counts
	for ,  := range  {
		,  := .Prefix(.PrefixLength)
		[][]++
	}

	return true
}

func ( *connLimiter) ( netip.Addr) {
	.mu.Lock()
	defer .mu.Unlock()
	 := .networkPrefixLimitV4
	 := .connsPerNetworkPrefixV4
	 := .connLimitPerSubnetV4
	 := .ip4connsPerLimit
	 := .Is6()
	if  {
		 = .networkPrefixLimitV6
		 = .connsPerNetworkPrefixV6
		 = .connLimitPerSubnetV6
		 = .ip6connsPerLimit
	}

	// Check NetworkPrefix limits first
	if len() == 0 && len() > 0 {
		// Initialize just in case. We should have already initialized in
		// addConn, but if the callers calls rmConn first we don't want to panic
		 = make([]int, len())
		if  {
			.connsPerNetworkPrefixV6 = 
		} else {
			.connsPerNetworkPrefixV4 = 
		}
	}
	for ,  := range  {
		if .Network.Contains() {
			 := []
			if  <= 0 {
				log.Errorf("unexpected conn count for ip %s. Was this not added with addConn first?", )
				return
			}
			[]--
			// Done. We updated the count in the defined network prefix limit.
			return
		}
	}

	if len() == 0 && len() > 0 {
		// Initialize just in case. We should have already initialized in
		// addConn, but if the callers calls rmConn first we don't want to panic
		 = make([]map[netip.Prefix]int, len())
		if  {
			.ip6connsPerLimit = 
		} else {
			.ip4connsPerLimit = 
		}
	}

	for ,  := range  {
		,  := .Prefix(.PrefixLength)
		if  != nil {
			// Unexpected since we should have seen this IP before in addConn
			log.Errorf("unexpected error getting prefix: %v", )
			continue
		}
		,  := [][]
		if ! ||  == 0 {
			// Unexpected, but don't panic
			log.Errorf("unexpected conn count for %s ok=%v count=%v", , , )
			continue
		}
		[][]--
		if [][] <= 0 {
			delete([], )
		}
	}
}

// handshakeDuration is a higher end estimate of QUIC handshake time
const handshakeDuration = 5 * time.Second

// sourceAddressRPS is the refill rate for the source address verification rate limiter.
// A spoofed address if not verified will take a connLimiter token for handshakeDuration.
// Slow refill rate here favours increasing latency(because of address verification) in
// exchange for reducing the chances of spoofing successfully causing a DoS.
const sourceAddressRPS = float64(1.0*time.Second) / (2 * float64(handshakeDuration))

// newVerifySourceAddressRateLimiter returns a rate limiter for verifying source addresses.
// The returned limiter allows maxAllowedConns / 2 unverified addresses to begin handshake.
// This ensures that in the event someone is spoofing IPs, 1/2 the maximum allowed connections
// will be able to connect, although they will have increased latency because of address
// verification.
func newVerifySourceAddressRateLimiter( *connLimiter) *rate.Limiter {
	 := make([]rate.PrefixLimit, 0, len(.networkPrefixLimitV4)+len(.networkPrefixLimitV6))
	for ,  := range .networkPrefixLimitV4 {
		 = append(, rate.PrefixLimit{
			Prefix: .Network,
			Limit:  rate.Limit{RPS: sourceAddressRPS, Burst: .ConnCount / 2},
		})
	}
	for ,  := range .networkPrefixLimitV6 {
		 = append(, rate.PrefixLimit{
			Prefix: .Network,
			Limit:  rate.Limit{RPS: sourceAddressRPS, Burst: .ConnCount / 2},
		})
	}

	 := make([]rate.SubnetLimit, 0, len(.connLimitPerSubnetV4))
	for ,  := range .connLimitPerSubnetV4 {
		 = append(, rate.SubnetLimit{
			PrefixLength: .PrefixLength,
			Limit:        rate.Limit{RPS: sourceAddressRPS, Burst: .ConnCount / 2},
		})
	}

	 := make([]rate.SubnetLimit, 0, len(.connLimitPerSubnetV6))
	for ,  := range .connLimitPerSubnetV6 {
		 = append(, rate.SubnetLimit{
			PrefixLength: .PrefixLength,
			Limit:        rate.Limit{RPS: sourceAddressRPS, Burst: .ConnCount / 2},
		})
	}

	return &rate.Limiter{
		NetworkPrefixLimits: ,
		SubnetRateLimiter: rate.SubnetLimiter{
			IPv4SubnetLimits: ,
			IPv6SubnetLimits: ,
			GracePeriod:      1 * time.Minute,
		},
	}
}