File size: 8,263 Bytes
f0169cd
3003400
 
f0169cd
dd3ef50
f0169cd
 
 
3003400
f0169cd
 
 
 
 
 
 
 
3003400
f0169cd
 
 
 
3003400
f0169cd
 
 
 
3003400
f0169cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003400
f0169cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003400
f0169cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003400
f0169cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003400
 
 
f0169cd
 
3003400
f0169cd
 
 
 
 
 
 
 
 
 
 
3003400
f0169cd
3003400
 
 
f0169cd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
import traceback
import io
import time

# Attempt to import SNAC (should work if requirements.txt is correct)
try:
    from snac import SNAC
    print("SNAC module imported successfully.")
except ImportError as e:
    print(f"Error importing SNAC: {e}")
    # Raise a more informative error if SNAC isn't installed
    raise ImportError("Could not import SNAC. Make sure 'snac' is listed in requirements.txt and installed correctly.") from e

# --- Configuration ---
TARGET_SR = 24000 # SNAC operates at 24kHz
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- Load Model (Load once globally) ---
snac_model = None
try:
    print("Loading SNAC model...")
    start_time = time.time()
    snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
    snac_model = snac_model.to(DEVICE)
    snac_model.eval() # Set model to evaluation mode
    end_time = time.time()
    print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.")
except Exception as e:
    print(f"FATAL: Error loading SNAC model: {e}")
    print(traceback.format_exc())
    # If the model fails to load, the app can't function.
    # Gradio will likely show an error, but we print specifics here.

# --- Main Processing Function ---
def process_audio(audio_filepath):
    """
    Loads, resamples, encodes, decodes audio using SNAC, and returns results.
    """
    if snac_model is None:
        return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio."

    if audio_filepath is None:
        return None, None, None, "Please upload an audio file."

    logs = ["--- Starting Audio Processing ---"]
    try:
        # 1. Load Audio
        logs.append(f"Loading audio file: {audio_filepath}")
        load_start = time.time()
        original_waveform, original_sr = torchaudio.load(audio_filepath)
        load_end = time.time()
        logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s")

        # Ensure float32
        original_waveform = original_waveform.to(dtype=torch.float32)

        # Handle multi-channel audio: Use the first channel
        if original_waveform.shape[0] > 1:
            logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.")
            original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency initially

        # --- Prepare Original for Playback ---
        # Gradio Audio component expects (sample_rate, numpy_array)
        # Ensure numpy array is 1D or 2D [channels, samples]
        original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) # Squeeze removes channel dim if 1
        logs.append("Prepared original audio for playback.")


        # 2. Resample if necessary
        resample_start = time.time()
        if original_sr != TARGET_SR:
            logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...")
            resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) # Resampler on same device
            waveform_to_encode = resampler(original_waveform)
            logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}")
        else:
            logs.append("Waveform is already at the target sample rate (24kHz).")
            waveform_to_encode = original_waveform
        resample_end = time.time()
        logs.append(f"Resampling time: {resample_end - resample_start:.2f}s")

        # --- Prepare Resampled for Playback ---
        resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy())
        logs.append("Prepared resampled audio for playback.")

        # 3. Prepare for SNAC Encoding (add batch dim, move to device)
        # Input should be [Batch, Channel, Time] = [1, 1, Time]
        # waveform_to_encode should currently be [1, Time] after channel selection/resampling
        waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) # Add batch dimension -> [1, 1, Time]
        logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}")

        # 4. Encode Audio using SNAC
        logs.append("Encoding audio with snac_model.encode()...")
        encode_start = time.time()
        with torch.inference_mode():
            codes = snac_model.encode(waveform_batch)
        encode_end = time.time()

        if not codes or not all(isinstance(c, torch.Tensor) for c in codes):
             log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}"
             if isinstance(codes, list):
                 log_msg += f" with types {[type(c) for c in codes]}"
             logs.append(log_msg)
             raise ValueError(log_msg)

        logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s")
        for i, layer_codes in enumerate(codes):
             logs.append(f"  Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}")

        # 5. Decode the Codes using SNAC
        logs.append("Decoding the generated codes with snac_model.decode()...")
        decode_start = time.time()
        with torch.inference_mode():
            reconstructed_waveform = snac_model.decode(codes) # codes are already on DEVICE
        decode_end = time.time()
        logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s")

        # 6. Prepare Reconstructed Audio for Playback
        # Output is [Batch, 1, Time]. Move to CPU, remove Batch/Channel, convert to NumPy.
        reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() # Squeeze removes Batch and Channel dims
        logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}")
        reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np)

        logs.append("\n--- Audio Processing Completed Successfully ---")
        return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs)

    except Exception as e:
        logs.append("\n--- An Error Occurred ---")
        logs.append(f"Error Type: {type(e).__name__}")
        logs.append(f"Error Details: {e}")
        logs.append("\n--- Traceback ---")
        logs.append(traceback.format_exc())
        # Return None for audio components on error, and the detailed log
        return None, None, None, "\n".join(logs)

# --- Gradio Interface ---
DESCRIPTION = """
This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_24khz`).
1. Upload an audio file (wav, mp3, flac, etc.).
2. The audio will be automatically resampled to 24kHz if needed.
3. The 24kHz audio is encoded into discrete codes by SNAC.
4. These codes are then decoded back into audio by SNAC.
5. You can listen to the original, the 24kHz version (if resampled), and the final reconstructed audio.

**Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed.
"""

iface = gr.Interface(
    fn=process_audio,
    inputs=gr.Audio(type="filepath", label="Upload Audio File"),
    outputs=[
        gr.Audio(label="Original Audio"),
        gr.Audio(label="Resampled Audio (24kHz Input to SNAC)"),
        gr.Audio(label="Reconstructed Audio (Output from SNAC)"),
        gr.Textbox(label="Log Output", lines=15)
    ],
    title="SNAC Audio Codec Demo (24kHz)",
    description=DESCRIPTION,
    examples=[
        # Add paths to example audio files if you upload some to your Space repo
        # ["examples/example1.wav"],
        # ["examples/example2.mp3"],
    ],
    cache_examples=False # Disable caching if examples change or have issues
)

if __name__ == "__main__":
    if snac_model is None:
        print("Cannot launch Gradio interface because SNAC model failed to load.")
    else:
        print("Launching Gradio Interface...")
        iface.launch()