|
import yaml |
|
import os |
|
import look2hear.models |
|
import argparse |
|
import torch |
|
import torchaudio |
|
import torchaudio.transforms as T |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser = argparse.ArgumentParser(description="Separate speech sources using Look2Hear TIGER model.") |
|
parser.add_argument("--audio_path", default="test/mix.wav", help="Path to audio file (mixture).") |
|
parser.add_argument("--output_dir", default="separated_audio", help="Directory to save separated audio files.") |
|
parser.add_argument("--model_cache_dir", default="cache", help="Directory to cache downloaded model.") |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
audio_path = args.audio_path |
|
|
|
output_dir = args.output_dir |
|
|
|
cache_dir = args.model_cache_dir |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
print("Loading TIGER model...") |
|
|
|
if cache_dir: |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
model = look2hear.models.TIGER.from_pretrained("JusperLee/TIGER-speech", cache_dir=cache_dir) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
target_sr = 16000 |
|
print(f"Loading audio from: {audio_path}") |
|
try: |
|
|
|
waveform, original_sr = torchaudio.load(audio_path) |
|
except Exception as e: |
|
print(f"Error loading audio file {audio_path}: {e}") |
|
exit(1) |
|
print(f"Original sample rate: {original_sr} Hz, Target sample rate: {target_sr} Hz") |
|
|
|
|
|
if original_sr != target_sr: |
|
print(f"Resampling audio from {original_sr} Hz to {target_sr} Hz...") |
|
resampler = T.Resample(orig_freq=original_sr, new_freq=target_sr) |
|
waveform = resampler(waveform) |
|
print("Resampling complete.") |
|
|
|
|
|
audio = waveform.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if audio.dim() == 1: |
|
audio = audio.unsqueeze(0) |
|
|
|
|
|
audio_input = audio.unsqueeze(0).to(device) |
|
print(f"Audio tensor prepared with shape: {audio_input.shape}") |
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
print(f"Output directory: {output_dir}") |
|
print("Performing separation...") |
|
|
|
with torch.no_grad(): |
|
|
|
ests_speech = model(audio_input) |
|
|
|
|
|
|
|
|
|
|
|
ests_speech = ests_speech.squeeze(0) |
|
|
|
num_speakers = ests_speech.shape[0] |
|
|
|
print(f"Separation complete. Detected {num_speakers} potential speakers.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(num_speakers): |
|
output_filename = os.path.join(output_dir, f"spk{i+1}.wav") |
|
speaker_track = ests_speech[i].cpu() |
|
print(f"Saving speaker {i+1} to {output_filename}") |
|
try: |
|
torchaudio.save( |
|
output_filename, |
|
speaker_track, |
|
target_sr |
|
) |
|
except Exception as e: |
|
print(f"Error saving file {output_filename}: {e}") |
|
|