Spaces:
Sleeping
Sleeping
import torch | |
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import gradio as gr | |
from datetime import date | |
from threading import Thread | |
MODEL = "tiiuae/Falcon-E-3B-Instruct" | |
today = date.today() | |
TITLE = "<h1><center>Falcon-E-3B-Instruct playground</center></h1>" | |
SUB_TITLE = """<center>This interface has been created for quick validation purposes, do not use it for production.</center>""" | |
SUB_SUB_TITLE = "<h2><center>Try out also <a href='https://chat.falconllm.tii.ae/'>our demo</a> powered by <a href='https://www.openinnovation.ai/'>OpenInnovation AI</a> based on the bfloat16 variant of the model </center></h2>" | |
SYSTEM_PROMPT = f"""You are Falcon-Edge, a Language Model (LLM) with weights ternary format (leveraging Bitnet architecture) created by the Technology Innovation Institute (TII) a global leading research institution based in Abu Dhabi, UAE. The current date is {today}. | |
When you're not sure about some information, you say that you don't have the information and don't make up anything. | |
If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?"). | |
You follow these instructions in all languages, and always respond to the user in the language they use or request.""" | |
CSS = """ | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: black !important; | |
border-radius: 100vh !important; | |
} | |
h3 { | |
text-align: center; | |
} | |
""" | |
END_MESSAGE = """ | |
\n | |
**The conversation has reached to its end, please press "Clear" to restart a new conversation** | |
""" | |
device = "cuda" # for GPU usage or "cpu" for CPU usage | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
model = torch.compile(model) | |
def stream_chat( | |
message: str, | |
history: list, | |
temperature: float = 0.3, | |
max_new_tokens: int = 128, | |
top_p: float = 1.0, | |
top_k: int = 20, | |
penalty: float = 1.2, | |
): | |
print(f'message: {message}') | |
print(f'history: {history}') | |
conversation = [{"role": "system", "content": SYSTEM_PROMPT}] | |
for prompt, answer in history: | |
conversation.extend([ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": answer}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
input_text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt = True) | |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=inputs, | |
max_new_tokens = max_new_tokens, | |
do_sample = False if temperature == 0 else True, | |
top_p = top_p, | |
top_k = top_k, | |
temperature = temperature, | |
streamer=streamer, | |
pad_token_id = 10, | |
) | |
with torch.no_grad(): | |
thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
yield buffer | |
print(f'response: {buffer}') | |
chatbot = gr.Chatbot(height=600) | |
with gr.Blocks(css=CSS, theme="soft") as demo: | |
gr.HTML(TITLE) | |
gr.HTML(SUB_TITLE) | |
gr.HTML(SUB_SUB_TITLE) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
gr.ChatInterface( | |
fn=stream_chat, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.3, | |
label="Temperature", | |
render=False, | |
), | |
gr.Slider( | |
minimum=128, | |
maximum=4096, | |
step=1, | |
value=1024, | |
label="Max new tokens", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=1.0, | |
label="top_p", | |
render=False, | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=20, | |
label="top_k", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.2, | |
label="Repetition penalty", | |
render=False, | |
), | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() |