// Copyright 2021-2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package nats

import (
	
	
	
	
	
	
	
	
	
	mrand 
	
	
	
	
	

	
)

type wsOpCode int

const (
	// From https://tools.ietf.org/html/rfc6455#section-5.2
	wsTextMessage   = wsOpCode(1)
	wsBinaryMessage = wsOpCode(2)
	wsCloseMessage  = wsOpCode(8)
	wsPingMessage   = wsOpCode(9)
	wsPongMessage   = wsOpCode(10)

	wsFinalBit = 1 << 7
	wsRsv1Bit  = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6
	wsRsv2Bit  = 1 << 5
	wsRsv3Bit  = 1 << 4

	wsMaskBit = 1 << 7

	wsContinuationFrame     = 0
	wsMaxFrameHeaderSize    = 14
	wsMaxControlPayloadSize = 125
	wsCloseSatusSize        = 2

	// From https://tools.ietf.org/html/rfc6455#section-11.7
	wsCloseStatusNormalClosure      = 1000
	wsCloseStatusNoStatusReceived   = 1005
	wsCloseStatusAbnormalClosure    = 1006
	wsCloseStatusInvalidPayloadData = 1007

	wsScheme    = "ws"
	wsSchemeTLS = "wss"

	wsPMCExtension      = "permessage-deflate" // per-message compression
	wsPMCSrvNoCtx       = "server_no_context_takeover"
	wsPMCCliNoCtx       = "client_no_context_takeover"
	wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx
)

// From https://tools.ietf.org/html/rfc6455#section-1.3
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}

type websocketReader struct {
	r        io.Reader
	pending  [][]byte
	compress bool
	ib       []byte
	ff       bool
	fc       bool
	nl       bool
	dc       *wsDecompressor
	nc       *Conn
}

type wsDecompressor struct {
	flate io.ReadCloser
	bufs  [][]byte
	off   int
}

type websocketWriter struct {
	w          io.Writer
	compress   bool
	compressor *flate.Writer
	ctrlFrames [][]byte // pending frames that should be sent at the next Write()
	cm         []byte   // close message that needs to be sent when everything else has been sent
	cmDone     bool     // a close message has been added or sent (never going back to false)
	noMoreSend bool     // if true, even if there is a Write() call, we should not send anything
}

func ( *wsDecompressor) ( []byte) (int, error) {
	if len() == 0 {
		return 0, nil
	}
	if len(.bufs) == 0 {
		return 0, io.EOF
	}
	 := 0
	 := len()
	for  := .bufs[0];  != nil &&  > 0; {
		 := len([.off:])
		if  >  {
			 = 
		}
		copy([:], [.off:.off+])
		 += 
		 -= 
		.off += 
		 = .nextBuf()
	}
	return , nil
}

func ( *wsDecompressor) () []byte {
	// We still have remaining data in the first buffer
	if .off != len(.bufs[0]) {
		return .bufs[0]
	}
	// We read the full first buffer. Reset offset.
	.off = 0
	// We were at the last buffer, so we are done.
	if len(.bufs) == 1 {
		.bufs = nil
		return nil
	}
	// Here we move to the next buffer.
	.bufs = .bufs[1:]
	return .bufs[0]
}

func ( *wsDecompressor) () (byte, error) {
	if len(.bufs) == 0 {
		return 0, io.EOF
	}
	 := .bufs[0][.off]
	.off++
	.nextBuf()
	return , nil
}

func ( *wsDecompressor) ( []byte) {
	.bufs = append(.bufs, )
}

func ( *wsDecompressor) () ([]byte, error) {
	.off = 0
	// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
	// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
	// does not report unexpected EOF.
	.bufs = append(.bufs, compressFinalBlock)
	// Create or reset the decompressor with his object (wsDecompressor)
	// that provides Read() and ReadByte() APIs that will consume from
	// the compressed buffers (d.bufs).
	if .flate == nil {
		.flate = flate.NewReader()
	} else {
		.flate.(flate.Resetter).Reset(, nil)
	}
	,  := io.ReadAll(.flate)
	// Now reset the compressed buffers list
	.bufs = nil
	return , 
}

func wsNewReader( io.Reader) *websocketReader {
	return &websocketReader{r: , ff: true}
}

// From now on, reads will be from the readLoop and we will need to
// acquire the connection lock should we have to send/write a control
// message from handleControlFrame.
//
// Note: this runs under the connection lock.
func ( *websocketReader) () {
	.nl = true
}

func ( *websocketReader) ( []byte) (int, error) {
	var  error
	var  []byte

	if  := len(.ib);  > 0 {
		 = .ib
		.ib = nil
	} else {
		if len(.pending) > 0 {
			return .drainPending(), nil
		}

		// Get some data from the underlying reader.
		,  := .r.Read()
		if  != nil {
			return 0, 
		}
		 = [:]
	}

	// Now parse this and decode frames. We will possibly read more to
	// ensure that we get a full frame.
	var (
		 []byte
		    int
		    = len()
		    = 0
	)
	for  <  {
		 := []
		 := wsOpCode( & 0xF)
		 := &wsFinalBit != 0
		 := &wsRsv1Bit != 0
		++

		, ,  = wsGet(.r, , , 1)
		if  != nil {
			return 0, 
		}
		 := [0]

		// Store size in case it is < 125
		 = int( & 0x7F)

		switch  {
		case wsPingMessage, wsPongMessage, wsCloseMessage:
			if  > wsMaxControlPayloadSize {
				return 0, fmt.Errorf(
					"control frame length bigger than maximum allowed of %v bytes",
					wsMaxControlPayloadSize)
			}
			if  {
				return 0, errors.New("control frame should not be compressed")
			}
			if ! {
				return 0, errors.New("control frame does not have final bit set")
			}
		case wsTextMessage, wsBinaryMessage:
			if !.ff {
				return 0, errors.New("new message started before final frame for previous message was received")
			}
			.ff = 
			.fc = 
		case wsContinuationFrame:
			// Compressed bit must be only set in the first frame
			if .ff ||  {
				return 0, errors.New("invalid continuation frame")
			}
			.ff = 
		default:
			return 0, fmt.Errorf("unknown opcode %v", )
		}

		// If the encoded size is <= 125, then `rem` is simply the remainder size of the
		// frame. If it is 126, then the actual size is encoded as a uint16. For larger
		// frames, `rem` will initially be 127 and the actual size is encoded as a uint64.
		switch  {
		case 126:
			, ,  = wsGet(.r, , , 2)
			if  != nil {
				return 0, 
			}
			 = int(binary.BigEndian.Uint16())
		case 127:
			, ,  = wsGet(.r, , , 8)
			if  != nil {
				return 0, 
			}
			 = int(binary.BigEndian.Uint64())
		}

		// Handle control messages in place...
		if wsIsControlFrame() {
			,  = .handleControlFrame(, , , )
			if  != nil {
				return 0, 
			}
			 = 0
			continue
		}

		var  []byte
		// This ensures that we get the full payload for this frame.
		, ,  = wsGet(.r, , , )
		if  != nil {
			return 0, 
		}
		// We read the full frame.
		 = 0
		 := true
		if .fc {
			// Don't add to pending if we are not dealing with the final frame.
			 = .ff
			// Add the compressed payload buffer to the list.
			.addCBuf()
			// Decompress only when this is the final frame.
			if .ff {
				,  = .dc.decompress()
				if  != nil {
					return 0, 
				}
				.fc = false
			}
		} else if .compress {
			 = bytes.Clone()
		}
		// Add to the pending list if dealing with uncompressed frames or
		// after we have received the full compressed message and decompressed it.
		if  {
			.pending = append(.pending, )
		}
	}
	// In case of compression, there may be nothing to drain
	if len(.pending) > 0 {
		return .drainPending(), nil
	}
	return 0, nil
}

func ( *websocketReader) ( []byte) {
	if .dc == nil {
		.dc = &wsDecompressor{}
	}
	// Add a copy of the incoming buffer to the list of compressed buffers.
	.dc.addBuf(append([]byte(nil), ...))
}

