Update app.py
Browse files
app.py
CHANGED
@@ -2,16 +2,17 @@ import numpy as np
|
|
2 |
import torch
|
3 |
import gym
|
4 |
from models.attention_model_wrapper import Agent
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
agent
|
|
|
9 |
|
10 |
from wrappers.syncVectorEnvPomo import SyncVectorEnv
|
11 |
from wrappers.recordWrapper import RecordEpisodeStatistics
|
12 |
|
13 |
-
env_id =
|
14 |
-
env_entry_point =
|
15 |
seed = 0
|
16 |
|
17 |
gym.envs.register(
|
@@ -19,6 +20,7 @@ gym.envs.register(
|
|
19 |
entry_point=env_entry_point,
|
20 |
)
|
21 |
|
|
|
22 |
def make_env(env_id, seed, cfg={}):
|
23 |
def thunk():
|
24 |
env = gym.make(env_id, **cfg)
|
@@ -27,109 +29,124 @@ def make_env(env_id, seed, cfg={}):
|
|
27 |
env.action_space.seed(seed)
|
28 |
env.observation_space.seed(seed)
|
29 |
return env
|
|
|
30 |
return thunk
|
31 |
|
32 |
|
33 |
def inference(data):
|
34 |
-
envs = SyncVectorEnv(
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
trajectories = []
|
40 |
agent.eval()
|
41 |
obs = envs.reset()
|
42 |
done = np.array([False])
|
43 |
while not done.all():
|
44 |
-
|
45 |
with torch.no_grad():
|
46 |
action, logits = agent(obs)
|
47 |
obs, reward, done, info = envs.step(action.cpu().numpy())
|
48 |
trajectories.append(action.cpu().numpy())
|
49 |
-
nodes_coordinates = obs[
|
50 |
-
final_return = info[0][
|
51 |
-
resulting_traj = np.array(trajectories)[:,0,0]
|
52 |
return resulting_traj, final_return
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
|
67 |
import matplotlib.pyplot as plt
|
68 |
from matplotlib.collections import LineCollection
|
69 |
from matplotlib.colors import ListedColormap, BoundaryNorm
|
70 |
|
|
|
71 |
def make_segments(x, y):
|
72 |
-
|
73 |
Create list of line segments from x and y coordinates, in the correct format for LineCollection:
|
74 |
an array of the form numlines x (points per line) x 2 (x and y) array
|
75 |
-
|
76 |
|
77 |
points = np.array([x, y]).T.reshape(-1, 1, 2)
|
78 |
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
79 |
-
|
80 |
return segments
|
81 |
|
82 |
-
|
83 |
-
|
|
|
84 |
Plot a colored line with coordinates x and y
|
85 |
Optionally specify colors in the array z
|
86 |
Optionally specify a colormap, a norm function and a line width
|
87 |
-
|
88 |
-
|
89 |
# Default colors equally spaced on [0,1]:
|
90 |
if z is None:
|
91 |
z = np.linspace(0.3, 1.0, len(x))
|
92 |
-
|
93 |
# Special case if a single number:
|
94 |
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
|
95 |
z = np.array([z])
|
96 |
-
|
97 |
z = np.asarray(z)
|
98 |
-
|
99 |
segments = make_segments(x, y)
|
100 |
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
|
101 |
-
|
102 |
ax = plt.gca()
|
103 |
ax.add_collection(lc)
|
104 |
-
|
105 |
return lc
|
106 |
|
|
|
107 |
def plot(coords):
|
108 |
fig = plt.figure()
|
109 |
-
x,y = coords.T
|
110 |
-
lc = colorline(x,y,cmap=
|
111 |
-
plt.axis(
|
112 |
return fig
|
113 |
|
|
|
114 |
import gradio as gr
|
115 |
|
|
|
116 |
def run_inference(data):
|
117 |
data = data.astype(float).to_numpy()
|
118 |
resulting_traj, final_return = inference(data)
|
119 |
-
result_text = f
|
120 |
-
return [plot(data[resulting_traj]),result_text]
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
2 |
import torch
|
3 |
import gym
|
4 |
from models.attention_model_wrapper import Agent
|
5 |
+
|
6 |
+
device = "cpu"
|
7 |
+
ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt"
|
8 |
+
agent = Agent(device=device, name="tsp").to(device)
|
9 |
+
agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
|
10 |
|
11 |
from wrappers.syncVectorEnvPomo import SyncVectorEnv
|
12 |
from wrappers.recordWrapper import RecordEpisodeStatistics
|
13 |
|
14 |
+
env_id = "tsp-v0"
|
15 |
+
env_entry_point = "envs.tsp_vector_env:TSPVectorEnv"
|
16 |
seed = 0
|
17 |
|
18 |
gym.envs.register(
|
|
|
20 |
entry_point=env_entry_point,
|
21 |
)
|
22 |
|
23 |
+
|
24 |
def make_env(env_id, seed, cfg={}):
|
25 |
def thunk():
|
26 |
env = gym.make(env_id, **cfg)
|
|
|
29 |
env.action_space.seed(seed)
|
30 |
env.observation_space.seed(seed)
|
31 |
return env
|
32 |
+
|
33 |
return thunk
|
34 |
|
35 |
|
36 |
def inference(data):
|
37 |
+
envs = SyncVectorEnv(
|
38 |
+
[
|
39 |
+
make_env(
|
40 |
+
env_id, seed, dict(n_traj=1, max_nodes=len(data), eval_data="from_input", eval_data_from_input=data)
|
41 |
+
)
|
42 |
+
]
|
43 |
+
)
|
44 |
|
45 |
trajectories = []
|
46 |
agent.eval()
|
47 |
obs = envs.reset()
|
48 |
done = np.array([False])
|
49 |
while not done.all():
|
50 |
+
# ALGO LOGIC: action logic
|
51 |
with torch.no_grad():
|
52 |
action, logits = agent(obs)
|
53 |
obs, reward, done, info = envs.step(action.cpu().numpy())
|
54 |
trajectories.append(action.cpu().numpy())
|
55 |
+
nodes_coordinates = obs["observations"][0]
|
56 |
+
final_return = info[0]["episode"]["r"]
|
57 |
+
resulting_traj = np.array(trajectories)[:, 0, 0]
|
58 |
return resulting_traj, final_return
|
59 |
|
60 |
+
|
61 |
+
default_data = np.array(
|
62 |
+
[
|
63 |
+
[0.5488135, 0.71518937],
|
64 |
+
[0.60276338, 0.54488318],
|
65 |
+
[0.4236548, 0.64589411],
|
66 |
+
[0.43758721, 0.891773],
|
67 |
+
[0.96366276, 0.38344152],
|
68 |
+
[0.79172504, 0.52889492],
|
69 |
+
[0.56804456, 0.92559664],
|
70 |
+
[0.07103606, 0.0871293],
|
71 |
+
[0.0202184, 0.83261985],
|
72 |
+
[0.77815675, 0.87001215],
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
# @title Helper function for plotting
|
77 |
# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
|
78 |
import matplotlib.pyplot as plt
|
79 |
from matplotlib.collections import LineCollection
|
80 |
from matplotlib.colors import ListedColormap, BoundaryNorm
|
81 |
|
82 |
+
|
83 |
def make_segments(x, y):
|
84 |
+
"""
|
85 |
Create list of line segments from x and y coordinates, in the correct format for LineCollection:
|
86 |
an array of the form numlines x (points per line) x 2 (x and y) array
|
87 |
+
"""
|
88 |
|
89 |
points = np.array([x, y]).T.reshape(-1, 1, 2)
|
90 |
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
91 |
+
|
92 |
return segments
|
93 |
|
94 |
+
|
95 |
+
def colorline(x, y, z=None, cmap=plt.get_cmap("copper"), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
|
96 |
+
"""
|
97 |
Plot a colored line with coordinates x and y
|
98 |
Optionally specify colors in the array z
|
99 |
Optionally specify a colormap, a norm function and a line width
|
100 |
+
"""
|
101 |
+
|
102 |
# Default colors equally spaced on [0,1]:
|
103 |
if z is None:
|
104 |
z = np.linspace(0.3, 1.0, len(x))
|
105 |
+
|
106 |
# Special case if a single number:
|
107 |
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
|
108 |
z = np.array([z])
|
109 |
+
|
110 |
z = np.asarray(z)
|
111 |
+
|
112 |
segments = make_segments(x, y)
|
113 |
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
|
114 |
+
|
115 |
ax = plt.gca()
|
116 |
ax.add_collection(lc)
|
117 |
+
|
118 |
return lc
|
119 |
|
120 |
+
|
121 |
def plot(coords):
|
122 |
fig = plt.figure()
|
123 |
+
x, y = coords.T
|
124 |
+
lc = colorline(x, y, cmap="Reds")
|
125 |
+
plt.axis("square")
|
126 |
return fig
|
127 |
|
128 |
+
|
129 |
import gradio as gr
|
130 |
|
131 |
+
|
132 |
def run_inference(data):
|
133 |
data = data.astype(float).to_numpy()
|
134 |
resulting_traj, final_return = inference(data)
|
135 |
+
result_text = f"Planned Tour:\t{resulting_traj}\nTotal tour length:\t{final_return[0]:.2f}"
|
136 |
+
return [plot(data[resulting_traj]), result_text]
|
137 |
+
|
138 |
+
|
139 |
+
demo = gr.Interface(
|
140 |
+
run_inference,
|
141 |
+
gr.Dataframe(
|
142 |
+
label="Input",
|
143 |
+
headers=["x", "y"],
|
144 |
+
row_count=10,
|
145 |
+
col_count=(2, "fixed"),
|
146 |
+
max_rows=10,
|
147 |
+
value=default_data.tolist(),
|
148 |
+
overflow_row_behaviour="show_ends",
|
149 |
+
),
|
150 |
+
[gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)],
|
151 |
+
)
|
152 |
+
demo.launch()
|