fix please
Browse files
app.py
CHANGED
@@ -16,7 +16,6 @@ class RichTextStreamer(TextIteratorStreamer):
|
|
16 |
def __init__(self, tokenizer, **kwargs):
|
17 |
super().__init__(tokenizer, **kwargs)
|
18 |
self.token_queue = queue.Queue()
|
19 |
-
self.prompt_shown = not self.skip_prompt
|
20 |
|
21 |
def put(self, value):
|
22 |
if isinstance(value, torch.Tensor):
|
@@ -27,9 +26,6 @@ class RichTextStreamer(TextIteratorStreamer):
|
|
27 |
token_ids = [value]
|
28 |
|
29 |
for token_id in token_ids:
|
30 |
-
if self.skip_prompt and not self.prompt_shown:
|
31 |
-
continue # skip prompt tokens
|
32 |
-
|
33 |
token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
|
34 |
is_special = token_id in self.tokenizer.all_special_ids
|
35 |
self.token_queue.put({
|
@@ -37,7 +33,6 @@ class RichTextStreamer(TextIteratorStreamer):
|
|
37 |
"token": token_str,
|
38 |
"is_special": is_special
|
39 |
})
|
40 |
-
self.prompt_shown = True
|
41 |
|
42 |
def __iter__(self):
|
43 |
while True:
|
@@ -92,6 +87,8 @@ def chat_with_model(messages):
|
|
92 |
|
93 |
print(f'Step 1: {messages}')
|
94 |
|
|
|
|
|
95 |
for token_info in streamer:
|
96 |
token_str = token_info["token"]
|
97 |
token_id = token_info["token_id"]
|
@@ -119,6 +116,12 @@ def chat_with_model(messages):
|
|
119 |
output_text = output_text.split("\nUser:")[0].rstrip()
|
120 |
break
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
generated_tokens += 1
|
123 |
if generated_tokens >= max_new_tokens:
|
124 |
break
|
|
|
16 |
def __init__(self, tokenizer, **kwargs):
|
17 |
super().__init__(tokenizer, **kwargs)
|
18 |
self.token_queue = queue.Queue()
|
|
|
19 |
|
20 |
def put(self, value):
|
21 |
if isinstance(value, torch.Tensor):
|
|
|
26 |
token_ids = [value]
|
27 |
|
28 |
for token_id in token_ids:
|
|
|
|
|
|
|
29 |
token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
|
30 |
is_special = token_id in self.tokenizer.all_special_ids
|
31 |
self.token_queue.put({
|
|
|
33 |
"token": token_str,
|
34 |
"is_special": is_special
|
35 |
})
|
|
|
36 |
|
37 |
def __iter__(self):
|
38 |
while True:
|
|
|
87 |
|
88 |
print(f'Step 1: {messages}')
|
89 |
|
90 |
+
prompt_text = current_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
91 |
+
|
92 |
for token_info in streamer:
|
93 |
token_str = token_info["token"]
|
94 |
token_id = token_info["token_id"]
|
|
|
116 |
output_text = output_text.split("\nUser:")[0].rstrip()
|
117 |
break
|
118 |
|
119 |
+
# Strip prompt from start of generated output
|
120 |
+
if output_text.startswith(prompt_text):
|
121 |
+
stripped_output = output_text[len(prompt_text):]
|
122 |
+
else:
|
123 |
+
stripped_output = output_text
|
124 |
+
|
125 |
generated_tokens += 1
|
126 |
if generated_tokens >= max_new_tokens:
|
127 |
break
|