From c4a79484510551692033f11669eb85075fe1abef Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:52:05 +0000 Subject: [PATCH 1/2] Fix Bug: Prevent Empty Region in WithAzureRegion from Overriding MSAL_FORCE_REGION (#545) * Fixed a bug where if empty region is passed in WithAzureRegion it would override the MSAL_FORCE_REGION * Updated the first tests * Removed dead code. * Update confidential_test.go * Cleaned up test --- apps/confidential/confidential.go | 4 +- apps/confidential/confidential_test.go | 95 +++++++++++++++----------- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 5b375794..22c17d20 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -305,7 +305,9 @@ func WithInstanceDiscovery(enabled bool) Option { // If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail. func WithAzureRegion(val string) Option { return func(o *clientOptions) { - o.azureRegion = val + if val != "" { + o.azureRegion = val + } } } diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 1036c4f7..8d6d61a9 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -195,50 +195,65 @@ func TestRegionAutoEnable_EmptyRegion_EnvRegion(t *testing.T) { } } -func TestRegionAutoEnable_SpecifiedRegion_EnvRegion(t *testing.T) { - cred, err := NewCredFromSecret(fakeSecret) - if err != nil { - t.Fatal(err) - } - - envRegion := "envRegion" - err = os.Setenv("MSAL_FORCE_REGION", envRegion) - if err != nil { - t.Fatal(err) - } - defer os.Unsetenv("MSAL_FORCE_REGION") - - lmo := "login.microsoftonline.com" - tenant := "tenant" - mockClient := mock.Client{} - testRegion := "region" - client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion)) - if err != nil { - t.Fatal(err) - } - - if client.base.AuthParams.AuthorityInfo.Region != testRegion { - t.Fatalf("wanted %q, got %q", testRegion, client.base.AuthParams.AuthorityInfo.Region) +func TestRegionAutoEnable_SpecifiedEmptyRegion_EnvRegion(t *testing.T) { + tests := []struct { + name string + envRegion string + region string + resultRegion string + }{ + { + name: "Region is empty, envRegion is set", + envRegion: "region", + region: "", + resultRegion: "region", + }, + { + name: "Region is set, envRegion is set", + envRegion: "region", + region: "setRegion", + resultRegion: "setRegion", + }, + { + name: "Region is set, envRegion is empty", + envRegion: "", + region: "setRegion", + resultRegion: "setRegion", + }, + { + name: "Disable region is set, envRegion is set", + envRegion: "region", + region: "DisableMsalForceRegion", + resultRegion: "", + }, } -} -func TestRegionAutoEnable_DisableMsalForceRegion(t *testing.T) { - cred, err := NewCredFromSecret(fakeSecret) - if err != nil { - t.Fatal(err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + if test.envRegion != "" { + t.Setenv("MSAL_FORCE_REGION", test.envRegion) + } + lmo := "login.microsoftonline.com" + tenant := "tenant" + mockClient := mock.Client{} - lmo := "login.microsoftonline.com" - tenant := "tenant" - mockClient := mock.Client{} - testRegion := "DisableMsalForceRegion" - client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion)) - if err != nil { - t.Fatal(err) - } + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(test.region)) + if err != nil { + t.Fatal(err) + } - if client.base.AuthParams.AuthorityInfo.Region != "" { - t.Fatalf("wanted empty, got %q", client.base.AuthParams.AuthorityInfo.Region) + if test.resultRegion == "DisableMsalForceRegion" { + if client.base.AuthParams.AuthorityInfo.Region != "" { + t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region) + } + } else if client.base.AuthParams.AuthorityInfo.Region != test.resultRegion { + t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region) + } + }) } } From e6d9244268108fbdb46b86c2a3b223c7102615b6 Mon Sep 17 00:00:00 2001 From: Andrew O Hart Date: Fri, 14 Feb 2025 15:28:34 +0000 Subject: [PATCH 2/2] Add Managed Identity Support (#552) * Added Managed Identity support for multiple sources (IMDS, App Service, CloudShell, AzureML, Service Fabric, Azure Arc) * Updated tests * Updated documentation * Added new Managed Identity client that currently supports cache and retry policies --- .github/workflows/go.yml | 2 +- README.md | 32 + apps/confidential/confidential_test.go | 57 +- apps/errors/errors.go | 9 + apps/internal/base/base.go | 5 +- apps/internal/base/base_test.go | 12 +- .../base/{internal => }/storage/items.go | 3 +- .../base/{internal => }/storage/items_test.go | 2 +- .../storage/partitioned_storage.go | 2 +- .../storage/partitioned_storage_test.go | 5 +- .../base/{internal => }/storage/storage.go | 7 +- .../{internal => }/storage/storage_test.go | 2 +- .../testdata/test_serialized_cache.json | 0 .../storage/testdata/v1.0_cache.json | 0 .../storage/testdata/v1.0_v1.1_cache.json | 0 apps/internal/local/server.go | 3 +- apps/internal/mock/mock.go | 14 + apps/internal/oauth/oauth.go | 9 +- .../oauth/ops/accesstokens/accesstokens.go | 12 +- .../ops/accesstokens/accesstokens_test.go | 150 ++- .../internal/oauth/ops/accesstokens/tokens.go | 65 +- apps/internal/oauth/ops/internal/comm/comm.go | 4 +- apps/managedidentity/azure_ml.go | 28 + apps/managedidentity/cloud_shell.go | 37 + apps/managedidentity/managedidentity.go | 681 ++++++++++ apps/managedidentity/managedidentity_test.go | 1133 +++++++++++++++++ apps/managedidentity/servicefabric.go | 25 + apps/managedidentity/servicefabric_test.go | 116 ++ apps/public/public_test.go | 8 +- apps/tests/benchmarks/confidential.go | 5 +- apps/tests/devapps/main.go | 2 +- .../docs/msi_manual_testing.md | 3 + apps/tests/performance/performance_test.go | 3 +- docs/managedidentity_public_api.md | 205 +++ 34 files changed, 2576 insertions(+), 65 deletions(-) rename apps/internal/base/{internal => }/storage/items.go (98%) rename apps/internal/base/{internal => }/storage/items_test.go (99%) rename apps/internal/base/{internal => }/storage/partitioned_storage.go (99%) rename apps/internal/base/{internal => }/storage/partitioned_storage_test.go (98%) rename apps/internal/base/{internal => }/storage/storage.go (99%) rename apps/internal/base/{internal => }/storage/storage_test.go (99%) rename apps/internal/base/{internal => }/storage/testdata/test_serialized_cache.json (100%) rename apps/internal/base/{internal => }/storage/testdata/v1.0_cache.json (100%) rename apps/internal/base/{internal => }/storage/testdata/v1.0_v1.1_cache.json (100%) create mode 100644 apps/managedidentity/azure_ml.go create mode 100644 apps/managedidentity/cloud_shell.go create mode 100644 apps/managedidentity/managedidentity.go create mode 100644 apps/managedidentity/managedidentity_test.go create mode 100644 apps/managedidentity/servicefabric.go create mode 100644 apps/managedidentity/servicefabric_test.go create mode 100644 apps/tests/devapps/managedidentity/docs/msi_manual_testing.md create mode 100644 docs/managedidentity_public_api.md diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 695517eb..463360cd 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -39,7 +39,7 @@ jobs: run: go build ./apps/... - name: Unit Tests - run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/... + run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/... ./apps/managedidentity/... # Intergration tests runs on ADO # - name: Integration Tests # run: go test -race ./apps/tests/integration/... diff --git a/README.md b/README.md index b90801c6..3d6ccce2 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,28 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some } confidentialClient, err := confidential.New("https://login.microsoftonline.com/your_tenant", "client_id", cred) ``` + * Initializing a Managed Identity client for SystemAssigned: + + ```go + import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" + + // Managed identity client have a type of ID required, SystemAssigned or UserAssigned + miSystemAssigned, err := mi.New(mi.SystemAssigned()) + if err != nil { + // TODO: handle error + } + ``` + * Initializing a Managed Identity client for UserAssigned: + + ```go + import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" + + // Managed identity client have a type of ID required, SystemAssigned or UserAssigned + miSystemAssigned, err := mi.New(mi.UserAssignedClientID("YOUR_CLIENT_ID")) + if err != nil { + // TODO: handle error + } + ``` 1. Call `AcquireTokenSilent()` to look for a cached token. If `AcquireTokenSilent()` returns an error, call another `AcquireToken...` method to authenticate. @@ -96,6 +118,16 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some accessToken := result.AccessToken ``` + * ManagedIdentity clietn can simply call `AcquireToken()`: + ```go + resource := "" + result, err := miSystemAssigned.AcquireToken(context.TODO(), resource) + if err != nil { + // TODO: handle error + } + accessToken := result.AccessToken + ``` + ## Community Help and Support We use [Stack Overflow](http://stackoverflow.com/questions/tagged/msal) to work with the community on supporting Azure Active Directory and its SDKs, including this one! We highly recommend you ask your questions on Stack Overflow (we're all on there!) Also browse existing issues to see if someone has had your question before. Please use the "msal" tag when asking your questions. diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 8d6d61a9..77b2fcf6 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -10,7 +10,6 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -25,6 +24,7 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" @@ -35,6 +35,7 @@ import ( // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request type errorClient struct{} +type contextKey struct{} func (*errorClient) Do(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String()) @@ -138,7 +139,7 @@ func TestAcquireTokenByCredential(t *testing.T) { } client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExpiresOn: time.Now().Add(1 * time.Hour), ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, TokenType: "Bearer", @@ -305,7 +306,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) { func TestAcquireTokenByAssertionCallback(t *testing.T) { calls := 0 - key := struct{}{} + key := contextKey{} ctx := context.WithValue(context.Background(), key, true) getAssertion := func(c context.Context, o AssertionRequestOptions) (string, error) { if v := c.Value(key); v == nil || !v.(bool) { @@ -358,7 +359,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) { tr := accesstokens.TokenResponse{ AccessToken: token, RefreshToken: refresh, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExpiresOn: time.Now().Add(1 * time.Hour), ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, IDToken: accesstokens.IDToken{ @@ -427,6 +428,40 @@ func TestAcquireTokenByAuthCode(t *testing.T) { } } +func TestInvalidJsonErrFromResponse(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + tenant := "A" + lmo := "login.microsoftonline.com" + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + // cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it. + if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil { + t.Fatal("silent auth should fail because the cache is empty") + } + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) + body := fmt.Sprintf( + `{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`, + tenant, 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(), + ) + mockClient.AppendResponse(mock.WithBody([]byte(body))) + _, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant)) + if err == nil { + t.Fatal("should have failed with InvalidJsonErr Response") + } + var ie errors.InvalidJsonErr + if !errors.As(err, &ie) { + t.Fatal("should have revieved a InvalidJsonErr, but got", err) + } +} + func TestAcquireTokenSilentTenants(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { @@ -478,7 +513,7 @@ func TestADFSTokenCaching(t *testing.T) { AccessToken: "at1", RefreshToken: "rt", TokenType: "bearer", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExpiresOn: time.Now().Add(time.Hour), ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, IDToken: accesstokens.IDToken{ @@ -608,7 +643,7 @@ func TestNewCredFromCert(t *testing.T) { t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) { client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExpiresOn: time.Now().Add(time.Hour), GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, }, cred, fakeAuthority, opts...) if err != nil { @@ -724,7 +759,7 @@ func TestNewCredFromTokenProvider(t *testing.T) { expectedToken := "expected token" called := false expiresIn := 4200 - key := struct{}{} + key := contextKey{} ctx := context.WithValue(context.Background(), key, true) cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { if called { @@ -982,7 +1017,7 @@ func TestWithClaims(t *testing.T) { case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims)) default: - t.Fatalf("test bug: no test for " + method) + t.Fatalf("test bug: no test for %s", method) } if err != nil { t.Fatal(err) @@ -1092,7 +1127,7 @@ func TestWithTenantID(t *testing.T) { case "obo": ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant)) default: - t.Fatalf("test bug: no test for " + method) + t.Fatalf("test bug: no test for %s", method) } if err != nil { if test.expectError { @@ -1402,7 +1437,7 @@ func TestWithAuthenticationScheme(t *testing.T) { } client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExpiresOn: time.Now().Add(1 * time.Hour), ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, TokenType: "TokenType", @@ -1442,7 +1477,7 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) { } client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExpiresOn: time.Now().Add(1 * time.Hour), ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, TokenType: "Bearer", diff --git a/apps/errors/errors.go b/apps/errors/errors.go index c9b8dbed..b5cbb572 100644 --- a/apps/errors/errors.go +++ b/apps/errors/errors.go @@ -64,11 +64,20 @@ type CallErr struct { Err error } +type InvalidJsonErr struct { + Err error +} + // Errors implements error.Error(). func (e CallErr) Error() string { return e.Err.Error() } +// Errors implements error.Error(). +func (e InvalidJsonErr) Error() string { + return e.Err.Error() +} + // Verbose prints a versbose error message with the request or response. func (e CallErr) Verbose() string { e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index e473d126..6011a00b 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -14,7 +14,7 @@ import ( "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" @@ -111,7 +111,6 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu if err := storageTokenResponse.AccessToken.Validate(); err != nil { return AuthResult{}, fmt.Errorf("problem with access token in StorageTokenResponse: %w", err) } - account := storageTokenResponse.Account accessToken := storageTokenResponse.AccessToken.Secret grantedScopes := strings.Split(storageTokenResponse.AccessToken.Scopes, scopeSeparator) @@ -146,7 +145,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco Account: account, IDToken: tokenResponse.IDToken, AccessToken: tokenResponse.AccessToken, - ExpiresOn: tokenResponse.ExpiresOn.T, + ExpiresOn: tokenResponse.ExpiresOn, GrantedScopes: tokenResponse.GrantedScopes.Slice, Metadata: AuthResultMetadata{ TokenSource: IdentityProvider, diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 09238780..f06f13fa 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -12,7 +12,7 @@ import ( "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" @@ -50,7 +50,7 @@ func fakeClient(t *testing.T, opts ...Option) Client { client.Token.AccessTokens = &fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ AccessToken: fakeAccessToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExpiresOn: time.Now().Add(time.Hour), FamilyID: "family-id", GrantedScopes: accesstokens.Scopes{Slice: testScopes}, IDToken: fakeIDToken, @@ -135,7 +135,7 @@ func TestAcquireTokenSilentScopes(t *testing.T) { accesstokens.TokenResponse{ AccessToken: fakeAccessToken, ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)}, + ExpiresOn: time.Now().Add(-time.Hour), GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes}, IDToken: fakeIDToken, RefreshToken: fakeRefreshToken, @@ -178,7 +178,7 @@ func TestAcquireTokenSilentGrantedScopes(t *testing.T) { }, accesstokens.TokenResponse{ AccessToken: expectedToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExpiresOn: time.Now().Add(time.Hour), GrantedScopes: accesstokens.Scopes{Slice: grantedScopes}, TokenType: "Bearer", }, @@ -335,7 +335,7 @@ func TestCreateAuthenticationResult(t *testing.T) { desc: "no declined scopes", input: accesstokens.TokenResponse{ AccessToken: "accessToken", - ExpiresOn: internalTime.DurationTime{T: future}, + ExpiresOn: future, GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, DeclinedScopes: nil, }, @@ -353,7 +353,7 @@ func TestCreateAuthenticationResult(t *testing.T) { desc: "declined scopes", input: accesstokens.TokenResponse{ AccessToken: "accessToken", - ExpiresOn: internalTime.DurationTime{T: future}, + ExpiresOn: future, GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, DeclinedScopes: []string{"openid"}, }, diff --git a/apps/internal/base/internal/storage/items.go b/apps/internal/base/storage/items.go similarity index 98% rename from apps/internal/base/internal/storage/items.go rename to apps/internal/base/storage/items.go index f9be9027..95cb2b41 100644 --- a/apps/internal/base/internal/storage/items.go +++ b/apps/internal/base/storage/items.go @@ -102,8 +102,9 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex // Key outputs the key that can be used to uniquely look up this entry in a map. func (a AccessToken) Key() string { + ks := []string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes} key := strings.Join( - []string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes}, + ks, shared.CacheKeySeparator, ) // add token type to key for new access tokens types. skip for bearer token type to diff --git a/apps/internal/base/internal/storage/items_test.go b/apps/internal/base/storage/items_test.go similarity index 99% rename from apps/internal/base/internal/storage/items_test.go rename to apps/internal/base/storage/items_test.go index d1df933d..24f45063 100644 --- a/apps/internal/base/internal/storage/items_test.go +++ b/apps/internal/base/storage/items_test.go @@ -305,7 +305,7 @@ func TestContractUnmarshalJSON(t *testing.T) { } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestContractUnmarshalJSON: -want/+got:\n%s", diff) - t.Errorf(string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage))) + t.Errorf("%s", string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage))) } } diff --git a/apps/internal/base/internal/storage/partitioned_storage.go b/apps/internal/base/storage/partitioned_storage.go similarity index 99% rename from apps/internal/base/internal/storage/partitioned_storage.go rename to apps/internal/base/storage/partitioned_storage.go index c0931833..b816766e 100644 --- a/apps/internal/base/internal/storage/partitioned_storage.go +++ b/apps/internal/base/storage/partitioned_storage.go @@ -114,7 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes realm, clientID, cachedAt, - tokenResponse.ExpiresOn.T, + tokenResponse.ExpiresOn, tokenResponse.ExtExpiresOn.T, target, tokenResponse.AccessToken, diff --git a/apps/internal/base/internal/storage/partitioned_storage_test.go b/apps/internal/base/storage/partitioned_storage_test.go similarity index 98% rename from apps/internal/base/internal/storage/partitioned_storage_test.go rename to apps/internal/base/storage/partitioned_storage_test.go index 86859cf2..6a73e91e 100644 --- a/apps/internal/base/internal/storage/partitioned_storage_test.go +++ b/apps/internal/base/storage/partitioned_storage_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" @@ -59,7 +58,7 @@ func TestOBOAccessTokenScopes(t *testing.T) { accesstokens.TokenResponse{ AccessToken: scope[0] + "-at", ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExpiresOn: time.Now().Add(time.Hour), GrantedScopes: accesstokens.Scopes{Slice: scope}, IDToken: idt, RefreshToken: upn + "-rt", @@ -121,7 +120,7 @@ func TestOBOPartitioning(t *testing.T) { accesstokens.TokenResponse{ AccessToken: upn + "-at", ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExpiresOn: time.Now().Add(time.Hour), GrantedScopes: accesstokens.Scopes{Slice: scopes}, IDToken: idt, RefreshToken: upn + "-rt", diff --git a/apps/internal/base/internal/storage/storage.go b/apps/internal/base/storage/storage.go similarity index 99% rename from apps/internal/base/internal/storage/storage.go rename to apps/internal/base/storage/storage.go index 2221e60c..334431b7 100644 --- a/apps/internal/base/internal/storage/storage.go +++ b/apps/internal/base/storage/storage.go @@ -173,6 +173,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces environment := authParameters.AuthorityInfo.Host realm := authParameters.AuthorityInfo.Tenant clientID := authParameters.ClientID + target := strings.Join(tokenResponse.GrantedScopes.Slice, scopeSeparator) cachedAt := time.Now() authnSchemeKeyID := authParameters.AuthnScheme.KeyID() @@ -193,7 +194,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces realm, clientID, cachedAt, - tokenResponse.ExpiresOn.T, + tokenResponse.ExpiresOn, tokenResponse.ExtExpiresOn.T, target, tokenResponse.AccessToken, @@ -265,6 +266,9 @@ func (m *Manager) aadMetadataFromCache(ctx context.Context, authorityInfo author } func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { + if m.requests == nil { + return authority.InstanceDiscoveryMetadata{}, fmt.Errorf("httpclient in oauth instance for fetching metadata is nil") + } m.aadCacheMu.Lock() defer m.aadCacheMu.Unlock() discoveryResponse, err := m.requests.AADInstanceDiscovery(ctx, authorityInfo) @@ -459,6 +463,7 @@ func (m *Manager) readAccount(homeAccountID string, envAliases []string, realm s func (m *Manager) writeAccount(account shared.Account) error { key := account.Key() + m.contractMu.Lock() defer m.contractMu.Unlock() m.contract.Accounts[key] = account diff --git a/apps/internal/base/internal/storage/storage_test.go b/apps/internal/base/storage/storage_test.go similarity index 99% rename from apps/internal/base/internal/storage/storage_test.go rename to apps/internal/base/storage/storage_test.go index 0570115c..59eae24a 100644 --- a/apps/internal/base/internal/storage/storage_test.go +++ b/apps/internal/base/storage/storage_test.go @@ -1007,7 +1007,7 @@ func TestWrite(t *testing.T) { Oid: "lid", PreferredUsername: "username", } - expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)} + expiresOn := now.Add(1000 * time.Second) tokenResponse := accesstokens.TokenResponse{ AccessToken: "accessToken", RefreshToken: "refreshToken", diff --git a/apps/internal/base/internal/storage/testdata/test_serialized_cache.json b/apps/internal/base/storage/testdata/test_serialized_cache.json similarity index 100% rename from apps/internal/base/internal/storage/testdata/test_serialized_cache.json rename to apps/internal/base/storage/testdata/test_serialized_cache.json diff --git a/apps/internal/base/internal/storage/testdata/v1.0_cache.json b/apps/internal/base/storage/testdata/v1.0_cache.json similarity index 100% rename from apps/internal/base/internal/storage/testdata/v1.0_cache.json rename to apps/internal/base/storage/testdata/v1.0_cache.json diff --git a/apps/internal/base/internal/storage/testdata/v1.0_v1.1_cache.json b/apps/internal/base/storage/testdata/v1.0_v1.1_cache.json similarity index 100% rename from apps/internal/base/internal/storage/testdata/v1.0_v1.1_cache.json rename to apps/internal/base/storage/testdata/v1.0_v1.1_cache.json diff --git a/apps/internal/local/server.go b/apps/internal/local/server.go index fda5d7dd..cda678e3 100644 --- a/apps/internal/local/server.go +++ b/apps/internal/local/server.go @@ -146,7 +146,8 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { // Note: It is a little weird we handle some errors by not going to the failPage. If they all should, // change this to s.error() and make s.error() write the failPage instead of an error code. _, _ = w.Write([]byte(fmt.Sprintf(failPage, headerErr, desc))) - s.putResult(Result{Err: fmt.Errorf(desc)}) + s.putResult(Result{Err: fmt.Errorf("%s", desc)}) + return } diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 5de171fd..a612c2c1 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -46,6 +46,20 @@ func WithCallback(callback func(*http.Request)) responseOption { }) } +// WithHTTPHeader sets the HTTP headers of the response to the specified value. +func WithHTTPHeader(header http.Header) responseOption { + return respOpt(func(r *response) { + r.headers = header + }) +} + +// WithHTTPStatusCode sets the HTTP statusCode of response to the specified value. +func WithHTTPStatusCode(statusCode int) responseOption { + return respOpt(func(r *response) { + r.code = statusCode + }) +} + // Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. type Client struct { resp []response diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index e0653134..ad476e07 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -14,7 +14,6 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" - internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" @@ -120,11 +119,9 @@ func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams return accesstokens.TokenResponse{}, err } return accesstokens.TokenResponse{ - TokenType: authParams.AuthnScheme.AccessTokenType(), - AccessToken: tr.AccessToken, - ExpiresOn: internalTime.DurationTime{ - T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second), - }, + TokenType: authParams.AuthnScheme.AccessTokenType(), + AccessToken: tr.AccessToken, + ExpiresOn: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second), GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes}, }, nil } diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens.go b/apps/internal/oauth/ops/accesstokens/accesstokens.go index a7b7b074..71275b32 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens.go @@ -262,11 +262,7 @@ func (c Client) FromClientSecret(ctx context.Context, authParameters authority.A qv.Set(clientID, authParameters.ClientID) addScopeQueryParam(qv, authParameters) - token, err := c.doTokenResp(ctx, authParameters, qv) - if err != nil { - return token, fmt.Errorf("FromClientSecret(): %w", err) - } - return token, nil + return c.doTokenResp(ctx, authParameters, qv) } func (c Client) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (TokenResponse, error) { @@ -281,11 +277,7 @@ func (c Client) FromAssertion(ctx context.Context, authParameters authority.Auth qv.Set(clientInfo, clientInfoVal) addScopeQueryParam(qv, authParameters) - token, err := c.doTokenResp(ctx, authParameters, qv) - if err != nil { - return token, fmt.Errorf("FromAssertion(): %w", err) - } - return token, nil + return c.doTokenResp(ctx, authParameters, qv) } func (c Client) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (TokenResponse, error) { diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index 59d3506d..5099294a 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -757,18 +757,85 @@ func TestTokenResponseUnmarshal(t *testing.T) { jwtDecoder: jwtDecoderFake, }, { - desc: "Success", + desc: "Success: ExpiresOn as Unix timestamp number, expires_in present", + payload: fmt.Sprintf(` + { + "access_token": "secret", + "expires_on": %d, + "expires_in": "3600", + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, time.Now().Add(time.Hour).Unix()), + want: TokenResponse{ + AccessToken: "secret", + ExpiresOn: time.Now().Add(time.Hour), // from expires_on + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Success: ExpiresOn as ISO 8601 string", payload: ` { "access_token": "secret", - "expires_in": 86399, + "expires_on": "2024-12-31T23:59:59Z", + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{ + AccessToken: "secret", + ExpiresOn: time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC), + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Success: ExpiresOn as MM/dd/yyyy HH:mm:ss string", + payload: ` + { + "access_token": "secret", + "expires_on": "12/31/2024 23:59:59", + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{ + AccessToken: "secret", + ExpiresOn: time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC), + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Success: ExpiresOn as yyyy-MM-dd HH:mm:ss string", + payload: ` + { + "access_token": "secret", + "expires_on": "2024-12-31 23:59:59", "ext_expires_in": 86399, - "client_info": {"uid": "uid","utid": "utid"}, + "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" }`, want: TokenResponse{ AccessToken: "secret", - ExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + ExpiresOn: time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC), ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ @@ -778,6 +845,77 @@ func TestTokenResponseUnmarshal(t *testing.T) { }, jwtDecoder: jwtDecoderFake, }, + { + desc: "Success: ExpiresOn empty, fallback to ExpiresIn", + payload: ` + { + "access_token": "secret", + "expires_on": "", + "expires_in": 3600, + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{ + AccessToken: "secret", + ExpiresOn: time.Now().Add(time.Hour), + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Success: Only expires_in", + payload: ` + { + "access_token": "secret", + "expires_in": 3600, + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{ + AccessToken: "secret", + ExpiresOn: time.Now().Add(time.Hour), + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Error: Missing both expires_on and expires_in", + payload: ` + { + "access_token": "secret", + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{}, + err: true, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Error: Invalid ExpiresOn format", + payload: ` + { + "access_token": "secret", + "expires_on": "invalid-date-format", + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{}, + err: true, + jwtDecoder: jwtDecoderFake, + }, } for _, test := range tests { @@ -795,7 +933,9 @@ func TestTokenResponseUnmarshal(t *testing.T) { case err != nil: continue } - + if got.ExpiresOn.Unix() != test.want.ExpiresOn.Unix() { + t.Errorf("TestCreateTokenResponse: got %v, want %v", got.ExpiresOn.Unix(), test.want.ExpiresOn.Unix()) + } // Note: IncludeUnexported prevents minor differences in time.Time due to internal fields. if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" { t.Errorf("TestCreateTokenResponse: -want/+got:\n%s", diff) diff --git a/apps/internal/oauth/ops/accesstokens/tokens.go b/apps/internal/oauth/ops/accesstokens/tokens.go index 3107b45c..84b17a1c 100644 --- a/apps/internal/oauth/ops/accesstokens/tokens.go +++ b/apps/internal/oauth/ops/accesstokens/tokens.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "reflect" + "strconv" "strings" "time" @@ -173,14 +174,74 @@ type TokenResponse struct { FamilyID string `json:"foci"` IDToken IDToken `json:"id_token"` ClientInfo ClientInfo `json:"client_info"` - ExpiresOn internalTime.DurationTime `json:"expires_in"` + ExpiresOn time.Time `json:"-"` ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"` GrantedScopes Scopes `json:"scope"` DeclinedScopes []string // This is derived AdditionalFields map[string]interface{} + scopesComputed bool +} + +func (tr *TokenResponse) UnmarshalJSON(data []byte) error { + type Alias TokenResponse + aux := &struct { + ExpiresIn internalTime.DurationTime `json:"expires_in,omitempty"` + ExpiresOn any `json:"expires_on,omitempty"` + *Alias + }{ + Alias: (*Alias)(tr), + } + + // Unmarshal the JSON data into the aux struct + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + // Function to parse different date formats + // This is a workaround for the issue described here: + // https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/4963 + parseExpiresOn := func(expiresOn string) (time.Time, error) { + var formats = []string{ + "01/02/2006 15:04:05", // MM/dd/yyyy HH:mm:ss + "2006-01-02 15:04:05", // yyyy-MM-dd HH:mm:ss + time.RFC3339Nano, // ISO 8601 (with nanosecond precision) + } + + for _, format := range formats { + if t, err := time.Parse(format, expiresOn); err == nil { + return t, nil + } + } + return time.Time{}, fmt.Errorf("invalid ExpiresOn format: %s", expiresOn) + } - scopesComputed bool + if expiresOnStr, ok := aux.ExpiresOn.(string); ok { + if ts, err := strconv.ParseInt(expiresOnStr, 10, 64); err == nil { + tr.ExpiresOn = time.Unix(ts, 0) + return nil + } + if expiresOnStr != "" { + if t, err := parseExpiresOn(expiresOnStr); err != nil { + return err + } else { + tr.ExpiresOn = t + return nil + } + } + } + + // Check if ExpiresOn is a number (Unix timestamp or ISO 8601) + if expiresOnNum, ok := aux.ExpiresOn.(float64); ok { + tr.ExpiresOn = time.Unix(int64(expiresOnNum), 0) + return nil + } + + if !aux.ExpiresIn.T.IsZero() { + tr.ExpiresOn = aux.ExpiresIn.T + return nil + } + return errors.New("expires_in and expires_on are both missing or invalid") } // ComputeScope computes the final scopes based on what was granted by the server and diff --git a/apps/internal/oauth/ops/internal/comm/comm.go b/apps/internal/oauth/ops/internal/comm/comm.go index d62aac74..79068036 100644 --- a/apps/internal/oauth/ops/internal/comm/comm.go +++ b/apps/internal/oauth/ops/internal/comm/comm.go @@ -98,7 +98,7 @@ func (c *Client) JSONCall(ctx context.Context, endpoint string, headers http.Hea if resp != nil { if err := unmarshal(data, resp); err != nil { - return fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data)) + return errors.InvalidJsonErr{Err: fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data))} } } return nil @@ -221,7 +221,7 @@ func (c *Client) URLFormCall(ctx context.Context, endpoint string, qv url.Values } if resp != nil { if err := unmarshal(data, resp); err != nil { - return fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data)) + return errors.InvalidJsonErr{Err: fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data))} } } return nil diff --git a/apps/managedidentity/azure_ml.go b/apps/managedidentity/azure_ml.go new file mode 100644 index 00000000..d7cffc29 --- /dev/null +++ b/apps/managedidentity/azure_ml.go @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package managedidentity + +import ( + "context" + "net/http" + "os" +) + +func createAzureMLAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, os.Getenv(msiEndpointEnvVar), nil) + if err != nil { + return nil, err + } + + req.Header.Set("secret", os.Getenv(msiSecretEnvVar)) + q := req.URL.Query() + q.Set(apiVersionQueryParameterName, azureMLAPIVersion) + q.Set(resourceQueryParameterName, resource) + q.Set("clientid", os.Getenv("DEFAULT_IDENTITY_CLIENT_ID")) + if cid, ok := id.(UserAssignedClientID); ok { + q.Set("clientid", string(cid)) + } + req.URL.RawQuery = q.Encode() + return req, nil +} diff --git a/apps/managedidentity/cloud_shell.go b/apps/managedidentity/cloud_shell.go new file mode 100644 index 00000000..be9a0bca --- /dev/null +++ b/apps/managedidentity/cloud_shell.go @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package managedidentity + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" +) + +func createCloudShellAuthRequest(ctx context.Context, resource string) (*http.Request, error) { + msiEndpoint := os.Getenv(msiEndpointEnvVar) + msiEndpointParsed, err := url.Parse(msiEndpoint) + if err != nil { + return nil, fmt.Errorf("couldn't parse %q: %s", msiEndpoint, err) + } + + data := url.Values{} + data.Set(resourceQueryParameterName, resource) + msiDataEncoded := data.Encode() + body := io.NopCloser(strings.NewReader(msiDataEncoded)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, msiEndpointParsed.String(), body) + if err != nil { + return nil, fmt.Errorf("error creating http request %s", err) + } + + req.Header.Set(metaHTTPHeaderName, "true") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + return req, nil +} diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go new file mode 100644 index 00000000..4e274869 --- /dev/null +++ b/apps/managedidentity/managedidentity.go @@ -0,0 +1,681 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +/* +Package managedidentity provides a client for retrieval of Managed Identity applications. +The Managed Identity Client is used to acquire a token for managed identity assigned to +an azure resource such as Azure function, app service, virtual machine, etc. to acquire a token +without using credentials. +*/ +package managedidentity + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" +) + +const ( + // DefaultToIMDS indicates that the source is defaulted to IMDS when no environment variables are set. + DefaultToIMDS Source = "DefaultToIMDS" + AzureArc Source = "AzureArc" + ServiceFabric Source = "ServiceFabric" + CloudShell Source = "CloudShell" + AzureML Source = "AzureML" + AppService Source = "AppService" + + // General request query parameter names + metaHTTPHeaderName = "Metadata" + apiVersionQueryParameterName = "api-version" + resourceQueryParameterName = "resource" + wwwAuthenticateHeaderName = "www-authenticate" + + // UAMI query parameter name + miQueryParameterClientId = "client_id" + miQueryParameterObjectId = "object_id" + miQueryParameterPrincipalId = "principal_id" + miQueryParameterResourceIdIMDS = "msi_res_id" + miQueryParameterResourceId = "mi_res_id" + + // IMDS + imdsDefaultEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + imdsAPIVersion = "2018-02-01" + systemAssignedManagedIdentity = "system_assigned_managed_identity" + + // Azure Arc + azureArcEndpoint = "http://127.0.0.1:40342/metadata/identity/oauth2/token" + azureArcAPIVersion = "2020-06-01" + azureArcFileExtension = ".key" + azureArcMaxFileSizeBytes int64 = 4096 + linuxTokenPath = "/var/opt/azcmagent/tokens" + linuxHimdsPath = "/opt/azcmagent/bin/himds" + azureConnectedMachine = "AzureConnectedMachineAgent" + himdsExecutableName = "himds.exe" + tokenName = "Tokens" + + // App Service + appServiceAPIVersion = "2019-08-01" + + // AzureML + azureMLAPIVersion = "2017-09-01" + // Service Fabric + serviceFabricAPIVersion = "2019-07-01-preview" + + // Environment Variables + identityEndpointEnvVar = "IDENTITY_ENDPOINT" + identityHeaderEnvVar = "IDENTITY_HEADER" + azurePodIdentityAuthorityHostEnvVar = "AZURE_POD_IDENTITY_AUTHORITY_HOST" + imdsEndVar = "IMDS_ENDPOINT" + msiEndpointEnvVar = "MSI_ENDPOINT" + msiSecretEnvVar = "MSI_SECRET" + identityServerThumbprintEnvVar = "IDENTITY_SERVER_THUMBPRINT" + + defaultRetryCount = 3 +) + +var retryCodesForIMDS = []int{ + http.StatusNotFound, // 404 + http.StatusGone, // 410 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusNotImplemented, // 501 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + http.StatusHTTPVersionNotSupported, // 505 + http.StatusVariantAlsoNegotiates, // 506 + http.StatusInsufficientStorage, // 507 + http.StatusLoopDetected, // 508 + http.StatusNotExtended, // 510 + http.StatusNetworkAuthenticationRequired, // 511 +} + +var retryStatusCodes = []int{ + http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 +} + +var getAzureArcPlatformPath = func(platform string) string { + switch platform { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), azureConnectedMachine, tokenName) + case "linux": + return linuxTokenPath + default: + return "" + } +} + +var getAzureArcHimdsFilePath = func(platform string) string { + switch platform { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), azureConnectedMachine, himdsExecutableName) + case "linux": + return linuxHimdsPath + default: + return "" + } +} + +type Source string + +type ID interface { + value() string +} + +type systemAssignedValue string // its private for a reason to make the input consistent. +type UserAssignedClientID string +type UserAssignedObjectID string +type UserAssignedResourceID string + +func (s systemAssignedValue) value() string { return string(s) } +func (c UserAssignedClientID) value() string { return string(c) } +func (o UserAssignedObjectID) value() string { return string(o) } +func (r UserAssignedResourceID) value() string { return string(r) } +func SystemAssigned() ID { + return systemAssignedValue(systemAssignedManagedIdentity) +} + +// cache never uses the client because instance discovery is always disabled. +var cacheManager *storage.Manager = storage.New(nil) + +type Client struct { + httpClient ops.HTTPClient + miType ID + source Source + authParams authority.AuthParams + retryPolicyEnabled bool +} + +type AcquireTokenOptions struct { + claims string +} + +type ClientOption func(*Client) + +type AcquireTokenOption func(o *AcquireTokenOptions) + +// WithClaims sets additional claims to request for the token, such as those required by token revocation or conditional access policies. +// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. +func WithClaims(claims string) AcquireTokenOption { + return func(o *AcquireTokenOptions) { + o.claims = claims + } +} + +// WithHTTPClient allows for a custom HTTP client to be set. +func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { + return func(c *Client) { + c.httpClient = httpClient + } +} + +func WithRetryPolicyDisabled() ClientOption { + return func(c *Client) { + c.retryPolicyEnabled = false + } +} + +// Client to be used to acquire tokens for managed identity. +// ID: [SystemAssigned], [UserAssignedClientID], [UserAssignedResourceID], [UserAssignedObjectID] +// +// Options: [WithHTTPClient] +func New(id ID, options ...ClientOption) (Client, error) { + source, err := GetSource() + if err != nil { + return Client{}, err + } + + // Check for user-assigned restrictions based on the source + switch source { + case AzureArc: + switch id.(type) { + case UserAssignedClientID, UserAssignedResourceID, UserAssignedObjectID: + return Client{}, errors.New("Azure Arc doesn't support user-assigned managed identities") + } + case AzureML: + switch id.(type) { + case UserAssignedObjectID, UserAssignedResourceID: + return Client{}, errors.New("Azure ML supports specifying a user-assigned managed identity by client ID only") + } + case CloudShell: + switch id.(type) { + case UserAssignedClientID, UserAssignedResourceID, UserAssignedObjectID: + return Client{}, errors.New("Cloud Shell doesn't support user-assigned managed identities") + } + case ServiceFabric: + switch id.(type) { + case UserAssignedClientID, UserAssignedResourceID, UserAssignedObjectID: + return Client{}, errors.New("Service Fabric API doesn't support specifying a user-assigned identity. The identity is determined by cluster resource configuration. See https://aka.ms/servicefabricmi") + } + } + + switch t := id.(type) { + case UserAssignedClientID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("empty %T", t) + } + case UserAssignedResourceID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("empty %T", t) + } + case UserAssignedObjectID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("empty %T", t) + } + case systemAssignedValue: + default: + return Client{}, fmt.Errorf("unsupported type %T", id) + } + client := Client{ + miType: id, + httpClient: shared.DefaultClient, + retryPolicyEnabled: true, + source: source, + } + for _, option := range options { + option(&client) + } + fakeAuthInfo, err := authority.NewInfoFromAuthorityURI("https://login.microsoftonline.com/managed_identity", false, true) + if err != nil { + return Client{}, err + } + client.authParams = authority.NewAuthParams(client.miType.value(), fakeAuthInfo) + return client, nil +} + +// GetSource detects and returns the managed identity source available on the environment. +func GetSource() (Source, error) { + identityEndpoint := os.Getenv(identityEndpointEnvVar) + identityHeader := os.Getenv(identityHeaderEnvVar) + identityServerThumbprint := os.Getenv(identityServerThumbprintEnvVar) + msiEndpoint := os.Getenv(msiEndpointEnvVar) + msiSecret := os.Getenv(msiSecretEnvVar) + imdsEndpoint := os.Getenv(imdsEndVar) + + if identityEndpoint != "" && identityHeader != "" { + if identityServerThumbprint != "" { + return ServiceFabric, nil + } + return AppService, nil + } else if msiEndpoint != "" { + if msiSecret != "" { + return AzureML, nil + } else { + return CloudShell, nil + } + } else if isAzureArcEnvironment(identityEndpoint, imdsEndpoint) { + return AzureArc, nil + } + + return DefaultToIMDS, nil +} + +// Acquires tokens from the configured managed identity on an azure resource. +// +// Resource: scopes application is requesting access to +// Options: [WithClaims] +func (c Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { + resource = strings.TrimSuffix(resource, "/.default") + o := AcquireTokenOptions{} + for _, option := range options { + option(&o) + } + c.authParams.Scopes = []string{resource} + + // ignore cached access tokens when given claims + if o.claims == "" { + storageTokenResponse, err := cacheManager.Read(ctx, c.authParams) + if err != nil { + return base.AuthResult{}, err + } + ar, err := base.AuthResultFromStorage(storageTokenResponse) + if err == nil { + ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) + return ar, err + } + } + switch c.source { + case AzureArc: + return c.acquireTokenForAzureArc(ctx, resource) + case AzureML: + return c.acquireTokenForAzureML(ctx, resource) + case CloudShell: + return c.acquireTokenForCloudShell(ctx, resource) + case DefaultToIMDS: + return c.acquireTokenForIMDS(ctx, resource) + case AppService: + return c.acquireTokenForAppService(ctx, resource) + case ServiceFabric: + return c.acquireTokenForServiceFabric(ctx, resource) + default: + return base.AuthResult{}, fmt.Errorf("unsupported source %q", c.source) + } +} + +func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (base.AuthResult, error) { + req, err := createAppServiceAuthRequest(ctx, c.miType, resource) + if err != nil { + return base.AuthResult{}, err + } + tokenResponse, err := c.getTokenForRequest(req, resource) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(c.authParams, tokenResponse) +} + +func (c Client) acquireTokenForIMDS(ctx context.Context, resource string) (base.AuthResult, error) { + req, err := createIMDSAuthRequest(ctx, c.miType, resource) + if err != nil { + return base.AuthResult{}, err + } + tokenResponse, err := c.getTokenForRequest(req, resource) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(c.authParams, tokenResponse) +} + +func (c Client) acquireTokenForCloudShell(ctx context.Context, resource string) (base.AuthResult, error) { + req, err := createCloudShellAuthRequest(ctx, resource) + if err != nil { + return base.AuthResult{}, err + } + tokenResponse, err := c.getTokenForRequest(req, resource) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(c.authParams, tokenResponse) +} + +func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (base.AuthResult, error) { + req, err := createAzureMLAuthRequest(ctx, c.miType, resource) + if err != nil { + return base.AuthResult{}, err + } + tokenResponse, err := c.getTokenForRequest(req, resource) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(c.authParams, tokenResponse) +} + +func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (base.AuthResult, error) { + req, err := createServiceFabricAuthRequest(ctx, resource) + if err != nil { + return base.AuthResult{}, err + } + tokenResponse, err := c.getTokenForRequest(req, resource) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(c.authParams, tokenResponse) +} + +func (c Client) acquireTokenForAzureArc(ctx context.Context, resource string) (base.AuthResult, error) { + req, err := createAzureArcAuthRequest(ctx, resource, "") + if err != nil { + return base.AuthResult{}, err + } + + response, err := c.httpClient.Do(req) + if err != nil { + return base.AuthResult{}, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusUnauthorized { + return base.AuthResult{}, fmt.Errorf("expected a 401 response, received %d", response.StatusCode) + } + + secret, err := c.getAzureArcSecretKey(response, runtime.GOOS) + if err != nil { + return base.AuthResult{}, err + } + + secondRequest, err := createAzureArcAuthRequest(ctx, resource, string(secret)) + if err != nil { + return base.AuthResult{}, err + } + + tokenResponse, err := c.getTokenForRequest(secondRequest, resource) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(c.authParams, tokenResponse) +} + +func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) { + if cacheManager == nil { + return base.AuthResult{}, errors.New("cache instance is nil") + } + account, err := cacheManager.Write(authParams, token) + if err != nil { + return base.AuthResult{}, err + } + ar, err := base.NewAuthResult(token, account) + if err != nil { + return base.AuthResult{}, err + } + ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) + return ar, err +} + +// contains checks if the element is present in the list. +func contains[T comparable](list []T, element T) bool { + for _, v := range list { + if v == element { + return true + } + } + return false +} + +// retry performs an HTTP request with retries based on the provided options. +func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error) { + var resp *http.Response + var err error + for attempt := 0; attempt < maxRetries; attempt++ { + tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Minute) + defer tryCancel() + if resp != nil && resp.Body != nil { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + cloneReq := req.Clone(tryCtx) + resp, err = c.httpClient.Do(cloneReq) + retrylist := retryStatusCodes + if c.source == DefaultToIMDS { + retrylist = retryCodesForIMDS + } + if err == nil && !contains(retrylist, resp.StatusCode) { + return resp, nil + } + select { + case <-time.After(time.Second): + case <-req.Context().Done(): + err = req.Context().Err() + return resp, err + } + } + return resp, err +} + +func (c Client) getTokenForRequest(req *http.Request, resource string) (accesstokens.TokenResponse, error) { + r := accesstokens.TokenResponse{} + var resp *http.Response + var err error + + if c.retryPolicyEnabled { + resp, err = c.retry(defaultRetryCount, req) + } else { + resp, err = c.httpClient.Do(req) + } + if err != nil { + return r, err + } + responseBytes, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return r, err + } + switch resp.StatusCode { + case http.StatusOK, http.StatusAccepted: + default: + sd := strings.TrimSpace(string(responseBytes)) + if sd != "" { + return r, errors.CallErr{ + Req: req, + Resp: resp, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", + req.URL.String(), + req.Method, + resp.StatusCode, + sd), + } + } + return r, errors.CallErr{ + Req: req, + Resp: resp, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d", req.URL.String(), req.Method, resp.StatusCode), + } + } + + err = json.Unmarshal(responseBytes, &r) + if err != nil { + return r, errors.InvalidJsonErr{ + Err: fmt.Errorf("error parsing the json error: %s", err), + } + } + r.GrantedScopes.Slice = append(r.GrantedScopes.Slice, resource) + + return r, err +} + +func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { + identityEndpoint := os.Getenv(identityEndpointEnvVar) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar)) + q := req.URL.Query() + q.Set("api-version", appServiceAPIVersion) + q.Set("resource", resource) + switch t := id.(type) { + case UserAssignedClientID: + q.Set(miQueryParameterClientId, string(t)) + case UserAssignedResourceID: + q.Set(miQueryParameterResourceId, string(t)) + case UserAssignedObjectID: + q.Set(miQueryParameterObjectId, string(t)) + case systemAssignedValue: + default: + return nil, fmt.Errorf("unsupported type %T", id) + } + req.URL.RawQuery = q.Encode() + return req, nil +} + +func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { + msiEndpoint, err := url.Parse(imdsDefaultEndpoint) + if err != nil { + return nil, fmt.Errorf("couldn't parse %q: %s", imdsDefaultEndpoint, err) + } + msiParameters := msiEndpoint.Query() + msiParameters.Set(apiVersionQueryParameterName, imdsAPIVersion) + msiParameters.Set(resourceQueryParameterName, resource) + + switch t := id.(type) { + case UserAssignedClientID: + msiParameters.Set(miQueryParameterClientId, string(t)) + case UserAssignedResourceID: + msiParameters.Set(miQueryParameterResourceIdIMDS, string(t)) + case UserAssignedObjectID: + msiParameters.Set(miQueryParameterObjectId, string(t)) + case systemAssignedValue: // not adding anything + default: + return nil, fmt.Errorf("unsupported type %T", id) + } + + msiEndpoint.RawQuery = msiParameters.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("error creating http request %s", err) + } + req.Header.Set(metaHTTPHeaderName, "true") + return req, nil +} + +func createAzureArcAuthRequest(ctx context.Context, resource string, key string) (*http.Request, error) { + identityEndpoint := os.Getenv(identityEndpointEnvVar) + if identityEndpoint == "" { + identityEndpoint = azureArcEndpoint + } + msiEndpoint, parseErr := url.Parse(identityEndpoint) + + if parseErr != nil { + return nil, fmt.Errorf("couldn't parse %q: %s", identityEndpoint, parseErr) + } + + msiParameters := msiEndpoint.Query() + msiParameters.Set(apiVersionQueryParameterName, azureArcAPIVersion) + msiParameters.Set(resourceQueryParameterName, resource) + + msiEndpoint.RawQuery = msiParameters.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("error creating http request %s", err) + } + req.Header.Set(metaHTTPHeaderName, "true") + + if key != "" { + req.Header.Set("Authorization", fmt.Sprintf("Basic %s", key)) + } + + return req, nil +} + +func isAzureArcEnvironment(identityEndpoint, imdsEndpoint string) bool { + if identityEndpoint != "" && imdsEndpoint != "" { + return true + } + himdsFilePath := getAzureArcHimdsFilePath(runtime.GOOS) + if himdsFilePath != "" { + if _, err := os.Stat(himdsFilePath); err == nil { + return true + } + } + return false +} + +func (c *Client) getAzureArcSecretKey(response *http.Response, platform string) (string, error) { + wwwAuthenticateHeader := response.Header.Get(wwwAuthenticateHeaderName) + + if len(wwwAuthenticateHeader) == 0 { + return "", errors.New("response has no www-authenticate header") + } + + // check if the platform is supported + expectedSecretFilePath := getAzureArcPlatformPath(platform) + if expectedSecretFilePath == "" { + return "", errors.New("platform not supported, expected linux or windows") + } + + parts := strings.Split(wwwAuthenticateHeader, "Basic realm=") + if len(parts) < 2 { + return "", fmt.Errorf("basic realm= not found in the string, instead found: %s", wwwAuthenticateHeader) + } + + secretFilePath := parts + + // check that the file in the file path is a .key file + fileName := filepath.Base(secretFilePath[1]) + if !strings.HasSuffix(fileName, azureArcFileExtension) { + return "", fmt.Errorf("invalid file extension, expected %s, got %s", azureArcFileExtension, filepath.Ext(fileName)) + } + + // check that file path from header matches the expected file path for the platform + if expectedSecretFilePath != filepath.Dir(secretFilePath[1]) { + return "", fmt.Errorf("invalid file path, expected %s, got %s", expectedSecretFilePath, filepath.Dir(secretFilePath[1])) + } + + fileInfo, err := os.Stat(secretFilePath[1]) + if err != nil { + return "", fmt.Errorf("failed to get metadata for %s due to error: %s", secretFilePath[1], err) + } + + // Throw an error if the secret file's size is greater than 4096 bytes + if s := fileInfo.Size(); s > azureArcMaxFileSizeBytes { + return "", fmt.Errorf("invalid secret file size, expected %d, file size was %d", azureArcMaxFileSizeBytes, s) + } + + // Attempt to read the contents of the secret file + secret, err := os.ReadFile(secretFilePath[1]) + if err != nil { + return "", fmt.Errorf("failed to read %q due to error: %s", secretFilePath[1], err) + } + + return string(secret), nil +} diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go new file mode 100644 index 00000000..ce9e8cfc --- /dev/null +++ b/apps/managedidentity/managedidentity_test.go @@ -0,0 +1,1133 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package managedidentity + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" +) + +const ( + // Test Resources + resource = "https://management.azure.com" + resourceDefaultSuffix = "https://management.azure.com/.default" + token = "fake-access-token" + fakeAzureArcFilePath = "fake/fake" + secretKey = "secret.key" + basicRealm = "Basic realm=" + + errorExpectedButGot = "expected %v, got %v" + errorFormingJsonResponse = "error while forming json response : %s" +) + +type SuccessfulResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresOn int64 `json:"expires_on,omitempty"` + Resource string `json:"resource"` + TokenType string `json:"token_type"` +} + +type ErrorResponse struct { + Err string `json:"error"` + Desc string `json:"error_description"` +} + +func getSuccessfulResponse(resource string, doesHaveExpireIn bool) ([]byte, error) { + var response SuccessfulResponse + if doesHaveExpireIn { + duration := 10 * time.Minute + expiresIn := duration.Seconds() + response = SuccessfulResponse{ + AccessToken: token, + ExpiresIn: int64(expiresIn), + Resource: resource, + TokenType: "Bearer", + } + } else { + response = SuccessfulResponse{ + AccessToken: token, + ExpiresOn: time.Now().Add(time.Hour).Unix(), + Resource: resource, + TokenType: "Bearer", + } + } + jsonResponse, err := json.Marshal(response) + return jsonResponse, err +} + +func makeResponseWithErrorData(err string, desc string) ([]byte, error) { + responseBody := ErrorResponse{ + Err: err, + Desc: desc, + } + jsonResponse, e := json.Marshal(responseBody) + return jsonResponse, e +} + +func createMockFile(t *testing.T, path string, size int64) { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("failed to create directory: %v", err) + } + + f, err := os.Create(path) + if err != nil { + t.Fatalf("failed to create file: %v", err) + } + defer f.Close() + + if size > 0 { + if err := f.Truncate(size); err != nil { + t.Fatalf("failed to truncate file: %v", err) + } + } + + // Write the content to the file + if _, err := f.WriteString("secret file data"); err != nil { + t.Fatalf("failed to write to file: %v", err) + } + t.Cleanup(func() { os.Remove(path) }) +} + +func setEnvVars(t *testing.T, source Source) { + switch source { + case AzureArc: + t.Setenv(identityEndpointEnvVar, "http://127.0.0.1:40342/metadata/identity/oauth2/token") + t.Setenv(imdsEndVar, "http://169.254.169.254/metadata/identity/oauth2/token") + case AppService: + t.Setenv(identityEndpointEnvVar, "http://127.0.0.1:41564/msi/token") + t.Setenv(identityHeaderEnvVar, "secret") + case CloudShell: + t.Setenv(msiEndpointEnvVar, "http://localhost:50342/oauth2/token") + case ServiceFabric: + t.Setenv(identityEndpointEnvVar, "http://localhost:40342/metadata/identity/oauth2/token") + t.Setenv(identityHeaderEnvVar, "secret") + t.Setenv(identityServerThumbprintEnvVar, "thumbprint") + case AzureML: + t.Setenv(msiEndpointEnvVar, "http://127.0.0.1:41564/msi/token") + t.Setenv(msiSecretEnvVar, "redacted") + } +} + +func setCustomAzureArcPlatformPath(t *testing.T, path string) { + originalFunc := getAzureArcPlatformPath + getAzureArcPlatformPath = func(string) string { + return path + } + + t.Cleanup(func() { getAzureArcPlatformPath = originalFunc }) +} + +func setCustomAzureArcFilePath(t *testing.T, path string) { + originalFunc := getAzureArcHimdsFilePath + getAzureArcHimdsFilePath = func(string) string { + return path + } + + t.Cleanup(func() { getAzureArcHimdsFilePath = originalFunc }) +} + +func TestSource(t *testing.T) { + for _, testCase := range []Source{AzureArc, DefaultToIMDS, CloudShell, AzureML, AppService} { + t.Run(string(testCase), func(t *testing.T) { + setEnvVars(t, testCase) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + + actualSource, err := GetSource() + if err != nil { + t.Fatalf("error while getting source: %s", err.Error()) + } + if actualSource != testCase { + t.Fatalf(errorExpectedButGot, testCase, actualSource) + } + }) + } +} + +func TestRetryFunction(t *testing.T) { + tests := []struct { + name string + mockResponses []struct { + body string + statusCode int + } + expectedStatus int + expectedBody string + maxRetries int + source Source + }{ + { + name: "Successful Request", + mockResponses: []struct { + body string + statusCode int + }{ + {"Failed", http.StatusInternalServerError}, + {"Success", http.StatusOK}, + }, + expectedStatus: http.StatusOK, + expectedBody: "Success", + maxRetries: 3, + source: AzureArc, + }, + { + name: "Successful Request", + mockResponses: []struct { + body string + statusCode int + }{ + {"Failed", http.StatusNotFound}, + {"Success", http.StatusOK}, + }, + expectedStatus: http.StatusOK, + expectedBody: "Success", + maxRetries: 3, + source: DefaultToIMDS, + }, + { + name: "Max Retries Reached", + mockResponses: []struct { + body string + statusCode int + }{ + {"Error", http.StatusInternalServerError}, + {"Error", http.StatusInternalServerError}, + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: "Error", + maxRetries: 2, + source: AzureArc, + }, + { + name: "Max Retries Reached", + mockResponses: []struct { + body string + statusCode int + }{ + {"Error", http.StatusNotFound}, + {"Error", http.StatusInternalServerError}, + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: "Error", + maxRetries: 2, + source: DefaultToIMDS, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &mock.Client{} + for _, resp := range tt.mockResponses { + body := bytes.NewBufferString(resp.body) + mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode)) + } + client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithRetryPolicyDisabled()) + if err != nil { + t.Fatal(err) + } + reqBody := bytes.NewBufferString("Test Body") + req, err := http.NewRequest("POST", "https://example.com", reqBody) + if err != nil { + t.Fatal(err) + } + finalResp, err := client.retry(tt.maxRetries, req) + if err != nil { + t.Fatal(err) + } + if finalResp.StatusCode != tt.expectedStatus { + t.Fatalf("Expected status code %d, got %d", tt.expectedStatus, finalResp.StatusCode) + } + bodyBytes, err := io.ReadAll(finalResp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + finalResp.Body.Close() + if string(bodyBytes) != tt.expectedBody { + t.Fatalf("Expected body %q, got %q", tt.expectedBody, bodyBytes) + } + }) + } +} + +func Test_RetryPolicy_For_AcquireToken(t *testing.T) { + testCases := []struct { + numberOfFails int + expectedFail bool + disableRetry bool + }{ + {numberOfFails: 1, expectedFail: false, disableRetry: false}, + {numberOfFails: 1, expectedFail: true, disableRetry: true}, + {numberOfFails: 1, expectedFail: true, disableRetry: true}, + {numberOfFails: 2, expectedFail: false, disableRetry: false}, + {numberOfFails: 3, expectedFail: true, disableRetry: false}, + } + for _, testCase := range testCases { + t.Run(fmt.Sprintf("Testing retry policy with %d ", testCase.numberOfFails), func(t *testing.T) { + fakeErrorClient := mock.Client{} + responseBody, err := makeResponseWithErrorData("sample error", "sample error desc") + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + errorRetryCounter := 0 + for i := 0; i < testCase.numberOfFails; i++ { + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusInternalServerError), + mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + errorRetryCounter++ + })) + } + if !testCase.expectedFail { + successRespBody, err := getSuccessfulResponse(resource, true) + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusAccepted), + mock.WithBody(successRespBody)) + } + var client Client + if testCase.disableRetry { + client, err = New(SystemAssigned(), WithHTTPClient(&fakeErrorClient), WithRetryPolicyDisabled()) + } else { + client, err = New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) + } + if err != nil { + t.Fatal(err) + } + resp, err := client.AcquireToken(context.Background(), resource, WithClaims("noCache")) + if testCase.expectedFail { + if err == nil { + t.Fatalf("should have encountered the error") + } + if resp.AccessToken != "" { + t.Fatalf("accesstoken should be empty") + } + } else { + if err != nil { + t.Fatal(err) + } + if resp.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, resp.AccessToken) + } + } + if testCase.disableRetry { + if errorRetryCounter != 1 { + t.Fatalf("expected Number of retry of 1, got %d", errorRetryCounter) + } + } else if errorRetryCounter != testCase.numberOfFails && testCase.numberOfFails < 3 { + t.Fatalf("expected Number of retry of %d, got %d", testCase.numberOfFails, errorRetryCounter) + } + }) + } +} + +func TestCacheScopes(t *testing.T) { + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + mc := mock.Client{} + client, err := New(SystemAssigned(), WithHTTPClient(&mc)) + if err != nil { + t.Fatal(err) + } + + for _, r := range []string{"A", "B/.default"} { + mc.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(r, "", "", "", 3600))) + for i := 0; i < 2; i++ { + ar, err := client.AcquireToken(context.Background(), r) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != r { + t.Fatalf("expected %q, got %q", r, ar.AccessToken) + } + } + } +} + +func TestAzureArcReturnsWhenHimdsFound(t *testing.T) { + mockFilePath := filepath.Join(t.TempDir(), "himds") + setCustomAzureArcFilePath(t, mockFilePath) + + // Create the mock himds file + createMockFile(t, mockFilePath, 1024) + + actualSource, err := GetSource() + if err != nil { + t.Fatalf("error while getting source: %s", err.Error()) + } + + if actualSource != AzureArc { + t.Fatalf(errorExpectedButGot, AzureArc, actualSource) + } +} + +func TestIMDSAcquireTokenReturnsTokenSuccess(t *testing.T) { + testCases := []struct { + resource string + miType ID + }{ + {resource: resource, miType: SystemAssigned()}, + {resource: resourceDefaultSuffix, miType: SystemAssigned()}, + {resource: resource, miType: UserAssignedClientID("clientId")}, + {resource: resourceDefaultSuffix, miType: UserAssignedResourceID("resourceId")}, + {resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")}, + } + for _, testCase := range testCases { + t.Run(string(DefaultToIMDS)+"-"+testCase.miType.value(), func(t *testing.T) { + endpoint := imdsDefaultEndpoint + + var localUrl *url.URL + mockClient := mock.Client{} + responseBody, err := getSuccessfulResponse(resource, true) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + // resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) + } + query := localUrl.Query() + + if query.Get(apiVersionQueryParameterName) != imdsAPIVersion { + t.Fatalf("api-version not on %s got %s", imdsAPIVersion, query.Get(apiVersionQueryParameterName)) + } + if query.Get(resourceQueryParameterName) != strings.TrimSuffix(testCase.resource, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + switch i := testCase.miType.(type) { + case UserAssignedClientID: + if query.Get(miQueryParameterClientId) != i.value() { + t.Fatalf("resource client-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterClientId)) + } + case UserAssignedResourceID: + if query.Get(miQueryParameterResourceIdIMDS) != i.value() { + t.Fatalf("resource resource-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceIdIMDS)) + } + case UserAssignedObjectID: + if query.Get(miQueryParameterObjectId) != i.value() { + t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) + } + } + if result.Metadata.TokenSource != base.IdentityProvider { + t.Fatalf("expected IndenityProvider tokensource, got %d", result.Metadata.TokenSource) + } + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + result, err = client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource) + } + secondFakeClient, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err = secondFakeClient.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource) + } + }) + } +} + +func TestInvalidJsonErrReturnOnAcquireToken(t *testing.T) { + resource := "https://resource/.default" + miType := SystemAssigned() + + setEnvVars(t, DefaultToIMDS) + mockClient := mock.Client{} + responseBody := fmt.Sprintf( + `{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`, + "tetant", 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(), + ) + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody([]byte(responseBody))) + + // Resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + + _, err = client.AcquireToken(context.Background(), resource) + if err == nil { + t.Fatal("should have failed with InvalidJsonErr Response") + } + var ie errors.InvalidJsonErr + if !errors.As(err, &ie) { + t.Fatal("should have revieved a InvalidJsonErr, but got", err) + } +} + +func TestCloudShellAcquireTokenReturnsTokenSuccess(t *testing.T) { + resource := "https://resource/.default" + miType := SystemAssigned() + + setEnvVars(t, CloudShell) + endpoint := os.Getenv(msiEndpointEnvVar) + + var localUrl *url.URL + var resourceString string + mockClient := mock.Client{} + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + err = r.ParseForm() + if err != nil { + t.Fatal(err) + } + resourceString = r.FormValue(resourceQueryParameterName) + })) + + // Resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + + result, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + + if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) + } + + if resourceString != strings.TrimSuffix(resource, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + + if result.Metadata.TokenSource != base.IdentityProvider { + t.Fatalf("expected IdentityProvider tokensource, got %d", result.Metadata.TokenSource) + } + + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + + result, err = client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource) + } + + secondFakeClient, err := New(miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + + result, err = secondFakeClient.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource) + } +} + +func TestCloudShellOnlySystemAssignedSupported(t *testing.T) { + setEnvVars(t, CloudShell) + mockClient := mock.Client{} + + for _, testCase := range []ID{ + UserAssignedClientID("client"), + UserAssignedObjectID("ObjectId"), + UserAssignedResourceID("resourceid"), + } { + _, err := New(testCase, WithHTTPClient(&mockClient)) + if err == nil { + t.Fatal(`expected error: CloudShell not supported error"`) + + } + if err.Error() != "Cloud Shell doesn't support user-assigned managed identities" { + t.Fatalf("expected error: Cloud Shell doesn't support user-assigned managed identities, got error: %q", err) + } + + } +} + +func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) { + setEnvVars(t, AppService) + testCases := []struct { + resource string + miType ID + }{ + {resource: resource, miType: SystemAssigned()}, + {resource: resourceDefaultSuffix, miType: SystemAssigned()}, + {resource: resource, miType: UserAssignedClientID("clientId")}, + {resource: resourceDefaultSuffix, miType: UserAssignedResourceID("resourceId")}, + {resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")}, + } + for _, testCase := range testCases { + t.Run(string(AppService)+"-"+testCase.miType.value(), func(t *testing.T) { + endpoint := "http://127.0.0.1:41564/msi/token" + + var localUrl *url.URL + mockClient := mock.Client{} + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + // resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) + } + query := localUrl.Query() + + if query.Get(apiVersionQueryParameterName) != appServiceAPIVersion { + t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, query.Get(apiVersionQueryParameterName)) + } + if r := query.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + switch i := testCase.miType.(type) { + case UserAssignedClientID: + if actual := query.Get(miQueryParameterClientId); actual != i.value() { + t.Fatalf("resource client-id is incorrect, wanted %s got %s", i.value(), actual) + } + case UserAssignedResourceID: + if query.Get(miQueryParameterResourceId) != i.value() { + t.Fatalf("resource resource id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceId)) + } + case UserAssignedObjectID: + if query.Get(miQueryParameterObjectId) != i.value() { + t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) + } + } + if result.Metadata.TokenSource != base.IdentityProvider { + t.Fatalf("expected IndenityProvider tokensource, got %d", result.Metadata.TokenSource) + } + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + result, err = client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource) + } + secondFakeClient, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err = secondFakeClient.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource) + } + }) + } +} + +func TestAzureMLAcquireTokenReturnsTokenSuccess(t *testing.T) { + defaultClientID := "A" + t.Setenv("DEFAULT_IDENTITY_CLIENT_ID", defaultClientID) + + setEnvVars(t, AzureML) + testCases := []struct { + resource string + miType ID + expectedClientID string + }{ + {resource: resource, miType: SystemAssigned(), expectedClientID: defaultClientID}, + {resource: resourceDefaultSuffix, miType: SystemAssigned(), expectedClientID: defaultClientID}, + {resource: resource, miType: UserAssignedClientID("B"), expectedClientID: "B"}, + } + for _, testCase := range testCases { + t.Run(string(AzureML)+"-"+testCase.miType.value(), func(t *testing.T) { + endpoint := "http://127.0.0.1:41564/msi/token" + + var localUrl *url.URL + mockClient := mock.Client{} + responseBody, err := getSuccessfulResponse(resource, false) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), + mock.WithBody(responseBody), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + // resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) + } + query := localUrl.Query() + + if query.Get(apiVersionQueryParameterName) != azureMLAPIVersion { + t.Fatalf("api-version not on %s got %s", azureMLAPIVersion, query.Get(apiVersionQueryParameterName)) + } + if r := query.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if result.Metadata.TokenSource != base.IdentityProvider { + t.Fatalf("expected IdentityProvider tokensource, got %d", result.Metadata.TokenSource) + } + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + if actual := query.Get("clientid"); actual != testCase.expectedClientID { + t.Fatalf("expected clientid to be set to %s, got %s", testCase.expectedClientID, actual) + } + + result, err = client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource) + } + secondFakeClient, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err = secondFakeClient.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource) + } + }) + } +} + +func TestAzureMLErrors(t *testing.T) { + setEnvVars(t, AzureML) + mockClient := mock.Client{} + + for _, testCase := range []ID{ + UserAssignedObjectID("ObjectId"), + UserAssignedResourceID("resourceid")} { + _, err := New(testCase, WithHTTPClient(&mockClient)) + if err == nil { + t.Fatal("expected error: Azure ML supports specifying a user-assigned managed identity by client ID only") + + } + if err.Error() != "Azure ML supports specifying a user-assigned managed identity by client ID only" { + t.Fatalf("expected error: Azure ML supports specifying a user-assigned managed identity by client ID only, got error: %q", err) + } + + } +} + +func TestAzureArc(t *testing.T) { + testCaseFilePath := filepath.Join(t.TempDir(), azureConnectedMachine) + + endpoint := azureArcEndpoint + setEnvVars(t, AzureArc) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + + var localUrl *url.URL + mockClient := mock.Client{} + + mockFilePath := filepath.Join(testCaseFilePath, secretKey) + setCustomAzureArcPlatformPath(t, testCaseFilePath) + + createMockFile(t, mockFilePath, 0) + + headers := http.Header{} + headers.Set(wwwAuthenticateHeaderName, basicRealm+filepath.Join(testCaseFilePath, secretKey)) + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), + mock.WithHTTPHeader(headers), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + + responseBody, err := getSuccessfulResponse(resource, true) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithHTTPHeader(headers), + mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + + // resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), resourceDefaultSuffix) + if err != nil { + t.Fatal(err) + } + + if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) + } + + query := localUrl.Query() + + if query.Get(apiVersionQueryParameterName) != azureArcAPIVersion { + t.Fatalf("api-version not on %s got %s", azureArcAPIVersion, query.Get(apiVersionQueryParameterName)) + } + if query.Get(resourceQueryParameterName) != strings.TrimSuffix(resourceDefaultSuffix, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if result.Metadata.TokenSource != base.IdentityProvider { + t.Fatalf("expected IndenityProvider tokensource, got %d", result.Metadata.TokenSource) + } + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + result, err = client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource) + } + secondFakeClient, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err = secondFakeClient.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource) + } + +} + +func TestAzureArcOnlySystemAssignedSupported(t *testing.T) { + setEnvVars(t, AzureArc) + mockClient := mock.Client{} + + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + for _, testCase := range []ID{ + UserAssignedClientID("client"), + UserAssignedObjectID("ObjectId"), + UserAssignedResourceID("resourceid")} { + _, err := New(testCase, WithHTTPClient(&mockClient)) + if err == nil { + t.Fatal(`expected error: AzureArc not supported error"`) + + } + if err.Error() != "Azure Arc doesn't support user-assigned managed identities" { + t.Fatalf("expected error: AzureArc not supported error, got error: %q", err) + } + + } +} + +func TestAzureArcPlatformSupported(t *testing.T) { + setEnvVars(t, AzureArc) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + mockClient := mock.Client{} + headers := http.Header{} + headers.Set(wwwAuthenticateHeaderName, "Basic realm=/path/to/secret.key") + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), + mock.WithHTTPHeader(headers), + ) + setCustomAzureArcPlatformPath(t, "") + + client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), resource) + if err == nil || !strings.Contains(err.Error(), "platform not supported") { + t.Fatalf(`expected error: "%v" got error: "%v"`, "platform not supported", err) + + } + if result.AccessToken != "" { + t.Fatalf("access token should be empty") + } +} + +func TestAzureArcErrors(t *testing.T) { + setEnvVars(t, AzureArc) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + testCaseFilePath := filepath.Join(t.TempDir(), "AzureConnectedMachineAgent") + + testCases := []struct { + name string + headerValue string + expectedError string + fileSize int64 + }{ + { + name: "No www-authenticate header", + expectedError: "response has no www-authenticate header", + }, + { + name: "Basic realm= not found", + headerValue: "Basic ", + expectedError: "basic realm= not found in the string, instead found: Basic ", + }, + { + name: "Invalid file extension", + headerValue: "Basic realm=/path/to/secret.txt", + expectedError: "invalid file extension, expected .key, got .txt", + }, + { + name: "Invalid file path", + headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey), + expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"), + }, + { + name: "Unable to get file info", + headerValue: basicRealm + filepath.Join(testCaseFilePath, "2secret.key"), + expectedError: "failed to get metadata", + }, + { + name: "Invalid secret file size", + headerValue: basicRealm + filepath.Join(testCaseFilePath, secretKey), + expectedError: "invalid secret file size, expected 4096, file size was 5000", + fileSize: 5000, + }, + } + + for _, testCase := range testCases { + t.Run(string(testCase.name), func(t *testing.T) { + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + mockClient := mock.Client{} + mockFilePath := filepath.Join(testCaseFilePath, secretKey) + setCustomAzureArcPlatformPath(t, testCaseFilePath) + createMockFile(t, mockFilePath, testCase.fileSize) + headers := http.Header{} + headers.Set(wwwAuthenticateHeaderName, testCase.headerValue) + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), + mock.WithHTTPHeader(headers), + ) + + responseBody, err := getSuccessfulResponse(resource, true) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithHTTPHeader(headers), + mock.WithBody(responseBody)) + + client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + return + } + result, err := client.AcquireToken(context.Background(), resource) + if err == nil || !strings.Contains(err.Error(), testCase.expectedError) { + t.Fatalf(`expected error: "%v" got error: "%v"`, testCase.expectedError, err) + + } + if result.AccessToken != "" { + t.Fatal("access token should be empty") + } + }) + } +} + +func TestSystemAssignedReturnsAcquireTokenFailure(t *testing.T) { + testCases := []struct { + code int + err string + desc string + }{ + {code: http.StatusNotFound}, + {code: http.StatusNotImplemented}, + {code: http.StatusServiceUnavailable}, + {code: http.StatusBadRequest, + err: "invalid_request", + desc: "Identity not found", + }, + } + + for _, testCase := range testCases { + t.Run(http.StatusText(testCase.code), func(t *testing.T) { + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + fakeErrorClient := mock.Client{} + responseBody, err := makeResponseWithErrorData(testCase.err, testCase.desc) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), + mock.WithBody(responseBody)) + client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient), WithRetryPolicyDisabled()) + if err != nil { + t.Fatal(err) + } + resp, err := client.AcquireToken(context.Background(), resource) + if err == nil { + t.Fatalf("should have encountered the error") + } + var callErr errors.CallErr + if errors.As(err, &callErr) { + if !strings.Contains(err.Error(), testCase.err) { + t.Fatalf("expected message '%s' in error, got %q", testCase.err, callErr.Error()) + } + if callErr.Resp.StatusCode != testCase.code { + t.Fatalf("expected status code %d, got %d", testCase.code, callErr.Resp.StatusCode) + } + } else { + t.Fatalf("expected error of type %T, got %T", callErr, err) + } + if resp.AccessToken != "" { + t.Fatalf("access token should be empty") + } + }) + } +} + +func TestCreatingIMDSClient(t *testing.T) { + tests := []struct { + name string + id ID + wantErr bool + }{ + { + name: "System Assigned", + id: SystemAssigned(), + }, + { + name: "Client ID", + id: UserAssignedClientID("test-client-id"), + }, + { + name: "Resource ID", + id: UserAssignedResourceID("test-resource-id"), + }, + { + name: "Object ID", + id: UserAssignedObjectID("test-object-id"), + }, + { + name: "Empty Client ID", + id: UserAssignedClientID(""), + wantErr: true, + }, + { + name: "Empty Resource ID", + id: UserAssignedResourceID(""), + wantErr: true, + }, + { + name: "Empty Object ID", + id: UserAssignedObjectID(""), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + client, err := New(tt.id) + if tt.wantErr { + if err == nil { + t.Fatal("client New() should return a error but did not.") + } + return + } + if err != nil { + t.Fatal(err) + } + if client.miType.value() != tt.id.value() { + t.Fatal("client New() did not assign a correct value to type.") + } + }) + } +} diff --git a/apps/managedidentity/servicefabric.go b/apps/managedidentity/servicefabric.go new file mode 100644 index 00000000..535065e9 --- /dev/null +++ b/apps/managedidentity/servicefabric.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package managedidentity + +import ( + "context" + "net/http" + "os" +) + +func createServiceFabricAuthRequest(ctx context.Context, resource string) (*http.Request, error) { + identityEndpoint := os.Getenv(identityEndpointEnvVar) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Secret", os.Getenv(identityHeaderEnvVar)) + q := req.URL.Query() + q.Set("api-version", serviceFabricAPIVersion) + q.Set("resource", resource) + req.URL.RawQuery = q.Encode() + return req, nil +} diff --git a/apps/managedidentity/servicefabric_test.go b/apps/managedidentity/servicefabric_test.go new file mode 100644 index 00000000..58f6a4ec --- /dev/null +++ b/apps/managedidentity/servicefabric_test.go @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package managedidentity + +import ( + "context" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" +) + +func TestServiceFabricAcquireTokenReturnsTokenSuccess(t *testing.T) { + setEnvVars(t, ServiceFabric) + testCases := []struct { + resource string + miType ID + }{ + {resource: resource, miType: SystemAssigned()}, + {resource: resourceDefaultSuffix, miType: SystemAssigned()}, + } + for _, testCase := range testCases { + t.Run(string(DefaultToIMDS)+"-"+testCase.miType.value(), func(t *testing.T) { + endpoint := imdsDefaultEndpoint + var localUrl *url.URL + var localHeader http.Header + mockClient := mock.Client{} + responseBody, err := getSuccessfulResponse(resource, true) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + localHeader = r.Header + })) + // resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if localUrl == nil || !strings.HasPrefix(localUrl.String(), "http://localhost:40342/metadata/identity/oauth2/token") { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) + } + query := localUrl.Query() + + if got := query.Get(apiVersionQueryParameterName); got != serviceFabricAPIVersion { + t.Fatalf("api-version not on %s got %s", serviceFabricAPIVersion, got) + } + if query.Get(resourceQueryParameterName) != strings.TrimSuffix(testCase.resource, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if localHeader.Get("Accept") != "application/json" { + t.Fatalf("expected Accept header to be application/json, got %s", localHeader.Get("Accept")) + } + if localHeader.Get("Secret") != "secret" { + t.Fatalf("expected secret to be secret, got %s", query.Get("Secret")) + } + if result.Metadata.TokenSource != base.IdentityProvider { + t.Fatalf("expected IndenityProvider tokensource, got %d", result.Metadata.TokenSource) + } + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + result, err = client.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource) + } + secondFakeClient, err := New(testCase.miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err = secondFakeClient.AcquireToken(context.Background(), testCase.resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource) + } + }) + } +} +func TestServiceFabricErrors(t *testing.T) { + setEnvVars(t, ServiceFabric) + mockClient := mock.Client{} + + for _, testCase := range []ID{ + UserAssignedObjectID("ObjectId"), + UserAssignedResourceID("resourceid"), + UserAssignedClientID("ClientID")} { + _, err := New(testCase, WithHTTPClient(&mockClient)) + if err == nil { + t.Fatal("expected error: Service Fabric API doesn't support specifying a user-assigned identity. The identity is determined by cluster resource configuration. See https://aka.ms/servicefabricmi") + } + if err.Error() != "Service Fabric API doesn't support specifying a user-assigned identity. The identity is determined by cluster resource configuration. See https://aka.ms/servicefabricmi" { + t.Fatalf("expected error: Service Fabric API doesn't support specifying a user-assigned identity. The identity is determined by cluster resource configuration. See https://aka.ms/servicefabricmi, got error: %q", err) + } + + } +} diff --git a/apps/public/public_test.go b/apps/public/public_test.go index c0fb9b33..4fe575d2 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -255,7 +255,7 @@ func TestAcquireTokenWithTenantID(t *testing.T) { case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(test.tenant)) default: - t.Fatalf("test bug: no test for " + method) + t.Fatalf("test bug: no test for %s", method) } if err != nil { if test.expectError { @@ -309,7 +309,7 @@ func TestADFSTokenCaching(t *testing.T) { AccessToken: "at1", RefreshToken: "rt", TokenType: "Bearer", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExpiresOn: time.Now().Add(time.Hour), ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, IDToken: accesstokens.IDToken{ @@ -620,7 +620,7 @@ func TestWithClaims(t *testing.T) { client.base.Token.WSTrust = fake.WSTrust{SamlTokenInfo: wstrust.SamlTokenInfo{AssertionType: "urn:ietf:params:oauth:grant-type:saml1_1-bearer"}} ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims)) default: - t.Fatalf("test bug: no test for " + method) + t.Fatalf("test bug: no test for %s", method) } if method == "devicecode" && err == nil { // complete the device code flow @@ -910,7 +910,7 @@ func TestWithAuthenticationScheme(t *testing.T) { case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithAuthenticationScheme(authScheme)) default: - t.Fatalf("test bug: no test for " + testCase.name) + t.Fatalf("test bug: no test for %s", testCase.name) } // validate that the token is created correctly diff --git a/apps/tests/benchmarks/confidential.go b/apps/tests/benchmarks/confidential.go index 802bafbb..09a706d4 100644 --- a/apps/tests/benchmarks/confidential.go +++ b/apps/tests/benchmarks/confidential.go @@ -14,7 +14,6 @@ import ( "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" - internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" @@ -40,7 +39,7 @@ func fakeClient() (base.Client, error) { AccessTokens: &fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ AccessToken: accessToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExpiresOn: time.Now().Add(1 * time.Hour), GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, }, }, @@ -89,7 +88,7 @@ func populateTokenCache(client base.Client, params testParams) execTime { // each token has a different scope which is what makes them unique _, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{ AccessToken: accessToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExpiresOn: time.Now().Add(1 * time.Hour), GrantedScopes: accesstokens.Scopes{Slice: []string{strconv.FormatInt(int64(i), 10)}}, }, true) if err != nil { diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index 59672fab..ca0d12e6 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -5,7 +5,7 @@ import ( ) var ( - //config = CreateConfig("config.json") + //config = CreateConfig("config.json") reenable when config is implemented cacheAccessor = &TokenCache{file: "serialized_cache.json"} ) diff --git a/apps/tests/devapps/managedidentity/docs/msi_manual_testing.md b/apps/tests/devapps/managedidentity/docs/msi_manual_testing.md new file mode 100644 index 00000000..b09d87b0 --- /dev/null +++ b/apps/tests/devapps/managedidentity/docs/msi_manual_testing.md @@ -0,0 +1,3 @@ +# Running Managed Identity Sources + +A full overview of how to run each sample source can be found in the [Azure Samples - MSAL GO](https://github.com/Azure-Samples/msal-managed-identity/tree/main/src/go) repository diff --git a/apps/tests/performance/performance_test.go b/apps/tests/performance/performance_test.go index 66050b72..30e98e70 100644 --- a/apps/tests/performance/performance_test.go +++ b/apps/tests/performance/performance_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" - internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" @@ -55,7 +54,7 @@ func populateCache(users int, tokens int, authParams authority.AuthParams, clien AccessToken: fmt.Sprintf("fake_access_token%d", user), RefreshToken: "fake_refresh_token", ClientInfo: accesstokens.ClientInfo{UID: "my_uid", UTID: fmt.Sprintf("%dmy_utid", user)}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExpiresOn: time.Now().Add(1 * time.Hour), GrantedScopes: accesstokens.Scopes{Slice: []string{scope}}, IDToken: accesstokens.IDToken{ RawToken: "x.e30", diff --git a/docs/managedidentity_public_api.md b/docs/managedidentity_public_api.md new file mode 100644 index 00000000..6c5401cc --- /dev/null +++ b/docs/managedidentity_public_api.md @@ -0,0 +1,205 @@ +# Managed Identity Public API Design Specification + +The purpose of this file is to go over the changes required for adding the Managed Identity feature to MSAL GO + +## Public API + +The public API will be quite small. Based on the Java and .NET implementations, there is only 1 exposed method, **acquireTokenForManagedIdentity()** + +```go +// Acquires tokens from the configured managed identity on an azure resource. +// +// Resource: scopes application is requesting access to +// Options: [WithClaims] +func (client Client) AcquireToken(context context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { + return base.AuthResult{}, nil +} + +// Source represents the managed identity sources supported. +type Source int + +const ( + // AzureArc represents the source to acquire token for managed identity is Azure Arc. + AzureArc = 0 + + // DefaultToIMDS indicates that the source is defaulted to IMDS since no environment variables are set. + DefaultToIMDS = 1 +) + +// Detects and returns the managed identity source available on the environment. +func GetSource() Source { + return DefaultToIMDS +} +``` + +The end user simply needs to create their own instance of Managed Identity Client, i.e **managedIdentity.Client()**, passing in the **ManagedIdentityType** they want to use, and then call the public API. The example below shows creation of different clients for each of the different Managed Identity Types + +```go +import ( + "context" + "fmt" + "net/http" + + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" +) + +func RunManagedIdentity() { + customHttpClient := &http.Client{} + + miSystemAssigned, error := mi.New(mi.SystemAssigned()) + if error != nil { + fmt.Println(error) + } + + miClientIdAssigned, error := mi.New(mi.ClientID("client id 123"), mi.WithHTTPClient(customHttpClient)) + if error != nil { + fmt.Println(error) + } + + miResourceIdAssigned, error := mi.New(mi.ResourceID("resource id 123")) + if error != nil { + fmt.Println(error) + } + + miObjectIdAssigned, error := mi.New(mi.ObjectID("object id 123")) + if error != nil { + fmt.Println(error) + } + + miSystemAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) + + miClientIdAssigned.AcquireToken(context.Background(), "resource") + + miResourceIdAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) + + miObjectIdAssigned.AcquireToken(context.Background(), "resource") +} +``` + +To create a new **ManagedIdentityClient** + +```go +// Client to be used to acquire tokens for managed identity. +// ID: [SystemAssigned()], [ClientID("clientID")], [ResourceID("resourceID")], [ObjectID("objectID")] +// +// Options: [WithHTTPClient] +func New(id ID, options ...Option) (Client, error) { + // implementation details +} +``` + +The options available for passing to the client are + +```go +// WithHTTPClient allows for a custom HTTP client to be set. +func WithHTTPClient(httpClient ops.HTTPClient) Option { + // implementation details +} +``` + +The options available for the request are + +```go +// WithClaims sets additional claims to request for the token, such as those required by conditional access policies. +// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. +func WithClaims(claims string) AcquireTokenOption { + // implementation details +} +``` + +## Error Handling + +Error handling in GO is different to what we used to in languages like Java or Swift. +There is no concept of ‘exceptions’, instead we just return errors and immediately check if an error was returned and handle it there and then. +The SDK will return client-side errors like so: + +```go +if err != nil { + return errors.New("Some Managed Identity Error here”) +} +``` + +This will be inside of any client methods that throw errors, using descriptive errors based on the .NET and Java Implementation. These errors will be propagated down the chain and handled when they are received + +For service side errors it works a little differently + +```go +switch reply.StatusCode { + case 200, 201: + default: + sd := strings.TrimSpace(string(data)) + + if sd != "" { + // We probably have the error in the body. + return nil, errors.CallErr { + Req: req, + Resp: reply, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s",req.URL.String(), req.Method, reply.StatusCode, sd) + } + } + + return nil, errors.CallErr{ + Req: req, + Resp: reply, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d", req.URL.String(), req.Method, reply.StatusCode), + } +} +``` + +In this example, you can see we are returning **errors.CallErr(Req: httpRequest, Resp: httpResponse, Err: error)** + +For the service side errors we have a struct object like this: + +```go +type CallErr struct { + Req *http.Request + // Resp contains response body + Resp *http.Response + Err error +} +``` + +This structure should be followed for future service calls. More information on this implementation can be found [here](https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/ae2db6b72c7010958355f448e99209bd28e76e67/apps/errors/error_design.md#L1) + +## Caching + +Other MSALs have an Enum called **TokenSource** that lets us differentiate between **IdentityProvider**, **Cache** and **Broker**. + +Since GO does not have Brokers, we have created a PR [here](https://github.com/AzureAD/microsoft-authentication-library-for-go/pull/498) that adds a **AuthenticationResultMetadata** class to the **_base.go_** instance of **AuthResult** + +This **AuthenticationResultMetadata** contains the **TokenSource** and **RefreshOn** values, like .NET and Java implementations. The **TokenSource** here does not contain the broker field as it is not something that is planned currently + +```go +type TokenSource int + +const ( + IdentityProvider TokenSource = 0 + Cache = 1 +) + +type AuthResultMetadata struct { + TokenSource TokenSource + RefreshOn time.Time +} +``` + +## FIC Support + +You can review information on FIC [here](https://review.learn.microsoft.com/en-us/identity/microsoft-identity-platform/federated-identity-credentials?branch=main&tabs=dotnet) + +Managed Identity abstracts the complexity of certificates away, by virtue of being hosted on an Azure VM you get access to the services you need i.e. key vault + +Managed Identity is a single tenant. This is an issue as Microsoft has many multi tenanted apps. +FIC solves this by allowing you to declare a trust relationship with an identity provider and application i.e. ‘I trust this GitHub token, if I see this Git Hub token, give me a token for something I want access to i.e. Key Vault’ +So, if you can get a token for Managed Identity you can use it to access the key vault in all tenants + +Right now, we shouldn’t have to do anything. +Currently FIC would be the token for the certificate in **acquireTokenByCredential()**, we would just provide the token for ManagedIdentity instead of using the certificate + +This is a 2-step process: + +1. Get token for Managed Identity. Would be a special token for a specific scope. + +2. Create a confidential client and get a token. Will get an API certificate for the assertion, and use the Managed Identity token instead of the certificate + +All we need to do for now is test FIC with Managed Identity, and update any documentation to go along with it \ No newline at end of file