Spaces:
Sleeping
Sleeping
dcachedir
Browse files- BERTmodel.py +7 -4
- DISTILLBERTmodel.py +5 -4
- ROBERTAmodel.py +5 -4
BERTmodel.py
CHANGED
@@ -11,23 +11,26 @@ from transformers import (
|
|
11 |
)
|
12 |
import torch.nn.functional as F
|
13 |
|
|
|
|
|
14 |
|
15 |
|
16 |
class BERTVisualizer(TransformerVisualizer):
|
17 |
def __init__(self,task):
|
18 |
super().__init__()
|
19 |
self.task = task
|
20 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
21 |
print('finding model', self.task)
|
22 |
if self.task == 'mlm':
|
23 |
self.model = BertForMaskedLM.from_pretrained(
|
24 |
"bert-base-uncased",
|
25 |
-
attn_implementation="eager" # fallback to standard attention
|
|
|
26 |
).to(self.device)
|
27 |
elif self.task == 'sst':
|
28 |
-
self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2",device_map=None)
|
29 |
elif self.task == 'mnli':
|
30 |
-
self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-MNLI", device_map=None)
|
31 |
else:
|
32 |
raise ValueError(f"Unsupported task: {self.task}")
|
33 |
print('model found')
|
|
|
11 |
)
|
12 |
import torch.nn.functional as F
|
13 |
|
14 |
+
|
15 |
+
CACHE_DIR = "./hf_cache"
|
16 |
|
17 |
|
18 |
class BERTVisualizer(TransformerVisualizer):
|
19 |
def __init__(self,task):
|
20 |
super().__init__()
|
21 |
self.task = task
|
22 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=CACHE_DIR)
|
23 |
print('finding model', self.task)
|
24 |
if self.task == 'mlm':
|
25 |
self.model = BertForMaskedLM.from_pretrained(
|
26 |
"bert-base-uncased",
|
27 |
+
attn_implementation="eager", # fallback to standard attention
|
28 |
+
cache_dir=CACHE_DIR
|
29 |
).to(self.device)
|
30 |
elif self.task == 'sst':
|
31 |
+
self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2",device_map=None, cache_dir=CACHE_DIR)
|
32 |
elif self.task == 'mnli':
|
33 |
+
self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-MNLI", device_map=None, cache_dir=CACHE_DIR)
|
34 |
else:
|
35 |
raise ValueError(f"Unsupported task: {self.task}")
|
36 |
print('model found')
|
DISTILLBERTmodel.py
CHANGED
@@ -12,17 +12,18 @@ from transformers import (
|
|
12 |
DistilBertForMaskedLM, DistilBertForSequenceClassification
|
13 |
)
|
14 |
|
|
|
15 |
class DistilBERTVisualizer(TransformerVisualizer):
|
16 |
def __init__(self, task):
|
17 |
super().__init__()
|
18 |
self.task = task
|
19 |
-
self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
20 |
if self.task == 'mlm':
|
21 |
-
self.model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
|
22 |
elif self.task == 'sst':
|
23 |
-
self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')
|
24 |
elif self.task == 'mnli':
|
25 |
-
self.model = DistilBertForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-MNLI")
|
26 |
|
27 |
|
28 |
else:
|
|
|
12 |
DistilBertForMaskedLM, DistilBertForSequenceClassification
|
13 |
)
|
14 |
|
15 |
+
CACHE_DIR = "./hf_cache"
|
16 |
class DistilBERTVisualizer(TransformerVisualizer):
|
17 |
def __init__(self, task):
|
18 |
super().__init__()
|
19 |
self.task = task
|
20 |
+
self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR)
|
21 |
if self.task == 'mlm':
|
22 |
+
self.model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR)
|
23 |
elif self.task == 'sst':
|
24 |
+
self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english', cache_dir=CACHE_DIR)
|
25 |
elif self.task == 'mnli':
|
26 |
+
self.model = DistilBertForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-MNLI", cache_dir=CACHE_DIR)
|
27 |
|
28 |
|
29 |
else:
|
ROBERTAmodel.py
CHANGED
@@ -6,17 +6,18 @@ from transformers import (
|
|
6 |
RobertaForMaskedLM, RobertaForSequenceClassification
|
7 |
)
|
8 |
|
|
|
9 |
class RoBERTaVisualizer(TransformerVisualizer):
|
10 |
def __init__(self, task):
|
11 |
super().__init__()
|
12 |
self.task = task
|
13 |
-
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
14 |
if self.task == 'mlm':
|
15 |
-
self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
|
16 |
elif self.task == 'sst':
|
17 |
-
self.model = RobertaForSequenceClassification.from_pretrained('textattack/roberta-base-SST-2')
|
18 |
elif self.task == 'mnli':
|
19 |
-
self.model = RobertaForSequenceClassification.from_pretrained("roberta-large-mnli")
|
20 |
|
21 |
|
22 |
self.model.to(self.device)
|
|
|
6 |
RobertaForMaskedLM, RobertaForSequenceClassification
|
7 |
)
|
8 |
|
9 |
+
CACHE_DIR = "./hf_cache"
|
10 |
class RoBERTaVisualizer(TransformerVisualizer):
|
11 |
def __init__(self, task):
|
12 |
super().__init__()
|
13 |
self.task = task
|
14 |
+
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=CACHE_DIR)
|
15 |
if self.task == 'mlm':
|
16 |
+
self.model = RobertaForMaskedLM.from_pretrained("roberta-base", cache_dir=CACHE_DIR)
|
17 |
elif self.task == 'sst':
|
18 |
+
self.model = RobertaForSequenceClassification.from_pretrained('textattack/roberta-base-SST-2', cache_dir=CACHE_DIR)
|
19 |
elif self.task == 'mnli':
|
20 |
+
self.model = RobertaForSequenceClassification.from_pretrained("roberta-large-mnli", cache_dir=CACHE_DIR)
|
21 |
|
22 |
|
23 |
self.model.to(self.device)
|