File size: 3,603 Bytes
d83d604
98ac441
d1155a6
5813702
7245c1f
28b68a3
ff5002a
d1155a6
28b68a3
020246b
 
73fdc8f
020246b
5813702
861709c
daec533
861709c
020246b
 
5813702
73fdc8f
5813702
0c65a3b
28b68a3
98ac441
28b68a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d83d604
020246b
 
6ed741d
5813702
 
 
98ac441
28b68a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7245c1f
 
28b68a3
d1155a6
7245c1f
5813702
 
 
 
 
53d5734
7b5b68f
fd8e8ce
 
28b68a3
 
fd8e8ce
 
020246b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from nltk.corpus import stopwords
from spaces import GPU  # Required for ZeroGPU Spaces
import nltk

# Download stopwords if not already available
nltk.download("stopwords")
stop_words = set(stopwords.words("english"))

# Model list
model_choices = {
    "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
    "T5 Small (t5-small)": "t5-small",
    "T5 Base (t5-base)": "t5-base",
    "Pegasus XSum (google/pegasus-xsum)": "google/pegasus-xsum",
    "BART CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
}

model_cache = {}

# Clean text: remove special characters and stop words
def clean_text(input_text):
    cleaned = re.sub(r"[^A-Za-z0-9\s]", " ", input_text)
    words = cleaned.split()
    words = [word for word in words if word.lower() not in stop_words]
    return " ".join(words).strip()

# Load model and tokenizer
def load_model(model_name):
    if model_name not in model_cache:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        model.to("cuda" if torch.cuda.is_available() else "cpu")
        model_cache[model_name] = (tokenizer, model)

        # Warm up
        dummy_input = tokenizer("summarize: warmup", return_tensors="pt").input_ids.to(model.device)
        model.generate(dummy_input, max_length=10)
    return model_cache[model_name]

# Main function triggered by Gradio
@GPU  # 👈 Required for ZeroGPU to trigger GPU spin-up
def summarize_text(input_text, model_label, char_limit):
    if not input_text.strip():
        return "Please enter some text."

    input_text = clean_text(input_text)
    model_name = model_choices[model_label]
    tokenizer, model = load_model(model_name)

    # Prefix for T5/FLAN-style models
    if "t5" in model_name.lower():
        input_text = "summarize: " + input_text

    inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"].to(model.device)


    # Adjust the generation parameters
    summary_ids = model.generate(
        input_ids,
        max_length=30,                # Keep output length short, around the original text's length
        min_length=15,                # Ensure the summary is not too short
        do_sample=False,              # Disable sampling to avoid introducing new words
        num_beams=5,                  # Beam search to find the most likely sequence of tokens
        early_stopping=True,          # Stop once a reasonable summary is generated
        no_repeat_ngram_size=2        # Prevent repetition of n-grams (bigrams in this case)
    )



    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary[:char_limit].strip()

# Gradio UI
iface = gr.Interface(
    fn=summarize_text,
    inputs=[
        gr.Textbox(lines=6, label="Enter text to summarize"),
        gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
        gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
    ],
    outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
    title="🔥 Fast Summarizer (GPU-Optimized)",
    description="Summarizes input using Hugging Face models with ZeroGPU support. Now faster with CUDA, float16, and warm start!"
)

iface.launch()