Lecture - Uncertainty in Deep Learning MT25, Stochastic approximate inference in DNNs
In this lecture, we use the mathematical tools developed in [[Lecture - Uncertainty in Deep Learning MT25, Some very useful mathematical tools]]U to derive an update rule for the approximate posterior in a shallow neural network, and then show how this generalises to deep neural networks.
Reparameterisation trick in a shallow network
Suppose:
- We are considering a deep neural network with vector outputs $y$
- $W \in \mathbb R^{K \times D}$ is the weight matrix of the last layer
- $b = 0$ for the output layer
- We don’t consider the rest of the weights of the network, so that we may write $f^W(\pmb x) = W^\top \phi(x)$, where $\phi(\pmb x)$ is the feature map implemented by the previous layers of the network
and consider the following generative story:
- Nature chose $W$ which defines a function $f^W(\pmb x) := W^\top \phi(x)$
- Then nature generated function values with inputs $x _ 1, \ldots, x _ N$ given by $f^W(x _ n)$
- These were corrupted with additive Gaussian noise $y _ n := f^W(x _ n) + \epsilon _ n$, $\epsilon _ n \sim \mathcal N(0, \sigma^2)$
- We then observe these corrupted values ${(x _ 1, y _ 1), \ldots, (x _ N, y _ N)}$
We have the prior:
- $p(w _ {k,d}) = \mathcal N(w _ {k,d} \mid 0, s^2)$ (i.e. each entry is Gaussian distributed)
and the likelihood:
- $P(Y \mid X, W) = \prod _ n \mathcal N(Y _ n; f^W(X _ n), \sigma^2 I _ D)$
and finally, we wish to approximate the posterior via the variational distribution $q _ {m, \sigma}(w _ {k,d})$ where
\[q _ {m, \theta}(w _ {k, d}) = \mathcal N(w _ {k,d} \mid m _ {k, d}, \sigma^2 _ {k,d})\]
and we collect these into the matrices $M \in \mathbb R^{K \times D}$ and $S \in \mathbb R^{K \times D}$ (note this is quite a strong condition on the variational distribution, since each $w _ {k,d}$ is independent).
Recall that we optimise against the ELBO, which we may write
\[\text{ELBO}(\theta) = \mathbb E _ {q _ {\theta}(W)} [\log p(Y \mid X, W)] - \text{KL}(q _ \theta(W), p(W))\]
We have shown previously that the $\text{KL}(q _ \theta(W), p(W))$ is analytic, and so the tricky part is really approximating the gradient of $\mathbb E _ {q _ {\theta}(W)} [\log p(Y \mid X, W)]$.
Derive an estimate $\hat G(\theta, {\hat \epsilon}$) using the reparameterisation trick.
@todo.