leuschnm commited on
Commit
ac3bcbd
·
1 Parent(s): 0ebac9a
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -62,7 +62,7 @@ def predict(_model, _dataloader, datepicker):
62
  def update_plot(df, preds):
63
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
64
  df = df[~df["pred"].isna()]
65
- df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
66
 
67
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
68
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
@@ -71,7 +71,7 @@ def update_plot(df, preds):
71
  return st.pyplot(fig)
72
 
73
  @st.cache_resource
74
- def generate_plot():
75
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
76
 
77
  # Plot scatter plots for each group
@@ -142,7 +142,7 @@ def main():
142
 
143
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
144
 
145
- generate_plot()
146
 
147
  if st.button("Forecast Sales", type="primary"):
148
  dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)
 
62
  def update_plot(df, preds):
63
  df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left")
64
  df = df[~df["pred"].isna()]
65
+ #df[["sales", "pred"]] = df[["sales", "pred"]].replace(0.0, np.nan)
66
 
67
  axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
68
  axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red')
 
71
  return st.pyplot(fig)
72
 
73
  @st.cache_resource
74
+ def generate_plot(df):
75
  fig, axs = plt.subplots(2, 2, figsize=(8, 6))
76
 
77
  # Plot scatter plots for each group
 
142
 
143
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 10, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
144
 
145
+ generate_plot(df)
146
 
147
  if st.button("Forecast Sales", type="primary"):
148
  dataloader = prepare_dataset(parameters, df.copy(), rain, temperature, datepicker, rain_mapping)