herilalaina commited on
Commit
1d48158
·
1 Parent(s): befc43b
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -31,16 +31,26 @@ def line_plot_fn(data, cutoff, ci_form):
31
  y = torch.from_numpy(data.y.values).float().unsqueeze(1)
32
 
33
  rest_prob = (1 - (ci / 100)) / 2
34
- predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[(cutoff-1):], qs=[rest_prob, 0.5, 1-rest_prob])
35
 
36
  fig, ax = plt.subplots()
37
 
38
  ax.plot(x, data.y, "black", label="target")
39
 
 
 
 
 
 
 
 
 
 
 
40
  # plot extrapolation
41
  ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
42
  ax.fill_between(
43
- x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {cutoff}%"
44
  )
45
 
46
  # plot cutoff
 
31
  y = torch.from_numpy(data.y.values).float().unsqueeze(1)
32
 
33
  rest_prob = (1 - (ci / 100)) / 2
34
+ predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[rest_prob, 0.5, 1-rest_prob])
35
 
36
  fig, ax = plt.subplots()
37
 
38
  ax.plot(x, data.y, "black", label="target")
39
 
40
+ predictions = predictions.numpy()
41
+ new = np.array([y[cutoff-1], y[cutoff-1], y[cutoff-1]]).reshape(1, 3)
42
+ predictions = np.concatenate(
43
+ [
44
+ new,
45
+ predictions
46
+ ],
47
+ axis=0
48
+ )
49
+
50
  # plot extrapolation
51
  ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
52
  ax.fill_between(
53
+ x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {ci}%"
54
  )
55
 
56
  # plot cutoff