sagawa commited on
Commit
60b0e86
·
verified ·
1 Parent(s): 05b9666

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -21
app.py CHANGED
@@ -52,11 +52,23 @@ with st.expander("How to format your CSV", expanded=False):
52
  # ------------------------------
53
  # Demo data download
54
  # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
55
  @st.cache_data(show_spinner=False)
56
  def load_demo_csv_as_bytes() -> bytes:
57
  demo_df = pd.read_csv("data/demo_reaction_data.csv")
58
  return demo_df.to_csv(index=False).encode("utf-8")
59
 
 
60
  st.download_button(
61
  label="Download demo_reaction_data.csv",
62
  data=load_demo_csv_as_bytes(),
@@ -81,13 +93,19 @@ with st.sidebar:
81
 
82
  num_beams = st.slider(
83
  "Beam size",
84
- min_value=1, max_value=10, value=5, step=1,
 
 
 
85
  help="Number of beams for beam search.",
86
  )
87
 
88
  seed = st.number_input(
89
  "Random seed",
90
- min_value=0, max_value=2**32 - 1, value=42, step=1,
 
 
 
91
  help="Seed for reproducibility.",
92
  )
93
 
@@ -99,20 +117,29 @@ with st.sidebar:
99
  "Output max length", min_value=8, max_value=1024, value=300, step=8
100
  )
101
  output_min_length = st.number_input(
102
- "Output min length", min_value=-1, max_value=1024, value=-1, step=1,
 
 
 
 
103
  help="Use -1 to let the model decide.",
104
  )
105
  batch_size = st.number_input(
106
  "Batch size", min_value=1, max_value=16, value=1, step=1
107
  )
108
  num_workers = st.number_input(
109
- "DataLoader workers", min_value=0, max_value=8, value=4, step=1,
 
 
 
 
110
  help="Set to 0 if multiprocessing is restricted in your environment.",
111
  )
112
 
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
  st.caption(f"Detected device: **{device.type.upper()}**")
115
 
 
116
  # ------------------------------
117
  # Cached loaders
118
  # ------------------------------
@@ -121,6 +148,7 @@ def load_tokenizer(model_ref: str):
121
  resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
122
  return AutoTokenizer.from_pretrained(resolved, return_tensors="pt")
123
 
 
124
  @st.cache_resource(show_spinner=True)
125
  def load_model(model_ref: str, device_str: str):
126
  resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
@@ -129,10 +157,12 @@ def load_model(model_ref: str, device_str: str):
129
  model.eval()
130
  return model
131
 
 
132
  @st.cache_data(show_spinner=False)
133
  def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
134
  return df.to_csv(index=False).encode("utf-8")
135
 
 
136
  # ------------------------------
137
  # Main interaction
138
  # ------------------------------
@@ -150,7 +180,9 @@ with left:
150
 
151
  if uploaded is not None:
152
  try:
153
- raw_df = pd.read_csv(uploaded)
 
 
154
  st.subheader("Input preview")
155
  st.dataframe(raw_df.head(20), use_container_width=True)
156
  except Exception as e:
@@ -172,11 +204,11 @@ with right:
172
  # ------------------------------
173
  # Inference
174
  # ------------------------------
175
- if 'results_df' not in st.session_state:
176
- st.session_state['results_df'] = None
177
 
178
- if 'last_error' not in st.session_state:
179
- st.session_state['last_error'] = None
180
 
181
  if run:
182
  if uploaded is None:
@@ -205,14 +237,15 @@ if run:
205
  model = load_model(CFG.model_name_or_path, device.type)
206
  status.update(label="Model ready.", state="complete")
207
  except Exception as e:
208
- st.session_state['last_error'] = f"Failed to load model: {e}"
209
  status.update(label="Model load failed.", state="error")
210
  st.stop()
211
 
212
  # Prepare data
213
- input_df = pd.read_csv(uploaded)
 
 
214
  input_df = preprocess_df(input_df, drop_duplicates=False)
215
-
216
 
217
  # Dataset & loader
218
  dataset = ReactionT5Dataset(CFG, input_df)
@@ -261,28 +294,30 @@ if run:
261
 
262
  # Save predictions
263
  try:
264
- output_df = save_multiple_predictions(input_df, all_sequences, all_scores, CFG)
265
- st.session_state['results_df'] = output_df
 
 
266
  st.success("Prediction complete.")
267
  except Exception as e:
268
- st.session_state['last_error'] = f"Failed to assemble output: {e}"
269
- st.error(st.session_state['last_error'])
270
  st.stop()
271
 
272
  # ------------------------------
273
  # Results
274
  # ------------------------------
275
- if st.session_state.get('results_df') is not None:
276
  st.subheader("Results preview")
277
- st.dataframe(st.session_state['results_df'].head(50), use_container_width=True)
278
 
279
  st.download_button(
280
  label="Download predictions as CSV",
281
- data=df_to_csv_bytes(st.session_state['results_df']),
282
  file_name="output.csv",
283
  mime="text/csv",
284
  use_container_width=True,
285
  )
286
 
287
- if st.session_state.get('last_error'):
288
- st.error(st.session_state['last_error'])
 
52
  # ------------------------------
53
  # Demo data download
54
  # ------------------------------
