package autonatv2
import (
"context"
"fmt"
"os"
"runtime/debug"
"sync"
"time"
"math/rand/v2"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
"github.com/libp2p/go-msgio/pbio"
ma "github.com/multiformats/go-multiaddr"
)
type client struct {
host host .Host
dialData []byte
metricsTracer MetricsTracer
mu sync .Mutex
dialBackQueues map [uint64 ]chan ma .Multiaddr
}
func newClient(s *autoNATSettings ) *client {
return &client {
dialData : make ([]byte , 4000 ),
dialBackQueues : make (map [uint64 ]chan ma .Multiaddr ),
metricsTracer : s .metricsTracer ,
}
}
func (ac *client ) Start (h host .Host ) {
ac .host = h
ac .host .SetStreamHandler (DialBackProtocol , ac .handleDialBack )
}
func (ac *client ) Close () {
ac .host .RemoveStreamHandler (DialBackProtocol )
}
func (ac *client ) GetReachability (ctx context .Context , p peer .ID , reqs []Request ) (Result , error ) {
result , err := ac .getReachability (ctx , p , reqs )
if ac .metricsTracer != nil {
ac .metricsTracer .ClientCompletedRequest (reqs , result , err )
}
return result , err
}
func (ac *client ) getReachability (ctx context .Context , p peer .ID , reqs []Request ) (Result , error ) {
ctx , cancel := context .WithTimeout (ctx , streamTimeout )
defer cancel ()
s , err := ac .host .NewStream (ctx , p , DialProtocol )
if err != nil {
return Result {}, fmt .Errorf ("open %s stream failed: %w" , DialProtocol , err )
}
if err := s .Scope ().SetService (ServiceName ); err != nil {
s .Reset ()
return Result {}, fmt .Errorf ("attach stream %s to service %s failed: %w" , DialProtocol , ServiceName , err )
}
if err := s .Scope ().ReserveMemory (maxMsgSize , network .ReservationPriorityAlways ); err != nil {
s .Reset ()
return Result {}, fmt .Errorf ("failed to reserve memory for stream %s: %w" , DialProtocol , err )
}
defer s .Scope ().ReleaseMemory (maxMsgSize )
s .SetDeadline (time .Now ().Add (streamTimeout ))
defer s .Close ()
nonce := rand .Uint64 ()
ch := make (chan ma .Multiaddr , 1 )
ac .mu .Lock ()
ac .dialBackQueues [nonce ] = ch
ac .mu .Unlock ()
defer func () {
ac .mu .Lock ()
delete (ac .dialBackQueues , nonce )
ac .mu .Unlock ()
}()
msg := newDialRequest (reqs , nonce )
w := pbio .NewDelimitedWriter (s )
if err := w .WriteMsg (&msg ); err != nil {
s .Reset ()
return Result {}, fmt .Errorf ("dial request write failed: %w" , err )
}
r := pbio .NewDelimitedReader (s , maxMsgSize )
if err := r .ReadMsg (&msg ); err != nil {
s .Reset ()
return Result {}, fmt .Errorf ("dial msg read failed: %w" , err )
}
switch {
case msg .GetDialResponse () != nil :
break
case msg .GetDialDataRequest () != nil :
if err := validateDialDataRequest (reqs , &msg ); err != nil {
s .Reset ()
return Result {}, fmt .Errorf ("invalid dial data request: %s %w" , s .Conn ().RemoteMultiaddr (), err )
}
if err := sendDialData (ac .dialData , int (msg .GetDialDataRequest ().GetNumBytes ()), w , &msg ); err != nil {
s .Reset ()
return Result {}, fmt .Errorf ("dial data send failed: %w" , err )
}
if err := r .ReadMsg (&msg ); err != nil {
s .Reset ()
return Result {}, fmt .Errorf ("dial response read failed: %w" , err )
}
if msg .GetDialResponse () == nil {
s .Reset ()
return Result {}, fmt .Errorf ("invalid response type: %T" , msg .Msg )
}
default :
s .Reset ()
return Result {}, fmt .Errorf ("invalid msg type: %T" , msg .Msg )
}
resp := msg .GetDialResponse ()
if resp .GetStatus () != pb .DialResponse_OK {
if resp .GetStatus () == pb .DialResponse_E_DIAL_REFUSED {
return Result {AllAddrsRefused : true }, nil
}
return Result {}, fmt .Errorf ("dial request failed: response status %d %s" , resp .GetStatus (),
pb .DialResponse_ResponseStatus_name [int32 (resp .GetStatus ())])
}
if resp .GetDialStatus () == pb .DialStatus_UNUSED {
return Result {}, fmt .Errorf ("invalid response: invalid dial status UNUSED" )
}
if int (resp .AddrIdx ) >= len (reqs ) {
return Result {}, fmt .Errorf ("invalid response: addr index out of range: %d [0-%d)" , resp .AddrIdx , len (reqs ))
}
var dialBackAddr ma .Multiaddr
if resp .GetDialStatus () == pb .DialStatus_OK {
timer := time .NewTimer (dialBackStreamTimeout )
select {
case at := <- ch :
dialBackAddr = at
case <- ctx .Done ():
case <- timer .C :
}
timer .Stop ()
}
return ac .newResult (resp , reqs , dialBackAddr )
}
func validateDialDataRequest(reqs []Request , msg *pb .Message ) error {
idx := int (msg .GetDialDataRequest ().AddrIdx )
if idx >= len (reqs ) {
return fmt .Errorf ("addr index out of range: %d [0-%d)" , idx , len (reqs ))
}
if msg .GetDialDataRequest ().NumBytes > maxHandshakeSizeBytes {
return fmt .Errorf ("requested data too high: %d" , msg .GetDialDataRequest ().NumBytes )
}
if !reqs [idx ].SendDialData {
return fmt .Errorf ("low priority addr: %s index %d" , reqs [idx ].Addr , idx )
}
return nil
}
func (ac *client ) newResult (resp *pb .DialResponse , reqs []Request , dialBackAddr ma .Multiaddr ) (Result , error ) {
idx := int (resp .AddrIdx )
if idx >= len (reqs ) {
return Result {}, fmt .Errorf ("addrs index(%d) greater than len(reqs)(%d)" , idx , len (reqs ))
}
addr := reqs [idx ].Addr
rch := network .ReachabilityUnknown
switch resp .DialStatus {
case pb .DialStatus_OK :
if !ac .areAddrsConsistent (dialBackAddr , addr ) {
return Result {}, fmt .Errorf ("invalid response: dialBackAddr: %s, respAddr: %s" , dialBackAddr , addr )
}
rch = network .ReachabilityPublic
case pb .DialStatus_E_DIAL_BACK_ERROR :
if !ac .areAddrsConsistent (dialBackAddr , addr ) {
return Result {}, fmt .Errorf ("dial-back stream error: dialBackAddr: %s, respAddr: %s" , dialBackAddr , addr )
}
rch = network .ReachabilityPublic
case pb .DialStatus_E_DIAL_ERROR :
rch = network .ReachabilityPrivate
default :
log .Warn ("invalid status code received in response" ,
"address" , addr ,
"dial_status" , resp .DialStatus )
return Result {}, fmt .Errorf ("invalid response: invalid status code for addr %s: %d" , addr , resp .DialStatus )
}
return Result {
Addr : addr ,
Idx : idx ,
Reachability : rch ,
}, nil
}
func sendDialData(dialData []byte , numBytes int , w pbio .Writer , msg *pb .Message ) (err error ) {
ddResp := &pb .DialDataResponse {Data : dialData }
*msg = pb .Message {
Msg : &pb .Message_DialDataResponse {
DialDataResponse : ddResp ,
},
}
for remain := numBytes ; remain > 0 ; {
if remain < len (ddResp .Data ) {
ddResp .Data = ddResp .Data [:remain ]
}
if err := w .WriteMsg (msg ); err != nil {
return fmt .Errorf ("write failed: %w" , err )
}
remain -= len (dialData )
}
return nil
}
func newDialRequest(reqs []Request , nonce uint64 ) pb .Message {
addrbs := make ([][]byte , len (reqs ))
for i , r := range reqs {
addrbs [i ] = r .Addr .Bytes ()
}
return pb .Message {
Msg : &pb .Message_DialRequest {
DialRequest : &pb .DialRequest {
Addrs : addrbs ,
Nonce : nonce ,
},
},
}
}
func (ac *client ) handleDialBack (s network .Stream ) {
defer func () {
if rerr := recover (); rerr != nil {
fmt .Fprintf (os .Stderr , "caught panic: %s\n%s\n" , rerr , debug .Stack ())
}
s .Reset ()
}()
if err := s .Scope ().SetService (ServiceName ); err != nil {
log .Debug ("failed to attach stream to service" ,
"service_name" , ServiceName ,
"error" , err )
s .Reset ()
return
}
if err := s .Scope ().ReserveMemory (dialBackMaxMsgSize , network .ReservationPriorityAlways ); err != nil {
log .Debug ("failed to reserve memory for stream" ,
"protocol" , DialBackProtocol ,
"error" , err )
s .Reset ()
return
}
defer s .Scope ().ReleaseMemory (dialBackMaxMsgSize )
s .SetDeadline (time .Now ().Add (dialBackStreamTimeout ))
defer s .Close ()
r := pbio .NewDelimitedReader (s , dialBackMaxMsgSize )
var msg pb .DialBack
if err := r .ReadMsg (&msg ); err != nil {
log .Debug ("failed to read dialback message" ,
"remote_peer" , s .Conn ().RemotePeer (),
"error" , err )
s .Reset ()
return
}
nonce := msg .GetNonce ()
ac .mu .Lock ()
ch := ac .dialBackQueues [nonce ]
ac .mu .Unlock ()
if ch == nil {
log .Debug ("dialback received with invalid nonce" ,
"local_multiaddr" , s .Conn ().LocalMultiaddr (),
"remote_peer" , s .Conn ().RemotePeer (),
"nonce" , nonce )
s .Reset ()
return
}
select {
case ch <- s .Conn ().LocalMultiaddr ():
default :
log .Debug ("multiple dialbacks received" ,
"local_multiaddr" , s .Conn ().LocalMultiaddr (),
"remote_peer" , s .Conn ().RemotePeer ())
s .Reset ()
return
}
w := pbio .NewDelimitedWriter (s )
res := pb .DialBackResponse {}
if err := w .WriteMsg (&res ); err != nil {
log .Debug ("failed to write dialback response" ,
"error" , err )
s .Reset ()
}
}
var tlsWSAddr = ma .StringCast ("/tls/ws" )
func normalizeMultiaddr(addr ma .Multiaddr ) ma .Multiaddr {
addr = removeTrailing (addr , ma .P_P2P )
addr = removeTrailing (addr , ma .P_CERTHASH )
for i , c := range addr {
if c .Code () == ma .P_WSS {
na := make (ma .Multiaddr , 0 , len (addr )+1 )
na = append (na , addr [:i ]...)
na = append (na , tlsWSAddr ...)
na = append (na , addr [i +1 :]...)
addr = na
break
}
}
for i , c := range addr {
if c .Code () == ma .P_SNI {
na := make (ma .Multiaddr , 0 , len (addr )-1 )
na = append (na , addr [:i ]...)
na = append (na , addr [i +1 :]...)
addr = na
break
}
}
return addr
}
func removeTrailing(addr ma .Multiaddr , protocolCode int ) ma .Multiaddr {
for i := len (addr ) - 1 ; i >= 0 ; i -- {
if addr [i ].Code () != protocolCode {
return addr [:i +1 ]
}
}
return nil
}
func (ac *client ) areAddrsConsistent (connLocalAddr , dialedAddr ma .Multiaddr ) bool {
if len (connLocalAddr ) == 0 || len (dialedAddr ) == 0 {
return false
}
connLocalAddr = normalizeMultiaddr (connLocalAddr )
dialedAddr = normalizeMultiaddr (dialedAddr )
localProtos := connLocalAddr .Protocols ()
externalProtos := dialedAddr .Protocols ()
if len (localProtos ) != len (externalProtos ) {
return false
}
for i , lp := range localProtos {
ep := externalProtos [i ]
if i == 0 {
switch ep .Code {
case ma .P_DNS , ma .P_DNSADDR :
if lp .Code == ma .P_IP4 || lp .Code == ma .P_IP6 {
continue
}
return false
case ma .P_DNS4 :
if lp .Code == ma .P_IP4 {
continue
}
return false
case ma .P_DNS6 :
if lp .Code == ma .P_IP6 {
continue
}
return false
}
if lp .Code != ep .Code {
return false
}
} else if lp .Code != ep .Code {
return false
}
}
return true
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .