madankn79 commited on
Commit
53d5734
·
1 Parent(s): 9388e0e
Files changed (1) hide show
  1. app.py +30 -16
app.py CHANGED
@@ -3,14 +3,13 @@ import re
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from nltk.corpus import stopwords
6
- from spaces import GPU
7
  import nltk
8
 
9
- # Download stopwords if not already
10
  nltk.download('stopwords')
11
  stop_words = set(stopwords.words('english'))
12
 
13
- # Model choices
14
  model_choices = {
15
  "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
16
  "T5 Small (t5-small)": "t5-small",
@@ -21,12 +20,27 @@ model_choices = {
21
 
22
  model_cache = {}
23
 
 
24
  def clean_text(input_text):
25
- cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', input_text)
 
 
 
 
 
 
26
  words = cleaned_text.split()
27
  words = [word for word in words if word.lower() not in stop_words]
28
- return " ".join(words).strip()
 
 
 
 
 
 
 
29
 
 
30
  def load_model(model_name):
31
  if model_name not in model_cache:
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -35,16 +49,16 @@ def load_model(model_name):
35
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
36
  )
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- model = model.to(device)
39
- model_cache[model_name] = (tokenizer, model)
40
 
41
- # Warm up with dummy input
42
- dummy_input = tokenizer("summarize: hello world", return_tensors="pt").input_ids.to(device)
43
  model.generate(dummy_input, max_length=10)
44
 
 
45
  return model_cache[model_name]
46
 
47
- @GPU # 👈 Required for ZeroGPU to allocate GPU when this is called
48
  def summarize_text(input_text, model_label, char_limit):
49
  if not input_text.strip():
50
  return "Please enter some text."
@@ -62,8 +76,8 @@ def summarize_text(input_text, model_label, char_limit):
62
 
63
  summary_ids = model.generate(
64
  input_ids,
65
- max_length=20,
66
- min_length=5,
67
  do_sample=False
68
  )
69
 
@@ -75,12 +89,12 @@ iface = gr.Interface(
75
  fn=summarize_text,
76
  inputs=[
77
  gr.Textbox(lines=6, label="Enter text to summarize"),
78
- gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="Pegasus (google/pegasus-xsum)"),
79
- gr.Slider(minimum=10, maximum=100, value=65, step=1, label="Max Character Limit")
80
  ],
81
  outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
82
- title="Multi-Model Text Summarizer (GPU Ready)",
83
- description="Summarize long or short texts using state-of-the-art Hugging Face models with GPU acceleration (ZeroGPU-compatible)."
84
  )
85
 
86
  iface.launch()
 
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from nltk.corpus import stopwords
 
6
  import nltk
7
 
8
+ # Download NLTK stopwords
9
  nltk.download('stopwords')
10
  stop_words = set(stopwords.words('english'))
11
 
12
+ # Best lightweight summarization models
13
  model_choices = {
14
  "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
15
  "T5 Small (t5-small)": "t5-small",
 
20
 
21
  model_cache = {}
22
 
23
+ # Clean input text (remove stopwords and SKUs/product codes)
24
  def clean_text(input_text):
25
+ # Remove simple SKU codes (e.g., ST1642, AB1234, etc.)
26
+ cleaned_text = re.sub(r'\b[A-Za-z]{2,}[0-9]{3,}\b', '', input_text) # Alphanumeric SKU
27
+
28
+ # Replace special characters with a space
29
+ cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', cleaned_text)
30
+
31
+ # Tokenize the input text and remove stop words
32
  words = cleaned_text.split()
33
  words = [word for word in words if word.lower() not in stop_words]
34
+
35
+ # Rebuild the cleaned text
36
+ cleaned_text = " ".join(words)
37
+
38
+ # Strip leading and trailing spaces
39
+ cleaned_text = cleaned_text.strip()
40
+
41
+ return cleaned_text
42
 
43
+ # Load model and tokenizer
44
  def load_model(model_name):
45
  if model_name not in model_cache:
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
49
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
50
  )
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ model.to(device)
 
53
 
54
+ # Warm-up
55
+ dummy_input = tokenizer("summarize: warm up", return_tensors="pt").input_ids.to(device)
56
  model.generate(dummy_input, max_length=10)
57
 
58
+ model_cache[model_name] = (tokenizer, model)
59
  return model_cache[model_name]
60
 
61
+ # Summarize the text using a selected model
62
  def summarize_text(input_text, model_label, char_limit):
63
  if not input_text.strip():
64
  return "Please enter some text."
 
76
 
77
  summary_ids = model.generate(
78
  input_ids,
79
+ max_length=30, # Ensure max_length is greater than min_length
80
+ min_length=5, # Ensure min_length is less than max_length
81
  do_sample=False
82
  )
83
 
 
89
  fn=summarize_text,
90
  inputs=[
91
  gr.Textbox(lines=6, label="Enter text to summarize"),
92
+ gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
93
+ gr.Slider(minimum=30, maximum=200, value=80, step=1, label="Max Character Limit")
94
  ],
95
  outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
96
+ title="🚀 Fast Lightweight Summarizer (GPU Optimized)",
97
+ description="Summarize text quickly using compact models ideal for low-latency and ZeroGPU Spaces."
98
  )
99
 
100
  iface.launch()