Kano001's picture
Upload 654 files
3f7c971 verified
raw
history blame
3.41 kB
import hashlib
import pickle
import random
import warnings
import numpy as np
from pettingzoo.utils import parallel_to_aec
def hash(val):
val = pickle.dumps(val)
hasher = hashlib.md5()
hasher.update(val)
return hasher.hexdigest()
def calc_hash(new_env, rand_issue, max_env_iters):
cur_hashes = []
sampler = random.Random(42)
for i in range(3):
new_env.reset()
for j in range(rand_issue + 1):
random.randint(0, 1000)
np.random.normal(size=100)
for agent in new_env.agent_iter(max_env_iters):
obs, rew, done, info = new_env.last()
if done:
action = None
elif isinstance(obs, dict) and 'action_mask' in obs:
action = sampler.choice(np.flatnonzero(obs['action_mask']))
else:
action = new_env.action_space(agent).sample()
new_env.step(action)
cur_hashes.append(agent)
cur_hashes.append(hash_obsevation(obs))
cur_hashes.append(float(rew))
return hash(tuple(cur_hashes))
def seed_action_spaces(env):
if hasattr(env, 'possible_agents'):
for i, agent in enumerate(env.possible_agents):
env.action_space(agent).seed(42 + i)
def check_environment_deterministic(env1, env2, num_cycles):
'''
env1 and env2 should be seeded environments
returns a bool: true if env1 and env2 execute the same way
'''
# seeds action space so that actions are deterministic
seed_action_spaces(env1)
seed_action_spaces(env2)
num_agents = max(1, len(getattr(env1, 'possible_agents', [])))
# checks deterministic behavior if seed is set
hashes = []
num_seeds = 2
max_env_iters = num_cycles * num_agents
envs = [env1, env2]
for x in range(num_seeds):
hashes.append(calc_hash(envs[x], x, max_env_iters))
return all(hashes[0] == h for h in hashes)
def hash_obsevation(obs):
try:
val = hash(obs.tobytes())
return val
except AttributeError:
try:
return hash(obs)
except TypeError:
warnings.warn("Observation not an int or an Numpy array")
return 0
def test_environment_reset_deterministic(env1, num_cycles):
seed_action_spaces(env1)
env1.seed(42)
env1.reset()
hash1 = calc_hash(env1, 1, num_cycles)
seed_action_spaces(env1)
env1.seed(42)
env1.reset()
hash2 = calc_hash(env1, 2, num_cycles)
assert hash1 == hash2, "environments kept state after seed(42) and reset()"
def seed_test(env_constructor, num_cycles=10, test_kept_state=True):
env1 = env_constructor()
if test_kept_state:
test_environment_reset_deterministic(env1, num_cycles)
env2 = env_constructor()
base_seed = 42
env1.seed(base_seed)
env2.seed(base_seed)
assert check_environment_deterministic(env1, env2, num_cycles), \
("The environment gives different results on multiple runs when initialized with the same seed. This is usually a sign that you are using np.random or random modules directly, which uses a global random state.")
def parallel_seed_test(parallel_env_fn, num_cycles=10, test_kept_state=True):
def aec_env_fn():
parallel_env = parallel_env_fn()
env = parallel_to_aec(parallel_env)
return env
seed_test(aec_env_fn, num_cycles, test_kept_state)