package http3
import (
"bytes"
"fmt"
"log/slog"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
"github.com/quic-go/qpack"
"github.com/quic-go/quic-go/http3/qlog"
"golang.org/x/net/http/httpguts"
)
type HTTPStreamer interface {
HTTPStream () *Stream
}
const maxSmallResponseSize = 4096
type responseWriter struct {
str *Stream
conn *rawConn
header http .Header
trailers map [string ]struct {}
buf []byte
status int
smallResponseBuf []byte
contentLen int64
numWritten int64
headerComplete bool
headerWritten bool
isHead bool
trailerWritten bool
hijacked bool
logger *slog .Logger
}
var (
_ http .ResponseWriter = &responseWriter {}
_ http .Flusher = &responseWriter {}
_ Settingser = &responseWriter {}
_ HTTPStreamer = &responseWriter {}
_ interface {
SetReadDeadline(time .Time ) error
SetWriteDeadline(time .Time ) error
Flush()
FlushError() error
} = &responseWriter {}
)
func newResponseWriter(str *Stream , conn *rawConn , isHead bool , logger *slog .Logger ) *responseWriter {
return &responseWriter {
str : str ,
conn : conn ,
header : http .Header {},
buf : make ([]byte , frameHeaderLen ),
isHead : isHead ,
logger : logger ,
}
}
func (w *responseWriter ) Header () http .Header {
return w .header
}
func (w *responseWriter ) WriteHeader (status int ) {
if w .headerComplete {
return
}
if status < 100 || status > 999 {
panic (fmt .Sprintf ("invalid WriteHeader code %v" , status ))
}
w .status = status
if status < 200 {
w .writeHeader (status )
return
}
w .headerComplete = true
if _ , ok := w .header ["Date" ]; !ok {
w .header .Set ("Date" , time .Now ().UTC ().Format (http .TimeFormat ))
}
if clen := w .header .Get ("Content-Length" ); clen != "" {
if cl , err := strconv .ParseUint (clen , 10 , 63 ); err == nil {
w .contentLen = int64 (cl )
} else {
logger := w .logger
if logger == nil {
logger = slog .Default ()
}
logger .Error ("Malformed Content-Length" , "value" , clen )
w .header .Del ("Content-Length" )
}
}
}
func (w *responseWriter ) sniffContentType (p []byte ) {
_ , haveType := w .header ["Content-Type" ]
hasCE := w .header .Get ("Content-Encoding" ) != ""
if !hasCE && !haveType && len (p ) > 0 {
w .header .Set ("Content-Type" , http .DetectContentType (p ))
}
}
func (w *responseWriter ) Write (p []byte ) (int , error ) {
bodyAllowed := bodyAllowedForStatus (w .status )
if !w .headerComplete {
w .sniffContentType (p )
w .WriteHeader (http .StatusOK )
bodyAllowed = true
}
if !bodyAllowed {
return 0 , http .ErrBodyNotAllowed
}
w .numWritten += int64 (len (p ))
if w .contentLen != 0 && w .numWritten > w .contentLen {
return 0 , http .ErrContentLength
}
if w .isHead {
return len (p ), nil
}
if !w .headerWritten {
if len (w .smallResponseBuf )+len (p ) < maxSmallResponseSize {
w .smallResponseBuf = append (w .smallResponseBuf , p ...)
return len (p ), nil
}
}
return w .doWrite (p )
}
func (w *responseWriter ) doWrite (p []byte ) (int , error ) {
if !w .headerWritten {
w .sniffContentType (w .smallResponseBuf )
if err := w .writeHeader (w .status ); err != nil {
return 0 , maybeReplaceError (err )
}
w .headerWritten = true
}
l := uint64 (len (w .smallResponseBuf ) + len (p ))
if l == 0 {
return 0 , nil
}
df := &dataFrame {Length : l }
w .buf = w .buf [:0 ]
w .buf = df .Append (w .buf )
if w .str .qlogger != nil {
w .str .qlogger .RecordEvent (qlog .FrameCreated {
StreamID : w .str .StreamID (),
Raw : qlog .RawInfo {Length : len (w .buf ) + int (l ), PayloadLength : int (l )},
Frame : qlog .Frame {Frame : qlog .DataFrame {}},
})
}
if _ , err := w .str .writeUnframed (w .buf ); err != nil {
return 0 , maybeReplaceError (err )
}
if len (w .smallResponseBuf ) > 0 {
if _ , err := w .str .writeUnframed (w .smallResponseBuf ); err != nil {
return 0 , maybeReplaceError (err )
}
w .smallResponseBuf = nil
}
var n int
if len (p ) > 0 {
var err error
n , err = w .str .writeUnframed (p )
if err != nil {
return n , maybeReplaceError (err )
}
}
return n , nil
}
func (w *responseWriter ) writeHeader (status int ) error {
var headerFields []qlog .HeaderField
var headers bytes .Buffer
enc := qpack .NewEncoder (&headers )
if err := enc .WriteField (qpack .HeaderField {Name : ":status" , Value : strconv .Itoa (status )}); err != nil {
return err
}
if w .str .qlogger != nil {
headerFields = append (headerFields , qlog .HeaderField {Name : ":status" , Value : strconv .Itoa (status )})
}
if vals , ok := w .header ["Trailer" ]; ok {
for _ , val := range vals {
for _ , trailer := range strings .Split (val , "," ) {
trailer = textproto .CanonicalMIMEHeaderKey (strings .TrimSpace (trailer ))
w .declareTrailer (trailer )
}
}
}
for k , v := range w .header {
if _ , excluded := w .trailers [k ]; excluded {
continue
}
if strings .HasPrefix (k , http .TrailerPrefix ) {
continue
}
for index := range v {
name := strings .ToLower (k )
value := v [index ]
if err := enc .WriteField (qpack .HeaderField {Name : name , Value : value }); err != nil {
return err
}
if w .str .qlogger != nil {
headerFields = append (headerFields , qlog .HeaderField {Name : name , Value : value })
}
}
}
buf := make ([]byte , 0 , frameHeaderLen +headers .Len ())
buf = (&headersFrame {Length : uint64 (headers .Len ())}).Append (buf )
buf = append (buf , headers .Bytes ()...)
if w .str .qlogger != nil {
qlogCreatedHeadersFrame (w .str .qlogger , w .str .StreamID (), len (buf ), headers .Len (), headerFields )
}
_ , err := w .str .writeUnframed (buf )
return err
}
func (w *responseWriter ) FlushError () error {
if !w .headerComplete {
w .WriteHeader (http .StatusOK )
}
_ , err := w .doWrite (nil )
return err
}
func (w *responseWriter ) flushTrailers () {
if w .trailerWritten {
return
}
if err := w .writeTrailers (); err != nil {
w .logger .Debug ("could not write trailers" , "error" , err )
}
}
func (w *responseWriter ) Flush () {
if err := w .FlushError (); err != nil {
if w .logger != nil {
w .logger .Debug ("could not flush to stream" , "error" , err )
}
}
}
func (w *responseWriter ) declareTrailer (k string ) {
if !httpguts .ValidTrailerHeader (k ) {
w .logger .Debug ("ignoring invalid trailer" , slog .String ("header" , k ))
return
}
if w .trailers == nil {
w .trailers = make (map [string ]struct {})
}
w .trailers [k ] = struct {}{}
}
func (w *responseWriter ) writeTrailers () error {
for k := range w .header {
if strings .HasPrefix (k , http .TrailerPrefix ) {
w .declareTrailer (k )
}
}
if len (w .trailers ) == 0 {
return nil
}
trailers := make (http .Header , len (w .trailers ))
for trailer := range w .trailers {
if vals , ok := w .header [trailer ]; ok {
trailers [strings .TrimPrefix (trailer , http .TrailerPrefix )] = vals
}
}
written , err := writeTrailers (w .str .datagramStream , trailers , w .str .StreamID (), w .str .qlogger )
if written {
w .trailerWritten = true
}
return err
}
func (w *responseWriter ) HTTPStream () *Stream {
w .hijacked = true
w .Flush ()
return w .str
}
func (w *responseWriter ) wasStreamHijacked () bool { return w .hijacked }
func (w *responseWriter ) ReceivedSettings () <-chan struct {} {
return w .conn .ReceivedSettings ()
}
func (w *responseWriter ) Settings () *Settings {
return w .conn .Settings ()
}
func (w *responseWriter ) SetReadDeadline (deadline time .Time ) error {
return w .str .SetReadDeadline (deadline )
}
func (w *responseWriter ) SetWriteDeadline (deadline time .Time ) error {
return w .str .SetWriteDeadline (deadline )
}
func bodyAllowedForStatus(status int ) bool {
switch {
case status >= 100 && status <= 199 :
return false
case status == http .StatusNoContent :
return false
case status == http .StatusNotModified :
return false
}
return true
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .