File size: 4,108 Bytes
e085437
2548c76
e085437
20946b6
e085437
2548c76
 
e085437
 
50ea5b6
e085437
 
 
 
 
50ea5b6
 
 
 
 
e085437
 
20946b6
 
 
e085437
50ea5b6
 
8a0130f
 
 
 
50ea5b6
 
 
 
 
 
 
e085437
20946b6
062afec
50ea5b6
 
 
062afec
50ea5b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
062afec
50ea5b6
 
 
062afec
50ea5b6
 
 
 
 
e085437
50ea5b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20946b6
 
50ea5b6
 
e085437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50ea5b6
e085437
 
50ea5b6
 
 
20946b6
50ea5b6
 
 
e085437
50ea5b6
e085437
965bb86
50ea5b6
 
 
e085437
50ea5b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gc
import os
import sys
import warnings

import pandas as pd
import streamlit as st
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
)
from generation_utils import (
    ReactionT5Dataset,
    decode_output,
    save_multiple_predictions,
)
from train import preprocess_df
from utils import seed_everything

warnings.filterwarnings("ignore")


st.title("ReactionT5 task forward")
st.markdown("""
##### Predict reaction products from your inputs.
##### Upload a CSV that contains a `REACTANT` column. Optionally include `REAGENT`, `SOLVENT`, and/or `CATALYST`.
##### If a field lists multiple compounds, separate them with a dot (`.`). For details, download **demo_reaction_data.csv** and check its contents.
##### The output shows product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is the most probable).
""")

st.download_button(
    label="Download demo_reaction_data.csv",
    data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False),
    file_name="demo_reaction_data.csv",
    mime="text/csv",
)


class CFG:
    num_beams = st.number_input(
        label="num beams", min_value=1, max_value=10, value=5, step=1
    )
    num_return_sequences = num_beams
    input_data = st.file_uploader("Choose a CSV file")
    model_name_or_path = "sagawa/ReactionT5v2-forward"
    input_column = "input"
    input_max_length = 400
    output_max_length = 300
    output_min_length = -1
    model = "t5"
    seed = 42
    batch_size = 1


if st.button("predict"):
    with st.spinner(
        "Now processing. If num beams=5, this process takes about 15 seconds per reaction."
    ):

        CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        seed_everything(seed=CFG.seed)

        CFG.tokenizer = AutoTokenizer.from_pretrained(
            os.path.abspath(CFG.model_name_or_path)
            if os.path.exists(CFG.model_name_or_path)
            else CFG.model_name_or_path,
            return_tensors="pt",
        )
        model = AutoModelForSeq2SeqLM.from_pretrained(
            os.path.abspath(CFG.model_name_or_path)
            if os.path.exists(CFG.model_name_or_path)
            else CFG.model_name_or_path
        ).to(CFG.device)
        model.eval()

        input_data = pd.read_csv(CFG.input_data)
        input_data = preprocess_df(input_data, drop_duplicates=False)
        dataset = ReactionT5Dataset(CFG, input_data)
        dataloader = DataLoader(
            dataset,
            batch_size=CFG.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            drop_last=False,
        )

        all_sequences, all_scores = [], []
        for inputs in tqdm(dataloader, total=len(dataloader)):
            inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
            with torch.no_grad():
                output = model.generate(
                    **inputs,
                    min_length=CFG.output_min_length,
                    max_length=CFG.output_max_length,
                    num_beams=CFG.num_beams,
                    num_return_sequences=CFG.num_return_sequences,
                    return_dict_in_generate=True,
                    output_scores=True,
                )
            sequences, scores = decode_output(output, CFG)
            all_sequences.extend(sequences)
            if scores:
                all_scores.extend(scores)
            del output
            torch.cuda.empty_cache()
            gc.collect()

        output_df = save_multiple_predictions(
            input_data, all_sequences, all_scores, CFG
        )

        @st.cache
        def convert_df(df):
            return df.to_csv(index=False)

        csv = convert_df(output_df)

        st.download_button(
            label="Download data as CSV",
            data=csv,
            file_name="output.csv",
            mime="text/csv",
        )