John6666 commited on
Commit
43e2aa6
·
verified ·
1 Parent(s): 77c7cf9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +20 -10
  2. requirements.txt +3 -1
app.py CHANGED
@@ -4,11 +4,18 @@ import os
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
5
  from threading import Thread
6
  import torch
7
- from torch.nn.attention import SDPBackend, sdpa_kernel
 
 
 
 
 
 
8
 
9
  HF_TOKEN = os.getenv("HF_TOKEN", None)
10
  #REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
11
  REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
 
12
 
13
  DESCRIPTION = f'''
14
  <div>
@@ -40,11 +47,13 @@ h1 {
40
  tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
41
  if torch.cuda.is_available():
42
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
43
- model = AutoModelForCausalLM.from_pretrained(REPO_ID, quantization_config=nf4_config)
44
  else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
45
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
46
 
47
- @spaces.GPU(duration=30)
 
48
  def chat(message: str,
49
  history: list[dict],
50
  temperature: float,
@@ -62,7 +71,8 @@ def chat(message: str,
62
  messages.append({"role": "system", "content": sys_prompt})
63
  messages.append({"role": "user", "content": message})
64
 
65
- input_tensors = tokenizer.apply_chat_template(history + messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)
 
66
 
67
  input_ids = input_tensors["input_ids"]
68
  attention_mask = input_tensors["attention_mask"]
@@ -70,8 +80,8 @@ def chat(message: str,
70
  generate_kwargs = dict(
71
  input_ids=input_ids,
72
  attention_mask=attention_mask,
73
- streamer=streamer,
74
  max_new_tokens=max_new_tokens,
 
75
  do_sample=True,
76
  temperature=temperature,
77
  top_k=top_k,
@@ -82,10 +92,8 @@ def chat(message: str,
82
  if temperature == 0: generate_kwargs['do_sample'] = False
83
  response.append({"role": "assistant", "content": ""})
84
 
85
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
86
- t = Thread(target=model.generate, kwargs=generate_kwargs)
87
- t.start()
88
-
89
  for text in streamer:
90
  response[-1]["content"] += text
91
  yield response
@@ -93,6 +101,8 @@ def chat(message: str,
93
  print(e)
94
  gr.Warning(f"Error: {e}")
95
  yield response
 
 
96
 
97
  with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
98
  gr.Markdown(DESCRIPTION)
@@ -108,7 +118,7 @@ with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
108
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
109
  gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
110
  gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
111
- gr.Textbox(value="", label="System prompt", render=False),
112
  ],
113
  save_history=True,
114
  examples=[
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
5
  from threading import Thread
6
  import torch
7
+ import gc
8
+
9
+ def flush():
10
+ gc.collect()
11
+ torch.cuda.empty_cache()
12
+
13
+ torch.set_float32_matmul_precision("high")
14
 
15
  HF_TOKEN = os.getenv("HF_TOKEN", None)
16
  #REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
17
  REPO_ID = "nicoboss/DeepSeek-R1-Distill-Qwen-32B-Uncensored"
18
+ #REPO_ID = "Qwen/QwQ-32B"
19
 
20
  DESCRIPTION = f'''
21
  <div>
 
47
  tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
48
  if torch.cuda.is_available():
49
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
50
+ model = AutoModelForCausalLM.from_pretrained(REPO_ID, device_map="auto", quantization_config=nf4_config)
51
  else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
52
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
53
+ flush()
54
 
55
+ @spaces.GPU(duration=59)
56
+ @torch.inference_mode()
57
  def chat(message: str,
58
  history: list[dict],
59
  temperature: float,
 
71
  messages.append({"role": "system", "content": sys_prompt})
72
  messages.append({"role": "user", "content": message})
73
 
74
+ #input_tensors = tokenizer.apply_chat_template(history + messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)
75
+ input_tensors = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)
76
 
77
  input_ids = input_tensors["input_ids"]
78
  attention_mask = input_tensors["attention_mask"]
 
80
  generate_kwargs = dict(
81
  input_ids=input_ids,
82
  attention_mask=attention_mask,
 
83
  max_new_tokens=max_new_tokens,
84
+ streamer=streamer,
85
  do_sample=True,
86
  temperature=temperature,
87
  top_k=top_k,
 
92
  if temperature == 0: generate_kwargs['do_sample'] = False
93
  response.append({"role": "assistant", "content": ""})
94
 
95
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
96
+ t.start()
 
 
97
  for text in streamer:
98
  response[-1]["content"] += text
99
  yield response
 
101
  print(e)
102
  gr.Warning(f"Error: {e}")
103
  yield response
104
+ finally:
105
+ flush()
106
 
107
  with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
108
  gr.Markdown(DESCRIPTION)
 
118
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
119
  gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
120
  gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
121
+ gr.Textbox(value="", label="System prompt", render=False)
122
  ],
123
  save_history=True,
124
  examples=[
requirements.txt CHANGED
@@ -2,8 +2,10 @@ huggingface_hub
2
  torch==2.4.0
3
  torchvision
4
  accelerate
5
- transformers
6
  numpy<2
7
  sentencepiece
8
  triton
 
 
9
  bitsandbytes
 
2
  torch==2.4.0
3
  torchvision
4
  accelerate
5
+ git+https://github.com/huggingface/transformers
6
  numpy<2
7
  sentencepiece
8
  triton
9
+ optimum
10
+ optimum-quanto
11
  bitsandbytes