Ravindu9904 commited on
Commit
412e767
·
verified ·
1 Parent(s): 29d60bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -51
app.py CHANGED
@@ -1,65 +1,58 @@
1
  import gradio as gr
2
  import pydicom
3
  import numpy as np
4
- import matplotlib.pyplot as plt
 
5
  import torch
6
- from monai.networks.nets import UNet
7
- from monai.transforms import Compose, ScaleIntensity, ToTensor
8
 
9
- # 1. Define a simple MONAI model (2D UNet)
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- model = UNet(
12
- dimensions=2,
13
- in_channels=1,
14
- out_channels=1,
15
- channels=(16, 32, 64, 128, 256),
16
- strides=(2, 2, 2, 2),
17
- num_res_units=2,
18
- ).to(device)
19
- model.eval() # Set model to evaluation mode
20
 
21
- # 2. Dummy weights (for demo only)
22
- # In real use, load pre-trained weights:
23
- # model.load_state_dict(torch.load("your_model.pth", map_location=device))
24
-
25
- def interpret_dicom(files):
26
  slices = []
27
  for file in files:
28
  ds = pydicom.dcmread(file.name)
29
- slices.append(ds.pixel_array)
30
- slices = np.array(slices)
31
- mid_slice = slices[len(slices)//2]
32
-
33
- # Preprocess for MONAI model
34
- transform = Compose([ScaleIntensity(), ToTensor()])
35
- input_tensor = transform(mid_slice.astype(np.float32))
36
- input_tensor = input_tensor.unsqueeze(0).to(device) # Add batch dimension
37
-
38
- # 3. Run through MONAI model (dummy output for now)
39
- with torch.no_grad():
40
- output = model(input_tensor)
41
- output_np = output.cpu().numpy()[0, 0]
42
-
43
- # 4. Show original and model output side by side
44
- fig, axs = plt.subplots(1, 2, figsize=(8, 4))
45
- axs[0].imshow(mid_slice, cmap='gray')
46
- axs[0].set_title('Original')
47
- axs[0].axis('off')
48
- axs[1].imshow(output_np, cmap='hot')
49
- axs[1].set_title('Model Output')
50
- axs[1].axis('off')
51
- plt.tight_layout()
52
- plt.savefig('output.png')
53
- plt.close()
54
-
55
- return 'output.png', "Interpretation: Model output shown (demo weights)."
56
 
57
  iface = gr.Interface(
58
- fn=interpret_dicom,
59
- inputs=gr.File(file_count="multiple", label="Upload DICOM files"),
60
- outputs=[gr.Image(type="filepath", label="Result"), gr.Textbox(label="Interpretation")],
61
- title="DICOM Radiology Interpreter with MONAI",
62
- description="Upload your DICOM files (e.g., CT scan slices). The app will show the middle slice and a MONAI model output."
 
 
 
 
 
 
63
  )
64
 
65
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import pydicom
3
  import numpy as np
4
+ from PIL import Image
5
+ from transformers import AutoModelForVision2Seq, AutoProcessor
6
  import torch
 
 
7
 
8
+ # Load the model and processor
9
+ model_id = "MONAI/Llama3-VILA-M3-3B"
10
+ model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
11
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
12
 
13
+ def dicom_to_image(files):
14
+ # Read all DICOM files and sort by InstanceNumber if available
 
 
 
15
  slices = []
16
  for file in files:
17
  ds = pydicom.dcmread(file.name)
18
+ slices.append((ds, ds.get('InstanceNumber', 0)))
19
+ slices.sort(key=lambda x: x[1])
20
+ images = [s[0].pixel_array for s in slices]
21
+ # If multiple slices, take the middle one
22
+ img = images[len(images)//2] if len(images) > 1 else images[0]
23
+ # Normalize and convert to 8-bit
24
+ img = img.astype(np.float32)
25
+ img = (img - img.min()) / (img.max() - img.min() + 1e-5) * 255
26
+ img = img.astype(np.uint8)
27
+ pil_img = Image.fromarray(img)
28
+ return pil_img
29
+
30
+ def interpret(files, prompt):
31
+ pil_img = dicom_to_image(files)
32
+ # Prepare input for the model
33
+ inputs = processor(images=pil_img, text=prompt, return_tensors="pt")
34
+ # Move to GPU if available
35
+ if torch.cuda.is_available():
36
+ model.to("cuda")
37
+ for k in inputs:
38
+ inputs[k] = inputs[k].to("cuda")
39
+ # Generate report
40
+ output = model.generate(**inputs, max_new_tokens=256)
41
+ report = processor.decode(output[0], skip_special_tokens=True)
42
+ return pil_img, report
 
 
43
 
44
  iface = gr.Interface(
45
+ fn=interpret,
46
+ inputs=[
47
+ gr.File(file_count="multiple", label="Upload DICOM files"),
48
+ gr.Textbox(label="Prompt", value="Describe the findings in this image.")
49
+ ],
50
+ outputs=[
51
+ gr.Image(type="pil", label="Selected Image"),
52
+ gr.Textbox(label="AI-generated Report")
53
+ ],
54
+ title="Radiology Image Interpretation (VILA-M3-3B)",
55
+ description="Upload DICOM files (CT, MRI, or X-ray). The app will select the middle slice (for stacks), send it to MONAI/Llama3-VILA-M3-3B, and display the AI-generated report."
56
  )
57
 
58
  if __name__ == "__main__":