package sctp
import (
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
)
var castagnoliTable = crc32 .MakeTable (crc32 .Castagnoli )
var fourZeroes [4 ]byte
type packet struct {
sourcePort uint16
destinationPort uint16
verificationTag uint32
chunks []chunk
}
const (
packetHeaderSize = 12
)
var (
ErrPacketRawTooSmall = errors .New ("raw is smaller than the minimum length for a SCTP packet" )
ErrParseSCTPChunkNotEnoughData = errors .New ("unable to parse SCTP chunk, not enough data for complete header" )
ErrUnmarshalUnknownChunkType = errors .New ("failed to unmarshal, contains unknown chunk type" )
ErrChecksumMismatch = errors .New ("checksum mismatch theirs" )
)
func (p *packet ) unmarshal (doChecksum bool , raw []byte ) error {
if len (raw ) < packetHeaderSize {
return fmt .Errorf ("%w: raw only %d bytes, %d is the minimum length" , ErrPacketRawTooSmall , len (raw ), packetHeaderSize )
}
offset := packetHeaderSize
if offset +chunkHeaderSize <= len (raw ) {
switch chunkType (raw [offset ]) {
case ctInit , ctCookieEcho :
doChecksum = true
default :
}
}
theirChecksum := binary .LittleEndian .Uint32 (raw [8 :])
if theirChecksum != 0 || doChecksum {
ourChecksum := generatePacketChecksum (raw )
if theirChecksum != ourChecksum {
return fmt .Errorf ("%w: %d ours: %d" , ErrChecksumMismatch , theirChecksum , ourChecksum )
}
}
p .sourcePort = binary .BigEndian .Uint16 (raw [0 :])
p .destinationPort = binary .BigEndian .Uint16 (raw [2 :])
p .verificationTag = binary .BigEndian .Uint32 (raw [4 :])
for {
if offset == len (raw ) {
break
} else if offset +chunkHeaderSize > len (raw ) {
return fmt .Errorf ("%w: offset %d remaining %d" , ErrParseSCTPChunkNotEnoughData , offset , len (raw ))
}
var dataChunk chunk
switch chunkType (raw [offset ]) {
case ctInit :
dataChunk = &chunkInit {}
case ctInitAck :
dataChunk = &chunkInitAck {}
case ctAbort :
dataChunk = &chunkAbort {}
case ctCookieEcho :
dataChunk = &chunkCookieEcho {}
case ctCookieAck :
dataChunk = &chunkCookieAck {}
case ctHeartbeat :
dataChunk = &chunkHeartbeat {}
case ctPayloadData :
dataChunk = &chunkPayloadData {}
case ctSack :
dataChunk = &chunkSelectiveAck {}
case ctReconfig :
dataChunk = &chunkReconfig {}
case ctForwardTSN :
dataChunk = &chunkForwardTSN {}
case ctError :
dataChunk = &chunkError {}
case ctShutdown :
dataChunk = &chunkShutdown {}
case ctShutdownAck :
dataChunk = &chunkShutdownAck {}
case ctShutdownComplete :
dataChunk = &chunkShutdownComplete {}
default :
return fmt .Errorf ("%w: %s" , ErrUnmarshalUnknownChunkType , chunkType (raw [offset ]).String ())
}
if err := dataChunk .unmarshal (raw [offset :]); err != nil {
return err
}
p .chunks = append (p .chunks , dataChunk )
chunkValuePadding := getPadding (dataChunk .valueLength ())
offset += chunkHeaderSize + dataChunk .valueLength () + chunkValuePadding
}
return nil
}
func (p *packet ) marshal (doChecksum bool ) ([]byte , error ) {
raw := make ([]byte , packetHeaderSize )
binary .BigEndian .PutUint16 (raw [0 :], p .sourcePort )
binary .BigEndian .PutUint16 (raw [2 :], p .destinationPort )
binary .BigEndian .PutUint32 (raw [4 :], p .verificationTag )
for _ , c := range p .chunks {
chunkRaw , err := c .marshal ()
if err != nil {
return nil , err
}
raw = append (raw , chunkRaw ...)
paddingNeeded := getPadding (len (raw ))
if paddingNeeded != 0 {
raw = append (raw , make ([]byte , paddingNeeded )...)
}
}
if doChecksum {
binary .LittleEndian .PutUint32 (raw [8 :], generatePacketChecksum (raw ))
}
return raw , nil
}
func generatePacketChecksum(raw []byte ) (sum uint32 ) {
sum = crc32 .Update (sum , castagnoliTable , raw [0 :8 ])
sum = crc32 .Update (sum , castagnoliTable , fourZeroes [:])
sum = crc32 .Update (sum , castagnoliTable , raw [12 :])
return sum
}
func (p *packet ) String () string {
format := `Packet:
sourcePort: %d
destinationPort: %d
verificationTag: %d
`
res := fmt .Sprintf (format ,
p .sourcePort ,
p .destinationPort ,
p .verificationTag ,
)
for i , chunk := range p .chunks {
res += fmt .Sprintf ("Chunk %d:\n %s" , i , chunk )
}
return res
}
func TryMarshalUnmarshal (msg []byte ) int {
p := &packet {}
err := p .unmarshal (false , msg )
if err != nil {
return 0
}
_, err = p .marshal (false )
if err != nil {
return 0
}
return 1
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .