File size: 1,594 Bytes
f0c636b
6baca04
 
f0c636b
6baca04
 
faa4aa2
6baca04
 
 
699251d
6baca04
b9e0b01
6baca04
914f0b6
6baca04
3ee3f83
6baca04
 
9ba0dd3
 
6baca04
2d278af
9ba0dd3
6baca04
 
 
 
 
 
 
9ba0dd3
6baca04
 
 
 
 
 
 
 
 
a46e61a
6baca04
 
 
9ba0dd3
6baca04
 
 
 
 
991fe21
6baca04
8068f7e
6baca04
699251d
6baca04
 
252e169
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torch
import re

model = None
tokenizer = None

def init():
    from transformers import MT5ForConditionalGeneration, T5TokenizerFast
    import os

    global model, tokenizer

    hf_token = os.environ.get("HF_TOKEN")

    model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
    model.eval()
    tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
    tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})


def correct(text):


    text = re.sub(r'\u200d', '<ZWJ>', text)
    inputs = tokenizer(
        text,
        return_tensors='pt',
        padding='do_not_pad',
        max_length=1024
    )

    with torch.inference_mode():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=1024,
            num_beams=1,
            do_sample=False,
        )
    prediction = outputs[0]

    special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
    all_special_ids = set(tokenizer.all_special_ids)
    pred_tokens = prediction.cpu()

    tokens_list = pred_tokens.tolist()
    filtered_tokens = [
        token for token in tokens_list 
        if token == special_token_id_to_keep or token not in all_special_ids
    ]

    prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()

    return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)

init()
demo = gr.Interface(fn=correct, inputs="text", outputs="text")
demo.launch()