package webtransport
import (
"context"
"encoding/binary"
"errors"
"io"
"math/rand"
"net"
"sync"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/quicvarint"
)
type sessionID uint64
const closeWebtransportSessionCapsuleType http3 .CapsuleType = 0x2843
type acceptQueue[T any ] struct {
mx sync .Mutex
c chan struct {}
queue []T
}
func newAcceptQueue[T any ]() *acceptQueue [T ] {
return &acceptQueue [T ]{c : make (chan struct {}, 1 )}
}
func (q *acceptQueue [T ]) Add (str T ) {
q .mx .Lock ()
q .queue = append (q .queue , str )
q .mx .Unlock ()
select {
case q .c <- struct {}{}:
default :
}
}
func (q *acceptQueue [T ]) Next () T {
q .mx .Lock ()
defer q .mx .Unlock ()
if len (q .queue ) == 0 {
return *new (T )
}
str := q .queue [0 ]
q .queue = q .queue [1 :]
return str
}
func (q *acceptQueue [T ]) Chan () <-chan struct {} { return q .c }
type Session struct {
sessionID sessionID
qconn http3 .Connection
requestStr http3 .Stream
streamHdr []byte
uniStreamHdr []byte
ctx context .Context
closeMx sync .Mutex
closeErr error
streamCtxs map [int ]context .CancelFunc
bidiAcceptQueue acceptQueue [Stream ]
uniAcceptQueue acceptQueue [ReceiveStream ]
streams streamsMap
}
func newSession(sessionID sessionID , qconn http3 .Connection , requestStr http3 .Stream ) *Session {
tracingID := qconn .Context ().Value (quic .ConnectionTracingKey ).(quic .ConnectionTracingID )
ctx , ctxCancel := context .WithCancel (context .WithValue (context .Background (), quic .ConnectionTracingKey , tracingID ))
c := &Session {
sessionID : sessionID ,
qconn : qconn ,
requestStr : requestStr ,
ctx : ctx ,
streamCtxs : make (map [int ]context .CancelFunc ),
bidiAcceptQueue : *newAcceptQueue [Stream ](),
uniAcceptQueue : *newAcceptQueue [ReceiveStream ](),
streams : *newStreamsMap (),
}
c .uniStreamHdr = make ([]byte , 0 , 2 +quicvarint .Len (uint64 (c .sessionID )))
c .uniStreamHdr = quicvarint .Append (c .uniStreamHdr , webTransportUniStreamType )
c .uniStreamHdr = quicvarint .Append (c .uniStreamHdr , uint64 (c .sessionID ))
c .streamHdr = make ([]byte , 0 , 2 +quicvarint .Len (uint64 (c .sessionID )))
c .streamHdr = quicvarint .Append (c .streamHdr , webTransportFrameType )
c .streamHdr = quicvarint .Append (c .streamHdr , uint64 (c .sessionID ))
go func () {
defer ctxCancel ()
c .handleConn ()
}()
return c
}
func (s *Session ) handleConn () {
var closeErr *SessionError
err := s .parseNextCapsule ()
if !errors .As (err , &closeErr ) {
closeErr = &SessionError {Remote : true }
}
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
if s .closeErr == nil {
s .closeErr = closeErr
}
for _ , cancel := range s .streamCtxs {
cancel ()
}
s .streams .CloseSession ()
}
func (s *Session ) parseNextCapsule () error {
for {
typ , r , err := http3 .ParseCapsule (quicvarint .NewReader (s .requestStr ))
if err != nil {
return err
}
switch typ {
case closeWebtransportSessionCapsuleType :
b := make ([]byte , 4 )
if _ , err := io .ReadFull (r , b ); err != nil {
return err
}
appErrCode := binary .BigEndian .Uint32 (b )
appErrMsg , err := io .ReadAll (r )
if err != nil {
return err
}
return &SessionError {
Remote : true ,
ErrorCode : SessionErrorCode (appErrCode ),
Message : string (appErrMsg ),
}
default :
if _ , err := io .ReadAll (r ); err != nil {
return err
}
}
}
}
func (s *Session ) addStream (qstr quic .Stream , addStreamHeader bool ) Stream {
var hdr []byte
if addStreamHeader {
hdr = s .streamHdr
}
str := newStream (qstr , hdr , func () { s .streams .RemoveStream (qstr .StreamID ()) })
s .streams .AddStream (qstr .StreamID (), str .closeWithSession )
return str
}
func (s *Session ) addReceiveStream (qstr quic .ReceiveStream ) ReceiveStream {
str := newReceiveStream (qstr , func () { s .streams .RemoveStream (qstr .StreamID ()) })
s .streams .AddStream (qstr .StreamID (), func () {
str .closeWithSession ()
})
return str
}
func (s *Session ) addSendStream (qstr quic .SendStream ) SendStream {
str := newSendStream (qstr , s .uniStreamHdr , func () { s .streams .RemoveStream (qstr .StreamID ()) })
s .streams .AddStream (qstr .StreamID (), str .closeWithSession )
return str
}
func (s *Session ) addIncomingStream (qstr quic .Stream ) {
s .closeMx .Lock ()
closeErr := s .closeErr
if closeErr != nil {
s .closeMx .Unlock ()
qstr .CancelRead (sessionCloseErrorCode )
qstr .CancelWrite (sessionCloseErrorCode )
return
}
str := s .addStream (qstr , false )
s .closeMx .Unlock ()
s .bidiAcceptQueue .Add (str )
}
func (s *Session ) addIncomingUniStream (qstr quic .ReceiveStream ) {
s .closeMx .Lock ()
closeErr := s .closeErr
if closeErr != nil {
s .closeMx .Unlock ()
qstr .CancelRead (sessionCloseErrorCode )
return
}
str := s .addReceiveStream (qstr )
s .closeMx .Unlock ()
s .uniAcceptQueue .Add (str )
}
func (s *Session ) Context () context .Context {
return s .ctx
}
func (s *Session ) AcceptStream (ctx context .Context ) (Stream , error ) {
s .closeMx .Lock ()
closeErr := s .closeErr
s .closeMx .Unlock ()
if closeErr != nil {
return nil , closeErr
}
for {
if str := s .bidiAcceptQueue .Next (); str != nil {
return str , nil
}
select {
case <- s .ctx .Done ():
return nil , s .closeErr
case <- ctx .Done ():
return nil , ctx .Err ()
case <- s .bidiAcceptQueue .Chan ():
}
}
}
func (s *Session ) AcceptUniStream (ctx context .Context ) (ReceiveStream , error ) {
s .closeMx .Lock ()
closeErr := s .closeErr
s .closeMx .Unlock ()
if closeErr != nil {
return nil , s .closeErr
}
for {
if str := s .uniAcceptQueue .Next (); str != nil {
return str , nil
}
select {
case <- s .ctx .Done ():
return nil , s .closeErr
case <- ctx .Done ():
return nil , ctx .Err ()
case <- s .uniAcceptQueue .Chan ():
}
}
}
func (s *Session ) OpenStream () (Stream , error ) {
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
if s .closeErr != nil {
return nil , s .closeErr
}
qstr , err := s .qconn .OpenStream ()
if err != nil {
return nil , err
}
return s .addStream (qstr , true ), nil
}
func (s *Session ) addStreamCtxCancel (cancel context .CancelFunc ) (id int ) {
rand :
id = rand .Int ()
if _ , ok := s .streamCtxs [id ]; ok {
goto rand
}
s .streamCtxs [id ] = cancel
return id
}
func (s *Session ) OpenStreamSync (ctx context .Context ) (Stream , error ) {
s .closeMx .Lock ()
if s .closeErr != nil {
s .closeMx .Unlock ()
return nil , s .closeErr
}
ctx , cancel := context .WithCancel (ctx )
id := s .addStreamCtxCancel (cancel )
s .closeMx .Unlock ()
qstr , err := s .qconn .OpenStreamSync (ctx )
if err != nil {
if s .closeErr != nil {
return nil , s .closeErr
}
return nil , err
}
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
delete (s .streamCtxs , id )
if s .closeErr != nil {
qstr .CancelWrite (sessionCloseErrorCode )
qstr .CancelRead (sessionCloseErrorCode )
return nil , s .closeErr
}
return s .addStream (qstr , true ), nil
}
func (s *Session ) OpenUniStream () (SendStream , error ) {
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
if s .closeErr != nil {
return nil , s .closeErr
}
qstr , err := s .qconn .OpenUniStream ()
if err != nil {
return nil , err
}
return s .addSendStream (qstr ), nil
}
func (s *Session ) OpenUniStreamSync (ctx context .Context ) (str SendStream , err error ) {
s .closeMx .Lock ()
if s .closeErr != nil {
s .closeMx .Unlock ()
return nil , s .closeErr
}
ctx , cancel := context .WithCancel (ctx )
id := s .addStreamCtxCancel (cancel )
s .closeMx .Unlock ()
qstr , err := s .qconn .OpenUniStreamSync (ctx )
if err != nil {
if s .closeErr != nil {
return nil , s .closeErr
}
return nil , err
}
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
delete (s .streamCtxs , id )
if s .closeErr != nil {
qstr .CancelWrite (sessionCloseErrorCode )
return nil , s .closeErr
}
return s .addSendStream (qstr ), nil
}
func (s *Session ) LocalAddr () net .Addr {
return s .qconn .LocalAddr ()
}
func (s *Session ) RemoteAddr () net .Addr {
return s .qconn .RemoteAddr ()
}
func (s *Session ) CloseWithError (code SessionErrorCode , msg string ) error {
first , err := s .closeWithError (code , msg )
if err != nil || !first {
return err
}
s .requestStr .CancelRead (1337 )
err = s .requestStr .Close ()
<-s .ctx .Done ()
return err
}
func (s *Session ) SendDatagram (b []byte ) error {
return s .requestStr .SendDatagram (b )
}
func (s *Session ) ReceiveDatagram (ctx context .Context ) ([]byte , error ) {
return s .requestStr .ReceiveDatagram (ctx )
}
func (s *Session ) closeWithError (code SessionErrorCode , msg string ) (bool , error ) {
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
if s .closeErr != nil {
return false , nil
}
s .closeErr = &SessionError {
ErrorCode : code ,
Message : msg ,
}
b := make ([]byte , 4 , 4 +len (msg ))
binary .BigEndian .PutUint32 (b , uint32 (code ))
b = append (b , []byte (msg )...)
return true , http3 .WriteCapsule (
quicvarint .NewWriter (s .requestStr ),
closeWebtransportSessionCapsuleType ,
b ,
)
}
func (s *Session ) ConnectionState () quic .ConnectionState {
return s .qconn .ConnectionState ()
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .