Spaces:
Sleeping
Sleeping
| import random | |
| import re | |
| import warnings | |
| import gym | |
| import numpy as np | |
| import pettingzoo | |
| def test_state_space(env): | |
| assert isinstance(env.state_space, gym.spaces.Space), "State space for each environment must extend gym.spaces.Space" | |
| if not (isinstance(env.state_space, gym.spaces.Box) or isinstance(env.state_space, gym.spaces.Discrete)): | |
| warnings.warn("State space for each environment probably should be gym.spaces.box or gym.spaces.discrete") | |
| if isinstance(env.state_space, gym.spaces.Box): | |
| if np.any(np.equal(env.state_space.low, -np.inf)): | |
| warnings.warn("Environment's minimum state space value is -infinity. This is probably too low.") | |
| if np.any(np.equal(env.state_space.high, np.inf)): | |
| warnings.warn("Environment's maxmimum state space value is infinity. This is probably too high") | |
| if np.any(np.equal(env.state_space.low, env.state_space.high)): | |
| warnings.warn("Environment's maximum and minimum state space values are equal") | |
| if np.any(np.greater(env.state_space.low, env.state_space.high)): | |
| assert False, "Environment's minimum state space value is greater than it's maximum" | |
| if env.state_space.low.shape != env.state_space.shape: | |
| assert False, "Environment's state_space.low and state_space have different shapes" | |
| if env.state_space.high.shape != env.state_space.shape: | |
| assert False, "Environment's state_space.high and state_space have different shapes" | |
| def test_state(env, num_cycles): | |
| env.reset() | |
| state_0 = env.state() | |
| for agent in env.agent_iter(env.num_agents * num_cycles): | |
| observation, reward, done, info = env.last(observe=False) | |
| if done: | |
| action = None | |
| else: | |
| action = env.action_space(agent).sample() | |
| env.step(action) | |
| new_state = env.state() | |
| assert env.state_space.contains(new_state), "Environment's state is outside of it's state space" | |
| if isinstance(new_state, np.ndarray): | |
| if np.isinf(new_state).any(): | |
| warnings.warn("State contains infinity (np.inf) or negative infinity (-np.inf)") | |
| if np.isnan(new_state).any(): | |
| warnings.warn("State contains NaNs") | |
| if len(new_state.shape) > 3: | |
| warnings.warn("State has more than 3 dimensions") | |
| if new_state.shape == (0,): | |
| assert False, "State can not be an empty array" | |
| if new_state.shape == (1,): | |
| warnings.warn("State is a single number") | |
| if not isinstance(new_state, state_0.__class__): | |
| warnings.warn("State between Observations are different classes") | |
| if (new_state.shape != state_0.shape) and (len(new_state.shape) == len(state_0.shape)): | |
| warnings.warn("States are different shapes") | |
| if len(new_state.shape) != len(state_0.shape): | |
| warnings.warn("States have different number of dimensions") | |
| if not np.can_cast(new_state.dtype, np.dtype("float64")): | |
| warnings.warn("State numpy array is not a numeric dtype") | |
| if np.array_equal(new_state, np.zeros(new_state.shape)): | |
| warnings.warn("State numpy array is all zeros.") | |
| if not np.all(new_state >= 0) and ((len(new_state.shape) == 2) or (len(new_state.shape) == 3 and new_state.shape[2] == 1) or (len(new_state.shape) == 3 and new_state.shape[2] == 3)): | |
| warnings.warn("The state contains negative numbers and is in the shape of a graphical observation. This might be a bad thing.") | |
| else: | |
| warnings.warn("State is not NumPy array") | |
| def test_parallel_env(parallel_env): | |
| parallel_env.reset() | |
| assert isinstance(parallel_env.state_space, gym.spaces.Space), "State space for each parallel environment must extend gym.spaces.Space" | |
| state_0 = parallel_env.state() | |
| assert parallel_env.state_space.contains(state_0), "ParallelEnvironment's state is outside of it's state space" | |
| def state_test(env, parallel_env, num_cycles=10): | |
| test_state_space(env) | |
| test_state(env, num_cycles) | |
| test_parallel_env(parallel_env) | |