# File: src/dqn_agent.py import torch import torch.nn as nn import torch.optim as optim import random import numpy as np from collections import deque # Dueling DQN network architecture for state‑action value estimation class DuelingDQN(nn.Module): def __init__(self, state_size, action_size): super(DuelingDQN, self).__init__() self.fc1 = nn.Linear(state_size, 128) self.fc2 = nn.Linear(128, 128) # Value stream self.value_stream = nn.Sequential( nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1) ) # Advantage stream self.advantage_stream = nn.Sequential( nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, action_size) ) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) value = self.value_stream(x) advantage = self.advantage_stream(x) # Combine streams to get Q-values q_values = value + (advantage - advantage.mean(dim=1, keepdim=True)) return q_values class AdvancedDQNAgent: def __init__(self, state_size, action_size, device="cpu"): self.state_size = state_size self.action_size = action_size self.device = device self.memory = deque(maxlen=10000) self.gamma = 0.99 # discount factor self.epsilon = 1.0 # exploration rate self.epsilon_min = 0.01 self.epsilon_decay = 0.995 self.learning_rate = 0.001 self.batch_size = 64 self.policy_net = DuelingDQN(state_size, action_size).to(device) self.target_net = DuelingDQN(state_size, action_size).to(device) self.update_target_network() self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate) self.criterion = nn.MSELoss() def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def act(self, state): if np.random.rand() <= self.epsilon: return random.randrange(self.action_size) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) with torch.no_grad(): q_values = self.policy_net(state_tensor) return int(torch.argmax(q_values).item()) def remember(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def replay(self): if len(self.memory) < self.batch_size: return batch = random.sample(self.memory, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states = torch.FloatTensor(states).to(self.device) actions = torch.LongTensor(actions).unsqueeze(1).to(self.device) rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device) next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device) # Compute current Q-values current_q = self.policy_net(states).gather(1, actions) # Double DQN: select next action using policy net, evaluate with target net next_actions = torch.argmax(self.policy_net(next_states), dim=1, keepdim=True) next_q = self.target_net(next_states).gather(1, next_actions) target_q = rewards + (self.gamma * next_q * (1 - dones)) loss = self.criterion(current_q, target_q.detach()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay def save(self, path): torch.save(self.policy_net.state_dict(), path) def load(self, path): self.policy_net.load_state_dict(torch.load(path)) self.update_target_network()