vendor/github.com/gorilla/websocket/client.go
changeset 260 445e01aede7e
parent 251 1c52a0eeb952
--- a/vendor/github.com/gorilla/websocket/client.go	Tue Aug 23 22:33:28 2022 +0200
+++ b/vendor/github.com/gorilla/websocket/client.go	Tue Aug 23 22:39:43 2022 +0200
@@ -48,15 +48,23 @@
 }
 
 // A Dialer contains options for connecting to WebSocket server.
+//
+// It is safe to call Dialer's methods concurrently.
 type Dialer struct {
 	// NetDial specifies the dial function for creating TCP connections. If
 	// NetDial is nil, net.Dial is used.
 	NetDial func(network, addr string) (net.Conn, error)
 
 	// NetDialContext specifies the dial function for creating TCP connections. If
-	// NetDialContext is nil, net.DialContext is used.
+	// NetDialContext is nil, NetDial is used.
 	NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
 
+	// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
+	// NetDialTLSContext is nil, NetDialContext is used.
+	// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
+	// TLSClientConfig is ignored.
+	NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
 	// Proxy specifies a function to return a proxy for a given
 	// Request. If the function returns a non-nil error, the
 	// request is aborted with the provided error.
@@ -65,6 +73,8 @@
 
 	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
 	// If nil, the default configuration is used.
+	// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
+	// is done there and TLSClientConfig is ignored.
 	TLSClientConfig *tls.Config
 
 	// HandshakeTimeout specifies the duration for the handshake to complete.
@@ -176,7 +186,7 @@
 	}
 
 	req := &http.Request{
-		Method:     "GET",
+		Method:     http.MethodGet,
 		URL:        u,
 		Proto:      "HTTP/1.1",
 		ProtoMajor: 1,
@@ -237,13 +247,32 @@
 	// Get network dial function.
 	var netDial func(network, add string) (net.Conn, error)
 
-	if d.NetDialContext != nil {
-		netDial = func(network, addr string) (net.Conn, error) {
-			return d.NetDialContext(ctx, network, addr)
+	switch u.Scheme {
+	case "http":
+		if d.NetDialContext != nil {
+			netDial = func(network, addr string) (net.Conn, error) {
+				return d.NetDialContext(ctx, network, addr)
+			}
+		} else if d.NetDial != nil {
+			netDial = d.NetDial
 		}
-	} else if d.NetDial != nil {
-		netDial = d.NetDial
-	} else {
+	case "https":
+		if d.NetDialTLSContext != nil {
+			netDial = func(network, addr string) (net.Conn, error) {
+				return d.NetDialTLSContext(ctx, network, addr)
+			}
+		} else if d.NetDialContext != nil {
+			netDial = func(network, addr string) (net.Conn, error) {
+				return d.NetDialContext(ctx, network, addr)
+			}
+		} else if d.NetDial != nil {
+			netDial = d.NetDial
+		}
+	default:
+		return nil, nil, errMalformedURL
+	}
+
+	if netDial == nil {
 		netDialer := &net.Dialer{}
 		netDial = func(network, addr string) (net.Conn, error) {
 			return netDialer.DialContext(ctx, network, addr)
@@ -304,7 +333,9 @@
 		}
 	}()
 
-	if u.Scheme == "https" {
+	if u.Scheme == "https" && d.NetDialTLSContext == nil {
+		// If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
 		cfg := cloneTLSConfig(d.TLSClientConfig)
 		if cfg.ServerName == "" {
 			cfg.ServerName = hostNoPort
@@ -312,11 +343,12 @@
 		tlsConn := tls.Client(netConn, cfg)
 		netConn = tlsConn
 
-		var err error
-		if trace != nil {
-			err = doHandshakeWithTrace(trace, tlsConn, cfg)
-		} else {
-			err = doHandshake(tlsConn, cfg)
+		if trace != nil && trace.TLSHandshakeStart != nil {
+			trace.TLSHandshakeStart()
+		}
+		err := doHandshake(ctx, tlsConn, cfg)
+		if trace != nil && trace.TLSHandshakeDone != nil {
+			trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
 		}
 
 		if err != nil {
@@ -348,8 +380,8 @@
 	}
 
 	if resp.StatusCode != 101 ||
-		!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
-		!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
+		!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+		!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
 		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
 		// Before closing the network connection on return from this
 		// function, slurp up some of the response to aid application
@@ -382,14 +414,9 @@
 	return conn, resp, nil
 }
 
-func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
-	if err := tlsConn.Handshake(); err != nil {
-		return err
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+	if cfg == nil {
+		return &tls.Config{}
 	}
-	if !cfg.InsecureSkipVerify {
-		if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
-			return err
-		}
-	}
-	return nil
+	return cfg.Clone()
 }