Nadil Karunarathna commited on
Commit
991fe21
·
1 Parent(s): 436d1f7
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -42,18 +42,32 @@ def correct(text):
42
  prediction = outputs[0]
43
 
44
  special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
45
- all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
46
- special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
 
 
 
 
 
 
47
 
48
- pred_tokens = prediction.to(device)
49
- tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
50
- mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
51
- filtered_tokens = tokens_tensor[mask].tolist()
52
-
53
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
54
-
55
  return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # prediction_decoded = tokenizer.decode(prediction, skip_special_tokens=True).replace('\n', '').strip()
58
  # prediction_decoded = re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
59
 
 
42
  prediction = outputs[0]
43
 
44
  special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
45
+ all_special_ids = set(tokenizer.all_special_ids)
46
+ pred_tokens = prediction.cpu()
47
+
48
+ tokens_list = pred_tokens.tolist()
49
+ filtered_tokens = [
50
+ token for token in tokens_list
51
+ if token == special_token_id_to_keep or token not in all_special_ids
52
+ ]
53
 
 
 
 
 
 
54
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
55
+
56
  return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
57
 
58
+ # special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
59
+ # all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
60
+ # special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
61
+
62
+ # pred_tokens = prediction.to(device)
63
+ # tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
64
+ # mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
65
+ # filtered_tokens = tokens_tensor[mask].tolist()
66
+
67
+ # prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
68
+
69
+ # return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
70
+
71
  # prediction_decoded = tokenizer.decode(prediction, skip_special_tokens=True).replace('\n', '').strip()
72
  # prediction_decoded = re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
73