|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import random
|
|
import numpy as np
|
|
from collections import deque
|
|
|
|
|
|
class DuelingDQN(nn.Module):
|
|
def __init__(self, state_size, action_size):
|
|
super(DuelingDQN, self).__init__()
|
|
self.fc1 = nn.Linear(state_size, 128)
|
|
self.fc2 = nn.Linear(128, 128)
|
|
|
|
self.value_stream = nn.Sequential(
|
|
nn.Linear(128, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 1)
|
|
)
|
|
|
|
self.advantage_stream = nn.Sequential(
|
|
nn.Linear(128, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, action_size)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = torch.relu(self.fc1(x))
|
|
x = torch.relu(self.fc2(x))
|
|
value = self.value_stream(x)
|
|
advantage = self.advantage_stream(x)
|
|
|
|
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
|
return q_values
|
|
|
|
class AdvancedDQNAgent:
|
|
def __init__(self, state_size, action_size, device="cpu"):
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.device = device
|
|
|
|
self.memory = deque(maxlen=10000)
|
|
self.gamma = 0.99
|
|
self.epsilon = 1.0
|
|
self.epsilon_min = 0.01
|
|
self.epsilon_decay = 0.995
|
|
self.learning_rate = 0.001
|
|
self.batch_size = 64
|
|
|
|
self.policy_net = DuelingDQN(state_size, action_size).to(device)
|
|
self.target_net = DuelingDQN(state_size, action_size).to(device)
|
|
self.update_target_network()
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
|
|
self.criterion = nn.MSELoss()
|
|
|
|
def update_target_network(self):
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
def act(self, state):
|
|
if np.random.rand() <= self.epsilon:
|
|
return random.randrange(self.action_size)
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
with torch.no_grad():
|
|
q_values = self.policy_net(state_tensor)
|
|
return int(torch.argmax(q_values).item())
|
|
|
|
def remember(self, state, action, reward, next_state, done):
|
|
self.memory.append((state, action, reward, next_state, done))
|
|
|
|
def replay(self):
|
|
if len(self.memory) < self.batch_size:
|
|
return
|
|
|
|
batch = random.sample(self.memory, self.batch_size)
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
states = torch.FloatTensor(states).to(self.device)
|
|
actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
|
|
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
|
|
next_states = torch.FloatTensor(next_states).to(self.device)
|
|
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
|
|
|
|
|
|
current_q = self.policy_net(states).gather(1, actions)
|
|
|
|
next_actions = torch.argmax(self.policy_net(next_states), dim=1, keepdim=True)
|
|
next_q = self.target_net(next_states).gather(1, next_actions)
|
|
target_q = rewards + (self.gamma * next_q * (1 - dones))
|
|
|
|
loss = self.criterion(current_q, target_q.detach())
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
if self.epsilon > self.epsilon_min:
|
|
self.epsilon *= self.epsilon_decay
|
|
|
|
def save(self, path):
|
|
torch.save(self.policy_net.state_dict(), path)
|
|
|
|
def load(self, path):
|
|
self.policy_net.load_state_dict(torch.load(path))
|
|
self.update_target_network()
|
|
|