Metal3d commited on
Commit
2c7beb2
·
unverified ·
1 Parent(s): fb5a7c9

Changing the loop methodology

Browse files
Files changed (1) hide show
  1. main.py +13 -7
main.py CHANGED
@@ -4,7 +4,7 @@ import re
4
 
5
  import gradio as gr
6
  import spaces
7
- from transformers import AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
8
 
9
  JS = """
10
  () => {
@@ -46,6 +46,12 @@ print(model.config)
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
 
48
 
 
 
 
 
 
 
49
  def reformat_math(text):
50
  """Fix MathJax delimiters to use the Gradio syntax.
51
 
@@ -58,7 +64,7 @@ def reformat_math(text):
58
 
59
 
60
  @spaces.GPU
61
- def _generate(history):
62
  text = tokenizer.apply_chat_template(
63
  history,
64
  tokenize=False,
@@ -72,7 +78,7 @@ def _generate(history):
72
  asyncio.set_event_loop(loop)
73
 
74
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
75
- streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True)
76
 
77
  task = loop.run_in_executor(
78
  None,
@@ -97,15 +103,15 @@ async def chat(prompt, history):
97
  history = [] if history is None else history
98
  message_list = history + [message]
99
 
100
- task, streamer = _generate(message_list)
101
 
102
  buffer = ""
103
  reasoning = ""
104
  thinking = False
105
 
106
  try:
107
- async for new_text in streamer:
108
- if task.done() or task.cancelled():
109
  print("Cancelled")
110
  break # Stop le streaming si la tâche est annulée
111
 
@@ -127,7 +133,7 @@ async def chat(prompt, history):
127
 
128
  except asyncio.CancelledError:
129
  # this doesn't work, I don't find a way to stop generation thread
130
- print("Cancelled")
131
  streamer.on_finalized_text("cancelled", True)
132
  print("Signal sent")
133
  raise
 
4
 
5
  import gradio as gr
6
  import spaces
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
9
  JS = """
10
  () => {
 
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
 
48
 
49
+ async def stream(streamer):
50
+ for txt in streamer:
51
+ await asyncio.sleep(0.01)
52
+ yield txt
53
+
54
+
55
  def reformat_math(text):
56
  """Fix MathJax delimiters to use the Gradio syntax.
57
 
 
64
 
65
 
66
  @spaces.GPU
67
+ def generate(history):
68
  text = tokenizer.apply_chat_template(
69
  history,
70
  tokenize=False,
 
78
  asyncio.set_event_loop(loop)
79
 
80
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
81
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
82
 
83
  task = loop.run_in_executor(
84
  None,
 
103
  history = [] if history is None else history
104
  message_list = history + [message]
105
 
106
+ task, streamer = generate(message_list)
107
 
108
  buffer = ""
109
  reasoning = ""
110
  thinking = False
111
 
112
  try:
113
+ async for new_text in stream(streamer):
114
+ if task.cancelled():
115
  print("Cancelled")
116
  break # Stop le streaming si la tâche est annulée
117
 
 
133
 
134
  except asyncio.CancelledError:
135
  # this doesn't work, I don't find a way to stop generation thread
136
+ print("Cancelled by exception")
137
  streamer.on_finalized_text("cancelled", True)
138
  print("Signal sent")
139
  raise