Nadil Karunarathna commited on
Commit
fdd932a
·
1 Parent(s): a46e61a
Files changed (1) hide show
  1. app.py +8 -14
app.py CHANGED
@@ -40,21 +40,15 @@ def correct(text):
40
  )
41
  prediction = outputs[0]
42
 
43
- # special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
44
- # all_special_ids = set(tokenizer.all_special_ids)
45
- # pred_tokens = prediction.cpu()
46
-
47
- # tokens_list = pred_tokens.tolist()
48
- # filtered_tokens = [
49
- # token for token in tokens_list
50
- # if token == special_token_id_to_keep or token not in all_special_ids
51
- # ]
52
-
53
  special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
54
- all_special_ids_tensor = torch.tensor(tokenizer.all_special_ids, dtype=torch.long)
 
55
 
56
- mask = (prediction == special_token_id_to_keep) | (~torch.isin(prediction, all_special_ids_tensor))
57
- filtered_tokens = prediction[mask]
 
 
 
58
 
59
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
60
 
@@ -63,4 +57,4 @@ def correct(text):
63
  init()
64
 
65
  demo = gr.Interface(fn=correct, inputs="text", outputs="text")
66
- demo.launch()
 
40
  )
41
  prediction = outputs[0]
42
 
 
 
 
 
 
 
 
 
 
 
43
  special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
44
+ all_special_ids = set(tokenizer.all_special_ids)
45
+ pred_tokens = prediction.cpu()
46
 
47
+ tokens_list = pred_tokens.tolist()
48
+ filtered_tokens = [
49
+ token for token in tokens_list
50
+ if token == special_token_id_to_keep or token not in all_special_ids
51
+ ]
52
 
53
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
54
 
 
57
  init()
58
 
59
  demo = gr.Interface(fn=correct, inputs="text", outputs="text")
60
+ demo.launch(share=True)