package udp
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v2/deadline"
"github.com/pion/transport/v2/packetio"
"golang.org/x/net/ipv4"
)
const (
receiveMTU = 8192
sendMTU = 1500
defaultListenBacklog = 128
)
var (
ErrClosedListener = errors .New ("udp: listener closed" )
ErrListenQueueExceeded = errors .New ("udp: listen queue exceeded" )
ErrInvalidBatchConfig = errors .New ("udp: invalid batch config" )
)
type listener struct {
pConn net .PacketConn
readBatchSize int
accepting atomic .Value
acceptCh chan *Conn
doneCh chan struct {}
doneOnce sync .Once
acceptFilter func ([]byte ) bool
connLock sync .Mutex
conns map [string ]*Conn
connWG *sync .WaitGroup
readWG sync .WaitGroup
errClose atomic .Value
readDoneCh chan struct {}
errRead atomic .Value
logger logging .LeveledLogger
}
func (l *listener ) Accept () (net .Conn , error ) {
select {
case c := <- l .acceptCh :
l .connWG .Add (1 )
return c , nil
case <- l .readDoneCh :
err , _ := l .errRead .Load ().(error )
return nil , err
case <- l .doneCh :
return nil , ErrClosedListener
}
}
func (l *listener ) Close () error {
var err error
l .doneOnce .Do (func () {
l .accepting .Store (false )
close (l .doneCh )
l .connLock .Lock ()
lclose :
for {
select {
case c := <- l .acceptCh :
close (c .doneCh )
delete (l .conns , c .rAddr .String ())
default :
break lclose
}
}
nConns := len (l .conns )
l .connLock .Unlock ()
l .connWG .Done ()
if nConns == 0 {
l .readWG .Wait ()
if errClose , ok := l .errClose .Load ().(error ); ok {
err = errClose
}
} else {
err = nil
}
})
return err
}
func (l *listener ) Addr () net .Addr {
return l .pConn .LocalAddr ()
}
type BatchIOConfig struct {
Enable bool
ReadBatchSize int
WriteBatchSize int
WriteBatchInterval time .Duration
}
type ListenConfig struct {
Backlog int
AcceptFilter func ([]byte ) bool
ReadBufferSize int
WriteBufferSize int
Batch BatchIOConfig
LoggerFactory logging .LoggerFactory
}
func (lc *ListenConfig ) Listen (network string , laddr *net .UDPAddr ) (net .Listener , error ) {
if lc .Backlog == 0 {
lc .Backlog = defaultListenBacklog
}
if lc .Batch .Enable && (lc .Batch .WriteBatchSize <= 0 || lc .Batch .WriteBatchInterval <= 0 ) {
return nil , ErrInvalidBatchConfig
}
conn , err := net .ListenUDP (network , laddr )
if err != nil {
return nil , err
}
if lc .ReadBufferSize > 0 {
_ = conn .SetReadBuffer (lc .ReadBufferSize )
}
if lc .WriteBufferSize > 0 {
_ = conn .SetWriteBuffer (lc .WriteBufferSize )
}
loggerFactory := lc .LoggerFactory
if loggerFactory == nil {
loggerFactory = logging .NewDefaultLoggerFactory ()
}
logger := loggerFactory .NewLogger ("transport" )
l := &listener {
pConn : conn ,
acceptCh : make (chan *Conn , lc .Backlog ),
conns : make (map [string ]*Conn ),
doneCh : make (chan struct {}),
acceptFilter : lc .AcceptFilter ,
connWG : &sync .WaitGroup {},
readDoneCh : make (chan struct {}),
logger : logger ,
}
if lc .Batch .Enable {
l .pConn = NewBatchConn (conn , lc .Batch .WriteBatchSize , lc .Batch .WriteBatchInterval )
l .readBatchSize = lc .Batch .ReadBatchSize
}
l .accepting .Store (true )
l .connWG .Add (1 )
l .readWG .Add (2 )
go l .readLoop ()
go func () {
l .connWG .Wait ()
if err := l .pConn .Close (); err != nil {
l .errClose .Store (err )
}
l .readWG .Done ()
}()
return l , nil
}
func Listen (network string , laddr *net .UDPAddr ) (net .Listener , error ) {
return (&ListenConfig {}).Listen (network , laddr )
}
func (l *listener ) readLoop () {
defer l .readWG .Done ()
defer close (l .readDoneCh )
if br , ok := l .pConn .(BatchReader ); ok && l .readBatchSize > 1 {
l .readBatch (br )
} else {
l .read ()
}
}
func (l *listener ) readBatch (br BatchReader ) {
msgs := make ([]ipv4 .Message , l .readBatchSize )
for i := range msgs {
msg := &msgs [i ]
msg .Buffers = [][]byte {make ([]byte , receiveMTU )}
msg .OOB = make ([]byte , 40 )
}
for {
n , err := br .ReadBatch (msgs , 0 )
if err != nil {
l .errRead .Store (err )
return
}
for i := 0 ; i < n ; i ++ {
l .dispatchMsg (msgs [i ].Addr , msgs [i ].Buffers [0 ][:msgs [i ].N ])
}
}
}
func (l *listener ) read () {
buf := make ([]byte , receiveMTU )
for {
n , raddr , err := l .pConn .ReadFrom (buf )
if err != nil {
l .errRead .Store (err )
l .logger .Tracef ("error reading from connection err=%v" , err )
return
}
l .dispatchMsg (raddr , buf [:n ])
}
}
func (l *listener ) dispatchMsg (addr net .Addr , buf []byte ) {
conn , ok , err := l .getConn (addr , buf )
if err != nil {
return
}
if ok {
_ , err := conn .buffer .Write (buf )
if err != nil {
l .logger .Tracef ("error dispatching message addr=%v err=%v" , addr , err )
}
}
}
func (l *listener ) getConn (raddr net .Addr , buf []byte ) (*Conn , bool , error ) {
l .connLock .Lock ()
defer l .connLock .Unlock ()
conn , ok := l .conns [raddr .String ()]
if !ok {
if isAccepting , ok := l .accepting .Load ().(bool ); !isAccepting || !ok {
return nil , false , ErrClosedListener
}
if l .acceptFilter != nil {
if !l .acceptFilter (buf ) {
return nil , false , nil
}
}
conn = l .newConn (raddr )
select {
case l .acceptCh <- conn :
l .conns [raddr .String ()] = conn
default :
return nil , false , ErrListenQueueExceeded
}
}
return conn , true , nil
}
type Conn struct {
listener *listener
rAddr net .Addr
buffer *packetio .Buffer
doneCh chan struct {}
doneOnce sync .Once
writeDeadline *deadline .Deadline
}
func (l *listener ) newConn (rAddr net .Addr ) *Conn {
return &Conn {
listener : l ,
rAddr : rAddr ,
buffer : packetio .NewBuffer (),
doneCh : make (chan struct {}),
writeDeadline : deadline .New (),
}
}
func (c *Conn ) Read (p []byte ) (int , error ) {
return c .buffer .Read (p )
}
func (c *Conn ) Write (p []byte ) (n int , err error ) {
select {
case <- c .writeDeadline .Done ():
return 0 , context .DeadlineExceeded
default :
}
return c .listener .pConn .WriteTo (p , c .rAddr )
}
func (c *Conn ) Close () error {
var err error
c .doneOnce .Do (func () {
c .listener .connWG .Done ()
close (c .doneCh )
c .listener .connLock .Lock ()
delete (c .listener .conns , c .rAddr .String ())
nConns := len (c .listener .conns )
c .listener .connLock .Unlock ()
if isAccepting , ok := c .listener .accepting .Load ().(bool ); nConns == 0 && !isAccepting && ok {
c .listener .readWG .Wait ()
if errClose , ok := c .listener .errClose .Load ().(error ); ok {
err = errClose
}
} else {
err = nil
}
if errBuf := c .buffer .Close (); errBuf != nil && err == nil {
err = errBuf
}
})
return err
}
func (c *Conn ) LocalAddr () net .Addr {
return c .listener .pConn .LocalAddr ()
}
func (c *Conn ) RemoteAddr () net .Addr {
return c .rAddr
}
func (c *Conn ) SetDeadline (t time .Time ) error {
c .writeDeadline .Set (t )
return c .SetReadDeadline (t )
}
func (c *Conn ) SetReadDeadline (t time .Time ) error {
return c .buffer .SetReadDeadline (t )
}
func (c *Conn ) SetWriteDeadline (t time .Time ) error {
c .writeDeadline .Set (t )
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .