package libp2pwebtransport

import (
	
	
	
	
	
	
	
	
	
	

	
	ic 
	
	
	
	tpt 
	
	
	

	
	logging 
	ma 
	manet 
	
	
	
	
)

var log = logging.Logger("webtransport")

const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport"

const errorCodeConnectionGating = 0x47415445 // GATE in ASCII

const certValidity = 14 * 24 * time.Hour

type Option func(*transport) error

func ( clock.Clock) Option {
	return func( *transport) error {
		.clock = 
		return nil
	}
}

// WithTLSClientConfig sets a custom tls.Config used for dialing.
// This option is most useful for setting a custom tls.Config.RootCAs certificate pool.
// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and
// overwrite the VerifyPeerCertificate callback.
func ( *tls.Config) Option {
	return func( *transport) error {
		.tlsClientConf = 
		return nil
	}
}

func ( time.Duration) Option {
	return func( *transport) error {
		.handshakeTimeout = 
		return nil
	}
}

type transport struct {
	privKey ic.PrivKey
	pid     peer.ID
	clock   clock.Clock

	connManager *quicreuse.ConnManager
	rcmgr       network.ResourceManager
	gater       connmgr.ConnectionGater

	listenOnce     sync.Once
	listenOnceErr  error
	certManager    *certManager
	hasCertManager atomic.Bool // set to true once the certManager is initialized
	staticTLSConf  *tls.Config
	tlsClientConf  *tls.Config

	noise *noise.Transport

	connMx           sync.Mutex
	conns            map[quic.Connection]*conn // quic connection -> *conn
	handshakeTimeout time.Duration
}

var _ tpt.Transport = &transport{}
var _ tpt.Resolver = &transport{}
var _ io.Closer = &transport{}

func ( ic.PrivKey,  pnet.PSK,  *quicreuse.ConnManager,  connmgr.ConnectionGater,  network.ResourceManager,  ...Option) (tpt.Transport, error) {
	if len() > 0 {
		log.Error("WebTransport doesn't support private networks yet.")
		return nil, errors.New("WebTransport doesn't support private networks yet")
	}
	if  == nil {
		 = &network.NullResourceManager{}
	}
	,  := peer.IDFromPrivateKey()
	if  != nil {
		return nil, 
	}
	 := &transport{
		pid:              ,
		privKey:          ,
		rcmgr:            ,
		gater:            ,
		clock:            clock.New(),
		connManager:      ,
		conns:            map[quic.Connection]*conn{},
		handshakeTimeout: handshakeTimeout,
	}
	for ,  := range  {
		if  := ();  != nil {
			return nil, 
		}
	}
	,  := noise.New(noise.ID, , nil)
	if  != nil {
		return nil, 
	}
	.noise = 
	return , nil
}

func ( *transport) ( context.Context,  ma.Multiaddr,  peer.ID) (tpt.CapableConn, error) {
	,  := .rcmgr.OpenConnection(network.DirOutbound, false, )
	if  != nil {
		log.Debugw("resource manager blocked outgoing connection", "peer", , "addr", , "error", )
		return nil, 
	}

	,  := .dialWithScope(, , , )
	if  != nil {
		.Done()
		return nil, 
	}

	return , nil
}

func ( *transport) ( context.Context,  ma.Multiaddr,  peer.ID,  network.ConnManagementScope) (tpt.CapableConn, error) {
	, ,  := manet.DialArgs()
	if  != nil {
		return nil, 
	}
	 := fmt.Sprintf("https://%s%s?type=noise", , webtransportHTTPEndpoint)
	,  := extractCertHashes()
	if  != nil {
		return nil, 
	}

	if len() == 0 {
		return nil, errors.New("can't dial webtransport without certhashes")
	}

	,  := extractSNI()

	if  := .SetPeer();  != nil {
		log.Debugw("resource manager blocked outgoing connection for peer", "peer", , "addr", , "error", )
		return nil, 
	}

	,  := ma.SplitFunc(, func( ma.Component) bool { return .Protocol().Code == ma.P_WEBTRANSPORT })
	, ,  := .dial(, , , , )
	if  != nil {
		return nil, 
	}
	,  := .upgrade(, , , )
	if  != nil {
		.CloseWithError(1, "")
		.CloseWithError(1, "")
		return nil, 
	}
	if .gater != nil && !.gater.InterceptSecured(network.DirOutbound, , ) {
		.CloseWithError(errorCodeConnectionGating, "")
		.CloseWithError(errorCodeConnectionGating, "")
		return nil, fmt.Errorf("secured connection gated")
	}
	 := newConn(, , , , )
	.addConn(, )
	return , nil
}

