Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from peft import PeftModel | |
from text_processing import TextProcessor | |
import gc | |
import time | |
from pathlib import Path | |
# Configure page | |
st.set_page_config( | |
page_title="Biomedical Papers Analysis", | |
page_icon="🔬", | |
layout="wide" | |
) | |
# Initialize session state | |
if 'processed_data' not in st.session_state: | |
st.session_state.processed_data = None | |
if 'summaries' not in st.session_state: | |
st.session_state.summaries = None | |
if 'text_processor' not in st.session_state: | |
st.session_state.text_processor = None | |
def manage_resources(): | |
"""Clear memory and ensure resources are available""" | |
# Force garbage collection | |
gc.collect() | |
# Clear CUDA cache if available | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Set torch to use CPU | |
torch.set_num_threads(8) # Use half of available CPU threads for each model | |
def load_model(model_type): | |
"""Load appropriate model based on type with resource management""" | |
manage_resources() | |
try: | |
if model_type == "summarize": | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/bart-large-cnn", | |
cache_dir="./models", | |
device_map=None, # Explicitly set to None for CPU | |
torch_dtype=torch.float32 | |
).to("cpu") # Force CPU | |
model = PeftModel.from_pretrained( | |
base_model, | |
"pendar02/results", | |
device_map=None, # Explicitly set to None for CPU | |
torch_dtype=torch.float32, | |
is_trainable=False # Set to inference mode | |
).to("cpu") # Force CPU | |
tokenizer = AutoTokenizer.from_pretrained( | |
"facebook/bart-large-cnn", | |
cache_dir="./models" | |
) | |
else: # question_focused | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models", | |
device_map=None, # Explicitly set to None for CPU | |
torch_dtype=torch.float32 | |
).to("cpu") # Force CPU | |
model = PeftModel.from_pretrained( | |
base_model, | |
"pendar02/biobart-finetune", | |
device_map=None, # Explicitly set to None for CPU | |
torch_dtype=torch.float32, | |
is_trainable=False # Set to inference mode | |
).to("cpu") # Force CPU | |
tokenizer = AutoTokenizer.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models" | |
) | |
model.eval() # Set to evaluation mode | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
raise | |
def process_excel(uploaded_file): | |
"""Process uploaded Excel file""" | |
try: | |
df = pd.read_excel(uploaded_file) | |
required_columns = ['Abstract', 'Article Title', 'Authors', | |
'Source Title', 'Publication Year', 'DOI'] | |
# Check required columns | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
st.error(f"Missing required columns: {', '.join(missing_columns)}") | |
return None | |
return df[required_columns] | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
return None | |
def preprocess_text(text): | |
"""Preprocess text to add appropriate formatting before summarization""" | |
if not isinstance(text, str) or not text.strip(): | |
return text | |
# Split text into sentences (basic implementation) | |
sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')] | |
# Remove empty sentences | |
sentences = [s for s in sentences if s] | |
# Join with proper line breaks | |
formatted_text = '\n'.join(sentences) | |
return formatted_text | |
def generate_summary(text, model, tokenizer): | |
"""Generate summary for single abstract""" | |
if not isinstance(text, str) or not text.strip(): | |
return "No abstract available to summarize." | |
# Preprocess the text first | |
formatted_text = preprocess_text(text) | |
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 150, | |
"min_length": 50, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True | |
} | |
) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def generate_focused_summary(question, abstracts, model, tokenizer): | |
"""Generate focused summary based on question""" | |
# Preprocess each abstract | |
formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts] | |
combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts) | |
inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 200, | |
"min_length": 50, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True | |
} | |
) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def main(): | |
st.title("🔬 Biomedical Papers Analysis") | |
# File upload section | |
uploaded_file = st.file_uploader( | |
"Upload Excel file containing papers", | |
type=['xlsx', 'xls'], | |
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" | |
) | |
if uploaded_file is not None: | |
# Process Excel file | |
if st.session_state.processed_data is None: | |
with st.spinner("Processing file..."): | |
df = process_excel(uploaded_file) | |
if df is not None: | |
st.session_state.processed_data = df.dropna(subset=["Abstract"]) | |
if st.session_state.processed_data is not None: | |
df = st.session_state.processed_data | |
st.write(f"📊 Loaded {len(df)} papers") | |
# Individual Summaries Section | |
st.header("📝 Individual Paper Summaries") | |
# Question input before the unified generate button | |
st.header("❓ Question-focused Summary (Optional)") | |
question = st.text_input("Enter your research question (optional):") | |
# Unified generate button | |
if st.button("Generate Analysis"): | |
try: | |
# Step 1: Generate Individual Summaries | |
if st.session_state.summaries is None: | |
with st.spinner("Generating individual summaries..."): | |
model, tokenizer = load_model("summarize") | |
progress_text = st.empty() | |
progress_bar = st.progress(0) | |
summary_display = st.container() | |
summaries = [] | |
for i, (_, row) in enumerate(df.iterrows()): | |
progress_text.text(f"Processing paper {i+1} of {len(df)}") | |
progress_bar.progress((i + 1) / len(df)) | |
summary = generate_summary(row['Abstract'], model, tokenizer) | |
summaries.append(summary) | |
with summary_display: | |
st.write(f"**Paper {i+1}:** {row['Article Title']}") | |
st.write(summary) | |
st.divider() | |
st.session_state.summaries = summaries | |
# Clear memory after individual summaries | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Step 2: Generate Question-Focused Summary (only if question is provided) | |
if question.strip(): | |
with st.spinner("Generating question-focused summary..."): | |
# Clear memory before question processing | |
torch.cuda.empty_cache() | |
gc.collect() | |
results = st.session_state.text_processor.find_most_relevant_abstracts( | |
question, | |
df['Abstract'].tolist(), | |
top_k=5 | |
) | |
model, tokenizer = load_model("question_focused") | |
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() | |
focused_summary = generate_focused_summary( | |
question, | |
relevant_abstracts, | |
model, | |
tokenizer | |
) | |
st.subheader("Question-Focused Summary") | |
st.write(focused_summary) | |
st.subheader("Most Relevant Papers") | |
relevant_papers = df.iloc[results['top_indices']][ | |
['Article Title', 'Authors', 'Publication Year', 'DOI'] | |
] | |
relevant_papers['Relevance Score'] = results['scores'] | |
relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) | |
st.dataframe( | |
relevant_papers, | |
column_config={ | |
'Publication Year': st.column_config.NumberColumn('Year', format="%d"), | |
'Relevance Score': st.column_config.NumberColumn('Relevance', format="%.3f") | |
}, | |
hide_index=True | |
) | |
# Clear memory after question processing | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception as e: | |
st.error(f"Error in analysis: {str(e)}") | |
# Display sorted summaries if they exist | |
if st.session_state.summaries is not None: | |
st.subheader("All Paper Summaries") | |
sort_options = { | |
'Article Title': 'Article Title', | |
'Authors': 'Authors', | |
'Publication Year': 'Publication Year', | |
'Source Title': 'Source Title' | |
} | |
col1, col2 = st.columns(2) | |
with col1: | |
sort_column = st.selectbox("Sort by:", list(sort_options.keys())) | |
with col2: | |
ascending = st.checkbox("Ascending order", True) | |
display_df = df.copy() | |
display_df['Summary'] = st.session_state.summaries | |
display_df['Publication Year'] = display_df['Publication Year'].astype(int) | |
sorted_df = display_df.sort_values(by=sort_options[sort_column], ascending=ascending) | |
st.dataframe( | |
sorted_df[['Article Title', 'Authors', 'Source Title', | |
'Publication Year', 'DOI', 'Summary']], | |
column_config={ | |
'Article Title': st.column_config.TextColumn('Article Title', width='medium'), | |
'Authors': st.column_config.TextColumn('Authors', width='medium'), | |
'Source Title': st.column_config.TextColumn('Source Title', width='medium'), | |
'Publication Year': st.column_config.NumberColumn('Year', format="%d"), | |
'DOI': st.column_config.TextColumn('DOI', width='small'), | |
'Summary': st.column_config.TextColumn('Summary', width='large'), | |
}, | |
hide_index=True | |
) | |
if __name__ == "__main__": | |
main() | |