// Copyright 2021 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nats import ( "bufio" "bytes" "compress/flate" "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/binary" "errors" "fmt" "io" "io/ioutil" mrand "math/rand" "net/http" "net/url" "strings" "time" "unicode/utf8" ) type wsOpCode int const ( // From https://tools.ietf.org/html/rfc6455#section-5.2 wsTextMessage = wsOpCode(1) wsBinaryMessage = wsOpCode(2) wsCloseMessage = wsOpCode(8) wsPingMessage = wsOpCode(9) wsPongMessage = wsOpCode(10) wsFinalBit = 1 << 7 wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6 wsRsv2Bit = 1 << 5 wsRsv3Bit = 1 << 4 wsMaskBit = 1 << 7 wsContinuationFrame = 0 wsMaxFrameHeaderSize = 14 wsMaxControlPayloadSize = 125 wsCloseSatusSize = 2 // From https://tools.ietf.org/html/rfc6455#section-11.7 wsCloseStatusNormalClosure = 1000 wsCloseStatusNoStatusReceived = 1005 wsCloseStatusAbnormalClosure = 1006 wsCloseStatusInvalidPayloadData = 1007 wsScheme = "ws" wsSchemeTLS = "wss" wsPMCExtension = "permessage-deflate" // per-message compression wsPMCSrvNoCtx = "server_no_context_takeover" wsPMCCliNoCtx = "client_no_context_takeover" wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx ) // From https://tools.ietf.org/html/rfc6455#section-1.3 var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} type websocketReader struct { r io.Reader pending [][]byte ib []byte ff bool fc bool nl bool dc *wsDecompressor nc *Conn } type wsDecompressor struct { flate io.ReadCloser bufs [][]byte off int } type websocketWriter struct { w io.Writer compress bool compressor *flate.Writer ctrlFrames [][]byte // pending frames that should be sent at the next Write() cm []byte // close message that needs to be sent when everything else has been sent cmDone bool // a close message has been added or sent (never going back to false) noMoreSend bool // if true, even if there is a Write() call, we should not send anything } func (d *wsDecompressor) Read(dst []byte) (int, error) { if len(dst) == 0 { return 0, nil } if len(d.bufs) == 0 { return 0, io.EOF } copied := 0 rem := len(dst) for buf := d.bufs[0]; buf != nil && rem > 0; { n := len(buf[d.off:]) if n > rem { n = rem } copy(dst[copied:], buf[d.off:d.off+n]) copied += n rem -= n d.off += n buf = d.nextBuf() } return copied, nil } func (d *wsDecompressor) nextBuf() []byte { // We still have remaining data in the first buffer if d.off != len(d.bufs[0]) { return d.bufs[0] } // We read the full first buffer. Reset offset. d.off = 0 // We were at the last buffer, so we are done. if len(d.bufs) == 1 { d.bufs = nil return nil } // Here we move to the next buffer. d.bufs = d.bufs[1:] return d.bufs[0] } func (d *wsDecompressor) ReadByte() (byte, error) { if len(d.bufs) == 0 { return 0, io.EOF } b := d.bufs[0][d.off] d.off++ d.nextBuf() return b, nil } func (d *wsDecompressor) addBuf(b []byte) { d.bufs = append(d.bufs, b) } func (d *wsDecompressor) decompress() ([]byte, error) { d.off = 0 // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader // does not report unexpected EOF. d.bufs = append(d.bufs, compressFinalBlock) // Create or reset the decompressor with his object (wsDecompressor) // that provides Read() and ReadByte() APIs that will consume from // the compressed buffers (d.bufs). if d.flate == nil { d.flate = flate.NewReader(d) } else { d.flate.(flate.Resetter).Reset(d, nil) } // TODO: When Go 1.15 support is dropped, replace with io.ReadAll() b, err := ioutil.ReadAll(d.flate) // Now reset the compressed buffers list d.bufs = nil return b, err } func wsNewReader(r io.Reader) *websocketReader { return &websocketReader{r: r, ff: true} } // From now on, reads will be from the readLoop and we will need to // acquire the connection lock should we have to send/write a control // message from handleControlFrame. // // Note: this runs under the connection lock. func (r *websocketReader) doneWithConnect() { r.nl = true } func (r *websocketReader) Read(p []byte) (int, error) { var err error var buf []byte if l := len(r.ib); l > 0 { buf = r.ib r.ib = nil } else { if len(r.pending) > 0 { return r.drainPending(p), nil } // Get some data from the underlying reader. n, err := r.r.Read(p) if err != nil { return 0, err } buf = p[:n] } // Now parse this and decode frames. We will possibly read more to // ensure that we get a full frame. var ( tmpBuf []byte pos int max = len(buf) rem = 0 ) for pos < max { b0 := buf[pos] frameType := wsOpCode(b0 & 0xF) final := b0&wsFinalBit != 0 compressed := b0&wsRsv1Bit != 0 pos++ tmpBuf, pos, err = wsGet(r.r, buf, pos, 1) if err != nil { return 0, err } b1 := tmpBuf[0] // Store size in case it is < 125 rem = int(b1 & 0x7F) switch frameType { case wsPingMessage, wsPongMessage, wsCloseMessage: if rem > wsMaxControlPayloadSize { return 0, fmt.Errorf( fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes", wsMaxControlPayloadSize)) } if compressed { return 0, errors.New("control frame should not be compressed") } if !final { return 0, errors.New("control frame does not have final bit set") } case wsTextMessage, wsBinaryMessage: if !r.ff { return 0, errors.New("new message started before final frame for previous message was received") } r.ff = final r.fc = compressed case wsContinuationFrame: // Compressed bit must be only set in the first frame if r.ff || compressed { return 0, errors.New("invalid continuation frame") } r.ff = final default: return 0, fmt.Errorf("unknown opcode %v", frameType) } // If the encoded size is <= 125, then `rem` is simply the remainder size of the // frame. If it is 126, then the actual size is encoded as a uint16. For larger // frames, `rem` will initially be 127 and the actual size is encoded as a uint64. switch rem { case 126: tmpBuf, pos, err = wsGet(r.r, buf, pos, 2) if err != nil { return 0, err } rem = int(binary.BigEndian.Uint16(tmpBuf)) case 127: tmpBuf, pos, err = wsGet(r.r, buf, pos, 8) if err != nil { return 0, err } rem = int(binary.BigEndian.Uint64(tmpBuf)) } // Handle control messages in place... if wsIsControlFrame(frameType) { pos, err = r.handleControlFrame(frameType, buf, pos, rem) if err != nil { return 0, err } rem = 0 continue } var b []byte // This ensures that we get the full payload for this frame. b, pos, err = wsGet(r.r, buf, pos, rem) if err != nil { return 0, err } // We read the full frame. rem = 0 addToPending := true if r.fc { // Don't add to pending if we are not dealing with the final frame. addToPending = r.ff // Add the compressed payload buffer to the list. r.addCBuf(b) // Decompress only when this is the final frame. if r.ff { b, err = r.dc.decompress() if err != nil { return 0, err } r.fc = false } } // Add to the pending list if dealing with uncompressed frames or // after we have received the full compressed message and decompressed it. if addToPending { r.pending = append(r.pending, b) } } // In case of compression, there may be nothing to drain if len(r.pending) > 0 { return r.drainPending(p), nil } return 0, nil } func (r *websocketReader) addCBuf(b []byte) { if r.dc == nil { r.dc = &wsDecompressor{} } // Add a copy of the incoming buffer to the list of compressed buffers. r.dc.addBuf(append([]byte(nil), b...)) } func (r *websocketReader) drainPending(p []byte) int { var n int var max = len(p) for i, buf := range r.pending { if n+len(buf) <= max { copy(p[n:], buf) n += len(buf) } else { // Is there room left? if n < max { // Write the partial and update this slice. rem := max - n copy(p[n:], buf[:rem]) n += rem r.pending[i] = buf[rem:] } // These are the remaining slices that will need to be used at // the next Read() call. r.pending = r.pending[i:] return n } } r.pending = r.pending[:0] return n } func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { avail := len(buf) - pos if avail >= needed { return buf[pos : pos+needed], pos + needed, nil } b := make([]byte, needed) start := copy(b, buf[pos:]) for start != needed { n, err := r.Read(b[start:cap(b)]) start += n if err != nil { return b, start, err } } return b, pos + avail, nil } func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos, rem int) (int, error) { var payload []byte var err error if rem > 0 { payload, pos, err = wsGet(r.r, buf, pos, rem) if err != nil { return pos, err } } switch frameType { case wsCloseMessage: status := wsCloseStatusNoStatusReceived var body string lp := len(payload) // If there is a payload, the status is represented as a 2-byte // unsigned integer (in network byte order). Then, there may be an // optional body. hasStatus, hasBody := lp >= wsCloseSatusSize, lp > wsCloseSatusSize if hasStatus { // Decode the status status = int(binary.BigEndian.Uint16(payload[:wsCloseSatusSize])) // Now if there is a body, capture it and make sure this is a valid UTF-8. if hasBody { body = string(payload[wsCloseSatusSize:]) if !utf8.ValidString(body) { // https://tools.ietf.org/html/rfc6455#section-5.5.1 // If body is present, it must be a valid utf8 status = wsCloseStatusInvalidPayloadData body = "invalid utf8 body in close frame" } } } r.nc.wsEnqueueCloseMsg(r.nl, status, body) // Return io.EOF so that readLoop will close the connection as client closed // after processing pending buffers. return pos, io.EOF case wsPingMessage: r.nc.wsEnqueueControlMsg(r.nl, wsPongMessage, payload) case wsPongMessage: // Nothing to do.. } return pos, nil } func (w *websocketWriter) Write(p []byte) (int, error) { if w.noMoreSend { return 0, nil } var total int var n int var err error // If there are control frames, they can be sent now. Actually spec says // that they should be sent ASAP, so we will send before any application data. if len(w.ctrlFrames) > 0 { n, err = w.writeCtrlFrames() if err != nil { return n, err } total += n } // Do the following only if there is something to send. // We will end with checking for need to send close message. if len(p) > 0 { if w.compress { buf := &bytes.Buffer{} if w.compressor == nil { w.compressor, _ = flate.NewWriter(buf, flate.BestSpeed) } else { w.compressor.Reset(buf) } w.compressor.Write(p) w.compressor.Close() b := buf.Bytes() p = b[:len(b)-4] } fh, key := wsCreateFrameHeader(w.compress, wsBinaryMessage, len(p)) wsMaskBuf(key, p) n, err = w.w.Write(fh) total += n if err == nil { n, err = w.w.Write(p) total += n } } if err == nil && w.cm != nil { n, err = w.writeCloseMsg() total += n } return total, err } func (w *websocketWriter) writeCtrlFrames() (int, error) { var ( n int total int i int err error ) for ; i < len(w.ctrlFrames); i++ { buf := w.ctrlFrames[i] n, err = w.w.Write(buf) total += n if err != nil { break } } if i != len(w.ctrlFrames) { w.ctrlFrames = w.ctrlFrames[i+1:] } else { w.ctrlFrames = w.ctrlFrames[:0] } return total, err } func (w *websocketWriter) writeCloseMsg() (int, error) { n, err := w.w.Write(w.cm) w.cm, w.noMoreSend = nil, true return n, err } func wsMaskBuf(key, buf []byte) { for i := 0; i < len(buf); i++ { buf[i] ^= key[i&3] } } // Create the frame header. // Encodes the frame type and optional compression flag, and the size of the payload. func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { fh := make([]byte, wsMaxFrameHeaderSize) n, key := wsFillFrameHeader(fh, compressed, frameType, l) return fh[:n], key } func wsFillFrameHeader(fh []byte, compressed bool, frameType wsOpCode, l int) (int, []byte) { var n int b := byte(frameType) b |= wsFinalBit if compressed { b |= wsRsv1Bit } b1 := byte(wsMaskBit) switch { case l <= 125: n = 2 fh[0] = b fh[1] = b1 | byte(l) case l < 65536: n = 4 fh[0] = b fh[1] = b1 | 126 binary.BigEndian.PutUint16(fh[2:], uint16(l)) default: n = 10 fh[0] = b fh[1] = b1 | 127 binary.BigEndian.PutUint64(fh[2:], uint64(l)) } var key []byte var keyBuf [4]byte if _, err := io.ReadFull(rand.Reader, keyBuf[:4]); err != nil { kv := mrand.Int31() binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv)) } copy(fh[n:], keyBuf[:4]) key = fh[n : n+4] n += 4 return n, key } func (nc *Conn) wsInitHandshake(u *url.URL) error { compress := nc.Opts.Compression tlsRequired := u.Scheme == wsSchemeTLS || nc.Opts.Secure || nc.Opts.TLSConfig != nil // Do TLS here as needed. if tlsRequired { if err := nc.makeTLSConn(); err != nil { return err } } else { nc.bindToNewConn() } var err error // For http request, we need the passed URL to contain either http or https scheme. scheme := "http" if tlsRequired { scheme = "https" } ustr := fmt.Sprintf("%s://%s", scheme, u.Host) if nc.Opts.ProxyPath != "" { proxyPath := nc.Opts.ProxyPath if !strings.HasPrefix(proxyPath, "/") { proxyPath = "/" + proxyPath } ustr += proxyPath } u, err = url.Parse(ustr) if err != nil { return err } req := &http.Request{ Method: "GET", URL: u, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Host: u.Host, } wsKey, err := wsMakeChallengeKey() if err != nil { return err } req.Header["Upgrade"] = []string{"websocket"} req.Header["Connection"] = []string{"Upgrade"} req.Header["Sec-WebSocket-Key"] = []string{wsKey} req.Header["Sec-WebSocket-Version"] = []string{"13"} if compress { req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue) } if err := req.Write(nc.conn); err != nil { return err } var resp *http.Response br := bufio.NewReaderSize(nc.conn, 4096) nc.conn.SetReadDeadline(time.Now().Add(nc.Opts.Timeout)) resp, err = http.ReadResponse(br, req) if err == nil && (resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) { err = fmt.Errorf("invalid websocket connection") } // Check compression extension... if err == nil && compress { // Check that not only permessage-deflate extension is present, but that // we also have server and client no context take over. srvCompress, noCtxTakeover := wsPMCExtensionSupport(resp.Header) // If server does not support compression, then simply disable it in our side. if !srvCompress { compress = false } else if !noCtxTakeover { err = fmt.Errorf("compression negotiation error") } } if resp != nil { resp.Body.Close() } nc.conn.SetReadDeadline(time.Time{}) if err != nil { return err } wsr := wsNewReader(nc.br.r) wsr.nc = nc // We have to slurp whatever is in the bufio reader and copy to br.r if n := br.Buffered(); n != 0 { wsr.ib, _ = br.Peek(n) } nc.br.r = wsr nc.bw.w = &websocketWriter{w: nc.bw.w, compress: compress} nc.ws = true return nil } func (nc *Conn) wsClose() { nc.mu.Lock() defer nc.mu.Unlock() if !nc.ws { return } nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_) } func (nc *Conn) wsEnqueueCloseMsg(needsLock bool, status int, payload string) { // In some low-level unit tests it will happen... if nc == nil { return } if needsLock { nc.mu.Lock() defer nc.mu.Unlock() } nc.wsEnqueueCloseMsgLocked(status, payload) } func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) { wr, ok := nc.bw.w.(*websocketWriter) if !ok || wr.cmDone { return } statusAndPayloadLen := 2 + len(payload) frame := make([]byte, 2+4+statusAndPayloadLen) n, key := wsFillFrameHeader(frame, false, wsCloseMessage, statusAndPayloadLen) // Set the status binary.BigEndian.PutUint16(frame[n:], uint16(status)) // If there is a payload, copy if len(payload) > 0 { copy(frame[n+2:], payload) } // Mask status + payload wsMaskBuf(key, frame[n:n+statusAndPayloadLen]) wr.cm = frame wr.cmDone = true nc.bw.flush() } func (nc *Conn) wsEnqueueControlMsg(needsLock bool, frameType wsOpCode, payload []byte) { // In some low-level unit tests it will happen... if nc == nil { return } if needsLock { nc.mu.Lock() defer nc.mu.Unlock() } wr, ok := nc.bw.w.(*websocketWriter) if !ok { return } fh, key := wsCreateFrameHeader(false, frameType, len(payload)) wr.ctrlFrames = append(wr.ctrlFrames, fh) if len(payload) > 0 { wsMaskBuf(key, payload) wr.ctrlFrames = append(wr.ctrlFrames, payload) } nc.bw.flush() } func wsPMCExtensionSupport(header http.Header) (bool, bool) { for _, extensionList := range header["Sec-Websocket-Extensions"] { extensions := strings.Split(extensionList, ",") for _, extension := range extensions { extension = strings.Trim(extension, " \t") params := strings.Split(extension, ";") for i, p := range params { p = strings.Trim(p, " \t") if strings.EqualFold(p, wsPMCExtension) { var snc bool var cnc bool for j := i + 1; j < len(params); j++ { p = params[j] p = strings.Trim(p, " \t") if strings.EqualFold(p, wsPMCSrvNoCtx) { snc = true } else if strings.EqualFold(p, wsPMCCliNoCtx) { cnc = true } if snc && cnc { return true, true } } return true, false } } } } return false, false } func wsMakeChallengeKey() (string, error) { p := make([]byte, 16) if _, err := io.ReadFull(rand.Reader, p); err != nil { return "", err } return base64.StdEncoding.EncodeToString(p), nil } func wsAcceptKey(key string) string { h := sha1.New() h.Write([]byte(key)) h.Write(wsGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // Returns true if the op code corresponds to a control frame. func wsIsControlFrame(frameType wsOpCode) bool { return frameType >= wsCloseMessage } func isWebsocketScheme(u *url.URL) bool { return u.Scheme == wsScheme || u.Scheme == wsSchemeTLS }