I am trying to write a Custom Policy function to be integrated with RLlib, but I persistently get an error on shape mismatch.
I have an environment with an obs space of size 12 and action space of size 2, both continuous.
I want to include in the LSTM also the last action and last reward. Since my task is a navigation task with varying episodes length, I should also mask or pad my sequences. For the moment I did not implement neither the masking or padding.
Here my code snippet and the error I get in the implementation:
class CustomTorchModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super(CustomTorchModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.view_requirements['state_out_0'] = ViewRequirement(
data_col='state_out_0',
shift="-1:-1", # Adjust based on your needs
space=obs_space # or appropriate space
)
# self.view_requirements["prev_rewards"] = ViewRequirement("rewards", shift=-1, space=obs_space)
# self.view_requirements["prev_actions"] = ViewRequirement("actions", shift=-1, space=action_space)
# Debugging initialization
print(f"Initializing CustomTorchModel with obs_space: {obs_space.shape}, action_space: {action_space.shape}, num_outputs: {num_outputs}")
self.obs_size = obs_space.shape[0]
self.action_size = action_space.shape[0]
input_size = self.obs_size + self.action_size + 1 # Include last action and reward
self.num_layers = 1
# Define the fully connected layers
self.fc1 = nn.Linear(input_size, 128)
self.fc2 = nn.Linear(128, 64)
# LSTM configuration
self.lstm_hidden_size = 32
# Set the correct input size for LSTM
self.lstm = nn.LSTM(64, self.lstm_hidden_size, batch_first=True)
# Define output layers
self.policy_mean = nn.Linear(self.lstm_hidden_size, self.action_size)
self.policy_std = nn.Linear(self.lstm_hidden_size, self.action_size)
self.value_head = nn.Linear(self.lstm_hidden_size, 1)
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
print(f"Forward call with state: {state} and seq_lens: {seq_lens}")
x = torch.cat([input_dict["obs"], input_dict["prev_actions"], input_dict["prev_rewards"].unsqueeze(-1)], dim=1)
print(f"Concatenated input x shape: {x.shape}")
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
print(f"Post-FC layer shape: {x.shape}")
x = x.unsqueeze(1) # For LSTM compatibility.
print(f"Input to LSTM shape: {x.shape}")
if state is None or len(state) == 0:
h0 = torch.zeros(self.num_layers, x.size(0), self.lstm_hidden_size, device=x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.lstm_hidden_size, device=x.device)
print("Generated new initial states for LSTM")
else:
h0, c0 = state[0], state[1]
x, (hn, cn) = self.lstm(x, (h0, c0))
print(f"Output from LSTM x shape: {x.shape}, hn shape: {hn.shape}, cn shape: {cn.shape}")
x = x.squeeze(1) # Remove sequence length dimension.
print(f"Squeezed output shape: {x.shape}")
action_mean = self.policy_mean(x)
action_std = torch.exp(self.policy_std(x))
print(f"Action mean shape: {action_mean.shape}, Action std shape: {action_std.shape}")
self._value_out = self.value_head(x)
print(f"Value output shape: {self._value_out.shape}")
return torch.cat([action_mean, action_std], dim=-1), [hn, cn]
@override(TorchModelV2)
def value_function(self):
return self._value_out.squeeze(1)
# @override(ModelV2)
def get_initial_state(self, batch_size):
# Return zeros for both hidden and cell states with correct dimensions
return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
torch.zeros(self.num_layers, batch_size, self.hidden_size))
and the error I get with some printed output:
(PPO pid=8778) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=8778, ip=127.0.0.1, actor_id=7872f311615162416b32a0ed01000000, repr=PPO)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 533, in __init__
(PPO pid=8778) super().__init__(
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
(PPO pid=8778) self.setup(copy.deepcopy(self.config))
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 631, in setup
(PPO pid=8778) self.workers = WorkerSet(
(PPO pid=8778) ^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 159, in __init__
(PPO pid=8778) self._setup(
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 250, in _setup
(PPO pid=8778) self._local_worker = self._make_worker(
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 1016, in _make_worker
(PPO pid=8778) worker = cls(
(PPO pid=8778) ^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
(PPO pid=8778) self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=8778) self._build_policy_map(
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=8778) new_policy = create_policy_for_framework(
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
(PPO pid=8778) return policy_class(observation_space, action_space, merged_config)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
(PPO pid=8778) self._initialize_loss_from_dummy_batch()
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/policy.py", line 1396, in _initialize_loss_from_dummy_batch
(PPO pid=8778) actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 557, in compute_actions_from_input_dict
(PPO pid=8778) return self._compute_action_helper(
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
(PPO pid=8778) return func(self, *a, **k)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1260, in _compute_action_helper
(PPO pid=8778) dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/modelv2.py", line 255, in __call__
(PPO pid=8778) res = self.forward(restored, state or [], seq_lens)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 219, in forward
(PPO pid=8778) wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/Desktop/KTH_work/2D_obstacles/PPO/navigation/untitled2.py", line 154, in forward
(PPO pid=8778) x, (hn, cn) = self.lstm(x, (h0, c0))
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
(PPO pid=8778) return self._call_impl(*args, **kwargs)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
(PPO pid=8778) return forward_call(*args, **kwargs)
(PPO pid=8778) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 874, in forward
(PPO pid=8778) self.check_forward_args(input, hx, batch_sizes)
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 789, in check_forward_args
(PPO pid=8778) self.check_input(input, batch_sizes)
(PPO pid=8778) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 239, in check_input
(PPO pid=8778) raise RuntimeError(
(PPO pid=8778) RuntimeError: input.size(-1) must be equal to input_size. Expected 15, got 64
(PPO pid=8778) Forward call with state: [] and seq_lens: None
(PPO pid=8778) Concatenated input x shape: torch.Size([32, 15])
(PPO pid=8778) Post-FC layer shape: torch.Size([32, 64])
(PPO pid=8778) Input to LSTM shape: torch.Size([32, 1, 64])
(PPO pid=8778) Generated new initial states for LSTM
Can someone please help me?
I tried to modify the shape of the LSTM and make them compatible, but unfortunately nothing worked 🙁