Bofeee5675 commited on
Commit
9a092ba
Β·
verified Β·
1 Parent(s): ab6bb61

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +292 -59
  3. requirements.txt +7 -1
.gitattributes CHANGED
@@ -33,6 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
36
  app_store.png filter=lfs diff=lfs merge=lfs -text
37
  apple_music.png filter=lfs diff=lfs merge=lfs -text
38
  safari_google.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
  app_store.png filter=lfs diff=lfs merge=lfs -text
38
  apple_music.png filter=lfs diff=lfs merge=lfs -text
39
  safari_google.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,64 +1,297 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
62
 
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import os
4
+ from datetime import datetime
5
+
6
  import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ from peft import PeftModel
11
+ from PIL import Image, ImageDraw
12
+ from qwen_vl_utils import process_vision_info
13
+ from transformers import (
14
+ AutoProcessor,
15
+ )
16
+
17
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
18
+ from peft.peft_model import PeftModel
19
+
20
+ def load_model_and_processor(model_path, lora_path=None, merge_lora=True):
21
+ """
22
+ Load the Qwen2.5-VL model and processor with optional LoRA weights.
23
+
24
+ Args:
25
+ args: Arguments containing:
26
+ - model_path: Path to the base model
27
+ - precision: Model precision ("fp16", "bf16", or "fp32")
28
+ - lora_path: Path to LoRA weights (optional)
29
+ - merge_lora: Boolean indicating whether to merge LoRA weights
30
+
31
+ Returns:
32
+ tuple: (processor, model) - The initialized processor and model
33
+ """
34
+ # Initialize processor
35
+ try:
36
+ processor = AutoProcessor.from_pretrained(
37
+ model_path,
38
+ min_pixels=256*28*28,
39
+ max_pixels=1344*28*28,
40
+ model_max_length=8196,
41
+ )
42
+ except Exception as e:
43
+ print(f"Error loading processor: {e}")
44
+ processor = None
45
+ config = AutoConfig.from_pretrained(model_path)
46
+ print(config)
47
+ raise e
48
+ # Initialize base model
49
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
50
+ model_path,
51
+ device_map="cpu",
52
+ torch_dtype=torch.bfloat16,
53
+ # attn_implementation="flash_attention_2",
54
+ )
55
+
56
+ # Load LoRA weights if path is provided
57
+ if lora_path is not None and len(lora_path) > 0:
58
+ print(f"Loading LoRA weights from {lora_path}")
59
+ model = PeftModel.from_pretrained(model, lora_path)
60
+
61
+ if merge_lora:
62
+ print("Merging LoRA weights into base model")
63
+ model = model.merge_and_unload()
64
+
65
+ model.eval()
66
+
67
+ return processor, model
68
+ # Define constants
69
+ DESCRIPTION = "[TongUI Demo](https://huggingface.co/datasets/Bofeee5675/TongUI-143K)"
70
+ _SYSTEM = "Based on the screenshot of the page, I give a text description and you give its corresponding location. The coordinate represents a clickable location [x, y] for an element, which is a relative coordinate on the screenshot, scaled from 0 to 1."
71
+ MIN_PIXELS = 256 * 28 * 28
72
+ MAX_PIXELS = 1344 * 28 * 28
73
+
74
+ processor, model = load_model_and_processor(
75
+ model_path="Qwen/Qwen2.5-VL-3B-Instruct",
76
+ lora_path="Bofeee5675/TongUI-3B",
77
+ merge_lora=True,
78
  )
