Llava-Video / app.py
juvix's picture
Loads and samples video frames with accurate timestamps Sends a precise multi-task prompt to LLaVA (frame analysis + transcription) Extracts and formats the output cleanly into visual events and speech Uses a default prompt to auto-analyze the uploaded video
9558bd2 verified
raw
history blame
5.38 kB
import spaces
import gradio as gr
import subprocess
# Install Flash-Attention safely
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import torch
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
import copy
from decord import VideoReader, cpu
import numpy as np
# App info
title = "# ๐Ÿ™‹๐Ÿปโ€โ™‚๏ธWelcome to ๐ŸŒŸTonic's ๐ŸŒ‹๐Ÿ“นLLaVA-Video!"
description1 ="""**๐ŸŒ‹๐Ÿ“นLLaVA-Video-7B-Qwen2** analyzes visual content and transcribes speech from videos. It supports fine-grained reasoning over video frames using 64 sampled keyframes."""
description2 ="""**Max Frames**: 64 ยท **Languages**: English, Chinese ยท **Aspect Ratio**: any ยท **Precision**: bfloat16"""
join_us = """
## Join us :
๐ŸŒŸTeamTonic๐ŸŒŸ is always making cool demos! Join our active builder's ๐Ÿ› ๏ธcommunity ๐Ÿ‘ป [![Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP)
"""
# ---------- Load & Sample Video ----------
def load_video(video_path, max_frames_num=64, fps=1, force_sample=True):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
total_frame_num = len(vr)
avg_fps = vr.get_avg_fps()
video_time = total_frame_num / avg_fps
step = round(avg_fps / fps)
frame_idx = list(range(0, len(vr), step))
if len(frame_idx) > max_frames_num or force_sample:
frame_idx = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int).tolist()
frame_time = [i / avg_fps for i in frame_idx]
frame_time_str = ", ".join([f"{t:.2f}s" for t in frame_time])
frames = vr.get_batch(frame_idx).asnumpy()
return frames, frame_time_str, video_time
# ---------- Load LLaVA-Video Model ----------
pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
model_name = "llava_qwen"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_map = "auto"
print("Loading model...")
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
model.eval()
print("Model loaded successfully!")
# ---------- Response Formatter ----------
import re
def format_response(response: str):
actions = re.findall(r"(\d+\.\d+s\s*-\s*\d+\.\d+s:\s*.+)", response)
speech = re.findall(r"(\d+\.\d+s:\s*.+)", response)
formatted = "**๐ŸŸข Visual Events:**\n" + "\n".join(actions) + "\n\n**๐Ÿ—ฃ๏ธ Speech Transcript:**\n" + "\n".join(speech)
return formatted if actions or speech else response
# ---------- Core Inference ----------
@spaces.GPU
def process_video(video_path, question):
max_frames_num = 64
video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
video_tensor = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
video = [video_tensor]
# Add timing metadata to prompt
conv_template = "qwen_1_5"
time_instruction = f"The video is {video_time:.2f} seconds long, and {max_frames_num} frames were uniformly sampled at these times: {frame_time}. Analyze them."
full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], full_question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
with torch.no_grad():
output = model.generate(
input_ids,
images=video,
modalities=["video"],
do_sample=False,
temperature=0,
max_new_tokens=4096,
)
raw_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
return format_response(raw_output)
# ---------- Gradio UI ----------
default_prompt = (
"Analyze the video frame by frame. For each visible action or change (e.g., motion, expression, object, movement), "
"output the timestamp and what happens, like '0.0s - 0.1s: man lifts arm'. Also transcribe any spoken dialogue with timestamps in the format '0.0s: speech...'."
)
def gradio_interface(video_file, question):
if video_file is None:
return "โ— Please upload a video."
return process_video(video_file, question or default_prompt)
with gr.Blocks() as demo:
gr.Markdown(title)
with gr.Row():
gr.Markdown(description1)
gr.Markdown(description2)
with gr.Accordion("Join Us", open=False):
gr.Markdown(join_us)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="๐Ÿ“น Upload Your Video")
question_input = gr.Textbox(label="๐Ÿ™‹๐Ÿปโ€โ™‚๏ธ Your Prompt", value=default_prompt, lines=4)
submit_button = gr.Button("Analyze with ๐ŸŒ‹๐Ÿ“นLLaVA-Video")
output = gr.Textbox(label="๐Ÿง  Result", lines=20)
submit_button.click(fn=gradio_interface, inputs=[video_input, question_input], outputs=output)
if __name__ == "__main__":
demo.launch(show_error=True, ssr_mode=False)