55
+ import io
56
+
57
+
58
+ @st.cache_data(show_spinner=False)
59
+ def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame:
60
+ # If your files are always UTF-8, this is fine:
61
+ return pd.read_csv(io.BytesIO(file_bytes))
62
+ # If you prefer explicit text decoding:
63
+ # return pd.read_csv(io.StringIO(file_bytes.decode("utf-8")))
64
+
65
+
66
  @st.cache_data(show_spinner=False)
67
  def load_demo_csv_as_bytes() -> bytes:
68
  demo_df = pd.read_csv("data/demo_reaction_data.csv")
69
  return demo_df.to_csv(index=False).encode("utf-8")
70
 
71
+
72
  st.download_button(
73
  label="Download demo_reaction_data.csv",
74
  data=load_demo_csv_as_bytes(),
 
93
 
94
  num_beams = st.slider(
95
  "Beam size",
96
+ min_value=1,
97
+ max_value=10,
98
+ value=5,
99
+ step=1,
100
  help="Number of beams for beam search.",
101
  )
102
 
103
  seed = st.number_input(
104
  "Random seed",
105
+ min_value=0,
106
+ max_value=2**32 - 1,
107
+ value=42,
108
+ step=1,
109
  help="Seed for reproducibility.",
110
  )
111
 
 
117
  "Output max length", min_value=8, max_value=1024, value=300, step=8
118
  )
119
  output_min_length = st.number_input(
120
+ "Output min length",
121
+ min_value=-1,
122
+ max_value=1024,
123
+ value=-1,
124
+ step=1,
125
  help="Use -1 to let the model decide.",
126
  )
127
  batch_size = st.number_input(
128
  "Batch size", min_value=1, max_value=16, value=1, step=1
129
  )
130
  num_workers = st.number_input(
131
+ "DataLoader workers",
132
+ min_value=0,
133
+ max_value=8,
134
+ value=4,
135
+ step=1,
136
  help="Set to 0 if multiprocessing is restricted in your environment.",
137
  )
138
 
139
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
  st.caption(f"Detected device: **{device.type.upper()}**")
141
 
142
+
143
  # ------------------------------
144
  # Cached loaders
145
  # ------------------------------
 
148
  resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
149
  return AutoTokenizer.from_pretrained(resolved, return_tensors="pt")
150
 
151
+
152
  @st.cache_resource(show_spinner=True)
153
  def load_model(model_ref: str, device_str: str):
154
  resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
 
157
  model.eval()
158
  return model
159
 
160
+
161
  @st.cache_data(show_spinner=False)
162
  def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
163
  return df.to_csv(index=False).encode("utf-8")
164
 
165
+
166
  # ------------------------------
167
  # Main interaction
168
  # ------------------------------
 
180
 
181
  if uploaded is not None:
182
  try:
183
+ file_bytes = uploaded.getvalue()
184
+ raw_df = parse_csv_from_bytes(file_bytes)
185
+ # raw_df = pd.read_csv(uploaded)
186
  st.subheader("Input preview")
187
  st.dataframe(raw_df.head(20), use_container_width=True)
188
  except Exception as e:
 
204
  # ------------------------------
205
  # Inference
206
  # ------------------------------
207
+ if "results_df" not in st.session_state:
208
+ st.session_state["results_df"] = None
209
 
210
+ if "last_error" not in st.session_state:
211
+ st.session_state["last_error"] = None
212
 
213
  if run:
214
  if uploaded is None:
 
237
  model = load_model(CFG.model_name_or_path, device.type)
238
  status.update(label="Model ready.", state="complete")
239
  except Exception as e:
240
+ st.session_state["last_error"] = f"Failed to load model: {e}"
241
  status.update(label="Model load failed.", state="error")
242
  st.stop()
243
 
244
  # Prepare data
245
+ file_bytes = uploaded.getvalue()
246
+ input_df = parse_csv_from_bytes(file_bytes)
247
+ # input_df = pd.read_csv(uploaded)
248
  input_df = preprocess_df(input_df, drop_duplicates=False)
 
249
 
250
  # Dataset & loader
251
  dataset = ReactionT5Dataset(CFG, input_df)
 
294
 
295
  # Save predictions
296
  try:
297
+ output_df = save_multiple_predictions(
298
+ input_df, all_sequences, all_scores, CFG
299
+ )
300
+ st.session_state["results_df"] = output_df
301
  st.success("Prediction complete.")
302
  except Exception as e:
303
+ st.session_state["last_error"] = f"Failed to assemble output: {e}"
304
+ st.error(st.session_state["last_error"])
305
  st.stop()
306
 
307
  # ------------------------------
308
  # Results
309
  # ------------------------------
310
+ if st.session_state.get("results_df") is not None:
311
  st.subheader("Results preview")
312
+ st.dataframe(st.session_state["results_df"].head(50), use_container_width=True)
313
 
314
  st.download_button(
315
  label="Download predictions as CSV",
316
+ data=df_to_csv_bytes(st.session_state["results_df"]),
317
  file_name="output.csv",
318
  mime="text/csv",
319
  use_container_width=True,
320
  )
321
 
322
+ if st.session_state.get("last_error"):
323
+ st.error(st.session_state["last_error"])