package http3
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
"sync"
"sync/atomic"
"golang.org/x/net/http/httpguts"
"github.com/quic-go/quic-go"
)
type Settings struct {
EnableDatagrams bool
EnableExtendedConnect bool
Other map [uint64 ]uint64
}
type RoundTripOpt struct {
OnlyCachedConn bool
}
type clientConn interface {
OpenRequestStream(context .Context ) (*RequestStream , error )
RoundTrip(*http .Request ) (*http .Response , error )
handleUnidirectionalStream(*quic .ReceiveStream )
}
type roundTripperWithCount struct {
cancel context .CancelFunc
dialing chan struct {}
dialErr error
conn *quic .Conn
clientConn clientConn
useCount atomic .Int64
}
func (r *roundTripperWithCount ) Close () error {
r .cancel ()
<-r .dialing
if r .conn != nil {
return r .conn .CloseWithError (0 , "" )
}
return nil
}
type Transport struct {
TLSClientConfig *tls .Config
QUICConfig *quic .Config
Dial func (ctx context .Context , addr string , tlsCfg *tls .Config , cfg *quic .Config ) (*quic .Conn , error )
EnableDatagrams bool
AdditionalSettings map [uint64 ]uint64
MaxResponseHeaderBytes int
DisableCompression bool
Logger *slog .Logger
mutex sync .Mutex
initOnce sync .Once
initErr error
newClientConn func (*quic .Conn ) clientConn
clients map [string ]*roundTripperWithCount
transport *quic .Transport
closed bool
}
var (
_ http .RoundTripper = &Transport {}
_ io .Closer = &Transport {}
)
var (
ErrNoCachedConn = errors .New ("http3: no cached connection was available" )
ErrTransportClosed = errors .New ("http3: transport is closed" )
)
func (t *Transport ) init () error {
if t .newClientConn == nil {
t .newClientConn = func (conn *quic .Conn ) clientConn {
return newClientConn (
conn ,
t .EnableDatagrams ,
t .AdditionalSettings ,
t .MaxResponseHeaderBytes ,
t .DisableCompression ,
t .Logger ,
)
}
}
if t .QUICConfig == nil {
t .QUICConfig = defaultQuicConfig .Clone ()
t .QUICConfig .EnableDatagrams = t .EnableDatagrams
}
if t .EnableDatagrams && !t .QUICConfig .EnableDatagrams {
return errors .New ("HTTP Datagrams enabled, but QUIC Datagrams disabled" )
}
if len (t .QUICConfig .Versions ) == 0 {
t .QUICConfig = t .QUICConfig .Clone ()
t .QUICConfig .Versions = []quic .Version {quic .SupportedVersions ()[0 ]}
}
if len (t .QUICConfig .Versions ) != 1 {
return errors .New ("can only use a single QUIC version for dialing a HTTP/3 connection" )
}
if t .QUICConfig .MaxIncomingStreams == 0 {
t .QUICConfig .MaxIncomingStreams = -1
}
if t .Dial == nil {
udpConn , err := net .ListenUDP ("udp" , nil )
if err != nil {
return err
}
t .transport = &quic .Transport {Conn : udpConn }
}
return nil
}
func (t *Transport ) RoundTripOpt (req *http .Request , opt RoundTripOpt ) (*http .Response , error ) {
rsp , err := t .roundTripOpt (req , opt )
if err != nil {
if req .Body != nil {
req .Body .Close ()
}
return nil , err
}
return rsp , nil
}
func (t *Transport ) roundTripOpt (req *http .Request , opt RoundTripOpt ) (*http .Response , error ) {
t .initOnce .Do (func () { t .initErr = t .init () })
if t .initErr != nil {
return nil , t .initErr
}
if req .URL == nil {
return nil , errors .New ("http3: nil Request.URL" )
}
if req .URL .Scheme != "https" {
return nil , fmt .Errorf ("http3: unsupported protocol scheme: %s" , req .URL .Scheme )
}
if req .URL .Host == "" {
return nil , errors .New ("http3: no Host in request URL" )
}
if req .Header == nil {
return nil , errors .New ("http3: nil Request.Header" )
}
if req .Method != "" && !validMethod (req .Method ) {
return nil , fmt .Errorf ("http3: invalid method %q" , req .Method )
}
for k , vv := range req .Header {
if !httpguts .ValidHeaderFieldName (k ) {
return nil , fmt .Errorf ("http3: invalid http header field name %q" , k )
}
for _ , v := range vv {
if !httpguts .ValidHeaderFieldValue (v ) {
return nil , fmt .Errorf ("http3: invalid http header field value %q for key %v" , v , k )
}
}
}
return t .doRoundTripOpt (req , opt , false )
}
func (t *Transport ) doRoundTripOpt (req *http .Request , opt RoundTripOpt , isRetried bool ) (*http .Response , error ) {
hostname := authorityAddr (hostnameFromURL (req .URL ))
trace := httptrace .ContextClientTrace (req .Context ())
traceGetConn (trace , hostname )
cl , isReused , err := t .getClient (req .Context (), hostname , opt .OnlyCachedConn )
if err != nil {
return nil , err
}
select {
case <- cl .dialing :
case <- req .Context ().Done ():
return nil , context .Cause (req .Context ())
}
if cl .dialErr != nil {
t .removeClient (hostname )
return nil , cl .dialErr
}
defer cl .useCount .Add (-1 )
traceGotConn (trace , cl .conn , isReused )
rsp , err := cl .clientConn .RoundTrip (req )
if err != nil {
select {
case <- req .Context ().Done ():
return nil , err
default :
}
if isRetried {
return nil , err
}
t .removeClient (hostname )
req , err = canRetryRequest (err , req )
if err != nil {
return nil , err
}
return t .doRoundTripOpt (req , opt , true )
}
return rsp , nil
}
func canRetryRequest(err error , req *http .Request ) (*http .Request , error ) {
var connErr *errConnUnusable
if errors .As (err , &connErr ) {
return req , nil
}
var e *Error
if !errors .As (err , &e ) || e .ErrorCode != ErrCodeRequestRejected {
return nil , err
}
if req .Body == nil || req .Body == http .NoBody {
return req , nil
}
if req .GetBody != nil {
newBody , err := req .GetBody ()
if err != nil {
return nil , err
}
reqCopy := *req
reqCopy .Body = newBody
req = &reqCopy
return &reqCopy , nil
}
return nil , fmt .Errorf ("http3: Transport: cannot retry err [%w] after Request.Body was written; define Request.GetBody to avoid this error" , err )
}
func (t *Transport ) RoundTrip (req *http .Request ) (*http .Response , error ) {
return t .RoundTripOpt (req , RoundTripOpt {})
}
func (t *Transport ) getClient (ctx context .Context , hostname string , onlyCached bool ) (rtc *roundTripperWithCount , isReused bool , err error ) {
t .mutex .Lock ()
defer t .mutex .Unlock ()
if t .closed {
return nil , false , ErrTransportClosed
}
if t .clients == nil {
t .clients = make (map [string ]*roundTripperWithCount )
}
cl , ok := t .clients [hostname ]
if !ok {
if onlyCached {
return nil , false , ErrNoCachedConn
}
ctx , cancel := context .WithCancel (ctx )
cl = &roundTripperWithCount {
dialing : make (chan struct {}),
cancel : cancel ,
}
go func () {
defer close (cl .dialing )
defer cancel ()
conn , rt , err := t .dial (ctx , hostname )
if err != nil {
cl .dialErr = err
return
}
cl .conn = conn
cl .clientConn = rt
}()
t .clients [hostname ] = cl
}
select {
case <- cl .dialing :
if cl .dialErr != nil {
delete (t .clients , hostname )
return nil , false , cl .dialErr
}
select {
case <- cl .conn .HandshakeComplete ():
isReused = true
default :
}
default :
}
cl .useCount .Add (1 )
return cl , isReused , nil
}
func (t *Transport ) dial (ctx context .Context , hostname string ) (*quic .Conn , clientConn , error ) {
var tlsConf *tls .Config
if t .TLSClientConfig == nil {
tlsConf = &tls .Config {}
} else {
tlsConf = t .TLSClientConfig .Clone ()
}
if tlsConf .ServerName == "" {
sni , _ , err := net .SplitHostPort (hostname )
if err != nil {
sni = hostname
}
tlsConf .ServerName = sni
}
tlsConf .NextProtos = []string {NextProtoH3 }
dial := t .Dial
if dial == nil {
dial = func (ctx context .Context , addr string , tlsCfg *tls .Config , cfg *quic .Config ) (*quic .Conn , error ) {
network := "udp"
udpAddr , err := t .resolveUDPAddr (ctx , network , addr )
if err != nil {
return nil , err
}
trace := httptrace .ContextClientTrace (ctx )
traceConnectStart (trace , network , udpAddr .String ())
traceTLSHandshakeStart (trace )
conn , err := t .transport .DialEarly (ctx , udpAddr , tlsCfg , cfg )
var state tls .ConnectionState
if conn != nil {
state = conn .ConnectionState ().TLS
}
traceTLSHandshakeDone (trace , state , err )
traceConnectDone (trace , network , udpAddr .String (), err )
return conn , err
}
}
conn , err := dial (ctx , hostname , tlsConf , t .QUICConfig )
if err != nil {
return nil , nil , err
}
clientConn := t .newClientConn (conn )
go func () {
for {
str , err := conn .AcceptUniStream (context .Background ())
if err != nil {
return
}
go clientConn .handleUnidirectionalStream (str )
}
}()
return conn , clientConn , nil
}
func (t *Transport ) resolveUDPAddr (ctx context .Context , network , addr string ) (*net .UDPAddr , error ) {
host , portStr , err := net .SplitHostPort (addr )
if err != nil {
return nil , err
}
port , err := net .LookupPort (network , portStr )
if err != nil {
return nil , err
}
resolver := net .DefaultResolver
ipAddrs , err := resolver .LookupIPAddr (ctx , host )
if err != nil {
return nil , err
}
addrs := addrList (ipAddrs )
ip := addrs .forResolve (network , addr )
return &net .UDPAddr {IP : ip .IP , Port : port , Zone : ip .Zone }, nil
}
func (t *Transport ) removeClient (hostname string ) {
t .mutex .Lock ()
defer t .mutex .Unlock ()
if t .clients == nil {
return
}
delete (t .clients , hostname )
}
func (t *Transport ) NewClientConn (conn *quic .Conn ) *ClientConn {
c := newClientConn (
conn ,
t .EnableDatagrams ,
t .AdditionalSettings ,
t .MaxResponseHeaderBytes ,
t .DisableCompression ,
t .Logger ,
)
go func () {
for {
str , err := conn .AcceptUniStream (context .Background ())
if err != nil {
return
}
go c .handleUnidirectionalStream (str )
}
}()
return c
}
func (t *Transport ) NewRawClientConn (conn *quic .Conn ) *RawClientConn {
return &RawClientConn {
ClientConn : newClientConn (
conn ,
t .EnableDatagrams ,
t .AdditionalSettings ,
t .MaxResponseHeaderBytes ,
t .DisableCompression ,
t .Logger ,
),
}
}
func (t *Transport ) Close () error {
t .mutex .Lock ()
defer t .mutex .Unlock ()
for _ , cl := range t .clients {
if err := cl .Close (); err != nil {
return err
}
}
t .clients = nil
if t .transport != nil {
if err := t .transport .Close (); err != nil {
return err
}
if err := t .transport .Conn .Close (); err != nil {
return err
}
t .transport = nil
}
t .closed = true
return nil
}
func hostnameFromURL(url *url .URL ) string {
if url != nil {
return url .Host
}
return ""
}
func validMethod(method string ) bool {
return len (method ) > 0 && strings .IndexFunc (method , isNotToken ) == -1
}
func isNotToken(r rune ) bool {
return !httpguts .IsTokenRune (r )
}
func (t *Transport ) CloseIdleConnections () {
t .mutex .Lock ()
defer t .mutex .Unlock ()
for hostname , cl := range t .clients {
if cl .useCount .Load () == 0 {
cl .Close ()
delete (t .clients , hostname )
}
}
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .