import spaces
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
HF_TOKEN = os.getenv("HF_TOKEN", None)
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
DESCRIPTION = f'''
{REPO_ID}
'''
PLACEHOLDER = f"""
{REPO_ID}
Ask me anything...
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
if torch.cuda.is_available():
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(REPO_ID, quantization_config=nf4_config)
else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
@spaces.GPU(duration=30)
def chat(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template(history + messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for text in streamer:
response[-1]["content"] += text
yield response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
yield response
with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat,
type="messages",
chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'),
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
gr.Textbox(value="", label="System prompt", render=False),
],
save_history=True,
examples=[
['How to setup a human base on Mars? Give short answer.'],
['Explain theory of relativity to me like I’m 8 years old.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False)
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)