Loads and samples video frames with accurate timestamps

#2
by juvix - opened
Files changed (1) hide show
  1. app.py +65 -80
app.py CHANGED
@@ -1,105 +1,90 @@
1
  import spaces
2
  import gradio as gr
 
3
 
4
- import subprocess # πŸ₯²
5
  subprocess.run(
6
  "pip install flash-attn --no-build-isolation",
7
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
8
  shell=True,
9
  )
10
- # subprocess.run(
11
- # "pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git",
12
- # shell=True,
13
- # )
14
 
15
  import torch
16
  from llava.model.builder import load_pretrained_model
17
  from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
18
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
19
- from llava.conversation import conv_templates, SeparatorStyle
20
  import copy
21
- import warnings
22
  from decord import VideoReader, cpu
23
  import numpy as np
24
- import tempfile
25
- import os
26
- import shutil
27
- #warnings.filterwarnings("ignore")
28
  title = "# πŸ™‹πŸ»β€β™‚οΈWelcome to 🌟Tonic's πŸŒ‹πŸ“ΉLLaVA-Video!"
29
- description1 ="""The **πŸŒ‹πŸ“ΉLLaVA-Video-7B-Qwen2** is a 7B parameter model trained on the πŸŒ‹πŸ“ΉLLaVA-Video-178K dataset and the LLaVA-OneVision dataset. It is [based on the **Qwen2 language model**](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f), supporting a context window of up to 32K tokens. The model can process and interact with images, multi-images, and videos, with specific optimizations for video analysis.
30
- This model leverages the **SO400M vision backbone** for visual input and Qwen2 for language processing, making it highly efficient in multi-modal reasoning, including visual and video-based tasks.
31
- πŸŒ‹πŸ“ΉLLaVA-Video has larger variants of [32B](https://huggingface.co/lmms-lab/LLaVA-NeXT-Video-32B-Qwen) and [72B](https://huggingface.co/lmms-lab/LLaVA-Video-72B-Qwen2) and with a [variant](https://huggingface.co/lmms-lab/LLaVA-Video-7B-Qwen2-Video-Only) only trained on the new synthetic data
32
- For further details, please visit the [Project Page](https://github.com/LLaVA-VL/LLaVA-NeXT) or check out the corresponding [research paper](https://arxiv.org/abs/2410.02713).
33
- - **Architecture**: `LlavaQwenForCausalLM`
34
- - **Attention Heads**: 28
35
- - **Hidden Layers**: 28
36
- - **Hidden Size**: 3584
37
- """
38
- description2 ="""
39
- - **Intermediate Size**: 18944
40
- - **Max Frames Supported**: 64
41
- - **Languages Supported**: English, Chinese
42
- - **Image Aspect Ratio**: `anyres_max_9`
43
- - **Image Resolution**: Various grid resolutions
44
- - **Max Position Embeddings**: 32,768
45
- - **Vocab Size**: 152,064
46
- - **Model Precision**: bfloat16
47
- - **Hardware Used for Training**: 256 * Nvidia Tesla A100 GPUs
48
- """
49
 
50
  join_us = """
51
  ## Join us :
52
- 🌟TeamTonic🌟 is always making cool demos! Join our active builder's πŸ› οΈcommunity πŸ‘» [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On πŸ€—Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)πŸ€—Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant πŸ€—
53
  """
54
 
55
- def load_video(video_path, max_frames_num, fps=1, force_sample=False):
56
- if max_frames_num == 0:
57
- return np.zeros((1, 336, 336, 3))
58
  vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
59
  total_frame_num = len(vr)
60
- video_time = total_frame_num / vr.get_avg_fps()
61
- fps = round(vr.get_avg_fps()/fps)
62
- frame_idx = [i for i in range(0, len(vr), fps)]
63
- frame_time = [i/fps for i in frame_idx]
 
64
  if len(frame_idx) > max_frames_num or force_sample:
65
- sample_fps = max_frames_num
66
- uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
67
- frame_idx = uniform_sampled_frames.tolist()
68
- frame_time = [i/vr.get_avg_fps() for i in frame_idx]
69
- frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
70
- spare_frames = vr.get_batch(frame_idx).asnumpy()
71
- return spare_frames, frame_time, video_time
72
-
73
- # Load the model
74
  pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