func ( *websocketReader) ( []byte) int {
	var  int
	var  = len()

	for ,  := range .pending {
		if +len() <=  {
			copy([:], )
			 += len()
		} else {
			// Is there room left?
			if  <  {
				// Write the partial and update this slice.
				 :=  - 
				copy([:], [:])
				 += 
				.pending[] = [:]
			}
			// These are the remaining slices that will need to be used at
			// the next Read() call.
			.pending = .pending[:]
			return 
		}
	}
	.pending = .pending[:0]
	return 
}

func wsGet( io.Reader,  []byte, ,  int) ([]byte, int, error) {
	 := len() - 
	if  >=  {
		return [ : +],  + , nil
	}
	 := make([]byte, )
	 := copy(, [:])
	for  !=  {
		,  := .Read([:cap()])
		 += 
		if  != nil {
			return , , 
		}
	}
	return ,  + , nil
}

func ( *websocketReader) ( wsOpCode,  []byte, ,  int) (int, error) {
	var  []byte
	var  error

	if  > 0 {
		, ,  = wsGet(.r, , , )
		if  != nil {
			return , 
		}
	}
	switch  {
	case wsCloseMessage:
		 := wsCloseStatusNoStatusReceived
		var  string
		 := len()
		// If there is a payload, the status is represented as a 2-byte
		// unsigned integer (in network byte order). Then, there may be an
		// optional body.
		,  :=  >= wsCloseSatusSize,  > wsCloseSatusSize
		if  {
			// Decode the status
			 = int(binary.BigEndian.Uint16([:wsCloseSatusSize]))
			// Now if there is a body, capture it and make sure this is a valid UTF-8.
			if  {
				 = string([wsCloseSatusSize:])
				if !utf8.ValidString() {
					// https://tools.ietf.org/html/rfc6455#section-5.5.1
					// If body is present, it must be a valid utf8
					 = wsCloseStatusInvalidPayloadData
					 = "invalid utf8 body in close frame"
				}
			}
		}
		.nc.wsEnqueueCloseMsg(.nl, , )
		// Return io.EOF so that readLoop will close the connection as client closed
		// after processing pending buffers.
		return , io.EOF
	case wsPingMessage:
		.nc.wsEnqueueControlMsg(.nl, wsPongMessage, )
	case wsPongMessage:
		// Nothing to do..
	}
	return , nil
}

func ( *websocketWriter) ( []byte) (int, error) {
	if .noMoreSend {
		return 0, nil
	}
	var  int
	var  int
	var  error
	// If there are control frames, they can be sent now. Actually spec says
	// that they should be sent ASAP, so we will send before any application data.
	if len(.ctrlFrames) > 0 {
		,  = .writeCtrlFrames()
		if  != nil {
			return , 
		}
		 += 
	}
	// Do the following only if there is something to send.
	// We will end with checking for need to send close message.
	if len() > 0 {
		if .compress {
			 := &bytes.Buffer{}
			if .compressor == nil {
				.compressor, _ = flate.NewWriter(, flate.BestSpeed)
			} else {
				.compressor.Reset()
			}
			if ,  = .compressor.Write();  != nil {
				return , 
			}
			if  = .compressor.Flush();  != nil {
				return , 
			}
			 := .Bytes()
			 = [:len()-4]
		}
		,  := wsCreateFrameHeader(.compress, wsBinaryMessage, len())
		wsMaskBuf(, )
		,  = .w.Write()
		 += 
		if  == nil {
			,  = .w.Write()
			 += 
		}
	}
	if  == nil && .cm != nil {
		,  = .writeCloseMsg()
		 += 
	}
	return , 
}

func ( *websocketWriter) () (int, error) {
	var (
		     int
		 int
		     int
		   error
	)
	for ;  < len(.ctrlFrames); ++ {
		 := .ctrlFrames[]
		,  = .w.Write()
		 += 
		if  != nil {
			break
		}
	}
	if  != len(.ctrlFrames) {
		.ctrlFrames = .ctrlFrames[+1:]
	} else {
		.ctrlFrames = .ctrlFrames[:0]
	}
	return , 
}

