aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/Azure/go-autorest/autorest/azure/token.go
blob: cfcd030114c63d5e3a742bd1be51fb21ca97fedc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
package azure

import (
    "crypto/rand"
    "crypto/rsa"
    "crypto/sha1"
    "crypto/x509"
    "encoding/base64"
    "fmt"
    "net/http"
    "net/url"
    "strconv"
    "time"

    "github.com/Azure/go-autorest/autorest"
    "github.com/dgrijalva/jwt-go"
)

const (
    defaultRefresh = 5 * time.Minute
    tokenBaseDate  = "1970-01-01T00:00:00Z"

    // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
    OAuthGrantTypeDeviceCode = "device_code"

    // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
    OAuthGrantTypeClientCredentials = "client_credentials"

    // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
    OAuthGrantTypeRefreshToken = "refresh_token"
)

var expirationBase time.Time

func init() {
    expirationBase, _ = time.Parse(time.RFC3339, tokenBaseDate)
}

// TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh
type TokenRefreshCallback func(Token) error

// Token encapsulates the access token used to authorize Azure requests.
type Token struct {
    AccessToken  string `json:"access_token"`
    RefreshToken string `json:"refresh_token"`

    ExpiresIn string `json:"expires_in"`
    ExpiresOn string `json:"expires_on"`
    NotBefore string `json:"not_before"`

    Resource string `json:"resource"`
    Type     string `json:"token_type"`
}

// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
    s, err := strconv.Atoi(t.ExpiresOn)
    if err != nil {
        s = -3600
    }
    return expirationBase.Add(time.Duration(s) * time.Second).UTC()
}

// IsExpired returns true if the Token is expired, false otherwise.
func (t Token) IsExpired() bool {
    return t.WillExpireIn(0)
}

// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
// from now, false otherwise.
func (t Token) WillExpireIn(d time.Duration) bool {
    return !t.Expires().After(time.Now().Add(d))
}

// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the AccessToken of the Token.
func (t *Token) WithAuthorization() autorest.PrepareDecorator {
    return func(p autorest.Preparer) autorest.Preparer {
        return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
            return (autorest.WithBearerAuthorization(t.AccessToken)(p)).Prepare(r)
        })
    }
}

// ServicePrincipalNoSecret represents a secret type that contains no secret
// meaning it is not valid for fetching a fresh token. This is used by Manual
type ServicePrincipalNoSecret struct {
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
// It only returns an error for the ServicePrincipalNoSecret type
func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
    return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
}

// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
    SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
}

// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
type ServicePrincipalTokenSecret struct {
    ClientSecret string
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
    v.Set("client_secret", tokenSecret.ClientSecret)
    return nil
}

// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
type ServicePrincipalCertificateSecret struct {
    Certificate *x509.Certificate
    PrivateKey  *rsa.PrivateKey
}

// SignJwt returns the JWT signed with the certificate's private key.
func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
    hasher := sha1.New()
    _, err := hasher.Write(secret.Certificate.Raw)
    if err != nil {
        return "", err
    }

    thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))

    // The jti (JWT ID) claim provides a unique identifier for the JWT.
    jti := make([]byte, 20)
    _, err = rand.Read(jti)
    if err != nil {
        return "", err
    }

    token := jwt.New(jwt.SigningMethodRS256)
    token.Header["x5t"] = thumbprint
    token.Claims = jwt.MapClaims{
        "aud": spt.oauthConfig.TokenEndpoint.String(),
        "iss": spt.clientID,
        "sub": spt.clientID,
        "jti": base64.URLEncoding.EncodeToString(jti),
        "nbf": time.Now().Unix(),
        "exp": time.Now().Add(time.Hour * 24).Unix(),
    }

    signedString, err := token.SignedString(secret.PrivateKey)
    return signedString, err
}

// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
    jwt, err := secret.SignJwt(spt)
    if err != nil {
        return err
    }

    v.Set("client_assertion", jwt)
    v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
    return nil
}

// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
    Token

    secret        ServicePrincipalSecret
    oauthConfig   OAuthConfig
    clientID      string
    resource      string
    autoRefresh   bool
    refreshWithin time.Duration
    sender        autorest.Sender

    refreshCallbacks []TokenRefreshCallback
}

// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
    spt := &ServicePrincipalToken{
        oauthConfig:      oauthConfig,
        secret:           secret,
        clientID:         id,
        resource:         resource,
        autoRefresh:      true,
        refreshWithin:    defaultRefresh,
        sender:           &http.Client{},
        refreshCallbacks: callbacks,
    }
    return spt, nil
}

// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
    spt, err := NewServicePrincipalTokenWithSecret(
        oauthConfig,
        clientID,
        resource,
        &ServicePrincipalNoSecret{},
        callbacks...)
    if err != nil {
        return nil, err
    }

    spt.Token = token

    return spt, nil
}

// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
// credentials scoped to the named resource.
func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
    return NewServicePrincipalTokenWithSecret(
        oauthConfig,
        clientID,
        resource,
        &ServicePrincipalTokenSecret{
            ClientSecret: secret,
        },
        callbacks...,
    )
}

// NewServicePrincipalTokenFromCertificate create a ServicePrincipalToken from the supplied pkcs12 bytes.
func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
    return NewServicePrincipalTokenWithSecret(
        oauthConfig,
        clientID,
        resource,
        &ServicePrincipalCertificateSecret{
            PrivateKey:  privateKey,
            Certificate: certificate,
        },
        callbacks...,
    )
}

// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin).
func (spt *ServicePrincipalToken) EnsureFresh() error {
    if spt.WillExpireIn(spt.refreshWithin) {
        return spt.Refresh()
    }
    return nil
}

// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
    if spt.refreshCallbacks != nil {
        for _, callback := range spt.refreshCallbacks {
            err := callback(spt.Token)
            if err != nil {
                return autorest.NewErrorWithError(err,
                    "azure.ServicePrincipalToken", "InvokeRefreshCallbacks", nil, "A TokenRefreshCallback handler returned an error")
            }
        }
    }
    return nil
}

// Refresh obtains a fresh token for the Service Principal.
func (spt *ServicePrincipalToken) Refresh() error {
    return spt.refreshInternal(spt.resource)
}

// RefreshExchange refreshes the token, but for a different resource.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
    return spt.refreshInternal(resource)
}

func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
    v := url.Values{}
    v.Set("client_id", spt.clientID)
    v.Set("resource", resource)

    if spt.RefreshToken != "" {
        v.Set("grant_type", OAuthGrantTypeRefreshToken)
        v.Set("refresh_token", spt.RefreshToken)
    } else {
        v.Set("grant_type", OAuthGrantTypeClientCredentials)
        err := spt.secret.SetAuthenticationValues(spt, &v)
        if err != nil {
            return err
        }
    }

    req, _ := autorest.Prepare(&http.Request{},
        autorest.AsPost(),
        autorest.AsFormURLEncoded(),
        autorest.WithBaseURL(spt.oauthConfig.TokenEndpoint.String()),
        autorest.WithFormData(v))

    resp, err := autorest.SendWithSender(spt.sender, req)
    if err != nil {
        return autorest.NewErrorWithError(err,
            "azure.ServicePrincipalToken", "Refresh", resp, "Failure sending request for Service Principal %s",
            spt.clientID)
    }

    var newToken Token
    err = autorest.Respond(resp,
        autorest.WithErrorUnlessStatusCode(http.StatusOK),
        autorest.ByUnmarshallingJSON(&newToken),
        autorest.ByClosing())
    if err != nil {
        return autorest.NewErrorWithError(err,
            "azure.ServicePrincipalToken", "Refresh", resp, "Failure handling response to Service Principal %s request",
            spt.clientID)
    }

    spt.Token = newToken

    err = spt.InvokeRefreshCallbacks(newToken)
    if err != nil {
        // its already wrapped inside InvokeRefreshCallbacks
        return err
    }

    return nil
}

// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
    spt.autoRefresh = autoRefresh
}

// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
// refresh the token.
func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
    spt.refreshWithin = d
    return
}

// SetSender sets the autorest.Sender used when obtaining the Service Principal token. An
// undecorated http.Client is used by default.
func (spt *ServicePrincipalToken) SetSender(s autorest.Sender) {
    spt.sender = s
}

// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the AccessToken of the ServicePrincipalToken.
//
// By default, the token will automatically refresh if nearly expired (as determined by the
// RefreshWithin interval). Use the AutoRefresh method to enable or disable automatically refreshing
// tokens.
func (spt *ServicePrincipalToken) WithAuthorization() autorest.PrepareDecorator {
    return func(p autorest.Preparer) autorest.Preparer {
        return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
            if spt.autoRefresh {
                err := spt.EnsureFresh()
                if err != nil {
                    return r, autorest.NewErrorWithError(err,
                        "azure.ServicePrincipalToken", "WithAuthorization", nil, "Failed to refresh Service Principal Token for request to %s",
                        r.URL)
                }
            }
            return (autorest.WithBearerAuthorization(spt.AccessToken)(p)).Prepare(r)
        })
    }
}