// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package ice

import (
	
	
	
	
	
	
	

	
	
)

// ErrGetTransportAddress can't convert net.Addr to underlying type (UDPAddr or TCPAddr).
var ErrGetTransportAddress = errors.New("failed to get local transport address")

// TCPMux is allows grouping multiple TCP net.Conns and using them like UDP
// net.PacketConns. The main implementation of this is TCPMuxDefault, and this
// interface exists to allow mocking in tests.
type TCPMux interface {
	io.Closer
	GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error)
	RemoveConnByUfrag(ufrag string)
}

type ipAddr string

// TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by
// Ufrag. It is a default implementation of TCPMux interface.
type TCPMuxDefault struct {
	params *TCPMuxParams
	closed bool

	// connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag and local address
	connsIPv4, connsIPv6 map[string]map[ipAddr]*tcpPacketConn

	mu sync.Mutex
	wg sync.WaitGroup
}

// TCPMuxParams are parameters for TCPMux.
type TCPMuxParams struct {
	Listener       net.Listener
	Logger         logging.LeveledLogger
	ReadBufferSize int

	// Maximum buffer size for write op. 0 means no write buffer, the write op will block until the whole packet is written
	// if the write buffer is full, the subsequent write packet will be dropped until it has enough space.
	// a default 4MB is recommended.
	WriteBufferSize int

	// A new established connection will be removed if the first STUN binding request is not received within this timeout,
	// avoiding the client with bad network or attacker to create a lot of empty connections.
	// Default 30s timeout will be used if not set.
	FirstStunBindTimeout time.Duration

	// TCPMux will create connection from STUN binding request with an unknown username, if
	// the connection is not used in the timeout, it will be removed to avoid resource leak / attack.
	// Default 30s timeout will be used if not set.
	AliveDurationForConnFromStun time.Duration
}

// NewTCPMuxDefault creates a new instance of TCPMuxDefault.
func ( TCPMuxParams) *TCPMuxDefault {
	if .Logger == nil {
		.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
	}

	if .FirstStunBindTimeout == 0 {
		.FirstStunBindTimeout = 30 * time.Second
	}

	if .AliveDurationForConnFromStun == 0 {
		.AliveDurationForConnFromStun = 30 * time.Second
	}

	 := &TCPMuxDefault{
		params: &,

		connsIPv4: map[string]map[ipAddr]*tcpPacketConn{},
		connsIPv6: map[string]map[ipAddr]*tcpPacketConn{},
	}

	.wg.Add(1)
	go func() {
		defer .wg.Done()
		.start()
	}()

	return 
}

func ( *TCPMuxDefault) () {
	.params.Logger.Infof("Listening TCP on %s", .params.Listener.Addr())
	for {
		,  := .params.Listener.Accept()
		if  != nil {
			.params.Logger.Infof("Error accepting connection: %s", )

			return
		}

		.params.Logger.Debugf("Accepted connection from: %s to %s", .RemoteAddr(), .LocalAddr())

		.wg.Add(1)
		go func() {
			defer .wg.Done()
			.handleConn()
		}()
	}
}

// LocalAddr returns the listening address of this TCPMuxDefault.
func ( *TCPMuxDefault) () net.Addr {
	return .params.Listener.Addr()
}

// GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
func ( *TCPMuxDefault) ( string,  bool,  net.IP) (net.PacketConn, error) {
	.mu.Lock()
	defer .mu.Unlock()

	if .closed {
		return nil, io.ErrClosedPipe
	}

	if ,  := .getConn(, , );  {
		.ClearAliveTimer()

		return , nil
	}

	return .createConn(, , , false)
}

