package sctp
import (
"errors"
"fmt"
"io"
"os"
"sync"
"sync/atomic"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v3/deadline"
)
const (
ReliabilityTypeReliable byte = 0
ReliabilityTypeRexmit byte = 1
ReliabilityTypeTimed byte = 2
)
type StreamState int
const (
StreamStateOpen StreamState = iota
StreamStateClosing
StreamStateClosed
)
func (ss StreamState ) String () string {
switch ss {
case StreamStateOpen :
return "open"
case StreamStateClosing :
return "closing"
case StreamStateClosed :
return "closed"
}
return "unknown"
}
var (
ErrOutboundPacketTooLarge = errors .New ("outbound packet larger than maximum message size" )
ErrStreamClosed = errors .New ("stream closed" )
ErrReadDeadlineExceeded = fmt .Errorf ("read deadline exceeded: %w" , os .ErrDeadlineExceeded )
)
type Stream struct {
association *Association
lock sync .RWMutex
streamIdentifier uint16
defaultPayloadType PayloadProtocolIdentifier
reassemblyQueue *reassemblyQueue
sequenceNumber uint16
readNotifier *sync .Cond
readErr error
readTimeoutCancel chan struct {}
writeDeadline *deadline .Deadline
writeLock sync .Mutex
unordered bool
reliabilityType byte
reliabilityValue uint32
bufferedAmount uint64
bufferedAmountLow uint64
onBufferedAmountLow func ()
state StreamState
log logging .LeveledLogger
name string
}
func (s *Stream ) StreamIdentifier () uint16 {
s .lock .RLock ()
defer s .lock .RUnlock ()
return s .streamIdentifier
}
func (s *Stream ) SetDefaultPayloadType (defaultPayloadType PayloadProtocolIdentifier ) {
atomic .StoreUint32 ((*uint32 )(&s .defaultPayloadType ), uint32 (defaultPayloadType ))
}
func (s *Stream ) SetReliabilityParams (unordered bool , relType byte , relVal uint32 ) {
s .lock .Lock ()
defer s .lock .Unlock ()
s .setReliabilityParams (unordered , relType , relVal )
}
func (s *Stream ) setReliabilityParams (unordered bool , relType byte , relVal uint32 ) {
s .log .Debugf ("[%s] reliability params: ordered=%v type=%d value=%d" ,
s .name , !unordered , relType , relVal )
s .unordered = unordered
s .reliabilityType = relType
s .reliabilityValue = relVal
}
func (s *Stream ) Read (p []byte ) (int , error ) {
n , _ , err := s .ReadSCTP (p )
return n , err
}
func (s *Stream ) ReadSCTP (payload []byte ) (int , PayloadProtocolIdentifier , error ) {
s .lock .Lock ()
defer s .lock .Unlock ()
defer func () {
if s .readTimeoutCancel != nil && s .readErr != nil {
close (s .readTimeoutCancel )
s .readTimeoutCancel = nil
}
}()
for {
n , ppi , err := s .reassemblyQueue .read (payload )
if err == nil || errors .Is (err , io .ErrShortBuffer ) {
return n , ppi , err
}
if s .readErr != nil {
return 0 , PayloadProtocolIdentifier (0 ), s .readErr
}
s .readNotifier .Wait ()
}
}
func (s *Stream ) SetReadDeadline (deadline time .Time ) error {
s .lock .Lock ()
defer s .lock .Unlock ()
if s .readTimeoutCancel != nil {
close (s .readTimeoutCancel )
s .readTimeoutCancel = nil
}
if s .readErr != nil {
if !errors .Is (s .readErr , ErrReadDeadlineExceeded ) {
return nil
}
s .readErr = nil
}
if !deadline .IsZero () {
s .readTimeoutCancel = make (chan struct {})
go func (readTimeoutCancel chan struct {}) {
t := time .NewTimer (time .Until (deadline ))
select {
case <- readTimeoutCancel :
t .Stop ()
return
case <- t .C :
select {
case <- readTimeoutCancel :
return
default :
}
s .lock .Lock ()
if s .readErr == nil {
s .readErr = ErrReadDeadlineExceeded
}
s .readTimeoutCancel = nil
s .lock .Unlock ()
s .readNotifier .Signal ()
}
}(s .readTimeoutCancel )
}
return nil
}
func (s *Stream ) handleData (pd *chunkPayloadData ) {
s .lock .Lock ()
defer s .lock .Unlock ()
var readable bool
if s .reassemblyQueue .push (pd ) {
readable = s .reassemblyQueue .isReadable ()
s .log .Debugf ("[%s] reassemblyQueue readable=%v" , s .name , readable )
if readable {
s .log .Debugf ("[%s] readNotifier.signal()" , s .name )
s .readNotifier .Signal ()
s .log .Debugf ("[%s] readNotifier.signal() done" , s .name )
}
}
}
func (s *Stream ) handleForwardTSNForOrdered (ssn uint16 ) {
var readable bool
func () {
s .lock .Lock ()
defer s .lock .Unlock ()
if s .unordered {
return
}
s .reassemblyQueue .forwardTSNForOrdered (ssn )
readable = s .reassemblyQueue .isReadable ()
}()
if readable {
s .readNotifier .Signal ()
}
}
func (s *Stream ) handleForwardTSNForUnordered (newCumulativeTSN uint32 ) {
var readable bool
func () {
s .lock .Lock ()
defer s .lock .Unlock ()
if !s .unordered {
return
}
s .reassemblyQueue .forwardTSNForUnordered (newCumulativeTSN )
readable = s .reassemblyQueue .isReadable ()
}()
if readable {
s .readNotifier .Signal ()
}
}
func (s *Stream ) Write (payload []byte ) (n int , err error ) {
ppi := PayloadProtocolIdentifier (atomic .LoadUint32 ((*uint32 )(&s .defaultPayloadType )))
return s .WriteSCTP (payload , ppi )
}
func (s *Stream ) WriteSCTP (payload []byte , ppi PayloadProtocolIdentifier ) (int , error ) {
maxMessageSize := s .association .MaxMessageSize ()
if len (payload ) > int (maxMessageSize ) {
return 0 , fmt .Errorf ("%w: %v" , ErrOutboundPacketTooLarge , maxMessageSize )
}
if s .State () != StreamStateOpen {
return 0 , ErrStreamClosed
}
if s .association .isBlockWrite () {
s .writeLock .Lock ()
}
chunks , unordered := s .packetize (payload , ppi )
n := len (payload )
err := s .association .sendPayloadData (s .writeDeadline , chunks )
if err != nil {
s .lock .Lock ()
s .bufferedAmount -= uint64 (n )
if !unordered {
s .sequenceNumber --
}
s .lock .Unlock ()
n = 0
}
if s .association .isBlockWrite () {
s .writeLock .Unlock ()
}
return n , err
}
func (s *Stream ) SetWriteDeadline (deadline time .Time ) error {
s .writeDeadline .Set (deadline )
return nil
}
func (s *Stream ) SetDeadline (t time .Time ) error {
if err := s .SetReadDeadline (t ); err != nil {
return err
}
return s .SetWriteDeadline (t )
}
func (s *Stream ) packetize (raw []byte , ppi PayloadProtocolIdentifier ) ([]*chunkPayloadData , bool ) {
s .lock .Lock ()
defer s .lock .Unlock ()
offset := uint32 (0 )
remaining := uint32 (len (raw ))
unordered := ppi != PayloadTypeWebRTCDCEP && s .unordered
var chunks []*chunkPayloadData
var head *chunkPayloadData
for remaining != 0 {
fragmentSize := min32 (s .association .maxPayloadSize , remaining )
userData := make ([]byte , fragmentSize )
copy (userData , raw [offset :offset +fragmentSize ])
chunk := &chunkPayloadData {
streamIdentifier : s .streamIdentifier ,
userData : userData ,
unordered : unordered ,
beginningFragment : offset == 0 ,
endingFragment : remaining -fragmentSize == 0 ,
immediateSack : false ,
payloadType : ppi ,
streamSequenceNumber : s .sequenceNumber ,
head : head ,
}
if head == nil {
head = chunk
}
chunks = append (chunks , chunk )
remaining -= fragmentSize
offset += fragmentSize
}
if !unordered {
s .sequenceNumber ++
}
s .bufferedAmount += uint64 (len (raw ))
s .log .Tracef ("[%s] bufferedAmount = %d" , s .name , s .bufferedAmount )
return chunks , unordered
}
func (s *Stream ) Close () error {
if sid , resetOutbound := func () (uint16 , bool ) {
s .lock .Lock ()
defer s .lock .Unlock ()
s .log .Debugf ("[%s] Close: state=%s" , s .name , s .state .String ())
if s .state == StreamStateOpen {
if s .readErr == nil {
s .state = StreamStateClosing
} else {
s .state = StreamStateClosed
}
s .log .Debugf ("[%s] state change: open => %s" , s .name , s .state .String ())
return s .streamIdentifier , true
}
return s .streamIdentifier , false
}(); resetOutbound {
return s .association .sendResetRequest (sid )
}
return nil
}
func (s *Stream ) BufferedAmount () uint64 {
s .lock .RLock ()
defer s .lock .RUnlock ()
return s .bufferedAmount
}
func (s *Stream ) BufferedAmountLowThreshold () uint64 {
s .lock .RLock ()
defer s .lock .RUnlock ()
return s .bufferedAmountLow
}
func (s *Stream ) SetBufferedAmountLowThreshold (th uint64 ) {
s .lock .Lock ()
defer s .lock .Unlock ()
s .bufferedAmountLow = th
}
func (s *Stream ) OnBufferedAmountLow (f func ()) {
s .lock .Lock ()
defer s .lock .Unlock ()
s .onBufferedAmountLow = f
}
func (s *Stream ) onBufferReleased (nBytesReleased int ) {
if nBytesReleased <= 0 {
return
}
s .lock .Lock ()
fromAmount := s .bufferedAmount
if s .bufferedAmount < uint64 (nBytesReleased ) {
s .bufferedAmount = 0
s .log .Errorf ("[%s] released buffer size %d should be <= %d" ,
s .name , nBytesReleased , s .bufferedAmount )
} else {
s .bufferedAmount -= uint64 (nBytesReleased )
}
s .log .Tracef ("[%s] bufferedAmount = %d" , s .name , s .bufferedAmount )
if s .onBufferedAmountLow != nil && fromAmount > s .bufferedAmountLow && s .bufferedAmount <= s .bufferedAmountLow {
f := s .onBufferedAmountLow
s .lock .Unlock ()
f ()
return
}
s .lock .Unlock ()
}
func (s *Stream ) getNumBytesInReassemblyQueue () int {
return s .reassemblyQueue .getNumBytes ()
}
func (s *Stream ) onInboundStreamReset () {
s .lock .Lock ()
defer s .lock .Unlock ()
s .log .Debugf ("[%s] onInboundStreamReset: state=%s" , s .name , s .state .String ())
s .readErr = io .EOF
s .readNotifier .Broadcast ()
if s .state == StreamStateClosing {
s .log .Debugf ("[%s] state change: closing => closed" , s .name )
s .state = StreamStateClosed
}
}
func (s *Stream ) State () StreamState {
s .lock .RLock ()
defer s .lock .RUnlock ()
return s .state
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .