Neural Processes for Image Completion π¨π π
Presented by Christabella Irwanto
- π slides
Agenda π π
- Motivation
- Related work
- Neural process variants
- Experiments
- Key learnings
- Conclusion
Motivationπ‘ π
- Two popular approaches to function approximation in machine learning
- Neural networks (NNs)
- Bayesian inference on stochastic processes, most common being Gaussian processes (GP)
- Combine the best of GPβs and NNβs π
Deep neural networks (NN) π
- Training of parametric function via gradient descent
- π Excel at function approximation
- π Tends to need a lot of training data
- π (why?) Prior knowledge can only be specified in rather limited ways, e.g. architecture
- π Fast forward-passes at test-time
- π Inflexible, need to be trained from scratch for new functions
Gaussian processes (GP) π
- Compute prior distribution over functions, update belief with data, and sample from distribution for predictions
- Allows reasoning about multiple functions
- Captures covariability in outputs given inputs
- π Encode knowledge in prior stochastic process
- π Kernels predefined in restricted functional forms
- π Computationally expensive for large datasets
Best of both worlds π π
- Neural processes ((Garnelo et al., 2018), (Garnelo et al., 2018)) are a class of NN-based models that approximate stochastic processes
- Like GPs, NPs
- model distributions over functions
- provide uncertainty estimates
- are more flexible at test time
- Unlike GPs, NPs
- learn an implicit kernel from the data directly
- computationally efficient at training and prediction (like NNs)
Meta-learning π
- Deep learning only approximates one function
- Meta-learning algorithms try to approximate distributions over functions
- Often implemented as deep generative models that do few-shot estimations of underlying density of the data, like NPs
- NPs constitute a high-level abstraction; reusable for multiple tasks
Image completion task π
Figure 5 from (Garnelo et al., 2018)
- Regression over functions in \(f: [0,1]^2 \rightarrow [0,1]^C\)
- \(f\) accepts pixel coordinate pairs normalized to \([0, 1]\)
- \(f\) outputs channel intensities at input coordinates
- Task is not an end in itself (as a generative model)
- Complex 2-D functions easy to evaluate visually
Neural process variants π
- Conditional neural process (CNP)
- Latent variable model of CNP
- Neural process (NP)
- Attentive neural process (ANP)
Conditional neural process (CNP) π
Architecture of CNP from (DeepMind, 2018)
- Accept pairs of context points \((x_i, y_i)\) (coordinates and pixel intensity)
- \((x_i, y_i)\) passed through encoder \(e\) to get individual representations \(r_i\)
- aggregator \(a\) combines \(r_i\) into global representation \(r\)
- Decoder \(d\) accepts both \(r\) and target location \(x_T\)
Training of CNP π
- We have a set of observations \(O = {(x_i, y_i)}^{n-1}_{i=0} \subset X \times Y\).
- CNP is approximate conditional stochastic process \(Q_\theta\)
- We train \(Q_\theta\) to predict \(O\) conditioned on a random subset of \(O\) of size \(N \sim \text{uniform}[0, \dotsc,n-1]\), \(O_N = {(x_i, y_i)}^{N}_{i=0}\)
- Minimize negative conditional log-likelihood of the targets’ ground truth outputs \(y_i\) under the predicted distribution \(f : X \rightarrow Y = \mathcal{N}(\mu_i, \sigma_i^2)\)
\[ \mathcal { L } ( \theta ) = - \mathbb { E } _ { f \sim P } \left[ \mathbb { E } _ { N } \left[ \log Q _ { \theta } \left( \left\{ y _ { i } \right\} _ { i = 0 } ^ { n - 1 } | O _ { N } , \left\{ x _ { i } \right\} _ { i = 0 } ^ { n - 1 } \right) \right] \right] \] (Eq. 4 in (Garnelo et al., 2018))
# Latent encoder outputs mu and sigma.
dist = tf.contrib.distributions.MultivariateNormalDiag(
loc=mu, scale_diag=sigma)
log_p = dist.log_prob(target_y)
loss = -tf.reduce_mean(log_prob)
train_step = optimizer.minimize(loss)
CNP paper results π
Figure 3a from CNP (Garnelo et al., 2018)
- Outperforms kNN and GP (MSE) when # context points < 50%
- Order of context points is less important (factorized model?)
- \(\therefore\) CNPs learn good data-specific βpriorβ (few points needed)
- Uncertainty estimation: variance gets localized to digitβs edges
- Butβ¦ π cannot produce coherent samples
- \(\mu_i\) and \(\sigma_i\) will always be same given same \(\left\{ (x_i, y_i) \right\} _ { i = 1 } ^ { C }\)
- Even if observed pixels could conform to 1,4, and 7, CNP can only output incoherent image of mean and variance of all those digits
Latent formulation of CNP π
Architecture of CNP latent variable model, modified from (DeepMind, 2018)
- Let \(r\) parameterize distribution of global latent variable, \(z \sim \mathcal{N}(\mu( r), I \sigma( r))\).
- \(z\) is sampled once and passed into decoder \(d\)
- \(z\) captures the global uncertainty
- allows sampling on global level (one function at a time) rather than locally (one output \(y_i\) for each input \(x_i\) at a time)
Implications for image completion π
Figure 6 from (Garnelo et al., 2018)
- Latent CNPs (and NPs) can produce different function samples for the same conditioned observations
- Produce different digits/faces when sampled
- Whereas CNPs just output (roughly) the mean of all digits/faces
Training of latent CNP π
\[\text {ELBO} = \mathbb{E}_{q(z | C, T)} \left[ \sum_{t=1}^T \log p(y_t^{\ast} | z, x_t^{\ast}) + \log \frac{q(z | C)}{q(z | C, T)} \right]\]
- Like VAE, maximize variational lower bound of log-likelihood
- First term: expected log-likelihood of \(y_t\) under approximate posterior
- Second term: negative KL divergence between prior and posterior
- Instead of a Gaussian prior \(p(z)\), have conditional prior \(p(z|C)\) on context \(C \subset X \times Y\)
- Cannot access \(p(z|C)\) since it depends on encoder \(h\); use approximate \(q(z|C)\)
- Gaussian posterior \(p(z|C, T)\) on both context \(C\) and targets \(T \subset X\)
- Instead of a Gaussian prior \(p(z)\), have conditional prior \(p(z|C)\) on context \(C \subset X \times Y\)
Neural process (NP) π
NP diagram from (DeepMind, 2018)
- NPs are a generalization of CNPs, very similar to latent CNP
- But \(e\) is split into two different encoders to produce the global representation \(r_c\) and a separate code \(s_c\) that parameterises \(z\) (instead of directly parameterising \(z\) with \(r\))
- \(d\) also accepts \(r\), not only target location \(x_*\) (i.e. \(x_T\)) and \(z\)
NPs underfit π
Figure 1 from (Kim et al., 2019)
- π NPs tend to underfit the context set
- Inaccurate predictive means and overestimated variances
- Imperfect reconstruction of top-half
- Mean-aggregation in encoder acts as a bottleneck
- Same weight given to each context point
- Difficult for decoder to learn which context points are relevant for a given target
Attentive neural process (ANP) π
- Aggregator is replaced by cross-attention
- Each query can attend more closely to relevant context points
- Equivalent to uniform attention in NP
- Encoder MLPβs replaced with self-attention
- To model interactions between context points
Related work π
Modified from Figure 2 in (Garnelo et al., 2018)
- a) and b) do not allow for targeted sampling at positions \(x_T\)
- Can only generate i.i.d. samples from estimated output density \(y_T\)
- CNP cannot model global uncertainty /distribution over functions; only one \(y_T\) given \(y_C, x_C, x_T\)
- c), d), and e) can sample different functions
- GPs contain kernel predicting the covariance between \(C\) and \(T\), forming a multivariate Gaussian producing samples
Experiments π
- CNP latent variable model in PyTorch for \(28 \times 28\) monochrome MNIST images
(e_1): Linear(in_features=3, out_features=400, bias=True)
(e_2): Linear(in_features=400, out_features=400, bias=True)
(e_3): Linear(in_features=400, out_features=128, bias=True)
(r_to_z_mean): Linear(in_features=128, out_features=128, bias=True)
(r_to_z_logvar): Linear(in_features=128, out_features=128, bias=True)
(d_1): Linear(in_features=130, out_features=400, bias=True)
(d_2): Linear(in_features=400, out_features=400, bias=True)
(d_3): Linear(in_features=400, out_features=400, bias=True)
(d_4): Linear(in_features=400, out_features=400, bias=True)
(d_5): Linear(in_features=400, out_features=1, bias=True)
- Architecture modified from unofficial PyTorch implementation (geniki, 2018)
Training π
400 epochs took 7 hours on NVIDIA Quadro P5000 GPU (Paniikki computer)
Results π
- Blue pixels are unobserved; 5 samples per set of context points
10 and 100 context points at epoch 400 π
- With only 10 context points, generated images are sometimes incoherent
- Increasing to 100 results in more coherent images
300 context points π
- Only 38% observed, so still some variation in the different coherent samples drawn
- E.g. in the last column, images of both “3” and “9” are generated that conform to those 100 observed pixels
784 (100%) context points π
- When fully observed, uncertainty is reduced and the samples converge to a single estimate
- Even with full observation, pixel-perfect reconstruction is not achieved due to bottleneck at representation level
What about those nice celebrity faces? π
- Unable to reproduce results on more complex \(64 \times 64\) RGB CelebA images
- Problems with convergence and lack of implementation details in CNP paper
Key learnings and findings π
- Paper omits many implementation details; some details can be filled in with experience and some cannot
- Would have been lost without open-source implementations
- E.g. How does \(r\) parameterize \(z\)? Further MLP layers that map \(r\) to the parameters of Gaussian \(z\)
- Tricks like
sigma = 0.1 + 0.9 * tf.sigmoid(log_sigma)
to avoid zero standard deviationβ¦
- Many ways to achieve similar outcomes
- E.g. using or not using VAE reparameterization trick when sampling latent \(z\)
Conclusion β¨ π
- Incorporates ideas from GP inference into a NN training regime to overcome certain drawbacks of both
- Possible to achieve desirable GP-like behaviour, e.g. predictive uncertainties, encapsulating high-level statistics of a family of functions
- However, challenging to interpret since effects are implicit
- NPs are ultimately still neural network-based approximations and suffer from problems such as the difficulty of hyperparameter tuning/model selection
Lingering doubts π€ π
- Section 3.2 Meta-learning of CNP: “the latent variant of CNP can also be seen as an approximated amortized version of Bayesian DL (Gal & Ghahramani, 2016; Blundell et al., 2015; Louizos et al., 2017; Louizos & Welling, 2017)”
Dropout as a Bayesian approximation, Weight uncertainty in neural networks, Bayesian compression for DL; Multiplicative normalizing flows for variational bayesian neural networks
- βCNPs can only model a factored prediction of the mean and the variances, disregarding covariance between target pointsβ
- βSince CNP returns factored outputs, the best prediction it can produce given limited context information is to average over all possible predictions that agree with the context.β
Inference in NP π
- How does the approximate posterior look like in terms of the network architecture?
- How does conditional prior \(p(z|C)\) depend on encoder \(h\) (making it intractable)?
Bibliography
Wikipedia contributors (2019). Gaussian process — Wikipedia, the free encyclopedia. Retrieved from https://en.wikipedia.org/w/index.php?title=Gaussian_process&oldid=896566703. [Online; accessed 16-May-2019]. β©
Garnelo, M., Schwarz, J., Rosenbaum, D., Viola, F., Rezende, D. J., Eslami, S. M. A., & Teh, Y. W. (2018), Neural Processes, CoRR. β©
Garnelo, M., Rosenbaum, D., Maddison, C. J., Ramalho, T., Saxton, D., Shanahan, M., Teh, Y. W., β¦ (2018), Conditional Neural Processes, CoRR. β©
DeepMind, (2018). Open-source implementations of neural process variants. Retrieved from https://github.com/deepmind/neural-processes. (Accessed on 05/08/2019). β©
Kim, H., Mnih, A., Schwarz, J., Garnelo, M., Eslami, A., Rosenbaum, D., Vinyals, O., β¦ (2019), Attentive Neural Processes, CoRR. β©
geniki, (2018). Pytorch implementation of neural processes. Retrieved from https://github.com/geniki/neural-processes/. (Accessed on 05/17/2019). β©
MΓ€rtens, K. (2018). Neural processes as distributions over functions | kaspar mΓ€rtens | phd student in statistical machine learning. Retrieved from https://kasparmartens.rbind.io/post/np/. (Accessed on 05/17/2019). β©