// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package sctp

import (
	
	
	
	
)

// Create the crc32 table we'll use for the checksum.
var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) // nolint:gochecknoglobals

// Allocate and zero this data once.
// We need to use it for the checksum and don't want to allocate/clear each time.
var fourZeroes [4]byte // nolint:gochecknoglobals

/*
Packet represents an SCTP packet, defined in https://tools.ietf.org/html/rfc4960#section-3
An SCTP packet is composed of a common header and chunks.  A chunk
contains either control information or user data.

						SCTP Packet Format
	 0                   1                   2                   3
	 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	|                        Common Header                          |
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	|                          Chunk #1                             |
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	|                           ...                                 |
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	|                          Chunk #n                             |
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

					SCTP Common Header Format
	 0                   1                   2                   3
	 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	|     Source Value Number      |     Destination Value Number   |
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	|                      Verification Tag                         |
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
	|                           Checksum                            |
	+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
type packet struct {
	sourcePort      uint16
	destinationPort uint16
	verificationTag uint32
	chunks          []chunk
}

const (
	packetHeaderSize = 12
)

// SCTP packet errors.
var (
	ErrPacketRawTooSmall           = errors.New("raw is smaller than the minimum length for a SCTP packet")
	ErrParseSCTPChunkNotEnoughData = errors.New("unable to parse SCTP chunk, not enough data for complete header")
	ErrUnmarshalUnknownChunkType   = errors.New("failed to unmarshal, contains unknown chunk type")
	ErrChecksumMismatch            = errors.New("checksum mismatch theirs")
)

func ( *packet) ( bool,  []byte) error { //nolint:cyclop
	if len() < packetHeaderSize {
		return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(), packetHeaderSize)
	}

	 := packetHeaderSize

	// Check if doing CRC32c is required.
	// Without having SCTP AUTH implemented, this depends only on the type
	// og the first chunk.
	if +chunkHeaderSize <= len() {
		switch chunkType([]) {
		case ctInit, ctCookieEcho:
			 = true
		default:
		}
	}
	 := binary.LittleEndian.Uint32([8:])
	if  != 0 ||  {
		 := generatePacketChecksum()
		if  !=  {
			return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, , )
		}
	}

	.sourcePort = binary.BigEndian.Uint16([0:])
	.destinationPort = binary.BigEndian.Uint16([2:])
	.verificationTag = binary.BigEndian.Uint32([4:])

	for {
		// Exact match, no more chunks
		if  == len() {
			break
		} else if +chunkHeaderSize > len() {
			return fmt.Errorf("%w: offset %d remaining %d", ErrParseSCTPChunkNotEnoughData, , len())
		}

		var  chunk
		switch chunkType([]) {
		case ctInit:
			 = &chunkInit{}
		case ctInitAck:
			 = &chunkInitAck{}
		case ctAbort:
			 = &chunkAbort{}
		case ctCookieEcho:
			 = &chunkCookieEcho{}
		case ctCookieAck:
			 = &chunkCookieAck{}
		case ctHeartbeat:
			 = &chunkHeartbeat{}
		case ctPayloadData:
			 = &chunkPayloadData{}
		case ctSack:
			 = &chunkSelectiveAck{}
		case ctReconfig:
			 = &chunkReconfig{}
		case ctForwardTSN:
			 = &chunkForwardTSN{}
		case ctError:
			 = &chunkError{}
		case ctShutdown:
			 = &chunkShutdown{}
		case ctShutdownAck:
			 = &chunkShutdownAck{}
		case ctShutdownComplete:
			 = &chunkShutdownComplete{}
		default:
			return fmt.Errorf("%w: %s", ErrUnmarshalUnknownChunkType, chunkType([]).String())
		}

		if  := .unmarshal([:]);  != nil {
			return 
		}

		.chunks = append(.chunks, )
		 := getPadding(.valueLength())
		 += chunkHeaderSize + .valueLength() + 
	}

	return nil
}

func ( *packet) ( bool) ([]byte, error) {
	 := make([]byte, packetHeaderSize)

	// Populate static headers
	// 8-12 is Checksum which will be populated when packet is complete
	binary.BigEndian.PutUint16([0:], .sourcePort)
	binary.BigEndian.PutUint16([2:], .destinationPort)
	binary.BigEndian.PutUint32([4:], .verificationTag)

	// Populate chunks
	for ,  := range .chunks {
		,  := .marshal()
		if  != nil {
			return nil, 
		}
		 = append(, ...) //nolint:makezero // todo:fix

		 := getPadding(len())
		if  != 0 {
			 = append(, make([]byte, )...) //nolint:makezero // todo:fix
		}
	}

	if  {
		// golang CRC32C uses reflected input and reflected output, the
		// net result of this is to have the bytes flipped compared to
		// the non reflected variant that the spec expects.
		//
		// Use LittleEndian.PutUint32 to avoid flipping the bytes in to
		// the spec compliant checksum order
		binary.LittleEndian.PutUint32([8:], generatePacketChecksum())
	}

	return , nil
}

func generatePacketChecksum( []byte) ( uint32) {
	// Fastest way to do a crc32 without allocating.
	 = crc32.Update(, castagnoliTable, [0:8])
	 = crc32.Update(, castagnoliTable, fourZeroes[:])
	 = crc32.Update(, castagnoliTable, [12:])

	return 
}

// String makes packet printable.
func ( *packet) () string {
	 := `Packet:
	sourcePort: %d
	destinationPort: %d
	verificationTag: %d
	`
	 := fmt.Sprintf(,
		.sourcePort,
		.destinationPort,
		.verificationTag,
	)
	for ,  := range .chunks {
		 += fmt.Sprintf("Chunk %d:\n %s", , )
	}

	return 
}

// TryMarshalUnmarshal attempts to marshal and unmarshal a message. Added for fuzzing.
func ( []byte) int {
	 := &packet{}
	 := .unmarshal(false, )
	if  != nil {
		return 0
	}

	_,  = .marshal(false)
	if  != nil {
		return 0
	}

	return 1
}