Spaces:
Sleeping
Sleeping
""" | |
Copyright $today.year LY Corporation | |
LY Corporation licenses this file to you under the Apache License, | |
version 2.0 (the "License"); you may not use this file except in compliance | |
with the License. You may obtain a copy of the License at: | |
https://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
License for the specific language governing permissions and limitations | |
under the License. | |
""" | |
import os | |
import subprocess | |
import ffmpeg | |
import gradio as gr | |
import pandas as pd | |
import torch | |
from lighthouse.models import * | |
from tqdm import tqdm | |
# use GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_NAMES = ["cg_detr", "moment_detr", "eatr", "qd_detr", "tr_detr", "uvcom"] | |
FEATURES = ["clip"] | |
TOPK_MOMENT = 5 | |
TOPK_HIGHLIGHT = 5 | |
""" | |
Helper functions | |
""" | |
def load_pretrained_weights(): | |
file_urls = [] | |
for model_name in MODEL_NAMES: | |
for feature in FEATURES: | |
file_urls.append( | |
"https://zenodo.org/records/13960580/files/{}_{}_qvhighlight.ckpt".format( | |
feature, model_name | |
) | |
) | |
for file_url in tqdm(file_urls): | |
if not os.path.exists("weights/" + os.path.basename(file_url)): | |
command = "wget -P weights/ {}".format(file_url) | |
subprocess.run(command, shell=True) | |
return file_urls | |
def flatten(array2d): | |
list1d = [] | |
for elem in array2d: | |
list1d += elem | |
return list1d | |
""" | |
Model initialization | |
""" | |
load_pretrained_weights() | |
model = CGDETRPredictor( | |
"weights/clip_cg_detr_qvhighlight.ckpt", | |
device=device, | |
feature_name="clip", | |
slowfast_path=None, | |
pann_path=None, | |
) | |
loaded_video = None | |
loaded_video_path = None | |
js_codes = [ | |
"""() => {{ | |
let moment_text = document.getElementById('result_{}').textContent; | |
var replaced_text = moment_text.replace(/moment..../, '').replace(/\ Score.*/, ''); | |
let start_end = JSON.parse(replaced_text); | |
document.getElementsByTagName("video")[0].currentTime = start_end[0]; | |
document.getElementsByTagName("video")[0].play(); | |
}}""".format(i) | |
for i in range(TOPK_MOMENT) | |
] | |
""" | |
Gradio functions | |
""" | |
def video_upload(video): | |
global loaded_video, loaded_video_path | |
if video is None: | |
loaded_video = None | |
loaded_video_path = video | |
yield gr.update(value="Removed the video", visible=True) | |
else: | |
yield gr.update( | |
value="Processing the video. Wait for a minute...", visible=True | |
) | |
loaded_video = model.encode_video(video) | |
loaded_video_path = video | |
yield gr.update(value="Finished video processing!", visible=True) | |
def model_load(radio, video): | |
global loaded_video, loaded_video_path | |
if radio is not None: | |
loading_msg = "Loading new model. Wait for a minute..." | |
yield ( | |
gr.update(value=loading_msg, visible=True), | |
gr.update(value=loading_msg, visible=True), | |
) | |
global model | |
feature, model_name = radio.split("+") | |
feature, model_name = feature.strip(), model_name.strip() | |
if model_name == "moment_detr": | |
model_class = MomentDETRPredictor | |
elif model_name == "qd_detr": | |
model_class = QDDETRPredictor | |
elif model_name == "eatr": | |
model_class = EaTRPredictor | |
elif model_name == "tr_detr": | |
model_class = TRDETRPredictor | |
elif model_name == "uvcom": | |
model_class = UVCOMPredictor | |
elif model_name == "cg_detr": | |
model_class = CGDETRPredictor | |
else: | |
raise gr.Error("Select from the models") | |
model = model_class( | |
"weights/{}_{}_qvhighlight.ckpt".format(feature, model_name), | |
device=device, | |
feature_name="{}".format(feature), | |
) | |
load_finished_msg = "Model loaded: {}".format(radio) | |
encode_process_msg = ( | |
"Processing the video. Wait for a minute..." if video is not None else "" | |
) | |
yield ( | |
gr.update(value=load_finished_msg, visible=True), | |
gr.update(value=encode_process_msg, visible=True), | |
) | |
if video is not None: | |
loaded_video = model.encode_video(video) | |
loaded_video_path = video | |
encode_finished_msg = "Finished video processing!" | |
yield ( | |
gr.update(value=load_finished_msg, visible=True), | |
gr.update(value=encode_finished_msg, visible=True), | |
) | |
else: | |
loaded_video = None | |
loaded_video_path = None | |
def predict(textbox, line, gallery): | |
global loaded_video, loaded_video_path | |
if loaded_video is None: | |
raise gr.Error( | |
"Upload the video before pushing the `Retrieve moment & highlight detection` button." | |
) | |
else: | |
prediction = model.predict(textbox, loaded_video) | |
mr_results = prediction["pred_relevant_windows"] | |
hl_results = prediction["pred_saliency_scores"] | |
buttons = [] | |
for i, pred in enumerate(mr_results[:TOPK_MOMENT]): | |
buttons.append( | |
gr.Button( | |
value="moment {}: [{}, {}] Score: {}".format( | |
i + 1, pred[0], pred[1], pred[2] | |
), | |
visible=True, | |
) | |
) | |
# Visualize the HD score | |
seconds = [model._vision_encoder._clip_len * i for i in range(len(hl_results))] | |
hl_data = pd.DataFrame({"second": seconds, "saliency_score": hl_results}) | |
min_val, max_val = min(hl_results), max(hl_results) + 1 | |
min_x, max_x = min(seconds), max(seconds) | |
line = gr.LinePlot( | |
value=hl_data, | |
x="second", | |
y="saliency_score", | |
visible=True, | |
y_lim=[min_val, max_val], | |
x_lim=[min_x, max_x], | |
) | |
# Show highlight frames | |
n_largest_df = hl_data.nlargest(columns="saliency_score", n=TOPK_HIGHLIGHT) | |
highlighted_seconds = n_largest_df.second.tolist() | |
highlighted_scores = n_largest_df.saliency_score.tolist() | |
output_image_paths = [] | |
for i, (second, score) in enumerate( | |
zip(highlighted_seconds, highlighted_scores) | |
): | |
output_path = "highlight_frames/highlight_{}.png".format(i) | |
( | |
ffmpeg.input(loaded_video_path, ss=second) | |
.output(output_path, vframes=1, qscale=2) | |
.global_args("-loglevel", "quiet", "-y") | |
.run() | |
) | |
output_image_paths.append( | |
(output_path, "Highlight: {} - score: {:.02f}".format(i + 1, score)) | |
) | |
gallery = gr.Gallery( | |
value=output_image_paths, | |
label="gradio", | |
columns=5, | |
show_download_button=True, | |
visible=True, | |
) | |
return buttons + [line, gallery] | |
def main(): | |
title = """# Moment Retrieval & Highlight Detection Demo""" | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(title) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown("## Model selection") | |
radio_list = flatten( | |
[ | |
[ | |
"{} + {}".format(feature, model_name) | |
for model_name in MODEL_NAMES | |
] | |
for feature in FEATURES | |
] | |
) | |
radio = gr.Radio( | |
radio_list, | |
label="models", | |
value="clip + cg_detr", | |
info="Which model do you want to use? More models is available in the original repository. Please refer to https://github.com/line/lighthouse for more details.", | |
) | |
load_status_text = gr.Textbox( | |
label="Model load status", value="Model loaded: clip + cg_detr" | |
) | |
with gr.Group(): | |
gr.Markdown("## Video and query") | |
video_input = gr.Video(elem_id="video", height=600) | |
output = gr.Textbox(label="Video processing progress") | |
query_input = gr.Textbox(label="query") | |
button = gr.Button( | |
"Retrieve moment & highlight detection", variant="primary" | |
) | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown("## Retrieved moments") | |
button_1 = gr.Button( | |
value="moment 1", visible=False, elem_id="result_0" | |
) | |
button_2 = gr.Button( | |
value="moment 2", visible=False, elem_id="result_1" | |
) | |
button_3 = gr.Button( | |
value="moment 3", visible=False, elem_id="result_2" | |
) | |
button_4 = gr.Button( | |
value="moment 4", visible=False, elem_id="result_3" | |
) | |
button_5 = gr.Button( | |
value="moment 5", visible=False, elem_id="result_4" | |
) | |
button_1.click(None, None, None, js=js_codes[0]) | |
button_2.click(None, None, None, js=js_codes[1]) | |
button_3.click(None, None, None, js=js_codes[2]) | |
button_4.click(None, None, None, js=js_codes[3]) | |
button_5.click(None, None, None, js=js_codes[4]) | |
# dummy | |
with gr.Group(): | |
gr.Markdown("## Saliency score") | |
line = gr.LinePlot( | |
value=pd.DataFrame({"x": [], "y": []}), | |
x="x", | |
y="y", | |
visible=False, | |
) | |
gr.Markdown("### Highlighted frames") | |
gallery = gr.Gallery( | |
value=[], label="highlight", columns=5, visible=False | |
) | |
video_input.change(video_upload, inputs=[video_input], outputs=output) | |
radio.select( | |
model_load, | |
inputs=[radio, video_input], | |
outputs=[load_status_text, output], | |
) | |
button.click( | |
predict, | |
inputs=[query_input, line, gallery], | |
outputs=[ | |
button_1, | |
button_2, | |
button_3, | |
button_4, | |
button_5, | |
line, | |
gallery, | |
], | |
) | |
demo.launch(share=True, server_name="0.0.0.0") | |
if __name__ == "__main__": | |
main() | |