Metal3d commited on
Commit
bc76ed6
·
unverified ·
1 Parent(s): 2c7beb2

Use thread, all in chat function

Browse files
Files changed (1) hide show
  1. main.py +33 -64
main.py CHANGED
@@ -1,6 +1,5 @@
1
- import asyncio
2
- import functools
3
  import re
 
4
 
5
  import gradio as gr
6
  import spaces
@@ -46,12 +45,6 @@ print(model.config)
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
 
48
 
49
- async def stream(streamer):
50
- for txt in streamer:
51
- await asyncio.sleep(0.01)
52
- yield txt
53
-
54
-
55
  def reformat_math(text):
56
  """Fix MathJax delimiters to use the Gradio syntax.
57
 
@@ -64,79 +57,55 @@ def reformat_math(text):
64
 
65
 
66
  @spaces.GPU
67
- def generate(history):
 
 
 
 
 
 
 
 
 
 
68
  text = tokenizer.apply_chat_template(
69
- history,
70
  tokenize=False,
71
  add_generation_prompt=True,
72
  )
73
 
74
- try:
75
- loop = asyncio.get_event_loop()
76
- except:
77
- loop = asyncio.new_event_loop()
78
- asyncio.set_event_loop(loop)
79
-
80
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
81
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
82
 
83
- task = loop.run_in_executor(
84
- None,
85
- functools.partial(
86
- model.generate,
87
  max_new_tokens=1024 * 128,
88
  streamer=streamer,
89
  **model_inputs,
90
  ),
91
- )
92
- return task, streamer
93
-
94
-
95
- async def chat(prompt, history):
96
- """Respond to a chat prompt."""
97
- message = {
98
- "role": "user",
99
- "content": prompt,
100
- }
101
-
102
- # build the messages list
103
- history = [] if history is None else history
104
- message_list = history + [message]
105
-
106
- task, streamer = generate(message_list)
107
 
108
  buffer = ""
109
  reasoning = ""
110
  thinking = False
111
 
112
- try:
113
- async for new_text in stream(streamer):
114
- if task.cancelled():
115
- print("Cancelled")
116
- break # Stop le streaming si la tâche est annulée
117
-
118
- if not thinking and "<think>" in new_text:
119
- thinking = True
120
- continue
121
- if thinking and "</think>" in new_text:
122
- thinking = False
123
- continue
124
-
125
- if thinking:
126
- reasoning += new_text
127
- heading = "# Reasoning\n\n"
128
- yield "I'm thinking, please wait a moment...", heading + reasoning
129
- continue
130
-
131
- buffer += new_text
132
- yield reformat_math(buffer), reasoning
133
-
134
- except asyncio.CancelledError:
135
- # this doesn't work, I don't find a way to stop generation thread
136
- print("Cancelled by exception")
137
- streamer.on_finalized_text("cancelled", True)
138
- print("Signal sent")
139
- raise
140
 
141
 
142
  chat_bot = gr.Chatbot(
 
 
 
1
  import re
2
+ import threading
3
 
4
  import gradio as gr
5
  import spaces
 
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
 
47
 
 
 
 
 
 
 
48
  def reformat_math(text):
49
  """Fix MathJax delimiters to use the Gradio syntax.
50
 
 
57
 
58
 
59
  @spaces.GPU
60
+ def chat(prompt, history):
61
+ """Respond to a chat prompt."""
62
+ message = {
63
+ "role": "user",
64
+ "content": prompt,
65
+ }
66
+
67
+ # build the messages list
68
+ history = [] if history is None else history
69
+ message_list = history + [message]
70
+
71
  text = tokenizer.apply_chat_template(
72
+ message_list,
73
  tokenize=False,
74
  add_generation_prompt=True,
75
  )
76
 
 
 
 
 
 
 
77
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
78
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
79
 
80
+ threading.Thread(
81
+ target=model.generate,
82
+ kwargs=dict(
 
83
  max_new_tokens=1024 * 128,
84
  streamer=streamer,
85
  **model_inputs,
86
  ),
87
+ ).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  buffer = ""
90
  reasoning = ""
91
  thinking = False
92
 
93
+ for new_text in streamer:
94
+ if not thinking and "<think>" in new_text:
95
+ thinking = True
96
+ continue
97
+ if thinking and "</think>" in new_text:
98
+ thinking = False
99
+ continue
100
+
101
+ if thinking:
102
+ reasoning += new_text
103
+ heading = "# Reasoning\n\n"
104
+ yield "I'm thinking, please wait a moment...", heading + reasoning
105
+ continue
106
+
107
+ buffer += new_text
108
+ yield reformat_math(buffer), reasoning
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  chat_bot = gr.Chatbot(