46 } |
46 } |
47 return d.Dial(u.String(), requestHeader) |
47 return d.Dial(u.String(), requestHeader) |
48 } |
48 } |
49 |
49 |
50 // A Dialer contains options for connecting to WebSocket server. |
50 // A Dialer contains options for connecting to WebSocket server. |
|
51 // |
|
52 // It is safe to call Dialer's methods concurrently. |
51 type Dialer struct { |
53 type Dialer struct { |
52 // NetDial specifies the dial function for creating TCP connections. If |
54 // NetDial specifies the dial function for creating TCP connections. If |
53 // NetDial is nil, net.Dial is used. |
55 // NetDial is nil, net.Dial is used. |
54 NetDial func(network, addr string) (net.Conn, error) |
56 NetDial func(network, addr string) (net.Conn, error) |
55 |
57 |
56 // NetDialContext specifies the dial function for creating TCP connections. If |
58 // NetDialContext specifies the dial function for creating TCP connections. If |
57 // NetDialContext is nil, net.DialContext is used. |
59 // NetDialContext is nil, NetDial is used. |
58 NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) |
60 NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) |
|
61 |
|
62 // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If |
|
63 // NetDialTLSContext is nil, NetDialContext is used. |
|
64 // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and |
|
65 // TLSClientConfig is ignored. |
|
66 NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) |
59 |
67 |
60 // Proxy specifies a function to return a proxy for a given |
68 // Proxy specifies a function to return a proxy for a given |
61 // Request. If the function returns a non-nil error, the |
69 // Request. If the function returns a non-nil error, the |
62 // request is aborted with the provided error. |
70 // request is aborted with the provided error. |
63 // If Proxy is nil or returns a nil *URL, no proxy is used. |
71 // If Proxy is nil or returns a nil *URL, no proxy is used. |
64 Proxy func(*http.Request) (*url.URL, error) |
72 Proxy func(*http.Request) (*url.URL, error) |
65 |
73 |
66 // TLSClientConfig specifies the TLS configuration to use with tls.Client. |
74 // TLSClientConfig specifies the TLS configuration to use with tls.Client. |
67 // If nil, the default configuration is used. |
75 // If nil, the default configuration is used. |
|
76 // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake |
|
77 // is done there and TLSClientConfig is ignored. |
68 TLSClientConfig *tls.Config |
78 TLSClientConfig *tls.Config |
69 |
79 |
70 // HandshakeTimeout specifies the duration for the handshake to complete. |
80 // HandshakeTimeout specifies the duration for the handshake to complete. |
71 HandshakeTimeout time.Duration |
81 HandshakeTimeout time.Duration |
72 |
82 |
235 } |
245 } |
236 |
246 |
237 // Get network dial function. |
247 // Get network dial function. |
238 var netDial func(network, add string) (net.Conn, error) |
248 var netDial func(network, add string) (net.Conn, error) |
239 |
249 |
240 if d.NetDialContext != nil { |
250 switch u.Scheme { |
241 netDial = func(network, addr string) (net.Conn, error) { |
251 case "http": |
242 return d.NetDialContext(ctx, network, addr) |
252 if d.NetDialContext != nil { |
243 } |
253 netDial = func(network, addr string) (net.Conn, error) { |
244 } else if d.NetDial != nil { |
254 return d.NetDialContext(ctx, network, addr) |
245 netDial = d.NetDial |
255 } |
246 } else { |
256 } else if d.NetDial != nil { |
|
257 netDial = d.NetDial |
|
258 } |
|
259 case "https": |
|
260 if d.NetDialTLSContext != nil { |
|
261 netDial = func(network, addr string) (net.Conn, error) { |
|
262 return d.NetDialTLSContext(ctx, network, addr) |
|
263 } |
|
264 } else if d.NetDialContext != nil { |
|
265 netDial = func(network, addr string) (net.Conn, error) { |
|
266 return d.NetDialContext(ctx, network, addr) |
|
267 } |
|
268 } else if d.NetDial != nil { |
|
269 netDial = d.NetDial |
|
270 } |
|
271 default: |
|
272 return nil, nil, errMalformedURL |
|
273 } |
|
274 |
|
275 if netDial == nil { |
247 netDialer := &net.Dialer{} |
276 netDialer := &net.Dialer{} |
248 netDial = func(network, addr string) (net.Conn, error) { |
277 netDial = func(network, addr string) (net.Conn, error) { |
249 return netDialer.DialContext(ctx, network, addr) |
278 return netDialer.DialContext(ctx, network, addr) |
250 } |
279 } |
251 } |
280 } |
302 if netConn != nil { |
331 if netConn != nil { |
303 netConn.Close() |
332 netConn.Close() |
304 } |
333 } |
305 }() |
334 }() |
306 |
335 |
307 if u.Scheme == "https" { |
336 if u.Scheme == "https" && d.NetDialTLSContext == nil { |
|
337 // If NetDialTLSContext is set, assume that the TLS handshake has already been done |
|
338 |
308 cfg := cloneTLSConfig(d.TLSClientConfig) |
339 cfg := cloneTLSConfig(d.TLSClientConfig) |
309 if cfg.ServerName == "" { |
340 if cfg.ServerName == "" { |
310 cfg.ServerName = hostNoPort |
341 cfg.ServerName = hostNoPort |
311 } |
342 } |
312 tlsConn := tls.Client(netConn, cfg) |
343 tlsConn := tls.Client(netConn, cfg) |
313 netConn = tlsConn |
344 netConn = tlsConn |
314 |
345 |
315 var err error |
346 if trace != nil && trace.TLSHandshakeStart != nil { |
316 if trace != nil { |
347 trace.TLSHandshakeStart() |
317 err = doHandshakeWithTrace(trace, tlsConn, cfg) |
348 } |
318 } else { |
349 err := doHandshake(ctx, tlsConn, cfg) |
319 err = doHandshake(tlsConn, cfg) |
350 if trace != nil && trace.TLSHandshakeDone != nil { |
|
351 trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) |
320 } |
352 } |
321 |
353 |
322 if err != nil { |
354 if err != nil { |
323 return nil, nil, err |
355 return nil, nil, err |
324 } |
356 } |
346 d.Jar.SetCookies(u, rc) |
378 d.Jar.SetCookies(u, rc) |
347 } |
379 } |
348 } |
380 } |
349 |
381 |
350 if resp.StatusCode != 101 || |
382 if resp.StatusCode != 101 || |
351 !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || |
383 !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || |
352 !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || |
384 !tokenListContainsValue(resp.Header, "Connection", "upgrade") || |
353 resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { |
385 resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { |
354 // Before closing the network connection on return from this |
386 // Before closing the network connection on return from this |
355 // function, slurp up some of the response to aid application |
387 // function, slurp up some of the response to aid application |
356 // debugging. |
388 // debugging. |
357 buf := make([]byte, 1024) |
389 buf := make([]byte, 1024) |
380 netConn.SetDeadline(time.Time{}) |
412 netConn.SetDeadline(time.Time{}) |
381 netConn = nil // to avoid close in defer. |
413 netConn = nil // to avoid close in defer. |
382 return conn, resp, nil |
414 return conn, resp, nil |
383 } |
415 } |
384 |
416 |
385 func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { |
417 func cloneTLSConfig(cfg *tls.Config) *tls.Config { |
386 if err := tlsConn.Handshake(); err != nil { |
418 if cfg == nil { |
387 return err |
419 return &tls.Config{} |
388 } |
420 } |
389 if !cfg.InsecureSkipVerify { |
421 return cfg.Clone() |
390 if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { |
422 } |
391 return err |
|
392 } |
|
393 } |
|
394 return nil |
|
395 } |
|