GotoUsuke's picture
Upload folder using huggingface_hub
ab4488b verified
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import platform
import time
from typing import Dict, Optional, Any
from msal import PublicClientApplication, TokenCache
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
from azure.core.exceptions import ClientAuthenticationError
from .. import CredentialUnavailableError
from .._internal import resolve_tenant, validate_tenant_id, within_dac
from .._internal.decorators import wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions
from .. import AuthenticationRecord
class SilentAuthenticationCredential:
"""Internal class for authenticating from the default shared cache given an AuthenticationRecord.
:param authentication_record: an AuthenticationRecord from which to authenticate
:type authentication_record: ~azure.identity.AuthenticationRecord
:keyword str tenant_id: tenant ID of the application the credential is authenticating for. Defaults to the tenant
"""
def __init__(
self, authentication_record: AuthenticationRecord, *, tenant_id: Optional[str] = None, **kwargs
) -> None:
self._auth_record = authentication_record
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = tenant_id or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._cache = kwargs.pop("_cache", None)
self._cae_cache = kwargs.pop("_cae_cache", None)
if self._cache or self._cae_cache:
self._custom_cache = True
else:
self._custom_cache = False
self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)
self._client_applications: Dict[str, PublicClientApplication] = {}
self._cae_client_applications: Dict[str, PublicClientApplication] = {}
self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", [])
self._client = MsalClient(**kwargs)
def __enter__(self) -> "SilentAuthenticationCredential":
self._client.__enter__()
return self
def __exit__(self, *args):
self._client.__exit__(*args)
def close(self) -> None:
self.__exit__()
def get_token(
self,
*scopes: str,
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
enable_cae: bool = False,
**kwargs: Any,
) -> AccessToken:
options: TokenRequestOptions = {}
if claims:
options["claims"] = claims
if tenant_id:
options["tenant_id"] = tenant_id
options["enable_cae"] = enable_cae
token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs)
return AccessToken(token_info.token, token_info.expires_on)
def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
return self._get_token_base(*scopes, options=options, base_method_name="get_token_info")
def _get_token_base(
self,
*scopes: str,
options: Optional[TokenRequestOptions] = None,
base_method_name: str = "get_token_info",
**kwargs: Any,
) -> AccessTokenInfo:
if not scopes:
raise ValueError(f"'{base_method_name}' requires at least one scope")
options = options or {}
claims = options.get("claims")
tenant_id = options.get("tenant_id")
enable_cae = options.get("enable_cae", False)
token_cache = self._cae_cache if enable_cae else self._cache
# Try to load the cache if it is None.
if not token_cache:
token_cache = self._initialize_cache(is_cae=enable_cae)
# If the cache is still None, raise an error.
if not token_cache:
if within_dac.get():
raise CredentialUnavailableError(message="Shared token cache unavailable")
raise ClientAuthenticationError(message="Shared token cache unavailable")
return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs)
def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]:
# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
cache_options = self._cache_persistence_options or TokenCachePersistenceOptions(allow_unencrypted_storage=True)
if platform.system() not in {"Darwin", "Linux", "Windows"}:
raise CredentialUnavailableError(message="Shared token cache is not supported on this platform.")
if not self._cache and not is_cae:
try:
self._cache = _load_persistent_cache(cache_options, is_cae)
except Exception: # pylint:disable=broad-except
return None
if not self._cae_cache and is_cae:
try:
self._cae_cache = _load_persistent_cache(cache_options, is_cae)
except Exception: # pylint:disable=broad-except
return None
return self._cae_cache if is_cae else self._cache
def _get_client_application(self, **kwargs: Any):
tenant_id = resolve_tenant(
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
)
client_applications_map = self._client_applications
capabilities = None
token_cache = self._cache
if kwargs.get("enable_cae"):
client_applications_map = self._cae_client_applications
# CP1 = can handle claims challenges (CAE)
capabilities = ["CP1"]
token_cache = self._cae_cache
if tenant_id not in client_applications_map:
client_applications_map[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=token_cache,
http_client=self._client,
client_capabilities=capabilities,
)
return client_applications_map[tenant_id]
@wrap_exceptions
def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
"""Silently acquire a token from MSAL.
:param str scopes: desired scopes for the access token
:return: an access token
:rtype: ~azure.core.credentials.AccessToken
"""
result = None
client_application = self._get_client_application(**kwargs)
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")
for account in accounts_for_user:
if account.get("home_account_id") != self._auth_record.home_account_id:
continue
now = int(time.time())
result = client_application.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None
return AccessTokenInfo(
result["access_token"],
now + int(result["expires_in"]),
token_type=result.get("token_type", "Bearer"),
refresh_on=refresh_on,
)
# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
if result:
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
message = "Token acquisition failed"
details = result.get("error_description") or result.get("error")
if details:
message += ": {}".format(details)
raise ClientAuthenticationError(message=message)
# cache doesn't contain a matching refresh (or access) token
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Remove the non-picklable entries
if not self._custom_cache:
del state["_cache"]
del state["_cae_cache"]
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
# Re-create the unpickable entries
if not self._custom_cache:
self._cache = None
self._cae_cache = None