BabakScrapes commited on
Commit
e863891
·
verified ·
1 Parent(s): 461f0ea

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +209 -0
pipeline.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import sys
4
+ import os
5
+ from transformers import RobertaTokenizer, AutoModelForTokenClassification, RobertaForSequenceClassification
6
+ import spacy
7
+ import spacy_alignments as tokenizations
8
+ import numpy as np
9
+ from copy import deepcopy
10
+ from sty import fg, bg, ef, rs, RgbBg, Style
11
+ import re
12
+ from tqdm import tqdm
13
+ import gradio as gr
14
+
15
+ dir_path = os.path.dirname(os.path.realpath(__file__)).split("\webapp")[0]
16
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
17
+ os.system("python -m spacy download en_core_web_sm")
18
+ nlp = spacy.load("en_core_web_sm")
19
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
20
+ clause_model = AutoModelForTokenClassification.from_pretrained("{}\\Trained Models\\clause_model_512".format(dir_path), num_labels=3)
21
+ classification_model = RobertaForSequenceClassification.from_pretrained("{}\Trained Models\classfication_model".format(dir_path), num_labels=18)
22
+
23
+ labels2attrs = {
24
+ "##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"),
25
+ "##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"),
26
+ "##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), # This should be (static, or habitual)
27
+ "##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"),
28
+ "##BASIC STATE": ("specific", "stative", "static"),
29
+ "##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"),
30
+ "##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"),
31
+ "##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"),
32
+ "##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"),
33
+ "##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), # habitual count as unbounded
34
+ "##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), # The car is red now (static)
35
+ "##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), # I go to the gym regularly (habitual)
36
+ "##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"),
37
+ "##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"),
38
+ "##QUESTION": ("NA", "NA", "NA"),
39
+ "##IMPERATIVE": ("NA", "NA", "NA"),
40
+ "##NONSENSE": ("NA", "NA", "NA"),
41
+ "##OTHER": ("NA", "NA", "NA"),
42
+ }
43
+
44
+ label2index = {l:i for l,i in zip(labels2attrs.keys(), np.arange(len(labels2attrs)))}
45
+ index2label = {i:l for l,i in label2index.items()}
46
+
47
+ def auto_split(text):
48
+ doc = nlp(text)
49
+ current_len = 0
50
+ snippets = []
51
+ current_snippet = ""
52
+ for sent in doc.sents:
53
+ text = sent.text
54
+ words = text.split()
55
+ if current_len + len(words) > 200:
56
+ snippets.append(current_snippet)
57
+ current_snippet = text
58
+ current_len = len(words)
59
+ else:
60
+ current_snippet += " " + text
61
+ current_len += len(words)
62
+ snippets.append(current_snippet) # the leftover part.
63
+ return snippets
64
+
65
+
66
+ def majority_vote(array):
67
+ unique, counts = np.unique(np.array(array), return_counts=True)
68
+ return unique[np.argmax(counts)]
69
+
70
+ def get_pred_clause_labels(text, words):
71
+ model_inputs = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
72
+ roberta_tokens = (tokenizer.convert_ids_to_tokens(model_inputs['input_ids'][0]))
73
+ a2b, b2a = tokenizations.get_alignments(words, roberta_tokens)
74
+ logits = clause_model(**model_inputs)[0]
75
+ tagging = logits.argmax(-1)[0].numpy()
76
+ pred_labels = []
77
+ for aligment in a2b: # spacy token index to roberta_token index
78
+ if len(aligment) == 0: pred_labels.append(1)
79
+ elif len(aligment) == 1: pred_labels.append(tagging[aligment[0]])
80
+ else:
81
+ pred_labels.append(majority_vote([tagging[a] for a in aligment]))
82
+ assert len(pred_labels) == len(words)
83
+ return pred_labels
84
+
85
+ def seg_clause(text):
86
+ words = text.strip().split()
87
+ labels = get_pred_clause_labels(text, words)
88
+ segmented_clauses = []
89
+ prev_label = 2
90
+ current_clause = None
91
+ for cur_token, cur_label in zip(words, labels):
92
+ if prev_label == 2: current_clause = []
93
+ if current_clause != None: current_clause.append(cur_token)
94
+
95
+ if cur_label == 2:
96
+ if prev_label in [0, 1]:
97
+ segmented_clauses.append(deepcopy(current_clause)) ## 0 1 1 1 1 2 0 1 1
98
+ current_clause = None
99
+ prev_label = cur_label
100
+
101
+ if current_clause is not None and len(current_clause) != 0: # append leftover
102
+ segmented_clauses.append(deepcopy(current_clause))
103
+ return [" ".join(clause) for clause in segmented_clauses if clause is not None]
104
+
105
+ def pretty_print_segmented_clause(segmented_clauses):
106
+ np.random.seed(42)
107
+ bg.orange = Style(RgbBg(255, 150, 50))
108
+ bg.purple = Style(RgbBg(180, 130, 225))
109
+ colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple]
110
+ prev_color = 0
111
+ to_print = []
112
+ for cl in segmented_clauses:
113
+ color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color))
114
+ prev_color = color_choice
115
+ colored_cl = colors[color_choice] + cl + bg.rs
116
+ to_print.append(colored_cl)
117
+ print(*to_print, sep=" ")
118
+
119
+
120
+ def get_pred_classification_labels(clauses, batch_size=32):
121
+ clause2labels = []
122
+ for i in range(0, len(clauses), batch_size):
123
+ batch_examples = clauses[i : i + batch_size]
124
+ model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
125
+ logits = classification_model(**model_inputs)[0]
126
+ pred_labels = logits.argmax(-1).numpy()
127
+ pred_labels = [index2label[l] for l in pred_labels]
128
+ clause2labels.extend([(s, str(l),) for s,l in zip(batch_examples, pred_labels)])
129
+ return clause2labels
130
+
131
+
132
+
133
+ def run_pipeline(text):
134
+ snippets = auto_split(text)
135
+ print(snippets)
136
+ all_clauses = []
137
+ for s in snippets:
138
+ segmented_clauses = seg_clause(s)
139
+ all_clauses.extend(segmented_clauses)
140
+
141
+ clause2labels = get_pred_classification_labels(all_clauses)
142
+ output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
143
+ return output_clauses, clause2labels
144
+
145
+ # with open("pipeline_outputs.jsonl", "w") as fw:
146
+ # with open("all_text.txt", "r") as f:
147
+ # lines = f.readlines()
148
+ # print(f"Totally detected {len(lines)} documents.")
149
+ # for text in tqdm(lines):
150
+ # snippets = auto_split(text)
151
+ # all_clauses = []
152
+ # for s in snippets:
153
+ # segmented_clauses = seg_clause(s)
154
+ # all_clauses.extend(segmented_clauses)
155
+ # # pretty_print_segmented_clause(segmented_clauses)
156
+
157
+ # clause2labels = get_pred_classification_labels(all_clauses)
158
+ # json.dump(clause2labels, fw)
159
+ # fw.write("\n")
160
+
161
+ color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"]
162
+ index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
163
+ color_panel_2 = ["Violet", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Gray"]
164
+ str_attrs = [str(v) for v in set(labels2attrs.values())]
165
+ print(str_attrs, len(str_attrs), len(color_panel_2))
166
+ assert len(str_attrs) == len(color_panel_2)
167
+ attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
168
+
169
+ # attr_colormap = {
170
+ # ("specific", "dynamic", "episodic"):
171
+ # ("generic", "dynamic", "episodic"):
172
+ # ("specific", "dynamic", "static"):
173
+ # ("generic", "dynamic", "static"):
174
+ # ("specific", "stative", "static"):
175
+ # ("specific", "dynamic", "static"):
176
+ # ("generic", "dynamic", "static"):
177
+ # ("specific", "dynamic", "episodic"):
178
+ # ("generic", "dynamic", "episodic"):
179
+ # ("generic", "dynamic", "habitual"):
180
+ # ("generic", "stative", "static"):
181
+ # ("generic", "stative", "habitual"):
182
+ # ("specific", "dynamic", "habitual"):
183
+ # ("specific", "stative", "habitual"):
184
+ # ("NA", "NA", "NA"):
185
+ # }
186
+
187
+
188
+ if __name__ == "__main__":
189
+
190
+ demo = gr.Interface(
191
+ fn=run_pipeline,
192
+ inputs=["text"],
193
+ outputs= [
194
+ gr.HighlightedText(
195
+ label="Clause Segmentation",
196
+ show_label=True,
197
+ combine_adjacent=False,
198
+ ).style(color_map = index_colormap),
199
+
200
+ gr.HighlightedText(
201
+ label="Attribute Classification",
202
+ show_label=True,
203
+ show_legend=True,
204
+ combine_adjacent=False,
205
+ ).style(color_map=attr_colormap),
206
+ ]
207
+ )
208
+
209
+ demo.launch(share=True)