package webrtc
import (
"errors"
"io"
"sync"
"time"
"github.com/pion/datachannel"
"github.com/pion/logging"
"github.com/pion/sctp"
"github.com/pion/webrtc/v4/pkg/rtcerr"
)
const sctpMaxChannels = uint16 (65535 )
type SCTPTransport struct {
lock sync .RWMutex
dtlsTransport *DTLSTransport
state SCTPTransportState
isStarted bool
maxChannels *uint16
onErrorHandler func (error )
onCloseHandler func (error )
sctpAssociation *sctp .Association
onDataChannelHandler func (*DataChannel )
onDataChannelOpenedHandler func (*DataChannel )
dataChannels []*DataChannel
dataChannelIDsUsed map [uint16 ]struct {}
dataChannelsOpened uint32
dataChannelsRequested uint32
dataChannelsAccepted uint32
api *API
log logging .LeveledLogger
}
func (api *API ) NewSCTPTransport (dtls *DTLSTransport ) *SCTPTransport {
res := &SCTPTransport {
dtlsTransport : dtls ,
state : SCTPTransportStateConnecting ,
api : api ,
log : api .settingEngine .LoggerFactory .NewLogger ("ortc" ),
dataChannelIDsUsed : make (map [uint16 ]struct {}),
}
res .updateMaxChannels ()
return res
}
func (r *SCTPTransport ) Transport () *DTLSTransport {
r .lock .RLock ()
defer r .lock .RUnlock ()
return r .dtlsTransport
}
func (r *SCTPTransport ) GetCapabilities () SCTPCapabilities {
var maxMessageSize uint32
if a := r .association (); a != nil {
maxMessageSize = a .MaxMessageSize ()
}
return SCTPCapabilities {
MaxMessageSize : maxMessageSize ,
}
}
func (r *SCTPTransport ) Start (capabilities SCTPCapabilities ) error {
if r .isStarted {
return nil
}
r .isStarted = true
maxMessageSize := capabilities .MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = sctpMaxMessageSizeUnsetValue
}
dtlsTransport := r .Transport ()
if dtlsTransport == nil || dtlsTransport .conn == nil {
return errSCTPTransportDTLS
}
sctpAssociation , err := sctp .Client (sctp .Config {
NetConn : dtlsTransport .conn ,
MaxReceiveBufferSize : r .api .settingEngine .sctp .maxReceiveBufferSize ,
EnableZeroChecksum : r .api .settingEngine .sctp .enableZeroChecksum ,
LoggerFactory : r .api .settingEngine .LoggerFactory ,
RTOMax : float64 (r .api .settingEngine .sctp .rtoMax ) / float64 (time .Millisecond ),
BlockWrite : r .api .settingEngine .detach .DataChannels && r .api .settingEngine .dataChannelBlockWrite ,
MaxMessageSize : maxMessageSize ,
MTU : outboundMTU ,
MinCwnd : r .api .settingEngine .sctp .minCwnd ,
FastRtxWnd : r .api .settingEngine .sctp .fastRtxWnd ,
CwndCAStep : r .api .settingEngine .sctp .cwndCAStep ,
})
if err != nil {
return err
}
r .lock .Lock ()
r .sctpAssociation = sctpAssociation
r .state = SCTPTransportStateConnected
dataChannels := append ([]*DataChannel {}, r .dataChannels ...)
r .lock .Unlock ()
var openedDCCount uint32
for _ , d := range dataChannels {
if d .ReadyState () == DataChannelStateConnecting {
err := d .open (r )
if err != nil {
r .log .Warnf ("failed to open data channel: %s" , err )
continue
}
openedDCCount ++
}
}
r .lock .Lock ()
r .dataChannelsOpened += openedDCCount
r .lock .Unlock ()
go r .acceptDataChannels (sctpAssociation , dataChannels )
return nil
}
func (r *SCTPTransport ) Stop () error {
r .lock .Lock ()
defer r .lock .Unlock ()
if r .sctpAssociation == nil {
return nil
}
r .sctpAssociation .Abort ("" )
r .sctpAssociation = nil
r .state = SCTPTransportStateClosed
return nil
}
func (r *SCTPTransport ) acceptDataChannels (
assoc *sctp .Association ,
existingDataChannels []*DataChannel ,
) {
dataChannels := make ([]*datachannel .DataChannel , 0 , len (existingDataChannels ))
for _ , dc := range existingDataChannels {
dc .mu .Lock ()
isNil := dc .dataChannel == nil
dc .mu .Unlock ()
if isNil {
continue
}
dataChannels = append (dataChannels , dc .dataChannel )
}
ACCEPT :
for {
dc , err := datachannel .Accept (assoc , &datachannel .Config {
LoggerFactory : r .api .settingEngine .LoggerFactory ,
}, dataChannels ...)
if err != nil {
if !errors .Is (err , io .EOF ) {
r .log .Errorf ("Failed to accept data channel: %v" , err )
r .onError (err )
r .onClose (err )
} else {
r .onClose (nil )
}
return
}
for _ , ch := range dataChannels {
if ch .StreamIdentifier () == dc .StreamIdentifier () {
continue ACCEPT
}
}
var (
maxRetransmits *uint16
maxPacketLifeTime *uint16
)
val := uint16 (dc .Config .ReliabilityParameter )
ordered := true
switch dc .Config .ChannelType {
case datachannel .ChannelTypeReliable :
ordered = true
case datachannel .ChannelTypeReliableUnordered :
ordered = false
case datachannel .ChannelTypePartialReliableRexmit :
ordered = true
maxRetransmits = &val
case datachannel .ChannelTypePartialReliableRexmitUnordered :
ordered = false
maxRetransmits = &val
case datachannel .ChannelTypePartialReliableTimed :
ordered = true
maxPacketLifeTime = &val
case datachannel .ChannelTypePartialReliableTimedUnordered :
ordered = false
maxPacketLifeTime = &val
default :
}
sid := dc .StreamIdentifier ()
rtcDC , err := r .api .newDataChannel (&DataChannelParameters {
ID : &sid ,
Label : dc .Config .Label ,
Protocol : dc .Config .Protocol ,
Negotiated : dc .Config .Negotiated ,
Ordered : ordered ,
MaxPacketLifeTime : maxPacketLifeTime ,
MaxRetransmits : maxRetransmits ,
}, r , r .api .settingEngine .LoggerFactory .NewLogger ("ortc" ))
if err != nil {
if err1 := dc .Close (); err1 != nil {
r .log .Errorf ("Failed to close invalid data channel: %v" , err1 )
}
r .log .Errorf ("Failed to accept data channel: %v" , err )
r .onError (err )
continue ACCEPT
}
<-r .onDataChannel (rtcDC )
rtcDC .handleOpen (dc , true , dc .Config .Negotiated )
r .lock .Lock ()
r .dataChannelsOpened ++
handler := r .onDataChannelOpenedHandler
r .lock .Unlock ()
if handler != nil {
handler (rtcDC )
}
}
}
func (r *SCTPTransport ) OnError (f func (err error )) {
r .lock .Lock ()
defer r .lock .Unlock ()
r .onErrorHandler = f
}
func (r *SCTPTransport ) onError (err error ) {
r .lock .RLock ()
handler := r .onErrorHandler
r .lock .RUnlock ()
if handler != nil {
go handler (err )
}
}
func (r *SCTPTransport ) OnClose (f func (err error )) {
r .lock .Lock ()
defer r .lock .Unlock ()
r .onCloseHandler = f
}
func (r *SCTPTransport ) onClose (err error ) {
r .lock .RLock ()
handler := r .onCloseHandler
r .lock .RUnlock ()
if handler != nil {
go handler (err )
}
}
func (r *SCTPTransport ) OnDataChannel (f func (*DataChannel )) {
r .lock .Lock ()
defer r .lock .Unlock ()
r .onDataChannelHandler = f
}
func (r *SCTPTransport ) OnDataChannelOpened (f func (*DataChannel )) {
r .lock .Lock ()
defer r .lock .Unlock ()
r .onDataChannelOpenedHandler = f
}
func (r *SCTPTransport ) onDataChannel (dc *DataChannel ) (done chan struct {}) {
r .lock .Lock ()
r .dataChannels = append (r .dataChannels , dc )
r .dataChannelsAccepted ++
if dc .ID () != nil {
r .dataChannelIDsUsed [*dc .ID ()] = struct {}{}
} else {
r .log .Errorf ("accepted data channel with no ID" )
}
handler := r .onDataChannelHandler
r .lock .Unlock ()
done = make (chan struct {})
if handler == nil || dc == nil {
close (done )
return
}
go func () {
handler (dc )
close (done )
}()
return
}
func (r *SCTPTransport ) updateMaxChannels () {
val := sctpMaxChannels
r .maxChannels = &val
}
func (r *SCTPTransport ) MaxChannels () uint16 {
r .lock .Lock ()
defer r .lock .Unlock ()
if r .maxChannels == nil {
return sctpMaxChannels
}
return *r .maxChannels
}
func (r *SCTPTransport ) State () SCTPTransportState {
r .lock .RLock ()
defer r .lock .RUnlock ()
return r .state
}
func (r *SCTPTransport ) collectStats (collector *statsReportCollector ) {
collector .Collecting ()
stats := SCTPTransportStats {
Timestamp : statsTimestampFrom (time .Now ()),
Type : StatsTypeSCTPTransport ,
ID : "sctpTransport" ,
}
association := r .association ()
if association != nil {
stats .BytesSent = association .BytesSent ()
stats .BytesReceived = association .BytesReceived ()
stats .SmoothedRoundTripTime = association .SRTT () * 0.001
stats .CongestionWindow = association .CWND ()
stats .ReceiverWindow = association .RWND ()
stats .MTU = association .MTU ()
}
collector .Collect (stats .ID , stats )
}
func (r *SCTPTransport ) generateAndSetDataChannelID (dtlsRole DTLSRole , idOut **uint16 ) error {
var id uint16
if dtlsRole != DTLSRoleClient {
id ++
}
maxVal := r .MaxChannels ()
r .lock .Lock ()
defer r .lock .Unlock ()
for ; id < maxVal -1 ; id += 2 {
if _ , ok := r .dataChannelIDsUsed [id ]; ok {
continue
}
*idOut = &id
r .dataChannelIDsUsed [id ] = struct {}{}
return nil
}
return &rtcerr .OperationError {Err : ErrMaxDataChannelID }
}
func (r *SCTPTransport ) association () *sctp .Association {
if r == nil {
return nil
}
r .lock .RLock ()
association := r .sctpAssociation
r .lock .RUnlock ()
return association
}
func (r *SCTPTransport ) BufferedAmount () int {
r .lock .Lock ()
defer r .lock .Unlock ()
if r .sctpAssociation == nil {
return 0
}
return r .sctpAssociation .BufferedAmount ()
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .