baliddeki commited on
Commit
e65d0e5
·
1 Parent(s): d4d94d2

fix with endpoints

Browse files
Files changed (1) hide show
  1. app.py +21 -27
app.py CHANGED
@@ -14,9 +14,11 @@ import gc
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # Load tokenizer and models
18
  HF_TOKEN = os.getenv("HF_TOKEN")
 
19
 
 
20
  tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
21
  video_model = models.video.r3d_18(weights="KINETICS400_V1")
22
  video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
@@ -26,27 +28,20 @@ projector = ImageToTextProjector(512, report_generator.config.d_model)
26
 
27
  num_classes = 4
28
  class_names = ["acute", "normal", "chronic", "lacunar"]
29
- combined_model = CombinedModel(
30
- video_model, report_generator, num_classes, projector, tokenizer
31
- )
32
- model_file = hf_hub_download(
33
- "baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN
34
- )
35
  state_dict = torch.load(model_file, map_location=device)
36
  combined_model.load_state_dict(state_dict)
 
37
  combined_model.eval()
38
 
39
  # Image transforms
40
- image_transform = transforms.Compose(
41
- [
42
- transforms.Resize((112, 112)),
43
- transforms.ToTensor(),
44
- transforms.Normalize(
45
- mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]
46
- ),
47
- ]
48
- )
49
-
50
 
51
  def dicom_to_image(file_bytes):
52
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
@@ -55,23 +50,21 @@ def dicom_to_image(file_bytes):
55
  pixel_array = pixel_array.astype(np.uint8)
56
  return Image.fromarray(pixel_array).convert("RGB")
57
 
58
-
59
  def predict(images):
60
  if not images:
61
  return "No image uploaded.", ""
62
 
63
- # Convert images
64
  processed_imgs = []
65
  for img in images:
66
  filename = img.name.lower()
67
  if filename.endswith((".dcm", ".ima")):
68
- dicom_img = dicom_to_image(img.read())
 
69
  processed_imgs.append(dicom_img)
70
  else:
71
  pil_img = Image.open(img).convert("RGB")
72
  processed_imgs.append(pil_img)
73
 
74
- # Sample or pad
75
  n_frames = 16
76
  if len(processed_imgs) >= n_frames:
77
  images_sampled = [
@@ -79,9 +72,7 @@ def predict(images):
79
  for i in np.linspace(0, len(processed_imgs) - 1, n_frames, dtype=int)
80
  ]
81
  else:
82
- images_sampled = processed_imgs + [processed_imgs[-1]] * (
83
- n_frames - len(processed_imgs)
84
- )
85
 
86
  tensor_imgs = [image_transform(i) for i in images_sampled]
87
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
@@ -97,8 +88,7 @@ def predict(images):
97
 
98
  return class_name, report[0] if report else "No report generated."
99
 
100
-
101
- # Gradio interface
102
  demo = gr.Interface(
103
  fn=predict,
104
  inputs=gr.File(
@@ -106,9 +96,13 @@ demo = gr.Interface(
106
  file_count="multiple",
107
  label="Upload CT Scan Images",
108
  ),
109
- outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Generated Report")],
 
 
 
110
  title="Phronesis Medical Report Generator",
111
  description="Upload CT scan DICOM or image files. Returns diagnosis classification and generated report.",
112
  )
113
 
 
114
  demo.launch()
 
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Environment setup
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache"
20
 
21
+ # Model initialization
22
  tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
23
  video_model = models.video.r3d_18(weights="KINETICS400_V1")
24
  video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
 
28
 
29
  num_classes = 4
30
  class_names = ["acute", "normal", "chronic", "lacunar"]
31
+ combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer)
32
+
33
+ model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
 
 
 
34
  state_dict = torch.load(model_file, map_location=device)
35
  combined_model.load_state_dict(state_dict)
36
+ combined_model.to(device)
37
  combined_model.eval()
38
 
39
  # Image transforms
40
+ image_transform = transforms.Compose([
41
+ transforms.Resize((112, 112)),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
44
+ ])
 
 
 
 
 
45
 
46
  def dicom_to_image(file_bytes):
47
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
 
50
  pixel_array = pixel_array.astype(np.uint8)
51
  return Image.fromarray(pixel_array).convert("RGB")
52
 
 
53
  def predict(images):
54
  if not images:
55
  return "No image uploaded.", ""
56
 
 
57
  processed_imgs = []
58
  for img in images:
59
  filename = img.name.lower()
60
  if filename.endswith((".dcm", ".ima")):
61
+ file_bytes = img.read()
62
+ dicom_img = dicom_to_image(file_bytes)
63
  processed_imgs.append(dicom_img)
64
  else:
65
  pil_img = Image.open(img).convert("RGB")
66
  processed_imgs.append(pil_img)
67
 
 
68
  n_frames = 16
69
  if len(processed_imgs) >= n_frames:
70
  images_sampled = [
 
72
  for i in np.linspace(0, len(processed_imgs) - 1, n_frames, dtype=int)
73
  ]
74
  else:
75
+ images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
 
 
76
 
77
  tensor_imgs = [image_transform(i) for i in images_sampled]
78
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
 
88
 
89
  return class_name, report[0] if report else "No report generated."
90
 
91
+ # Define Gradio Interface explicitly
 
92
  demo = gr.Interface(
93
  fn=predict,
94
  inputs=gr.File(
 
96
  file_count="multiple",
97
  label="Upload CT Scan Images",
98
  ),
99
+ outputs=[
100
+ gr.Textbox(label="Predicted Class"),
101
+ gr.Textbox(label="Generated Report")
102
+ ],
103
  title="Phronesis Medical Report Generator",
104
  description="Upload CT scan DICOM or image files. Returns diagnosis classification and generated report.",
105
  )
106
 
107
+ # Launch with explicit api_name for REST API compatibility
108
  demo.launch()