File size: 8,068 Bytes
d5cd46b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

import argparse
import json
import os
import torch
import base64
from io import BytesIO
from PIL import Image

# 条件导入,根据选择的推理引擎
try:
    from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
    transformers_available = True
except ImportError:
    transformers_available = False
    print("警告: transformers相关库未安装,无法使用transformers引擎")

try:
    from vllm import LLM, SamplingParams
    vllm_available = True
except ImportError:
    vllm_available = False
    print("警告: vllm相关库未安装,无法使用vllm引擎")

# 合并 qwen_vl_utils 的代码
def process_vision_info(messages):
    """处理多模态消息中的图像和视频信息

    Args:
        messages: 包含图像或视频的消息列表

    Returns:
        images_data: 处理后的图像数据
        videos_data: 处理后的视频数据
    """
    images_list, videos_list = [], []
    for message in messages:
        content = message.get("content", None)
        if isinstance(content, str):
            # 纯文本消息,不处理
            continue
        elif isinstance(content, list):
            # 混合消息,可能包含图像或视频
            for item in content:
                if not isinstance(item, dict):
                    continue
                
                # 处理图像
                if item.get("type") == "image" and "image" in item:
                    image = item["image"]
                    if isinstance(image, str):
                        # 图像URL或路径,尝试加载
                        try:
                            image = Image.open(image)
                        except Exception as e:
                            print(f"图像加载失败: {e}")
                            continue
                    
                    # 转换PIL图像为base64编码
                    if isinstance(image, Image.Image):
                        buffered = BytesIO()
                        image.save(buffered, format="PNG")
                        image_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
                        images_list.append(image_str)
                
                # 处理视频(如有需要)
                elif item.get("type") == "video" and "video" in item:
                    # 暂不支持视频
                    pass
                    
    return images_list or None, videos_list or None

def predict_location(
    image_path, 
    model_name="Qwen/Qwen2.5-VL-7B-Instruct", 
    inference_engine="transformers"
):
    """
    对单个图片进行位置识别预测
    
    参数:
        image_path: 图片文件路径
        model_name: 模型名称或路径
        inference_engine: 推理引擎,"vllm" 或 "transformers"
        
    返回:
        预测结果文本
    """
    # 检查图片是否存在
    if not os.path.exists(image_path):
        return f"错误: 图片文件不存在: {image_path}"
    
    # 加载图片
    try:
        image = Image.open(image_path)
        print(f"成功加载图片: {image_path}")
    except Exception as e:
        return f"错误: 无法加载图片: {str(e)}"
    
    # 加载处理器
    print(f"加载处理器: {model_name}")
    processor = AutoProcessor.from_pretrained(model_name, padding_side='left')
    
    # 构建提示消息 - 简化版本,没有SFT和COT
    question_text = "In which country and within which first-level administrative region of that country was this picture taken?Please answer in the format of <answer>$country,administrative_area_level_1$</answer>?"
    system_message = "You are a helpful assistant good at solving problems with step-by-step reasoning. You should first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags."
    
    # 构建简化后的提示消息
    prompt_messages = [
        {
            "role": "system",
            "content": [
                {"type": "text", "text": system_message}
            ]
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": question_text}
            ]
        }
    ]
    
    # 根据选定的引擎进行推理
    if inference_engine == "vllm":
        if not vllm_available:
            return "错误: vLLM库不可用,请安装vllm或选择transformers引擎"
        
        # 使用vLLM进行推理
        print(f"使用vLLM加载模型: {model_name}")
        llm = LLM(
            model=model_name,
            limit_mm_per_prompt={"image": 10, "video": 10},
            dtype="auto",
            gpu_memory_utilization=0.95,
        )
        
        # 设置采样参数
        sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.8,
            repetition_penalty=1.05,
            max_tokens=2048,
            stop_token_ids=[],
        )
        
        # 处理消息为vLLM格式
        prompt = processor.apply_chat_template(
            prompt_messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        # 处理图像数据
        image_inputs, video_inputs = process_vision_info(prompt_messages)
        
        mm_data = {}
        if image_inputs is not None:
            mm_data["image"] = image_inputs
        
        # 构建vLLM输入
        llm_input = {
            "prompt": prompt,
            "multi_modal_data": mm_data,
        }
        
        # 生成回答
        outputs = llm.generate([llm_input], sampling_params=sampling_params)
        response = outputs[0].outputs[0].text
        
    else:  # transformers
        if not transformers_available:
            return "错误: Transformers相关库不可用,请安装必要的包"
        
        # 使用transformers加载模型
        print(f"使用transformers加载模型: {model_name}")
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name, torch_dtype="auto", device_map="auto"
        )
        
        # 准备输入
        text = processor.apply_chat_template(
            prompt_messages, tokenize=False, add_generation_prompt=True
        )
        
        # 处理输入
        inputs = processor(
            text=text,
            images=prompt_messages[1]['content'][0]['image'],
            return_tensors="pt",
        )
        
        inputs = inputs.to(model.device)
        
        # 生成回答
        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=2048)
        
        # 处理输出
        generated_ids_trimmed = generated_ids[0][len(inputs['input_ids'][0]):]
        response = processor.decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        
        # 清理GPU缓存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print("\n=== 推理结果 ===")
    print(response)
    print("=================\n")
    
    return response

if __name__ == "__main__":
    # 命令行参数设置
    parser = argparse.ArgumentParser(description='对单个图片进行位置识别预测')
    parser.add_argument('--image_path', type=str, required=True,
                        help='图片文件路径')
    parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
                        help='模型名称或路径')
    parser.add_argument('--inference_engine', type=str, default="transformers", choices=["vllm", "transformers"],
                        help='推理引擎: vllm 或 transformers')
    
    args = parser.parse_args()
    
    # 单个图片推理
    result = predict_location(
        image_path=args.image_path,
        model_name=args.model_name,
        inference_engine=args.inference_engine
    )
    
    print(f"最终预测结果: {result}")