import gc import os import warnings from types import SimpleNamespace import pandas as pd import numpy as np import streamlit as st import torch # Local imports from generation_utils import ( ReactionT5Dataset, decode_output, save_multiple_predictions, ) from models import ReactionT5Yield2 from torch.utils.data import DataLoader from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from utils import seed_everything warnings.filterwarnings("ignore") # ------------------------------ # Page setup # ------------------------------ st.set_page_config( page_title="ReactionT5", page_icon=None, layout="wide", ) st.title("ReactionT5") st.caption( "Predict reaction products, reactants, or yields from your inputs using a pretrained ReactionT5 model." ) # ------------------------------ # Sidebar: configuration # ------------------------------ with st.sidebar: st.header("Configuration") task = st.selectbox( "Task", options=["product prediction", "retrosynthesis prediction", "yield prediction"], index=0, help="Choose the task to run.", ) with st.expander("How to format your CSV", expanded=False): if task == "product prediction": st.markdown( """ - `REACTANT` column is required. - Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`. - If a field lists multiple compounds, separate them with a dot (`.`). - For details, download **demo_reaction_data.csv** and check its contents. """ ) elif task == "retrosynthesis prediction": st.markdown( """ - `PRODUCT` column is required. - No optional columns are used. - If a field lists multiple compounds, separate them with a dot (`.`). - For details, download **demo_retro_data.csv** and check its contents. """ ) else: # yield prediction st.markdown( """ - `REACTANT` and `PRODUCT` columns are required. - Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`. - If a field lists multiple compounds, separate them with a dot (`.`). - For details, download **demo_yield_data.csv** and check its contents. - Output contains predicted **reaction yield** on a **0–100% scale**. """ ) # ------------------------------ # Demo data download # ------------------------------ import io @st.cache_data(show_spinner=False) def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame: # If your files are always UTF-8, this is fine: return pd.read_csv(io.BytesIO(file_bytes)) # If you prefer explicit text decoding: # return pd.read_csv(io.StringIO(file_bytes.decode("utf-8"))) @st.cache_data(show_spinner=False) def load_demo_csv_as_bytes() -> bytes: demo_df = pd.read_csv("data/demo_reaction_data.csv") return demo_df.to_csv(index=False).encode("utf-8") st.download_button( label="Download demo_reaction_data.csv", data=load_demo_csv_as_bytes(), file_name="demo_reaction_data.csv", mime="text/csv", use_container_width=True, ) st.divider() # ------------------------------ # Sidebar: configuration # ------------------------------ with st.sidebar: st.header("Configuration") # Model options tied to task if task == "product prediction": model_options = [ "sagawa/ReactionT5v2-forward", "sagawa/ReactionT5v2-forward-USPTO_MIT", ] model_help = "Recommended models for product prediction." input_max_length_default = 400 output_max_length_default = 300 from task_forward.train import preprocess_df elif task == "retrosynthesis prediction": model_options = [ "sagawa/ReactionT5v2-retrosynthesis", "sagawa/ReactionT5v2-retrosynthesis-USPTO_50k", ] model_help = "Recommended models for retrosynthesis prediction." input_max_length_default = 100 output_max_length_default = 400 from task_retrosynthesis.train import preprocess_df else: # yield prediction model_options = ["sagawa/ReactionT5v2-yield"] # default as requested model_help = "Default model for yield prediction." input_max_length_default = 400 from task_yield.train import preprocess_df model_name_or_path = st.selectbox( "Model", options=model_options, index=0, help=model_help, ) if task != "yield prediction": num_beams = st.slider( "Beam size", min_value=1, max_value=10, value=5, step=1, help="Number of beams for beam search.", ) seed = st.number_input( "Random seed", min_value=0, max_value=2**32 - 1, value=42, step=1, help="Seed for reproducibility.", ) with st.expander("Advanced generation", expanded=False): input_max_length = st.number_input( "Input max length", min_value=8, max_value=1024, value=input_max_length_default, step=8, ) if task != "yield prediction": output_max_length = st.number_input( "Output max length", min_value=8, max_value=1024, value=output_max_length_default, step=8, ) output_min_length = st.number_input( "Output min length", min_value=-1, max_value=1024, value=-1, step=1, help="Use -1 to let the model decide.", ) batch_size = st.number_input( "Batch size", min_value=1, max_value=16, value=1, step=1 ) num_workers = st.number_input( "DataLoader workers", min_value=0, max_value=8, value=4, step=1, help="Set to 0 if multiprocessing is restricted in your environment.", ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.caption(f"Detected device: **{device.type.upper()}**") # ------------------------------ # Cached loaders # ------------------------------ @st.cache_resource(show_spinner=False) def load_tokenizer(model_ref: str): resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref return AutoTokenizer.from_pretrained(resolved, return_tensors="pt") @st.cache_resource(show_spinner=True) def load_model(model_ref: str, device_str: str, task: str): resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref if task != "yield prediction": model = AutoModelForSeq2SeqLM.from_pretrained(resolved) else: model = ReactionT5Yield2.from_pretrained(resolved) model.to(torch.device(device_str)) model.eval() return model @st.cache_data(show_spinner=False) def df_to_csv_bytes(df: pd.DataFrame) -> bytes: return df.to_csv(index=False).encode("utf-8") # ------------------------------ # Main interaction # ------------------------------ left, right = st.columns([1.4, 1.0], vertical_alignment="top") with left: with st.form("predict_form", clear_on_submit=False): uploaded = st.file_uploader( "Upload a CSV file with reactions", type=["csv"], accept_multiple_files=False, help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.", ) run = st.form_submit_button("Predict", use_container_width=True) if uploaded is not None: try: file_bytes = uploaded.getvalue() raw_df = parse_csv_from_bytes(file_bytes) # raw_df = pd.read_csv(uploaded) st.subheader("Input preview") st.dataframe(raw_df.head(20), use_container_width=True) except Exception as e: st.error(f"Failed to read CSV: {e}") with right: st.subheader("Notes") if task == "product prediction": st.markdown( f""" - Approximate time: about **3 seconds per reaction** when `beam size = 5` (varies by hardware). - Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable). """ ) elif task == "retrosynthesis prediction": st.markdown( f""" - Approximate time: about **5 seconds per reaction** when `beam size = 5` (varies by hardware). - Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable). """ ) else: # yield prediction st.markdown( f""" - Approximate time: about **0.25 seconds per reaction** when `batch size = 1` (varies by hardware). - Output contains predicted **reaction yield** on a **0–100% scale**. """ ) st.info( "In this space, CPU is used for inference. So the speed is slower than using a GPU." ) # ------------------------------ # Inference # ------------------------------ if "results_df" not in st.session_state: st.session_state["results_df"] = None if "last_error" not in st.session_state: st.session_state["last_error"] = None if run: if uploaded is None: st.warning("Please upload a CSV file before running prediction.") else: # Build config object expected by your dataset/utils CFG = SimpleNamespace( task=task, num_beams=int(num_beams) if task != "yield prediction" else None, num_return_sequences=int(num_beams) if task != "yield prediction" else None, # tie to beams by default model_name_or_path=model_name_or_path, input_column="input", input_max_length=int(input_max_length) if task != "yield prediction" else None, output_max_length=int(output_max_length) if task != "yield prediction" else None, output_min_length=int(output_min_length) if task != "yield prediction" else None, seed=int(seed), batch_size=int(batch_size), debug=False ) seed_everything(seed=CFG.seed) # Load model & tokenizer with st.status("Loading model and tokenizer...", expanded=False) as status: try: tokenizer = load_tokenizer(CFG.model_name_or_path) CFG.tokenizer = tokenizer model = load_model(CFG.model_name_or_path, device.type, task) status.update(label="Model ready.", state="complete") except Exception as e: st.session_state["last_error"] = f"Failed to load model: {e}" status.update(label="Model load failed.", state="error") st.stop() # Prepare data file_bytes = uploaded.getvalue() input_df = parse_csv_from_bytes(file_bytes) if task != "yield prediction": input_df = preprocess_df(input_df, drop_duplicates=False) else: input_df = preprocess_df(input_df, cfg=CFG,drop_duplicates=False) # Dataset & loader dataset = ReactionT5Dataset(CFG, input_df) dataloader = DataLoader( dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=int(num_workers), pin_memory=(device.type == "cuda"), drop_last=False, ) if task == "yield prediction": # Use custom inference function for yield prediction prediction = [] total = len(dataloader) progress = st.progress(0, text="Predicting yields...") info_placeholder = st.empty() for i, inputs in enumerate(dataloader, start=1): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): y_preds = model(inputs) prediction.extend(y_preds.to("cpu").numpy()) del y_preds progress.progress(i / total, text=f"Predicting yields... {i}/{total}") info_placeholder.caption(f"Processed batch {i} of {total}") prediction = np.concatenate(prediction) output_df = input_df.copy() output_df["prediction"] = prediction output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0) st.session_state["results_df"] = output_df st.success("Prediction complete.") else: # Generation loop with progress all_sequences, all_scores = [], [] total = len(dataloader) progress = st.progress(0, text="Generating predictions...") info_placeholder = st.empty() for i, inputs in enumerate(dataloader, start=1): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): output = model.generate( **inputs, min_length=CFG.output_min_length, max_length=CFG.output_max_length, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True, ) sequences, scores = decode_output(output, CFG) all_sequences.extend(sequences) if scores: all_scores.extend(scores) del output if device.type == "cuda": torch.cuda.empty_cache() gc.collect() progress.progress(i / total, text=f"Generating predictions... {i}/{total}") info_placeholder.caption(f"Processed batch {i} of {total}") progress.empty() info_placeholder.empty() # Save predictions try: output_df = save_multiple_predictions( input_df, all_sequences, all_scores, CFG ) st.session_state["results_df"] = output_df st.success("Prediction complete.") except Exception as e: st.session_state["last_error"] = f"Failed to assemble output: {e}" st.error(st.session_state["last_error"]) st.stop() # ------------------------------ # Results # ------------------------------ if st.session_state.get("results_df") is not None: st.subheader("Results preview") st.dataframe(st.session_state["results_df"].head(50), use_container_width=True) st.download_button( label="Download predictions as CSV", data=df_to_csv_bytes(st.session_state["results_df"]), file_name="output.csv", mime="text/csv", use_container_width=True, ) if st.session_state.get("last_error"): st.error(st.session_state["last_error"])