Metal3d commited on
Commit
48a12f0
·
unverified ·
1 Parent(s): 633edd7

Moving spaces.GPU

Browse files
Files changed (1) hide show
  1. main.py +27 -23
main.py CHANGED
@@ -57,6 +57,29 @@ def reformat_math(text):
57
  return text
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  async def chat(prompt, history):
61
  """Respond to a chat prompt."""
62
  message = {
@@ -64,31 +87,12 @@ async def chat(prompt, history):
64
  "content": prompt,
65
  }
66
 
 
67
  history = [] if history is None else history
 
68
 
69
- @spaces.GPU
70
- def _generate():
71
- text = tokenizer.apply_chat_template(
72
- history + [message],
73
- tokenize=False,
74
- add_generation_prompt=True,
75
- )
76
-
77
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
78
- streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True)
79
-
80
- task = asyncio.get_running_loop().run_in_executor(
81
- None,
82
- functools.partial(
83
- model.generate,
84
- max_new_tokens=1024 * 128,
85
- streamer=streamer,
86
- **model_inputs,
87
- ),
88
- )
89
- return task, streamer
90
-
91
- task, streamer = _generate()
92
 
93
  buffer = ""
94
  reasoning = ""
 
57
  return text
58
 
59
 
60
+ @spaces.GPU
61
+ def _generate(history):
62
+ text = tokenizer.apply_chat_template(
63
+ history,
64
+ tokenize=False,
65
+ add_generation_prompt=True,
66
+ )
67
+
68
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
69
+ streamer = AsyncTextIteratorStreamer(tokenizer, skip_special_tokens=True)
70
+
71
+ task = asyncio.get_running_loop().run_in_executor(
72
+ None,
73
+ functools.partial(
74
+ model.generate,
75
+ max_new_tokens=1024 * 128,
76
+ streamer=streamer,
77
+ **model_inputs,
78
+ ),
79
+ )
80
+ return task, streamer
81
+
82
+
83
  async def chat(prompt, history):
84
  """Respond to a chat prompt."""
85
  message = {
 
87
  "content": prompt,
88
  }
89
 
90
+ # build the messages list
91
  history = [] if history is None else history
92
+ message_list = history + [message]
93
 
94
+ # get the task and the streamer
95
+ task, streamer = _generate(message_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  buffer = ""
98
  reasoning = ""