MPO Extension
MPO Extension -- A more intuitive interpretation of MPO
Introduction
As stated in the last post, MPO is motivated from the perspective of "RL as inference". A following work, which can be seen as an extension of MPO, provides an alternative perspective — an intuitive perspective (policy evaluation and improvement) of this algorithm.
Generally speaking, many off-policy algorithms are implemented by alternating between two steps :
- Policy Evaluation : For the current policy, learn action-value fuction (Q-function).
- Policy Improvement : Given the current action-value function, improve the policy.
By standard definition, the objective function is
\begin{array}{cc} \mathcal{J}(\pi) = \mathbb{E}_{\pi,p(s_{0})}[\sum_{t=0}^{\infty}\gamma^{t} r(s_{t},a_{t})|s_{0}\sim p(\cdot), a_{t}\sim \pi(\cdot | s)] \end{array}
Policy Evaluation
In principle, any sufficiently accurate off-policy method for learing Q-functions can be applied here, e.g., use the simple 1-step TD learning : Fit a parametric Q-function $Q_{\phi}^{\pi}(s,a)$ with parameters $\phi$ by minimizing the squared TD error
\begin{array}{cc} \min_{\phi} (r(s_{t},a_{t})+\gamma Q_{\phi^{'}}^{\pi^{(k-1)}}(s_{t+1},a_{t+1}\sim \pi^{(k-1)}(a|s_{t+1})) - Q_{\phi}^{\pi^{(k)}}(s_{t},a_{t}))^{2} \end{array}
Policy Improvement
Intuitively, if for all state $s$, we improve the expectation $\bar{\mathcal{J}}(s,\pi) = \mathbb{E}_{\pi}[Q^{\pi^{(k)}}(s,a)]$ and our evaluation of $Q$ is accurate enough, then our objective $\mathcal{J}$ will be improved. Note that we don't want to fully optimize $\bar{\mathcal{J}}$ because evaluation of $Q$ is not exact, hence we don't want to be misled by such errors. Basically, the approach is a two-step procedure :
- Construct a non-parametric estimate $q$ s.t. $\bar{\mathcal{J}}(s,q)\geq \bar{\mathcal{J}}(s,\pi^{(k)})$
- Update policy by supervised learning (MLE): \(\begin{array}{cc} \pi^{(k+1)} = arg\;min_{\pi_{\theta}}\mathbb{E}_{\mu_{\pi}(s)}[KL(q(a|s)||\pi_{\theta}(a|s))] \end{array}\)
Finding Action Weights (Correspond to E-step)
In this step, we construct $q$ by sample based estimation. Intuitively, we want to assign probability to $q$ such that the 'better' actions have higher probability.
Given a learned approximate Q-function, from the replay buffer, we sample K states $\{s_{j}\}_{j=1,...,K}$. For each state $s_{j}$, we sample $N$ actions from the last policy distribution ($\pi^{k}$), then evaluate each state-action pair using the approximate Q-function. Now we get the states, actions, and their corresponding Q-values : $\{s_{j},\{a_{i},Q^{\pi^{(k)}}(s_{j},a_{i})\}_{i=1,...,N}\}_{j=1,...,K}$. For all $s_{j},a_{i}$, denote $q(a_{i}|s_{j})=q_{ij}$.
In general, we can calculate weights using any rank preserving transformation of the Q-values, here are some choices :
- Using ranking to transform Q-values : Choose the weight of the i-th best action for the j-th sampled state to be proportional to $q_{ij} \propto ln(\frac{N+\eta}{i})$, where $\eta$ is a temperature parameter.
- Using an exponential transformation of the Q-values : We want to obtain the weights by optimizing for an optimal assignment of action probabilities, and also constrain the change of the policy to avoid collapsing onto one action immediately. It can be acheived by solving the following KL regularized objective :
\begin{array}{cc} q_{ij} = \underset{q(a_{i}|s_{j})}{arg\;max}\sum_{j}^{K}\sum_{i}^{N} q(a_{i}|s_{j})Q^{\pi^{(k)}}(s_{j},a_{i})\\ s.t.\; \frac{1}{K}\sum_{j}^{K}\sum_{i}^{N}q(a_{i}|s_{j})\log\frac{q(a_{i}|s_{j})}{1/N}<\epsilon;\;\forall j,\sum_{i}^{N}q(a_{i}|s_{j})=1 \end{array}
The constraints here forces the weights to stay close to the last policy probabilities, and the weights are normalized. Its solution can be obtained in closed form :
\begin{array}{ccc} q_{ij} = q(a_{i}|s_{j}) = exp(Q^{\pi^{(k)}}(s_{j},a_{i})/\eta)/\sum_{i}exp(Q^{\pi^{(k)}}(s_{j},a_{i})/\eta). \end{array}
where the temperature parameter $\eta$ can be computed by solving the convex dual function :
\begin{array}{cc} \eta = arg\;min_{\eta}\eta\epsilon + \eta\sum_{j}^{K}\frac{1}{K}\log(\sum_{i}^{N}\frac{1}{N}exp(\frac{Q(s_{j},a_{i})}{\eta})) \end{array}
If you're farmiliar with bandit literature, it's easy to see that this is similar to EXP3 algorithm for adversarial bandit. Actually, on a high level, if the MDP collapses to a bandit setting, this framework can be related to the black-box optimization literature.
Corresponding to E-step in MPO, where we choose the based variational distribution $q(a|s)=\frac{q(a,s)}{\mu(s)}$ s.t. the lower bound on $\log p(\theta_{t}|R=1)$ is as tight as possible, we can derive the same solution from different perspectives.
Fitting an Improved Policy (Correspond to M-step)
Note that the $q$ we obtained is only over sampled state & actions, so we want to generalize it over the whole state and action space. For this, we want to minimize the KL divergence between the obtained sample based distribution and the parametric policy $\pi_{\theta}$. We solve a weighted supervised learning problem :
\begin{array}{cc} \pi^{(k+1)} = \underset{\pi_{\theta}}{arg\;max}\sum_{j}^{K}\sum_{i}^{N}q_{ij}\log \pi_{\theta}(a_{i}|s_{j}) \end{array}
As it's a supervised learning problem, it can suffer from overfitting, moreover, since the approximation of $Q^{\pi^{(k)}}$ is inexact, the change of the action distribution could be in the wrong direction. To limit the change in the parametric policy, we employ an additional KL constraint, hence the objective became :
\begin{array}{cc} \pi^{(k+1)} = \underset{\pi_{\theta}}{arg\;max}\sum_{j}^{K}\sum_{i}^{N}q_{ij}\log \pi_{\theta}(a_{i}|s_{j})\\ s.t.\;\sum_{j}^{K}\frac{1}{K}KL(\pi^{(k)}(a|s_{j})||\pi_{\theta}(a|s_{j}))<\epsilon_{\pi} \end{array}
where $\epsilon_{\pi}$ denotes the allowed expected change over state distribution in KL divergence for the policy.
This objective can be extended to a primal optimization problem that can be applied to gradient based optimization :
\begin{array}{cc} \underset{\theta}{\max}\;\underset{\alpha>0}{\min}L(\theta,\eta) = \sum_{j}\sum_{i}q_{ij}\log \pi_{\theta}(a_{i}|s_{j}) + \\ \alpha(\epsilon_{\pi}-\sum_{j}^{K}\frac{1}{K}KL(\pi^{(k)}(a|s_{j})||\pi_{\theta}(a|s_{j}))) \end{array}
This step corresponds to M-step in MPO, where we optimize the parameter $\theta$ of the policy $\pi(a|s,\theta)$ towards the obtained variational distribution $q(a|s)$.
Summary
Motivated from the perspective of policy evaluation and improvement, an off-policy actor-critic gradient-free algorithm is derived, and this algorithm draws on connections to black-box optimization literature and 'RL as an inference'. This can be seen as an interpretation of MPO (in previous post) from a different perspective.
Reference
Abdolmaleki, A., Springenberg, J. T., Degrave, J., Bohez, S., Tassa, Y., Belov, D., Heess, N., and Riedmiller, M. Rela- tive entropy regularized policy iteration. arXiv preprint arXiv:1812.02256, 2018a.