package stun
import (
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/pion/dtls/v2"
"github.com/pion/transport/v2"
"github.com/pion/transport/v2/stdnet"
)
var ErrUnsupportedURI = fmt .Errorf ("invalid schema or transport" )
func Dial (network , address string ) (*Client , error ) {
conn , err := net .Dial (network , address )
if err != nil {
return nil , err
}
return NewClient (conn )
}
type DialConfig struct {
DTLSConfig dtls .Config
TLSConfig tls .Config
Net transport .Net
}
func DialURI (uri *URI , cfg *DialConfig ) (*Client , error ) {
var conn Connection
var err error
nw := cfg .Net
if nw == nil {
nw , err = stdnet .NewNet ()
if err != nil {
return nil , fmt .Errorf ("failed to create net: %w" , err )
}
}
addr := net .JoinHostPort (uri .Host , strconv .Itoa (uri .Port ))
switch {
case uri .Scheme == SchemeTypeSTUN :
if conn , err = nw .Dial ("udp" , addr ); err != nil {
return nil , fmt .Errorf ("failed to listen: %w" , err )
}
case uri .Scheme == SchemeTypeTURN :
network := "udp"
if uri .Proto == ProtoTypeTCP {
network = "tcp"
}
if conn , err = nw .Dial (network , addr ); err != nil {
return nil , fmt .Errorf ("failed to dial: %w" , err )
}
case uri .Scheme == SchemeTypeTURNS && uri .Proto == ProtoTypeUDP :
dtlsCfg := cfg .DTLSConfig
dtlsCfg .ServerName = uri .Host
udpConn , err := nw .Dial ("udp" , addr )
if err != nil {
return nil , fmt .Errorf ("failed to dial: %w" , err )
}
if conn , err = dtls .Client (udpConn , &dtlsCfg ); err != nil {
return nil , fmt .Errorf ("failed to connect to '%s': %w" , addr , err )
}
case (uri .Scheme == SchemeTypeTURNS || uri .Scheme == SchemeTypeSTUNS ) && uri .Proto == ProtoTypeTCP :
tlsCfg := cfg .TLSConfig
tlsCfg .ServerName = uri .Host
tcpConn , err := nw .Dial ("tcp" , addr )
if err != nil {
return nil , fmt .Errorf ("failed to dial: %w" , err )
}
conn = tls .Client (tcpConn , &tlsCfg )
default :
return nil , ErrUnsupportedURI
}
return NewClient (conn )
}
var ErrNoConnection = errors .New ("no connection provided" )
type ClientOption func (c *Client )
func WithHandler (h Handler ) ClientOption {
return func (c *Client ) {
c .handler = h
}
}
func WithRTO (rto time .Duration ) ClientOption {
return func (c *Client ) {
c .rto = int64 (rto )
}
}
func WithClock (clock Clock ) ClientOption {
return func (c *Client ) {
c .clock = clock
}
}
func WithTimeoutRate (d time .Duration ) ClientOption {
return func (c *Client ) {
c .rtoRate = d
}
}
func WithAgent (a ClientAgent ) ClientOption {
return func (c *Client ) {
c .a = a
}
}
func WithCollector (coll Collector ) ClientOption {
return func (c *Client ) {
c .collector = coll
}
}
func WithNoConnClose () ClientOption {
return func (c *Client ) {
c .closeConn = false
}
}
func WithNoRetransmit (c *Client ) {
c .maxAttempts = 0
if c .rto == 0 {
c .rto = defaultMaxAttempts * int64 (defaultRTO )
}
}
const (
defaultTimeoutRate = time .Millisecond * 5
defaultRTO = time .Millisecond * 300
defaultMaxAttempts = 7
)
func NewClient (conn Connection , options ...ClientOption ) (*Client , error ) {
c := &Client {
close : make (chan struct {}),
c : conn ,
clock : systemClock (),
rto : int64 (defaultRTO ),
rtoRate : defaultTimeoutRate ,
t : make (map [transactionID ]*clientTransaction , 100 ),
maxAttempts : defaultMaxAttempts ,
closeConn : true ,
}
for _ , o := range options {
o (c )
}
if c .c == nil {
return nil , ErrNoConnection
}
if c .a == nil {
c .a = NewAgent (nil )
}
if err := c .a .SetHandler (c .handleAgentCallback ); err != nil {
return nil , err
}
if c .collector == nil {
c .collector = &tickerCollector {
close : make (chan struct {}),
clock : c .clock ,
}
}
if err := c .collector .Start (c .rtoRate , func (t time .Time ) {
closedOrPanic (c .a .Collect (t ))
}); err != nil {
return nil , err
}
c .wg .Add (1 )
go c .readUntilClosed ()
runtime .SetFinalizer (c , clientFinalizer )
return c , nil
}
func clientFinalizer(c *Client ) {
if c == nil {
return
}
err := c .Close ()
if errors .Is (err , ErrClientClosed ) {
return
}
if err == nil {
log .Println ("client: called finalizer on non-closed client" )
return
}
log .Println ("client: called finalizer on non-closed client:" , err )
}
type Connection interface {
io .Reader
io .Writer
io .Closer
}
type ClientAgent interface {
Process (*Message ) error
Close () error
Start (id [TransactionIDSize ]byte , deadline time .Time ) error
Stop (id [TransactionIDSize ]byte ) error
Collect (time .Time ) error
SetHandler (h Handler ) error
}
type Client struct {
rto int64
a ClientAgent
c Connection
close chan struct {}
rtoRate time .Duration
maxAttempts int32
closed bool
closeConn bool
wg sync .WaitGroup
clock Clock
handler Handler
collector Collector
t map [transactionID ]*clientTransaction
mux sync .RWMutex
}
type clientTransaction struct {
id transactionID
attempt int32
calls int32
h Handler
start time .Time
rto time .Duration
raw []byte
}
func (t *clientTransaction ) handle (e Event ) {
if atomic .AddInt32 (&t .calls , 1 ) == 1 {
t .h (e )
}
}
var clientTransactionPool = &sync .Pool {
New : func () interface {} {
return &clientTransaction {
raw : make ([]byte , 1500 ),
}
},
}
func acquireClientTransaction() *clientTransaction {
return clientTransactionPool .Get ().(*clientTransaction )
}
func putClientTransaction(t *clientTransaction ) {
t .raw = t .raw [:0 ]
t .start = time .Time {}
t .attempt = 0
t .id = transactionID {}
clientTransactionPool .Put (t )
}
func (t *clientTransaction ) nextTimeout (now time .Time ) time .Time {
return now .Add (time .Duration (t .attempt +1 ) * t .rto )
}
func (c *Client ) start (t *clientTransaction ) error {
c .mux .Lock ()
defer c .mux .Unlock ()
if c .closed {
return ErrClientClosed
}
_ , exists := c .t [t .id ]
if exists {
return ErrTransactionExists
}
c .t [t .id ] = t
return nil
}
type Clock interface {
Now () time .Time
}
type systemClockService struct {}
func (systemClockService ) Now () time .Time { return time .Now () }
func systemClock() systemClockService {
return systemClockService {}
}
func (c *Client ) SetRTO (rto time .Duration ) {
atomic .StoreInt64 (&c .rto , int64 (rto ))
}
type StopErr struct {
Err error
Cause error
}
func (e StopErr ) Error () string {
return fmt .Sprintf ("error while stopping due to %s: %s" , sprintErr (e .Cause ), sprintErr (e .Err ))
}
type CloseErr struct {
AgentErr error
ConnectionErr error
}
func sprintErr(err error ) string {
if err == nil {
return "<nil>"
}
return err .Error()
}
func (c CloseErr ) Error () string {
return fmt .Sprintf ("failed to close: %s (connection), %s (agent)" , sprintErr (c .ConnectionErr ), sprintErr (c .AgentErr ))
}
func (c *Client ) readUntilClosed () {
defer c .wg .Done ()
m := new (Message )
m .Raw = make ([]byte , 1024 )
for {
select {
case <- c .close :
return
default :
}
_ , err := m .ReadFrom (c .c )
if err == nil {
if pErr := c .a .Process (m ); errors .Is (pErr , ErrAgentClosed ) {
return
}
}
}
}
func closedOrPanic(err error ) {
if err == nil || errors .Is (err , ErrAgentClosed ) {
return
}
panic (err )
}
type tickerCollector struct {
close chan struct {}
wg sync .WaitGroup
clock Clock
}
type Collector interface {
Start (rate time .Duration , f func (now time .Time )) error
Close () error
}
func (a *tickerCollector ) Start (rate time .Duration , f func (now time .Time )) error {
t := time .NewTicker (rate )
a .wg .Add (1 )
go func () {
defer a .wg .Done ()
for {
select {
case <- a .close :
t .Stop ()
return
case <- t .C :
f (a .clock .Now ())
}
}
}()
return nil
}
func (a *tickerCollector ) Close () error {
close (a .close )
a .wg .Wait ()
return nil
}
var ErrClientClosed = errors .New ("client is closed" )
func (c *Client ) Close () error {
if err := c .checkInit (); err != nil {
return err
}
c .mux .Lock ()
if c .closed {
c .mux .Unlock ()
return ErrClientClosed
}
c .closed = true
c .mux .Unlock ()
if closeErr := c .collector .Close (); closeErr != nil {
return closeErr
}
var connErr error
agentErr := c .a .Close ()
if c .closeConn {
connErr = c .c .Close ()
}
close (c .close )
c .wg .Wait ()
if agentErr == nil && connErr == nil {
return nil
}
return CloseErr {
AgentErr : agentErr ,
ConnectionErr : connErr ,
}
}
func (c *Client ) Indicate (m *Message ) error {
return c .Start (m , nil )
}
type callbackWaitHandler struct {
handler Handler
callback func (event Event )
cond *sync .Cond
processed bool
}
func (s *callbackWaitHandler ) HandleEvent (e Event ) {
s .cond .L .Lock ()
if s .callback == nil {
panic ("s.callback is nil" )
}
s .callback (e )
s .processed = true
s .cond .Broadcast ()
s .cond .L .Unlock ()
}
func (s *callbackWaitHandler ) wait () {
s .cond .L .Lock ()
for !s .processed {
s .cond .Wait ()
}
s .processed = false
s .callback = nil
s .cond .L .Unlock ()
}
func (s *callbackWaitHandler ) setCallback (f func (event Event )) {
if f == nil {
panic ("f is nil" )
}
s .cond .L .Lock ()
s .callback = f
if s .handler == nil {
s .handler = s .HandleEvent
}
s .cond .L .Unlock ()
}
var callbackWaitHandlerPool = sync .Pool {
New : func () interface {} {
return &callbackWaitHandler {
cond : sync .NewCond (new (sync .Mutex )),
}
},
}
var ErrClientNotInitialized = errors .New ("client not initialized" )
func (c *Client ) checkInit () error {
if c == nil || c .c == nil || c .a == nil || c .close == nil {
return ErrClientNotInitialized
}
return nil
}
func (c *Client ) Do (m *Message , f func (Event )) error {
if err := c .checkInit (); err != nil {
return err
}
if f == nil {
return c .Indicate (m )
}
h := callbackWaitHandlerPool .Get ().(*callbackWaitHandler )
h .setCallback (f )
defer func () {
callbackWaitHandlerPool .Put (h )
}()
if err := c .Start (m , h .handler ); err != nil {
return err
}
h .wait ()
return nil
}
func (c *Client ) delete (id transactionID ) {
c .mux .Lock ()
if c .t != nil {
delete (c .t , id )
}
c .mux .Unlock ()
}
type buffer struct {
buf []byte
}
var bufferPool = &sync .Pool {
New : func () interface {} {
return &buffer {buf : make ([]byte , 2048 )}
},
}
func (c *Client ) handleAgentCallback (e Event ) {
c .mux .Lock ()
if c .closed {
c .mux .Unlock ()
return
}
t , found := c .t [e .TransactionID ]
if found {
delete (c .t , t .id )
}
c .mux .Unlock ()
if !found {
if c .handler != nil && !errors .Is (e .Error , ErrTransactionStopped ) {
c .handler (e )
}
return
}
if atomic .LoadInt32 (&c .maxAttempts ) <= t .attempt || e .Error == nil {
t .handle (e )
putClientTransaction (t )
return
}
t .attempt ++
b := bufferPool .Get ().(*buffer )
b .buf = b .buf [:copy (b .buf [:cap (b .buf )], t .raw )]
defer bufferPool .Put (b )
var (
now = c .clock .Now ()
timeOut = t .nextTimeout (now )
id = t .id
)
if startErr := c .start (t ); startErr != nil {
c .delete (id )
e .Error = startErr
t .handle (e )
putClientTransaction (t )
return
}
if startErr := c .a .Start (id , timeOut ); startErr != nil {
c .delete (id )
e .Error = startErr
t .handle (e )
putClientTransaction (t )
return
}
_ , writeErr := c .c .Write (b .buf )
if writeErr != nil {
c .delete (id )
e .Error = writeErr
if stopErr := c .a .Stop (id ); stopErr != nil {
e .Error = StopErr {
Err : stopErr ,
Cause : writeErr ,
}
}
t .handle (e )
putClientTransaction (t )
return
}
}
func (c *Client ) Start (m *Message , h Handler ) error {
if err := c .checkInit (); err != nil {
return err
}
c .mux .RLock ()
closed := c .closed
c .mux .RUnlock ()
if closed {
return ErrClientClosed
}
if h != nil {
t := acquireClientTransaction ()
t .id = m .TransactionID
t .start = c .clock .Now ()
t .h = h
t .rto = time .Duration (atomic .LoadInt64 (&c .rto ))
t .attempt = 0
t .raw = append (t .raw [:0 ], m .Raw ...)
t .calls = 0
d := t .nextTimeout (t .start )
if err := c .start (t ); err != nil {
return err
}
if err := c .a .Start (m .TransactionID , d ); err != nil {
return err
}
}
_ , err := m .WriteTo (c .c )
if err != nil && h != nil {
c .delete (m .TransactionID )
if stopErr := c .a .Stop (m .TransactionID ); stopErr != nil {
return StopErr {
Err : stopErr ,
Cause : err ,
}
}
}
return err
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .