Ravindu9904's picture
Update app.py
412e767 verified
raw
history blame
2.18 kB
import gradio as gr
import pydicom
import numpy as np
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
# Load the model and processor
model_id = "MONAI/Llama3-VILA-M3-3B"
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
processor = AutoProcessor.from_pretrained(model_id)
def dicom_to_image(files):
# Read all DICOM files and sort by InstanceNumber if available
slices = []
for file in files:
ds = pydicom.dcmread(file.name)
slices.append((ds, ds.get('InstanceNumber', 0)))
slices.sort(key=lambda x: x[1])
images = [s[0].pixel_array for s in slices]
# If multiple slices, take the middle one
img = images[len(images)//2] if len(images) > 1 else images[0]
# Normalize and convert to 8-bit
img = img.astype(np.float32)
img = (img - img.min()) / (img.max() - img.min() + 1e-5) * 255
img = img.astype(np.uint8)
pil_img = Image.fromarray(img)
return pil_img
def interpret(files, prompt):
pil_img = dicom_to_image(files)
# Prepare input for the model
inputs = processor(images=pil_img, text=prompt, return_tensors="pt")
# Move to GPU if available
if torch.cuda.is_available():
model.to("cuda")
for k in inputs:
inputs[k] = inputs[k].to("cuda")
# Generate report
output = model.generate(**inputs, max_new_tokens=256)
report = processor.decode(output[0], skip_special_tokens=True)
return pil_img, report
iface = gr.Interface(
fn=interpret,
inputs=[
gr.File(file_count="multiple", label="Upload DICOM files"),
gr.Textbox(label="Prompt", value="Describe the findings in this image.")
],
outputs=[
gr.Image(type="pil", label="Selected Image"),
gr.Textbox(label="AI-generated Report")
],
title="Radiology Image Interpretation (VILA-M3-3B)",
description="Upload DICOM files (CT, MRI, or X-ray). The app will select the middle slice (for stacks), send it to MONAI/Llama3-VILA-M3-3B, and display the AI-generated report."
)
if __name__ == "__main__":
iface.launch()