madankn79's picture
google
d1155a6
raw
history blame
4.18 kB
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from nltk.corpus import stopwords
from spaces import GPU
import nltk
# Download stopwords if not already
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
# Model choices
model_choices = {
"Pegasus (google/pegasus-xsum)": "google/pegasus-xsum",
"BigBird-Pegasus (google/bigbird-pegasus-large-arxiv)": "google/bigbird-pegasus-large-arxiv",
"LongT5 Large (google/long-t5-tglobal-large)": "google/long-t5-tglobal-large",
"BART Large CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
"ProphetNet (microsoft/prophetnet-large-uncased-cnndm)": "microsoft/prophetnet-large-uncased-cnndm",
"LED (allenai/led-base-16384)": "allenai/led-base-16384",
"T5 Large (t5-large)": "t5-large",
"Flan-T5 Large (google/flan-t5-large)": "google/flan-t5-large",
"DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
"DistilBART XSum (mrm8488/distilbart-xsum-12-6)": "mrm8488/distilbart-xsum-12-6",
"T5 Base (t5-base)": "t5-base",
"Flan-T5 Base (google/flan-t5-base)": "google/flan-t5-base",
"BART CNN SamSum (philschmid/bart-large-cnn-samsum)": "philschmid/bart-large-cnn-samsum",
"T5 SamSum (knkarthick/pegasus-samsum)": "knkarthick/pegasus-samsum",
"LongT5 Base (google/long-t5-tglobal-base)": "google/long-t5-tglobal-base",
"T5 Small (t5-small)": "t5-small",
"MBART (facebook/mbart-large-cc25)": "facebook/mbart-large-cc25",
"MarianMT (Helsinki-NLP/opus-mt-en-ro)": "Helsinki-NLP/opus-mt-en-ro",
"Falcon Instruct (tiiuae/falcon-7b-instruct)": "tiiuae/falcon-7b-instruct",
"BART ELI5 (yjernite/bart_eli5)": "yjernite/bart_eli5"
}
model_cache = {}
def clean_text(input_text):
cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', input_text)
words = cleaned_text.split()
words = [word for word in words if word.lower() not in stop_words]
return " ".join(words).strip()
def load_model(model_name):
if model_name not in model_cache:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model_cache[model_name] = (tokenizer, model)
# Warm up with dummy input
dummy_input = tokenizer("summarize: hello world", return_tensors="pt").input_ids.to(device)
model.generate(dummy_input, max_length=10)
return model_cache[model_name]
@GPU # 👈 Required for ZeroGPU to allocate GPU when this is called
def summarize_text(input_text, model_label, char_limit):
if not input_text.strip():
return "Please enter some text."
input_text = clean_text(input_text)
model_name = model_choices[model_label]
tokenizer, model = load_model(model_name)
if "t5" in model_name.lower() or "flan" in model_name.lower():
input_text = "summarize: " + input_text
device = model.device
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(device)
summary_ids = model.generate(
input_ids,
max_length=30,
min_length=5,
do_sample=False
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary[:char_limit].strip()
# Gradio UI
iface = gr.Interface(
fn=summarize_text,
inputs=[
gr.Textbox(lines=6, label="Enter text to summarize"),
gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="Pegasus (google/pegasus-xsum)"),
gr.Slider(minimum=30, maximum=200, value=80, step=1, label="Max Character Limit")
],
outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
title="Multi-Model Text Summarizer (GPU Ready)",
description="Summarize long or short texts using state-of-the-art Hugging Face models with GPU acceleration (ZeroGPU-compatible)."
)
iface.launch()