package basichost
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/host/autonat"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/host/pstoremanager"
"github.com/libp2p/go-libp2p/p2p/host/relaysvc"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/prometheus/client_golang/prometheus"
logging "github.com/libp2p/go-libp2p/gologshim"
ma "github.com/multiformats/go-multiaddr"
msmux "github.com/multiformats/go-multistream"
)
var log = logging .Logger ("basichost" )
var (
DefaultNegotiationTimeout = 10 * time .Second
DefaultAddrsFactory = func (addrs []ma .Multiaddr ) []ma .Multiaddr { return addrs }
)
type AddrsFactory func ([]ma .Multiaddr ) []ma .Multiaddr
type BasicHost struct {
ctx context .Context
ctxCancel context .CancelFunc
closeSync sync .Once
refCount sync .WaitGroup
network network .Network
psManager *pstoremanager .PeerstoreManager
mux *msmux .MultistreamMuxer [protocol .ID ]
ids identify .IDService
hps *holepunch .Service
pings *ping .PingService
cmgr connmgr .ConnManager
eventbus event .Bus
relayManager *relaysvc .RelayManager
negtimeout time .Duration
emitters struct {
evtLocalProtocolsUpdated event .Emitter
}
autoNATMx sync .RWMutex
autoNat autonat .AutoNAT
autonatv2 *autonatv2 .AutoNAT
addressManager *addrsManager
}
var _ host .Host = (*BasicHost )(nil )
type HostOpts struct {
EventBus event .Bus
MultistreamMuxer *msmux .MultistreamMuxer [protocol .ID ]
NegotiationTimeout time .Duration
AddrsFactory AddrsFactory
NATManager func (network .Network ) NATManager
ConnManager connmgr .ConnManager
EnablePing bool
EnableRelayService bool
RelayServiceOpts []relayv2 .Option
UserAgent string
ProtocolVersion string
DisableSignedPeerRecord bool
EnableHolePunching bool
HolePunchingOptions []holepunch .Option
EnableMetrics bool
PrometheusRegisterer prometheus .Registerer
AutoNATv2MetricsTracker MetricsTracker
ObservedAddrsManager ObservedAddrsManager
AutoNATv2 *autonatv2 .AutoNAT
}
func NewHost (n network .Network , opts *HostOpts ) (*BasicHost , error ) {
if opts == nil {
opts = &HostOpts {}
}
if opts .EventBus == nil {
opts .EventBus = eventbus .NewBus ()
}
psManager , err := pstoremanager .NewPeerstoreManager (n .Peerstore (), opts .EventBus , n )
if err != nil {
return nil , err
}
hostCtx , cancel := context .WithCancel (context .Background ())
h := &BasicHost {
network : n ,
psManager : psManager ,
mux : msmux .NewMultistreamMuxer [protocol .ID ](),
negtimeout : DefaultNegotiationTimeout ,
eventbus : opts .EventBus ,
ctx : hostCtx ,
ctxCancel : cancel ,
}
if h .emitters .evtLocalProtocolsUpdated , err = h .eventbus .Emitter (&event .EvtLocalProtocolsUpdated {}, eventbus .Stateful ); err != nil {
return nil , err
}
if opts .MultistreamMuxer != nil {
h .mux = opts .MultistreamMuxer
}
idOpts := []identify .Option {
identify .UserAgent (opts .UserAgent ),
identify .ProtocolVersion (opts .ProtocolVersion ),
}
if opts .DisableSignedPeerRecord {
idOpts = append (idOpts , identify .DisableSignedPeerRecord ())
}
if opts .EnableMetrics {
idOpts = append (idOpts ,
identify .WithMetricsTracer (
identify .NewMetricsTracer (identify .WithRegisterer (opts .PrometheusRegisterer ))))
}
h .ids , err = identify .NewIDService (h , idOpts ...)
if err != nil {
return nil , fmt .Errorf ("failed to create Identify service: %s" , err )
}
addrFactory := DefaultAddrsFactory
if opts .AddrsFactory != nil {
addrFactory = opts .AddrsFactory
}
var natmgr NATManager
if opts .NATManager != nil {
natmgr = opts .NATManager (h .Network ())
}
if opts .AutoNATv2 != nil {
h .autonatv2 = opts .AutoNATv2
}
var autonatv2Client autonatv2Client
if h .autonatv2 != nil {
autonatv2Client = h .autonatv2
}
addCertHashesFunc := func (addrs []ma .Multiaddr ) []ma .Multiaddr {
return addrs
}
if swarm , ok := h .Network ().(interface {
AddCertHashes (addrs []ma .Multiaddr ) []ma .Multiaddr
}); ok {
addCertHashesFunc = swarm .AddCertHashes
}
h .addressManager , err = newAddrsManager (
h .eventbus ,
natmgr ,
addrFactory ,
h .Network ().ListenAddresses ,
addCertHashesFunc ,
opts .ObservedAddrsManager ,
autonatv2Client ,
opts .EnableMetrics ,
opts .PrometheusRegisterer ,
opts .DisableSignedPeerRecord ,
h .Peerstore ().PrivKey (h .ID ()),
h .Peerstore (),
h .ID (),
)
if err != nil {
return nil , fmt .Errorf ("failed to create address service: %w" , err )
}
if opts .EnableHolePunching {
if opts .EnableMetrics {
hpOpts := []holepunch .Option {
holepunch .WithMetricsTracer (holepunch .NewMetricsTracer (holepunch .WithRegisterer (opts .PrometheusRegisterer )))}
opts .HolePunchingOptions = append (hpOpts , opts .HolePunchingOptions ...)
}
h .hps , err = holepunch .NewService (h , h .ids , h .addressManager .HolePunchAddrs , opts .HolePunchingOptions ...)
if err != nil {
return nil , fmt .Errorf ("failed to create hole punch service: %w" , err )
}
}
if uint64 (opts .NegotiationTimeout ) != 0 {
h .negtimeout = opts .NegotiationTimeout
}
if opts .ConnManager == nil {
h .cmgr = &connmgr .NullConnMgr {}
} else {
h .cmgr = opts .ConnManager
n .Notify (h .cmgr .Notifee ())
}
if opts .EnableRelayService {
if opts .EnableMetrics {
metricsOpt := []relayv2 .Option {
relayv2 .WithMetricsTracer (
relayv2 .NewMetricsTracer (relayv2 .WithRegisterer (opts .PrometheusRegisterer )))}
opts .RelayServiceOpts = append (metricsOpt , opts .RelayServiceOpts ...)
}
h .relayManager = relaysvc .NewRelayManager (h , opts .RelayServiceOpts ...)
}
if opts .EnablePing {
h .pings = ping .NewPingService (h )
}
n .SetStreamHandler (h .newStreamHandler )
return h , nil
}
func (h *BasicHost ) Start () {
h .psManager .Start ()
if h .autonatv2 != nil {
err := h .autonatv2 .Start (h )
if err != nil {
log .Error ("autonat v2 failed to start" , "err" , err )
}
}
h .Network ().Notify (h .addressManager .NetNotifee ())
if err := h .addressManager .Start (); err != nil {
log .Error ("address service failed to start" , "err" , err )
}
h .ids .Start ()
}
func (h *BasicHost ) newStreamHandler (s network .Stream ) {
before := time .Now ()
if h .negtimeout > 0 {
if err := s .SetDeadline (time .Now ().Add (h .negtimeout )); err != nil {
log .Debug ("setting stream deadline" , "err" , err )
s .Reset ()
return
}
}
protoID , handle , err := h .Mux ().Negotiate (s )
took := time .Since (before )
if err != nil {
if err == io .EOF {
lvl := slog .LevelDebug
if took > time .Second *10 {
lvl = slog .LevelWarn
}
log .Log (context .Background (), lvl , "protocol EOF" , "remote_peer" , s .Conn ().RemotePeer (), "duration" , took )
} else {
log .Debug ("protocol mux failed" , "err" , err , "duration" , took , "stream_id" , s .ID (), "remote_peer" , s .Conn ().RemotePeer (), "remote_multiaddr" , s .Conn ().RemoteMultiaddr ())
}
s .ResetWithError (network .StreamProtocolNegotiationFailed )
return
}
if h .negtimeout > 0 {
if err := s .SetDeadline (time .Time {}); err != nil {
log .Debug ("resetting stream deadline" , "err" , err )
s .Reset ()
return
}
}
if err := s .SetProtocol (protoID ); err != nil {
log .Debug ("error setting stream protocol" , "err" , err )
s .ResetWithError (network .StreamResourceLimitExceeded )
return
}
log .Debug ("negotiated" , "protocol" , protoID , "duration" , took )
handle (protoID , s )
}
func (h *BasicHost ) ID () peer .ID {
return h .Network ().LocalPeer ()
}
func (h *BasicHost ) Peerstore () peerstore .Peerstore {
return h .Network ().Peerstore ()
}
func (h *BasicHost ) Network () network .Network {
return h .network
}
func (h *BasicHost ) Mux () protocol .Switch {
return h .mux
}
func (h *BasicHost ) IDService () identify .IDService {
return h .ids
}
func (h *BasicHost ) EventBus () event .Bus {
return h .eventbus
}
func (h *BasicHost ) SetStreamHandler (pid protocol .ID , handler network .StreamHandler ) {
h .Mux ().AddHandler (pid , func (_ protocol .ID , rwc io .ReadWriteCloser ) error {
is := rwc .(network .Stream )
handler (is )
return nil
})
h .emitters .evtLocalProtocolsUpdated .Emit (event .EvtLocalProtocolsUpdated {
Added : []protocol .ID {pid },
})
}
func (h *BasicHost ) SetStreamHandlerMatch (pid protocol .ID , m func (protocol .ID ) bool , handler network .StreamHandler ) {
h .Mux ().AddHandlerWithFunc (pid , m , func (_ protocol .ID , rwc io .ReadWriteCloser ) error {
is := rwc .(network .Stream )
handler (is )
return nil
})
h .emitters .evtLocalProtocolsUpdated .Emit (event .EvtLocalProtocolsUpdated {
Added : []protocol .ID {pid },
})
}
func (h *BasicHost ) RemoveStreamHandler (pid protocol .ID ) {
h .Mux ().RemoveHandler (pid )
h .emitters .evtLocalProtocolsUpdated .Emit (event .EvtLocalProtocolsUpdated {
Removed : []protocol .ID {pid },
})
}
func (h *BasicHost ) NewStream (ctx context .Context , p peer .ID , pids ...protocol .ID ) (str network .Stream , strErr error ) {
if _ , ok := ctx .Deadline (); !ok {
if h .negtimeout > 0 {
var cancel context .CancelFunc
ctx , cancel = context .WithTimeout (ctx , h .negtimeout )
defer cancel ()
}
}
if nodial , _ := network .GetNoDial (ctx ); !nodial {
err := h .Connect (ctx , peer .AddrInfo {ID : p })
if err != nil {
return nil , err
}
}
s , err := h .Network ().NewStream (network .WithNoDial (ctx , "already dialed" ), p )
if err != nil {
if errors .Is (err , network .ErrNoConn ) {
return nil , errors .New ("connection failed" )
}
return nil , fmt .Errorf ("failed to open stream: %w" , err )
}
defer func () {
if strErr != nil && s != nil {
s .ResetWithError (network .StreamProtocolNegotiationFailed )
}
}()
select {
case <- h .ids .IdentifyWait (s .Conn ()):
case <- ctx .Done ():
return nil , fmt .Errorf ("identify failed to complete: %w" , ctx .Err ())
}
pref , err := h .preferredProtocol (p , pids )
if err != nil {
return nil , err
}
if pref != "" {
if err := s .SetProtocol (pref ); err != nil {
return nil , err
}
lzcon := msmux .NewMSSelect (s , pref )
return &streamWrapper {
Stream : s ,
rw : lzcon ,
}, nil
}
var selected protocol .ID
errCh := make (chan error , 1 )
go func () {
selected , err = msmux .SelectOneOf (pids , s )
errCh <- err
}()
select {
case err = <- errCh :
if err != nil {
return nil , fmt .Errorf ("failed to negotiate protocol: %w" , err )
}
case <- ctx .Done ():
s .ResetWithError (network .StreamProtocolNegotiationFailed )
<-errCh
return nil , fmt .Errorf ("failed to negotiate protocol: %w" , ctx .Err ())
}
if err := s .SetProtocol (selected ); err != nil {
s .ResetWithError (network .StreamResourceLimitExceeded )
return nil , err
}
_ = h .Peerstore ().AddProtocols (p , selected )
return s , nil
}
func (h *BasicHost ) preferredProtocol (p peer .ID , pids []protocol .ID ) (protocol .ID , error ) {
supported , err := h .Peerstore ().SupportsProtocols (p , pids ...)
if err != nil {
return "" , err
}
var out protocol .ID
if len (supported ) > 0 {
out = supported [0 ]
}
return out , nil
}
func (h *BasicHost ) Connect (ctx context .Context , pi peer .AddrInfo ) error {
h .Peerstore ().AddAddrs (pi .ID , pi .Addrs , peerstore .TempAddrTTL )
forceDirect , _ := network .GetForceDirectDial (ctx )
canUseLimitedConn , _ := network .GetAllowLimitedConn (ctx )
if !forceDirect {
connectedness := h .Network ().Connectedness (pi .ID )
if connectedness == network .Connected || (canUseLimitedConn && connectedness == network .Limited ) {
return nil
}
}
return h .dialPeer (ctx , pi .ID )
}
func (h *BasicHost ) dialPeer (ctx context .Context , p peer .ID ) error {
log .Debug ("host dialing peer" , "source_peer" , h .ID (), "destination_peer" , p )
c , err := h .Network ().DialPeer (ctx , p )
if err != nil {
return fmt .Errorf ("failed to dial: %w" , err )
}
select {
case <- h .ids .IdentifyWait (c ):
case <- ctx .Done ():
return fmt .Errorf ("identify failed to complete: %w" , ctx .Err ())
}
log .Debug ("host finished dialing peer" , "source_peer" , h .ID (), "destination_peer" , p )
return nil
}
func (h *BasicHost ) ConnManager () connmgr .ConnManager {
return h .cmgr
}
func (h *BasicHost ) Addrs () []ma .Multiaddr {
return h .addressManager .Addrs ()
}
func (h *BasicHost ) AllAddrs () []ma .Multiaddr {
return h .addressManager .DirectAddrs ()
}
func (h *BasicHost ) ConfirmedAddrs () (reachable []ma .Multiaddr , unreachable []ma .Multiaddr , unknown []ma .Multiaddr ) {
return h .addressManager .ConfirmedAddrs ()
}
func (h *BasicHost ) SetAutoNat (a autonat .AutoNAT ) {
h .autoNATMx .Lock ()
defer h .autoNATMx .Unlock ()
if h .autoNat == nil {
h .autoNat = a
}
}
func (h *BasicHost ) GetAutoNat () autonat .AutoNAT {
h .autoNATMx .Lock ()
defer h .autoNATMx .Unlock ()
return h .autoNat
}
func (h *BasicHost ) Reachability () network .Reachability {
return *h .addressManager .hostReachability .Load ()
}
func (h *BasicHost ) Close () error {
h .closeSync .Do (func () {
h .ctxCancel ()
if h .cmgr != nil {
h .cmgr .Close ()
}
if h .ids != nil {
h .ids .Close ()
}
if h .autoNat != nil {
h .autoNat .Close ()
}
if h .relayManager != nil {
h .relayManager .Close ()
}
if h .hps != nil {
h .hps .Close ()
}
if h .autonatv2 != nil {
h .autonatv2 .Close ()
}
_ = h .emitters .evtLocalProtocolsUpdated .Close ()
if err := h .network .Close (); err != nil {
log .Error ("swarm close failed" , "err" , err )
}
h .addressManager .Close ()
h .psManager .Close ()
if h .Peerstore () != nil {
h .Peerstore ().Close ()
}
h .refCount .Wait ()
if h .Network ().ResourceManager () != nil {
h .Network ().ResourceManager ().Close ()
}
})
return nil
}
type streamWrapper struct {
network .Stream
rw io .ReadWriteCloser
}
func (s *streamWrapper ) Read (b []byte ) (int , error ) {
return s .rw .Read (b )
}
func (s *streamWrapper ) Write (b []byte ) (int , error ) {
return s .rw .Write (b )
}
func (s *streamWrapper ) Close () error {
_ = s .Stream .SetReadDeadline (time .Now ().Add (DefaultNegotiationTimeout ))
return s .rw .Close ()
}
func (s *streamWrapper ) CloseWrite () error {
if flusher , ok := s .rw .(interface { Flush () error }); ok {
_ = flusher .Flush ()
}
return s .Stream .CloseWrite ()
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .