Spaces:
Running
Running
File size: 4,757 Bytes
960b1a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# train.py
# coding: utf-8
import logging
import os
import shutil
import datetime
import whisper
import toml
# os.environ["HF_HOME"] = "models"
from utils.config_loader import ConfigLoader
from utils.logger_setup import setup_logger
from utils.search_utils import greedy_search, exhaustive_search
from training.train_utils import (
make_dataset_and_loader,
train_once
)
from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor
def main():
# Грузим конфиг
base_config = ConfigLoader("config.toml")
model_name = base_config.model_name.replace("/", "_").replace(" ", "_").lower()
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
results_dir = f"results_{model_name}_{timestamp}"
os.makedirs(results_dir, exist_ok=True)
epochlog_dir = os.path.join(results_dir, "metrics_by_epoch")
os.makedirs(epochlog_dir, exist_ok=True)
# Настраиваем logging
log_file = os.path.join(results_dir, "session_log.txt")
setup_logger(logging.DEBUG, log_file=log_file)
# Грузим конфиг
base_config.show_config()
shutil.copy("config.toml", os.path.join(results_dir, "config_copy.toml"))
# Файл, куда будет писать наш жадный поиск
overrides_file = os.path.join(results_dir, "overrides.txt")
csv_prefix = os.path.join(epochlog_dir, "metrics_epochlog")
audio_feature_extractor= PretrainedAudioEmbeddingExtractor(base_config)
text_feature_extractor = PretrainedTextEmbeddingExtractor(base_config)
# Инициализируем Whisper-модель один раз
logging.info(f"Инициализация Whisper: модель={base_config.whisper_model}, устройство={base_config.whisper_device}")
whisper_model = whisper.load_model(base_config.whisper_model, device=base_config.whisper_device)
# Делаем датасеты/лоадеры
# Общий train_loader
_, train_loader = make_dataset_and_loader(base_config, "train", audio_feature_extractor, text_feature_extractor, whisper_model)
# Раздельные dev/test
dev_loaders = []
test_loaders = []
for dataset_name in base_config.datasets:
_, dev_loader = make_dataset_and_loader(base_config, "dev", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)
_, test_loader = make_dataset_and_loader(base_config, "test", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name)
dev_loaders.append((dataset_name, dev_loader))
test_loaders.append((dataset_name, test_loader))
if base_config.prepare_only:
logging.info("== Режим prepare_only: только подготовка данных, без обучения ==")
return
search_config = toml.load("search_params.toml")
param_grid = dict(search_config["grid"])
default_values = dict(search_config["defaults"])
if base_config.search_type == "greedy":
greedy_search(
base_config = base_config,
train_loader = train_loader,
dev_loader = dev_loaders,
test_loader = test_loaders,
train_fn = train_once,
overrides_file = overrides_file,
param_grid = param_grid,
default_values = default_values,
csv_prefix = csv_prefix
)
elif base_config.search_type == "exhaustive":
exhaustive_search(
base_config = base_config,
train_loader = train_loader,
dev_loader = dev_loaders,
test_loader = test_loaders,
train_fn = train_once,
overrides_file = overrides_file,
param_grid = param_grid,
csv_prefix = csv_prefix
)
elif base_config.search_type == "none":
logging.info("== Режим одиночной тренировки (без поиска параметров) ==")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
csv_file_path = f"{csv_prefix}_single_{timestamp}.csv"
train_once(
config = base_config,
train_loader = train_loader,
dev_loaders = dev_loaders,
test_loaders = test_loaders,
metrics_csv_path = csv_file_path
)
else:
raise ValueError(f"⛔️ Неверное значение search_type в конфиге: '{base_config.search_type}'. Используй 'greedy', 'exhaustive' или 'none'.")
if __name__ == "__main__":
main()
|