package yamux
import (
"bufio"
"context"
"fmt"
"io"
"log"
"math"
"net"
"os"
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"time"
pool "github.com/libp2p/go-buffer-pool"
)
type MemoryManager interface {
ReserveMemory (size int , prio uint8 ) error
ReleaseMemory (size int )
Done ()
}
type nullMemoryManagerImpl struct {}
func (n nullMemoryManagerImpl ) ReserveMemory (size int , prio uint8 ) error { return nil }
func (n nullMemoryManagerImpl ) ReleaseMemory (size int ) {}
func (n nullMemoryManagerImpl ) Done () {}
var nullMemoryManager = &nullMemoryManagerImpl {}
type Session struct {
rtt int64
localGoAway int32
nextStreamID uint32
config *Config
logger *log .Logger
conn net .Conn
reader io .Reader
newMemoryManager func () (MemoryManager , error )
pingLock sync .Mutex
pingID uint32
activePing *ping
numIncomingStreams uint32
streams map [uint32 ]*Stream
inflight map [uint32 ]struct {}
streamLock sync .Mutex
synCh chan struct {}
acceptCh chan *Stream
sendCh chan []byte
pongCh, pingCh chan uint32
recvDoneCh chan struct {}
recvErr error
sendDoneCh chan struct {}
client bool
shutdown bool
shutdownErr error
shutdownCh chan struct {}
shutdownLock sync .Mutex
keepaliveLock sync .Mutex
keepaliveTimer *time .Timer
keepaliveActive bool
}
func newSession(config *Config , conn net .Conn , client bool , readBuf int , newMemoryManager func () (MemoryManager , error )) *Session {
var reader io .Reader = conn
if readBuf > 0 {
reader = bufio .NewReaderSize (reader , readBuf )
}
if newMemoryManager == nil {
newMemoryManager = func () (MemoryManager , error ) { return nullMemoryManager , nil }
}
s := &Session {
config : config ,
client : client ,
logger : log .New (config .LogOutput , "" , log .LstdFlags ),
conn : conn ,
reader : reader ,
streams : make (map [uint32 ]*Stream ),
inflight : make (map [uint32 ]struct {}),
synCh : make (chan struct {}, config .AcceptBacklog ),
acceptCh : make (chan *Stream , config .AcceptBacklog ),
sendCh : make (chan []byte , 64 ),
pongCh : make (chan uint32 , config .PingBacklog ),
pingCh : make (chan uint32 ),
recvDoneCh : make (chan struct {}),
sendDoneCh : make (chan struct {}),
shutdownCh : make (chan struct {}),
newMemoryManager : newMemoryManager ,
}
if client {
s .nextStreamID = 1
} else {
s .nextStreamID = 2
}
if config .EnableKeepAlive {
s .startKeepalive ()
}
go s .recv ()
go s .send ()
go s .startMeasureRTT ()
return s
}
func (s *Session ) IsClosed () bool {
select {
case <- s .shutdownCh :
return true
default :
return false
}
}
func (s *Session ) CloseChan () <-chan struct {} {
return s .shutdownCh
}
func (s *Session ) NumStreams () int {
s .streamLock .Lock ()
num := len (s .streams )
s .streamLock .Unlock ()
return num
}
func (s *Session ) Open (ctx context .Context ) (net .Conn , error ) {
conn , err := s .OpenStream (ctx )
if err != nil {
return nil , err
}
return conn , nil
}
func (s *Session ) OpenStream (ctx context .Context ) (*Stream , error ) {
if s .IsClosed () {
return nil , s .shutdownErr
}
select {
case s .synCh <- struct {}{}:
case <- ctx .Done ():
return nil , ctx .Err ()
case <- s .shutdownCh :
return nil , s .shutdownErr
}
span , err := s .newMemoryManager ()
if err != nil {
return nil , fmt .Errorf ("failed to create resource scope span: %w" , err )
}
if err := span .ReserveMemory (initialStreamWindow , 255 ); err != nil {
return nil , err
}
GET_ID :
id := atomic .LoadUint32 (&s .nextStreamID )
if id >= math .MaxUint32 -1 {
span .Done ()
return nil , ErrStreamsExhausted
}
if !atomic .CompareAndSwapUint32 (&s .nextStreamID , id , id +2 ) {
goto GET_ID
}
stream := newStream (s , id , streamInit , initialStreamWindow , span )
s .streamLock .Lock ()
s .streams [id ] = stream
s .inflight [id ] = struct {}{}
s .streamLock .Unlock ()
if err := stream .sendWindowUpdate (ctx .Done ()); err != nil {
defer span .Done ()
select {
case <- s .synCh :
default :
s .logger .Printf ("[ERR] yamux: aborted stream open without inflight syn semaphore" )
}
return nil , err
}
return stream , nil
}
func (s *Session ) Accept () (net .Conn , error ) {
conn , err := s .AcceptStream ()
if err != nil {
return nil , err
}
return conn , err
}
func (s *Session ) AcceptStream () (*Stream , error ) {
for {
select {
case stream := <- s .acceptCh :
if err := stream .sendWindowUpdate (nil ); err != nil {
s .logger .Printf ("[WARN] error sending window update before accepting: %s" , err )
continue
}
return stream , nil
case <- s .shutdownCh :
return nil , s .shutdownErr
}
}
}
func (s *Session ) Close () error {
return s .close (ErrSessionShutdown , false , goAwayNormal )
}
func (s *Session ) CloseWithError (errCode uint32 ) error {
return s .close (&GoAwayError {Remote : false , ErrorCode : errCode }, true , errCode )
}
func (s *Session ) close (shutdownErr error , sendGoAway bool , errCode uint32 ) error {
s .shutdownLock .Lock ()
defer s .shutdownLock .Unlock ()
if s .shutdown {
return nil
}
s .shutdown = true
if s .shutdownErr == nil {
s .shutdownErr = shutdownErr
}
close (s .shutdownCh )
s .stopKeepalive ()
if sendGoAway && errCode != goAwayNormal {
<-s .sendDoneCh
ga := s .goAway (errCode )
if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
_, _ = s .conn .Write (ga [:])
}
s .conn .SetWriteDeadline (time .Time {})
}
s .conn .Close ()
<-s .sendDoneCh
<-s .recvDoneCh
resetErr := shutdownErr
if _ , ok := resetErr .(*GoAwayError ); !ok {
resetErr = fmt .Errorf ("%w: connection closed: %w" , ErrStreamReset , shutdownErr )
}
s .streamLock .Lock ()
defer s .streamLock .Unlock ()
for id , stream := range s .streams {
stream .forceClose (resetErr )
delete (s .streams , id )
stream .memorySpan .Done ()
}
return nil
}
func (s *Session ) GoAway () error {
return s .sendMsg (s .goAway (goAwayNormal ), nil , nil , true )
}
func (s *Session ) goAway (reason uint32 ) header {
atomic .SwapInt32 (&s .localGoAway , 1 )
hdr := encode (typeGoAway , 0 , 0 , reason )
return hdr
}
func (s *Session ) measureRTT () {
rtt , err := s .Ping ()
if err != nil {
return
}
if !atomic .CompareAndSwapInt64 (&s .rtt , 0 , rtt .Nanoseconds ()) {
prev := atomic .LoadInt64 (&s .rtt )
smoothedRTT := prev /2 + rtt .Nanoseconds ()/2
atomic .StoreInt64 (&s .rtt , smoothedRTT )
}
}
func (s *Session ) startMeasureRTT () {
s .measureRTT ()
t := time .NewTicker (s .config .MeasureRTTInterval )
defer t .Stop ()
for {
select {
case <- s .CloseChan ():
return
case <- t .C :
s .measureRTT ()
}
}
}
func (s *Session ) getRTT () time .Duration {
return time .Duration (atomic .LoadInt64 (&s .rtt ))
}
func (s *Session ) Ping () (dur time .Duration , err error ) {
s .pingLock .Lock ()
if activePing := s .activePing ; activePing != nil {
s .pingLock .Unlock ()
return activePing .wait ()
}
activePing := newPing (s .pingID )
s .pingID ++
s .activePing = activePing
s .pingLock .Unlock ()
defer func () {
activePing .finish (dur , err )
s .pingLock .Lock ()
s .activePing = nil
s .pingLock .Unlock ()
}()
timer := time .NewTimer (s .config .ConnectionWriteTimeout )
defer timer .Stop ()
select {
case s .pingCh <- activePing .id :
case <- timer .C :
return 0 , ErrTimeout
case <- s .shutdownCh :
return 0 , s .shutdownErr
}
start := time .Now ()
if !timer .Stop () {
<-timer .C
}
timer .Reset (s .config .ConnectionWriteTimeout )
select {
case <- activePing .pingResponse :
case <- timer .C :
return 0 , ErrTimeout
case <- s .shutdownCh :
return 0 , s .shutdownErr
}
return time .Since (start ), nil
}
func (s *Session ) startKeepalive () {
s .keepaliveLock .Lock ()
defer s .keepaliveLock .Unlock ()
s .keepaliveTimer = time .AfterFunc (s .config .KeepAliveInterval , func () {
s .keepaliveLock .Lock ()
if s .keepaliveTimer == nil || s .keepaliveActive {
s .keepaliveLock .Unlock ()
return
}
s .keepaliveActive = true
s .keepaliveLock .Unlock ()
_ , err := s .Ping ()
s .keepaliveLock .Lock ()
s .keepaliveActive = false
if s .keepaliveTimer != nil {
s .keepaliveTimer .Reset (s .config .KeepAliveInterval )
}
s .keepaliveLock .Unlock ()
if err != nil {
s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
s .close (ErrKeepAliveTimeout , false , 0 )
}
})
}
func (s *Session ) stopKeepalive () {
s .keepaliveLock .Lock ()
defer s .keepaliveLock .Unlock ()
if s .keepaliveTimer != nil {
s .keepaliveTimer .Stop ()
s .keepaliveTimer = nil
}
}
func (s *Session ) extendKeepalive () {
s .keepaliveLock .Lock ()
if s .keepaliveTimer != nil && !s .keepaliveActive {
s .keepaliveTimer .Reset (s .config .KeepAliveInterval )
}
s .keepaliveLock .Unlock ()
}
func (s *Session ) sendMsg (hdr header , body []byte , deadline <-chan struct {}, waitForShutDown bool ) error {
select {
case <- s .shutdownCh :
return s .shutdownErr
default :
}
select {
case <- deadline :
return ErrTimeout
default :
}
buf := pool .Get (headerSize + len (body ))
copy (buf [:headerSize ], hdr [:])
copy (buf [headerSize :], body )
select {
case <- s .shutdownCh :
pool .Put (buf )
return s .shutdownErr
case <- s .sendDoneCh :
pool .Put (buf )
if waitForShutDown {
<-s .shutdownCh
return s .shutdownErr
}
return errSendLoopDone
case s .sendCh <- buf :
return nil
case <- deadline :
pool .Put (buf )
return ErrTimeout
}
}
func (s *Session ) send () {
if err := s .sendLoop (); err != nil {
s .shutdownLock .Lock ()
if s .shutdownErr == nil {
s .conn .Close ()
<-s .recvDoneCh
if _ , ok := s .recvErr .(*GoAwayError ); ok {
err = s .recvErr
}
s .shutdownErr = err
}
s .shutdownLock .Unlock ()
s .close (err , false , 0 )
}
}
func (s *Session ) sendLoop () (err error ) {
defer func () {
if rerr := recover (); rerr != nil {
fmt .Fprintf (os .Stderr , "caught panic: %s\n%s\n" , rerr , debug .Stack ())
err = fmt .Errorf ("panic in yamux send loop: %s" , rerr )
}
}()
defer close (s .sendDoneCh )
var lastWriteDeadline time .Time
extendWriteDeadline := func () error {
now := time .Now ()
if now .Add (s .config .ConnectionWriteTimeout / 2 ).After (lastWriteDeadline ) {
lastWriteDeadline = now .Add (s .config .ConnectionWriteTimeout )
return s .conn .SetWriteDeadline (lastWriteDeadline )
}
return nil
}
writer := s .conn
for {
select {
case <- s .shutdownCh :
return nil
default :
}
var buf []byte
select {
case pingID := <- s .pingCh :
buf = pool .Get (headerSize )
hdr := encode (typePing , flagSYN , 0 , pingID )
copy (buf , hdr [:])
case pingID := <- s .pongCh :
buf = pool .Get (headerSize )
hdr := encode (typePing , flagACK , 0 , pingID )
copy (buf , hdr [:])
default :
select {
case buf = <- s .sendCh :
case pingID := <- s .pingCh :
buf = pool .Get (headerSize )
hdr := encode (typePing , flagSYN , 0 , pingID )
copy (buf , hdr [:])
case pingID := <- s .pongCh :
buf = pool .Get (headerSize )
hdr := encode (typePing , flagACK , 0 , pingID )
copy (buf , hdr [:])
case <- s .shutdownCh :
return nil
}
}
if err := extendWriteDeadline (); err != nil {
pool .Put (buf )
return err
}
_ , err := writer .Write (buf )
pool .Put (buf )
if err != nil {
if os .IsTimeout (err ) {
err = ErrConnectionWriteTimeout
}
return err
}
}
}
func (s *Session ) recv () {
if err := s .recvLoop (); err != nil {
s .close (err , false , 0 )
}
}
var (
handlers = []func (*Session , header ) error {
typeData : (*Session ).handleStreamMessage ,
typeWindowUpdate : (*Session ).handleStreamMessage ,
typePing : (*Session ).handlePing ,
typeGoAway : (*Session ).handleGoAway ,
}
)
func (s *Session ) recvLoop () (err error ) {
defer func () {
if rerr := recover (); rerr != nil {
fmt .Fprintf (os .Stderr , "caught panic: %s\n%s\n" , rerr , debug .Stack ())
err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
}
}()
defer func () {
s .recvErr = err
close (s .recvDoneCh )
}()
var hdr header
for {
if _ , err := io .ReadFull (s .reader , hdr [:]); err != nil {
if err != io .EOF && !strings .Contains (err .Error(), "closed" ) && !strings .Contains (err .Error(), "reset by peer" ) {
s .logger .Printf ("[ERR] yamux: Failed to read header: %v" , err )
}
return err
}
s .extendKeepalive ()
if hdr .Version () != protoVersion {
s .logger .Printf ("[ERR] yamux: Invalid protocol version: %d" , hdr .Version ())
return ErrInvalidVersion
}
mt := hdr .MsgType ()
if mt < typeData || mt > typeGoAway {
return ErrInvalidMsgType
}
if err := handlers [mt ](s , hdr ); err != nil {
return err
}
}
}
func (s *Session ) handleStreamMessage (hdr header ) error {
id := hdr .StreamID ()
flags := hdr .Flags ()
if flags &flagSYN == flagSYN {
if err := s .incomingStream (id ); err != nil {
return err
}
}
s .streamLock .Lock ()
stream := s .streams [id ]
s .streamLock .Unlock ()
if stream == nil {
if hdr .MsgType () == typeData && hdr .Length () > 0 {
if _ , err := io .CopyN (io .Discard , s .reader , int64 (hdr .Length ())); err != nil {
return nil
}
}
return nil
}
if hdr .MsgType () == typeWindowUpdate {
stream .incrSendWindow (hdr , flags )
return nil
}
if err := stream .readData (hdr , flags , s .reader ); err != nil {
if sendErr := s .sendMsg (s .goAway (goAwayProtoErr ), nil , nil , false ); sendErr != nil && sendErr != errSendLoopDone {
s .logger .Printf ("[WARN] yamux: failed to send go away: %v" , sendErr )
}
return err
}
return nil
}
func (s *Session ) handlePing (hdr header ) error {
flags := hdr .Flags ()
pingID := hdr .Length ()
if flags &flagSYN == flagSYN {
select {
case s .pongCh <- pingID :
default :
s .logger .Printf ("[WARN] yamux: dropped ping reply" )
}
return nil
}
s .pingLock .Lock ()
if s .activePing != nil && s .activePing .id == pingID {
select {
case s .activePing .pingResponse <- struct {}{}:
default :
}
}
s .pingLock .Unlock ()
return nil
}
func (s *Session ) handleGoAway (hdr header ) error {
code := hdr .Length ()
switch code {
case goAwayNormal :
return ErrRemoteGoAway
case goAwayProtoErr :
s .logger .Printf ("[ERR] yamux: received protocol error go away" )
case goAwayInternalErr :
s .logger .Printf ("[ERR] yamux: received internal error go away" )
default :
s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
}
return &GoAwayError {Remote : true , ErrorCode : code }
}
func (s *Session ) incomingStream (id uint32 ) error {
if s .client != (id %2 == 0 ) {
s .logger .Printf ("[ERR] yamux: both endpoints are clients" )
return fmt .Errorf ("both yamux endpoints are clients" )
}
if atomic .LoadInt32 (&s .localGoAway ) == 1 {
hdr := encode (typeWindowUpdate , flagRST , id , 0 )
return s .sendMsg (hdr , nil , nil , false )
}
span , err := s .newMemoryManager ()
if err != nil {
return fmt .Errorf ("failed to create resource span: %w" , err )
}
if err := span .ReserveMemory (initialStreamWindow , 255 ); err != nil {
return err
}
stream := newStream (s , id , streamSYNReceived , initialStreamWindow , span )
s .streamLock .Lock ()
defer s .streamLock .Unlock ()
if _ , ok := s .streams [id ]; ok {
s .logger .Printf ("[ERR] yamux: duplicate stream declared" )
if sendErr := s .sendMsg (s .goAway (goAwayProtoErr ), nil , nil , false ); sendErr != nil && sendErr != errSendLoopDone {
s .logger .Printf ("[WARN] yamux: failed to send go away: %v" , sendErr )
}
span .Done ()
return ErrDuplicateStream
}
if s .numIncomingStreams >= s .config .MaxIncomingStreams {
s .logger .Printf ("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset" )
defer span .Done ()
hdr := encode (typeWindowUpdate , flagRST , id , 0 )
return s .sendMsg (hdr , nil , nil , false )
}
s .numIncomingStreams ++
s .streams [id ] = stream
select {
case s .acceptCh <- stream :
return nil
default :
defer span .Done ()
s .logger .Printf ("[WARN] yamux: backlog exceeded, forcing stream reset" )
s .deleteStream (id )
hdr := encode (typeWindowUpdate , flagRST , id , 0 )
return s .sendMsg (hdr , nil , nil , false )
}
}
func (s *Session ) closeStream (id uint32 ) {
s .streamLock .Lock ()
defer s .streamLock .Unlock ()
if _ , ok := s .inflight [id ]; ok {
select {
case <- s .synCh :
default :
s .logger .Printf ("[ERR] yamux: SYN tracking out of sync" )
}
delete (s .inflight , id )
}
s .deleteStream (id )
}
func (s *Session ) deleteStream (id uint32 ) {
str , ok := s .streams [id ]
if !ok {
return
}
if s .client == (id %2 == 0 ) {
if s .numIncomingStreams == 0 {
s .logger .Printf ("[ERR] yamux: numIncomingStreams underflow" )
s .numIncomingStreams = math .MaxUint32
} else {
s .numIncomingStreams --
}
}
delete (s .streams , id )
str .memorySpan .Done ()
}
func (s *Session ) establishStream (id uint32 ) {
s .streamLock .Lock ()
if _ , ok := s .inflight [id ]; ok {
delete (s .inflight , id )
} else {
s .logger .Printf ("[ERR] yamux: established stream without inflight SYN (no tracking entry)" )
}
select {
case <- s .synCh :
default :
s .logger .Printf ("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)" )
}
s .streamLock .Unlock ()
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .