// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package sctp

import (
	
	
	
	
)

func sortChunksByTSN( []*chunkPayloadData) {
	sort.Slice(, func(,  int) bool {
		return sna32LT([].tsn, [].tsn)
	})
}

func sortChunksBySSN( []*chunkSet) {
	sort.Slice(, func(,  int) bool {
		return sna16LT([].ssn, [].ssn)
	})
}

// chunkSet is a set of chunks that share the same SSN.
type chunkSet struct {
	ssn    uint16 // used only with the ordered chunks
	ppi    PayloadProtocolIdentifier
	chunks []*chunkPayloadData
}

func newChunkSet( uint16,  PayloadProtocolIdentifier) *chunkSet {
	return &chunkSet{
		ssn:    ,
		ppi:    ,
		chunks: []*chunkPayloadData{},
	}
}

func ( *chunkSet) ( *chunkPayloadData) bool {
	// check if dup
	for ,  := range .chunks {
		if .tsn == .tsn {
			return false
		}
	}

	// append and sort
	.chunks = append(.chunks, )
	sortChunksByTSN(.chunks)

	// Check if we now have a complete set
	 := .isComplete()

	return 
}

func ( *chunkSet) () bool {
	// Condition for complete set
	//   0. Has at least one chunk.
	//   1. Begins with beginningFragment set to true
	//   2. Ends with endingFragment set to true
	//   3. TSN monotinically increase by 1 from beginning to end

	// 0.
	 := len(.chunks)
	if  == 0 {
		return false
	}

	// 1.
	if !.chunks[0].beginningFragment {
		return false
	}

	// 2.
	if !.chunks[-1].endingFragment {
		return false
	}

	// 3.
	var  uint32
	for ,  := range .chunks {
		if  > 0 {
			// Fragments must have contiguous TSN
			// From RFC 4960 Section 3.3.1:
			//   When a user message is fragmented into multiple chunks, the TSNs are
			//   used by the receiver to reassemble the message.  This means that the
			//   TSNs for each fragment of a fragmented user message MUST be strictly
			//   sequential.
			if .tsn != +1 {
				// mid or end fragment is missing
				return false
			}
		}

		 = .tsn
	}

	return true
}

type reassemblyQueue struct {
	si              uint16
	nextSSN         uint16 // expected SSN for next ordered chunk
	ordered         []*chunkSet
	unordered       []*chunkSet
	unorderedChunks []*chunkPayloadData
	nBytes          uint64
}

var errTryAgain = errors.New("try again")

func newReassemblyQueue( uint16) *reassemblyQueue {
	// From RFC 4960 Sec 6.5:
	//   The Stream Sequence Number in all the streams MUST start from 0 when
	//   the association is established.  Also, when the Stream Sequence
	//   Number reaches the value 65535 the next Stream Sequence Number MUST
	//   be set to 0.
	return &reassemblyQueue{
		si:        ,
		nextSSN:   0, // From RFC 4960 Sec 6.5:
		ordered:   make([]*chunkSet, 0),
		unordered: make([]*chunkSet, 0),
	}
}

func ( *reassemblyQueue) ( *chunkPayloadData) bool { //nolint:cyclop
	var  *chunkSet

	if .streamIdentifier != .si {
		return false
	}

	if .unordered {
		// First, insert into unorderedChunks array
		.unorderedChunks = append(.unorderedChunks, )
		atomic.AddUint64(&.nBytes, uint64(len(.userData)))
		sortChunksByTSN(.unorderedChunks)

		// Scan unorderedChunks that are contiguous (in TSN)
		 = .findCompleteUnorderedChunkSet()

		// If found, append the complete set to the unordered array
		if  != nil {
			.unordered = append(.unordered, )

			return true
		}

		return false
	}

	// This is an ordered chunk

	if sna16LT(.streamSequenceNumber, .nextSSN) {
		return false
	}

	// Check if a fragmented chunkSet with the fragmented SSN already exists
	if .isFragmented() {
		for ,  := range .ordered {
			// nolint:godox
			// TODO: add caution around SSN wrapping here... this helps only a little bit
			// by ensuring we don't add to an unfragmented cset (1 chunk). There's
			// a case where if the SSN does wrap around, we may see the same SSN
			// for a different chunk.

			// nolint:godox
			// TODO: this slice can get pretty big; it may be worth maintaining a map
			// for O(1) lookups at the cost of 2x memory.
			if .ssn == .streamSequenceNumber && .chunks[0].isFragmented() {
				 = 

				break
			}
		}
	}

	// If not found, create a new chunkSet
	if  == nil {
		 = newChunkSet(.streamSequenceNumber, .payloadType)
		.ordered = append(.ordered, )
		if !.unordered {
			sortChunksBySSN(.ordered)
		}
	}

	atomic.AddUint64(&.nBytes, uint64(len(.userData)))

	return .push()
}

func ( *reassemblyQueue) () *chunkSet {
	 := -1
	 := 0
	var  uint32
	var  bool

	for ,  := range .unorderedChunks {
		// seek beigining
		if .beginningFragment {
			 = 
			 = 1
			 = .tsn

			if .endingFragment {
				 = true

				break
			}

			continue
		}

		if  < 0 {
			continue
		}

		// Check if contiguous in TSN
		if .tsn != +1 {
			 = -1

			continue
		}

		 = .tsn
		++

		if .endingFragment {
			 = true

			break
		}
	}

	if ! {
		return nil
	}

	// Extract the range of chunks
	var  []*chunkPayloadData
	 = append(, .unorderedChunks[:+]...)

	.unorderedChunks = append(
		.unorderedChunks[:],
		.unorderedChunks[+:]...)

	 := newChunkSet(0, [0].payloadType)
	.chunks = 

	return 
}

func ( *reassemblyQueue) () bool {
	// Check unordered first
	if len(.unordered) > 0 {
		// The chunk sets in r.unordered should all be complete.
		return true
	}

	// Check ordered sets
	if len(.ordered) > 0 {
		 := .ordered[0]
		if .isComplete() {
			if sna16LTE(.ssn, .nextSSN) {
				return true
			}
		}
	}

	return false
}

func ( *reassemblyQueue) ( []byte) (int, PayloadProtocolIdentifier, error) { // nolint: cyclop
	var (
		        *chunkSet
		 bool
		      int
		         error
	)

	switch {
	case len(.unordered) > 0:
		 = .unordered[0]
		 = true
	case len(.ordered) > 0:
		 = .ordered[0]
		if !.isComplete() {
			return 0, 0, errTryAgain
		}
		if sna16GT(.ssn, .nextSSN) {
			return 0, 0, errTryAgain
		}
	default:
		return 0, 0, errTryAgain
	}

	for ,  := range .chunks {
		if len()- < len(.userData) {
			 = io.ErrShortBuffer
		} else {
			copy([:], .userData)
		}

		 += len(.userData)
	}

	switch {
	case  != nil:
		return , 0, 
	case :
		.unordered = .unordered[1:]
	default:
		.ordered = .ordered[1:]
		if .ssn == .nextSSN {
			.nextSSN++
		}
	}

	.subtractNumBytes()

	return , .ppi, 
}

func ( *reassemblyQueue) ( uint16) {
	// Use lastSSN to locate a chunkSet then remove it if the set has
	// not been complete
	 := []*chunkSet{}
	for ,  := range .ordered {
		if sna16LTE(.ssn, ) {
			if !.isComplete() {
				// drop the set
				for ,  := range .chunks {
					.subtractNumBytes(len(.userData))
				}

				continue
			}
		}
		 = append(, )
	}
	.ordered = 

	// Finally, forward nextSSN
	if sna16LTE(.nextSSN, ) {
		.nextSSN =  + 1
	}
}

func ( *reassemblyQueue) ( uint32) {
	// Remove all fragments in the unordered sets that contains chunks
	// equal to or older than `newCumulativeTSN`.
	// We know all sets in the r.unordered are complete ones.
	// Just remove chunks that are equal to or older than newCumulativeTSN
	// from the unorderedChunks
	 := -1
	for ,  := range .unorderedChunks {
		if sna32GT(.tsn, ) {
			break
		}
		 = 
	}
	if  >= 0 {
		for ,  := range .unorderedChunks[0 : +1] {
			.subtractNumBytes(len(.userData))
		}
		.unorderedChunks = .unorderedChunks[+1:]
	}
}

func ( *reassemblyQueue) ( int) {
	 := atomic.LoadUint64(&.nBytes)
	if int() >=  { //nolint:gosec // G115
		atomic.AddUint64(&.nBytes, -uint64()) //nolint:gosec // G115
	} else {
		atomic.StoreUint64(&.nBytes, 0)
	}
}

func ( *reassemblyQueue) () int {
	return int(atomic.LoadUint64(&.nBytes)) //nolint:gosec // G115
}