File size: 3,732 Bytes
406f22d
 
 
 
 
 
 
 
 
 
 
8f04882
406f22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f04882
406f22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import yaml
import os
import look2hear.models
import argparse
import torch
import torchaudio
import torchaudio.transforms as T # Added for resampling

# audio path
parser = argparse.ArgumentParser()
# --- Argument Parsing ---
parser = argparse.ArgumentParser(description="Split audio with WAVbot.com")
parser.add_argument("--audio_path", default="test/mix.wav", help="Path to audio file (mixture).")
parser.add_argument("--output_dir", default="separated_audio", help="Directory to save separated audio files.")
parser.add_argument("--model_cache_dir", default="cache", help="Directory to cache downloaded model.")

# Parse arguments once at the beginning

args = parser.parse_args()

audio_path = args.audio_path

output_dir = args.output_dir

cache_dir = args.model_cache_dir
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")

# Load model

print("Loading WAvbot model...")
# Ensure cache directory exists if specified
if cache_dir:
    os.makedirs(cache_dir, exist_ok=True)
# Load the pretrained model
model = look2hear.models.TIGER.from_pretrained("JusperLee/TIGER-speech", cache_dir=cache_dir)
model.to(device)
model.eval()


# --- Audio Loading and Preprocessing ---
# Define the target sample rate expected by the model (usually 16kHz for TIGER)

target_sr = 16000
print(f"Loading audio from: {audio_path}")
try:
    # Load audio and get its original sample rate
    waveform, original_sr = torchaudio.load(audio_path)
except Exception as e:
    print(f"Error loading audio file {audio_path}: {e}")
    exit(1)
print(f"Original sample rate: {original_sr} Hz, Target sample rate: {target_sr} Hz")

# Resample if necessary
if original_sr != target_sr:
    print(f"Resampling audio from {original_sr} Hz to {target_sr} Hz...")
    resampler = T.Resample(orig_freq=original_sr, new_freq=target_sr)
    waveform = resampler(waveform)
    print("Resampling complete.")
    
# Move waveform to the target device
audio = waveform.to(device)

# Prepare the input tensor for the model
# Model likely expects a batch dimension [B, T] or [B, C, T]
# Assuming input is mono or model handles channels; add batch dim
# If audio has channel dim [C, T], keep it. If it's just [T], add channel dim first.

if audio.dim() == 1:
    audio = audio.unsqueeze(0) # Add channel dimension -> [1, T]
# Add batch dimension -> [1, C, T]
# The original audio[None] is equivalent to unsqueeze(0) on the batch dimension
audio_input = audio.unsqueeze(0).to(device)
print(f"Audio tensor prepared with shape: {audio_input.shape}")

# --- Speech Separation ---

# Create output directory if it doesn't exist

os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")
print("Performing separation...")

with torch.no_grad():
    # Pass the prepared input tensor to the model
    ests_speech = model(audio_input)  # Expected output shape: [B, num_spk, T]

# Process the estimated sources

# Remove the batch dimension -> [num_spk, T]

ests_speech = ests_speech.squeeze(0)

num_speakers = ests_speech.shape[0]

print(f"Separation complete. Detected {num_speakers} potential speakers.")



# --- Save Separated Audio ---

# Dynamically save all separated tracks

for i in range(num_speakers):
    output_filename = os.path.join(output_dir, f"spk{i+1}.wav")
    speaker_track = ests_speech[i].cpu() # Get the i-th speaker track and move to CPU
    print(f"Saving speaker {i+1} to {output_filename}")
    try:
        torchaudio.save(
            output_filename,
            speaker_track, # Save the individual track
            target_sr      # Save with the target sample rate
        )
    except Exception as e:
        print(f"Error saving file {output_filename}: {e}")