sagawa commited on
Commit
955d374
·
verified ·
1 Parent(s): ac1e254

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -4,6 +4,7 @@ import warnings
4
  from types import SimpleNamespace
5
 
6
  import pandas as pd
 
7
  import streamlit as st
8
  import torch
9
 
@@ -112,7 +113,6 @@ with st.sidebar:
112
  model_help = "Default model for yield prediction."
113
  input_max_length_default = 400
114
  from task_yield.train import preprocess_df
115
- from task_yield.prediction import inference_fn
116
 
117
  model_name_or_path = st.selectbox(
118
  "Model",
@@ -311,7 +311,20 @@ if run:
311
 
312
  if task == "yield prediction":
313
  # Use custom inference function for yield prediction
314
- prediction = inference_fn(dataloader, model, CFG)
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  output_df = input_df.copy()
316
  output_df["prediction"] = prediction
317
  output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0)
 
4
  from types import SimpleNamespace
5
 
6
  import pandas as pd
7
+ import numpy as np
8
  import streamlit as st
9
  import torch
10
 
 
113
  model_help = "Default model for yield prediction."
114
  input_max_length_default = 400
115
  from task_yield.train import preprocess_df
 
116
 
117
  model_name_or_path = st.selectbox(
118
  "Model",
 
311
 
312
  if task == "yield prediction":
313
  # Use custom inference function for yield prediction
314
+ prediction = []
315
+ total = len(dataloader)
316
+ progress = st.progress(0, text="Predicting yields...")
317
+ info_placeholder = st.empty()
318
+ for i, inputs in enumerate(dataloader, start=1):
319
+ inputs = {k: v.to(device) for k, v in inputs.items()}
320
+ with torch.no_grad():
321
+ y_preds = model(inputs)
322
+ prediction.extend(y_preds.to("cpu").numpy())
323
+ del y_preds
324
+ progress.progress(i / total, text=f"Predicting yields... {i}/{total}")
325
+ info_placeholder.caption(f"Processed batch {i} of {total}")
326
+
327
+ prediction = np.concatenate(prediction)
328
  output_df = input_df.copy()
329
  output_df["prediction"] = prediction
330
  output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0)