Try to fix
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import threading
|
|
12 |
from transformers import TextIteratorStreamer
|
13 |
import queue
|
14 |
|
|
|
15 |
class RichTextStreamer(TextIteratorStreamer):
|
16 |
def __init__(self, tokenizer, **kwargs):
|
17 |
super().__init__(tokenizer, **kwargs)
|
@@ -74,7 +75,7 @@ def chat_with_model(messages):
|
|
74 |
pad_token_id=pad_id
|
75 |
)
|
76 |
|
77 |
-
|
78 |
thread.start()
|
79 |
|
80 |
output_text = ""
|
@@ -87,11 +88,10 @@ def chat_with_model(messages):
|
|
87 |
token_id = token_info["token_id"]
|
88 |
is_special = token_info["is_special"]
|
89 |
|
90 |
-
# Skip appending the EOS token to output
|
91 |
if token_id == current_tokenizer.eos_token_id:
|
|
|
92 |
break
|
93 |
|
94 |
-
# Detect reasoning block
|
95 |
if "<think>" in token_str:
|
96 |
in_think = True
|
97 |
token_str = token_str.replace("<think>", "")
|
@@ -112,6 +112,9 @@ def chat_with_model(messages):
|
|
112 |
messages[-1]["content"] = output_text
|
113 |
yield messages
|
114 |
|
|
|
|
|
|
|
115 |
current_model.to("cpu")
|
116 |
torch.cuda.empty_cache()
|
117 |
|
|
|
12 |
from transformers import TextIteratorStreamer
|
13 |
import queue
|
14 |
|
15 |
+
@spaces.GPU
|
16 |
class RichTextStreamer(TextIteratorStreamer):
|
17 |
def __init__(self, tokenizer, **kwargs):
|
18 |
super().__init__(tokenizer, **kwargs)
|
|
|
75 |
pad_token_id=pad_id
|
76 |
)
|
77 |
|
78 |
+
thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
|
79 |
thread.start()
|
80 |
|
81 |
output_text = ""
|
|
|
88 |
token_id = token_info["token_id"]
|
89 |
is_special = token_info["is_special"]
|
90 |
|
|
|
91 |
if token_id == current_tokenizer.eos_token_id:
|
92 |
+
streamer.end_of_generation.set() # signal to stop generation thread
|
93 |
break
|
94 |
|
|
|
95 |
if "<think>" in token_str:
|
96 |
in_think = True
|
97 |
token_str = token_str.replace("<think>", "")
|
|
|
112 |
messages[-1]["content"] = output_text
|
113 |
yield messages
|
114 |
|
115 |
+
# Ensure generation thread stops
|
116 |
+
thread.join(timeout=1.0)
|
117 |
+
|
118 |
current_model.to("cpu")
|
119 |
torch.cuda.empty_cache()
|
120 |
|