Gapeleon commited on
Commit
dd3ef50
·
verified ·
1 Parent(s): d54f19d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -160
app.py CHANGED
@@ -1,178 +1,154 @@
1
- import gradio as gr
2
  import torch
3
  import torchaudio
4
- import torchaudio.transforms as T
5
- import numpy as np
6
- import traceback
7
- import io
8
  import time
 
9
 
10
- # Attempt to import SNAC (should work if requirements.txt is correct)
11
- try:
12
- from snac import SNAC
13
- print("SNAC module imported successfully.")
14
- except ImportError as e:
15
- print(f"Error importing SNAC: {e}")
16
- # Raise a more informative error if SNAC isn't installed
17
- raise ImportError("Could not import SNAC. Make sure 'snac' is listed in requirements.txt and installed correctly.") from e
 
 
 
 
 
 
 
 
 
18
 
19
- # --- Configuration ---
20
- TARGET_SR = 24000 # SNAC operates at 24kHz
21
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
- print(f"Using device: {DEVICE}")
 
 
 
23
 
24
- # --- Load Model (Load once globally) ---
25
- snac_model = None
26
- try:
27
- print("Loading SNAC model...")
28
  start_time = time.time()
29
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
30
- snac_model = snac_model.to(DEVICE)
31
- snac_model.eval() # Set model to evaluation mode
32
- end_time = time.time()
33
- print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.")
34
- except Exception as e:
35
- print(f"FATAL: Error loading SNAC model: {e}")
36
- print(traceback.format_exc())
37
- # If the model fails to load, the app can't function.
38
- # Gradio will likely show an error, but we print specifics here.
39
-
40
- # --- Main Processing Function ---
41
- def process_audio(audio_filepath):
42
- """
43
- Loads, resamples, encodes, decodes audio using SNAC, and returns results.
44
- """
45
- if snac_model is None:
46
- return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio."
47
-
48
- if audio_filepath is None:
49
- return None, None, None, "Please upload an audio file."
50
-
51
- logs = ["--- Starting Audio Processing ---"]
52
  try:
53
- # 1. Load Audio
54
- logs.append(f"Loading audio file: {audio_filepath}")
55
- load_start = time.time()
56
- original_waveform, original_sr = torchaudio.load(audio_filepath)
57
- load_end = time.time()
58
- logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s")
59
-
60
- # Ensure float32
61
- original_waveform = original_waveform.to(dtype=torch.float32)
62
-
63
- # Handle multi-channel audio: Use the first channel
64
- if original_waveform.shape[0] > 1:
65
- logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.")
66
- original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency initially
67
-
68
- # --- Prepare Original for Playback ---
69
- # Gradio Audio component expects (sample_rate, numpy_array)
70
- # Ensure numpy array is 1D or 2D [channels, samples]
71
- original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) # Squeeze removes channel dim if 1
72
- logs.append("Prepared original audio for playback.")
73
-
74
-
75
- # 2. Resample if necessary
76
- resample_start = time.time()
77
- if original_sr != TARGET_SR:
78
- logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...")
79
- resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) # Resampler on same device
80
- waveform_to_encode = resampler(original_waveform)
81
- logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}")
82
  else:
83
- logs.append("Waveform is already at the target sample rate (24kHz).")
84
- waveform_to_encode = original_waveform
85
- resample_end = time.time()
86
- logs.append(f"Resampling time: {resample_end - resample_start:.2f}s")
87
-
88
- # --- Prepare Resampled for Playback ---
89
- resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy())
90
- logs.append("Prepared resampled audio for playback.")
91
-
92
- # 3. Prepare for SNAC Encoding (add batch dim, move to device)
93
- # Input should be [Batch, Channel, Time] = [1, 1, Time]
94
- # waveform_to_encode should currently be [1, Time] after channel selection/resampling
95
- waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) # Add batch dimension -> [1, 1, Time]
96
- logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}")
97
-
98
- # 4. Encode Audio using SNAC
99
- logs.append("Encoding audio with snac_model.encode()...")
100
- encode_start = time.time()
101
- with torch.inference_mode():
102
- codes = snac_model.encode(waveform_batch)
103
- encode_end = time.time()
104
-
105
- if not codes or not all(isinstance(c, torch.Tensor) for c in codes):
106
- log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}"
107
- if isinstance(codes, list):
108
- log_msg += f" with types {[type(c) for c in codes]}"
109
- logs.append(log_msg)
110
- raise ValueError(log_msg)
111
-
112
- logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s")
113
- for i, layer_codes in enumerate(codes):
114
- logs.append(f" Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}")
115
-
116
- # 5. Decode the Codes using SNAC
117
- logs.append("Decoding the generated codes with snac_model.decode()...")
118
- decode_start = time.time()
119
- with torch.inference_mode():
120
- reconstructed_waveform = snac_model.decode(codes) # codes are already on DEVICE
121
- decode_end = time.time()
122
- logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s")
123
-
124
- # 6. Prepare Reconstructed Audio for Playback
125
- # Output is [Batch, 1, Time]. Move to CPU, remove Batch/Channel, convert to NumPy.
126
- reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() # Squeeze removes Batch and Channel dims
127
- logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}")
128
- reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np)
129
-
130
- logs.append("\n--- Audio Processing Completed Successfully ---")
131
- return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs)
132
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
- logs.append("\n--- An Error Occurred ---")
135
- logs.append(f"Error Type: {type(e).__name__}")
136
- logs.append(f"Error Details: {e}")
137
- logs.append("\n--- Traceback ---")
138
- logs.append(traceback.format_exc())
139
- # Return None for audio components on error, and the detailed log
140
- return None, None, None, "\n".join(logs)
141
-
142
- # --- Gradio Interface ---
143
- DESCRIPTION = """
144
- This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_24khz`).
145
- 1. Upload an audio file (wav, mp3, flac, etc.).
146
- 2. The audio will be automatically resampled to 24kHz if needed.
147
- 3. The 24kHz audio is encoded into discrete codes by SNAC.
148
- 4. These codes are then decoded back into audio by SNAC.
149
- 5. You can listen to the original, the 24kHz version (if resampled), and the final reconstructed audio.
150
-
151
- **Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed.
152
  """
153
 
154
  iface = gr.Interface(
155
- fn=process_audio,
156
- inputs=gr.Audio(type="filepath", label="Upload Audio File"),
157
  outputs=[
158
- gr.Audio(label="Original Audio"),
159
- gr.Audio(label="Resampled Audio (24kHz Input to SNAC)"),
160
- gr.Audio(label="Reconstructed Audio (Output from SNAC)"),
161
- gr.Textbox(label="Log Output", lines=15)
162
- ],
163
- title="SNAC Audio Codec Demo (24kHz)",
164
- description=DESCRIPTION,
165
- examples=[
166
- # Add paths to example audio files if you upload some to your Space repo
167
- # ["examples/example1.wav"],
168
- # ["examples/example2.mp3"],
169
  ],
170
- cache_examples=False # Disable caching if examples change or have issues
 
171
  )
172
 
173
  if __name__ == "__main__":
174
- if snac_model is None:
175
- print("Cannot launch Gradio interface because SNAC model failed to load.")
176
- else:
177
- print("Launching Gradio Interface...")
178
- iface.launch()
 
 
1
  import torch
2
  import torchaudio
3
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig
4
+ import gradio as gr
5
+ import os
 
6
  import time
7
+ import numpy as np
8
 
9
+ # Load model and processor (runs once on startup)
10
+ model_name = "ibm-granite/granite-speech-3.2-8b"
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"Using device: {device}")
13
+
14
+ print("Loading processor...")
15
+ speech_granite_processor = AutoProcessor.from_pretrained(
16
+ model_name, trust_remote_code=True)
17
+ tokenizer = speech_granite_processor.tokenizer
18
+
19
+ print("Loading model with 4-bit quantization...")
20
+ quantization_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_compute_dtype=torch.float16,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_use_double_quant=True
25
+ )
26
 
27
+ speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
28
+ model_name,
29
+ quantization_config=quantization_config,
30
+ device_map="auto",
31
+ trust_remote_code=True
32
+ )
33
+ print("Model loaded successfully")
34
 
35
+ def transcribe_audio(audio_input):
36
+ """Process audio input and return transcription"""
 
 
37
  start_time = time.time()
38
+
39
+ logs = [f"Audio input received: {type(audio_input)}"]
40
+
41
+ if audio_input is None:
42
+ return "Error: No audio provided.", 0.0
43
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
+ # Handle different audio input formats
46
+ if isinstance(audio_input, tuple) and len(audio_input) == 2:
47
+ # Microphone input: (sample_rate, numpy_array)
48
+ logs.append("Processing microphone input")
49
+ sr, wav_np = audio_input
50
+ wav = torch.from_numpy(wav_np).float().unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  else:
52
+ # File input: filepath string
53
+ logs.append(f"Processing file input: {audio_input}")
54
+ wav, sr = torchaudio.load(audio_input)
55
+ logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
56
+
57
+ # Convert to mono if stereo
58
+ if wav.shape[0] > 1:
59
+ wav = torch.mean(wav, dim=0, keepdim=True)
60
+ logs.append("Converted stereo to mono")
61
+
62
+ # Resample to 16kHz if needed
63
+ if sr != 16000:
64
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
65
+ wav = resampler(wav)
66
+ sr = 16000
67
+ logs.append(f"Resampled to {sr}Hz")
68
+
69
+ logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
70
+
71
+ # Create text prompt
72
+ chat = [
73
+ {
74
+ "role": "system",
75
+ "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
76
+ },
77
+ {
78
+ "role": "user",
79
+ "content": "<|audio|>can you transcribe the speech into a written format?",
80
+ }
81
+ ]
82
+
83
+ text = tokenizer.apply_chat_template(
84
+ chat, tokenize=False, add_generation_prompt=True
85
+ )
86
+
87
+ # Compute audio embeddings
88
+ logs.append("Preparing model inputs")
89
+ model_inputs = speech_granite_processor(
90
+ text=text,
91
+ audio=wav.numpy().squeeze(), # Convert to numpy and squeeze
92
+ sampling_rate=sr,
93
+ return_tensors="pt",
94
+ ).to(device)
95
+
96
+ # Generate transcription
97
+ logs.append("Generating transcription")
98
+ model_outputs = speech_granite.generate(
99
+ **model_inputs,
100
+ max_new_tokens=1000,
101
+ num_beams=4,
102
+ do_sample=False,
103
+ min_length=1,
104
+ top_p=1.0,
105
+ repetition_penalty=3.0,
106
+ length_penalty=1.0,
107
+ temperature=1.0,
108
+ bos_token_id=tokenizer.bos_token_id,
109
+ eos_token_id=tokenizer.eos_token_id,
110
+ pad_token_id=tokenizer.pad_token_id,
111
+ )
112
+
113
+ # Extract the generated text (skipping input tokens)
114
+ logs.append("Processing output")
115
+ num_input_tokens = model_inputs["input_ids"].shape[-1]
116
+ new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
117
+
118
+ output_text = tokenizer.batch_decode(
119
+ new_tokens, add_special_tokens=False, skip_special_tokens=True
120
+ )
121
+
122
+ transcription = output_text[0].strip().upper()
123
+ logs.append(f"Transcription complete: {transcription[:50]}...")
124
+
125
  except Exception as e:
126
+ import traceback
127
+ error_trace = traceback.format_exc()
128
+ print(error_trace)
129
+ print("\n".join(logs))
130
+ return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0
131
+
132
+ processing_time = round(time.time() - start_time, 2)
133
+ return transcription, processing_time
134
+
135
+ # Create Gradio interface
136
+ title = "IBM Granite Speech-to-Text (8B Quantized)"
137
+ description = """
138
+ Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit).
139
+ Upload an audio file or use your microphone to record speech.
 
 
 
 
140
  """
141
 
142
  iface = gr.Interface(
143
+ fn=transcribe_audio,
144
+ inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
145
  outputs=[
146
+ gr.Textbox(label="Transcription", lines=5),
147
+ gr.Number(label="Processing Time (seconds)")
 
 
 
 
 
 
 
 
 
148
  ],
149
+ title=title,
150
+ description=description,
151
  )
152
 
153
  if __name__ == "__main__":
154
+ iface.launch()