func ( *transport) ( context.Context,  ma.Multiaddr, ,  string,  []multihash.DecodedMultihash) (*webtransport.Session, quic.Connection, error) {
	var  *tls.Config
	if .tlsClientConf != nil {
		 = .tlsClientConf.Clone()
	} else {
		 = &tls.Config{}
	}
	.NextProtos = append(.NextProtos, http3.NextProtoH3)

	if  != "" {
		.ServerName = 
	}

	if len() > 0 {
		// This is not insecure. We verify the certificate ourselves.
		// See https://www.w3.org/TR/webtransport/#certificate-hashes.
		.InsecureSkipVerify = true
		.VerifyPeerCertificate = func( [][]byte,  [][]*x509.Certificate) error {
			return verifyRawCerts(, )
		}
	}
	 = quicreuse.WithAssociation(, )
	,  := .connManager.DialQUIC(, , , .allowWindowIncrease)
	if  != nil {
		return nil, nil, 
	}
	 := webtransport.Dialer{
		DialAddr: func( context.Context,  string,  *tls.Config,  *quic.Config) (quic.EarlyConnection, error) {
			return .(quic.EarlyConnection), nil
		},
		QUICConfig: .connManager.ClientConfig().Clone(),
	}
	, ,  := .Dial(, , nil)
	if  != nil {
		.CloseWithError(1, "")
		return nil, nil, 
	}
	if .StatusCode < 200 || .StatusCode > 299 {
		.CloseWithError(1, "")
		return nil, nil, fmt.Errorf("invalid response status code: %d", .StatusCode)
	}
	return , , 
}

func ( *transport) ( context.Context,  *webtransport.Session,  peer.ID,  []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) {
	,  := toWebtransportMultiaddr(.LocalAddr())
	if  != nil {
		return nil, fmt.Errorf("error determining local addr: %w", )
	}
	,  := toWebtransportMultiaddr(.RemoteAddr())
	if  != nil {
		return nil, fmt.Errorf("error determining remote addr: %w", )
	}

	,  := .OpenStreamSync()
	if  != nil {
		return nil, 
	}
	defer .Close()

	// Now run a Noise handshake (using early data) and get all the certificate hashes from the server.
	// We will verify that the certhashes we used to dial is a subset of the certhashes we received from the server.
	var  bool
	,  := .noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func( *pb.NoiseExtensions) error {
		,  := decodeCertHashesFromProtobuf(.WebtransportCerthashes)
		if  != nil {
			return 
		}
		for ,  := range  {
			var  bool
			for ,  := range  {
				if .Code == .Code && bytes.Equal(.Digest, .Digest) {
					 = true
					break
				}
			}
			if ! {
				return fmt.Errorf("missing cert hash: %v", )
			}
		}
		 = true
		return nil
	}), nil))
	if  != nil {
		return nil, fmt.Errorf("failed to create Noise transport: %w", )
	}
	,  := .SecureOutbound(, &webtransportStream{Stream: , wsess: }, )
	if  != nil {
		return nil, 
	}
	defer .Close()
	// The Noise handshake _should_ guarantee that our verification callback is called.
	// Double-check just in case.
	if ! {
		return nil, errors.New("didn't verify")
	}
	return &connSecurityMultiaddrs{
		ConnSecurity:   ,
		ConnMultiaddrs: &connMultiaddrs{local: , remote: },
	}, nil
}

func decodeCertHashesFromProtobuf( [][]byte) ([]multihash.DecodedMultihash, error) {
	 := make([]multihash.DecodedMultihash, 0, len())
	for ,  := range  {
		,  := multihash.Decode()
		if  != nil {
			return nil, fmt.Errorf("failed to decode hash: %w", )
		}
		 = append(, *)
	}
	return , nil
}