79
+ # Helper functions
80
+ def draw_point(image_input, point=None, radius=5):
81
+ """Draw a point on the image."""
82
+ if isinstance(image_input, str):
83
+ image = Image.open(image_input)
84
+ else:
85
+ image = Image.fromarray(np.uint8(image_input))
86
+
87
+ if point:
88
+ x, y = point[0] * image.width, point[1] * image.height
89
+ ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
90
+ return image
91
+
92
+ def array_to_image_path(image_array):
93
+ """Save the uploaded image and return its path."""
94
+ if image_array is None:
95
+ raise ValueError("No image provided. Please upload an image before submitting.")
96
+ img = Image.fromarray(np.uint8(image_array))
97
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
98
+ filename = f"image_{timestamp}.png"
99
+ img.save(filename)
100
+ return os.path.abspath(filename)
101
+
102
+ @spaces.GPU
103
+ def run_tongui(image, query):
104
+ """Main function for inference."""
105
+ image_path = array_to_image_path(image)
106
+
107
+ messages = [
108
+ {
109
+ "role": "user",
110
+ "content": [
111
+ {"type": "text", "text": _SYSTEM},
112
+ {"type": "image", "image": image_path, "min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS},
113
+ {"type": "text", "text": query}
114
+ ],
115
+ }
116
+ ]
117
+
118
+ # Prepare inputs for the model
119
+
120
+ global model
121
+
122
+ model = model.to("cuda")
123
+
124
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
125
+ image_inputs, video_inputs = process_vision_info(messages)
126
+ inputs = processor(
127
+ text=[text],
128
+ images=image_inputs,
129
+ videos=video_inputs,
130
+ padding=True,
131
+ return_tensors="pt"
132
+ )
133
+ inputs = inputs.to("cuda")
134
+
135
+ # Generate output
136
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
137
+ generated_ids_trimmed = [
138
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
139
+ ]
140
+ output_text = processor.batch_decode(
141
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
142
+ )[0]
143
+
144
+ # Parse the output into coordinates
145
+ click_xy = ast.literal_eval(output_text)
146
+
147
+ # Draw the point on the image
148
+ result_image = draw_point(image_path, click_xy, radius=10)
149
+ return result_image, str(click_xy)
150
+
151
+ # Function to record votes
152
+ def record_vote(vote_type, image_path, query, action_generated):
153
+ """Record a vote in a JSON file."""
154
+ vote_data = {
155
+ "vote_type": vote_type,
156
+ "image_path": image_path,
157
+ "query": query,
158
+ "action_generated": action_generated,
159
+ "timestamp": datetime.now().isoformat()
160
+ }
161
+ with open("votes.json", "a") as f:
162
+ f.write(json.dumps(vote_data) + "\n")
163
+ return f"Your {vote_type} has been recorded. Thank you!"
164
+
165
+ # Helper function to handle vote recording
166
+ def handle_vote(vote_type, image_path, query, action_generated):
167
+ """Handle vote recording by using the consistent image path."""
168
+ if image_path is None:
169
+ return "No image uploaded. Please upload an image before voting."
170
+ return record_vote(vote_type, image_path, query, action_generated)
171
+
172
+
173
+
174
+
175
+ # Define layout and UI
176
+ def build_demo(embed_mode, concurrency_count=1):
177
+ with gr.Blocks(title="TongUI Demo", theme=gr.themes.Default()) as demo:
178
+ # State to store the consistent image path
179
+ state_image_path = gr.State(value=None)
180
+
181
+ if not embed_mode:
182
+ gr.HTML(
183
+ """
184
+ <div style="text-align: center; margin-bottom: 20px;">
185
+ <p>TongUI: Building Generalized GUI Agents by Learning from Multimodal Web Tutorials</p>
186
+ </div>
187
+ """
188
+ )
189
+
190
+ with gr.Row():
191
+ with gr.Column(scale=3):
192
+ # Input components
193
+ imagebox = gr.Image(type="numpy", label="Input Screenshot")
194
+ textbox = gr.Textbox(
195
+ show_label=True,
196
+ placeholder="Enter a query (e.g., 'Click Nahant')",
197
+ label="Query",
198
+ )
199
+ submit_btn = gr.Button(value="Submit", variant="primary")
200
+
201
+ # Placeholder examples
202
+ gr.Examples(
203
+ examples=[
204
+ ["./examples/app_store.png", "Download Kindle."],
205
+ ["./examples/apple_music.png", "Star to favorite."],
206
+ ["./examples/safari_google.png", "Click on search bar."],
207
+ ],
208
+ inputs=[imagebox, textbox],
209
+ examples_per_page=3
210
+ )
211
+
212
+ with gr.Column(scale=8):
213
+ # Output components
214
+ output_img = gr.Image(type="pil", label="Output Image")
215
+ # Add a note below the image to explain the red point
216
+ gr.HTML(
217
+ """
218
+ <p><strong>Note:</strong> The <span style="color: red;">red point</span> on the output image represents the predicted clickable coordinates.</p>
219
+ """
220
+ )
221
+ output_coords = gr.Textbox(label="Clickable Coordinates")
222
+
223
+ # Buttons for voting, flagging, regenerating, and clearing
224
+ with gr.Row(elem_id="action-buttons", equal_height=True):
225
+ vote_btn = gr.Button(value="πŸ‘ Vote", variant="secondary")
226
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", variant="secondary")
227
+ flag_btn = gr.Button(value="🚩 Flag", variant="secondary")
228
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", variant="secondary")
229
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=True) # Combined Clear button
230
+
231
+ # Define button actions
232
+ def on_submit(image, query):
233
+ """Handle the submit button click."""
234
+ if image is None:
235
+ raise ValueError("No image provided. Please upload an image before submitting.")
236
+
237
+ # Generate consistent image path and store it in the state
238
+ image_path = array_to_image_path(image)
239
+ return run_tongui(image, query) + (image_path,)
240
+
241
+ submit_btn.click(
242
+ on_submit,
243
+ [imagebox, textbox],
244
+ [output_img, output_coords, state_image_path],
245
+ )
246
+
247
+ clear_btn.click(
248
+ lambda: (None, None, None, None, None),
249
+ inputs=None,
250
+ outputs=[imagebox, textbox, output_img, output_coords, state_image_path], # Clear all outputs
251
+ queue=False
252
+ )
253
+
254
+ regenerate_btn.click(
255
+ lambda image, query, state_image_path: run_tongui(image, query),
256
+ [imagebox, textbox, state_image_path],
257
+ [output_img, output_coords],
258
+ )
259
+
260
+ # Record vote actions without feedback messages
261
+ vote_btn.click(
262
+ lambda image_path, query, action_generated: handle_vote(
263
+ "upvote", image_path, query, action_generated
264
+ ),
265
+ inputs=[state_image_path, textbox, output_coords],
266
+ outputs=[],
267
+ queue=False
268
+ )
269
+
270
+ downvote_btn.click(
271
+ lambda image_path, query, action_generated: handle_vote(
272
+ "downvote", image_path, query, action_generated
273
+ ),
274
+ inputs=[state_image_path, textbox, output_coords],
275
+ outputs=[],
276
+ queue=False
277
+ )
278
 
279
+ flag_btn.click(
280
+ lambda image_path, query, action_generated: handle_vote(
281
+ "flag", image_path, query, action_generated
282
+ ),
283
+ inputs=[state_image_path, textbox, output_coords],
284
+ outputs=[],
285
+ queue=False
286
+ )
287
 
288
+ return demo
289
+ # Launch the app
290
  if __name__ == "__main__":
291
+ demo = build_demo(embed_mode=False)
292
+ demo.queue(api_open=False).launch(
293
+ server_name="0.0.0.0",
294
+ server_port=7860,
295
+ ssr_mode=False,
296
+ debug=True,
297
+ )
requirements.txt CHANGED
@@ -1 +1,7 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
1
+ huggingface_hub>=0.30.0
2
+ numpy
3
+ torch
4
+ peft
5
+ qwen_vl_utils
6
+ torchvision
7
+ git+https://github.com/huggingface/transformers.git