import gradio as gr import pydicom import numpy as np from PIL import Image from transformers import AutoModelForVision2Seq, AutoProcessor import torch # Load the model and processor model_id = "MONAI/Llama3-VILA-M3-3B" model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) processor = AutoProcessor.from_pretrained(model_id) def dicom_to_image(files): # Read all DICOM files and sort by InstanceNumber if available slices = [] for file in files: ds = pydicom.dcmread(file.name) slices.append((ds, ds.get('InstanceNumber', 0))) slices.sort(key=lambda x: x[1]) images = [s[0].pixel_array for s in slices] # If multiple slices, take the middle one img = images[len(images)//2] if len(images) > 1 else images[0] # Normalize and convert to 8-bit img = img.astype(np.float32) img = (img - img.min()) / (img.max() - img.min() + 1e-5) * 255 img = img.astype(np.uint8) pil_img = Image.fromarray(img) return pil_img def interpret(files, prompt): pil_img = dicom_to_image(files) # Prepare input for the model inputs = processor(images=pil_img, text=prompt, return_tensors="pt") # Move to GPU if available if torch.cuda.is_available(): model.to("cuda") for k in inputs: inputs[k] = inputs[k].to("cuda") # Generate report output = model.generate(**inputs, max_new_tokens=256) report = processor.decode(output[0], skip_special_tokens=True) return pil_img, report iface = gr.Interface( fn=interpret, inputs=[ gr.File(file_count="multiple", label="Upload DICOM files"), gr.Textbox(label="Prompt", value="Describe the findings in this image.") ], outputs=[ gr.Image(type="pil", label="Selected Image"), gr.Textbox(label="AI-generated Report") ], title="Radiology Image Interpretation (VILA-M3-3B)", description="Upload DICOM files (CT, MRI, or X-ray). The app will select the middle slice (for stacks), send it to MONAI/Llama3-VILA-M3-3B, and display the AI-generated report." ) if __name__ == "__main__": iface.launch()