vendor/github.com/gorilla/websocket/conn.go
changeset 251 1c52a0eeb952
parent 242 2a9ec03fe5a1
child 256 6d9efbef00a9
equal deleted inserted replaced
250:c040f992052f 251:1c52a0eeb952
   258 	enableWriteCompression bool
   258 	enableWriteCompression bool
   259 	compressionLevel       int
   259 	compressionLevel       int
   260 	newCompressionWriter   func(io.WriteCloser, int) io.WriteCloser
   260 	newCompressionWriter   func(io.WriteCloser, int) io.WriteCloser
   261 
   261 
   262 	// Read fields
   262 	// Read fields
   263 	reader        io.ReadCloser // the current reader returned to the application
   263 	reader  io.ReadCloser // the current reader returned to the application
   264 	readErr       error
   264 	readErr error
   265 	br            *bufio.Reader
   265 	br      *bufio.Reader
   266 	readRemaining int64 // bytes remaining in current frame.
   266 	// bytes remaining in current frame.
       
   267 	// set setReadRemaining to safely update this value and prevent overflow
       
   268 	readRemaining int64
   267 	readFinal     bool  // true the current message has more frames.
   269 	readFinal     bool  // true the current message has more frames.
   268 	readLength    int64 // Message size.
   270 	readLength    int64 // Message size.
   269 	readLimit     int64 // Maximum message size.
   271 	readLimit     int64 // Maximum message size.
   270 	readMaskPos   int
   272 	readMaskPos   int
   271 	readMaskKey   [4]byte
   273 	readMaskKey   [4]byte
   318 	c.SetPingHandler(nil)
   320 	c.SetPingHandler(nil)
   319 	c.SetPongHandler(nil)
   321 	c.SetPongHandler(nil)
   320 	return c
   322 	return c
   321 }
   323 }
   322 
   324 
       
   325 // setReadRemaining tracks the number of bytes remaining on the connection. If n
       
   326 // overflows, an ErrReadLimit is returned.
       
   327 func (c *Conn) setReadRemaining(n int64) error {
       
   328 	if n < 0 {
       
   329 		return ErrReadLimit
       
   330 	}
       
   331 
       
   332 	c.readRemaining = n
       
   333 	return nil
       
   334 }
       
   335 
   323 // Subprotocol returns the negotiated protocol for the connection.
   336 // Subprotocol returns the negotiated protocol for the connection.
   324 func (c *Conn) Subprotocol() string {
   337 func (c *Conn) Subprotocol() string {
   325 	return c.subprotocol
   338 	return c.subprotocol
   326 }
   339 }
   327 
   340 
   449 		c.writeFatal(ErrCloseSent)
   462 		c.writeFatal(ErrCloseSent)
   450 	}
   463 	}
   451 	return err
   464 	return err
   452 }
   465 }
   453 
   466 
   454 func (c *Conn) prepWrite(messageType int) error {
   467 // beginMessage prepares a connection and message writer for a new message.
       
   468 func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
   455 	// Close previous writer if not already closed by the application. It's
   469 	// Close previous writer if not already closed by the application. It's
   456 	// probably better to return an error in this situation, but we cannot
   470 	// probably better to return an error in this situation, but we cannot
   457 	// change this without breaking existing applications.
   471 	// change this without breaking existing applications.
   458 	if c.writer != nil {
   472 	if c.writer != nil {
   459 		c.writer.Close()
   473 		c.writer.Close()
   468 	err := c.writeErr
   482 	err := c.writeErr
   469 	c.writeErrMu.Unlock()
   483 	c.writeErrMu.Unlock()
   470 	if err != nil {
   484 	if err != nil {
   471 		return err
   485 		return err
   472 	}
   486 	}
       
   487 
       
   488 	mw.c = c
       
   489 	mw.frameType = messageType
       
   490 	mw.pos = maxFrameHeaderSize
   473 
   491 
   474 	if c.writeBuf == nil {
   492 	if c.writeBuf == nil {
   475 		wpd, ok := c.writePool.Get().(writePoolData)
   493 		wpd, ok := c.writePool.Get().(writePoolData)
   476 		if ok {
   494 		if ok {
   477 			c.writeBuf = wpd.buf
   495 			c.writeBuf = wpd.buf
   489 // previous writer if the application has not already done so.
   507 // previous writer if the application has not already done so.
   490 //
   508 //
   491 // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
   509 // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
   492 // PongMessage) are supported.
   510 // PongMessage) are supported.
   493 func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
   511 func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
   494 	if err := c.prepWrite(messageType); err != nil {
   512 	var mw messageWriter
       
   513 	if err := c.beginMessage(&mw, messageType); err != nil {
   495 		return nil, err
   514 		return nil, err
   496 	}
   515 	}
   497 
   516 	c.writer = &mw
   498 	mw := &messageWriter{
       
   499 		c:         c,
       
   500 		frameType: messageType,
       
   501 		pos:       maxFrameHeaderSize,
       
   502 	}
       
   503 	c.writer = mw
       
   504 	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
   517 	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
   505 		w := c.newCompressionWriter(c.writer, c.compressionLevel)
   518 		w := c.newCompressionWriter(c.writer, c.compressionLevel)
   506 		mw.compress = true
   519 		mw.compress = true
   507 		c.writer = w
   520 		c.writer = w
   508 	}
   521 	}
   515 	pos       int  // end of data in writeBuf.
   528 	pos       int  // end of data in writeBuf.
   516 	frameType int  // type of the current frame.
   529 	frameType int  // type of the current frame.
   517 	err       error
   530 	err       error
   518 }
   531 }
   519 
   532 
   520 func (w *messageWriter) fatal(err error) error {
   533 func (w *messageWriter) endMessage(err error) error {
   521 	if w.err != nil {
   534 	if w.err != nil {
   522 		w.err = err
   535 		return err
   523 		w.c.writer = nil
   536 	}
       
   537 	c := w.c
       
   538 	w.err = err
       
   539 	c.writer = nil
       
   540 	if c.writePool != nil {
       
   541 		c.writePool.Put(writePoolData{buf: c.writeBuf})
       
   542 		c.writeBuf = nil
   524 	}
   543 	}
   525 	return err
   544 	return err
   526 }
   545 }
   527 
   546 
   528 // flushFrame writes buffered data and extra as a frame to the network. The
   547 // flushFrame writes buffered data and extra as a frame to the network. The
   532 	length := w.pos - maxFrameHeaderSize + len(extra)
   551 	length := w.pos - maxFrameHeaderSize + len(extra)
   533 
   552 
   534 	// Check for invalid control frames.
   553 	// Check for invalid control frames.
   535 	if isControl(w.frameType) &&
   554 	if isControl(w.frameType) &&
   536 		(!final || length > maxControlFramePayloadSize) {
   555 		(!final || length > maxControlFramePayloadSize) {
   537 		return w.fatal(errInvalidControlFrame)
   556 		return w.endMessage(errInvalidControlFrame)
   538 	}
   557 	}
   539 
   558 
   540 	b0 := byte(w.frameType)
   559 	b0 := byte(w.frameType)
   541 	if final {
   560 	if final {
   542 		b0 |= finalBit
   561 		b0 |= finalBit
   577 	if !c.isServer {
   596 	if !c.isServer {
   578 		key := newMaskKey()
   597 		key := newMaskKey()
   579 		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
   598 		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
   580 		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
   599 		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
   581 		if len(extra) > 0 {
   600 		if len(extra) > 0 {
   582 			return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
   601 			return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
   583 		}
   602 		}
   584 	}
   603 	}
   585 
   604 
   586 	// Write the buffers to the connection with best-effort detection of
   605 	// Write the buffers to the connection with best-effort detection of
   587 	// concurrent writes. See the concurrency section in the package
   606 	// concurrent writes. See the concurrency section in the package
   598 		panic("concurrent write to websocket connection")
   617 		panic("concurrent write to websocket connection")
   599 	}
   618 	}
   600 	c.isWriting = false
   619 	c.isWriting = false
   601 
   620 
   602 	if err != nil {
   621 	if err != nil {
   603 		return w.fatal(err)
   622 		return w.endMessage(err)
   604 	}
   623 	}
   605 
   624 
   606 	if final {
   625 	if final {
   607 		c.writer = nil
   626 		w.endMessage(errWriteClosed)
   608 		if c.writePool != nil {
       
   609 			c.writePool.Put(writePoolData{buf: c.writeBuf})
       
   610 			c.writeBuf = nil
       
   611 		}
       
   612 		return nil
   627 		return nil
   613 	}
   628 	}
   614 
   629 
   615 	// Setup for next frame.
   630 	// Setup for next frame.
   616 	w.pos = maxFrameHeaderSize
   631 	w.pos = maxFrameHeaderSize
   704 
   719 
   705 func (w *messageWriter) Close() error {
   720 func (w *messageWriter) Close() error {
   706 	if w.err != nil {
   721 	if w.err != nil {
   707 		return w.err
   722 		return w.err
   708 	}
   723 	}
   709 	if err := w.flushFrame(true, nil); err != nil {
   724 	return w.flushFrame(true, nil)
   710 		return err
       
   711 	}
       
   712 	w.err = errWriteClosed
       
   713 	return nil
       
   714 }
   725 }
   715 
   726 
   716 // WritePreparedMessage writes prepared message into connection.
   727 // WritePreparedMessage writes prepared message into connection.
   717 func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
   728 func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
   718 	frameType, frameData, err := pm.frame(prepareKey{
   729 	frameType, frameData, err := pm.frame(prepareKey{
   740 func (c *Conn) WriteMessage(messageType int, data []byte) error {
   751 func (c *Conn) WriteMessage(messageType int, data []byte) error {
   741 
   752 
   742 	if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
   753 	if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
   743 		// Fast path with no allocations and single frame.
   754 		// Fast path with no allocations and single frame.
   744 
   755 
   745 		if err := c.prepWrite(messageType); err != nil {
   756 		var mw messageWriter
       
   757 		if err := c.beginMessage(&mw, messageType); err != nil {
   746 			return err
   758 			return err
   747 		}
   759 		}
   748 		mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
       
   749 		n := copy(c.writeBuf[mw.pos:], data)
   760 		n := copy(c.writeBuf[mw.pos:], data)
   750 		mw.pos += n
   761 		mw.pos += n
   751 		data = data[n:]
   762 		data = data[n:]
   752 		return mw.flushFrame(true, data)
   763 		return mw.flushFrame(true, data)
   753 	}
   764 	}
   790 	}
   801 	}
   791 
   802 
   792 	final := p[0]&finalBit != 0
   803 	final := p[0]&finalBit != 0
   793 	frameType := int(p[0] & 0xf)
   804 	frameType := int(p[0] & 0xf)
   794 	mask := p[1]&maskBit != 0
   805 	mask := p[1]&maskBit != 0
   795 	c.readRemaining = int64(p[1] & 0x7f)
   806 	c.setReadRemaining(int64(p[1] & 0x7f))
   796 
   807 
   797 	c.readDecompress = false
   808 	c.readDecompress = false
   798 	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
   809 	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
   799 		c.readDecompress = true
   810 		c.readDecompress = true
   800 		p[0] &^= rsv1Bit
   811 		p[0] &^= rsv1Bit
   824 		c.readFinal = final
   835 		c.readFinal = final
   825 	default:
   836 	default:
   826 		return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
   837 		return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
   827 	}
   838 	}
   828 
   839 
   829 	// 3. Read and parse frame length.
   840 	// 3. Read and parse frame length as per
       
   841 	// https://tools.ietf.org/html/rfc6455#section-5.2
       
   842 	//
       
   843 	// The length of the "Payload data", in bytes: if 0-125, that is the payload
       
   844 	// length.
       
   845 	// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
       
   846 	// integer are the payload length.
       
   847 	// - If 127, the following 8 bytes interpreted as
       
   848 	// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
       
   849 	// payload length. Multibyte length quantities are expressed in network byte
       
   850 	// order.
   830 
   851 
   831 	switch c.readRemaining {
   852 	switch c.readRemaining {
   832 	case 126:
   853 	case 126:
   833 		p, err := c.read(2)
   854 		p, err := c.read(2)
   834 		if err != nil {
   855 		if err != nil {
   835 			return noFrame, err
   856 			return noFrame, err
   836 		}
   857 		}
   837 		c.readRemaining = int64(binary.BigEndian.Uint16(p))
   858 
       
   859 		if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
       
   860 			return noFrame, err
       
   861 		}
   838 	case 127:
   862 	case 127:
   839 		p, err := c.read(8)
   863 		p, err := c.read(8)
   840 		if err != nil {
   864 		if err != nil {
   841 			return noFrame, err
   865 			return noFrame, err
   842 		}
   866 		}
   843 		c.readRemaining = int64(binary.BigEndian.Uint64(p))
   867 
       
   868 		if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
       
   869 			return noFrame, err
       
   870 		}
   844 	}
   871 	}
   845 
   872 
   846 	// 4. Handle frame masking.
   873 	// 4. Handle frame masking.
   847 
   874 
   848 	if mask != c.isServer {
   875 	if mask != c.isServer {
   861 	// 5. For text and binary messages, enforce read limit and return.
   888 	// 5. For text and binary messages, enforce read limit and return.
   862 
   889 
   863 	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
   890 	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
   864 
   891 
   865 		c.readLength += c.readRemaining
   892 		c.readLength += c.readRemaining
       
   893 		// Don't allow readLength to overflow in the presence of a large readRemaining
       
   894 		// counter.
       
   895 		if c.readLength < 0 {
       
   896 			return noFrame, ErrReadLimit
       
   897 		}
       
   898 
   866 		if c.readLimit > 0 && c.readLength > c.readLimit {
   899 		if c.readLimit > 0 && c.readLength > c.readLimit {
   867 			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
   900 			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
   868 			return noFrame, ErrReadLimit
   901 			return noFrame, ErrReadLimit
   869 		}
   902 		}
   870 
   903 
   874 	// 6. Read control frame payload.
   907 	// 6. Read control frame payload.
   875 
   908 
   876 	var payload []byte
   909 	var payload []byte
   877 	if c.readRemaining > 0 {
   910 	if c.readRemaining > 0 {
   878 		payload, err = c.read(int(c.readRemaining))
   911 		payload, err = c.read(int(c.readRemaining))
   879 		c.readRemaining = 0
   912 		c.setReadRemaining(0)
   880 		if err != nil {
   913 		if err != nil {
   881 			return noFrame, err
   914 			return noFrame, err
   882 		}
   915 		}
   883 		if c.isServer {
   916 		if c.isServer {
   884 			maskBytes(c.readMaskKey, 0, payload)
   917 			maskBytes(c.readMaskKey, 0, payload)
   947 		frameType, err := c.advanceFrame()
   980 		frameType, err := c.advanceFrame()
   948 		if err != nil {
   981 		if err != nil {
   949 			c.readErr = hideTempErr(err)
   982 			c.readErr = hideTempErr(err)
   950 			break
   983 			break
   951 		}
   984 		}
       
   985 
   952 		if frameType == TextMessage || frameType == BinaryMessage {
   986 		if frameType == TextMessage || frameType == BinaryMessage {
   953 			c.messageReader = &messageReader{c}
   987 			c.messageReader = &messageReader{c}
   954 			c.reader = c.messageReader
   988 			c.reader = c.messageReader
   955 			if c.readDecompress {
   989 			if c.readDecompress {
   956 				c.reader = c.newDecompressionReader(c.reader)
   990 				c.reader = c.newDecompressionReader(c.reader)
   987 			n, err := c.br.Read(b)
  1021 			n, err := c.br.Read(b)
   988 			c.readErr = hideTempErr(err)
  1022 			c.readErr = hideTempErr(err)
   989 			if c.isServer {
  1023 			if c.isServer {
   990 				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
  1024 				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
   991 			}
  1025 			}
   992 			c.readRemaining -= int64(n)
  1026 			rem := c.readRemaining
       
  1027 			rem -= int64(n)
       
  1028 			c.setReadRemaining(rem)
   993 			if c.readRemaining > 0 && c.readErr == io.EOF {
  1029 			if c.readRemaining > 0 && c.readErr == io.EOF {
   994 				c.readErr = errUnexpectedEOF
  1030 				c.readErr = errUnexpectedEOF
   995 			}
  1031 			}
   996 			return n, c.readErr
  1032 			return n, c.readErr
   997 		}
  1033 		}
  1039 // not time out.
  1075 // not time out.
  1040 func (c *Conn) SetReadDeadline(t time.Time) error {
  1076 func (c *Conn) SetReadDeadline(t time.Time) error {
  1041 	return c.conn.SetReadDeadline(t)
  1077 	return c.conn.SetReadDeadline(t)
  1042 }
  1078 }
  1043 
  1079 
  1044 // SetReadLimit sets the maximum size for a message read from the peer. If a
  1080 // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
  1045 // message exceeds the limit, the connection sends a close message to the peer
  1081 // message exceeds the limit, the connection sends a close message to the peer
  1046 // and returns ErrReadLimit to the application.
  1082 // and returns ErrReadLimit to the application.
  1047 func (c *Conn) SetReadLimit(limit int64) {
  1083 func (c *Conn) SetReadLimit(limit int64) {
  1048 	c.readLimit = limit
  1084 	c.readLimit = limit
  1049 }
  1085 }