yifan0sun commited on
Commit
d63cce4
·
1 Parent(s): 5e5b29d
Files changed (3) hide show
  1. BERTmodel.py +7 -4
  2. DISTILLBERTmodel.py +5 -4
  3. ROBERTAmodel.py +5 -4
BERTmodel.py CHANGED
@@ -11,23 +11,26 @@ from transformers import (
11
  )
12
  import torch.nn.functional as F
13
 
 
 
14
 
15
 
16
  class BERTVisualizer(TransformerVisualizer):
17
  def __init__(self,task):
18
  super().__init__()
19
  self.task = task
20
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
21
  print('finding model', self.task)
22
  if self.task == 'mlm':
23
  self.model = BertForMaskedLM.from_pretrained(
24
  "bert-base-uncased",
25
- attn_implementation="eager" # fallback to standard attention
 
26
  ).to(self.device)
27
  elif self.task == 'sst':
28
- self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2",device_map=None)
29
  elif self.task == 'mnli':
30
- self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-MNLI", device_map=None)
31
  else:
32
  raise ValueError(f"Unsupported task: {self.task}")
33
  print('model found')
 
11
  )
12
  import torch.nn.functional as F
13
 
14
+
15
+ CACHE_DIR = "./hf_cache"
16
 
17
 
18
  class BERTVisualizer(TransformerVisualizer):
19
  def __init__(self,task):
20
  super().__init__()
21
  self.task = task
22
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=CACHE_DIR)
23
  print('finding model', self.task)
24
  if self.task == 'mlm':
25
  self.model = BertForMaskedLM.from_pretrained(
26
  "bert-base-uncased",
27
+ attn_implementation="eager", # fallback to standard attention
28
+ cache_dir=CACHE_DIR
29
  ).to(self.device)
30
  elif self.task == 'sst':
31
+ self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2",device_map=None, cache_dir=CACHE_DIR)
32
  elif self.task == 'mnli':
33
+ self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-MNLI", device_map=None, cache_dir=CACHE_DIR)
34
  else:
35
  raise ValueError(f"Unsupported task: {self.task}")
36
  print('model found')
DISTILLBERTmodel.py CHANGED
@@ -12,17 +12,18 @@ from transformers import (
12
  DistilBertForMaskedLM, DistilBertForSequenceClassification
13
  )
14
 
 
15
  class DistilBERTVisualizer(TransformerVisualizer):
16
  def __init__(self, task):
17
  super().__init__()
18
  self.task = task
19
- self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
20
  if self.task == 'mlm':
21
- self.model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
22
  elif self.task == 'sst':
23
- self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')
24
  elif self.task == 'mnli':
25
- self.model = DistilBertForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-MNLI")
26
 
27
 
28
  else:
 
12
  DistilBertForMaskedLM, DistilBertForSequenceClassification
13
  )
14
 
15
+ CACHE_DIR = "./hf_cache"
16
  class DistilBERTVisualizer(TransformerVisualizer):
17
  def __init__(self, task):
18
  super().__init__()
19
  self.task = task
20
+ self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR)
21
  if self.task == 'mlm':
22
+ self.model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR)
23
  elif self.task == 'sst':
24
+ self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english', cache_dir=CACHE_DIR)
25
  elif self.task == 'mnli':
26
+ self.model = DistilBertForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-MNLI", cache_dir=CACHE_DIR)
27
 
28
 
29
  else:
ROBERTAmodel.py CHANGED
@@ -6,17 +6,18 @@ from transformers import (
6
  RobertaForMaskedLM, RobertaForSequenceClassification
7
  )
8
 
 
9
  class RoBERTaVisualizer(TransformerVisualizer):
10
  def __init__(self, task):
11
  super().__init__()
12
  self.task = task
13
- self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
14
  if self.task == 'mlm':
15
- self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
16
  elif self.task == 'sst':
17
- self.model = RobertaForSequenceClassification.from_pretrained('textattack/roberta-base-SST-2')
18
  elif self.task == 'mnli':
19
- self.model = RobertaForSequenceClassification.from_pretrained("roberta-large-mnli")
20
 
21
 
22
  self.model.to(self.device)
 
6
  RobertaForMaskedLM, RobertaForSequenceClassification
7
  )
8
 
9
+ CACHE_DIR = "./hf_cache"
10
  class RoBERTaVisualizer(TransformerVisualizer):
11
  def __init__(self, task):
12
  super().__init__()
13
  self.task = task
14
+ self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=CACHE_DIR)
15
  if self.task == 'mlm':
16
+ self.model = RobertaForMaskedLM.from_pretrained("roberta-base", cache_dir=CACHE_DIR)
17
  elif self.task == 'sst':
18
+ self.model = RobertaForSequenceClassification.from_pretrained('textattack/roberta-base-SST-2', cache_dir=CACHE_DIR)
19
  elif self.task == 'mnli':
20
+ self.model = RobertaForSequenceClassification.from_pretrained("roberta-large-mnli", cache_dir=CACHE_DIR)
21
 
22
 
23
  self.model.to(self.device)