75
  model_name = "llava_qwen"
76
  device = "cuda" if torch.cuda.is_available() else "cpu"
77
  device_map = "auto"
78
 
79
  print("Loading model...")
80
- tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
81
  model.eval()
82
  print("Model loaded successfully!")
83
 
 
 
 
 
 
 
 
 
 
84
  @spaces.GPU
85
  def process_video(video_path, question):
86
  max_frames_num = 64
87
  video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
88
- video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
89
- video = [video]
90
 
 
91
  conv_template = "qwen_1_5"
92
- time_instruction = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}. Please answer the following questions related to this video."
93
-
94
  full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
95
-
96
  conv = copy.deepcopy(conv_templates[conv_template])
97
  conv.append_message(conv.roles[0], full_question)
98
  conv.append_message(conv.roles[1], None)
99
- prompt_question = conv.get_prompt()
100
-
101
- input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
102
-
103
  with torch.no_grad():
104
  output = model.generate(
105
  input_ids,
@@ -109,37 +94,37 @@ def process_video(video_path, question):
109
  temperature=0,
110
  max_new_tokens=4096,
111
  )
112
-
113
- response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
114
- return response
 
 
 
 
 
 
115
 
116
  def gradio_interface(video_file, question):
117
  if video_file is None:
118
- return "Please upload a video file."
119
- response = process_video(video_file, question)
120
- return response
121
 
122
  with gr.Blocks() as demo:
123
  gr.Markdown(title)
124
  with gr.Row():
125
- with gr.Group():
126
- gr.Markdown(description1)
127
- with gr.Group():
128
- gr.Markdown(description2)
129
  with gr.Accordion("Join Us", open=False):
130
  gr.Markdown(join_us)
 
131
  with gr.Row():
132
  with gr.Column():
133
- video_input = gr.Video()
134
- question_input = gr.Textbox(label="πŸ™‹πŸ»β€β™‚οΈUser Question", placeholder="Ask a question about the video...")
135
- submit_button = gr.Button("AskπŸŒ‹πŸ“ΉLLaVA-Video")
136
- output = gr.Textbox(label="πŸŒ‹πŸ“ΉLLaVA-Video")
137
-
138
- submit_button.click(
139
- fn=gradio_interface,
140
- inputs=[video_input, question_input],
141
- outputs=output
142
- )
143
 
144
  if __name__ == "__main__":
145
- demo.launch(show_error=True, ssr_mode = False)
 
1
  import spaces
2
  import gradio as gr
3
+ import subprocess
4
 
5
+ # Install Flash-Attention safely
6
  subprocess.run(
7
  "pip install flash-attn --no-build-isolation",
8
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
9
  shell=True,
10
  )
 
 
 
 
11
 
12
  import torch
13
  from llava.model.builder import load_pretrained_model
14
  from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
15
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
16
+ from llava.conversation import conv_templates
17
  import copy
 
18
  from decord import VideoReader, cpu
19
  import numpy as np
20
+
21
+ # App info
 
 
22
  title = "# πŸ™‹πŸ»β€β™‚οΈWelcome to 🌟Tonic's πŸŒ‹πŸ“ΉLLaVA-Video!"
23
+ description1 ="""**πŸŒ‹πŸ“ΉLLaVA-Video-7B-Qwen2** analyzes visual content and transcribes speech from videos. It supports fine-grained reasoning over video frames using 64 sampled keyframes."""
24
+ description2 ="""**Max Frames**: 64 Β· **Languages**: English, Chinese Β· **Aspect Ratio**: any Β· **Precision**: bfloat16"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  join_us = """
27
  ## Join us :
