Spaces:
Runtime error
Runtime error
File size: 4,790 Bytes
04c9edc 77c7cf9 04c9edc 8c30c71 04c9edc 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 5d07c9b 77c7cf9 8c30c71 77c7cf9 04c9edc 8c30c71 04c9edc 8c30c71 04c9edc 77c7cf9 8c30c71 04c9edc 77c7cf9 8c30c71 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 04c9edc 77c7cf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import spaces
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
HF_TOKEN = os.getenv("HF_TOKEN", None)
#REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
DESCRIPTION = f'''
<div>
<h1 style="text-align: center;">{REPO_ID}</h1>
</div>
'''
PLACEHOLDER = f"""
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">{REPO_ID}</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
if torch.cuda.is_available():
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(REPO_ID, quantization_config=nf4_config)
else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
@spaces.GPU(duration=30)
def chat(message: str,
history: list[dict],
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
repetition_penalty: float,
sys_prompt: str,
progress=gr.Progress(track_tqdm=True)
):
try:
messages = []
response = []
if not history: history = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": message})
input_tensors = tokenizer.apply_chat_template(history + messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)
input_ids = input_tensors["input_ids"]
attention_mask = input_tensors["attention_mask"]
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
if temperature == 0: generate_kwargs['do_sample'] = False
response.append({"role": "assistant", "content": ""})
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for text in streamer:
response[-1]["content"] += text
yield response
except Exception as e:
print(e)
gr.Warning(f"Error: {e}")
yield response
with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat,
type="messages",
chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'),
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
gr.Textbox(value="", label="System prompt", render=False),
],
save_history=True,
examples=[
['How to setup a human base on Mars? Give short answer.'],
['Explain theory of relativity to me like I’m 8 years old.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False)
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)
|