madankn79 commited on
Commit
d1155a6
·
1 Parent(s): 198fb99
Files changed (1) hide show
  1. app.py +27 -36
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import gradio as gr
2
  import re
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from nltk.corpus import stopwords
5
-
6
- # Download the NLTK stopwords (only the first time you run)
7
  import nltk
 
 
8
  nltk.download('stopwords')
 
9
 
10
- # Model choices ordered by accuracy
11
  model_choices = {
12
  "Pegasus (google/pegasus-xsum)": "google/pegasus-xsum",
13
  "BigBird-Pegasus (google/bigbird-pegasus-large-arxiv)": "google/bigbird-pegasus-large-arxiv",
@@ -33,66 +36,54 @@ model_choices = {
33
 
34
  model_cache = {}
35
 
36
- # Get NLTK stopwords (common stop words)
37
- stop_words = set(stopwords.words('english'))
38
-
39
- # Function to clean input text by removing unnecessary words like stop words
40
  def clean_text(input_text):
41
- # Replace special characters with a space
42
  cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', input_text)
43
-
44
- # Tokenize the input text and remove stop words
45
  words = cleaned_text.split()
46
  words = [word for word in words if word.lower() not in stop_words]
47
-
48
- # Rebuild the cleaned text
49
- cleaned_text = " ".join(words)
50
-
51
- # Strip leading and trailing spaces
52
- cleaned_text = cleaned_text.strip()
53
-
54
- return cleaned_text
55
 
56
- # Load model and tokenizer
57
  def load_model(model_name):
58
  if model_name not in model_cache:
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
61
  model_cache[model_name] = (tokenizer, model)
 
 
 
 
 
62
  return model_cache[model_name]
63
 
64
- # Summarize the text using a selected model
65
  def summarize_text(input_text, model_label, char_limit):
66
  if not input_text.strip():
67
  return "Please enter some text."
68
 
69
- # Clean the input text by removing special characters and stop words
70
  input_text = clean_text(input_text)
71
-
72
  model_name = model_choices[model_label]
73
  tokenizer, model = load_model(model_name)
74
 
75
- # Adjust the input format for T5 and FLAN models
76
  if "t5" in model_name.lower() or "flan" in model_name.lower():
77
  input_text = "summarize: " + input_text
78
 
 
79
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
 
80
 
81
  summary_ids = model.generate(
82
- inputs["input_ids"],
83
- max_length=20, # Still approximate; can be tuned per model
84
  min_length=5,
85
  do_sample=False
86
  )
87
 
88
- # Decode the summary while skipping special tokens and cleaning unwanted characters
89
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
90
-
91
- # Remove unwanted characters like pipes or any unwanted symbols
92
- summary = summary.replace("|", "") # Remove pipes
93
- summary = summary.strip() # Remove leading/trailing whitespace
94
-
95
- return summary[:char_limit] # Enforce character limit
96
 
97
  # Gradio UI
98
  iface = gr.Interface(
@@ -100,11 +91,11 @@ iface = gr.Interface(
100
  inputs=[
101
  gr.Textbox(lines=6, label="Enter text to summarize"),
102
  gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="Pegasus (google/pegasus-xsum)"),
103
- gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
104
  ],
105
  outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
106
- title="Multi-Model Text Summarizer",
107
- description="Summarize text using different Hugging Face models with a user-defined character limit."
108
  )
109
 
110
  iface.launch()
 
1
  import gradio as gr
2
  import re
3
+ import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from nltk.corpus import stopwords
6
+ from spaces import GPU
 
7
  import nltk
8
+
9
+ # Download stopwords if not already
10
  nltk.download('stopwords')
11
+ stop_words = set(stopwords.words('english'))
12
 
13
+ # Model choices
14
  model_choices = {
15
  "Pegasus (google/pegasus-xsum)": "google/pegasus-xsum",
16
  "BigBird-Pegasus (google/bigbird-pegasus-large-arxiv)": "google/bigbird-pegasus-large-arxiv",
 
36
 
37
  model_cache = {}
38
 
 
 
 
 
39
  def clean_text(input_text):
 
40
  cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', input_text)
 
 
41
  words = cleaned_text.split()
42
  words = [word for word in words if word.lower() not in stop_words]
43
+ return " ".join(words).strip()
 
 
 
 
 
 
 
44
 
 
45
  def load_model(model_name):
46
  if model_name not in model_cache:
47
  tokenizer = AutoTokenizer.from_pretrained(model_name)
48
+ model = AutoModelForSeq2SeqLM.from_pretrained(
49
+ model_name,
50
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
51
+ )
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ model = model.to(device)
54
  model_cache[model_name] = (tokenizer, model)
55
+
56
+ # Warm up with dummy input
57
+ dummy_input = tokenizer("summarize: hello world", return_tensors="pt").input_ids.to(device)
58
+ model.generate(dummy_input, max_length=10)
59
+
60
  return model_cache[model_name]
61
 
62
+ @GPU # 👈 Required for ZeroGPU to allocate GPU when this is called
63
  def summarize_text(input_text, model_label, char_limit):
64
  if not input_text.strip():
65
  return "Please enter some text."
66
 
 
67
  input_text = clean_text(input_text)
 
68
  model_name = model_choices[model_label]
69
  tokenizer, model = load_model(model_name)
70
 
 
71
  if "t5" in model_name.lower() or "flan" in model_name.lower():
72
  input_text = "summarize: " + input_text
73
 
74
+ device = model.device
75
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
76
+ input_ids = inputs["input_ids"].to(device)
77
 
78
  summary_ids = model.generate(
79
+ input_ids,
80
+ max_length=30,
81
  min_length=5,
82
  do_sample=False
83
  )
84
 
 
85
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
86
+ return summary[:char_limit].strip()
 
 
 
 
 
87
 
88
  # Gradio UI
89
  iface = gr.Interface(
 
91
  inputs=[
92
  gr.Textbox(lines=6, label="Enter text to summarize"),
93
  gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="Pegasus (google/pegasus-xsum)"),
94
+ gr.Slider(minimum=30, maximum=200, value=80, step=1, label="Max Character Limit")
95
  ],
96
  outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
97
+ title="Multi-Model Text Summarizer (GPU Ready)",
98
+ description="Summarize long or short texts using state-of-the-art Hugging Face models with GPU acceleration (ZeroGPU-compatible)."
99
  )
100
 
101
  iface.launch()