Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021-present, the Recognai S.L. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import asyncio | |
import importlib | |
import os | |
import threading | |
import warnings | |
from itertools import chain | |
from types import ModuleType | |
from typing import Any, Optional, Tuple | |
class LazyargillaModule(ModuleType): | |
"""Module class that surfaces all objects but only performs associated imports when the objects are requested. | |
Shamelessly copied and adapted from the Hugging Face transformers implementation. | |
""" | |
def __init__( | |
self, | |
name, | |
module_file, | |
import_structure, | |
deprecated_import_structure=None, | |
module_spec=None, | |
extra_objects=None, | |
): | |
super().__init__(name) | |
self._modules = set(import_structure.keys()) | |
self._class_to_module = {} | |
for key, values in import_structure.items(): | |
for value in values: | |
self._class_to_module[value] = key | |
# Needed for autocompletion in an IDE | |
self.__all__ = list(import_structure.keys()) + list( | |
chain(*import_structure.values()) | |
) | |
self.__file__ = module_file | |
self.__spec__ = module_spec | |
self.__path__ = [os.path.dirname(module_file)] | |
self._objects = {} if extra_objects is None else extra_objects | |
self._name = name | |
self._import_structure = import_structure | |
# deprecated stuff | |
deprecated_import_structure = deprecated_import_structure or {} | |
self._deprecated_modules = set(deprecated_import_structure.keys()) | |
self._deprecated_class_to_module = {} | |
for key, values in deprecated_import_structure.items(): | |
for value in values: | |
self._deprecated_class_to_module[value] = key | |
# Needed for autocompletion in an IDE | |
def __dir__(self): | |
result = super().__dir__() | |
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether | |
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. | |
for attr in self.__all__: | |
if attr not in result: | |
result.append(attr) | |
return result | |
def __getattr__(self, name: str) -> Any: | |
if name in self._objects: | |
return self._objects[name] | |
if name in self._modules: | |
value = self._get_module(name) | |
elif name in self._class_to_module.keys(): | |
module = self._get_module(self._class_to_module[name]) | |
value = getattr(module, name) | |
elif name in self._deprecated_modules: | |
value = self._get_module(name, deprecated=True) | |
elif name in self._deprecated_class_to_module.keys(): | |
module = self._get_module( | |
self._deprecated_class_to_module[name], deprecated=True, class_name=name | |
) | |
value = getattr(module, name) | |
else: | |
raise AttributeError(f"module {self.__name__} has no attribute {name}") | |
setattr(self, name, value) | |
return value | |
def _get_module( | |
self, | |
module_name: str, | |
deprecated: bool = False, | |
class_name: Optional[str] = None, | |
): | |
if deprecated: | |
warnings.warn( | |
f"Importing '{class_name or module_name}' from the argilla namespace (that is " | |
f"`argilla.{class_name or module_name}`) is deprecated and will not work in a future version. " | |
f"Make sure you update your code accordingly.", | |
category=FutureWarning, | |
) | |
try: | |
return importlib.import_module("." + module_name, self.__name__) | |
except Exception as e: | |
raise RuntimeError( | |
f"Failed to import {self.__name__}.{module_name} because of the following error " | |
f"(look up to see its traceback):\n{e}" | |
) from e | |
def __reduce__(self): | |
return self.__class__, (self._name, self.__file__, self._import_structure) | |
def limit_value_length(data: Any, max_length: int) -> Any: | |
""" | |
Given an input data, limits string values to a max_length by fetching | |
last max_length characters | |
Parameters | |
---------- | |
data: | |
Input data | |
max_length: | |
Max length for string values | |
Returns | |
------- | |
Limited version of data, if any | |
""" | |
if isinstance(data, str): | |
return data[-max_length:] | |
if isinstance(data, dict): | |
return { | |
k: limit_value_length(v, max_length=max_length) for k, v in data.items() | |
} | |
if isinstance(data, (list, tuple, set)): | |
new_values = map(lambda x: limit_value_length(x, max_length=max_length), data) | |
return type(data)(new_values) | |
return data | |
__LOOP__, __THREAD__ = None, None | |
def setup_loop_in_thread() -> Tuple[asyncio.AbstractEventLoop, threading.Thread]: | |
"""Sets up a new asyncio event loop in a new thread, and runs it forever. | |
Returns: | |
A tuple containing the event loop and the thread. | |
""" | |
def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: | |
asyncio.set_event_loop(loop) | |
loop.run_forever() | |
global __LOOP__ | |
global __THREAD__ | |
if not (__LOOP__ and __THREAD__): | |
loop = asyncio.new_event_loop() | |
thread = threading.Thread( | |
target=start_background_loop, args=(loop,), daemon=True | |
) | |
thread.start() | |
__LOOP__, __THREAD__ = loop, thread | |
return __LOOP__, __THREAD__ | |