Ruurd commited on
Commit
0bcfdcb
·
1 Parent(s): 64a8918

fix please

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -16,7 +16,6 @@ class RichTextStreamer(TextIteratorStreamer):
16
  def __init__(self, tokenizer, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
18
  self.token_queue = queue.Queue()
19
- self.prompt_shown = not self.skip_prompt
20
 
21
  def put(self, value):
22
  if isinstance(value, torch.Tensor):
@@ -27,9 +26,6 @@ class RichTextStreamer(TextIteratorStreamer):
27
  token_ids = [value]
28
 
29
  for token_id in token_ids:
30
- if self.skip_prompt and not self.prompt_shown:
31
- continue # skip prompt tokens
32
-
33
  token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
34
  is_special = token_id in self.tokenizer.all_special_ids
35
  self.token_queue.put({
@@ -37,7 +33,6 @@ class RichTextStreamer(TextIteratorStreamer):
37
  "token": token_str,
38
  "is_special": is_special
39
  })
40
- self.prompt_shown = True
41
 
42
  def __iter__(self):
43
  while True:
@@ -92,6 +87,8 @@ def chat_with_model(messages):
92
 
93
  print(f'Step 1: {messages}')
94
 
 
 
95
  for token_info in streamer:
96
  token_str = token_info["token"]
97
  token_id = token_info["token_id"]
@@ -119,6 +116,12 @@ def chat_with_model(messages):
119
  output_text = output_text.split("\nUser:")[0].rstrip()
120
  break
121
 
 
 
 
 
 
 
122
  generated_tokens += 1
123
  if generated_tokens >= max_new_tokens:
124
  break
 
16
  def __init__(self, tokenizer, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
18
  self.token_queue = queue.Queue()
 
19
 
20
  def put(self, value):
21
  if isinstance(value, torch.Tensor):
 
26
  token_ids = [value]
27
 
28
  for token_id in token_ids:
 
 
 
29
  token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
30
  is_special = token_id in self.tokenizer.all_special_ids
31
  self.token_queue.put({
 
33
  "token": token_str,
34
  "is_special": is_special
35
  })
 
36
 
37
  def __iter__(self):
38
  while True:
 
87
 
88
  print(f'Step 1: {messages}')
89
 
90
+ prompt_text = current_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
91
+
92
  for token_info in streamer:
93
  token_str = token_info["token"]
94
  token_id = token_info["token_id"]
 
116
  output_text = output_text.split("\nUser:")[0].rstrip()
117
  break
118
 
119
+ # Strip prompt from start of generated output
120
+ if output_text.startswith(prompt_text):
121
+ stripped_output = output_text[len(prompt_text):]
122
+ else:
123
+ stripped_output = output_text
124
+
125
  generated_tokens += 1
126
  if generated_tokens >= max_new_tokens:
127
  break