Macropodus commited on
Commit
6b2b9b7
Β·
verified Β·
1 Parent(s): beff6fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -3
app.py CHANGED
@@ -12,6 +12,13 @@ model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
12
  vocab = tokenizer.vocab
13
 
14
 
 
 
 
 
 
 
 
15
  def func_macro_correct(text):
16
  with torch.no_grad():
17
  outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
@@ -29,7 +36,53 @@ def func_macro_correct(text):
29
  return False
30
  return True
31
 
32
- def get_errors(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  """Get new corrected text and errors between corrected text and origin text
34
  code from: https://github.com/shibing624/pycorrector
35
  """
@@ -57,7 +110,14 @@ def func_macro_correct(text):
57
 
58
  _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
59
  corrected_text = _text[:len(text)]
60
- corrected_text, details = get_errors(corrected_text, text, know_tokens=vocab)
 
 
 
 
 
 
 
61
  print(text, ' => ', corrected_text, details)
62
  return corrected_text + ' ' + str(details)
63
 
@@ -88,6 +148,6 @@ if __name__ == '__main__':
88
  description="Copy or input error Chinese text. Submit and the machine will correct text.",
89
  article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
90
  examples=examples
91
- ).launch()
92
 
93
 
 
12
  vocab = tokenizer.vocab
13
 
14
 
15
+ # from modelscope import AutoTokenizer, AutoModelForMaskedLM
16
+ # pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
17
+ # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
18
+ # model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
19
+ # vocab = tokenizer.vocab
20
+
21
+
22
  def func_macro_correct(text):
23
  with torch.no_grad():
24
  outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
 
36
  return False
37
  return True
38
 
39
+ def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
40
+ """Get errors between corrected text and origin text
41
+ code from: https://github.com/shibing624/pycorrector
42
+ """
43
+ new_corrected_text = ""
44
+ errors = []
45
+ i, j = 0, 0
46
+ unk_tokens = unk_tokens or [' ', 'β€œ', '”', 'β€˜', '’', '琊', '\n', '…', 'ζ“€', '\t', 'ηŽ•', 'οƒ˜']
47
+ while i < len(origin_text) and j < len(corrected_text):
48
+ if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
49
+ new_corrected_text += origin_text[i]
50
+ i += 1
51
+ elif corrected_text[j] in unk_tokens:
52
+ new_corrected_text += corrected_text[j]
53
+ j += 1
54
+ # Deal with Chinese characters
55
+ elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
56
+ # If the two characters are the same, then the two pointers move forward together
57
+ if origin_text[i] == corrected_text[j]:
58
+ new_corrected_text += corrected_text[j]
59
+ i += 1
60
+ j += 1
61
+ else:
62
+ # Check for insertion errors
63
+ if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
64
+ errors.append(('', corrected_text[j], j))
65
+ new_corrected_text += corrected_text[j]
66
+ j += 1
67
+ # Check for deletion errors
68
+ elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
69
+ errors.append((origin_text[i], '', i))
70
+ i += 1
71
+ # Check for replacement errors
72
+ else:
73
+ errors.append((origin_text[i], corrected_text[j], i))
74
+ new_corrected_text += corrected_text[j]
75
+ i += 1
76
+ j += 1
77
+ else:
78
+ new_corrected_text += origin_text[i]
79
+ if origin_text[i] == corrected_text[j]:
80
+ j += 1
81
+ i += 1
82
+ errors = sorted(errors, key=operator.itemgetter(2))
83
+ return new_corrected_text, errors
84
+
85
+ def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
86
  """Get new corrected text and errors between corrected text and origin text
87
  code from: https://github.com/shibing624/pycorrector
88
  """
 
110
 
111
  _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
112
  corrected_text = _text[:len(text)]
113
+ print("#" * 128)
114
+ print(text)
115
+ print(corrected_text)
116
+ print(len(text), len(corrected_text))
117
+ if len(corrected_text) == len(text):
118
+ corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
119
+ else:
120
+ corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
121
  print(text, ' => ', corrected_text, details)
122
  return corrected_text + ' ' + str(details)
123
 
 
148
  description="Copy or input error Chinese text. Submit and the machine will correct text.",
149
  article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
150
  examples=examples
151
+ ).launch() # .launch(server_name="0.0.0.0", server_port=8036, share=False, debug=True)
152
 
153