Ruurd commited on
Commit
2ed78b7
·
1 Parent(s): 55c85d9

Try to fix

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -12,6 +12,7 @@ import threading
12
  from transformers import TextIteratorStreamer
13
  import queue
14
 
 
15
  class RichTextStreamer(TextIteratorStreamer):
16
  def __init__(self, tokenizer, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
@@ -74,7 +75,7 @@ def chat_with_model(messages):
74
  pad_token_id=pad_id
75
  )
76
 
77
- thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
78
  thread.start()
79
 
80
  output_text = ""
@@ -87,11 +88,10 @@ def chat_with_model(messages):
87
  token_id = token_info["token_id"]
88
  is_special = token_info["is_special"]
89
 
90
- # Skip appending the EOS token to output
91
  if token_id == current_tokenizer.eos_token_id:
 
92
  break
93
 
94
- # Detect reasoning block
95
  if "<think>" in token_str:
96
  in_think = True
97
  token_str = token_str.replace("<think>", "")
@@ -112,6 +112,9 @@ def chat_with_model(messages):
112
  messages[-1]["content"] = output_text
113
  yield messages
114
 
 
 
 
115
  current_model.to("cpu")
116
  torch.cuda.empty_cache()
117
 
 
12
  from transformers import TextIteratorStreamer
13
  import queue
14
 
15
+ @spaces.GPU
16
  class RichTextStreamer(TextIteratorStreamer):
17
  def __init__(self, tokenizer, **kwargs):
18
  super().__init__(tokenizer, **kwargs)
 
75
  pad_token_id=pad_id
76
  )
77
 
78
+ thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
79
  thread.start()
80
 
81
  output_text = ""
 
88
  token_id = token_info["token_id"]
89
  is_special = token_info["is_special"]
90
 
 
91
  if token_id == current_tokenizer.eos_token_id:
92
+ streamer.end_of_generation.set() # signal to stop generation thread
93
  break
94
 
 
95
  if "<think>" in token_str:
96
  in_think = True
97
  token_str = token_str.replace("<think>", "")
 
112
  messages[-1]["content"] = output_text
113
  yield messages
114
 
115
+ # Ensure generation thread stops
116
+ thread.join(timeout=1.0)
117
+
118
  current_model.to("cpu")
119
  torch.cuda.empty_cache()
120