package datachannel
import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"github.com/pion/logging"
"github.com/pion/sctp"
)
const receiveMTU = 8192
type Reader interface {
ReadDataChannel ([]byte ) (int , bool , error )
}
type ReadDeadliner interface {
SetReadDeadline (time .Time ) error
}
type Writer interface {
WriteDataChannel ([]byte , bool ) (int , error )
}
type WriteDeadliner interface {
SetWriteDeadline (time .Time ) error
}
type ReadWriteCloser interface {
io .Reader
io .Writer
Reader
Writer
io .Closer
}
type ReadWriteCloserDeadliner interface {
ReadWriteCloser
ReadDeadliner
WriteDeadliner
}
type DataChannel struct {
Config
messagesSent uint32
messagesReceived uint32
bytesSent uint64
bytesReceived uint64
mu sync .Mutex
onOpenCompleteHandler func ()
openCompleteHandlerOnce sync .Once
stream *sctp .Stream
log logging .LeveledLogger
}
type Config struct {
ChannelType ChannelType
Negotiated bool
Priority uint16
ReliabilityParameter uint32
Label string
Protocol string
LoggerFactory logging .LoggerFactory
}
func newDataChannel(stream *sctp .Stream , config *Config ) *DataChannel {
return &DataChannel {
Config : *config ,
stream : stream ,
log : config .LoggerFactory .NewLogger ("datachannel" ),
}
}
func Dial (a *sctp .Association , id uint16 , config *Config ) (*DataChannel , error ) {
stream , err := a .OpenStream (id , sctp .PayloadTypeWebRTCBinary )
if err != nil {
return nil , err
}
dc , err := Client (stream , config )
if err != nil {
return nil , err
}
return dc , nil
}
func Client (stream *sctp .Stream , config *Config ) (*DataChannel , error ) {
msg := &channelOpen {
ChannelType : config .ChannelType ,
Priority : config .Priority ,
ReliabilityParameter : config .ReliabilityParameter ,
Label : []byte (config .Label ),
Protocol : []byte (config .Protocol ),
}
if !config .Negotiated {
rawMsg , err := msg .Marshal ()
if err != nil {
return nil , fmt .Errorf ("failed to marshal ChannelOpen %w" , err )
}
if _, err = stream .WriteSCTP (rawMsg , sctp .PayloadTypeWebRTCDCEP ); err != nil {
return nil , fmt .Errorf ("failed to send ChannelOpen %w" , err )
}
}
return newDataChannel (stream , config ), nil
}
func Accept (a *sctp .Association , config *Config , existingChannels ...*DataChannel ) (*DataChannel , error ) {
stream , err := a .AcceptStream ()
if err != nil {
return nil , err
}
for _ , ch := range existingChannels {
if ch .StreamIdentifier () == stream .StreamIdentifier () {
ch .stream .SetDefaultPayloadType (sctp .PayloadTypeWebRTCBinary )
return ch , nil
}
}
stream .SetDefaultPayloadType (sctp .PayloadTypeWebRTCBinary )
dc , err := Server (stream , config )
if err != nil {
return nil , err
}
return dc , nil
}
func Server (stream *sctp .Stream , config *Config ) (*DataChannel , error ) {
buffer := make ([]byte , receiveMTU )
n , ppi , err := stream .ReadSCTP (buffer )
if err != nil {
return nil , err
}
if ppi != sctp .PayloadTypeWebRTCDCEP {
return nil , fmt .Errorf ("%w %s" , ErrInvalidPayloadProtocolIdentifier , ppi )
}
openMsg , err := parseExpectDataChannelOpen (buffer [:n ])
if err != nil {
return nil , fmt .Errorf ("failed to parse DataChannelOpen packet %w" , err )
}
config .ChannelType = openMsg .ChannelType
config .Priority = openMsg .Priority
config .ReliabilityParameter = openMsg .ReliabilityParameter
config .Label = string (openMsg .Label )
config .Protocol = string (openMsg .Protocol )
dataChannel := newDataChannel (stream , config )
err = dataChannel .writeDataChannelAck ()
if err != nil {
return nil , err
}
err = dataChannel .commitReliabilityParams ()
if err != nil {
return nil , err
}
return dataChannel , nil
}
func (c *DataChannel ) Read (p []byte ) (int , error ) {
n , _ , err := c .ReadDataChannel (p )
return n , err
}
func (c *DataChannel ) ReadDataChannel (p []byte ) (int , bool , error ) {
for {
n , ppi , err := c .stream .ReadSCTP (p )
if errors .Is (err , io .EOF ) {
if closeErr := c .stream .Close (); closeErr != nil {
return 0 , false , closeErr
}
}
if err != nil {
return 0 , false , err
}
if ppi == sctp .PayloadTypeWebRTCDCEP {
if err = c .handleDCEP (p [:n ]); err != nil {
c .log .Errorf ("Failed to handle DCEP: %s" , err .Error())
}
continue
} else if ppi == sctp .PayloadTypeWebRTCBinaryEmpty || ppi == sctp .PayloadTypeWebRTCStringEmpty {
n = 0
}
atomic .AddUint32 (&c .messagesReceived , 1 )
atomic .AddUint64 (&c .bytesReceived , uint64 (n ))
isString := ppi == sctp .PayloadTypeWebRTCString || ppi == sctp .PayloadTypeWebRTCStringEmpty
return n , isString , err
}
}
func (c *DataChannel ) SetReadDeadline (t time .Time ) error {
return c .stream .SetReadDeadline (t )
}
func (c *DataChannel ) SetWriteDeadline (t time .Time ) error {
return c .stream .SetWriteDeadline (t )
}
func (c *DataChannel ) MessagesSent () uint32 {
return atomic .LoadUint32 (&c .messagesSent )
}
func (c *DataChannel ) MessagesReceived () uint32 {
return atomic .LoadUint32 (&c .messagesReceived )
}
func (c *DataChannel ) OnOpen (f func ()) {
c .mu .Lock ()
c .openCompleteHandlerOnce = sync .Once {}
c .onOpenCompleteHandler = f
c .mu .Unlock ()
}
func (c *DataChannel ) onOpenComplete () {
c .mu .Lock ()
hdlr := c .onOpenCompleteHandler
c .mu .Unlock ()
if hdlr != nil {
go c .openCompleteHandlerOnce .Do (func () {
hdlr ()
})
}
}
func (c *DataChannel ) BytesSent () uint64 {
return atomic .LoadUint64 (&c .bytesSent )
}
func (c *DataChannel ) BytesReceived () uint64 {
return atomic .LoadUint64 (&c .bytesReceived )
}
func (c *DataChannel ) StreamIdentifier () uint16 {
return c .stream .StreamIdentifier ()
}
func (c *DataChannel ) handleDCEP (data []byte ) error {
msg , err := parse (data )
if err != nil {
return fmt .Errorf ("failed to parse DataChannel packet %w" , err )
}
switch msg := msg .(type ) {
case *channelAck :
if err := c .commitReliabilityParams (); err != nil {
return err
}
c .onOpenComplete ()
default :
return fmt .Errorf ("%w, wanted ACK got %v" , ErrUnexpectedDataChannelType , msg )
}
return nil
}
func (c *DataChannel ) Write (p []byte ) (n int , err error ) {
return c .WriteDataChannel (p , false )
}
func (c *DataChannel ) WriteDataChannel (p []byte , isString bool ) (n int , err error ) {
var ppi sctp .PayloadProtocolIdentifier
switch {
case !isString && len (p ) > 0 :
ppi = sctp .PayloadTypeWebRTCBinary
case !isString && len (p ) == 0 :
ppi = sctp .PayloadTypeWebRTCBinaryEmpty
case isString && len (p ) > 0 :
ppi = sctp .PayloadTypeWebRTCString
case isString && len (p ) == 0 :
ppi = sctp .PayloadTypeWebRTCStringEmpty
}
atomic .AddUint32 (&c .messagesSent , 1 )
atomic .AddUint64 (&c .bytesSent , uint64 (len (p )))
if len (p ) == 0 {
_ , err := c .stream .WriteSCTP ([]byte {0 }, ppi )
return 0 , err
}
return c .stream .WriteSCTP (p , ppi )
}
func (c *DataChannel ) writeDataChannelAck () error {
ack := channelAck {}
ackMsg , err := ack .Marshal ()
if err != nil {
return fmt .Errorf ("failed to marshal ChannelOpen ACK: %w" , err )
}
if _, err = c .stream .WriteSCTP (ackMsg , sctp .PayloadTypeWebRTCDCEP ); err != nil {
return fmt .Errorf ("failed to send ChannelOpen ACK: %w" , err )
}
return err
}
func (c *DataChannel ) Close () error {
return c .stream .Close ()
}
func (c *DataChannel ) BufferedAmount () uint64 {
return c .stream .BufferedAmount ()
}
func (c *DataChannel ) BufferedAmountLowThreshold () uint64 {
return c .stream .BufferedAmountLowThreshold ()
}
func (c *DataChannel ) SetBufferedAmountLowThreshold (th uint64 ) {
c .stream .SetBufferedAmountLowThreshold (th )
}
func (c *DataChannel ) OnBufferedAmountLow (f func ()) {
c .stream .OnBufferedAmountLow (f )
}
func (c *DataChannel ) commitReliabilityParams () error {
switch c .Config .ChannelType {
case ChannelTypeReliable :
c .stream .SetReliabilityParams (false , sctp .ReliabilityTypeReliable , c .Config .ReliabilityParameter )
case ChannelTypeReliableUnordered :
c .stream .SetReliabilityParams (true , sctp .ReliabilityTypeReliable , c .Config .ReliabilityParameter )
case ChannelTypePartialReliableRexmit :
c .stream .SetReliabilityParams (false , sctp .ReliabilityTypeRexmit , c .Config .ReliabilityParameter )
case ChannelTypePartialReliableRexmitUnordered :
c .stream .SetReliabilityParams (true , sctp .ReliabilityTypeRexmit , c .Config .ReliabilityParameter )
case ChannelTypePartialReliableTimed :
c .stream .SetReliabilityParams (false , sctp .ReliabilityTypeTimed , c .Config .ReliabilityParameter )
case ChannelTypePartialReliableTimedUnordered :
c .stream .SetReliabilityParams (true , sctp .ReliabilityTypeTimed , c .Config .ReliabilityParameter )
default :
return fmt .Errorf ("%w %v" , ErrInvalidChannelType , c .Config .ChannelType )
}
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .