// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package stun

import (
	
	
	
	
	
)

const (
	// magicCookie is fixed value that aids in distinguishing STUN packets
	// from packets of other protocols when STUN is multiplexed with those
	// other protocols on the same Port.
	//
	// The magic cookie field MUST contain the fixed value 0x2112A442 in
	// network byte order.
	//
	// Defined in "STUN Message Structure", section 6.
	magicCookie         = 0x2112A442
	attributeHeaderSize = 4
	messageHeaderSize   = 20

	// TransactionIDSize is length of transaction id array (in bytes).
	TransactionIDSize = 12 // 96 bit
)

// NewTransactionID returns new random transaction ID using crypto/rand
// as source.
func () ( [TransactionIDSize]byte) {
	readFullOrPanic(rand.Reader, [:])
	return 
}

// IsMessage returns true if b looks like STUN message.
// Useful for multiplexing. IsMessage does not guarantee
// that decoding will be successful.
func ( []byte) bool {
	return len() >= messageHeaderSize && bin.Uint32([4:8]) == magicCookie
}

// New returns *Message with pre-allocated Raw.
func () *Message {
	const  = 120
	return &Message{
		Raw: make([]byte, messageHeaderSize, ),
	}
}

// ErrDecodeToNil occurs on Decode(data, nil) call.
var ErrDecodeToNil = errors.New("attempt to decode to nil message")

// Decode decodes Message from data to m, returning error if any.
func ( []byte,  *Message) error {
	if  == nil {
		return ErrDecodeToNil
	}
	.Raw = append(.Raw[:0], ...)
	return .Decode()
}

// Message represents a single STUN packet. It uses aggressive internal
// buffering to enable zero-allocation encoding and decoding,
// so there are some usage constraints:
//
//	Message, its fields, results of m.Get or any attribute a.GetFrom
//	are valid only until Message.Raw is not modified.
type Message struct {
	Type          MessageType
	Length        uint32 // len(Raw) not including header
	TransactionID [TransactionIDSize]byte
	Attributes    Attributes
	Raw           []byte
}

// MarshalBinary implements the encoding.BinaryMarshaler interface.
func ( Message) () ( []byte,  error) {
	// We can't return m.Raw, allocation is expected by implicit interface
	// contract induced by other implementations.
	 := make([]byte, len(.Raw))
	copy(, .Raw)
	return , nil
}

// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
func ( *Message) ( []byte) error {
	// We can't retain data, copy is expected by interface contract.
	.Raw = append(.Raw[:0], ...)
	return .Decode()
}

// GobEncode implements the gob.GobEncoder interface.
func ( Message) () ([]byte, error) {
	return .MarshalBinary()
}

// GobDecode implements the gob.GobDecoder interface.
func ( *Message) ( []byte) error {
	return .UnmarshalBinary()
}

// AddTo sets b.TransactionID to m.TransactionID.
//
// Implements Setter to aid in crafting responses.
func ( *Message) ( *Message) error {
	.TransactionID = .TransactionID
	.WriteTransactionID()
	return nil
}

// NewTransactionID sets m.TransactionID to random value from crypto/rand
// and returns error if any.
func ( *Message) () error {
	,  := io.ReadFull(rand.Reader, .TransactionID[:])
	if  == nil {
		.WriteTransactionID()
	}
	return 
}

func ( *Message) () string {
	 := base64.StdEncoding.EncodeToString(.TransactionID[:])
	 := ""
	for ,  := range .Attributes {
		 += fmt.Sprintf("attr%d=%s ", , .Type)
	}
	return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", .Type, .Length, len(.Attributes), , )
}

// Reset resets Message, attributes and underlying buffer length.
func ( *Message) () {
	.Raw = .Raw[:0]
	.Length = 0
	.Attributes = .Attributes[:0]
}

// grow ensures that internal buffer has n length.
func ( *Message) ( int) {
	if len(.Raw) >=  {
		return
	}
	if cap(.Raw) >=  {
		.Raw = .Raw[:]
		return
	}
	.Raw = append(.Raw, make([]byte, -len(.Raw))...)
}

// Add appends new attribute to message. Not goroutine-safe.
//
// Value of attribute is copied to internal buffer so
// it is safe to reuse v.
func ( *Message) ( AttrType,  []byte) {
	// Allocating buffer for TLV (type-length-value).
	// T = t, L = len(v), V = v.
	// m.Raw will look like:
	// [0:20]                               <- message header
	// [20:20+m.Length]                     <- existing message attributes
	// [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
	// [first:last]                         <- same as previous
	// [0 1|2 3|4    4 + len(v)]            <- mapping for allocated buffer
	//   T   L        V
	 := attributeHeaderSize + len()  // ~ len(TLV) = len(TL) + len(V)
	 := messageHeaderSize + int(.Length) // first byte number
	 :=  +                   // last byte number
	.grow()                               // growing cap(Raw) to fit TLV
	.Raw = .Raw[:]                       // now len(Raw) = last
	.Length += uint32()              // rendering length change

	// Sub-slicing internal buffer to simplify encoding.
	 := .Raw[:]           // slice for TLV
	 := [attributeHeaderSize:] // slice for V
	 := RawAttribute{
		Type:   ,              // T
		Length: uint16(len()), // L
		Value:  ,          // V
	}

	// Encoding attribute TLV to allocated buffer.
	bin.PutUint16([0:2], .Type.Value()) // T
	bin.PutUint16([2:4], .Length)       // L
	copy(, )                             // V

	// Checking that attribute value needs padding.
	if .Length%padding != 0 {
		// Performing padding.
		 := nearestPaddedValueLength(len()) - len()
		 += 
		.grow()
		// setting all padding bytes to zero
		// to prevent data leak from previous
		// data in next bytesToAdd bytes
		 = .Raw[- : ]
		for  := range  {
			[] = 0
		}
		.Raw = .Raw[:]           // increasing buffer length
		.Length += uint32() // rendering length change
	}
	.Attributes = append(.Attributes, )
	.WriteLength()
}

func attrSliceEqual(,  Attributes) bool {
	for ,  := range  {
		 := false
		for ,  := range  {
			if .Type != .Type {
				continue
			}
			if .Equal() {
				 = true
				break
			}
		}
		if ! {
			return false
		}
	}
	return true
}

func attrEqual(,  Attributes) bool {
	if  == nil &&  == nil {
		return true
	}
	if  == nil ||  == nil {
		return false
	}
	if len() != len() {
		return false
	}
	if !attrSliceEqual(, ) {
		return false
	}
	if !attrSliceEqual(, ) {
		return false
	}
	return true
}

// Equal returns true if Message b equals to m.
// Ignores m.Raw.
func ( *Message) ( *Message) bool {
	if  == nil &&  == nil {
		return true
	}
	if  == nil ||  == nil {
		return false
	}
	if .Type != .Type {
		return false
	}
	if .TransactionID != .TransactionID {
		return false
	}
	if .Length != .Length {
		return false
	}
	if !attrEqual(.Attributes, .Attributes) {
		return false
	}
	return true
}

// WriteLength writes m.Length to m.Raw.
func ( *Message) () {
	.grow(4)
	bin.PutUint16(.Raw[2:4], uint16(.Length))
}

// WriteHeader writes header to underlying buffer. Not goroutine-safe.
func ( *Message) () {
	.grow(messageHeaderSize)
	_ = .Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below

	.WriteType()
	.WriteLength()
	bin.PutUint32(.Raw[4:8], magicCookie)               // magic cookie
	copy(.Raw[8:messageHeaderSize], .TransactionID[:]) // transaction ID
}

// WriteTransactionID writes m.TransactionID to m.Raw.
func ( *Message) () {
	copy(.Raw[8:messageHeaderSize], .TransactionID[:]) // transaction ID
}

// WriteAttributes encodes all m.Attributes to m.
func ( *Message) () {
	 := .Attributes
	.Attributes = [:0]
	for ,  := range  {
		.Add(.Type, .Value)
	}
	.Attributes = 
}

// WriteType writes m.Type to m.Raw.
func ( *Message) () {
	.grow(2)
	bin.PutUint16(.Raw[0:2], .Type.Value()) // message type
}

// SetType sets m.Type and writes it to m.Raw.
func ( *Message) ( MessageType) {
	.Type = 
	.WriteType()
}

// Encode re-encodes message into m.Raw.
func ( *Message) () {
	.Raw = .Raw[:0]
	.WriteHeader()
	.Length = 0
	.WriteAttributes()
}

// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
// call result.
func ( *Message) ( io.Writer) (int64, error) {
	,  := .Write(.Raw)
	return int64(), 
}

// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
// Decodes it and return error if any. If m.Raw is too small, will return
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
//
// Can return *DecodeErr while decoding too.
func ( *Message) ( io.Reader) (int64, error) {
	 := .Raw[:cap(.Raw)]
	var (
		   int
		 error
	)
	if ,  = .Read();  != nil {
		return int64(), 
	}
	.Raw = [:]
	return int64(), .Decode()
}

// ErrUnexpectedHeaderEOF means that there were not enough bytes in
// m.Raw to read header.
var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header")

// Decode decodes m.Raw into m.
func ( *Message) () error {
	// decoding message header
	 := .Raw
	if len() < messageHeaderSize {
		return ErrUnexpectedHeaderEOF
	}
	var (
		        = bin.Uint16([0:2])      // first 2 bytes
		     = int(bin.Uint16([2:4])) // second 2 bytes
		   = bin.Uint32([4:8])      // last 4 bytes
		 = messageHeaderSize +   // len(m.Raw)
	)
	if  != magicCookie {
		 := fmt.Sprintf("%x is invalid magic cookie (should be %x)", , magicCookie)
		return newDecodeErr("message", "cookie", )
	}
	if len() <  {
		 := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(), )
		return newAttrDecodeErr("message", )
	}
	// saving header data
	.Type.ReadValue()
	.Length = uint32()
	copy(.TransactionID[:], [8:messageHeaderSize])

	.Attributes = .Attributes[:0]
	var (
		 = 0
		      = [messageHeaderSize:]
	)
	for  <  {
		// checking that we have enough bytes to read header
		if len() < attributeHeaderSize {
			 := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(), attributeHeaderSize)
			return newAttrDecodeErr("header", )
		}
		var (
			 = RawAttribute{
				Type:   compatAttrType(bin.Uint16([0:2])), // first 2 bytes
				Length: bin.Uint16([2:4]),                 // second 2 bytes
			}
			     = int(.Length)                // attribute length
			 = nearestPaddedValueLength() // expected buffer length (with padding)
		)
		 = [attributeHeaderSize:] // slicing again to simplify value read
		 += attributeHeaderSize
		if len() <  { // checking size
			 := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(), , .Type)
			return newAttrDecodeErr("value", )
		}
		.Value = [:]
		 += 
		 = [:]

		.Attributes = append(.Attributes, )
	}
	return nil
}

// Write decodes message and return error if any.
//
// Any error is unrecoverable, but message could be partially decoded.
func ( *Message) ( []byte) (int, error) {
	.Raw = append(.Raw[:0], ...)
	return len(), .Decode()
}

// CloneTo clones m to b securing any further m mutations.
func ( *Message) ( *Message) error {
	.Raw = append(.Raw[:0], .Raw...)
	return .Decode()
}

// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
type MessageClass byte

// Possible values for message class in STUN Message Type.
const (
	ClassRequest         MessageClass = 0x00 // 0b00
	ClassIndication      MessageClass = 0x01 // 0b01
	ClassSuccessResponse MessageClass = 0x02 // 0b10
	ClassErrorResponse   MessageClass = 0x03 // 0b11
)

// Common STUN message types.
var (
	// Binding request message type.
	BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
	// Binding success response message type
	BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
	// Binding error response message type.
	BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
)

func ( MessageClass) () string {
	switch  {
	case ClassRequest:
		return "request"
	case ClassIndication:
		return "indication"
	case ClassSuccessResponse:
		return "success response"
	case ClassErrorResponse:
		return "error response"
	default:
		panic("unknown message class") //nolint
	}
}

// Method is uint16 representation of 12-bit STUN method.
type Method uint16

// Possible methods for STUN Message.
const (
	MethodBinding          Method = 0x001
	MethodAllocate         Method = 0x003
	MethodRefresh          Method = 0x004
	MethodSend             Method = 0x006
	MethodData             Method = 0x007
	MethodCreatePermission Method = 0x008
	MethodChannelBind      Method = 0x009
)

