voice-clone-app / src /data /dataset.py
hengjie yang
Initial commit: Voice Clone App with Gradio interface
9580089
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import random
from typing import Dict, List, Optional, Tuple
from .audio_processor import AudioProcessor
from ..configs.config import AudioConfig, Config
class SpeakerDataset(Dataset):
"""
说话人数据集:用于加载单个说话人的音频数据
"""
def __init__(
self,
audio_files: List[str],
audio_processor: AudioProcessor,
cache_size: int = 100 # 添加缓存机制
):
self.audio_files = audio_files
self.audio_processor = audio_processor
self.cache = {}
self.cache_size = cache_size
def __len__(self) -> int:
return len(self.audio_files)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
audio_path = self.audio_files[idx]
# 使用缓存机制
if audio_path in self.cache:
return self.cache[audio_path]
try:
audio, mel_spec = self.audio_processor.preprocess_audio(audio_path)
item = {
'audio': torch.FloatTensor(audio),
'mel_spec': torch.FloatTensor(mel_spec),
'file_path': audio_path
}
# 更新缓存
if len(self.cache) < self.cache_size:
self.cache[audio_path] = item
return item
except Exception as e:
print(f"Error processing file {audio_path}: {str(e)}")
# 返回数据集中的下一个有效样本
return self.__getitem__((idx + 1) % len(self))
class VoiceDatasetManager:
"""
数据集管理器:负责数据集的组织和任务采样
"""
def __init__(
self,
root_dir: str,
audio_processor: Optional[AudioProcessor] = None,
config: Optional[Config] = None
):
self.root_dir = root_dir
self.config = config or Config()
self.audio_processor = audio_processor or AudioProcessor(config=self.config.audio)
self.speakers = self._scan_speakers()
def _scan_speakers(self) -> Dict[str, List[str]]:
speakers = {}
for speaker_id in os.listdir(self.root_dir):
speaker_dir = os.path.join(self.root_dir, speaker_id)
if os.path.isdir(speaker_dir):
audio_files = []
# 递归搜索所有子目录
for root, _, files in os.walk(speaker_dir):
for file in files:
if file.endswith(self.config.data.valid_audio_extensions):
audio_path = os.path.join(root, file)
# 验证文件是否可访问
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
audio_files.append(audio_path)
# 只保留具有足够样本的说话人
if len(audio_files) >= self.config.data.min_samples_per_speaker:
speakers[speaker_id] = audio_files
else:
print(f"Warning: Speaker {speaker_id} has insufficient samples")
return speakers
def get_speaker_dataset(self, speaker_id: str) -> SpeakerDataset:
"""获取特定说话人的数据集"""
if speaker_id not in self.speakers:
raise ValueError(f"Speaker {speaker_id} not found in dataset")
return SpeakerDataset(
self.speakers[speaker_id],
self.audio_processor,
cache_size=self.config.data.cache_size
)
class MetaLearningDataset(Dataset):
"""
元学习数据集:用于少样本语音克隆的训练
每次返回一个任务的数据,包含支持集和查询集
"""
def __init__(
self,
dataset_manager: VoiceDatasetManager,
config: Config
):
self.dataset_manager = dataset_manager
self.config = config
# 验证数据集
available_speakers = [
speaker_id for speaker_id, files in dataset_manager.speakers.items()
if len(files) >= (config.meta_learning.k_shot + config.meta_learning.k_query)
]
if len(available_speakers) < config.meta_learning.n_way:
raise ValueError(
f"Not enough speakers with sufficient samples. "
f"Need {config.meta_learning.n_way} speakers with "
f"{config.meta_learning.k_shot + config.meta_learning.k_query} samples each, "
f"but only found {len(available_speakers)}"
)
self.available_speakers = available_speakers
def __len__(self) -> int:
return self.config.meta_learning.n_tasks
def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
返回一个任务的数据
Returns:
support_data: 包含支持集数据的字典
- mel_spec: [n_way*k_shot, n_mels, time]
- speaker_ids: [n_way*k_shot]
query_data: 包含查询集数据的字典
- mel_spec: [n_way*k_query, n_mels, time]
- speaker_ids: [n_way*k_query]
"""
# 随机选择说话人
selected_speakers = random.sample(self.available_speakers, self.config.meta_learning.n_way)
support_data = {
'mel_spec': [],
'speaker_ids': []
}
query_data = {
'mel_spec': [],
'speaker_ids': []
}
for speaker_idx, speaker_id in enumerate(selected_speakers):
speaker_files = self.dataset_manager.speakers[speaker_id]
selected_files = random.sample(
speaker_files,
self.config.meta_learning.k_shot + self.config.meta_learning.k_query
)
for i, file_path in enumerate(selected_files):
try:
_, mel_spec = self.dataset_manager.audio_processor.preprocess_audio(file_path)
mel_tensor = torch.FloatTensor(mel_spec) # [n_mels, time]
target_dict = support_data if i < self.config.meta_learning.k_shot else query_data
target_dict['mel_spec'].append(mel_tensor)
target_dict['speaker_ids'].append(speaker_idx)
except Exception as e:
print(f"Error processing {file_path}: {str(e)}")
continue
# 转换为张量
for data_dict in [support_data, query_data]:
if len(data_dict['mel_spec']) == 0:
raise RuntimeError("No valid samples found for task")
data_dict['mel_spec'] = torch.stack(data_dict['mel_spec'])
data_dict['speaker_ids'] = torch.LongTensor(data_dict['speaker_ids'])
return support_data, query_data
def create_meta_learning_dataloader(
root_dir: str,
config: Optional[Config] = None,
**kwargs
) -> DataLoader:
"""
创建用于元学习的数据加载器
Args:
root_dir: 数据集根目录
config: 配置对象
**kwargs: 其他参数
Returns:
DataLoader: 元学习数据加载器
"""
config = config or Config()
# 更新配置
for key, value in kwargs.items():
if hasattr(config.meta_learning, key):
setattr(config.meta_learning, key, value)
# 创建数据集管理器
dataset_manager = VoiceDatasetManager(root_dir, config=config)
# 创建数据集
dataset = MetaLearningDataset(dataset_manager, config)
# 创建数据加载器
return DataLoader(
dataset,
batch_size=1, # 固定为1,因为每个样本已经包含了一个完整的任务
shuffle=True,
num_workers=0, # 避免多进程带来的问题
pin_memory=True,
collate_fn=lambda x: x[0] # 移除批次维度
)