snac_test / app.py
Gapeleon's picture
Update app.py
f0169cd verified
import gradio as gr
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
import traceback
import io
import time
# Attempt to import SNAC (should work if requirements.txt is correct)
try:
from snac import SNAC
print("SNAC module imported successfully.")
except ImportError as e:
print(f"Error importing SNAC: {e}")
# Raise a more informative error if SNAC isn't installed
raise ImportError("Could not import SNAC. Make sure 'snac' is listed in requirements.txt and installed correctly.") from e
# --- Configuration ---
TARGET_SR = 24000 # SNAC operates at 24kHz
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# --- Load Model (Load once globally) ---
snac_model = None
try:
print("Loading SNAC model...")
start_time = time.time()
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to(DEVICE)
snac_model.eval() # Set model to evaluation mode
end_time = time.time()
print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.")
except Exception as e:
print(f"FATAL: Error loading SNAC model: {e}")
print(traceback.format_exc())
# If the model fails to load, the app can't function.
# Gradio will likely show an error, but we print specifics here.
# --- Main Processing Function ---
def process_audio(audio_filepath):
"""
Loads, resamples, encodes, decodes audio using SNAC, and returns results.
"""
if snac_model is None:
return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio."
if audio_filepath is None:
return None, None, None, "Please upload an audio file."
logs = ["--- Starting Audio Processing ---"]
try:
# 1. Load Audio
logs.append(f"Loading audio file: {audio_filepath}")
load_start = time.time()
original_waveform, original_sr = torchaudio.load(audio_filepath)
load_end = time.time()
logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s")
# Ensure float32
original_waveform = original_waveform.to(dtype=torch.float32)
# Handle multi-channel audio: Use the first channel
if original_waveform.shape[0] > 1:
logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.")
original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency initially
# --- Prepare Original for Playback ---
# Gradio Audio component expects (sample_rate, numpy_array)
# Ensure numpy array is 1D or 2D [channels, samples]
original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) # Squeeze removes channel dim if 1
logs.append("Prepared original audio for playback.")
# 2. Resample if necessary
resample_start = time.time()
if original_sr != TARGET_SR:
logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...")
resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) # Resampler on same device
waveform_to_encode = resampler(original_waveform)
logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}")
else:
logs.append("Waveform is already at the target sample rate (24kHz).")
waveform_to_encode = original_waveform
resample_end = time.time()
logs.append(f"Resampling time: {resample_end - resample_start:.2f}s")
# --- Prepare Resampled for Playback ---
resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy())
logs.append("Prepared resampled audio for playback.")
# 3. Prepare for SNAC Encoding (add batch dim, move to device)
# Input should be [Batch, Channel, Time] = [1, 1, Time]
# waveform_to_encode should currently be [1, Time] after channel selection/resampling
waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) # Add batch dimension -> [1, 1, Time]
logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}")
# 4. Encode Audio using SNAC
logs.append("Encoding audio with snac_model.encode()...")
encode_start = time.time()
with torch.inference_mode():
codes = snac_model.encode(waveform_batch)
encode_end = time.time()
if not codes or not all(isinstance(c, torch.Tensor) for c in codes):
log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}"
if isinstance(codes, list):
log_msg += f" with types {[type(c) for c in codes]}"
logs.append(log_msg)
raise ValueError(log_msg)
logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s")
for i, layer_codes in enumerate(codes):
logs.append(f" Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}")
# 5. Decode the Codes using SNAC
logs.append("Decoding the generated codes with snac_model.decode()...")
decode_start = time.time()
with torch.inference_mode():
reconstructed_waveform = snac_model.decode(codes) # codes are already on DEVICE
decode_end = time.time()
logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s")
# 6. Prepare Reconstructed Audio for Playback
# Output is [Batch, 1, Time]. Move to CPU, remove Batch/Channel, convert to NumPy.
reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() # Squeeze removes Batch and Channel dims
logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}")
reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np)
logs.append("\n--- Audio Processing Completed Successfully ---")
return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs)
except Exception as e:
logs.append("\n--- An Error Occurred ---")
logs.append(f"Error Type: {type(e).__name__}")
logs.append(f"Error Details: {e}")
logs.append("\n--- Traceback ---")
logs.append(traceback.format_exc())
# Return None for audio components on error, and the detailed log
return None, None, None, "\n".join(logs)
# --- Gradio Interface ---
DESCRIPTION = """
This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_24khz`).
1. Upload an audio file (wav, mp3, flac, etc.).
2. The audio will be automatically resampled to 24kHz if needed.
3. The 24kHz audio is encoded into discrete codes by SNAC.
4. These codes are then decoded back into audio by SNAC.
5. You can listen to the original, the 24kHz version (if resampled), and the final reconstructed audio.
**Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed.
"""
iface = gr.Interface(
fn=process_audio,
inputs=gr.Audio(type="filepath", label="Upload Audio File"),
outputs=[
gr.Audio(label="Original Audio"),
gr.Audio(label="Resampled Audio (24kHz Input to SNAC)"),
gr.Audio(label="Reconstructed Audio (Output from SNAC)"),
gr.Textbox(label="Log Output", lines=15)
],
title="SNAC Audio Codec Demo (24kHz)",
description=DESCRIPTION,
examples=[
# Add paths to example audio files if you upload some to your Space repo
# ["examples/example1.wav"],
# ["examples/example2.mp3"],
],
cache_examples=False # Disable caching if examples change or have issues
)
if __name__ == "__main__":
if snac_model is None:
print("Cannot launch Gradio interface because SNAC model failed to load.")
else:
print("Launching Gradio Interface...")
iface.launch()