import spaces import gradio as gr import os from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig from threading import Thread import torch from torch.nn.attention import SDPBackend, sdpa_kernel HF_TOKEN = os.getenv("HF_TOKEN", None) #REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored" DESCRIPTION = f'''

{REPO_ID}

''' PLACEHOLDER = f"""

{REPO_ID}

Ask me anything...

""" css = """ h1 { text-align: center; display: block; } #duplicate-button { margin: auto; color: white; background: #1565c0; border-radius: 100vh; } """ tokenizer = AutoTokenizer.from_pretrained(REPO_ID) if torch.cuda.is_available(): nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) model = AutoModelForCausalLM.from_pretrained(REPO_ID, quantization_config=nf4_config) else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) @spaces.GPU(duration=30) def chat(message: str, history: list[dict], temperature: float, max_new_tokens: int, top_p: float, top_k: int, repetition_penalty: float, sys_prompt: str, progress=gr.Progress(track_tqdm=True) ): try: messages = [] response = [] if not history: history = [] messages.append({"role": "system", "content": sys_prompt}) messages.append({"role": "user", "content": message}) input_tensors = tokenizer.apply_chat_template(history + messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device) input_ids = input_tensors["input_ids"] attention_mask = input_tensors["attention_mask"] generate_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id, ) if temperature == 0: generate_kwargs['do_sample'] = False response.append({"role": "assistant", "content": ""}) with sdpa_kernel(SDPBackend.FLASH_ATTENTION): t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for text in streamer: response[-1]["content"] += text yield response except Exception as e: print(e) gr.Warning(f"Error: {e}") yield response with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo: gr.Markdown(DESCRIPTION) gr.ChatInterface( fn=chat, type="messages", chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'), fill_height=True, additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False), gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False), gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False), gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False), gr.Textbox(value="", label="System prompt", render=False), ], save_history=True, examples=[ ['How to setup a human base on Mars? Give short answer.'], ['Explain theory of relativity to me like I’m 8 years old.'], ['What is 9,000 * 9,000?'], ['Write a pun-filled happy birthday message to my friend Alex.'], ['Justify why a penguin might make a good king of the jungle.'] ], cache_examples=False) if __name__ == "__main__": demo.queue().launch(ssr_mode=False)