File size: 4,277 Bytes
d83d604
98ac441
5813702
73fdc8f
6ed741d
5813702
 
861709c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ed741d
 
 
5813702
73fdc8f
5813702
0c65a3b
f13147a
 
 
 
 
 
 
 
98ac441
 
 
f13147a
 
 
 
 
 
 
 
98ac441
 
f13147a
98ac441
 
 
5813702
 
 
 
 
 
d83d604
98ac441
6ed741d
5813702
 
 
98ac441
 
 
5813702
 
 
98ac441
6ed741d
5813702
d83d604
5813702
6ed741d
5813702
 
b7e9fc8
5813702
 
 
 
ce89b24
6ed741d
ce89b24
 
 
 
 
6ed741d
5813702
 
 
 
 
 
f13147a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Model choices ordered by accuracy
model_choices = {
    "Pegasus (google/pegasus-xsum)": "google/pegasus-xsum",
    "BigBird-Pegasus (google/bigbird-pegasus-large-arxiv)": "google/bigbird-pegasus-large-arxiv",
    "LongT5 Large (google/long-t5-tglobal-large)": "google/long-t5-tglobal-large",
    "BART Large CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
    "ProphetNet (microsoft/prophetnet-large-uncased-cnndm)": "microsoft/prophetnet-large-uncased-cnndm",
    "LED (allenai/led-base-16384)": "allenai/led-base-16384",
    "T5 Large (t5-large)": "t5-large",
    "Flan-T5 Large (google/flan-t5-large)": "google/flan-t5-large",
    "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
    "DistilBART XSum (mrm8488/distilbart-xsum-12-6)": "mrm8488/distilbart-xsum-12-6",
    "T5 Base (t5-base)": "t5-base",
    "Flan-T5 Base (google/flan-t5-base)": "google/flan-t5-base",
    "BART CNN SamSum (philschmid/bart-large-cnn-samsum)": "philschmid/bart-large-cnn-samsum",
    "T5 SamSum (knkarthick/pegasus-samsum)": "knkarthick/pegasus-samsum",
    "LongT5 Base (google/long-t5-tglobal-base)": "google/long-t5-tglobal-base",
    "T5 Small (t5-small)": "t5-small",
    "MBART (facebook/mbart-large-cc25)": "facebook/mbart-large-cc25",
    "MarianMT (Helsinki-NLP/opus-mt-en-ro)": "Helsinki-NLP/opus-mt-en-ro",
    "Falcon Instruct (tiiuae/falcon-7b-instruct)": "tiiuae/falcon-7b-instruct",
    "BART ELI5 (yjernite/bart_eli5)": "yjernite/bart_eli5"
}

model_cache = {}

# List of common prepositions and conjunctions
prepositions_and_conjunctions = set([
    "in", "on", "at", "by", "for", "with", "about", "as", "into", "during", "before", "after",
    "of", "to", "from", "and", "but", "or", "nor", "so", "yet", "for", "because", "although", "since",
    "unless", "until", "while", "if", "than", "whether", "where", "when", "that", "which", "who", "whom"
])

# Function to clean input text by removing prepositions and conjunctions
def clean_text(input_text):
    # Replace special characters with a space
    cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', input_text)
    
    # Tokenize the input text and remove prepositions/conjunctions
    words = cleaned_text.split()
    words = [word for word in words if word.lower() not in prepositions_and_conjunctions]
    
    # Rebuild the cleaned text
    cleaned_text = " ".join(words)
    
    # Strip leading and trailing spaces
    cleaned_text = cleaned_text.strip()
    
    return cleaned_text

# Load model and tokenizer
def load_model(model_name):
    if model_name not in model_cache:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        model_cache[model_name] = (tokenizer, model)
    return model_cache[model_name]

# Summarize the text using a selected model
def summarize_text(input_text, model_label, char_limit):
    if not input_text.strip():
        return "Please enter some text."

    # Clean the input text by removing special characters and extra spaces
    input_text = clean_text(input_text)

    model_name = model_choices[model_label]
    tokenizer, model = load_model(model_name)

    # Adjust the input format for T5 and FLAN models
    if "t5" in model_name.lower() or "flan" in model_name.lower():
        input_text = "summarize: " + input_text

    inputs = tokenizer(input_text, return_tensors="pt", truncation=True)

    summary_ids = model.generate(
        inputs["input_ids"],
        max_length=15,  # Still approximate; can be tuned per model
        min_length=5,
        do_sample=False
    )

    # Decode the summary while skipping special tokens and cleaning unwanted characters
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    # Remove unwanted characters like pipes or any unwanted symbols
    summary = summary.replace("|", "")  # Remove pipes
    summary = summary.strip()  # Remove leading/trailing whitespace

    return summary[:char_limit]  # Enforce character limit

# Gradio UI
iface = gr.Interface(
    fn=summarize_text,
    inputs=[
        gr.Textbox(lines=6, label="Enter text to summarize"),
        gr.Dropdown(choices