RobPruzan commited on
Commit
680cdda
·
1 Parent(s): dbe5251

Adding word sense disambiguation + definitions to synonym generation

Browse files
Files changed (1) hide show
  1. app.py +218 -3
app.py CHANGED
@@ -11,6 +11,7 @@ import gradio as gr
11
  import readability
12
  import seaborn as sns
13
  import torch
 
14
  from fuzzywuzzy import fuzz
15
  from nltk.corpus import stopwords
16
  from nltk.corpus import wordnet as wn
@@ -18,6 +19,8 @@ from nltk.tokenize import word_tokenize
18
  from sklearn.metrics.pairwise import cosine_similarity
19
  from transformers import DistilBertTokenizer
20
  from transformers import pipeline
 
 
21
 
22
 
23
  nltk.download('wordnet')
@@ -442,6 +445,218 @@ def vocab_level_inter(text):
442
  interp.append(('', 0))
443
  return {'original': text, 'interpretation': interp}, f'{level(sum/total*4*2.5)[1:]} Level Vocabulary'
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
446
  gr.HTML("""<center><h7 style="font-size: 35px">Automatic Literacy and Speech Assesment</h7></center>""")
447
  gr.HTML("""<center><h7 style="font-size: 15px">This may take 60s to generate all statistics</h7></center>""")
@@ -460,8 +675,8 @@ with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
460
  audio_file = gr.Audio(source="microphone",type="filepath")
461
  grade1 = gr.Button("Grade Your Speech")
462
  with gr.Group():
463
- gr.Markdown("Reading Level Based Synonyms | Enter only one word at a time")
464
- words = gr.Textbox(label="Word For Synonyms")
465
  lvl = gr.Dropdown(choices=["Elementary Level", "Middle School Level", "High School Level", "College Level" ], label="Intended Reading Level For Synonym")
466
  get_syns = gr.Button("Get Synonyms")
467
  reccos = gr.Label()
@@ -532,6 +747,6 @@ with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
532
  grade.click(vocab_level_inter, inputs=in_text, outputs=[interpretation3, vocab_output])
533
  grade1.click(speech_to_score, inputs=audio_file, outputs=diff_output)
534
  b1.click(speech_to_text, inputs=[audio_file1, target], outputs=[text, some_val, phones])
535
- get_syns.click(gen_syns, inputs=[words, lvl], outputs=reccos)
536
  find_sim.click(get_sim_words, inputs=[in_text, words1], outputs=sims)
537
  demo.launch(debug=True)
 
11
  import readability
12
  import seaborn as sns
13
  import torch
14
+ import torch.nn.functional as F
15
  from fuzzywuzzy import fuzz
16
  from nltk.corpus import stopwords
17
  from nltk.corpus import wordnet as wn
 
19
  from sklearn.metrics.pairwise import cosine_similarity
20
  from transformers import DistilBertTokenizer
21
  from transformers import pipeline
22
+ from transformers import BertTokenizer
23
+ from transformers import AutoTokenizer, BertForSequenceClassification
24
 
25
 
26
  nltk.download('wordnet')
 
445
  interp.append(('', 0))
446
  return {'original': text, 'interpretation': interp}, f'{level(sum/total*4*2.5)[1:]} Level Vocabulary'
447
 
448
+
449
+
450
+ logger = logging.getLogger(__name__)
451
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
452
+ datefmt = '%m/%d/%Y %H:%M:%S',
453
+ level = logging.INFO)
454
+ tokenizer4 = AutoTokenizer.from_pretrained('kanishka/GlossBERT')
455
+
456
+ def construct_context_gloss_pairs_through_nltk(input, target_start_id, target_end_id):
457
+ """
458
+ construct context gloss pairs like sent_cls_ws
459
+ :param input: str, a sentence
460
+ :param target_start_id: int
461
+ :param target_end_id: int
462
+ :param lemma: lemma of the target word
463
+ :return: candidate lists
464
+ """
465
+
466
+ sent = tokenizer4.tokenize(input)
467
+ assert 0 <= target_start_id and target_start_id < target_end_id and target_end_id <= len(sent)
468
+ target = " ".join(sent[target_start_id:target_end_id])
469
+ if len(sent) > target_end_id:
470
+ sent = sent[:target_start_id] + ['"'] + sent[target_start_id:target_end_id] + ['"'] + sent[target_end_id:]
471
+ else:
472
+ sent = sent[:target_start_id] + ['"'] + sent[target_start_id:target_end_id] + ['"']
473
+
474
+ sent = " ".join(sent)
475
+
476
+ candidate = []
477
+ syns = wn.synsets(target)
478
+
479
+ for syn in syns:
480
+ if target == syn.name().split('.')[0]:
481
+ continue
482
+
483
+ gloss = (syn.definition(), syn.name())
484
+ candidate.append((sent, f"{target} : {gloss}", target, gloss))
485
+
486
+ assert len(candidate) != 0, f'there is no candidate sense of "{target}" in WordNet, please check'
487
+ # print(f'there are {len(candidate)} candidate senses of "{target}"')
488
+
489
+
490
+ return candidate
491
+
492
+
493
+ class InputFeatures(object):
494
+ """A single set of features of data."""
495
+
496
+ def __init__(self, input_ids, input_mask, segment_ids):
497
+ self.input_ids = input_ids
498
+ self.input_mask = input_mask
499
+ self.segment_ids = segment_ids
500
+
501
+
502
+ def convert_to_features(candidate, tokenizer3, max_seq_length=512):
503
+
504
+ candidate_results = []
505
+ features = []
506
+ for item in candidate:
507
+ text_a = item[0] # sentence
508
+ text_b = item[1] # gloss
509
+ candidate_results.append((item[-2], item[-1])) # (target, gloss)
510
+
511
+
512
+ tokens_a = tokenizer3.tokenize(text_a)
513
+ tokens_b = tokenizer3.tokenize(text_b)
514
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
515
+ tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
516
+ segment_ids = [0] * len(tokens)
517
+ tokens += tokens_b + ["[SEP]"]
518
+ segment_ids += [1] * (len(tokens_b) + 1)
519
+
520
+ input_ids = tokenizer3.convert_tokens_to_ids(tokens)
521
+
522
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
523
+ # tokens are attended to.
524
+ input_mask = [1] * len(input_ids)
525
+
526
+ # Zero-pad up to the sequence length.
527
+ padding = [0] * (max_seq_length - len(input_ids))
528
+ input_ids += padding
529
+ input_mask += padding
530
+ segment_ids += padding
531
+
532
+ assert len(input_ids) == max_seq_length
533
+ assert len(input_mask) == max_seq_length
534
+ assert len(segment_ids) == max_seq_length
535
+
536
+ features.append(
537
+ InputFeatures(input_ids=input_ids,
538
+ input_mask=input_mask,
539
+ segment_ids=segment_ids))
540
+
541
+
542
+ return features, candidate_results
543
+
544
+
545
+
546
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
547
+ """Truncates a sequence pair in place to the maximum length."""
548
+
549
+ # This is a simple heuristic which will always truncate the longer sequence
550
+ # one token at a time. This makes more sense than truncating an equal percent
551
+ # of tokens from each, since if one sequence is very short then each token
552
+ # that's truncated likely contains more information than a longer sequence.
553
+ while True:
554
+ total_length = len(tokens_a) + len(tokens_b)
555
+ if total_length <= max_length:
556
+ break
557
+ if len(tokens_a) > len(tokens_b):
558
+ tokens_a.pop()
559
+ else:
560
+ tokens_b.pop()
561
+
562
+
563
+ def infer(input, target_start_id, target_end_id, args):
564
+ sent = tokenizer4.tokenize(input)
565
+ assert 0 <= target_start_id and target_start_id < target_end_id and target_end_id <= len(sent)
566
+ target = " ".join(sent[target_start_id:target_end_id])
567
+
568
+
569
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
570
+
571
+
572
+ label_list = ["0", "1"]
573
+ num_labels = len(label_list)
574
+
575
+ model = BertForSequenceClassification.from_pretrained(args.bert_model,
576
+ num_labels=num_labels)
577
+ model.to(device)
578
+
579
+ # print(f"input: {input}\ntarget: {target}")
580
+ examples = construct_context_gloss_pairs_through_nltk(input, target_start_id, target_end_id)
581
+ eval_features, candidate_results = convert_to_features(examples, tokenizer4)
582
+ input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
583
+ input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
584
+ segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
585
+
586
+
587
+ model.eval()
588
+ input_ids = input_ids.to(device)
589
+ input_mask = input_mask.to(device)
590
+ segment_ids = segment_ids.to(device)
591
+ with torch.no_grad():
592
+ logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None).logits
593
+ logits_ = F.softmax(logits, dim=-1)
594
+ logits_ = logits_.detach().cpu().numpy()
595
+ output = np.argmax(logits_, axis=0)[1]
596
+ results= []
597
+ for idx, i in enumerate(logits_):
598
+ results.append((candidate_results[idx][1], i[1]*100))
599
+ sorted_results = sorted(results, key=lambda x: x[1], reverse=True)
600
+
601
+ return sorted_results
602
+
603
+ def format_for_gradio(inp):
604
+ retval = ''
605
+ for idx, i in enumerate(inp):
606
+ if idx == len(inp)-1:
607
+ retval += i.split('.')[0]
608
+ break
609
+ retval += f'''{i.split('.')[0]} | '''
610
+ return retval
611
+
612
+
613
+ def smart_synonyms(text, level):
614
+ parser = argparse.ArgumentParser()
615
+ parser.add_argument("--bert_model", default="kanishka/GlossBERT", type=str)
616
+ parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available")
617
+ args, unknown = parser.parse_known_args()
618
+
619
+ location = 0
620
+ word = ''
621
+ tokens = tokenizer4.tokenize(text)
622
+ school_to_level = {"Elementary Level":'1', "Middle School Level":'2', "High School Level":'3', "College Level":'4'}
623
+ for idx, i in enumerate(tokens):
624
+ if i[0] == '@':
625
+ location = idx
626
+ text = text.replace('@', '')
627
+ word = tokens[location]
628
+ break
629
+ raw_syns = []
630
+ raw_defs = []
631
+ raw_scores = []
632
+ syns = []
633
+ defs = []
634
+ scores = []
635
+ preds = infer(text, location, location+1, args)
636
+ for i in preds:
637
+ if not i[0][1].split('.')[0] in data[school_to_level[level]]:
638
+ continue
639
+ raw_syns.append(i[0][1])
640
+ raw_defs.append(i[0][0])
641
+ raw_scores.append(i[1])
642
+ if i[1] > 5:
643
+ syns.append(i[0][1])
644
+ defs.append(i[0][0])
645
+ scores.append(i[1])
646
+
647
+ if not syns:
648
+ top_syns = int(len(raw_syns)*.25//1+1)
649
+ syns = raw_syns[:top_syns]
650
+ defs = raw_defs[:top_syns]
651
+ scores = raw_scores[:top_syns]
652
+
653
+ cleaned_syns = format_for_gradio(syns)
654
+ cleaend_defs = format_for_gradio(defs)
655
+
656
+ return f'{cleaned_syns}: Definition- {cleaend_defs} | '
657
+
658
+
659
+
660
  with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
661
  gr.HTML("""<center><h7 style="font-size: 35px">Automatic Literacy and Speech Assesment</h7></center>""")
662
  gr.HTML("""<center><h7 style="font-size: 15px">This may take 60s to generate all statistics</h7></center>""")
 
675
  audio_file = gr.Audio(source="microphone",type="filepath")
676
  grade1 = gr.Button("Grade Your Speech")
677
  with gr.Group():
678
+ gr.Markdown("""Reading Level Based Synonyms | Enter a sentence with the word you want a synonym | Add an @ before the target word for synonym, e.g. - "Today is an @amazing day"- target word = amazing" """)
679
+ words = gr.Textbox(label="Text with word for synonyms")
680
  lvl = gr.Dropdown(choices=["Elementary Level", "Middle School Level", "High School Level", "College Level" ], label="Intended Reading Level For Synonym")
681
  get_syns = gr.Button("Get Synonyms")
682
  reccos = gr.Label()
 
747
  grade.click(vocab_level_inter, inputs=in_text, outputs=[interpretation3, vocab_output])
748
  grade1.click(speech_to_score, inputs=audio_file, outputs=diff_output)
749
  b1.click(speech_to_text, inputs=[audio_file1, target], outputs=[text, some_val, phones])
750
+ get_syns.click(smart_synonyms, inputs=[words, lvl], outputs=reccos)
751
  find_sim.click(get_sim_words, inputs=[in_text, words1], outputs=sims)
752
  demo.launch(debug=True)