package http3
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"strconv"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
"github.com/quic-go/qpack"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3/qlog"
"github.com/quic-go/quic-go/qlogwriter"
)
const bodyCopyBufferSize = 8 * 1024
type requestWriter struct {
mutex sync .Mutex
encoder *qpack .Encoder
headerBuf *bytes .Buffer
}
func newRequestWriter() *requestWriter {
headerBuf := &bytes .Buffer {}
encoder := qpack .NewEncoder (headerBuf )
return &requestWriter {
encoder : encoder ,
headerBuf : headerBuf ,
}
}
func (w *requestWriter ) WriteRequestHeader (wr io .Writer , req *http .Request , gzip bool , streamID quic .StreamID , qlogger qlogwriter .Recorder ) error {
buf := &bytes .Buffer {}
if err := w .writeHeaders (buf , req , gzip , streamID , qlogger ); err != nil {
return err
}
if _ , err := wr .Write (buf .Bytes ()); err != nil {
return err
}
trace := httptrace .ContextClientTrace (req .Context ())
traceWroteHeaders (trace )
return nil
}
func (w *requestWriter ) writeHeaders (wr io .Writer , req *http .Request , gzip bool , streamID quic .StreamID , qlogger qlogwriter .Recorder ) error {
w .mutex .Lock ()
defer w .mutex .Unlock ()
defer w .encoder .Close ()
defer w .headerBuf .Reset ()
var trailers string
if len (req .Trailer ) > 0 {
keys := make ([]string , 0 , len (req .Trailer ))
for k := range req .Trailer {
if httpguts .ValidTrailerHeader (k ) {
keys = append (keys , k )
}
}
trailers = strings .Join (keys , ", " )
}
headerFields , err := w .encodeHeaders (req , gzip , trailers , actualContentLength (req ), qlogger != nil )
if err != nil {
return err
}
b := make ([]byte , 0 , 128 )
b = (&headersFrame {Length : uint64 (w .headerBuf .Len ())}).Append (b )
if qlogger != nil {
qlogCreatedHeadersFrame (qlogger , streamID , len (b )+w .headerBuf .Len (), w .headerBuf .Len (), headerFields )
}
if _ , err := wr .Write (b ); err != nil {
return err
}
_, err = wr .Write (w .headerBuf .Bytes ())
return err
}
func isExtendedConnectRequest(req *http .Request ) bool {
return req .Method == http .MethodConnect && req .Proto != "" && req .Proto != "HTTP/1.1"
}
func (w *requestWriter ) encodeHeaders (req *http .Request , addGzipHeader bool , trailers string , contentLength int64 , doQlog bool ) ([]qlog .HeaderField , error ) {
host := req .Host
if host == "" {
host = req .URL .Host
}
host , err := httpguts .PunycodeHostPort (host )
if err != nil {
return nil , err
}
if !httpguts .ValidHostHeader (host ) {
return nil , errors .New ("http3: invalid Host header" )
}
isExtendedConnect := isExtendedConnectRequest (req )
var path string
if req .Method != http .MethodConnect || isExtendedConnect {
path = req .URL .RequestURI ()
if !validPseudoPath (path ) {
orig := path
path = strings .TrimPrefix (path , req .URL .Scheme +"://" +host )
if !validPseudoPath (path ) {
if req .URL .Opaque != "" {
return nil , fmt .Errorf ("invalid request :path %q from URL.Opaque = %q" , orig , req .URL .Opaque )
} else {
return nil , fmt .Errorf ("invalid request :path %q" , orig )
}
}
}
}
for k , vv := range req .Header {
if !httpguts .ValidHeaderFieldName (k ) {
return nil , fmt .Errorf ("invalid HTTP header name %q" , k )
}
for _ , v := range vv {
if !httpguts .ValidHeaderFieldValue (v ) {
return nil , fmt .Errorf ("invalid HTTP header value %q for header %q" , v , k )
}
}
}
enumerateHeaders := func (f func (name , value string )) {
f (":authority" , host )
f (":method" , req .Method )
if req .Method != http .MethodConnect || isExtendedConnect {
f (":path" , path )
f (":scheme" , req .URL .Scheme )
}
if isExtendedConnect {
f (":protocol" , req .Proto )
}
if trailers != "" {
f ("trailer" , trailers )
}
var didUA bool
for k , vv := range req .Header {
if strings .EqualFold (k , "host" ) || strings .EqualFold (k , "content-length" ) {
continue
} else if strings .EqualFold (k , "connection" ) || strings .EqualFold (k , "proxy-connection" ) ||
strings .EqualFold (k , "transfer-encoding" ) || strings .EqualFold (k , "upgrade" ) ||
strings .EqualFold (k , "keep-alive" ) {
continue
} else if strings .EqualFold (k , "user-agent" ) {
didUA = true
if len (vv ) < 1 {
continue
}
vv = vv [:1 ]
if vv [0 ] == "" {
continue
}
}
for _ , v := range vv {
f (k , v )
}
}
if shouldSendReqContentLength (req .Method , contentLength ) {
f ("content-length" , strconv .FormatInt (contentLength , 10 ))
}
if addGzipHeader {
f ("accept-encoding" , "gzip" )
}
if !didUA {
f ("user-agent" , defaultUserAgent )
}
}
hlSize := uint64 (0 )
enumerateHeaders (func (name , value string ) {
hf := hpack .HeaderField {Name : name , Value : value }
hlSize += uint64 (hf .Size ())
})
trace := httptrace .ContextClientTrace (req .Context ())
traceHeaders := traceHasWroteHeaderField (trace )
var headerFields []qlog .HeaderField
if doQlog {
headerFields = make ([]qlog .HeaderField , 0 , len (req .Header ))
}
enumerateHeaders (func (name , value string ) {
name = strings .ToLower (name )
w .encoder .WriteField (qpack .HeaderField {Name : name , Value : value })
if traceHeaders {
traceWroteHeaderField (trace , name , value )
}
if doQlog {
headerFields = append (headerFields , qlog .HeaderField {Name : name , Value : value })
}
})
return headerFields , nil
}
func authorityAddr(authority string ) (addr string ) {
host , port , err := net .SplitHostPort (authority )
if err != nil {
port = "443"
host = authority
}
if a , err := idna .ToASCII (host ); err == nil {
host = a
}
if strings .HasPrefix (host , "[" ) && strings .HasSuffix (host , "]" ) {
return host + ":" + port
}
return net .JoinHostPort (host , port )
}
func validPseudoPath(v string ) bool {
return (len (v ) > 0 && v [0 ] == '/' ) || v == "*"
}
func actualContentLength(req *http .Request ) int64 {
if req .Body == nil {
return 0
}
if req .ContentLength != 0 {
return req .ContentLength
}
return -1
}
func shouldSendReqContentLength(method string , contentLength int64 ) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
switch method {
case "POST" , "PUT" , "PATCH" :
return true
default :
return false
}
}
func (w *requestWriter ) WriteRequestTrailer (wr io .Writer , req *http .Request , streamID quic .StreamID , qlogger qlogwriter .Recorder ) error {
_ , err := writeTrailers (wr , req .Trailer , streamID , qlogger )
return err
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .