File size: 4,881 Bytes
eacf0bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9edaf8c
 
eacf0bd
 
9edaf8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# A100 Zero GPU
import spaces

# TroL Package
import torch
from PIL import Image
from utils.utils import *
import torch.nn.functional as F
from trol.load_trol import load_trol
from torchvision.transforms.functional import pil_to_tensor

# Gradio Package
import time
import gradio as gr
from threading import Thread
from accelerate import Accelerator
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor

# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# accel
accel = Accelerator()

# model selection
link = "TroL-7B" # [Select One] 'TroL-1.8B' | 'TroL-3.8B' | 'TroL-7B'

# User prompt
prompt_type="with_image" # Select one option "text_only", "with_image"
img_path='figures/demo.png'
question="What is the troll doing? Provide the detail in the image and imagine what the event happens."

# loading model
model, tokenizer = load_trol(link=link)

# cpu -> gpu
for param in model.parameters():
    if not param.is_cuda:
        param.data = param.to('cuda:0')

def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p):

    # propagation
    _inputs = model.eval_process(inputs=inputs,
                                 data='demo',
                                 tokenizer=tokenizer,
                                 device=device,
                                 img_token_number=image_token_number)
    generation_kwargs = _inputs
    generation_kwargs.update({'streamer': streamer})
    generation_kwargs.update({'do_sample': True})
    generation_kwargs.update({'max_new_tokens': new_max_token})
    generation_kwargs.update({'top_p': top_p})
    generation_kwargs.update({'temperature': temperature})
    generation_kwargs.update({'use_cache': True})
    return model.generate(**generation_kwargs)

@spaces.GPU
def bot_streaming(message, history, link, temperature, new_max_token, top_p):

    try:
        # prompt type -> input prompt
        image_token_number = None
        if len(message['files']) != 0:
            # Image Load
            image = pil_to_tensor(Image.open(Image.open(message['files'][0]).convert("RGB")).convert("RGB"))
            if not "3.8B" in link:
                image_token_number = 1225
                image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
            inputs = [{'image': image, 'question': message['text']}]

        else:
            inputs = [{'question': message['text']}]

        # Text Generation
        with torch.inference_mode():
            # kwargs
            streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

            # Threading generation
            thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
                                                                image_token_number=image_token_number,
                                                                streamer=streamer,
                                                                device=accel.device,
                                                                temperature=temperature,
                                                                new_max_token=new_max_token,
                                                                top_p=top_p))
            thread.start()

            # generated text
            generated_text = ""
            for new_text in streamer:
                generated_text += new_text
            generated_text

        # Text decoding
        response = output_filtering(generated_text, model)
    
    except:
        response = "There may be unsupported format: ex) pdf, video, sound. Only supported is single image in this version."

    # private log print
    text = message['text']
    files = message['files']
    print(f'Text: {text}')
    print(f'MM Files: {files}')


    buffer = ""
    for character in response:
        buffer += character
        time.sleep(0.015)
        yield buffer

demo = gr.ChatInterface(fn=bot_streaming,
                        additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
                        additional_inputs_accordion="Generation Hyperparameters",
                        theme=gr.themes.Soft(),
                        title="TroL",
                        description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy\n"
                                    "Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)",
                        stop_btn="Stop Generation", multimodal=True)
demo.launch()