File size: 8,327 Bytes
ba7e9ef
270d2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a72565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba7e9ef
33c5278
915804e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270d2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33c5278
 
2a72565
33c5278
270d2eb
 
 
 
33c5278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270d2eb
 
 
 
ba7e9ef
7b54149
ba7e9ef
 
270d2eb
3164323
 
270d2eb
 
 
 
33c5278
270d2eb
915804e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import spaces
import gradio as gr

import torch
import matplotlib.pyplot as plt 
from PIL import Image, ImageDraw, ImageFont 
import requests 
from io import BytesIO 
import numpy as np 

# load a simple face detector 
from retinaface import RetinaFace 

device = "cuda" if torch.cuda.is_available() else "cpu"

# load Gaze-LLE model
model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout")
model.eval()
model.to(device)

def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
    if isinstance(heatmap, torch.Tensor):
        heatmap = heatmap.detach().cpu().numpy()
    heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
    heatmap = plt.cm.jet(np.array(heatmap) / 255.)
    heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
    heatmap = Image.fromarray(heatmap).convert("RGBA")
    heatmap.putalpha(90)
    overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)

    if bbox is not None:
        width, height = pil_image.size
        xmin, ymin, xmax, ymax = bbox
        draw = ImageDraw.Draw(overlay_image)
        draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))

        if inout_score is not None:
            text = f"in-frame: {inout_score:.2f}"
            text_width = draw.textlength(text)
            text_height = int(height * 0.01)
            text_x = xmin * width
            text_y = ymax * height + text_height
            draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
    return overlay_image

def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
    colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
    overlay_image = pil_image.convert("RGBA")
    draw = ImageDraw.Draw(overlay_image)
    width, height = pil_image.size

    for i in range(len(bboxes)):
        bbox = bboxes[i]
        xmin, ymin, xmax, ymax = bbox
        color = colors[i % len(colors)]
        draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))

        if inout_scores is not None:
            inout_score = inout_scores[i]
            text = f"in-frame: {inout_score:.2f}"
            text_width = draw.textlength(text)
            text_height = int(height * 0.01)
            text_x = xmin * width
            text_y = ymax * height + text_height
            draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))

        if inout_scores is not None and inout_score > inout_thresh:
            heatmap = heatmaps[i]
            heatmap_np = heatmap.detach().cpu().numpy()
            max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
            gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
            gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
            bbox_center_x = ((xmin + xmax) / 2) * width
            bbox_center_y = ((ymin + ymax) / 2) * height

            draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
            draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))

    return overlay_image

@spaces.GPU() # ZeroGPU ready
def main(image_input, progress=gr.Progress(track_tqdm=True)):
    
    """Estimate gaze direction for detected faces in an image using Gaze-LLE.
    
    This function processes an input image to detect faces, estimates gaze heatmaps 
    for each face using a pre-trained Gaze-LLE model, and then visualizes the results 
    including gaze direction and whether each person's gaze is within the frame.

    Args:
        image_input: A filepath to the input image. Should be a photo containing one or more human faces.
        progress: Optional Gradio progress tracker for UI feedback (used during inference).

    Returns:
        result_gazed (PIL.Image.Image): A single composite image with bounding boxes around faces, 
            lines indicating predicted gaze direction, and indicators of whether gaze is "in-frame".
        heatmap_results (List[PIL.Image.Image]): A list of individual images, one per face, each showing 
            the original image overlaid with a heatmap of the predicted gaze target.
    """
    
    # load image
    image = Image.open(image_input)
    width, height = image.size

    # detect faces
    resp = RetinaFace.detect_faces(np.array(image))
    print(resp)
    bboxes = [resp[key]["facial_area"] for key in resp.keys()]
    print(bboxes)

    # prepare gazelle input
    img_tensor = transform(image).unsqueeze(0).to(device)
    norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]]

    input = {
        "images": img_tensor, # [num_images, 3, 448, 448]
        "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...]
    }

    with torch.no_grad():
        output = model(input)

    img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap
    print(img1_person1_heatmap.shape)
    if model.inout:
        img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction)
        print(img1_person1_inout.item())

    # visualize predicted gaze heatmap for each person and gaze in/out of frame score
    heatmap_results = []
    for i in range(len(bboxes)):
        overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None)
        heatmap_results.append(overlay_img)

    # combined visualization with maximal gaze points for each person
    result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)

    return result_gazed, heatmap_results

css="""
div#col-container{
    margin: 0 auto;
    max-width: 982px;
}
"""

with gr.Blocks(css=css) as demo: 
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders")
        gr.Markdown("A transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href="https://github.com/fkryan/gazelle">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a> 
            <a href="https://arxiv.org/abs/2412.09586">
                <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
            </a>
            <a href="https://huggingface.co/spaces/fffiloni/Gaze-LLE?duplicate=true">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
            </a>
            <a href="https://huggingface.co/fffiloni">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
            </a>
        </div>
        """)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Image Input", type="filepath")
                submit_button = gr.Button("Submit")
                gr.Examples(
                    examples = ["examples/the_office.png", "examples/succession.png"],
                    inputs = [input_image]
                )
            with gr.Column():
                result = gr.Image(label="Result")
                heatmaps = gr.Gallery(label="Heatmap", columns=3)

    submit_button.click(
        fn = main,
        inputs = [input_image],
        outputs = [result, heatmaps]
    )
demo.queue().launch(ssr_mode=False, show_error=True, mcp_server=True)