pendar02 commited on
Commit
a12289f
·
verified ·
1 Parent(s): cfd2959

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -18
app.py CHANGED
@@ -150,46 +150,107 @@ def post_process_summary(summary):
150
  return cleaned_summary
151
 
152
  def improve_summary_generation(text, model, tokenizer):
153
- """Enhanced version of generate_summary with better parameters and post-processing"""
154
  if not isinstance(text, str) or not text.strip():
155
  return "No abstract available to summarize."
156
 
 
157
  word_count = len(text.split())
158
- if word_count < 50:
159
  return text
160
 
 
161
  formatted_text = preprocess_text(text)
162
 
163
- # Adjust generation parameters for better coherence
164
- inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
 
 
 
165
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
166
 
 
167
  with torch.no_grad():
168
  summary_ids = model.generate(
169
  **{
170
  "input_ids": inputs["input_ids"],
171
  "attention_mask": inputs["attention_mask"],
172
- "max_length": min(200, word_count + 50),
173
- "min_length": min(50, word_count),
174
- "num_beams": 5, # Increased from 4
175
- "length_penalty": 1.5, # Adjusted from 2.0
176
- "early_stopping": True,
177
  "no_repeat_ngram_size": 3,
178
- "temperature": 0.7, # Added temperature for better diversity
179
- "top_p": 0.9, # Added top_p sampling
180
- "repetition_penalty": 1.2 # Added repetition penalty
 
 
181
  }
182
  )
183
 
184
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
185
 
186
- # Apply post-processing
187
- summary = post_process_summary(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- # Check if summary is too similar to original
190
- if summary.lower() == text.lower() or len(summary.split()) / word_count > 0.9:
191
- return text
192
-
193
  return summary
194
 
195
  def generate_focused_summary(question, abstracts, model, tokenizer):
 
150
  return cleaned_summary
151
 
152
  def improve_summary_generation(text, model, tokenizer):
153
+ """Enhanced version of summary generation optimized for biomedical papers"""
154
  if not isinstance(text, str) or not text.strip():
155
  return "No abstract available to summarize."
156
 
157
+ # Don't summarize if text is too short
158
  word_count = len(text.split())
159
+ if word_count < 100: # Increased minimum length for medical texts
160
  return text
161
 
162
+ # Preprocess text
163
  formatted_text = preprocess_text(text)
164
 
165
+ # Prepare inputs
166
+ inputs = tokenizer(
167
+ formatted_text,
168
+ return_tensors="pt",
169
+ max_length=1024,
170
+ truncation=True,
171
+ padding=True
172
+ )
173
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
174
 
175
+ # Generate summary with parameters tuned for biomedical text
176
  with torch.no_grad():
177
  summary_ids = model.generate(
178
  **{
179
  "input_ids": inputs["input_ids"],
180
  "attention_mask": inputs["attention_mask"],
181
+ "max_length": 300, # Increased for medical summaries
182
+ "min_length": 100, # Increased to ensure comprehensive coverage
183
+ "num_beams": 4,
184
+ "length_penalty": 2.0, # Encourage slightly longer summaries
 
185
  "no_repeat_ngram_size": 3,
186
+ "early_stopping": True,
187
+ "do_sample": True, # Enable sampling
188
+ "top_p": 0.95, # Nucleus sampling
189
+ "temperature": 0.85, # Slightly higher temperature for medical terms
190
+ "repetition_penalty": 1.5 # Increased to avoid repeated stats/numbers
191
  }
192
  )
193
 
194
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
195
 
196
+ # Enhanced post-processing for medical text
197
+ summary = post_process_medical_summary(summary)
198
+
199
+ return summary
200
+
201
+ def post_process_medical_summary(summary):
202
+ """Special post-processing for medical/scientific summaries"""
203
+ if not summary:
204
+ return summary
205
+
206
+ # Fix common medical text issues
207
+ summary = (summary
208
+ .replace(" p =", " p=") # Fix p-value spacing
209
+ .replace(" n =", " n=") # Fix sample size spacing
210
+ .replace("( ", "(") # Fix parentheses spacing
211
+ .replace(" )", ")")
212
+ .replace("vs.", "versus") # Expand abbreviations
213
+ .replace("..", ".") # Fix double periods
214
+ )
215
+
216
+ # Ensure statistical significance symbols are correct
217
+ summary = (summary
218
+ .replace("p < ", "p<")
219
+ .replace("p > ", "p>")
220
+ .replace("P < ", "p<")
221
+ .replace("P > ", "p>")
222
+ )
223
+
224
+ # Fix number formatting
225
+ summary = (summary
226
+ .replace(" +/- ", "±")
227
+ .replace(" ± ", "±")
228
+ )
229
+
230
+ # Split into sentences and process each
231
+ sentences = [s.strip() for s in summary.split('.')]
232
+ processed_sentences = []
233
+
234
+ for sentence in sentences:
235
+ if sentence:
236
+ # Capitalize first letter
237
+ sentence = sentence[0].upper() + sentence[1:] if sentence else sentence
238
+
239
+ # Fix common medical abbreviations spacing
240
+ sentence = (sentence
241
+ .replace(" et al ", " et al. ")
242
+ .replace("et al.", "et al.") # Fix double period
243
+ )
244
+
245
+ processed_sentences.append(sentence)
246
+
247
+ # Join sentences
248
+ summary = '. '.join(processed_sentences)
249
+
250
+ # Ensure proper ending
251
+ if summary and not summary.endswith('.'):
252
+ summary += '.'
253
 
 
 
 
 
254
  return summary
255
 
256
  def generate_focused_summary(question, abstracts, model, tokenizer):