vendor/golang.org/x/oauth2/internal/token.go
changeset 251 1c52a0eeb952
parent 242 2a9ec03fe5a1
--- a/vendor/golang.org/x/oauth2/internal/token.go	Wed Sep 18 19:17:42 2019 +0200
+++ b/vendor/golang.org/x/oauth2/internal/token.go	Sun Feb 16 18:54:01 2020 +0100
@@ -5,19 +5,21 @@
 package internal
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
+	"math"
 	"mime"
 	"net/http"
 	"net/url"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 
-	"golang.org/x/net/context"
 	"golang.org/x/net/context/ctxhttp"
 )
 
@@ -61,22 +63,21 @@
 	TokenType    string         `json:"token_type"`
 	RefreshToken string         `json:"refresh_token"`
 	ExpiresIn    expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
-	Expires      expirationTime `json:"expires"`    // broken Facebook spelling of expires_in
 }
 
 func (e *tokenJSON) expiry() (t time.Time) {
 	if v := e.ExpiresIn; v != 0 {
 		return time.Now().Add(time.Duration(v) * time.Second)
 	}
-	if v := e.Expires; v != 0 {
-		return time.Now().Add(time.Duration(v) * time.Second)
-	}
 	return
 }
 
 type expirationTime int32
 
 func (e *expirationTime) UnmarshalJSON(b []byte) error {
+	if len(b) == 0 || string(b) == "null" {
+		return nil
+	}
 	var n json.Number
 	err := json.Unmarshal(b, &n)
 	if err != nil {
@@ -86,101 +87,78 @@
 	if err != nil {
 		return err
 	}
+	if i > math.MaxInt32 {
+		i = math.MaxInt32
+	}
 	*e = expirationTime(i)
 	return nil
 }
 
-var brokenAuthHeaderProviders = []string{
-	"https://accounts.google.com/",
-	"https://api.codeswholesale.com/oauth/token",
-	"https://api.dropbox.com/",
-	"https://api.dropboxapi.com/",
-	"https://api.instagram.com/",
-	"https://api.netatmo.net/",
-	"https://api.odnoklassniki.ru/",
-	"https://api.pushbullet.com/",
-	"https://api.soundcloud.com/",
-	"https://api.twitch.tv/",
-	"https://id.twitch.tv/",
-	"https://app.box.com/",
-	"https://api.box.com/",
-	"https://connect.stripe.com/",
-	"https://login.mailchimp.com/",
-	"https://login.microsoftonline.com/",
-	"https://login.salesforce.com/",
-	"https://login.windows.net",
-	"https://login.live.com/",
-	"https://oauth.sandbox.trainingpeaks.com/",
-	"https://oauth.trainingpeaks.com/",
-	"https://oauth.vk.com/",
-	"https://openapi.baidu.com/",
-	"https://slack.com/",
-	"https://test-sandbox.auth.corp.google.com",
-	"https://test.salesforce.com/",
-	"https://user.gini.net/",
-	"https://www.douban.com/",
-	"https://www.googleapis.com/",
-	"https://www.linkedin.com/",
-	"https://www.strava.com/oauth/",
-	"https://www.wunderlist.com/oauth/",
-	"https://api.patreon.com/",
-	"https://sandbox.codeswholesale.com/oauth/token",
-	"https://api.sipgate.com/v1/authorization/oauth",
-	"https://api.medium.com/v1/tokens",
-	"https://log.finalsurge.com/oauth/token",
-	"https://multisport.todaysplan.com.au/rest/oauth/access_token",
-	"https://whats.todaysplan.com.au/rest/oauth/access_token",
-	"https://stackoverflow.com/oauth/access_token",
-	"https://account.health.nokia.com",
+// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
+//
+// Deprecated: this function no longer does anything. Caller code that
+// wants to avoid potential extra HTTP requests made during
+// auto-probing of the provider's auth style should set
+// Endpoint.AuthStyle.
+func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
+
+// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
+type AuthStyle int
+
+const (
+	AuthStyleUnknown  AuthStyle = 0
+	AuthStyleInParams AuthStyle = 1
+	AuthStyleInHeader AuthStyle = 2
+)
+
+// authStyleCache is the set of tokenURLs we've successfully used via
+// RetrieveToken and which style auth we ended up using.
+// It's called a cache, but it doesn't (yet?) shrink. It's expected that
+// the set of OAuth2 servers a program contacts over time is fixed and
+// small.
+var authStyleCache struct {
+	sync.Mutex
+	m map[string]AuthStyle // keyed by tokenURL
 }
 
-// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
-var brokenAuthHeaderDomains = []string{
-	".auth0.com",
-	".force.com",
-	".myshopify.com",
-	".okta.com",
-	".oktapreview.com",
+// ResetAuthCache resets the global authentication style cache used
+// for AuthStyleUnknown token requests.
+func ResetAuthCache() {
+	authStyleCache.Lock()
+	defer authStyleCache.Unlock()
+	authStyleCache.m = nil
 }
 
-func RegisterBrokenAuthHeaderProvider(tokenURL string) {
-	brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
+// lookupAuthStyle reports which auth style we last used with tokenURL
+// when calling RetrieveToken and whether we have ever done so.
+func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
+	authStyleCache.Lock()
+	defer authStyleCache.Unlock()
+	style, ok = authStyleCache.m[tokenURL]
+	return
 }
 
-// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
-// implements the OAuth2 spec correctly
-// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
-// In summary:
-// - Reddit only accepts client secret in the Authorization header
-// - Dropbox accepts either it in URL param or Auth header, but not both.
-// - Google only accepts URL param (not spec compliant?), not Auth header
-// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
-func providerAuthHeaderWorks(tokenURL string) bool {
-	for _, s := range brokenAuthHeaderProviders {
-		if strings.HasPrefix(tokenURL, s) {
-			// Some sites fail to implement the OAuth2 spec fully.
-			return false
-		}
+// setAuthStyle adds an entry to authStyleCache, documented above.
+func setAuthStyle(tokenURL string, v AuthStyle) {
+	authStyleCache.Lock()
+	defer authStyleCache.Unlock()
+	if authStyleCache.m == nil {
+		authStyleCache.m = make(map[string]AuthStyle)
 	}
-
-	if u, err := url.Parse(tokenURL); err == nil {
-		for _, s := range brokenAuthHeaderDomains {
-			if strings.HasSuffix(u.Host, s) {
-				return false
-			}
-		}
-	}
-
-	// Assume the provider implements the spec properly
-	// otherwise. We can add more exceptions as they're
-	// discovered. We will _not_ be adding configurable hooks
-	// to this package to let users select server bugs.
-	return true
+	authStyleCache.m[tokenURL] = v
 }
 
-func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
-	bustedAuth := !providerAuthHeaderWorks(tokenURL)
-	if bustedAuth {
+// newTokenRequest returns a new *http.Request to retrieve a new token
+// from tokenURL using the provided clientID, clientSecret, and POST
+// body parameters.
+//
+// inParams is whether the clientID & clientSecret should be encoded
+// as the POST body. An 'inParams' value of true means to send it in
+// the POST body (along with any values in v); false means to send it
+// in the Authorization header.
+func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
+	if authStyle == AuthStyleInParams {
+		v = cloneURLValues(v)
 		if clientID != "" {
 			v.Set("client_id", clientID)
 		}
@@ -193,15 +171,70 @@
 		return nil, err
 	}
 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-	if !bustedAuth {
+	if authStyle == AuthStyleInHeader {
 		req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
 	}
+	return req, nil
+}
+
+func cloneURLValues(v url.Values) url.Values {
+	v2 := make(url.Values, len(v))
+	for k, vv := range v {
+		v2[k] = append([]string(nil), vv...)
+	}
+	return v2
+}
+
+func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
+	needsAuthStyleProbe := authStyle == 0
+	if needsAuthStyleProbe {
+		if style, ok := lookupAuthStyle(tokenURL); ok {
+			authStyle = style
+			needsAuthStyleProbe = false
+		} else {
+			authStyle = AuthStyleInHeader // the first way we'll try
+		}
+	}
+	req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
+	if err != nil {
+		return nil, err
+	}
+	token, err := doTokenRoundTrip(ctx, req)
+	if err != nil && needsAuthStyleProbe {
+		// If we get an error, assume the server wants the
+		// clientID & clientSecret in a different form.
+		// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
+		// In summary:
+		// - Reddit only accepts client secret in the Authorization header
+		// - Dropbox accepts either it in URL param or Auth header, but not both.
+		// - Google only accepts URL param (not spec compliant?), not Auth header
+		// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
+		//
+		// We used to maintain a big table in this code of all the sites and which way
+		// they went, but maintaining it didn't scale & got annoying.
+		// So just try both ways.
+		authStyle = AuthStyleInParams // the second way we'll try
+		req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
+		token, err = doTokenRoundTrip(ctx, req)
+	}
+	if needsAuthStyleProbe && err == nil {
+		setAuthStyle(tokenURL, authStyle)
+	}
+	// Don't overwrite `RefreshToken` with an empty value
+	// if this was a token refreshing request.
+	if token != nil && token.RefreshToken == "" {
+		token.RefreshToken = v.Get("refresh_token")
+	}
+	return token, err
+}
+
+func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
 	r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
 	if err != nil {
 		return nil, err
 	}
-	defer r.Body.Close()
 	body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
+	r.Body.Close()
 	if err != nil {
 		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
 	}
@@ -227,12 +260,6 @@
 			Raw:          vals,
 		}
 		e := vals.Get("expires_in")
-		if e == "" {
-			// TODO(jbd): Facebook's OAuth2 implementation is broken and
-			// returns expires_in field in expires. Remove the fallback to expires,
-			// when Facebook fixes their implementation.
-			e = vals.Get("expires")
-		}
 		expires, _ := strconv.Atoi(e)
 		if expires != 0 {
 			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
@@ -251,13 +278,8 @@
 		}
 		json.Unmarshal(body, &token.Raw) // no error checks for optional fields
 	}
-	// Don't overwrite `RefreshToken` with an empty value
-	// if this was a token refreshing request.
-	if token.RefreshToken == "" {
-		token.RefreshToken = v.Get("refresh_token")
-	}
 	if token.AccessToken == "" {
-		return token, errors.New("oauth2: server response missing access_token")
+		return nil, errors.New("oauth2: server response missing access_token")
 	}
 	return token, nil
 }