John6666 commited on
Commit
77c7cf9
·
verified ·
1 Parent(s): 5d07c9b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -30
app.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
5
  from threading import Thread
6
  import torch
 
7
 
8
  HF_TOKEN = os.getenv("HF_TOKEN", None)
9
  #REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
@@ -43,18 +44,25 @@ if torch.cuda.is_available():
43
  else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
44
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
45
 
46
- @spaces.GPU(duration=59)
47
- def chat(message: str,
48
- history: list[dict],
49
- temperature: float,
50
  max_new_tokens: int,
 
 
 
 
51
  progress=gr.Progress(track_tqdm=True)
52
  ):
53
  try:
 
 
54
  if not history: history = []
55
- history.append({"role": "user", "content": message})
 
56
 
57
- input_tensors = tokenizer.apply_chat_template(history, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)
58
 
59
  input_ids = input_tensors["input_ids"]
60
  attention_mask = input_tensors["attention_mask"]
@@ -66,53 +74,51 @@ def chat(message: str,
66
  max_new_tokens=max_new_tokens,
67
  do_sample=True,
68
  temperature=temperature,
 
 
 
69
  pad_token_id=tokenizer.eos_token_id,
70
  )
71
  if temperature == 0: generate_kwargs['do_sample'] = False
72
- history.append({"role": "assistant", "content": ""})
73
 
74
- t = Thread(target=model.generate, kwargs=generate_kwargs)
75
- t.start()
 
76
 
77
  for text in streamer:
78
- history[-1]["content"] += text
79
- yield history
80
  except Exception as e:
81
  print(e)
82
  gr.Warning(f"Error: {e}")
83
- yield history
84
-
85
- chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface')
86
 
87
  with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
88
  gr.Markdown(DESCRIPTION)
89
  gr.ChatInterface(
90
  fn=chat,
91
- chatbot=chatbot,
 
92
  fill_height=True,
93
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
94
  additional_inputs=[
95
- gr.Slider(minimum=0,
96
- maximum=1,
97
- step=0.1,
98
- value=0.5,
99
- label="Temperature",
100
- render=False),
101
- gr.Slider(minimum=128,
102
- maximum=4096,
103
- step=1,
104
- value=512,
105
- label="Max new tokens",
106
- render=False),
107
- ],
108
  examples=[
109
  ['How to setup a human base on Mars? Give short answer.'],
110
  ['Explain theory of relativity to me like I’m 8 years old.'],
111
  ['What is 9,000 * 9,000?'],
112
  ['Write a pun-filled happy birthday message to my friend Alex.'],
113
  ['Justify why a penguin might make a good king of the jungle.']
114
- ],
115
  cache_examples=False)
116
 
117
  if __name__ == "__main__":
118
- demo.launch(ssr_mode=False)
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
5
  from threading import Thread
6
  import torch
7
+ from torch.nn.attention import SDPBackend, sdpa_kernel
8
 
9
  HF_TOKEN = os.getenv("HF_TOKEN", None)
10
  #REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
 
44
  else: model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
45
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
46
 
47
+ @spaces.GPU(duration=30)
48
+ def chat(message: str,
49
+ history: list[dict],
50
+ temperature: float,
51
  max_new_tokens: int,
52
+ top_p: float,
53
+ top_k: int,
54
+ repetition_penalty: float,
55
+ sys_prompt: str,
56
  progress=gr.Progress(track_tqdm=True)
57
  ):
58
  try:
59
+ messages = []
60
+ response = []
61
  if not history: history = []
62
+ messages.append({"role": "system", "content": sys_prompt})
63
+ messages.append({"role": "user", "content": message})
64
 
65
+ input_tensors = tokenizer.apply_chat_template(history + messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(model.device)
66
 
67
  input_ids = input_tensors["input_ids"]
68
  attention_mask = input_tensors["attention_mask"]
 
74
  max_new_tokens=max_new_tokens,
75
  do_sample=True,
76
  temperature=temperature,
77
+ top_k=top_k,
78
+ top_p=top_p,
79
+ repetition_penalty=repetition_penalty,
80
  pad_token_id=tokenizer.eos_token_id,
81
  )
82
  if temperature == 0: generate_kwargs['do_sample'] = False
83
+ response.append({"role": "assistant", "content": ""})
84
 
85
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
86
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
87
+ t.start()
88
 
89
  for text in streamer:
90
+ response[-1]["content"] += text
91
+ yield response
92
  except Exception as e:
93
  print(e)
94
  gr.Warning(f"Error: {e}")
95
+ yield response
 
 
96
 
97
  with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
98
  gr.Markdown(DESCRIPTION)
99
  gr.ChatInterface(
100
  fn=chat,
101
+ type="messages",
102
+ chatbot=gr.Chatbot(height=450, type="messages", placeholder=PLACEHOLDER, label='Gradio ChatInterface'),
103
  fill_height=True,
104
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
105
  additional_inputs=[
106
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False),
107
+ gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
108
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", render=False),
109
+ gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k", render=False),
110
+ gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", render=False),
111
+ gr.Textbox(value="", label="System prompt", render=False),
112
+ ],
113
+ save_history=True,
 
 
 
 
 
114
  examples=[
115
  ['How to setup a human base on Mars? Give short answer.'],
116
  ['Explain theory of relativity to me like I’m 8 years old.'],
117
  ['What is 9,000 * 9,000?'],
118
  ['Write a pun-filled happy birthday message to my friend Alex.'],
119
  ['Justify why a penguin might make a good king of the jungle.']
120
+ ],
121
  cache_examples=False)
122
 
123
  if __name__ == "__main__":
124
+ demo.queue().launch(ssr_mode=False)