htian01's picture
Update app.py
3fb0bbe verified
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
import numpy as np
from PIL import Image
import html # <--- 1. 导入 html 模块
def progress_bar_html(label: str) -> str:
"""
Returns an HTML snippet for a thin progress bar with a label.
The progress bar is styled as a dark animated bar.
"""
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
def downsample_video(video_path):
"""
Downsamples the video to 10 evenly spaced frames.
Each frame is converted to a PIL Image along with its timestamp.
"""
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
if total_frames <= 0 or fps <= 0:
vidcap.release()
return frames
# Sample 10 evenly spaced frames.
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
MODEL_ID = "XiaomiMiMo/MiMo-VL-7B-RL" # Alternatively: "XiaomiMiMo/MiMo-VL-7B-RL"
# MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "XiaomiMiMo/MiMo-VL-7B-RL"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
print(f"Successfully load the model: {model}")
@spaces.GPU
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
return
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
return
messages = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
],
}
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt],
images=images if images else None,
return_tensors="pt",
padding=True,
).to("cuda")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=4096)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
yield progress_bar_html("Processing with MiMo-VL-7B-RL Model")
for new_text in streamer:
escaped_new_text = html.escape(new_text)
buffer += escaped_new_text
time.sleep(0.01)
yield buffer
examples = [
[{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
[{"text": "How many dog in the image?", "files": ["example_images/dogs.jpg"]}],
[{"text": "How many dog in the image? Count the dog by grounding. Put the numberin \\boxed{}", "files": ["example_images/dogs.jpg"]}],
]
demo = gr.ChatInterface(
fn=model_inference,
description="# **Qwen2.5-VL-7B-Instruct `@video-infer for video understanding`**",
examples=examples,
fill_height=True,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
)
demo.launch(debug=True)