# train.py # coding: utf-8 import logging import os import shutil import datetime import whisper import toml # os.environ["HF_HOME"] = "models" from utils.config_loader import ConfigLoader from utils.logger_setup import setup_logger from utils.search_utils import greedy_search, exhaustive_search from training.train_utils import ( make_dataset_and_loader, train_once ) from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor def main(): # Грузим конфиг base_config = ConfigLoader("config.toml") model_name = base_config.model_name.replace("/", "_").replace(" ", "_").lower() timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') results_dir = f"results_{model_name}_{timestamp}" os.makedirs(results_dir, exist_ok=True) epochlog_dir = os.path.join(results_dir, "metrics_by_epoch") os.makedirs(epochlog_dir, exist_ok=True) # Настраиваем logging log_file = os.path.join(results_dir, "session_log.txt") setup_logger(logging.DEBUG, log_file=log_file) # Грузим конфиг base_config.show_config() shutil.copy("config.toml", os.path.join(results_dir, "config_copy.toml")) # Файл, куда будет писать наш жадный поиск overrides_file = os.path.join(results_dir, "overrides.txt") csv_prefix = os.path.join(epochlog_dir, "metrics_epochlog") audio_feature_extractor= PretrainedAudioEmbeddingExtractor(base_config) text_feature_extractor = PretrainedTextEmbeddingExtractor(base_config) # Инициализируем Whisper-модель один раз logging.info(f"Инициализация Whisper: модель={base_config.whisper_model}, устройство={base_config.whisper_device}") whisper_model = whisper.load_model(base_config.whisper_model, device=base_config.whisper_device) # Делаем датасеты/лоадеры # Общий train_loader _, train_loader = make_dataset_and_loader(base_config, "train", audio_feature_extractor, text_feature_extractor, whisper_model) # Раздельные dev/test dev_loaders = [] test_loaders = [] for dataset_name in base_config.datasets: _, dev_loader = make_dataset_and_loader(base_config, "dev", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name) _, test_loader = make_dataset_and_loader(base_config, "test", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name) dev_loaders.append((dataset_name, dev_loader)) test_loaders.append((dataset_name, test_loader)) if base_config.prepare_only: logging.info("== Режим prepare_only: только подготовка данных, без обучения ==") return search_config = toml.load("search_params.toml") param_grid = dict(search_config["grid"]) default_values = dict(search_config["defaults"]) if base_config.search_type == "greedy": greedy_search( base_config = base_config, train_loader = train_loader, dev_loader = dev_loaders, test_loader = test_loaders, train_fn = train_once, overrides_file = overrides_file, param_grid = param_grid, default_values = default_values, csv_prefix = csv_prefix ) elif base_config.search_type == "exhaustive": exhaustive_search( base_config = base_config, train_loader = train_loader, dev_loader = dev_loaders, test_loader = test_loaders, train_fn = train_once, overrides_file = overrides_file, param_grid = param_grid, csv_prefix = csv_prefix ) elif base_config.search_type == "none": logging.info("== Режим одиночной тренировки (без поиска параметров) ==") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") csv_file_path = f"{csv_prefix}_single_{timestamp}.csv" train_once( config = base_config, train_loader = train_loader, dev_loaders = dev_loaders, test_loaders = test_loaders, metrics_csv_path = csv_file_path ) else: raise ValueError(f"⛔️ Неверное значение search_type в конфиге: '{base_config.search_type}'. Используй 'greedy', 'exhaustive' или 'none'.") if __name__ == "__main__": main()