Gapeleon commited on
Commit
f0169cd
·
verified ·
1 Parent(s): dd3ef50

Update app.py

Browse files

revert accidental change

Files changed (1) hide show
  1. app.py +160 -136
app.py CHANGED
@@ -1,154 +1,178 @@
 
1
  import torch
2
  import torchaudio
3
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig
4
- import gradio as gr
5
- import os
6
- import time
7
  import numpy as np
 
 
 
8
 
9
- # Load model and processor (runs once on startup)
10
- model_name = "ibm-granite/granite-speech-3.2-8b"
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- print(f"Using device: {device}")
13
-
14
- print("Loading processor...")
15
- speech_granite_processor = AutoProcessor.from_pretrained(
16
- model_name, trust_remote_code=True)
17
- tokenizer = speech_granite_processor.tokenizer
18
-
19
- print("Loading model with 4-bit quantization...")
20
- quantization_config = BitsAndBytesConfig(
21
- load_in_4bit=True,
22
- bnb_4bit_compute_dtype=torch.float16,
23
- bnb_4bit_quant_type="nf4",
24
- bnb_4bit_use_double_quant=True
25
- )
26
 
27
- speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
28
- model_name,
29
- quantization_config=quantization_config,
30
- device_map="auto",
31
- trust_remote_code=True
32
- )
33
- print("Model loaded successfully")
34
 
35
- def transcribe_audio(audio_input):
36
- """Process audio input and return transcription"""
 
 
37
  start_time = time.time()
38
-
39
- logs = [f"Audio input received: {type(audio_input)}"]
40
-
41
- if audio_input is None:
42
- return "Error: No audio provided.", 0.0
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
- # Handle different audio input formats
46
- if isinstance(audio_input, tuple) and len(audio_input) == 2:
47
- # Microphone input: (sample_rate, numpy_array)
48
- logs.append("Processing microphone input")
49
- sr, wav_np = audio_input
50
- wav = torch.from_numpy(wav_np).float().unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  else:
52
- # File input: filepath string
53
- logs.append(f"Processing file input: {audio_input}")
54
- wav, sr = torchaudio.load(audio_input)
55
- logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
56
-
57
- # Convert to mono if stereo
58
- if wav.shape[0] > 1:
59
- wav = torch.mean(wav, dim=0, keepdim=True)
60
- logs.append("Converted stereo to mono")
61
-
62
- # Resample to 16kHz if needed
63
- if sr != 16000:
64
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
65
- wav = resampler(wav)
66
- sr = 16000
67
- logs.append(f"Resampled to {sr}Hz")
68
-
69
- logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
70
-
71
- # Create text prompt
72
- chat = [
73
- {
74
- "role": "system",
75
- "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
76
- },
77
- {
78
- "role": "user",
79
- "content": "<|audio|>can you transcribe the speech into a written format?",
80
- }
81
- ]
82
-
83
- text = tokenizer.apply_chat_template(
84
- chat, tokenize=False, add_generation_prompt=True
85
- )
86
-
87
- # Compute audio embeddings
88
- logs.append("Preparing model inputs")
89
- model_inputs = speech_granite_processor(
90
- text=text,
91
- audio=wav.numpy().squeeze(), # Convert to numpy and squeeze
92
- sampling_rate=sr,
93
- return_tensors="pt",
94
- ).to(device)
95
-
96
- # Generate transcription
97
- logs.append("Generating transcription")
98
- model_outputs = speech_granite.generate(
99
- **model_inputs,
100
- max_new_tokens=1000,
101
- num_beams=4,
102
- do_sample=False,
103
- min_length=1,
104
- top_p=1.0,
105
- repetition_penalty=3.0,
106
- length_penalty=1.0,
107
- temperature=1.0,
108
- bos_token_id=tokenizer.bos_token_id,
109
- eos_token_id=tokenizer.eos_token_id,
110
- pad_token_id=tokenizer.pad_token_id,
111
- )
112
-
113
- # Extract the generated text (skipping input tokens)
114
- logs.append("Processing output")
115
- num_input_tokens = model_inputs["input_ids"].shape[-1]
116
- new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
117
-
118
- output_text = tokenizer.batch_decode(
119
- new_tokens, add_special_tokens=False, skip_special_tokens=True
120
- )
121
-
122
- transcription = output_text[0].strip().upper()
123
- logs.append(f"Transcription complete: {transcription[:50]}...")
124
-
125
  except Exception as e:
126
- import traceback
127
- error_trace = traceback.format_exc()
128
- print(error_trace)
129
- print("\n".join(logs))
130
- return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0
131
-
132
- processing_time = round(time.time() - start_time, 2)
133
- return transcription, processing_time
134
-
135
- # Create Gradio interface
136
- title = "IBM Granite Speech-to-Text (8B Quantized)"
137
- description = """
138
- Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit).
139
- Upload an audio file or use your microphone to record speech.
 
 
 
 
140
  """
141
 
142
  iface = gr.Interface(
143
- fn=transcribe_audio,
144
- inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
145
  outputs=[
146
- gr.Textbox(label="Transcription", lines=5),
147
- gr.Number(label="Processing Time (seconds)")
 
 
 
 
 
 
 
 
 
148
  ],
149
- title=title,
150
- description=description,
151
  )
152
 
153
  if __name__ == "__main__":
154
- iface.launch()
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
  import torchaudio
4
+ import torchaudio.transforms as T
 
 
 
5
  import numpy as np
6
+ import traceback
7
+ import io
8
+ import time
9
 
10
+ # Attempt to import SNAC (should work if requirements.txt is correct)
11
+ try:
12
+ from snac import SNAC
13
+ print("SNAC module imported successfully.")
14
+ except ImportError as e:
15
+ print(f"Error importing SNAC: {e}")
16
+ # Raise a more informative error if SNAC isn't installed
17
+ raise ImportError("Could not import SNAC. Make sure 'snac' is listed in requirements.txt and installed correctly.") from e
 
 
 
 
 
 
 
 
 
18
 
19
+ # --- Configuration ---
20
+ TARGET_SR = 24000 # SNAC operates at 24kHz
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print(f"Using device: {DEVICE}")
 
 
 
23
 
24
+ # --- Load Model (Load once globally) ---
25
+ snac_model = None
26
+ try:
27
+ print("Loading SNAC model...")
28
  start_time = time.time()
29
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
30
+ snac_model = snac_model.to(DEVICE)
31
+ snac_model.eval() # Set model to evaluation mode
32
+ end_time = time.time()
33
+ print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.")
34
+ except Exception as e:
35
+ print(f"FATAL: Error loading SNAC model: {e}")
36
+ print(traceback.format_exc())
37
+ # If the model fails to load, the app can't function.
38
+ # Gradio will likely show an error, but we print specifics here.
39
+
40
+ # --- Main Processing Function ---
41
+ def process_audio(audio_filepath):
42
+ """
43
+ Loads, resamples, encodes, decodes audio using SNAC, and returns results.
44
+ """
45
+ if snac_model is None:
46
+ return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio."
47
+
48
+ if audio_filepath is None:
49
+ return None, None, None, "Please upload an audio file."
50
+
51
+ logs = ["--- Starting Audio Processing ---"]
52
  try:
53
+ # 1. Load Audio
54
+ logs.append(f"Loading audio file: {audio_filepath}")
55
+ load_start = time.time()
56
+ original_waveform, original_sr = torchaudio.load(audio_filepath)
57
+ load_end = time.time()
58
+ logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s")
59
+
60
+ # Ensure float32
61
+ original_waveform = original_waveform.to(dtype=torch.float32)
62
+
63
+ # Handle multi-channel audio: Use the first channel
64
+ if original_waveform.shape[0] > 1:
65
+ logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.")
66
+ original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency initially
67
+
68
+ # --- Prepare Original for Playback ---
69
+ # Gradio Audio component expects (sample_rate, numpy_array)
70
+ # Ensure numpy array is 1D or 2D [channels, samples]
71
+ original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) # Squeeze removes channel dim if 1
72
+ logs.append("Prepared original audio for playback.")
73
+
74
+
75
+ # 2. Resample if necessary
76
+ resample_start = time.time()
77
+ if original_sr != TARGET_SR:
78
+ logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...")
79
+ resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) # Resampler on same device
80
+ waveform_to_encode = resampler(original_waveform)
81
+ logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}")
82
  else:
83
+ logs.append("Waveform is already at the target sample rate (24kHz).")
84
+ waveform_to_encode = original_waveform
85
+ resample_end = time.time()
86
+ logs.append(f"Resampling time: {resample_end - resample_start:.2f}s")
87
+
88
+ # --- Prepare Resampled for Playback ---
89
+ resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy())
90
+ logs.append("Prepared resampled audio for playback.")
91
+
92
+ # 3. Prepare for SNAC Encoding (add batch dim, move to device)
93
+ # Input should be [Batch, Channel, Time] = [1, 1, Time]
94
+ # waveform_to_encode should currently be [1, Time] after channel selection/resampling
95
+ waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) # Add batch dimension -> [1, 1, Time]
96
+ logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}")
97
+
98
+ # 4. Encode Audio using SNAC
99
+ logs.append("Encoding audio with snac_model.encode()...")
100
+ encode_start = time.time()
101
+ with torch.inference_mode():
102
+ codes = snac_model.encode(waveform_batch)
103
+ encode_end = time.time()
104
+
105
+ if not codes or not all(isinstance(c, torch.Tensor) for c in codes):
106
+ log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}"
107
+ if isinstance(codes, list):
108
+ log_msg += f" with types {[type(c) for c in codes]}"
109
+ logs.append(log_msg)
110
+ raise ValueError(log_msg)
111
+
112
+ logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s")
113
+ for i, layer_codes in enumerate(codes):
114
+ logs.append(f" Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}")
115
+
116
+ # 5. Decode the Codes using SNAC
117
+ logs.append("Decoding the generated codes with snac_model.decode()...")
118
+ decode_start = time.time()
119
+ with torch.inference_mode():
120
+ reconstructed_waveform = snac_model.decode(codes) # codes are already on DEVICE
121
+ decode_end = time.time()
122
+ logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s")
123
+
124
+ # 6. Prepare Reconstructed Audio for Playback
125
+ # Output is [Batch, 1, Time]. Move to CPU, remove Batch/Channel, convert to NumPy.
126
+ reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() # Squeeze removes Batch and Channel dims
127
+ logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}")
128
+ reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np)
129
+
130
+ logs.append("\n--- Audio Processing Completed Successfully ---")
131
+ return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs)
132
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
+ logs.append("\n--- An Error Occurred ---")
135
+ logs.append(f"Error Type: {type(e).__name__}")
136
+ logs.append(f"Error Details: {e}")
137
+ logs.append("\n--- Traceback ---")
138
+ logs.append(traceback.format_exc())
139
+ # Return None for audio components on error, and the detailed log
140
+ return None, None, None, "\n".join(logs)
141
+
142
+ # --- Gradio Interface ---
143
+ DESCRIPTION = """
144
+ This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_24khz`).
145
+ 1. Upload an audio file (wav, mp3, flac, etc.).
146
+ 2. The audio will be automatically resampled to 24kHz if needed.
147
+ 3. The 24kHz audio is encoded into discrete codes by SNAC.
148
+ 4. These codes are then decoded back into audio by SNAC.
149
+ 5. You can listen to the original, the 24kHz version (if resampled), and the final reconstructed audio.
150
+
151
+ **Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed.
152
  """
153
 
154
  iface = gr.Interface(
155
+ fn=process_audio,
156
+ inputs=gr.Audio(type="filepath", label="Upload Audio File"),
157
  outputs=[
158
+ gr.Audio(label="Original Audio"),
159
+ gr.Audio(label="Resampled Audio (24kHz Input to SNAC)"),
160
+ gr.Audio(label="Reconstructed Audio (Output from SNAC)"),
161
+ gr.Textbox(label="Log Output", lines=15)
162
+ ],
163
+ title="SNAC Audio Codec Demo (24kHz)",
164
+ description=DESCRIPTION,
165
+ examples=[
166
+ # Add paths to example audio files if you upload some to your Space repo
167
+ # ["examples/example1.wav"],
168
+ # ["examples/example2.mp3"],
169
  ],
170
+ cache_examples=False # Disable caching if examples change or have issues
 
171
  )
172
 
173
  if __name__ == "__main__":
174
+ if snac_model is None:
175
+ print("Cannot launch Gradio interface because SNAC model failed to load.")
176
+ else:
177
+ print("Launching Gradio Interface...")
178
+ iface.launch()