import torch import logging from transformers import get_scheduler class DummyScheduler: def step(self, *args, **kwargs): pass class SmartScheduler: def __init__(self, scheduler_type, optimizer, config, steps_per_epoch): self.scheduler_type = scheduler_type.lower() self.is_batch_level = False if self.scheduler_type == "plateau": self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", factor=0.5, patience=2, min_lr=1e-7 ) logging.info("[Scheduler] Используется ReduceLROnPlateau (по метрике).") elif self.scheduler_type == "cosine": self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=config.num_epochs, eta_min=1e-6 ) logging.info("[Scheduler] Используется CosineAnnealingLR.") elif self.scheduler_type == "onecycle": if steps_per_epoch == 0: raise ValueError("train_loader пустой, OneCycle не может работать без данных.") self.scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=config.lr, steps_per_epoch=steps_per_epoch, epochs=config.num_epochs ) self.is_batch_level = True logging.info(f"[Scheduler] Используется OneCycleLR ({steps_per_epoch} шагов на эпоху).") elif self.scheduler_type.startswith("huggingface_"): scheduler_name = self.scheduler_type.replace("huggingface_", "") total_steps = steps_per_epoch * config.num_epochs warmup_steps = int(total_steps * config.warmup_ratio) self.scheduler = get_scheduler( name=scheduler_name, optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, ) self.is_batch_level = True # HuggingFace обычно требует шагать по батчам logging.info(f"[Scheduler] HuggingFace: {scheduler_name} — warmup_steps={warmup_steps}, total_steps={total_steps}") elif self.scheduler_type == "none": self.scheduler = DummyScheduler() logging.info("[Scheduler] Нет шедулера (ручное управление lr).") else: raise ValueError(f"Неизвестный scheduler_type: {scheduler_type}") def step(self, metric=None, batch_level=False): """ batch_level=True ➔ шагать после батча (например, для OneCycle, HuggingFace schedulers) batch_level=False ➔ шагать после эпохи """ if isinstance(self.scheduler, DummyScheduler): return if self.scheduler_type == "plateau": if not batch_level: self.scheduler.step(metric) elif self.is_batch_level: if batch_level: self.scheduler.step() else: if not batch_level: self.scheduler.step()