Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,13 @@ model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
|
|
12 |
vocab = tokenizer.vocab
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def func_macro_correct(text):
|
16 |
with torch.no_grad():
|
17 |
outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
|
@@ -29,7 +36,53 @@ def func_macro_correct(text):
|
|
29 |
return False
|
30 |
return True
|
31 |
|
32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"""Get new corrected text and errors between corrected text and origin text
|
34 |
code from: https://github.com/shibing624/pycorrector
|
35 |
"""
|
@@ -57,7 +110,14 @@ def func_macro_correct(text):
|
|
57 |
|
58 |
_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
|
59 |
corrected_text = _text[:len(text)]
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
print(text, ' => ', corrected_text, details)
|
62 |
return corrected_text + ' ' + str(details)
|
63 |
|
@@ -88,6 +148,6 @@ if __name__ == '__main__':
|
|
88 |
description="Copy or input error Chinese text. Submit and the machine will correct text.",
|
89 |
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
|
90 |
examples=examples
|
91 |
-
).launch()
|
92 |
|
93 |
|
|
|
12 |
vocab = tokenizer.vocab
|
13 |
|
14 |
|
15 |
+
# from modelscope import AutoTokenizer, AutoModelForMaskedLM
|
16 |
+
# pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
|
17 |
+
# tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
18 |
+
# model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
|
19 |
+
# vocab = tokenizer.vocab
|
20 |
+
|
21 |
+
|
22 |
def func_macro_correct(text):
|
23 |
with torch.no_grad():
|
24 |
outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
|
|
|
36 |
return False
|
37 |
return True
|
38 |
|
39 |
+
def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
|
40 |
+
"""Get errors between corrected text and origin text
|
41 |
+
code from: https://github.com/shibing624/pycorrector
|
42 |
+
"""
|
43 |
+
new_corrected_text = ""
|
44 |
+
errors = []
|
45 |
+
i, j = 0, 0
|
46 |
+
unk_tokens = unk_tokens or [' ', 'β', 'β', 'β', 'β', 'η', '\n', 'β¦', 'ζ€', '\t', 'η', 'ο']
|
47 |
+
while i < len(origin_text) and j < len(corrected_text):
|
48 |
+
if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
|
49 |
+
new_corrected_text += origin_text[i]
|
50 |
+
i += 1
|
51 |
+
elif corrected_text[j] in unk_tokens:
|
52 |
+
new_corrected_text += corrected_text[j]
|
53 |
+
j += 1
|
54 |
+
# Deal with Chinese characters
|
55 |
+
elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
|
56 |
+
# If the two characters are the same, then the two pointers move forward together
|
57 |
+
if origin_text[i] == corrected_text[j]:
|
58 |
+
new_corrected_text += corrected_text[j]
|
59 |
+
i += 1
|
60 |
+
j += 1
|
61 |
+
else:
|
62 |
+
# Check for insertion errors
|
63 |
+
if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
|
64 |
+
errors.append(('', corrected_text[j], j))
|
65 |
+
new_corrected_text += corrected_text[j]
|
66 |
+
j += 1
|
67 |
+
# Check for deletion errors
|
68 |
+
elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
|
69 |
+
errors.append((origin_text[i], '', i))
|
70 |
+
i += 1
|
71 |
+
# Check for replacement errors
|
72 |
+
else:
|
73 |
+
errors.append((origin_text[i], corrected_text[j], i))
|
74 |
+
new_corrected_text += corrected_text[j]
|
75 |
+
i += 1
|
76 |
+
j += 1
|
77 |
+
else:
|
78 |
+
new_corrected_text += origin_text[i]
|
79 |
+
if origin_text[i] == corrected_text[j]:
|
80 |
+
j += 1
|
81 |
+
i += 1
|
82 |
+
errors = sorted(errors, key=operator.itemgetter(2))
|
83 |
+
return new_corrected_text, errors
|
84 |
+
|
85 |
+
def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
|
86 |
"""Get new corrected text and errors between corrected text and origin text
|
87 |
code from: https://github.com/shibing624/pycorrector
|
88 |
"""
|
|
|
110 |
|
111 |
_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
|
112 |
corrected_text = _text[:len(text)]
|
113 |
+
print("#" * 128)
|
114 |
+
print(text)
|
115 |
+
print(corrected_text)
|
116 |
+
print(len(text), len(corrected_text))
|
117 |
+
if len(corrected_text) == len(text):
|
118 |
+
corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
|
119 |
+
else:
|
120 |
+
corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
|
121 |
print(text, ' => ', corrected_text, details)
|
122 |
return corrected_text + ' ' + str(details)
|
123 |
|
|
|
148 |
description="Copy or input error Chinese text. Submit and the machine will correct text.",
|
149 |
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
|
150 |
examples=examples
|
151 |
+
).launch() # .launch(server_name="0.0.0.0", server_port=8036, share=False, debug=True)
|
152 |
|
153 |
|