Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
import os | |
import random | |
from typing import Dict, List, Optional, Tuple | |
from .audio_processor import AudioProcessor | |
from ..configs.config import AudioConfig, Config | |
class SpeakerDataset(Dataset): | |
""" | |
说话人数据集:用于加载单个说话人的音频数据 | |
""" | |
def __init__( | |
self, | |
audio_files: List[str], | |
audio_processor: AudioProcessor, | |
cache_size: int = 100 # 添加缓存机制 | |
): | |
self.audio_files = audio_files | |
self.audio_processor = audio_processor | |
self.cache = {} | |
self.cache_size = cache_size | |
def __len__(self) -> int: | |
return len(self.audio_files) | |
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
audio_path = self.audio_files[idx] | |
# 使用缓存机制 | |
if audio_path in self.cache: | |
return self.cache[audio_path] | |
try: | |
audio, mel_spec = self.audio_processor.preprocess_audio(audio_path) | |
item = { | |
'audio': torch.FloatTensor(audio), | |
'mel_spec': torch.FloatTensor(mel_spec), | |
'file_path': audio_path | |
} | |
# 更新缓存 | |
if len(self.cache) < self.cache_size: | |
self.cache[audio_path] = item | |
return item | |
except Exception as e: | |
print(f"Error processing file {audio_path}: {str(e)}") | |
# 返回数据集中的下一个有效样本 | |
return self.__getitem__((idx + 1) % len(self)) | |
class VoiceDatasetManager: | |
""" | |
数据集管理器:负责数据集的组织和任务采样 | |
""" | |
def __init__( | |
self, | |
root_dir: str, | |
audio_processor: Optional[AudioProcessor] = None, | |
config: Optional[Config] = None | |
): | |
self.root_dir = root_dir | |
self.config = config or Config() | |
self.audio_processor = audio_processor or AudioProcessor(config=self.config.audio) | |
self.speakers = self._scan_speakers() | |
def _scan_speakers(self) -> Dict[str, List[str]]: | |
speakers = {} | |
for speaker_id in os.listdir(self.root_dir): | |
speaker_dir = os.path.join(self.root_dir, speaker_id) | |
if os.path.isdir(speaker_dir): | |
audio_files = [] | |
# 递归搜索所有子目录 | |
for root, _, files in os.walk(speaker_dir): | |
for file in files: | |
if file.endswith(self.config.data.valid_audio_extensions): | |
audio_path = os.path.join(root, file) | |
# 验证文件是否可访问 | |
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0: | |
audio_files.append(audio_path) | |
# 只保留具有足够样本的说话人 | |
if len(audio_files) >= self.config.data.min_samples_per_speaker: | |
speakers[speaker_id] = audio_files | |
else: | |
print(f"Warning: Speaker {speaker_id} has insufficient samples") | |
return speakers | |
def get_speaker_dataset(self, speaker_id: str) -> SpeakerDataset: | |
"""获取特定说话人的数据集""" | |
if speaker_id not in self.speakers: | |
raise ValueError(f"Speaker {speaker_id} not found in dataset") | |
return SpeakerDataset( | |
self.speakers[speaker_id], | |
self.audio_processor, | |
cache_size=self.config.data.cache_size | |
) | |
class MetaLearningDataset(Dataset): | |
""" | |
元学习数据集:用于少样本语音克隆的训练 | |
每次返回一个任务的数据,包含支持集和查询集 | |
""" | |
def __init__( | |
self, | |
dataset_manager: VoiceDatasetManager, | |
config: Config | |
): | |
self.dataset_manager = dataset_manager | |
self.config = config | |
# 验证数据集 | |
available_speakers = [ | |
speaker_id for speaker_id, files in dataset_manager.speakers.items() | |
if len(files) >= (config.meta_learning.k_shot + config.meta_learning.k_query) | |
] | |
if len(available_speakers) < config.meta_learning.n_way: | |
raise ValueError( | |
f"Not enough speakers with sufficient samples. " | |
f"Need {config.meta_learning.n_way} speakers with " | |
f"{config.meta_learning.k_shot + config.meta_learning.k_query} samples each, " | |
f"but only found {len(available_speakers)}" | |
) | |
self.available_speakers = available_speakers | |
def __len__(self) -> int: | |
return self.config.meta_learning.n_tasks | |
def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: | |
""" | |
返回一个任务的数据 | |
Returns: | |
support_data: 包含支持集数据的字典 | |
- mel_spec: [n_way*k_shot, n_mels, time] | |
- speaker_ids: [n_way*k_shot] | |
query_data: 包含查询集数据的字典 | |
- mel_spec: [n_way*k_query, n_mels, time] | |
- speaker_ids: [n_way*k_query] | |
""" | |
# 随机选择说话人 | |
selected_speakers = random.sample(self.available_speakers, self.config.meta_learning.n_way) | |
support_data = { | |
'mel_spec': [], | |
'speaker_ids': [] | |
} | |
query_data = { | |
'mel_spec': [], | |
'speaker_ids': [] | |
} | |
for speaker_idx, speaker_id in enumerate(selected_speakers): | |
speaker_files = self.dataset_manager.speakers[speaker_id] | |
selected_files = random.sample( | |
speaker_files, | |
self.config.meta_learning.k_shot + self.config.meta_learning.k_query | |
) | |
for i, file_path in enumerate(selected_files): | |
try: | |
_, mel_spec = self.dataset_manager.audio_processor.preprocess_audio(file_path) | |
mel_tensor = torch.FloatTensor(mel_spec) # [n_mels, time] | |
target_dict = support_data if i < self.config.meta_learning.k_shot else query_data | |
target_dict['mel_spec'].append(mel_tensor) | |
target_dict['speaker_ids'].append(speaker_idx) | |
except Exception as e: | |
print(f"Error processing {file_path}: {str(e)}") | |
continue | |
# 转换为张量 | |
for data_dict in [support_data, query_data]: | |
if len(data_dict['mel_spec']) == 0: | |
raise RuntimeError("No valid samples found for task") | |
data_dict['mel_spec'] = torch.stack(data_dict['mel_spec']) | |
data_dict['speaker_ids'] = torch.LongTensor(data_dict['speaker_ids']) | |
return support_data, query_data | |
def create_meta_learning_dataloader( | |
root_dir: str, | |
config: Optional[Config] = None, | |
**kwargs | |
) -> DataLoader: | |
""" | |
创建用于元学习的数据加载器 | |
Args: | |
root_dir: 数据集根目录 | |
config: 配置对象 | |
**kwargs: 其他参数 | |
Returns: | |
DataLoader: 元学习数据加载器 | |
""" | |
config = config or Config() | |
# 更新配置 | |
for key, value in kwargs.items(): | |
if hasattr(config.meta_learning, key): | |
setattr(config.meta_learning, key, value) | |
# 创建数据集管理器 | |
dataset_manager = VoiceDatasetManager(root_dir, config=config) | |
# 创建数据集 | |
dataset = MetaLearningDataset(dataset_manager, config) | |
# 创建数据加载器 | |
return DataLoader( | |
dataset, | |
batch_size=1, # 固定为1,因为每个样本已经包含了一个完整的任务 | |
shuffle=True, | |
num_workers=0, # 避免多进程带来的问题 | |
pin_memory=True, | |
collate_fn=lambda x: x[0] # 移除批次维度 | |
) |