package udp
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
idtlsnet "github.com/pion/dtls/v3/internal/net"
dtlsnet "github.com/pion/dtls/v3/pkg/net"
"github.com/pion/transport/v3/deadline"
)
const (
receiveMTU = 8192
defaultListenBacklog = 128
)
var (
ErrClosedListener = errors .New ("udp: listener closed" )
ErrListenQueueExceeded = errors .New ("udp: listen queue exceeded" )
)
type listener struct {
pConn *net .UDPConn
accepting atomic .Value
acceptCh chan *PacketConn
doneCh chan struct {}
doneOnce sync .Once
acceptFilter func ([]byte ) bool
datagramRouter func ([]byte ) (string , bool )
connIdentifier func ([]byte ) (string , bool )
connLock sync .Mutex
conns map [string ]*PacketConn
connWG sync .WaitGroup
readWG sync .WaitGroup
errClose atomic .Value
readDoneCh chan struct {}
errRead atomic .Value
}
func (l *listener ) Accept () (net .PacketConn , net .Addr , error ) {
select {
case c := <- l .acceptCh :
l .connWG .Add (1 )
return c , c .raddr , nil
case <- l .readDoneCh :
err , _ := l .errRead .Load ().(error )
return nil , nil , err
case <- l .doneCh :
return nil , nil , ErrClosedListener
}
}
func (l *listener ) Close () error {
var err error
l .doneOnce .Do (func () {
l .accepting .Store (false )
close (l .doneCh )
l .connLock .Lock ()
lclose :
for {
select {
case c := <- l .acceptCh :
close (c .doneCh )
if id := c .id .Load (); id != nil {
delete (l .conns , id .(string ))
}
if c .rmraddr .Load () == nil {
delete (l .conns , c .raddr .String ())
c .rmraddr .Store (true )
}
default :
break lclose
}
}
nConns := len (l .conns )
l .connLock .Unlock ()
l .connWG .Done ()
if nConns == 0 {
l .readWG .Wait ()
if errClose , ok := l .errClose .Load ().(error ); ok {
err = errClose
}
} else {
err = nil
}
})
return err
}
func (l *listener ) Addr () net .Addr {
return l .pConn .LocalAddr ()
}
type ListenConfig struct {
Backlog int
AcceptFilter func ([]byte ) bool
DatagramRouter func ([]byte ) (string , bool )
ConnectionIdentifier func ([]byte ) (string , bool )
}
func (lc *ListenConfig ) Listen (network string , laddr *net .UDPAddr ) (dtlsnet .PacketListener , error ) {
if lc .Backlog == 0 {
lc .Backlog = defaultListenBacklog
}
conn , err := net .ListenUDP (network , laddr )
if err != nil {
return nil , err
}
packetListener := &listener {
pConn : conn ,
acceptCh : make (chan *PacketConn , lc .Backlog ),
conns : make (map [string ]*PacketConn ),
doneCh : make (chan struct {}),
acceptFilter : lc .AcceptFilter ,
datagramRouter : lc .DatagramRouter ,
connIdentifier : lc .ConnectionIdentifier ,
readDoneCh : make (chan struct {}),
}
packetListener .accepting .Store (true )
packetListener .connWG .Add (1 )
packetListener .readWG .Add (2 )
go packetListener .readLoop ()
go func () {
packetListener .connWG .Wait ()
if err := packetListener .pConn .Close (); err != nil {
packetListener .errClose .Store (err )
}
packetListener .readWG .Done ()
}()
return packetListener , nil
}
func Listen (network string , laddr *net .UDPAddr ) (dtlsnet .PacketListener , error ) {
return (&ListenConfig {}).Listen (network , laddr )
}
func (l *listener ) readLoop () {
defer l .readWG .Done ()
defer close (l .readDoneCh )
buf := make ([]byte , receiveMTU )
for {
n , raddr , err := l .pConn .ReadFrom (buf )
if err != nil {
l .errRead .Store (err )
return
}
conn , ok , err := l .getConn (raddr , buf [:n ])
if err != nil {
continue
}
if ok {
_, _ = conn .buffer .WriteTo (buf [:n ], raddr )
}
}
}
func (l *listener ) getConn (raddr net .Addr , buf []byte ) (*PacketConn , bool , error ) {
l .connLock .Lock ()
defer l .connLock .Unlock ()
if l .datagramRouter != nil {
if id , ok := l .datagramRouter (buf ); ok {
if conn , ok := l .conns [id ]; ok {
return conn , true , nil
}
}
}
conn , ok := l .conns [raddr .String ()]
if !ok {
if isAccepting , ok := l .accepting .Load ().(bool ); !isAccepting || !ok {
return nil , false , ErrClosedListener
}
if l .acceptFilter != nil {
if !l .acceptFilter (buf ) {
return nil , false , nil
}
}
conn = l .newPacketConn (raddr )
select {
case l .acceptCh <- conn :
l .conns [raddr .String ()] = conn
default :
return nil , false , ErrListenQueueExceeded
}
}
return conn , true , nil
}
type PacketConn struct {
listener *listener
raddr net .Addr
rmraddr atomic .Value
id atomic .Value
buffer *idtlsnet .PacketBuffer
doneCh chan struct {}
doneOnce sync .Once
writeDeadline *deadline .Deadline
}
func (l *listener ) newPacketConn (raddr net .Addr ) *PacketConn {
return &PacketConn {
listener : l ,
raddr : raddr ,
buffer : idtlsnet .NewPacketBuffer (),
doneCh : make (chan struct {}),
writeDeadline : deadline .New (),
}
}
func (c *PacketConn ) ReadFrom (buff []byte ) (int , net .Addr , error ) {
return c .buffer .ReadFrom (buff )
}
func (c *PacketConn ) WriteTo (payload []byte , addr net .Addr ) (n int , err error ) {
if c .listener .connIdentifier != nil {
id := c .id .Load ()
if id == nil {
candidate , ok := c .listener .connIdentifier (payload )
if ok {
c .listener .connLock .Lock ()
c .listener .conns [candidate ] = c
c .listener .connLock .Unlock ()
c .id .Store (candidate )
}
}
if id != nil && c .rmraddr .Load () == nil && addr .String () != c .raddr .String () {
c .listener .connLock .Lock ()
delete (c .listener .conns , c .raddr .String ())
c .rmraddr .Store (true )
c .listener .connLock .Unlock ()
}
}
select {
case <- c .writeDeadline .Done ():
return 0 , context .DeadlineExceeded
default :
}
return c .listener .pConn .WriteTo (payload , addr )
}
func (c *PacketConn ) Close () error {
var err error
c .doneOnce .Do (func () {
c .listener .connWG .Done ()
close (c .doneCh )
c .listener .connLock .Lock ()
if id := c .id .Load (); id != nil {
delete (c .listener .conns , id .(string ))
}
if c .rmraddr .Load () == nil {
delete (c .listener .conns , c .raddr .String ())
c .rmraddr .Store (true )
}
nConns := len (c .listener .conns )
c .listener .connLock .Unlock ()
if isAccepting , ok := c .listener .accepting .Load ().(bool ); nConns == 0 && !isAccepting && ok {
c .listener .readWG .Wait ()
if errClose , ok := c .listener .errClose .Load ().(error ); ok {
err = errClose
}
} else {
err = nil
}
if errBuf := c .buffer .Close (); errBuf != nil && err == nil {
err = errBuf
}
})
return err
}
func (c *PacketConn ) LocalAddr () net .Addr {
return c .listener .pConn .LocalAddr ()
}
func (c *PacketConn ) SetDeadline (t time .Time ) error {
c .writeDeadline .Set (t )
return c .SetReadDeadline (t )
}
func (c *PacketConn ) SetReadDeadline (t time .Time ) error {
return c .buffer .SetReadDeadline (t )
}
func (c *PacketConn ) SetWriteDeadline (t time .Time ) error {
c .writeDeadline .Set (t )
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .