I’m writing environment for rl agent training.
My env.step method takes as action array with length 3
def scan(self, f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
def step_env(
self,
key: chex.PRNGKey,
state: EnvState,
action: Union[int, float, chex.Array],
params: EnvParams,
) -> Tuple[chex.Array, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
c_action = jnp.clip(action,
params.min_action,
params.max_action)
_, m1 = self.scan(self.Rx, 0, action[0])
_, m2 = self.scan(self.Rx, 0, action[1])
_, m3 = self.scan(self.Rx, 0, action[2])
I vectorize the env.step using
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(rng_step,
env_state,
action,
env_params)
If I want to run the env.step I got error
How is it possible? If I plot the action in scan function a get array with length 5 (I vectored env.step for 5 envs)