File size: 5,380 Bytes
eec1ab5
17d4493
9558bd2
eec1ab5
9558bd2
21fd64b
 
 
 
 
621c8c9
17d4493
 
 
9558bd2
 
17d4493
 
 
9558bd2
 
bb1a1ab
9558bd2
 
bb1a1ab
fd74da5
bb1a1ab
9558bd2
bb1a1ab
17d4493
9558bd2
 
17d4493
 
9558bd2
 
 
 
 
17d4493
9558bd2
 
 
 
 
 
 
 
 
17d4493
 
 
 
 
 
9558bd2
17d4493
 
 
9558bd2
 
 
 
 
 
 
 
 
21fd64b
17d4493
 
 
9558bd2
 
17d4493
9558bd2
17d4493
9558bd2
 
17d4493
 
 
 
9558bd2
 
 
 
17d4493
 
 
 
 
 
 
 
 
9558bd2
 
 
 
 
 
 
 
 
17d4493
 
 
9558bd2
 
17d4493
 
fd74da5
 
9558bd2
 
fd74da5
 
9558bd2
bb1a1ab
 
9558bd2
 
 
 
 
 
17d4493
 
9558bd2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import spaces
import gradio as gr
import subprocess

# Install Flash-Attention safely
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

import torch
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
import copy
from decord import VideoReader, cpu
import numpy as np

# App info
title = "# πŸ™‹πŸ»β€β™‚οΈWelcome to 🌟Tonic's πŸŒ‹πŸ“ΉLLaVA-Video!"
description1 ="""**πŸŒ‹πŸ“ΉLLaVA-Video-7B-Qwen2** analyzes visual content and transcribes speech from videos. It supports fine-grained reasoning over video frames using 64 sampled keyframes."""
description2 ="""**Max Frames**: 64 Β· **Languages**: English, Chinese Β· **Aspect Ratio**: any Β· **Precision**: bfloat16"""

join_us = """
## Join us :
🌟TeamTonic🌟 is always making cool demos! Join our active builder's πŸ› οΈcommunity πŸ‘» [![Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP)
"""

# ---------- Load & Sample Video ----------
def load_video(video_path, max_frames_num=64, fps=1, force_sample=True):
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    total_frame_num = len(vr)
    avg_fps = vr.get_avg_fps()
    video_time = total_frame_num / avg_fps
    step = round(avg_fps / fps)

    frame_idx = list(range(0, len(vr), step))
    if len(frame_idx) > max_frames_num or force_sample:
        frame_idx = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int).tolist()
    
    frame_time = [i / avg_fps for i in frame_idx]
    frame_time_str = ", ".join([f"{t:.2f}s" for t in frame_time])
    
    frames = vr.get_batch(frame_idx).asnumpy()
    return frames, frame_time_str, video_time

# ---------- Load LLaVA-Video Model ----------
pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
model_name = "llava_qwen"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_map = "auto"

print("Loading model...")
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
model.eval()
print("Model loaded successfully!")

# ---------- Response Formatter ----------
import re
def format_response(response: str):
    actions = re.findall(r"(\d+\.\d+s\s*-\s*\d+\.\d+s:\s*.+)", response)
    speech = re.findall(r"(\d+\.\d+s:\s*.+)", response)
    formatted = "**🟒 Visual Events:**\n" + "\n".join(actions) + "\n\n**πŸ—£οΈ Speech Transcript:**\n" + "\n".join(speech)
    return formatted if actions or speech else response

# ---------- Core Inference ----------
@spaces.GPU
def process_video(video_path, question):
    max_frames_num = 64
    video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
    video_tensor = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
    video = [video_tensor]

    # Add timing metadata to prompt
    conv_template = "qwen_1_5"
    time_instruction = f"The video is {video_time:.2f} seconds long, and {max_frames_num} frames were uniformly sampled at these times: {frame_time}. Analyze them."

    full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
    conv = copy.deepcopy(conv_templates[conv_template])
    conv.append_message(conv.roles[0], full_question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)

    with torch.no_grad():
        output = model.generate(
            input_ids,
            images=video,
            modalities=["video"],
            do_sample=False,
            temperature=0,
            max_new_tokens=4096,
        )

    raw_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
    return format_response(raw_output)

# ---------- Gradio UI ----------
default_prompt = (
    "Analyze the video frame by frame. For each visible action or change (e.g., motion, expression, object, movement), "
    "output the timestamp and what happens, like '0.0s - 0.1s: man lifts arm'. Also transcribe any spoken dialogue with timestamps in the format '0.0s: speech...'."
)

def gradio_interface(video_file, question):
    if video_file is None:
        return "❗ Please upload a video."
    return process_video(video_file, question or default_prompt)

with gr.Blocks() as demo:
    gr.Markdown(title)
    with gr.Row():
        gr.Markdown(description1)
        gr.Markdown(description2)
    with gr.Accordion("Join Us", open=False):
        gr.Markdown(join_us)

    with gr.Row():
        with gr.Column():
            video_input = gr.Video(label="πŸ“Ή Upload Your Video")
            question_input = gr.Textbox(label="πŸ™‹πŸ»β€β™‚οΈ Your Prompt", value=default_prompt, lines=4)
            submit_button = gr.Button("Analyze with πŸŒ‹πŸ“ΉLLaVA-Video")
        output = gr.Textbox(label="🧠 Result", lines=20)

    submit_button.click(fn=gradio_interface, inputs=[video_input, question_input], outputs=output)

if __name__ == "__main__":
    demo.launch(show_error=True, ssr_mode=False)