Lecture - Uncertainty in Deep Learning MT25, Approximate inference
In this lecture, we are introduced to the idea of “variational inference” which is where we estimate the posterior distribution $\mathbb P(W \mid X, Y)$ by optimising the parameters $\theta$ of a simpler distribution.
Why do this? In [[Lecture - Uncertainty in Deep Learning MT25, Bayesian probabilistic modelling of functions]]U we considered the simple situation of Bayesian linear regression over features $\phi(\cdot)$ where our likelihood and priors were both Gaussians, and found that it was possible to derive the mean and variance of the posterior and predictive distribution analytically:
\[\mathbb P(W \mid X, Y) = \mathcal N(W \mid \mu, \Sigma)\]where
\[\begin{aligned} &\Sigma = \left( \frac{1}{\sigma^2} \phi(X)^\top \phi(X) + \frac{1}{s^2} I_k \right)^{-1} \\ &\mu = \frac{1}{\sigma_2} \phi(X)^\top Y \end{aligned}\]and
\[\mathbb P(y^\ast \mid x^\ast, X, Y) = \mathcal N(y^\ast \mid \mu^\ast, \Sigma^\ast)\]where
\[\begin{aligned} &\mu^\ast = \mu^\top \phi(x^\ast) \\ &\text{Var}(y^\ast \mid x^\ast, X, Y) = \sigma^2 + \phi(x^\ast)^\top \Sigma \phi(x^\ast) \end{aligned}\]Take a look at these terms – just because they are analytic, it doesn’t mean that they are easy to compute. For example, if $\phi$ outputs vectors of size $K = 10,000$, we have to invert a $10,000 \times 10,000$ matrix.
More generally, we might not even be able to compute the posterior distribution in the first place. In this specific setup it was easy because we knew by properties of Gaussians that the posterior was a Gaussian also, but with more general distributions we don’t even have the ability to evaluate the posterior.
Hence it’s necessary to come up with some tractable way of approximating the posterior distribution $\mathbb P(W \mid X, Y)$. Variational inference aims to approximate $\mathbb P(W \mid X, Y)$ by a simpler distribution $q _ \theta(W)$ parameterised by $\theta$, and find the best values of $\theta$.
This might seem impossible: how can we improve our approximation $q _ \theta (W)$ if we can’t compare it to $\mathbb P(W \mid X, Y)$? It turns out that by a clever choice of optimisation objective and some manipulations, we can derive a tractable objective that doesn’t involve computing $\mathbb P(W \mid X, Y)$ explicitly.
Distances (divergences) between distributions
See [[Notes - Uncertainty in Deep Learning MT25, Kullback-Leibler divergence]]U.
Minimising the divergence between the variational distribution and the true posterior
In the context of variational inference, we aim to approximate the posterior $p(W \mid X, Y)$ by some variational distribution $q _ \theta(W)$.
What objective do we minimise?
In the context of variational inference, we aim to approximate the posterior $p(W \mid X, Y)$ by some variational distribution $q _ \theta(W)$. To do this, we minimise
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \int q _ \theta(W) \log \frac{q _ \theta (W)}{p(W \mid X, Y)} \text dW\]
@State the identity that gives a tractable objective from this identity.
In the context of variational inference, we aim to approximate the posterior $p(W \mid X, Y)$ by some variational distribution $q _ \theta(W)$. To do this, we minimise
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \int q _ \theta(W) \log \frac{q _ \theta (W)}{p(W \mid X, Y)} \text dW\]
@Prove that we may rearrange this to
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \log p(Y \mid X) - \int q _ \theta(W) \log p(Y \mid X, W) \text dW + \text{KL}(q _ \theta(W), p(W))\]
Rearranging gives
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \log p(Y \mid X) - \int q _ \theta(W) \log p(Y \mid X, W) \text dW + \text{KL}(q _ \theta(W), p(W))\]as required.
In the context of variational inference, we aim to approximate the posterior $p(W \mid X, Y)$ by some variational distribution $q _ \theta(W)$. To do this, we minimise
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \int q _ \theta(W) \log \frac{q _ \theta (W)}{p(W \mid X, Y)} \text dW\]
We may rearrange this to
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \log p(Y \mid X) - \int q _ \theta(W) \log p(Y \mid X, W) \text dW + \text{KL}(q _ \theta(W), p(W))\]
In this context, @state the ELBO, describe where it gets its name and explain why it is useful.
Let
\[\text{ELBO}(\theta) = \int q _ \theta(W) \log p(Y \mid X, W) \text dW - \text{KL}(q _ \theta(W), p(W))\]Then we can write
\[\log p(Y \mid X) - \text{ELBO}(\theta) = \text{KL}(q _ \theta(W), p(W \mid X, Y))\]Since $\log p(Y \mid X)$ is just a constant (the log likelihood of the data), maximising $\text{ELBO}(\theta)$ must minimise $\text{KL}(q _ \theta(W), p(W \mid X, Y))$.
Since $\text{KL}(q _ \theta(W), p(W \mid X, Y)) \ge 0$, we see that
\[\text{ELBO}(\theta) \le \log p(Y \mid X)\]and so it’s the “evidence lower bound”.
It is is useful because all terms of the ELBO are tractable.
In the context of variational inference, we aim to approximate the posterior $p(W \mid X, Y)$ by some variational distribution $q _ \theta(W)$. To do this, we minimise
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \int q _ \theta(W) \log \frac{q _ \theta (W)}{p(W \mid X, Y)} \text dW\]
Let
\[\text{ELBO}(\theta) = \int q _ \theta(W) \log p(Y \mid X, W) \text dW - \text{KL}(q _ \theta(W), p(W))\]
Can you describe what both of these terms represent intuitively?
- $\int q _ \theta(W) \log p(Y \mid X, W) \text dW$: The first term measures how well we explain the data, this would be maximised if $q _ \theta(W) = \hat W _ \text{MLE}$.
- $\text{KL}(q _ \theta(W), p(W))$: This is penalising complicated models, ideally we want the variational distribution to be quite similar to the prior.
In the context of variational inference, we aim to approximate the posterior $p(W \mid X, Y)$ by some variational distribution $q _ \theta(W)$. To do this, we minimise
\[\text{KL}(q _ \theta(W), p(W \mid X, Y)) = \int q _ \theta(W) \log \frac{q _ \theta (W)}{p(W \mid X, Y)} \text dW\]
Let
\[\text{ELBO}(\theta) = \int q _ \theta(W) \log p(Y \mid X, W) \text dW - \text{KL}(q _ \theta(W), p(W))\]
@Prove that $\log p(Y \mid X) \ge \text{ELBO}(\theta)$ using Jensen’s inequality.
The (analytic) ELBO for multivariate Bayesian basis function regression
Suppose:
- We are considering a deep neural network with vector outputs $y$
- $W \in \mathbb R^{K \times D}$ is the weight matrix of the last layer
- $b = 0$ for the output layer
- We don’t consider the rest of the weights of the network, so that we may write $f^W(\pmb x) = W^\top \phi(x)$, where $\phi(\pmb x)$ is the feature map implemented by the previous layers of the network
and consider the following generative story:
- Nature chose $W$ which defines a function $f^W(\pmb x) := W^\top \phi(x)$
- Then nature generated function values with inputs $x _ 1, \ldots, x _ N$ given by $f^W(x _ n)$
- These were corrupted with additive Gaussian noise $y _ n := f^W(x _ n) + \epsilon _ n$, $\epsilon _ n \sim \mathcal N(0, \sigma^2)$
- We then observe these corrupted values ${(x _ 1, y _ 1), \ldots, (x _ N, y _ N)}$
We have the prior:
- $p(w _ {k,d}) = \mathcal N(w _ {k,d} \mid 0, s^2)$ (i.e. each entry is Gaussian distributed)
and the likelihood:
- $P(Y \mid X, W) = \prod _ n \mathcal N(Y _ n; f^W(X _ n), \sigma^2 I _ D)$
and finally, we wish to approximate the posterior via the variational distribution $q _ {m, \sigma}(w _ {k,d})$ where
\[q _ {m, \theta}(w _ {k, d}) = \mathcal N(w _ {k,d} \mid m _ {k, d}, \sigma^2 _ {k,d})\]
and we collect these into the matrices $M \in \mathbb R^{K \times D}$ and $S \in \mathbb R^{K \times D}$ (note this is quite a strong condition on the variational distribution, since each $w _ {k,d}$ is independent).
@Prove that the ELBO can be given analytically in this case as
\[\text{ELBO}(M, S) = -\frac{1}{2\sigma^2}
\left(
\sum _ n \|Y _ n - M^\top \phi(X _ n)\| _ 2^2
+ \phi(X _ n)^\top \mathrm{diag}(SS^\top)\phi(X _ n)
\right)
- \frac{ND}{2} \log 2\pi\sigma^2
- \sum _ {k,d} \frac{1}{2}\left(
s^{-2}\sigma _ {k,d}^2 + s^{-2}m _ {k,d}^2 - 1 + \log\left(\frac{s^2}{\sigma _ {k,d}^2}\right)
\right)\]
Recall that the ELBO is defined (or derived) as follows:
\[\text{ELBO}(\theta) = \int q _ \theta(W) \log p(Y \mid X, W) \text dW - \text{KL}(q _ \theta(W), p(W))\]We tackle each term separately. For notation convenience, use $\phi _ n := \phi(X _ n)$.
First term:
\[\begin{aligned} &\int q(w) \log p(Y \mid X, W) \text dW \\ =& \int q(W) \sum^N _ {i = 1} \left[-\frac D 2 \log (2\pi \sigma^2) - \frac{1}{2\sigma^2} \vert \vert Y _ n - W^\top \phi _ n \vert \vert _ 2^2\right] \text dW \\ =& -\frac{ND}{2} \log (2\pi \sigma^2) - \frac{1}{2\sigma^2} \sum^N _ {n = 1} \mathbb E _ q[ \vert \vert Y _ n - W^\top \phi _ n \vert \vert ^2 _ 2] \\ =& -\frac{ND}{2} \log(2\pi \sigma^2) - \frac{1}{2\sigma^2} \sum^N _ {n = 1} \sum^D _ {d = 1} \mathbb E _ q[(Y _ {n,d} - W _ {:, d}^\top \phi _ n)^2] \\ =& -\frac{ND}{2} \log(2\pi \sigma^2) - \frac{1}{2\sigma^2} \sum^N _ {n=1} \sum^D _ {d=1} ((Y _ {n,d} - M _ {:,d}^\top \phi _ n)^2 + \text{Var} _ q(W _ {:, d}^\top \phi _ n)) \\ =& -\frac{ND}{2} \log(2\pi \sigma^2) - \frac{1}{2\sigma^2} \sum^N _ {n=1} ( \vert \vert Y _ n - M^\top \phi _ n \vert \vert ^2 _ 2 + \sum^D _ {d=1} \phi _ n^\top \text{Cov} _ q(W _ {:, d} \phi _ n)) \\ =& -\frac{ND}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum^N _ {n=1} ( \vert \vert Y _ n - M^\top \phi _ n \vert \vert ^2 _ 2 + \phi _ n^\top \text{diag}(SS^\top) \phi _ n) \end{aligned}\]Second term:
\[\begin{aligned} \text{KL}(q(W), p(W)) &= \int q(W) \log \frac{q(W)}{p(W)} \text dW \\ &= \sum^K _ {k = 1} \sum^D _ {d = 1} \int q(w _ {k, d}) \log \frac{q(w _ {k,d})}{p(w _ {k, d})} \text d w _ {k, d} \\ &= \sim^K _ {k=1} \sum^D _ {d=1} \text{KL}(\mathcal N(m _ {k, d}, \sigma^2 _ {k,d}), \mathcal N(0, s^2)) \\ &= \sum^K _ {k = 1} \sum^D _ {d=1} \frac 1 2 \left( \frac{\sigma^2 _ {k,d}}{s^2} + \frac{m _ {k,d}^2}{s^2} - 1 + \log \frac{s^2}{\sigma^2 _ {k, d}} \right) \end{aligned}\]Combining these together:
\[\begin{aligned} \text{ELBO}(M, S) &= \int q(W) \log p(Y \mid X, W )\text dW - \text{KL}(q(W), p(W)) \\ &= -\frac{1}{2\sigma^2} \sum^N _ {n=1} \left( \vert \vert Y _ n - M^\top \phi(X _ n) \vert \vert ^2 + \phi(X _ n)^\top \text{diag}(SS^\top) \phi(X _ n) \right) \\ &\quad\quad - \frac{ND}{2} \log(2\pi\sigma^2) - \sum^K _ {k=1}\sum^D _ {d=1} \frac 1 2 \left( s^{-2} \sigma _ {k,d}^2 + s^{-2} m _ {k,d}^2 - 1 + \log \frac{s^2}{\sigma^2 _ {k,d}} \right) \\ &= -\frac{1}{2\sigma^2} \left( \sum _ n \|Y _ n - M^\top \phi(X _ n)\| _ 2^2 + \phi(X _ n)^\top \mathrm{diag}(SS^\top)\phi(X _ n) \right) - \frac{ND}{2} \log 2\pi\sigma^2 - \sum _ {k,d} \frac{1}{2}\left( s^{-2}\sigma _ {k,d}^2 + s^{-2}m _ {k,d}^2 - 1 + \log\left(\frac{s^2}{\sigma _ {k,d}^2}\right) \right) \end{aligned}\]Suppose:
- We are considering a deep neural network with vector outputs $y$
- $W \in \mathbb R^{K \times D}$ is the weight matrix of the last layer
- $b = 0$ for the output layer
- We don’t consider the rest of the weights of the network, so that we may write $f^W(\pmb x) = W^\top \phi(x)$, where $\phi(\pmb x)$ is the feature map implemented by the previous layers of the network
and consider the following generative story:
- Nature chose $W$ which defines a function $f^W(\pmb x) := W^\top \phi(x)$
- Then nature generated function values with inputs $x _ 1, \ldots, x _ N$ given by $f^W(x _ n)$
- These were corrupted with additive Gaussian noise $y _ n := f^W(x _ n) + \epsilon _ n$, $\epsilon _ n \sim \mathcal N(0, \sigma^2)$
- We then observe these corrupted values ${(x _ 1, y _ 1), \ldots, (x _ N, y _ N)}$
We have the prior:
- $p(w _ {k,d}) = \mathcal N(w _ {k,d} \mid 0, s^2)$ (i.e. each entry is Gaussian distributed)
and the likelihood:
- $P(Y \mid X, W) = \prod _ n \mathcal N(Y _ n; f^W(X _ n), \sigma^2 I _ D)$
and finally, we wish to approximate the posterior via the variational distribution $q _ {m, \sigma}(w _ {k,d})$ where
\[q _ {m, \theta}(w _ {k, d}) = \mathcal N(w _ {k,d} \mid m _ {k, d}, \sigma^2 _ {k,d})\]
and we collect these into the matrices $M \in \mathbb R^{K \times D}$ and $S \in \mathbb R^{K \times D}$ (note this is quite a strong condition on the variational distribution, since each $w _ {k,d}$ is independent).
The ELBO can be given analytically in this case as:
\[\text{ELBO}(M, S) = -\frac{1}{2\sigma^2}
\left(
\sum _ n \|Y _ n - M^\top \phi(X _ n)\| _ 2^2
+ \phi(X _ n)^\top \mathrm{diag}(SS^\top)\phi(X _ n)
\right)
- \frac{ND}{2} \log 2\pi\sigma^2
- \sum _ {k,d} \frac{1}{2}\left(
s^{-2}\sigma _ {k,d}^2 + s^{-2}m _ {k,d}^2 - 1 + \log\left(\frac{s^2}{\sigma _ {k,d}^2}\right)
\right)\]
Given that this yields the optimal approximate posterior as
\[p(W \mid X,Y) \approx q^\ast _ {M,S}(W) = \mathcal N(W; M^\ast, S^\ast)\]
where $M^\ast, S^\ast$ are the parameters that maximise the ELBO, what does the predictive distribution $p(y^\ast \mid x^\ast, X, Y)$ look like? What is this similar to?
This is almost identical to the predictive distribution in the exact Gaussian posterior setting, but now we are using the approximate mean and variance instead.