Fix stopping criteria
Browse files
app.py
CHANGED
@@ -6,6 +6,16 @@ import spaces
|
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
7 |
import threading
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
@spaces.GPU
|
10 |
def chat_with_model(messages):
|
11 |
global current_model, current_tokenizer
|
@@ -18,15 +28,17 @@ def chat_with_model(messages):
|
|
18 |
prompt = format_prompt(messages)
|
19 |
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
|
20 |
|
21 |
-
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=
|
|
|
|
|
22 |
generation_kwargs = dict(
|
23 |
**inputs,
|
24 |
max_new_tokens=256,
|
25 |
do_sample=True,
|
26 |
-
streamer=streamer
|
|
|
27 |
)
|
28 |
|
29 |
-
# Launch generation in a background thread
|
30 |
thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
|
31 |
thread.start()
|
32 |
|
@@ -37,15 +49,13 @@ def chat_with_model(messages):
|
|
37 |
for new_text in streamer:
|
38 |
output_text += new_text
|
39 |
messages[-1]["content"] = output_text
|
40 |
-
if current_tokenizer.eos_token and current_tokenizer.eos_token in output_text:
|
41 |
-
break
|
42 |
yield messages
|
43 |
-
|
44 |
|
45 |
current_model.to("cpu")
|
46 |
torch.cuda.empty_cache()
|
47 |
|
48 |
|
|
|
49 |
# Globals
|
50 |
current_model = None
|
51 |
current_tokenizer = None
|
|
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
7 |
import threading
|
8 |
|
9 |
+
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
|
10 |
+
import threading
|
11 |
+
|
12 |
+
class StopOnEos(StoppingCriteria):
|
13 |
+
def __init__(self, eos_token_id):
|
14 |
+
self.eos_token_id = eos_token_id
|
15 |
+
|
16 |
+
def __call__(self, input_ids, scores, **kwargs):
|
17 |
+
return input_ids[0, -1].item() == self.eos_token_id
|
18 |
+
|
19 |
@spaces.GPU
|
20 |
def chat_with_model(messages):
|
21 |
global current_model, current_tokenizer
|
|
|
28 |
prompt = format_prompt(messages)
|
29 |
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
|
30 |
|
31 |
+
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
|
32 |
+
stopping_criteria = StoppingCriteriaList([StopOnEos(current_tokenizer.eos_token_id)])
|
33 |
+
|
34 |
generation_kwargs = dict(
|
35 |
**inputs,
|
36 |
max_new_tokens=256,
|
37 |
do_sample=True,
|
38 |
+
streamer=streamer,
|
39 |
+
stopping_criteria=stopping_criteria
|
40 |
)
|
41 |
|
|
|
42 |
thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
|
43 |
thread.start()
|
44 |
|
|
|
49 |
for new_text in streamer:
|
50 |
output_text += new_text
|
51 |
messages[-1]["content"] = output_text
|
|
|
|
|
52 |
yield messages
|
|
|
53 |
|
54 |
current_model.to("cpu")
|
55 |
torch.cuda.empty_cache()
|
56 |
|
57 |
|
58 |
+
|
59 |
# Globals
|
60 |
current_model = None
|
61 |
current_tokenizer = None
|