package webrtc
import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"github.com/pion/datachannel"
"github.com/pion/logging"
"github.com/pion/webrtc/v4/pkg/rtcerr"
)
var errSCTPNotEstablished = errors .New ("SCTP not established" )
type DataChannel struct {
mu sync .RWMutex
statsID string
label string
ordered bool
maxPacketLifeTime *uint16
maxRetransmits *uint16
protocol string
negotiated bool
id *uint16
readyState atomic .Value
bufferedAmountLowThreshold uint64
detachCalled bool
readLoopActive chan struct {}
isGracefulClosed bool
onMessageHandler func (DataChannelMessage )
openHandlerOnce sync .Once
onOpenHandler func ()
dialHandlerOnce sync .Once
onDialHandler func ()
onCloseHandler func ()
onBufferedAmountLow func ()
onErrorHandler func (error )
sctpTransport *SCTPTransport
dataChannel *datachannel .DataChannel
api *API
log logging .LeveledLogger
}
func (api *API ) NewDataChannel (transport *SCTPTransport , params *DataChannelParameters ) (*DataChannel , error ) {
d , err := api .newDataChannel (params , nil , api .settingEngine .LoggerFactory .NewLogger ("ortc" ))
if err != nil {
return nil , err
}
err = d .open (transport )
if err != nil {
return nil , err
}
return d , nil
}
func (api *API ) newDataChannel (
params *DataChannelParameters ,
sctpTransport *SCTPTransport ,
log logging .LeveledLogger ,
) (*DataChannel , error ) {
if len (params .Label ) > 65535 {
return nil , &rtcerr .TypeError {Err : ErrStringSizeLimit }
}
dataChannel := &DataChannel {
sctpTransport : sctpTransport ,
statsID : fmt .Sprintf ("DataChannel-%d" , time .Now ().UnixNano ()),
label : params .Label ,
protocol : params .Protocol ,
negotiated : params .Negotiated ,
id : params .ID ,
ordered : params .Ordered ,
maxPacketLifeTime : params .MaxPacketLifeTime ,
maxRetransmits : params .MaxRetransmits ,
api : api ,
log : log ,
}
dataChannel .setReadyState (DataChannelStateConnecting )
return dataChannel , nil
}
func (d *DataChannel ) open (sctpTransport *SCTPTransport ) error {
association := sctpTransport .association ()
if association == nil {
return errSCTPNotEstablished
}
d .mu .Lock ()
if d .sctpTransport != nil {
d .mu .Unlock ()
return nil
}
d .sctpTransport = sctpTransport
var channelType datachannel .ChannelType
var reliabilityParameter uint32
switch {
case d .maxPacketLifeTime == nil && d .maxRetransmits == nil :
if d .ordered {
channelType = datachannel .ChannelTypeReliable
} else {
channelType = datachannel .ChannelTypeReliableUnordered
}
case d .maxRetransmits != nil :
reliabilityParameter = uint32 (*d .maxRetransmits )
if d .ordered {
channelType = datachannel .ChannelTypePartialReliableRexmit
} else {
channelType = datachannel .ChannelTypePartialReliableRexmitUnordered
}
default :
reliabilityParameter = uint32 (*d .maxPacketLifeTime )
if d .ordered {
channelType = datachannel .ChannelTypePartialReliableTimed
} else {
channelType = datachannel .ChannelTypePartialReliableTimedUnordered
}
}
cfg := &datachannel .Config {
ChannelType : channelType ,
Priority : datachannel .ChannelPriorityNormal ,
ReliabilityParameter : reliabilityParameter ,
Label : d .label ,
Protocol : d .protocol ,
Negotiated : d .negotiated ,
LoggerFactory : d .api .settingEngine .LoggerFactory ,
}
if d .id == nil {
d .mu .Unlock ()
var dcID *uint16
err := d .sctpTransport .generateAndSetDataChannelID (d .sctpTransport .dtlsTransport .role (), &dcID )
if err != nil {
return err
}
d .mu .Lock ()
d .id = dcID
}
dc , err := datachannel .Dial (association , *d .id , cfg )
if err != nil {
d .mu .Unlock ()
return err
}
dc .SetBufferedAmountLowThreshold (d .bufferedAmountLowThreshold )
dc .OnBufferedAmountLow (d .onBufferedAmountLow )
d .mu .Unlock ()
d .onDial ()
d .handleOpen (dc , false , d .negotiated )
return nil
}
func (d *DataChannel ) Transport () *SCTPTransport {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .sctpTransport
}
func (d *DataChannel ) checkDetachAfterOpen () {
d .mu .RLock ()
defer d .mu .RUnlock ()
if d .api .settingEngine .detach .DataChannels && !d .detachCalled {
d .log .Warn ("webrtc.DetachDataChannels() enabled but didn't Detach, call Detach from OnOpen" )
}
}
func (d *DataChannel ) OnOpen (f func ()) {
d .mu .Lock ()
d .openHandlerOnce = sync .Once {}
d .onOpenHandler = f
d .mu .Unlock ()
if d .ReadyState () == DataChannelStateOpen {
go d .openHandlerOnce .Do (func () {
f ()
d .checkDetachAfterOpen ()
})
}
}
func (d *DataChannel ) onOpen () {
d .mu .RLock ()
handler := d .onOpenHandler
if d .isGracefulClosed {
d .mu .RUnlock ()
return
}
d .mu .RUnlock ()
if handler != nil {
go d .openHandlerOnce .Do (func () {
handler ()
d .checkDetachAfterOpen ()
})
}
}
func (d *DataChannel ) OnDial (f func ()) {
d .mu .Lock ()
d .dialHandlerOnce = sync .Once {}
d .onDialHandler = f
d .mu .Unlock ()
if d .ReadyState () == DataChannelStateOpen {
go d .dialHandlerOnce .Do (f )
}
}
func (d *DataChannel ) onDial () {
d .mu .RLock ()
handler := d .onDialHandler
if d .isGracefulClosed {
d .mu .RUnlock ()
return
}
d .mu .RUnlock ()
if handler != nil {
go d .dialHandlerOnce .Do (handler )
}
}
func (d *DataChannel ) OnClose (f func ()) {
d .mu .Lock ()
defer d .mu .Unlock ()
d .onCloseHandler = f
}
func (d *DataChannel ) onClose () {
d .mu .RLock ()
handler := d .onCloseHandler
d .mu .RUnlock ()
if handler != nil {
go handler ()
}
}
func (d *DataChannel ) OnMessage (f func (msg DataChannelMessage )) {
d .mu .Lock ()
defer d .mu .Unlock ()
d .onMessageHandler = f
}
func (d *DataChannel ) onMessage (msg DataChannelMessage ) {
d .mu .RLock ()
handler := d .onMessageHandler
if d .isGracefulClosed {
d .mu .RUnlock ()
return
}
d .mu .RUnlock ()
if handler == nil {
return
}
handler (msg )
}
func (d *DataChannel ) handleOpen (dc *datachannel .DataChannel , isRemote , isAlreadyNegotiated bool ) {
d .mu .Lock ()
if d .isGracefulClosed {
d .mu .Unlock ()
if err := dc .Close (); err != nil {
d .log .Errorf ("Failed to close DataChannel that was closed during connecting state %v" , err .Error())
}
d .onClose ()
return
}
d .dataChannel = dc
bufferedAmountLowThreshold := d .bufferedAmountLowThreshold
onBufferedAmountLow := d .onBufferedAmountLow
d .mu .Unlock ()
d .setReadyState (DataChannelStateOpen )
if d .api .settingEngine .detach .DataChannels || isRemote || isAlreadyNegotiated {
d .dataChannel .SetBufferedAmountLowThreshold (bufferedAmountLowThreshold )
d .dataChannel .OnBufferedAmountLow (onBufferedAmountLow )
d .onOpen ()
} else {
dc .OnOpen (func () {
d .onOpen ()
})
}
d .mu .Lock ()
defer d .mu .Unlock ()
if d .isGracefulClosed {
return
}
if !d .api .settingEngine .detach .DataChannels {
d .readLoopActive = make (chan struct {})
go d .readLoop ()
}
}
func (d *DataChannel ) OnError (f func (err error )) {
d .mu .Lock ()
defer d .mu .Unlock ()
d .onErrorHandler = f
}
func (d *DataChannel ) onError (err error ) {
d .mu .RLock ()
handler := d .onErrorHandler
if d .isGracefulClosed {
d .mu .RUnlock ()
return
}
d .mu .RUnlock ()
if handler != nil {
go handler (err )
}
}
func (d *DataChannel ) readLoop () {
defer func () {
d .mu .Lock ()
readLoopActive := d .readLoopActive
d .mu .Unlock ()
defer close (readLoopActive )
}()
buffer := make ([]byte , sctpMaxMessageSizeUnsetValue )
for {
n , isString , err := d .dataChannel .ReadDataChannel (buffer )
if err != nil {
if errors .Is (err , io .ErrShortBuffer ) {
if int64 (n ) < int64 (d .api .settingEngine .getSCTPMaxMessageSize ()) {
buffer = append (buffer , make ([]byte , len (buffer ))...)
continue
}
d .log .Errorf (
"Incoming DataChannel message larger then Max Message size %v" ,
d .api .settingEngine .getSCTPMaxMessageSize (),
)
}
d .setReadyState (DataChannelStateClosed )
if !errors .Is (err , io .EOF ) {
d .onError (err )
}
d .onClose ()
return
}
d .onMessage (DataChannelMessage {
Data : append ([]byte {}, buffer [:n ]...),
IsString : isString ,
})
}
}
func (d *DataChannel ) Send (data []byte ) error {
err := d .ensureOpen ()
if err != nil {
return err
}
_, err = d .dataChannel .WriteDataChannel (data , false )
return err
}
func (d *DataChannel ) SendText (s string ) error {
err := d .ensureOpen ()
if err != nil {
return err
}
_, err = d .dataChannel .WriteDataChannel ([]byte (s ), true )
return err
}
func (d *DataChannel ) ensureOpen () error {
d .mu .RLock ()
defer d .mu .RUnlock ()
if d .ReadyState () != DataChannelStateOpen {
return io .ErrClosedPipe
}
return nil
}
func (d *DataChannel ) Detach () (datachannel .ReadWriteCloser , error ) {
return d .DetachWithDeadline ()
}
func (d *DataChannel ) DetachWithDeadline () (datachannel .ReadWriteCloserDeadliner , error ) {
d .mu .Lock ()
if !d .api .settingEngine .detach .DataChannels {
d .mu .Unlock ()
return nil , errDetachNotEnabled
}
if d .dataChannel == nil {
d .mu .Unlock ()
return nil , errDetachBeforeOpened
}
d .detachCalled = true
dataChannel := d .dataChannel
d .mu .Unlock ()
d .sctpTransport .lock .Lock ()
n := len (d .sctpTransport .dataChannels )
j := 0
for i := 0 ; i < n ; i ++ {
if d == d .sctpTransport .dataChannels [i ] {
continue
}
d .sctpTransport .dataChannels [j ] = d .sctpTransport .dataChannels [i ]
j ++
}
for i := j ; i < n ; i ++ {
d .sctpTransport .dataChannels [i ] = nil
}
d .sctpTransport .dataChannels = d .sctpTransport .dataChannels [:j ]
d .sctpTransport .lock .Unlock ()
return dataChannel , nil
}
func (d *DataChannel ) Close () error {
return d .close (false )
}
func (d *DataChannel ) GracefulClose () error {
return d .close (true )
}
func (d *DataChannel ) close (shouldGracefullyClose bool ) error {
d .mu .Lock ()
d .isGracefulClosed = true
readLoopActive := d .readLoopActive
if shouldGracefullyClose && readLoopActive != nil {
defer func () {
<-readLoopActive
}()
}
haveSctpTransport := d .dataChannel != nil
d .mu .Unlock ()
if d .ReadyState () == DataChannelStateClosed {
return nil
}
d .setReadyState (DataChannelStateClosing )
if !haveSctpTransport {
return nil
}
return d .dataChannel .Close ()
}
func (d *DataChannel ) Label () string {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .label
}
func (d *DataChannel ) Ordered () bool {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .ordered
}
func (d *DataChannel ) MaxPacketLifeTime () *uint16 {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .maxPacketLifeTime
}
func (d *DataChannel ) MaxRetransmits () *uint16 {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .maxRetransmits
}
func (d *DataChannel ) Protocol () string {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .protocol
}
func (d *DataChannel ) Negotiated () bool {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .negotiated
}
func (d *DataChannel ) ID () *uint16 {
d .mu .RLock ()
defer d .mu .RUnlock ()
return d .id
}
func (d *DataChannel ) ReadyState () DataChannelState {
if v , ok := d .readyState .Load ().(DataChannelState ); ok {
return v
}
return DataChannelState (0 )
}
func (d *DataChannel ) BufferedAmount () uint64 {
d .mu .RLock ()
defer d .mu .RUnlock ()
if d .dataChannel == nil {
return 0
}
return d .dataChannel .BufferedAmount ()
}
func (d *DataChannel ) BufferedAmountLowThreshold () uint64 {
d .mu .RLock ()
defer d .mu .RUnlock ()
if d .dataChannel == nil {
return d .bufferedAmountLowThreshold
}
return d .dataChannel .BufferedAmountLowThreshold ()
}
func (d *DataChannel ) SetBufferedAmountLowThreshold (th uint64 ) {
d .mu .Lock ()
defer d .mu .Unlock ()
d .bufferedAmountLowThreshold = th
if d .dataChannel != nil {
d .dataChannel .SetBufferedAmountLowThreshold (th )
}
}
func (d *DataChannel ) OnBufferedAmountLow (f func ()) {
d .mu .Lock ()
defer d .mu .Unlock ()
d .onBufferedAmountLow = f
if d .dataChannel != nil {
d .dataChannel .OnBufferedAmountLow (f )
}
}
func (d *DataChannel ) getStatsID () string {
d .mu .Lock ()
defer d .mu .Unlock ()
return d .statsID
}
func (d *DataChannel ) collectStats (collector *statsReportCollector ) {
collector .Collecting ()
d .mu .Lock ()
defer d .mu .Unlock ()
stats := DataChannelStats {
Timestamp : statsTimestampNow (),
Type : StatsTypeDataChannel ,
ID : d .statsID ,
Label : d .label ,
Protocol : d .protocol ,
State : d .ReadyState (),
}
if d .id != nil {
stats .DataChannelIdentifier = int32 (*d .id )
}
if d .dataChannel != nil {
stats .MessagesSent = d .dataChannel .MessagesSent ()
stats .BytesSent = d .dataChannel .BytesSent ()
stats .MessagesReceived = d .dataChannel .MessagesReceived ()
stats .BytesReceived = d .dataChannel .BytesReceived ()
}
collector .Collect (stats .ID , stats )
}
func (d *DataChannel ) setReadyState (r DataChannelState ) {
d .readyState .Store (r )
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .