|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import torchaudio.transforms as T |
|
import numpy as np |
|
import traceback |
|
import io |
|
import time |
|
|
|
|
|
try: |
|
from snac import SNAC |
|
print("SNAC module imported successfully.") |
|
except ImportError as e: |
|
print(f"Error importing SNAC: {e}") |
|
|
|
raise ImportError("Could not import SNAC. Make sure 'snac' is listed in requirements.txt and installed correctly.") from e |
|
|
|
|
|
TARGET_SR = 24000 |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {DEVICE}") |
|
|
|
|
|
snac_model = None |
|
try: |
|
print("Loading SNAC model...") |
|
start_time = time.time() |
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
snac_model = snac_model.to(DEVICE) |
|
snac_model.eval() |
|
end_time = time.time() |
|
print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.") |
|
except Exception as e: |
|
print(f"FATAL: Error loading SNAC model: {e}") |
|
print(traceback.format_exc()) |
|
|
|
|
|
|
|
|
|
def process_audio(audio_filepath): |
|
""" |
|
Loads, resamples, encodes, decodes audio using SNAC, and returns results. |
|
""" |
|
if snac_model is None: |
|
return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio." |
|
|
|
if audio_filepath is None: |
|
return None, None, None, "Please upload an audio file." |
|
|
|
logs = ["--- Starting Audio Processing ---"] |
|
try: |
|
|
|
logs.append(f"Loading audio file: {audio_filepath}") |
|
load_start = time.time() |
|
original_waveform, original_sr = torchaudio.load(audio_filepath) |
|
load_end = time.time() |
|
logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s") |
|
|
|
|
|
original_waveform = original_waveform.to(dtype=torch.float32) |
|
|
|
|
|
if original_waveform.shape[0] > 1: |
|
logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.") |
|
original_waveform = original_waveform[0:1, :] |
|
|
|
|
|
|
|
|
|
original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) |
|
logs.append("Prepared original audio for playback.") |
|
|
|
|
|
|
|
resample_start = time.time() |
|
if original_sr != TARGET_SR: |
|
logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...") |
|
resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) |
|
waveform_to_encode = resampler(original_waveform) |
|
logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}") |
|
else: |
|
logs.append("Waveform is already at the target sample rate (24kHz).") |
|
waveform_to_encode = original_waveform |
|
resample_end = time.time() |
|
logs.append(f"Resampling time: {resample_end - resample_start:.2f}s") |
|
|
|
|
|
resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy()) |
|
logs.append("Prepared resampled audio for playback.") |
|
|
|
|
|
|
|
|
|
waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) |
|
logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}") |
|
|
|
|
|
logs.append("Encoding audio with snac_model.encode()...") |
|
encode_start = time.time() |
|
with torch.inference_mode(): |
|
codes = snac_model.encode(waveform_batch) |
|
encode_end = time.time() |
|
|
|
if not codes or not all(isinstance(c, torch.Tensor) for c in codes): |
|
log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}" |
|
if isinstance(codes, list): |
|
log_msg += f" with types {[type(c) for c in codes]}" |
|
logs.append(log_msg) |
|
raise ValueError(log_msg) |
|
|
|
logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s") |
|
for i, layer_codes in enumerate(codes): |
|
logs.append(f" Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}") |
|
|
|
|
|
logs.append("Decoding the generated codes with snac_model.decode()...") |
|
decode_start = time.time() |
|
with torch.inference_mode(): |
|
reconstructed_waveform = snac_model.decode(codes) |
|
decode_end = time.time() |
|
logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s") |
|
|
|
|
|
|
|
reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() |
|
logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}") |
|
reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np) |
|
|
|
logs.append("\n--- Audio Processing Completed Successfully ---") |
|
return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs) |
|
|
|
except Exception as e: |
|
logs.append("\n--- An Error Occurred ---") |
|
logs.append(f"Error Type: {type(e).__name__}") |
|
logs.append(f"Error Details: {e}") |
|
logs.append("\n--- Traceback ---") |
|
logs.append(traceback.format_exc()) |
|
|
|
return None, None, None, "\n".join(logs) |
|
|
|
|
|
DESCRIPTION = """ |
|
This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_24khz`). |
|
1. Upload an audio file (wav, mp3, flac, etc.). |
|
2. The audio will be automatically resampled to 24kHz if needed. |
|
3. The 24kHz audio is encoded into discrete codes by SNAC. |
|
4. These codes are then decoded back into audio by SNAC. |
|
5. You can listen to the original, the 24kHz version (if resampled), and the final reconstructed audio. |
|
|
|
**Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed. |
|
""" |
|
|
|
iface = gr.Interface( |
|
fn=process_audio, |
|
inputs=gr.Audio(type="filepath", label="Upload Audio File"), |
|
outputs=[ |
|
gr.Audio(label="Original Audio"), |
|
gr.Audio(label="Resampled Audio (24kHz Input to SNAC)"), |
|
gr.Audio(label="Reconstructed Audio (Output from SNAC)"), |
|
gr.Textbox(label="Log Output", lines=15) |
|
], |
|
title="SNAC Audio Codec Demo (24kHz)", |
|
description=DESCRIPTION, |
|
examples=[ |
|
|
|
|
|
|
|
], |
|
cache_examples=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
if snac_model is None: |
|
print("Cannot launch Gradio interface because SNAC model failed to load.") |
|
else: |
|
print("Launching Gradio Interface...") |
|
iface.launch() |