File size: 3,305 Bytes
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import logging
from transformers import get_scheduler

class DummyScheduler:
    def step(self, *args, **kwargs):
        pass

class SmartScheduler:
    def __init__(self, scheduler_type, optimizer, config, steps_per_epoch):
        self.scheduler_type = scheduler_type.lower()
        self.is_batch_level = False

        if self.scheduler_type == "plateau":
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode="max",
                factor=0.5,
                patience=2,
                min_lr=1e-7
            )
            logging.info("[Scheduler] Используется ReduceLROnPlateau (по метрике).")

        elif self.scheduler_type == "cosine":
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=config.num_epochs,
                eta_min=1e-6
            )
            logging.info("[Scheduler] Используется CosineAnnealingLR.")

        elif self.scheduler_type == "onecycle":
            if steps_per_epoch == 0:
                raise ValueError("train_loader пустой, OneCycle не может работать без данных.")
            self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=config.lr,
                steps_per_epoch=steps_per_epoch,
                epochs=config.num_epochs
            )
            self.is_batch_level = True
            logging.info(f"[Scheduler] Используется OneCycleLR ({steps_per_epoch} шагов на эпоху).")

        elif self.scheduler_type.startswith("huggingface_"):
            scheduler_name = self.scheduler_type.replace("huggingface_", "")

            total_steps = steps_per_epoch * config.num_epochs
            warmup_steps = int(total_steps * config.warmup_ratio)

            self.scheduler = get_scheduler(
                name=scheduler_name,
                optimizer=optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps,
    )
            self.is_batch_level = True  # HuggingFace обычно требует шагать по батчам
            logging.info(f"[Scheduler] HuggingFace: {scheduler_name} — warmup_steps={warmup_steps}, total_steps={total_steps}")

        elif self.scheduler_type == "none":
            self.scheduler = DummyScheduler()
            logging.info("[Scheduler] Нет шедулера (ручное управление lr).")

        else:
            raise ValueError(f"Неизвестный scheduler_type: {scheduler_type}")

    def step(self, metric=None, batch_level=False):
        """
        batch_level=True  ➔ шагать после батча (например, для OneCycle, HuggingFace schedulers)
        batch_level=False ➔ шагать после эпохи
        """
        if isinstance(self.scheduler, DummyScheduler):
            return

        if self.scheduler_type == "plateau":
            if not batch_level:
                self.scheduler.step(metric)
        elif self.is_batch_level:
            if batch_level:
                self.scheduler.step()
        else:
            if not batch_level:
                self.scheduler.step()