File size: 3,520 Bytes
168da77
 
 
 
8584d3c
06de88f
bfd4b05
5ae7f9c
d364219
 
 
 
 
 
 
 
 
168da77
 
 
 
0963b2f
168da77
d3f5533
 
168da77
d364219
0963b2f
 
5659ce7
24384a7
 
 
 
 
 
 
 
 
5659ce7
24384a7
 
5659ce7
24384a7
 
 
 
 
 
 
 
 
d69619f
24384a7
 
 
 
1b08433
24384a7
 
d1df8d3
24384a7
daa8caf
24384a7
 
 
 
5ae7f9c
24384a7
5ae7f9c
24384a7
 
 
5ae7f9c
24384a7
 
5ae7f9c
24384a7
 
 
1c699c0
 
 
 
 
 
 
 
 
55f048a
1c699c0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
from threading import Thread
import spaces
import accelerate
import time

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''

model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

model.to('cuda')

processor = AutoProcessor.from_pretrained(model_id)

# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        # message["files"][-1] is a Dict or just a string
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        # if there's no image uploaded for this turn, look for images in the past turns
        # kept inside tuples, take the last one
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            # Handle the case where image is None
            gr.Error("You need to upload an image for LLaVA to work.")
    except NameError:
        # Handle the case where 'image' is not defined at all
        gr.Error("You need to upload an image for LLaVA to work.")

    prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    # print(f"prompt: {prompt}")
    image = Image.open(image)
    inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)

    streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    # print(f"text_prompt: {text_prompt}")

    buffer = ""
    time.sleep(0.5)
    for new_text in streamer:
        # find <|eot_id|> and remove it from the new_text
        if "<|eot_id|>" in new_text:
            new_text = new_text.split("<|eot_id|>")[0]
        buffer += new_text

        # generated_text_without_prompt = buffer[len(text_prompt):]
        generated_text_without_prompt = buffer
        # print(generated_text_without_prompt)
        time.sleep(0.06)
        # print(f"new_text: {generated_text_without_prompt}")
        yield generated_text_without_prompt


chatbot = gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)

with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=bot_streaming,
        chatbot=chatbot,
        fill_height=True,
        multimodal=True,
        textbox=chat_input,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)