I’m working with the Stable Baselines3 library to train a Proximal Policy Optimization (PPO) model for a reinforcement learning project. I want to integrate a Dirichlet distribution for action selection to make the sum of elements in the action vector equal to 1 rather than applying a softmax to the default agent’s action. By default, PPO uses Gaussian distribution in sb3.
I tried to achieve this by extending the PPO class and overwriting methods to include Dirichlet sampling for action probabilities. However, this approach hasn’t worked as expected. Here’s a snippet of what I tried:
import torch
import torch.nn as nn
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
import numpy as np
class DirichletActorCriticPolicy(ActorCriticPolicy):
def __init__(self, *args, alpha=None, **kwargs):
super(DirichletActorCriticPolicy, self).__init__(*args, **kwargs)
self.alpha = 0.1
def action(self, obs):
logits = self.extract_features(obs)
probs = np.exp(logits) / np.exp(logits).sum(-1, keepdims=True)
action = np.random.dirichlet(self.alpha * probs.reshape(-1, 3))
return action, None
env = ToyDirichletEnv()
model = PPO(DirichletActorCriticPolicy, env, verbose=1)
print(model.policy)
model.learn(total_timesteps=50)
For debugging, I implemented a toy environment that plots the sum of the action to check if the sum is 1.
class ToyDirichletEnv(gym.Env):
def __init__(self):
super(ToyDirichletEnv, self).__init__()
self.action_space = gym.spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32)
self.observation_space = gym.spaces.Box(low=np.array([0, 0, 0, 0]), high=np.array([10]*4), dtype=np.float32)
self.position = 5
def reset(self, seed=None):
self.position = 5
return np.array([self.position]*4).astype(np.float32), {}
def step(self, action):
print('sum0(action)', sum(action))
reward = sum([el**2 for el in action])
self.position += 1
done = False
if self.position > 15:
done = True
return np.array([self.position]*4).astype(np.float32), reward, done, done, {}
def render(self, mode='human'):
print(f"Position: {self.position}")
It seems that it never gets inside the method.