I am working on small project to learn a computer model to navigate in a network – moving from an origin to a destination node.
I had first made a GNN-DQN model, that is fairly ok at a 100 node network, but gets extremely slow in training with 1000 nodes.
It wasn’t straight forward to see how parallel computing could be used to speed it up, so I moved to a GNN-PPO model instead.
I am using pytorch and cuda with my Nvidia RTX4070 graphics card, but I am not getting the orders of magnitude improvements, that I am hoping for.
Basically the processing time for each training episode remained the same.
Could you please help me with any guidance on how to improve this?
You can see the full script here.
In short, it generates a 100 node network, creates the machine learning models, does the training, showcases a few examples applying the trained model.
import networkx as nx
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv
from torch_geometric.utils import from_networkx
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Generate a grid-like graph with additional random connections
def generate_road_structure_graph(grid_size=(10, 10), additional_edges_prob=0.002):
G = nx.grid_2d_graph(*grid_size)
G = nx.convert_node_labels_to_integers(G)
nodes = list(G.nodes())
for u in nodes:
for v in nodes:
if u != v and not G.has_edge(u, v) and random.random() < additional_edges_prob:
G.add_edge(u, v)
for node in G.nodes:
G.nodes[node]['pos'] = np.array([node % grid_size[1], node // grid_size[1]]) + 0.1 * np.random.randn(2)
for (u, v) in G.edges():
G.edges[u, v]['weight'] = np.linalg.norm(G.nodes[u]['pos'] - G.nodes[v]['pos'])
return G
# Generate the road-like graph
G = generate_road_structure_graph()
# Extract positions for nodes
pos = nx.get_node_attributes(G, 'pos')
# Rescale node positions for better visualization
scaled_pos = {node: pos[node] * 10 for node in G.nodes}
# Draw the graph
plt.figure(figsize=(8, 6))
nx.draw(G, scaled_pos, with_labels=True, node_color='skyblue', edge_color='gray', node_size=500, font_size=10, font_weight='bold')
edges = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edge_labels(G, scaled_pos, edge_labels={edge: f"{weight:.2f}" for edge, weight in edges.items()})
plt.title("Road-like Graph Structure")
plt.show()
#%%
# Convert NetworkX graph to PyTorch Geometric data
data = from_networkx(G).to(device)
# Graph Neural Network (GNN)
class GNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(GNN, self).__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
return x
# Actor-Critic Networks for PPO
class ActorCritic(nn.Module):
def __init__(self, input_dim, output_dim):
super(ActorCritic, self).__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, 128)
self.actor = nn.Linear(128, output_dim)
self.critic = nn.Linear(128, 1)
self.initialize_weights()
def initialize_weights(self):
nn.init.kaiming_normal_(self.fc1.weight)
nn.init.kaiming_normal_(self.fc2.weight)
nn.init.kaiming_normal_(self.actor.weight)
nn.init.kaiming_normal_(self.critic.weight)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
action_probs = F.softmax(self.actor(x), dim=-1)
value = self.critic(x)
return action_probs, value
# Memory for PPO
class PPOBuffer:
def __init__(self, capacity):
self.states = []
self.actions = []
self.rewards = []
self.next_states = []
self.log_probs = []
self.dones = []
self.capacity = capacity
def push(self, state, action, reward, next_state, log_prob, done):
if len(self.states) >= self.capacity:
self.states.pop(0)
self.actions.pop(0)
self.rewards.pop(0)
self.next_states.pop(0)
self.log_probs.pop(0)
self.dones.pop(0)
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.next_states.append(next_state)
self.log_probs.append(log_prob)
self.dones.append(done)
def sample(self):
return self.states, self.actions, self.rewards, self.next_states, self.log_probs, self.dones
def clear(self):
self.states = []
self.actions = []
self.rewards = []
self.next_states = []
self.log_probs = []
self.dones = []
# Environment to represent the graph and agent navigation
class GraphEnv:
def __init__(self, G, data):
self.G = G
self.data = data # Data already on GPU
self.nodes = list(G.nodes)
self.node_features = torch.eye(len(G.nodes), dtype=torch.float).to(device)
self.gnn = GNN(input_dim=self.node_features.shape[1], hidden_dim=64, output_dim=32).to(device)
# Initialize node_positions correctly
self.node_positions = torch.stack(
[torch.tensor(G.nodes[node]['pos'], dtype=torch.float) for node in self.nodes]
).to(device)
self.max_distance = self.compute_max_distance()
# Precompute distances, direction vectors, and edge existence
self.distance_matrix, self.direction_matrix = self.compute_distances_and_directions()
self.edge_existence_matrix = self.compute_edge_existence()
self.reset()
def compute_edge_existence(self):
num_nodes = len(self.nodes)
edges = torch.zeros((num_nodes, num_nodes), dtype=torch.bool, device=device)
for i in range(num_nodes):
for j in range(num_nodes):
if self.G.has_edge(i, j):
edges[i, j] = True
return edges
def compute_max_distance(self):
distances = torch.cdist(self.node_positions, self.node_positions) # Compute pairwise distances
return distances.max().item()
def compute_distances_and_directions(self):
num_nodes = len(self.nodes)
distances = torch.zeros((num_nodes, num_nodes), device=device)
directions = torch.zeros((num_nodes, num_nodes, 2), device=device)
for i in range(num_nodes):
for j in range(num_nodes):
if i != j:
direction = self.node_positions[j] - self.node_positions[i]
distance = direction.norm()
distances[i, j] = distance
if distance > 0:
directions[i, j] = direction / distance
return distances, directions
def reset(self):
self.origin = random.choice(self.nodes)
self.destination = random.choice(self.nodes)
while self.destination == self.origin:
self.destination = random.choice(self.nodes)
self.current_node = self.origin
self.visited_nodes = set()
self.destination_pos_tensor = self.node_positions[self.destination] # Store tensor
return self._get_state()
def _get_state(self):
with torch.no_grad(): # Disable gradient tracking for state computations
self.data.x = self.node_features
embeddings = self.gnn(self.data.x, self.data.edge_index)
current_node_embedding = embeddings[self.current_node].detach()
relative_position = self.node_positions[self.current_node] - self.destination_pos_tensor
remaining_distance_vector = self.destination_pos_tensor - self.node_positions[self.current_node]
remaining_distance_magnitude = self.distance_matrix[self.current_node, self.destination]
if remaining_distance_magnitude != 0:
remaining_distance_vector /= remaining_distance_magnitude
visited_neighbors = torch.zeros(len(self.nodes), dtype=torch.float, device=device)
for neighbor in self.G.neighbors(self.current_node):
visited_neighbors[neighbor] = 1 if neighbor in self.visited_nodes else 0
neighbor_distances = self.distance_matrix[self.current_node].clone()
neighbor_distances[~visited_neighbors.bool()] = -1.0 # Mask unvisited nodes
neighbor_direction_vectors = self.direction_matrix[self.current_node]
state = torch.cat([
current_node_embedding,
relative_position,
remaining_distance_vector,
visited_neighbors,
neighbor_distances,
neighbor_direction_vectors.flatten()
]).to(device)
return state
def step(self, action):
next_node = self.nodes[action]
if not self.edge_existence_matrix[self.current_node, next_node]:
return self._get_state(), -0.1, False # Penalty for invalid move
move_distance = self.distance_matrix[self.current_node, next_node]
distance_before = self.distance_matrix[self.current_node, self.destination]
distance_after = self.distance_matrix[next_node, self.destination]
reward = 0
done = False
if next_node in self.visited_nodes:
reward -= 0.1 # Penalty for revisiting a node
if next_node == self.destination:
reward += 1.0 # Reward for reaching the destination
done = True
else:
distance_difference = distance_before - distance_after
progress_reward = distance_difference / self.max_distance
reward += progress_reward
if distance_difference <= 0:
reward -= -distance_difference / self.max_distance
penalty = 2 * move_distance / self.max_distance
reward -= penalty
self.visited_nodes.add(next_node)
self.current_node = next_node
return self._get_state(), reward, done
class VectorizedEnvironments:
def __init__(self, num_envs, G, data):
self.num_envs = num_envs
self.envs = [GraphEnv(G, data) for _ in range(num_envs)]
def reset(self):
# Ensure all reset states are returned as GPU tensors
return torch.stack([env.reset() for env in self.envs]).to(device)
def step(self, actions):
# Actions should be provided as a tensor on the GPU
actions = actions.to(device)
next_states, rewards, dones = [], [], []
for env, action in zip(self.envs, actions):
next_state, reward, done = env.step(action.item())
next_states.append(next_state)
rewards.append(reward)
dones.append(done)
return (torch.stack(next_states).to(device),
torch.tensor(rewards, device=device),
torch.tensor(dones, device=device))
def train_ppo_parallel():
# Hyperparameters
gamma = 0.99
clip_epsilon = 0.3
num_epochs = 5
max_steps_per_episode = 150
buffer_capacity = 50000
num_envs = 50 # Number of parallel environments
# Calculate the number of episodes per update based on parallel environments
episodes_per_update = num_envs # Since num_envs episodes are run in parallel each time
num_updates = 50000 // episodes_per_update # Adjust this as needed based on total desired episodes
# Initialize vectorized environments and PPO agent
vec_env = VectorizedEnvironments(num_envs, G, data)
input_dim = 32 + 2 + 2 + len(G.nodes) + len(G.nodes) + 2 * len(G.nodes)
output_dim = len(G.nodes)
ppo = ActorCritic(input_dim, output_dim).to(device)
optimizer = optim.Adam(ppo.parameters(), lr=0.0005)
# Memory for PPO
memory = PPOBuffer(buffer_capacity)
# Training Loop
total_episodes = 0
for update in range(num_updates):
# Get initial states from all environments as GPU tensors
states = vec_env.reset() # Already on the correct device
total_rewards = torch.zeros(num_envs, device=device)
dones = torch.zeros(num_envs, dtype=torch.bool, device=device)
successful_episodes = 0 # Track successful episodes
steps = 0
while not dones.all() and steps < max_steps_per_episode:
with torch.no_grad():
action_probs, _ = ppo(states)
action_dists = torch.distributions.Categorical(action_probs)
actions = action_dists.sample()
log_probs = action_dists.log_prob(actions)
next_states, rewards, next_dones = vec_env.step(actions)
rewards = rewards.to(device)
dones = next_dones.clone().detach().to(dtype=torch.bool, device=device)
# Convert to numpy only once and as a batch
memory.push(states.cpu().numpy(), actions.cpu().numpy(), rewards.cpu().numpy(), next_states.cpu().numpy(), log_probs.cpu().numpy(), dones.cpu().numpy())
states = next_states
total_rewards += rewards * (~dones)
steps += 1
successful_episodes += next_dones.sum().item() # Increment successful episode count
total_episodes += num_envs # Each update completes `num_envs` episodes
# Update policy
states, actions, rewards, next_states, log_probs, dones = memory.sample()
states = torch.FloatTensor(states).to(device)
actions = torch.LongTensor(actions).to(device)
rewards = torch.FloatTensor(rewards).to(device)
next_states = torch.FloatTensor(next_states).to(device)
old_log_probs = torch.FloatTensor(log_probs).to(device)
dones = torch.FloatTensor(dones).to(device)
# Calculate discounted rewards and advantages
with torch.no_grad():
_, next_values = ppo(next_states)
next_values = next_values.squeeze(-1)
returns = rewards + gamma * next_values * (1 - dones)
_, values = ppo(states)
values = values.squeeze(-1)
advantages = returns - values.detach()
total_loss = 0
for _ in range(num_epochs):
new_action_probs, new_values = ppo(states)
new_values = new_values.squeeze(-1)
new_action_dist = torch.distributions.Categorical(new_action_probs)
new_log_probs = new_action_dist.log_prob(actions)
ratios = torch.exp(new_log_probs - old_log_probs)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = F.mse_loss(new_values, returns)
loss = actor_loss + 0.5 * critic_loss
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
memory.clear()
avg_reward = total_rewards.mean().item()
success_fraction = successful_episodes / num_envs
print(f"Update {update + 1}/{num_updates}, "
f"Total Episodes: {total_episodes}, "
f"Average Total Reward: {avg_reward}, "
f"Total Loss: {total_loss}, "
f"Success Fraction: {success_fraction:.2f}")
torch.save(ppo.state_dict(), 'ppo_model.pth')
print("Model saved as 'ppo_model.pth'.")
train_ppo_parallel()
#%%
def test_ppo(test_examples):
# Re-initialize the environment and PPO model
env = GraphEnv(G, data)
n = len(env.nodes)
input_dim = 32 + 2 + 2 + n + n + 2 * n
output_dim = n
ppo = ActorCritic(input_dim, output_dim)
ppo.load_state_dict(torch.load('ppo_model.pth'))
ppo.eval() # Set the model to evaluation mode
for start, destination in test_examples:
print(f"Testing with start: {start}, destination: {destination}")
state = env.set_start_and_destination(start, destination)
done = False
steps = 0
path = [start]
while not done and steps < 30:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
with torch.no_grad():
action_probs, _ = ppo(state_tensor)
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
next_state, reward, done = env.step(action.item())
state = next_state
steps += 1
path.append(env.current_node)
print(f"Step {steps}: Moved to node {env.current_node} with reward {reward}")
# Visualize the path taken on the graph
plt.figure(figsize=(12, 10))
nx.draw(G, scaled_pos, with_labels=True, node_color='skyblue', edge_color='gray', node_size=500, font_size=8, font_weight='bold')
nx.draw_networkx_nodes(G, scaled_pos, nodelist=path, node_color='green', node_size=300)
nx.draw_networkx_nodes(G, scaled_pos, nodelist=[start], node_color='blue', node_size=500)
nx.draw_networkx_nodes(G, scaled_pos, nodelist=[destination], node_color='red', node_size=500)
nx.draw_networkx_edges(G, scaled_pos, edgelist=[(path[i], path[i+1]) for i in range(len(path)-1)], width=2, edge_color='black')
nx.draw_networkx_labels(G, scaled_pos, {start: f"Startn{start}", destination: f"Destinationn{destination}"}, font_color='white')
# Hardcoding the legend
plt.plot([], [], 'bo', label="Start Node")
plt.plot([], [], 'ro', label="Destination Node")
plt.plot([], [], 'go', label="Path Taken")
plt.plot([], [], 'o', color='skyblue', label="Other Nodes")
plt.legend(scatterpoints=1)
plt.title(f"Path from {start} to {destination}")
plt.show()
# Define test examples (start and destination node pairs)
test_examples = [(0, 99), (10, 50), (25, 75), (30, 45)]
# Test the trained model with the given examples
test_ppo(test_examples)
adaptall is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.