// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package ice

import (
	
	
	
	
	
	
	

	
	
	
	
)

// UDPMux allows multiple connections to go over a single UDP port.
type UDPMux interface {
	io.Closer
	GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
	RemoveConnByUfrag(ufrag string)
	GetListenAddresses() []net.Addr
}

// UDPMuxDefault is an implementation of the interface.
type UDPMuxDefault struct {
	params UDPMuxParams

	closedChan chan struct{}
	closeOnce  sync.Once

	// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
	connsIPv4, connsIPv6 map[string]*udpMuxedConn

	addressMapMu sync.RWMutex
	addressMap   map[ipPort]*udpMuxedConn

	// Buffer pool to recycle buffers for net.UDPAddr encodes/decodes
	pool *sync.Pool

	mu sync.Mutex

	// For UDP connection listen at unspecified address
	localAddrsForUnspecified []net.Addr
}

// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
	Logger        logging.LeveledLogger
	UDPConn       net.PacketConn
	UDPConnString string

	// Required for gathering local addresses
	// in case a un UDPConn is passed which does not
	// bind to a specific local address.
	Net transport.Net
}

// NewUDPMuxDefault creates an implementation of UDPMux.
func ( UDPMuxParams) *UDPMuxDefault { //nolint:cyclop
	if .Logger == nil {
		.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
	}

	var  []net.Addr
	if ,  := .UDPConn.LocalAddr().(*net.UDPAddr); ! { //nolint:nestif
		.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", .UDPConn.LocalAddr())
	} else if  && .IP.IsUnspecified() {
		// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
		// it will break the applications that are already using unspecified UDP connection
		// with UDPMuxDefault, so print a warn log and create a local address list for mux.
		.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
		var  []NetworkType
		switch {
		case .IP.To4() != nil:
			 = []NetworkType{NetworkTypeUDP4}

		case .IP.To16() != nil:
			 = []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}

		default:
			.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", .UDPConn.LocalAddr())
		}
		if len() > 0 {
			if .Net == nil {
				var  error
				if .Net,  = stdnet.NewNet();  != nil {
					.Logger.Errorf("Failed to get create network: %v", )
				}
			}

			, ,  := localInterfaces(.Net, nil, nil, , true)
			if  == nil {
				for ,  := range  {
					 = append(, &net.UDPAddr{
						IP:   .AsSlice(),
						Port: .Port,
						Zone: .Zone(),
					})
				}
			} else {
				.Logger.Errorf("Failed to get local interfaces for unspecified addr: %v", )
			}
		}
	}
	.UDPConnString = .UDPConn.LocalAddr().String()

	 := &UDPMuxDefault{
		addressMap: map[ipPort]*udpMuxedConn{},
		params:     ,
		connsIPv4:  make(map[string]*udpMuxedConn),
		connsIPv6:  make(map[string]*udpMuxedConn),
		closedChan: make(chan struct{}, 1),
		pool: &sync.Pool{
			New: func() interface{} {
				// Big enough buffer to fit both packet and address
				return newBufferHolder(receiveMTU)
			},
		},
		localAddrsForUnspecified: ,
	}

	go .connWorker()

	return 
}

// LocalAddr returns the listening address of this UDPMuxDefault.
func ( *UDPMuxDefault) () net.Addr {
	return .params.UDPConn.LocalAddr()
}

// GetListenAddresses returns the list of addresses that this mux is listening on.
func ( *UDPMuxDefault) () []net.Addr {
	if len(.localAddrsForUnspecified) > 0 {
		return .localAddrsForUnspecified
	}

	return []net.Addr{.LocalAddr()}
}

// GetConn returns a PacketConn given the connection's ufrag and network address.
// creates the connection if an existing one can't be found.
func ( *UDPMuxDefault) ( string,  net.Addr) (net.PacketConn, error) {
	// don't check addr for mux using unspecified address
	if len(.localAddrsForUnspecified) == 0 && .params.UDPConnString != .String() {
		return nil, errInvalidAddress
	}

	var  bool
	if ,  := .(*net.UDPAddr);  != nil && .IP.To4() == nil {
		 = true
	}
	.mu.Lock()
	defer .mu.Unlock()

	if .IsClosed() {
		return nil, io.ErrClosedPipe
	}

	if ,  := .getConn(, );  {
		return , nil
	}

	 := .createMuxedConn()
	go func() {
		<-.CloseChannel()
		.RemoveConnByUfrag()
	}()

	if  {
		.connsIPv6[] = 
	} else {
		.connsIPv4[] = 
	}

	return , nil
}

