Skip to content

Commit 3b70655

Browse files
authored
{Core} Decouple MSAL credentials from SDK get_token protocol (#29955)
1 parent 7161d86 commit 3b70655

File tree

8 files changed

+112
-97
lines changed

8 files changed

+112
-97
lines changed

src/azure-cli-core/azure/cli/core/_profile.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from azure.cli.core._session import ACCOUNT
1212
from azure.cli.core.azclierror import AuthenticationError
1313
from azure.cli.core.cloud import get_active_cloud, set_cloud_subscription
14+
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
1415
from azure.cli.core.util import in_cloud_console, can_launch_browser, is_github_codespaces
1516
from knack.log import get_logger
1617
from knack.util import CLIError
@@ -313,9 +314,10 @@ def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subsc
313314
import jwt
314315
identity_type = MsiAccountTypes.system_assigned
315316
from .auth.msal_credentials import ManagedIdentityCredential
317+
from .auth.constants import ACCESS_TOKEN
316318

317319
cred = ManagedIdentityCredential()
318-
token = cred.get_token(*self._arm_scope).token
320+
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
319321
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
320322
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
321323
tenant = decode['tid']
@@ -339,9 +341,10 @@ def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subsc
339341
def login_in_cloud_shell(self):
340342
import jwt
341343
from .auth.msal_credentials import CloudShellCredential
344+
from .auth.constants import ACCESS_TOKEN
342345

343346
cred = CloudShellCredential()
344-
token = cred.get_token(*self._arm_scope).token
347+
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
345348
logger.info('Cloud Shell token was retrieved. Now trying to initialize local accounts...')
346349
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
347350
tenant = decode['tid']
@@ -397,21 +400,19 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
397400
if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
398401
# Cloud Shell
399402
from .auth.msal_credentials import CloudShellCredential
400-
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
401403
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
402-
cred = CredentialAdaptor(CloudShellCredential())
404+
sdk_cred = CredentialAdaptor(CloudShellCredential())
403405

404406
elif managed_identity_type:
405407
# managed identity
406408
if _on_azure_arc():
407409
from .auth.msal_credentials import ManagedIdentityCredential
408-
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
409410
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
410-
cred = CredentialAdaptor(ManagedIdentityCredential())
411+
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
411412
else:
412413
# The resource is merely used by msrestazure to get the first access token.
413414
# It is not actually used in an API invocation.
414-
cred = MsiAccountTypes.msi_auth_factory(
415+
sdk_cred = MsiAccountTypes.msi_auth_factory(
415416
managed_identity_type, managed_identity_id,
416417
self.cli_ctx.cloud.endpoints.active_directory_resource_id)
417418

@@ -431,10 +432,9 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
431432
external_credentials = []
432433
for external_tenant in external_tenants:
433434
external_credentials.append(self._create_credential(account, tenant_id=external_tenant))
434-
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
435-
cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)
435+
sdk_cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)
436436

437-
return (cred,
437+
return (sdk_cred,
438438
str(account[_SUBSCRIPTION_ID]),
439439
str(account[_TENANT_ID]))
440440

@@ -460,24 +460,24 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
460460
if tenant:
461461
raise CLIError("Tenant shouldn't be specified for Cloud Shell account")
462462
from .auth.msal_credentials import CloudShellCredential
463-
cred = CloudShellCredential()
463+
sdk_cred = CredentialAdaptor(CloudShellCredential())
464464

465465
elif managed_identity_type:
466466
# managed identity
467467
if tenant:
468468
raise CLIError("Tenant shouldn't be specified for managed identity account")
469469
if _on_azure_arc():
470470
from .auth.msal_credentials import ManagedIdentityCredential
471-
cred = ManagedIdentityCredential()
471+
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
472472
else:
473473
from .auth.util import scopes_to_resource
474-
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
475-
scopes_to_resource(scopes))
474+
sdk_cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
475+
scopes_to_resource(scopes))
476476

477477
else:
478-
cred = self._create_credential(account, tenant_id=tenant)
478+
sdk_cred = CredentialAdaptor(self._create_credential(account, tenant_id=tenant))
479479

480-
sdk_token = cred.get_token(*scopes)
480+
sdk_token = sdk_cred.get_token(*scopes)
481481
# Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility
482482
# WARNING: expiresOn is deprecated and will be removed in future release.
483483
import datetime
@@ -856,7 +856,6 @@ def find_using_common_tenant(self, username, credential=None):
856856
specific_tenant_credential = identity.get_user_credential(username)
857857

858858
try:
859-
860859
subscriptions = self.find_using_specific_tenant(tenant_id, specific_tenant_credential,
861860
tenant_id_description=t)
862861
except AuthenticationError as ex:
@@ -927,9 +926,9 @@ def _create_subscription_client(self, credential):
927926
raise CLIInternalError("Unable to get '{}' in profile '{}'"
928927
.format(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, self.cli_ctx.cloud.profile))
929928
api_version = get_api_version(self.cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS)
930-
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, credential)
931-
932-
client = client_type(credential, api_version=api_version,
929+
sdk_cred = CredentialAdaptor(credential)
930+
client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, sdk_cred)
931+
client = client_type(sdk_cred, api_version=api_version,
933932
base_url=self.cli_ctx.cloud.endpoints.resource_manager,
934933
**client_kwargs)
935934
return client

src/azure-cli-core/azure/cli/core/auth/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
# --------------------------------------------------------------------------------------------
55

66
AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'
7+
8+
ACCESS_TOKEN = 'access_token'
9+
EXPIRES_IN = "expires_in"

src/azure-cli-core/azure/cli/core/auth/credential_adaptor.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
# --------------------------------------------------------------------------------------------
55

66
from knack.log import get_logger
7+
from .util import build_sdk_access_token
78

89
logger = get_logger(__name__)
910

1011

1112
class CredentialAdaptor:
1213
def __init__(self, credential, auxiliary_credentials=None):
13-
"""Cross-tenant credential adaptor. It takes a main credential and auxiliary credentials.
14-
14+
"""Credential adaptor between MSAL credential and SDK credential.
1515
It implements Track 2 SDK's azure.core.credentials.TokenCredential by exposing get_token.
1616
17-
:param credential: Main credential from .msal_authentication
18-
:param auxiliary_credentials: Credentials from .msal_authentication for cross tenant authentication.
19-
Details about cross tenant authentication:
17+
:param credential: MSAL credential from ._msal_credentials
18+
:param auxiliary_credentials: MSAL credentials for cross-tenant authentication.
19+
Details about cross-tenant authentication:
2020
https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant
2121
"""
2222

@@ -32,11 +32,12 @@ def get_token(self, *scopes, **kwargs):
3232
if 'data' in kwargs:
3333
filtered_kwargs['data'] = kwargs['data']
3434

35-
return self._credential.get_token(*scopes, **filtered_kwargs)
35+
return build_sdk_access_token(self._credential.acquire_token(list(scopes), **filtered_kwargs))
3636

3737
def get_auxiliary_tokens(self, *scopes, **kwargs):
3838
"""Get access tokens from auxiliary credentials."""
3939
# To test cross-tenant authentication, see https://github.com/Azure/azure-cli/issues/16691
4040
if self._auxiliary_credentials:
41-
return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
41+
return [build_sdk_access_token(cred.acquire_token(list(scopes), **kwargs))
42+
for cred in self._auxiliary_credentials]
4243
return None

src/azure-cli-core/azure/cli/core/auth/identity.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,8 @@ def login_with_service_principal(self, client_id, credential, scopes):
192192
"""
193193
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)
194194
client_credential = sp_auth.get_msal_client_credential()
195-
cca = ConfidentialClientApplication(client_id, client_credential=client_credential, **self._msal_app_kwargs)
196-
result = cca.acquire_token_for_client(scopes)
197-
check_result(result)
195+
cred = ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)
196+
cred.acquire_token(scopes)
198197

199198
# Only persist the service principal after a successful login
200199
entry = sp_auth.get_entry_to_persist()

src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

+24-33
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,7 @@
44
# --------------------------------------------------------------------------------------------
55

66
"""
7-
Credentials defined in this module are alternative implementations of credentials provided by Azure Identity.
8-
9-
These credentials implement azure.core.credentials.TokenCredential by exposing `get_token` method for Track 2
10-
SDK invocation.
11-
12-
If you want to implement your own credential, the credential must also expose `get_token` method.
13-
14-
`get_token` method takes `scopes` as positional arguments and other optional `kwargs`, such as `claims`, `data`.
15-
The return value should be a named tuple containing two elements: token (str), expires_on (int). You may simply use
16-
azure.cli.core.auth.util.AccessToken to build the return value. See below credentials as examples.
7+
Credentials to acquire tokens from MSAL.
178
"""
189

1910
from knack.log import get_logger
@@ -22,15 +13,15 @@
2213
ManagedIdentityClient, SystemAssignedManagedIdentity)
2314

2415
from .constants import AZURE_CLI_CLIENT_ID
25-
from .util import check_result, build_sdk_access_token
16+
from .util import check_result
2617

2718
logger = get_logger(__name__)
2819

2920

3021
class UserCredential: # pylint: disable=too-few-public-methods
3122

3223
def __init__(self, client_id, username, **kwargs):
33-
"""User credential implementing get_token interface.
24+
"""User credential wrapping msal.application.PublicClientApplication
3425
3526
:param client_id: Client ID of the CLI.
3627
:param username: The username for user credential.
@@ -52,14 +43,16 @@ def __init__(self, client_id, username, **kwargs):
5243

5344
self._account = accounts[0]
5445

