lighthouse-emnlp2024 commited on
Commit
2b48bef
·
1 Parent(s): dfb7b55
Files changed (3) hide show
  1. app.py +348 -0
  2. packages.txt +1 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright $today.year LY Corporation
3
+
4
+ LY Corporation licenses this file to you under the Apache License,
5
+ version 2.0 (the "License"); you may not use this file except in compliance
6
+ with the License. You may obtain a copy of the License at:
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13
+ License for the specific language governing permissions and limitations
14
+ under the License.
15
+ """
16
+
17
+ import os
18
+ import subprocess
19
+
20
+ import ffmpeg
21
+ import gradio as gr
22
+ import pandas as pd
23
+ import torch
24
+ from lighthouse.models import *
25
+ from tqdm import tqdm
26
+
27
+ # use GPU if available
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ MODEL_NAMES = ["cg_detr", "moment_detr", "eatr", "qd_detr", "tr_detr", "uvcom"]
30
+ FEATURES = ["clip"]
31
+ TOPK_MOMENT = 5
32
+ TOPK_HIGHLIGHT = 5
33
+
34
+ """
35
+ Helper functions
36
+ """
37
+
38
+
39
+ def load_pretrained_weights():
40
+ file_urls = []
41
+ for model_name in MODEL_NAMES:
42
+ for feature in FEATURES:
43
+ file_urls.append(
44
+ "https://zenodo.org/records/13960580/files/{}_{}_qvhighlight.ckpt".format(
45
+ feature, model_name
46
+ )
47
+ )
48
+ for file_url in tqdm(file_urls):
49
+ if not os.path.exists("gradio_demo/weights/" + os.path.basename(file_url)):
50
+ command = "wget -P gradio_demo/weights/ {}".format(file_url)
51
+ subprocess.run(command, shell=True)
52
+
53
+ # Slowfast weights
54
+ if not os.path.exists("SLOWFAST_8x8_R50.pkl"):
55
+ subprocess.run(
56
+ "wget https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl",
57
+ shell=True,
58
+ )
59
+
60
+ # PANNs weights
61
+ if not os.path.exists("Cnn14_mAP=0.431.pth"):
62
+ subprocess.run(
63
+ "wget https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth",
64
+ shell=True,
65
+ )
66
+
67
+ return file_urls
68
+
69
+
70
+ def flatten(array2d):
71
+ list1d = []
72
+ for elem in array2d:
73
+ list1d += elem
74
+ return list1d
75
+
76
+
77
+ """
78
+ Model initialization
79
+ """
80
+ load_pretrained_weights()
81
+ model = CGDETRPredictor(
82
+ "gradio_demo/weights/clip_cg_detr_qvhighlight.ckpt",
83
+ device=device,
84
+ feature_name="clip",
85
+ slowfast_path=None,
86
+ pann_path=None,
87
+ )
88
+ loaded_video = None
89
+ loaded_video_path = None
90
+
91
+ js_codes = [
92
+ """() => {{
93
+ let moment_text = document.getElementById('result_{}').textContent;
94
+ var replaced_text = moment_text.replace(/moment..../, '').replace(/\ Score.*/, '');
95
+ let start_end = JSON.parse(replaced_text);
96
+ document.getElementsByTagName("video")[0].currentTime = start_end[0];
97
+ document.getElementsByTagName("video")[0].play();
98
+ }}""".format(i)
99
+ for i in range(TOPK_MOMENT)
100
+ ]
101
+
102
+ """
103
+ Gradio functions
104
+ """
105
+
106
+
107
+ def video_upload(video):
108
+ global loaded_video, loaded_video_path
109
+ if video is None:
110
+ loaded_video = None
111
+ loaded_video_path = video
112
+ yield gr.update(value="Removed the video", visible=True)
113
+ else:
114
+ yield gr.update(
115
+ value="Processing the video. Wait for a minute...", visible=True
116
+ )
117
+ loaded_video = model.encode_video(video)
118
+ loaded_video_path = video
119
+ yield gr.update(value="Finished video processing!", visible=True)
120
+
121
+
122
+ def model_load(radio, video):
123
+ global loaded_video, loaded_video_path
124
+ if radio is not None:
125
+ loading_msg = "Loading new model. Wait for a minute..."
126
+ yield (
127
+ gr.update(value=loading_msg, visible=True),
128
+ gr.update(value=loading_msg, visible=True),
129
+ )
130
+ global model
131
+ feature, model_name = radio.split("+")
132
+ feature, model_name = feature.strip(), model_name.strip()
133
+
134
+ if model_name == "moment_detr":
135
+ model_class = MomentDETRPredictor
136
+ elif model_name == "qd_detr":
137
+ model_class = QDDETRPredictor
138
+ elif model_name == "eatr":
139
+ model_class = EaTRPredictor
140
+ elif model_name == "tr_detr":
141
+ model_class = TRDETRPredictor
142
+ elif model_name == "uvcom":
143
+ model_class = UVCOMPredictor
144
+ elif model_name == "cg_detr":
145
+ model_class = CGDETRPredictor
146
+ else:
147
+ raise gr.Error("Select from the models")
148
+
149
+ model = model_class(
150
+ "gradio_demo/weights/{}_{}_qvhighlight.ckpt".format(feature, model_name),
151
+ device=device,
152
+ feature_name="{}".format(feature),
153
+ slowfast_path="SLOWFAST_8x8_R50.pkl",
154
+ pann_path="Cnn14_mAP=0.431.pth",
155
+ )
156
+
157
+ load_finished_msg = "Model loaded: {}".format(radio)
158
+ encode_process_msg = (
159
+ "Processing the video. Wait for a minute..." if video is not None else ""
160
+ )
161
+ yield (
162
+ gr.update(value=load_finished_msg, visible=True),
163
+ gr.update(value=encode_process_msg, visible=True),
164
+ )
165
+
166
+ if video is not None:
167
+ loaded_video = model.encode_video(video)
168
+ loaded_video_path = video
169
+ encode_finished_msg = "Finished video processing!"
170
+ yield (
171
+ gr.update(value=load_finished_msg, visible=True),
172
+ gr.update(value=encode_finished_msg, visible=True),
173
+ )
174
+ else:
175
+ loaded_video = None
176
+ loaded_video_path = None
177
+
178
+
179
+ def predict(textbox, line, gallery):
180
+ global loaded_video, loaded_video_path
181
+ if loaded_video is None:
182
+ raise gr.Error(
183
+ "Upload the video before pushing the `Retrieve moment & highlight detection` button."
184
+ )
185
+ else:
186
+ prediction = model.predict(textbox, loaded_video)
187
+
188
+ mr_results = prediction["pred_relevant_windows"]
189
+ hl_results = prediction["pred_saliency_scores"]
190
+
191
+ buttons = []
192
+ for i, pred in enumerate(mr_results[:TOPK_MOMENT]):
193
+ buttons.append(
194
+ gr.Button(
195
+ value="moment {}: [{}, {}] Score: {}".format(
196
+ i + 1, pred[0], pred[1], pred[2]
197
+ ),
198
+ visible=True,
199
+ )
200
+ )
201
+
202
+ # Visualize the HD score
203
+ seconds = [model._vision_encoder._clip_len * i for i in range(len(hl_results))]
204
+ hl_data = pd.DataFrame({"second": seconds, "saliency_score": hl_results})
205
+ min_val, max_val = min(hl_results), max(hl_results) + 1
206
+ min_x, max_x = min(seconds), max(seconds)
207
+ line = gr.LinePlot(
208
+ value=hl_data,
209
+ x="second",
210
+ y="saliency_score",
211
+ visible=True,
212
+ y_lim=[min_val, max_val],
213
+ x_lim=[min_x, max_x],
214
+ )
215
+
216
+ # Show highlight frames
217
+ n_largest_df = hl_data.nlargest(columns="saliency_score", n=TOPK_HIGHLIGHT)
218
+ highlighted_seconds = n_largest_df.second.tolist()
219
+ highlighted_scores = n_largest_df.saliency_score.tolist()
220
+
221
+ output_image_paths = []
222
+ for i, (second, score) in enumerate(
223
+ zip(highlighted_seconds, highlighted_scores)
224
+ ):
225
+ output_path = "gradio_demo/highlight_frames/highlight_{}.png".format(i)
226
+ (
227
+ ffmpeg.input(loaded_video_path, ss=second)
228
+ .output(output_path, vframes=1, qscale=2)
229
+ .global_args("-loglevel", "quiet", "-y")
230
+ .run()
231
+ )
232
+ output_image_paths.append(
233
+ (output_path, "Highlight: {} - score: {:.02f}".format(i + 1, score))
234
+ )
235
+ gallery = gr.Gallery(
236
+ value=output_image_paths,
237
+ label="gradio",
238
+ columns=5,
239
+ show_download_button=True,
240
+ visible=True,
241
+ )
242
+ return buttons + [line, gallery]
243
+
244
+
245
+ def main():
246
+ title = """# Moment Retrieval & Highlight Detection Demo"""
247
+
248
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
249
+ gr.Markdown(title)
250
+
251
+ with gr.Row():
252
+ with gr.Column():
253
+ with gr.Group():
254
+ gr.Markdown("## Model selection")
255
+ radio_list = flatten(
256
+ [
257
+ [
258
+ "{} + {}".format(feature, model_name)
259
+ for model_name in MODEL_NAMES
260
+ ]
261
+ for feature in FEATURES
262
+ ]
263
+ )
264
+ radio = gr.Radio(
265
+ radio_list,
266
+ label="models",
267
+ value="clip + cg_detr",
268
+ info="Which model do you want to use?",
269
+ )
270
+ load_status_text = gr.Textbox(
271
+ label="Model load status", value="Model loaded: clip + cg_detr"
272
+ )
273
+
274
+ with gr.Group():
275
+ gr.Markdown("## Video and query")
276
+ video_input = gr.Video(elem_id="video", height=600)
277
+ output = gr.Textbox(label="Video processing progress")
278
+ query_input = gr.Textbox(label="query")
279
+ button = gr.Button(
280
+ "Retrieve moment & highlight detection", variant="primary"
281
+ )
282
+
283
+ with gr.Column():
284
+ with gr.Group():
285
+ gr.Markdown("## Retrieved moments")
286
+
287
+ button_1 = gr.Button(
288
+ value="moment 1", visible=False, elem_id="result_0"
289
+ )
290
+ button_2 = gr.Button(
291
+ value="moment 2", visible=False, elem_id="result_1"
292
+ )
293
+ button_3 = gr.Button(
294
+ value="moment 3", visible=False, elem_id="result_2"
295
+ )
296
+ button_4 = gr.Button(
297
+ value="moment 4", visible=False, elem_id="result_3"
298
+ )
299
+ button_5 = gr.Button(
300
+ value="moment 5", visible=False, elem_id="result_4"
301
+ )
302
+
303
+ button_1.click(None, None, None, js=js_codes[0])
304
+ button_2.click(None, None, None, js=js_codes[1])
305
+ button_3.click(None, None, None, js=js_codes[2])
306
+ button_4.click(None, None, None, js=js_codes[3])
307
+ button_5.click(None, None, None, js=js_codes[4])
308
+
309
+ # dummy
310
+ with gr.Group():
311
+ gr.Markdown("## Saliency score")
312
+ line = gr.LinePlot(
313
+ value=pd.DataFrame({"x": [], "y": []}),
314
+ x="x",
315
+ y="y",
316
+ visible=False,
317
+ )
318
+ gr.Markdown("### Highlighted frames")
319
+ gallery = gr.Gallery(
320
+ value=[], label="highlight", columns=5, visible=False
321
+ )
322
+
323
+ video_input.change(video_upload, inputs=[video_input], outputs=output)
324
+ radio.select(
325
+ model_load,
326
+ inputs=[radio, video_input],
327
+ outputs=[load_status_text, output],
328
+ )
329
+
330
+ button.click(
331
+ predict,
332
+ inputs=[query_input, line, gallery],
333
+ outputs=[
334
+ button_1,
335
+ button_2,
336
+ button_3,
337
+ button_4,
338
+ button_5,
339
+ line,
340
+ gallery,
341
+ ],
342
+ )
343
+
344
+ demo.launch()
345
+
346
+
347
+ if __name__ == "__main__":
348
+ main()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ git+https://github.com/line/lighthouse.git
2
+ torch==2.1.0
3
+ torchvision==0.16.0
4
+ torchaudio==2.1.0
5
+ torchtext==0.16.0