rcook commited on
Commit
5483dd7
·
verified ·
1 Parent(s): a5adb63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -0
app.py CHANGED
@@ -38,6 +38,19 @@ def summarize():
38
  data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
39
 
40
  rouge = evaluate.load("rouge")
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  return data_collator
43
  # return type(tokenized_billsum)
 
38
  data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
39
 
40
  rouge = evaluate.load("rouge")
41
+
42
+ def compute_metrics(eval_pred):
43
+ predictions, labels = eval_pred
44
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
45
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
46
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
47
+
48
+ result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
49
+
50
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
51
+ result["gen_len"] = np.mean(prediction_lens)
52
+
53
+ return {k: round(v, 4) for k, v in result.items()}
54
 
55
  return data_collator
56
  # return type(tokenized_billsum)