ReactionT5 / app.py
sagawa's picture
Update app.py
60bbe72 verified
import gc
import os
import warnings
from types import SimpleNamespace
import pandas as pd
import numpy as np
import streamlit as st
import torch
# Local imports
from generation_utils import (
ReactionT5Dataset,
decode_output,
save_multiple_predictions,
)
from models import ReactionT5Yield2
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import seed_everything
warnings.filterwarnings("ignore")
# ------------------------------
# Page setup
# ------------------------------
st.set_page_config(
page_title="ReactionT5",
page_icon=None,
layout="wide",
)
st.title("ReactionT5")
st.caption(
"Predict reaction products, reactants, or yields from your inputs using a pretrained ReactionT5 model."
)
# ------------------------------
# Sidebar: configuration
# ------------------------------
with st.sidebar:
st.header("Configuration")
task = st.selectbox(
"Task",
options=["product prediction", "retrosynthesis prediction", "yield prediction"],
index=0,
help="Choose the task to run.",
)
with st.expander("How to format your CSV", expanded=False):
if task == "product prediction":
st.markdown(
"""
- `REACTANT` column is required.
- Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`.
- If a field lists multiple compounds, separate them with a dot (`.`).
- For details, download **demo_reaction_data.csv** and check its contents.
"""
)
elif task == "retrosynthesis prediction":
st.markdown(
"""
- `PRODUCT` column is required.
- No optional columns are used.
- If a field lists multiple compounds, separate them with a dot (`.`).
- For details, download **demo_retro_data.csv** and check its contents.
"""
)
else: # yield prediction
st.markdown(
"""
- `REACTANT` and `PRODUCT` columns are required.
- Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`.
- If a field lists multiple compounds, separate them with a dot (`.`).
- For details, download **demo_yield_data.csv** and check its contents.
- Output contains predicted **reaction yield** on a **0–100% scale**.
"""
)
# ------------------------------
# Demo data download
# ------------------------------
import io
@st.cache_data(show_spinner=False)
def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame:
# If your files are always UTF-8, this is fine:
return pd.read_csv(io.BytesIO(file_bytes))
# If you prefer explicit text decoding:
# return pd.read_csv(io.StringIO(file_bytes.decode("utf-8")))
@st.cache_data(show_spinner=False)
def load_demo_csv_as_bytes() -> bytes:
demo_df = pd.read_csv("data/demo_reaction_data.csv")
return demo_df.to_csv(index=False).encode("utf-8")
st.download_button(
label="Download demo_reaction_data.csv",
data=load_demo_csv_as_bytes(),
file_name="demo_reaction_data.csv",
mime="text/csv",
use_container_width=True,
)
st.divider()
# ------------------------------
# Sidebar: configuration
# ------------------------------
with st.sidebar:
st.header("Configuration")
# Model options tied to task
if task == "product prediction":
model_options = [
"sagawa/ReactionT5v2-forward",
"sagawa/ReactionT5v2-forward-USPTO_MIT",
]
model_help = "Recommended models for product prediction."
input_max_length_default = 400
output_max_length_default = 300
from task_forward.train import preprocess_df
elif task == "retrosynthesis prediction":
model_options = [
"sagawa/ReactionT5v2-retrosynthesis",
"sagawa/ReactionT5v2-retrosynthesis-USPTO_50k",
]
model_help = "Recommended models for retrosynthesis prediction."
input_max_length_default = 100
output_max_length_default = 400
from task_retrosynthesis.train import preprocess_df
else: # yield prediction
model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
model_help = "Default model for yield prediction."
input_max_length_default = 400
from task_yield.train import preprocess_df
model_name_or_path = st.selectbox(
"Model",
options=model_options,
index=0,
help=model_help,
)
if task != "yield prediction":
num_beams = st.slider(
"Beam size",
min_value=1,
max_value=10,
value=5,
step=1,
help="Number of beams for beam search.",
)
seed = st.number_input(
"Random seed",
min_value=0,
max_value=2**32 - 1,
value=42,
step=1,
help="Seed for reproducibility.",
)
with st.expander("Advanced generation", expanded=False):
input_max_length = st.number_input(
"Input max length",
min_value=8,
max_value=1024,
value=input_max_length_default,
step=8,
)
if task != "yield prediction":
output_max_length = st.number_input(
"Output max length",
min_value=8,
max_value=1024,
value=output_max_length_default,
step=8,
)
output_min_length = st.number_input(
"Output min length",
min_value=-1,
max_value=1024,
value=-1,
step=1,
help="Use -1 to let the model decide.",
)
batch_size = st.number_input(
"Batch size", min_value=1, max_value=16, value=1, step=1
)
num_workers = st.number_input(
"DataLoader workers",
min_value=0,
max_value=8,
value=4,
step=1,
help="Set to 0 if multiprocessing is restricted in your environment.",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.caption(f"Detected device: **{device.type.upper()}**")
# ------------------------------
# Cached loaders
# ------------------------------
@st.cache_resource(show_spinner=False)
def load_tokenizer(model_ref: str):
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
return AutoTokenizer.from_pretrained(resolved, return_tensors="pt")
@st.cache_resource(show_spinner=True)
def load_model(model_ref: str, device_str: str, task: str):
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
if task != "yield prediction":
model = AutoModelForSeq2SeqLM.from_pretrained(resolved)
else:
model = ReactionT5Yield2.from_pretrained(resolved)
model.to(torch.device(device_str))
model.eval()
return model
@st.cache_data(show_spinner=False)
def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
return df.to_csv(index=False).encode("utf-8")
# ------------------------------
# Main interaction
# ------------------------------
left, right = st.columns([1.4, 1.0], vertical_alignment="top")
with left:
with st.form("predict_form", clear_on_submit=False):
uploaded = st.file_uploader(
"Upload a CSV file with reactions",
type=["csv"],
accept_multiple_files=False,
help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.",
)
run = st.form_submit_button("Predict", use_container_width=True)
if uploaded is not None:
try:
file_bytes = uploaded.getvalue()
raw_df = parse_csv_from_bytes(file_bytes)
# raw_df = pd.read_csv(uploaded)
st.subheader("Input preview")
st.dataframe(raw_df.head(20), use_container_width=True)
except Exception as e:
st.error(f"Failed to read CSV: {e}")
with right:
st.subheader("Notes")
if task == "product prediction":
st.markdown(
f"""
- Approximate time: about **3 seconds per reaction** when `beam size = 5` (varies by hardware).
- Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable).
"""
)
elif task == "retrosynthesis prediction":
st.markdown(
f"""
- Approximate time: about **5 seconds per reaction** when `beam size = 5` (varies by hardware).
- Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable).
"""
)
else: # yield prediction
st.markdown(
f"""
- Approximate time: about **0.25 seconds per reaction** when `batch size = 1` (varies by hardware).
- Output contains predicted **reaction yield** on a **0–100% scale**.
"""
)
st.info(
"In this space, CPU is used for inference. So the speed is slower than using a GPU."
)
# ------------------------------
# Inference
# ------------------------------
if "results_df" not in st.session_state:
st.session_state["results_df"] = None
if "last_error" not in st.session_state:
st.session_state["last_error"] = None
if run:
if uploaded is None:
st.warning("Please upload a CSV file before running prediction.")
else:
# Build config object expected by your dataset/utils
CFG = SimpleNamespace(
task=task,
num_beams=int(num_beams) if task != "yield prediction" else None,
num_return_sequences=int(num_beams)
if task != "yield prediction"
else None, # tie to beams by default
model_name_or_path=model_name_or_path,
input_column="input",
input_max_length=int(input_max_length)
if task != "yield prediction"
else None,
output_max_length=int(output_max_length)
if task != "yield prediction"
else None,
output_min_length=int(output_min_length)
if task != "yield prediction"
else None,
seed=int(seed),
batch_size=int(batch_size),
debug=False
)
seed_everything(seed=CFG.seed)
# Load model & tokenizer
with st.status("Loading model and tokenizer...", expanded=False) as status:
try:
tokenizer = load_tokenizer(CFG.model_name_or_path)
CFG.tokenizer = tokenizer
model = load_model(CFG.model_name_or_path, device.type, task)
status.update(label="Model ready.", state="complete")
except Exception as e:
st.session_state["last_error"] = f"Failed to load model: {e}"
status.update(label="Model load failed.", state="error")
st.stop()
# Prepare data
file_bytes = uploaded.getvalue()
input_df = parse_csv_from_bytes(file_bytes)
if task != "yield prediction":
input_df = preprocess_df(input_df, drop_duplicates=False)
else:
input_df = preprocess_df(input_df, cfg=CFG,drop_duplicates=False)
# Dataset & loader
dataset = ReactionT5Dataset(CFG, input_df)
dataloader = DataLoader(
dataset,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=int(num_workers),
pin_memory=(device.type == "cuda"),
drop_last=False,
)
if task == "yield prediction":
# Use custom inference function for yield prediction
prediction = []
total = len(dataloader)
progress = st.progress(0, text="Predicting yields...")
info_placeholder = st.empty()
for i, inputs in enumerate(dataloader, start=1):
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
y_preds = model(inputs)
prediction.extend(y_preds.to("cpu").numpy())
del y_preds
progress.progress(i / total, text=f"Predicting yields... {i}/{total}")
info_placeholder.caption(f"Processed batch {i} of {total}")
prediction = np.concatenate(prediction)
output_df = input_df.copy()
output_df["prediction"] = prediction
output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0)
st.session_state["results_df"] = output_df
st.success("Prediction complete.")
else:
# Generation loop with progress
all_sequences, all_scores = [], []
total = len(dataloader)
progress = st.progress(0, text="Generating predictions...")
info_placeholder = st.empty()
for i, inputs in enumerate(dataloader, start=1):
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
min_length=CFG.output_min_length,
max_length=CFG.output_max_length,
num_beams=CFG.num_beams,
num_return_sequences=CFG.num_return_sequences,
return_dict_in_generate=True,
output_scores=True,
)
sequences, scores = decode_output(output, CFG)
all_sequences.extend(sequences)
if scores:
all_scores.extend(scores)
del output
if device.type == "cuda":
torch.cuda.empty_cache()
gc.collect()
progress.progress(i / total, text=f"Generating predictions... {i}/{total}")
info_placeholder.caption(f"Processed batch {i} of {total}")
progress.empty()
info_placeholder.empty()
# Save predictions
try:
output_df = save_multiple_predictions(
input_df, all_sequences, all_scores, CFG
)
st.session_state["results_df"] = output_df
st.success("Prediction complete.")
except Exception as e:
st.session_state["last_error"] = f"Failed to assemble output: {e}"
st.error(st.session_state["last_error"])
st.stop()
# ------------------------------
# Results
# ------------------------------
if st.session_state.get("results_df") is not None:
st.subheader("Results preview")
st.dataframe(st.session_state["results_df"].head(50), use_container_width=True)
st.download_button(
label="Download predictions as CSV",
data=df_to_csv_bytes(st.session_state["results_df"]),
file_name="output.csv",
mime="text/csv",
use_container_width=True,
)
if st.session_state.get("last_error"):
st.error(st.session_state["last_error"])