ryhm commited on
Commit
0b41393
ยท
verified ยท
1 Parent(s): 0dd8dff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +439 -0
app.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ # from app_util import ContextDetDemo
4
+ import torch
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import torchvision.transforms as transforms
8
+ from utils.my_model import MyCNN
9
+ from models.common import DetectMultiBackend
10
+ import numpy as np
11
+ import csv
12
+ import torch.nn.functional as F
13
+ from PIL import Image, ImageOps
14
+ from utils.augmentations import letterbox
15
+ from utils.general import (scale_boxes, non_max_suppression)
16
+ import pandas as pd
17
+ import os
18
+
19
+ from torchvision.ops import roi_align
20
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
21
+ increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh,get_fixed_xyxy)
22
+ # Initialize Model with Error Handling
23
+ try:
24
+ # model = DetectMultiBackend('best.pt')
25
+ # model = DetectMultiBackend('best.pt')
26
+
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+ cell_attribute_model= MyCNN(num_classes=12, dropout_prob=0.5, in_channels=480).cpu()
29
+ folder_name = '/home/iml1/AR/Sparse_Det_TMI/Attribute_model'
30
+ custom_weights_path = f"Attridet_weight/Attrihead_hcm_100x.pth"
31
+ custom_weights = torch.load(custom_weights_path,map_location=torch.device('cpu'))
32
+ cell_attribute_model.load_state_dict(custom_weights)
33
+ cell_attribute_model.eval().to(device)
34
+
35
+ model = DetectMultiBackend('Attridet_weight/hcm_100x_yolo.pt')
36
+ except Exception as e:
37
+ print(f"Error loading model: {e}")
38
+
39
+ header = """
40
+ <div align=center>
41
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
42
+ Leukemia Detection with Morphology Attributes
43
+ </h1>
44
+ </div>
45
+ """
46
+
47
+ abstract = """
48
+ ๐Ÿค— This is the demo for <b>Leukemia Detection with with Morphology Attributes</b>.
49
+
50
+ ๐Ÿ†’ Our goal is to detect infected cells with better Morphology for the bettre diagnosis explainabilty.
51
+
52
+ โšก For faster inference, you may duplicate the space and use the GPU setting.
53
+ """
54
+
55
+ footer = r"""
56
+ ๐Ÿฆ **Github Repo**
57
+ We would be grateful if you consider starring our <a href="Website">https://github.com/intelligentMachines-ITU/Blood-Cancer-Dataset-Lukemia-Attri-MICCAI-2024</a>
58
+
59
+ ๐Ÿ“ **Citation**
60
+ ```bibtex
61
+ @inproceedings{rehman2024large,
62
+ title={A large-scale multi domain leukemia dataset for the white blood cells detection with morphological attributes for explainability},
63
+ author={Rehman, Abdul and Meraj, Talha and Minhas, Aiman Mahmood and Imran, Ayisha and Ali, Mohsen and Sultani, Waqas},
64
+ booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
65
+ pages={553--563},
66
+ year={2024},
67
+ organization={Springer}
68
+ }
69
+
70
+
71
+ ๐Ÿ“ง **Contact**
72
+ If you have any questions, please feel free to contact Abdul Rehman <b>(phdcs23002@itu.edu.pk)</b>.
73
+ """
74
+
75
+ css = """
76
+ h1#title {
77
+ text-align: center;
78
+ }
79
+ """
80
+
81
+
82
+
83
+
84
+ def capture_image(pil_img):
85
+ # if self.session_started:
86
+ # slide_number = self.slide_number_entry.text().strip()
87
+ # if slide_number:
88
+
89
+ # self.slide_dir = os.path.join(os.getcwd(), slide_number)
90
+ # # print(slide_dir)
91
+ # image_path = os.path.join(self.slide_dir, f"image_{self.image_counter}.png")
92
+ # ret, frame = self.camera.read()
93
+
94
+
95
+ # self.image_counter_label.setText(f"{self.image_counter}")
96
+ # cv2.imwrite(image_path, frame)
97
+
98
+ conf_thres=0.1
99
+ iou_thres=0.45
100
+ max_det=1000
101
+ hide_labels=False
102
+ hide_conf=False
103
+
104
+ all_predictions = []
105
+ # pil_img = Image.fromarray(frame)
106
+ image = pil_img.resize((640,640), Image.LANCZOS)
107
+ im0 = np.array(image)
108
+
109
+
110
+ im = letterbox(im0, 640, 32, auto=True)[0] # padded resize
111
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
112
+ img = np.ascontiguousarray(im)
113
+ img= torch.from_numpy(img)
114
+
115
+
116
+
117
+
118
+
119
+ # transform = transforms.Compose([
120
+ # transforms.ToPILImage(), # Convert numpy array to PIL Image
121
+ # transforms.Resize((640, 640)), # Resize image
122
+ # transforms.ToTensor(), # Convert PIL Image to tensor
123
+ # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize
124
+ # ])
125
+ # # Add batch dimension
126
+
127
+ # # Inference
128
+ # # pred, int_feats = model(img, augment=False, visualize=False)
129
+ # frame=transform(frame)
130
+ img = img.half() if model.fp16 else img.float() # uint8 to fp16/32
131
+ img /= 255
132
+
133
+ # Inference
134
+ img=img.unsqueeze(0)
135
+ pred, int_feats,_ = model(img, augment=False, visualize=False)
136
+
137
+
138
+ #attri
139
+
140
+ int_feats_p2 = int_feats[0][0].to(torch.float32).unsqueeze(0)
141
+ int_feats_p3 = int_feats[1][0].to(torch.float32).unsqueeze(0)
142
+ in_channels = int_feats_p2.shape[1]+int_feats_p3.shape[1]
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+ # Apply NMS
170
+ pred = non_max_suppression(pred, conf_thres, iou_thres, max_det=max_det)
171
+
172
+
173
+ if (len(pred[0])>0):
174
+ all_top_indices_cell_pred = []
175
+ top_indices_cell_pred = []
176
+ pred_Nuclear_Chromatin_array = []
177
+ pred_Nuclear_Shape_array = []
178
+ pred_Nucleus_array = []
179
+ pred_Cytoplasm_array = []
180
+ pred_Cytoplasmic_Basophilia_array = []
181
+ pred_Cytoplasmic_Vacuoles_array = []
182
+
183
+ for i in range(len(pred[0])):
184
+ if pred[0][i].numel() > 0: # Check if the tensor is not empty
185
+
186
+ pred_tensor = pred[0][i][0:4]
187
+
188
+ if pred[0][i][5] != 0:
189
+
190
+ img_shape_tensor = torch.tensor([img.shape[2], img.shape[3],img.shape[2],img.shape[3]]).to(device)
191
+
192
+ normalized_xyxy=pred_tensor.to(device) / img_shape_tensor
193
+ p2_feature_shape_tensor = torch.tensor([int_feats[0].shape[1], int_feats[0].shape[2],int_feats[0].shape[1],int_feats[0].shape[2]]).to(device) # reduce_channels_layer = torch.nn.Conv2d(1280, 250, kernel_size=1).to(device)
194
+ p3_feature_shape_tensor = torch.tensor([int_feats[1].shape[1], int_feats[1].shape[2],int_feats[1].shape[1],int_feats[1].shape[2]]).to(device) # reduce_channels_layer = torch.nn.Conv2d(1280, 250, kernel_size=1).to(device)
195
+
196
+
197
+ p2_normalized_xyxy = normalized_xyxy*p2_feature_shape_tensor
198
+ p3_normalized_xyxy = normalized_xyxy*p3_feature_shape_tensor
199
+ p2_x_min, p2_y_min, p2_x_max, p2_y_max = get_fixed_xyxy(p2_normalized_xyxy,int_feats_p2)
200
+ p3_x_min, p3_y_min, p3_x_max, p3_y_max = get_fixed_xyxy(p3_normalized_xyxy,int_feats_p3)
201
+
202
+ p2_roi = torch.tensor([p2_x_min, p2_y_min, p2_x_max, p2_y_max], device=device).float()
203
+ p3_roi = torch.tensor([p3_x_min, p3_y_min, p3_x_max, p3_y_max], device=device).float()
204
+
205
+ batch_index = torch.tensor([0], dtype=torch.float32, device = device)
206
+
207
+ # Concatenate the batch index to the bounding box coordinates
208
+ p2_roi_with_batch_index = torch.cat([batch_index, p2_roi])
209
+ p3_roi_with_batch_index = torch.cat([batch_index, p3_roi])
210
+ p2_resized_object = roi_align(int_feats_p2.to(device), p2_roi_with_batch_index.unsqueeze(0).to(device), output_size=(24, 30))
211
+ p3_resized_object = roi_align(int_feats_p3.to(device), p3_roi_with_batch_index.unsqueeze(0).to(device), output_size=(24, 30))
212
+ concat_box = torch.cat([p2_resized_object,p3_resized_object],dim=1)
213
+
214
+ output_cell_prediction= cell_attribute_model(concat_box)
215
+ output_cell_prediction_prob = F.softmax(output_cell_prediction.view(6,2), dim=1)
216
+ top_indices_cell_pred = torch.argmax(output_cell_prediction_prob, dim=1)
217
+ pred_Nuclear_Chromatin_array.append(top_indices_cell_pred[0].item())
218
+ pred_Nuclear_Shape_array.append(top_indices_cell_pred[1].item())
219
+ pred_Nucleus_array.append(top_indices_cell_pred[2].item())
220
+ pred_Cytoplasm_array.append(top_indices_cell_pred[3].item())
221
+ pred_Cytoplasmic_Basophilia_array.append(top_indices_cell_pred[4].item())
222
+ pred_Cytoplasmic_Vacuoles_array.append(top_indices_cell_pred[5].item())
223
+ # all_top_indices_cell_pred.append(top_indices_cell_pred.item())
224
+ else:
225
+ # top_indices_cell_pred = torch.tensor([0,0,0,0,0,0]).to(device)
226
+ pred_Nuclear_Chromatin_array.append(4)
227
+ pred_Nuclear_Shape_array.append(4)
228
+ pred_Nucleus_array.append(4)
229
+ pred_Cytoplasm_array.append(4)
230
+ pred_Cytoplasmic_Basophilia_array.append(4)
231
+ pred_Cytoplasmic_Vacuoles_array.append(4)
232
+
233
+
234
+
235
+
236
+ # Second-stage classifier (optional)
237
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
238
+
239
+ # Define the path for the CSV file
240
+ df_predictions = pd.DataFrame(columns=['Image Name', 'Prediction', 'Confidence', 'Nuclear Chromatin',
241
+ 'Nuclear Shape', 'Nucleus', 'Cytoplasm', 'Cytoplasmic Basophilia',
242
+ 'Cytoplasmic Vacuoles', 'x_min', 'y_min', 'x_max', 'y_max'])
243
+
244
+ # Function to add data to the DataFrame and plot labels
245
+ def write_to_dataframe(img, name, predicts, confid, pred_NC, pred_NS,
246
+ pred_N, pred_C, pred_CB, pred_CV,
247
+ x_min, y_min, x_max, y_max):
248
+ # global df_predictions
249
+
250
+ new_data = pd.DataFrame([{
251
+ 'Image Name': name,
252
+ 'Prediction': predicts,
253
+ 'Confidence': confid,
254
+ 'Nuclear Chromatin': pred_NC,
255
+ 'Nuclear Shape': pred_NS,
256
+ 'Nucleus': pred_N,
257
+ 'Cytoplasm': pred_C,
258
+ 'Cytoplasmic Basophilia': pred_CB,
259
+ 'Cytoplasmic Vacuoles': pred_CV,
260
+ 'x_min': x_min,
261
+ 'y_min': y_min,
262
+ 'x_max': x_max,
263
+ 'y_max': y_max
264
+ }])
265
+
266
+ # df_predictions = pd.concat([df_predictions, new_data], ignore_index=True)
267
+
268
+ # Draw bounding box and label
269
+ # cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
270
+ # cv2.putText(img, predicts, (x_min, y_min - 10),
271
+ # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
272
+
273
+ return new_data
274
+
275
+ names = ["None", "Myeloblast", "Lymphoblast", "Neutrophil", "Atypical lymphocyte",
276
+ "Promonocyte", "Monoblast", "Lymphocyte", "Myelocyte", "Abnormal promyelocyte",
277
+ "Monocyte", "Metamyelocyte", "Eosinophil", "Basophil"]
278
+
279
+ # Process predictions
280
+ for i, det in enumerate(pred): # per image
281
+
282
+ # img = cv2.imread("image.png") # Load the image
283
+
284
+ for count, (*xyxy, conf, cls) in enumerate(det):
285
+ c = int(cls) # integer class
286
+ label = names[c]
287
+ confidence = float(conf)
288
+ confidence_str = f'{confidence:.2f}'
289
+
290
+ x_min, y_min, x_max, y_max = xyxy
291
+ new_data_update = write_to_dataframe (im0 , "image.png", label, confidence_str,
292
+ pred_Nuclear_Chromatin_array[count],
293
+ pred_Nuclear_Shape_array[count],
294
+ pred_Nucleus_array[count],
295
+ pred_Cytoplasm_array[count],
296
+ pred_Cytoplasmic_Basophilia_array[count],
297
+ pred_Cytoplasmic_Vacuoles_array[count],
298
+ int(x_min.detach().cpu().item()),
299
+ int(y_min.detach().cpu().item()),
300
+ int(x_max.detach().cpu().item()),
301
+ int(y_max.detach().cpu().item()))
302
+ df_predictions = pd.concat([df_predictions, new_data_update], ignore_index=True)
303
+
304
+ # Save or display the result
305
+ # cv2.imwrite("annotated_image.png", img)
306
+ # cv2.imshow("Annotated Image", img)
307
+ # cv2.waitKey(0)
308
+ # cv2.destroyAllWindows()
309
+
310
+ # Optionally, display or export the DataFrame
311
+ result_list = []
312
+
313
+ # Conditions for each column
314
+ result_list.append("open" if (df_predictions['Nuclear Chromatin'] == 0).sum() > (df_predictions['Nuclear Chromatin'] == 1).sum() else "Coarse")
315
+ result_list.append("regular" if (df_predictions['Nuclear Shape'] == 0).sum() > (df_predictions['Nuclear Shape'] == 1).sum() else "irregular")
316
+ result_list.append("inconspicuous" if (df_predictions['Nucleus'] == 0).sum() > (df_predictions['Nucleus'] == 1).sum() else "prominent")
317
+ result_list.append("scanty" if (df_predictions['Cytoplasm'] == 0).sum() > (df_predictions['Cytoplasm'] == 1).sum() else "abundant")
318
+ result_list.append("slight" if (df_predictions['Cytoplasmic Basophilia'] == 0).sum() > (df_predictions['Cytoplasmic Basophilia'] == 1).sum() else "moderate")
319
+ result_list.append("absent" if (df_predictions['Cytoplasmic Vacuoles'] == 0).sum() > (df_predictions['Cytoplasmic Vacuoles'] == 1).sum() else "prominent")
320
+ # Sample text with <mask> placeholders
321
+ text = """These WBCโ€™s are, <mask> chromatin, and <mask> shaped nuclei. The nucleoli are <mask>, and the cytoplasm is <mask> with <mask> basophilia. Cytoplasmic vacuoles are <mask>."""
322
+
323
+ # Replace <mask> with values from result_list
324
+ filled_text = text.replace("<mask>", "{}").format(*result_list)
325
+
326
+
327
+ def plot_bboxes_from_dataframe(img, df_predictions):
328
+ # Iterate through the DataFrame
329
+ for _, row in df_predictions.iterrows():
330
+ # Extract coordinates (convert from string to int)
331
+ x_min, y_min, x_max, y_max = map(int, [row['x_min'], row['y_min'], row['x_max'], row['y_max']])
332
+ prediction = row['Prediction']
333
+ confidence = float(row['Confidence'])
334
+
335
+ # Skip predictions marked as 'None'
336
+ if prediction == "None":
337
+ continue
338
+
339
+ # Draw the bounding box
340
+ cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
341
+
342
+ # Display prediction with confidence
343
+ label = f"{prediction} ({confidence:.2f})"
344
+ cv2.putText(img, label, (x_min, max(0, y_min - 10)),
345
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
346
+
347
+ return img # Return the annotated image
348
+ # df_predictions.to_csv("predictions.csv", index=False) # Save if needed
349
+ annotated_img = plot_bboxes_from_dataframe(im0, df_predictions)
350
+ # cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
351
+ # cv2.putText(img, predicts, (x_min, y_min - 10)),
352
+ # print(df_predictions)
353
+
354
+
355
+ # else:
356
+ # QMessageBox.critical(self, "Error", "Please enter a slide number.")
357
+ # image_counter = 1
358
+ return annotated_img ,filled_text
359
+ # Process detections
360
+ # for i, det in enumerate(pred):
361
+ # if len(det):
362
+ # det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], frame.shape).round()
363
+ # for *xyxy, conf, cls in reversed(det):
364
+ # c = int(cls) # integer class
365
+ # label = None if self.hide_labels else (model.names[c] if self.hide_conf else f'{model.names[c]} {conf:.2f}')
366
+ # img0 = self.plot_one_box(xyxy, frame, label=label, color=(0,255,0))
367
+
368
+ # # Save image with bounding boxes
369
+ # output_path = os.path.join(self.slide_dir, f"image_detection{self.image_counter}.png")
370
+
371
+
372
+
373
+ # if len(det):
374
+ # cv2.imwrite(output_path, img0)
375
+ # #QMessageBox.information(self, "Success", f"Image {self.image_counter} captured and saved.")
376
+ # self.image_counter += 1
377
+ # self.image_counter_label.setText(f"{self.image_counter}")
378
+
379
+ def inference_fn_select(image_input):
380
+ try:
381
+ # img = letterbox(image_input, (640, 640), stride=32, auto=True)[0] # Resize and pad image
382
+ # img = img.transpose(2, 0, 1)[::-1] # Convert to channel-first format
383
+ # img = np.ascontiguousarray(img)
384
+ results,filled_text = capture_image(image_input)
385
+ state = 1# Model inference
386
+ result_pil = Image.fromarray(results)
387
+ return result_pil,filled_text
388
+ except Exception as e:
389
+ return None, f"Error in inference: {e}"
390
+
391
+ def set_cloze_samples(example: list) -> dict:
392
+ return gr.Image.update(example[0]), gr.Textbox.update(example[1]), 'Cloze Test'
393
+
394
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
395
+ gr.Markdown(header)
396
+ gr.Markdown(abstract)
397
+ state = gr.State([])
398
+
399
+ with gr.Row():
400
+ with gr.Column(scale=0.5, min_width=500):
401
+ image_input = gr.Image(type="pil", interactive=True, label="Upload an image ๐Ÿ“", height=250)
402
+ with gr.Column(scale=0.5, min_width=500):
403
+ task_button = gr.Radio(label="Contextual Task type", interactive=True,
404
+ choices=['Detect'],
405
+ value='Detect')
406
+ with gr.Row():
407
+ submit_button = gr.Button(value="๐Ÿƒ Run", interactive=True, variant="primary")
408
+ clear_button = gr.Button(value="๐Ÿ”„ Clear", interactive=True)
409
+
410
+ with gr.Row():
411
+ with gr.Column(scale=0.5, min_width=500):
412
+ image_output = gr.Image(type='pil', interactive=False, label="Detection output")
413
+ with gr.Column(scale=0.5, min_width=500):
414
+ chat_output = gr.Textbox(label="Text output")
415
+
416
+ submit_button.click(
417
+ inference_fn_select,
418
+ [image_input],
419
+ [image_output, chat_output],
420
+ )
421
+
422
+ clear_button.click(
423
+ lambda: (None, None, "", [], [], 'Detect'),
424
+ [],
425
+ [image_input, image_output, chat_output, task_button],
426
+ queue=False,
427
+ )
428
+
429
+ image_input.change(
430
+ lambda: (None, "", []),
431
+ [],
432
+ [image_output, chat_output],
433
+ queue=False,
434
+ )
435
+
436
+ gr.Markdown(footer)
437
+
438
+ demo.queue() # Enable request queuing
439
+ demo.launch(share=False)