kdamevski's picture
Upload folder using huggingface_hub
1c60c6e
# coding=utf-8
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import importlib
import os
import threading
import warnings
from itertools import chain
from types import ModuleType
from typing import Any, Optional, Tuple
class LazyargillaModule(ModuleType):
"""Module class that surfaces all objects but only performs associated imports when the objects are requested.
Shamelessly copied and adapted from the Hugging Face transformers implementation.
"""
def __init__(
self,
name,
module_file,
import_structure,
deprecated_import_structure=None,
module_spec=None,
extra_objects=None,
):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(
chain(*import_structure.values())
)
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
# deprecated stuff
deprecated_import_structure = deprecated_import_structure or {}
self._deprecated_modules = set(deprecated_import_structure.keys())
self._deprecated_class_to_module = {}
for key, values in deprecated_import_structure.items():
for value in values:
self._deprecated_class_to_module[value] = key
# Needed for autocompletion in an IDE
def __dir__(self):
result = super().__dir__()
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
elif name in self._deprecated_modules:
value = self._get_module(name, deprecated=True)
elif name in self._deprecated_class_to_module.keys():
module = self._get_module(
self._deprecated_class_to_module[name], deprecated=True, class_name=name
)
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(
self,
module_name: str,
deprecated: bool = False,
class_name: Optional[str] = None,
):
if deprecated:
warnings.warn(
f"Importing '{class_name or module_name}' from the argilla namespace (that is "
f"`argilla.{class_name or module_name}`) is deprecated and will not work in a future version. "
f"Make sure you update your code accordingly.",
category=FutureWarning,
)
try:
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error "
f"(look up to see its traceback):\n{e}"
) from e
def __reduce__(self):
return self.__class__, (self._name, self.__file__, self._import_structure)
def limit_value_length(data: Any, max_length: int) -> Any:
"""
Given an input data, limits string values to a max_length by fetching
last max_length characters
Parameters
----------
data:
Input data
max_length:
Max length for string values
Returns
-------
Limited version of data, if any
"""
if isinstance(data, str):
return data[-max_length:]
if isinstance(data, dict):
return {
k: limit_value_length(v, max_length=max_length) for k, v in data.items()
}
if isinstance(data, (list, tuple, set)):
new_values = map(lambda x: limit_value_length(x, max_length=max_length), data)
return type(data)(new_values)
return data
__LOOP__, __THREAD__ = None, None
def setup_loop_in_thread() -> Tuple[asyncio.AbstractEventLoop, threading.Thread]:
"""Sets up a new asyncio event loop in a new thread, and runs it forever.
Returns:
A tuple containing the event loop and the thread.
"""
def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
global __LOOP__
global __THREAD__
if not (__LOOP__ and __THREAD__):
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=start_background_loop, args=(loop,), daemon=True
)
thread.start()
__LOOP__, __THREAD__ = loop, thread
return __LOOP__, __THREAD__