Kano001's picture
Upload 654 files
3f7c971 verified
raw
history blame
1.13 kB
import os
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
avaliable_2p_versions = {
"classic": 4,
"two_paddles": 10,
"soccer": 14,
"foozpong": 19,
"hockey": 27,
"handball": 35,
"volleyball": 39,
"basketball": 45,
}
avaliable_4p_versions = {
"classic": 6,
"two_paddles": 11,
"soccer": 16,
"foozpong": 21,
"hockey": 29,
"quadrapong": 33,
"handball": 37,
"volleyball": 41,
"basketball": 49,
}
def raw_env(num_players=2, game_version="classic", **kwargs):
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players"
versions = avaliable_2p_versions if num_players == 2 else avaliable_4p_versions
assert game_version in versions, f"pong version {game_version} not supported for number of players {num_players}. Available options are {list(versions)}"
mode = versions[game_version]
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)
env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)