func ( *TCPMuxDefault) ( string,  bool,  net.IP,  bool) (*tcpPacketConn, error) {
	,  := .LocalAddr().(*net.TCPAddr)
	if ! {
		return nil, ErrGetTransportAddress
	}
	 := *
	// Note: this is missing zone for IPv6
	.IP = 

	var  time.Duration
	if  {
		 = .params.AliveDurationForConnFromStun
	}

	 := newTCPPacketConn(tcpPacketParams{
		ReadBuffer:    .params.ReadBufferSize,
		WriteBuffer:   .params.WriteBufferSize,
		LocalAddr:     &,
		Logger:        .params.Logger,
		AliveDuration: ,
	})

	var  map[ipAddr]*tcpPacketConn
	if  {
		if ,  = .connsIPv6[]; ! {
			 = make(map[ipAddr]*tcpPacketConn)
			.connsIPv6[] = 
		}
	} else {
		if ,  = .connsIPv4[]; ! {
			 = make(map[ipAddr]*tcpPacketConn)
			.connsIPv4[] = 
		}
	}
	// Note: this is missing zone for IPv6
	 := ipAddr(.String())
	[] = 

	.wg.Add(1)
	go func() {
		defer .wg.Done()
		<-.CloseChannel()
		.removeConnByUfragAndLocalHost(, )
	}()

	return , nil
}

func ( *TCPMuxDefault) ( io.Closer) {
	 := .Close()
	if  != nil {
		.params.Logger.Warnf("Error closing connection: %s", )
	}
}

func ( *TCPMuxDefault) ( net.Conn) { //nolint:cyclop
	 := make([]byte, 512)

	if .params.FirstStunBindTimeout > 0 {
		if  := .SetReadDeadline(time.Now().Add(.params.FirstStunBindTimeout));  != nil {
			.params.Logger.Warnf(
				"Failed to set read deadline for first STUN message: %s to %s , err: %s",
				.RemoteAddr(),
				.LocalAddr(),
				,
			)
		}
	}
	,  := readStreamingPacket(, )
	if  != nil {
		if errors.Is(, io.ErrShortBuffer) {
			.params.Logger.Warnf("Buffer too small for first packet from %s: %s", .RemoteAddr(), )
		} else {
			.params.Logger.Warnf("Error reading first packet from %s: %s", .RemoteAddr(), )
		}
		.closeAndLogError()

		return
	}
	if  = .SetReadDeadline(time.Time{});  != nil {
		.params.Logger.Warnf("Failed to reset read deadline from %s: %s", .RemoteAddr(), )
	}

	 = [:]

	 := &stun.Message{
		Raw: make([]byte, len()),
	}
	// Explicitly copy raw buffer so Message can own the memory.
	copy(.Raw, )
	if  = .Decode();  != nil {
		.closeAndLogError()
		.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", .RemoteAddr(), .LocalAddr(), )

		return
	}

	if  == nil || .Type.Method != stun.MethodBinding { // Not a STUN
		.closeAndLogError()
		.params.Logger.Warnf("Not a STUN message from %s to %s", .RemoteAddr(), .LocalAddr())

		return
	}

	for ,  := range .Attributes {
		.params.Logger.Debugf("Message attribute: %s", .String())
	}

	,  := .Get(stun.AttrUsername)
	if  != nil {
		.closeAndLogError()
		.params.Logger.Warnf(
			"No Username attribute in STUN message from %s to %s",
			.RemoteAddr(),
			.LocalAddr(),
		)

		return
	}

	 := strings.Split(string(), ":")[0]
	.params.Logger.Debugf("Ufrag: %s", )

	, ,  := net.SplitHostPort(.RemoteAddr().String())
	if  != nil {
		.closeAndLogError()
		.params.Logger.Warnf(
			"Failed to get host in STUN message from %s to %s",
			.RemoteAddr(),
			.LocalAddr(),
		)

		return
	}

	 := net.ParseIP().To4() == nil

	,  := .LocalAddr().(*net.TCPAddr)
	if ! {
		.closeAndLogError()
		.params.Logger.Warnf(
			"Failed to get local tcp address in STUN message from %s to %s",
			.RemoteAddr(),
			.LocalAddr(),
		)

		return
	}
	.mu.Lock()

	,  := .getConn(, , .IP)
	if ! {
		,  = .createConn(, , .IP, true)
		if  != nil {
			.mu.Unlock()
			.closeAndLogError()
			.params.Logger.Warnf(
				"Failed to create packetConn for STUN message from %s to %s",
				.RemoteAddr(),
				.LocalAddr(),
			)

			return
		}
	}
	.mu.Unlock()

	if  := .AddConn(, );  != nil {
		.closeAndLogError()
		.params.Logger.Warnf(
			"Error adding conn to tcpPacketConn from %s to %s: %s",
			.RemoteAddr(),
			.LocalAddr(),
			,
		)

		return
	}
}

