Stop at eos token
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ def chat_with_model(messages):
|
|
18 |
prompt = format_prompt(messages)
|
19 |
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
|
20 |
|
21 |
-
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=
|
22 |
generation_kwargs = dict(
|
23 |
**inputs,
|
24 |
max_new_tokens=256,
|
@@ -37,7 +37,10 @@ def chat_with_model(messages):
|
|
37 |
for new_text in streamer:
|
38 |
output_text += new_text
|
39 |
messages[-1]["content"] = output_text
|
|
|
|
|
40 |
yield messages
|
|
|
41 |
|
42 |
current_model.to("cpu")
|
43 |
torch.cuda.empty_cache()
|
|
|
18 |
prompt = format_prompt(messages)
|
19 |
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
|
20 |
|
21 |
+
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
|
22 |
generation_kwargs = dict(
|
23 |
**inputs,
|
24 |
max_new_tokens=256,
|
|
|
37 |
for new_text in streamer:
|
38 |
output_text += new_text
|
39 |
messages[-1]["content"] = output_text
|
40 |
+
if current_tokenizer.eos_token and current_tokenizer.eos_token in output_text:
|
41 |
+
break
|
42 |
yield messages
|
43 |
+
|
44 |
|
45 |
current_model.to("cpu")
|
46 |
torch.cuda.empty_cache()
|