ReactionT5 / app.py
sagawa's picture
Update app.py
8da1516 verified
raw
history blame
13.6 kB
import gc
import os
import warnings
from types import SimpleNamespace
import pandas as pd
import numpy as np
import streamlit as st
import torch
# Local imports
from generation_utils import (
ReactionT5Dataset,
decode_output,
save_multiple_predictions,
)
from models import ReactionT5Yield2
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import seed_everything
warnings.filterwarnings("ignore")
# ------------------------------
# Page setup
# ------------------------------
st.set_page_config(
page_title="ReactionT5",
page_icon=None,
layout="wide",
)
st.title("ReactionT5")
st.caption(
"Predict reaction products, reactants, or yields from your inputs using a pretrained ReactionT5 model."
)
with st.expander("How to format your CSV", expanded=False):
st.markdown(
"""
- Include a required `REACTANT` column.
- Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`.
- If a field lists multiple compounds, separate them with a dot (`.`).
- For details, download **demo_reaction_data.csv** and check its contents.
- Output contains predicted product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is most probable).
"""
)
# ------------------------------
# Demo data download
# ------------------------------
import io
@st.cache_data(show_spinner=False)
def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame:
# If your files are always UTF-8, this is fine:
return pd.read_csv(io.BytesIO(file_bytes))
# If you prefer explicit text decoding:
# return pd.read_csv(io.StringIO(file_bytes.decode("utf-8")))
@st.cache_data(show_spinner=False)
def load_demo_csv_as_bytes() -> bytes:
demo_df = pd.read_csv("data/demo_reaction_data.csv")
return demo_df.to_csv(index=False).encode("utf-8")
st.download_button(
label="Download demo_reaction_data.csv",
data=load_demo_csv_as_bytes(),
file_name="demo_reaction_data.csv",
mime="text/csv",
use_container_width=True,
)
st.divider()
# ------------------------------
# Sidebar: configuration
# ------------------------------
with st.sidebar:
st.header("Configuration")
task = st.selectbox(
"Task",
options=["product prediction", "retrosynthesis prediction", "yield prediction"],
index=0,
help="Choose the task to run.",
)
# Model options tied to task
if task == "product prediction":
model_options = [
"sagawa/ReactionT5v2-forward",
"sagawa/ReactionT5v2-forward-USPTO_MIT",
]
model_help = "Recommended models for product prediction."
input_max_length_default = 400
output_max_length_default = 300
from task_forward.train import preprocess_df
elif task == "retrosynthesis prediction":
model_options = [
"sagawa/ReactionT5v2-retrosynthesis",
"sagawa/ReactionT5v2-retrosynthesis-USPTO_50k",
]
model_help = "Recommended models for retrosynthesis prediction."
input_max_length_default = 100
output_max_length_default = 400
from task_retrosynthesis.train import preprocess_df
else: # yield prediction
model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
model_help = "Default model for yield prediction."
input_max_length_default = 400
from task_yield.train import preprocess_df
model_name_or_path = st.selectbox(
"Model",
options=model_options,
index=0,
help=model_help,
)
if task != "yield prediction":
num_beams = st.slider(
"Beam size",
min_value=1,
max_value=10,
value=5,
step=1,
help="Number of beams for beam search.",
)
seed = st.number_input(
"Random seed",
min_value=0,
max_value=2**32 - 1,
value=42,
step=1,
help="Seed for reproducibility.",
)
with st.expander("Advanced generation", expanded=False):
input_max_length = st.number_input(
"Input max length",
min_value=8,
max_value=1024,
value=input_max_length_default,
step=8,
)
if task != "yield prediction":
output_max_length = st.number_input(
"Output max length",
min_value=8,
max_value=1024,
value=output_max_length_default,
step=8,
)
output_min_length = st.number_input(
"Output min length",
min_value=-1,
max_value=1024,
value=-1,
step=1,
help="Use -1 to let the model decide.",
)
batch_size = st.number_input(
"Batch size", min_value=1, max_value=16, value=1, step=1
)
num_workers = st.number_input(
"DataLoader workers",
min_value=0,
max_value=8,
value=4,
step=1,
help="Set to 0 if multiprocessing is restricted in your environment.",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.caption(f"Detected device: **{device.type.upper()}**")
# ------------------------------
# Cached loaders
# ------------------------------
@st.cache_resource(show_spinner=False)
def load_tokenizer(model_ref: str):
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
return AutoTokenizer.from_pretrained(resolved, return_tensors="pt")
@st.cache_resource(show_spinner=True)
def load_model(model_ref: str, device_str: str, task: str):
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
if task != "yield prediction":
model = AutoModelForSeq2SeqLM.from_pretrained(resolved)
else:
model = ReactionT5Yield2.from_pretrained(resolved)
model.to(torch.device(device_str))
model.eval()
return model
@st.cache_data(show_spinner=False)
def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
return df.to_csv(index=False).encode("utf-8")
# ------------------------------
# Main interaction
# ------------------------------
left, right = st.columns([1.4, 1.0], vertical_alignment="top")
with left:
with st.form("predict_form", clear_on_submit=False):
uploaded = st.file_uploader(
"Upload a CSV file with reactions",
type=["csv"],
accept_multiple_files=False,
help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.",
)
run = st.form_submit_button("Predict", use_container_width=True)
if uploaded is not None:
try:
file_bytes = uploaded.getvalue()
raw_df = parse_csv_from_bytes(file_bytes)
# raw_df = pd.read_csv(uploaded)
st.subheader("Input preview")
st.dataframe(raw_df.head(20), use_container_width=True)
except Exception as e:
st.error(f"Failed to read CSV: {e}")
with right:
st.subheader("Notes")
st.markdown(
f"""
- Approximate time: about **15 seconds per reaction** when `beam size = 5` (varies by hardware).
- Results include the **sum of log-likelihoods** per prediction and are **sorted** by that value.
"""
)
st.info(
"If you encounter CUDA OOM issues, reduce max lengths or beam size, or switch to CPU."
)
# ------------------------------
# Inference
# ------------------------------
if "results_df" not in st.session_state:
st.session_state["results_df"] = None
if "last_error" not in st.session_state:
st.session_state["last_error"] = None
if run:
if uploaded is None:
st.warning("Please upload a CSV file before running prediction.")
else:
# Build config object expected by your dataset/utils
CFG = SimpleNamespace(
task=task,
num_beams=int(num_beams) if task != "yield prediction" else None,
num_return_sequences=int(num_beams)
if task != "yield prediction"
else None, # tie to beams by default
model_name_or_path=model_name_or_path,
input_column="input",
input_max_length=int(input_max_length)
if task != "yield prediction"
else None,
output_max_length=int(output_max_length)
if task != "yield prediction"
else None,
output_min_length=int(output_min_length)
if task != "yield prediction"
else None,
seed=int(seed),
batch_size=int(batch_size),
debug=False
)
seed_everything(seed=CFG.seed)
# Load model & tokenizer
with st.status("Loading model and tokenizer...", expanded=False) as status:
try:
tokenizer = load_tokenizer(CFG.model_name_or_path)
CFG.tokenizer = tokenizer
model = load_model(CFG.model_name_or_path, device.type, task)
status.update(label="Model ready.", state="complete")
except Exception as e:
st.session_state["last_error"] = f"Failed to load model: {e}"
status.update(label="Model load failed.", state="error")
st.stop()
# Prepare data
file_bytes = uploaded.getvalue()
input_df = parse_csv_from_bytes(file_bytes)
if task != "yield prediction":
input_df = preprocess_df(input_df, drop_duplicates=False)
else:
input_df = preprocess_df(input_df, cfg=CFG,drop_duplicates=False)
# Dataset & loader
dataset = ReactionT5Dataset(CFG, input_df)
dataloader = DataLoader(
dataset,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=int(num_workers),
pin_memory=(device.type == "cuda"),
drop_last=False,
)
if task == "yield prediction":
# Use custom inference function for yield prediction
prediction = []
total = len(dataloader)
progress = st.progress(0, text="Predicting yields...")
info_placeholder = st.empty()
for i, inputs in enumerate(dataloader, start=1):
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
y_preds = model(inputs)
prediction.extend(y_preds.to("cpu").numpy())
del y_preds
progress.progress(i / total, text=f"Predicting yields... {i}/{total}")
info_placeholder.caption(f"Processed batch {i} of {total}")
prediction = np.concatenate(prediction)
output_df = input_df.copy()
output_df["prediction"] = prediction
output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0)
st.session_state["results_df"] = output_df
st.success("Prediction complete.")
else:
# Generation loop with progress
all_sequences, all_scores = [], []
total = len(dataloader)
progress = st.progress(0, text="Generating predictions...")
info_placeholder = st.empty()
for i, inputs in enumerate(dataloader, start=1):
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
min_length=CFG.output_min_length,
max_length=CFG.output_max_length,
num_beams=CFG.num_beams,
num_return_sequences=CFG.num_return_sequences,
return_dict_in_generate=True,
output_scores=True,
)
sequences, scores = decode_output(output, CFG)
all_sequences.extend(sequences)
if scores:
all_scores.extend(scores)
del output
if device.type == "cuda":
torch.cuda.empty_cache()
gc.collect()
progress.progress(i / total, text=f"Generating predictions... {i}/{total}")
info_placeholder.caption(f"Processed batch {i} of {total}")
progress.empty()
info_placeholder.empty()
# Save predictions
try:
output_df = save_multiple_predictions(
input_df, all_sequences, all_scores, CFG
)
st.session_state["results_df"] = output_df
st.success("Prediction complete.")
except Exception as e:
st.session_state["last_error"] = f"Failed to assemble output: {e}"
st.error(st.session_state["last_error"])
st.stop()
# ------------------------------
# Results
# ------------------------------
if st.session_state.get("results_df") is not None:
st.subheader("Results preview")
st.dataframe(st.session_state["results_df"].head(50), use_container_width=True)
st.download_button(
label="Download predictions as CSV",
data=df_to_csv_bytes(st.session_state["results_df"]),
file_name="output.csv",
mime="text/csv",
use_container_width=True,
)
if st.session_state.get("last_error"):
st.error(st.session_state["last_error"])