Spaces:
Running
Running
import gc | |
import os | |
import warnings | |
from types import SimpleNamespace | |
import pandas as pd | |
import numpy as np | |
import streamlit as st | |
import torch | |
# Local imports | |
from generation_utils import ( | |
ReactionT5Dataset, | |
decode_output, | |
save_multiple_predictions, | |
) | |
from models import ReactionT5Yield2 | |
from torch.utils.data import DataLoader | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from utils import seed_everything | |
warnings.filterwarnings("ignore") | |
# ------------------------------ | |
# Page setup | |
# ------------------------------ | |
st.set_page_config( | |
page_title="ReactionT5", | |
page_icon=None, | |
layout="wide", | |
) | |
st.title("ReactionT5") | |
st.caption( | |
"Predict reaction products, reactants, or yields from your inputs using a pretrained ReactionT5 model." | |
) | |
# ------------------------------ | |
# Sidebar: configuration | |
# ------------------------------ | |
with st.sidebar: | |
st.header("Configuration") | |
task = st.selectbox( | |
"Task", | |
options=["product prediction", "retrosynthesis prediction", "yield prediction"], | |
index=0, | |
help="Choose the task to run.", | |
) | |
with st.expander("How to format your CSV", expanded=False): | |
if task == "product prediction": | |
st.markdown( | |
""" | |
- `REACTANT` column is required. | |
- Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`. | |
- If a field lists multiple compounds, separate them with a dot (`.`). | |
- For details, download **demo_reaction_data.csv** and check its contents. | |
""" | |
) | |
elif task == "retrosynthesis prediction": | |
st.markdown( | |
""" | |
- `PRODUCT` column is required. | |
- No optional columns are used. | |
- If a field lists multiple compounds, separate them with a dot (`.`). | |
- For details, download **demo_retro_data.csv** and check its contents. | |
""" | |
) | |
else: # yield prediction | |
st.markdown( | |
""" | |
- `REACTANT` and `PRODUCT` columns are required. | |
- Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`. | |
- If a field lists multiple compounds, separate them with a dot (`.`). | |
- For details, download **demo_yield_data.csv** and check its contents. | |
- Output contains predicted **reaction yield** on a **0–100% scale**. | |
""" | |
) | |
# ------------------------------ | |
# Demo data download | |
# ------------------------------ | |
import io | |
def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame: | |
# If your files are always UTF-8, this is fine: | |
return pd.read_csv(io.BytesIO(file_bytes)) | |
# If you prefer explicit text decoding: | |
# return pd.read_csv(io.StringIO(file_bytes.decode("utf-8"))) | |
def load_demo_csv_as_bytes() -> bytes: | |
demo_df = pd.read_csv("data/demo_reaction_data.csv") | |
return demo_df.to_csv(index=False).encode("utf-8") | |
st.download_button( | |
label="Download demo_reaction_data.csv", | |
data=load_demo_csv_as_bytes(), | |
file_name="demo_reaction_data.csv", | |
mime="text/csv", | |
use_container_width=True, | |
) | |
st.divider() | |
# ------------------------------ | |
# Sidebar: configuration | |
# ------------------------------ | |
with st.sidebar: | |
st.header("Configuration") | |
# Model options tied to task | |
if task == "product prediction": | |
model_options = [ | |
"sagawa/ReactionT5v2-forward", | |
"sagawa/ReactionT5v2-forward-USPTO_MIT", | |
] | |
model_help = "Recommended models for product prediction." | |
input_max_length_default = 400 | |
output_max_length_default = 300 | |
from task_forward.train import preprocess_df | |
elif task == "retrosynthesis prediction": | |
model_options = [ | |
"sagawa/ReactionT5v2-retrosynthesis", | |
"sagawa/ReactionT5v2-retrosynthesis-USPTO_50k", | |
] | |
model_help = "Recommended models for retrosynthesis prediction." | |
input_max_length_default = 100 | |
output_max_length_default = 400 | |
from task_retrosynthesis.train import preprocess_df | |
else: # yield prediction | |
model_options = ["sagawa/ReactionT5v2-yield"] # default as requested | |
model_help = "Default model for yield prediction." | |
input_max_length_default = 400 | |
from task_yield.train import preprocess_df | |
model_name_or_path = st.selectbox( | |
"Model", | |
options=model_options, | |
index=0, | |
help=model_help, | |
) | |
if task != "yield prediction": | |
num_beams = st.slider( | |
"Beam size", | |
min_value=1, | |
max_value=10, | |
value=5, | |
step=1, | |
help="Number of beams for beam search.", | |
) | |
seed = st.number_input( | |
"Random seed", | |
min_value=0, | |
max_value=2**32 - 1, | |
value=42, | |
step=1, | |
help="Seed for reproducibility.", | |
) | |
with st.expander("Advanced generation", expanded=False): | |
input_max_length = st.number_input( | |
"Input max length", | |
min_value=8, | |
max_value=1024, | |
value=input_max_length_default, | |
step=8, | |
) | |
if task != "yield prediction": | |
output_max_length = st.number_input( | |
"Output max length", | |
min_value=8, | |
max_value=1024, | |
value=output_max_length_default, | |
step=8, | |
) | |
output_min_length = st.number_input( | |
"Output min length", | |
min_value=-1, | |
max_value=1024, | |
value=-1, | |
step=1, | |
help="Use -1 to let the model decide.", | |
) | |
batch_size = st.number_input( | |
"Batch size", min_value=1, max_value=16, value=1, step=1 | |
) | |
num_workers = st.number_input( | |
"DataLoader workers", | |
min_value=0, | |
max_value=8, | |
value=4, | |
step=1, | |
help="Set to 0 if multiprocessing is restricted in your environment.", | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
st.caption(f"Detected device: **{device.type.upper()}**") | |
# ------------------------------ | |
# Cached loaders | |
# ------------------------------ | |
def load_tokenizer(model_ref: str): | |
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref | |
return AutoTokenizer.from_pretrained(resolved, return_tensors="pt") | |
def load_model(model_ref: str, device_str: str, task: str): | |
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref | |
if task != "yield prediction": | |
model = AutoModelForSeq2SeqLM.from_pretrained(resolved) | |
else: | |
model = ReactionT5Yield2.from_pretrained(resolved) | |
model.to(torch.device(device_str)) | |
model.eval() | |
return model | |
def df_to_csv_bytes(df: pd.DataFrame) -> bytes: | |
return df.to_csv(index=False).encode("utf-8") | |
# ------------------------------ | |
# Main interaction | |
# ------------------------------ | |
left, right = st.columns([1.4, 1.0], vertical_alignment="top") | |
with left: | |
with st.form("predict_form", clear_on_submit=False): | |
uploaded = st.file_uploader( | |
"Upload a CSV file with reactions", | |
type=["csv"], | |
accept_multiple_files=False, | |
help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.", | |
) | |
run = st.form_submit_button("Predict", use_container_width=True) | |
if uploaded is not None: | |
try: | |
file_bytes = uploaded.getvalue() | |
raw_df = parse_csv_from_bytes(file_bytes) | |
# raw_df = pd.read_csv(uploaded) | |
st.subheader("Input preview") | |
st.dataframe(raw_df.head(20), use_container_width=True) | |
except Exception as e: | |
st.error(f"Failed to read CSV: {e}") | |
with right: | |
st.subheader("Notes") | |
if task == "product prediction": | |
st.markdown( | |
f""" | |
- Approximate time: about **3 seconds per reaction** when `beam size = 5` (varies by hardware). | |
- Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable). | |
""" | |
) | |
elif task == "retrosynthesis prediction": | |
st.markdown( | |
f""" | |
- Approximate time: about **5 seconds per reaction** when `beam size = 5` (varies by hardware). | |
- Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable). | |
""" | |
) | |
else: # yield prediction | |
st.markdown( | |
f""" | |
- Approximate time: about **0.25 seconds per reaction** when `batch size = 1` (varies by hardware). | |
- Output contains predicted **reaction yield** on a **0–100% scale**. | |
""" | |
) | |
st.info( | |
"In this space, CPU is used for inference. So the speed is slower than using a GPU." | |
) | |
# ------------------------------ | |
# Inference | |
# ------------------------------ | |
if "results_df" not in st.session_state: | |
st.session_state["results_df"] = None | |
if "last_error" not in st.session_state: | |
st.session_state["last_error"] = None | |
if run: | |
if uploaded is None: | |
st.warning("Please upload a CSV file before running prediction.") | |
else: | |
# Build config object expected by your dataset/utils | |
CFG = SimpleNamespace( | |
task=task, | |
num_beams=int(num_beams) if task != "yield prediction" else None, | |
num_return_sequences=int(num_beams) | |
if task != "yield prediction" | |
else None, # tie to beams by default | |
model_name_or_path=model_name_or_path, | |
input_column="input", | |
input_max_length=int(input_max_length) | |
if task != "yield prediction" | |
else None, | |
output_max_length=int(output_max_length) | |
if task != "yield prediction" | |
else None, | |
output_min_length=int(output_min_length) | |
if task != "yield prediction" | |
else None, | |
seed=int(seed), | |
batch_size=int(batch_size), | |
debug=False | |
) | |
seed_everything(seed=CFG.seed) | |
# Load model & tokenizer | |
with st.status("Loading model and tokenizer...", expanded=False) as status: | |
try: | |
tokenizer = load_tokenizer(CFG.model_name_or_path) | |
CFG.tokenizer = tokenizer | |
model = load_model(CFG.model_name_or_path, device.type, task) | |
status.update(label="Model ready.", state="complete") | |
except Exception as e: | |
st.session_state["last_error"] = f"Failed to load model: {e}" | |
status.update(label="Model load failed.", state="error") | |
st.stop() | |
# Prepare data | |
file_bytes = uploaded.getvalue() | |
input_df = parse_csv_from_bytes(file_bytes) | |
if task != "yield prediction": | |
input_df = preprocess_df(input_df, drop_duplicates=False) | |
else: | |
input_df = preprocess_df(input_df, cfg=CFG,drop_duplicates=False) | |
# Dataset & loader | |
dataset = ReactionT5Dataset(CFG, input_df) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
shuffle=False, | |
num_workers=int(num_workers), | |
pin_memory=(device.type == "cuda"), | |
drop_last=False, | |
) | |
if task == "yield prediction": | |
# Use custom inference function for yield prediction | |
prediction = [] | |
total = len(dataloader) | |
progress = st.progress(0, text="Predicting yields...") | |
info_placeholder = st.empty() | |
for i, inputs in enumerate(dataloader, start=1): | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
y_preds = model(inputs) | |
prediction.extend(y_preds.to("cpu").numpy()) | |
del y_preds | |
progress.progress(i / total, text=f"Predicting yields... {i}/{total}") | |
info_placeholder.caption(f"Processed batch {i} of {total}") | |
prediction = np.concatenate(prediction) | |
output_df = input_df.copy() | |
output_df["prediction"] = prediction | |
output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0) | |
st.session_state["results_df"] = output_df | |
st.success("Prediction complete.") | |
else: | |
# Generation loop with progress | |
all_sequences, all_scores = [], [] | |
total = len(dataloader) | |
progress = st.progress(0, text="Generating predictions...") | |
info_placeholder = st.empty() | |
for i, inputs in enumerate(dataloader, start=1): | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
min_length=CFG.output_min_length, | |
max_length=CFG.output_max_length, | |
num_beams=CFG.num_beams, | |
num_return_sequences=CFG.num_return_sequences, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
sequences, scores = decode_output(output, CFG) | |
all_sequences.extend(sequences) | |
if scores: | |
all_scores.extend(scores) | |
del output | |
if device.type == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
progress.progress(i / total, text=f"Generating predictions... {i}/{total}") | |
info_placeholder.caption(f"Processed batch {i} of {total}") | |
progress.empty() | |
info_placeholder.empty() | |
# Save predictions | |
try: | |
output_df = save_multiple_predictions( | |
input_df, all_sequences, all_scores, CFG | |
) | |
st.session_state["results_df"] = output_df | |
st.success("Prediction complete.") | |
except Exception as e: | |
st.session_state["last_error"] = f"Failed to assemble output: {e}" | |
st.error(st.session_state["last_error"]) | |
st.stop() | |
# ------------------------------ | |
# Results | |
# ------------------------------ | |
if st.session_state.get("results_df") is not None: | |
st.subheader("Results preview") | |
st.dataframe(st.session_state["results_df"].head(50), use_container_width=True) | |
st.download_button( | |
label="Download predictions as CSV", | |
data=df_to_csv_bytes(st.session_state["results_df"]), | |
file_name="output.csv", | |
mime="text/csv", | |
use_container_width=True, | |
) | |
if st.session_state.get("last_error"): | |
st.error(st.session_state["last_error"]) | |