Ruurd commited on
Commit
22564e3
·
1 Parent(s): cb98777

Stop at eos token

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -18,7 +18,7 @@ def chat_with_model(messages):
18
  prompt = format_prompt(messages)
19
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
20
 
21
- streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
22
  generation_kwargs = dict(
23
  **inputs,
24
  max_new_tokens=256,
@@ -37,7 +37,10 @@ def chat_with_model(messages):
37
  for new_text in streamer:
38
  output_text += new_text
39
  messages[-1]["content"] = output_text
 
 
40
  yield messages
 
41
 
42
  current_model.to("cpu")
43
  torch.cuda.empty_cache()
 
18
  prompt = format_prompt(messages)
19
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
20
 
21
+ streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
22
  generation_kwargs = dict(
23
  **inputs,
24
  max_new_tokens=256,
 
37
  for new_text in streamer:
38
  output_text += new_text
39
  messages[-1]["content"] = output_text
40
+ if current_tokenizer.eos_token and current_tokenizer.eos_token in output_text:
41
+ break
42
  yield messages
43
+
44
 
45
  current_model.to("cpu")
46
  torch.cuda.empty_cache()