package srtp
import (
"bytes"
"fmt"
"github.com/pion/transport/v3/replaydetector"
)
const (
labelSRTPEncryption = 0x00
labelSRTPAuthenticationTag = 0x01
labelSRTPSalt = 0x02
labelSRTCPEncryption = 0x03
labelSRTCPAuthenticationTag = 0x04
labelSRTCPSalt = 0x05
maxSequenceNumber = 65535
maxROC = (1 << 32 ) - 1
seqNumMedian = 1 << 15
seqNumMax = 1 << 16
)
type srtpSSRCState struct {
ssrc uint32
rolloverHasProcessed bool
index uint64
replayDetector replaydetector .ReplayDetector
}
type srtcpSSRCState struct {
srtcpIndex uint32
ssrc uint32
replayDetector replaydetector .ReplayDetector
}
type RCCMode int
const (
RCCModeNone RCCMode = iota
RCCMode1
RCCMode2
RCCMode3
)
type Context struct {
cipher srtpCipher
srtpSSRCStates map [uint32 ]*srtpSSRCState
srtcpSSRCStates map [uint32 ]*srtcpSSRCState
newSRTCPReplayDetector func () replaydetector .ReplayDetector
newSRTPReplayDetector func () replaydetector .ReplayDetector
profile ProtectionProfile
sendMKI []byte
mkis map [string ]srtpCipher
encryptSRTP bool
encryptSRTCP bool
rccMode RCCMode
rocTransmitRate uint16
authTagRTPLen *int
}
func CreateContext (
masterKey , masterSalt []byte ,
profile ProtectionProfile ,
opts ...ContextOption ,
) (c *Context , err error ) {
c = &Context {
srtpSSRCStates : map [uint32 ]*srtpSSRCState {},
srtcpSSRCStates : map [uint32 ]*srtcpSSRCState {},
profile : profile ,
mkis : map [string ]srtpCipher {},
}
for _ , o := range append (
[]ContextOption {
SRTPNoReplayProtection (),
SRTCPNoReplayProtection (),
SRTPEncryption (),
SRTCPEncryption (),
},
opts ...,
) {
if errOpt := o (c ); errOpt != nil {
return nil , errOpt
}
}
if err = c .checkRCCMode (); err != nil {
return nil , err
}
if c .authTagRTPLen != nil {
var authKeyLen int
authKeyLen , err = c .profile .AuthKeyLen ()
if err != nil {
return nil , err
}
if *c .authTagRTPLen > authKeyLen {
return nil , errTooLongSRTPAuthTag
}
}
c .cipher , err = c .createCipher (c .sendMKI , masterKey , masterSalt , c .encryptSRTP , c .encryptSRTCP )
if err != nil {
return nil , err
}
if len (c .sendMKI ) != 0 {
c .mkis [string (c .sendMKI )] = c .cipher
}
return c , nil
}
func (c *Context ) AddCipherForMKI (mki , masterKey , masterSalt []byte ) error {
if len (c .mkis ) == 0 {
return errMKIIsNotEnabled
}
if len (mki ) == 0 || len (mki ) != len (c .sendMKI ) {
return errInvalidMKILength
}
if _ , ok := c .mkis [string (mki )]; ok {
return errMKIAlreadyInUse
}
cipher , err := c .createCipher (mki , masterKey , masterSalt , c .encryptSRTP , c .encryptSRTCP )
if err != nil {
return err
}
c .mkis [string (mki )] = cipher
return nil
}
func (c *Context ) createCipher (mki , masterKey , masterSalt []byte , encryptSRTP , encryptSRTCP bool ) (srtpCipher , error ) {
keyLen , err := c .profile .KeyLen ()
if err != nil {
return nil , err
}
saltLen , err := c .profile .SaltLen ()
if err != nil {
return nil , err
}
if masterKeyLen := len (masterKey ); masterKeyLen != keyLen {
return nil , fmt .Errorf ("%w expected(%d) actual(%d)" , errShortSrtpMasterKey , keyLen , masterKey )
} else if masterSaltLen := len (masterSalt ); masterSaltLen != saltLen {
return nil , fmt .Errorf ("%w expected(%d) actual(%d)" , errShortSrtpMasterSalt , saltLen , masterSaltLen )
}
profileWithArgs := protectionProfileWithArgs {
ProtectionProfile : c .profile ,
authTagRTPLen : c .authTagRTPLen ,
}
switch c .profile {
case ProtectionProfileAeadAes128Gcm , ProtectionProfileAeadAes256Gcm :
return newSrtpCipherAeadAesGcm (profileWithArgs , masterKey , masterSalt , mki , encryptSRTP , encryptSRTCP )
case ProtectionProfileAes128CmHmacSha1_32 ,
ProtectionProfileAes128CmHmacSha1_80 ,
ProtectionProfileAes256CmHmacSha1_32 ,
ProtectionProfileAes256CmHmacSha1_80 :
return newSrtpCipherAesCmHmacSha1 (profileWithArgs , masterKey , masterSalt , mki , encryptSRTP , encryptSRTCP )
case ProtectionProfileNullHmacSha1_32 , ProtectionProfileNullHmacSha1_80 :
return newSrtpCipherAesCmHmacSha1 (profileWithArgs , masterKey , masterSalt , mki , false , false )
default :
return nil , fmt .Errorf ("%w: %#v" , errNoSuchSRTPProfile , c .profile )
}
}
func (c *Context ) RemoveMKI (mki []byte ) error {
if _ , ok := c .mkis [string (mki )]; !ok {
return ErrMKINotFound
}
if bytes .Equal (mki , c .sendMKI ) {
return errMKIAlreadyInUse
}
delete (c .mkis , string (mki ))
return nil
}
func (c *Context ) SetSendMKI (mki []byte ) error {
cipher , ok := c .mkis [string (mki )]
if !ok {
return ErrMKINotFound
}
c .sendMKI = mki
c .cipher = cipher
return nil
}
func (s *srtpSSRCState ) nextRolloverCount (sequenceNumber uint16 ) (roc uint32 , diff int64 , overflow bool ) {
seq := int32 (sequenceNumber )
localRoc := uint32 (s .index >> 16 )
localSeq := int32 (s .index & (seqNumMax - 1 ))
guessRoc := localRoc
var difference int32
if s .rolloverHasProcessed {
if s .index > seqNumMedian {
if localSeq < seqNumMedian {
if seq -localSeq > seqNumMedian {
guessRoc = localRoc - 1
difference = seq - localSeq - seqNumMax
} else {
guessRoc = localRoc
difference = seq - localSeq
}
} else {
if localSeq -seqNumMedian > seq {
guessRoc = localRoc + 1
difference = seq - localSeq + seqNumMax
} else {
guessRoc = localRoc
difference = seq - localSeq
}
}
} else {
difference = seq - localSeq
}
}
return guessRoc , int64 (difference ), (guessRoc == 0 && localRoc == maxROC )
}
func (s *srtpSSRCState ) updateRolloverCount (sequenceNumber uint16 , difference int64 , hasRemoteRoc bool ,
remoteRoc uint32 ,
) {
switch {
case hasRemoteRoc :
s .index = (uint64 (remoteRoc ) << 16 ) | uint64 (sequenceNumber )
s .rolloverHasProcessed = true
case !s .rolloverHasProcessed :
s .index |= uint64 (sequenceNumber )
s .rolloverHasProcessed = true
case difference > 0 :
s .index += uint64 (difference )
}
}
func (c *Context ) getSRTPSSRCState (ssrc uint32 ) *srtpSSRCState {
s , ok := c .srtpSSRCStates [ssrc ]
if ok {
return s
}
s = &srtpSSRCState {
ssrc : ssrc ,
replayDetector : c .newSRTPReplayDetector (),
}
c .srtpSSRCStates [ssrc ] = s
return s
}
func (c *Context ) getSRTCPSSRCState (ssrc uint32 ) *srtcpSSRCState {
s , ok := c .srtcpSSRCStates [ssrc ]
if ok {
return s
}
s = &srtcpSSRCState {
ssrc : ssrc ,
replayDetector : c .newSRTCPReplayDetector (),
}
c .srtcpSSRCStates [ssrc ] = s
return s
}
func (c *Context ) ROC (ssrc uint32 ) (uint32 , bool ) {
s , ok := c .srtpSSRCStates [ssrc ]
if !ok {
return 0 , false
}
return uint32 (s .index >> 16 ), true
}
func (c *Context ) SetROC (ssrc uint32 , roc uint32 ) {
s := c .getSRTPSSRCState (ssrc )
s .index = uint64 (roc ) << 16
s .rolloverHasProcessed = false
}
func (c *Context ) Index (ssrc uint32 ) (uint32 , bool ) {
s , ok := c .srtcpSSRCStates [ssrc ]
if !ok {
return 0 , false
}
return s .srtcpIndex , true
}
func (c *Context ) SetIndex (ssrc uint32 , index uint32 ) {
s := c .getSRTCPSSRCState (ssrc )
s .srtcpIndex = index % (maxSRTCPIndex + 1 )
}
func (c *Context ) checkRCCMode () error {
if c .rccMode == RCCModeNone {
return nil
}
if c .rocTransmitRate == 0 {
return errZeroRocTransmitRate
}
switch c .profile {
case ProtectionProfileAeadAes128Gcm , ProtectionProfileAeadAes256Gcm :
if c .rccMode != RCCMode3 {
return errUnsupportedRccMode
}
case ProtectionProfileAes128CmHmacSha1_32 ,
ProtectionProfileAes256CmHmacSha1_32 ,
ProtectionProfileNullHmacSha1_32 :
if c .authTagRTPLen == nil {
return errTooShortSRTPAuthTag
}
fallthrough
case ProtectionProfileAes128CmHmacSha1_80 ,
ProtectionProfileAes256CmHmacSha1_80 ,
ProtectionProfileNullHmacSha1_80 :
if c .rccMode != RCCMode2 {
return errUnsupportedRccMode
}
if c .authTagRTPLen != nil && *c .authTagRTPLen < 4 {
return errTooShortSRTPAuthTag
}
default :
return errUnsupportedRccMode
}
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .