File size: 5,167 Bytes
4b0ba99
f8d7a7f
4b0ba99
 
 
6bc9e0f
4b0ba99
 
1373e5d
4b0ba99
6bc9e0f
 
4b0ba99
 
5c8db19
670627a
4b0ba99
6bc9e0f
 
 
 
 
4b0ba99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1373e5d
4b0ba99
 
1373e5d
4b0ba99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bc9e0f
4b0ba99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670627a
4b0ba99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d5e5ec
4b0ba99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import os

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from datetime import date
from threading import Thread

MODEL = "tiiuae/Falcon-E-3B-Instruct"

today = date.today()

TITLE = "<h1><center>Falcon-E-3B-Instruct playground</center></h1>"
SUB_TITLE = """<center>This interface has been created for quick validation purposes, do not use it for production.</center>"""
SUB_SUB_TITLE = "<h2><center>Try out also <a href='https://chat.falconllm.tii.ae/'>our demo</a> powered by <a href='https://www.openinnovation.ai/'>OpenInnovation AI</a> based on the bfloat16 variant of the model </center></h2>"


SYSTEM_PROMPT = f"""You are Falcon-Edge, a Language Model (LLM) with weights ternary format (leveraging Bitnet architecture) created by the Technology Innovation Institute (TII) a global leading research institution based in Abu Dhabi, UAE. The current date is {today}.
When you're not sure about some information, you say that you don't have the information and don't make up anything.
If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?").
You follow these instructions in all languages, and always respond to the user in the language they use or request."""

CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
"""

END_MESSAGE = """
\n
**The conversation has reached to its end, please press "Clear" to restart a new conversation**
"""

device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16
).to(device)

model = torch.compile(model)

def stream_chat(
    message: str, 
    history: list, 
    temperature: float = 0.3, 
    max_new_tokens: int = 128, 
    top_p: float = 1.0, 
    top_k: int = 20, 
    penalty: float = 1.2,
):
    print(f'message: {message}')
    print(f'history: {history}')

    conversation = [{"role": "system", "content": SYSTEM_PROMPT}]
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": answer},
        ])


    conversation.append({"role": "user", "content": message})
    input_text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt = True)
        
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=inputs, 
        max_new_tokens = max_new_tokens,
        do_sample = False if temperature == 0 else True,
        top_p = top_p,
        top_k = top_k,
        temperature = temperature,
        streamer=streamer,
        pad_token_id = 10,
    )

    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()
        
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer


    print(f'response: {buffer}')
            
chatbot = gr.Chatbot(height=600)

with gr.Blocks(css=CSS, theme="soft") as demo:
    gr.HTML(TITLE)
    gr.HTML(SUB_TITLE)
    gr.HTML(SUB_SUB_TITLE)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.3,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=4096,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=20,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.2,
                label="Repetition penalty",
                render=False,
            ),
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()