suryadev1 commited on
Commit
3d3c5a2
·
verified ·
1 Parent(s): 61a85a4
Files changed (1) hide show
  1. new_test_saved_finetuned_model.py +2 -1
new_test_saved_finetuned_model.py CHANGED
@@ -183,7 +183,8 @@ class BERTFineTuneTrainer:
183
  tlabels.extend(data['label'].cpu().numpy())
184
  positive_class_probs = [prob[1] for prob in probabs]
185
  # Compare predicted labels to true labels and calculate accuracy
186
- correct = (data['label'] == predicted_labels).sum().item()
 
187
 
188
  avg_loss += loss.item()
189
  total_correct += correct
 
183
  tlabels.extend(data['label'].cpu().numpy())
184
  positive_class_probs = [prob[1] for prob in probabs]
185
  # Compare predicted labels to true labels and calculate accuracy
186
+ correct = (data['label'].to(predicted_labels.device) == predicted_labels).sum().item()
187
+
188
 
189
  avg_loss += loss.item()
190
  total_correct += correct