File size: 7,279 Bytes
18fa92b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import heapq
# Neural Network for Deep Q-Learning
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, hidden_sizes=[256, 128, 64], dropout_rate=0.1):
super(QNetwork, self).__init__()
self.state_size = state_size
self.action_size = action_size
# Build a deeper network with configurable hidden layers
layers = []
prev_size = state_size
for hidden_size in hidden_sizes:
layers.append(nn.Linear(prev_size, hidden_size))
layers.append(nn.BatchNorm1d(hidden_size)) # Add batch normalization
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate)) # Add dropout for regularization
prev_size = hidden_size
layers.append(nn.Linear(prev_size, action_size))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# Prioritized Experience Replay Memory (simplified)
class PriorityReplayMemory:
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001):
self.capacity = capacity
self.alpha = alpha # Priority exponent
self.beta = beta # Importance sampling weight
self.beta_increment = beta_increment
self.memory = [] # Heap for priorities
self.experiences = deque(maxlen=capacity) # Store experiences
self.max_priority = 1.0
def add(self, experience, error=None):
priority = error if error is not None else self.max_priority
priority = (abs(priority) + 1e-5) ** self.alpha # Small constant to avoid zero priority
heapq.heappush(self.memory, (-priority, len(self.experiences))) # Negative for max heap
self.experiences.append(experience)
def sample(self, batch_size):
if len(self.experiences) < batch_size:
return None, None, None
# Calculate sampling probabilities
priorities = np.array([-p for p, _ in self.memory[:len(self.experiences)]])
probs = priorities / priorities.sum()
# Sample indices
indices = np.random.choice(len(self.experiences), batch_size, p=probs, replace=False)
samples = [self.experiences[idx] for idx in indices]
# Importance sampling weights
weights = (len(self.experiences) * probs[indices]) ** (-self.beta)
weights /= weights.max() # Normalize
self.beta = min(1.0, self.beta + self.beta_increment) # Anneal beta
return samples, indices, torch.FloatTensor(weights)
def update_priorities(self, indices, errors):
for idx, error in zip(indices, errors):
priority = (abs(error) + 1e-5) ** self.alpha
self.memory[idx] = (-priority, self.memory[idx][1])
self.max_priority = max(self.max_priority, priority)
heapq.heapify(self.memory) # Re-heapify after updates
def __len__(self):
return len(self.experiences)
# Enhanced Reinforcement Learning Agent
class RLAgent:
def __init__(self, state_size, action_size,
lr=0.0005, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, min_epsilon=0.01,
memory_capacity=10000, batch_size=64, target_update_freq=1000,
use_double_dqn=True, clip_grad_norm=1.0):
self.state_size = state_size
self.action_size = action_size
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.min_epsilon = min_epsilon
self.batch_size = batch_size
self.use_double_dqn = use_double_dqn
self.clip_grad_norm = clip_grad_norm
# Networks
self.policy_net = QNetwork(state_size, action_size)
self.target_net = QNetwork(state_size, action_size)
self.target_net.load_state_dict(self.policy_net.state_dict()) # Copy weights
self.target_net.eval() # Target network doesn't train
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.criterion = nn.SmoothL1Loss() # Huber loss for stability
# Memory
self.memory = PriorityReplayMemory(memory_capacity)
self.steps = 0
self.target_update_freq = target_update_freq
def remember(self, state, action, reward, next_state, done):
# Initial error estimate (could be refined with TD error later)
state_tensor = torch.FloatTensor(state).unsqueeze(0)
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
with torch.no_grad():
q_value = self.policy_net(state_tensor)[action]
next_q = self.target_net(next_state_tensor).max().item()
target = reward + (1 - done) * self.gamma * next_q
error = abs(q_value.item() - target)
self.memory.add((state, action, reward, next_state, done), error)
def act(self, state):
self.steps += 1
if random.random() < self.epsilon:
return random.randint(0, self.action_size - 1)
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
return torch.argmax(self.policy_net(state)).item()
def train(self):
if len(self.memory) < self.batch_size:
return
# Sample from memory
batch, indices, weights = self.memory.sample(self.batch_size)
if batch is None:
return
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
weights = weights.unsqueeze(1)
# Compute Q-values
q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
# Double DQN or standard DQN
if self.use_double_dqn:
next_actions = self.policy_net(next_states).argmax(1)
next_q_values = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
next_q_values = self.target_net(next_states).max(1)[0]
# Compute targets
targets = rewards + self.gamma * next_q_values * (1 - dones)
# Compute TD errors for priority update
td_errors = (q_values - targets).detach().cpu().numpy()
# Loss with importance sampling weights
loss = (self.criterion(q_values, targets) * weights).mean()
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad_norm)
self.optimizer.step()
# Update priorities
self.memory.update_priorities(indices, td_errors)
# Update target network
if self.steps % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Decay epsilon
self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay) |