TheEighthDay commited on
Commit
83ba4e4
·
verified ·
1 Parent(s): df7d303

Delete simple_inference.py

Browse files
Files changed (1) hide show
  1. simple_inference.py +0 -238
simple_inference.py DELETED
@@ -1,238 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- import argparse
5
- import json
6
- import os
7
- import torch
8
- import base64
9
- from io import BytesIO
10
- from PIL import Image
11
-
12
- # 条件导入,根据选择的推理引擎
13
- try:
14
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
15
- transformers_available = True
16
- except ImportError:
17
- transformers_available = False
18
- print("警告: transformers相关库未安装,无法使用transformers引擎")
19
-
20
- try:
21
- from vllm import LLM, SamplingParams
22
- vllm_available = True
23
- except ImportError:
24
- vllm_available = False
25
- print("警告: vllm相关库未安装,无法使用vllm引擎")
26
-
27
- # 合并 qwen_vl_utils 的代码
28
- def process_vision_info(messages):
29
- """处理多模态消息中的图像和视频信息
30
-
31
- Args:
32
- messages: 包含图像或视频的消息列表
33
-
34
- Returns:
35
- images_data: 处理后的图像数据
36
- videos_data: 处理后的视频数据
37
- """
38
- images_list, videos_list = [], []
39
- for message in messages:
40
- content = message.get("content", None)
41
- if isinstance(content, str):
42
- # 纯文本消息,不处理
43
- continue
44
- elif isinstance(content, list):
45
- # 混合消息,可能包含图像或视频
46
- for item in content:
47
- if not isinstance(item, dict):
48
- continue
49
-
50
- # 处理图像
51
- if item.get("type") == "image" and "image" in item:
52
- image = item["image"]
53
- if isinstance(image, str):
54
- # 图像URL或路径,尝试加载
55
- try:
56
- image = Image.open(image)
57
- except Exception as e:
58
- print(f"图像加载失败: {e}")
59
- continue
60
-
61
- # 转换PIL图像为base64编码
62
- if isinstance(image, Image.Image):
63
- buffered = BytesIO()
64
- image.save(buffered, format="PNG")
65
- image_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
66
- images_list.append(image_str)
67
-
68
- # 处理视频(如有需要)
69
- elif item.get("type") == "video" and "video" in item:
70
- # 暂不支持视频
71
- pass
72
-
73
- return images_list or None, videos_list or None
74
-
75
- def predict_location(
76
- image_path,
77
- model_name="Qwen/Qwen2.5-VL-7B-Instruct",
78
- inference_engine="transformers"
79
- ):
80
- """
81
- 对单个图片进行位置识别预测
82
-
83
- 参数:
84
- image_path: 图片文件路径
85
- model_name: 模型名称或路径
86
- inference_engine: 推理引擎,"vllm" 或 "transformers"
87
-
88
- 返回:
89
- 预测结果文本
90
- """
91
- # 检查图片是否存在
92
- if not os.path.exists(image_path):
93
- return f"错误: 图片文件不存在: {image_path}"
94
-
95
- # 加载图片
96
- try:
97
- image = Image.open(image_path)
98
- print(f"成功加载图片: {image_path}")
99
- except Exception as e:
100
- return f"错误: 无法加载图片: {str(e)}"
101
-
102
- # 加载处理器
103
- print(f"加载处理器: {model_name}")
104
- processor = AutoProcessor.from_pretrained(model_name, padding_side='left')
105
-
106
- # 构建提示消息 - 简化版本,没有SFT和COT
107
- question_text = "In which country and within which first-level administrative region of that country was this picture taken?Please answer in the format of <answer>$country,administrative_area_level_1$</answer>?"
108
- system_message = "You are a helpful assistant good at solving problems with step-by-step reasoning. You should first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags."
109
-
110
- # 构建简化后的提示消息
111
- prompt_messages = [
112
- {
113
- "role": "system",
114
- "content": [
115
- {"type": "text", "text": system_message}
116
- ]
117
- },
118
- {
119
- "role": "user",
120
- "content": [
121
- {"type": "image", "image": image},
122
- {"type": "text", "text": question_text}
123
- ]
124
- }
125
- ]
126
-
127
- # 根据选定的引擎进行推理
128
- if inference_engine == "vllm":
129
- if not vllm_available:
130
- return "错误: vLLM库不可用,请安装vllm或选择transformers引擎"
131
-
132
- # 使用vLLM进行推理
133
- print(f"使用vLLM加载模型: {model_name}")
134
- llm = LLM(
135
- model=model_name,
136
- limit_mm_per_prompt={"image": 10, "video": 10},
137
- dtype="auto",
138
- gpu_memory_utilization=0.95,
139
- )
140
-
141
- # 设置采样参数
142
- sampling_params = SamplingParams(
143
- temperature=0.7,
144
- top_p=0.8,
145
- repetition_penalty=1.05,
146
- max_tokens=2048,
147
- stop_token_ids=[],
148
- )
149
-
150
- # 处理消息为vLLM格式
151
- prompt = processor.apply_chat_template(
152
- prompt_messages,
153
- tokenize=False,
154
- add_generation_prompt=True,
155
- )
156
-
157
- # 处理图像数据
158
- image_inputs, video_inputs = process_vision_info(prompt_messages)
159
-
160
- mm_data = {}
161
- if image_inputs is not None:
162
- mm_data["image"] = image_inputs
163
-
164
- # 构建vLLM输入
165
- llm_input = {
166
- "prompt": prompt,
167
- "multi_modal_data": mm_data,
168
- }
169
-
170
- # 生成回答
171
- outputs = llm.generate([llm_input], sampling_params=sampling_params)
172
- response = outputs[0].outputs[0].text
173
-
174
- else: # transformers
175
- if not transformers_available:
176
- return "错误: Transformers相关库不可用,请安装必要的包"
177
-
178
- # 使用transformers加载模型
179
- print(f"使用transformers加载模型: {model_name}")
180
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
181
- model_name, torch_dtype="auto", device_map="auto"
182
- )
183
-
184
- # 准备输入
185
- text = processor.apply_chat_template(
186
- prompt_messages, tokenize=False, add_generation_prompt=True
187
- )
188
-
189
- # 处理输入
190
- inputs = processor(
191
- text=text,
192
- images=prompt_messages[1]['content'][0]['image'],
193
- return_tensors="pt",
194
- )
195
-
196
- inputs = inputs.to(model.device)
197
-
198
- # 生成回答
199
- with torch.no_grad():
200
- generated_ids = model.generate(**inputs, max_new_tokens=2048)
201
-
202
- # 处理输出
203
- generated_ids_trimmed = generated_ids[0][len(inputs['input_ids'][0]):]
204
- response = processor.decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
205
-
206
- # 清理GPU缓存
207
- if torch.cuda.is_available():
208
- torch.cuda.empty_cache()
209
-
210
- print("\n=== 推理结果 ===")
211
- print(response)
212
- print("=================\n")
213
-
214
- return response
215
-
216
- # if __name__ == "__main__":
217
- # # 命令行参数设置
218
- # parser = argparse.ArgumentParser(description='对单个图片进行位置识别预测')
219
- # parser.add_argument('--image_path', type=str, required=True,
220
- # help='图片文件路径')
221
- # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
222
- # help='模型名称或路径')
223
- # parser.add_argument('--inference_engine', type=str, default="transformers", choices=["vllm", "transformers"],
224
- # help='推理引擎: vllm 或 transformers')
225
-
226
- # args = parser.parse_args()
227
-
228
- # # 单个图片推理
229
- # result = predict_location(
230
- # image_path=args.image_path,
231
- # model_name=args.model_name,
232
- # inference_engine=args.inference_engine
233
- # )
234
-
235
- # print(f"最终预测结果: {result}")
236
-
237
- # 使用示例:
238
- # python simple_inference.py --image_path /data/phd/tiankaibin/dataset/data/streetview_images_first_tier_cities/testaccio_rome_italy_h45_r100_20250317_183133.jpg --model_name TheEighthDay/SeekWorld_RL_PLUS --inference_engine vllm