vendor/github.com/gorilla/websocket/client.go
changeset 260 445e01aede7e
parent 251 1c52a0eeb952
equal deleted inserted replaced
259:db4911b0c721 260:445e01aede7e
    46 	}
    46 	}
    47 	return d.Dial(u.String(), requestHeader)
    47 	return d.Dial(u.String(), requestHeader)
    48 }
    48 }
    49 
    49 
    50 // A Dialer contains options for connecting to WebSocket server.
    50 // A Dialer contains options for connecting to WebSocket server.
       
    51 //
       
    52 // It is safe to call Dialer's methods concurrently.
    51 type Dialer struct {
    53 type Dialer struct {
    52 	// NetDial specifies the dial function for creating TCP connections. If
    54 	// NetDial specifies the dial function for creating TCP connections. If
    53 	// NetDial is nil, net.Dial is used.
    55 	// NetDial is nil, net.Dial is used.
    54 	NetDial func(network, addr string) (net.Conn, error)
    56 	NetDial func(network, addr string) (net.Conn, error)
    55 
    57 
    56 	// NetDialContext specifies the dial function for creating TCP connections. If
    58 	// NetDialContext specifies the dial function for creating TCP connections. If
    57 	// NetDialContext is nil, net.DialContext is used.
    59 	// NetDialContext is nil, NetDial is used.
    58 	NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
    60 	NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
       
    61 
       
    62 	// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
       
    63 	// NetDialTLSContext is nil, NetDialContext is used.
       
    64 	// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
       
    65 	// TLSClientConfig is ignored.
       
    66 	NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
    59 
    67 
    60 	// Proxy specifies a function to return a proxy for a given
    68 	// Proxy specifies a function to return a proxy for a given
    61 	// Request. If the function returns a non-nil error, the
    69 	// Request. If the function returns a non-nil error, the
    62 	// request is aborted with the provided error.
    70 	// request is aborted with the provided error.
    63 	// If Proxy is nil or returns a nil *URL, no proxy is used.
    71 	// If Proxy is nil or returns a nil *URL, no proxy is used.
    64 	Proxy func(*http.Request) (*url.URL, error)
    72 	Proxy func(*http.Request) (*url.URL, error)
    65 
    73 
    66 	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
    74 	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
    67 	// If nil, the default configuration is used.
    75 	// If nil, the default configuration is used.
       
    76 	// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
       
    77 	// is done there and TLSClientConfig is ignored.
    68 	TLSClientConfig *tls.Config
    78 	TLSClientConfig *tls.Config
    69 
    79 
    70 	// HandshakeTimeout specifies the duration for the handshake to complete.
    80 	// HandshakeTimeout specifies the duration for the handshake to complete.
    71 	HandshakeTimeout time.Duration
    81 	HandshakeTimeout time.Duration
    72 
    82 
   174 		// User name and password are not allowed in websocket URIs.
   184 		// User name and password are not allowed in websocket URIs.
   175 		return nil, nil, errMalformedURL
   185 		return nil, nil, errMalformedURL
   176 	}
   186 	}
   177 
   187 
   178 	req := &http.Request{
   188 	req := &http.Request{
   179 		Method:     "GET",
   189 		Method:     http.MethodGet,
   180 		URL:        u,
   190 		URL:        u,
   181 		Proto:      "HTTP/1.1",
   191 		Proto:      "HTTP/1.1",
   182 		ProtoMajor: 1,
   192 		ProtoMajor: 1,
   183 		ProtoMinor: 1,
   193 		ProtoMinor: 1,
   184 		Header:     make(http.Header),
   194 		Header:     make(http.Header),
   235 	}
   245 	}
   236 
   246 
   237 	// Get network dial function.
   247 	// Get network dial function.
   238 	var netDial func(network, add string) (net.Conn, error)
   248 	var netDial func(network, add string) (net.Conn, error)
   239 
   249 
   240 	if d.NetDialContext != nil {
   250 	switch u.Scheme {
   241 		netDial = func(network, addr string) (net.Conn, error) {
   251 	case "http":
   242 			return d.NetDialContext(ctx, network, addr)
   252 		if d.NetDialContext != nil {
   243 		}
   253 			netDial = func(network, addr string) (net.Conn, error) {
   244 	} else if d.NetDial != nil {
   254 				return d.NetDialContext(ctx, network, addr)
   245 		netDial = d.NetDial
   255 			}
   246 	} else {
   256 		} else if d.NetDial != nil {
       
   257 			netDial = d.NetDial
       
   258 		}
       
   259 	case "https":
       
   260 		if d.NetDialTLSContext != nil {
       
   261 			netDial = func(network, addr string) (net.Conn, error) {
       
   262 				return d.NetDialTLSContext(ctx, network, addr)
       
   263 			}
       
   264 		} else if d.NetDialContext != nil {
       
   265 			netDial = func(network, addr string) (net.Conn, error) {
       
   266 				return d.NetDialContext(ctx, network, addr)
       
   267 			}
       
   268 		} else if d.NetDial != nil {
       
   269 			netDial = d.NetDial
       
   270 		}
       
   271 	default:
       
   272 		return nil, nil, errMalformedURL
       
   273 	}
       
   274 
       
   275 	if netDial == nil {
   247 		netDialer := &net.Dialer{}
   276 		netDialer := &net.Dialer{}
   248 		netDial = func(network, addr string) (net.Conn, error) {
   277 		netDial = func(network, addr string) (net.Conn, error) {
   249 			return netDialer.DialContext(ctx, network, addr)
   278 			return netDialer.DialContext(ctx, network, addr)
   250 		}
   279 		}
   251 	}
   280 	}
   302 		if netConn != nil {
   331 		if netConn != nil {
   303 			netConn.Close()
   332 			netConn.Close()
   304 		}
   333 		}
   305 	}()
   334 	}()
   306 
   335 
   307 	if u.Scheme == "https" {
   336 	if u.Scheme == "https" && d.NetDialTLSContext == nil {
       
   337 		// If NetDialTLSContext is set, assume that the TLS handshake has already been done
       
   338 
   308 		cfg := cloneTLSConfig(d.TLSClientConfig)
   339 		cfg := cloneTLSConfig(d.TLSClientConfig)
   309 		if cfg.ServerName == "" {
   340 		if cfg.ServerName == "" {
   310 			cfg.ServerName = hostNoPort
   341 			cfg.ServerName = hostNoPort
   311 		}
   342 		}
   312 		tlsConn := tls.Client(netConn, cfg)
   343 		tlsConn := tls.Client(netConn, cfg)
   313 		netConn = tlsConn
   344 		netConn = tlsConn
   314 
   345 
   315 		var err error
   346 		if trace != nil && trace.TLSHandshakeStart != nil {
   316 		if trace != nil {
   347 			trace.TLSHandshakeStart()
   317 			err = doHandshakeWithTrace(trace, tlsConn, cfg)
   348 		}
   318 		} else {
   349 		err := doHandshake(ctx, tlsConn, cfg)
   319 			err = doHandshake(tlsConn, cfg)
   350 		if trace != nil && trace.TLSHandshakeDone != nil {
       
   351 			trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
   320 		}
   352 		}
   321 
   353 
   322 		if err != nil {
   354 		if err != nil {
   323 			return nil, nil, err
   355 			return nil, nil, err
   324 		}
   356 		}
   346 			d.Jar.SetCookies(u, rc)
   378 			d.Jar.SetCookies(u, rc)
   347 		}
   379 		}
   348 	}
   380 	}
   349 
   381 
   350 	if resp.StatusCode != 101 ||
   382 	if resp.StatusCode != 101 ||
   351 		!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
   383 		!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
   352 		!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
   384 		!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
   353 		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
   385 		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
   354 		// Before closing the network connection on return from this
   386 		// Before closing the network connection on return from this
   355 		// function, slurp up some of the response to aid application
   387 		// function, slurp up some of the response to aid application
   356 		// debugging.
   388 		// debugging.
   357 		buf := make([]byte, 1024)
   389 		buf := make([]byte, 1024)
   380 	netConn.SetDeadline(time.Time{})
   412 	netConn.SetDeadline(time.Time{})
   381 	netConn = nil // to avoid close in defer.
   413 	netConn = nil // to avoid close in defer.
   382 	return conn, resp, nil
   414 	return conn, resp, nil
   383 }
   415 }
   384 
   416 
   385 func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
   417 func cloneTLSConfig(cfg *tls.Config) *tls.Config {
   386 	if err := tlsConn.Handshake(); err != nil {
   418 	if cfg == nil {
   387 		return err
   419 		return &tls.Config{}
   388 	}
   420 	}
   389 	if !cfg.InsecureSkipVerify {
   421 	return cfg.Clone()
   390 		if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
   422 }
   391 			return err
       
   392 		}
       
   393 	}
       
   394 	return nil
       
   395 }