package libp2pwebtransport
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/benbjohnson/clock"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/webtransport-go"
)
var log = logging .Logger ("webtransport" )
const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport"
const errorCodeConnectionGating = 0x47415445
const certValidity = 14 * 24 * time .Hour
type Option func (*transport ) error
func WithClock (cl clock .Clock ) Option {
return func (t *transport ) error {
t .clock = cl
return nil
}
}
func WithTLSClientConfig (c *tls .Config ) Option {
return func (t *transport ) error {
t .tlsClientConf = c
return nil
}
}
func WithHandshakeTimeout (d time .Duration ) Option {
return func (t *transport ) error {
t .handshakeTimeout = d
return nil
}
}
type transport struct {
privKey ic .PrivKey
pid peer .ID
clock clock .Clock
connManager *quicreuse .ConnManager
rcmgr network .ResourceManager
gater connmgr .ConnectionGater
listenOnce sync .Once
listenOnceErr error
certManager *certManager
hasCertManager atomic .Bool
staticTLSConf *tls .Config
tlsClientConf *tls .Config
noise *noise .Transport
connMx sync .Mutex
conns map [quic .Connection ]*conn
handshakeTimeout time .Duration
}
var _ tpt .Transport = &transport {}
var _ tpt .Resolver = &transport {}
var _ io .Closer = &transport {}
func New (key ic .PrivKey , psk pnet .PSK , connManager *quicreuse .ConnManager , gater connmgr .ConnectionGater , rcmgr network .ResourceManager , opts ...Option ) (tpt .Transport , error ) {
if len (psk ) > 0 {
log .Error ("WebTransport doesn't support private networks yet." )
return nil , errors .New ("WebTransport doesn't support private networks yet" )
}
if rcmgr == nil {
rcmgr = &network .NullResourceManager {}
}
id , err := peer .IDFromPrivateKey (key )
if err != nil {
return nil , err
}
t := &transport {
pid : id ,
privKey : key ,
rcmgr : rcmgr ,
gater : gater ,
clock : clock .New (),
connManager : connManager ,
conns : map [quic .Connection ]*conn {},
handshakeTimeout : handshakeTimeout ,
}
for _ , opt := range opts {
if err := opt (t ); err != nil {
return nil , err
}
}
n , err := noise .New (noise .ID , key , nil )
if err != nil {
return nil , err
}
t .noise = n
return t , nil
}
func (t *transport ) Dial (ctx context .Context , raddr ma .Multiaddr , p peer .ID ) (tpt .CapableConn , error ) {
scope , err := t .rcmgr .OpenConnection (network .DirOutbound , false , raddr )
if err != nil {
log .Debugw ("resource manager blocked outgoing connection" , "peer" , p , "addr" , raddr , "error" , err )
return nil , err
}
c , err := t .dialWithScope (ctx , raddr , p , scope )
if err != nil {
scope .Done ()
return nil , err
}
return c , nil
}
func (t *transport ) dialWithScope (ctx context .Context , raddr ma .Multiaddr , p peer .ID , scope network .ConnManagementScope ) (tpt .CapableConn , error ) {
_ , addr , err := manet .DialArgs (raddr )
if err != nil {
return nil , err
}
url := fmt .Sprintf ("https://%s%s?type=noise" , addr , webtransportHTTPEndpoint )
certHashes , err := extractCertHashes (raddr )
if err != nil {
return nil , err
}
if len (certHashes ) == 0 {
return nil , errors .New ("can't dial webtransport without certhashes" )
}
sni , _ := extractSNI (raddr )
if err := scope .SetPeer (p ); err != nil {
log .Debugw ("resource manager blocked outgoing connection for peer" , "peer" , p , "addr" , raddr , "error" , err )
return nil , err
}
maddr , _ := ma .SplitFunc (raddr , func (c ma .Component ) bool { return c .Protocol ().Code == ma .P_WEBTRANSPORT })
sess , qconn , err := t .dial (ctx , maddr , url , sni , certHashes )
if err != nil {
return nil , err
}
sconn , err := t .upgrade (ctx , sess , p , certHashes )
if err != nil {
sess .CloseWithError (1 , "" )
qconn .CloseWithError (1 , "" )
return nil , err
}
if t .gater != nil && !t .gater .InterceptSecured (network .DirOutbound , p , sconn ) {
sess .CloseWithError (errorCodeConnectionGating , "" )
qconn .CloseWithError (errorCodeConnectionGating , "" )
return nil , fmt .Errorf ("secured connection gated" )
}
conn := newConn (t , sess , sconn , scope , qconn )
t .addConn (qconn , conn )
return conn , nil
}
func (t *transport ) dial (ctx context .Context , addr ma .Multiaddr , url , sni string , certHashes []multihash .DecodedMultihash ) (*webtransport .Session , quic .Connection , error ) {
var tlsConf *tls .Config
if t .tlsClientConf != nil {
tlsConf = t .tlsClientConf .Clone ()
} else {
tlsConf = &tls .Config {}
}
tlsConf .NextProtos = append (tlsConf .NextProtos , http3 .NextProtoH3 )
if sni != "" {
tlsConf .ServerName = sni
}
if len (certHashes ) > 0 {
tlsConf .InsecureSkipVerify = true
tlsConf .VerifyPeerCertificate = func (rawCerts [][]byte , _ [][]*x509 .Certificate ) error {
return verifyRawCerts (rawCerts , certHashes )
}
}
ctx = quicreuse .WithAssociation (ctx , t )
conn , err := t .connManager .DialQUIC (ctx , addr , tlsConf , t .allowWindowIncrease )
if err != nil {
return nil , nil , err
}
dialer := webtransport .Dialer {
DialAddr : func (_ context .Context , _ string , _ *tls .Config , _ *quic .Config ) (quic .EarlyConnection , error ) {
return conn .(quic .EarlyConnection ), nil
},
QUICConfig : t .connManager .ClientConfig ().Clone (),
}
rsp , sess , err := dialer .Dial (ctx , url , nil )
if err != nil {
conn .CloseWithError (1 , "" )
return nil , nil , err
}
if rsp .StatusCode < 200 || rsp .StatusCode > 299 {
conn .CloseWithError (1 , "" )
return nil , nil , fmt .Errorf ("invalid response status code: %d" , rsp .StatusCode )
}
return sess , conn , err
}
func (t *transport ) upgrade (ctx context .Context , sess *webtransport .Session , p peer .ID , certHashes []multihash .DecodedMultihash ) (*connSecurityMultiaddrs , error ) {
local , err := toWebtransportMultiaddr (sess .LocalAddr ())
if err != nil {
return nil , fmt .Errorf ("error determining local addr: %w" , err )
}
remote , err := toWebtransportMultiaddr (sess .RemoteAddr ())
if err != nil {
return nil , fmt .Errorf ("error determining remote addr: %w" , err )
}
str , err := sess .OpenStreamSync (ctx )
if err != nil {
return nil , err
}
defer str .Close ()
var verified bool
n , err := t .noise .WithSessionOptions (noise .EarlyData (newEarlyDataReceiver (func (b *pb .NoiseExtensions ) error {
decodedCertHashes , err := decodeCertHashesFromProtobuf (b .WebtransportCerthashes )
if err != nil {
return err
}
for _ , sent := range certHashes {
var found bool
for _ , rcvd := range decodedCertHashes {
if sent .Code == rcvd .Code && bytes .Equal (sent .Digest , rcvd .Digest ) {
found = true
break
}
}
if !found {
return fmt .Errorf ("missing cert hash: %v" , sent )
}
}
verified = true
return nil
}), nil ))
if err != nil {
return nil , fmt .Errorf ("failed to create Noise transport: %w" , err )
}
c , err := n .SecureOutbound (ctx , &webtransportStream {Stream : str , wsess : sess }, p )
if err != nil {
return nil , err
}
defer c .Close ()
if !verified {
return nil , errors .New ("didn't verify" )
}
return &connSecurityMultiaddrs {
ConnSecurity : c ,
ConnMultiaddrs : &connMultiaddrs {local : local , remote : remote },
}, nil
}
func decodeCertHashesFromProtobuf(b [][]byte ) ([]multihash .DecodedMultihash , error ) {
hashes := make ([]multihash .DecodedMultihash , 0 , len (b ))
for _ , h := range b {
dh , err := multihash .Decode (h )
if err != nil {
return nil , fmt .Errorf ("failed to decode hash: %w" , err )
}
hashes = append (hashes , *dh )
}
return hashes , nil
}
func (t *transport ) CanDial (addr ma .Multiaddr ) bool {
ok , _ := IsWebtransportMultiaddr (addr )
return ok
}
func (t *transport ) Listen (laddr ma .Multiaddr ) (tpt .Listener , error ) {
isWebTransport , certhashCount := IsWebtransportMultiaddr (laddr )
if !isWebTransport {
return nil , fmt .Errorf ("cannot listen on non-WebTransport addr: %s" , laddr )
}
if certhashCount > 0 {
return nil , fmt .Errorf ("cannot listen on a specific certhash non-WebTransport addr: %s" , laddr )
}
if t .staticTLSConf == nil {
t .listenOnce .Do (func () {
t .certManager , t .listenOnceErr = newCertManager (t .privKey , t .clock )
t .hasCertManager .Store (true )
})
if t .listenOnceErr != nil {
return nil , t .listenOnceErr
}
} else {
return nil , errors .New ("static TLS config not supported on WebTransport" )
}
tlsConf := t .staticTLSConf .Clone ()
if tlsConf == nil {
tlsConf = &tls .Config {GetConfigForClient : func (*tls .ClientHelloInfo ) (*tls .Config , error ) {
return t .certManager .GetConfig (), nil
}}
}
tlsConf .NextProtos = append (tlsConf .NextProtos , http3 .NextProtoH3 )
ln , err := t .connManager .ListenQUICAndAssociate (t , laddr , tlsConf , t .allowWindowIncrease )
if err != nil {
return nil , err
}
return newListener (ln , t , t .staticTLSConf != nil )
}
func (t *transport ) Protocols () []int {
return []int {ma .P_WEBTRANSPORT }
}
func (t *transport ) Proxy () bool {
return false
}
func (t *transport ) Close () error {
t .listenOnce .Do (func () {})
if t .certManager != nil {
return t .certManager .Close ()
}
return nil
}
func (t *transport ) allowWindowIncrease (conn quic .Connection , size uint64 ) bool {
t .connMx .Lock ()
defer t .connMx .Unlock ()
c , ok := t .conns [conn ]
if !ok {
return false
}
return c .allowWindowIncrease (size )
}
func (t *transport ) addConn (conn quic .Connection , c *conn ) {
t .connMx .Lock ()
t .conns [conn ] = c
t .connMx .Unlock ()
}
func (t *transport ) removeConn (conn quic .Connection ) {
t .connMx .Lock ()
delete (t .conns , conn )
t .connMx .Unlock ()
}
func extractSNI(maddr ma .Multiaddr ) (sni string , foundSniComponent bool ) {
ma .ForEach (maddr , func (c ma .Component ) bool {
switch c .Protocol ().Code {
case ma .P_SNI :
sni = c .Value ()
foundSniComponent = true
return false
case ma .P_DNS , ma .P_DNS4 , ma .P_DNS6 , ma .P_DNSADDR :
sni = c .Value ()
return true
}
return true
})
return sni , foundSniComponent
}
func (t *transport ) Resolve (_ context .Context , maddr ma .Multiaddr ) ([]ma .Multiaddr , error ) {
sni , foundSniComponent := extractSNI (maddr )
if foundSniComponent || sni == "" {
return []ma .Multiaddr {maddr }, nil
}
beforeQuicMA , afterIncludingQuicMA := ma .SplitFunc (maddr , func (c ma .Component ) bool {
return c .Protocol ().Code == ma .P_QUIC_V1
})
if len (afterIncludingQuicMA ) == 0 {
return nil , fmt .Errorf ("no quic component found in %s" , maddr )
}
quicComponent , afterQuicMA := ma .SplitFirst (afterIncludingQuicMA )
if quicComponent == nil {
return nil , fmt .Errorf ("no quic component found in %s" , maddr )
}
sniComponent , err := ma .NewComponent (ma .ProtocolWithCode (ma .P_SNI ).Name , sni )
if err != nil {
return nil , err
}
result := beforeQuicMA .AppendComponent (quicComponent , sniComponent )
result = append (result , afterQuicMA ...)
return []ma .Multiaddr {result }, nil
}
func (t *transport ) AddCertHashes (m ma .Multiaddr ) (ma .Multiaddr , bool ) {
if !t .hasCertManager .Load () {
return m , false
}
return m .Encapsulate (t .certManager .AddrComponent ()), true
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .