|
import gradio as gr |
|
import torch |
|
import re |
|
|
|
model = None |
|
tokenizer = None |
|
|
|
def init(): |
|
from transformers import MT5ForConditionalGeneration, T5TokenizerFast |
|
import os |
|
|
|
global model, tokenizer |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token) |
|
model.eval() |
|
tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base") |
|
tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']}) |
|
|
|
|
|
def correct(text): |
|
|
|
|
|
text = re.sub(r'\u200d', '<ZWJ>', text) |
|
inputs = tokenizer( |
|
text, |
|
return_tensors='pt', |
|
padding='do_not_pad', |
|
max_length=1024 |
|
) |
|
|
|
with torch.inference_mode(): |
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_length=1024, |
|
num_beams=1, |
|
do_sample=False, |
|
) |
|
prediction = outputs[0] |
|
|
|
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>') |
|
all_special_ids = set(tokenizer.all_special_ids) |
|
pred_tokens = prediction.cpu() |
|
|
|
tokens_list = pred_tokens.tolist() |
|
filtered_tokens = [ |
|
token for token in tokens_list |
|
if token == special_token_id_to_keep or token not in all_special_ids |
|
] |
|
|
|
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip() |
|
|
|
return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded) |
|
|
|
init() |
|
demo = gr.Interface(fn=correct, inputs="text", outputs="text") |
|
demo.launch() |
|
|