cpwan commited on
Commit
4d5f005
·
1 Parent(s): a4e57fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -3,13 +3,21 @@ import torch
3
  import gym
4
  from models.attention_model_wrapper import Agent
5
 
 
 
 
 
 
 
 
 
 
 
6
  device = "cpu"
7
  ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt"
8
  agent = Agent(device=device, name="tsp").to(device)
9
  agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
10
 
11
- from wrappers.syncVectorEnvPomo import SyncVectorEnv
12
- from wrappers.recordWrapper import RecordEpisodeStatistics
13
 
14
  env_id = "tsp-v0"
15
  env_entry_point = "envs.tsp_vector_env:TSPVectorEnv"
@@ -75,9 +83,6 @@ default_data = np.array(
75
 
76
  # @title Helper function for plotting
77
  # colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
78
- import matplotlib.pyplot as plt
79
- from matplotlib.collections import LineCollection
80
- from matplotlib.colors import ListedColormap, BoundaryNorm
81
 
82
 
83
  def make_segments(x, y):
@@ -126,9 +131,6 @@ def plot(coords):
126
  return fig
127
 
128
 
129
- import gradio as gr
130
-
131
-
132
  def run_inference(data):
133
  data = data.astype(float).to_numpy()
134
  resulting_traj, final_return = inference(data)
@@ -149,4 +151,4 @@ demo = gr.Interface(
149
  ),
150
  [gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)],
151
  )
152
- demo.launch()
 
3
  import gym
4
  from models.attention_model_wrapper import Agent
5
 
6
+ from wrappers.syncVectorEnvPomo import SyncVectorEnv
7
+ from wrappers.recordWrapper import RecordEpisodeStatistics
8
+
9
+
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.collections import LineCollection
12
+ from matplotlib.colors import ListedColormap, BoundaryNorm
13
+
14
+ import gradio as gr
15
+
16
  device = "cpu"
17
  ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt"
18
  agent = Agent(device=device, name="tsp").to(device)
19
  agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
20
 
 
 
21
 
22
  env_id = "tsp-v0"
23
  env_entry_point = "envs.tsp_vector_env:TSPVectorEnv"
 
83
 
84
  # @title Helper function for plotting
85
  # colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
 
 
 
86
 
87
 
88
  def make_segments(x, y):
 
131
  return fig
132
 
133
 
 
 
 
134
  def run_inference(data):
135
  data = data.astype(float).to_numpy()
136
  resulting_traj, final_return = inference(data)
 
151
  ),
152
  [gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)],
153
  )
154
+ demo.launch(share=True)