package nats
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
mrand "math/rand"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"github.com/klauspost/compress/flate"
)
type wsOpCode int
const (
wsTextMessage = wsOpCode (1 )
wsBinaryMessage = wsOpCode (2 )
wsCloseMessage = wsOpCode (8 )
wsPingMessage = wsOpCode (9 )
wsPongMessage = wsOpCode (10 )
wsFinalBit = 1 << 7
wsRsv1Bit = 1 << 6
wsRsv2Bit = 1 << 5
wsRsv3Bit = 1 << 4
wsMaskBit = 1 << 7
wsContinuationFrame = 0
wsMaxFrameHeaderSize = 14
wsMaxControlPayloadSize = 125
wsCloseSatusSize = 2
wsCloseStatusNormalClosure = 1000
wsCloseStatusNoStatusReceived = 1005
wsCloseStatusAbnormalClosure = 1006
wsCloseStatusInvalidPayloadData = 1007
wsScheme = "ws"
wsSchemeTLS = "wss"
wsPMCExtension = "permessage-deflate"
wsPMCSrvNoCtx = "server_no_context_takeover"
wsPMCCliNoCtx = "client_no_context_takeover"
wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx
)
var wsGUID = []byte ("258EAFA5-E914-47DA-95CA-C5AB0DC85B11" )
var compressFinalBlock = []byte {0x00 , 0x00 , 0xff , 0xff , 0x01 , 0x00 , 0x00 , 0xff , 0xff }
type websocketReader struct {
r io .Reader
pending [][]byte
compress bool
ib []byte
ff bool
fc bool
nl bool
dc *wsDecompressor
nc *Conn
}
type wsDecompressor struct {
flate io .ReadCloser
bufs [][]byte
off int
}
type websocketWriter struct {
w io .Writer
compress bool
compressor *flate .Writer
ctrlFrames [][]byte
cm []byte
cmDone bool
noMoreSend bool
}
func (d *wsDecompressor ) Read (dst []byte ) (int , error ) {
if len (dst ) == 0 {
return 0 , nil
}
if len (d .bufs ) == 0 {
return 0 , io .EOF
}
copied := 0
rem := len (dst )
for buf := d .bufs [0 ]; buf != nil && rem > 0 ; {
n := len (buf [d .off :])
if n > rem {
n = rem
}
copy (dst [copied :], buf [d .off :d .off +n ])
copied += n
rem -= n
d .off += n
buf = d .nextBuf ()
}
return copied , nil
}
func (d *wsDecompressor ) nextBuf () []byte {
if d .off != len (d .bufs [0 ]) {
return d .bufs [0 ]
}
d .off = 0
if len (d .bufs ) == 1 {
d .bufs = nil
return nil
}
d .bufs = d .bufs [1 :]
return d .bufs [0 ]
}
func (d *wsDecompressor ) ReadByte () (byte , error ) {
if len (d .bufs ) == 0 {
return 0 , io .EOF
}
b := d .bufs [0 ][d .off ]
d .off ++
d .nextBuf ()
return b , nil
}
func (d *wsDecompressor ) addBuf (b []byte ) {
d .bufs = append (d .bufs , b )
}
func (d *wsDecompressor ) decompress () ([]byte , error ) {
d .off = 0
d .bufs = append (d .bufs , compressFinalBlock )
if d .flate == nil {
d .flate = flate .NewReader (d )
} else {
d .flate .(flate .Resetter ).Reset (d , nil )
}
b , err := io .ReadAll (d .flate )
d .bufs = nil
return b , err
}
func wsNewReader(r io .Reader ) *websocketReader {
return &websocketReader {r : r , ff : true }
}
func (r *websocketReader ) doneWithConnect () {
r .nl = true
}
func (r *websocketReader ) Read (p []byte ) (int , error ) {
var err error
var buf []byte
if l := len (r .ib ); l > 0 {
buf = r .ib
r .ib = nil
} else {
if len (r .pending ) > 0 {
return r .drainPending (p ), nil
}
n , err := r .r .Read (p )
if err != nil {
return 0 , err
}
buf = p [:n ]
}
var (
tmpBuf []byte
pos int
max = len (buf )
rem = 0
)
for pos < max {
b0 := buf [pos ]
frameType := wsOpCode (b0 & 0xF )
final := b0 &wsFinalBit != 0
compressed := b0 &wsRsv1Bit != 0
pos ++
tmpBuf , pos , err = wsGet (r .r , buf , pos , 1 )
if err != nil {
return 0 , err
}
b1 := tmpBuf [0 ]
rem = int (b1 & 0x7F )
switch frameType {
case wsPingMessage , wsPongMessage , wsCloseMessage :
if rem > wsMaxControlPayloadSize {
return 0 , fmt .Errorf (
"control frame length bigger than maximum allowed of %v bytes" ,
wsMaxControlPayloadSize )
}
if compressed {
return 0 , errors .New ("control frame should not be compressed" )
}
if !final {
return 0 , errors .New ("control frame does not have final bit set" )
}
case wsTextMessage , wsBinaryMessage :
if !r .ff {
return 0 , errors .New ("new message started before final frame for previous message was received" )
}
r .ff = final
r .fc = compressed
case wsContinuationFrame :
if r .ff || compressed {
return 0 , errors .New ("invalid continuation frame" )
}
r .ff = final
default :
return 0 , fmt .Errorf ("unknown opcode %v" , frameType )
}
switch rem {
case 126 :
tmpBuf , pos , err = wsGet (r .r , buf , pos , 2 )
if err != nil {
return 0 , err
}
rem = int (binary .BigEndian .Uint16 (tmpBuf ))
case 127 :
tmpBuf , pos , err = wsGet (r .r , buf , pos , 8 )
if err != nil {
return 0 , err
}
rem = int (binary .BigEndian .Uint64 (tmpBuf ))
}
if wsIsControlFrame (frameType ) {
pos , err = r .handleControlFrame (frameType , buf , pos , rem )
if err != nil {
return 0 , err
}
rem = 0
continue
}
var b []byte
b , pos , err = wsGet (r .r , buf , pos , rem )
if err != nil {
return 0 , err
}
rem = 0
addToPending := true
if r .fc {
addToPending = r .ff
r .addCBuf (b )
if r .ff {
b , err = r .dc .decompress ()
if err != nil {
return 0 , err
}
r .fc = false
}
} else if r .compress {
b = bytes .Clone (b )
}
if addToPending {
r .pending = append (r .pending , b )
}
}
if len (r .pending ) > 0 {
return r .drainPending (p ), nil
}
return 0 , nil
}
func (r *websocketReader ) addCBuf (b []byte ) {
if r .dc == nil {
r .dc = &wsDecompressor {}
}
r .dc .addBuf (append ([]byte (nil ), b ...))
}
func (r *websocketReader ) drainPending (p []byte ) int {
var n int
var max = len (p )
for i , buf := range r .pending {
if n +len (buf ) <= max {
copy (p [n :], buf )
n += len (buf )
} else {
if n < max {
rem := max - n
copy (p [n :], buf [:rem ])
n += rem
r .pending [i ] = buf [rem :]
}
r .pending = r .pending [i :]
return n
}
}
r .pending = r .pending [:0 ]
return n
}
func wsGet(r io .Reader , buf []byte , pos , needed int ) ([]byte , int , error ) {
avail := len (buf ) - pos
if avail >= needed {
return buf [pos : pos +needed ], pos + needed , nil
}
b := make ([]byte , needed )
start := copy (b , buf [pos :])
for start != needed {
n , err := r .Read (b [start :cap (b )])
start += n
if err != nil {
return b , start , err
}
}
return b , pos + avail , nil
}
func (r *websocketReader ) handleControlFrame (frameType wsOpCode , buf []byte , pos , rem int ) (int , error ) {
var payload []byte
var err error
if rem > 0 {
payload , pos , err = wsGet (r .r , buf , pos , rem )
if err != nil {
return pos , err
}
}
switch frameType {
case wsCloseMessage :
status := wsCloseStatusNoStatusReceived
var body string
lp := len (payload )
hasStatus , hasBody := lp >= wsCloseSatusSize , lp > wsCloseSatusSize
if hasStatus {
status = int (binary .BigEndian .Uint16 (payload [:wsCloseSatusSize ]))
if hasBody {
body = string (payload [wsCloseSatusSize :])
if !utf8 .ValidString (body ) {
status = wsCloseStatusInvalidPayloadData
body = "invalid utf8 body in close frame"
}
}
}
r .nc .wsEnqueueCloseMsg (r .nl , status , body )
return pos , io .EOF
case wsPingMessage :
r .nc .wsEnqueueControlMsg (r .nl , wsPongMessage , payload )
case wsPongMessage :
}
return pos , nil
}
func (w *websocketWriter ) Write (p []byte ) (int , error ) {
if w .noMoreSend {
return 0 , nil
}
var total int
var n int
var err error
if len (w .ctrlFrames ) > 0 {
n , err = w .writeCtrlFrames ()
if err != nil {
return n , err
}
total += n
}
if len (p ) > 0 {
if w .compress {
buf := &bytes .Buffer {}
if w .compressor == nil {
w .compressor , _ = flate .NewWriter (buf , flate .BestSpeed )
} else {
w .compressor .Reset (buf )
}
if n , err = w .compressor .Write (p ); err != nil {
return n , err
}
if err = w .compressor .Flush (); err != nil {
return n , err
}
b := buf .Bytes ()
p = b [:len (b )-4 ]
}
fh , key := wsCreateFrameHeader (w .compress , wsBinaryMessage , len (p ))
wsMaskBuf (key , p )
n , err = w .w .Write (fh )
total += n
if err == nil {
n , err = w .w .Write (p )
total += n
}
}
if err == nil && w .cm != nil {
n , err = w .writeCloseMsg ()
total += n
}
return total , err
}
func (w *websocketWriter ) writeCtrlFrames () (int , error ) {
var (
n int
total int
i int
err error
)
for ; i < len (w .ctrlFrames ); i ++ {
buf := w .ctrlFrames [i ]
n , err = w .w .Write (buf )
total += n
if err != nil {
break
}
}
if i != len (w .ctrlFrames ) {
w .ctrlFrames = w .ctrlFrames [i +1 :]
} else {
w .ctrlFrames = w .ctrlFrames [:0 ]
}
return total , err
}
func (w *websocketWriter ) writeCloseMsg () (int , error ) {
n , err := w .w .Write (w .cm )
w .cm , w .noMoreSend = nil , true
return n , err
}
func wsMaskBuf(key , buf []byte ) {
for i := 0 ; i < len (buf ); i ++ {
buf [i ] ^= key [i &3 ]
}
}
func wsCreateFrameHeader(compressed bool , frameType wsOpCode , l int ) ([]byte , []byte ) {
fh := make ([]byte , wsMaxFrameHeaderSize )
n , key := wsFillFrameHeader (fh , compressed , frameType , l )
return fh [:n ], key
}
func wsFillFrameHeader(fh []byte , compressed bool , frameType wsOpCode , l int ) (int , []byte ) {
var n int
b := byte (frameType )
b |= wsFinalBit
if compressed {
b |= wsRsv1Bit
}
b1 := byte (wsMaskBit )
switch {
case l <= 125 :
n = 2
fh [0 ] = b
fh [1 ] = b1 | byte (l )
case l < 65536 :
n = 4
fh [0 ] = b
fh [1 ] = b1 | 126
binary .BigEndian .PutUint16 (fh [2 :], uint16 (l ))
default :
n = 10
fh [0 ] = b
fh [1 ] = b1 | 127
binary .BigEndian .PutUint64 (fh [2 :], uint64 (l ))
}
var key []byte
var keyBuf [4 ]byte
if _ , err := io .ReadFull (rand .Reader , keyBuf [:4 ]); err != nil {
kv := mrand .Int31 ()
binary .LittleEndian .PutUint32 (keyBuf [:4 ], uint32 (kv ))
}
copy (fh [n :], keyBuf [:4 ])
key = fh [n : n +4 ]
n += 4
return n , key
}
func (nc *Conn ) wsInitHandshake (u *url .URL ) error {
compress := nc .Opts .Compression
tlsRequired := u .Scheme == wsSchemeTLS || nc .Opts .Secure || nc .Opts .TLSConfig != nil || nc .Opts .TLSCertCB != nil || nc .Opts .RootCAsCB != nil
if tlsRequired {
if err := nc .makeTLSConn (); err != nil {
return err
}
} else {
nc .bindToNewConn ()
}
var err error
scheme := "http"
if tlsRequired {
scheme = "https"
}
ustr := fmt .Sprintf ("%s://%s" , scheme , u .Host )
if nc .Opts .ProxyPath != "" {
proxyPath := nc .Opts .ProxyPath
if !strings .HasPrefix (proxyPath , "/" ) {
proxyPath = "/" + proxyPath
}
ustr += proxyPath
}
u , err = url .Parse (ustr )
if err != nil {
return err
}
req := &http .Request {
Method : "GET" ,
URL : u ,
Proto : "HTTP/1.1" ,
ProtoMajor : 1 ,
ProtoMinor : 1 ,
Header : make (http .Header ),
Host : u .Host ,
}
wsKey , err := wsMakeChallengeKey ()
if err != nil {
return err
}
req .Header ["Upgrade" ] = []string {"websocket" }
req .Header ["Connection" ] = []string {"Upgrade" }
req .Header ["Sec-WebSocket-Key" ] = []string {wsKey }
req .Header ["Sec-WebSocket-Version" ] = []string {"13" }
if compress {
req .Header .Add ("Sec-WebSocket-Extensions" , wsPMCReqHeaderValue )
}
if err := nc .wsUpdateConnectionHeaders (req ); err != nil {
return err
}
if err := req .Write (nc .conn ); err != nil {
return err
}
var resp *http .Response
br := bufio .NewReaderSize (nc .conn , 4096 )
nc .conn .SetReadDeadline (time .Now ().Add (nc .Opts .Timeout ))
resp , err = http .ReadResponse (br , req )
if err == nil &&
(resp .StatusCode != 101 ||
!strings .EqualFold (resp .Header .Get ("Upgrade" ), "websocket" ) ||
!strings .EqualFold (resp .Header .Get ("Connection" ), "upgrade" ) ||
resp .Header .Get ("Sec-Websocket-Accept" ) != wsAcceptKey (wsKey )) {
err = errors .New ("invalid websocket connection" )
}
if err == nil && compress {
srvCompress , noCtxTakeover := wsPMCExtensionSupport (resp .Header )
if !srvCompress {
compress = false
} else if !noCtxTakeover {
err = errors .New ("compression negotiation error" )
}
}
if resp != nil {
resp .Body .Close ()
}
nc .conn .SetReadDeadline (time .Time {})
if err != nil {
return err
}
wsr := wsNewReader (nc .br .r )
wsr .nc = nc
wsr .compress = compress
if n := br .Buffered (); n != 0 {
wsr .ib , _ = br .Peek (n )
}
nc .br .r = wsr
nc .bw .w = &websocketWriter {w : nc .bw .w , compress : compress }
nc .ws = true
return nil
}
func (nc *Conn ) wsClose () {
nc .mu .Lock ()
defer nc .mu .Unlock ()
if !nc .ws {
return
}
nc .wsEnqueueCloseMsgLocked (wsCloseStatusNormalClosure , _EMPTY_ )
}
func (nc *Conn ) wsEnqueueCloseMsg (needsLock bool , status int , payload string ) {
if nc == nil {
return
}
if needsLock {
nc .mu .Lock ()
defer nc .mu .Unlock ()
}
nc .wsEnqueueCloseMsgLocked (status , payload )
}
func (nc *Conn ) wsEnqueueCloseMsgLocked (status int , payload string ) {
wr , ok := nc .bw .w .(*websocketWriter )
if !ok || wr .cmDone {
return
}
statusAndPayloadLen := 2 + len (payload )
frame := make ([]byte , 2 +4 +statusAndPayloadLen )
n , key := wsFillFrameHeader (frame , false , wsCloseMessage , statusAndPayloadLen )
binary .BigEndian .PutUint16 (frame [n :], uint16 (status ))
if len (payload ) > 0 {
copy (frame [n +2 :], payload )
}
wsMaskBuf (key , frame [n :n +statusAndPayloadLen ])
wr .cm = frame
wr .cmDone = true
nc .bw .flush ()
if c := wr .compressor ; c != nil {
c .Close ()
}
}
func (nc *Conn ) wsEnqueueControlMsg (needsLock bool , frameType wsOpCode , payload []byte ) {
if nc == nil {
return
}
if needsLock {
nc .mu .Lock ()
defer nc .mu .Unlock ()
}
wr , ok := nc .bw .w .(*websocketWriter )
if !ok {
return
}
fh , key := wsCreateFrameHeader (false , frameType , len (payload ))
wr .ctrlFrames = append (wr .ctrlFrames , fh )
if len (payload ) > 0 {
wsMaskBuf (key , payload )
wr .ctrlFrames = append (wr .ctrlFrames , payload )
}
nc .bw .flush ()
}
func (nc *Conn ) wsUpdateConnectionHeaders (req *http .Request ) error {
var headers http .Header
var err error
if nc .Opts .WebSocketConnectionHeadersHandler != nil {
headers , err = nc .Opts .WebSocketConnectionHeadersHandler ()
if err != nil {
return err
}
} else {
headers = nc .Opts .WebSocketConnectionHeaders
}
for key , values := range headers {
for _ , val := range values {
req .Header .Add (key , val )
}
}
return nil
}
func wsPMCExtensionSupport(header http .Header ) (bool , bool ) {
for _ , extensionList := range header ["Sec-Websocket-Extensions" ] {
extensions := strings .Split (extensionList , "," )
for _ , extension := range extensions {
extension = strings .Trim (extension , " \t" )
params := strings .Split (extension , ";" )
for i , p := range params {
p = strings .Trim (p , " \t" )
if strings .EqualFold (p , wsPMCExtension ) {
var snc bool
var cnc bool
for j := i + 1 ; j < len (params ); j ++ {
p = params [j ]
p = strings .Trim (p , " \t" )
if strings .EqualFold (p , wsPMCSrvNoCtx ) {
snc = true
} else if strings .EqualFold (p , wsPMCCliNoCtx ) {
cnc = true
}
if snc && cnc {
return true , true
}
}
return true , false
}
}
}
}
return false , false
}
func wsMakeChallengeKey() (string , error ) {
p := make ([]byte , 16 )
if _ , err := io .ReadFull (rand .Reader , p ); err != nil {
return "" , err
}
return base64 .StdEncoding .EncodeToString (p ), nil
}
func wsAcceptKey(key string ) string {
h := sha1 .New ()
h .Write ([]byte (key ))
h .Write (wsGUID )
return base64 .StdEncoding .EncodeToString (h .Sum (nil ))
}
func wsIsControlFrame(frameType wsOpCode ) bool {
return frameType >= wsCloseMessage
}
func isWebsocketScheme(u *url .URL ) bool {
return u .Scheme == wsScheme || u .Scheme == wsSchemeTLS
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .