GraphRag
/
graphrag-ollama
/lib
/python3.12
/site-packages
/azure
/identity
/_internal
/interactive.py
# ------------------------------------ | |
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
# ------------------------------------ | |
"""Base class for credentials using MSAL for interactive user authentication""" | |
import abc | |
import base64 | |
import json | |
import logging | |
import time | |
from typing import Any, Optional, Iterable, Dict | |
from urllib.parse import urlparse | |
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions | |
from azure.core.exceptions import ClientAuthenticationError | |
from .msal_credentials import MsalCredential | |
from .._auth_record import AuthenticationRecord | |
from .._constants import KnownAuthorities | |
from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError | |
from .._internal import wrap_exceptions | |
ABC = abc.ABC | |
_LOGGER = logging.getLogger(__name__) | |
_DEFAULT_AUTHENTICATE_SCOPES = { | |
"https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",), | |
"https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",), | |
"https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",), | |
} | |
def _decode_client_info(raw) -> str: | |
"""Decode client info. Taken from msal.oauth2cli.oidc. | |
:param str raw: base64-encoded client info | |
:return: decoded client info | |
:rtype: str | |
""" | |
raw += "=" * (-len(raw) % 4) | |
raw = str(raw) # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode. | |
return base64.urlsafe_b64decode(raw).decode("utf-8") | |
def _build_auth_record(response): | |
"""Build an AuthenticationRecord from the result of an MSAL ClientApplication token request. | |
:param response: The result of a token request | |
:type response: dict[str, typing.Any] | |
:return: An AuthenticationRecord | |
:rtype: ~azure.identity.AuthenticationRecord | |
:raises ~azure.core.exceptions.ClientAuthenticationError: If the response doesn't contain expected data | |
""" | |
try: | |
id_token = response["id_token_claims"] | |
if "client_info" in response: | |
client_info = json.loads(_decode_client_info(response["client_info"])) | |
home_account_id = "{uid}.{utid}".format(**client_info) | |
else: | |
# MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info | |
home_account_id = id_token["sub"] | |
# "iss" is the URL of the issuing tenant e.g. https://authority/tenant | |
issuer = urlparse(id_token["iss"]) | |
# tenant which issued the token, not necessarily user's home tenant | |
tenant_id = id_token.get("tid") or issuer.path.strip("/") | |
# Microsoft Entra ID returns "preferred_username", ADFS returns "upn" | |
username = id_token.get("preferred_username") or id_token["upn"] | |
return AuthenticationRecord( | |
authority=issuer.netloc, | |
client_id=id_token["aud"], | |
home_account_id=home_account_id, | |
tenant_id=tenant_id, | |
username=username, | |
) | |
except (KeyError, ValueError) as ex: | |
auth_error = ClientAuthenticationError( | |
message="Failed to build AuthenticationRecord from unexpected identity token" | |
) | |
raise auth_error from ex | |
class InteractiveCredential(MsalCredential, ABC): | |
def __init__( | |
self, | |
*, | |
authentication_record: Optional[AuthenticationRecord] = None, | |
disable_automatic_authentication: bool = False, | |
**kwargs: Any, | |
) -> None: | |
self._disable_automatic_authentication = disable_automatic_authentication | |
self._auth_record = authentication_record | |
if self._auth_record: | |
kwargs.pop("client_id", None) # authentication_record overrides client_id argument | |
tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id | |
super(InteractiveCredential, self).__init__( | |
client_id=self._auth_record.client_id, | |
authority=self._auth_record.authority, | |
tenant_id=tenant_id, | |
**kwargs, | |
) | |
else: | |
super(InteractiveCredential, self).__init__(**kwargs) | |
def get_token( | |
self, | |
*scopes: str, | |
claims: Optional[str] = None, | |
tenant_id: Optional[str] = None, | |
enable_cae: bool = False, | |
**kwargs: Any, | |
) -> AccessToken: | |
"""Request an access token for `scopes`. | |
This method is called automatically by Azure SDK clients. | |
:param str scopes: desired scopes for the access token. This method requires at least one scope. | |
For more information about scopes, see | |
https://learn.microsoft.com/entra/identity-platform/scopes-oidc. | |
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's | |
claims challenge following an authorization failure | |
:keyword str tenant_id: optional tenant to include in the token request. | |
:keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested | |
token. Defaults to False. | |
:return: An access token with the desired scopes. | |
:rtype: ~azure.core.credentials.AccessToken | |
:raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks | |
required data, state, or platform support | |
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` | |
attribute gives a reason. | |
:raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is | |
configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. | |
""" | |
options: TokenRequestOptions = {} | |
if claims: | |
options["claims"] = claims | |
if tenant_id: | |
options["tenant_id"] = tenant_id | |
options["enable_cae"] = enable_cae | |
token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) | |
return AccessToken(token_info.token, token_info.expires_on) | |
def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: | |
"""Request an access token for `scopes`. | |
This is an alternative to `get_token` to enable certain scenarios that require additional properties | |
on the token. This method is called automatically by Azure SDK clients. | |
:param str scopes: desired scopes for the access token. This method requires at least one scope. | |
For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. | |
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. | |
:paramtype options: ~azure.core.credentials.TokenRequestOptions | |
:rtype: AccessTokenInfo | |
:return: An AccessTokenInfo instance containing information about the token. | |
:raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks | |
required data, state, or platform support | |
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` | |
attribute gives a reason. | |
:raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is | |
configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. | |
""" | |
return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") | |
def _get_token_base( | |
self, | |
*scopes: str, | |
options: Optional[TokenRequestOptions] = None, | |
base_method_name: str = "get_token_info", | |
**kwargs: Any, | |
) -> AccessTokenInfo: | |
if not scopes: | |
message = f"'{base_method_name}' requires at least one scope" | |
_LOGGER.warning("%s.%s failed: %s", self.__class__.__name__, base_method_name, message) | |
raise ValueError(message) | |
allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) | |
options = options or {} | |
claims = options.get("claims") | |
tenant_id = options.get("tenant_id") | |
enable_cae = options.get("enable_cae", False) | |
# Check for arbitrary additional options to enable intermediary support for PoP tokens. | |
for key in options: | |
if key not in TokenRequestOptions.__annotations__: # pylint:disable=no-member | |
kwargs.setdefault(key, options[key]) # type: ignore | |
try: | |
token = self._acquire_token_silent( | |
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs | |
) | |
_LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) | |
return token | |
except Exception as ex: # pylint:disable=broad-except | |
if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt): | |
_LOGGER.warning( | |
"%s.%s failed: %s", | |
self.__class__.__name__, | |
base_method_name, | |
ex, | |
exc_info=_LOGGER.isEnabledFor(logging.DEBUG), | |
) | |
raise | |
# silent authentication failed -> authenticate interactively | |
now = int(time.time()) | |
try: | |
result = self._request_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) | |
if "access_token" not in result: | |
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) | |
response = self._client.get_error_response(result) | |
raise ClientAuthenticationError(message=message, response=response) | |
# this may be the first authentication, or the user may have authenticated a different identity | |
self._auth_record = _build_auth_record(result) | |
except Exception as ex: # pylint:disable=broad-except | |
_LOGGER.warning( | |
"%s.%s failed: %s", | |
self.__class__.__name__, | |
base_method_name, | |
ex, | |
exc_info=_LOGGER.isEnabledFor(logging.DEBUG), | |
) | |
raise | |
_LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) | |
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None | |
return AccessTokenInfo( | |
result["access_token"], | |
now + int(result["expires_in"]), | |
token_type=result.get("token_type", "Bearer"), | |
refresh_on=refresh_on, | |
) | |
def authenticate( | |
self, *, scopes: Optional[Iterable[str]] = None, claims: Optional[str] = None, **kwargs: Any | |
) -> AuthenticationRecord: | |
"""Interactively authenticate a user. This method will always generate a challenge to the user. | |
:keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by | |
:func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token | |
for these scopes. | |
:keyword str claims: additional claims required in the token, such as those provided by | |
:func:`AuthenticationRequiredError.claims` | |
:rtype: ~azure.identity.AuthenticationRecord | |
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` | |
attribute gives a reason. | |
""" | |
if not scopes: | |
if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES: | |
# the credential is configured to use a cloud whose ARM scope we can't determine | |
raise CredentialUnavailableError( | |
message="Authenticating in this environment requires a value for the 'scopes' keyword argument." | |
) | |
scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority] | |
_ = self.get_token(*scopes, _allow_prompt=True, claims=claims, **kwargs) | |
return self._auth_record # type: ignore | |
def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: | |
result = None | |
claims = kwargs.get("claims") | |
if self._auth_record: | |
app = self._get_app(**kwargs) | |
for account in app.get_accounts(username=self._auth_record.username): | |
if account.get("home_account_id") != self._auth_record.home_account_id: | |
continue | |
now = int(time.time()) | |
result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims) | |
if result and "access_token" in result and "expires_in" in result: | |
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None | |
return AccessTokenInfo( | |
result["access_token"], | |
now + int(result["expires_in"]), | |
token_type=result.get("token_type", "Bearer"), | |
refresh_on=refresh_on, | |
) | |
# if we get this far, result is either None or the content of a Microsoft Entra ID error response | |
if result: | |
response = self._client.get_error_response(result) | |
raise AuthenticationRequiredError(scopes, claims=claims, response=response) | |
raise AuthenticationRequiredError(scopes, claims=claims) | |
def _request_token(self, *scopes, **kwargs) -> Dict: | |
pass | |