package ciphersuite
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"encoding/binary"
"hash"
"github.com/pion/dtls/v3/internal/util"
"github.com/pion/dtls/v3/pkg/crypto/prf"
"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
"golang.org/x/crypto/cryptobyte"
)
type cbcMode interface {
cipher .BlockMode
SetIV([]byte )
}
type CBC struct {
writeCBC, readCBC cbcMode
writeMac, readMac []byte
h prf .HashFunc
}
func NewCBC (
localKey , localWriteIV , localMac , remoteKey , remoteWriteIV , remoteMac []byte ,
hashFunc prf .HashFunc ,
) (*CBC , error ) {
writeBlock , err := aes .NewCipher (localKey )
if err != nil {
return nil , err
}
readBlock , err := aes .NewCipher (remoteKey )
if err != nil {
return nil , err
}
writeCBC , ok := cipher .NewCBCEncrypter (writeBlock , localWriteIV ).(cbcMode )
if !ok {
return nil , errFailedToCast
}
readCBC , ok := cipher .NewCBCDecrypter (readBlock , remoteWriteIV ).(cbcMode )
if !ok {
return nil , errFailedToCast
}
return &CBC {
writeCBC : writeCBC ,
writeMac : localMac ,
readCBC : readCBC ,
readMac : remoteMac ,
h : hashFunc ,
}, nil
}
func (c *CBC ) Encrypt (pkt *recordlayer .RecordLayer , raw []byte ) ([]byte , error ) {
payload := raw [pkt .Header .Size ():]
raw = raw [:pkt .Header .Size ()]
blockSize := c .writeCBC .BlockSize ()
h := pkt .Header
var err error
var mac []byte
if h .ContentType == protocol .ContentTypeConnectionID {
mac , err = c .hmacCID (h .Epoch , h .SequenceNumber , h .Version , payload , c .writeMac , c .h , h .ConnectionID )
} else {
mac , err = c .hmac (h .Epoch , h .SequenceNumber , h .ContentType , h .Version , payload , c .writeMac , c .h )
}
if err != nil {
return nil , err
}
payload = append (payload , mac ...)
padding := make ([]byte , blockSize -len (payload )%blockSize )
paddingLen := len (padding )
for i := 0 ; i < paddingLen ; i ++ {
padding [i ] = byte (paddingLen - 1 )
}
payload = append (payload , padding ...)
iv := make ([]byte , blockSize )
if _ , err := rand .Read (iv ); err != nil {
return nil , err
}
c .writeCBC .SetIV (iv )
c .writeCBC .CryptBlocks (payload , payload )
payload = append (iv , payload ...)
raw = append (raw , payload ...)
binary .BigEndian .PutUint16 (raw [pkt .Header .Size ()-2 :], uint16 (len (raw )-pkt .Header .Size ()))
return raw , nil
}
func (c *CBC ) Decrypt (header recordlayer .Header , in []byte ) ([]byte , error ) {
blockSize := c .readCBC .BlockSize ()
mac := c .h ()
if err := header .Unmarshal (in ); err != nil {
return nil , err
}
body := in [header .Size ():]
switch {
case header .ContentType == protocol .ContentTypeChangeCipherSpec :
return in , nil
case len (body )%blockSize != 0 || len (body ) < blockSize +util .Max (mac .Size ()+1 , blockSize ):
return nil , errNotEnoughRoomForNonce
}
c .readCBC .SetIV (body [:blockSize ])
body = body [blockSize :]
c .readCBC .CryptBlocks (body , body )
paddingLen , paddingGood := examinePadding (body )
if paddingGood != 255 {
return nil , errInvalidMAC
}
macSize := mac .Size ()
if len (body ) < macSize {
return nil , errInvalidMAC
}
dataEnd := len (body ) - macSize - paddingLen
expectedMAC := body [dataEnd : dataEnd +macSize ]
var err error
var actualMAC []byte
if header .ContentType == protocol .ContentTypeConnectionID {
actualMAC , err = c .hmacCID (
header .Epoch , header .SequenceNumber , header .Version , body [:dataEnd ], c .readMac , c .h , header .ConnectionID ,
)
} else {
actualMAC , err = c .hmac (
header .Epoch , header .SequenceNumber , header .ContentType , header .Version , body [:dataEnd ], c .readMac , c .h ,
)
}
if err != nil || !hmac .Equal (actualMAC , expectedMAC ) {
return nil , errInvalidMAC
}
return append (in [:header .Size ()], body [:dataEnd ]...), nil
}
func (c *CBC ) hmac (
epoch uint16 ,
sequenceNumber uint64 ,
contentType protocol .ContentType ,
protocolVersion protocol .Version ,
payload []byte ,
key []byte ,
hf func () hash .Hash ,
) ([]byte , error ) {
hmacHash := hmac .New (hf , key )
msg := make ([]byte , 13 )
binary .BigEndian .PutUint16 (msg , epoch )
util .PutBigEndianUint48 (msg [2 :], sequenceNumber )
msg [8 ] = byte (contentType )
msg [9 ] = protocolVersion .Major
msg [10 ] = protocolVersion .Minor
binary .BigEndian .PutUint16 (msg [11 :], uint16 (len (payload )))
if _ , err := hmacHash .Write (msg ); err != nil {
return nil , err
}
if _ , err := hmacHash .Write (payload ); err != nil {
return nil , err
}
return hmacHash .Sum (nil ), nil
}
func (c *CBC ) hmacCID (
epoch uint16 ,
sequenceNumber uint64 ,
protocolVersion protocol .Version ,
payload []byte ,
key []byte ,
hf func () hash .Hash ,
cid []byte ,
) ([]byte , error ) {
ip := &recordlayer .InnerPlaintext {}
if err := ip .Unmarshal (payload ); err != nil {
return nil , err
}
hmacHash := hmac .New (hf , key )
var msg cryptobyte .Builder
msg .AddUint64 (seqNumPlaceholder )
msg .AddUint8 (uint8 (protocol .ContentTypeConnectionID ))
msg .AddUint8 (uint8 (len (cid )))
msg .AddUint8 (uint8 (protocol .ContentTypeConnectionID ))
msg .AddUint8 (protocolVersion .Major )
msg .AddUint8 (protocolVersion .Minor )
msg .AddUint16 (epoch )
util .AddUint48 (&msg , sequenceNumber )
msg .AddBytes (cid )
msg .AddUint16 (uint16 (len (payload )))
msg .AddBytes (ip .Content )
msg .AddUint8 (uint8 (ip .RealType ))
msg .AddBytes (make ([]byte , ip .Zeros ))
if _ , err := hmacHash .Write (msg .BytesOrPanic ()); err != nil {
return nil , err
}
if _ , err := hmacHash .Write (payload ); err != nil {
return nil , err
}
return hmacHash .Sum (nil ), nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .