File size: 4,890 Bytes
d83d604
98ac441
d1155a6
5813702
7245c1f
28b68a3
ff5002a
d1155a6
28b68a3
020246b
e66c157
020246b
73fdc8f
704c23d
 
 
 
 
 
 
 
 
020246b
5813702
2289b0c
861709c
2289b0c
daec533
2289b0c
 
 
 
 
 
5813702
73fdc8f
5813702
0c65a3b
28b68a3
98ac441
28b68a3
32a64aa
2c7735b
28b68a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d83d604
020246b
 
6ed741d
5813702
 
 
98ac441
28b68a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7245c1f
 
28b68a3
6b9cd40
ab605a2
 
 
6b9cd40
 
 
 
d1155a6
7245c1f
5813702
 
 
 
 
53d5734
7b5b68f
fd8e8ce
 
28b68a3
 
fd8e8ce
 
020246b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from nltk.corpus import stopwords
from spaces import GPU  # Required for ZeroGPU Spaces
import nltk

# Download stopwords if not already available
nltk.download("stopwords")
nltk.download('punkt')
stop_words = set(stopwords.words("english"))

# Define additional words (prepositions, conjunctions, articles) to remove
extra_stopwords = set([
    'a', 'an', 'the', 'and', 'but', 'or', 'for', 'nor', 'so', 'yet', 'at', 'in', 'on', 'with', 'about', 'as', 'by', 'to', 'from', 'of', 'over', 'under', 'during', 'before', 'after', 'between', 'into', 'through', 'among', 'above', 'below'
])

# Combine NLTK stopwords with extra stopwords
stop_words = set(stopwords.words("english")).union(extra_stopwords)


# Model list
model_choices = {
    "T5 Base (t5-base)": "t5-base",
    "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
    "DistilBART XSum (sshleifer/distilbart-xsum-12-6)": "sshleifer/distilbart-xsum-12-6",
    "T5 Small (t5-small)": "t5-small",
    "Flan-T5 Base (google/flan-t5-base)": "google/flan-t5-base",
    "BART Large CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
    "PEGASUS XSum (google/pegasus-xsum)": "google/pegasus-xsum",
    "BART Large XSum (facebook/bart-large-xsum)": "facebook/bart-large-xsum",
    "DistilGPT-2 (distilgpt2)": "distilgpt2",
    "BART Large SciTLDR (facebook/bart-large-scitldr)": "facebook/bart-large-scitldr"
}

model_cache = {}

# Clean text: remove special characters and stop words
def clean_text(input_text):
    cleaned = re.sub(r"[^A-Za-z0-9\s]", " ", input_text)
    cleaned = re.sub(r"\b[A-Za-z]{2,}[0-9]{3,}\b", "", cleaned)  # SKU/product code pattern (letters followed by numbers)

    words = cleaned.split()
    words = [word for word in words if word.lower() not in stop_words]
    return " ".join(words).strip()

# Load model and tokenizer
def load_model(model_name):
    if model_name not in model_cache:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        model.to("cuda" if torch.cuda.is_available() else "cpu")
        model_cache[model_name] = (tokenizer, model)

        # Warm up
        dummy_input = tokenizer("summarize: warmup", return_tensors="pt").input_ids.to(model.device)
        model.generate(dummy_input, max_length=10)
    return model_cache[model_name]

# Main function triggered by Gradio
@GPU  # 👈 Required for ZeroGPU to trigger GPU spin-up
def summarize_text(input_text, model_label, char_limit):
    if not input_text.strip():
        return "Please enter some text."

    input_text = clean_text(input_text)
    model_name = model_choices[model_label]
    tokenizer, model = load_model(model_name)

    # Prefix for T5/FLAN-style models
    if "t5" in model_name.lower():
        input_text = "summarize: " + input_text

    inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"].to(model.device)


    # Adjust the generation parameters
    summary_ids = model.generate(
        input_ids,
        max_length=30,                # Keep output length short, around the original text's length
        min_length=15,                # Ensure the summary is not too short
        do_sample=False,              # Disable sampling to avoid introducing new words
        num_beams=5,                  # Beam search to find the most likely sequence of tokens
        early_stopping=True,          # Stop once a reasonable summary is generated
        no_repeat_ngram_size=2        # Prevent repetition of n-grams (bigrams in this case)
    )



    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    # Remove any non-alphanumeric characters except space
    summary = re.sub(r"[^A-Za-z0-9\s]", "", summary)

    # Strip unwanted trailing spaces and punctuation
    summary = summary.strip()  # Remove leading and trailing spaces
    summary = re.sub(r'[^\w\s]$', '', summary)  # Remove trailing punctuation

    return summary[:char_limit].strip()

# Gradio UI
iface = gr.Interface(
    fn=summarize_text,
    inputs=[
        gr.Textbox(lines=6, label="Enter text to summarize"),
        gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
        gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
    ],
    outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
    title="🔥 Fast Summarizer (GPU-Optimized)",
    description="Summarizes input using Hugging Face models with ZeroGPU support. Now faster with CUDA, float16, and warm start!"
)

iface.launch()