package client
import (
"encoding/binary"
"errors"
"fmt"
"math"
"net"
"time"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/pion/turn/v4/internal/proto"
)
var (
_ transport .TCPListener = (*TCPAllocation )(nil )
_ transport .Dialer = (*TCPAllocation )(nil )
)
func noDeadline() time .Time {
return time .Time {}
}
type TCPAllocation struct {
connAttemptCh chan *connectionAttempt
acceptTimer *time .Timer
allocation
}
func NewTCPAllocation (config *AllocationConfig ) *TCPAllocation {
alloc := &TCPAllocation {
connAttemptCh : make (chan *connectionAttempt , 10 ),
acceptTimer : time .NewTimer (time .Duration (math .MaxInt64 )),
allocation : allocation {
client : config .Client ,
relayedAddr : config .RelayedAddr ,
serverAddr : config .ServerAddr ,
username : config .Username ,
realm : config .Realm ,
permMap : newPermissionMap (),
integrity : config .Integrity ,
_nonce : config .Nonce ,
_lifetime : config .Lifetime ,
net : config .Net ,
log : config .Log ,
},
}
alloc .log .Debugf ("Initial lifetime: %d seconds" , int (alloc .lifetime ().Seconds ()))
alloc .refreshAllocTimer = NewPeriodicTimer (
timerIDRefreshAlloc ,
alloc .onRefreshTimers ,
alloc .lifetime ()/2 ,
)
alloc .refreshPermsTimer = NewPeriodicTimer (
timerIDRefreshPerms ,
alloc .onRefreshTimers ,
permRefreshInterval ,
)
if alloc .refreshAllocTimer .Start () {
alloc .log .Debug ("Started refreshAllocTimer" )
}
if alloc .refreshPermsTimer .Start () {
alloc .log .Debug ("Started refreshPermsTimer" )
}
return alloc
}
func (a *TCPAllocation ) Connect (peer net .Addr ) (proto .ConnectionID , error ) {
setters := []stun .Setter {
stun .TransactionID ,
stun .NewType (stun .MethodConnect , stun .ClassRequest ),
addr2PeerAddress (peer ),
a .username ,
a .realm ,
a .nonce (),
a .integrity ,
stun .Fingerprint ,
}
msg , err := stun .Build (setters ...)
if err != nil {
return 0 , err
}
a .log .Debugf ("Send connect request (peer=%v)" , peer )
trRes , err := a .client .PerformTransaction (msg , a .serverAddr , false )
if err != nil {
return 0 , err
}
res := trRes .Msg
if res .Type .Class == stun .ClassErrorResponse {
var code stun .ErrorCodeAttribute
if err = code .GetFrom (res ); err == nil {
return 0 , fmt .Errorf ("%s (error %s)" , res .Type , code )
}
return 0 , fmt .Errorf ("%s" , res .Type )
}
var cid proto .ConnectionID
if err := cid .GetFrom (res ); err != nil {
return 0 , err
}
a .log .Debugf ("Connect request successful (cid=%v)" , cid )
return cid , nil
}
func (a *TCPAllocation ) Dial (network , rAddrStr string ) (net .Conn , error ) {
rAddr , err := net .ResolveTCPAddr (network , rAddrStr )
if err != nil {
return nil , err
}
return a .DialTCP (network , nil , rAddr )
}
func (a *TCPAllocation ) DialWithConn (conn net .Conn , network , rAddrStr string ) (*TCPConn , error ) {
rAddr , err := net .ResolveTCPAddr (network , rAddrStr )
if err != nil {
return nil , err
}
return a .DialTCPWithConn (conn , network , rAddr )
}
func (a *TCPAllocation ) DialTCP (network string , lAddr , rAddr *net .TCPAddr ) (*TCPConn , error ) {
var rAddrServer *net .TCPAddr
if addr , ok := a .serverAddr .(*net .TCPAddr ); ok {
rAddrServer = &net .TCPAddr {
IP : addr .IP ,
Port : addr .Port ,
}
} else if addr , ok := a .serverAddr .(*net .UDPAddr ); ok {
rAddrServer = &net .TCPAddr {
IP : addr .IP ,
Port : addr .Port ,
}
} else {
return nil , errInvalidTURNAddress
}
conn , err := a .net .DialTCP (network , lAddr , rAddrServer )
if err != nil {
return nil , err
}
dataConn , err := a .DialTCPWithConn (conn , network , rAddr )
if err != nil {
conn .Close ()
}
return dataConn , err
}
func (a *TCPAllocation ) DialTCPWithConn (conn net .Conn , _ string , rAddr *net .TCPAddr ) (*TCPConn , error ) {
var err error
perm , ok := a .permMap .find (rAddr )
if !ok {
perm = &permission {}
a .permMap .insert (rAddr , perm )
}
for i := 0 ; i < maxRetryAttempts ; i ++ {
if err = a .createPermission (perm , rAddr ); !errors .Is (err , errTryAgain ) {
break
}
}
if err != nil {
return nil , err
}
cid , err := a .Connect (rAddr )
if err != nil {
return nil , err
}
tcpConn , ok := conn .(transport .TCPConn )
if !ok {
return nil , errTCPAddrCast
}
dataConn := &TCPConn {
TCPConn : tcpConn ,
ConnectionID : cid ,
remoteAddress : rAddr ,
allocation : a ,
}
if err := a .BindConnection (dataConn , cid ); err != nil {
return nil , fmt .Errorf ("failed to bind connection: %w" , err )
}
return dataConn , nil
}
func (a *TCPAllocation ) BindConnection (dataConn *TCPConn , cid proto .ConnectionID ) error {
msg , err := stun .Build (
stun .TransactionID ,
stun .NewType (stun .MethodConnectionBind , stun .ClassRequest ),
cid ,
a .username ,
a .realm ,
a .nonce (),
a .integrity ,
stun .Fingerprint ,
)
if err != nil {
return err
}
a .log .Debugf ("Send connectionBind request (cid=%v)" , cid )
_, err = dataConn .Write (msg .Raw )
if err != nil {
return err
}
b := make ([]byte , stunHeaderSize )
n , err := dataConn .Read (b )
if n != stunHeaderSize {
return errIncompleteTURNFrame
} else if err != nil {
return err
}
if !stun .IsMessage (b ) {
return errInvalidTURNFrame
}
datagramSize := binary .BigEndian .Uint16 (b [2 :4 ]) + stunHeaderSize
raw := make ([]byte , datagramSize )
copy (raw , b )
_, err = dataConn .Read (raw [stunHeaderSize :])
if err != nil {
return err
}
res := &stun .Message {Raw : raw }
if err = res .Decode (); err != nil {
return fmt .Errorf ("failed to decode STUN message: %w" , err )
}
switch res .Type .Class {
case stun .ClassErrorResponse :
var code stun .ErrorCodeAttribute
if err = code .GetFrom (res ); err == nil {
return fmt .Errorf ("%s (error %s)" , res .Type , code )
}
return fmt .Errorf ("%s" , res .Type )
case stun .ClassSuccessResponse :
a .log .Debug ("Successful connectionBind request" )
return nil
default :
return fmt .Errorf ("%w: %s" , errUnexpectedSTUNRequestMessage , res .String ())
}
}
func (a *TCPAllocation ) Accept () (net .Conn , error ) {
return a .AcceptTCP ()
}
func (a *TCPAllocation ) AcceptTCP () (transport .TCPConn , error ) {
addr , err := net .ResolveTCPAddr ("tcp4" , a .serverAddr .String ())
if err != nil {
return nil , err
}
tcpConn , err := a .net .DialTCP ("tcp" , nil , addr )
if err != nil {
return nil , err
}
dataConn , err := a .AcceptTCPWithConn (tcpConn )
if err != nil {
tcpConn .Close ()
}
return dataConn , err
}
func (a *TCPAllocation ) AcceptTCPWithConn (conn net .Conn ) (*TCPConn , error ) {
select {
case attempt := <- a .connAttemptCh :
tcpConn , ok := conn .(transport .TCPConn )
if !ok {
return nil , errTCPAddrCast
}
dataConn := &TCPConn {
TCPConn : tcpConn ,
ConnectionID : attempt .cid ,
remoteAddress : attempt .from ,
allocation : a ,
}
if err := a .BindConnection (dataConn , attempt .cid ); err != nil {
return nil , fmt .Errorf ("failed to bind connection: %w" , err )
}
return dataConn , nil
case <- a .acceptTimer .C :
return nil , &net .OpError {
Op : "accept" ,
Net : a .Addr ().Network (),
Addr : a .Addr (),
Err : newTimeoutError ("i/o timeout" ),
}
}
}
func (a *TCPAllocation ) SetDeadline (t time .Time ) error {
var d time .Duration
if t == noDeadline () {
d = time .Duration (math .MaxInt64 )
} else {
d = time .Until (t )
}
a .acceptTimer .Reset (d )
return nil
}
func (a *TCPAllocation ) Close () error {
a .refreshAllocTimer .Stop ()
a .refreshPermsTimer .Stop ()
a .client .OnDeallocated (a .relayedAddr )
return a .refreshAllocation (0 , true )
}
func (a *TCPAllocation ) Addr () net .Addr {
return a .relayedAddr
}
func (a *TCPAllocation ) HandleConnectionAttempt (from *net .TCPAddr , cid proto .ConnectionID ) {
a .connAttemptCh <- &connectionAttempt {
from : from ,
cid : cid ,
}
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .