import torch import numpy as np import sys import os from transformers import RobertaTokenizer, AutoModelForTokenClassification, RobertaForSequenceClassification import spacy import spacy_alignments as tokenizations import numpy as np from copy import deepcopy from sty import fg, bg, ef, rs, RgbBg, Style import re from tqdm import tqdm import gradio as gr dir_path = os.path.dirname(os.path.realpath(__file__)).split("\webapp")[0] os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" os.system("python -m spacy download en_core_web_sm") nlp = spacy.load("en_core_web_sm") tokenizer = RobertaTokenizer.from_pretrained("roberta-base") clause_model = AutoModelForTokenClassification.from_pretrained("{}/clause_model_512".format(dir_path), num_labels=3) classification_model = RobertaForSequenceClassification.from_pretrained("{}/classfication_model".format(dir_path), num_labels=18) labels2attrs = { "##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"), "##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"), "##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), # This should be (static, or habitual) "##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"), "##BASIC STATE": ("specific", "stative", "static"), "##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"), "##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"), "##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"), "##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"), "##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), # habitual count as unbounded "##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), # The car is red now (static) "##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), # I go to the gym regularly (habitual) "##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"), "##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"), "##QUESTION": ("NA", "NA", "NA"), "##IMPERATIVE": ("NA", "NA", "NA"), "##NONSENSE": ("NA", "NA", "NA"), "##OTHER": ("NA", "NA", "NA"), } label2index = {l:i for l,i in zip(labels2attrs.keys(), np.arange(len(labels2attrs)))} index2label = {i:l for l,i in label2index.items()} def auto_split(text): doc = nlp(text) current_len = 0 snippets = [] current_snippet = "" for sent in doc.sents: text = sent.text words = text.split() if current_len + len(words) > 200: snippets.append(current_snippet) current_snippet = text current_len = len(words) else: current_snippet += " " + text current_len += len(words) snippets.append(current_snippet) # the leftover part. return snippets def majority_vote(array): unique, counts = np.unique(np.array(array), return_counts=True) return unique[np.argmax(counts)] def get_pred_clause_labels(text, words): model_inputs = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt') roberta_tokens = (tokenizer.convert_ids_to_tokens(model_inputs['input_ids'][0])) a2b, b2a = tokenizations.get_alignments(words, roberta_tokens) logits = clause_model(**model_inputs)[0] tagging = logits.argmax(-1)[0].numpy() pred_labels = [] for aligment in a2b: # spacy token index to roberta_token index if len(aligment) == 0: pred_labels.append(1) elif len(aligment) == 1: pred_labels.append(tagging[aligment[0]]) else: pred_labels.append(majority_vote([tagging[a] for a in aligment])) assert len(pred_labels) == len(words) return pred_labels def seg_clause(text): words = text.strip().split() labels = get_pred_clause_labels(text, words) segmented_clauses = [] prev_label = 2 current_clause = None for cur_token, cur_label in zip(words, labels): if prev_label == 2: current_clause = [] if current_clause != None: current_clause.append(cur_token) if cur_label == 2: if prev_label in [0, 1]: segmented_clauses.append(deepcopy(current_clause)) ## 0 1 1 1 1 2 0 1 1 current_clause = None prev_label = cur_label if current_clause is not None and len(current_clause) != 0: # append leftover segmented_clauses.append(deepcopy(current_clause)) return [" ".join(clause) for clause in segmented_clauses if clause is not None] def pretty_print_segmented_clause(segmented_clauses): np.random.seed(42) bg.orange = Style(RgbBg(255, 150, 50)) bg.purple = Style(RgbBg(180, 130, 225)) colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple] prev_color = 0 to_print = [] for cl in segmented_clauses: color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color)) prev_color = color_choice colored_cl = colors[color_choice] + cl + bg.rs to_print.append(colored_cl) print(*to_print, sep=" ") def get_pred_classification_labels(clauses, batch_size=32): clause2labels = [] for i in range(0, len(clauses), batch_size): batch_examples = clauses[i : i + batch_size] model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt') logits = classification_model(**model_inputs)[0] pred_labels = logits.argmax(-1).numpy() pred_labels = [index2label[l] for l in pred_labels] clause2labels.extend([(s, str(l),) for s,l in zip(batch_examples, pred_labels)]) return clause2labels def run_pipeline(text): snippets = auto_split(text) print(snippets) all_clauses = [] for s in snippets: segmented_clauses = seg_clause(s) all_clauses.extend(segmented_clauses) clause2labels = get_pred_classification_labels(all_clauses) output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)] return output_clauses, clause2labels # with open("pipeline_outputs.jsonl", "w") as fw: # with open("all_text.txt", "r") as f: # lines = f.readlines() # print(f"Totally detected {len(lines)} documents.") # for text in tqdm(lines): # snippets = auto_split(text) # all_clauses = [] # for s in snippets: # segmented_clauses = seg_clause(s) # all_clauses.extend(segmented_clauses) # # pretty_print_segmented_clause(segmented_clauses) # clause2labels = get_pred_classification_labels(all_clauses) # json.dump(clause2labels, fw) # fw.write("\n") color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"] index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)} color_panel_2 = ["Violet", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Gray"] str_attrs = [str(v) for v in set(labels2attrs.values())] print(str_attrs, len(str_attrs), len(color_panel_2)) assert len(str_attrs) == len(color_panel_2) attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)} # attr_colormap = { # ("specific", "dynamic", "episodic"): # ("generic", "dynamic", "episodic"): # ("specific", "dynamic", "static"): # ("generic", "dynamic", "static"): # ("specific", "stative", "static"): # ("specific", "dynamic", "static"): # ("generic", "dynamic", "static"): # ("specific", "dynamic", "episodic"): # ("generic", "dynamic", "episodic"): # ("generic", "dynamic", "habitual"): # ("generic", "stative", "static"): # ("generic", "stative", "habitual"): # ("specific", "dynamic", "habitual"): # ("specific", "stative", "habitual"): # ("NA", "NA", "NA"): # } if __name__ == "__main__": demo = gr.Interface( fn=run_pipeline, inputs=["text"], outputs= [ gr.HighlightedText( label="Clause Segmentation", show_label=True, combine_adjacent=False, ).style(color_map = index_colormap), gr.HighlightedText( label="Attribute Classification", show_label=True, show_legend=True, combine_adjacent=False, ).style(color_map=attr_colormap), ] ) demo.launch(share=True)