vision / demo.py
Pittawat Taveekitworachai
feat: limit min max img size
c0b4343
raw
history blame
3.43 kB
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor,
TextIteratorStreamer,
)
from PIL import Image
from threading import Thread
import gradio as gr
model_name = "scb10x/typhoon2-qwen2vl-7b-vision-instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(
model_name, min_pixels=min_pixels, max_pixels=max_pixels
)
def bot_streaming(message, history, max_new_tokens=512):
txt = message["text"]
messages = []
images = []
for i, msg in enumerate(history):
if isinstance(msg[0], tuple):
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": history[i + 1][0]},
{"type": "image"},
],
}
)
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": history[i + 1][1]}],
}
)
images.append(Image.open(msg[0][0]).convert("RGB"))
elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str):
pass
elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str):
messages.append(
{"role": "user", "content": [{"type": "text", "text": msg[0]}]}
)
messages.append(
{"role": "assistant", "content": [{"type": "text", "text": msg[1]}]}
)
if len(message["files"]) == 1:
if isinstance(message["files"][0], str):
image = Image.open(message["files"][0]).convert("RGB")
else:
image = Image.open(message["files"][0]["path"]).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": txt}, {"type": "image"}],
}
)
else:
messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
if images == []:
inputs = processor(text=texts, return_tensors="pt").to("cuda")
else:
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
processor, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
demo = gr.ChatInterface(
fn=bot_streaming,
title="Typhoon 2 Vision",
textbox=gr.MultimodalTextbox(),
additional_inputs=[
gr.Slider(
minimum=512,
maximum=1024,
value=512,
step=1,
label="Maximum number of new tokens to generate",
)
],
cache_examples=False,
stop_btn="Stop Generation",
fill_height=True,
multimodal=True,
)
demo.launch(debug=True)