TongUI / app.py
zhangbofei
feat: push
a440324
import ast
import json
import os
from datetime import datetime
import gradio as gr
import numpy as np
import spaces
import torch
from peft import PeftModel
from PIL import Image, ImageDraw
from qwen_vl_utils import process_vision_info
from transformers import (
AutoProcessor,
)
from omegaconf import OmegaConf
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
from peft.peft_model import PeftModel
config = OmegaConf.load("app_config.yaml")
def load_model_and_processor(model_path, lora_path=None, merge_lora=True):
"""
Load the Qwen2.5-VL model and processor with optional LoRA weights.
Args:
args: Arguments containing:
- model_path: Path to the base model
- precision: Model precision ("fp16", "bf16", or "fp32")
- lora_path: Path to LoRA weights (optional)
- merge_lora: Boolean indicating whether to merge LoRA weights
Returns:
tuple: (processor, model) - The initialized processor and model
"""
# Initialize processor
try:
processor = AutoProcessor.from_pretrained(
model_path,
min_pixels=256*28*28,
max_pixels=1344*28*28,
model_max_length=8196,
)
except Exception as e:
print(f"Error loading processor: {e}")
processor = None
config = AutoConfig.from_pretrained(model_path)
print(config)
raise e
# Initialize base model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="cpu",
torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
)
# Load LoRA weights if path is provided
if lora_path is not None and len(lora_path) > 0:
print(f"Loading LoRA weights from {lora_path}")
model = PeftModel.from_pretrained(model, lora_path)
if merge_lora:
print("Merging LoRA weights into base model")
model = model.merge_and_unload()
model.eval()
return processor, model
# Define constants
DESCRIPTION = "[TongUI Demo](https://huggingface.co/datasets/Bofeee5675/TongUI-143K)"
_SYSTEM = "Based on the screenshot of the page, I give a text description and you give its corresponding location. The coordinate represents a clickable location [x, y] for an element, which is a relative coordinate on the screenshot, scaled from 0 to 1."
MIN_PIXELS = 256 * 28 * 28
MAX_PIXELS = 1344 * 28 * 28
processor, model = load_model_and_processor(
model_path=config.model,
lora_path=config.lora_path,
merge_lora=True,
)
# Helper functions
def draw_point(image_input, point=None, radius=5):
"""Draw a point on the image."""
if isinstance(image_input, str):
image = Image.open(image_input)
else:
image = Image.fromarray(np.uint8(image_input))
if point:
x, y = point[0] * image.width, point[1] * image.height
ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
return image
def array_to_image_path(image_array):
"""Save the uploaded image and return its path."""
if image_array is None:
raise ValueError("No image provided. Please upload an image before submitting.")
img = Image.fromarray(np.uint8(image_array))
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.png"
img.save(filename)
return os.path.abspath(filename)
@spaces.GPU
def run_tongui(image, query):
"""Main function for inference."""
image_path = array_to_image_path(image)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": _SYSTEM},
{"type": "image", "image": image_path, "min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS},
{"type": "text", "text": query}
],
}
]
# Prepare inputs for the model
global model
model = model.to("cuda")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
inputs = inputs.to("cuda")
# Generate output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Parse the output into coordinates
click_xy = ast.literal_eval(output_text)
# Draw the point on the image
result_image = draw_point(image_path, click_xy, radius=10)
return result_image, str(click_xy)
# Function to record votes
def record_vote(vote_type, image_path, query, action_generated):
"""Record a vote in a JSON file."""
vote_data = {
"vote_type": vote_type,
"image_path": image_path,
"query": query,
"action_generated": action_generated,
"timestamp": datetime.now().isoformat()
}
with open("votes.json", "a") as f:
f.write(json.dumps(vote_data) + "\n")
return f"Your {vote_type} has been recorded. Thank you!"
# Helper function to handle vote recording
def handle_vote(vote_type, image_path, query, action_generated):
"""Handle vote recording by using the consistent image path."""
if image_path is None:
return "No image uploaded. Please upload an image before voting."
return record_vote(vote_type, image_path, query, action_generated)
# Define layout and UI
def build_demo(embed_mode, concurrency_count=1):
with gr.Blocks(title="TongUI Demo", theme=gr.themes.Default()) as demo:
# State to store the consistent image path
state_image_path = gr.State(value=None)
if not embed_mode:
gr.HTML(
"""
<div style="text-align: center; margin-bottom: 20px;">
<p>TongUI: Building Generalized GUI Agents by Learning from Multimodal Web Tutorials</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=3):
# Input components
imagebox = gr.Image(type="numpy", label="Input Screenshot")
textbox = gr.Textbox(
show_label=True,
placeholder="Enter a query (e.g., 'Click Nahant')",
label="Query",
)
submit_btn = gr.Button(value="Submit", variant="primary")
# Placeholder examples
gr.Examples(
examples=[
["./examples/app_store.png", "Download Kindle."],
["./examples/apple_music.png", "Star to favorite."],
["./examples/safari_google.png", "Click on search bar."],
],
inputs=[imagebox, textbox],
examples_per_page=3
)
with gr.Column(scale=8):
# Output components
output_img = gr.Image(type="pil", label="Output Image")
# Add a note below the image to explain the red point
gr.HTML(
"""
<p><strong>Note:</strong> The <span style="color: red;">red point</span> on the output image represents the predicted clickable coordinates.</p>
"""
)
output_coords = gr.Textbox(label="Clickable Coordinates")
# Buttons for voting, flagging, regenerating, and clearing
with gr.Row(elem_id="action-buttons", equal_height=True):
vote_btn = gr.Button(value="πŸ‘ Vote", variant="secondary")
downvote_btn = gr.Button(value="πŸ‘Ž Downvote", variant="secondary")
flag_btn = gr.Button(value="🚩 Flag", variant="secondary")
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", variant="secondary")
clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=True) # Combined Clear button
# Define button actions
def on_submit(image, query):
"""Handle the submit button click."""
if image is None:
raise ValueError("No image provided. Please upload an image before submitting.")
# Generate consistent image path and store it in the state
image_path = array_to_image_path(image)
return run_tongui(image, query) + (image_path,)
submit_btn.click(
on_submit,
[imagebox, textbox],
[output_img, output_coords, state_image_path],
)
clear_btn.click(
lambda: (None, None, None, None, None),
inputs=None,
outputs=[imagebox, textbox, output_img, output_coords, state_image_path], # Clear all outputs
queue=False
)
regenerate_btn.click(
lambda image, query, state_image_path: run_tongui(image, query),
[imagebox, textbox, state_image_path],
[output_img, output_coords],
)
# Record vote actions without feedback messages
vote_btn.click(
lambda image_path, query, action_generated: handle_vote(
"upvote", image_path, query, action_generated
),
inputs=[state_image_path, textbox, output_coords],
outputs=[],
queue=False
)
downvote_btn.click(
lambda image_path, query, action_generated: handle_vote(
"downvote", image_path, query, action_generated
),
inputs=[state_image_path, textbox, output_coords],
outputs=[],
queue=False
)
flag_btn.click(
lambda image_path, query, action_generated: handle_vote(
"flag", image_path, query, action_generated
),
inputs=[state_image_path, textbox, output_coords],
outputs=[],
queue=False
)
return demo
# Launch the app
if __name__ == "__main__":
demo = build_demo(embed_mode=False)
demo.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
debug=True,
)