// RemoveConnByUfrag stops and removes the muxed packet connection.
func ( *UDPMuxDefault) ( string) {
	 := make([]*udpMuxedConn, 0, 2)

	// Keep lock section small to avoid deadlock with conn lock.
	.mu.Lock()
	if ,  := .connsIPv4[];  {
		delete(.connsIPv4, )
		 = append(, )
	}
	if ,  := .connsIPv6[];  {
		delete(.connsIPv6, )
		 = append(, )
	}
	.mu.Unlock()

	if len() == 0 {
		// No need to lock if no connection was found.
		return
	}

	.addressMapMu.Lock()
	defer .addressMapMu.Unlock()

	for ,  := range  {
		 := .getAddresses()
		for ,  := range  {
			delete(.addressMap, )
		}
	}
}

// IsClosed returns true if the mux had been closed.
func ( *UDPMuxDefault) () bool {
	select {
	case <-.closedChan:
		return true
	default:
		return false
	}
}

// Close the mux, no further connections could be created.
func ( *UDPMuxDefault) () error {
	var  error
	.closeOnce.Do(func() {
		.mu.Lock()
		defer .mu.Unlock()

		for ,  := range .connsIPv4 {
			_ = .Close()
		}
		for ,  := range .connsIPv6 {
			_ = .Close()
		}

		.connsIPv4 = make(map[string]*udpMuxedConn)
		.connsIPv6 = make(map[string]*udpMuxedConn)

		close(.closedChan)

		_ = .params.UDPConn.Close()
	})

	return 
}

func ( *UDPMuxDefault) ( []byte,  net.Addr) ( int,  error) {
	return .params.UDPConn.WriteTo(, )
}

func ( *UDPMuxDefault) ( *udpMuxedConn,  ipPort) {
	if .IsClosed() {
		return
	}

	.addressMapMu.Lock()
	defer .addressMapMu.Unlock()

	,  := .addressMap[]
	if  {
		.removeAddress()
	}
	.addressMap[] = 

	.params.Logger.Debugf("Registered %s for %s", .addr.String(), .params.Key)
}

func ( *UDPMuxDefault) ( string) *udpMuxedConn {
	 := newUDPMuxedConn(&udpMuxedConnParams{
		Mux:       ,
		Key:       ,
		AddrPool:  .pool,
		LocalAddr: .LocalAddr(),
		Logger:    .params.Logger,
	})

	return 
}

func ( *UDPMuxDefault) () { //nolint:cyclop
	 := .params.Logger

	defer func() {
		_ = .Close()
	}()

	 := make([]byte, receiveMTU)
	for {
		, ,  := .params.UDPConn.ReadFrom()
		if .IsClosed() {
			return
		} else if  != nil {
			if os.IsTimeout() {
				continue
			} else if !errors.Is(, io.EOF) {
				.Errorf("Failed to read UDP packet: %v", )
			}

			return
		}

		,  := .(*net.UDPAddr)
		if ! {
			.Errorf("Underlying PacketConn did not return a UDPAddr")

			return
		}
		,  := newIPPort(.IP, .Zone, uint16(.Port)) //nolint:gosec
		if  != nil {
			.Errorf("Failed to create a new IP/Port host pair")

			return
		}

		// If we have already seen this address dispatch to the appropriate destination
		.addressMapMu.Lock()
		 := .addressMap[]
		.addressMapMu.Unlock()

		// If we haven't seen this address before but is a STUN packet lookup by ufrag
		if  == nil && stun.IsMessage([:]) {
			 := &stun.Message{
				Raw: append([]byte{}, [:]...),
			}

			if  = .Decode();  != nil {
				.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", .String(), )

				continue
			}

			,  := .Get(stun.AttrUsername)
			if  != nil {
				.params.Logger.Warnf("No Username attribute in STUN message from %s", .String())

				continue
			}

			 := strings.Split(string(), ":")[0]
			 := .IP.To4() == nil

			.mu.Lock()
			, _ = .getConn(, )
			.mu.Unlock()
		}

		if  == nil {
			.params.Logger.Tracef("Dropping packet from %s, addr: %s", .addr, )

			continue
		}

		if  = .writePacket([:], );  != nil {
			.params.Logger.Errorf("Failed to write packet: %v", )
		}
	}
}

func ( *UDPMuxDefault) ( string,  bool) ( *udpMuxedConn,  bool) {
	if  {
		,  = .connsIPv6[]
	} else {
		,  = .connsIPv4[]
	}

	return
}

type bufferHolder struct {
	next *bufferHolder
	buf  []byte
	addr *net.UDPAddr
}

func newBufferHolder( int) *bufferHolder {
	return &bufferHolder{
		buf: make([]byte, ),
	}
}

func ( *bufferHolder) () {
	.next = nil
	.addr = nil
}

type ipPort struct {
	addr netip.Addr
	port uint16
}

// newIPPort create a custom type of address based on netip.Addr and
// port. The underlying ip address passed is converted to IPv6 format
// to simplify ip address handling.
func newIPPort( net.IP,  string,  uint16) (ipPort, error) {
	,  := netip.AddrFromSlice(.To16())
	if ! {
		return ipPort{}, errInvalidIPAddress
	}

	return ipPort{
		addr: .WithZone(),
		port: ,
	}, nil
}