fermentord/vendor/github.com/nats-io/nats.go/ws.go

780 lines
19 KiB
Go

// Copyright 2021-2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nats
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
mrand "math/rand"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"github.com/klauspost/compress/flate"
)
type wsOpCode int
const (
// From https://tools.ietf.org/html/rfc6455#section-5.2
wsTextMessage = wsOpCode(1)
wsBinaryMessage = wsOpCode(2)
wsCloseMessage = wsOpCode(8)
wsPingMessage = wsOpCode(9)
wsPongMessage = wsOpCode(10)
wsFinalBit = 1 << 7
wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6
wsRsv2Bit = 1 << 5
wsRsv3Bit = 1 << 4
wsMaskBit = 1 << 7
wsContinuationFrame = 0
wsMaxFrameHeaderSize = 14
wsMaxControlPayloadSize = 125
wsCloseSatusSize = 2
// From https://tools.ietf.org/html/rfc6455#section-11.7
wsCloseStatusNormalClosure = 1000
wsCloseStatusNoStatusReceived = 1005
wsCloseStatusAbnormalClosure = 1006
wsCloseStatusInvalidPayloadData = 1007
wsScheme = "ws"
wsSchemeTLS = "wss"
wsPMCExtension = "permessage-deflate" // per-message compression
wsPMCSrvNoCtx = "server_no_context_takeover"
wsPMCCliNoCtx = "client_no_context_takeover"
wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx
)
// From https://tools.ietf.org/html/rfc6455#section-1.3
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}
type websocketReader struct {
r io.Reader
pending [][]byte
ib []byte
ff bool
fc bool
nl bool
dc *wsDecompressor
nc *Conn
}
type wsDecompressor struct {
flate io.ReadCloser
bufs [][]byte
off int
}
type websocketWriter struct {
w io.Writer
compress bool
compressor *flate.Writer
ctrlFrames [][]byte // pending frames that should be sent at the next Write()
cm []byte // close message that needs to be sent when everything else has been sent
cmDone bool // a close message has been added or sent (never going back to false)
noMoreSend bool // if true, even if there is a Write() call, we should not send anything
}
func (d *wsDecompressor) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if len(d.bufs) == 0 {
return 0, io.EOF
}
copied := 0
rem := len(dst)
for buf := d.bufs[0]; buf != nil && rem > 0; {
n := len(buf[d.off:])
if n > rem {
n = rem
}
copy(dst[copied:], buf[d.off:d.off+n])
copied += n
rem -= n
d.off += n
buf = d.nextBuf()
}
return copied, nil
}
func (d *wsDecompressor) nextBuf() []byte {
// We still have remaining data in the first buffer
if d.off != len(d.bufs[0]) {
return d.bufs[0]
}
// We read the full first buffer. Reset offset.
d.off = 0
// We were at the last buffer, so we are done.
if len(d.bufs) == 1 {
d.bufs = nil
return nil
}
// Here we move to the next buffer.
d.bufs = d.bufs[1:]
return d.bufs[0]
}
func (d *wsDecompressor) ReadByte() (byte, error) {
if len(d.bufs) == 0 {
return 0, io.EOF
}
b := d.bufs[0][d.off]
d.off++
d.nextBuf()
return b, nil
}
func (d *wsDecompressor) addBuf(b []byte) {
d.bufs = append(d.bufs, b)
}
func (d *wsDecompressor) decompress() ([]byte, error) {
d.off = 0
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
d.bufs = append(d.bufs, compressFinalBlock)
// Create or reset the decompressor with his object (wsDecompressor)
// that provides Read() and ReadByte() APIs that will consume from
// the compressed buffers (d.bufs).
if d.flate == nil {
d.flate = flate.NewReader(d)
} else {
d.flate.(flate.Resetter).Reset(d, nil)
}
b, err := io.ReadAll(d.flate)
// Now reset the compressed buffers list
d.bufs = nil
return b, err
}
func wsNewReader(r io.Reader) *websocketReader {
return &websocketReader{r: r, ff: true}
}
// From now on, reads will be from the readLoop and we will need to
// acquire the connection lock should we have to send/write a control
// message from handleControlFrame.
//
// Note: this runs under the connection lock.
func (r *websocketReader) doneWithConnect() {
r.nl = true
}
func (r *websocketReader) Read(p []byte) (int, error) {
var err error
var buf []byte
if l := len(r.ib); l > 0 {
buf = r.ib
r.ib = nil
} else {
if len(r.pending) > 0 {
return r.drainPending(p), nil
}
// Get some data from the underlying reader.
n, err := r.r.Read(p)
if err != nil {
return 0, err
}
buf = p[:n]
}
// Now parse this and decode frames. We will possibly read more to
// ensure that we get a full frame.
var (
tmpBuf []byte
pos int
max = len(buf)
rem = 0
)
for pos < max {
b0 := buf[pos]
frameType := wsOpCode(b0 & 0xF)
final := b0&wsFinalBit != 0
compressed := b0&wsRsv1Bit != 0
pos++
tmpBuf, pos, err = wsGet(r.r, buf, pos, 1)
if err != nil {
return 0, err
}
b1 := tmpBuf[0]
// Store size in case it is < 125
rem = int(b1 & 0x7F)
switch frameType {
case wsPingMessage, wsPongMessage, wsCloseMessage:
if rem > wsMaxControlPayloadSize {
return 0, fmt.Errorf(
fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes",
wsMaxControlPayloadSize))
}
if compressed {
return 0, errors.New("control frame should not be compressed")
}
if !final {
return 0, errors.New("control frame does not have final bit set")
}
case wsTextMessage, wsBinaryMessage:
if !r.ff {
return 0, errors.New("new message started before final frame for previous message was received")
}
r.ff = final
r.fc = compressed
case wsContinuationFrame:
// Compressed bit must be only set in the first frame
if r.ff || compressed {
return 0, errors.New("invalid continuation frame")
}
r.ff = final
default:
return 0, fmt.Errorf("unknown opcode %v", frameType)
}
// If the encoded size is <= 125, then `rem` is simply the remainder size of the
// frame. If it is 126, then the actual size is encoded as a uint16. For larger
// frames, `rem` will initially be 127 and the actual size is encoded as a uint64.
switch rem {
case 126:
tmpBuf, pos, err = wsGet(r.r, buf, pos, 2)
if err != nil {
return 0, err
}
rem = int(binary.BigEndian.Uint16(tmpBuf))
case 127:
tmpBuf, pos, err = wsGet(r.r, buf, pos, 8)
if err != nil {
return 0, err
}
rem = int(binary.BigEndian.Uint64(tmpBuf))
}
// Handle control messages in place...
if wsIsControlFrame(frameType) {
pos, err = r.handleControlFrame(frameType, buf, pos, rem)
if err != nil {
return 0, err
}
rem = 0
continue
}
var b []byte
// This ensures that we get the full payload for this frame.
b, pos, err = wsGet(r.r, buf, pos, rem)
if err != nil {
return 0, err
}
// We read the full frame.
rem = 0
addToPending := true
if r.fc {
// Don't add to pending if we are not dealing with the final frame.
addToPending = r.ff
// Add the compressed payload buffer to the list.
r.addCBuf(b)
// Decompress only when this is the final frame.
if r.ff {
b, err = r.dc.decompress()
if err != nil {
return 0, err
}
r.fc = false
}
}
// Add to the pending list if dealing with uncompressed frames or
// after we have received the full compressed message and decompressed it.
if addToPending {
r.pending = append(r.pending, b)
}
}
// In case of compression, there may be nothing to drain
if len(r.pending) > 0 {
return r.drainPending(p), nil
}
return 0, nil
}
func (r *websocketReader) addCBuf(b []byte) {
if r.dc == nil {
r.dc = &wsDecompressor{}
}
// Add a copy of the incoming buffer to the list of compressed buffers.
r.dc.addBuf(append([]byte(nil), b...))
}
func (r *websocketReader) drainPending(p []byte) int {
var n int
var max = len(p)
for i, buf := range r.pending {
if n+len(buf) <= max {
copy(p[n:], buf)
n += len(buf)
} else {
// Is there room left?
if n < max {
// Write the partial and update this slice.
rem := max - n
copy(p[n:], buf[:rem])
n += rem
r.pending[i] = buf[rem:]
}
// These are the remaining slices that will need to be used at
// the next Read() call.
r.pending = r.pending[i:]
return n
}
}
r.pending = r.pending[:0]
return n
}
func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) {
avail := len(buf) - pos
if avail >= needed {
return buf[pos : pos+needed], pos + needed, nil
}
b := make([]byte, needed)
start := copy(b, buf[pos:])
for start != needed {
n, err := r.Read(b[start:cap(b)])
start += n
if err != nil {
return b, start, err
}
}
return b, pos + avail, nil
}
func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos, rem int) (int, error) {
var payload []byte
var err error
if rem > 0 {
payload, pos, err = wsGet(r.r, buf, pos, rem)
if err != nil {
return pos, err
}
}
switch frameType {
case wsCloseMessage:
status := wsCloseStatusNoStatusReceived
var body string
lp := len(payload)
// If there is a payload, the status is represented as a 2-byte
// unsigned integer (in network byte order). Then, there may be an
// optional body.
hasStatus, hasBody := lp >= wsCloseSatusSize, lp > wsCloseSatusSize
if hasStatus {
// Decode the status
status = int(binary.BigEndian.Uint16(payload[:wsCloseSatusSize]))
// Now if there is a body, capture it and make sure this is a valid UTF-8.
if hasBody {
body = string(payload[wsCloseSatusSize:])
if !utf8.ValidString(body) {
// https://tools.ietf.org/html/rfc6455#section-5.5.1
// If body is present, it must be a valid utf8
status = wsCloseStatusInvalidPayloadData
body = "invalid utf8 body in close frame"
}
}
}
r.nc.wsEnqueueCloseMsg(r.nl, status, body)
// Return io.EOF so that readLoop will close the connection as client closed
// after processing pending buffers.
return pos, io.EOF
case wsPingMessage:
r.nc.wsEnqueueControlMsg(r.nl, wsPongMessage, payload)
case wsPongMessage:
// Nothing to do..
}
return pos, nil
}
func (w *websocketWriter) Write(p []byte) (int, error) {
if w.noMoreSend {
return 0, nil
}
var total int
var n int
var err error
// If there are control frames, they can be sent now. Actually spec says
// that they should be sent ASAP, so we will send before any application data.
if len(w.ctrlFrames) > 0 {
n, err = w.writeCtrlFrames()
if err != nil {
return n, err
}
total += n
}
// Do the following only if there is something to send.
// We will end with checking for need to send close message.
if len(p) > 0 {
if w.compress {
buf := &bytes.Buffer{}
if w.compressor == nil {
w.compressor, _ = flate.NewWriter(buf, flate.BestSpeed)
} else {
w.compressor.Reset(buf)
}
if n, err = w.compressor.Write(p); err != nil {
return n, err
}
if err = w.compressor.Flush(); err != nil {
return n, err
}
b := buf.Bytes()
p = b[:len(b)-4]
}
fh, key := wsCreateFrameHeader(w.compress, wsBinaryMessage, len(p))
wsMaskBuf(key, p)
n, err = w.w.Write(fh)
total += n
if err == nil {
n, err = w.w.Write(p)
total += n
}
}
if err == nil && w.cm != nil {
n, err = w.writeCloseMsg()
total += n
}
return total, err
}
func (w *websocketWriter) writeCtrlFrames() (int, error) {
var (
n int
total int
i int
err error
)
for ; i < len(w.ctrlFrames); i++ {
buf := w.ctrlFrames[i]
n, err = w.w.Write(buf)
total += n
if err != nil {
break
}
}
if i != len(w.ctrlFrames) {
w.ctrlFrames = w.ctrlFrames[i+1:]
} else {
w.ctrlFrames = w.ctrlFrames[:0]
}
return total, err
}
func (w *websocketWriter) writeCloseMsg() (int, error) {
n, err := w.w.Write(w.cm)
w.cm, w.noMoreSend = nil, true
return n, err
}
func wsMaskBuf(key, buf []byte) {
for i := 0; i < len(buf); i++ {
buf[i] ^= key[i&3]
}
}
// Create the frame header.
// Encodes the frame type and optional compression flag, and the size of the payload.
func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) ([]byte, []byte) {
fh := make([]byte, wsMaxFrameHeaderSize)
n, key := wsFillFrameHeader(fh, compressed, frameType, l)
return fh[:n], key
}
func wsFillFrameHeader(fh []byte, compressed bool, frameType wsOpCode, l int) (int, []byte) {
var n int
b := byte(frameType)
b |= wsFinalBit
if compressed {
b |= wsRsv1Bit
}
b1 := byte(wsMaskBit)
switch {
case l <= 125:
n = 2
fh[0] = b
fh[1] = b1 | byte(l)
case l < 65536:
n = 4
fh[0] = b
fh[1] = b1 | 126
binary.BigEndian.PutUint16(fh[2:], uint16(l))
default:
n = 10
fh[0] = b
fh[1] = b1 | 127
binary.BigEndian.PutUint64(fh[2:], uint64(l))
}
var key []byte
var keyBuf [4]byte
if _, err := io.ReadFull(rand.Reader, keyBuf[:4]); err != nil {
kv := mrand.Int31()
binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv))
}
copy(fh[n:], keyBuf[:4])
key = fh[n : n+4]
n += 4
return n, key
}
func (nc *Conn) wsInitHandshake(u *url.URL) error {
compress := nc.Opts.Compression
tlsRequired := u.Scheme == wsSchemeTLS || nc.Opts.Secure || nc.Opts.TLSConfig != nil || nc.Opts.TLSCertCB != nil || nc.Opts.RootCAsCB != nil
// Do TLS here as needed.
if tlsRequired {
if err := nc.makeTLSConn(); err != nil {
return err
}
} else {
nc.bindToNewConn()
}
var err error
// For http request, we need the passed URL to contain either http or https scheme.
scheme := "http"
if tlsRequired {
scheme = "https"
}
ustr := fmt.Sprintf("%s://%s", scheme, u.Host)
if nc.Opts.ProxyPath != "" {
proxyPath := nc.Opts.ProxyPath
if !strings.HasPrefix(proxyPath, "/") {
proxyPath = "/" + proxyPath
}
ustr += proxyPath
}
u, err = url.Parse(ustr)
if err != nil {
return err
}
req := &http.Request{
Method: "GET",
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
wsKey, err := wsMakeChallengeKey()
if err != nil {
return err
}
req.Header["Upgrade"] = []string{"websocket"}
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{wsKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if compress {
req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue)
}
if err := req.Write(nc.conn); err != nil {
return err
}
var resp *http.Response
br := bufio.NewReaderSize(nc.conn, 4096)
nc.conn.SetReadDeadline(time.Now().Add(nc.Opts.Timeout))
resp, err = http.ReadResponse(br, req)
if err == nil &&
(resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) {
err = fmt.Errorf("invalid websocket connection")
}
// Check compression extension...
if err == nil && compress {
// Check that not only permessage-deflate extension is present, but that
// we also have server and client no context take over.
srvCompress, noCtxTakeover := wsPMCExtensionSupport(resp.Header)
// If server does not support compression, then simply disable it in our side.
if !srvCompress {
compress = false
} else if !noCtxTakeover {
err = fmt.Errorf("compression negotiation error")
}
}
if resp != nil {
resp.Body.Close()
}
nc.conn.SetReadDeadline(time.Time{})
if err != nil {
return err
}
wsr := wsNewReader(nc.br.r)
wsr.nc = nc
// We have to slurp whatever is in the bufio reader and copy to br.r
if n := br.Buffered(); n != 0 {
wsr.ib, _ = br.Peek(n)
}
nc.br.r = wsr
nc.bw.w = &websocketWriter{w: nc.bw.w, compress: compress}
nc.ws = true
return nil
}
func (nc *Conn) wsClose() {
nc.mu.Lock()
defer nc.mu.Unlock()
if !nc.ws {
return
}
nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_)
}
func (nc *Conn) wsEnqueueCloseMsg(needsLock bool, status int, payload string) {
// In some low-level unit tests it will happen...
if nc == nil {
return
}
if needsLock {
nc.mu.Lock()
defer nc.mu.Unlock()
}
nc.wsEnqueueCloseMsgLocked(status, payload)
}
func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) {
wr, ok := nc.bw.w.(*websocketWriter)
if !ok || wr.cmDone {
return
}
statusAndPayloadLen := 2 + len(payload)
frame := make([]byte, 2+4+statusAndPayloadLen)
n, key := wsFillFrameHeader(frame, false, wsCloseMessage, statusAndPayloadLen)
// Set the status
binary.BigEndian.PutUint16(frame[n:], uint16(status))
// If there is a payload, copy
if len(payload) > 0 {
copy(frame[n+2:], payload)
}
// Mask status + payload
wsMaskBuf(key, frame[n:n+statusAndPayloadLen])
wr.cm = frame
wr.cmDone = true
nc.bw.flush()
if c := wr.compressor; c != nil {
c.Close()
}
}
func (nc *Conn) wsEnqueueControlMsg(needsLock bool, frameType wsOpCode, payload []byte) {
// In some low-level unit tests it will happen...
if nc == nil {
return
}
if needsLock {
nc.mu.Lock()
defer nc.mu.Unlock()
}
wr, ok := nc.bw.w.(*websocketWriter)
if !ok {
return
}
fh, key := wsCreateFrameHeader(false, frameType, len(payload))
wr.ctrlFrames = append(wr.ctrlFrames, fh)
if len(payload) > 0 {
wsMaskBuf(key, payload)
wr.ctrlFrames = append(wr.ctrlFrames, payload)
}
nc.bw.flush()
}
func wsPMCExtensionSupport(header http.Header) (bool, bool) {
for _, extensionList := range header["Sec-Websocket-Extensions"] {
extensions := strings.Split(extensionList, ",")
for _, extension := range extensions {
extension = strings.Trim(extension, " \t")
params := strings.Split(extension, ";")
for i, p := range params {
p = strings.Trim(p, " \t")
if strings.EqualFold(p, wsPMCExtension) {
var snc bool
var cnc bool
for j := i + 1; j < len(params); j++ {
p = params[j]
p = strings.Trim(p, " \t")
if strings.EqualFold(p, wsPMCSrvNoCtx) {
snc = true
} else if strings.EqualFold(p, wsPMCCliNoCtx) {
cnc = true
}
if snc && cnc {
return true, true
}
}
return true, false
}
}
}
}
return false, false
}
func wsMakeChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
func wsAcceptKey(key string) string {
h := sha1.New()
h.Write([]byte(key))
h.Write(wsGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
// Returns true if the op code corresponds to a control frame.
func wsIsControlFrame(frameType wsOpCode) bool {
return frameType >= wsCloseMessage
}
func isWebsocketScheme(u *url.URL) bool {
return u.Scheme == wsScheme || u.Scheme == wsSchemeTLS
}