fffiloni commited on
Commit
2a72565
·
verified ·
1 Parent(s): 076dd71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -63
app.py CHANGED
@@ -17,6 +17,66 @@ model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout
17
  model.eval()
18
  model.to(device)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def main(image_input, progress=gr.Progress(track_tqdm=True)):
21
  # load image
22
  image = Image.open(image_input)
@@ -47,74 +107,12 @@ def main(image_input, progress=gr.Progress(track_tqdm=True)):
47
  print(img1_person1_inout.item())
48
 
49
  # visualize predicted gaze heatmap for each person and gaze in/out of frame score
50
-
51
- def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
52
- if isinstance(heatmap, torch.Tensor):
53
- heatmap = heatmap.detach().cpu().numpy()
54
- heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
55
- heatmap = plt.cm.jet(np.array(heatmap) / 255.)
56
- heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
57
- heatmap = Image.fromarray(heatmap).convert("RGBA")
58
- heatmap.putalpha(90)
59
- overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)
60
-
61
- if bbox is not None:
62
- width, height = pil_image.size
63
- xmin, ymin, xmax, ymax = bbox
64
- draw = ImageDraw.Draw(overlay_image)
65
- draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))
66
-
67
- if inout_score is not None:
68
- text = f"in-frame: {inout_score:.2f}"
69
- text_width = draw.textlength(text)
70
- text_height = int(height * 0.01)
71
- text_x = xmin * width
72
- text_y = ymax * height + text_height
73
- draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
74
- return overlay_image
75
-
76
  heatmap_results = []
77
  for i in range(len(bboxes)):
78
- overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None))
79
  heatmap_results.append(overlay_img)
80
 
81
  # combined visualization with maximal gaze points for each person
82
-
83
- def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
84
- colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
85
- overlay_image = pil_image.convert("RGBA")
86
- draw = ImageDraw.Draw(overlay_image)
87
- width, height = pil_image.size
88
-
89
- for i in range(len(bboxes)):
90
- bbox = bboxes[i]
91
- xmin, ymin, xmax, ymax = bbox
92
- color = colors[i % len(colors)]
93
- draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))
94
-
95
- if inout_scores is not None:
96
- inout_score = inout_scores[i]
97
- text = f"in-frame: {inout_score:.2f}"
98
- text_width = draw.textlength(text)
99
- text_height = int(height * 0.01)
100
- text_x = xmin * width
101
- text_y = ymax * height + text_height
102
- draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
103
-
104
- if inout_scores is not None and inout_score > inout_thresh:
105
- heatmap = heatmaps[i]
106
- heatmap_np = heatmap.detach().cpu().numpy()
107
- max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
108
- gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
109
- gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
110
- bbox_center_x = ((xmin + xmax) / 2) * width
111
- bbox_center_y = ((ymin + ymax) / 2) * height
112
-
113
- draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
114
- draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))
115
-
116
- return overlay_image
117
-
118
  result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)
119
 
120
  return result_gazed, heatmap_results
 
17
  model.eval()
18
  model.to(device)
19
 
20
+ def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
21
+ if isinstance(heatmap, torch.Tensor):
22
+ heatmap = heatmap.detach().cpu().numpy()
23
+ heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
24
+ heatmap = plt.cm.jet(np.array(heatmap) / 255.)
25
+ heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
26
+ heatmap = Image.fromarray(heatmap).convert("RGBA")
27
+ heatmap.putalpha(90)
28
+ overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)
29
+
30
+ if bbox is not None:
31
+ width, height = pil_image.size
32
+ xmin, ymin, xmax, ymax = bbox
33
+ draw = ImageDraw.Draw(overlay_image)
34
+ draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))
35
+
36
+ if inout_score is not None:
37
+ text = f"in-frame: {inout_score:.2f}"
38
+ text_width = draw.textlength(text)
39
+ text_height = int(height * 0.01)
40
+ text_x = xmin * width
41
+ text_y = ymax * height + text_height
42
+ draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
43
+ return overlay_image
44
+
45
+ def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
46
+ colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
47
+ overlay_image = pil_image.convert("RGBA")
48
+ draw = ImageDraw.Draw(overlay_image)
49
+ width, height = pil_image.size
50
+
51
+ for i in range(len(bboxes)):
52
+ bbox = bboxes[i]
53
+ xmin, ymin, xmax, ymax = bbox
54
+ color = colors[i % len(colors)]
55
+ draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))
56
+
57
+ if inout_scores is not None:
58
+ inout_score = inout_scores[i]
59
+ text = f"in-frame: {inout_score:.2f}"
60
+ text_width = draw.textlength(text)
61
+ text_height = int(height * 0.01)
62
+ text_x = xmin * width
63
+ text_y = ymax * height + text_height
64
+ draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
65
+
66
+ if inout_scores is not None and inout_score > inout_thresh:
67
+ heatmap = heatmaps[i]
68
+ heatmap_np = heatmap.detach().cpu().numpy()
69
+ max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
70
+ gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
71
+ gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
72
+ bbox_center_x = ((xmin + xmax) / 2) * width
73
+ bbox_center_y = ((ymin + ymax) / 2) * height
74
+
75
+ draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
76
+ draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))
77
+
78
+ return overlay_image
79
+
80
  def main(image_input, progress=gr.Progress(track_tqdm=True)):
81
  # load image
82
  image = Image.open(image_input)
 
107
  print(img1_person1_inout.item())
108
 
109
  # visualize predicted gaze heatmap for each person and gaze in/out of frame score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  heatmap_results = []
111
  for i in range(len(bboxes)):
112
+ overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None)
113
  heatmap_results.append(overlay_img)
114
 
115
  # combined visualization with maximal gaze points for each person
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)
117
 
118
  return result_gazed, heatmap_results