suryadev1 commited on
Commit
168746a
·
verified ·
1 Parent(s): 5a8d2be
Files changed (1) hide show
  1. new_test_saved_finetuned_model.py +3 -0
new_test_saved_finetuned_model.py CHANGED
@@ -162,6 +162,9 @@ class BERTFineTuneTrainer:
162
  logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
163
 
164
  logits = logits.cpu()
 
 
 
165
  loss = self.criterion(logits, data["label"])
166
  # if torch.cuda.device_count() > 1:
167
  # loss = loss.mean()
 
162
  logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
163
 
164
  logits = logits.cpu()
165
+ devic = logits.device # or self.model.device if available
166
+ labels = data["label"].to(devic)
167
+
168
  loss = self.criterion(logits, data["label"])
169
  # if torch.cuda.device_count() > 1:
170
  # loss = loss.mean()