vendor/golang.org/x/oauth2/internal/token.go
changeset 251 1c52a0eeb952
parent 242 2a9ec03fe5a1
equal deleted inserted replaced
250:c040f992052f 251:1c52a0eeb952
     3 // license that can be found in the LICENSE file.
     3 // license that can be found in the LICENSE file.
     4 
     4 
     5 package internal
     5 package internal
     6 
     6 
     7 import (
     7 import (
       
     8 	"context"
     8 	"encoding/json"
     9 	"encoding/json"
     9 	"errors"
    10 	"errors"
    10 	"fmt"
    11 	"fmt"
    11 	"io"
    12 	"io"
    12 	"io/ioutil"
    13 	"io/ioutil"
       
    14 	"math"
    13 	"mime"
    15 	"mime"
    14 	"net/http"
    16 	"net/http"
    15 	"net/url"
    17 	"net/url"
    16 	"strconv"
    18 	"strconv"
    17 	"strings"
    19 	"strings"
       
    20 	"sync"
    18 	"time"
    21 	"time"
    19 
    22 
    20 	"golang.org/x/net/context"
       
    21 	"golang.org/x/net/context/ctxhttp"
    23 	"golang.org/x/net/context/ctxhttp"
    22 )
    24 )
    23 
    25 
    24 // Token represents the credentials used to authorize
    26 // Token represents the credentials used to authorize
    25 // the requests to access protected resources on the OAuth 2.0
    27 // the requests to access protected resources on the OAuth 2.0
    59 type tokenJSON struct {
    61 type tokenJSON struct {
    60 	AccessToken  string         `json:"access_token"`
    62 	AccessToken  string         `json:"access_token"`
    61 	TokenType    string         `json:"token_type"`
    63 	TokenType    string         `json:"token_type"`
    62 	RefreshToken string         `json:"refresh_token"`
    64 	RefreshToken string         `json:"refresh_token"`
    63 	ExpiresIn    expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
    65 	ExpiresIn    expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
    64 	Expires      expirationTime `json:"expires"`    // broken Facebook spelling of expires_in
       
    65 }
    66 }
    66 
    67 
    67 func (e *tokenJSON) expiry() (t time.Time) {
    68 func (e *tokenJSON) expiry() (t time.Time) {
    68 	if v := e.ExpiresIn; v != 0 {
    69 	if v := e.ExpiresIn; v != 0 {
    69 		return time.Now().Add(time.Duration(v) * time.Second)
    70 		return time.Now().Add(time.Duration(v) * time.Second)
    70 	}
    71 	}
    71 	if v := e.Expires; v != 0 {
       
    72 		return time.Now().Add(time.Duration(v) * time.Second)
       
    73 	}
       
    74 	return
    72 	return
    75 }
    73 }
    76 
    74 
    77 type expirationTime int32
    75 type expirationTime int32
    78 
    76 
    79 func (e *expirationTime) UnmarshalJSON(b []byte) error {
    77 func (e *expirationTime) UnmarshalJSON(b []byte) error {
       
    78 	if len(b) == 0 || string(b) == "null" {
       
    79 		return nil
       
    80 	}
    80 	var n json.Number
    81 	var n json.Number
    81 	err := json.Unmarshal(b, &n)
    82 	err := json.Unmarshal(b, &n)
    82 	if err != nil {
    83 	if err != nil {
    83 		return err
    84 		return err
    84 	}
    85 	}
    85 	i, err := n.Int64()
    86 	i, err := n.Int64()
    86 	if err != nil {
    87 	if err != nil {
    87 		return err
    88 		return err
    88 	}
    89 	}
       
    90 	if i > math.MaxInt32 {
       
    91 		i = math.MaxInt32
       
    92 	}
    89 	*e = expirationTime(i)
    93 	*e = expirationTime(i)
    90 	return nil
    94 	return nil
    91 }
    95 }
    92 
    96 
    93 var brokenAuthHeaderProviders = []string{
    97 // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
    94 	"https://accounts.google.com/",
    98 //
    95 	"https://api.codeswholesale.com/oauth/token",
    99 // Deprecated: this function no longer does anything. Caller code that
    96 	"https://api.dropbox.com/",
   100 // wants to avoid potential extra HTTP requests made during
    97 	"https://api.dropboxapi.com/",
   101 // auto-probing of the provider's auth style should set
    98 	"https://api.instagram.com/",
   102 // Endpoint.AuthStyle.
    99 	"https://api.netatmo.net/",
   103 func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
   100 	"https://api.odnoklassniki.ru/",
   104 
   101 	"https://api.pushbullet.com/",
   105 // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
   102 	"https://api.soundcloud.com/",
   106 type AuthStyle int
   103 	"https://api.twitch.tv/",
   107 
   104 	"https://id.twitch.tv/",
   108 const (
   105 	"https://app.box.com/",
   109 	AuthStyleUnknown  AuthStyle = 0
   106 	"https://api.box.com/",
   110 	AuthStyleInParams AuthStyle = 1
   107 	"https://connect.stripe.com/",
   111 	AuthStyleInHeader AuthStyle = 2
   108 	"https://login.mailchimp.com/",
   112 )
   109 	"https://login.microsoftonline.com/",
   113 
   110 	"https://login.salesforce.com/",
   114 // authStyleCache is the set of tokenURLs we've successfully used via
   111 	"https://login.windows.net",
   115 // RetrieveToken and which style auth we ended up using.
   112 	"https://login.live.com/",
   116 // It's called a cache, but it doesn't (yet?) shrink. It's expected that
   113 	"https://oauth.sandbox.trainingpeaks.com/",
   117 // the set of OAuth2 servers a program contacts over time is fixed and
   114 	"https://oauth.trainingpeaks.com/",
   118 // small.
   115 	"https://oauth.vk.com/",
   119 var authStyleCache struct {
   116 	"https://openapi.baidu.com/",
   120 	sync.Mutex
   117 	"https://slack.com/",
   121 	m map[string]AuthStyle // keyed by tokenURL
   118 	"https://test-sandbox.auth.corp.google.com",
   122 }
   119 	"https://test.salesforce.com/",
   123 
   120 	"https://user.gini.net/",
   124 // ResetAuthCache resets the global authentication style cache used
   121 	"https://www.douban.com/",
   125 // for AuthStyleUnknown token requests.
   122 	"https://www.googleapis.com/",
   126 func ResetAuthCache() {
   123 	"https://www.linkedin.com/",
   127 	authStyleCache.Lock()
   124 	"https://www.strava.com/oauth/",
   128 	defer authStyleCache.Unlock()
   125 	"https://www.wunderlist.com/oauth/",
   129 	authStyleCache.m = nil
   126 	"https://api.patreon.com/",
   130 }
   127 	"https://sandbox.codeswholesale.com/oauth/token",
   131 
   128 	"https://api.sipgate.com/v1/authorization/oauth",
   132 // lookupAuthStyle reports which auth style we last used with tokenURL
   129 	"https://api.medium.com/v1/tokens",
   133 // when calling RetrieveToken and whether we have ever done so.
   130 	"https://log.finalsurge.com/oauth/token",
   134 func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
   131 	"https://multisport.todaysplan.com.au/rest/oauth/access_token",
   135 	authStyleCache.Lock()
   132 	"https://whats.todaysplan.com.au/rest/oauth/access_token",
   136 	defer authStyleCache.Unlock()
   133 	"https://stackoverflow.com/oauth/access_token",
   137 	style, ok = authStyleCache.m[tokenURL]
   134 	"https://account.health.nokia.com",
   138 	return
   135 }
   139 }
   136 
   140 
   137 // brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
   141 // setAuthStyle adds an entry to authStyleCache, documented above.
   138 var brokenAuthHeaderDomains = []string{
   142 func setAuthStyle(tokenURL string, v AuthStyle) {
   139 	".auth0.com",
   143 	authStyleCache.Lock()
   140 	".force.com",
   144 	defer authStyleCache.Unlock()
   141 	".myshopify.com",
   145 	if authStyleCache.m == nil {
   142 	".okta.com",
   146 		authStyleCache.m = make(map[string]AuthStyle)
   143 	".oktapreview.com",
   147 	}
   144 }
   148 	authStyleCache.m[tokenURL] = v
   145 
   149 }
   146 func RegisterBrokenAuthHeaderProvider(tokenURL string) {
   150 
   147 	brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
   151 // newTokenRequest returns a new *http.Request to retrieve a new token
   148 }
   152 // from tokenURL using the provided clientID, clientSecret, and POST
   149 
   153 // body parameters.
   150 // providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
   154 //
   151 // implements the OAuth2 spec correctly
   155 // inParams is whether the clientID & clientSecret should be encoded
   152 // See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
   156 // as the POST body. An 'inParams' value of true means to send it in
   153 // In summary:
   157 // the POST body (along with any values in v); false means to send it
   154 // - Reddit only accepts client secret in the Authorization header
   158 // in the Authorization header.
   155 // - Dropbox accepts either it in URL param or Auth header, but not both.
   159 func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
   156 // - Google only accepts URL param (not spec compliant?), not Auth header
   160 	if authStyle == AuthStyleInParams {
   157 // - Stripe only accepts client secret in Auth header with Bearer method, not Basic
   161 		v = cloneURLValues(v)
   158 func providerAuthHeaderWorks(tokenURL string) bool {
       
   159 	for _, s := range brokenAuthHeaderProviders {
       
   160 		if strings.HasPrefix(tokenURL, s) {
       
   161 			// Some sites fail to implement the OAuth2 spec fully.
       
   162 			return false
       
   163 		}
       
   164 	}
       
   165 
       
   166 	if u, err := url.Parse(tokenURL); err == nil {
       
   167 		for _, s := range brokenAuthHeaderDomains {
       
   168 			if strings.HasSuffix(u.Host, s) {
       
   169 				return false
       
   170 			}
       
   171 		}
       
   172 	}
       
   173 
       
   174 	// Assume the provider implements the spec properly
       
   175 	// otherwise. We can add more exceptions as they're
       
   176 	// discovered. We will _not_ be adding configurable hooks
       
   177 	// to this package to let users select server bugs.
       
   178 	return true
       
   179 }
       
   180 
       
   181 func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
       
   182 	bustedAuth := !providerAuthHeaderWorks(tokenURL)
       
   183 	if bustedAuth {
       
   184 		if clientID != "" {
   162 		if clientID != "" {
   185 			v.Set("client_id", clientID)
   163 			v.Set("client_id", clientID)
   186 		}
   164 		}
   187 		if clientSecret != "" {
   165 		if clientSecret != "" {
   188 			v.Set("client_secret", clientSecret)
   166 			v.Set("client_secret", clientSecret)
   191 	req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
   169 	req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
   192 	if err != nil {
   170 	if err != nil {
   193 		return nil, err
   171 		return nil, err
   194 	}
   172 	}
   195 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   173 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   196 	if !bustedAuth {
   174 	if authStyle == AuthStyleInHeader {
   197 		req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
   175 		req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
   198 	}
   176 	}
       
   177 	return req, nil
       
   178 }
       
   179 
       
   180 func cloneURLValues(v url.Values) url.Values {
       
   181 	v2 := make(url.Values, len(v))
       
   182 	for k, vv := range v {
       
   183 		v2[k] = append([]string(nil), vv...)
       
   184 	}
       
   185 	return v2
       
   186 }
       
   187 
       
   188 func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
       
   189 	needsAuthStyleProbe := authStyle == 0
       
   190 	if needsAuthStyleProbe {
       
   191 		if style, ok := lookupAuthStyle(tokenURL); ok {
       
   192 			authStyle = style
       
   193 			needsAuthStyleProbe = false
       
   194 		} else {
       
   195 			authStyle = AuthStyleInHeader // the first way we'll try
       
   196 		}
       
   197 	}
       
   198 	req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
       
   199 	if err != nil {
       
   200 		return nil, err
       
   201 	}
       
   202 	token, err := doTokenRoundTrip(ctx, req)
       
   203 	if err != nil && needsAuthStyleProbe {
       
   204 		// If we get an error, assume the server wants the
       
   205 		// clientID & clientSecret in a different form.
       
   206 		// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
       
   207 		// In summary:
       
   208 		// - Reddit only accepts client secret in the Authorization header
       
   209 		// - Dropbox accepts either it in URL param or Auth header, but not both.
       
   210 		// - Google only accepts URL param (not spec compliant?), not Auth header
       
   211 		// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
       
   212 		//
       
   213 		// We used to maintain a big table in this code of all the sites and which way
       
   214 		// they went, but maintaining it didn't scale & got annoying.
       
   215 		// So just try both ways.
       
   216 		authStyle = AuthStyleInParams // the second way we'll try
       
   217 		req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
       
   218 		token, err = doTokenRoundTrip(ctx, req)
       
   219 	}
       
   220 	if needsAuthStyleProbe && err == nil {
       
   221 		setAuthStyle(tokenURL, authStyle)
       
   222 	}
       
   223 	// Don't overwrite `RefreshToken` with an empty value
       
   224 	// if this was a token refreshing request.
       
   225 	if token != nil && token.RefreshToken == "" {
       
   226 		token.RefreshToken = v.Get("refresh_token")
       
   227 	}
       
   228 	return token, err
       
   229 }
       
   230 
       
   231 func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
   199 	r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
   232 	r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
   200 	if err != nil {
   233 	if err != nil {
   201 		return nil, err
   234 		return nil, err
   202 	}
   235 	}
   203 	defer r.Body.Close()
       
   204 	body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
   236 	body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
       
   237 	r.Body.Close()
   205 	if err != nil {
   238 	if err != nil {
   206 		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
   239 		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
   207 	}
   240 	}
   208 	if code := r.StatusCode; code < 200 || code > 299 {
   241 	if code := r.StatusCode; code < 200 || code > 299 {
   209 		return nil, &RetrieveError{
   242 		return nil, &RetrieveError{
   225 			TokenType:    vals.Get("token_type"),
   258 			TokenType:    vals.Get("token_type"),
   226 			RefreshToken: vals.Get("refresh_token"),
   259 			RefreshToken: vals.Get("refresh_token"),
   227 			Raw:          vals,
   260 			Raw:          vals,
   228 		}
   261 		}
   229 		e := vals.Get("expires_in")
   262 		e := vals.Get("expires_in")
   230 		if e == "" {
       
   231 			// TODO(jbd): Facebook's OAuth2 implementation is broken and
       
   232 			// returns expires_in field in expires. Remove the fallback to expires,
       
   233 			// when Facebook fixes their implementation.
       
   234 			e = vals.Get("expires")
       
   235 		}
       
   236 		expires, _ := strconv.Atoi(e)
   263 		expires, _ := strconv.Atoi(e)
   237 		if expires != 0 {
   264 		if expires != 0 {
   238 			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
   265 			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
   239 		}
   266 		}
   240 	default:
   267 	default:
   249 			Expiry:       tj.expiry(),
   276 			Expiry:       tj.expiry(),
   250 			Raw:          make(map[string]interface{}),
   277 			Raw:          make(map[string]interface{}),
   251 		}
   278 		}
   252 		json.Unmarshal(body, &token.Raw) // no error checks for optional fields
   279 		json.Unmarshal(body, &token.Raw) // no error checks for optional fields
   253 	}
   280 	}
   254 	// Don't overwrite `RefreshToken` with an empty value
       
   255 	// if this was a token refreshing request.
       
   256 	if token.RefreshToken == "" {
       
   257 		token.RefreshToken = v.Get("refresh_token")
       
   258 	}
       
   259 	if token.AccessToken == "" {
   281 	if token.AccessToken == "" {
   260 		return token, errors.New("oauth2: server response missing access_token")
   282 		return nil, errors.New("oauth2: server response missing access_token")
   261 	}
   283 	}
   262 	return token, nil
   284 	return token, nil
   263 }
   285 }
   264 
   286 
   265 type RetrieveError struct {
   287 type RetrieveError struct {