Spaces:
Sleeping
Sleeping
File size: 3,603 Bytes
d83d604 98ac441 d1155a6 5813702 7245c1f 28b68a3 ff5002a d1155a6 28b68a3 020246b 73fdc8f 020246b 5813702 861709c daec533 861709c 020246b 5813702 73fdc8f 5813702 0c65a3b 28b68a3 98ac441 28b68a3 d83d604 020246b 6ed741d 5813702 98ac441 28b68a3 7245c1f 28b68a3 d1155a6 7245c1f 5813702 53d5734 7b5b68f fd8e8ce 28b68a3 fd8e8ce 020246b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from nltk.corpus import stopwords
from spaces import GPU # Required for ZeroGPU Spaces
import nltk
# Download stopwords if not already available
nltk.download("stopwords")
stop_words = set(stopwords.words("english"))
# Model list
model_choices = {
"DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
"T5 Small (t5-small)": "t5-small",
"T5 Base (t5-base)": "t5-base",
"Pegasus XSum (google/pegasus-xsum)": "google/pegasus-xsum",
"BART CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
}
model_cache = {}
# Clean text: remove special characters and stop words
def clean_text(input_text):
cleaned = re.sub(r"[^A-Za-z0-9\s]", " ", input_text)
words = cleaned.split()
words = [word for word in words if word.lower() not in stop_words]
return " ".join(words).strip()
# Load model and tokenizer
def load_model(model_name):
if model_name not in model_cache:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model.to("cuda" if torch.cuda.is_available() else "cpu")
model_cache[model_name] = (tokenizer, model)
# Warm up
dummy_input = tokenizer("summarize: warmup", return_tensors="pt").input_ids.to(model.device)
model.generate(dummy_input, max_length=10)
return model_cache[model_name]
# Main function triggered by Gradio
@GPU # 👈 Required for ZeroGPU to trigger GPU spin-up
def summarize_text(input_text, model_label, char_limit):
if not input_text.strip():
return "Please enter some text."
input_text = clean_text(input_text)
model_name = model_choices[model_label]
tokenizer, model = load_model(model_name)
# Prefix for T5/FLAN-style models
if "t5" in model_name.lower():
input_text = "summarize: " + input_text
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(model.device)
# Adjust the generation parameters
summary_ids = model.generate(
input_ids,
max_length=30, # Keep output length short, around the original text's length
min_length=15, # Ensure the summary is not too short
do_sample=False, # Disable sampling to avoid introducing new words
num_beams=5, # Beam search to find the most likely sequence of tokens
early_stopping=True, # Stop once a reasonable summary is generated
no_repeat_ngram_size=2 # Prevent repetition of n-grams (bigrams in this case)
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary[:char_limit].strip()
# Gradio UI
iface = gr.Interface(
fn=summarize_text,
inputs=[
gr.Textbox(lines=6, label="Enter text to summarize"),
gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
],
outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
title="🔥 Fast Summarizer (GPU-Optimized)",
description="Summarizes input using Hugging Face models with ZeroGPU support. Now faster with CUDA, float16, and warm start!"
)
iface.launch()
|