package webtransport
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"slices"
"sync"
"time"
"unicode/utf8"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/quicvarint"
"github.com/dunglas/httpsfv"
)
const (
wtAvailableProtocolsHeader = "WT-Available-Protocols"
wtProtocolHeader = "WT-Protocol"
)
const (
webTransportFrameType = 0x41
webTransportUniStreamType = 0x54
)
type quicConnKeyType struct {}
var quicConnKey = quicConnKeyType {}
func ConfigureHTTP3Server (s *http3 .Server ) {
if s .AdditionalSettings == nil {
s .AdditionalSettings = make (map [uint64 ]uint64 , 1 )
}
s .AdditionalSettings [settingsEnableWebtransport ] = 1
s .EnableDatagrams = true
origConnContext := s .ConnContext
s .ConnContext = func (ctx context .Context , conn *quic .Conn ) context .Context {
if origConnContext != nil {
ctx = origConnContext (ctx , conn )
}
ctx = context .WithValue (ctx , quicConnKey , conn )
return ctx
}
}
type Server struct {
H3 *http3 .Server
ApplicationProtocols []string
ReorderingTimeout time .Duration
CheckOrigin func (r *http .Request ) bool
ctx context .Context
ctxCancel context .CancelFunc
refCount sync .WaitGroup
initOnce sync .Once
initErr error
connsMx sync .Mutex
conns map [*quic .Conn ]*sessionManager
}
func (s *Server ) initialize () error {
s .initOnce .Do (func () {
s .initErr = s .init ()
})
return s .initErr
}
func (s *Server ) timeout () time .Duration {
timeout := s .ReorderingTimeout
if timeout == 0 {
return 5 * time .Second
}
return timeout
}
func (s *Server ) init () error {
s .ctx , s .ctxCancel = context .WithCancel (context .Background ())
s .conns = make (map [*quic .Conn ]*sessionManager )
if s .CheckOrigin == nil {
s .CheckOrigin = checkSameOrigin
}
return nil
}
func (s *Server ) Serve (conn net .PacketConn ) error {
if err := s .initialize (); err != nil {
return err
}
var quicConf *quic .Config
if s .H3 .QUICConfig != nil {
quicConf = s .H3 .QUICConfig .Clone ()
} else {
quicConf = &quic .Config {}
}
quicConf .EnableDatagrams = true
quicConf .EnableStreamResetPartialDelivery = true
ln , err := quic .ListenEarly (conn , s .H3 .TLSConfig , quicConf )
if err != nil {
return err
}
defer ln .Close ()
for {
qconn , err := ln .Accept (s .ctx )
if err != nil {
return err
}
s .refCount .Add (1 )
go func () {
defer s .refCount .Done ()
if err := s .ServeQUICConn (qconn ); err != nil {
log .Printf ("http3: error serving QUIC connection: %v" , err )
}
}()
}
}
func (s *Server ) ServeQUICConn (conn *quic .Conn ) error {
connState := conn .ConnectionState ()
if !connState .SupportsDatagrams .Local {
return errors .New ("webtransport: QUIC DATAGRAM support required, enable it via QUICConfig.EnableDatagrams" )
}
if !connState .SupportsStreamResetPartialDelivery .Local {
return errors .New ("webtransport: QUIC Stream Resets with Partial Delivery required, enable it via QUICConfig.EnableStreamResetPartialDelivery" )
}
if err := s .initialize (); err != nil {
return err
}
s .connsMx .Lock ()
sessMgr , ok := s .conns [conn ]
if !ok {
sessMgr = newSessionManager (s .timeout ())
s .conns [conn ] = sessMgr
}
s .connsMx .Unlock ()
context .AfterFunc (conn .Context (), func () {
s .connsMx .Lock ()
delete (s .conns , conn )
s .connsMx .Unlock ()
sessMgr .Close ()
})
http3Conn , err := s .H3 .NewRawServerConn (conn )
if err != nil {
return err
}
go func () {
select {
case <- s .ctx .Done ():
conn .CloseWithError (0 , "" )
case <- conn .Context ().Done ():
}
}()
var wg sync .WaitGroup
wg .Add (2 )
go func () {
defer wg .Done ()
for {
str , err := conn .AcceptStream (s .ctx )
if err != nil {
return
}
wg .Add (1 )
go func () {
defer wg .Done ()
typ , err := quicvarint .Peek (str )
if err != nil {
return
}
if typ != webTransportFrameType {
http3Conn .HandleRequestStream (str )
return
}
if _ , err := quicvarint .Read (quicvarint .NewReader (str )); err != nil {
return
}
id , err := quicvarint .Read (quicvarint .NewReader (str ))
if err != nil {
str .CancelRead (quic .StreamErrorCode (http3 .ErrCodeGeneralProtocolError ))
str .CancelWrite (quic .StreamErrorCode (http3 .ErrCodeGeneralProtocolError ))
return
}
sessMgr .AddStream (str , sessionID (id ))
}()
}
}()
go func () {
defer wg .Done ()
for {
str , err := conn .AcceptUniStream (s .ctx )
if err != nil {
return
}
wg .Add (1 )
go func () {
defer wg .Done ()
typ , err := quicvarint .Peek (str )
if err != nil {
return
}
if typ != webTransportUniStreamType {
http3Conn .HandleUnidirectionalStream (str )
return
}
r := quicvarint .NewReader (str )
if _ , err := quicvarint .Read (r ); err != nil {
return
}
id , err := quicvarint .Read (r )
if err != nil {
str .CancelRead (quic .StreamErrorCode (http3 .ErrCodeGeneralProtocolError ))
return
}
sessMgr .AddUniStream (str , sessionID (id ))
}()
}
}()
wg .Wait ()
return nil
}
func (s *Server ) ListenAndServe () error {
addr := s .H3 .Addr
if addr == "" {
addr = ":https"
}
udpAddr , err := net .ResolveUDPAddr ("udp" , addr )
if err != nil {
return err
}
conn , err := net .ListenUDP ("udp" , udpAddr )
if err != nil {
return err
}
return s .Serve (conn )
}
func (s *Server ) ListenAndServeTLS (certFile , keyFile string ) error {
cert , err := tls .LoadX509KeyPair (certFile , keyFile )
if err != nil {
return err
}
if s .H3 .TLSConfig == nil {
s .H3 .TLSConfig = &tls .Config {}
}
s .H3 .TLSConfig .Certificates = []tls .Certificate {cert }
return s .ListenAndServe ()
}
func (s *Server ) Close () error {
s .initOnce .Do (func () {})
if s .ctxCancel != nil {
s .ctxCancel ()
}
s .connsMx .Lock ()
if s .conns != nil {
for _ , mgr := range s .conns {
mgr .Close ()
}
s .conns = nil
}
s .connsMx .Unlock ()
err := s .H3 .Close ()
s .refCount .Wait ()
return err
}
func (s *Server ) Upgrade (w http .ResponseWriter , r *http .Request ) (*Session , error ) {
if err := s .initialize (); err != nil {
return nil , err
}
if r .Method != http .MethodConnect {
return nil , fmt .Errorf ("expected CONNECT request, got %s" , r .Method )
}
if r .Proto != protocolHeader {
return nil , fmt .Errorf ("unexpected protocol: %s" , r .Proto )
}
if !s .CheckOrigin (r ) {
return nil , errors .New ("webtransport: request origin not allowed" )
}
id := r .Context ().Value (quicConnKey )
if id == nil {
return nil , errors .New ("webtransport: missing QUIC connection" )
}
conn := id .(*quic .Conn )
selectedProtocol := s .selectProtocol (r .Header [http .CanonicalHeaderKey (wtAvailableProtocolsHeader )])
settingser := w .(http3 .Settingser )
timer := time .NewTimer (s .timeout ())
defer timer .Stop ()
select {
case <- settingser .ReceivedSettings ():
case <- timer .C :
return nil , errors .New ("webtransport: didn't receive the client's SETTINGS on time" )
}
settings := settingser .Settings ()
if !settings .EnableDatagrams {
return nil , errors .New ("webtransport: missing datagram support" )
}
if selectedProtocol != "" {
v , err := httpsfv .Marshal (httpsfv .NewItem (selectedProtocol ))
if err != nil {
return nil , fmt .Errorf ("failed to marshal selected protocol: %w" , err )
}
w .Header ().Add (wtProtocolHeader , v )
}
w .WriteHeader (http .StatusOK )
w .(http .Flusher ).Flush ()
str := w .(http3 .HTTPStreamer ).HTTPStream ()
sessID := sessionID (str .StreamID ())
s .connsMx .Lock ()
defer s .connsMx .Unlock ()
sessMgr , ok := s .conns [conn ]
if !ok {
return nil , errors .New ("webtransport: connection session manager not found" )
}
sess := newSession (context .WithoutCancel (r .Context ()), sessID , conn , str , selectedProtocol )
sessMgr .AddSession (sessID , sess )
return sess , nil
}
func (s *Server ) selectProtocol (theirs []string ) string {
list , err := httpsfv .UnmarshalList (theirs )
if err != nil {
return ""
}
offered := make ([]string , 0 , len (list ))
for _ , item := range list {
i , ok := item .(httpsfv .Item )
if !ok {
return ""
}
protocol , ok := i .Value .(string )
if !ok {
return ""
}
offered = append (offered , protocol )
}
var selectedProtocol string
for _ , p := range offered {
if slices .Contains (s .ApplicationProtocols , p ) {
selectedProtocol = p
break
}
}
return selectedProtocol
}
func checkSameOrigin(r *http .Request ) bool {
origin := r .Header .Get ("Origin" )
if origin == "" {
return true
}
u , err := url .Parse (origin )
if err != nil {
return false
}
return equalASCIIFold (u .Host , r .Host )
}
func equalASCIIFold(s , t string ) bool {
for s != "" && t != "" {
sr , size := utf8 .DecodeRuneInString (s )
s = s [size :]
tr , size := utf8 .DecodeRuneInString (t )
t = t [size :]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .