Kano001's picture
Upload 654 files
3f7c971 verified
raw
history blame
1.66 kB
import random
import numpy as np
def collect_render_results(env, mode):
results = []
env.reset()
for i in range(5):
if i > 0:
for agent in env.agent_iter(env.num_agents // 2 + 1):
obs, reward, done, info = env.last()
if done:
action = None
elif isinstance(obs, dict) and 'action_mask' in obs:
action = random.choice(np.flatnonzero(obs['action_mask']))
else:
action = env.action_space(agent).sample()
env.step(action)
render_result = env.render(mode=mode)
results.append(render_result)
return results
def render_test(env_fn, custom_tests={}):
env = env_fn()
render_modes = env.metadata.get('render.modes')[:]
assert render_modes is not None, "Environment's that support rendering must define render modes in metadata"
for mode in render_modes:
render_results = collect_render_results(env, mode)
for res in render_results:
if mode in custom_tests.keys():
assert custom_tests[mode](res)
if mode == 'rgb_array':
assert isinstance(res, np.ndarray) and len(res.shape) == 3 and res.shape[2] == 3 and res.dtype == np.uint8, f"rgb_array mode must return a valid image array, is {res}"
if mode == 'ansi':
assert isinstance(res, str) # and len(res.shape) == 3 and res.shape[2] == 3 and res.dtype == np.uint8, "rgb_array mode must have shit in it"
if mode == "human":
assert res is None
env.close()
env = env_fn()