chatbot-zero / app.py
John6666's picture
Upload 2 files
f53e84c verified
raw
history blame
7.33 kB
import spaces
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("high")
HF_TOKEN = os.getenv("HF_TOKEN", None)
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DESCRIPTION = f'''
<div>
<h1 style="text-align: center;">{REPO_ID}</h1>
</div>
'''
PLACEHOLDER = f"""
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">{REPO_ID}</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
if torch.cuda.is_available():
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(REPO_ID, device_map="auto", quantization_config=nf4_config)
else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
@spaces.GPU(duration=59)
@torch.inference_mode()
def chat_stream(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
#print("history: ", [{"role": x["role"], "content": x["content"]} for x in history if "role" in x.keys()])
#print("messages: ", [{"role": x["role"], "content": x["content"]} for x in messages if "role" in x.keys()])
#print("tokenized: ", tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, add_special_tokens=False, tokenize=False))
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
streamer=streamer,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for text in streamer:
response[-1]["content"] += text
yield response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
yield response
finally:
flush()
@spaces.GPU(duration=59)
@torch.inference_mode()
def chat(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template([{"role": x["role"], "content": x["content"]} for x in history + messages if "role" in x.keys()], add_generation_prompt=True, return_dict=True, add_special_tokens=False, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
output_ids = model.generate(**generate_kwargs)
output = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1) :], skip_special_tokens=True)
response[-1]["content"] = output
return response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
return response
finally:
flush()
with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_stream,
type="messages",
chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'),
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
gr.Textbox(value="", label="System prompt", render=False)
],
save_history=True,
examples=[
['How to setup a human base on Mars? Give short answer.'],
['Explain theory of relativity to me like I’m 8 years old.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False)
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)