lm-spell / app.py
Nadil Karunarathna
eval
3ee3f83
import gradio as gr
import torch
import re
model = None
tokenizer = None
def init():
from transformers import MT5ForConditionalGeneration, T5TokenizerFast
import os
global model, tokenizer
hf_token = os.environ.get("HF_TOKEN")
model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
model.eval()
tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
def correct(text):
text = re.sub(r'\u200d', '<ZWJ>', text)
inputs = tokenizer(
text,
return_tensors='pt',
padding='do_not_pad',
max_length=1024
)
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=1024,
num_beams=1,
do_sample=False,
)
prediction = outputs[0]
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
all_special_ids = set(tokenizer.all_special_ids)
pred_tokens = prediction.cpu()
tokens_list = pred_tokens.tolist()
filtered_tokens = [
token for token in tokens_list
if token == special_token_id_to_keep or token not in all_special_ids
]
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
init()
demo = gr.Interface(fn=correct, inputs="text", outputs="text")
demo.launch()