// Methods from RFC 6062.
const (
	MethodConnect           Method = 0x000a
	MethodConnectionBind    Method = 0x000b
	MethodConnectionAttempt Method = 0x000c
)

func methodName() map[Method]string {
	return map[Method]string{
		MethodBinding:          "Binding",
		MethodAllocate:         "Allocate",
		MethodRefresh:          "Refresh",
		MethodSend:             "Send",
		MethodData:             "Data",
		MethodCreatePermission: "CreatePermission",
		MethodChannelBind:      "ChannelBind",

		// RFC 6062.
		MethodConnect:           "Connect",
		MethodConnectionBind:    "ConnectionBind",
		MethodConnectionAttempt: "ConnectionAttempt",
	}
}

func ( Method) () string {
	,  := methodName()[]
	if ! {
		// Falling back to hex representation.
		 = fmt.Sprintf("0x%x", uint16())
	}
	return 
}

// MessageType is STUN Message Type Field.
type MessageType struct {
	Method Method       // e.g. binding
	Class  MessageClass // e.g. request
}

// AddTo sets m type to t.
func ( MessageType) ( *Message) error {
	.SetType()
	return nil
}

// NewType returns new message type with provided method and class.
func ( Method,  MessageClass) MessageType {
	return MessageType{
		Method: ,
		Class:  ,
	}
}

const (
	methodABits = 0xf   // 0b0000000000001111
	methodBBits = 0x70  // 0b0000000001110000
	methodDBits = 0xf80 // 0b0000111110000000

	methodBShift = 1
	methodDShift = 2

	firstBit  = 0x1
	secondBit = 0x2

	c0Bit = firstBit
	c1Bit = secondBit

	classC0Shift = 4
	classC1Shift = 7
)

// Value returns bit representation of messageType.
func ( MessageType) () uint16 {
	//	 0                 1
	//	 2  3  4 5 6 7 8 9 0 1 2 3 4 5
	//	+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
	//	|M |M |M|M|M|C|M|M|M|C|M|M|M|M|
	//	|11|10|9|8|7|1|6|5|4|0|3|2|1|0|
	//	+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
	// Figure 3: Format of STUN Message Type Field

	// Warning: Abandon all hope ye who enter here.
	// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
	 := uint16(.Method)
	 :=  & methodABits // A = M * 0b0000000000001111 (right 4 bits)
	 :=  & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
	 :=  & methodDBits // D = M * 0b0000111110000000 (5 bits after B)

	// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
	 =  + ( << methodBShift) + ( << methodDShift)

	// C0 is zero bit of C, C1 is first bit.
	// C0 = C * 0b01, C1 = (C * 0b10) >> 1
	// Ct = C0 << 4 + C1 << 8.
	// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
	// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
	// (see figure 3).
	 := uint16(.Class)
	 := ( & c0Bit) << classC0Shift
	 := ( & c1Bit) << classC1Shift
	 :=  + 

	return  + 
}

// ReadValue decodes uint16 into MessageType.
func ( *MessageType) ( uint16) {
	// Decoding class.
	// We are taking first bit from v >> 4 and second from v >> 7.
	 := ( >> classC0Shift) & c0Bit
	 := ( >> classC1Shift) & c1Bit
	 :=  + 
	.Class = MessageClass()

	// Decoding method.
	 :=  & methodABits                   // A(M0-M3)
	 := ( >> methodBShift) & methodBBits // B(M4-M6)
	 := ( >> methodDShift) & methodDBits // D(M7-M11)
	 :=  +  + 
	.Method = Method()
}

func ( MessageType) () string {
	return fmt.Sprintf("%s %s", .Method, .Class)
}

// Contains return true if message contain t attribute.
func ( *Message) ( AttrType) bool {
	for ,  := range .Attributes {
		if .Type ==  {
			return true
		}
	}
	return false
}

type transactionIDValueSetter [TransactionIDSize]byte

// NewTransactionIDSetter returns new Setter that sets message transaction id
// to provided value.
func ( [TransactionIDSize]byte) Setter {
	return transactionIDValueSetter()
}

func ( transactionIDValueSetter) ( *Message) error {
	.TransactionID = 
	.WriteTransactionID()
	return nil
}