Spaces:
Running
Running
File size: 3,305 Bytes
960b1a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
import logging
from transformers import get_scheduler
class DummyScheduler:
def step(self, *args, **kwargs):
pass
class SmartScheduler:
def __init__(self, scheduler_type, optimizer, config, steps_per_epoch):
self.scheduler_type = scheduler_type.lower()
self.is_batch_level = False
if self.scheduler_type == "plateau":
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.5,
patience=2,
min_lr=1e-7
)
logging.info("[Scheduler] Используется ReduceLROnPlateau (по метрике).")
elif self.scheduler_type == "cosine":
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=config.num_epochs,
eta_min=1e-6
)
logging.info("[Scheduler] Используется CosineAnnealingLR.")
elif self.scheduler_type == "onecycle":
if steps_per_epoch == 0:
raise ValueError("train_loader пустой, OneCycle не может работать без данных.")
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.lr,
steps_per_epoch=steps_per_epoch,
epochs=config.num_epochs
)
self.is_batch_level = True
logging.info(f"[Scheduler] Используется OneCycleLR ({steps_per_epoch} шагов на эпоху).")
elif self.scheduler_type.startswith("huggingface_"):
scheduler_name = self.scheduler_type.replace("huggingface_", "")
total_steps = steps_per_epoch * config.num_epochs
warmup_steps = int(total_steps * config.warmup_ratio)
self.scheduler = get_scheduler(
name=scheduler_name,
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
self.is_batch_level = True # HuggingFace обычно требует шагать по батчам
logging.info(f"[Scheduler] HuggingFace: {scheduler_name} — warmup_steps={warmup_steps}, total_steps={total_steps}")
elif self.scheduler_type == "none":
self.scheduler = DummyScheduler()
logging.info("[Scheduler] Нет шедулера (ручное управление lr).")
else:
raise ValueError(f"Неизвестный scheduler_type: {scheduler_type}")
def step(self, metric=None, batch_level=False):
"""
batch_level=True ➔ шагать после батча (например, для OneCycle, HuggingFace schedulers)
batch_level=False ➔ шагать после эпохи
"""
if isinstance(self.scheduler, DummyScheduler):
return
if self.scheduler_type == "plateau":
if not batch_level:
self.scheduler.step(metric)
elif self.is_batch_level:
if batch_level:
self.scheduler.step()
else:
if not batch_level:
self.scheduler.step()
|