Spaces:
Running
Running
File size: 4,108 Bytes
e085437 2548c76 e085437 20946b6 e085437 2548c76 e085437 50ea5b6 e085437 50ea5b6 e085437 20946b6 e085437 50ea5b6 8a0130f 50ea5b6 e085437 20946b6 062afec 50ea5b6 062afec 50ea5b6 062afec 50ea5b6 062afec 50ea5b6 e085437 50ea5b6 20946b6 50ea5b6 e085437 50ea5b6 e085437 50ea5b6 20946b6 50ea5b6 e085437 50ea5b6 e085437 965bb86 50ea5b6 e085437 50ea5b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gc
import os
import sys
import warnings
import pandas as pd
import streamlit as st
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
)
from generation_utils import (
ReactionT5Dataset,
decode_output,
save_multiple_predictions,
)
from train import preprocess_df
from utils import seed_everything
warnings.filterwarnings("ignore")
st.title("ReactionT5 task forward")
st.markdown("""
##### Predict reaction products from your inputs.
##### Upload a CSV that contains a `REACTANT` column. Optionally include `REAGENT`, `SOLVENT`, and/or `CATALYST`.
##### If a field lists multiple compounds, separate them with a dot (`.`). For details, download **demo_reaction_data.csv** and check its contents.
##### The output shows product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is the most probable).
""")
st.download_button(
label="Download demo_reaction_data.csv",
data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False),
file_name="demo_reaction_data.csv",
mime="text/csv",
)
class CFG:
num_beams = st.number_input(
label="num beams", min_value=1, max_value=10, value=5, step=1
)
num_return_sequences = num_beams
input_data = st.file_uploader("Choose a CSV file")
model_name_or_path = "sagawa/ReactionT5v2-forward"
input_column = "input"
input_max_length = 400
output_max_length = 300
output_min_length = -1
model = "t5"
seed = 42
batch_size = 1
if st.button("predict"):
with st.spinner(
"Now processing. If num beams=5, this process takes about 15 seconds per reaction."
):
CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything(seed=CFG.seed)
CFG.tokenizer = AutoTokenizer.from_pretrained(
os.path.abspath(CFG.model_name_or_path)
if os.path.exists(CFG.model_name_or_path)
else CFG.model_name_or_path,
return_tensors="pt",
)
model = AutoModelForSeq2SeqLM.from_pretrained(
os.path.abspath(CFG.model_name_or_path)
if os.path.exists(CFG.model_name_or_path)
else CFG.model_name_or_path
).to(CFG.device)
model.eval()
input_data = pd.read_csv(CFG.input_data)
input_data = preprocess_df(input_data, drop_duplicates=False)
dataset = ReactionT5Dataset(CFG, input_data)
dataloader = DataLoader(
dataset,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
drop_last=False,
)
all_sequences, all_scores = [], []
for inputs in tqdm(dataloader, total=len(dataloader)):
inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
min_length=CFG.output_min_length,
max_length=CFG.output_max_length,
num_beams=CFG.num_beams,
num_return_sequences=CFG.num_return_sequences,
return_dict_in_generate=True,
output_scores=True,
)
sequences, scores = decode_output(output, CFG)
all_sequences.extend(sequences)
if scores:
all_scores.extend(scores)
del output
torch.cuda.empty_cache()
gc.collect()
output_df = save_multiple_predictions(
input_data, all_sequences, all_scores, CFG
)
@st.cache
def convert_df(df):
return df.to_csv(index=False)
csv = convert_df(output_df)
st.download_button(
label="Download data as CSV",
data=csv,
file_name="output.csv",
mime="text/csv",
) |