Spaces:
Running
Running
import gc | |
import os | |
import sys | |
import warnings | |
import pandas as pd | |
import streamlit as st | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
sys.path.append( | |
os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward")) | |
) | |
from generation_utils import ( | |
ReactionT5Dataset, | |
decode_output, | |
save_multiple_predictions, | |
) | |
from train import preprocess_df | |
from utils import seed_everything | |
warnings.filterwarnings("ignore") | |
st.title("ReactionT5 task forward") | |
st.markdown(""" | |
##### Predict reaction products from your inputs. | |
##### Upload a CSV that contains a `REACTANT` column. Optionally include `REAGENT`, `SOLVENT`, and/or `CATALYST`. | |
##### If a field lists multiple compounds, separate them with a dot (`.`). For details, download **demo_reaction_data.csv** and check its contents. | |
##### The output shows product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is the most probable). | |
""") | |
st.download_button( | |
label="Download demo_reaction_data.csv", | |
data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False), | |
file_name="demo_reaction_data.csv", | |
mime="text/csv", | |
) | |
class CFG: | |
num_beams = st.number_input( | |
label="num beams", min_value=1, max_value=10, value=5, step=1 | |
) | |
num_return_sequences = num_beams | |
input_data = st.file_uploader("Choose a CSV file") | |
model_name_or_path = "sagawa/ReactionT5v2-forward" | |
input_column = "input" | |
input_max_length = 400 | |
output_max_length = 300 | |
output_min_length = -1 | |
model = "t5" | |
seed = 42 | |
batch_size = 1 | |
if st.button("predict"): | |
with st.spinner( | |
"Now processing. If num beams=5, this process takes about 15 seconds per reaction." | |
): | |
CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
seed_everything(seed=CFG.seed) | |
CFG.tokenizer = AutoTokenizer.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path, | |
return_tensors="pt", | |
) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path | |
).to(CFG.device) | |
model.eval() | |
input_data = pd.read_csv(CFG.input_data) | |
input_data = preprocess_df(input_data, drop_duplicates=False) | |
dataset = ReactionT5Dataset(CFG, input_data) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
shuffle=False, | |
num_workers=4, | |
pin_memory=True, | |
drop_last=False, | |
) | |
all_sequences, all_scores = [], [] | |
for inputs in tqdm(dataloader, total=len(dataloader)): | |
inputs = {k: v.to(CFG.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
min_length=CFG.output_min_length, | |
max_length=CFG.output_max_length, | |
num_beams=CFG.num_beams, | |
num_return_sequences=CFG.num_return_sequences, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
sequences, scores = decode_output(output, CFG) | |
all_sequences.extend(sequences) | |
if scores: | |
all_scores.extend(scores) | |
del output | |
torch.cuda.empty_cache() | |
gc.collect() | |
output_df = save_multiple_predictions( | |
input_data, all_sequences, all_scores, CFG | |
) | |
def convert_df(df): | |
return df.to_csv(index=False) | |
csv = convert_df(output_df) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name="output.csv", | |
mime="text/csv", | |
) |