func ( *transport) ( ma.Multiaddr) bool {
	,  := IsWebtransportMultiaddr()
	return 
}

func ( *transport) ( ma.Multiaddr) (tpt.Listener, error) {
	,  := IsWebtransportMultiaddr()
	if ! {
		return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", )
	}
	if  > 0 {
		return nil, fmt.Errorf("cannot listen on a specific certhash non-WebTransport addr: %s", )
	}
	if .staticTLSConf == nil {
		.listenOnce.Do(func() {
			.certManager, .listenOnceErr = newCertManager(.privKey, .clock)
			.hasCertManager.Store(true)
		})
		if .listenOnceErr != nil {
			return nil, .listenOnceErr
		}
	} else {
		return nil, errors.New("static TLS config not supported on WebTransport")
	}
	 := .staticTLSConf.Clone()
	if  == nil {
		 = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
			return .certManager.GetConfig(), nil
		}}
	}
	.NextProtos = append(.NextProtos, http3.NextProtoH3)

	,  := .connManager.ListenQUICAndAssociate(, , , .allowWindowIncrease)
	if  != nil {
		return nil, 
	}
	return newListener(, , .staticTLSConf != nil)
}

func ( *transport) () []int {
	return []int{ma.P_WEBTRANSPORT}
}

func ( *transport) () bool {
	return false
}

func ( *transport) () error {
	.listenOnce.Do(func() {})
	if .certManager != nil {
		return .certManager.Close()
	}
	return nil
}

func ( *transport) ( quic.Connection,  uint64) bool {
	.connMx.Lock()
	defer .connMx.Unlock()

	,  := .conns[]
	if ! {
		return false
	}
	return .allowWindowIncrease()
}

func ( *transport) ( quic.Connection,  *conn) {
	.connMx.Lock()
	.conns[] = 
	.connMx.Unlock()
}

func ( *transport) ( quic.Connection) {
	.connMx.Lock()
	delete(.conns, )
	.connMx.Unlock()
}

// extractSNI returns what the SNI should be for the given maddr. If there is an
// SNI component in the multiaddr, then it will be returned and
// foundSniComponent will be true. If there's no SNI component, but there is a
// DNS-like component, then that will be returned for the sni and
// foundSniComponent will be false (since we didn't find an actual sni component).
func extractSNI( ma.Multiaddr) ( string,  bool) {
	ma.ForEach(, func( ma.Component) bool {
		switch .Protocol().Code {
		case ma.P_SNI:
			 = .Value()
			 = true
			return false
		case ma.P_DNS, ma.P_DNS4, ma.P_DNS6, ma.P_DNSADDR:
			 = .Value()
			// Keep going in case we find an `sni` component
			return true
		}
		return true
	})
	return , 
}

// Resolve implements transport.Resolver
func ( *transport) ( context.Context,  ma.Multiaddr) ([]ma.Multiaddr, error) {
	,  := extractSNI()

	if  ||  == "" {
		// The multiaddr already had an sni field, we can keep using it. Or we don't have any sni like thing
		return []ma.Multiaddr{}, nil
	}

	,  := ma.SplitFunc(, func( ma.Component) bool {
		return .Protocol().Code == ma.P_QUIC_V1
	})
	if len() == 0 {
		return nil, fmt.Errorf("no quic component found in %s", )
	}
	,  := ma.SplitFirst()
	if  == nil {
		// Should not happen since we split on P_QUIC_V1 already
		return nil, fmt.Errorf("no quic component found in %s", )
	}
	,  := ma.NewComponent(ma.ProtocolWithCode(ma.P_SNI).Name, )
	if  != nil {
		return nil, 
	}
	 := .AppendComponent(, )
	 = append(, ...)
	return []ma.Multiaddr{}, nil
}

// AddCertHashes adds the current certificate hashes to a multiaddress.
// If called before Listen, it's a no-op.
func ( *transport) ( ma.Multiaddr) (ma.Multiaddr, bool) {
	if !.hasCertManager.Load() {
		return , false
	}
	return .Encapsulate(.certManager.AddrComponent()), true
}