cpwan commited on
Commit
a4e57fd
·
1 Parent(s): 9c0f3da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -58
app.py CHANGED
@@ -2,16 +2,17 @@ import numpy as np
2
  import torch
3
  import gym
4
  from models.attention_model_wrapper import Agent
5
- device = 'cpu'
6
- ckpt_path = './runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt'
7
- agent = Agent(device=device, name='tsp').to(device)
8
- agent.load_state_dict(torch.load(ckpt_path))
 
9
 
10
  from wrappers.syncVectorEnvPomo import SyncVectorEnv
11
  from wrappers.recordWrapper import RecordEpisodeStatistics
12
 
13
- env_id = 'tsp-v0'
14
- env_entry_point = 'envs.tsp_vector_env:TSPVectorEnv'
15
  seed = 0
16
 
17
  gym.envs.register(
@@ -19,6 +20,7 @@ gym.envs.register(
19
  entry_point=env_entry_point,
20
  )
21
 
 
22
  def make_env(env_id, seed, cfg={}):
23
  def thunk():
24
  env = gym.make(env_id, **cfg)
@@ -27,109 +29,124 @@ def make_env(env_id, seed, cfg={}):
27
  env.action_space.seed(seed)
28
  env.observation_space.seed(seed)
29
  return env
 
30
  return thunk
31
 
32
 
33
  def inference(data):
34
- envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=1,
35
- max_nodes = len(data),
36
- eval_data = 'from_input',
37
- eval_data_from_input = data))])
 
 
 
38
 
39
  trajectories = []
40
  agent.eval()
41
  obs = envs.reset()
42
  done = np.array([False])
43
  while not done.all():
44
- # ALGO LOGIC: action logic
45
  with torch.no_grad():
46
  action, logits = agent(obs)
47
  obs, reward, done, info = envs.step(action.cpu().numpy())
48
  trajectories.append(action.cpu().numpy())
49
- nodes_coordinates = obs['observations'][0]
50
- final_return = info[0]['episode']['r']
51
- resulting_traj = np.array(trajectories)[:,0,0]
52
  return resulting_traj, final_return
53
 
54
- default_data = np.array([[0.5488135 , 0.71518937],
55
- [0.60276338, 0.54488318],
56
- [0.4236548 , 0.64589411],
57
- [0.43758721, 0.891773 ],
58
- [0.96366276, 0.38344152],
59
- [0.79172504, 0.52889492],
60
- [0.56804456, 0.92559664],
61
- [0.07103606, 0.0871293 ],
62
- [0.0202184 , 0.83261985],
63
- [0.77815675, 0.87001215],])
64
-
65
- #@title Helper function for plotting
 
 
 
 
 
66
  # colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
67
  import matplotlib.pyplot as plt
68
  from matplotlib.collections import LineCollection
69
  from matplotlib.colors import ListedColormap, BoundaryNorm
70
 
 
71
  def make_segments(x, y):
72
- '''
73
  Create list of line segments from x and y coordinates, in the correct format for LineCollection:
74
  an array of the form numlines x (points per line) x 2 (x and y) array
75
- '''
76
 
77
  points = np.array([x, y]).T.reshape(-1, 1, 2)
78
  segments = np.concatenate([points[:-1], points[1:]], axis=1)
79
-
80
  return segments
81
 
82
- def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
83
- '''
 
84
  Plot a colored line with coordinates x and y
85
  Optionally specify colors in the array z
86
  Optionally specify a colormap, a norm function and a line width
87
- '''
88
-
89
  # Default colors equally spaced on [0,1]:
90
  if z is None:
91
  z = np.linspace(0.3, 1.0, len(x))
92
-
93
  # Special case if a single number:
94
  if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
95
  z = np.array([z])
96
-
97
  z = np.asarray(z)
98
-
99
  segments = make_segments(x, y)
100
  lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
101
-
102
  ax = plt.gca()
103
  ax.add_collection(lc)
104
-
105
  return lc
106
 
 
107
  def plot(coords):
108
  fig = plt.figure()
109
- x,y = coords.T
110
- lc = colorline(x,y,cmap='Reds')
111
- plt.axis('square')
112
  return fig
113
 
 
114
  import gradio as gr
115
 
 
116
  def run_inference(data):
117
  data = data.astype(float).to_numpy()
118
  resulting_traj, final_return = inference(data)
119
- result_text = f'Planned Tour:\t{resulting_traj}\nTotal tour length:\t{final_return[0]:.2f}'
120
- return [plot(data[resulting_traj]),result_text]
121
-
122
- demo = gr.Interface(run_inference, gr.Dataframe(
123
- label = 'Input',
124
- headers=['x','y'],
125
- row_count=10,
126
- col_count=(2, "fixed"),
127
- max_rows = 10,
128
- value = default_data.tolist(),
129
- overflow_row_behaviour = 'show_ends'
130
- ),
131
- [gr.Plot(label= 'Results Visualization'),
132
- gr.Code(label= 'Results',
133
- interactive=False)])
134
- demo.launch(share = True)
135
-
 
 
2
  import torch
