File size: 1,677 Bytes
5ca9963
2f4647f
2319467
 
2f4647f
5ca9963
2f4647f
 
 
699fe26
2f4647f
 
5dda5d9
2f4647f
2319467
 
2f4647f
2319467
73d58c2
2f4647f
e5e2e79
73d58c2
5a6e90f
73d58c2
2f4647f
5a6e90f
73d58c2
 
5a6e90f
f580378
 
 
 
73d58c2
 
 
 
 
 
 
2f4647f
5a6e90f
5ca9963
2f4647f
 
2319467
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from PIL import Image
import requests
from io import BytesIO

# Initialize the model and processor
model_name = "Qwen/Qwen2-VL-2B-Instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Qwen2VLForConditionalGeneration.from_pretrained(model_name).to(device)
processor = AutoProcessor.from_pretrained(model_name)

# Load the image from URL
image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))

# Automatically preprocess the image and text input using the processor
text_input = "Describe this image."

# The processor automatically handles resizing, normalization, and tokenization
inputs = processor(
    images=img,
    text=text_input,
    return_tensors="pt",
    padding=True,  # Automatically pad to match model input size
)

# Check the number of tokens generated by the processor and the shape of inputs
print("Input tokens:", inputs.input_ids.shape)
print("Image features shape:", inputs.pixel_values.shape)

# Ensure image and text are properly tokenized and features align
assert inputs.input_ids.shape[1] > 0, "No tokens generated for text input!"
assert inputs.pixel_values.shape[0] > 0, "No features generated for the image!"

# Move inputs to the device (either GPU or CPU)
inputs = {key: value.to(device) for key, value in inputs.items()}

# Inference
generated_ids = model.generate(**inputs, max_new_tokens=128)

# Decode the output
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(output_text)