// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package netctx

import (
	
	
	
	
	
	
)

// ReaderFrom is an interface for context controlled packet reader.
type ReaderFrom interface {
	ReadFromContext(context.Context, []byte) (int, net.Addr, error)
}

// WriterTo is an interface for context controlled packet writer.
type WriterTo interface {
	WriteToContext(context.Context, []byte, net.Addr) (int, error)
}

// PacketConn is a wrapper of net.PacketConn using context.Context.
type PacketConn interface {
	ReaderFrom
	WriterTo
	io.Closer
	LocalAddr() net.Addr
	Conn() net.PacketConn
}

type packetConn struct {
	nextConn  net.PacketConn
	closed    chan struct{}
	closeOnce sync.Once
	readMu    sync.Mutex
	writeMu   sync.Mutex
}

// NewPacketConn creates a new PacketConn wrapping the given net.PacketConn.
func ( net.PacketConn) PacketConn {
	 := &packetConn{
		nextConn: ,
		closed:   make(chan struct{}),
	}
	return 
}

// ReadFromContext reads a packet from the connection,
// copying the payload into p. It returns the number of
// bytes copied into p and the return address that
// was on the packet.
// It returns the number of bytes read (0 <= n <= len(p))
// and any error encountered. Callers should always process
// the n > 0 bytes returned before considering the error err.
// Unlike net.PacketConn.ReadFrom(), the provided context is
// used to control timeout.
func ( *packetConn) ( context.Context,  []byte) (int, net.Addr, error) {
	.readMu.Lock()
	defer .readMu.Unlock()

	select {
	case <-.closed:
		return 0, nil, net.ErrClosed
	default:
	}

	 := make(chan struct{})
	var  sync.WaitGroup
	var  atomic.Value
	.Add(1)
	go func() {
		defer .Done()
		select {
		case <-.Done():
			// context canceled
			if  := .nextConn.SetReadDeadline(veryOld);  != nil {
				.Store()
				return
			}
			<-
			if  := .nextConn.SetReadDeadline(time.Time{});  != nil {
				.Store()
			}
		case <-:
		}
	}()

	, ,  := .nextConn.ReadFrom()

	close()
	.Wait()
	if  := .Err();  != nil &&  == 0 {
		 = 
	}
	if ,  := .Load().(error);  &&  == nil &&  != nil {
		 = 
	}
	return , , 
}

// WriteToContext writes a packet with payload p to addr.
// Unlike net.PacketConn.WriteTo(), the provided context
// is used to control timeout.
// On packet-oriented connections, write timeouts are rare.
func ( *packetConn) ( context.Context,  []byte,  net.Addr) (int, error) {
	.writeMu.Lock()
	defer .writeMu.Unlock()

	select {
	case <-.closed:
		return 0, ErrClosing
	default:
	}

	 := make(chan struct{})
	var  sync.WaitGroup
	var  atomic.Value
	.Add(1)
	go func() {
		defer .Done()
		select {
		case <-.Done():
			// context canceled
			if  := .nextConn.SetWriteDeadline(veryOld);  != nil {
				.Store()
				return
			}
			<-
			if  := .nextConn.SetWriteDeadline(time.Time{});  != nil {
				.Store()
			}
		case <-:
		}
	}()

	,  := .nextConn.WriteTo(, )

	close()
	.Wait()
	if  := .Err();  != nil &&  == 0 {
		 = 
	}
	if ,  := .Load().(error);  &&  == nil &&  != nil {
		 = 
	}
	return , 
}

// Close closes the connection.
// Any blocked ReadFromContext or WriteToContext operations will be unblocked
// and return errors.
func ( *packetConn) () error {
	 := .nextConn.Close()
	.closeOnce.Do(func() {
		.writeMu.Lock()
		.readMu.Lock()
		close(.closed)
		.readMu.Unlock()
		.writeMu.Unlock()
	})
	return 
}

// LocalAddr returns the local network address, if known.
func ( *packetConn) () net.Addr {
	return .nextConn.LocalAddr()
}

// Conn returns the underlying net.PacketConn.
func ( *packetConn) () net.PacketConn {
	return .nextConn
}