Spaces:
Running
Running
Commit
·
86892b5
1
Parent(s):
bba8dbb
Better error logging, fallbacks
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
import cv2
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
-
from transformers import
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
9 |
import time
|
10 |
import nltk
|
@@ -41,7 +41,8 @@ def analyze_image(image, vision_components):
|
|
41 |
outputs = model.generate(**inputs, max_length=30)
|
42 |
caption = processor.decode(outputs[0], skip_special_tokens=True)
|
43 |
return caption if isinstance(caption, str) else ""
|
44 |
-
except Exception:
|
|
|
45 |
return "" # Return empty string on error
|
46 |
|
47 |
def initialize_llm():
|
@@ -88,21 +89,37 @@ def generate_roast(caption, llm_components):
|
|
88 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
89 |
response = response.split("[/INST]")[1].strip()
|
90 |
return response if isinstance(response, str) else ""
|
91 |
-
except Exception:
|
|
|
92 |
return "" # Return empty string on error
|
93 |
|
94 |
# Parler-TTS setup
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
def parler_preprocess(text):
|
105 |
-
text =
|
106 |
if text and text[-1] not in punctuation:
|
107 |
text = f"{text}."
|
108 |
abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
|
@@ -115,79 +132,149 @@ def parler_preprocess(text):
|
|
115 |
text = text.replace(abv, separate_abb(abv))
|
116 |
return text
|
117 |
|
118 |
-
def text_to_speech(text):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
description = ("Elisabeth speaks in a mature, strict, nagging, and slightly disappointed tone, "
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
if not text or not isinstance(text, str):
|
124 |
-
return (
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
if not roast or not isinstance(roast, str):
|
136 |
-
audio = (PARLER_SAMPLE_RATE, np.zeros(1))
|
137 |
-
else:
|
138 |
-
audio = text_to_speech(roast)
|
139 |
-
return caption, roast, audio
|
140 |
|
141 |
-
def
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
)
|
159 |
-
final_caption = caption if isinstance(caption, str) else default_caption
|
160 |
-
final_roast = roast if isinstance(roast, str) else default_roast
|
161 |
-
final_audio = audio if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray) else default_audio
|
162 |
-
return image, final_caption, final_roast, final_audio
|
163 |
-
return image, default_caption, default_roast, default_audio
|
164 |
-
video_feed.change(
|
165 |
-
process_webcam,
|
166 |
-
inputs=[video_feed],
|
167 |
-
outputs=[video_feed, analysis_output, roast_output, audio_output]
|
168 |
-
)
|
169 |
|
170 |
def create_app():
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
174 |
|
175 |
-
with
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
183 |
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
if __name__ == "__main__":
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import cv2
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
9 |
import time
|
10 |
import nltk
|
|
|
41 |
outputs = model.generate(**inputs, max_length=30)
|
42 |
caption = processor.decode(outputs[0], skip_special_tokens=True)
|
43 |
return caption if isinstance(caption, str) else ""
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error in analyze_image: {str(e)}")
|
46 |
return "" # Return empty string on error
|
47 |
|
48 |
def initialize_llm():
|
|
|
89 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
90 |
response = response.split("[/INST]")[1].strip()
|
91 |
return response if isinstance(response, str) else ""
|
92 |
+
except Exception as e:
|
93 |
+
print(f"Error in generate_roast: {str(e)}")
|
94 |
return "" # Return empty string on error
|
95 |
|
96 |
# Parler-TTS setup
|
97 |
+
def setup_tts():
|
98 |
+
try:
|
99 |
+
parler_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
100 |
+
parler_repo_id = "parler-tts/parler-tts-mini-expresso"
|
101 |
+
parler_model = ParlerTTSForConditionalGeneration.from_pretrained(parler_repo_id).to(parler_device)
|
102 |
+
parler_tokenizer = AutoTokenizer.from_pretrained(parler_repo_id)
|
103 |
+
parler_feature_extractor = AutoFeatureExtractor.from_pretrained(parler_repo_id)
|
104 |
+
PARLER_SAMPLE_RATE = parler_feature_extractor.sampling_rate
|
105 |
+
PARLER_SEED = 42
|
106 |
+
parler_number_normalizer = EnglishNumberNormalizer()
|
107 |
+
|
108 |
+
return {
|
109 |
+
"model": parler_model,
|
110 |
+
"tokenizer": parler_tokenizer,
|
111 |
+
"feature_extractor": parler_feature_extractor,
|
112 |
+
"sample_rate": PARLER_SAMPLE_RATE,
|
113 |
+
"seed": PARLER_SEED,
|
114 |
+
"number_normalizer": parler_number_normalizer,
|
115 |
+
"device": parler_device
|
116 |
+
}
|
117 |
+
except Exception as e:
|
118 |
+
print(f"Error setting up TTS: {str(e)}")
|
119 |
+
return None
|
120 |
|
121 |
+
def parler_preprocess(text, number_normalizer):
|
122 |
+
text = number_normalizer(text).strip()
|
123 |
if text and text[-1] not in punctuation:
|
124 |
text = f"{text}."
|
125 |
abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
|
|
|
132 |
text = text.replace(abv, separate_abb(abv))
|
133 |
return text
|
134 |
|
135 |
+
def text_to_speech(text, tts_components):
|
136 |
+
if tts_components is None:
|
137 |
+
return (16000, np.zeros(1)) # Default sample rate if components failed to load
|
138 |
+
|
139 |
+
model = tts_components["model"]
|
140 |
+
tokenizer = tts_components["tokenizer"]
|
141 |
+
device = tts_components["device"]
|
142 |
+
sample_rate = tts_components["sample_rate"]
|
143 |
+
seed = tts_components["seed"]
|
144 |
+
number_normalizer = tts_components["number_normalizer"]
|
145 |
+
|
146 |
description = ("Elisabeth speaks in a mature, strict, nagging, and slightly disappointed tone, "
|
147 |
+
"with a hint of love and high expectations, at a moderate pace with high quality audio. "
|
148 |
+
"She sounds like a stereotypical Asian mother who compares you to your cousins, "
|
149 |
+
"questions your life choices, and threatens you with a slipper, but ultimately wants the best for you.")
|
150 |
if not text or not isinstance(text, str):
|
151 |
+
return (sample_rate, np.zeros(1))
|
152 |
+
try:
|
153 |
+
inputs = tokenizer(description, return_tensors="pt").to(device)
|
154 |
+
prompt = tokenizer(parler_preprocess(text, number_normalizer), return_tensors="pt").to(device)
|
155 |
+
set_seed(seed)
|
156 |
+
generation = model.generate(input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids)
|
157 |
+
audio_arr = generation.cpu().numpy().squeeze()
|
158 |
+
return (sample_rate, audio_arr)
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Error in text_to_speech: {str(e)}")
|
161 |
+
return (sample_rate, np.zeros(1))
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
+
def process_frame(image, vision_components, llm_components, tts_components):
|
164 |
+
try:
|
165 |
+
caption = analyze_image(image, vision_components)
|
166 |
+
roast = generate_roast(caption, llm_components)
|
167 |
+
|
168 |
+
default_sample_rate = 16000
|
169 |
+
if tts_components is not None:
|
170 |
+
default_sample_rate = tts_components["sample_rate"]
|
171 |
+
|
172 |
+
if not roast or not isinstance(roast, str):
|
173 |
+
audio = (default_sample_rate, np.zeros(1))
|
174 |
+
else:
|
175 |
+
audio = text_to_speech(roast, tts_components)
|
176 |
+
return caption, roast, audio
|
177 |
+
except Exception as e:
|
178 |
+
print(f"Error in process_frame: {str(e)}")
|
179 |
+
return "", "", (default_sample_rate, np.zeros(1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
def create_app():
|
182 |
+
try:
|
183 |
+
# Initialize components before creating the app
|
184 |
+
vision_components = initialize_vision_model()
|
185 |
+
tts_components = setup_tts()
|
186 |
|
187 |
+
# Try to initialize LLM with Hugging Face token
|
188 |
+
hf_token = os.environ.get("HF_TOKEN")
|
189 |
+
llm_components = None
|
190 |
+
if hf_token:
|
191 |
+
try:
|
192 |
+
llm_components = initialize_llm()
|
193 |
+
except Exception as e:
|
194 |
+
print(f"Error initializing LLM: {str(e)}. Will use fallback.")
|
195 |
+
|
196 |
+
# Fallback if LLM initialization failed
|
197 |
+
if llm_components is None:
|
198 |
+
def fallback_generate_roast(caption, _):
|
199 |
+
return f"I see you {caption}. Why you not doctor yet? Your cousin studying at Harvard!"
|
200 |
|
201 |
+
llm_components = {"generate_roast": fallback_generate_roast}
|
202 |
+
|
203 |
+
# Set initial values and processing parameters
|
204 |
+
last_process_time = time.time() - 10
|
205 |
+
processing_interval = 5
|
206 |
|
207 |
+
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
208 |
+
gr.Markdown("# AsianMOM: Asian Mother Observer & Mocker")
|
209 |
+
gr.Markdown("### Camera captures what you're doing and your Asian mom responds appropriately")
|
210 |
+
|
211 |
+
with gr.Row():
|
212 |
+
with gr.Column():
|
213 |
+
video_feed = gr.Image(sources=["webcam"], streaming=True, label="Camera Feed")
|
214 |
+
|
215 |
+
with gr.Column():
|
216 |
+
analysis_output = gr.Textbox(label="What AsianMOM Sees", lines=2)
|
217 |
+
roast_output = gr.Textbox(label="AsianMOM's Thoughts", lines=4)
|
218 |
+
audio_output = gr.Audio(label="AsianMOM Says", autoplay=True)
|
219 |
+
|
220 |
+
# Define processing function
|
221 |
+
def process_webcam(image):
|
222 |
+
nonlocal last_process_time
|
223 |
+
current_time = time.time()
|
224 |
+
default_caption = ""
|
225 |
+
default_roast = ""
|
226 |
+
default_sample_rate = 16000
|
227 |
+
if tts_components is not None:
|
228 |
+
default_sample_rate = tts_components["sample_rate"]
|
229 |
+
default_audio = (default_sample_rate, np.zeros(1))
|
230 |
|
231 |
+
if current_time - last_process_time >= processing_interval and image is not None:
|
232 |
+
last_process_time = current_time
|
233 |
+
try:
|
234 |
+
caption, roast, audio = process_frame(
|
235 |
+
image,
|
236 |
+
vision_components,
|
237 |
+
llm_components,
|
238 |
+
tts_components
|
239 |
+
)
|
240 |
+
final_caption = caption if isinstance(caption, str) else default_caption
|
241 |
+
final_roast = roast if isinstance(roast, str) else default_roast
|
242 |
+
final_audio = audio if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray) else default_audio
|
243 |
+
return image, final_caption, final_roast, final_audio
|
244 |
+
except Exception as e:
|
245 |
+
print(f"Error in process_webcam: {str(e)}")
|
246 |
+
return image, default_caption, default_roast, default_audio
|
247 |
+
|
248 |
+
# Setup the processing chain
|
249 |
+
video_feed.change(
|
250 |
+
process_webcam,
|
251 |
+
inputs=[video_feed],
|
252 |
+
outputs=[video_feed, analysis_output, roast_output, audio_output]
|
253 |
+
)
|
254 |
+
|
255 |
+
return app
|
256 |
+
except Exception as e:
|
257 |
+
print(f"Error creating app: {str(e)}")
|
258 |
+
# Create a fallback simple app that reports the error
|
259 |
+
with gr.Blocks() as fallback_app:
|
260 |
+
gr.Markdown("# AsianMOM: Error Initializing")
|
261 |
+
gr.Markdown(f"Error: {str(e)}")
|
262 |
+
gr.Markdown("Please check your environment setup and try again.")
|
263 |
+
return fallback_app
|
264 |
|
265 |
if __name__ == "__main__":
|
266 |
+
try:
|
267 |
+
# Download required resources
|
268 |
+
os.system('python -m unidic download')
|
269 |
+
nltk.download('averaged_perceptron_tagger_eng')
|
270 |
+
|
271 |
+
# Create and launch app
|
272 |
+
app = create_app()
|
273 |
+
app.launch(share=True, debug=True)
|
274 |
+
except Exception as e:
|
275 |
+
print(f"Fatal error: {str(e)}")
|
276 |
+
# If all else fails, create a minimal app
|
277 |
+
with gr.Blocks() as minimal_app:
|
278 |
+
gr.Markdown("# AsianMOM: Fatal Error")
|
279 |
+
gr.Markdown(f"Fatal error: {str(e)}")
|
280 |
+
minimal_app.launch(share=True)
|