Ruurd commited on
Commit
f69cdc0
·
1 Parent(s): 22564e3

Fix stopping criteria

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -6,6 +6,16 @@ import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  import threading
8
 
 
 
 
 
 
 
 
 
 
 
9
  @spaces.GPU
10
  def chat_with_model(messages):
11
  global current_model, current_tokenizer
@@ -18,15 +28,17 @@ def chat_with_model(messages):
18
  prompt = format_prompt(messages)
19
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
20
 
21
- streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
 
 
22
  generation_kwargs = dict(
23
  **inputs,
24
  max_new_tokens=256,
25
  do_sample=True,
26
- streamer=streamer
 
27
  )
28
 
29
- # Launch generation in a background thread
30
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
31
  thread.start()
32
 
@@ -37,15 +49,13 @@ def chat_with_model(messages):
37
  for new_text in streamer:
38
  output_text += new_text
39
  messages[-1]["content"] = output_text
40
- if current_tokenizer.eos_token and current_tokenizer.eos_token in output_text:
41
- break
42
  yield messages
43
-
44
 
45
  current_model.to("cpu")
46
  torch.cuda.empty_cache()
47
 
48
 
 
49
  # Globals
50
  current_model = None
51
  current_tokenizer = None
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  import threading
8
 
9
+ from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
10
+ import threading
11
+
12
+ class StopOnEos(StoppingCriteria):
13
+ def __init__(self, eos_token_id):
14
+ self.eos_token_id = eos_token_id
15
+
16
+ def __call__(self, input_ids, scores, **kwargs):
17
+ return input_ids[0, -1].item() == self.eos_token_id
18
+
19
  @spaces.GPU
20
  def chat_with_model(messages):
21
  global current_model, current_tokenizer
 
28
  prompt = format_prompt(messages)
29
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
30
 
31
+ streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
32
+ stopping_criteria = StoppingCriteriaList([StopOnEos(current_tokenizer.eos_token_id)])
33
+
34
  generation_kwargs = dict(
35
  **inputs,
36
  max_new_tokens=256,
37
  do_sample=True,
38
+ streamer=streamer,
39
+ stopping_criteria=stopping_criteria
40
  )
41
 
 
42
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
43
  thread.start()
44
 
 
49
  for new_text in streamer:
50
  output_text += new_text
51
  messages[-1]["content"] = output_text
 
 
52
  yield messages
 
53
 
54
  current_model.to("cpu")
55
  torch.cuda.empty_cache()
56
 
57
 
58
+
59
  # Globals
60
  current_model = None
61
  current_tokenizer = None