File size: 3,520 Bytes
168da77 8584d3c 06de88f bfd4b05 5ae7f9c d364219 168da77 0963b2f 168da77 d3f5533 168da77 d364219 0963b2f 5659ce7 24384a7 5659ce7 24384a7 5659ce7 24384a7 d69619f 24384a7 1b08433 24384a7 d1df8d3 24384a7 daa8caf 24384a7 5ae7f9c 24384a7 5ae7f9c 24384a7 5ae7f9c 24384a7 5ae7f9c 24384a7 1c699c0 55f048a 1c699c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
from threading import Thread
import spaces
import accelerate
import time
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton π</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model.to('cuda')
processor = AutoProcessor.from_pretrained(model_id)
# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
@spaces.GPU
def bot_streaming(message, history):
print(message)
if message["files"]:
# message["files"][-1] is a Dict or just a string
if type(message["files"][-1]) == dict:
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
# if there's no image uploaded for this turn, look for images in the past turns
# kept inside tuples, take the last one
for hist in history:
if type(hist[0]) == tuple:
image = hist[0][0]
try:
if image is None:
# Handle the case where image is None
gr.Error("You need to upload an image for LLaVA to work.")
except NameError:
# Handle the case where 'image' is not defined at all
gr.Error("You need to upload an image for LLaVA to work.")
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
# print(f"prompt: {prompt}")
image = Image.open(image)
inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
# print(f"text_prompt: {text_prompt}")
buffer = ""
time.sleep(0.5)
for new_text in streamer:
# find <|eot_id|> and remove it from the new_text
if "<|eot_id|>" in new_text:
new_text = new_text.split("<|eot_id|>")[0]
buffer += new_text
# generated_text_without_prompt = buffer[len(text_prompt):]
generated_text_without_prompt = buffer
# print(generated_text_without_prompt)
time.sleep(0.06)
# print(f"new_text: {generated_text_without_prompt}")
yield generated_text_without_prompt
chatbot = gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=bot_streaming,
chatbot=chatbot,
fill_height=True,
multimodal=True,
textbox=chat_input,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False) |