File size: 4,757 Bytes
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# train.py
# coding: utf-8
import logging
import os
import shutil
import datetime
import whisper
import toml
# os.environ["HF_HOME"] = "models"

from utils.config_loader import ConfigLoader
from utils.logger_setup import setup_logger
from utils.search_utils import greedy_search, exhaustive_search
from training.train_utils import (
    make_dataset_and_loader,
    train_once
)
from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor

def main():

    #  Грузим конфиг
    base_config = ConfigLoader("config.toml")

    model_name = base_config.model_name.replace("/", "_").replace(" ", "_").lower()
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    results_dir = f"results_{model_name}_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)

    epochlog_dir = os.path.join(results_dir, "metrics_by_epoch")
    os.makedirs(epochlog_dir, exist_ok=True)

    # Настраиваем logging
    log_file = os.path.join(results_dir, "session_log.txt")
    setup_logger(logging.DEBUG, log_file=log_file)

    #  Грузим конфиг
    base_config.show_config()

    shutil.copy("config.toml", os.path.join(results_dir, "config_copy.toml"))
    #  Файл, куда будет писать наш жадный поиск
    overrides_file = os.path.join(results_dir, "overrides.txt")
    csv_prefix = os.path.join(epochlog_dir, "metrics_epochlog")

    audio_feature_extractor= PretrainedAudioEmbeddingExtractor(base_config)
    text_feature_extractor = PretrainedTextEmbeddingExtractor(base_config)

    # Инициализируем Whisper-модель один раз
    logging.info(f"Инициализация Whisper: модель={base_config.whisper_model}, устройство={base_config.whisper_device}")
    whisper_model = whisper.load_model(base_config.whisper_model, device=base_config.whisper_device)

    # Делаем датасеты/лоадеры
    # Общий train_loader
    _, train_loader = make_dataset_and_loader(base_config, "train", audio_feature_extractor, text_feature_extractor, whisper_model)

    # Раздельные dev/test
    dev_loaders = []
    test_loaders = []

    for dataset_name in base_config.datasets:
        _, dev_loader = make_dataset_and_loader(base_config, "dev",  audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)
        _, test_loader = make_dataset_and_loader(base_config, "test",  audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)

        dev_loaders.append((dataset_name, dev_loader))
        test_loaders.append((dataset_name, test_loader))

    if base_config.prepare_only:
        logging.info("== Режим prepare_only: только подготовка данных, без обучения ==")
        return

    search_config = toml.load("search_params.toml")
    param_grid = dict(search_config["grid"])
    default_values = dict(search_config["defaults"])

    if base_config.search_type == "greedy":
        greedy_search(
            base_config       = base_config,
            train_loader      = train_loader,
            dev_loader        = dev_loaders,
            test_loader       = test_loaders,
            train_fn          = train_once,
            overrides_file    = overrides_file,
            param_grid        = param_grid,
            default_values    = default_values,
            csv_prefix        = csv_prefix
        )

    elif base_config.search_type == "exhaustive":
        exhaustive_search(
            base_config       = base_config,
            train_loader      = train_loader,
            dev_loader        = dev_loaders,
            test_loader       = test_loaders,
            train_fn          = train_once,
            overrides_file    = overrides_file,
            param_grid        = param_grid,
            csv_prefix        = csv_prefix
        )

    elif base_config.search_type == "none":
        logging.info("== Режим одиночной тренировки (без поиска параметров) ==")

        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        csv_file_path = f"{csv_prefix}_single_{timestamp}.csv"

        train_once(
            config           = base_config,
            train_loader     = train_loader,
            dev_loaders      = dev_loaders,
            test_loaders     = test_loaders,
            metrics_csv_path = csv_file_path
        )

    else:
        raise ValueError(f"⛔️ Неверное значение search_type в конфиге: '{base_config.search_type}'. Используй 'greedy', 'exhaustive' или 'none'.")


if __name__ == "__main__":
    main()