BiBiER / utils /config_loader.py
farbverlauf's picture
gpu
960b1a0
# utils/config_loader.py
import os
import toml
import logging
class ConfigLoader:
"""
Класс для загрузки и обработки конфигурации из `config.toml`.
"""
def __init__(self, config_path="config.toml"):
if not os.path.exists(config_path):
raise FileNotFoundError(f"Файл конфигурации `{config_path}` не найден!")
self.config = toml.load(config_path)
# ---------------------------
# Общие параметры
# ---------------------------
self.split = self.config.get("split", "train")
# ---------------------------
# Пути к данным
# ---------------------------
self.datasets = self.config.get("datasets", {})
# ---------------------------
# Пути к синтетическим данным
# ---------------------------
synthetic_data_cfg = self.config.get("synthetic_data", {})
self.use_synthetic_data = synthetic_data_cfg.get("use_synthetic_data", False)
self.synthetic_path = synthetic_data_cfg.get("synthetic_path", "E:/MELD_S")
self.synthetic_ratio = synthetic_data_cfg.get("synthetic_ratio", 0.0)
# ---------------------------
# Модальности и эмоции
# ---------------------------
self.modalities = self.config.get("modalities", ["audio"])
self.emotion_columns = self.config.get("emotion_columns", ["anger", "disgust", "fear", "happy", "neutral", "sad", "surprise"])
# ---------------------------
# DataLoader
# ---------------------------
dataloader_cfg = self.config.get("dataloader", {})
self.num_workers = dataloader_cfg.get("num_workers", 0)
self.shuffle = dataloader_cfg.get("shuffle", True)
self.prepare_only = dataloader_cfg.get("prepare_only", False)
# ---------------------------
# Аудио
# ---------------------------
audio_cfg = self.config.get("audio", {})
self.sample_rate = audio_cfg.get("sample_rate", 16000)
self.wav_length = audio_cfg.get("wav_length", 2)
self.save_merged_audio = audio_cfg.get("save_merged_audio", True)
self.merged_audio_base_path = audio_cfg.get("merged_audio_base_path", "saved_merges")
self.merged_audio_suffix = audio_cfg.get("merged_audio_suffix", "_merged")
self.force_remerge = audio_cfg.get("force_remerge", False)
# ---------------------------
# Whisper / Текст
# ---------------------------
text_cfg = self.config.get("text", {})
self.text_source = text_cfg.get("source", "csv")
self.text_column = text_cfg.get("text_column", "text")
self.whisper_model = text_cfg.get("whisper_model", "tiny")
self.max_text_tokens = text_cfg.get("max_tokens", 15)
self.whisper_device = text_cfg.get("whisper_device", "cuda")
self.use_whisper_for_nontrain_if_no_text = text_cfg.get("use_whisper_for_nontrain_if_no_text", True)
# ---------------------------
# Тренировка: общие
# ---------------------------
train_general = self.config.get("train", {}).get("general", {})
self.random_seed = train_general.get("random_seed", 42)
self.subset_size = train_general.get("subset_size", 0)
self.merge_probability = train_general.get("merge_probability", 0)
self.batch_size = train_general.get("batch_size", 8)
self.num_epochs = train_general.get("num_epochs", 100)
self.max_patience = train_general.get("max_patience", 10)
self.save_best_model = train_general.get("save_best_model", False)
self.save_prepared_data = train_general.get("save_prepared_data", True)
self.save_feature_path = train_general.get("save_feature_path", "./features/")
self.search_type = train_general.get("search_type", "none")
self.smoothing_probability = train_general.get("smoothing_probability", 0)
self.path_to_df_ls = train_general.get("path_to_df_ls", None)
# ---------------------------
# Тренировка: параметры модели
# ---------------------------
train_model = self.config.get("train", {}).get("model", {})
self.model_name = train_model.get("model_name", "BiFormer")
self.hidden_dim = train_model.get("hidden_dim", 256)
self.hidden_dim_gated = train_model.get("hidden_dim_gated", 256)
self.num_transformer_heads = train_model.get("num_transformer_heads", 8)
self.num_graph_heads = train_model.get("num_graph_heads", 8)
self.tr_layer_number = train_model.get("tr_layer_number", 1)
self.mamba_d_state = train_model.get("mamba_d_state", 16)
self.mamba_ker_size = train_model.get("mamba_ker_size", 4)
self.mamba_layer_number = train_model.get("mamba_layer_number", 3)
self.positional_encoding = train_model.get("positional_encoding", True)
self.dropout = train_model.get("dropout", 0.0)
self.out_features = train_model.get("out_features", 128)
self.mode = train_model.get("mode", "mean")
# ---------------------------
# Тренировка: оптимизатор
# ---------------------------
train_optimizer = self.config.get("train", {}).get("optimizer", {})
self.optimizer = train_optimizer.get("optimizer", "adam")
self.lr = train_optimizer.get("lr", 1e-4)
self.weight_decay = train_optimizer.get("weight_decay", 0.0)
self.momentum = train_optimizer.get("momentum", 0.9)
# ---------------------------
# Тренировка: шедулер
# ---------------------------
train_scheduler = self.config.get("train", {}).get("scheduler", {})
self.scheduler_type = train_scheduler.get("scheduler_type", "plateau")
self.warmup_ratio = train_scheduler.get("warmup_ratio", 0.1)
# ---------------------------
# Эмбеддинги
# ---------------------------
emb_cfg = self.config.get("embeddings", {})
self.audio_model_name = emb_cfg.get("audio_model", "amiriparian/ExHuBERT")
self.text_model_name = emb_cfg.get("text_model", "jinaai/jina-embeddings-v3")
self.audio_classifier_checkpoint = emb_cfg.get("audio_classifier_checkpoint", "best_audio_model.pt")
self.text_classifier_checkpoint = emb_cfg.get("text_classifier_checkpoint", "best_text_model.pth")
self.audio_embedding_dim = emb_cfg.get("audio_embedding_dim", 1024)
self.text_embedding_dim = emb_cfg.get("text_embedding_dim", 1024)
self.emb_normalize = emb_cfg.get("emb_normalize", True)
self.audio_pooling = emb_cfg.get("audio_pooling", None)
self.text_pooling = emb_cfg.get("text_pooling", None)
self.max_tokens = emb_cfg.get("max_tokens", 256)
self.emb_device = emb_cfg.get("device", "cuda")
# ---------------------------
# Синтетика
# ---------------------------
# textgen_cfg = self.config.get("textgen", {})
# self.model_name = textgen_cfg.get("model_name", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
# self.max_new_tokens = textgen_cfg.get("max_new_tokens", 50)
# self.temperature = textgen_cfg.get("temperature", 1.0)
# self.top_p = textgen_cfg.get("top_p", 0.95)
if __name__ == "__main__":
self.log_config()
def log_config(self):
logging.info("=== CONFIGURATION ===")
logging.info(f"Split: {self.split}")
logging.info(f"Datasets loaded: {list(self.datasets.keys())}")
for name, ds in self.datasets.items():
logging.info(f"[Dataset: {name}]")
logging.info(f" Base Dir: {ds.get('base_dir', 'N/A')}")
logging.info(f" CSV Path: {ds.get('csv_path', '')}")
logging.info(f" WAV Dir: {ds.get('wav_dir', '')}")
logging.info(f"Emotion columns: {self.emotion_columns}")
# Логируем обучающие параметры
logging.info("--- Training Config ---")
logging.info(f"Sample Rate={self.sample_rate}, Wav Length={self.wav_length}s")
logging.info(f"Whisper Model={self.whisper_model}, Device={self.whisper_device}, MaxTokens={self.max_text_tokens}")
logging.info(f"use_whisper_for_nontrain_if_no_text={self.use_whisper_for_nontrain_if_no_text}")
logging.info(f"DataLoader: batch_size={self.batch_size}, num_workers={self.num_workers}, shuffle={self.shuffle}")
logging.info(f"Model Name: {self.model_name}")
logging.info(f"Random Seed: {self.random_seed}")
logging.info(f"Hidden Dim: {self.hidden_dim}")
logging.info(f"Hidden Dim in Gated: {self.hidden_dim_gated}")
logging.info(f"Num Heads in Transformer: {self.num_transformer_heads}")
logging.info(f"Num Heads in Graph: {self.num_graph_heads}")
logging.info(f"Mode stat pooling: {self.mode}")
logging.info(f"Optimizer: {self.optimizer}")
logging.info(f"Scheduler Type: {self.scheduler_type}")
logging.info(f"Warmup Ratio: {self.warmup_ratio}")
logging.info(f"Weight Decay for Adam: {self.weight_decay}")
logging.info(f"Momentum (SGD): {self.momentum}")
logging.info(f"Positional Encoding: {self.positional_encoding}")
logging.info(f"Number of Transformer Layers: {self.tr_layer_number}")
logging.info(f"Mamba D State: {self.mamba_d_state}")
logging.info(f"Mamba Kernel Size: {self.mamba_ker_size}")
logging.info(f"Mamba Layer Number: {self.mamba_layer_number}")
logging.info(f"Dropout: {self.dropout}")
logging.info(f"Out Features: {self.out_features}")
logging.info(f"LR: {self.lr}")
logging.info(f"Num Epochs: {self.num_epochs}")
logging.info(f"Merge Probability={self.merge_probability}")
logging.info(f"Smoothing Probability={self.smoothing_probability}")
logging.info(f"Max Patience={self.max_patience}")
logging.info(f"Save Prepared Data={self.save_prepared_data}")
logging.info(f"Path to Save Features={self.save_feature_path}")
logging.info(f"Search Type={self.search_type}")
# Логируем embeddings
logging.info("--- Embeddings Config ---")
logging.info(f"Audio Model: {self.audio_model_name}, Text Model: {self.text_model_name}")
logging.info(f"Audio dim={self.audio_embedding_dim}, Text dim={self.text_embedding_dim}")
logging.info(f"Audio pooling={self.audio_pooling}, Text pooling={self.text_pooling}")
logging.info(f"Emb device={self.emb_device}, Normalize={self.emb_normalize}")
def show_config(self):
self.log_config()