Spaces:
Paused
Paused
Update app_chat.py
Browse files- app_chat.py +20 -20
app_chat.py
CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
|
|
6 |
import spaces
|
7 |
import torch
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
9 |
-
|
10 |
|
11 |
import subprocess
|
12 |
# useradd -m -u 1000 user
|
@@ -33,19 +33,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
33 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
34 |
#tokenizer.use_default_system_prompt = False
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
|
50 |
@spaces.GPU
|
51 |
def generate(
|
@@ -66,7 +66,7 @@ def generate(
|
|
66 |
|
67 |
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
|
68 |
|
69 |
-
|
70 |
|
71 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
72 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
@@ -84,7 +84,7 @@ def generate(
|
|
84 |
temperature=temperature,
|
85 |
num_beams=1,
|
86 |
repetition_penalty=repetition_penalty,
|
87 |
-
|
88 |
)
|
89 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
90 |
t.start()
|
@@ -98,11 +98,11 @@ def generate(
|
|
98 |
chat_interface = gr.ChatInterface(
|
99 |
fn=generate,
|
100 |
additional_inputs=[
|
101 |
-
gr.Textbox(label="System prompt", lines=6, value="You are a helpful assistant. Your name is Hymba-1.5B-Instruct-8K. \
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
gr.Slider(
|
107 |
label="Max new tokens",
|
108 |
minimum=1,
|
|
|
6 |
import spaces
|
7 |
import torch
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
9 |
+
from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
|
10 |
|
11 |
import subprocess
|
12 |
# useradd -m -u 1000 user
|
|
|
33 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
34 |
#tokenizer.use_default_system_prompt = False
|
35 |
|
36 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
37 |
+
def __init__(self, tokenizer, stops = [], encounters=1):
|
38 |
+
super().__init__()
|
39 |
+
self.stops = [stop.to("cuda") for stop in stops]
|
40 |
+
self.tokenizer = tokenizer
|
41 |
+
self.num_mamba_stop_ids = 8
|
42 |
|
43 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
44 |
+
last_token = input_ids[0][-self.num_mamba_stop_ids:]
|
45 |
+
for stop in self.stops:
|
46 |
+
if self.tokenizer.decode(stop) in self.tokenizer.decode(last_token):
|
47 |
+
return True
|
48 |
+
return False
|
49 |
|
50 |
@spaces.GPU
|
51 |
def generate(
|
|
|
66 |
|
67 |
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
|
68 |
|
69 |
+
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
|
70 |
|
71 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
72 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
|
|
84 |
temperature=temperature,
|
85 |
num_beams=1,
|
86 |
repetition_penalty=repetition_penalty,
|
87 |
+
"stopping_criteria": stopping_criteria,
|
88 |
)
|
89 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
90 |
t.start()
|
|
|
98 |
chat_interface = gr.ChatInterface(
|
99 |
fn=generate,
|
100 |
additional_inputs=[
|
101 |
+
# gr.Textbox(label="System prompt", lines=6, value="You are a helpful assistant. Your name is Hymba-1.5B-Instruct-8K. \
|
102 |
+
# You are a new family of small language models featuring a hybrid-head architecture that strategically integrates attention mechanisms with state space models (SSMs). \
|
103 |
+
# You are developed by Deep Learning Efficiency Research (DLER) team at NVIDIA Research. \
|
104 |
+
# The above is just a background context. You can answer any questions not limited to the above background context."),
|
105 |
+
gr.Textbox(label="System prompt", lines=6, value="You are a helpful assistant. Your name is Hymba-1.5B-Instruct-8K. "),
|
106 |
gr.Slider(
|
107 |
label="Max new tokens",
|
108 |
minimum=1,
|