wondervictor commited on
Commit
3ed4749
·
0 Parent(s):
Files changed (7) hide show
  1. .gitattributes +44 -0
  2. .gitignore +165 -0
  3. README.md +14 -0
  4. app.py +533 -0
  5. infer.py +299 -0
  6. requirements.txt +7 -0
  7. visualizer.py +126 -0
.gitattributes ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/I7pTpMjqNRM_1080p_small.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/interface.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/newyork.jpg filter=lfs diff=lfs merge=lfs -text
39
+ examples/puzzle.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/000000001000.jpeg filter=lfs diff=lfs merge=lfs -text
41
+ examples/000000018380.jpeg filter=lfs diff=lfs merge=lfs -text
42
+ examples/bancopy.jpg filter=lfs diff=lfs merge=lfs -text
43
+ examples/beijing.jpg filter=lfs diff=lfs merge=lfs -text
44
+ simhei.ttf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .DS_Store
163
+ video_frames
164
+ examples
165
+ simhei.ttf
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Seed1.5 VL
3
+ emoji: 🚀
4
+ colorFrom: green
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Seed1.5-VL API Demo
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) [Seed-VL-Cookbook] Bytedance Seed
2
+ import os
3
+ import re
4
+ import cv2
5
+ import json
6
+ import time
7
+ import numpy as np
8
+ import gradio as gr
9
+ from infer import SeedVLInfer, ConversationModeI18N, ConversationModeCN
10
+ from visualizer import draw_boxes_points_with_labels
11
+
12
+ infer = SeedVLInfer(model_id=os.getenv('MODEL_ID'), api_key=os.getenv('API_KEY'))
13
+
14
+ label_translations = {
15
+ "gr_chatinterface_ofl": {
16
+ "English": "Chatbot",
17
+ "中文": "对话界面"
18
+ },
19
+ "gr_chatinterface_ol": {
20
+ "English": "Chatbot",
21
+ "中文": "对话界面"
22
+ },
23
+ "gr_tab_ol": {
24
+ "English": "Online",
25
+ "中文": "在线模式"
26
+ },
27
+ "gr_tab_ofl": {
28
+ "English": "Offline",
29
+ "中文": "离线模式"
30
+ },
31
+ "gr_thinking": {
32
+ "English": ConversationModeI18N.D,
33
+ "中文": ConversationModeCN.D,
34
+ },
35
+ "gr_temperature": {
36
+ "English": "Temperature",
37
+ "中文": "温度系数"
38
+ },
39
+ "gr_webcam_image": {
40
+ "English": "🤳 Open Webcam",
41
+ "中文": "🤳 打开摄像头"
42
+ },
43
+ "gr_webcam_images": {
44
+ "English": "📹 Recorded Frames",
45
+ "中文": "📹 录制的视频帧"
46
+ },
47
+ "gr_chatinterface_ofl.textbox.placeholder": {
48
+ "English":
49
+ "Ask me anything. You can also drop in images and .mp4 videos.",
50
+ "中文": "有什么想问的?支持上传图片和.mp4视频。"
51
+ },
52
+ "gr_chatinterface_ol.textbox.placeholder": {
53
+ "English": "Ask me anything...",
54
+ "中文": "有什么想问的?"
55
+ },
56
+ "gr_clear_button": {
57
+ "English": "🧹 Clear History",
58
+ "中文": "🧹 清除历史对话"
59
+ }
60
+ }
61
+
62
+ def add_escape(text: str):
63
+ return text.replace('<', '\<').replace('>', '\>')
64
+
65
+ def remove_escape(text: str):
66
+ return text.replace('\<', '<').replace('\>', '>')
67
+
68
+ def plot_boxes_points_detections(image_path, message):
69
+ detection_pattern = r'\[\s*{.*?}\s*\]'
70
+ detection_matches = re.finditer(detection_pattern, message, flags=re.DOTALL)
71
+ bboxes, categories = [], []
72
+ for match in detection_matches:
73
+ matched_str = match.group(0)
74
+ detections = json.loads(matched_str)
75
+ for detection in detections:
76
+ cat, bbox_str = detection['category'], detection['bbox']
77
+ bbox_str = bbox_str.replace('<bbox>', '').replace('</bbox>', '').replace('</bbox', '')
78
+ bbox = list(map(float, bbox_str.split(' ')))
79
+ bboxes.append(bbox)
80
+ categories.append(cat)
81
+ if not bboxes:
82
+ box_pattern = r'<bbox>(\d+(?:\.\d+)?)\s+(\d+(?:\.\d+)?)\s+(\d+(?:\.\d+)?)\s+(\d+(?:\.\d+)?)</bbox>'
83
+ box_matches = re.finditer(box_pattern, message)
84
+ bboxes = [
85
+ [float(match.group(1)), float(match.group(2)),
86
+ float(match.group(3)), float(match.group(4))]
87
+ for match in box_matches
88
+ ]
89
+
90
+ points = []
91
+ if not bboxes:
92
+ point_pattern = r'<point>(\d+(?:\.\d+)?)\s+(\d+(?:\.\d+)?)</point>'
93
+ point_matches = re.finditer(point_pattern, message)
94
+ points = [
95
+ [float(match.group(1)), float(match.group(2))]
96
+ for match in point_matches
97
+ ]
98
+
99
+ if not bboxes and not points:
100
+ return
101
+
102
+ bboxes = np.array(bboxes, dtype='float') / 1000
103
+ points = np.array(points, dtype='float') / 1000
104
+
105
+ image = cv2.imread(image_path)
106
+ h, w, c = image.shape
107
+ if bboxes.size:
108
+ bboxes[:, 0::2] *= w
109
+ bboxes[:, 1::2] *= h
110
+ if points.size:
111
+ points[:, 0] *= w
112
+ points[:, 1] *= h
113
+ output_image = draw_boxes_points_with_labels(image, bboxes, points, categories)
114
+ return output_image
115
+
116
+ def general_chat(inputs: dict, gr_history: list, infer_history: list,
117
+ if_thinking: bool, temperature: float, online: bool = False):
118
+ if 'text' in inputs:
119
+ inputs['text'] = remove_escape(inputs['text'])
120
+ mode = ConversationModeI18N.D if if_thinking else ConversationModeI18N.G
121
+ for response_text, infer_history, finished in infer(inputs=inputs,
122
+ history=infer_history,
123
+ mode=mode,
124
+ temperature=temperature,
125
+ online=online):
126
+ if if_thinking:
127
+ reasoning_text, response_text = response_text.split('</think>')
128
+ reasoning_text = reasoning_text.lstrip('<think>')
129
+ response_message = [{
130
+ "role": "assistant",
131
+ "content": add_escape(reasoning_text),
132
+ 'metadata': {
133
+ 'title': '🤔 Thinking'
134
+ }
135
+ }, {
136
+ "role": "assistant",
137
+ "content": add_escape(response_text)
138
+ }]
139
+ else:
140
+ response_message = [{
141
+ "role": "assistant",
142
+ "content": add_escape(response_text)
143
+ }]
144
+ if finished and len(inputs.get('files', [])) == 1 and not inputs['files'][0].endswith('.mp4'):
145
+ image_path = inputs['files'][0]
146
+ response_text = infer_history[-1]['content']
147
+ try:
148
+ if if_thinking:
149
+ reasoning_text, response_text = response_text.split('</think>')
150
+ output_image = plot_boxes_points_detections(image_path, response_text)
151
+ if output_image is not None:
152
+ response_message.append({
153
+ "role": "assistant",
154
+ "content": gr.Image(output_image),
155
+ })
156
+ except Exception as e:
157
+ print(e)
158
+ yield response_message, infer_history
159
+
160
+ def online_record_chat(text: str, gr_history: list, gr_webcam_images: list,
161
+ gr_counter: int, infer_history: list, if_thinking: bool,
162
+ temperature: float):
163
+ if not gr_webcam_images:
164
+ gr_webcam_images = []
165
+ gr_webcam_images = gr_webcam_images[gr_counter:]
166
+ inputs = {'text': text, 'files': [webp for webp, _ in gr_webcam_images]}
167
+ yield f'received {len(gr_webcam_images)} new frames, processing...', gr_counter + len(
168
+ gr_webcam_images), infer_history
169
+ for response_message, infer_history in general_chat(
170
+ inputs, gr_history, infer_history, if_thinking, temperature, online=True):
171
+ yield response_message, gr.skip(), infer_history
172
+
173
+
174
+ with gr.Blocks() as demo:
175
+ with gr.Column():
176
+ gr_title = gr.Markdown('# Seed1.5-VL')
177
+ with gr.Row():
178
+ gr.Markdown(
179
+ """
180
+ <div style="display:flex; flex-direction:column; gap:10px;">
181
+ <a
182
+ href="https://github.com/ByteDance-Seed/Seed1.5-VL"
183
+ target="_blank"
184
+ style="
185
+ display: inline-flex;
186
+ align-items: center;
187
+ gap: 8px;
188
+ white-space: nowrap;
189
+ text-decoration: none;
190
+ "
191
+ >
192
+ <img
193
+ src="https://cdn.jsdelivr.net/gh/devicons/devicon/icons/github/github-original.svg"
194
+ alt="GitHub"
195
+ width="24"
196
+ >
197
+ Seed1.5-VL Cookbook
198
+ </a>
199
+ </div>
200
+ """
201
+ )
202
+ gr.Markdown(
203
+ """
204
+ <div style="display:flex; flex-direction:column; gap:10px;">
205
+ <a
206
+ href="https://huggingface.co/papers/2505.07062"
207
+ target="_blank"
208
+ style="
209
+ display: inline-flex;
210
+ align-items: center;
211
+ gap: 8px;
212
+ white-space: nowrap;
213
+ text-decoration: none;
214
+ "
215
+ >
216
+ <img
217
+ src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
218
+ alt="Paper"
219
+ width="24"
220
+ >
221
+ Seed1.5-VL Paper
222
+ </a>
223
+ </div>
224
+ """,
225
+ )
226
+ gr.Markdown(' ')
227
+ gr.Markdown(' ')
228
+ gr.Markdown(' ')
229
+ gr.Markdown(' ')
230
+ with gr.Row():
231
+ gr_lang_selector = gr.Dropdown(choices=["English", "中文"],
232
+ value="English",
233
+ label="🌐 English Interface/中文界面",
234
+ interactive=True,
235
+ min_width=400,
236
+ scale=0)
237
+
238
+ with gr.Tabs():
239
+ with gr.Tab("Offline") as gr_tab_ofl:
240
+ gr_infer_history = gr.State([])
241
+ gr_thinking_hidden = gr.Checkbox(value=True, visible=False)
242
+ gr_temperature_hidden = gr.Slider(minimum=0.0,
243
+ maximum=2.0,
244
+ step=0.1,
245
+ value=0.0,
246
+ interactive=True,
247
+ visible=False)
248
+ gr_chatinterface_ofl = gr.ChatInterface(
249
+ fn=general_chat,
250
+ type="messages",
251
+ multimodal=True,
252
+ textbox=gr.MultimodalTextbox(
253
+ file_count="multiple",
254
+ file_types=["image", ".mp4"],
255
+ sources=["upload"],
256
+ stop_btn=True,
257
+ placeholder=label_translations[
258
+ 'gr_chatinterface_ofl.textbox.placeholder']['English'],
259
+ ),
260
+ additional_inputs=[
261
+ gr_infer_history, gr_thinking_hidden, gr_temperature_hidden
262
+ ],
263
+ additional_outputs=[gr_infer_history],
264
+ )
265
+ def add_escape_fn(inputs: dict):
266
+ if inputs and 'text' in inputs:
267
+ inputs['text'] = add_escape(inputs['text'])
268
+ return inputs
269
+ gr_chatinterface_ofl.textbox.submit(
270
+ fn=add_escape_fn,
271
+ inputs=[gr_chatinterface_ofl.saved_input],
272
+ outputs=[gr_chatinterface_ofl.saved_input]
273
+ )
274
+ gr.on(triggers=[gr_chatinterface_ofl.chatbot.clear],
275
+ fn=lambda: [],
276
+ outputs=[gr_infer_history])
277
+ with gr.Row():
278
+ gr_thinking_ofl = gr.Checkbox(
279
+ value=True,
280
+ label=label_translations['gr_thinking']['English'],
281
+ )
282
+ gr_thinking_ofl.change(lambda x: x,
283
+ inputs=gr_thinking_ofl,
284
+ outputs=gr_thinking_hidden)
285
+ gr_temperature_ofl = gr.Slider(
286
+ minimum=0.0,
287
+ maximum=2.0,
288
+ step=0.1,
289
+ value=0.0,
290
+ label=label_translations['gr_temperature']['English'],
291
+ interactive=True)
292
+ gr_temperature_ofl.change(lambda x: x,
293
+ inputs=gr_temperature_ofl,
294
+ outputs=gr_temperature_hidden)
295
+ gr_clear_button_ofl = gr.Button(value=label_translations['gr_clear_button']['English'])
296
+ def clear_history_fn():
297
+ return None, [], [], [], []
298
+ gr_clear_button_ofl.click(
299
+ fn=clear_history_fn,
300
+ outputs=[
301
+ gr_chatinterface_ofl.conversation_id,
302
+ gr_chatinterface_ofl.saved_conversations,
303
+ gr_chatinterface_ofl.chatbot,
304
+ gr_chatinterface_ofl.chatbot_state,
305
+ gr_infer_history
306
+ ]
307
+ )
308
+ with gr.Column(visible=True) as gr_examples_en:
309
+ gr.Examples(
310
+ label='7 Examples: text, image, video, multiple images/videos, visual puzzle, points grounding, open-vocabulary detection.',
311
+ examples=[
312
+ {
313
+ "text": "Who are you?",
314
+ "files": []
315
+ },
316
+ {
317
+ "text": "Introduce this.",
318
+ "files": ["examples/bancopy.jpg"]
319
+ },
320
+ {
321
+ "text":
322
+ """Find Curry's "Good Night" celebration time.""",
323
+ "files":
324
+ ["examples/I7pTpMjqNRM_1080p_small.mp4"]
325
+ },
326
+ {
327
+ "text":
328
+ "Share your feelings.",
329
+ "files": [
330
+ "examples/newyork.jpg",
331
+ "examples/beijing.jpg"
332
+ ]
333
+ },
334
+ {
335
+ "text": "Look and answer.",
336
+ "files": ["examples/puzzle.png"]
337
+ },
338
+ {
339
+ "text": "Please point out all the hats on people's heads in the image, output concatenated point coordinates like <point>x y</point><point>x y</point>",
340
+ "files": ["examples/000000001000.jpeg"]
341
+ },
342
+ {
343
+ "text": """Please detect all plate, photo, kid, cup in the image, and output all objects in the JSON format, which is a list of dict like [{"category": category, "bbox": "<bbox>x1 y1 x2 y2</bbox>"}, {"category": category, "bbox": "<bbox>x1 y1 x2 y2</bbox>"}]""",
344
+ "files": ["examples/000000018380.jpeg"]
345
+ }
346
+ ],
347
+ inputs=[gr_chatinterface_ofl.textbox],
348
+ )
349
+ with gr.Column(visible=False) as gr_examples_cn:
350
+ gr.Examples(
351
+ label='七个示例:文本,图像,视频,多个图像/视频,视觉解谜,坐标定位,开放式物体检测。',
352
+ examples=[
353
+ {
354
+ "text": "你是谁?",
355
+ "files": []
356
+ },
357
+ {
358
+ "text": "介绍一下。",
359
+ "files": ["examples/bancopy.jpg"]
360
+ },
361
+ {
362
+ "text":
363
+ "找到库里的“晚安”庆祝时间段。",
364
+ "files":
365
+ ["examples/I7pTpMjqNRM_1080p_small.mp4"]
366
+ },
367
+ {
368
+ "text":
369
+ "你有什么感想?",
370
+ "files": [
371
+ "examples/newyork.jpg",
372
+ "examples/beijing.jpg"
373
+ ]
374
+ },
375
+ {
376
+ "text": "看图回答。",
377
+ "files": ["examples/puzzle.png"]
378
+ },
379
+ {
380
+ "text": "请点出图像中所有戴在头上的帽子, 输出串联的点坐标<point>x y</point><point>x y</point>",
381
+ "files": ["examples/000000001000.jpeg"]
382
+ },
383
+ {
384
+ "text": """请检测图像中所有的盘子、照片、小孩和杯子。请以JSON格式输出一个由字典组成的列表,就像:[{"category": 类别, "bbox": "<bbox>x1 y1 x2 y2</bbox>"}, {"category": 类别, "bbox": "<bbox>x1 y1 x2 y2</bbox>"}]""",
385
+ "files": ["examples/000000018380.jpeg"]
386
+ }
387
+ ],
388
+ inputs=[gr_chatinterface_ofl.textbox],
389
+ )
390
+ with gr.Tab("Online") as gr_tab_ol:
391
+ with gr.Row():
392
+ with gr.Column(scale=1):
393
+ gr_infer_history_ol = gr.State([])
394
+ gr_thinking_hidden = gr.Checkbox(value=True, visible=False)
395
+ gr_temperature_hidden = gr.Slider(minimum=0.0,
396
+ maximum=2.0,
397
+ step=0.1,
398
+ value=1.0,
399
+ interactive=True,
400
+ visible=False)
401
+ with gr.Row():
402
+ with gr.Column(scale=1):
403
+ gr_webcam_image = gr.Image(
404
+ label=label_translations['gr_webcam_image']
405
+ ['English'],
406
+ sources="webcam",
407
+ height=250,
408
+ type='filepath')
409
+ gr_webcam_images = gr.Gallery(
410
+ label=label_translations['gr_webcam_images']
411
+ ['English'],
412
+ show_label=True,
413
+ format='webp',
414
+ columns=1,
415
+ height=250,
416
+ preview=True,
417
+ interactive=False)
418
+ gr_counter = gr.Number(value=0, visible=False)
419
+ with gr.Column(scale=3):
420
+ gr_chatinterface_ol = gr.ChatInterface(
421
+ fn=online_record_chat,
422
+ type="messages",
423
+ multimodal=False,
424
+ textbox=gr.
425
+ Textbox(placeholder=label_translations[
426
+ 'gr_chatinterface_ol.textbox.placeholder']
427
+ ['English'],
428
+ submit_btn=True,
429
+ stop_btn=True),
430
+ additional_inputs=[
431
+ gr_webcam_images, gr_counter,
432
+ gr_infer_history_ol, gr_thinking_hidden,
433
+ gr_temperature_hidden
434
+ ],
435
+ additional_outputs=[
436
+ gr_counter, gr_infer_history_ol
437
+ ],
438
+ )
439
+
440
+ def cache_webcam(recorded_image: str,
441
+ recorded_images: list):
442
+ if not recorded_images:
443
+ recorded_images = []
444
+ return recorded_images + [recorded_image]
445
+
446
+ gr_webcam_image.stream(
447
+ fn=cache_webcam,
448
+ inputs=[gr_webcam_image, gr_webcam_images],
449
+ outputs=[gr_webcam_images],
450
+ stream_every=1,
451
+ concurrency_limit=30,
452
+ )
453
+ with gr.Row():
454
+ gr_thinking_ol = gr.Checkbox(
455
+ value=True,
456
+ label=label_translations['gr_thinking']
457
+ ['English'],
458
+ )
459
+ gr_thinking_ol.change(
460
+ lambda x: x,
461
+ inputs=gr_thinking_ol,
462
+ outputs=gr_thinking_hidden)
463
+ gr_temperature_ol = gr.Slider(
464
+ minimum=0.0,
465
+ maximum=2.0,
466
+ step=0.1,
467
+ value=1.0,
468
+ label=label_translations['gr_temperature']
469
+ ['English'],
470
+ interactive=True)
471
+ gr_temperature_ol.change(
472
+ lambda x: x,
473
+ inputs=gr_temperature_ol,
474
+ outputs=gr_temperature_hidden)
475
+ gr_clear_button_ol = gr.Button(value=label_translations['gr_clear_button']['English'])
476
+ def clear_history_fn():
477
+ return None, [], [], [], []
478
+ gr_clear_button_ol.click(
479
+ fn=clear_history_fn,
480
+ outputs=[
481
+ gr_chatinterface_ol.conversation_id,
482
+ gr_chatinterface_ol.saved_conversations,
483
+ gr_chatinterface_ol.chatbot,
484
+ gr_chatinterface_ol.chatbot_state,
485
+ gr_infer_history_ol
486
+ ]
487
+ )
488
+
489
+ def update_lang(lang: str):
490
+ return (
491
+ gr.update(label=label_translations['gr_chatinterface_ofl'][lang]),
492
+ gr.update(label=label_translations['gr_chatinterface_ol'][lang]),
493
+ gr.update(placeholder=label_translations[
494
+ 'gr_chatinterface_ofl.textbox.placeholder'][lang]),
495
+ gr.update(placeholder=label_translations[
496
+ 'gr_chatinterface_ol.textbox.placeholder'][lang]),
497
+ gr.update(label=label_translations['gr_tab_ofl'][lang]),
498
+ gr.update(label=label_translations['gr_tab_ol'][lang]),
499
+ gr.update(label=label_translations['gr_thinking'][lang]),
500
+ gr.update(label=label_translations['gr_thinking'][lang]),
501
+ gr.update(label=label_translations['gr_temperature'][lang]),
502
+ gr.update(label=label_translations['gr_temperature'][lang]),
503
+ gr.update(visible=lang == 'English'),
504
+ gr.update(visible=lang != 'English'),
505
+ gr.update(label=label_translations['gr_webcam_image'][lang]),
506
+ gr.update(label=label_translations['gr_webcam_images'][lang]),
507
+ gr.update(value=label_translations['gr_clear_button'][lang]),
508
+ gr.update(value=label_translations['gr_clear_button'][lang]),
509
+ )
510
+
511
+ gr_lang_selector.change(fn=update_lang,
512
+ inputs=[gr_lang_selector],
513
+ outputs=[
514
+ gr_chatinterface_ofl.chatbot,
515
+ gr_chatinterface_ol.chatbot,
516
+ gr_chatinterface_ofl.textbox,
517
+ gr_chatinterface_ol.textbox,
518
+ gr_tab_ofl,
519
+ gr_tab_ol,
520
+ gr_thinking_ofl,
521
+ gr_thinking_ol,
522
+ gr_temperature_ofl,
523
+ gr_temperature_ol,
524
+ gr_examples_en,
525
+ gr_examples_cn,
526
+ gr_webcam_image,
527
+ gr_webcam_images,
528
+ gr_clear_button_ofl,
529
+ gr_clear_button_ol,
530
+ ])
531
+ demo.queue(default_concurrency_limit=100, max_size=100).launch(share=True,
532
+ max_threads=100,
533
+ ssr_mode=False)
infer.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) [Seed-VL-Cookbook] Bytedance Seed
2
+ import json
3
+ import time
4
+ import math
5
+ import base64
6
+ import requests
7
+
8
+ import torch
9
+ import decord
10
+ import numpy as np
11
+ from PIL import Image, ImageSequence
12
+ from torchvision.io import read_image, encode_jpeg
13
+ from torchvision.transforms.functional import resize
14
+ from torchvision.transforms import InterpolationMode
15
+
16
+
17
+ class ConversationModeI18N:
18
+ G = "General"
19
+ D = "Deep Thinking"
20
+
21
+
22
+ class ConversationModeCN:
23
+ G = "常规"
24
+ D = "深度思考"
25
+
26
+
27
+ def round_by_factor(number: int, factor: int) -> int:
28
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
29
+ return round(number / factor) * factor
30
+
31
+
32
+ def ceil_by_factor(number: int, factor: int) -> int:
33
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
34
+ return math.ceil(number / factor) * factor
35
+
36
+
37
+ def floor_by_factor(number: int, factor: int) -> int:
38
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
39
+ return math.floor(number / factor) * factor
40
+
41
+
42
+ def get_resized_hw_for_Navit(
43
+ height: int,
44
+ width: int,
45
+ min_pixels: int,
46
+ max_pixels: int,
47
+ max_ratio: int = 200,
48
+ factor: int = 28,
49
+ ):
50
+ if max(height, width) / min(height, width) > max_ratio:
51
+ raise ValueError(
52
+ f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)}"
53
+ )
54
+ h_bar = max(factor, round_by_factor(height, factor))
55
+ w_bar = max(factor, round_by_factor(width, factor))
56
+ if h_bar * w_bar > max_pixels:
57
+ beta = math.sqrt((height * width) / max_pixels)
58
+ h_bar = floor_by_factor(height / beta, factor)
59
+ w_bar = floor_by_factor(width / beta, factor)
60
+ elif h_bar * w_bar < min_pixels:
61
+ beta = math.sqrt(min_pixels / (height * width))
62
+ h_bar = ceil_by_factor(height * beta, factor)
63
+ w_bar = ceil_by_factor(width * beta, factor)
64
+ return int(h_bar), int(w_bar)
65
+
66
+
67
+ class SeedVLInfer:
68
+ def __init__(
69
+ self,
70
+ model_id: str,
71
+ api_key: str,
72
+ base_url: str = 'https://ark.cn-beijing.volces.com/api/v3/chat/completions',
73
+ min_pixels: int = 4 * 28 * 28,
74
+ max_pixels: int = 5120 * 28 * 28,
75
+ video_sampling_strategy: dict = {
76
+ 'sampling_fps':
77
+ 1,
78
+ 'min_n_frames':
79
+ 16,
80
+ 'max_video_length':
81
+ 81920,
82
+ 'max_pixels_choices': [
83
+ 640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28,
84
+ 160 * 28 * 28, 128 * 28 * 28
85
+ ],
86
+ 'use_timestamp':
87
+ True,
88
+ },
89
+ ):
90
+ self.base_url = base_url
91
+ self.api_key = api_key
92
+ self.model_id = model_id
93
+ self.min_pixels = min_pixels
94
+ self.max_pixels = max_pixels
95
+ self.sampling_fps = video_sampling_strategy.get('sampling_fps', 1)
96
+ self.min_n_frames = video_sampling_strategy.get('min_n_frames', 16)
97
+ self.max_video_length = video_sampling_strategy.get(
98
+ 'max_video_length', 81920)
99
+ self.max_pixels_choices = video_sampling_strategy.get(
100
+ 'max_pixels_choices', [
101
+ 640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28,
102
+ 160 * 28 * 28, 128 * 28 * 28
103
+ ])
104
+ self.use_timestamp = video_sampling_strategy.get('use_timestamp', True)
105
+
106
+ def preprocess_video(self, video_path: str):
107
+ try:
108
+ video_reader = decord.VideoReader(video_path, num_threads=2)
109
+ fps = video_reader.get_avg_fps()
110
+ except decord._ffi.base.DECORDError:
111
+ video_reader = [
112
+ frame.convert('RGB')
113
+ for frame in ImageSequence.Iterator(Image.open(video_path))
114
+ ]
115
+ fps = 1
116
+
117
+ length = len(video_reader)
118
+ n_frames = min(
119
+ max(math.ceil(length / fps * self.sampling_fps),
120
+ self.min_n_frames), length)
121
+ frame_indices = np.linspace(0, length - 1,
122
+ n_frames).round().astype(int).tolist()
123
+ max_pixels = self.max_pixels
124
+ for round_idx, max_pixels in enumerate(self.max_pixels_choices):
125
+ is_last_round = round_idx == len(self.max_pixels_choices) - 1
126
+ if len(frame_indices
127
+ ) * max_pixels / 28 / 28 > self.max_video_length:
128
+ if is_last_round:
129
+ max_frame_num = int(self.max_video_length / max_pixels *
130
+ 28 * 28)
131
+ select_ids = np.linspace(
132
+ 0,
133
+ len(frame_indices) - 1,
134
+ max_frame_num).round().astype(int).tolist()
135
+ frame_indices = [
136
+ frame_indices[select_id] for select_id in select_ids
137
+ ]
138
+ else:
139
+ continue
140
+ else:
141
+ break
142
+
143
+ if hasattr(video_reader, "get_batch"):
144
+ video_clip = torch.from_numpy(
145
+ video_reader.get_batch(frame_indices).asnumpy()).permute(
146
+ 0, 3, 1, 2)
147
+ else:
148
+ video_clip_array = torch.stack(
149
+ [np.array(video_reader[i]) for i in frame_indices], dim=0)
150
+ video_clip = torch.from_numpy(video_clip_array).permute(0, 3, 1, 2)
151
+
152
+ height, width = video_clip.shape[-2:]
153
+ resized_height, resized_width = get_resized_hw_for_Navit(
154
+ height,
155
+ width,
156
+ min_pixels=self.min_pixels,
157
+ max_pixels=max_pixels,
158
+ )
159
+ resized_video_clip = resize(video_clip,
160
+ (resized_height, resized_width),
161
+ interpolation=InterpolationMode.BICUBIC,
162
+ antialias=True)
163
+ if self.use_timestamp:
164
+ resized_video_clip = [
165
+ (round(i / fps, 1), f)
166
+ for i, f in zip(frame_indices, resized_video_clip)
167
+ ]
168
+ return resized_video_clip
169
+
170
+ def preprocess_streaming_frame(self, frame: torch.Tensor):
171
+ height, width = frame.shape[-2:]
172
+ resized_height, resized_width = get_resized_hw_for_Navit(
173
+ height,
174
+ width,
175
+ min_pixels=self.min_pixels,
176
+ max_pixels=self.max_pixels_choices[0],
177
+ )
178
+ resized_frame = resize(frame[None], (resized_height, resized_width),
179
+ interpolation=InterpolationMode.BICUBIC,
180
+ antialias=True)[0]
181
+ return resized_frame
182
+
183
+ def encode_image(self, image: torch.Tensor) -> str:
184
+ if image.shape[0] == 4:
185
+ image = image[:3]
186
+ encoded = encode_jpeg(image)
187
+ return base64.b64encode(encoded.numpy()).decode('utf-8')
188
+
189
+ def construct_messages(self,
190
+ inputs: dict,
191
+ streaming_timestamp: int = None,
192
+ online: bool = False) -> list[dict]:
193
+ content = []
194
+ for i, path in enumerate(inputs.get('files', [])):
195
+ if path.endswith('.mp4'):
196
+ video = self.preprocess_video(video_path=path)
197
+ for frame in video:
198
+ if self.use_timestamp:
199
+ timestamp, frame = frame
200
+ content.append({
201
+ "type": "text",
202
+ "text": f'[{timestamp} second]',
203
+ })
204
+ content.append({
205
+ "type": "image_url",
206
+ "image_url": {
207
+ "url":
208
+ f"data:image/jpeg;base64,{self.encode_image(frame)}",
209
+ "detail": "high"
210
+ },
211
+ })
212
+ else:
213
+ image = read_image(path)
214
+ if online and path.endswith('.webp'):
215
+ streaming_timestamp = i
216
+ if streaming_timestamp is not None:
217
+ image = self.preprocess_streaming_frame(frame=image)
218
+ content.append({
219
+ "type": "image_url",
220
+ "image_url": {
221
+ "url":
222
+ f"data:image/jpeg;base64,{self.encode_image(image)}",
223
+ "detail": "high"
224
+ },
225
+ })
226
+ if streaming_timestamp is not None:
227
+ content.insert(-1, {
228
+ "type": "text",
229
+ "text": f'[{streaming_timestamp} second]',
230
+ })
231
+ query = inputs.get('text', '')
232
+ if query:
233
+ content.append({
234
+ "type": "text",
235
+ "text": query,
236
+ })
237
+ messages = [{
238
+ "role": "user",
239
+ "content": content,
240
+ }]
241
+ return messages
242
+
243
+ def request(self,
244
+ messages,
245
+ thinking: bool = True,
246
+ temperature: float = 1.0):
247
+ headers = {
248
+ "Authorization": f"Bearer {self.api_key}",
249
+ "Content-Type": "application/json"
250
+ }
251
+ payload = {
252
+ "model": self.model_id,
253
+ "messages": messages,
254
+ "stream": True,
255
+ "thinking": {
256
+ "type": "enabled" if thinking else "disabled",
257
+ },
258
+ "temperature": temperature,
259
+ }
260
+ for _ in range(10):
261
+ try:
262
+ requested = requests.post(self.base_url,
263
+ headers=headers,
264
+ json=payload,
265
+ stream=True,
266
+ timeout=600)
267
+ break
268
+ except Exception as e:
269
+ time.sleep(0.1)
270
+ print(e)
271
+ content, reasoning_content = '', ''
272
+ for line in requested.iter_lines():
273
+ if not line:
274
+ continue
275
+ if line.startswith(b'data:'):
276
+ data = line[len("data: "):]
277
+ if data == b"[DONE]":
278
+ yield content, reasoning_content, True
279
+ break
280
+ delta = json.loads(data)['choices'][0]['delta']
281
+ content += delta['content']
282
+ reasoning_content += delta.get('reasoning_content', '')
283
+ yield content, reasoning_content, False
284
+
285
+ def __call__(self,
286
+ inputs: dict,
287
+ history: list[dict] = [],
288
+ mode: str = ConversationModeI18N.D,
289
+ temperature: float = 1.0,
290
+ online: bool = False):
291
+ messages = self.construct_messages(inputs=inputs, online=online)
292
+ updated_history = history + messages
293
+ for response, reasoning, finished in self.request(
294
+ messages=updated_history,
295
+ thinking=mode == ConversationModeI18N.D,
296
+ temperature=temperature):
297
+ if mode == ConversationModeI18N.D:
298
+ response = '<think>' + reasoning + '</think>' + response
299
+ yield response, updated_history + [{'role': 'assistant', 'content': response}], finished
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ supervision==0.25.1
2
+ openai==1.76.0
3
+ opencv-python==4.10.0.84
4
+ numpy==1.26.2
5
+ pillow==11.0.0
6
+ matplotlib==3.10.
7
+ decord==0.6.0
visualizer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import supervision as sv
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from supervision.annotators.utils import resolve_color
7
+ # visualization toos based on supervision
8
+ BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(thickness=2)
9
+
10
+ class LabelAnnotator(sv.LabelAnnotator):
11
+
12
+ @staticmethod
13
+ def resolve_text_background_xyxy(
14
+ center_coordinates,
15
+ text_wh,
16
+ position,
17
+ ):
18
+ center_x, center_y = center_coordinates
19
+ text_w, text_h = text_wh
20
+ return center_x, center_y, center_x + text_w, center_y + text_h
21
+
22
+ def _draw_labels(
23
+ self,
24
+ scene: np.ndarray,
25
+ labels: list[str],
26
+ label_properties: np.ndarray,
27
+ detections,
28
+ custom_color_lookup,
29
+ ) -> None:
30
+ assert len(labels) == len(label_properties) == len(detections), (
31
+ f"Number of label properties ({len(label_properties)}), "
32
+ f"labels ({len(labels)}) and detections ({len(detections)}) "
33
+ "do not match."
34
+ )
35
+
36
+ color_lookup = (
37
+ custom_color_lookup
38
+ if custom_color_lookup is not None
39
+ else self.color_lookup
40
+ )
41
+
42
+ font = ImageFont.truetype("simhei.ttf", int(30 * self.text_scale))
43
+
44
+ for idx, label_property in enumerate(label_properties):
45
+ background_color = resolve_color(
46
+ color=self.color,
47
+ detections=detections,
48
+ detection_idx=idx,
49
+ color_lookup=color_lookup,
50
+ )
51
+ text_color = resolve_color(
52
+ color=self.text_color,
53
+ detections=detections,
54
+ detection_idx=idx,
55
+ color_lookup=color_lookup,
56
+ )
57
+
58
+ box_xyxy = label_property[:4]
59
+ text_height_padded = label_property[4]
60
+ self.draw_rounded_rectangle(
61
+ scene=scene,
62
+ xyxy=box_xyxy,
63
+ color=background_color.as_bgr(),
64
+ border_radius=self.border_radius,
65
+ )
66
+
67
+ text_x = box_xyxy[0] + self.text_padding
68
+ text_y = box_xyxy[1]
69
+
70
+ scene_pil = Image.fromarray(cv2.cvtColor(scene, cv2.COLOR_BGR2RGB))
71
+ draw = ImageDraw.Draw(scene_pil)
72
+ draw.text(
73
+ (text_x, text_y),
74
+ labels[idx],
75
+ font=font,
76
+ fill=(text_color.r, text_color.g, text_color.b),
77
+ )
78
+ scene[:] = cv2.cvtColor(np.array(scene_pil), cv2.COLOR_RGB2BGR)
79
+
80
+
81
+ LABEL_ANNOTATOR = LabelAnnotator(text_padding=4,
82
+ text_scale=0.5,
83
+ text_thickness=1)
84
+
85
+
86
+ POINT_ANNOTATOR = sv.DotAnnotator(radius=6)
87
+
88
+ def draw_boxes_points_with_labels(
89
+ cv2_image,
90
+ boxes=None,
91
+ points=None,
92
+ classes=None,
93
+ output_path=None,
94
+ ):
95
+ annotated_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
96
+
97
+ if boxes is not None and boxes.size:
98
+ detections = sv.Detections(
99
+ xyxy=boxes,
100
+ class_id=np.arange(len(boxes)),
101
+ confidence=np.ones(len(boxes))
102
+ )
103
+ annotated_image = BOUNDING_BOX_ANNOTATOR.annotate(
104
+ annotated_image, detections)
105
+ if points is not None and points.size:
106
+ points = np.concatenate([points, points], axis=1)
107
+ detections = sv.Detections(
108
+ xyxy=points,
109
+ class_id=np.arange(len(points)),
110
+ confidence=np.ones(len(points))
111
+ )
112
+ annotated_image = POINT_ANNOTATOR.annotate(
113
+ annotated_image, detections,
114
+ )
115
+ if classes:
116
+ annotated_image = LABEL_ANNOTATOR.annotate(
117
+ annotated_image, detections, labels=classes
118
+ )
119
+
120
+ if output_path:
121
+ cv2.imwrite(
122
+ output_path,
123
+ cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
124
+ )
125
+
126
+ return annotated_image