Nadil Karunarathna commited on
Commit
a46e61a
·
1 Parent(s): 2d278af
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -40,15 +40,21 @@ def correct(text):
40
  )
41
  prediction = outputs[0]
42
 
 
 
 
 
 
 
 
 
 
 
43
  special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
44
- all_special_ids = set(tokenizer.all_special_ids)
45
- pred_tokens = prediction.cpu()
46
-
47
- tokens_list = pred_tokens.tolist()
48
- filtered_tokens = [
49
- token for token in tokens_list
50
- if token == special_token_id_to_keep or token not in all_special_ids
51
- ]
52
 
53
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
54
 
 
40
  )
41
  prediction = outputs[0]
42
 
43
+ # special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
44
+ # all_special_ids = set(tokenizer.all_special_ids)
45
+ # pred_tokens = prediction.cpu()
46
+
47
+ # tokens_list = pred_tokens.tolist()
48
+ # filtered_tokens = [
49
+ # token for token in tokens_list
50
+ # if token == special_token_id_to_keep or token not in all_special_ids
51
+ # ]
52
+
53
  special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
54
+ all_special_ids_tensor = torch.tensor(tokenizer.all_special_ids, dtype=torch.long)
55
+
56
+ mask = (prediction == special_token_id_to_keep) | (~torch.isin(prediction, all_special_ids_tensor))
57
+ filtered_tokens = prediction[mask]
 
 
 
 
58
 
59
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
60