28
+ 🌟TeamTonic🌟 is always making cool demos! Join our active builder's πŸ› οΈcommunity πŸ‘» [![Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP)
29
  """
30
 
31
+ # ---------- Load & Sample Video ----------
32
+ def load_video(video_path, max_frames_num=64, fps=1, force_sample=True):
 
33
  vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
34
  total_frame_num = len(vr)
35
+ avg_fps = vr.get_avg_fps()
36
+ video_time = total_frame_num / avg_fps
37
+ step = round(avg_fps / fps)
38
+
39
+ frame_idx = list(range(0, len(vr), step))
40
  if len(frame_idx) > max_frames_num or force_sample:
41
+ frame_idx = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int).tolist()
42
+
43
+ frame_time = [i / avg_fps for i in frame_idx]
44
+ frame_time_str = ", ".join([f"{t:.2f}s" for t in frame_time])
45
+
46
+ frames = vr.get_batch(frame_idx).asnumpy()
47
+ return frames, frame_time_str, video_time
48
+
49
+ # ---------- Load LLaVA-Video Model ----------
50
  pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
51
  model_name = "llava_qwen"
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
  device_map = "auto"
54
 
55
  print("Loading model...")
56
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
57
  model.eval()
58
  print("Model loaded successfully!")
59
 
60
+ # ---------- Response Formatter ----------
61
+ import re
62
+ def format_response(response: str):
63
+ actions = re.findall(r"(\d+\.\d+s\s*-\s*\d+\.\d+s:\s*.+)", response)
64
+ speech = re.findall(r"(\d+\.\d+s:\s*.+)", response)
65
+ formatted = "**🟒 Visual Events:**\n" + "\n".join(actions) + "\n\n**πŸ—£οΈ Speech Transcript:**\n" + "\n".join(speech)
66
+ return formatted if actions or speech else response
67
+
68
+ # ---------- Core Inference ----------
69
  @spaces.GPU
70
  def process_video(video_path, question):
71
  max_frames_num = 64
72
  video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
73
+ video_tensor = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
74
+ video = [video_tensor]
75
 
76
+ # Add timing metadata to prompt
77
  conv_template = "qwen_1_5"
78
+ time_instruction = f"The video is {video_time:.2f} seconds long, and {max_frames_num} frames were uniformly sampled at these times: {frame_time}. Analyze them."
79
+
80
  full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
 
81
  conv = copy.deepcopy(conv_templates[conv_template])
82
  conv.append_message(conv.roles[0], full_question)
83
  conv.append_message(conv.roles[1], None)
84
+ prompt = conv.get_prompt()
85
+
86
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
87
+
88
  with torch.no_grad():
89
  output = model.generate(
90
  input_ids,
 
94
  temperature=0,
95
  max_new_tokens=4096,
96
  )
97
+
98
+ raw_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
99
+ return format_response(raw_output)
100
+
101
+ # ---------- Gradio UI ----------
102
+ default_prompt = (
103
+ "Analyze the video frame by frame. For each visible action or change (e.g., motion, expression, object, movement), "
104
+ "output the timestamp and what happens, like '0.0s - 0.1s: man lifts arm'. Also transcribe any spoken dialogue with timestamps in the format '0.0s: speech...'."
105
+ )
106
 
107
  def gradio_interface(video_file, question):
108
  if video_file is None:
109
+ return "❗ Please upload a video."
110
+ return process_video(video_file, question or default_prompt)
 
111
 
112
  with gr.Blocks() as demo:
113
  gr.Markdown(title)
114
  with gr.Row():
115
+ gr.Markdown(description1)
116
+ gr.Markdown(description2)
 
 
117
  with gr.Accordion("Join Us", open=False):
118
  gr.Markdown(join_us)
119
+
120
  with gr.Row():
121
  with gr.Column():
122
+ video_input = gr.Video(label="πŸ“Ή Upload Your Video")
123
+ question_input = gr.Textbox(label="πŸ™‹πŸ»β€β™‚οΈ Your Prompt", value=default_prompt, lines=4)
124
+ submit_button = gr.Button("Analyze with πŸŒ‹πŸ“ΉLLaVA-Video")
125
+ output = gr.Textbox(label="🧠 Result", lines=20)
126
+
127
+ submit_button.click(fn=gradio_interface, inputs=[video_input, question_input], outputs=output)
 
 
 
 
128
 
129
  if __name__ == "__main__":
130
+ demo.launch(show_error=True, ssr_mode=False)