// Close closes the listener and waits for all goroutines to exit.
func ( *TCPMuxDefault) () error {
	.mu.Lock()
	.closed = true

	for ,  := range .connsIPv4 {
		for ,  := range  {
			.closeAndLogError()
		}
	}
	for ,  := range .connsIPv6 {
		for ,  := range  {
			.closeAndLogError()
		}
	}

	.connsIPv4 = map[string]map[ipAddr]*tcpPacketConn{}
	.connsIPv6 = map[string]map[ipAddr]*tcpPacketConn{}

	 := .params.Listener.Close()

	.mu.Unlock()

	.wg.Wait()

	return 
}

// RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag.
func ( *TCPMuxDefault) ( string) {
	 := make([]*tcpPacketConn, 0, 4)

	// Keep lock section small to avoid deadlock with conn lock
	.mu.Lock()
	if ,  := .connsIPv4[];  {
		delete(.connsIPv4, )
		for ,  := range  {
			 = append(, )
		}
	}
	if ,  := .connsIPv6[];  {
		delete(.connsIPv6, )
		for ,  := range  {
			 = append(, )
		}
	}

	.mu.Unlock()

	// Close the connections outside the critical section to avoid
	// deadlocking TCP mux if (*tcpPacketConn).Close() blocks.
	for ,  := range  {
		.closeAndLogError()
	}
}

func ( *TCPMuxDefault) ( string,  ipAddr) {
	 := make([]*tcpPacketConn, 0, 4)

	// Keep lock section small to avoid deadlock with conn lock
	.mu.Lock()
	if ,  := .connsIPv4[];  {
		if ,  := [];  {
			delete(, )
			if len() == 0 {
				delete(.connsIPv4, )
			}
			 = append(, )
		}
	}
	if ,  := .connsIPv6[];  {
		if ,  := [];  {
			delete(, )
			if len() == 0 {
				delete(.connsIPv6, )
			}
			 = append(, )
		}
	}
	.mu.Unlock()

	// Close the connections outside the critical section to avoid
	// deadlocking TCP mux if (*tcpPacketConn).Close() blocks.
	for ,  := range  {
		.closeAndLogError()
	}
}

func ( *TCPMuxDefault) ( string,  bool,  net.IP) ( *tcpPacketConn,  bool) {
	var  map[ipAddr]*tcpPacketConn
	if  {
		,  = .connsIPv6[]
	} else {
		,  = .connsIPv4[]
	}
	if  != nil {
		// Note: this is missing zone for IPv6
		 := ipAddr(.String())
		,  = []
	}

	return
}

const streamingPacketHeaderLen = 2

// readStreamingPacket reads 1 packet from stream
// read packet  bytes https://tools.ietf.org/html/rfc4571#section-2
// 2-byte length header prepends each packet:
//
//	 0                   1                   2                   3
//	 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
//	-----------------------------------------------------------------
//	|             LENGTH            |  RTP or RTCP packet ...       |
//	-----------------------------------------------------------------
func readStreamingPacket( net.Conn,  []byte) (int, error) {
	 := make([]byte, streamingPacketHeaderLen)
	var ,  int
	var  error

	for  < streamingPacketHeaderLen {
		if ,  = .Read([:streamingPacketHeaderLen]);  != nil {
			return 0, 
		}
		 += 
	}

	 := int(binary.BigEndian.Uint16())

	if  > cap() {
		return , io.ErrShortBuffer
	}

	 = 0
	for  <  {
		if ,  = .Read([:]);  != nil {
			return 0, 
		}
		 += 
	}

	return , nil
}

func writeStreamingPacket( net.Conn,  []byte) (int, error) {
	 := make([]byte, streamingPacketHeaderLen+len())
	binary.BigEndian.PutUint16(, uint16(len())) //nolint:gosec // G115
	copy([2:], )

	,  := .Write()
	if  != nil {
		return 0, 
	}

	return  - streamingPacketHeaderLen, nil
}