Kano001's picture
Upload 654 files
3f7c971 verified
raw
history blame
5.34 kB
import math
import warnings
import magent
import numpy as np
from gym.spaces import Box, Discrete
from gym.utils import EzPickle
from pettingzoo import AECEnv
from pettingzoo.magent.render import Renderer
from pettingzoo.utils import agent_selector
from pettingzoo.utils.conversions import parallel_to_aec_wrapper, parallel_wrapper_fn
from .magent_env import magent_parallel_env, make_env
map_size = 200
max_cycles_default = 500
KILL_REWARD = 5
minimap_mode_default = False
default_reward_args = dict(step_reward=-0.01, attack_penalty=-0.1, dead_penalty=-1, attack_food_reward=0.5)
def parallel_env(max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, **reward_args):
env_reward_args = dict(**default_reward_args)
env_reward_args.update(reward_args)
return _parallel_env(map_size, minimap_mode, env_reward_args, max_cycles, extra_features)
def raw_env(max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, **reward_args):
return parallel_to_aec_wrapper(parallel_env(max_cycles, minimap_mode, extra_features, **reward_args))
env = make_env(raw_env)
def load_config(size, minimap_mode, step_reward, attack_penalty, dead_penalty, attack_food_reward):
gw = magent.gridworld
cfg = gw.Config()
cfg.set({"map_width": size, "map_height": size})
cfg.set({"minimap_mode": minimap_mode})
options = {
'width': 1, 'length': 1, 'hp': 3, 'speed': 3,
'view_range': gw.CircleRange(7), 'attack_range': gw.CircleRange(1),
'damage': 6, 'step_recover': 0, 'attack_in_group': 1,
'step_reward': step_reward, 'attack_penalty': attack_penalty, 'dead_penalty': dead_penalty
}
agent = cfg.register_agent_type(
name="agent",
attr=options)
options = {
'width': 1, 'length': 1, 'hp': 25, 'speed': 0,
'view_range': gw.CircleRange(1), 'attack_range': gw.CircleRange(0),
'kill_reward': KILL_REWARD}
food = cfg.register_agent_type(
name='food',
attr=options)
g_f = cfg.add_group(food)
g_s = cfg.add_group(agent)
a = gw.AgentSymbol(g_s, index='any')
b = gw.AgentSymbol(g_f, index='any')
cfg.add_reward_rule(gw.Event(a, 'attack', b), receiver=a, value=attack_food_reward)
return cfg
class _parallel_env(magent_parallel_env, EzPickle):
metadata = {
'render.modes': ['human','rgb_array'],
'name': "gather_v4",
'video.frames_per_second': 5,
}
def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features):
EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features)
env = magent.GridWorld(load_config(map_size, minimap_mode, **reward_args))
handles = env.get_handles()
reward_vals = np.array([5] + list(reward_args.values()))
reward_range = [np.minimum(reward_vals, 0).sum(), np.maximum(reward_vals, 0).sum()]
names = ["omnivore"]
super().__init__(env, handles[1:], names, map_size, max_cycles, reward_range, minimap_mode, extra_features)
def generate_map(self):
env, map_size = self.env, self.map_size
handles = env.get_handles()[1:]
food_handle = env.get_handles()[0]
center_x, center_y = map_size // 2, map_size // 2
def add_square(pos, side, gap):
side = int(side)
for x in range(center_x - side // 2, center_x + side // 2 + 1, gap):
pos.append([x, center_y - side // 2])
pos.append([x, center_y + side // 2])
for y in range(center_y - side // 2, center_y + side // 2 + 1, gap):
pos.append([center_x - side // 2, y])
pos.append([center_x + side // 2, y])
# agent
pos = []
add_square(pos, map_size * 0.9, 3)
add_square(pos, map_size * 0.8, 4)
add_square(pos, map_size * 0.7, 6)
env.add_agents(handles[0], method="custom", pos=pos)
# food
pos = []
add_square(pos, map_size * 0.65, 10)
add_square(pos, map_size * 0.6, 10)
add_square(pos, map_size * 0.55, 10)
add_square(pos, map_size * 0.5, 4)
add_square(pos, map_size * 0.45, 3)
add_square(pos, map_size * 0.4, 1)
add_square(pos, map_size * 0.3, 1)
add_square(pos, map_size * 0.3 - 2, 1)
add_square(pos, map_size * 0.3 - 4, 1)
add_square(pos, map_size * 0.3 - 6, 1)
env.add_agents(food_handle, method="custom", pos=pos)
# pattern
pattern = ( [[int(not((i % 4 == 0 or i % 4 == 1) or (j % 4 == 0 or j % 4 == 1)) ) for j in range(53)] for i in range(53)])
def draw(base_x, base_y, data):
w, h = len(data), len(data[0])
pos = []
for i in range(w):
for j in range(h):
if data[i][j] == 1:
start_x = i + base_x
start_y = j + base_y
for x in range(start_x, start_x + 1):
for y in range(start_y, start_y + 1):
pos.append([y, x])
env.add_agents(food_handle, method="custom", pos=pos)
w, h = len(pattern), len(pattern[0])
draw(map_size // 2 - w // 2, map_size // 2 - h // 2, pattern)