import gc import os import sys import warnings import pandas as pd import streamlit as st import torch from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward")) ) from generation_utils import ( ReactionT5Dataset, decode_output, save_multiple_predictions, ) from train import preprocess_df from utils import seed_everything warnings.filterwarnings("ignore") st.title("ReactionT5 task forward") st.markdown(""" ##### Predict reaction products from your inputs. ##### Upload a CSV that contains a `REACTANT` column. Optionally include `REAGENT`, `SOLVENT`, and/or `CATALYST`. ##### If a field lists multiple compounds, separate them with a dot (`.`). For details, download **demo_reaction_data.csv** and check its contents. ##### The output shows product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is the most probable). """) st.download_button( label="Download demo_reaction_data.csv", data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False), file_name="demo_reaction_data.csv", mime="text/csv", ) class CFG: num_beams = st.number_input( label="num beams", min_value=1, max_value=10, value=5, step=1 ) num_return_sequences = num_beams input_data = st.file_uploader("Choose a CSV file") model_name_or_path = "sagawa/ReactionT5v2-forward" input_column = "input" input_max_length = 400 output_max_length = 300 output_min_length = -1 model = "t5" seed = 42 batch_size = 1 if st.button("predict"): with st.spinner( "Now processing. If num beams=5, this process takes about 15 seconds per reaction." ): CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(seed=CFG.seed) CFG.tokenizer = AutoTokenizer.from_pretrained( os.path.abspath(CFG.model_name_or_path) if os.path.exists(CFG.model_name_or_path) else CFG.model_name_or_path, return_tensors="pt", ) model = AutoModelForSeq2SeqLM.from_pretrained( os.path.abspath(CFG.model_name_or_path) if os.path.exists(CFG.model_name_or_path) else CFG.model_name_or_path ).to(CFG.device) model.eval() input_data = pd.read_csv(CFG.input_data) input_data = preprocess_df(input_data, drop_duplicates=False) dataset = ReactionT5Dataset(CFG, input_data) dataloader = DataLoader( dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False, ) all_sequences, all_scores = [], [] for inputs in tqdm(dataloader, total=len(dataloader)): inputs = {k: v.to(CFG.device) for k, v in inputs.items()} with torch.no_grad(): output = model.generate( **inputs, min_length=CFG.output_min_length, max_length=CFG.output_max_length, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True, ) sequences, scores = decode_output(output, CFG) all_sequences.extend(sequences) if scores: all_scores.extend(scores) del output torch.cuda.empty_cache() gc.collect() output_df = save_multiple_predictions( input_data, all_sequences, all_scores, CFG ) @st.cache def convert_df(df): return df.to_csv(index=False) csv = convert_df(output_df) st.download_button( label="Download data as CSV", data=csv, file_name="output.csv", mime="text/csv", )