Spaces:
Running
on
Zero
Running
on
Zero
Update utils/models.py
Browse files- utils/models.py +15 -5
utils/models.py
CHANGED
@@ -154,6 +154,9 @@ def run_inference(model_name, context, question):
|
|
154 |
if generation_interrupt.is_set():
|
155 |
return ""
|
156 |
|
|
|
|
|
|
|
157 |
print("REACHED HERE BEFORE pipe")
|
158 |
print(f"Loading model {model_name}...")
|
159 |
if "bitnet" in model_name.lower():
|
@@ -206,7 +209,10 @@ def run_inference(model_name, context, question):
|
|
206 |
result = pipe(
|
207 |
text_input,
|
208 |
max_new_tokens=512,
|
209 |
-
generation_kwargs={
|
|
|
|
|
|
|
210 |
)[0]["generated_text"]
|
211 |
|
212 |
result = result[-1]["content"]
|
@@ -221,7 +227,6 @@ def run_inference(model_name, context, question):
|
|
221 |
**tokenizer_kwargs,
|
222 |
)
|
223 |
|
224 |
-
|
225 |
model_inputs = model_inputs.to(model.device)
|
226 |
|
227 |
input_ids = model_inputs.input_ids
|
@@ -239,7 +244,8 @@ def run_inference(model_name, context, question):
|
|
239 |
attention_mask=attention_mask,
|
240 |
max_new_tokens=512,
|
241 |
eos_token_id=tokenizer.eos_token_id,
|
242 |
-
pad_token_id=tokenizer.pad_token_id
|
|
|
243 |
)
|
244 |
|
245 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
@@ -259,6 +265,7 @@ def run_inference(model_name, context, question):
|
|
259 |
# output_sequences = bitnet_model.generate(
|
260 |
# **formatted,
|
261 |
# max_new_tokens=512,
|
|
|
262 |
# )
|
263 |
|
264 |
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
@@ -275,7 +282,10 @@ def run_inference(model_name, context, question):
|
|
275 |
outputs = pipe(
|
276 |
formatted,
|
277 |
max_new_tokens=512,
|
278 |
-
generation_kwargs={
|
|
|
|
|
|
|
279 |
)
|
280 |
# print(outputs[0]['generated_text'])
|
281 |
result = outputs[0]["generated_text"][input_length:]
|
@@ -290,4 +300,4 @@ def run_inference(model_name, context, question):
|
|
290 |
if torch.cuda.is_available():
|
291 |
torch.cuda.empty_cache()
|
292 |
|
293 |
-
return result
|
|
|
154 |
if generation_interrupt.is_set():
|
155 |
return ""
|
156 |
|
157 |
+
# Create interrupt criteria for this generation
|
158 |
+
interrupt_criteria = InterruptCriteria(generation_interrupt)
|
159 |
+
|
160 |
print("REACHED HERE BEFORE pipe")
|
161 |
print(f"Loading model {model_name}...")
|
162 |
if "bitnet" in model_name.lower():
|
|
|
209 |
result = pipe(
|
210 |
text_input,
|
211 |
max_new_tokens=512,
|
212 |
+
generation_kwargs={
|
213 |
+
"skip_special_tokens": True,
|
214 |
+
"stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
|
215 |
+
},
|
216 |
)[0]["generated_text"]
|
217 |
|
218 |
result = result[-1]["content"]
|
|
|
227 |
**tokenizer_kwargs,
|
228 |
)
|
229 |
|
|
|
230 |
model_inputs = model_inputs.to(model.device)
|
231 |
|
232 |
input_ids = model_inputs.input_ids
|
|
|
244 |
attention_mask=attention_mask,
|
245 |
max_new_tokens=512,
|
246 |
eos_token_id=tokenizer.eos_token_id,
|
247 |
+
pad_token_id=tokenizer.pad_token_id,
|
248 |
+
stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
|
249 |
)
|
250 |
|
251 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
|
|
265 |
# output_sequences = bitnet_model.generate(
|
266 |
# **formatted,
|
267 |
# max_new_tokens=512,
|
268 |
+
# stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
|
269 |
# )
|
270 |
|
271 |
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
|
|
282 |
outputs = pipe(
|
283 |
formatted,
|
284 |
max_new_tokens=512,
|
285 |
+
generation_kwargs={
|
286 |
+
"skip_special_tokens": True,
|
287 |
+
"stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
|
288 |
+
},
|
289 |
)
|
290 |
# print(outputs[0]['generated_text'])
|
291 |
result = outputs[0]["generated_text"][input_length:]
|
|
|
300 |
if torch.cuda.is_available():
|
301 |
torch.cuda.empty_cache()
|
302 |
|
303 |
+
return result
|