Nadil Karunarathna
commited on
Commit
·
a46e61a
1
Parent(s):
2d278af
wip
Browse files
app.py
CHANGED
@@ -40,15 +40,21 @@ def correct(text):
|
|
40 |
)
|
41 |
prediction = outputs[0]
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
filtered_tokens = [
|
49 |
-
token for token in tokens_list
|
50 |
-
if token == special_token_id_to_keep or token not in all_special_ids
|
51 |
-
]
|
52 |
|
53 |
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
|
54 |
|
|
|
40 |
)
|
41 |
prediction = outputs[0]
|
42 |
|
43 |
+
# special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
|
44 |
+
# all_special_ids = set(tokenizer.all_special_ids)
|
45 |
+
# pred_tokens = prediction.cpu()
|
46 |
+
|
47 |
+
# tokens_list = pred_tokens.tolist()
|
48 |
+
# filtered_tokens = [
|
49 |
+
# token for token in tokens_list
|
50 |
+
# if token == special_token_id_to_keep or token not in all_special_ids
|
51 |
+
# ]
|
52 |
+
|
53 |
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
|
54 |
+
all_special_ids_tensor = torch.tensor(tokenizer.all_special_ids, dtype=torch.long)
|
55 |
+
|
56 |
+
mask = (prediction == special_token_id_to_keep) | (~torch.isin(prediction, all_special_ids_tensor))
|
57 |
+
filtered_tokens = prediction[mask]
|
|
|
|
|
|
|
|
|
58 |
|
59 |
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
|
60 |
|