Spaces:
Running
Running
Commit
·
1d48158
1
Parent(s):
befc43b
add demo
Browse files
app.py
CHANGED
@@ -31,16 +31,26 @@ def line_plot_fn(data, cutoff, ci_form):
|
|
31 |
y = torch.from_numpy(data.y.values).float().unsqueeze(1)
|
32 |
|
33 |
rest_prob = (1 - (ci / 100)) / 2
|
34 |
-
predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[
|
35 |
|
36 |
fig, ax = plt.subplots()
|
37 |
|
38 |
ax.plot(x, data.y, "black", label="target")
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# plot extrapolation
|
41 |
ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
|
42 |
ax.fill_between(
|
43 |
-
x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {
|
44 |
)
|
45 |
|
46 |
# plot cutoff
|
|
|
31 |
y = torch.from_numpy(data.y.values).float().unsqueeze(1)
|
32 |
|
33 |
rest_prob = (1 - (ci / 100)) / 2
|
34 |
+
predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[rest_prob, 0.5, 1-rest_prob])
|
35 |
|
36 |
fig, ax = plt.subplots()
|
37 |
|
38 |
ax.plot(x, data.y, "black", label="target")
|
39 |
|
40 |
+
predictions = predictions.numpy()
|
41 |
+
new = np.array([y[cutoff-1], y[cutoff-1], y[cutoff-1]]).reshape(1, 3)
|
42 |
+
predictions = np.concatenate(
|
43 |
+
[
|
44 |
+
new,
|
45 |
+
predictions
|
46 |
+
],
|
47 |
+
axis=0
|
48 |
+
)
|
49 |
+
|
50 |
# plot extrapolation
|
51 |
ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
|
52 |
ax.fill_between(
|
53 |
+
x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {ci}%"
|
54 |
)
|
55 |
|
56 |
# plot cutoff
|