madankn79 commited on
Commit
020246b
·
1 Parent(s): d5a7621
Files changed (1) hide show
  1. app.py +28 -51
app.py CHANGED
@@ -2,66 +2,49 @@ import gradio as gr
2
  import re
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
- from spaces import GPU # Required for ZeroGPU Spaces
6
  from nltk.corpus import stopwords
 
7
  import nltk
8
 
9
- # Download NLTK stopwords (only the first time you run)
10
- nltk.download('stopwords')
11
- stop_words = set(stopwords.words('english'))
12
 
13
- # Best lightweight summarization models
14
  model_choices = {
15
  "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
16
  "T5 Small (t5-small)": "t5-small",
17
  "T5 Base (t5-base)": "t5-base",
18
- "Flan-T5 Base (google/flan-t5-base)": "google/flan-t5-base",
19
- "DistilBART XSum (sshleifer/distilbart-xsum-12-6)": "sshleifer/distilbart-xsum-12-6"
20
  }
21
 
22
  model_cache = {}
23
 
24
- # Clean input text (remove stopwords, SKU codes, and non-meaningful text)
25
  def clean_text(input_text):
26
- # Step 1: Remove any non-English characters (like special symbols, non-latin characters)
27
- cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', input_text) # Allow only letters and numbers
28
- cleaned_text = re.sub(r'\s+', ' ', cleaned_text) # Replace multiple spaces with a single space
29
-
30
- # Step 2: Tokenize the text and remove stopwords and words that are too short to be meaningful
31
- words = cleaned_text.split()
32
- filtered_words = [word for word in words if word.lower() not in stop_words and len(word) > 2]
33
-
34
- # Step 3: Rebuild the text from the remaining words
35
- filtered_text = " ".join(filtered_words)
36
-
37
- # Step 4: Remove any product codes or sequences (e.g., ST1642, AB1234)
38
- # Assuming product codes follow a pattern of letters followed by numbers
39
- filtered_text = re.sub(r'\b[A-Za-z]{2,}[0-9]{3,}\b', '', filtered_text) # SKU/product code pattern
40
-
41
- # Strip leading/trailing spaces
42
- filtered_text = filtered_text.strip()
43
-
44
- return filtered_text
45
 
46
  # Load model and tokenizer
47
  def load_model(model_name):
48
  if model_name not in model_cache:
49
  tokenizer = AutoTokenizer.from_pretrained(model_name)
50
  model = AutoModelForSeq2SeqLM.from_pretrained(
51
- model_name,
52
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
53
  )
54
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
- model.to(device)
56
 
57
- # Warm-up
58
- dummy_input = tokenizer("summarize: warm up", return_tensors="pt").input_ids.to(device)
59
  model.generate(dummy_input, max_length=10)
60
-
61
- model_cache[model_name] = (tokenizer, model)
62
  return model_cache[model_name]
63
 
64
- # Summarize the text using a selected model
 
65
  def summarize_text(input_text, model_label, char_limit):
66
  if not input_text.strip():
67
  return "Please enter some text."
@@ -70,26 +53,20 @@ def summarize_text(input_text, model_label, char_limit):
70
  model_name = model_choices[model_label]
71
  tokenizer, model = load_model(model_name)
72
 
73
- if "t5" in model_name.lower() or "flan" in model_name.lower():
 
74
  input_text = "summarize: " + input_text
75
 
76
- device = model.device
77
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
78
- input_ids = inputs["input_ids"].to(device)
79
-
80
- # Adjust the length constraints to make sure min_length < max_length
81
- max_len = 20 # Set your desired max length
82
- min_len = 5 # Ensure min_length is smaller than max_length
83
 
84
- # Generate summary
85
  summary_ids = model.generate(
86
  input_ids,
87
- max_length=max_len,
88
- min_length=min_len,
89
  do_sample=False
90
  )
91
 
92
- # Decode the summary
93
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
94
  return summary[:char_limit].strip()
95
 
@@ -99,11 +76,11 @@ iface = gr.Interface(
99
  inputs=[
100
  gr.Textbox(lines=6, label="Enter text to summarize"),
101
  gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
102
- gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
103
  ],
104
  outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
105
- title="🚀 Fast Lightweight Summarizer (GPU Optimized)",
106
- description="Summarize text quickly using compact models ideal for low-latency and ZeroGPU Spaces."
107
  )
108
 
109
- iface.launch(ssr_mode=False)
 
2
  import re
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
5
  from nltk.corpus import stopwords
6
+ from spaces import GPU # Required for ZeroGPU Spaces
7
  import nltk
8
 
9
+ # Download stopwords if not already available
10
+ nltk.download("stopwords")
11
+ stop_words = set(stopwords.words("english"))
12
 
13
+ # Model list
14
  model_choices = {
15
  "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
16
  "T5 Small (t5-small)": "t5-small",
17
  "T5 Base (t5-base)": "t5-base",
18
+ "Pegasus XSum (google/pegasus-xsum)": "google/pegasus-xsum",
19
+ "BART CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
20
  }
21
 
22
  model_cache = {}
23
 
24
+ # Clean text: remove special characters and stop words
25
  def clean_text(input_text):
26
+ cleaned = re.sub(r"[^A-Za-z0-9\s]", " ", input_text)
27
+ words = cleaned.split()
28
+ words = [word for word in words if word.lower() not in stop_words]
29
+ return " ".join(words).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Load model and tokenizer
32
  def load_model(model_name):
33
  if model_name not in model_cache:
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
  model = AutoModelForSeq2SeqLM.from_pretrained(
36
+ model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
37
  )
38
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
39
+ model_cache[model_name] = (tokenizer, model)
40
 
41
+ # Warm up
42
+ dummy_input = tokenizer("summarize: warmup", return_tensors="pt").input_ids.to(model.device)
43
  model.generate(dummy_input, max_length=10)
 
 
44
  return model_cache[model_name]
45
 
46
+ # Main function triggered by Gradio
47
+ @GPU # 👈 Required for ZeroGPU to trigger GPU spin-up
48
  def summarize_text(input_text, model_label, char_limit):
49
  if not input_text.strip():
50
  return "Please enter some text."
 
53
  model_name = model_choices[model_label]
54
  tokenizer, model = load_model(model_name)
55
 
56
+ # Prefix for T5/FLAN-style models
57
+ if "t5" in model_name.lower():
58
  input_text = "summarize: " + input_text
59
 
 
60
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
61
+ input_ids = inputs["input_ids"].to(model.device)
 
 
 
 
62
 
 
63
  summary_ids = model.generate(
64
  input_ids,
65
+ max_length=50,
66
+ min_length=10,
67
  do_sample=False
68
  )
69
 
 
70
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
71
  return summary[:char_limit].strip()
72
 
 
76
  inputs=[
77
  gr.Textbox(lines=6, label="Enter text to summarize"),
78
  gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
79
+ gr.Slider(minimum=30, maximum=200, value=80, step=1, label="Max Character Limit")
80
  ],
81
  outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
82
+ title="🔥 Fast Summarizer (GPU-Optimized)",
83
+ description="Summarizes input using Hugging Face models with ZeroGPU support. Now faster with CUDA, float16, and warm start!"
84
  )
85
 
86
+ iface.launch()