// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package turn

import (
	
	
	
	

	
	
)

var (
	errInvalidTURNFrame    = errors.New("data is not a valid TURN frame, no STUN or ChannelData found")
	errIncompleteTURNFrame = errors.New("data contains incomplete STUN or TURN frame")
)

// STUNConn wraps a net.Conn and implements
// net.PacketConn by being STUN aware and
// packetizing the stream.
type STUNConn struct {
	nextConn net.Conn
	buff     []byte
}

const (
	stunHeaderSize = 20

	channelDataLengthSize = 2
	channelDataNumberSize = channelDataLengthSize
	channelDataHeaderSize = channelDataLengthSize + channelDataNumberSize
	channelDataPadding    = 4
)

// Given a buffer give the last offset of the TURN frame
// If the buffer isn't a valid STUN or ChannelData packet,
// or the length doesn't match return false.
func consumeSingleTURNFrame( []byte) (int, error) {
	// Too short to determine if ChannelData or STUN
	if len() < 9 {
		return 0, errIncompleteTURNFrame
	}

	var  uint16
	switch {
	case stun.IsMessage():
		 = binary.BigEndian.Uint16([2:4]) + stunHeaderSize
	case proto.ChannelNumber(binary.BigEndian.Uint16([0:2])).Valid():
		 = binary.BigEndian.Uint16([channelDataNumberSize:channelDataHeaderSize])
		if  := ( + channelDataPadding) % channelDataPadding;  != 0 {
			 = ( + channelDataPadding) - 
		}

		 += channelDataHeaderSize
	case len() < stunHeaderSize:
		return 0, errIncompleteTURNFrame
	default:
		return 0, errInvalidTURNFrame
	}

	if len() < int() {
		return 0, errIncompleteTURNFrame
	}

	return int(), nil
}

// ReadFrom implements ReadFrom from net.PacketConn.
func ( *STUNConn) ( []byte) ( int,  net.Addr,  error) {
	// First pass any buffered data from previous reads
	,  = consumeSingleTURNFrame(.buff)
	if errors.Is(, errInvalidTURNFrame) {
		return 0, nil, 
	} else if  == nil {
		copy(, .buff[:])
		.buff = .buff[:]

		return , .nextConn.RemoteAddr(), nil
	}

	// Then read from the nextConn, appending to our buff
	,  = .nextConn.Read()
	if  != nil {
		return 0, nil, 
	}

	.buff = append(.buff, append([]byte{}, [:]...)...)

	return .()
}

// WriteTo implements WriteTo from net.PacketConn.
func ( *STUNConn) ( []byte,  net.Addr) ( int,  error) {
	return .nextConn.Write()
}

// Close implements Close from net.PacketConn.
func ( *STUNConn) () error {
	return .nextConn.Close()
}

// LocalAddr implements LocalAddr from net.PacketConn.
func ( *STUNConn) () net.Addr {
	return .nextConn.LocalAddr()
}

// SetDeadline implements SetDeadline from net.PacketConn.
func ( *STUNConn) ( time.Time) error {
	return .nextConn.SetDeadline()
}

// SetReadDeadline implements SetReadDeadline from net.PacketConn.
func ( *STUNConn) ( time.Time) error {
	return .nextConn.SetReadDeadline()
}

// SetWriteDeadline implements SetWriteDeadline from net.PacketConn.
func ( *STUNConn) ( time.Time) error {
	return .nextConn.SetWriteDeadline()
}

// NewSTUNConn creates a STUNConn.
func ( net.Conn) *STUNConn {
	return &STUNConn{nextConn: }
}