ChaseHan's picture
Update app.py
49fbaa3 verified
raw
history blame
5.95 kB
import gradio as gr
import cv2
import numpy as np
import tempfile
from ultralytics import YOLO
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
# Load the Latex2Layout model for layout detection
latex2layout_model_path = "latex2layout_object_detection_yolov8.pt"
latex2layout_model = YOLO(latex2layout_model_path)
# Download and load the Qwen2.5-VL-3B model
qwen_model_path = "Qwen/Qwen2.5-VL-3B"
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_path, device_map="auto", trust_remote_code=True)
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_path)
def detect_layout(image):
"""
Detect layout elements in the image using the Latex2Layout model.
Args:
image: The uploaded image (numpy array)
Returns:
layout_description: Textual description of detected layout elements
"""
if image is None:
return "Error: No image provided."
# Run layout detection
results = latex2layout_model(image)
result = results[0]
layout_description = []
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
cls_id = int(box.cls[0])
cls_name = result.names[cls_id]
layout_description.append(f"{cls_name} at position ({x1}, {y1}, {x2}, {y2})")
return ", ".join(layout_description) if layout_description else "No elements detected."
def process_image_and_question(image, question):
"""
Process the image with Latex2Layout and answer the question using Qwen2.5-VL.
Args:
image: The uploaded image (numpy array)
question: The user's question (string)
Returns:
annotated_image: Image with detection boxes
response: Answer from Qwen2.5-VL
"""
if image is None or not question:
return None, "Error: Please upload an image and provide a question."
# Convert numpy image to PIL for Qwen2.5-VL
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Detect layout using Latex2Layout
layout_description = detect_layout(image)
# Prepare annotated image
annotated_image = image.copy()
results = latex2layout_model(image)[0]
for box in results.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
conf = float(box.conf[0])
cls_id = int(box.cls[0])
cls_name = results.names[cls_id]
color = tuple(np.random.randint(0, 255, 3).tolist())
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
label = f'{cls_name} {conf:.2f}'
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Prepare input for Qwen2.5-VL
input_text = f"Layout: {layout_description}\nQuestion: {question}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_pil},
{"type": "text", "text": input_text}
]
}
]
# Tokenize and generate response
inputs = qwen_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = qwen_tokenizer([inputs], return_tensors="pt").to(qwen_model.device)
with torch.no_grad():
output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100)
response = qwen_tokenizer.decode(output_ids[0][len(model_inputs["input_ids"][0]):], skip_special_tokens=True)
return annotated_image, response
# Custom CSS for styling
custom_css = """
.container { max-width: 1200px; margin: auto; }
.button-primary { background-color: #4CAF50; color: white; }
.gr-image { border: 2px solid #ddd; border-radius: 5px; }
.gr-textbox { font-family: Arial; }
"""
# Create Gradio interface
with gr.Blocks(
title="Latex2Layout Visual Q&A",
theme=gr.themes.Default(),
css=custom_css
) as demo:
gr.Markdown(
"""
# Latex2Layout Visual Q&A
Upload an image and ask a question about its layout. The **Latex2Layout** model detects elements, and **Qwen2.5-VL** provides answers based on the image and layout information.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Image",
type="numpy",
height=400,
elem_classes="gr-image"
)
question_input = gr.Textbox(
label="Ask a Question",
placeholder="e.g., What elements are in the image?",
lines=2
)
submit_btn = gr.Button(
"Get Answer",
variant="primary",
elem_classes="button-primary"
)
with gr.Column(scale=1):
output_image = gr.Image(
label="Detected Layout",
height=400,
elem_classes="gr-image"
)
output_text = gr.Textbox(
label="Answer",
lines=5,
max_lines=10,
elem_classes="gr-textbox"
)
# Event handler
submit_btn.click(
fn=process_image_and_question,
inputs=[input_image, question_input],
outputs=[output_image, output_text],
_js="() => { document.querySelector('.button-primary').innerText = 'Processing...'; }",
show_progress=True
).then(
fn=lambda: gr.update(value="Get Answer"),
outputs=submit_btn,
_js="() => { document.querySelector('.button-primary').innerText = 'Get Answer'; }"
)
# Launch the application
if __name__ == "__main__":
demo.launch()