Speeding up the training of a GNN-PPO model using the GPU

I am working on small project to learn a computer model to navigate in a network – moving from an origin to a destination node.

I had first made a GNN-DQN model, that is fairly ok at a 100 node network, but gets extremely slow in training with 1000 nodes.

It wasn’t straight forward to see how parallel computing could be used to speed it up, so I moved to a GNN-PPO model instead.
I am using pytorch and cuda with my Nvidia RTX4070 graphics card, but I am not getting the orders of magnitude improvements, that I am hoping for.

Basically the processing time for each training episode remained the same.

Could you please help me with any guidance on how to improve this?

You can see the full script here.
In short, it generates a 100 node network, creates the machine learning models, does the training, showcases a few examples applying the trained model.


import networkx as nx
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv
from torch_geometric.utils import from_networkx

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate a grid-like graph with additional random connections
def generate_road_structure_graph(grid_size=(10, 10), additional_edges_prob=0.002):
    G = nx.grid_2d_graph(*grid_size)
    G = nx.convert_node_labels_to_integers(G)

    nodes = list(G.nodes())
    for u in nodes:
        for v in nodes:
            if u != v and not G.has_edge(u, v) and random.random() < additional_edges_prob:
                G.add_edge(u, v)

    for node in G.nodes:
        G.nodes[node]['pos'] = np.array([node % grid_size[1], node // grid_size[1]]) + 0.1 * np.random.randn(2)

    for (u, v) in G.edges():
        G.edges[u, v]['weight'] = np.linalg.norm(G.nodes[u]['pos'] - G.nodes[v]['pos'])
    
    return G

# Generate the road-like graph
G = generate_road_structure_graph()

# Extract positions for nodes
pos = nx.get_node_attributes(G, 'pos')

# Rescale node positions for better visualization
scaled_pos = {node: pos[node] * 10 for node in G.nodes}

# Draw the graph
plt.figure(figsize=(8, 6))
nx.draw(G, scaled_pos, with_labels=True, node_color='skyblue', edge_color='gray', node_size=500, font_size=10, font_weight='bold')
edges = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edge_labels(G, scaled_pos, edge_labels={edge: f"{weight:.2f}" for edge, weight in edges.items()})
plt.title("Road-like Graph Structure")
plt.show()
#%%
# Convert NetworkX graph to PyTorch Geometric data
data = from_networkx(G).to(device)

# Graph Neural Network (GNN)
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

# Actor-Critic Networks for PPO
class ActorCritic(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.actor = nn.Linear(128, output_dim)
        self.critic = nn.Linear(128, 1)
        self.initialize_weights()

    def initialize_weights(self):
        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.kaiming_normal_(self.actor.weight)
        nn.init.kaiming_normal_(self.critic.weight)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        action_probs = F.softmax(self.actor(x), dim=-1)
        value = self.critic(x)
        return action_probs, value

# Memory for PPO
class PPOBuffer:
    def __init__(self, capacity):
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.log_probs = []
        self.dones = []
        self.capacity = capacity

    def push(self, state, action, reward, next_state, log_prob, done):
        if len(self.states) >= self.capacity:
            self.states.pop(0)
            self.actions.pop(0)
            self.rewards.pop(0)
            self.next_states.pop(0)
            self.log_probs.pop(0)
            self.dones.pop(0)
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.log_probs.append(log_prob)
        self.dones.append(done)

    def sample(self):
        return self.states, self.actions, self.rewards, self.next_states, self.log_probs, self.dones

    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.log_probs = []
        self.dones = []

# Environment to represent the graph and agent navigation
class GraphEnv:
    def __init__(self, G, data):
        self.G = G
        self.data = data  # Data already on GPU
        self.nodes = list(G.nodes)
        self.node_features = torch.eye(len(G.nodes), dtype=torch.float).to(device)
        self.gnn = GNN(input_dim=self.node_features.shape[1], hidden_dim=64, output_dim=32).to(device)
        
        # Initialize node_positions correctly
        self.node_positions = torch.stack(
            [torch.tensor(G.nodes[node]['pos'], dtype=torch.float) for node in self.nodes]
        ).to(device)
        self.max_distance = self.compute_max_distance()

        # Precompute distances, direction vectors, and edge existence
        self.distance_matrix, self.direction_matrix = self.compute_distances_and_directions()
        self.edge_existence_matrix = self.compute_edge_existence()

        self.reset()

    def compute_edge_existence(self):
        num_nodes = len(self.nodes)
        edges = torch.zeros((num_nodes, num_nodes), dtype=torch.bool, device=device)
        for i in range(num_nodes):
            for j in range(num_nodes):
                if self.G.has_edge(i, j):
                    edges[i, j] = True
        return edges

    def compute_max_distance(self):
        distances = torch.cdist(self.node_positions, self.node_positions)  # Compute pairwise distances
        return distances.max().item()

    def compute_distances_and_directions(self):
        num_nodes = len(self.nodes)
        distances = torch.zeros((num_nodes, num_nodes), device=device)
        directions = torch.zeros((num_nodes, num_nodes, 2), device=device)

        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j:
                    direction = self.node_positions[j] - self.node_positions[i]
                    distance = direction.norm()
                    distances[i, j] = distance
                    if distance > 0:
                        directions[i, j] = direction / distance

        return distances, directions

    def reset(self):
        self.origin = random.choice(self.nodes)
        self.destination = random.choice(self.nodes)
        while self.destination == self.origin:
            self.destination = random.choice(self.nodes)
        self.current_node = self.origin
        self.visited_nodes = set()
        self.destination_pos_tensor = self.node_positions[self.destination]  # Store tensor
        return self._get_state()

    def _get_state(self):
        with torch.no_grad():  # Disable gradient tracking for state computations
            self.data.x = self.node_features
            embeddings = self.gnn(self.data.x, self.data.edge_index)
            current_node_embedding = embeddings[self.current_node].detach()

            relative_position = self.node_positions[self.current_node] - self.destination_pos_tensor
            remaining_distance_vector = self.destination_pos_tensor - self.node_positions[self.current_node]
            remaining_distance_magnitude = self.distance_matrix[self.current_node, self.destination]
            if remaining_distance_magnitude != 0:
                remaining_distance_vector /= remaining_distance_magnitude

            visited_neighbors = torch.zeros(len(self.nodes), dtype=torch.float, device=device)
            for neighbor in self.G.neighbors(self.current_node):
                visited_neighbors[neighbor] = 1 if neighbor in self.visited_nodes else 0

            neighbor_distances = self.distance_matrix[self.current_node].clone()
            neighbor_distances[~visited_neighbors.bool()] = -1.0  # Mask unvisited nodes

            neighbor_direction_vectors = self.direction_matrix[self.current_node]

            state = torch.cat([
                current_node_embedding,
                relative_position,
                remaining_distance_vector,
                visited_neighbors,
                neighbor_distances,
                neighbor_direction_vectors.flatten()
            ]).to(device)

        return state

    def step(self, action):
        next_node = self.nodes[action]
        if not self.edge_existence_matrix[self.current_node, next_node]:
            return self._get_state(), -0.1, False  # Penalty for invalid move
    
        move_distance = self.distance_matrix[self.current_node, next_node]
        distance_before = self.distance_matrix[self.current_node, self.destination]
        distance_after = self.distance_matrix[next_node, self.destination]
    
        reward = 0
        done = False
    
        if next_node in self.visited_nodes:
            reward -= 0.1  # Penalty for revisiting a node
        if next_node == self.destination:
            reward += 1.0  # Reward for reaching the destination
            done = True
        else:
            distance_difference = distance_before - distance_after
            progress_reward = distance_difference / self.max_distance
            reward += progress_reward
            if distance_difference <= 0:
                reward -= -distance_difference / self.max_distance
    
        penalty = 2 * move_distance / self.max_distance
        reward -= penalty
    
        self.visited_nodes.add(next_node)
        self.current_node = next_node
        return self._get_state(), reward, done

class VectorizedEnvironments:
    def __init__(self, num_envs, G, data):
        self.num_envs = num_envs
        self.envs = [GraphEnv(G, data) for _ in range(num_envs)]

    def reset(self):
        # Ensure all reset states are returned as GPU tensors
        return torch.stack([env.reset() for env in self.envs]).to(device)

    def step(self, actions):
        # Actions should be provided as a tensor on the GPU
        actions = actions.to(device)
        next_states, rewards, dones = [], [], []
        for env, action in zip(self.envs, actions):
            next_state, reward, done = env.step(action.item())
            next_states.append(next_state)
            rewards.append(reward)
            dones.append(done)
        return (torch.stack(next_states).to(device),
                torch.tensor(rewards, device=device),
                torch.tensor(dones, device=device))

def train_ppo_parallel():
    # Hyperparameters
    gamma = 0.99
    clip_epsilon = 0.3
    num_epochs = 5
    max_steps_per_episode = 150
    buffer_capacity = 50000
    num_envs = 50  # Number of parallel environments

    # Calculate the number of episodes per update based on parallel environments
    episodes_per_update = num_envs  # Since num_envs episodes are run in parallel each time
    num_updates = 50000 // episodes_per_update  # Adjust this as needed based on total desired episodes

    # Initialize vectorized environments and PPO agent
    vec_env = VectorizedEnvironments(num_envs, G, data)
    input_dim = 32 + 2 + 2 + len(G.nodes) + len(G.nodes) + 2 * len(G.nodes)
    output_dim = len(G.nodes)
    ppo = ActorCritic(input_dim, output_dim).to(device)
    optimizer = optim.Adam(ppo.parameters(), lr=0.0005)

    # Memory for PPO
    memory = PPOBuffer(buffer_capacity)

    # Training Loop
    total_episodes = 0
    for update in range(num_updates):
        # Get initial states from all environments as GPU tensors
        states = vec_env.reset()  # Already on the correct device
        total_rewards = torch.zeros(num_envs, device=device)
        dones = torch.zeros(num_envs, dtype=torch.bool, device=device)
        successful_episodes = 0  # Track successful episodes
        steps = 0

        while not dones.all() and steps < max_steps_per_episode:
            with torch.no_grad():
                action_probs, _ = ppo(states)
            action_dists = torch.distributions.Categorical(action_probs)
            actions = action_dists.sample()
            log_probs = action_dists.log_prob(actions)

            next_states, rewards, next_dones = vec_env.step(actions)
            rewards = rewards.to(device)
            dones = next_dones.clone().detach().to(dtype=torch.bool, device=device)

            # Convert to numpy only once and as a batch
            memory.push(states.cpu().numpy(), actions.cpu().numpy(), rewards.cpu().numpy(), next_states.cpu().numpy(), log_probs.cpu().numpy(), dones.cpu().numpy())

            states = next_states
            total_rewards += rewards * (~dones)
            steps += 1

            successful_episodes += next_dones.sum().item()  # Increment successful episode count

        total_episodes += num_envs  # Each update completes `num_envs` episodes

        # Update policy
        states, actions, rewards, next_states, log_probs, dones = memory.sample()
        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        old_log_probs = torch.FloatTensor(log_probs).to(device)
        dones = torch.FloatTensor(dones).to(device)

        # Calculate discounted rewards and advantages
        with torch.no_grad():
            _, next_values = ppo(next_states)
            next_values = next_values.squeeze(-1)
        returns = rewards + gamma * next_values * (1 - dones)

        _, values = ppo(states)
        values = values.squeeze(-1)
        advantages = returns - values.detach()

        total_loss = 0

        for _ in range(num_epochs):
            new_action_probs, new_values = ppo(states)
            new_values = new_values.squeeze(-1)
            new_action_dist = torch.distributions.Categorical(new_action_probs)
            new_log_probs = new_action_dist.log_prob(actions)

            ratios = torch.exp(new_log_probs - old_log_probs)

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()

            critic_loss = F.mse_loss(new_values, returns)

            loss = actor_loss + 0.5 * critic_loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        memory.clear()
        avg_reward = total_rewards.mean().item()
        success_fraction = successful_episodes / num_envs
        print(f"Update {update + 1}/{num_updates}, "
              f"Total Episodes: {total_episodes}, "
              f"Average Total Reward: {avg_reward}, "
              f"Total Loss: {total_loss}, "
              f"Success Fraction: {success_fraction:.2f}")

    torch.save(ppo.state_dict(), 'ppo_model.pth')
    print("Model saved as 'ppo_model.pth'.")


train_ppo_parallel()



#%%
def test_ppo(test_examples):
    # Re-initialize the environment and PPO model
    env = GraphEnv(G, data)
    n = len(env.nodes)
    input_dim = 32 + 2 + 2 + n + n + 2 * n
    output_dim = n
    ppo = ActorCritic(input_dim, output_dim)
    ppo.load_state_dict(torch.load('ppo_model.pth'))
    ppo.eval()  # Set the model to evaluation mode

    for start, destination in test_examples:
        print(f"Testing with start: {start}, destination: {destination}")
        state = env.set_start_and_destination(start, destination)
        done = False
        steps = 0
        path = [start]

        while not done and steps < 30:
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
            with torch.no_grad():
                action_probs, _ = ppo(state_tensor)
            action_dist = torch.distributions.Categorical(action_probs)
            action = action_dist.sample()

            next_state, reward, done = env.step(action.item())
            state = next_state
            steps += 1
            path.append(env.current_node)
            print(f"Step {steps}: Moved to node {env.current_node} with reward {reward}")

        # Visualize the path taken on the graph
        plt.figure(figsize=(12, 10))
        nx.draw(G, scaled_pos, with_labels=True, node_color='skyblue', edge_color='gray', node_size=500, font_size=8, font_weight='bold')
        nx.draw_networkx_nodes(G, scaled_pos, nodelist=path, node_color='green', node_size=300)
        nx.draw_networkx_nodes(G, scaled_pos, nodelist=[start], node_color='blue', node_size=500)
        nx.draw_networkx_nodes(G, scaled_pos, nodelist=[destination], node_color='red', node_size=500)
        nx.draw_networkx_edges(G, scaled_pos, edgelist=[(path[i], path[i+1]) for i in range(len(path)-1)], width=2, edge_color='black')
        nx.draw_networkx_labels(G, scaled_pos, {start: f"Startn{start}", destination: f"Destinationn{destination}"}, font_color='white')
        
        # Hardcoding the legend
        plt.plot([], [], 'bo', label="Start Node")
        plt.plot([], [], 'ro', label="Destination Node")
        plt.plot([], [], 'go', label="Path Taken")
        plt.plot([], [], 'o', color='skyblue', label="Other Nodes")
        plt.legend(scatterpoints=1)

        plt.title(f"Path from {start} to {destination}")
        plt.show()

# Define test examples (start and destination node pairs)
test_examples = [(0, 99), (10, 50), (25, 75), (30, 45)]

# Test the trained model with the given examples
test_ppo(test_examples)

New contributor

adaptall is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.

Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa Dịch vụ tổ chức sự kiện 5 sao Thông tin về chúng tôi Dịch vụ sinh nhật bé trai Dịch vụ sinh nhật bé gái Sự kiện trọn gói Các tiết mục giải trí Dịch vụ bổ trợ Tiệc cưới sang trọng Dịch vụ khai trương Tư vấn tổ chức sự kiện Hình ảnh sự kiện Cập nhật tin tức Liên hệ ngay Thuê chú hề chuyên nghiệp Tiệc tất niên cho công ty Trang trí tiệc cuối năm Tiệc tất niên độc đáo Sinh nhật bé Hải Đăng Sinh nhật đáng yêu bé Khánh Vân Sinh nhật sang trọng Bích Ngân Tiệc sinh nhật bé Thanh Trang Dịch vụ ông già Noel Xiếc thú vui nhộn Biểu diễn xiếc quay đĩa Dịch vụ tổ chức tiệc uy tín Khám phá dịch vụ của chúng tôi Tiệc sinh nhật cho bé trai Trang trí tiệc cho bé gái Gói sự kiện chuyên nghiệp Chương trình giải trí hấp dẫn Dịch vụ hỗ trợ sự kiện Trang trí tiệc cưới đẹp Khởi đầu thành công với khai trương Chuyên gia tư vấn sự kiện Xem ảnh các sự kiện đẹp Tin mới về sự kiện Kết nối với đội ngũ chuyên gia Chú hề vui nhộn cho tiệc sinh nhật Ý tưởng tiệc cuối năm Tất niên độc đáo Trang trí tiệc hiện đại Tổ chức sinh nhật cho Hải Đăng Sinh nhật độc quyền Khánh Vân Phong cách tiệc Bích Ngân Trang trí tiệc bé Thanh Trang Thuê dịch vụ ông già Noel chuyên nghiệp Xem xiếc khỉ đặc sắc Xiếc quay đĩa thú vị
Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa
Thiết kế website Thiết kế website Thiết kế website Cách kháng tài khoản quảng cáo Mua bán Fanpage Facebook Dịch vụ SEO Tổ chức sinh nhật