vendor/github.com/gorilla/websocket/conn.go
changeset 251 1c52a0eeb952
parent 242 2a9ec03fe5a1
child 256 6d9efbef00a9
--- a/vendor/github.com/gorilla/websocket/conn.go	Wed Sep 18 19:17:42 2019 +0200
+++ b/vendor/github.com/gorilla/websocket/conn.go	Sun Feb 16 18:54:01 2020 +0100
@@ -260,10 +260,12 @@
 	newCompressionWriter   func(io.WriteCloser, int) io.WriteCloser
 
 	// Read fields
-	reader        io.ReadCloser // the current reader returned to the application
-	readErr       error
-	br            *bufio.Reader
-	readRemaining int64 // bytes remaining in current frame.
+	reader  io.ReadCloser // the current reader returned to the application
+	readErr error
+	br      *bufio.Reader
+	// bytes remaining in current frame.
+	// set setReadRemaining to safely update this value and prevent overflow
+	readRemaining int64
 	readFinal     bool  // true the current message has more frames.
 	readLength    int64 // Message size.
 	readLimit     int64 // Maximum message size.
@@ -320,6 +322,17 @@
 	return c
 }
 
+// setReadRemaining tracks the number of bytes remaining on the connection. If n
+// overflows, an ErrReadLimit is returned.
+func (c *Conn) setReadRemaining(n int64) error {
+	if n < 0 {
+		return ErrReadLimit
+	}
+
+	c.readRemaining = n
+	return nil
+}
+
 // Subprotocol returns the negotiated protocol for the connection.
 func (c *Conn) Subprotocol() string {
 	return c.subprotocol
@@ -451,7 +464,8 @@
 	return err
 }
 
-func (c *Conn) prepWrite(messageType int) error {
+// beginMessage prepares a connection and message writer for a new message.
+func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
 	// Close previous writer if not already closed by the application. It's
 	// probably better to return an error in this situation, but we cannot
 	// change this without breaking existing applications.
@@ -471,6 +485,10 @@
 		return err
 	}
 
