Spaces:
Sleeping
Sleeping
import math | |
import warnings | |
import magent | |
import numpy as np | |
from gym.spaces import Box, Discrete | |
from gym.utils import EzPickle | |
from pettingzoo import AECEnv | |
from pettingzoo.magent.render import Renderer | |
from pettingzoo.utils import agent_selector | |
from pettingzoo.utils.conversions import parallel_to_aec_wrapper, parallel_wrapper_fn | |
from .magent_env import magent_parallel_env, make_env | |
map_size = 200 | |
max_cycles_default = 500 | |
KILL_REWARD = 5 | |
minimap_mode_default = False | |
default_reward_args = dict(step_reward=-0.01, attack_penalty=-0.1, dead_penalty=-1, attack_food_reward=0.5) | |
def parallel_env(max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, **reward_args): | |
env_reward_args = dict(**default_reward_args) | |
env_reward_args.update(reward_args) | |
return _parallel_env(map_size, minimap_mode, env_reward_args, max_cycles, extra_features) | |
def raw_env(max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, **reward_args): | |
return parallel_to_aec_wrapper(parallel_env(max_cycles, minimap_mode, extra_features, **reward_args)) | |
env = make_env(raw_env) | |
def load_config(size, minimap_mode, step_reward, attack_penalty, dead_penalty, attack_food_reward): | |
gw = magent.gridworld | |
cfg = gw.Config() | |
cfg.set({"map_width": size, "map_height": size}) | |
cfg.set({"minimap_mode": minimap_mode}) | |
options = { | |
'width': 1, 'length': 1, 'hp': 3, 'speed': 3, | |
'view_range': gw.CircleRange(7), 'attack_range': gw.CircleRange(1), | |
'damage': 6, 'step_recover': 0, 'attack_in_group': 1, | |
'step_reward': step_reward, 'attack_penalty': attack_penalty, 'dead_penalty': dead_penalty | |
} | |
agent = cfg.register_agent_type( | |
name="agent", | |
attr=options) | |
options = { | |
'width': 1, 'length': 1, 'hp': 25, 'speed': 0, | |
'view_range': gw.CircleRange(1), 'attack_range': gw.CircleRange(0), | |
'kill_reward': KILL_REWARD} | |
food = cfg.register_agent_type( | |
name='food', | |
attr=options) | |
g_f = cfg.add_group(food) | |
g_s = cfg.add_group(agent) | |
a = gw.AgentSymbol(g_s, index='any') | |
b = gw.AgentSymbol(g_f, index='any') | |
cfg.add_reward_rule(gw.Event(a, 'attack', b), receiver=a, value=attack_food_reward) | |
return cfg | |
class _parallel_env(magent_parallel_env, EzPickle): | |
metadata = { | |
'render.modes': ['human','rgb_array'], | |
'name': "gather_v4", | |
'video.frames_per_second': 5, | |
} | |
def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features): | |
EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features) | |
env = magent.GridWorld(load_config(map_size, minimap_mode, **reward_args)) | |
handles = env.get_handles() | |
reward_vals = np.array([5] + list(reward_args.values())) | |
reward_range = [np.minimum(reward_vals, 0).sum(), np.maximum(reward_vals, 0).sum()] | |
names = ["omnivore"] | |
super().__init__(env, handles[1:], names, map_size, max_cycles, reward_range, minimap_mode, extra_features) | |
def generate_map(self): | |
env, map_size = self.env, self.map_size | |
handles = env.get_handles()[1:] | |
food_handle = env.get_handles()[0] | |
center_x, center_y = map_size // 2, map_size // 2 | |
def add_square(pos, side, gap): | |
side = int(side) | |
for x in range(center_x - side // 2, center_x + side // 2 + 1, gap): | |
pos.append([x, center_y - side // 2]) | |
pos.append([x, center_y + side // 2]) | |
for y in range(center_y - side // 2, center_y + side // 2 + 1, gap): | |
pos.append([center_x - side // 2, y]) | |
pos.append([center_x + side // 2, y]) | |
# agent | |
pos = [] | |
add_square(pos, map_size * 0.9, 3) | |
add_square(pos, map_size * 0.8, 4) | |
add_square(pos, map_size * 0.7, 6) | |
env.add_agents(handles[0], method="custom", pos=pos) | |
# food | |
pos = [] | |
add_square(pos, map_size * 0.65, 10) | |
add_square(pos, map_size * 0.6, 10) | |
add_square(pos, map_size * 0.55, 10) | |
add_square(pos, map_size * 0.5, 4) | |
add_square(pos, map_size * 0.45, 3) | |
add_square(pos, map_size * 0.4, 1) | |
add_square(pos, map_size * 0.3, 1) | |
add_square(pos, map_size * 0.3 - 2, 1) | |
add_square(pos, map_size * 0.3 - 4, 1) | |
add_square(pos, map_size * 0.3 - 6, 1) | |
env.add_agents(food_handle, method="custom", pos=pos) | |
# pattern | |
pattern = ( [[int(not((i % 4 == 0 or i % 4 == 1) or (j % 4 == 0 or j % 4 == 1)) ) for j in range(53)] for i in range(53)]) | |
def draw(base_x, base_y, data): | |
w, h = len(data), len(data[0]) | |
pos = [] | |
for i in range(w): | |
for j in range(h): | |
if data[i][j] == 1: | |
start_x = i + base_x | |
start_y = j + base_y | |
for x in range(start_x, start_x + 1): | |
for y in range(start_y, start_y + 1): | |
pos.append([y, x]) | |
env.add_agents(food_handle, method="custom", pos=pos) | |
w, h = len(pattern), len(pattern[0]) | |
draw(map_size // 2 - w // 2, map_size // 2 - h // 2, pattern) | |