3
  import gym
4
  from models.attention_model_wrapper import Agent
5
+
6
+ device = "cpu"
7
+ ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt"
8
+ agent = Agent(device=device, name="tsp").to(device)
9
+ agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
10
 
11
  from wrappers.syncVectorEnvPomo import SyncVectorEnv
12
  from wrappers.recordWrapper import RecordEpisodeStatistics
13
 
14
+ env_id = "tsp-v0"
15
+ env_entry_point = "envs.tsp_vector_env:TSPVectorEnv"
16
  seed = 0
17
 
18
  gym.envs.register(
 
20
  entry_point=env_entry_point,
21
  )
22
 
23
+
24
  def make_env(env_id, seed, cfg={}):
25
  def thunk():
26
  env = gym.make(env_id, **cfg)
 
29
  env.action_space.seed(seed)
30
  env.observation_space.seed(seed)
31
  return env
32
+
33
  return thunk
34
 
35
 
36
  def inference(data):
37
+ envs = SyncVectorEnv(
38
+ [
39
+ make_env(
40
+ env_id, seed, dict(n_traj=1, max_nodes=len(data), eval_data="from_input", eval_data_from_input=data)
41
+ )
42
+ ]
43
+ )
44
 
45
  trajectories = []
46
  agent.eval()
47
  obs = envs.reset()
48
  done = np.array([False])
49
  while not done.all():
50
+ # ALGO LOGIC: action logic
51
  with torch.no_grad():
52
  action, logits = agent(obs)
53
  obs, reward, done, info = envs.step(action.cpu().numpy())
54
  trajectories.append(action.cpu().numpy())
55
+ nodes_coordinates = obs["observations"][0]
56
+ final_return = info[0]["episode"]["r"]
57
+ resulting_traj = np.array(trajectories)[:, 0, 0]
58
  return resulting_traj, final_return
59
 
60
+
61
+ default_data = np.array(
62
+ [
63
+ [0.5488135, 0.71518937],
64
+ [0.60276338, 0.54488318],
65
+ [0.4236548, 0.64589411],
66
+ [0.43758721, 0.891773],
67
+ [0.96366276, 0.38344152],
68
+ [0.79172504, 0.52889492],
69
+ [0.56804456, 0.92559664],
70
+ [0.07103606, 0.0871293],
71
+ [0.0202184, 0.83261985],
72
+ [0.77815675, 0.87001215],
73
+ ]
74
+ )
75
+
76
+ # @title Helper function for plotting
77
  # colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
78
  import matplotlib.pyplot as plt
79
  from matplotlib.collections import LineCollection
80
  from matplotlib.colors import ListedColormap, BoundaryNorm
81
 
82
+
83
  def make_segments(x, y):
84
+ """
85
  Create list of line segments from x and y coordinates, in the correct format for LineCollection:
86
  an array of the form numlines x (points per line) x 2 (x and y) array
87
+ """
88
 
89
  points = np.array([x, y]).T.reshape(-1, 1, 2)
90
  segments = np.concatenate([points[:-1], points[1:]], axis=1)
91
+
92
  return segments
93
 
94
+
95
+ def colorline(x, y, z=None, cmap=plt.get_cmap("copper"), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
96
+ """
97
  Plot a colored line with coordinates x and y
98
  Optionally specify colors in the array z
99
  Optionally specify a colormap, a norm function and a line width
100
+ """
101
+
102
  # Default colors equally spaced on [0,1]:
103
  if z is None:
104
  z = np.linspace(0.3, 1.0, len(x))
105
+
106
  # Special case if a single number:
107
  if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
108
  z = np.array([z])
109
+
110
  z = np.asarray(z)
111
+
112
  segments = make_segments(x, y)
113
  lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
114
+
115
  ax = plt.gca()
116
  ax.add_collection(lc)
117
+
118
  return lc
119
 
120
+
121
  def plot(coords):
122
  fig = plt.figure()
123
+ x, y = coords.T
124
+ lc = colorline(x, y, cmap="Reds")
125
+ plt.axis("square")
126
  return fig
127
 
128
+
129
  import gradio as gr
130
 
131
+
132
  def run_inference(data):
133
  data = data.astype(float).to_numpy()
134
  resulting_traj, final_return = inference(data)
135
+ result_text = f"Planned Tour:\t{resulting_traj}\nTotal tour length:\t{final_return[0]:.2f}"
136
+ return [plot(data[resulting_traj]), result_text]
137
+
138
+
139
+ demo = gr.Interface(
140
+ run_inference,
141
+ gr.Dataframe(
142
+ label="Input",
143
+ headers=["x", "y"],
144
+ row_count=10,
145
+ col_count=(2, "fixed"),
146
+ max_rows=10,
147
+ value=default_data.tolist(),
148
+ overflow_row_behaviour="show_ends",
149
+ ),
150
+ [gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)],
151
+ )
152
+ demo.launch()