Spaces:
Sleeping
Sleeping
import gradio as gr | |
import cv2 | |
import numpy as np | |
import tempfile | |
from ultralytics import YOLO | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import torch | |
# Load the Latex2Layout model for layout detection | |
latex2layout_model_path = "latex2layout_object_detection_yolov8.pt" | |
latex2layout_model = YOLO(latex2layout_model_path) | |
# Download and load the Qwen2.5-VL-3B model | |
qwen_model_path = "Qwen/Qwen2.5-VL-3B" | |
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_path, device_map="auto", trust_remote_code=True) | |
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_path) | |
def detect_layout(image): | |
""" | |
Detect layout elements in the image using the Latex2Layout model. | |
Args: | |
image: The uploaded image (numpy array) | |
Returns: | |
layout_description: Textual description of detected layout elements | |
""" | |
if image is None: | |
return "Error: No image provided." | |
# Run layout detection | |
results = latex2layout_model(image) | |
result = results[0] | |
layout_description = [] | |
for box in result.boxes: | |
x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy()) | |
cls_id = int(box.cls[0]) | |
cls_name = result.names[cls_id] | |
layout_description.append(f"{cls_name} at position ({x1}, {y1}, {x2}, {y2})") | |
return ", ".join(layout_description) if layout_description else "No elements detected." | |
def process_image_and_question(image, question): | |
""" | |
Process the image with Latex2Layout and answer the question using Qwen2.5-VL. | |
Args: | |
image: The uploaded image (numpy array) | |
question: The user's question (string) | |
Returns: | |
annotated_image: Image with detection boxes | |
response: Answer from Qwen2.5-VL | |
""" | |
if image is None or not question: | |
return None, "Error: Please upload an image and provide a question." | |
# Convert numpy image to PIL for Qwen2.5-VL | |
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
# Detect layout using Latex2Layout | |
layout_description = detect_layout(image) | |
# Prepare annotated image | |
annotated_image = image.copy() | |
results = latex2layout_model(image)[0] | |
for box in results.boxes: | |
x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy()) | |
conf = float(box.conf[0]) | |
cls_id = int(box.cls[0]) | |
cls_name = results.names[cls_id] | |
color = tuple(np.random.randint(0, 255, 3).tolist()) | |
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2) | |
label = f'{cls_name} {conf:.2f}' | |
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1) | |
cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
# Prepare input for Qwen2.5-VL | |
input_text = f"Layout: {layout_description}\nQuestion: {question}" | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image_pil}, | |
{"type": "text", "text": input_text} | |
] | |
} | |
] | |
# Tokenize and generate response | |
inputs = qwen_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
model_inputs = qwen_tokenizer([inputs], return_tensors="pt").to(qwen_model.device) | |
with torch.no_grad(): | |
output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100) | |
response = qwen_tokenizer.decode(output_ids[0][len(model_inputs["input_ids"][0]):], skip_special_tokens=True) | |
return annotated_image, response | |
# Custom CSS for styling | |
custom_css = """ | |
.container { max-width: 1200px; margin: auto; } | |
.button-primary { background-color: #4CAF50; color: white; } | |
.gr-image { border: 2px solid #ddd; border-radius: 5px; } | |
.gr-textbox { font-family: Arial; } | |
""" | |
# Create Gradio interface | |
with gr.Blocks( | |
title="Latex2Layout Visual Q&A", | |
theme=gr.themes.Default(), | |
css=custom_css | |
) as demo: | |
gr.Markdown( | |
""" | |
# Latex2Layout Visual Q&A | |
Upload an image and ask a question about its layout. The **Latex2Layout** model detects elements, and **Qwen2.5-VL** provides answers based on the image and layout information. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image( | |
label="Upload Image", | |
type="numpy", | |
height=400, | |
elem_classes="gr-image" | |
) | |
question_input = gr.Textbox( | |
label="Ask a Question", | |
placeholder="e.g., What elements are in the image?", | |
lines=2 | |
) | |
submit_btn = gr.Button( | |
"Get Answer", | |
variant="primary", | |
elem_classes="button-primary" | |
) | |
with gr.Column(scale=1): | |
output_image = gr.Image( | |
label="Detected Layout", | |
height=400, | |
elem_classes="gr-image" | |
) | |
output_text = gr.Textbox( | |
label="Answer", | |
lines=5, | |
max_lines=10, | |
elem_classes="gr-textbox" | |
) | |
# Event handler | |
submit_btn.click( | |
fn=process_image_and_question, | |
inputs=[input_image, question_input], | |
outputs=[output_image, output_text], | |
_js="() => { document.querySelector('.button-primary').innerText = 'Processing...'; }", | |
show_progress=True | |
).then( | |
fn=lambda: gr.update(value="Get Answer"), | |
outputs=submit_btn, | |
_js="() => { document.querySelector('.button-primary').innerText = 'Get Answer'; }" | |
) | |
# Launch the application | |
if __name__ == "__main__": | |
demo.launch() | |