Curious Observation in SAC Implementation

2 minute read

While working on my entry for the MineRL 2020 competition, I had to implement the soft-actor critic (SAC) algorithm from scratch. However, while coding the calculation of the log probabilities of the actions, I came across a tiny obstacle.

Let me provide a little context here: the policy network outputs the mean and variance of a Gaussian distribution, from which we can get our actions via the reparametrization trick:

\[\mathbf{a}_t = \mu_\phi(\mathbf{s}_t) + \sigma^2_\phi(\mathbf{s}_t) \cdot \epsilon \quad \text{where}, \epsilon \sim \mathcal{N}(0,1)\]

But we need to have our actions bounded, and so we take a $ \mathrm{tanh} $ over the sampled actions, $ \tilde{\mathbf{a}}_t = \mathrm{tanh}(\mathbf{a}_t) $. This “squashing” requires for an appropriate correction when calculating $ \log \pi (\mathbf{a}_t | \mathbf{s}_t) $, which the authors detail as:

\[\begin{equation} \log \pi(\tilde{\mathbf{a_t}} \mid \mathbf{s_t})=\log \mathcal{N}_\phi(\mathbf{a}_t \mid \mathbf{s_t}, \epsilon)-\sum_{i=1}^{D} \log \left(1-\tanh ^{2}\left(\left[ a_{t} \right]_i\right)\right) \label{eq1} \end{equation}\]

where, $ \mathcal{N}_\phi $ represents the reparametrized Gaussian distribution.

In code, you can achieve this simply by:

dist = Normal(mean, std)
# sample actions from reparametrized distribution
gauss_actions = dist.rsample()
# squashed actions
actions = torch.tanh(gauss_actions)
# log probabilities
log_prob = dist.log_prob(gauss_actions) - 
        torch.sum(torch.log(1 - actions ** 2), dim=1)

This is what I wrote the first time I coded up the algorithm. However if you try and run this, you will notice that dist.log_prob() returns a vector instead of a scalar, which is where the problem began for me.

After comparing this with other implementaitons, I noticed a key operation I had missed. Take for example the stable-baselines implementation which boils down to the following:

log_prob = dist.log_prob(gauss_actions).sum(dim=1) - 
        torch.sum(torch.log(1 - actions ** 2), dim=1)

This operation is justified ( and made more clear) by the sum_independent_dims() function:

def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
    Continuous actions are usually considered to be independent,
    so we can sum components of the ``log_prob`` or the entropy.
    :param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,)
    :return: (th.Tensor) shape: (n_batch,)
    if len(tensor.shape) > 1:
        tensor = tensor.sum(dim=1)
        tensor = tensor.sum()
    return tensor


The first part of the problem comes from the distribution these standard implementations use. Stable-baselines uses the torch.distributions.Normal as the underlying distribution. However if you read the documentation page, you will see that pytorch also offers a MultivariateNormal distribution. The difference between the two is that the former is a scalar Gaussian distribution. This means that if you give it a multi-dimensional mean, $ \mathbf{\mu} \in \mathbb{R}^n, ~n > 1 $, it is the same as sampling from $ n $ different Gaussian distributions. This is why calling log_prob() returns a vector with the implementation above, since you are actually working with multiple distributions.

Further note that a multivariate Gaussian distribution with a diagonal covariance matrix boils down to the summation operation:

\[\begin{align*} \log \pi( \mathbf{x} ) & = \left(-\frac{1}{2}(\mathbf{x}-\boldsymbol{\mu})^{\mathrm{T}} \boldsymbol{\Sigma}^{-1}(\mathbf{x}-\boldsymbol{\mu})\right) - \frac{n}{2} \log 2 \pi -\frac{1}{2}\log|\mathbf{\Sigma}| \\ & = -\frac{1}{2}\sum_{i=1}^n \frac{1}{\sigma_i^2}(x_i - \mu_i)^2 - \frac{n}{2}\log 2\pi - \sum_{i=1}^n \log \sigma_i \\ &= \sum_{i=0}^{n} \log \mathcal{N}(x_i, \mu_i, \sigma_i)\end{align*}\]

Therefore, the reason why we can sum over action dimensions is because we have enforced this independence, by using a diagonal covariance matrix (which is outputted by the policy network). The question then is why do the standard implementations not use MultivariateNormal instead of Normal?

Leave a comment