GraphRag
/
graphrag-ollama
/lib
/python3.12
/site-packages
/azure
/identity
/_credentials
/silent.py
# ------------------------------------ | |
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
# ------------------------------------ | |
import platform | |
import time | |
from typing import Dict, Optional, Any | |
from msal import PublicClientApplication, TokenCache | |
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions | |
from azure.core.exceptions import ClientAuthenticationError | |
from .. import CredentialUnavailableError | |
from .._internal import resolve_tenant, validate_tenant_id, within_dac | |
from .._internal.decorators import wrap_exceptions | |
from .._internal.msal_client import MsalClient | |
from .._internal.shared_token_cache import NO_TOKEN | |
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions | |
from .. import AuthenticationRecord | |
class SilentAuthenticationCredential: | |
"""Internal class for authenticating from the default shared cache given an AuthenticationRecord. | |
:param authentication_record: an AuthenticationRecord from which to authenticate | |
:type authentication_record: ~azure.identity.AuthenticationRecord | |
:keyword str tenant_id: tenant ID of the application the credential is authenticating for. Defaults to the tenant | |
""" | |
def __init__( | |
self, authentication_record: AuthenticationRecord, *, tenant_id: Optional[str] = None, **kwargs | |
) -> None: | |
self._auth_record = authentication_record | |
# authenticate in the tenant that produced the record unless "tenant_id" specifies another | |
self._tenant_id = tenant_id or self._auth_record.tenant_id | |
validate_tenant_id(self._tenant_id) | |
self._cache = kwargs.pop("_cache", None) | |
self._cae_cache = kwargs.pop("_cae_cache", None) | |
if self._cache or self._cae_cache: | |
self._custom_cache = True | |
else: | |
self._custom_cache = False | |
self._cache_persistence_options = kwargs.pop("cache_persistence_options", None) | |
self._client_applications: Dict[str, PublicClientApplication] = {} | |
self._cae_client_applications: Dict[str, PublicClientApplication] = {} | |
self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", []) | |
self._client = MsalClient(**kwargs) | |
def __enter__(self) -> "SilentAuthenticationCredential": | |
self._client.__enter__() | |
return self | |
def __exit__(self, *args): | |
self._client.__exit__(*args) | |
def close(self) -> None: | |
self.__exit__() | |
def get_token( | |
self, | |
*scopes: str, | |
claims: Optional[str] = None, | |
tenant_id: Optional[str] = None, | |
enable_cae: bool = False, | |
**kwargs: Any, | |
) -> AccessToken: | |
options: TokenRequestOptions = {} | |
if claims: | |
options["claims"] = claims | |
if tenant_id: | |
options["tenant_id"] = tenant_id | |
options["enable_cae"] = enable_cae | |
token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) | |
return AccessToken(token_info.token, token_info.expires_on) | |
def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: | |
return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") | |
def _get_token_base( | |
self, | |
*scopes: str, | |
options: Optional[TokenRequestOptions] = None, | |
base_method_name: str = "get_token_info", | |
**kwargs: Any, | |
) -> AccessTokenInfo: | |
if not scopes: | |
raise ValueError(f"'{base_method_name}' requires at least one scope") | |
options = options or {} | |
claims = options.get("claims") | |
tenant_id = options.get("tenant_id") | |
enable_cae = options.get("enable_cae", False) | |
token_cache = self._cae_cache if enable_cae else self._cache | |
# Try to load the cache if it is None. | |
if not token_cache: | |
token_cache = self._initialize_cache(is_cae=enable_cae) | |
# If the cache is still None, raise an error. | |
if not token_cache: | |
if within_dac.get(): | |
raise CredentialUnavailableError(message="Shared token cache unavailable") | |
raise ClientAuthenticationError(message="Shared token cache unavailable") | |
return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) | |
def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]: | |
# If no cache options were provided, the default cache will be used. This credential accepts the | |
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the | |
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must | |
# have allowed that. | |
cache_options = self._cache_persistence_options or TokenCachePersistenceOptions(allow_unencrypted_storage=True) | |
if platform.system() not in {"Darwin", "Linux", "Windows"}: | |
raise CredentialUnavailableError(message="Shared token cache is not supported on this platform.") | |
if not self._cache and not is_cae: | |
try: | |
self._cache = _load_persistent_cache(cache_options, is_cae) | |
except Exception: # pylint:disable=broad-except | |
return None | |
if not self._cae_cache and is_cae: | |
try: | |
self._cae_cache = _load_persistent_cache(cache_options, is_cae) | |
except Exception: # pylint:disable=broad-except | |
return None | |
return self._cae_cache if is_cae else self._cache | |
def _get_client_application(self, **kwargs: Any): | |
tenant_id = resolve_tenant( | |
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs | |
) | |
client_applications_map = self._client_applications | |
capabilities = None | |
token_cache = self._cache | |
if kwargs.get("enable_cae"): | |
client_applications_map = self._cae_client_applications | |
# CP1 = can handle claims challenges (CAE) | |
capabilities = ["CP1"] | |
token_cache = self._cae_cache | |
if tenant_id not in client_applications_map: | |
client_applications_map[tenant_id] = PublicClientApplication( | |
client_id=self._auth_record.client_id, | |
authority="https://{}/{}".format(self._auth_record.authority, tenant_id), | |
token_cache=token_cache, | |
http_client=self._client, | |
client_capabilities=capabilities, | |
) | |
return client_applications_map[tenant_id] | |
def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: | |
"""Silently acquire a token from MSAL. | |
:param str scopes: desired scopes for the access token | |
:return: an access token | |
:rtype: ~azure.core.credentials.AccessToken | |
""" | |
result = None | |
client_application = self._get_client_application(**kwargs) | |
accounts_for_user = client_application.get_accounts(username=self._auth_record.username) | |
if not accounts_for_user: | |
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.") | |
for account in accounts_for_user: | |
if account.get("home_account_id") != self._auth_record.home_account_id: | |
continue | |
now = int(time.time()) | |
result = client_application.acquire_token_silent_with_error( | |
list(scopes), account=account, claims_challenge=kwargs.get("claims") | |
) | |
if result and "access_token" in result and "expires_in" in result: | |
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None | |
return AccessTokenInfo( | |
result["access_token"], | |
now + int(result["expires_in"]), | |
token_type=result.get("token_type", "Bearer"), | |
refresh_on=refresh_on, | |
) | |
# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently | |
if result: | |
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it | |
message = "Token acquisition failed" | |
details = result.get("error_description") or result.get("error") | |
if details: | |
message += ": {}".format(details) | |
raise ClientAuthenticationError(message=message) | |
# cache doesn't contain a matching refresh (or access) token | |
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username)) | |
def __getstate__(self) -> Dict[str, Any]: | |
state = self.__dict__.copy() | |
# Remove the non-picklable entries | |
if not self._custom_cache: | |
del state["_cache"] | |
del state["_cae_cache"] | |
return state | |
def __setstate__(self, state: Dict[str, Any]) -> None: | |
self.__dict__.update(state) | |
# Re-create the unpickable entries | |
if not self._custom_cache: | |
self._cache = None | |
self._cae_cache = None | |