File size: 3,411 Bytes
3f7c971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import hashlib
import pickle
import random
import warnings

import numpy as np

from pettingzoo.utils import parallel_to_aec


def hash(val):
    val = pickle.dumps(val)
    hasher = hashlib.md5()
    hasher.update(val)
    return hasher.hexdigest()


def calc_hash(new_env, rand_issue, max_env_iters):
    cur_hashes = []
    sampler = random.Random(42)
    for i in range(3):
        new_env.reset()
        for j in range(rand_issue + 1):
            random.randint(0, 1000)
            np.random.normal(size=100)
        for agent in new_env.agent_iter(max_env_iters):
            obs, rew, done, info = new_env.last()
            if done:
                action = None
            elif isinstance(obs, dict) and 'action_mask' in obs:
                action = sampler.choice(np.flatnonzero(obs['action_mask']))
            else:
                action = new_env.action_space(agent).sample()
            new_env.step(action)
            cur_hashes.append(agent)
            cur_hashes.append(hash_obsevation(obs))
            cur_hashes.append(float(rew))

    return hash(tuple(cur_hashes))


def seed_action_spaces(env):
    if hasattr(env, 'possible_agents'):
        for i, agent in enumerate(env.possible_agents):
            env.action_space(agent).seed(42 + i)


def check_environment_deterministic(env1, env2, num_cycles):
    '''
    env1 and env2 should be seeded environments

    returns a bool: true if env1 and env2 execute the same way
    '''

    # seeds action space so that actions are deterministic
    seed_action_spaces(env1)
    seed_action_spaces(env2)

    num_agents = max(1, len(getattr(env1, 'possible_agents', [])))

    # checks deterministic behavior if seed is set
    hashes = []
    num_seeds = 2
    max_env_iters = num_cycles * num_agents
    envs = [env1, env2]
    for x in range(num_seeds):
        hashes.append(calc_hash(envs[x], x, max_env_iters))

    return all(hashes[0] == h for h in hashes)


def hash_obsevation(obs):
    try:
        val = hash(obs.tobytes())
        return val
    except AttributeError:
        try:
            return hash(obs)
        except TypeError:
            warnings.warn("Observation not an int or an Numpy array")
            return 0


def test_environment_reset_deterministic(env1, num_cycles):
    seed_action_spaces(env1)
    env1.seed(42)
    env1.reset()
    hash1 = calc_hash(env1, 1, num_cycles)
    seed_action_spaces(env1)
    env1.seed(42)
    env1.reset()
    hash2 = calc_hash(env1, 2, num_cycles)
    assert hash1 == hash2, "environments kept state after seed(42) and reset()"


def seed_test(env_constructor, num_cycles=10, test_kept_state=True):
    env1 = env_constructor()
    if test_kept_state:
        test_environment_reset_deterministic(env1, num_cycles)
    env2 = env_constructor()
    base_seed = 42
    env1.seed(base_seed)
    env2.seed(base_seed)

    assert check_environment_deterministic(env1, env2, num_cycles), \
        ("The environment gives different results on multiple runs when initialized with the same seed. This is usually a sign that you are using np.random or random modules directly, which uses a global random state.")


def parallel_seed_test(parallel_env_fn, num_cycles=10, test_kept_state=True):
    def aec_env_fn():
        parallel_env = parallel_env_fn()
        env = parallel_to_aec(parallel_env)
        return env

    seed_test(aec_env_fn, num_cycles, test_kept_state)