package http3
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptrace"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
"github.com/quic-go/qpack"
)
var errGoAway = errors .New ("connection in graceful shutdown" )
type Connection interface {
OpenStream () (quic .Stream , error )
OpenStreamSync (context .Context ) (quic .Stream , error )
OpenUniStream () (quic .SendStream , error )
OpenUniStreamSync (context .Context ) (quic .SendStream , error )
LocalAddr () net .Addr
RemoteAddr () net .Addr
CloseWithError (quic .ApplicationErrorCode , string ) error
Context () context .Context
ConnectionState () quic .ConnectionState
ReceivedSettings () <-chan struct {}
Settings () *Settings
}
type connection struct {
quic .Connection
ctx context .Context
perspective protocol .Perspective
logger *slog .Logger
enableDatagrams bool
decoder *qpack .Decoder
streamMx sync .Mutex
streams map [protocol .StreamID ]*datagrammer
lastStreamID protocol .StreamID
maxStreamID protocol .StreamID
settings *Settings
receivedSettings chan struct {}
idleTimeout time .Duration
idleTimer *time .Timer
}
func newConnection(
ctx context .Context ,
quicConn quic .Connection ,
enableDatagrams bool ,
perspective protocol .Perspective ,
logger *slog .Logger ,
idleTimeout time .Duration ,
) *connection {
c := &connection {
ctx : ctx ,
Connection : quicConn ,
perspective : perspective ,
logger : logger ,
idleTimeout : idleTimeout ,
enableDatagrams : enableDatagrams ,
decoder : qpack .NewDecoder (func (hf qpack .HeaderField ) {}),
receivedSettings : make (chan struct {}),
streams : make (map [protocol .StreamID ]*datagrammer ),
maxStreamID : protocol .InvalidStreamID ,
lastStreamID : protocol .InvalidStreamID ,
}
if idleTimeout > 0 {
c .idleTimer = time .AfterFunc (idleTimeout , c .onIdleTimer )
}
return c
}
func (c *connection ) onIdleTimer () {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeNoError ), "idle timeout" )
}
func (c *connection ) clearStream (id quic .StreamID ) {
c .streamMx .Lock ()
defer c .streamMx .Unlock ()
delete (c .streams , id )
if c .idleTimeout > 0 && len (c .streams ) == 0 {
c .idleTimer .Reset (c .idleTimeout )
}
if c .maxStreamID != protocol .InvalidStreamID {
if len (c .streams ) == 0 {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeNoError ), "" )
}
}
}
func (c *connection ) openRequestStream (
ctx context .Context ,
requestWriter *requestWriter ,
reqDone chan <- struct {},
disableCompression bool ,
maxHeaderBytes uint64 ,
) (*requestStream , error ) {
if c .perspective == protocol .PerspectiveClient {
c .streamMx .Lock ()
maxStreamID := c .maxStreamID
var nextStreamID quic .StreamID
if c .lastStreamID == protocol .InvalidStreamID {
nextStreamID = 0
} else {
nextStreamID = c .lastStreamID + 4
}
c .streamMx .Unlock ()
if maxStreamID != protocol .InvalidStreamID && nextStreamID >= maxStreamID {
return nil , errGoAway
}
}
str , err := c .OpenStreamSync (ctx )
if err != nil {
return nil , err
}
datagrams := newDatagrammer (func (b []byte ) error { return c .sendDatagram (str .StreamID (), b ) })
c .streamMx .Lock ()
c .streams [str .StreamID ()] = datagrams
c .lastStreamID = str .StreamID ()
c .streamMx .Unlock ()
qstr := newStateTrackingStream (str , c , datagrams )
rsp := &http .Response {}
hstr := newStream (qstr , c , datagrams , func (r io .Reader , l uint64 ) error {
hdr , err := c .decodeTrailers (r , l , maxHeaderBytes )
if err != nil {
return err
}
rsp .Trailer = hdr
return nil
})
trace := httptrace .ContextClientTrace (ctx )
return newRequestStream (hstr , requestWriter , reqDone , c .decoder , disableCompression , maxHeaderBytes , rsp , trace ), nil
}
func (c *connection ) decodeTrailers (r io .Reader , l , maxHeaderBytes uint64 ) (http .Header , error ) {
if l > maxHeaderBytes {
return nil , fmt .Errorf ("HEADERS frame too large: %d bytes (max: %d)" , l , maxHeaderBytes )
}
b := make ([]byte , l )
if _ , err := io .ReadFull (r , b ); err != nil {
return nil , err
}
fields , err := c .decoder .DecodeFull (b )
if err != nil {
return nil , err
}
return parseTrailers (fields )
}
func (c *connection ) acceptStream (ctx context .Context ) (quic .Stream , *datagrammer , error ) {
str , err := c .AcceptStream (ctx )
if err != nil {
return nil , nil , err
}
datagrams := newDatagrammer (func (b []byte ) error { return c .sendDatagram (str .StreamID (), b ) })
if c .perspective == protocol .PerspectiveServer {
strID := str .StreamID ()
c .streamMx .Lock ()
c .streams [strID ] = datagrams
if c .idleTimeout > 0 {
if len (c .streams ) == 1 {
c .idleTimer .Stop ()
}
}
c .streamMx .Unlock ()
str = newStateTrackingStream (str , c , datagrams )
}
return str , datagrams , nil
}
func (c *connection ) CloseWithError (code quic .ApplicationErrorCode , msg string ) error {
if c .idleTimer != nil {
c .idleTimer .Stop ()
}
return c .Connection .CloseWithError (code , msg )
}
func (c *connection ) handleUnidirectionalStreams (hijack func (StreamType , quic .ConnectionTracingID , quic .ReceiveStream , error ) (hijacked bool )) {
var (
rcvdControlStr atomic .Bool
rcvdQPACKEncoderStr atomic .Bool
rcvdQPACKDecoderStr atomic .Bool
)
for {
str , err := c .AcceptUniStream (context .Background ())
if err != nil {
if c .logger != nil {
c .logger .Debug ("accepting unidirectional stream failed" , "error" , err )
}
return
}
go func (str quic .ReceiveStream ) {
streamType , err := quicvarint .Read (quicvarint .NewReader (str ))
if err != nil {
id := c .Context ().Value (quic .ConnectionTracingKey ).(quic .ConnectionTracingID )
if hijack != nil && hijack (StreamType (streamType ), id , str , err ) {
return
}
if c .logger != nil {
c .logger .Debug ("reading stream type on stream failed" , "stream ID" , str .StreamID (), "error" , err )
}
return
}
switch streamType {
case streamTypeControlStream :
case streamTypeQPACKEncoderStream :
if isFirst := rcvdQPACKEncoderStr .CompareAndSwap (false , true ); !isFirst {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeStreamCreationError ), "duplicate QPACK encoder stream" )
}
return
case streamTypeQPACKDecoderStream :
if isFirst := rcvdQPACKDecoderStr .CompareAndSwap (false , true ); !isFirst {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeStreamCreationError ), "duplicate QPACK decoder stream" )
}
return
case streamTypePushStream :
switch c .perspective {
case protocol .PerspectiveClient :
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeIDError ), "" )
case protocol .PerspectiveServer :
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeStreamCreationError ), "" )
}
return
default :
if hijack != nil {
if hijack (
StreamType (streamType ),
c .Context ().Value (quic .ConnectionTracingKey ).(quic .ConnectionTracingID ),
str ,
nil ,
) {
return
}
}
str .CancelRead (quic .StreamErrorCode (ErrCodeStreamCreationError ))
return
}
if isFirstControlStr := rcvdControlStr .CompareAndSwap (false , true ); !isFirstControlStr {
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeStreamCreationError ), "duplicate control stream" )
return
}
c .handleControlStream (str )
}(str )
}
}
func (c *connection ) handleControlStream (str quic .ReceiveStream ) {
fp := &frameParser {conn : c .Connection , r : str }
f , err := fp .ParseNext ()
if err != nil {
var serr *quic .StreamError
if err == io .EOF || errors .As (err , &serr ) {
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeClosedCriticalStream ), "" )
return
}
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeFrameError ), "" )
return
}
sf , ok := f .(*settingsFrame )
if !ok {
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeMissingSettings ), "" )
return
}
c .settings = &Settings {
EnableDatagrams : sf .Datagram ,
EnableExtendedConnect : sf .ExtendedConnect ,
Other : sf .Other ,
}
close (c .receivedSettings )
if sf .Datagram {
if c .enableDatagrams && !c .ConnectionState ().SupportsDatagrams {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeSettingsError ), "missing QUIC Datagram support" )
return
}
go func () {
if err := c .receiveDatagrams (); err != nil {
if c .logger != nil {
c .logger .Debug ("receiving datagrams failed" , "error" , err )
}
}
}()
}
if c .perspective == protocol .PerspectiveServer {
return
}
for {
f , err := fp .ParseNext ()
if err != nil {
var serr *quic .StreamError
if err == io .EOF || errors .As (err , &serr ) {
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeClosedCriticalStream ), "" )
return
}
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeFrameError ), "" )
return
}
goaway , ok := f .(*goAwayFrame )
if !ok {
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeFrameUnexpected ), "" )
return
}
if goaway .StreamID %4 != 0 {
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeIDError ), "" )
return
}
c .streamMx .Lock ()
if c .maxStreamID != protocol .InvalidStreamID && goaway .StreamID > c .maxStreamID {
c .streamMx .Unlock ()
c .Connection .CloseWithError (quic .ApplicationErrorCode (ErrCodeIDError ), "" )
return
}
c .maxStreamID = goaway .StreamID
hasActiveStreams := len (c .streams ) > 0
c .streamMx .Unlock ()
if !hasActiveStreams {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeNoError ), "" )
return
}
}
}
func (c *connection ) sendDatagram (streamID protocol .StreamID , b []byte ) error {
data := make ([]byte , 0 , len (b )+8 )
data = quicvarint .Append (data , uint64 (streamID /4 ))
data = append (data , b ...)
return c .SendDatagram (data )
}
func (c *connection ) receiveDatagrams () error {
for {
b , err := c .ReceiveDatagram (context .Background ())
if err != nil {
return err
}
quarterStreamID , n , err := quicvarint .Parse (b )
if err != nil {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeDatagramError ), "" )
return fmt .Errorf ("could not read quarter stream id: %w" , err )
}
if quarterStreamID > maxQuarterStreamID {
c .CloseWithError (quic .ApplicationErrorCode (ErrCodeDatagramError ), "" )
return fmt .Errorf ("invalid quarter stream id: %w" , err )
}
streamID := protocol .StreamID (4 * quarterStreamID )
c .streamMx .Lock ()
dg , ok := c .streams [streamID ]
c .streamMx .Unlock ()
if !ok {
continue
}
dg .enqueue (b [n :])
}
}
func (c *connection ) ReceivedSettings () <-chan struct {} { return c .receivedSettings }
func (c *connection ) Settings () *Settings { return c .settings }
func (c *connection ) Context () context .Context { return c .ctx }
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .