package http3
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptrace"
"net/textproto"
"sync"
"time"
"github.com/quic-go/qpack"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3/qlog"
"github.com/quic-go/quic-go/qlogwriter"
)
const (
MethodGet0RTT = "GET_0RTT"
MethodHead0RTT = "HEAD_0RTT"
)
const (
defaultUserAgent = "quic-go HTTP/3"
defaultMaxResponseHeaderBytes = 10 * 1 << 20
)
var errGoAway = errors .New ("connection in graceful shutdown" )
type errConnUnusable struct { e error }
func (e *errConnUnusable ) Unwrap () error { return e .e }
func (e *errConnUnusable ) Error () string { return fmt .Sprintf ("http3: conn unusable: %s" , e .e .Error()) }
const max1xxResponses = 5
var defaultQuicConfig = &quic .Config {
MaxIncomingStreams : -1 ,
KeepAlivePeriod : 10 * time .Second ,
}
type ClientConn struct {
conn *quic .Conn
rawConn *rawConn
decoder *qpack .Decoder
additionalSettings map [uint64 ]uint64
maxResponseHeaderBytes int
disableCompression bool
streamMx sync .Mutex
maxStreamID quic .StreamID
lastStreamID quic .StreamID
qlogger qlogwriter .Recorder
logger *slog .Logger
requestWriter *requestWriter
}
var _ http .RoundTripper = &ClientConn {}
func newClientConn(
conn *quic .Conn ,
enableDatagrams bool ,
additionalSettings map [uint64 ]uint64 ,
maxResponseHeaderBytes int ,
disableCompression bool ,
logger *slog .Logger ,
) *ClientConn {
var qlogger qlogwriter .Recorder
if qlogTrace := conn .QlogTrace (); qlogTrace != nil && qlogTrace .SupportsSchemas (qlog .EventSchema ) {
qlogger = qlogTrace .AddProducer ()
}
c := &ClientConn {
conn : conn ,
additionalSettings : additionalSettings ,
disableCompression : disableCompression ,
maxStreamID : invalidStreamID ,
lastStreamID : invalidStreamID ,
logger : logger ,
qlogger : qlogger ,
decoder : qpack .NewDecoder (),
}
if maxResponseHeaderBytes <= 0 {
c .maxResponseHeaderBytes = defaultMaxResponseHeaderBytes
} else {
c .maxResponseHeaderBytes = maxResponseHeaderBytes
}
c .requestWriter = newRequestWriter ()
c .rawConn = newRawConn (
conn ,
enableDatagrams ,
c .onStreamsEmpty ,
c .handleControlStream ,
qlogger ,
c .logger ,
)
go func () {
_ , err := c .rawConn .openControlStream (&settingsFrame {
Datagram : enableDatagrams ,
Other : additionalSettings ,
MaxFieldSectionSize : int64 (c .maxResponseHeaderBytes ),
})
if err != nil {
if c .logger != nil {
c .logger .Debug ("setting up connection failed" , "error" , err )
}
c .conn .CloseWithError (quic .ApplicationErrorCode (ErrCodeInternalError ), "" )
return
}
}()
return c
}
func (c *ClientConn ) OpenRequestStream (ctx context .Context ) (*RequestStream , error ) {
return c .openRequestStream (ctx , c .requestWriter , nil , c .disableCompression , c .maxResponseHeaderBytes )
}
func (c *ClientConn ) openRequestStream (
ctx context .Context ,
requestWriter *requestWriter ,
reqDone chan <- struct {},
disableCompression bool ,
maxHeaderBytes int ,
) (*RequestStream , error ) {
c .streamMx .Lock ()
maxStreamID := c .maxStreamID
var nextStreamID quic .StreamID
if c .lastStreamID == invalidStreamID {
nextStreamID = 0
} else {
nextStreamID = c .lastStreamID + 4
}
c .streamMx .Unlock ()
if maxStreamID != invalidStreamID && nextStreamID >= maxStreamID {
return nil , errGoAway
}
str , err := c .conn .OpenStreamSync (ctx )
if err != nil {
return nil , err
}
c .streamMx .Lock ()
if c .lastStreamID == invalidStreamID {
c .lastStreamID = str .StreamID ()
} else {
c .lastStreamID = max (c .lastStreamID , str .StreamID ())
}
maxStreamID = c .maxStreamID
c .streamMx .Unlock ()
if maxStreamID != invalidStreamID && str .StreamID () >= maxStreamID {
str .CancelRead (quic .StreamErrorCode (ErrCodeRequestCanceled ))
str .CancelWrite (quic .StreamErrorCode (ErrCodeRequestCanceled ))
return nil , errGoAway
}
hstr := c .rawConn .TrackStream (str )
rsp := &http .Response {}
trace := httptrace .ContextClientTrace (ctx )
return newRequestStream (
newStream (hstr , c .rawConn , trace , func (r io .Reader , hf *headersFrame ) error {
hdr , err := decodeTrailers (r , hf , maxHeaderBytes , c .decoder , c .qlogger , str .StreamID ())
if err != nil {
return err
}
rsp .Trailer = hdr
return nil
}, c .qlogger ),
requestWriter ,
reqDone ,
c .decoder ,
disableCompression ,
maxHeaderBytes ,
rsp ,
), nil
}
func (c *ClientConn ) handleUnidirectionalStream (str *quic .ReceiveStream ) {
c .rawConn .handleUnidirectionalStream (str , false )
}
func (c *ClientConn ) handleControlStream (str *quic .ReceiveStream , fp *frameParser ) {
for {
f , err := fp .ParseNext (c .qlogger )
if err != nil {
var serr *quic .StreamError
if err == io .EOF || errors .As (err , &serr ) {
c .conn .CloseWithError (quic .ApplicationErrorCode (ErrCodeClosedCriticalStream ), "" )
return
}
c .conn .CloseWithError (quic .ApplicationErrorCode (ErrCodeFrameError ), "" )
return
}
goaway , ok := f .(*goAwayFrame )
if !ok {
c .conn .CloseWithError (quic .ApplicationErrorCode (ErrCodeFrameUnexpected ), "" )
return
}
if goaway .StreamID %4 != 0 {
c .conn .CloseWithError (quic .ApplicationErrorCode (ErrCodeIDError ), "" )
return
}
c .streamMx .Lock ()
if c .maxStreamID != invalidStreamID && goaway .StreamID > c .maxStreamID {
c .streamMx .Unlock ()
c .conn .CloseWithError (quic .ApplicationErrorCode (ErrCodeIDError ), "" )
return
}
c .maxStreamID = goaway .StreamID
c .streamMx .Unlock ()
hasActiveStreams := c .rawConn .hasActiveStreams ()
if !hasActiveStreams {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeNoError ), "" )
return
}
}
}
func (c *ClientConn ) onStreamsEmpty () {
c .streamMx .Lock ()
defer c .streamMx .Unlock ()
if c .maxStreamID != invalidStreamID {
c .conn .CloseWithError (quic .ApplicationErrorCode (ErrCodeNoError ), "" )
}
}
func (c *ClientConn ) RoundTrip (req *http .Request ) (*http .Response , error ) {
rsp , err := c .roundTrip (req )
if err != nil && req .Context ().Err () != nil {
err = req .Context ().Err ()
}
return rsp , err
}
func (c *ClientConn ) roundTrip (req *http .Request ) (*http .Response , error ) {
switch req .Method {
case MethodGet0RTT :
reqCopy := *req
req = &reqCopy
req .Method = http .MethodGet
case MethodHead0RTT :
reqCopy := *req
req = &reqCopy
req .Method = http .MethodHead
default :
select {
case <- c .conn .HandshakeComplete ():
case <- req .Context ().Done ():
return nil , req .Context ().Err ()
}
}
if isExtendedConnectRequest (req ) {
connCtx := c .conn .Context ()
select {
case <- c .rawConn .ReceivedSettings ():
case <- connCtx .Done ():
return nil , context .Cause (connCtx )
}
if !c .rawConn .Settings ().EnableExtendedConnect {
return nil , errors .New ("http3: server didn't enable Extended CONNECT" )
}
}
reqDone := make (chan struct {})
str , err := c .openRequestStream (
req .Context (),
c .requestWriter ,
reqDone ,
c .disableCompression ,
c .maxResponseHeaderBytes ,
)
if err != nil {
return nil , &errConnUnusable {e : err }
}
done := make (chan struct {})
go func () {
defer close (done )
select {
case <- req .Context ().Done ():
str .CancelWrite (quic .StreamErrorCode (ErrCodeRequestCanceled ))
str .CancelRead (quic .StreamErrorCode (ErrCodeRequestCanceled ))
case <- reqDone :
}
}()
rsp , err := c .doRequest (req , str )
if err != nil {
close (reqDone )
<-done
return nil , maybeReplaceError (err )
}
return rsp , maybeReplaceError (err )
}
func (c *ClientConn ) ReceivedSettings () <-chan struct {} {
return c .rawConn .ReceivedSettings ()
}
func (c *ClientConn ) Settings () *Settings {
return c .rawConn .Settings ()
}
func (c *ClientConn ) CloseWithError (code quic .ApplicationErrorCode , msg string ) error {
return c .conn .CloseWithError (code , msg )
}
func (c *ClientConn ) Context () context .Context {
return c .conn .Context ()
}
type cancelingReader struct {
r io .Reader
str *RequestStream
}
func (r *cancelingReader ) Read (b []byte ) (int , error ) {
n , err := r .r .Read (b )
if err != nil && err != io .EOF {
r .str .CancelWrite (quic .StreamErrorCode (ErrCodeRequestCanceled ))
}
return n , err
}
func (c *ClientConn ) sendRequestBody (str *RequestStream , body io .ReadCloser , contentLength int64 ) error {
defer body .Close ()
buf := make ([]byte , bodyCopyBufferSize )
sr := &cancelingReader {str : str , r : body }
if contentLength == -1 {
_ , err := io .CopyBuffer (str , sr , buf )
return err
}
n , err := io .CopyBuffer (str , io .LimitReader (sr , contentLength ), buf )
if err != nil {
return err
}
var extra int64
extra , err = io .CopyBuffer (io .Discard , sr , buf )
n += extra
if n > contentLength {
str .CancelWrite (quic .StreamErrorCode (ErrCodeRequestCanceled ))
return fmt .Errorf ("http: ContentLength=%d with Body length %d" , contentLength , n )
}
return err
}
func (c *ClientConn ) doRequest (req *http .Request , str *RequestStream ) (*http .Response , error ) {
trace := httptrace .ContextClientTrace (req .Context ())
var sendingReqFailed bool
if err := str .sendRequestHeader (req ); err != nil {
traceWroteRequest (trace , err )
if c .logger != nil {
c .logger .Debug ("error writing request" , "error" , err )
}
sendingReqFailed = true
}
if !sendingReqFailed {
if req .Body == nil {
traceWroteRequest (trace , nil )
str .Close ()
} else {
go func () {
defer str .Close ()
contentLength := int64 (-1 )
if req .ContentLength > 0 {
contentLength = req .ContentLength
}
err := c .sendRequestBody (str , req .Body , contentLength )
traceWroteRequest (trace , err )
if err != nil {
if c .logger != nil {
c .logger .Debug ("error writing request" , "error" , err )
}
return
}
if len (req .Trailer ) > 0 {
if err := str .sendRequestTrailer (req ); err != nil {
if c .logger != nil {
c .logger .Debug ("error writing trailers" , "error" , err )
}
}
}
}()
}
}
var num1xx int
var res *http .Response
for {
var err error
res , err = str .ReadResponse ()
if err != nil {
return nil , err
}
resCode := res .StatusCode
is1xx := 100 <= resCode && resCode <= 199
is1xxNonTerminal := is1xx && resCode != http .StatusSwitchingProtocols
if is1xxNonTerminal {
num1xx ++
if num1xx > max1xxResponses {
str .CancelRead (quic .StreamErrorCode (ErrCodeExcessiveLoad ))
str .CancelWrite (quic .StreamErrorCode (ErrCodeExcessiveLoad ))
return nil , errors .New ("http3: too many 1xx informational responses" )
}
traceGot1xxResponse (trace , resCode , textproto .MIMEHeader (res .Header ))
if resCode == http .StatusContinue {
traceGot100Continue (trace )
}
continue
}
break
}
connState := c .conn .ConnectionState ().TLS
res .TLS = &connState
res .Request = req
return res , nil
}
type RawClientConn struct {
*ClientConn
}
func (c *RawClientConn ) HandleUnidirectionalStream (str *quic .ReceiveStream ) {
c .rawConn .handleUnidirectionalStream (str , false )
}
func (c *ClientConn ) HandleBidirectionalStream (str *quic .Stream ) {
c .rawConn .CloseWithError (
quic .ApplicationErrorCode (ErrCodeStreamCreationError ),
fmt .Sprintf ("server opened bidirectional stream %d" , str .StreamID ()),
)
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .