foxxy-hm commited on
Commit
d94f94a
·
1 Parent(s): 242ba47

Update src/models/qa_model.py

Browse files
Files changed (1) hide show
  1. src/models/qa_model.py +2 -2
src/models/qa_model.py CHANGED
@@ -9,12 +9,12 @@ from src.features.graph_utils import find_best_cluster
9
  class QAEnsembleModel(nn.Module):
10
 
11
  def __init__(self, model_name, model_checkpoints, entity_dict,
12
- thr=0.1, device="cuda:0"):
13
  super(QAEnsembleModel, self).__init__()
14
  self.nlps = []
15
  for model_checkpoint in model_checkpoints:
16
  model = AutoModelForQuestionAnswering.from_pretrained(model_name).half()
17
- model.load_state_dict(torch.load(model_checkpoint), strict=False)
18
  nlp = pipeline('question-answering', model=model,
19
  tokenizer=model_name, device=int(device.split(":")[-1]))
20
  self.nlps.append(nlp)
 
9
  class QAEnsembleModel(nn.Module):
10
 
11
  def __init__(self, model_name, model_checkpoints, entity_dict,
12
+ thr=0.1, device="cpu"):
13
  super(QAEnsembleModel, self).__init__()
14
  self.nlps = []
15
  for model_checkpoint in model_checkpoints:
16
  model = AutoModelForQuestionAnswering.from_pretrained(model_name).half()
17
+ model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
18
  nlp = pipeline('question-answering', model=model,
19
  tokenizer=model_name, device=int(device.split(":")[-1]))
20
  self.nlps.append(nlp)