aizip-dev commited on
Commit
c1f1ebf
·
verified ·
1 Parent(s): c508d80

Update utils/models.py

Browse files
Files changed (1) hide show
  1. utils/models.py +15 -5
utils/models.py CHANGED
@@ -154,6 +154,9 @@ def run_inference(model_name, context, question):
154
  if generation_interrupt.is_set():
155
  return ""
156
 
 
 
 
157
  print("REACHED HERE BEFORE pipe")
158
  print(f"Loading model {model_name}...")
159
  if "bitnet" in model_name.lower():
@@ -206,7 +209,10 @@ def run_inference(model_name, context, question):
206
  result = pipe(
207
  text_input,
208
  max_new_tokens=512,
209
- generation_kwargs={"skip_special_tokens": True},
 
 
 
210
  )[0]["generated_text"]
211
 
212
  result = result[-1]["content"]
@@ -221,7 +227,6 @@ def run_inference(model_name, context, question):
221
  **tokenizer_kwargs,
222
  )
223
 
224
-
225
  model_inputs = model_inputs.to(model.device)
226
 
227
  input_ids = model_inputs.input_ids
@@ -239,7 +244,8 @@ def run_inference(model_name, context, question):
239
  attention_mask=attention_mask,
240
  max_new_tokens=512,
241
  eos_token_id=tokenizer.eos_token_id,
242
- pad_token_id=tokenizer.pad_token_id # Addresses the warning
 
243
  )
244
 
245
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
@@ -259,6 +265,7 @@ def run_inference(model_name, context, question):
259
  # output_sequences = bitnet_model.generate(
260
  # **formatted,
261
  # max_new_tokens=512,
 
262
  # )
263
 
264
  # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
@@ -275,7 +282,10 @@ def run_inference(model_name, context, question):
275
  outputs = pipe(
276
  formatted,
277
  max_new_tokens=512,
278
- generation_kwargs={"skip_special_tokens": True},
 
 
 
279
  )
280
  # print(outputs[0]['generated_text'])
281
  result = outputs[0]["generated_text"][input_length:]
@@ -290,4 +300,4 @@ def run_inference(model_name, context, question):
290
  if torch.cuda.is_available():
291
  torch.cuda.empty_cache()
292
 
293
- return result
 
154
  if generation_interrupt.is_set():
155
  return ""
156
 
157
+ # Create interrupt criteria for this generation
158
+ interrupt_criteria = InterruptCriteria(generation_interrupt)
159
+
160
  print("REACHED HERE BEFORE pipe")
161
  print(f"Loading model {model_name}...")
162
  if "bitnet" in model_name.lower():
 
209
  result = pipe(
210
  text_input,
211
  max_new_tokens=512,
212
+ generation_kwargs={
213
+ "skip_special_tokens": True,
214
+ "stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
215
+ },
216
  )[0]["generated_text"]
217
 
218
  result = result[-1]["content"]
 
227
  **tokenizer_kwargs,
228
  )
229
 
 
230
  model_inputs = model_inputs.to(model.device)
231
 
232
  input_ids = model_inputs.input_ids
 
244
  attention_mask=attention_mask,
245
  max_new_tokens=512,
246
  eos_token_id=tokenizer.eos_token_id,
247
+ pad_token_id=tokenizer.pad_token_id,
248
+ stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
249
  )
250
 
251
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
 
265
  # output_sequences = bitnet_model.generate(
266
  # **formatted,
267
  # max_new_tokens=512,
268
+ # stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
269
  # )
270
 
271
  # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
 
282
  outputs = pipe(
283
  formatted,
284
  max_new_tokens=512,
285
+ generation_kwargs={
286
+ "skip_special_tokens": True,
287
+ "stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
288
+ },
289
  )
290
  # print(outputs[0]['generated_text'])
291
  result = outputs[0]["generated_text"][input_length:]
 
300
  if torch.cuda.is_available():
301
  torch.cuda.empty_cache()
302
 
303
+ return result