package webtransport
import (
"context"
"encoding/binary"
"io"
"math/rand/v2"
"net"
"sync"
"time"
"unicode/utf8"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/quicvarint"
)
type sessionID uint64
const closeSessionCapsuleType http3 .CapsuleType = 0x2843
const maxCloseCapsuleErrorMsgLen = 1024
type acceptQueue[T any ] struct {
mx sync .Mutex
c chan struct {}
queue []T
}
func newAcceptQueue[T any ]() *acceptQueue [T ] {
return &acceptQueue [T ]{c : make (chan struct {}, 1 )}
}
func (q *acceptQueue [T ]) Add (str T ) {
q .mx .Lock ()
q .queue = append (q .queue , str )
q .mx .Unlock ()
select {
case q .c <- struct {}{}:
default :
}
}
func (q *acceptQueue [T ]) Next () T {
q .mx .Lock ()
defer q .mx .Unlock ()
if len (q .queue ) == 0 {
return *new (T )
}
str := q .queue [0 ]
q .queue = q .queue [1 :]
return str
}
func (q *acceptQueue [T ]) Chan () <-chan struct {} { return q .c }
type http3Stream interface {
io .ReadWriteCloser
ReceiveDatagram(context .Context ) ([]byte , error )
SendDatagram([]byte ) error
CancelRead(quic .StreamErrorCode )
CancelWrite(quic .StreamErrorCode )
SetWriteDeadline(time .Time ) error
}
var (
_ http3Stream = &http3 .Stream {}
_ http3Stream = &http3 .RequestStream {}
)
type SessionState struct {
ConnectionState quic .ConnectionState
ApplicationProtocol string
}
type Session struct {
sessionID sessionID
conn *quic .Conn
str http3Stream
applicationProtocol string
streamHdr []byte
uniStreamHdr []byte
ctx context .Context
closeMx sync .Mutex
closeErr error
streamCtxs map [int ]context .CancelFunc
bidiAcceptQueue acceptQueue [*Stream ]
uniAcceptQueue acceptQueue [*ReceiveStream ]
streams streamsMap
}
func newSession(ctx context .Context , sessionID sessionID , conn *quic .Conn , str http3Stream , applicationProtocol string ) *Session {
ctx , ctxCancel := context .WithCancel (ctx )
c := &Session {
sessionID : sessionID ,
conn : conn ,
str : str ,
applicationProtocol : applicationProtocol ,
ctx : ctx ,
streamCtxs : make (map [int ]context .CancelFunc ),
bidiAcceptQueue : *newAcceptQueue [*Stream ](),
uniAcceptQueue : *newAcceptQueue [*ReceiveStream ](),
streams : *newStreamsMap (),
}
c .uniStreamHdr = make ([]byte , 0 , 2 +quicvarint .Len (uint64 (c .sessionID )))
c .uniStreamHdr = quicvarint .Append (c .uniStreamHdr , webTransportUniStreamType )
c .uniStreamHdr = quicvarint .Append (c .uniStreamHdr , uint64 (c .sessionID ))
c .streamHdr = make ([]byte , 0 , 2 +quicvarint .Len (uint64 (c .sessionID )))
c .streamHdr = quicvarint .Append (c .streamHdr , webTransportFrameType )
c .streamHdr = quicvarint .Append (c .streamHdr , uint64 (c .sessionID ))
go func () {
defer ctxCancel ()
c .handleConn ()
}()
return c
}
func (s *Session ) handleConn () {
err := s .parseNextCapsule ()
s .closeWithError (err )
}
func (s *Session ) parseNextCapsule () error {
for {
typ , r , err := http3 .ParseCapsule (quicvarint .NewReader (s .str ))
if err != nil {
return err
}
switch typ {
case closeSessionCapsuleType :
var b [4 ]byte
if _ , err := io .ReadFull (r , b [:]); err != nil {
return err
}
appErrCode := binary .BigEndian .Uint32 (b [:])
appErrMsg , err := io .ReadAll (io .LimitReader (r , maxCloseCapsuleErrorMsgLen ))
if err != nil {
return err
}
return &SessionError {
Remote : true ,
ErrorCode : SessionErrorCode (appErrCode ),
Message : string (appErrMsg ),
}
default :
if _ , err := io .ReadAll (r ); err != nil {
return err
}
}
}
}
func (s *Session ) addStream (qstr *quic .Stream , addStreamHeader bool ) *Stream {
var hdr []byte
if addStreamHeader {
hdr = s .streamHdr
}
str := newStream (qstr , hdr , func () { s .streams .RemoveStream (qstr .StreamID ()) })
s .streams .AddStream (qstr .StreamID (), str .closeWithSession )
return str
}
func (s *Session ) addReceiveStream (qstr *quic .ReceiveStream ) *ReceiveStream {
str := newReceiveStream (qstr , func () { s .streams .RemoveStream (qstr .StreamID ()) })
s .streams .AddStream (qstr .StreamID (), str .closeWithSession )
return str
}
func (s *Session ) addSendStream (qstr *quic .SendStream ) *SendStream {
str := newSendStream (qstr , s .uniStreamHdr , func () { s .streams .RemoveStream (qstr .StreamID ()) })
s .streams .AddStream (qstr .StreamID (), str .closeWithSession )
return str
}
func (s *Session ) addIncomingStream (qstr *quic .Stream ) {
s .closeMx .Lock ()
closeErr := s .closeErr
if closeErr != nil {
s .closeMx .Unlock ()
qstr .CancelRead (WTSessionGoneErrorCode )
qstr .CancelWrite (WTSessionGoneErrorCode )
return
}
str := s .addStream (qstr , false )
s .closeMx .Unlock ()
s .bidiAcceptQueue .Add (str )
}
func (s *Session ) addIncomingUniStream (qstr *quic .ReceiveStream ) {
s .closeMx .Lock ()
closeErr := s .closeErr
if closeErr != nil {
s .closeMx .Unlock ()
qstr .CancelRead (WTSessionGoneErrorCode )
return
}
str := s .addReceiveStream (qstr )
s .closeMx .Unlock ()
s .uniAcceptQueue .Add (str )
}
func (s *Session ) Context () context .Context {
return s .ctx
}
func (s *Session ) AcceptStream (ctx context .Context ) (*Stream , error ) {
s .closeMx .Lock ()
closeErr := s .closeErr
s .closeMx .Unlock ()
if closeErr != nil {
return nil , closeErr
}
for {
if str := s .bidiAcceptQueue .Next (); str != nil {
return str , nil
}
select {
case <- s .ctx .Done ():
return nil , s .closeErr
case <- ctx .Done ():
return nil , ctx .Err ()
case <- s .bidiAcceptQueue .Chan ():
}
}
}
func (s *Session ) AcceptUniStream (ctx context .Context ) (*ReceiveStream , error ) {
s .closeMx .Lock ()
closeErr := s .closeErr
s .closeMx .Unlock ()
if closeErr != nil {
return nil , s .closeErr
}
for {
if str := s .uniAcceptQueue .Next (); str != nil {
return str , nil
}
select {
case <- s .ctx .Done ():
return nil , s .closeErr
case <- ctx .Done ():
return nil , ctx .Err ()
case <- s .uniAcceptQueue .Chan ():
}
}
}
func (s *Session ) OpenStream () (*Stream , error ) {
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
if s .closeErr != nil {
return nil , s .closeErr
}
qstr , err := s .conn .OpenStream ()
if err != nil {
return nil , err
}
return s .addStream (qstr , true ), nil
}
func (s *Session ) addStreamCtxCancel (cancel context .CancelFunc ) (id int ) {
rand :
id = rand .Int ()
if _ , ok := s .streamCtxs [id ]; ok {
goto rand
}
s .streamCtxs [id ] = cancel
return id
}
func (s *Session ) OpenStreamSync (ctx context .Context ) (*Stream , error ) {
s .closeMx .Lock ()
if s .closeErr != nil {
s .closeMx .Unlock ()
return nil , s .closeErr
}
ctx , cancel := context .WithCancel (ctx )
id := s .addStreamCtxCancel (cancel )
s .closeMx .Unlock ()
qstr , err := s .conn .OpenStreamSync (ctx )
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
delete (s .streamCtxs , id )
if qstr != nil && s .closeErr != nil {
qstr .CancelRead (WTSessionGoneErrorCode )
qstr .CancelWrite (WTSessionGoneErrorCode )
return nil , s .closeErr
}
if err != nil {
if s .closeErr != nil {
return nil , s .closeErr
}
return nil , err
}
return s .addStream (qstr , true ), nil
}
func (s *Session ) OpenUniStream () (*SendStream , error ) {
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
if s .closeErr != nil {
return nil , s .closeErr
}
qstr , err := s .conn .OpenUniStream ()
if err != nil {
return nil , err
}
return s .addSendStream (qstr ), nil
}
func (s *Session ) OpenUniStreamSync (ctx context .Context ) (str *SendStream , err error ) {
s .closeMx .Lock ()
if s .closeErr != nil {
s .closeMx .Unlock ()
return nil , s .closeErr
}
ctx , cancel := context .WithCancel (ctx )
id := s .addStreamCtxCancel (cancel )
s .closeMx .Unlock ()
qstr , err := s .conn .OpenUniStreamSync (ctx )
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
delete (s .streamCtxs , id )
if qstr != nil && s .closeErr != nil {
qstr .CancelWrite (WTSessionGoneErrorCode )
return nil , s .closeErr
}
if err != nil {
if s .closeErr != nil {
return nil , s .closeErr
}
return nil , err
}
return s .addSendStream (qstr ), nil
}
func (s *Session ) LocalAddr () net .Addr {
return s .conn .LocalAddr ()
}
func (s *Session ) RemoteAddr () net .Addr {
return s .conn .RemoteAddr ()
}
func (s *Session ) CloseWithError (code SessionErrorCode , msg string ) error {
first , err := s .closeWithError (&SessionError {ErrorCode : code , Message : msg })
if err != nil || !first {
return err
}
if len (msg ) > maxCloseCapsuleErrorMsgLen {
msg = truncateUTF8 (msg , maxCloseCapsuleErrorMsgLen )
}
b := make ([]byte , 4 , 4 +len (msg ))
binary .BigEndian .PutUint32 (b , uint32 (code ))
b = append (b , []byte (msg )...)
s .str .SetWriteDeadline (time .Now ().Add (10 * time .Millisecond ))
if err := http3 .WriteCapsule (quicvarint .NewWriter (s .str ), closeSessionCapsuleType , b ); err != nil {
s .str .CancelWrite (WTSessionGoneErrorCode )
}
s .str .CancelRead (WTSessionGoneErrorCode )
err = s .str .Close ()
<-s .ctx .Done ()
return err
}
func (s *Session ) SendDatagram (b []byte ) error {
return s .str .SendDatagram (b )
}
func (s *Session ) ReceiveDatagram (ctx context .Context ) ([]byte , error ) {
return s .str .ReceiveDatagram (ctx )
}
func (s *Session ) closeWithError (closeErr error ) (bool , error ) {
s .closeMx .Lock ()
defer s .closeMx .Unlock ()
if s .closeErr != nil {
return false , nil
}
s .closeErr = closeErr
for _ , cancel := range s .streamCtxs {
cancel ()
}
s .streams .CloseSession (closeErr )
return true , nil
}
func (s *Session ) SessionState () SessionState {
return SessionState {
ConnectionState : s .conn .ConnectionState (),
ApplicationProtocol : s .applicationProtocol ,
}
}
func truncateUTF8(s string , n int ) string {
if len (s ) <= n {
return s
}
for n > 0 && !utf8 .RuneStart (s [n ]) {
n --
}
return s [:n ]
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .