Spaces:
Running
Running
# utils/config_loader.py | |
import os | |
import toml | |
import logging | |
class ConfigLoader: | |
""" | |
Класс для загрузки и обработки конфигурации из `config.toml`. | |
""" | |
def __init__(self, config_path="config.toml"): | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(f"Файл конфигурации `{config_path}` не найден!") | |
self.config = toml.load(config_path) | |
# --------------------------- | |
# Общие параметры | |
# --------------------------- | |
self.split = self.config.get("split", "train") | |
# --------------------------- | |
# Пути к данным | |
# --------------------------- | |
self.datasets = self.config.get("datasets", {}) | |
# --------------------------- | |
# Пути к синтетическим данным | |
# --------------------------- | |
synthetic_data_cfg = self.config.get("synthetic_data", {}) | |
self.use_synthetic_data = synthetic_data_cfg.get("use_synthetic_data", False) | |
self.synthetic_path = synthetic_data_cfg.get("synthetic_path", "E:/MELD_S") | |
self.synthetic_ratio = synthetic_data_cfg.get("synthetic_ratio", 0.0) | |
# --------------------------- | |
# Модальности и эмоции | |
# --------------------------- | |
self.modalities = self.config.get("modalities", ["audio"]) | |
self.emotion_columns = self.config.get("emotion_columns", ["anger", "disgust", "fear", "happy", "neutral", "sad", "surprise"]) | |
# --------------------------- | |
# DataLoader | |
# --------------------------- | |
dataloader_cfg = self.config.get("dataloader", {}) | |
self.num_workers = dataloader_cfg.get("num_workers", 0) | |
self.shuffle = dataloader_cfg.get("shuffle", True) | |
self.prepare_only = dataloader_cfg.get("prepare_only", False) | |
# --------------------------- | |
# Аудио | |
# --------------------------- | |
audio_cfg = self.config.get("audio", {}) | |
self.sample_rate = audio_cfg.get("sample_rate", 16000) | |
self.wav_length = audio_cfg.get("wav_length", 2) | |
self.save_merged_audio = audio_cfg.get("save_merged_audio", True) | |
self.merged_audio_base_path = audio_cfg.get("merged_audio_base_path", "saved_merges") | |
self.merged_audio_suffix = audio_cfg.get("merged_audio_suffix", "_merged") | |
self.force_remerge = audio_cfg.get("force_remerge", False) | |
# --------------------------- | |
# Whisper / Текст | |
# --------------------------- | |
text_cfg = self.config.get("text", {}) | |
self.text_source = text_cfg.get("source", "csv") | |
self.text_column = text_cfg.get("text_column", "text") | |
self.whisper_model = text_cfg.get("whisper_model", "tiny") | |
self.max_text_tokens = text_cfg.get("max_tokens", 15) | |
self.whisper_device = text_cfg.get("whisper_device", "cuda") | |
self.use_whisper_for_nontrain_if_no_text = text_cfg.get("use_whisper_for_nontrain_if_no_text", True) | |
# --------------------------- | |
# Тренировка: общие | |
# --------------------------- | |
train_general = self.config.get("train", {}).get("general", {}) | |
self.random_seed = train_general.get("random_seed", 42) | |
self.subset_size = train_general.get("subset_size", 0) | |
self.merge_probability = train_general.get("merge_probability", 0) | |
self.batch_size = train_general.get("batch_size", 8) | |
self.num_epochs = train_general.get("num_epochs", 100) | |
self.max_patience = train_general.get("max_patience", 10) | |
self.save_best_model = train_general.get("save_best_model", False) | |
self.save_prepared_data = train_general.get("save_prepared_data", True) | |
self.save_feature_path = train_general.get("save_feature_path", "./features/") | |
self.search_type = train_general.get("search_type", "none") | |
self.smoothing_probability = train_general.get("smoothing_probability", 0) | |
self.path_to_df_ls = train_general.get("path_to_df_ls", None) | |
# --------------------------- | |
# Тренировка: параметры модели | |
# --------------------------- | |
train_model = self.config.get("train", {}).get("model", {}) | |
self.model_name = train_model.get("model_name", "BiFormer") | |
self.hidden_dim = train_model.get("hidden_dim", 256) | |
self.hidden_dim_gated = train_model.get("hidden_dim_gated", 256) | |
self.num_transformer_heads = train_model.get("num_transformer_heads", 8) | |
self.num_graph_heads = train_model.get("num_graph_heads", 8) | |
self.tr_layer_number = train_model.get("tr_layer_number", 1) | |
self.mamba_d_state = train_model.get("mamba_d_state", 16) | |
self.mamba_ker_size = train_model.get("mamba_ker_size", 4) | |
self.mamba_layer_number = train_model.get("mamba_layer_number", 3) | |
self.positional_encoding = train_model.get("positional_encoding", True) | |
self.dropout = train_model.get("dropout", 0.0) | |
self.out_features = train_model.get("out_features", 128) | |
self.mode = train_model.get("mode", "mean") | |
# --------------------------- | |
# Тренировка: оптимизатор | |
# --------------------------- | |
train_optimizer = self.config.get("train", {}).get("optimizer", {}) | |
self.optimizer = train_optimizer.get("optimizer", "adam") | |
self.lr = train_optimizer.get("lr", 1e-4) | |
self.weight_decay = train_optimizer.get("weight_decay", 0.0) | |
self.momentum = train_optimizer.get("momentum", 0.9) | |
# --------------------------- | |
# Тренировка: шедулер | |
# --------------------------- | |
train_scheduler = self.config.get("train", {}).get("scheduler", {}) | |
self.scheduler_type = train_scheduler.get("scheduler_type", "plateau") | |
self.warmup_ratio = train_scheduler.get("warmup_ratio", 0.1) | |
# --------------------------- | |
# Эмбеддинги | |
# --------------------------- | |
emb_cfg = self.config.get("embeddings", {}) | |
self.audio_model_name = emb_cfg.get("audio_model", "amiriparian/ExHuBERT") | |
self.text_model_name = emb_cfg.get("text_model", "jinaai/jina-embeddings-v3") | |
self.audio_classifier_checkpoint = emb_cfg.get("audio_classifier_checkpoint", "best_audio_model.pt") | |
self.text_classifier_checkpoint = emb_cfg.get("text_classifier_checkpoint", "best_text_model.pth") | |
self.audio_embedding_dim = emb_cfg.get("audio_embedding_dim", 1024) | |
self.text_embedding_dim = emb_cfg.get("text_embedding_dim", 1024) | |
self.emb_normalize = emb_cfg.get("emb_normalize", True) | |
self.audio_pooling = emb_cfg.get("audio_pooling", None) | |
self.text_pooling = emb_cfg.get("text_pooling", None) | |
self.max_tokens = emb_cfg.get("max_tokens", 256) | |
self.emb_device = emb_cfg.get("device", "cuda") | |
# --------------------------- | |
# Синтетика | |
# --------------------------- | |
# textgen_cfg = self.config.get("textgen", {}) | |
# self.model_name = textgen_cfg.get("model_name", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B") | |
# self.max_new_tokens = textgen_cfg.get("max_new_tokens", 50) | |
# self.temperature = textgen_cfg.get("temperature", 1.0) | |
# self.top_p = textgen_cfg.get("top_p", 0.95) | |
if __name__ == "__main__": | |
self.log_config() | |
def log_config(self): | |
logging.info("=== CONFIGURATION ===") | |
logging.info(f"Split: {self.split}") | |
logging.info(f"Datasets loaded: {list(self.datasets.keys())}") | |
for name, ds in self.datasets.items(): | |
logging.info(f"[Dataset: {name}]") | |
logging.info(f" Base Dir: {ds.get('base_dir', 'N/A')}") | |
logging.info(f" CSV Path: {ds.get('csv_path', '')}") | |
logging.info(f" WAV Dir: {ds.get('wav_dir', '')}") | |
logging.info(f"Emotion columns: {self.emotion_columns}") | |
# Логируем обучающие параметры | |
logging.info("--- Training Config ---") | |
logging.info(f"Sample Rate={self.sample_rate}, Wav Length={self.wav_length}s") | |
logging.info(f"Whisper Model={self.whisper_model}, Device={self.whisper_device}, MaxTokens={self.max_text_tokens}") | |
logging.info(f"use_whisper_for_nontrain_if_no_text={self.use_whisper_for_nontrain_if_no_text}") | |
logging.info(f"DataLoader: batch_size={self.batch_size}, num_workers={self.num_workers}, shuffle={self.shuffle}") | |
logging.info(f"Model Name: {self.model_name}") | |
logging.info(f"Random Seed: {self.random_seed}") | |
logging.info(f"Hidden Dim: {self.hidden_dim}") | |
logging.info(f"Hidden Dim in Gated: {self.hidden_dim_gated}") | |
logging.info(f"Num Heads in Transformer: {self.num_transformer_heads}") | |
logging.info(f"Num Heads in Graph: {self.num_graph_heads}") | |
logging.info(f"Mode stat pooling: {self.mode}") | |
logging.info(f"Optimizer: {self.optimizer}") | |
logging.info(f"Scheduler Type: {self.scheduler_type}") | |
logging.info(f"Warmup Ratio: {self.warmup_ratio}") | |
logging.info(f"Weight Decay for Adam: {self.weight_decay}") | |
logging.info(f"Momentum (SGD): {self.momentum}") | |
logging.info(f"Positional Encoding: {self.positional_encoding}") | |
logging.info(f"Number of Transformer Layers: {self.tr_layer_number}") | |
logging.info(f"Mamba D State: {self.mamba_d_state}") | |
logging.info(f"Mamba Kernel Size: {self.mamba_ker_size}") | |
logging.info(f"Mamba Layer Number: {self.mamba_layer_number}") | |
logging.info(f"Dropout: {self.dropout}") | |
logging.info(f"Out Features: {self.out_features}") | |
logging.info(f"LR: {self.lr}") | |
logging.info(f"Num Epochs: {self.num_epochs}") | |
logging.info(f"Merge Probability={self.merge_probability}") | |
logging.info(f"Smoothing Probability={self.smoothing_probability}") | |
logging.info(f"Max Patience={self.max_patience}") | |
logging.info(f"Save Prepared Data={self.save_prepared_data}") | |
logging.info(f"Path to Save Features={self.save_feature_path}") | |
logging.info(f"Search Type={self.search_type}") | |
# Логируем embeddings | |
logging.info("--- Embeddings Config ---") | |
logging.info(f"Audio Model: {self.audio_model_name}, Text Model: {self.text_model_name}") | |
logging.info(f"Audio dim={self.audio_embedding_dim}, Text dim={self.text_embedding_dim}") | |
logging.info(f"Audio pooling={self.audio_pooling}, Text pooling={self.text_pooling}") | |
logging.info(f"Emb device={self.emb_device}, Normalize={self.emb_normalize}") | |
def show_config(self): | |
self.log_config() | |