Spaces:
Sleeping
Sleeping
File size: 4,206 Bytes
3f7c971 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import random
import re
import warnings
import gym
import numpy as np
import pettingzoo
def test_state_space(env):
assert isinstance(env.state_space, gym.spaces.Space), "State space for each environment must extend gym.spaces.Space"
if not (isinstance(env.state_space, gym.spaces.Box) or isinstance(env.state_space, gym.spaces.Discrete)):
warnings.warn("State space for each environment probably should be gym.spaces.box or gym.spaces.discrete")
if isinstance(env.state_space, gym.spaces.Box):
if np.any(np.equal(env.state_space.low, -np.inf)):
warnings.warn("Environment's minimum state space value is -infinity. This is probably too low.")
if np.any(np.equal(env.state_space.high, np.inf)):
warnings.warn("Environment's maxmimum state space value is infinity. This is probably too high")
if np.any(np.equal(env.state_space.low, env.state_space.high)):
warnings.warn("Environment's maximum and minimum state space values are equal")
if np.any(np.greater(env.state_space.low, env.state_space.high)):
assert False, "Environment's minimum state space value is greater than it's maximum"
if env.state_space.low.shape != env.state_space.shape:
assert False, "Environment's state_space.low and state_space have different shapes"
if env.state_space.high.shape != env.state_space.shape:
assert False, "Environment's state_space.high and state_space have different shapes"
def test_state(env, num_cycles):
env.reset()
state_0 = env.state()
for agent in env.agent_iter(env.num_agents * num_cycles):
observation, reward, done, info = env.last(observe=False)
if done:
action = None
else:
action = env.action_space(agent).sample()
env.step(action)
new_state = env.state()
assert env.state_space.contains(new_state), "Environment's state is outside of it's state space"
if isinstance(new_state, np.ndarray):
if np.isinf(new_state).any():
warnings.warn("State contains infinity (np.inf) or negative infinity (-np.inf)")
if np.isnan(new_state).any():
warnings.warn("State contains NaNs")
if len(new_state.shape) > 3:
warnings.warn("State has more than 3 dimensions")
if new_state.shape == (0,):
assert False, "State can not be an empty array"
if new_state.shape == (1,):
warnings.warn("State is a single number")
if not isinstance(new_state, state_0.__class__):
warnings.warn("State between Observations are different classes")
if (new_state.shape != state_0.shape) and (len(new_state.shape) == len(state_0.shape)):
warnings.warn("States are different shapes")
if len(new_state.shape) != len(state_0.shape):
warnings.warn("States have different number of dimensions")
if not np.can_cast(new_state.dtype, np.dtype("float64")):
warnings.warn("State numpy array is not a numeric dtype")
if np.array_equal(new_state, np.zeros(new_state.shape)):
warnings.warn("State numpy array is all zeros.")
if not np.all(new_state >= 0) and ((len(new_state.shape) == 2) or (len(new_state.shape) == 3 and new_state.shape[2] == 1) or (len(new_state.shape) == 3 and new_state.shape[2] == 3)):
warnings.warn("The state contains negative numbers and is in the shape of a graphical observation. This might be a bad thing.")
else:
warnings.warn("State is not NumPy array")
def test_parallel_env(parallel_env):
parallel_env.reset()
assert isinstance(parallel_env.state_space, gym.spaces.Space), "State space for each parallel environment must extend gym.spaces.Space"
state_0 = parallel_env.state()
assert parallel_env.state_space.contains(state_0), "ParallelEnvironment's state is outside of it's state space"
def state_test(env, parallel_env, num_cycles=10):
test_state_space(env)
test_state(env, num_cycles)
test_parallel_env(parallel_env)
|