Spaces:
Running
Running
# app.py | |
import gc | |
import os | |
import sys | |
import warnings | |
from typing import Optional, Tuple | |
import pandas as pd | |
import streamlit as st | |
import torch | |
from torch.utils.data import DataLoader | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
# Local imports | |
sys.path.append( | |
os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward")) | |
) | |
from generation_utils import ReactionT5Dataset, decode_output, save_multiple_predictions | |
from train import preprocess_df | |
from utils import seed_everything | |
warnings.filterwarnings("ignore") | |
# ----------------------------- | |
# Page / Theme / Global Styles | |
# ----------------------------- | |
# Subtle modern styles (card-like blocks, nicer headers, compact tables) | |
st.markdown( | |
""" | |
<style> | |
/* Base */ | |
.block-container {padding-top: 1.5rem; padding-bottom: 2rem;} | |
h1, h2, h3 { letter-spacing: .2px; } | |
.st-emotion-cache-1jicfl2 {padding: 1rem !important;} /* tabs pad (HF class may vary)*/ | |
/* Card container */ | |
.card { | |
border-radius: 18px; | |
padding: 1rem 1.2rem; | |
border: 1px solid rgba(127,127,127,0.15); | |
background: rgba(250,250,250,0.6); | |
backdrop-filter: blur(6px); | |
} | |
[data-baseweb="select"] div { border-radius: 12px !important; } | |
/* Buttons */ | |
.stButton>button { | |
border-radius: 12px; | |
padding: .6rem 1rem; | |
font-weight: 600; | |
} | |
/* Badges */ | |
.badge { | |
display:inline-block; | |
padding: .35em .6em; | |
border-radius: 10px; | |
background: rgba(0,0,0,.08); | |
font-size: .82rem; | |
margin-right: .4rem; | |
} | |
/* Tables */ | |
.dataframe td, .dataframe th { font-size: 0.92rem; } | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# ----------------------------- | |
# Header | |
# ----------------------------- | |
col_l, col_r = st.columns([0.78, 0.22]) | |
with col_l: | |
st.title("ReactionT5 • Task Forward") | |
st.markdown( | |
""" | |
Predict **reaction products** from inputs formatted as | |
`REACTANT:{reactants}REAGENT:{reagents}` | |
For multiple compounds: join with `"."` • If no reagent: use a single space `" "`. | |
""" | |
) | |
with col_r: | |
st.markdown("<div class='card'>", unsafe_allow_html=True) | |
st.markdown("**Status**") | |
gpu = torch.cuda.is_available() | |
st.markdown( | |
f""" | |
<span class='badge'>Device: {"CUDA" if gpu else "CPU"}</span> | |
<span class='badge'>Transformers</span> | |
<span class='badge'>Streamlit</span> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# ----------------------------- | |
# Sidebar: Controls / Parameters | |
# ----------------------------- | |
with st.sidebar: | |
st.header("Settings") | |
st.caption("Model") | |
model_name_or_path = st.text_input( | |
"Model name or path", | |
value="sagawa/ReactionT5v2-forward", | |
help="Hugging Face Hub repo or local path", | |
) | |
st.divider() | |
st.caption("Generation") | |
num_beams = st.slider("num_beams", 1, 10, 5, 1) | |
num_return_sequences = st.slider("num_return_sequences", 1, num_beams, num_beams, 1) | |
output_max_length = st.slider("max_length", 64, 512, 300, 8) | |
output_min_length = st.number_input("min_length", value=-1, step=1) | |
st.caption("Batch / Reproducibility") | |
batch_size = st.slider("batch_size", 1, 8, 1, 1) | |
seed = st.number_input("seed", value=42, step=1) | |
st.caption("Tokenizer / Input") | |
input_max_length = st.slider("input_max_length", 64, 512, 400, 8) | |
st.info( | |
"Rough guide: ~15 sec / reaction with `num_beams=5`.", | |
) | |
# ----------------------------- | |
# Helper: caching | |
# ----------------------------- | |
def load_model_and_tokenizer( | |
path_or_name: str, | |
) -> Tuple[AutoModelForSeq2SeqLM, AutoTokenizer]: | |
tok = AutoTokenizer.from_pretrained( | |
os.path.abspath(path_or_name) if os.path.exists(path_or_name) else path_or_name, | |
return_tensors="pt", | |
) | |
mdl = AutoModelForSeq2SeqLM.from_pretrained( | |
os.path.abspath(path_or_name) if os.path.exists(path_or_name) else path_or_name | |
) | |
return mdl, tok | |
def read_demo_csv() -> str: | |
df = pd.read_csv("data/demo_reaction_data.csv") | |
return df.to_csv(index=False) | |
def to_csv_bytes(df: pd.DataFrame) -> bytes: | |
return df.to_csv(index=False).encode("utf-8") | |
# ----------------------------- | |
# I/O Tabs | |
# ----------------------------- | |
tabs = st.tabs(["Input", "Output", "Guide"]) | |
with tabs[0]: | |
st.markdown("<div class='card'>", unsafe_allow_html=True) | |
st.subheader("Provide your input") | |
input_mode = st.radio( | |
"Choose input mode", | |
options=("CSV upload", "Text area"), | |
horizontal=True, | |
) | |
csv_buffer: Optional[bytes] = None | |
text_area_value: Optional[str] = None | |
if input_mode == "CSV upload": | |
st.caption('CSV must contain an `"input"` column.') | |
up = st.file_uploader("Upload CSV", type=["csv"]) | |
if up is not None: | |
csv_buffer = up.read() | |
st.success("CSV uploaded.") | |
st.download_button( | |
label="Download demo_reaction_data.csv", | |
data=read_demo_csv(), | |
file_name="demo_reaction_data.csv", | |
mime="text/csv", | |
use_container_width=True, | |
) | |
else: | |
st.caption('Each line will be treated as one sample in the `"input"` column.') | |
text_area_value = st.text_area( | |
"Enter one or more inputs (one per line)", | |
height=140, | |
placeholder="REACTANT:CCO.REAGENT:O\nREACTANT:CC(=O)O.REAGENT: ", | |
) | |
st.markdown("</div>", unsafe_allow_html=True) | |
with tabs[2]: | |
st.markdown("<div class='card'>", unsafe_allow_html=True) | |
st.subheader("Formatting rules") | |
st.markdown( | |
""" | |
- **Template**: `REACTANT:{reactants}REAGENT:{reagents}` | |
- **Multiple compounds**: join with `"."` | |
- **No reagent**: provide a single space `" "` after `REAGENT:` | |
- **CSV schema**: must contain an `input` column | |
- **Outputs**: predicted products (SMILES) and sum of log-likelihood per hypothesis | |
""" | |
) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# ----------------------------- | |
# Predict Button | |
# ----------------------------- | |
run = st.button("🚀 Predict", use_container_width=True) | |
# ----------------------------- | |
# Execution | |
# ----------------------------- | |
if run: | |
# Validate input | |
if input_mode == "CSV upload" and not csv_buffer: | |
st.error( | |
"Please upload a CSV file with an `input` column, or switch to Text area." | |
) | |
st.stop() | |
if input_mode == "Text area" and ( | |
text_area_value is None or not text_area_value.strip() | |
): | |
st.error("Please enter at least one line of input.") | |
st.stop() | |
with st.status("Initializing model & tokenizer…", expanded=False) as status: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
seed_everything(seed=seed) | |
model, tokenizer = load_model_and_tokenizer(model_name_or_path) | |
model = model.to(device).eval() | |
status.update(label="Model ready", state="complete") | |
# Prepare dataframe | |
if input_mode == "CSV upload": | |
df_in = pd.read_csv(pd.io.common.BytesIO(csv_buffer)) | |
else: | |
lines = [x.strip() for x in text_area_value.splitlines() if x.strip()] | |
df_in = pd.DataFrame({"input": lines}) | |
# Preprocess and dataset | |
try: | |
df_in = preprocess_df(df_in, drop_duplicates=False) | |
except Exception as e: | |
st.error(f"Input preprocessing failed: {e}") | |
st.stop() | |
class CFG: | |
# Configuration object used by ReactionT5Dataset/decode_output utilities | |
num_beams = num_beams | |
num_return_sequences = num_return_sequences | |
model_name_or_path = model_name_or_path | |
input_column = "input" | |
input_max_length = input_max_length | |
output_max_length = output_max_length | |
output_min_length = output_min_length | |
model = "t5" | |
seed = seed | |
batch_size = batch_size | |
device = device | |
tokenizer = tokenizer | |
dataset = ReactionT5Dataset(CFG, df_in) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
shuffle=False, | |
num_workers=0 if not torch.cuda.is_available() else 4, | |
pin_memory=torch.cuda.is_available(), | |
drop_last=False, | |
) | |
# Progress UI | |
total_steps = len(dataloader) | |
progress = st.progress(0, text=f"Running generation… 0 / {total_steps}") | |
all_sequences, all_scores = [], [] | |
try: | |
for idx, inputs in enumerate(dataloader, start=1): | |
inputs = {k: v.to(CFG.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
min_length=CFG.output_min_length, | |
max_length=CFG.output_max_length, | |
num_beams=CFG.num_beams, | |
num_return_sequences=CFG.num_return_sequences, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
sequences, scores = decode_output(output, CFG) | |
all_sequences.extend(sequences) | |
if scores: | |
all_scores.extend(scores) | |
# Memory hygiene | |
del output | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
progress.progress( | |
idx / total_steps, text=f"Running generation… {idx} / {total_steps}" | |
) | |
st.toast("Generation complete") | |
except Exception as e: | |
st.error(f"Generation failed: {e}") | |
st.stop() | |
# Save & show | |
try: | |
output_df = save_multiple_predictions(df_in, all_sequences, all_scores, CFG) | |
except Exception as e: | |
st.error(f"Post-processing failed: {e}") | |
st.stop() | |
with tabs[1]: | |
st.subheader("Results") | |
st.dataframe(output_df, use_container_width=True, hide_index=True) | |
st.download_button( | |
label="Download results (CSV)", | |
data=to_csv_bytes(output_df), | |
file_name="reactiont5_output.csv", | |
mime="text/csv", | |
use_container_width=True, | |
) | |
# ----------------------------- | |
# Footer Note | |
# ----------------------------- | |
st.markdown( | |
""" | |
<hr/> | |
<small> | |
Built with ❤️ using Streamlit & 🤗 Transformers. | |
</small> | |
""", | |
unsafe_allow_html=True, | |
) | |