sagawa commited on
Commit
50ea5b6
·
verified ·
1 Parent(s): f7811db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -322
app.py CHANGED
@@ -1,289 +1,98 @@
1
- # app.py
2
  import gc
3
  import os
4
  import sys
5
  import warnings
6
- from typing import Optional, Tuple
7
 
8
  import pandas as pd
9
  import streamlit as st
10
  import torch
11
  from torch.utils.data import DataLoader
 
12
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
13
 
14
- # Local imports
15
  sys.path.append(
16
  os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
17
  )
18
- from generation_utils import ReactionT5Dataset, decode_output, save_multiple_predictions
 
 
 
 
19
  from train import preprocess_df
20
  from utils import seed_everything
21
 
22
  warnings.filterwarnings("ignore")
23
 
24
- # -----------------------------
25
- # Page / Theme / Global Styles
26
- # -----------------------------
27
-
28
- # Subtle modern styles (card-like blocks, nicer headers, compact tables)
29
- st.markdown(
30
- """
31
- <style>
32
- /* Base */
33
- .block-container {padding-top: 1.5rem; padding-bottom: 2rem;}
34
- h1, h2, h3 { letter-spacing: .2px; }
35
- .st-emotion-cache-1jicfl2 {padding: 1rem !important;} /* tabs pad (HF class may vary)*/
36
-
37
- /* Card container */
38
- .card {
39
- border-radius: 18px;
40
- padding: 1rem 1.2rem;
41
- border: 1px solid rgba(127,127,127,0.15);
42
- background: rgba(250,250,250,0.6);
43
- backdrop-filter: blur(6px);
44
- }
45
- [data-baseweb="select"] div { border-radius: 12px !important; }
46
-
47
- /* Buttons */
48
- .stButton>button {
49
- border-radius: 12px;
50
- padding: .6rem 1rem;
51
- font-weight: 600;
52
- }
53
 
54
- /* Badges */
55
- .badge {
56
- display:inline-block;
57
- padding: .35em .6em;
58
- border-radius: 10px;
59
- background: rgba(0,0,0,.08);
60
- font-size: .82rem;
61
- margin-right: .4rem;
62
- }
63
-
64
- /* Tables */
65
- .dataframe td, .dataframe th { font-size: 0.92rem; }
66
- </style>
67
- """,
68
- unsafe_allow_html=True,
69
  )
70
 
71
- # -----------------------------
72
- # Header
73
- # -----------------------------
74
- col_l, col_r = st.columns([0.78, 0.22])
75
- with col_l:
76
- st.title("ReactionT5 • Task Forward")
77
- st.markdown(
78
- """
79
- Predict **reaction products** from inputs formatted as
80
- `REACTANT:{reactants}REAGENT:{reagents}`
81
- For multiple compounds: join with `"."` • If no reagent: use a single space `" "`.
82
- """
83
- )
84
- with col_r:
85
- st.markdown("<div class='card'>", unsafe_allow_html=True)
86
- st.markdown("**Status**")
87
- gpu = torch.cuda.is_available()
88
- st.markdown(
89
- f"""
90
- <span class='badge'>Device: {"CUDA" if gpu else "CPU"}</span>
91
- <span class='badge'>Transformers</span>
92
- <span class='badge'>Streamlit</span>
93
- """,
94
- unsafe_allow_html=True,
95
- )
96
- st.markdown("</div>", unsafe_allow_html=True)
97
-
98
- # -----------------------------
99
- # Sidebar: Controls / Parameters
100
- # -----------------------------
101
- with st.sidebar:
102
- st.header("Settings")
103
-
104
- st.caption("Model")
105
- model_name_or_path = st.text_input(
106
- "Model name or path",
107
- value="sagawa/ReactionT5v2-forward",
108
- help="Hugging Face Hub repo or local path",
109
- )
110
- st.markdown("---")
111
-
112
- st.caption("Generation")
113
- num_beams = st.slider("num_beams", 1, 10, 5, 1)
114
- num_return_sequences = st.slider("num_return_sequences", 1, num_beams, num_beams, 1)
115
- output_max_length = st.slider("max_length", 64, 512, 300, 8)
116
- output_min_length = st.number_input("min_length", value=-1, step=1)
117
-
118
- st.caption("Batch / Reproducibility")
119
- batch_size = st.slider("batch_size", 1, 8, 1, 1)
120
- seed = st.number_input("seed", value=42, step=1)
121
-
122
- st.caption("Tokenizer / Input")
123
- input_max_length = st.slider("input_max_length", 64, 512, 400, 8)
124
-
125
- st.info(
126
- "Rough guide: ~15 sec / reaction with `num_beams=5`.",
127
- )
128
-
129
-
130
- # -----------------------------
131
- # Helper: caching
132
- # -----------------------------
133
- @st.cache_resource(show_spinner=False)
134
- def load_model_and_tokenizer(
135
- path_or_name: str,
136
- ) -> Tuple[AutoModelForSeq2SeqLM, AutoTokenizer]:
137
- tok = AutoTokenizer.from_pretrained(
138
- os.path.abspath(path_or_name) if os.path.exists(path_or_name) else path_or_name,
139
- return_tensors="pt",
140
- )
141
- mdl = AutoModelForSeq2SeqLM.from_pretrained(
142
- os.path.abspath(path_or_name) if os.path.exists(path_or_name) else path_or_name
143
- )
144
- return mdl, tok
145
-
146
-
147
- @st.cache_data(show_spinner=False)
148
- def read_demo_csv() -> str:
149
- df = pd.read_csv("data/demo_reaction_data.csv")
150
- return df.to_csv(index=False)
151
-
152
-
153
- @st.cache_data(show_spinner=False)
154
- def to_csv_bytes(df: pd.DataFrame) -> bytes:
155
- return df.to_csv(index=False).encode("utf-8")
156
-
157
-
158
- # -----------------------------
159
- # I/O Tabs
160
- # -----------------------------
161
- tabs = st.tabs(["Input", "Output", "Guide"])
162
- with tabs[0]:
163
- st.markdown("<div class='card'>", unsafe_allow_html=True)
164
- st.subheader("Provide your input")
165
 
166
- input_mode = st.radio(
167
- "Choose input mode",
168
- options=("CSV upload", "Text area"),
169
- horizontal=True,
170
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- csv_buffer: Optional[bytes] = None
173
- text_area_value: Optional[str] = None
 
174
 
175
- if input_mode == "CSV upload":
176
- st.caption('CSV must contain an `"input"` column.')
177
- up = st.file_uploader("Upload CSV", type=["csv"])
178
- if up is not None:
179
- csv_buffer = up.read()
180
- st.success("CSV uploaded.")
181
- st.download_button(
182
- label="Download demo_reaction_data.csv",
183
- data=read_demo_csv(),
184
- file_name="demo_reaction_data.csv",
185
- mime="text/csv",
186
- use_container_width=True,
187
  )
188
- else:
189
- st.caption('Each line will be treated as one sample in the `"input"` column.')
190
- text_area_value = st.text_area(
191
- "Enter one or more inputs (one per line)",
192
- height=140,
193
- placeholder="REACTANT:CCO.REAGENT:O\nREACTANT:CC(=O)O.REAGENT: ",
 
 
 
 
 
 
 
 
 
 
 
194
  )
195
 
196
- st.markdown("</div>", unsafe_allow_html=True)
197
-
198
- with tabs[2]:
199
- st.markdown("<div class='card'>", unsafe_allow_html=True)
200
- st.subheader("Formatting rules")
201
- st.markdown(
202
- """
203
- - **Template**: `REACTANT:{reactants}REAGENT:{reagents}`
204
- - **Multiple compounds**: join with `"."`
205
- - **No reagent**: provide a single space `" "` after `REAGENT:`
206
- - **CSV schema**: must contain an `input` column
207
- - **Outputs**: predicted products (SMILES) and sum of log-likelihood per hypothesis
208
- """
209
- )
210
- st.markdown("</div>", unsafe_allow_html=True)
211
-
212
- # -----------------------------
213
- # Predict Button
214
- # -----------------------------
215
- run = st.button("🚀 Predict", use_container_width=True)
216
-
217
- # -----------------------------
218
- # Execution
219
- # -----------------------------
220
- if run:
221
- # Validate input
222
- if input_mode == "CSV upload" and not csv_buffer:
223
- st.error(
224
- "Please upload a CSV file with an `input` column, or switch to Text area."
225
- )
226
- st.stop()
227
-
228
- if input_mode == "Text area" and (
229
- text_area_value is None or not text_area_value.strip()
230
- ):
231
- st.error("Please enter at least one line of input.")
232
- st.stop()
233
-
234
- with st.status("Initializing model & tokenizer…", expanded=False) as status:
235
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
236
- seed_everything(seed=seed)
237
- model, tokenizer = load_model_and_tokenizer(model_name_or_path)
238
- model = model.to(device).eval()
239
- status.update(label="Model ready", state="complete")
240
-
241
- # Prepare dataframe
242
- if input_mode == "CSV upload":
243
- df_in = pd.read_csv(pd.io.common.BytesIO(csv_buffer))
244
- else:
245
- lines = [x.strip() for x in text_area_value.splitlines() if x.strip()]
246
- df_in = pd.DataFrame({"input": lines})
247
-
248
- # Preprocess and dataset
249
- try:
250
- df_in = preprocess_df(df_in, drop_duplicates=False)
251
- except Exception as e:
252
- st.error(f"Input preprocessing failed: {e}")
253
- st.stop()
254
-
255
- class CFG:
256
- # Configuration object used by ReactionT5Dataset/decode_output utilities
257
- num_beams = num_beams
258
- num_return_sequences = num_return_sequences
259
- model_name_or_path = model_name_or_path
260
- input_column = "input"
261
- input_max_length = input_max_length
262
- output_max_length = output_max_length
263
- output_min_length = output_min_length
264
- model = "t5"
265
- seed = seed
266
- batch_size = batch_size
267
- device = device
268
- tokenizer = tokenizer
269
-
270
- dataset = ReactionT5Dataset(CFG, df_in)
271
- dataloader = DataLoader(
272
- dataset,
273
- batch_size=CFG.batch_size,
274
- shuffle=False,
275
- num_workers=0 if not torch.cuda.is_available() else 4,
276
- pin_memory=torch.cuda.is_available(),
277
- drop_last=False,
278
- )
279
-
280
- # Progress UI
281
- total_steps = len(dataloader)
282
- progress = st.progress(0, text=f"Running generation… 0 / {total_steps}")
283
- all_sequences, all_scores = [], []
284
-
285
- try:
286
- for idx, inputs in enumerate(dataloader, start=1):
287
  inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
288
  with torch.no_grad():
289
  output = model.generate(
@@ -299,76 +108,23 @@ if run:
299
  all_sequences.extend(sequences)
300
  if scores:
301
  all_scores.extend(scores)
302
-
303
- # Memory hygiene
304
  del output
305
- if torch.cuda.is_available():
306
- torch.cuda.empty_cache()
307
  gc.collect()
308
 
309
- progress.progress(
310
- idx / total_steps, text=f"Running generation… {idx} / {total_steps}"
311
- )
312
 
313
- st.toast("Generation complete")
314
- except Exception as e:
315
- st.error(f"Generation failed: {e}")
316
- st.stop()
317
 
318
- # Save & show
319
- try:
320
- output_df = save_multiple_predictions(df_in, all_sequences, all_scores, CFG)
321
- except Exception as e:
322
- st.error(f"Post-processing failed: {e}")
323
- st.stop()
324
 
325
- with tabs[1]:
326
- st.subheader("Results")
327
- st.dataframe(output_df, use_container_width=True, hide_index=True)
328
  st.download_button(
329
- label="Download results (CSV)",
330
- data=to_csv_bytes(output_df),
331
- file_name="reactiont5_output.csv",
332
  mime="text/csv",
333
- use_container_width=True,
334
- )
335
-
336
- # -----------------------------
337
- # Footer Note (replace this whole block)
338
- # -----------------------------
339
- st.markdown(
340
- """
341
- <hr/>
342
- <div style="font-size:0.95rem; line-height:1.6">
343
- <strong>Citation</strong><br/>
344
- Sagawa, T., & Kojima, R. (2025).
345
- <em>ReactionT5: a pre-trained transformer model for accurate chemical reaction prediction with limited data</em>.
346
- <em>Journal of Cheminformatics</em>, 17(1), 126.
347
- <a href="https://doi.org/10.1186/s13321-025-01075-4" target="_blank" rel="noopener">
348
- https://doi.org/10.1186/s13321-025-01075-4
349
- </a>
350
-
351
- <details style="margin-top: .5rem;">
352
- <summary style="cursor: pointer;">Show BibTeX</summary>
353
- <pre style="white-space: pre-wrap; font-size:0.9rem; margin-top:.5rem;">
354
- @article{Sagawa2025,
355
- title = {ReactionT5: a pre-trained transformer model for accurate chemical reaction prediction with limited data},
356
- author = {Sagawa, Tatsuya and Kojima, Ryosuke},
357
- journal = {Journal of Cheminformatics},
358
- year = {2025},
359
- volume = {17},
360
- number = {1},
361
- pages = {126},
362
- doi = {10.1186/s13321-025-01075-4},
363
- url = {https://doi.org/10.1186/s13321-025-01075-4}
364
- }
365
- </pre>
366
- </details>
367
-
368
- <div style="margin-top:.75rem; color:#666;">
369
- Built with Streamlit and Transformers.
370
- </div>
371
- </div>
372
- """,
373
- unsafe_allow_html=True,
374
- )
 
 
1
  import gc
2
  import os
3
  import sys
4
  import warnings
 
5
 
6
  import pandas as pd
7
  import streamlit as st
8
  import torch
9
  from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
 
 
13
  sys.path.append(
14
  os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
15
  )
16
+ from generation_utils import (
17
+ ReactionT5Dataset,
18
+ decode_output,
19
+ save_multiple_predictions,
20
+ )
21
  from train import preprocess_df
22
  from utils import seed_everything
23
 
24
  warnings.filterwarnings("ignore")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ st.title("ReactionT5 task forward")
28
+ st.markdown("""
29
+ ##### At this space, you can predict the products of reactions from their inputs.
30
+ ##### The code expects input_data as a string or CSV file that contains an "input" column.
31
+ ##### The format of the string or contents of the column should be "REACTANT:{reactants}REAGENT:{reagents}".
32
+ ##### If there is no reagent, fill the blank with a space. For multiple compounds, concatenate them with ".".
33
+ ##### The output contains SMILES of predicted products and the sum of log-likelihood for each prediction, ordered by their log-likelihood (0th is the most probable product).
34
+ """)
35
+
36
+ st.download_button(
37
+ label="Download demo_reaction_data.csv",
38
+ data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False),
39
+ file_name="demo_reaction_data.csv",
40
+ mime="text/csv",
 
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ class CFG:
45
+ num_beams = st.number_input(
46
+ label="num beams", min_value=1, max_value=10, value=5, step=1
 
47
  )
48
+ num_return_sequences = num_beams
49
+ input_data = st.file_uploader("Choose a CSV file")
50
+ model_name_or_path = "sagawa/ReactionT5v2-forward"
51
+ input_column = "input"
52
+ input_max_length = 400
53
+ output_max_length = 300
54
+ output_min_length = -1
55
+ model = "t5"
56
+ seed = 42
57
+ batch_size = 1
58
+
59
+
60
+ if st.button("predict"):
61
+ with st.spinner(
62
+ "Now processing. If num beams=5, this process takes about 15 seconds per reaction."
63
+ ):
64
 
65
+ CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+ seed_everything(seed=CFG.seed)
68
 
69
+ CFG.tokenizer = AutoTokenizer.from_pretrained(
70
+ os.path.abspath(CFG.model_name_or_path)
71
+ if os.path.exists(CFG.model_name_or_path)
72
+ else CFG.model_name_or_path,
73
+ return_tensors="pt",
 
 
 
 
 
 
 
74
  )
75
+ model = AutoModelForSeq2SeqLM.from_pretrained(
76
+ os.path.abspath(CFG.model_name_or_path)
77
+ if os.path.exists(CFG.model_name_or_path)
78
+ else CFG.model_name_or_path
79
+ ).to(CFG.device)
80
+ model.eval()
81
+
82
+ input_data = pd.read_csv(CFG.input_data)
83
+ input_data = preprocess_df(input_data, drop_duplicates=False)
84
+ dataset = ReactionT5Dataset(CFG, input_data)
85
+ dataloader = DataLoader(
86
+ dataset,
87
+ batch_size=CFG.batch_size,
88
+ shuffle=False,
89
+ num_workers=4,
90
+ pin_memory=True,
91
+ drop_last=False,
92
  )
93
 
94
+ all_sequences, all_scores = [], []
95
+ for inputs in tqdm(dataloader, total=len(dataloader)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
97
  with torch.no_grad():
98
  output = model.generate(
 
108
  all_sequences.extend(sequences)
109
  if scores:
110
  all_scores.extend(scores)
 
 
111
  del output
112
+ torch.cuda.empty_cache()
 
113
  gc.collect()
114
 
115
+ output_df = save_multiple_predictions(
116
+ input_data, all_sequences, all_scores, CFG
117
+ )
118
 
119
+ @st.cache
120
+ def convert_df(df):
121
+ return df.to_csv(index=False)
 
122
 
123
+ csv = convert_df(output_df)
 
 
 
 
 
124
 
 
 
 
125
  st.download_button(
126
+ label="Download data as CSV",
127
+ data=csv,
128
+ file_name="output.csv",
129
  mime="text/csv",
130
+ )