Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from typing import TYPE_CHECKING, Union | |
| from .constants import FINETRAINERS_LOG_LEVEL | |
| if TYPE_CHECKING: | |
| from .parallel import ParallelBackendType | |
| class FinetrainersLoggerAdapter(logging.LoggerAdapter): | |
| def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None: | |
| super().__init__(logger, {}) | |
| self.parallel_backend = parallel_backend | |
| self._log_freq = {} | |
| self._log_freq_counter = {} | |
| def log( | |
| self, | |
| level, | |
| msg, | |
| *args, | |
| main_process_only: bool = False, | |
| local_main_process_only: bool = True, | |
| in_order: bool = False, | |
| **kwargs, | |
| ): | |
| # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice | |
| kwargs.setdefault("stacklevel", 2) | |
| if not self.isEnabledFor(level): | |
| return | |
| if self.parallel_backend is None: | |
| if int(os.environ.get("RANK", 0)) == 0: | |
| msg, kwargs = self.process(msg, kwargs) | |
| self.logger.log(level, msg, *args, **kwargs) | |
| return | |
| if (main_process_only or local_main_process_only) and in_order: | |
| raise ValueError( | |
| "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True." | |
| ) | |
| if (main_process_only and self.parallel_backend.is_main_process) or ( | |
| local_main_process_only and self.parallel_backend.is_local_main_process | |
| ): | |
| msg, kwargs = self.process(msg, kwargs) | |
| self.logger.log(level, msg, *args, **kwargs) | |
| return | |
| if in_order: | |
| for i in range(self.parallel_backend.world_size): | |
| if self.rank == i: | |
| msg, kwargs = self.process(msg, kwargs) | |
| self.logger.log(level, msg, *args, **kwargs) | |
| self.parallel_backend.wait_for_everyone() | |
| return | |
| if not main_process_only and not local_main_process_only: | |
| msg, kwargs = self.process(msg, kwargs) | |
| self.logger.log(level, msg, *args, **kwargs) | |
| return | |
| def log_freq( | |
| self, | |
| level: str, | |
| name: str, | |
| msg: str, | |
| frequency: int, | |
| *, | |
| main_process_only: bool = False, | |
| local_main_process_only: bool = True, | |
| in_order: bool = False, | |
| **kwargs, | |
| ) -> None: | |
| if frequency <= 0: | |
| return | |
| if name not in self._log_freq_counter: | |
| self._log_freq[name] = frequency | |
| self._log_freq_counter[name] = 0 | |
| if self._log_freq_counter[name] % self._log_freq[name] == 0: | |
| self.log( | |
| level, | |
| msg, | |
| main_process_only=main_process_only, | |
| local_main_process_only=local_main_process_only, | |
| in_order=in_order, | |
| **kwargs, | |
| ) | |
| self._log_freq_counter[name] += 1 | |
| def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]: | |
| global _logger | |
| return _logger | |
| def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter: | |
| _logger.parallel_backend = parallel_backend | |
| _logger = logging.getLogger("finetrainers") | |
| _logger.setLevel(FINETRAINERS_LOG_LEVEL) | |
| _console_handler = logging.StreamHandler() | |
| _console_handler.setLevel(FINETRAINERS_LOG_LEVEL) | |
| _formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| _console_handler.setFormatter(_formatter) | |
| _logger.addHandler(_console_handler) | |
| _logger = FinetrainersLoggerAdapter(_logger) | |