import os import gradio as gr import torch from transformers import AutoProcessor, Gemma3nForConditionalGeneration, TextIteratorStreamer from PIL import Image import threading import traceback import spaces # ----------------------------- # Config # ----------------------------- MODEL_ID = "yasserrmd/GemmaECG-Vision" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # safe CPU dtype # Generation defaults GEN_KW = dict( max_new_tokens=768, do_sample=True, temperature=1.0, top_p=0.95, top_k=64, use_cache=True, ) # Clinical prompt CLINICAL_PROMPT = """You are a clinical assistant specialized in ECG interpretation. Given an ECG image, generate a concise, structured, and medically accurate report. Use this exact format: Rhythm: PR Interval: QRS Duration: Axis: Bundle Branch Blocks: Atrial Abnormalities: Ventricular Hypertrophy: Q Wave or QS Complexes: T Wave Abnormalities: ST Segment Changes: Final Impression: Guidance: - Confirm sinus rhythm only if consistent P waves precede each QRS. - Describe PACs only if early, ectopic P waves are visible. - Do not diagnose myocardial infarction solely based on QS complexes unless accompanied by other signs (e.g., ST elevation, reciprocal changes, poor R wave progression). - Only mention axis deviation if QRS axis is clearly rightward (RAD) or leftward (LAD). - Use terms like "suggestive of" or "possible" for uncertain findings. - Avoid repetition and keep the report clinically focused. - Do not include external references or source citations. - Do not diagnose left bundle branch block unless QRS duration is ≥120 ms with typical morphology in leads I, V5, V6. - Mark T wave changes in inferior leads as “nonspecific” unless clear ST elevation or reciprocal depression is present. Your goal is to provide a structured ECG summary useful for a cardiologist or internal medicine physician. """ # ----------------------------- # Load model & processor # ----------------------------- model = Gemma3nForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=DTYPE ).to(DEVICE).eval() processor = AutoProcessor.from_pretrained(MODEL_ID) # ----------------------------- # Inference (streaming) function # ----------------------------- @spaces.GPU def analyze_ecg_stream(image: Image.Image): """ Streams model output into the Gradio textbox. Yields incremental text chunks. """ if image is None: yield "Please upload an ECG image." return # Build a multimodal chat-style message; rely on the model's chat template to inject image tokens. messages = [ { "role": "user", "content": [ {"type": "text", "text": CLINICAL_PROMPT}, {"type": "image"}, ], } ] try: # Try with chat template first (recommended for chat-tuned models) chat_text = processor.apply_chat_template(messages, add_generation_prompt=True) model_inputs = processor( text=chat_text, images=image, return_tensors="pt", ) model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()} except Exception as e: # If the template or image-token count fails, fallback to a simple text+image pack. # This handles errors like: # "Number of images does not match number of special image tokens..." fallback_note = ( "\n[Note] Falling back to a simpler prompt packing due to template/image token mismatch." ) try: model_inputs = processor( text=CLINICAL_PROMPT, images=image, return_tensors="pt", ) model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()} # Surface a short note at the start of the stream so user knows why yield fallback_note + "\n" except Exception as inner_e: err_msg = f"Input preparation failed:\n{repr(e)}\n{repr(inner_e)}" yield err_msg return # Prepare streamer streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) # Launch generation in a background thread generated_text = [] def _generate(): try: model.generate( **model_inputs, streamer=streamer, **GEN_KW ) except Exception as gen_e: # Put traceback into the stream so the user sees it (useful during debugging) tb = traceback.format_exc() streamer.put("\n\n[Generation Error]\n" + str(gen_e) + "\n" + tb) finally: streamer.end() thread = threading.Thread(target=_generate) thread.start() # Collect incremental tokens and yield buffer buffer = "" for token in streamer: buffer += token # Stream into Gradio textbox yield buffer def reset(): return None, "" # ----------------------------- # Gradio UI # ----------------------------- with gr.Blocks(css=""" .disclaimer { padding: 12px 16px; border: 1px solid #b91c1c; background: #fef2f2; color: #7f1d1d; border-radius: 8px; font-weight: 600; } .footer-note { font-size: 12px; color: #374151; } .gr-button { background-color: #1e3a8a; color: #ffffff; } """) as demo: gr.Markdown("## 🩺 ECG Interpretation Assistant — Gemma-ECG-Vision") gr.HTML("""
⚠️ Important Medical Disclaimer: This tool is for education and research purposes only. It is not a medical device and must not be used for diagnosis or treatment. Always consult a licensed clinician for interpretation and clinical decisions.
""") with gr.Row(): image_input = gr.Image(type="pil", label="Upload ECG Image", height=320) output_box = gr.Textbox( label="Generated ECG Report (Streaming)", lines=24, show_copy_button=True, autoscroll=True, ) with gr.Row(): with gr.Column(): submit_btn = gr.Button("Generate Report", variant="primary") with gr.Column(): reset_btn = gr.Button("Reset") # Wire actions: analyze_ecg_stream yields partial strings for streaming submit_btn.click( fn=analyze_ecg_stream, inputs=image_input, outputs=output_box, queue=True, api_name="analyze_ecg", ) reset_btn.click(fn=reset, outputs=[image_input, output_box]) gr.Markdown( """ """.format(model_id=MODEL_ID, device=DEVICE) ) # Enable queuing for proper streaming under concurrency #demo.queue(concurrency_count=2, max_size=16) # In hosted notebooks, you can set share=True if needed demo.launch(share=False, debug=True)