Gapeleon commited on
Commit
3003400
·
verified ·
1 Parent(s): 44f2a3c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import torchaudio.transforms as T
5
+ import numpy as np
6
+ import traceback
7
+ import io
8
+ import time
9
+
10
+ # Attempt to import SNAC (should work if requirements.txt is correct)
11
+ try:
12
+ from snac import SNAC
13
+ print("SNAC module imported successfully.")
14
+ except ImportError as e:
15
+ print(f"Error importing SNAC: {e}")
16
+ # Raise a more informative error if SNAC isn't installed
17
+ raise ImportError("Could not import SNAC. Make sure 'snac-codec' is listed in requirements.txt and installed correctly.") from e
18
+
19
+ # --- Configuration ---
20
+ TARGET_SR = 24000 # SNAC operates at 24kHz
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print(f"Using device: {DEVICE}")
23
+
24
+ # --- Load Model (Load once globally) ---
25
+ snac_model = None
26
+ try:
27
+ print("Loading SNAC model...")
28
+ start_time = time.time()
29
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
30
+ snac_model = snac_model.to(DEVICE)
31
+ snac_model.eval() # Set model to evaluation mode
32
+ end_time = time.time()
33
+ print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.")
34
+ except Exception as e:
35
+ print(f"FATAL: Error loading SNAC model: {e}")
36
+ print(traceback.format_exc())
37
+ # If the model fails to load, the app can't function.
38
+ # Gradio will likely show an error, but we print specifics here.
39
+
40
+ # --- Main Processing Function ---
41
+ def process_audio(audio_filepath):
42
+ """
43
+ Loads, resamples, encodes, decodes audio using SNAC, and returns results.
44
+ """
45
+ if snac_model is None:
46
+ return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio."
47
+
48
+ if audio_filepath is None:
49
+ return None, None, None, "Please upload an audio file."
50
+
51
+ logs = ["--- Starting Audio Processing ---"]
52
+ try:
53
+ # 1. Load Audio
54
+ logs.append(f"Loading audio file: {audio_filepath}")
55
+ load_start = time.time()
56
+ original_waveform, original_sr = torchaudio.load(audio_filepath)
57
+ load_end = time.time()
58
+ logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s")
59
+
60
+ # Ensure float32
61
+ original_waveform = original_waveform.to(dtype=torch.float32)
62
+
63
+ # Handle multi-channel audio: Use the first channel
64
+ if original_waveform.shape[0] > 1:
65
+ logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.")
66
+ original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency initially
67
+
68
+ # --- Prepare Original for Playback ---
69
+ # Gradio Audio component expects (sample_rate, numpy_array)
70
+ # Ensure numpy array is 1D or 2D [channels, samples]
71
+ original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) # Squeeze removes channel dim if 1
72
+ logs.append("Prepared original audio for playback.")
73
+
74
+
75
+ # 2. Resample if necessary
76
+ resample_start = time.time()
77
+ if original_sr != TARGET_SR:
78
+ logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...")
79
+ resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) # Resampler on same device
80
+ waveform_to_encode = resampler(original_waveform)
81
+ logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}")
82
+ else:
83
+ logs.append("Waveform is already at the target sample rate (24kHz).")
84
+ waveform_to_encode = original_waveform
85
+ resample_end = time.time()
86
+ logs.append(f"Resampling time: {resample_end - resample_start:.2f}s")
87
+
88
+ # --- Prepare Resampled for Playback ---
89
+ resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy())
90
+ logs.append("Prepared resampled audio for playback.")
91
+
92
+ # 3. Prepare for SNAC Encoding (add batch dim, move to device)
93
+ # Input should be [Batch, Channel, Time] = [1, 1, Time]
94
+ # waveform_to_encode should currently be [1, Time] after channel selection/resampling
95
+ waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) # Add batch dimension -> [1, 1, Time]
96
+ logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}")
97
+
98
+ # 4. Encode Audio using SNAC
99
+ logs.append("Encoding audio with snac_model.encode()...")
100
+ encode_start = time.time()
101
+ with torch.inference_mode():
102
+ codes = snac_model.encode(waveform_batch)
103
+ encode_end = time.time()
104
+
105
+ if not codes or not all(isinstance(c, torch.Tensor) for c in codes):
106
+ log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}"
107
+ if isinstance(codes, list):
108
+ log_msg += f" with types {[type(c) for c in codes]}"
109
+ logs.append(log_msg)
110
+ raise ValueError(log_msg)
111
+
112
+ logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s")
113
+ for i, layer_codes in enumerate(codes):
114
+ logs.append(f" Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}")
115
+
116
+ # 5. Decode the Codes using SNAC
117
+ logs.append("Decoding the generated codes with snac_model.decode()...")
118
+ decode_start = time.time()
119
+ with torch.inference_mode():
120
+ reconstructed_waveform = snac_model.decode(codes) # codes are already on DEVICE
121
+ decode_end = time.time()
122
+ logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s")
123
+
124
+ # 6. Prepare Reconstructed Audio for Playback
125
+ # Output is [Batch, 1, Time]. Move to CPU, remove Batch/Channel, convert to NumPy.
126
+ reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() # Squeeze removes Batch and Channel dims
127
+ logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}")
128
+ reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np)
129
+
130
+ logs.append("\n--- Audio Processing Completed Successfully ---")
131
+ return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs)
132
+
133
+ except Exception as e:
134
+ logs.append("\n--- An Error Occurred ---")
135
+ logs.append(f"Error Type: {type(e).__name__}")
136
+ logs.append(f"Error Details: {e}")
137
+ logs.append("\n--- Traceback ---")
138
+ logs.append(traceback.format_exc())
139
+ # Return None for audio components on error, and the detailed log
140
+ return None, None, None, "\n".join(logs)
141
+
142
+ # --- Gradio Interface ---
143
+ DESCRIPTION = """
144
+ # SNAC Audio Codec Demo (24kHz)
145
+
146
+ This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_24khz`).
147
+ 1. Upload an audio file (wav, mp3, flac, etc.).
148
+ 2. The audio will be automatically resampled to 24kHz if needed.
149
+ 3. The 24kHz audio is encoded into discrete codes by SNAC.
150
+ 4. These codes are then decoded back into audio by SNAC.
151
+ 5. You can listen to the original, the 24kHz version (if resampled), and the final reconstructed audio.
152
+
153
+ **Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed.
154
+ """
155
+
156
+ iface = gr.Interface(
157
+ fn=process_audio,
158
+ inputs=gr.Audio(type="filepath", label="Upload Audio File"),
159
+ outputs=[
160
+ gr.Audio(label="Original Audio"),
161
+ gr.Audio(label="Resampled Audio (24kHz Input to SNAC)"),
162
+ gr.Audio(label="Reconstructed Audio (Output from SNAC)"),
163
+ gr.Textbox(label="Log Output", lines=15)
164
+ ],
165
+ title="SNAC Audio Codec Demo (24kHz)",
166
+ description=DESCRIPTION,
167
+ examples=[
168
+ # Add paths to example audio files if you upload some to your Space repo
169
+ # ["examples/example1.wav"],
170
+ # ["examples/example2.mp3"],
171
+ ],
172
+ cache_examples=False # Disable caching if examples change or have issues
173
+ )
174
+
175
+ if __name__ == "__main__":
176
+ if snac_model is None:
177
+ print("Cannot launch Gradio interface because SNAC model failed to load.")
178
+ else:
179
+ print("Launching Gradio Interface...")
180
+ iface.launch()