vendor/github.com/gorilla/websocket/conn.go
changeset 260 445e01aede7e
parent 256 6d9efbef00a9
--- a/vendor/github.com/gorilla/websocket/conn.go	Tue Aug 23 22:33:28 2022 +0200
+++ b/vendor/github.com/gorilla/websocket/conn.go	Tue Aug 23 22:39:43 2022 +0200
@@ -13,6 +13,7 @@
 	"math/rand"
 	"net"
 	"strconv"
+	"strings"
 	"sync"
 	"time"
 	"unicode/utf8"
@@ -401,6 +402,12 @@
 	return nil
 }
 
+func (c *Conn) writeBufs(bufs ...[]byte) error {
+	b := net.Buffers(bufs)
+	_, err := b.WriteTo(c.conn)
+	return err
+}
+
 // WriteControl writes a control message with the given deadline. The allowed
 // message types are CloseMessage, PingMessage and PongMessage.
 func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@@ -794,47 +801,69 @@
 	}
 
 	// 2. Read and parse first two bytes of frame header.
+	// To aid debugging, collect and report all errors in the first two bytes
+	// of the header.
+
+	var errors []string
 
 	p, err := c.read(2)
 	if err != nil {
 		return noFrame, err
 	}
 
+	frameType := int(p[0] & 0xf)
 	final := p[0]&finalBit != 0
-	frameType := int(p[0] & 0xf)
+	rsv1 := p[0]&rsv1Bit != 0
+	rsv2 := p[0]&rsv2Bit != 0
+	rsv3 := p[0]&rsv3Bit != 0
 	mask := p[1]&maskBit != 0
 	c.setReadRemaining(int64(p[1] & 0x7f))
 
 	c.readDecompress = false
-	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
-		c.readDecompress = true
-		p[0] &^= rsv1Bit
+	if rsv1 {
+		if c.newDecompressionReader != nil {
+			c.readDecompress = true
+		} else {
+			errors = append(errors, "RSV1 set")
+		}
 	}
 
-	if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
-		return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
+	if rsv2 {
+		errors = append(errors, "RSV2 set")
+	}
+
+	if rsv3 {
+		errors = append(errors, "RSV3 set")
 	}
 
 	switch frameType {
 	case CloseMessage, PingMessage, PongMessage:
 		if c.readRemaining > maxControlFramePayloadSize {
-			return noFrame, c.handleProtocolError("control frame length > 125")
+			errors = append(errors, "len > 125 for control")
 		}
 		if !final {
-			return noFrame, c.handleProtocolError("control frame not final")
+			errors = append(errors, "FIN not set on control")
 		}
 	case TextMessage, BinaryMessage:
 		if !c.readFinal {
-			return noFrame, c.handleProtocolError("message start before final message frame")
+			errors = append(errors, "data before FIN")
 		}
 		c.readFinal = final
 	case continuationFrame:
 		if c.readFinal {
-			return noFrame, c.handleProtocolError("continuation after final message frame")
+			errors = append(errors, "continuation after FIN")
 		}
 		c.readFinal = final
 	default:
-		return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
+		errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
+	}
+
+	if mask != c.isServer {
+		errors = append(errors, "bad MASK")
+	}
+
+	if len(errors) > 0 {
+		return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
 	}
 
 	// 3. Read and parse frame length as per
@@ -872,10 +901,6 @@
 
 	// 4. Handle frame masking.
 
-	if mask != c.isServer {
-		return noFrame, c.handleProtocolError("incorrect mask flag")
-	}
-
 	if mask {
 		c.readMaskPos = 0
 		p, err := c.read(len(c.readMaskKey))
@@ -935,7 +960,7 @@
 		if len(payload) >= 2 {
 			closeCode = int(binary.BigEndian.Uint16(payload))
 			if !isValidReceivedCloseCode(closeCode) {
-				return noFrame, c.handleProtocolError("invalid close code")
+				return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
 			}
 			closeText = string(payload[2:])
 			if !utf8.ValidString(closeText) {
@@ -952,7 +977,11 @@
 }
 
 func (c *Conn) handleProtocolError(message string) error {
-	c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
+	data := FormatCloseMessage(CloseProtocolError, message)
+	if len(data) > maxControlFramePayloadSize {
+		data = data[:maxControlFramePayloadSize]
+	}
+	c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
 	return errors.New("websocket: " + message)
 }