sagawa commited on
Commit
062afec
·
verified ·
1 Parent(s): c3b9df6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -126
app.py CHANGED
@@ -1,143 +1,290 @@
 
1
  import gc
2
  import os
3
  import sys
4
  import warnings
 
5
 
6
  import pandas as pd
7
  import streamlit as st
8
  import torch
9
  from torch.utils.data import DataLoader
10
- from tqdm import tqdm
11
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
 
 
13
  sys.path.append(
14
  os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
15
  )
16
- from generation_utils import (
17
- ReactionT5Dataset,
18
- decode_output,
19
- save_multiple_predictions,
20
- )
21
  from train import preprocess_df
22
  from utils import seed_everything
23
 
24
  warnings.filterwarnings("ignore")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- st.title("ReactionT5 task forward")
28
- st.markdown("""
29
- ##### At this space, you can predict the products of reactions from their inputs.
30
- ##### The code expects input_data as a string or CSV file that contains an "input" column.
31
- ##### The format of the string or contents of the column should be "REACTANT:{reactants}REAGENT:{reagents}".
32
- ##### If there is no reagent, fill the blank with a space. For multiple compounds, concatenate them with ".".
33
- ##### The output contains SMILES of predicted products and the sum of log-likelihood for each prediction, ordered by their log-likelihood (0th is the most probable product).
34
- """)
35
-
36
- st.download_button(
37
- label="Download demo_reaction_data.csv",
38
- data=pd.read_csv("data/demo_reaction_data.csv").to_csv(index=False),
39
- file_name="demo_reaction_data.csv",
40
- mime="text/csv",
 
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- class CFG:
45
- num_beams = st.number_input(
46
- label="num beams", min_value=1, max_value=10, value=5, step=1
 
 
 
 
 
 
 
 
 
47
  )
48
- num_return_sequences = num_beams
49
- input_data = st.file_uploader("Choose a CSV file")
50
- model_name_or_path = "sagawa/ReactionT5v2-forward"
51
- input_column = "input"
52
- input_max_length = 400
53
- output_max_length = 300
54
- output_min_length = -1
55
- model = "t5"
56
- seed = 42
57
- batch_size = 1
58
-
59
-
60
- if st.button("predict"):
61
- with st.spinner(
62
- "Now processing. If num beams=5, this process takes about 15 seconds per reaction."
63
- ):
64
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
-
66
- # seed_everything(seed=CFG.seed)
67
-
68
- # tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors="pt")
69
- # model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device)
70
- # model.eval()
71
-
72
- # if CFG.uploaded_file is None:
73
- # input_compound = CFG.input_data
74
- # output = predict_single_input(input_compound)
75
- # sequences, scores = decode_output(output)
76
- # output_df = save_single_prediction(input_compound, sequences, scores)
77
- # else:
78
- # input_data = pd.read_csv(CFG.uploaded_file)
79
- # dataset = ProductDataset(CFG, input_data)
80
- # dataloader = DataLoader(
81
- # dataset,
82
- # batch_size=CFG.batch_size,
83
- # shuffle=False,
84
- # num_workers=4,
85
- # pin_memory=True,
86
- # drop_last=False,
87
- # )
88
-
89
- # all_sequences, all_scores = [], []
90
- # for inputs in dataloader:
91
- # inputs = {k: v[0].to(device) for k, v in inputs.items()}
92
- # with torch.no_grad():
93
- # output = model.generate(
94
- # **inputs,
95
- # num_beams=CFG.num_beams,
96
- # num_return_sequences=CFG.num_return_sequences,
97
- # return_dict_in_generate=True,
98
- # output_scores=True,
99
- # )
100
- # sequences, scores = decode_output(output)
101
- # all_sequences.extend(sequences)
102
- # if scores:
103
- # all_scores.extend(scores)
104
- # del output
105
- # torch.cuda.empty_cache()
106
- # gc.collect()
107
-
108
- # output_df = save_multiple_predictions(input_data, all_sequences, all_scores)
109
-
110
- CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
-
112
- seed_everything(seed=CFG.seed)
113
-
114
- CFG.tokenizer = AutoTokenizer.from_pretrained(
115
- os.path.abspath(CFG.model_name_or_path)
116
- if os.path.exists(CFG.model_name_or_path)
117
- else CFG.model_name_or_path,
118
- return_tensors="pt",
 
 
 
 
 
 
119
  )
120
- model = AutoModelForSeq2SeqLM.from_pretrained(
121
- os.path.abspath(CFG.model_name_or_path)
122
- if os.path.exists(CFG.model_name_or_path)
123
- else CFG.model_name_or_path
124
- ).to(CFG.device)
125
- model.eval()
126
-
127
- input_data = pd.read_csv(CFG.input_data)
128
- input_data = preprocess_df(input_data, drop_duplicates=False)
129
- dataset = ReactionT5Dataset(CFG, input_data)
130
- dataloader = DataLoader(
131
- dataset,
132
- batch_size=CFG.batch_size,
133
- shuffle=False,
134
- num_workers=4,
135
- pin_memory=True,
136
- drop_last=False,
137
  )
138
 
139
- all_sequences, all_scores = [], []
140
- for inputs in tqdm(dataloader, total=len(dataloader)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
142
  with torch.no_grad():
143
  output = model.generate(
@@ -153,25 +300,49 @@ if st.button("predict"):
153
  all_sequences.extend(sequences)
154
  if scores:
155
  all_scores.extend(scores)
 
 
156
  del output
157
- torch.cuda.empty_cache()
 
158
  gc.collect()
159
 
160
- output_df = save_multiple_predictions(
161
- input_data, all_sequences, all_scores, CFG
162
- )
163
-
164
- # output_df.to_csv(os.path.join(CFG.output_dir, "output.csv"), index=False)
165
 
166
- @st.cache
167
- def convert_df(df):
168
- return df.to_csv(index=False)
 
169
 
170
- csv = convert_df(output_df)
 
 
 
 
 
171
 
 
 
 
172
  st.download_button(
173
- label="Download data as CSV",
174
- data=csv,
175
- file_name="output.csv",
176
  mime="text/csv",
 
177
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import gc
3
  import os
4
  import sys
5
  import warnings
6
+ from typing import Optional, Tuple
7
 
8
  import pandas as pd
9
  import streamlit as st
10
  import torch
11
  from torch.utils.data import DataLoader
 
12
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
13
 
14
+ # Local imports
15
  sys.path.append(
16
  os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
17
  )
18
+ from generation_utils import ReactionT5Dataset, decode_output, save_multiple_predictions
 
 
 
 
19
  from train import preprocess_df
20
  from utils import seed_everything
21
 
22
  warnings.filterwarnings("ignore")
23
 
24
+ # -----------------------------
25
+ # Page / Theme / Global Styles
26
+ # -----------------------------
27
+
28
+ # Subtle modern styles (card-like blocks, nicer headers, compact tables)
29
+ st.markdown(
30
+ """
31
+ <style>
32
+ /* Base */
33
+ .block-container {padding-top: 1.5rem; padding-bottom: 2rem;}
34
+ h1, h2, h3 { letter-spacing: .2px; }
35
+ .st-emotion-cache-1jicfl2 {padding: 1rem !important;} /* tabs pad (HF class may vary)*/
36
+
37
+ /* Card container */
38
+ .card {
39
+ border-radius: 18px;
40
+ padding: 1rem 1.2rem;
41
+ border: 1px solid rgba(127,127,127,0.15);
42
+ background: rgba(250,250,250,0.6);
43
+ backdrop-filter: blur(6px);
44
+ }
45
+ [data-baseweb="select"] div { border-radius: 12px !important; }
46
+
47
+ /* Buttons */
48
+ .stButton>button {
49
+ border-radius: 12px;
50
+ padding: .6rem 1rem;
51
+ font-weight: 600;
52
+ }
53
 
54
+ /* Badges */
55
+ .badge {
56
+ display:inline-block;
57
+ padding: .35em .6em;
58
+ border-radius: 10px;
59
+ background: rgba(0,0,0,.08);
60
+ font-size: .82rem;
61
+ margin-right: .4rem;
62
+ }
63
+
64
+ /* Tables */
65
+ .dataframe td, .dataframe th { font-size: 0.92rem; }
66
+ </style>
67
+ """,
68
+ unsafe_allow_html=True,
69
  )
70
 
71
+ # -----------------------------
72
+ # Header
73
+ # -----------------------------
74
+ col_l, col_r = st.columns([0.78, 0.22])
75
+ with col_l:
76
+ st.title("ReactionT5 • Task Forward")
77
+ st.markdown(
78
+ """
79
+ Predict **reaction products** from inputs formatted as
80
+ `REACTANT:{reactants}REAGENT:{reagents}`
81
+ For multiple compounds: join with `"."` • If no reagent: use a single space `" "`.
82
+ """
83
+ )
84
+ with col_r:
85
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
86
+ st.markdown("**Status**")
87
+ gpu = torch.cuda.is_available()
88
+ st.markdown(
89
+ f"""
90
+ <span class='badge'>Device: {"CUDA" if gpu else "CPU"}</span>
91
+ <span class='badge'>Transformers</span>
92
+ <span class='badge'>Streamlit</span>
93
+ """,
94
+ unsafe_allow_html=True,
95
+ )
96
+ st.markdown("</div>", unsafe_allow_html=True)
97
 
98
+ # -----------------------------
99
+ # Sidebar: Controls / Parameters
100
+ # -----------------------------
101
+ with st.sidebar:
102
+ st.header("Settings")
103
+
104
+ st.caption("Model")
105
+ model_name_or_path = st.text_input(
106
+ "Model name or path",
107
+ value="sagawa/ReactionT5v2-forward",
108
+ help="Hugging Face Hub repo or local path",
109
+ label_visibility="collapsed",
110
  )
111
+ st.divider()
112
+
113
+ st.caption("Generation")
114
+ num_beams = st.slider("num_beams", 1, 10, 5, 1)
115
+ num_return_sequences = st.slider("num_return_sequences", 1, num_beams, num_beams, 1)
116
+ output_max_length = st.slider("max_length", 64, 512, 300, 8)
117
+ output_min_length = st.number_input("min_length", value=-1, step=1)
118
+
119
+ st.caption("Batch / Reproducibility")
120
+ batch_size = st.slider("batch_size", 1, 8, 1, 1)
121
+ seed = st.number_input("seed", value=42, step=1)
122
+
123
+ st.caption("Tokenizer / Input")
124
+ input_max_length = st.slider("input_max_length", 64, 512, 400, 8)
125
+
126
+ st.info(
127
+ "Rough guide: ~15 sec / reaction with `num_beams=5`.",
128
+ )
129
+
130
+
131
+ # -----------------------------
132
+ # Helper: caching
133
+ # -----------------------------
134
+ @st.cache_resource(show_spinner=False)
135
+ def load_model_and_tokenizer(
136
+ path_or_name: str,
137
+ ) -> Tuple[AutoModelForSeq2SeqLM, AutoTokenizer]:
138
+ tok = AutoTokenizer.from_pretrained(
139
+ os.path.abspath(path_or_name) if os.path.exists(path_or_name) else path_or_name,
140
+ return_tensors="pt",
141
+ )
142
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(
143
+ os.path.abspath(path_or_name) if os.path.exists(path_or_name) else path_or_name
144
+ )
145
+ return mdl, tok
146
+
147
+
148
+ @st.cache_data(show_spinner=False)
149
+ def read_demo_csv() -> str:
150
+ df = pd.read_csv("data/demo_reaction_data.csv")
151
+ return df.to_csv(index=False)
152
+
153
+
154
+ @st.cache_data(show_spinner=False)
155
+ def to_csv_bytes(df: pd.DataFrame) -> bytes:
156
+ return df.to_csv(index=False).encode("utf-8")
157
+
158
+
159
+ # -----------------------------
160
+ # I/O Tabs
161
+ # -----------------------------
162
+ tabs = st.tabs(["Input", "Output", "Guide"])
163
+ with tabs[0]:
164
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
165
+ st.subheader("Provide your input")
166
+
167
+ input_mode = st.radio(
168
+ "Choose input mode",
169
+ options=("CSV upload", "Text area"),
170
+ horizontal=True,
171
+ )
172
+
173
+ csv_buffer: Optional[bytes] = None
174
+ text_area_value: Optional[str] = None
175
+
176
+ if input_mode == "CSV upload":
177
+ st.caption('CSV must contain an `"input"` column.')
178
+ up = st.file_uploader("Upload CSV", type=["csv"])
179
+ if up is not None:
180
+ csv_buffer = up.read()
181
+ st.success("CSV uploaded.")
182
+ st.download_button(
183
+ label="Download demo_reaction_data.csv",
184
+ data=read_demo_csv(),
185
+ file_name="demo_reaction_data.csv",
186
+ mime="text/csv",
187
+ use_container_width=True,
188
  )
189
+ else:
190
+ st.caption('Each line will be treated as one sample in the `"input"` column.')
191
+ text_area_value = st.text_area(
192
+ "Enter one or more inputs (one per line)",
193
+ height=140,
194
+ placeholder="REACTANT:CCO.REAGENT:O\nREACTANT:CC(=O)O.REAGENT: ",
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
 
197
+ st.markdown("</div>", unsafe_allow_html=True)
198
+
199
+ with tabs[2]:
200
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
201
+ st.subheader("Formatting rules")
202
+ st.markdown(
203
+ """
204
+ - **Template**: `REACTANT:{reactants}REAGENT:{reagents}`
205
+ - **Multiple compounds**: join with `"."`
206
+ - **No reagent**: provide a single space `" "` after `REAGENT:`
207
+ - **CSV schema**: must contain an `input` column
208
+ - **Outputs**: predicted products (SMILES) and sum of log-likelihood per hypothesis
209
+ """
210
+ )
211
+ st.markdown("</div>", unsafe_allow_html=True)
212
+
213
+ # -----------------------------
214
+ # Predict Button
215
+ # -----------------------------
216
+ run = st.button("🚀 Predict", use_container_width=True)
217
+
218
+ # -----------------------------
219
+ # Execution
220
+ # -----------------------------
221
+ if run:
222
+ # Validate input
223
+ if input_mode == "CSV upload" and not csv_buffer:
224
+ st.error(
225
+ "Please upload a CSV file with an `input` column, or switch to Text area."
226
+ )
227
+ st.stop()
228
+
229
+ if input_mode == "Text area" and (
230
+ text_area_value is None or not text_area_value.strip()
231
+ ):
232
+ st.error("Please enter at least one line of input.")
233
+ st.stop()
234
+
235
+ with st.status("Initializing model & tokenizer…", expanded=False) as status:
236
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
237
+ seed_everything(seed=seed)
238
+ model, tokenizer = load_model_and_tokenizer(model_name_or_path)
239
+ model = model.to(device).eval()
240
+ status.update(label="Model ready", state="complete")
241
+
242
+ # Prepare dataframe
243
+ if input_mode == "CSV upload":
244
+ df_in = pd.read_csv(pd.io.common.BytesIO(csv_buffer))
245
+ else:
246
+ lines = [x.strip() for x in text_area_value.splitlines() if x.strip()]
247
+ df_in = pd.DataFrame({"input": lines})
248
+
249
+ # Preprocess and dataset
250
+ try:
251
+ df_in = preprocess_df(df_in, drop_duplicates=False)
252
+ except Exception as e:
253
+ st.error(f"Input preprocessing failed: {e}")
254
+ st.stop()
255
+
256
+ class CFG:
257
+ # Configuration object used by ReactionT5Dataset/decode_output utilities
258
+ num_beams = num_beams
259
+ num_return_sequences = num_return_sequences
260
+ model_name_or_path = model_name_or_path
261
+ input_column = "input"
262
+ input_max_length = input_max_length
263
+ output_max_length = output_max_length
264
+ output_min_length = output_min_length
265
+ model = "t5"
266
+ seed = seed
267
+ batch_size = batch_size
268
+ device = device
269
+ tokenizer = tokenizer
270
+
271
+ dataset = ReactionT5Dataset(CFG, df_in)
272
+ dataloader = DataLoader(
273
+ dataset,
274
+ batch_size=CFG.batch_size,
275
+ shuffle=False,
276
+ num_workers=0 if not torch.cuda.is_available() else 4,
277
+ pin_memory=torch.cuda.is_available(),
278
+ drop_last=False,
279
+ )
280
+
281
+ # Progress UI
282
+ total_steps = len(dataloader)
283
+ progress = st.progress(0, text=f"Running generation… 0 / {total_steps}")
284
+ all_sequences, all_scores = [], []
285
+
286
+ try:
287
+ for idx, inputs in enumerate(dataloader, start=1):
288
  inputs = {k: v.to(CFG.device) for k, v in inputs.items()}
289
  with torch.no_grad():
290
  output = model.generate(
 
300
  all_sequences.extend(sequences)
301
  if scores:
302
  all_scores.extend(scores)
303
+
304
+ # Memory hygiene
305
  del output
306
+ if torch.cuda.is_available():
307
+ torch.cuda.empty_cache()
308
  gc.collect()
309
 
310
+ progress.progress(
311
+ idx / total_steps, text=f"Running generation… {idx} / {total_steps}"
312
+ )
 
 
313
 
314
+ st.toast("Generation complete")
315
+ except Exception as e:
316
+ st.error(f"Generation failed: {e}")
317
+ st.stop()
318
 
319
+ # Save & show
320
+ try:
321
+ output_df = save_multiple_predictions(df_in, all_sequences, all_scores, CFG)
322
+ except Exception as e:
323
+ st.error(f"Post-processing failed: {e}")
324
+ st.stop()
325
 
326
+ with tabs[1]:
327
+ st.subheader("Results")
328
+ st.dataframe(output_df, use_container_width=True, hide_index=True)
329
  st.download_button(
330
+ label="Download results (CSV)",
331
+ data=to_csv_bytes(output_df),
332
+ file_name="reactiont5_output.csv",
333
  mime="text/csv",
334
+ use_container_width=True,
335
  )
336
+
337
+ # -----------------------------
338
+ # Footer Note
339
+ # -----------------------------
340
+ st.markdown(
341
+ """
342
+ <hr/>
343
+ <small>
344
+ Built with ❤️ using Streamlit & 🤗 Transformers.
345
+ </small>
346
+ """,
347
+ unsafe_allow_html=True,
348
+ )