baliddeki commited on
Commit
a54164e
·
1 Parent(s): e65d0e5

fix with endpoints 2

Browse files
Files changed (1) hide show
  1. app.py +39 -30
app.py CHANGED
@@ -14,11 +14,10 @@ import gc
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # Environment setup
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
  os.environ["HF_HOME"] = "/tmp/huggingface_cache"
20
 
21
- # Model initialization
22
  tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
23
  video_model = models.video.r3d_18(weights="KINETICS400_V1")
24
  video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
@@ -28,20 +27,29 @@ projector = ImageToTextProjector(512, report_generator.config.d_model)
28
 
29
  num_classes = 4
30
  class_names = ["acute", "normal", "chronic", "lacunar"]
31
- combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer)
 
 
32
 
33
- model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
 
 
34
  state_dict = torch.load(model_file, map_location=device)
35
  combined_model.load_state_dict(state_dict)
36
  combined_model.to(device)
37
  combined_model.eval()
38
 
39
  # Image transforms
40
- image_transform = transforms.Compose([
41
- transforms.Resize((112, 112)),
42
- transforms.ToTensor(),
43
- transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
44
- ])
 
 
 
 
 
45
 
46
  def dicom_to_image(file_bytes):
47
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
@@ -50,19 +58,20 @@ def dicom_to_image(file_bytes):
50
  pixel_array = pixel_array.astype(np.uint8)
51
  return Image.fromarray(pixel_array).convert("RGB")
52
 
53
- def predict(images):
54
- if not images:
 
55
  return "No image uploaded.", ""
56
 
57
  processed_imgs = []
58
- for img in images:
59
- filename = img.name.lower()
60
  if filename.endswith((".dcm", ".ima")):
61
- file_bytes = img.read()
62
  dicom_img = dicom_to_image(file_bytes)
63
  processed_imgs.append(dicom_img)
64
  else:
65
- pil_img = Image.open(img).convert("RGB")
66
  processed_imgs.append(pil_img)
67
 
68
  n_frames = 16
@@ -72,7 +81,9 @@ def predict(images):
72
  for i in np.linspace(0, len(processed_imgs) - 1, n_frames, dtype=int)
73
  ]
74
  else:
75
- images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
 
 
76
 
77
  tensor_imgs = [image_transform(i) for i in images_sampled]
78
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
@@ -88,21 +99,19 @@ def predict(images):
88
 
89
  return class_name, report[0] if report else "No report generated."
90
 
91
- # Define Gradio Interface explicitly
92
- demo = gr.Interface(
93
- fn=predict,
94
- inputs=gr.File(
95
- file_types=[".dcm", ".jpg", ".jpeg", ".png"],
96
  file_count="multiple",
 
97
  label="Upload CT Scan Images",
98
- ),
99
- outputs=[
100
- gr.Textbox(label="Predicted Class"),
101
- gr.Textbox(label="Generated Report")
102
- ],
103
- title="Phronesis Medical Report Generator",
104
- description="Upload CT scan DICOM or image files. Returns diagnosis classification and generated report.",
105
- )
106
 
107
- # Launch with explicit api_name for REST API compatibility
108
  demo.launch()
 
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  os.environ["HF_HOME"] = "/tmp/huggingface_cache"
19
 
20
+ # Load tokenizer and models
21
  tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
22
  video_model = models.video.r3d_18(weights="KINETICS400_V1")
23
  video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
 
27
 
28
  num_classes = 4
29
  class_names = ["acute", "normal", "chronic", "lacunar"]
30
+ combined_model = CombinedModel(
31
+ video_model, report_generator, num_classes, projector, tokenizer
32
+ )
33
 
34
+ model_file = hf_hub_download(
35
+ "baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN
36
+ )
37
  state_dict = torch.load(model_file, map_location=device)
38
  combined_model.load_state_dict(state_dict)
39
  combined_model.to(device)
40
  combined_model.eval()
41
 
42
  # Image transforms
43
+ image_transform = transforms.Compose(
44
+ [
45
+ transforms.Resize((112, 112)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(
48
+ mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]
49
+ ),
50
+ ]
51
+ )
52
+
53
 
54
  def dicom_to_image(file_bytes):
55
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
 
58
  pixel_array = pixel_array.astype(np.uint8)
59
  return Image.fromarray(pixel_array).convert("RGB")
60
 
61
+
62
+ def predict(files):
63
+ if not files:
64
  return "No image uploaded.", ""
65
 
66
  processed_imgs = []
67
+ for file in files:
68
+ filename = file.name.lower()
69
  if filename.endswith((".dcm", ".ima")):
70
+ file_bytes = file.read()
71
  dicom_img = dicom_to_image(file_bytes)
72
  processed_imgs.append(dicom_img)
73
  else:
74
+ pil_img = Image.open(file).convert("RGB")
75
  processed_imgs.append(pil_img)
76
 
77
  n_frames = 16
 
81
  for i in np.linspace(0, len(processed_imgs) - 1, n_frames, dtype=int)
82
  ]
83
  else:
84
+ images_sampled = processed_imgs + [processed_imgs[-1]] * (
85
+ n_frames - len(processed_imgs)
86
+ )
87
 
88
  tensor_imgs = [image_transform(i) for i in images_sampled]
89
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
 
99
 
100
  return class_name, report[0] if report else "No report generated."
101
 
102
+
103
+ # Gradio Blocks setup (explicitly)
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("## 🩺 Phronesis Medical Report Generator")
106
+ file_input = gr.File(
107
  file_count="multiple",
108
+ file_types=[".dcm", ".jpg", ".jpeg", ".png"],
109
  label="Upload CT Scan Images",
110
+ )
111
+ btn = gr.Button("Generate Report")
112
+ class_output = gr.Textbox(label="Predicted Class")
113
+ report_output = gr.Textbox(label="Generated Report")
114
+
115
+ btn.click(fn=predict, inputs=file_input, outputs=[class_output, report_output])
 
 
116
 
 
117
  demo.launch()