package stun
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
)
const (
magicCookie = 0x2112A442
attributeHeaderSize = 4
messageHeaderSize = 20
TransactionIDSize = 12
)
func NewTransactionID () (b [TransactionIDSize ]byte ) {
readFullOrPanic (rand .Reader , b [:])
return b
}
func IsMessage (b []byte ) bool {
return len (b ) >= messageHeaderSize && bin .Uint32 (b [4 :8 ]) == magicCookie
}
func New () *Message {
const defaultRawCapacity = 120
return &Message {
Raw : make ([]byte , messageHeaderSize , defaultRawCapacity ),
}
}
var ErrDecodeToNil = errors .New ("attempt to decode to nil message" )
func Decode (data []byte , m *Message ) error {
if m == nil {
return ErrDecodeToNil
}
m .Raw = append (m .Raw [:0 ], data ...)
return m .Decode ()
}
type Message struct {
Type MessageType
Length uint32
TransactionID [TransactionIDSize ]byte
Attributes Attributes
Raw []byte
}
func (m Message ) MarshalBinary () (data []byte , err error ) {
b := make ([]byte , len (m .Raw ))
copy (b , m .Raw )
return b , nil
}
func (m *Message ) UnmarshalBinary (data []byte ) error {
m .Raw = append (m .Raw [:0 ], data ...)
return m .Decode ()
}
func (m Message ) GobEncode () ([]byte , error ) {
return m .MarshalBinary ()
}
func (m *Message ) GobDecode (data []byte ) error {
return m .UnmarshalBinary (data )
}
func (m *Message ) AddTo (b *Message ) error {
b .TransactionID = m .TransactionID
b .WriteTransactionID ()
return nil
}
func (m *Message ) NewTransactionID () error {
_ , err := io .ReadFull (rand .Reader , m .TransactionID [:])
if err == nil {
m .WriteTransactionID ()
}
return err
}
func (m *Message ) String () string {
tID := base64 .StdEncoding .EncodeToString (m .TransactionID [:])
aInfo := ""
for k , a := range m .Attributes {
aInfo += fmt .Sprintf ("attr%d=%s " , k , a .Type )
}
return fmt .Sprintf ("%s l=%d attrs=%d id=%s, %s" , m .Type , m .Length , len (m .Attributes ), tID , aInfo )
}
func (m *Message ) Reset () {
m .Raw = m .Raw [:0 ]
m .Length = 0
m .Attributes = m .Attributes [:0 ]
}
func (m *Message ) grow (n int ) {
if len (m .Raw ) >= n {
return
}
if cap (m .Raw ) >= n {
m .Raw = m .Raw [:n ]
return
}
m .Raw = append (m .Raw , make ([]byte , n -len (m .Raw ))...)
}
func (m *Message ) Add (t AttrType , v []byte ) {
allocSize := attributeHeaderSize + len (v )
first := messageHeaderSize + int (m .Length )
last := first + allocSize
m .grow (last )
m .Raw = m .Raw [:last ]
m .Length += uint32 (allocSize )
buf := m .Raw [first :last ]
value := buf [attributeHeaderSize :]
attr := RawAttribute {
Type : t ,
Length : uint16 (len (v )),
Value : value ,
}
bin .PutUint16 (buf [0 :2 ], attr .Type .Value ())
bin .PutUint16 (buf [2 :4 ], attr .Length )
copy (value , v )
if attr .Length %padding != 0 {
bytesToAdd := nearestPaddedValueLength (len (v )) - len (v )
last += bytesToAdd
m .grow (last )
buf = m .Raw [last -bytesToAdd : last ]
for i := range buf {
buf [i ] = 0
}
m .Raw = m .Raw [:last ]
m .Length += uint32 (bytesToAdd )
}
m .Attributes = append (m .Attributes , attr )
m .WriteLength ()
}
func attrSliceEqual(a , b Attributes ) bool {
for _ , attr := range a {
found := false
for _ , attrB := range b {
if attrB .Type != attr .Type {
continue
}
if attrB .Equal (attr ) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func attrEqual(a , b Attributes ) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if len (a ) != len (b ) {
return false
}
if !attrSliceEqual (a , b ) {
return false
}
if !attrSliceEqual (b , a ) {
return false
}
return true
}
func (m *Message ) Equal (b *Message ) bool {
if m == nil && b == nil {
return true
}
if m == nil || b == nil {
return false
}
if m .Type != b .Type {
return false
}
if m .TransactionID != b .TransactionID {
return false
}
if m .Length != b .Length {
return false
}
if !attrEqual (m .Attributes , b .Attributes ) {
return false
}
return true
}
func (m *Message ) WriteLength () {
m .grow (4 )
bin .PutUint16 (m .Raw [2 :4 ], uint16 (m .Length ))
}
func (m *Message ) WriteHeader () {
m .grow (messageHeaderSize )
_ = m .Raw [:messageHeaderSize ]
m .WriteType ()
m .WriteLength ()
bin .PutUint32 (m .Raw [4 :8 ], magicCookie )
copy (m .Raw [8 :messageHeaderSize ], m .TransactionID [:])
}
func (m *Message ) WriteTransactionID () {
copy (m .Raw [8 :messageHeaderSize ], m .TransactionID [:])
}
func (m *Message ) WriteAttributes () {
attributes := m .Attributes
m .Attributes = attributes [:0 ]
for _ , a := range attributes {
m .Add (a .Type , a .Value )
}
m .Attributes = attributes
}
func (m *Message ) WriteType () {
m .grow (2 )
bin .PutUint16 (m .Raw [0 :2 ], m .Type .Value ())
}
func (m *Message ) SetType (t MessageType ) {
m .Type = t
m .WriteType ()
}
func (m *Message ) Encode () {
m .Raw = m .Raw [:0 ]
m .WriteHeader ()
m .Length = 0
m .WriteAttributes ()
}
func (m *Message ) WriteTo (w io .Writer ) (int64 , error ) {
n , err := w .Write (m .Raw )
return int64 (n ), err
}
func (m *Message ) ReadFrom (r io .Reader ) (int64 , error ) {
tBuf := m .Raw [:cap (m .Raw )]
var (
n int
err error
)
if n , err = r .Read (tBuf ); err != nil {
return int64 (n ), err
}
m .Raw = tBuf [:n ]
return int64 (n ), m .Decode ()
}
var ErrUnexpectedHeaderEOF = errors .New ("unexpected EOF: not enough bytes to read header" )
func (m *Message ) Decode () error {
buf := m .Raw
if len (buf ) < messageHeaderSize {
return ErrUnexpectedHeaderEOF
}
var (
t = bin .Uint16 (buf [0 :2 ])
size = int (bin .Uint16 (buf [2 :4 ]))
cookie = bin .Uint32 (buf [4 :8 ])
fullSize = messageHeaderSize + size
)
if cookie != magicCookie {
msg := fmt .Sprintf ("%x is invalid magic cookie (should be %x)" , cookie , magicCookie )
return newDecodeErr ("message" , "cookie" , msg )
}
if len (buf ) < fullSize {
msg := fmt .Sprintf ("buffer length %d is less than %d (expected message size)" , len (buf ), fullSize )
return newAttrDecodeErr ("message" , msg )
}
m .Type .ReadValue (t )
m .Length = uint32 (size )
copy (m .TransactionID [:], buf [8 :messageHeaderSize ])
m .Attributes = m .Attributes [:0 ]
var (
offset = 0
b = buf [messageHeaderSize :fullSize ]
)
for offset < size {
if len (b ) < attributeHeaderSize {
msg := fmt .Sprintf ("buffer length %d is less than %d (expected header size)" , len (b ), attributeHeaderSize )
return newAttrDecodeErr ("header" , msg )
}
var (
a = RawAttribute {
Type : compatAttrType (bin .Uint16 (b [0 :2 ])),
Length : bin .Uint16 (b [2 :4 ]),
}
aL = int (a .Length )
aBuffL = nearestPaddedValueLength (aL )
)
b = b [attributeHeaderSize :]
offset += attributeHeaderSize
if len (b ) < aBuffL {
msg := fmt .Sprintf ("buffer length %d is less than %d (expected value size for %s)" , len (b ), aBuffL , a .Type )
return newAttrDecodeErr ("value" , msg )
}
a .Value = b [:aL ]
offset += aBuffL
b = b [aBuffL :]
m .Attributes = append (m .Attributes , a )
}
return nil
}
func (m *Message ) Write (tBuf []byte ) (int , error ) {
m .Raw = append (m .Raw [:0 ], tBuf ...)
return len (tBuf ), m .Decode ()
}
func (m *Message ) CloneTo (b *Message ) error {
b .Raw = append (b .Raw [:0 ], m .Raw ...)
return b .Decode ()
}
type MessageClass byte
const (
ClassRequest MessageClass = 0x00
ClassIndication MessageClass = 0x01
ClassSuccessResponse MessageClass = 0x02
ClassErrorResponse MessageClass = 0x03
)
var (
BindingRequest = NewType (MethodBinding , ClassRequest )
BindingSuccess = NewType (MethodBinding , ClassSuccessResponse )
BindingError = NewType (MethodBinding , ClassErrorResponse )
)
func (c MessageClass ) String () string {
switch c {
case ClassRequest :
return "request"
case ClassIndication :
return "indication"
case ClassSuccessResponse :
return "success response"
case ClassErrorResponse :
return "error response"
default :
panic ("unknown message class" )
}
}
type Method uint16
const (
MethodBinding Method = 0x001
MethodAllocate Method = 0x003
MethodRefresh Method = 0x004
MethodSend Method = 0x006
MethodData Method = 0x007
MethodCreatePermission Method = 0x008
MethodChannelBind Method = 0x009
)
const (
MethodConnect Method = 0x000a
MethodConnectionBind Method = 0x000b
MethodConnectionAttempt Method = 0x000c
)
func methodName() map [Method ]string {
return map [Method ]string {
MethodBinding : "Binding" ,
MethodAllocate : "Allocate" ,
MethodRefresh : "Refresh" ,
MethodSend : "Send" ,
MethodData : "Data" ,
MethodCreatePermission : "CreatePermission" ,
MethodChannelBind : "ChannelBind" ,
MethodConnect : "Connect" ,
MethodConnectionBind : "ConnectionBind" ,
MethodConnectionAttempt : "ConnectionAttempt" ,
}
}
func (m Method ) String () string {
s , ok := methodName ()[m ]
if !ok {
s = fmt .Sprintf ("0x%x" , uint16 (m ))
}
return s
}
type MessageType struct {
Method Method
Class MessageClass
}
func (t MessageType ) AddTo (m *Message ) error {
m .SetType (t )
return nil
}
func NewType (method Method , class MessageClass ) MessageType {
return MessageType {
Method : method ,
Class : class ,
}
}
const (
methodABits = 0xf
methodBBits = 0x70
methodDBits = 0xf80
methodBShift = 1
methodDShift = 2
firstBit = 0x1
secondBit = 0x2
c0Bit = firstBit
c1Bit = secondBit
classC0Shift = 4
classC1Shift = 7
)
func (t MessageType ) Value () uint16 {
m := uint16 (t .Method )
a := m & methodABits
b := m & methodBBits
d := m & methodDBits
m = a + (b << methodBShift ) + (d << methodDShift )
c := uint16 (t .Class )
c0 := (c & c0Bit ) << classC0Shift
c1 := (c & c1Bit ) << classC1Shift
class := c0 + c1
return m + class
}
func (t *MessageType ) ReadValue (v uint16 ) {
c0 := (v >> classC0Shift ) & c0Bit
c1 := (v >> classC1Shift ) & c1Bit
class := c0 + c1
t .Class = MessageClass (class )
a := v & methodABits
b := (v >> methodBShift ) & methodBBits
d := (v >> methodDShift ) & methodDBits
m := a + b + d
t .Method = Method (m )
}
func (t MessageType ) String () string {
return fmt .Sprintf ("%s %s" , t .Method , t .Class )
}
func (m *Message ) Contains (t AttrType ) bool {
for _ , a := range m .Attributes {
if a .Type == t {
return true
}
}
return false
}
type transactionIDValueSetter [TransactionIDSize ]byte
func NewTransactionIDSetter (value [TransactionIDSize ]byte ) Setter {
return transactionIDValueSetter (value )
}
func (t transactionIDValueSetter ) AddTo (m *Message ) error {
m .TransactionID = t
m .WriteTransactionID ()
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .