Flow Annealed Importance Sampling Bootstrap
My notes on Bootstrap Your Flow and Flow Annealed Importance Sampling Bootstrap (FAB).
TL;DR
Training parametric distributions \(q_{\theta}\) to approximate a target density \(p\) often suffers from high-variance objectives. These objectives can sometimes be rewritten via an Annealed Importance Sampling (AIS) estimator. Choosing the AIS target distribution in a smart way can help reduce the variance of these objectives.
Setup
Suppose we want to approximate some target density \(p\). We assume access to a parametric family of densities \(q_{\theta}\) that we can sample from and evaluate. For example, \(q_{\theta}\) may be parametrized as a normalizing flow. Given some divergence \(D\) between probability distributions, we will select \(\theta\) to minimize the divergence \(D\) between \(q_{\theta}\) and \(p\).
KL Example
Consider the classical example of the KL divergence \(\text{KL}(q \| p) = \int q(x)\log \frac{q(x)}{p(x)}dx\). We could either try and minimize \(\text{KL}(q_{\theta} \| p)\) or \(\text{KL}(p \| q_{\theta})\). Let’s consider \(\text{KL}(q_{\theta} \| p)\). This objective is tractable to optimize (yay!) since we can sample from \(q_{\theta}\). However, it exhibits mode-seeking behaviour. This is because \(\log \frac{q_{\theta}(x)}{p(x)}\) blows up when \(q_{\theta}\) assigns high density to \(x\) such that \(p(x)\) is small. Hence, minimizing \(\text{KL}(q_{\theta} \| p)\) will choose a \(q_{\theta}\) that only assigns mass where \(p\) has sufficient density. The optimal \(q_{\theta}\) will avoid placing any mass where \(p\) does not have a lot. So, if \(p\) is multi-modal, \(q_{\theta}\) will most likely just learn a few of these modes and place all mass in them. We drop modes as a result. This is a downside of the objective.
On the other hand, if we optimize \(\text{KL}(p \| q_{\theta})\), then this exhibits mode-covering behaviour. This is because \(\log \frac{p(x)}{q_{\theta}(x)}\) will blow up when \(p\) assigns a lot of mass to \(x\), but \(q_{\theta}\) doesn’t. So, the learned \(q_{\theta}\) will try and place mass everywhere that \(p\) has mass (yay!). However, the issue is that we now need samples from \(p\) to train with this objective. We might not have access to samples from \(p\) though (or getting samples could be very expensive). So, we need to find other ways of training this objective without relying on samples from \(p\).
We could try to use Importance Sampling (IS) and rewrite the objective as \(\mathbb{E}_{x \sim q_{\theta}}[\frac{p(x)}{q_{\theta}(x)}\log \frac{p(x)}{q_{\theta}(x)}]\). We still have the mode-covering behaviour (since the objective itself has not changed) and we can now train on samples from \(q_{\theta}\). However, the variance of our IS estimator will be very large. That is, \(\text{Var}_{x \sim q_{\theta}}(\frac{p(x)}{q_{\theta}(x)}\log \frac{p(x)}{q_{\theta}(x)}))\) can be big and make optimization tricky. In particular, we know that the expected value is positive (since KL is positive). Observe that \(\frac{p(x)}{q_{\theta}(x)}\log \frac{p(x)}{q_{\theta}(x)}\) is positive for \(x\) such that \(p(x) > q_{\theta}(x)\). However, if \(q_{\theta}\) is very different from \(p\) (for example, at the onset of training), then most samples \(x\) from \(q_{\theta}\) will be such that \(q_{\theta}(x) > p(x)\) and thus \(\frac{p(x)}{q_{\theta}(x)}\log \frac{p(x)}{q_{\theta}(x)} < 0\). Said another way, for most \(x \sim q_{\theta}\), the integrand will be negative. To compensate, once in a while the integrand will be a very large positive value. This makes training challenging (have high variance).
Main Idea of FAB
This tradeoff exists for many divergences \(D\), which can often be written in the form \(D(q \| p) = \int q(x) f(q(x), p(x))dx\) for some \(f\). \(D(q \| p)\) will often be mode-seeking, whereas \(D(p \|q)\) will either require samples from \(p\) or have high variance if using a modified IS objective for training. (This is exactly what we saw in the KL example above.)
FAB considers the case of a mode-covering \(D(p \|q_{\theta})\). Note that for any density \(g\), we have
\[D(p \|q_{\theta}) = \mathbb{E}_{x \sim g}[\frac{p}{g}f(p, q_{\theta})].\]As discussed for the KL example, often, if we set \(g = q_{\theta}\), then the IS estimator has high variance. However, we know that the optimal \(g\) (the one minimizing the variance) is \(g_*(x) \propto p(x)f(p(x), q_{\theta}(x))\). Sampling from this optimal \(g_{*}\) is usually not tractable (which is what we would want for IS). So, FAB suggests using AIS to estimate the above objective. In particular, AIS will be initialized at \(q_{\theta}\) and will target \(g_{*}\). The output of AIS is a collection of samples \(x_1, \dots, x_n\) and corresponding weights \(w_1, \dots , w_n\). If AIS is implemented well, then the particles \(x_i\) will be distributed approximately as \(g_{*}\). Regardless though, AIS guarantees us that the weights \(w_i\) are chosen such that for any function \(h\), we have
\[\mathbb{E}_{\text{AIS}}[w_i h(x_i)] = \int h(x) g_{*}(x)dx,\]where the expectation is over the randomness of the AIS sampling procedure. Therefore we can now use the AIS outputted samples and weights to obtain an unbiased estimator of our training objective (i.e. we take \(h = \frac{p}{g_{*}}f(p, q_{\theta})\)). Since the samples are approximately from \(g_{*}\), the variance of our estimator should be small and so training will be tractable!
Paper Details
Both papers specifically focus on using the \(\alpha\)-divergence with \(\alpha = 2\). Up to constants, this is given by
\[D_{2}(p \|q_{\theta}) = \int \frac{p^2(x)}{q_{\theta}(x)}dx.\]This objective is a natural one as it is mode-covering and it corresponds to minimizing the variance of the importance weights \(\frac{p(x)}{q_{\theta}(x)}\). (In particular, after training \(q_{\theta}\), if IS is used to estimate expectations w.r.t. \(p\) so that the estimator is unbiased, then the weights will have low variance.) As discussed, this objective can be estimated with either samples from \(p\) or \(q_{\theta}\). However, the estimator will have lower variance if we use \(p\). So, the original paper Bootstrap Your Flow paper suggests using AIS to target \(p\). On the other hand, the Annealed Importance Sampling Bootstrap paper suggests using AIS to target \(g_{*} \propto \frac{p^2}{q_{\theta}}\) (i.e. choosing \(g\) that will induce the minimal variance estimator).
This method is performing a sort of bootstrapping since as the flow \(q_{\theta}\) improves, the initial samples to AIS will be closer to \(p\) which should make the final samples closer to \(\frac{p^2}{q}\) and so the AIS estimator will have lower variance. So, the expected values computed via AIS will be improved, and so the gradient information passed to optimizing \(q_{\theta}\) will be better, and thus \(q_{\theta}\) will become even closer to \(p\).
The paper explores using both the original weights \(w_i\) from AIS and the self-normalized weights \(\frac{w_i}{\sum_j w_j}\). They found that using the self-normalized weights improved training. (Just watch out for whether you are differentiating through the weights or not, as they depend implicitly on \(\theta\)).
See the papers for extensions on using AIS bootstrapping to train with other divergences/objectives/models.
Here’s one idea I have. As the authors point out, the distribution that minimizes the variance of the self-normalized weights is different than for the unnormalized ones. What if AIS was used to target this distribution instead? In general, this would not be tractable. However, for some divergences it will be, and may improve results even more.
Replay Buffer
The next contribution of Annealed Importance Sampling Bootstrap is to introduce a prioritized replay buffer. Once the parameters \(\theta\) are updated, we could dispose with the old AIS samples, and draw entirely new ones from AIS initialized at the new \(q_{\theta}\) targeting the new \(g_{*}\). This would require invoking the AIS procedure often, which may be very expensive. Instead, the authors propose to keep old samples in a buffer and to adjust the training procedure accordingly. Thus we can rely on the AIS sampling procedure less frequently.
Suppose that the replay buffer has samples \(x_i\) with corresponding weights \(w_i\) targeting \(g_{\theta'} \propto \frac{p^2}{q_{\theta'}}\). By rescaling the weights via
\[w_i \mapsto w_i \frac{g_{\theta}}{g_{\theta'}} = w_i \frac{q_{\theta'}}{q_{\theta}},\]the existing samples \(x_i\) and updated weights \(w_i \frac{q_{\theta'}}{q_{\theta}}\) now target \(g_{\theta} \propto \frac{p^2}{q_{\theta}}\). Therefore, suppose we store the weights, particles and previous \(q_{\theta'}\) values in the buffer. After we update to the new parameters \(\theta\), we can iterate through the buffer to update the weights and \(q_{\theta'}\) values so that they target the new \(g_{\theta}\). In this way, we can keep old samples and add fresh ones to the buffer without breaking anything. It is also nice that we do not need to reevaluate \(p\) which could be very expensive.
The authors also propose that instead of sampling uniformly from the buffer and using the AIS weights to weight the expected values, to rather sample proportional to the normalized AIS weights (and not include a weight in the expected value calculation). This procedure has the same expected value as the self-normalizing objective.
Since we are sampling in proportion to the weights \(w \approx \frac{g}{q_{\theta}} = \frac{p^2}{q_{\theta}^2}\), this means that sampling prioritizes samples where the current flow \(q_{\theta}\) under assigns mass relative to the target \(p\). As \(q_{\theta}\) improves on these points, the weights of these particles will decrease, and so the buffer will prioritize other samples instead, providing variability.
Iterating through all of the data points in the buffer before sampling is expensive, so the paper discusses some approximations to speed up training at the cost of introducing a small bias.
TODO
I want to read some more of the related literature before returning to their RELATED WORK and DISCUSSION sections.