vykanand commited on
Commit
5a6e90f
Β·
verified Β·
1 Parent(s): e5e2e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -44
app.py CHANGED
@@ -1,50 +1,60 @@
1
- import gradio as gr
2
- import torch
3
- from torchvision import models, transforms
4
- from PIL import Image
5
- import requests
6
- from io import BytesIO
7
 
8
- # Load pre-trained MobileNetV2 model (you can choose another model as needed)
9
- model = models.mobilenet_v2(pretrained=True)
10
- model.eval()
 
11
 
12
- # Define the image transformation (resize, normalization)
13
- transform = transforms.Compose([
14
- transforms.Resize(256),
15
- transforms.CenterCrop(224),
16
- transforms.ToTensor(),
17
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
- ])
19
 
20
- # Download the ImageNet class labels (you can replace this with your own if needed)
21
- LABELS_URL = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
22
- class_idx = requests.get(LABELS_URL).json()
23
- idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
24
 
25
- # Function to perform image inference
26
- def predict_image(image):
27
- image = Image.open(BytesIO(image)).convert("RGB")
28
- image = transform(image).unsqueeze(0)
29
-
30
- # Perform inference
31
- with torch.no_grad():
32
- output = model(image)
33
-
34
- # Get the predicted label
35
- _, predicted_class = torch.max(output, 1)
36
- label = idx2label[predicted_class.item()]
37
- return label
38
 
39
- # Gradio interface
40
- with gr.Blocks() as demo:
41
- gr.Markdown("### Image Classification with MobileNetV2")
42
-
43
- with gr.Row():
44
- image_input = gr.Image(type="bytes", label="Upload Image")
45
- result_output = gr.Textbox(label="Predicted Label")
46
-
47
- # Endpoint for inference
48
- gr.Interface(fn=predict_image, inputs=image_input, outputs=result_output, api_name="/predict_image")
 
 
49
 
50
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
 
 
 
 
3
 
4
+ # default: Load the model on the available device(s)
5
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
6
+ "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
7
+ )
8
 
9
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
10
+ # model = Qwen2VLForConditionalGeneration.from_pretrained(
11
+ # "Qwen/Qwen2-VL-2B-Instruct",
12
+ # torch_dtype=torch.bfloat16,
13
+ # attn_implementation="flash_attention_2",
14
+ # device_map="auto",
15
+ # )
16
 
17
+ # default processer
18
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
 
 
19
 
20
+ # The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
21
+ # min_pixels = 256*28*28
22
+ # max_pixels = 1280*28*28
23
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
 
 
 
 
 
 
 
 
 
24
 
25
+ messages = [
26
+ {
27
+ "role": "user",
28
+ "content": [
29
+ {
30
+ "type": "image",
31
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
32
+ },
33
+ {"type": "text", "text": "Describe this image."},
34
+ ],
35
+ }
36
+ ]
37
 
38
+ # Preparation for inference
39
+ text = processor.apply_chat_template(
40
+ messages, tokenize=False, add_generation_prompt=True
41
+ )
42
+ image_inputs, video_inputs = process_vision_info(messages)
43
+ inputs = processor(
44
+ text=[text],
45
+ images=image_inputs,
46
+ videos=video_inputs,
47
+ padding=True,
48
+ return_tensors="pt",
49
+ )
50
+ inputs = inputs.to("cuda")
51
+
52
+ # Inference: Generation of the output
53
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
54
+ generated_ids_trimmed = [
55
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
56
+ ]
57
+ output_text = processor.batch_decode(
58
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
59
+ )
60
+ print(output_text)