import os import threading import traceback import gradio as gr import torch from transformers import AutoProcessor, Gemma3nForConditionalGeneration, TextIteratorStreamer from PIL import Image import inspect import traceback import spaces # ----------------------------- # Config # ----------------------------- MODEL_ID = "yasserrmd/GemmaECG-Vision" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # safe CPU dtype # Generation defaults GEN_KW = dict( max_new_tokens=768, do_sample=True, temperature=1.0, top_p=0.95, top_k=64, use_cache=True, ) # Clinical prompt CLINICAL_PROMPT = """You are a clinical assistant specialized in ECG interpretation. Given an ECG image, generate a concise, structured, and medically accurate report. Use this exact format: Rhythm: PR Interval: QRS Duration: Axis: Bundle Branch Blocks: Atrial Abnormalities: Ventricular Hypertrophy: Q Wave or QS Complexes: T Wave Abnormalities: ST Segment Changes: Final Impression: Guidance: - Confirm sinus rhythm only if consistent P waves precede each QRS. - Describe PACs only if early, ectopic P waves are visible. - Do not diagnose myocardial infarction solely based on QS complexes unless accompanied by other signs (e.g., ST elevation, reciprocal changes, poor R wave progression). - Only mention axis deviation if QRS axis is clearly rightward (RAD) or leftward (LAD). - Use terms like "suggestive of" or "possible" for uncertain findings. - Avoid repetition and keep the report clinically focused. - Do not include external references or source citations. - Do not diagnose left bundle branch block unless QRS duration is ≥120 ms with typical morphology in leads I, V5, V6. - Mark T wave changes in inferior leads as “nonspecific” unless clear ST elevation or reciprocal depression is present. Your goal is to provide a structured ECG summary useful for a cardiologist or internal medicine physician. """ # ----------------------------- # Load model & processor # ----------------------------- model = Gemma3nForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=DTYPE ).to(DEVICE).eval() processor = AutoProcessor.from_pretrained(MODEL_ID) # ----------------------------- # Streaming generator # ----------------------------- @spaces.GPU def analyze_ecg_stream(image: Image.Image): if image is None: yield "Please upload an ECG image." return messages = [ {"role": "user", "content": [ {"type": "text", "text": CLINICAL_PROMPT}, {"type": "image"}, ]} ] # Prepare inputs (try chat template; fallback to plain text+image) try: chat_text = processor.apply_chat_template(messages, add_generation_prompt=True) model_inputs = processor(text=chat_text, images=image, return_tensors="pt") except Exception: # Fallback when the template/image token count mismatches model_inputs = processor(text=CLINICAL_PROMPT, images=image, return_tensors="pt") yield "[Note] Using fallback prompt packing.\n" model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()} # Streamer must use the tokenizer (not the processor) streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True ) def _generate(): try: model.generate( **model_inputs, streamer=streamer, **GEN_KW, ) except Exception as e: streamer.put("\n\n[Generation Error]\n" + traceback.format_exc()) finally: streamer.end() t = threading.Thread(target=_generate, daemon=True) t.start() buf = "" for piece in streamer: buf += piece yield buf def reset(): return None, "" # ----------------------------- # UI # ----------------------------- theme = gr.themes.Soft(primary_hue="indigo", neutral_hue="slate") custom_css = """ #app { max-width: 1100px; margin: 0 auto; } .header { display:flex; align-items:center; justify-content:space-between; padding: 16px 14px; border-radius: 14px; background: linear-gradient(135deg, #1f2937 0%, #111827 100%); color: #fff; box-shadow: 0 6px 20px rgba(0,0,0,0.25); } .brand { font-size: 18px; font-weight: 700; letter-spacing: 0.3px; } .disclaimer { margin-top: 12px; padding: 12px 14px; border-radius: 12px; background: #fef2f2; color:#7f1d1d; border:1px solid #fecaca; font-weight:600; } .card { background: #ffffff; border: 1px solid #e5e7eb; border-radius: 14px; padding: 16px; box-shadow: 0 8px 18px rgba(17,24,39,0.06); } footer { font-size: 12px; color:#6b7280; margin-top: 8px; } .gr-button { background-color:#1e3a8a !important; color:#fff !important; } """ with gr.Blocks(theme=theme, css=custom_css, elem_id="app") as demo: with gr.Row(): gr.HTML("""
🩺 ECG Interpretation Assistant
Gemma-ECG-Vision
⚠️ Education & Research Only: This tool is not a medical device and must not be used for diagnosis or treatment. Always consult a licensed clinician for interpretation and clinical decisions.
""") with gr.Row(equal_height=True): with gr.Column(scale=1): with gr.Group(elem_classes="card"): image_input = gr.Image( type="pil", label="Upload ECG Image", height=360, show_label=True ) with gr.Row(): submit_btn = gr.Button("Generate Report", variant="primary") reset_btn = gr.Button("Reset") with gr.Column(scale=2): with gr.Group(elem_classes="card"): output_box = gr.Textbox( label="Generated ECG Report (Streaming)", lines=26, show_copy_button=True, autoscroll=True, placeholder="The model's report will appear here…", ) gr.Markdown( "Tip: Clear, high-resolution ECGs with visible lead labels improve P wave and ST-segment assessment." ) gr.HTML(f""" """) submit_btn.click( fn=analyze_ecg_stream, inputs=image_input, outputs=output_box, queue=True, api_name="analyze_ecg", ) reset_btn.click(reset, outputs=[image_input, output_box]) def queue_with_compat(demo, max_size=32, limit=4): params = inspect.signature(gr.Blocks.queue).parameters if "concurrency_count" in params: # Older Gradio 3.x / early 4.x return demo.queue(concurrency_count=limit, max_size=max_size) elif "default_concurrency_limit" in params: # Newer Gradio 4.x return demo.queue(default_concurrency_limit=limit, max_size=max_size) else: # Fallback – no knobs exposed return demo.queue() # … build your UI as before … queue_with_compat(demo, max_size=32, limit=4) demo.launch(share=False, debug=True)