sagawa commited on
Commit
4a51867
·
verified ·
1 Parent(s): 8a0130f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -65
app.py CHANGED
@@ -2,14 +2,15 @@ import gc
2
  import os
3
  import sys
4
  import warnings
 
5
 
6
  import pandas as pd
7
  import streamlit as st
8
  import torch
9
  from torch.utils.data import DataLoader
10
- from tqdm import tqdm
11
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
 
 
13
  sys.path.append(
14
  os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
15
  )
@@ -23,76 +24,218 @@ from utils import seed_everything
23
 
24
  warnings.filterwarnings("ignore")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- st.title("ReactionT5 task forward")
28
- st.markdown("""
29
- ##### Predict reaction products from your inputs.
30
- ##### Upload a CSV that contains a `REACTANT` column. Optionally include `REAGENT`, `SOLVENT`, and/or `CATALYST`.
31
- ##### If a field lists multiple compounds, separate them with a dot (`.`). For details, download **demo_reaction_data.csv** and check its contents.
32
- ##### The output shows product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is the most probable).
33
- """)
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  st.download_button(
36
  label="Download demo_reaction_data.csv",
37
- data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False),
38
  file_name="demo_reaction_data.csv",
39
  mime="text/csv",
 
40
  )
41
 
 
 
 
 
 
 
 
42
 
43
- class CFG:
44
- num_beams = st.number_input(
45
- label="num beams", min_value=1, max_value=10, value=5, step=1
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  )
47
- num_return_sequences = num_beams
48
- input_data = st.file_uploader("Choose a CSV file")
49
- model_name_or_path = "sagawa/ReactionT5v2-forward"
50
- input_column = "input"
51
- input_max_length = 400
52
- output_max_length = 300
53
- output_min_length = -1
54
- model = "t5"
55
- seed = 42
56
- batch_size = 1
57
-
58
-
59
- if st.button("predict"):
60
- with st.spinner(
61
- "Now processing. If num beams=5, this process takes about 15 seconds per reaction."
62
- ):
63
-
64
- CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
-
66
- seed_everything(seed=CFG.seed)
67
 
68
- CFG.tokenizer = AutoTokenizer.from_pretrained(
69
- os.path.abspath(CFG.model_name_or_path)
70
- if os.path.exists(CFG.model_name_or_path)
71
- else CFG.model_name_or_path,
72
- return_tensors="pt",
73
  )
74
- model = AutoModelForSeq2SeqLM.from_pretrained(
75
- os.path.abspath(CFG.model_name_or_path)
76
- if os.path.exists(CFG.model_name_or_path)
77
- else CFG.model_name_or_path
78
- ).to(CFG.device)
79
- model.eval()
80
-
81
- input_data = pd.read_csv(CFG.input_data)
82
- input_data = preprocess_df(input_data, drop_duplicates=False)
83
- dataset = ReactionT5Dataset(CFG, input_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  dataloader = DataLoader(
85
  dataset,
86
  batch_size=CFG.batch_size,
87
  shuffle=False,
88
- num_workers=4,
89
- pin_memory=True,
90
  drop_last=False,
91
  )
92
 
 
93
  all_sequences, all_scores = [], []
94
- for inputs in tqdm(dataloader, total=len(dataloader)):
95
- inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
 
 
 
 
96
  with torch.no_grad():
97
  output = model.generate(
98
  **inputs,
@@ -107,23 +250,42 @@ if st.button("predict"):
107
  all_sequences.extend(sequences)
108
  if scores:
109
  all_scores.extend(scores)
 
110
  del output
111
- torch.cuda.empty_cache()
 
112
  gc.collect()
113
 
114
- output_df = save_multiple_predictions(
115
- input_data, all_sequences, all_scores, CFG
116
- )
 
 
117
 
118
- @st.cache
119
- def convert_df(df):
120
- return df.to_csv(index=False)
 
 
 
 
 
 
121
 
122
- csv = convert_df(output_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- st.download_button(
125
- label="Download data as CSV",
126
- data=csv,
127
- file_name="output.csv",
128
- mime="text/csv",
129
- )
 
2
  import os
3
  import sys
4
  import warnings
5
+ from types import SimpleNamespace
6
 
7
  import pandas as pd
8
  import streamlit as st
9
  import torch
10
  from torch.utils.data import DataLoader
 
11
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
 
13
+ # Local imports
14
  sys.path.append(
15
  os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
16
  )
 
24
 
25
  warnings.filterwarnings("ignore")
26
 
27
+ # ------------------------------
28
+ # Page setup
29
+ # ------------------------------
30
+ st.set_page_config(
31
+ page_title="ReactionT5 — Product Prediction",
32
+ page_icon=None,
33
+ layout="wide",
34
+ )
35
+
36
+ st.title("ReactionT5 — Product Prediction")
37
+ st.caption(
38
+ "Predict reaction products from your inputs using a pretrained ReactionT5 model."
39
+ )
40
 
41
+ with st.expander("How to format your CSV", expanded=False):
42
+ st.markdown(
43
+ """
44
+ - Include a required `REACTANT` column.
45
+ - Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`.
46
+ - If a field lists multiple compounds, separate them with a dot (`.`).
47
+ - For details, download **demo_reaction_data.csv** and check its contents.
48
+ - Output contains predicted product SMILES and the sum of log-likelihoods for each prediction, sorted by log-likelihood (index 0 is most probable).
49
+ """
50
+ )
51
+
52
+ # ------------------------------
53
+ # Demo data download
54
+ # ------------------------------
55
+ @st.cache_data(show_spinner=False)
56
+ def load_demo_csv_as_bytes() -> bytes:
57
+ demo_df = pd.read_csv("data/demo_reaction_data.csv")
58
+ return demo_df.to_csv(index=False).encode("utf-8")
59
 
60
  st.download_button(
61
  label="Download demo_reaction_data.csv",
62
+ data=load_demo_csv_as_bytes(),
63
  file_name="demo_reaction_data.csv",
64
  mime="text/csv",
65
+ use_container_width=True,
66
  )
67
 
68
+ st.divider()
69
+
70
+ # ------------------------------
71
+ # Sidebar: configuration
72
+ # ------------------------------
73
+ with st.sidebar:
74
+ st.header("Configuration")
75
 
76
+ model_name_or_path = st.text_input(
77
+ "Model",
78
+ value="sagawa/ReactionT5v2-forward",
79
+ help="Hugging Face model repo or a local path.",
80
+ )
81
+
82
+ num_beams = st.slider(
83
+ "Beam size",
84
+ min_value=1, max_value=10, value=5, step=1,
85
+ help="Number of beams for beam search.",
86
+ )
87
+
88
+ seed = st.number_input(
89
+ "Random seed",
90
+ min_value=0, max_value=2**32 - 1, value=42, step=1,
91
+ help="Seed for reproducibility.",
92
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ with st.expander("Advanced generation", expanded=False):
95
+ input_max_length = st.number_input(
96
+ "Input max length", min_value=8, max_value=1024, value=400, step=8
 
 
97
  )
98
+ output_max_length = st.number_input(
99
+ "Output max length", min_value=8, max_value=1024, value=300, step=8
100
+ )
101
+ output_min_length = st.number_input(
102
+ "Output min length", min_value=-1, max_value=1024, value=-1, step=1,
103
+ help="Use -1 to let the model decide.",
104
+ )
105
+ batch_size = st.number_input(
106
+ "Batch size", min_value=1, max_value=16, value=1, step=1
107
+ )
108
+ num_workers = st.number_input(
109
+ "DataLoader workers", min_value=0, max_value=8, value=4, step=1,
110
+ help="Set to 0 if multiprocessing is restricted in your environment.",
111
+ )
112
+
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+ st.caption(f"Detected device: **{device.type.upper()}**")
115
+
116
+ # ------------------------------
117
+ # Cached loaders
118
+ # ------------------------------
119
+ @st.cache_resource(show_spinner=False)
120
+ def load_tokenizer(model_ref: str):
121
+ resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
122
+ return AutoTokenizer.from_pretrained(resolved, return_tensors="pt")
123
+
124
+ @st.cache_resource(show_spinner=True)
125
+ def load_model(model_ref: str, device_str: str):
126
+ resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
127
+ model = AutoModelForSeq2SeqLM.from_pretrained(resolved)
128
+ model.to(torch.device(device_str))
129
+ model.eval()
130
+ return model
131
+
132
+ @st.cache_data(show_spinner=False)
133
+ def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
134
+ return df.to_csv(index=False).encode("utf-8")
135
+
136
+ # ------------------------------
137
+ # Main interaction
138
+ # ------------------------------
139
+ left, right = st.columns([1.4, 1.0], vertical_alignment="top")
140
+
141
+ with left:
142
+ with st.form("predict_form", clear_on_submit=False):
143
+ uploaded = st.file_uploader(
144
+ "Upload a CSV file with reactions",
145
+ type=["csv"],
146
+ accept_multiple_files=False,
147
+ help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.",
148
+ )
149
+ run = st.form_submit_button("Predict", use_container_width=True)
150
+
151
+ if uploaded is not None:
152
+ try:
153
+ raw_df = pd.read_csv(uploaded)
154
+ st.subheader("Input preview")
155
+ st.dataframe(raw_df.head(20), use_container_width=True)
156
+ except Exception as e:
157
+ st.error(f"Failed to read CSV: {e}")
158
+
159
+ with right:
160
+ st.subheader("Notes")
161
+ st.markdown(
162
+ f"""
163
+ - Beam size: **{num_beams}**
164
+ - Approximate time: about **15 seconds per reaction** when `beam size = 5` (varies by hardware).
165
+ - Results include the **sum of log-likelihoods** per prediction and are **sorted** by that value.
166
+ """
167
+ )
168
+ st.info(
169
+ "If you encounter CUDA OOM issues, reduce max lengths or beam size, or switch to CPU."
170
+ )
171
+
172
+ # ------------------------------
173
+ # Inference
174
+ # ------------------------------
175
+ if 'results_df' not in st.session_state:
176
+ st.session_state['results_df'] = None
177
+
178
+ if 'last_error' not in st.session_state:
179
+ st.session_state['last_error'] = None
180
+
181
+ if run:
182
+ if uploaded is None:
183
+ st.warning("Please upload a CSV file before running prediction.")
184
+ else:
185
+ # Build config object expected by your dataset/utils
186
+ CFG = SimpleNamespace(
187
+ num_beams=int(num_beams),
188
+ num_return_sequences=int(num_beams), # tie to beams by default
189
+ model_name_or_path=model_name_or_path,
190
+ input_column="input",
191
+ input_max_length=int(input_max_length),
192
+ output_max_length=int(output_max_length),
193
+ output_min_length=int(output_min_length),
194
+ model="t5",
195
+ seed=int(seed),
196
+ batch_size=int(batch_size),
197
+ )
198
+
199
+ seed_everything(seed=CFG.seed)
200
+
201
+ # Load model & tokenizer
202
+ with st.status("Loading model and tokenizer...", expanded=False) as status:
203
+ try:
204
+ tokenizer = load_tokenizer(CFG.model_name_or_path)
205
+ model = load_model(CFG.model_name_or_path, device.type)
206
+ status.update(label="Model ready.", state="complete")
207
+ except Exception as e:
208
+ st.session_state['last_error'] = f"Failed to load model: {e}"
209
+ status.update(label="Model load failed.", state="error")
210
+ st.stop()
211
+
212
+ # Prepare data
213
+ try:
214
+ input_df = pd.read_csv(uploaded)
215
+ input_df = preprocess_df(input_df, drop_duplicates=False)
216
+ except Exception as e:
217
+ st.error(f"Failed to preprocess input: {e}")
218
+ st.stop()
219
+
220
+ # Dataset & loader
221
+ dataset = ReactionT5Dataset(CFG, input_df)
222
  dataloader = DataLoader(
223
  dataset,
224
  batch_size=CFG.batch_size,
225
  shuffle=False,
226
+ num_workers=int(num_workers),
227
+ pin_memory=(device.type == "cuda"),
228
  drop_last=False,
229
  )
230
 
231
+ # Generation loop with progress
232
  all_sequences, all_scores = [], []
233
+ total = len(dataloader)
234
+ progress = st.progress(0, text="Generating predictions...")
235
+ info_placeholder = st.empty()
236
+
237
+ for i, inputs in enumerate(dataloader, start=1):
238
+ inputs = {k: v.to(device) for k, v in inputs.items()}
239
  with torch.no_grad():
240
  output = model.generate(
241
  **inputs,
 
250
  all_sequences.extend(sequences)
251
  if scores:
252
  all_scores.extend(scores)
253
+
254
  del output
255
+ if device.type == "cuda":
256
+ torch.cuda.empty_cache()
257
  gc.collect()
258
 
259
+ progress.progress(i / total, text=f"Generating predictions... {i}/{total}")
260
+ info_placeholder.caption(f"Processed batch {i} of {total}")
261
+
262
+ progress.empty()
263
+ info_placeholder.empty()
264
 
265
+ # Save predictions
266
+ try:
267
+ output_df = save_multiple_predictions(input_df, all_sequences, all_scores, CFG)
268
+ st.session_state['results_df'] = output_df
269
+ st.success("Prediction complete.")
270
+ except Exception as e:
271
+ st.session_state['last_error'] = f"Failed to assemble output: {e}"
272
+ st.error(st.session_state['last_error'])
273
+ st.stop()
274
 
275
+ # ------------------------------
276
+ # Results
277
+ # ------------------------------
278
+ if st.session_state.get('results_df') is not None:
279
+ st.subheader("Results preview")
280
+ st.dataframe(st.session_state['results_df'].head(50), use_container_width=True)
281
+
282
+ st.download_button(
283
+ label="Download predictions as CSV",
284
+ data=df_to_csv_bytes(st.session_state['results_df']),
285
+ file_name="output.csv",
286
+ mime="text/csv",
287
+ use_container_width=True,
288
+ )
289
 
290
+ if st.session_state.get('last_error'):
291
+ st.error(st.session_state['last_error'])