func ( *websocketWriter) () (int, error) {
	,  := .w.Write(.cm)
	.cm, .noMoreSend = nil, true
	return , 
}

func wsMaskBuf(,  []byte) {
	for  := 0;  < len(); ++ {
		[] ^= [&3]
	}
}

// Create the frame header.
// Encodes the frame type and optional compression flag, and the size of the payload.
func wsCreateFrameHeader( bool,  wsOpCode,  int) ([]byte, []byte) {
	 := make([]byte, wsMaxFrameHeaderSize)
	,  := wsFillFrameHeader(, , , )
	return [:], 
}

func wsFillFrameHeader( []byte,  bool,  wsOpCode,  int) (int, []byte) {
	var  int
	 := byte()
	 |= wsFinalBit
	if  {
		 |= wsRsv1Bit
	}
	 := byte(wsMaskBit)
	switch {
	case  <= 125:
		 = 2
		[0] = 
		[1] =  | byte()
	case  < 65536:
		 = 4
		[0] = 
		[1] =  | 126
		binary.BigEndian.PutUint16([2:], uint16())
	default:
		 = 10
		[0] = 
		[1] =  | 127
		binary.BigEndian.PutUint64([2:], uint64())
	}
	var  []byte
	var  [4]byte
	if ,  := io.ReadFull(rand.Reader, [:4]);  != nil {
		 := mrand.Int31()
		binary.LittleEndian.PutUint32([:4], uint32())
	}
	copy([:], [:4])
	 = [ : +4]
	 += 4
	return , 
}

func ( *Conn) ( *url.URL) error {
	 := .Opts.Compression
	 := .Scheme == wsSchemeTLS || .Opts.Secure || .Opts.TLSConfig != nil || .Opts.TLSCertCB != nil || .Opts.RootCAsCB != nil
	// Do TLS here as needed.
	if  {
		if  := .makeTLSConn();  != nil {
			return 
		}
	} else {
		.bindToNewConn()
	}

	var  error

	// For http request, we need the passed URL to contain either http or https scheme.
	 := "http"
	if  {
		 = "https"
	}
	 := fmt.Sprintf("%s://%s", , .Host)

	if .Opts.ProxyPath != "" {
		 := .Opts.ProxyPath
		if !strings.HasPrefix(, "/") {
			 = "/" + 
		}
		 += 
	}

	,  = url.Parse()
	if  != nil {
		return 
	}
	 := &http.Request{
		Method:     "GET",
		URL:        ,
		Proto:      "HTTP/1.1",
		ProtoMajor: 1,
		ProtoMinor: 1,
		Header:     make(http.Header),
		Host:       .Host,
	}
	,  := wsMakeChallengeKey()
	if  != nil {
		return 
	}

	.Header["Upgrade"] = []string{"websocket"}
	.Header["Connection"] = []string{"Upgrade"}
	.Header["Sec-WebSocket-Key"] = []string{}
	.Header["Sec-WebSocket-Version"] = []string{"13"}
	if  {
		.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue)
	}
	if  := .wsUpdateConnectionHeaders();  != nil {
		return 
	}
	if  := .Write(.conn);  != nil {
		return 
	}

	var  *http.Response

	 := bufio.NewReaderSize(.conn, 4096)
	.conn.SetReadDeadline(time.Now().Add(.Opts.Timeout))
	,  = http.ReadResponse(, )
	if  == nil &&
		(.StatusCode != 101 ||
			!strings.EqualFold(.Header.Get("Upgrade"), "websocket") ||
			!strings.EqualFold(.Header.Get("Connection"), "upgrade") ||
			.Header.Get("Sec-Websocket-Accept") != wsAcceptKey()) {

		 = errors.New("invalid websocket connection")
	}
	// Check compression extension...
	if  == nil &&  {
		// Check that not only permessage-deflate extension is present, but that
		// we also have server and client no context take over.
		,  := wsPMCExtensionSupport(.Header)

		// If server does not support compression, then simply disable it in our side.
		if ! {
			 = false
		} else if ! {
			 = errors.New("compression negotiation error")
		}
	}
	if  != nil {
		.Body.Close()
	}
	.conn.SetReadDeadline(time.Time{})
	if  != nil {
		return 
	}

	 := wsNewReader(.br.r)
	.nc = 
	.compress = 
	// We have to slurp whatever is in the bufio reader and copy to br.r
	if  := .Buffered();  != 0 {
		.ib, _ = .Peek()
	}
	.br.r = 
	.bw.w = &websocketWriter{w: .bw.w, compress: }
	.ws = true
	return nil
}

