// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package dtls

import (
	
	
	
)

// 2 megabytes
const fragmentBufferMaxSize = 2000000

type fragment struct {
	recordLayerHeader recordlayer.Header
	handshakeHeader   handshake.Header
	data              []byte
}

type fragmentBuffer struct {
	// map of MessageSequenceNumbers that hold slices of fragments
	cache map[uint16][]*fragment

	currentMessageSequenceNumber uint16
}

func newFragmentBuffer() *fragmentBuffer {
	return &fragmentBuffer{cache: map[uint16][]*fragment{}}
}

// current total size of buffer
func ( *fragmentBuffer) () int {
	 := 0
	for  := range .cache {
		for  := range .cache[] {
			 += len(.cache[][].data)
		}
	}
	return 
}

// Attempts to push a DTLS packet to the fragmentBuffer
// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
// when an error returns it is fatal, and the DTLS connection should be stopped
func ( *fragmentBuffer) ( []byte) (bool, error) {
	if .size()+len() >= fragmentBufferMaxSize {
		return false, errFragmentBufferOverflow
	}

	 := new(fragment)
	if  := .recordLayerHeader.Unmarshal();  != nil {
		return false, 
	}

	// fragment isn't a handshake, we don't need to handle it
	if .recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
		return false, nil
	}

	for  = [recordlayer.HeaderSize:]; len() != 0;  = new(fragment) {
		if  := .handshakeHeader.Unmarshal();  != nil {
			return false, 
		}

		if ,  := .cache[.handshakeHeader.MessageSequence]; ! {
			.cache[.handshakeHeader.MessageSequence] = []*fragment{}
		}

		// end index should be the length of handshake header but if the handshake
		// was fragmented, we should keep them all
		 := int(handshake.HeaderLength + .handshakeHeader.Length)
		if  := len();  >  {
			 = 
		}

		// Discard all headers, when rebuilding the packet we will re-build
		.data = append([]byte{}, [handshake.HeaderLength:]...)
		.cache[.handshakeHeader.MessageSequence] = append(.cache[.handshakeHeader.MessageSequence], )
		 = [:]
	}

	return true, nil
}

func ( *fragmentBuffer) () ( []byte,  uint16) {
	,  := .cache[.currentMessageSequenceNumber]
	if ! {
		return nil, 0
	}

	// Go doesn't support recursive lambdas
	var  func( uint32) bool

	 := []byte{}
	 = func( uint32) bool {
		for ,  := range  {
			if .handshakeHeader.FragmentOffset ==  {
				 := (.handshakeHeader.FragmentOffset + .handshakeHeader.FragmentLength)
				if  != .handshakeHeader.Length && .handshakeHeader.FragmentLength != 0 {
					if !() {
						return false
					}
				}

				 = append(.data, ...)
				return true
			}
		}
		return false
	}

	// Recursively collect up
	if !(0) {
		return nil, 0
	}

	 := [0].handshakeHeader
	.FragmentOffset = 0
	.FragmentLength = .Length

	,  := .Marshal()
	if  != nil {
		return nil, 0
	}

	 := [0].recordLayerHeader.Epoch

	delete(.cache, .currentMessageSequenceNumber)
	.currentMessageSequenceNumber++
	return append(, ...), 
}