package rate
import (
"container/heap"
"net/netip"
"slices"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
manet "github.com/multiformats/go-multiaddr/net"
"golang.org/x/time/rate"
)
type Limit struct {
RPS float64
Burst int
}
type PrefixLimit struct {
Prefix netip .Prefix
Limit
}
type SubnetLimit struct {
PrefixLength int
Limit
}
type Limiter struct {
NetworkPrefixLimits []PrefixLimit
GlobalLimit Limit
SubnetRateLimiter SubnetLimiter
initOnce sync .Once
globalBucket *rate .Limiter
networkPrefixBuckets []*rate .Limiter
}
func (r *Limiter ) init () {
r .initOnce .Do (func () {
if r .GlobalLimit .RPS == 0 {
r .globalBucket = rate .NewLimiter (rate .Inf , 0 )
} else {
r .globalBucket = rate .NewLimiter (rate .Limit (r .GlobalLimit .RPS ), r .GlobalLimit .Burst )
}
r .NetworkPrefixLimits = slices .Clone (r .NetworkPrefixLimits )
slices .SortFunc (r .NetworkPrefixLimits , func (a , b PrefixLimit ) int { return b .Prefix .Bits () - a .Prefix .Bits () })
r .networkPrefixBuckets = make ([]*rate .Limiter , 0 , len (r .NetworkPrefixLimits ))
for _ , limit := range r .NetworkPrefixLimits {
if limit .RPS == 0 {
r .networkPrefixBuckets = append (r .networkPrefixBuckets , rate .NewLimiter (rate .Inf , 0 ))
} else {
r .networkPrefixBuckets = append (r .networkPrefixBuckets , rate .NewLimiter (rate .Limit (limit .RPS ), limit .Burst ))
}
}
})
}
func (r *Limiter ) Limit (f func (s network .Stream )) func (s network .Stream ) {
r .init ()
return func (s network .Stream ) {
addr := s .Conn ().RemoteMultiaddr ()
ip , err := manet .ToIP (addr )
if err != nil {
ip = nil
}
ipAddr , ok := netip .AddrFromSlice (ip )
if !ok {
ipAddr = netip .Addr {}
}
if !r .Allow (ipAddr ) {
_ = s .ResetWithError (network .StreamRateLimited )
return
}
f (s )
}
}
func (r *Limiter ) Allow (ipAddr netip .Addr ) bool {
r .init ()
isWithinNetworkPrefix := false
for i , limit := range r .NetworkPrefixLimits {
if limit .Prefix .Contains (ipAddr ) {
if !r .networkPrefixBuckets [i ].Allow () {
return false
}
isWithinNetworkPrefix = true
}
}
if isWithinNetworkPrefix {
return true
}
if !r .SubnetRateLimiter .Allow (ipAddr , time .Now ()) {
return false
}
return r .globalBucket .Allow ()
}
type SubnetLimiter struct {
IPv4SubnetLimits []SubnetLimit
IPv6SubnetLimits []SubnetLimit
GracePeriod time .Duration
initOnce sync .Once
mx sync .Mutex
ipv4Heaps []*bucketHeap
ipv6Heaps []*bucketHeap
}
func (s *SubnetLimiter ) init () {
s .initOnce .Do (func () {
slices .SortFunc (s .IPv4SubnetLimits , func (a , b SubnetLimit ) int { return b .PrefixLength - a .PrefixLength })
slices .SortFunc (s .IPv6SubnetLimits , func (a , b SubnetLimit ) int { return b .PrefixLength - a .PrefixLength })
s .ipv4Heaps = make ([]*bucketHeap , len (s .IPv4SubnetLimits ))
for i := range s .IPv4SubnetLimits {
s .ipv4Heaps [i ] = &bucketHeap {
prefixBucket : make ([]prefixBucketWithExpiry , 0 ),
prefixToIndex : make (map [netip .Prefix ]int ),
}
heap .Init (s .ipv4Heaps [i ])
}
s .ipv6Heaps = make ([]*bucketHeap , len (s .IPv6SubnetLimits ))
for i := range s .IPv6SubnetLimits {
s .ipv6Heaps [i ] = &bucketHeap {
prefixBucket : make ([]prefixBucketWithExpiry , 0 ),
prefixToIndex : make (map [netip .Prefix ]int ),
}
heap .Init (s .ipv6Heaps [i ])
}
})
}
func (s *SubnetLimiter ) Allow (ipAddr netip .Addr , now time .Time ) bool {
s .init ()
s .mx .Lock ()
defer s .mx .Unlock ()
s .cleanUp (now )
var subNetLimits []SubnetLimit
var heaps []*bucketHeap
if ipAddr .Is4 () {
subNetLimits = s .IPv4SubnetLimits
heaps = s .ipv4Heaps
} else {
subNetLimits = s .IPv6SubnetLimits
heaps = s .ipv6Heaps
}
for i , limit := range subNetLimits {
prefix , err := ipAddr .Prefix (limit .PrefixLength )
if err != nil {
return false
}
bucket := heaps [i ].Get (prefix )
if bucket == (prefixBucketWithExpiry {}) {
bucket = prefixBucketWithExpiry {
Prefix : prefix ,
tokenBucket : tokenBucket {rate .NewLimiter (rate .Limit (limit .RPS ), limit .Burst )},
Expiry : now ,
}
}
if !bucket .Allow () {
return false
}
bucket .Expiry = bucket .FullAt (now ).Add (s .GracePeriod )
heaps [i ].Upsert (bucket )
}
return true
}
func (s *SubnetLimiter ) cleanUp (now time .Time ) {
for _ , h := range s .ipv4Heaps {
h .Expire (now )
}
for _ , h := range s .ipv6Heaps {
h .Expire (now )
}
}
type tokenBucket struct {
*rate .Limiter
}
func (b *tokenBucket ) FullAt (now time .Time ) time .Time {
tokensNeeded := float64 (b .Burst ()) - b .TokensAt (now )
refillRate := float64 (b .Limit ())
eta := time .Duration ((tokensNeeded / refillRate ) * float64 (time .Second ))
return now .Add (eta )
}
type prefixBucketWithExpiry struct {
tokenBucket
Prefix netip .Prefix
Expiry time .Time
}
type bucketHeap struct {
prefixBucket []prefixBucketWithExpiry
prefixToIndex map [netip .Prefix ]int
}
var _ heap .Interface = (*bucketHeap )(nil )
func (h *bucketHeap ) Upsert (b prefixBucketWithExpiry ) {
if i , ok := h .prefixToIndex [b .Prefix ]; ok {
h .prefixBucket [i ] = b
heap .Fix (h , i )
return
}
heap .Push (h , b )
}
func (h *bucketHeap ) Get (prefix netip .Prefix ) prefixBucketWithExpiry {
if i , ok := h .prefixToIndex [prefix ]; ok {
return h .prefixBucket [i ]
}
return prefixBucketWithExpiry {}
}
func (h *bucketHeap ) Expire (expiry time .Time ) {
for h .Len () > 0 {
oldest := h .prefixBucket [0 ]
if oldest .Expiry .After (expiry ) {
break
}
heap .Pop (h )
}
}
func (h *bucketHeap ) Len () int {
return len (h .prefixBucket )
}
func (h *bucketHeap ) Less (i , j int ) bool {
return h .prefixBucket [i ].Expiry .Before (h .prefixBucket [j ].Expiry )
}
func (h *bucketHeap ) Swap (i , j int ) {
h .prefixBucket [i ], h .prefixBucket [j ] = h .prefixBucket [j ], h .prefixBucket [i ]
h .prefixToIndex [h .prefixBucket [i ].Prefix ] = i
h .prefixToIndex [h .prefixBucket [j ].Prefix ] = j
}
func (h *bucketHeap ) Push (x any ) {
item := x .(prefixBucketWithExpiry )
h .prefixBucket = append (h .prefixBucket , item )
h .prefixToIndex [item .Prefix ] = len (h .prefixBucket ) - 1
}
func (h *bucketHeap ) Pop () any {
n := len (h .prefixBucket )
item := h .prefixBucket [n -1 ]
h .prefixBucket = h .prefixBucket [0 : n -1 ]
delete (h .prefixToIndex , item .Prefix )
return item
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .