baliddeki commited on
Commit
8c6ba75
·
1 Parent(s): a54164e

fix with endpoints 2

Browse files
Files changed (1) hide show
  1. app.py +30 -41
app.py CHANGED
@@ -17,7 +17,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  os.environ["HF_HOME"] = "/tmp/huggingface_cache"
19
 
20
- # Load tokenizer and models
21
  tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
22
  video_model = models.video.r3d_18(weights="KINETICS400_V1")
23
  video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
@@ -27,29 +27,19 @@ projector = ImageToTextProjector(512, report_generator.config.d_model)
27
 
28
  num_classes = 4
29
  class_names = ["acute", "normal", "chronic", "lacunar"]
30
- combined_model = CombinedModel(
31
- video_model, report_generator, num_classes, projector, tokenizer
32
- )
33
 
34
- model_file = hf_hub_download(
35
- "baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN
36
- )
37
  state_dict = torch.load(model_file, map_location=device)
38
  combined_model.load_state_dict(state_dict)
39
  combined_model.to(device)
40
  combined_model.eval()
41
 
42
- # Image transforms
43
- image_transform = transforms.Compose(
44
- [
45
- transforms.Resize((112, 112)),
46
- transforms.ToTensor(),
47
- transforms.Normalize(
48
- mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]
49
- ),
50
- ]
51
- )
52
-
53
 
54
  def dicom_to_image(file_bytes):
55
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
@@ -58,32 +48,28 @@ def dicom_to_image(file_bytes):
58
  pixel_array = pixel_array.astype(np.uint8)
59
  return Image.fromarray(pixel_array).convert("RGB")
60
 
61
-
62
  def predict(files):
63
  if not files:
64
- return "No image uploaded.", ""
65
 
66
  processed_imgs = []
67
- for file in files:
68
- filename = file.name.lower()
69
  if filename.endswith((".dcm", ".ima")):
70
- file_bytes = file.read()
71
- dicom_img = dicom_to_image(file_bytes)
72
- processed_imgs.append(dicom_img)
73
  else:
74
- pil_img = Image.open(file).convert("RGB")
75
- processed_imgs.append(pil_img)
76
 
77
  n_frames = 16
78
  if len(processed_imgs) >= n_frames:
79
  images_sampled = [
80
  processed_imgs[i]
81
- for i in np.linspace(0, len(processed_imgs) - 1, n_frames, dtype=int)
82
  ]
83
  else:
84
- images_sampled = processed_imgs + [processed_imgs[-1]] * (
85
- n_frames - len(processed_imgs)
86
- )
87
 
88
  tensor_imgs = [image_transform(i) for i in images_sampled]
89
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
@@ -99,19 +85,22 @@ def predict(files):
99
 
100
  return class_name, report[0] if report else "No report generated."
101
 
102
-
103
- # Gradio Blocks setup (explicitly)
104
  with gr.Blocks() as demo:
105
- gr.Markdown("## 🩺 Phronesis Medical Report Generator")
106
- file_input = gr.File(
107
- file_count="multiple",
108
- file_types=[".dcm", ".jpg", ".jpeg", ".png"],
109
- label="Upload CT Scan Images",
110
- )
111
- btn = gr.Button("Generate Report")
 
 
 
 
112
  class_output = gr.Textbox(label="Predicted Class")
113
  report_output = gr.Textbox(label="Generated Report")
114
 
115
- btn.click(fn=predict, inputs=file_input, outputs=[class_output, report_output])
116
 
117
  demo.launch()
 
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  os.environ["HF_HOME"] = "/tmp/huggingface_cache"
19
 
20
+ # Model loading
21
  tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
22
  video_model = models.video.r3d_18(weights="KINETICS400_V1")
23
  video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
 
27
 
28
  num_classes = 4
29
  class_names = ["acute", "normal", "chronic", "lacunar"]
30
+ combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer)
 
 
31
 
32
+ model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
 
 
33
  state_dict = torch.load(model_file, map_location=device)
34
  combined_model.load_state_dict(state_dict)
35
  combined_model.to(device)
36
  combined_model.eval()
37
 
38
+ image_transform = transforms.Compose([
39
+ transforms.Resize((112, 112)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
42
+ ])
 
 
 
 
 
 
43
 
44
  def dicom_to_image(file_bytes):
45
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
 
48
  pixel_array = pixel_array.astype(np.uint8)
49
  return Image.fromarray(pixel_array).convert("RGB")
50
 
 
51
  def predict(files):
52
  if not files:
53
+ return "No images uploaded.", ""
54
 
55
  processed_imgs = []
56
+ for file_obj in files:
57
+ filename = file_obj.name.lower()
58
  if filename.endswith((".dcm", ".ima")):
59
+ file_bytes = file_obj.read()
60
+ img = dicom_to_image(file_bytes)
 
61
  else:
62
+ img = Image.open(file_obj).convert("RGB")
63
+ processed_imgs.append(img)
64
 
65
  n_frames = 16
66
  if len(processed_imgs) >= n_frames:
67
  images_sampled = [
68
  processed_imgs[i]
69
+ for i in np.linspace(0, len(processed_imgs)-1, n_frames, dtype=int)
70
  ]
71
  else:
72
+ images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
 
 
73
 
74
  tensor_imgs = [image_transform(i) for i in images_sampled]
75
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
 
85
 
86
  return class_name, report[0] if report else "No report generated."
87
 
88
+ # Gradio Blocks (100% reliable approach)
 
89
  with gr.Blocks() as demo:
90
+ gr.Markdown("# 🩺 Phronesis Medical Report Generator")
91
+
92
+ upload_button = gr.UploadButton("Upload CT Scan Images", file_types=[".dcm", ".jpg", ".jpeg", ".png"], file_count="multiple")
93
+ files_state = gr.State([])
94
+
95
+ def store_files(new_files):
96
+ return new_files
97
+
98
+ upload_button.upload(store_files, upload_button, files_state)
99
+
100
+ generate_btn = gr.Button("Generate Report")
101
  class_output = gr.Textbox(label="Predicted Class")
102
  report_output = gr.Textbox(label="Generated Report")
103
 
104
+ generate_btn.click(predict, files_state, [class_output, report_output])
105
 
106
  demo.launch()