Nadil Karunarathna
commited on
Commit
·
991fe21
1
Parent(s):
436d1f7
wip
Browse files
app.py
CHANGED
@@ -42,18 +42,32 @@ def correct(text):
|
|
42 |
prediction = outputs[0]
|
43 |
|
44 |
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
|
45 |
-
all_special_ids =
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
pred_tokens = prediction.to(device)
|
49 |
-
tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
|
50 |
-
mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
|
51 |
-
filtered_tokens = tokens_tensor[mask].tolist()
|
52 |
-
|
53 |
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
|
54 |
-
|
55 |
return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# prediction_decoded = tokenizer.decode(prediction, skip_special_tokens=True).replace('\n', '').strip()
|
58 |
# prediction_decoded = re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
|
59 |
|
|
|
42 |
prediction = outputs[0]
|
43 |
|
44 |
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
|
45 |
+
all_special_ids = set(tokenizer.all_special_ids)
|
46 |
+
pred_tokens = prediction.cpu()
|
47 |
+
|
48 |
+
tokens_list = pred_tokens.tolist()
|
49 |
+
filtered_tokens = [
|
50 |
+
token for token in tokens_list
|
51 |
+
if token == special_token_id_to_keep or token not in all_special_ids
|
52 |
+
]
|
53 |
|
|
|
|
|
|
|
|
|
|
|
54 |
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
|
55 |
+
|
56 |
return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
|
57 |
|
58 |
+
# special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
|
59 |
+
# all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
|
60 |
+
# special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
|
61 |
+
|
62 |
+
# pred_tokens = prediction.to(device)
|
63 |
+
# tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
|
64 |
+
# mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
|
65 |
+
# filtered_tokens = tokens_tensor[mask].tolist()
|
66 |
+
|
67 |
+
# prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
|
68 |
+
|
69 |
+
# return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
|
70 |
+
|
71 |
# prediction_decoded = tokenizer.decode(prediction, skip_special_tokens=True).replace('\n', '').strip()
|
72 |
# prediction_decoded = re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
|
73 |
|