+	mw.c = c
+	mw.frameType = messageType
+	mw.pos = maxFrameHeaderSize
+
 	if c.writeBuf == nil {
 		wpd, ok := c.writePool.Get().(writePoolData)
 		if ok {
@@ -491,16 +509,11 @@
 // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
 // PongMessage) are supported.
 func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
-	if err := c.prepWrite(messageType); err != nil {
+	var mw messageWriter
+	if err := c.beginMessage(&mw, messageType); err != nil {
 		return nil, err
 	}
-
-	mw := &messageWriter{
-		c:         c,
-		frameType: messageType,
-		pos:       maxFrameHeaderSize,
-	}
-	c.writer = mw
+	c.writer = &mw
 	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
 		w := c.newCompressionWriter(c.writer, c.compressionLevel)
 		mw.compress = true
@@ -517,10 +530,16 @@
 	err       error
 }
 
-func (w *messageWriter) fatal(err error) error {
+func (w *messageWriter) endMessage(err error) error {
 	if w.err != nil {
-		w.err = err
-		w.c.writer = nil
+		return err
+	}
+	c := w.c
+	w.err = err
+	c.writer = nil
+	if c.writePool != nil {
+		c.writePool.Put(writePoolData{buf: c.writeBuf})
+		c.writeBuf = nil
 	}
 	return err
 }
@@ -534,7 +553,7 @@
 	// Check for invalid control frames.
 	if isControl(w.frameType) &&
 		(!final || length > maxControlFramePayloadSize) {
-		return w.fatal(errInvalidControlFrame)
+		return w.endMessage(errInvalidControlFrame)
 	}
 
 	b0 := byte(w.frameType)
@@ -579,7 +598,7 @@
 		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
 		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
 		if len(extra) > 0 {
-			return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
+			return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
 		}
 	}
 
@@ -600,15 +619,11 @@
 	c.isWriting = false
 
 	if err != nil {
-		return w.fatal(err)
+		return w.endMessage(err)
 	}
 
 	if final {
-		c.writer = nil
-		if c.writePool != nil {
-			c.writePool.Put(writePoolData{buf: c.writeBuf})
-			c.writeBuf = nil
-		}
+		w.endMessage(errWriteClosed)
 		return nil
 	}
 
@@ -706,11 +721,7 @@
 	if w.err != nil {
 		return w.err
 	}
-	if err := w.flushFrame(true, nil); err != nil {
-		return err
-	}
-	w.err = errWriteClosed
-	return nil
+	return w.flushFrame(true, nil)
 }
 
 // WritePreparedMessage writes prepared message into connection.
@@ -742,10 +753,10 @@
 	if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
 		// Fast path with no allocations and single frame.
 
-		if err := c.prepWrite(messageType); err != nil {
+		var mw messageWriter
+		if err := c.beginMessage(&mw, messageType); err != nil {
 			return err
 		}
-		mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
 		n := copy(c.writeBuf[mw.pos:], data)
 		mw.pos += n
 		data = data[n:]
@@ -792,7 +803,7 @@
 	final := p[0]&finalBit != 0
 	frameType := int(p[0] & 0xf)
 	mask := p[1]&maskBit != 0
-	c.readRemaining = int64(p[1] & 0x7f)
+	c.setReadRemaining(int64(p[1] & 0x7f))
 
 	c.readDecompress = false
 	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
@@ -826,7 +837,17 @@
 		return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
 	}
 
-	// 3. Read and parse frame length.
+	// 3. Read and parse frame length as per
+	// https://tools.ietf.org/html/rfc6455#section-5.2
+	//
+	// The length of the "Payload data", in bytes: if 0-125, that is the payload
+	// length.
+	// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
+	// integer are the payload length.
+	// - If 127, the following 8 bytes interpreted as
+	// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
+	// payload length. Multibyte length quantities are expressed in network byte
+	// order.
 
 	switch c.readRemaining {
 	case 126:
@@ -834,13 +855,19 @@
 		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint16(p))
+
+		if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
+			return noFrame, err
+		}
 	case 127:
 		p, err := c.read(8)
 		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint64(p))
+
+		if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
+			return noFrame, err
+		}
 	}
 
 	// 4. Handle frame masking.
@@ -863,6 +890,12 @@
 	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
 
 		c.readLength += c.readRemaining
+		// Don't allow readLength to overflow in the presence of a large readRemaining
+		// counter.
+		if c.readLength < 0 {
+			return noFrame, ErrReadLimit
+		}
+
 		if c.readLimit > 0 && c.readLength > c.readLimit {
 			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
 			return noFrame, ErrReadLimit
@@ -876,7 +909,7 @@
 	var payload []byte
 	if c.readRemaining > 0 {
 		payload, err = c.read(int(c.readRemaining))
-		c.readRemaining = 0
+		c.setReadRemaining(0)
 		if err != nil {
 			return noFrame, err
 		}
@@ -949,6 +982,7 @@
 			c.readErr = hideTempErr(err)
 			break
 		}
+
 		if frameType == TextMessage || frameType == BinaryMessage {
 			c.messageReader = &messageReader{c}
 			c.reader = c.messageReader
@@ -989,7 +1023,9 @@
 			if c.isServer {
 				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
 			}
-			c.readRemaining -= int64(n)
+			rem := c.readRemaining
+			rem -= int64(n)
+			c.setReadRemaining(rem)
 			if c.readRemaining > 0 && c.readErr == io.EOF {
 				c.readErr = errUnexpectedEOF
 			}
@@ -1041,7 +1077,7 @@
 	return c.conn.SetReadDeadline(t)
 }
 
-// SetReadLimit sets the maximum size for a message read from the peer. If a
+// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
 // message exceeds the limit, the connection sends a close message to the peer
 // and returns ErrReadLimit to the application.
 func (c *Conn) SetReadLimit(limit int64) {