func ( *Conn) () {
	.mu.Lock()
	defer .mu.Unlock()
	if !.ws {
		return
	}
	.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_)
}

func ( *Conn) ( bool,  int,  string) {
	// In some low-level unit tests it will happen...
	if  == nil {
		return
	}
	if  {
		.mu.Lock()
		defer .mu.Unlock()
	}
	.wsEnqueueCloseMsgLocked(, )
}

func ( *Conn) ( int,  string) {
	,  := .bw.w.(*websocketWriter)
	if ! || .cmDone {
		return
	}
	 := 2 + len()
	 := make([]byte, 2+4+)
	,  := wsFillFrameHeader(, false, wsCloseMessage, )
	// Set the status
	binary.BigEndian.PutUint16([:], uint16())
	// If there is a payload, copy
	if len() > 0 {
		copy([+2:], )
	}
	// Mask status + payload
	wsMaskBuf(, [:+])
	.cm = 
	.cmDone = true
	.bw.flush()
	if  := .compressor;  != nil {
		.Close()
	}
}

func ( *Conn) ( bool,  wsOpCode,  []byte) {
	// In some low-level unit tests it will happen...
	if  == nil {
		return
	}
	if  {
		.mu.Lock()
		defer .mu.Unlock()
	}
	,  := .bw.w.(*websocketWriter)
	if ! {
		return
	}
	,  := wsCreateFrameHeader(false, , len())
	.ctrlFrames = append(.ctrlFrames, )
	if len() > 0 {
		wsMaskBuf(, )
		.ctrlFrames = append(.ctrlFrames, )
	}
	.bw.flush()
}

func ( *Conn) ( *http.Request) error {
	var  http.Header
	var  error
	if .Opts.WebSocketConnectionHeadersHandler != nil {
		,  = .Opts.WebSocketConnectionHeadersHandler()
		if  != nil {
			return 
		}
	} else {
		 = .Opts.WebSocketConnectionHeaders
	}
	for ,  := range  {
		for ,  := range  {
			.Header.Add(, )
		}
	}
	return nil
}

func wsPMCExtensionSupport( http.Header) (bool, bool) {
	for ,  := range ["Sec-Websocket-Extensions"] {
		 := strings.Split(, ",")
		for ,  := range  {
			 = strings.Trim(, " \t")
			 := strings.Split(, ";")
			for ,  := range  {
				 = strings.Trim(, " \t")
				if strings.EqualFold(, wsPMCExtension) {
					var  bool
					var  bool
					for  :=  + 1;  < len(); ++ {
						 = []
						 = strings.Trim(, " \t")
						if strings.EqualFold(, wsPMCSrvNoCtx) {
							 = true
						} else if strings.EqualFold(, wsPMCCliNoCtx) {
							 = true
						}
						if  &&  {
							return true, true
						}
					}
					return true, false
				}
			}
		}
	}
	return false, false
}

func wsMakeChallengeKey() (string, error) {
	 := make([]byte, 16)
	if ,  := io.ReadFull(rand.Reader, );  != nil {
		return "", 
	}
	return base64.StdEncoding.EncodeToString(), nil
}

func wsAcceptKey( string) string {
	 := sha1.New()
	.Write([]byte())
	.Write(wsGUID)
	return base64.StdEncoding.EncodeToString(.Sum(nil))
}

// Returns true if the op code corresponds to a control frame.
func wsIsControlFrame( wsOpCode) bool {
	return  >= wsCloseMessage
}

func isWebsocketScheme( *url.URL) bool {
	return .Scheme == wsScheme || .Scheme == wsSchemeTLS
}