# Visualize dataset

Utilities to visualize episodes from a dataset. 

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import animation
import pathlib
from IPython.display import Video
import numpy as np
import os

dataset_path = pathlib.Path(os.path.abspath('')).parent / 'data/stickman_example'

directory = dataset_path.expanduser()
filenames = sorted(directory.glob('*.npz'))
if len(filenames) == 0:
    raise ValueError("Empty directory (or no episodes)")

try:
    filenames_dict = { int(str(f).replace(str(dataset_path), "").split("-")[0][1:]) : f for f in filenames}
except Exception as e:
    print("Error:", e)

print(directory)
print(len(filenames))

In [None]:
ep_num = next(iter(filenames_dict))

filename = filenames_dict[ep_num]
with filename.open('rb') as f:
    episode = np.load(f)
    episode = {k: episode[k] for k in episode.keys()}

# Show reward on top with red/green bar
pix_rew_max = np.round(episode['reward'] / 2 * 64)
for ob, pix_n in zip(episode['observation'], pix_rew_max):
    if pix_n < 0:
        pix_n = abs(pix_n)
        ob[:, 0, :int(pix_n+1)] = np.array([255,0,0]).reshape(3,1)
    else:
        ob[:, 0, :int(pix_n+1)] = np.array([0,255,0]).reshape(3,1)

# # np array with shape (frames, height, width, channels)
video = np.transpose(episode['observation'], axes=[0,2,3,1])

fig = plt.figure(frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
fig.set_size_inches(2,2)
im = ax.imshow(video[0,:,:,:])
plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

print('Episode reward', np.sum(episode['reward']))
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],interval=45)
file_path = str(pathlib.Path(os.path.abspath('')) / 'videos/temp.mp4')
anim.save(file_path)
print('Video file', file_path)
Video(file_path)