File size: 5,282 Bytes
d83d604
98ac441
d1155a6
5813702
7245c1f
28b68a3
ff5002a
d1155a6
28b68a3
020246b
e66c157
020246b
73fdc8f
704c23d
 
 
 
 
 
 
 
 
020246b
5813702
2f8d685
2289b0c
861709c
2289b0c
daec533
2289b0c
 
 
f91ed41
5813702
73fdc8f
5813702
0c65a3b
94567bf
 
 
 
 
 
 
28b68a3
98ac441
28b68a3
32a64aa
d665d2a
3d19915
2c7735b
0506c95
94567bf
0506c95
94567bf
0506c95
 
 
94567bf
28b68a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d83d604
020246b
 
6ed741d
5813702
 
 
98ac441
28b68a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d665d2a
28b68a3
 
7245c1f
 
28b68a3
6b9cd40
ab605a2
 
 
6b9cd40
 
 
 
d1155a6
7245c1f
5813702
 
 
 
 
ea8b822
7b5b68f
fd8e8ce
 
826f4b1
7ee2762
fd8e8ce
 
020246b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from nltk.corpus import stopwords
from spaces import GPU  # Required for ZeroGPU Spaces
import nltk

# Download stopwords if not already available
nltk.download("stopwords")
nltk.download('punkt')
stop_words = set(stopwords.words("english"))

# Define additional words (prepositions, conjunctions, articles) to remove
extra_stopwords = set([
    'a', 'an', 'the', 'and', 'but', 'or', 'for', 'nor', 'so', 'yet', 'at', 'in', 'on', 'with', 'about', 'as', 'by', 'to', 'from', 'of', 'over', 'under', 'during', 'before', 'after', 'between', 'into', 'through', 'among', 'above', 'below'
])

# Combine NLTK stopwords with extra stopwords
stop_words = set(stopwords.words("english")).union(extra_stopwords)


# Model list
model_choices = {
    "Xindus Summarizer" : "madankn/xindus_t5base",
    "T5 Base (t5-base)": "t5-base",
    "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
    "DistilBART XSum (sshleifer/distilbart-xsum-12-6)": "sshleifer/distilbart-xsum-12-6",
    "T5 Small (t5-small)": "t5-small",
    "Flan-T5 Base (google/flan-t5-base)": "google/flan-t5-base",
    "BART Large CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
    "PEGASUS XSum (google/pegasus-xsum)": "google/pegasus-xsum",
    "BART Large XSum (facebook/bart-large-xsum)": "facebook/bart-large-xsum"
}

model_cache = {}

def emphasize_keywords(text, keywords, repeat=3):
    for kw in keywords:
        pattern = r'\b' + re.escape(kw) + r'\b'
        text = re.sub(pattern, (kw + ' ') * repeat, text, flags=re.IGNORECASE)
    return text


# Clean text: remove special characters and stop words
def clean_text(input_text):
    cleaned = re.sub(r"[^A-Za-z0-9\s]", " ", input_text)
    cleaned = re.sub(r"\b[A-Za-z]{2,}[0-9]{3,}\b", "", cleaned)  # SKU/product code pattern (letters followed by numbers)
    cleaned = re.sub(r"\b[A-Za-z]{2,}[0-9]{2,}\b", "", cleaned)
    cleaned = re.sub(r"\b\d+\b", "", cleaned)  # Remove numbers as tokens

    # Example keyword list
    keywords = ["blazer", "shirt", "trouser", "saree", "tie", "suit"]
    cleaned = emphasize_keywords(cleaned, keywords)


    words = cleaned.split()
    words = [word for word in words if word.lower() not in stop_words]

    return " ".join(words).strip()

# Load model and tokenizer
def load_model(model_name):
    if model_name not in model_cache:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        model.to("cuda" if torch.cuda.is_available() else "cpu")
        model_cache[model_name] = (tokenizer, model)

        # Warm up
        dummy_input = tokenizer("summarize: warmup", return_tensors="pt").input_ids.to(model.device)
        model.generate(dummy_input, max_length=10)
    return model_cache[model_name]

# Main function triggered by Gradio
@GPU  # 👈 Required for ZeroGPU to trigger GPU spin-up
def summarize_text(input_text, model_label, char_limit):
    if not input_text.strip():
        return "Please enter some text."

    input_text = clean_text(input_text)
    model_name = model_choices[model_label]
    tokenizer, model = load_model(model_name)

    # Prefix for T5/FLAN-style models
    if "t5" in model_name.lower():
        input_text = "summarize: " + input_text

    inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"].to(model.device)


    # Adjust the generation parameters
    summary_ids = model.generate(
        input_ids,
        max_length=30,                # Keep output length short, around the original text's length
        min_length=15,                # Ensure the summary is not too short
        do_sample=False,              # Disable sampling to avoid introducing new words
        num_beams=5,                  # Beam search to find the most likely sequence of tokens
        early_stopping=True,          # Stop once a reasonable summary is generated
        no_repeat_ngram_size=1        # Prevent repetition of n-grams (bigrams in this case)
    )



    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    # Remove any non-alphanumeric characters except space
    summary = re.sub(r"[^A-Za-z0-9\s]", "", summary)

    # Strip unwanted trailing spaces and punctuation
    summary = summary.strip()  # Remove leading and trailing spaces
    summary = re.sub(r'[^\w\s]$', '', summary)  # Remove trailing punctuation

    return summary[:char_limit].strip()

# Gradio UI
iface = gr.Interface(
    fn=summarize_text,
    inputs=[
        gr.Textbox(lines=6, label="Enter text to summarize"),
        gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="T5 Base (t5-base)"),
        gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
    ],
    outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
    title="🔥 Xindus Summarizer (GPU-Optimized)",
    description="Summarizes input using Hugging Face models with ZeroGPU. Now faster with CUDA, float16, and warm start!"
)

iface.launch()