File size: 3,307 Bytes
bd0f7dc
 
 
 
 
 
 
11bbd27
 
 
 
 
 
bd0f7dc
f3d47d3
 
 
bd0f7dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11bbd27
 
bd0f7dc
 
11bbd27
bd0f7dc
 
 
 
 
 
 
 
 
 
 
 
 
11bbd27
bd0f7dc
 
 
 
f3d47d3
bd0f7dc
 
f3d47d3
 
 
 
 
 
 
 
 
11bbd27
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
from fastapi import FastAPI, UploadFile, File
from typing import List
from io import BytesIO

# FastAPI app initialization
app = FastAPI()

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    target_width = image_size * target_ratios[0][0]
    target_height = image_size * target_ratios[0][1]
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(target_ratios[0][0] * target_ratios[0][1]):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file: BytesIO, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values).to(device)
    return pixel_values

# Load Model
path = 'OpenGVLab/InternVL2_5-1B'
model = AutoModel.from_pretrained(
    path,
    low_cpu_mem_usage=True,
    use_flash_attn=False,
    trust_remote_code=True
).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

@app.post("/predict")
async def predict(file: UploadFile = File(...), question: str = "Describe the image"):
    # Load and preprocess the image
    file_bytes = BytesIO(await file.read())
    pixel_values = load_image(file_bytes)
    
    # Generate a response
    generation_config = dict(max_new_tokens=1024, do_sample=True)
    response, _ = model.chat(tokenizer, pixel_values, question, generation_config)
    return {"question": question, "response": response}