55-
def get_token(self, *scopes, claims=None, **kwargs):
56-
# scopes = ['https://pas.windows.net/CheckMyAccess/Linux/.default']
57-
logger.debug("UserCredential.get_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)
46+
def acquire_token(self, scopes, claims=None, **kwargs):
47+
# scopes must be a list.
48+
# For acquiring SSH certificate, scopes is ['https://pas.windows.net/CheckMyAccess/Linux/.default']
49+
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
50+
logger.debug("UserCredential.acquire_token: scopes=%r, claims=%r, kwargs=%r", scopes, claims, kwargs)
5851

5952
if claims:
6053
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
6154
self._msal_app.authority.tenant, claims)
62-
result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims,
55+
result = self._msal_app.acquire_token_silent_with_error(scopes, self._account, claims_challenge=claims,
6356
**kwargs)
6457

6558
from azure.cli.core.azclierror import AuthenticationError
@@ -82,7 +75,7 @@ def get_token(self, *scopes, claims=None, **kwargs):
8275
success_template, error_template = read_response_templates()
8376

8477
result = self._msal_app.acquire_token_interactive(
85-
list(scopes), login_hint=self._account['username'],
78+
scopes, login_hint=self._account['username'],
8679
port=8400 if self._msal_app.authority.is_adfs else None,
8780
success_template=success_template, error_template=error_template, **kwargs)
8881
check_result(result)
@@ -91,25 +84,24 @@ def get_token(self, *scopes, claims=None, **kwargs):
9184
# launch browser, but show the error message and `az login` command instead.
9285
else:
9386
raise
94-
return build_sdk_access_token(result)
87+
return result
9588

9689

9790
class ServicePrincipalCredential: # pylint: disable=too-few-public-methods
9891

9992
def __init__(self, client_id, client_credential, **kwargs):
100-
"""Service principal credential implementing get_token interface.
93+
"""Service principal credential wrapping msal.application.ConfidentialClientApplication.
10194
10295
:param client_id: The service principal's client ID.
10396
:param client_credential: client_credential that will be passed to MSAL.
10497
"""
105-
self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs)
106-
107-
def get_token(self, *scopes, **kwargs):
108-
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
98+
self._msal_app = ConfidentialClientApplication(client_id, client_credential=client_credential, **kwargs)
10999

110-
result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
100+
def acquire_token(self, scopes, **kwargs):
101+
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
102+
result = self._msal_app.acquire_token_for_client(scopes, **kwargs)
111103
check_result(result)
112-
return build_sdk_access_token(result)
104+
return result
113105

114106

115107
class CloudShellCredential: # pylint: disable=too-few-public-methods
@@ -126,12 +118,11 @@ def __init__(self):
126118
# token_cache=...
127119
)
128120

129-
def get_token(self, *scopes, **kwargs):
130-
logger.debug("CloudShellCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
131-
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
132-
result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs)
121+
def acquire_token(self, scopes, **kwargs):
122+
logger.debug("CloudShellCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
123+
result = self._msal_app.acquire_token_interactive(scopes, prompt="none", **kwargs)
133124
check_result(result, scopes=scopes)
134-
return build_sdk_access_token(result)
125+
return result
135126

136127

137128
class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
@@ -143,10 +134,10 @@ def __init__(self):
143134
import requests
144135
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())
145136

146-
def get_token(self, *scopes, **kwargs):
147-
logger.debug("ManagedIdentityCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
137+
def acquire_token(self, scopes, **kwargs):
138+
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
148139

149140
from .util import scopes_to_resource
150141
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
151142
check_result(result)
152-
return build_sdk_access_token(result)
143+
return result

src/azure-cli-core/azure/cli/core/auth/util.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,6 @@ def check_result(result, **kwargs):
140140

141141

142142
def build_sdk_access_token(token_entry):
143-
import time
144-
request_time = int(time.time())
145-
146143
# MSAL token entry sample:
147144
# {
148145
# 'access_token': 'eyJ0eXAiOiJKV...',
@@ -153,7 +150,8 @@ def build_sdk_access_token(token_entry):
153150
# Importing azure.core.credentials.AccessToken is expensive.
154151
# This can slow down commands that doesn't need azure.core, like `az account get-access-token`.
155152
# So We define our own AccessToken.
156-
return AccessToken(token_entry["access_token"], request_time + token_entry["expires_in"])
153+
from .constants import ACCESS_TOKEN, EXPIRES_IN
154+
return AccessToken(token_entry[ACCESS_TOKEN], _now_timestamp() + token_entry[EXPIRES_IN])
157155

158156

159157
def decode_access_token(access_token):
@@ -177,3 +175,8 @@ def read_response_templates():
177175
error_template = f.read()
178176

179177
return success_template, error_template
178+
179+
180+
def _now_timestamp():
181+
import time
182+
return int(time.time())

0 commit comments

Comments
 (0)