pmolchanov commited on
Commit
ee30c97
·
verified ·
1 Parent(s): b24e029

Update app_chat.py

Browse files
Files changed (1) hide show
  1. app_chat.py +20 -20
app_chat.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
- # from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
10
 
11
  import subprocess
12
  # useradd -m -u 1000 user
@@ -33,19 +33,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
33
  tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
34
  #tokenizer.use_default_system_prompt = False
35
 
36
- # class StoppingCriteriaSub(StoppingCriteria):
37
- # def __init__(self, tokenizer, stops = [], encounters=1):
38
- # super().__init__()
39
- # self.stops = [stop.to("cuda") for stop in stops]
40
- # self.tokenizer = tokenizer
41
- # self.num_mamba_stop_ids = 8
42
 
43
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
44
- # last_token = input_ids[0][-self.num_mamba_stop_ids:]
45
- # for stop in self.stops:
46
- # if self.tokenizer.decode(stop) in self.tokenizer.decode(last_token):
47
- # return True
48
- # return False
49
 
50
  @spaces.GPU
51
  def generate(
@@ -66,7 +66,7 @@ def generate(
66
 
67
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
68
 
69
- # stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
70
 
71
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
72
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -84,7 +84,7 @@ def generate(
84
  temperature=temperature,
85
  num_beams=1,
86
  repetition_penalty=repetition_penalty,
87
- # "stopping_criteria": stopping_criteria,
88
  )
89
  t = Thread(target=model.generate, kwargs=generate_kwargs)
90
  t.start()
@@ -98,11 +98,11 @@ def generate(
98
  chat_interface = gr.ChatInterface(
99
  fn=generate,
100
  additional_inputs=[
101
- gr.Textbox(label="System prompt", lines=6, value="You are a helpful assistant. Your name is Hymba-1.5B-Instruct-8K. \
102
- You are a new family of small language models featuring a hybrid-head architecture that strategically integrates attention mechanisms with state space models (SSMs). \
103
- You are developed by Deep Learning Efficiency Research (DLER) team at NVIDIA Research. \
104
- Nvidia Corporation is an American multinational corporation and technology company headquartered in Santa Clara, California. Nvidia was founded on April 5, 1993 by Jensen Huang. \
105
- The above is just a background context. You can answer any questions not limited to the above background context."),
106
  gr.Slider(
107
  label="Max new tokens",
108
  minimum=1,
 
6
  import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
10
 
11
  import subprocess
12
  # useradd -m -u 1000 user
 
33
  tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
34
  #tokenizer.use_default_system_prompt = False
35
 
36
+ class StoppingCriteriaSub(StoppingCriteria):
37
+ def __init__(self, tokenizer, stops = [], encounters=1):
38
+ super().__init__()
39
+ self.stops = [stop.to("cuda") for stop in stops]
40
+ self.tokenizer = tokenizer
41
+ self.num_mamba_stop_ids = 8
42
 
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
44
+ last_token = input_ids[0][-self.num_mamba_stop_ids:]
45
+ for stop in self.stops:
46
+ if self.tokenizer.decode(stop) in self.tokenizer.decode(last_token):
47
+ return True
48
+ return False
49
 
50
  @spaces.GPU
51
  def generate(
 
66
 
67
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
68
 
69
+ stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
70
 
71
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
72
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
84
  temperature=temperature,
85
  num_beams=1,
86
  repetition_penalty=repetition_penalty,
87
+ "stopping_criteria": stopping_criteria,
88
  )
89
  t = Thread(target=model.generate, kwargs=generate_kwargs)
90
  t.start()
 
98
  chat_interface = gr.ChatInterface(
99
  fn=generate,
100
  additional_inputs=[
101
+ # gr.Textbox(label="System prompt", lines=6, value="You are a helpful assistant. Your name is Hymba-1.5B-Instruct-8K. \
102
+ # You are a new family of small language models featuring a hybrid-head architecture that strategically integrates attention mechanisms with state space models (SSMs). \
103
+ # You are developed by Deep Learning Efficiency Research (DLER) team at NVIDIA Research. \
104
+ # The above is just a background context. You can answer any questions not limited to the above background context."),
105
+ gr.Textbox(label="System prompt", lines=6, value="You are a helpful assistant. Your name is Hymba-1.5B-Instruct-8K. "),
106
  gr.Slider(
107
  label="Max new tokens",
108
  minimum=1,