In this PPO tutorial, the split_trajs
of the SyncDataCollector
is False
. However, I want to split the collected data in trajectories and learn from them. So if I set this argument to True
, data collectors split by orbit are returned, but they are zero-padded. I want to remove this zero padding of the training data.
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=True,
device=device,
)
# ...
for i, tensordict_data in enumerate(collector):
for td_id, td_trajectory in enumerate(tensordict_data):
mask = td_trajectory["collector", "mask"]
# Now I want to erase the zero padding that each tensor in the trajectory tensordict has based on the mask (each tensor has a different size and dimension)
for _ in range(num_epochs):
# We'll need an "advantage" signal to make PPO work.
# We re-compute it at each epoch as its value depends on the value
# network which is updated in the inner loop.
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for _ in range(frames_per_batch // sub_batch_size):
subdata = replay_buffer.sample(sub_batch_size)
loss_vals = loss_module(subdata.to(device))
There are data masks in tensordict["collector", "mask"]
, but I don’t know how to apply these to the entire tensordict and remove the zero padding comprehensively. The shape and size of each tensordict is of course different, so simply applying torch.masked_select
is naturally an error. And I feel that a straightforward implementation would be very cumbersome. Can this be done briefly in PyTorch? Any ideas would be appreciated.
user25785513 is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.