package ice
import (
"encoding/binary"
"errors"
"io"
"net"
"strings"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/stun/v3"
)
var ErrGetTransportAddress = errors .New ("failed to get local transport address" )
type TCPMux interface {
io .Closer
GetConnByUfrag (ufrag string , isIPv6 bool , local net .IP ) (net .PacketConn , error )
RemoveConnByUfrag (ufrag string )
}
type ipAddr string
type TCPMuxDefault struct {
params *TCPMuxParams
closed bool
connsIPv4, connsIPv6 map [string ]map [ipAddr ]*tcpPacketConn
mu sync .Mutex
wg sync .WaitGroup
}
type TCPMuxParams struct {
Listener net .Listener
Logger logging .LeveledLogger
ReadBufferSize int
WriteBufferSize int
FirstStunBindTimeout time .Duration
AliveDurationForConnFromStun time .Duration
}
func NewTCPMuxDefault (params TCPMuxParams ) *TCPMuxDefault {
if params .Logger == nil {
params .Logger = logging .NewDefaultLoggerFactory ().NewLogger ("ice" )
}
if params .FirstStunBindTimeout == 0 {
params .FirstStunBindTimeout = 30 * time .Second
}
if params .AliveDurationForConnFromStun == 0 {
params .AliveDurationForConnFromStun = 30 * time .Second
}
mux := &TCPMuxDefault {
params : ¶ms ,
connsIPv4 : map [string ]map [ipAddr ]*tcpPacketConn {},
connsIPv6 : map [string ]map [ipAddr ]*tcpPacketConn {},
}
mux .wg .Add (1 )
go func () {
defer mux .wg .Done ()
mux .start ()
}()
return mux
}
func (m *TCPMuxDefault ) start () {
m .params .Logger .Infof ("Listening TCP on %s" , m .params .Listener .Addr ())
for {
conn , err := m .params .Listener .Accept ()
if err != nil {
m .params .Logger .Infof ("Error accepting connection: %s" , err )
return
}
m .params .Logger .Debugf ("Accepted connection from: %s to %s" , conn .RemoteAddr (), conn .LocalAddr ())
m .wg .Add (1 )
go func () {
defer m .wg .Done ()
m .handleConn (conn )
}()
}
}
func (m *TCPMuxDefault ) LocalAddr () net .Addr {
return m .params .Listener .Addr ()
}
func (m *TCPMuxDefault ) GetConnByUfrag (ufrag string , isIPv6 bool , local net .IP ) (net .PacketConn , error ) {
m .mu .Lock ()
defer m .mu .Unlock ()
if m .closed {
return nil , io .ErrClosedPipe
}
if conn , ok := m .getConn (ufrag , isIPv6 , local ); ok {
conn .ClearAliveTimer ()
return conn , nil
}
return m .createConn (ufrag , isIPv6 , local , false )
}
func (m *TCPMuxDefault ) createConn (ufrag string , isIPv6 bool , local net .IP , fromStun bool ) (*tcpPacketConn , error ) {
addr , ok := m .LocalAddr ().(*net .TCPAddr )
if !ok {
return nil , ErrGetTransportAddress
}
localAddr := *addr
localAddr .IP = local
var alive time .Duration
if fromStun {
alive = m .params .AliveDurationForConnFromStun
}
conn := newTCPPacketConn (tcpPacketParams {
ReadBuffer : m .params .ReadBufferSize ,
WriteBuffer : m .params .WriteBufferSize ,
LocalAddr : &localAddr ,
Logger : m .params .Logger ,
AliveDuration : alive ,
})
var conns map [ipAddr ]*tcpPacketConn
if isIPv6 {
if conns , ok = m .connsIPv6 [ufrag ]; !ok {
conns = make (map [ipAddr ]*tcpPacketConn )
m .connsIPv6 [ufrag ] = conns
}
} else {
if conns , ok = m .connsIPv4 [ufrag ]; !ok {
conns = make (map [ipAddr ]*tcpPacketConn )
m .connsIPv4 [ufrag ] = conns
}
}
connKey := ipAddr (local .String ())
conns [connKey ] = conn
m .wg .Add (1 )
go func () {
defer m .wg .Done ()
<-conn .CloseChannel ()
m .removeConnByUfragAndLocalHost (ufrag , connKey )
}()
return conn , nil
}
func (m *TCPMuxDefault ) closeAndLogError (closer io .Closer ) {
err := closer .Close ()
if err != nil {
m .params .Logger .Warnf ("Error closing connection: %s" , err )
}
}
func (m *TCPMuxDefault ) handleConn (conn net .Conn ) {
buf := make ([]byte , 512 )
if m .params .FirstStunBindTimeout > 0 {
if err := conn .SetReadDeadline (time .Now ().Add (m .params .FirstStunBindTimeout )); err != nil {
m .params .Logger .Warnf (
"Failed to set read deadline for first STUN message: %s to %s , err: %s" ,
conn .RemoteAddr (),
conn .LocalAddr (),
err ,
)
}
}
n , err := readStreamingPacket (conn , buf )
if err != nil {
if errors .Is (err , io .ErrShortBuffer ) {
m .params .Logger .Warnf ("Buffer too small for first packet from %s: %s" , conn .RemoteAddr (), err )
} else {
m .params .Logger .Warnf ("Error reading first packet from %s: %s" , conn .RemoteAddr (), err )
}
m .closeAndLogError (conn )
return
}
if err = conn .SetReadDeadline (time .Time {}); err != nil {
m .params .Logger .Warnf ("Failed to reset read deadline from %s: %s" , conn .RemoteAddr (), err )
}
buf = buf [:n ]
msg := &stun .Message {
Raw : make ([]byte , len (buf )),
}
copy (msg .Raw , buf )
if err = msg .Decode (); err != nil {
m .closeAndLogError (conn )
m .params .Logger .Warnf ("Failed to handle decode ICE from %s to %s: %v" , conn .RemoteAddr (), conn .LocalAddr (), err )
return
}
if m == nil || msg .Type .Method != stun .MethodBinding {
m .closeAndLogError (conn )
m .params .Logger .Warnf ("Not a STUN message from %s to %s" , conn .RemoteAddr (), conn .LocalAddr ())
return
}
for _ , attr := range msg .Attributes {
m .params .Logger .Debugf ("Message attribute: %s" , attr .String ())
}
attr , err := msg .Get (stun .AttrUsername )
if err != nil {
m .closeAndLogError (conn )
m .params .Logger .Warnf (
"No Username attribute in STUN message from %s to %s" ,
conn .RemoteAddr (),
conn .LocalAddr (),
)
return
}
ufrag := strings .Split (string (attr ), ":" )[0 ]
m .params .Logger .Debugf ("Ufrag: %s" , ufrag )
host , _ , err := net .SplitHostPort (conn .RemoteAddr ().String ())
if err != nil {
m .closeAndLogError (conn )
m .params .Logger .Warnf (
"Failed to get host in STUN message from %s to %s" ,
conn .RemoteAddr (),
conn .LocalAddr (),
)
return
}
isIPv6 := net .ParseIP (host ).To4 () == nil
localAddr , ok := conn .LocalAddr ().(*net .TCPAddr )
if !ok {
m .closeAndLogError (conn )
m .params .Logger .Warnf (
"Failed to get local tcp address in STUN message from %s to %s" ,
conn .RemoteAddr (),
conn .LocalAddr (),
)
return
}
m .mu .Lock ()
packetConn , ok := m .getConn (ufrag , isIPv6 , localAddr .IP )
if !ok {
packetConn , err = m .createConn (ufrag , isIPv6 , localAddr .IP , true )
if err != nil {
m .mu .Unlock ()
m .closeAndLogError (conn )
m .params .Logger .Warnf (
"Failed to create packetConn for STUN message from %s to %s" ,
conn .RemoteAddr (),
conn .LocalAddr (),
)
return
}
}
m .mu .Unlock ()
if err := packetConn .AddConn (conn , buf ); err != nil {
m .closeAndLogError (conn )
m .params .Logger .Warnf (
"Error adding conn to tcpPacketConn from %s to %s: %s" ,
conn .RemoteAddr (),
conn .LocalAddr (),
err ,
)
return
}
}
func (m *TCPMuxDefault ) Close () error {
m .mu .Lock ()
m .closed = true
for _ , conns := range m .connsIPv4 {
for _ , conn := range conns {
m .closeAndLogError (conn )
}
}
for _ , conns := range m .connsIPv6 {
for _ , conn := range conns {
m .closeAndLogError (conn )
}
}
m .connsIPv4 = map [string ]map [ipAddr ]*tcpPacketConn {}
m .connsIPv6 = map [string ]map [ipAddr ]*tcpPacketConn {}
err := m .params .Listener .Close ()
m .mu .Unlock ()
m .wg .Wait ()
return err
}
func (m *TCPMuxDefault ) RemoveConnByUfrag (ufrag string ) {
removedConns := make ([]*tcpPacketConn , 0 , 4 )
m .mu .Lock ()
if conns , ok := m .connsIPv4 [ufrag ]; ok {
delete (m .connsIPv4 , ufrag )
for _ , conn := range conns {
removedConns = append (removedConns , conn )
}
}
if conns , ok := m .connsIPv6 [ufrag ]; ok {
delete (m .connsIPv6 , ufrag )
for _ , conn := range conns {
removedConns = append (removedConns , conn )
}
}
m .mu .Unlock ()
for _ , conn := range removedConns {
m .closeAndLogError (conn )
}
}
func (m *TCPMuxDefault ) removeConnByUfragAndLocalHost (ufrag string , localIPAddr ipAddr ) {
removedConns := make ([]*tcpPacketConn , 0 , 4 )
m .mu .Lock ()
if conns , ok := m .connsIPv4 [ufrag ]; ok {
if conn , ok := conns [localIPAddr ]; ok {
delete (conns , localIPAddr )
if len (conns ) == 0 {
delete (m .connsIPv4 , ufrag )
}
removedConns = append (removedConns , conn )
}
}
if conns , ok := m .connsIPv6 [ufrag ]; ok {
if conn , ok := conns [localIPAddr ]; ok {
delete (conns , localIPAddr )
if len (conns ) == 0 {
delete (m .connsIPv6 , ufrag )
}
removedConns = append (removedConns , conn )
}
}
m .mu .Unlock ()
for _ , conn := range removedConns {
m .closeAndLogError (conn )
}
}
func (m *TCPMuxDefault ) getConn (ufrag string , isIPv6 bool , local net .IP ) (val *tcpPacketConn , ok bool ) {
var conns map [ipAddr ]*tcpPacketConn
if isIPv6 {
conns , ok = m .connsIPv6 [ufrag ]
} else {
conns , ok = m .connsIPv4 [ufrag ]
}
if conns != nil {
connKey := ipAddr (local .String ())
val , ok = conns [connKey ]
}
return
}
const streamingPacketHeaderLen = 2
func readStreamingPacket(conn net .Conn , buf []byte ) (int , error ) {
header := make ([]byte , streamingPacketHeaderLen )
var bytesRead , n int
var err error
for bytesRead < streamingPacketHeaderLen {
if n , err = conn .Read (header [bytesRead :streamingPacketHeaderLen ]); err != nil {
return 0 , err
}
bytesRead += n
}
length := int (binary .BigEndian .Uint16 (header ))
if length > cap (buf ) {
return length , io .ErrShortBuffer
}
bytesRead = 0
for bytesRead < length {
if n , err = conn .Read (buf [bytesRead :length ]); err != nil {
return 0 , err
}
bytesRead += n
}
return bytesRead , nil
}
func writeStreamingPacket(conn net .Conn , buf []byte ) (int , error ) {
bufCopy := make ([]byte , streamingPacketHeaderLen +len (buf ))
binary .BigEndian .PutUint16 (bufCopy , uint16 (len (buf )))
copy (bufCopy [2 :], buf )
n , err := conn .Write (bufCopy )
if err != nil {
return 0 , err
}
return n - streamingPacketHeaderLen , nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .