File size: 4,180 Bytes
d83d604
98ac441
d1155a6
5813702
ff5002a
d1155a6
ff5002a
d1155a6
 
ff5002a
d1155a6
73fdc8f
d1155a6
5813702
 
861709c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ed741d
 
 
5813702
73fdc8f
5813702
0c65a3b
98ac441
 
f13147a
ff5002a
d1155a6
98ac441
5813702
 
 
d1155a6
 
 
 
 
 
5813702
d1155a6
 
 
 
 
5813702
d83d604
d1155a6
6ed741d
5813702
 
 
98ac441
5813702
 
 
6ed741d
5813702
d83d604
d1155a6
5813702
d1155a6
6ed741d
5813702
d1155a6
 
5813702
 
 
 
6ed741d
d1155a6
5813702
 
 
 
 
 
fd8e8ce
e0515de
fd8e8ce
 
d1155a6
 
fd8e8ce
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from nltk.corpus import stopwords
from spaces import GPU
import nltk

# Download stopwords if not already
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

# Model choices
model_choices = {
    "Pegasus (google/pegasus-xsum)": "google/pegasus-xsum",
    "BigBird-Pegasus (google/bigbird-pegasus-large-arxiv)": "google/bigbird-pegasus-large-arxiv",
    "LongT5 Large (google/long-t5-tglobal-large)": "google/long-t5-tglobal-large",
    "BART Large CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
    "ProphetNet (microsoft/prophetnet-large-uncased-cnndm)": "microsoft/prophetnet-large-uncased-cnndm",
    "LED (allenai/led-base-16384)": "allenai/led-base-16384",
    "T5 Large (t5-large)": "t5-large",
    "Flan-T5 Large (google/flan-t5-large)": "google/flan-t5-large",
    "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
    "DistilBART XSum (mrm8488/distilbart-xsum-12-6)": "mrm8488/distilbart-xsum-12-6",
    "T5 Base (t5-base)": "t5-base",
    "Flan-T5 Base (google/flan-t5-base)": "google/flan-t5-base",
    "BART CNN SamSum (philschmid/bart-large-cnn-samsum)": "philschmid/bart-large-cnn-samsum",
    "T5 SamSum (knkarthick/pegasus-samsum)": "knkarthick/pegasus-samsum",
    "LongT5 Base (google/long-t5-tglobal-base)": "google/long-t5-tglobal-base",
    "T5 Small (t5-small)": "t5-small",
    "MBART (facebook/mbart-large-cc25)": "facebook/mbart-large-cc25",
    "MarianMT (Helsinki-NLP/opus-mt-en-ro)": "Helsinki-NLP/opus-mt-en-ro",
    "Falcon Instruct (tiiuae/falcon-7b-instruct)": "tiiuae/falcon-7b-instruct",
    "BART ELI5 (yjernite/bart_eli5)": "yjernite/bart_eli5"
}

model_cache = {}

def clean_text(input_text):
    cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', input_text)
    words = cleaned_text.split()
    words = [word for word in words if word.lower() not in stop_words]
    return " ".join(words).strip()

def load_model(model_name):
    if model_name not in model_cache:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        model_cache[model_name] = (tokenizer, model)

        # Warm up with dummy input
        dummy_input = tokenizer("summarize: hello world", return_tensors="pt").input_ids.to(device)
        model.generate(dummy_input, max_length=10)

    return model_cache[model_name]

@GPU  # 👈 Required for ZeroGPU to allocate GPU when this is called
def summarize_text(input_text, model_label, char_limit):
    if not input_text.strip():
        return "Please enter some text."

    input_text = clean_text(input_text)
    model_name = model_choices[model_label]
    tokenizer, model = load_model(model_name)

    if "t5" in model_name.lower() or "flan" in model_name.lower():
        input_text = "summarize: " + input_text

    device = model.device
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"].to(device)

    summary_ids = model.generate(
        input_ids,
        max_length=30,
        min_length=5,
        do_sample=False
    )

    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary[:char_limit].strip()

# Gradio UI
iface = gr.Interface(
    fn=summarize_text,
    inputs=[
        gr.Textbox(lines=6, label="Enter text to summarize"),
        gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="Pegasus (google/pegasus-xsum)"),
        gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
    ],
    outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
    title="Multi-Model Text Summarizer (GPU Ready)",
    description="Summarize long or short texts using state-of-the-art Hugging Face models with GPU acceleration (ZeroGPU-compatible)."
)

iface.launch()