ReactionT5 / app.py
sagawa's picture
Update app.py
8a0130f verified
raw
history blame
4.11 kB
import gc
import os
import sys
import warnings
import pandas as pd
import streamlit as st
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
)
from generation_utils import (
ReactionT5Dataset,
decode_output,
save_multiple_predictions,
)
from train import preprocess_df
from utils import seed_everything
warnings.filterwarnings("ignore")
st.title("ReactionT5 task forward")
st.markdown("""
##### Predict reaction products from your inputs.
##### Upload a CSV that contains a `REACTANT` column. Optionally include `REAGENT`, `SOLVENT`, and/or `CATALYST`.
##### If a field lists multiple compounds, separate them with a dot (`.`). For details, download **demo_reaction_data.csv** and check its contents.
##### The output shows product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is the most probable).
""")
st.download_button(
label="Download demo_reaction_data.csv",
data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False),
file_name="demo_reaction_data.csv",
mime="text/csv",
)
class CFG:
num_beams = st.number_input(
label="num beams", min_value=1, max_value=10, value=5, step=1
)
num_return_sequences = num_beams
input_data = st.file_uploader("Choose a CSV file")
model_name_or_path = "sagawa/ReactionT5v2-forward"
input_column = "input"
input_max_length = 400
output_max_length = 300
output_min_length = -1
model = "t5"
seed = 42
batch_size = 1
if st.button("predict"):
with st.spinner(
"Now processing. If num beams=5, this process takes about 15 seconds per reaction."
):
CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything(seed=CFG.seed)
CFG.tokenizer = AutoTokenizer.from_pretrained(
os.path.abspath(CFG.model_name_or_path)
if os.path.exists(CFG.model_name_or_path)
else CFG.model_name_or_path,
return_tensors="pt",
)
model = AutoModelForSeq2SeqLM.from_pretrained(
os.path.abspath(CFG.model_name_or_path)
if os.path.exists(CFG.model_name_or_path)
else CFG.model_name_or_path
).to(CFG.device)
model.eval()
input_data = pd.read_csv(CFG.input_data)
input_data = preprocess_df(input_data, drop_duplicates=False)
dataset = ReactionT5Dataset(CFG, input_data)
dataloader = DataLoader(
dataset,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
drop_last=False,
)
all_sequences, all_scores = [], []
for inputs in tqdm(dataloader, total=len(dataloader)):
inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
min_length=CFG.output_min_length,
max_length=CFG.output_max_length,
num_beams=CFG.num_beams,
num_return_sequences=CFG.num_return_sequences,
return_dict_in_generate=True,
output_scores=True,
)
sequences, scores = decode_output(output, CFG)
all_sequences.extend(sequences)
if scores:
all_scores.extend(scores)
del output
torch.cuda.empty_cache()
gc.collect()
output_df = save_multiple_predictions(
input_data, all_sequences, all_scores, CFG
)
@st.cache
def convert_df(df):
return df.to_csv(index=False)
csv = convert_df(output_df)
st.download_button(
label="Download data as CSV",
data=csv,
file_name="output.csv",
mime="text/csv",
)