Post-hoc Uncertainty Estimation for Bayesian Deep Learning
The methods developed in the previous chapter introduced scalable Bayesian inference directly in function space through Deep Variational Implicit Processes. However, in many real-world scenarios, a high-performing deterministic network is already available, and retraining it within a full Bayesian framework is impractical. This chapter therefore shifts focus from end-to-end Bayesian training to post-hoc uncertainty estimation—attaching calibrated predictive distributions to existing neural networks without modifying their original parameters. Two complementary approaches are proposed: the Variational Linearized Laplace Approximation (VaLLA) and the Fixed-Mean Gaussian Process (FMGP), both of which provide principled and computationally efficient routes to retrofit uncertainty into modern deep learning models.
Introduction
Deep neural networks (DNNs) have become the de facto solution for a broad range of pattern–recognition tasks, consistently achieving state-of-the-art accuracy in areas such as computer vision and natural language processing (He et al., 2016; Vaswani et al., 2017). However, conventional (deterministic) DNNs deliver only point predictions; they are frequently miscalibrated (Guo et al., 2017) and provide poor representations of epistemic uncertainty (Blundell et al., 2015). These shortcomings are unacceptable in risk-sensitive applications—for example, autonomous driving (Kendall & Gal, 2017) and medical decision support (Leibig et al., 2017)—where a faithful quantification of predictive confidence is essential.
A principled remedy is to cast the network weights as random variables and perform Bayesian inference, yielding a Bayesian neural network (BNN) (Graves, 2011; MacKay, 1992a; Neal, 2012). Because exact posterior inference is intractable in modern architectures, one resorts to approximations such as variational inference (VI) (Blundell et al., 2015), Markov chain Monte Carlo (MCMC) (Chen et al., 2014), or the Laplace approximation (LA) (MacKay, 1992b; Ritter et al., 2018). Empirically, VI and MCMC often underperform plain back-propagation (Wenzel et al., 2020), whereas LA strikes a favorable balance between accuracy and cost: starting from the maximum-a-posteriori (MAP) solution obtained by standard training, it fits a Gaussian posterior with mean at the MAP estimate and covariance given by the (negative inverse) Hessian of the log-posterior.
Forming and inverting the full Hessian is infeasible in large networks, so practitioners substitute the Generalised Gauss–Newton (GGN) matrix (Martens & Grosse, 2015). Although this approximation may under-fit (Lawrence, 2001), the effect is alleviated when the model is linearized at the MAP parameters. This insight underlies the Linearized Laplace Approximation (LLA) (Foong et al., 2019; Immer et al., 2021), which applies LA to the first-order Taylor expansion of the network. LLA preserves the original predictive mean—thus retaining baseline performance—while attaching principled error bars, and its GGN is guaranteed positive-definite, allowing post-hoc application at arbitrary weight configurations.
Despite these advantages, LLA is not yet turnkey for industrial-scale problems: computing the Jacobian of the network output with respect to millions of parameters for every training instance is prohibitively expensive. Consequently, further structural approximations—e.g. diagonal or Kronecker-factored (KFAC) representations of the GGN—or data sub-sampling are employed, sacrificing some statistical fidelity for tractability. Designing methods that match the calibration quality of Bayesian approaches while scaling gracefully to contemporary architectures, therefore, remains an active and important research challenge.
Contributions of this chapter. To overcome the scalability bottleneck of LLA while preserving its favorable calibration properties, two recent lines of work are examined:
Variational LLA (VaLLA). LLA’s predictive distribution is reinterpreted as that of a Gaussian process (GP) (Khan et al., 2019). A sparse variational GP with inducing points (Titsias, 2009) is then employed only for the covariance, leaving the mean exactly equal to the original DNN output by working in the dual, RKHS representation (see Section) VaLLA supports stochastic mini-batch optimization and attains sub-linear cost in the data size, outperforming alternative scalable-LA surrogates such as Kronecker/diagonal GGN factorizations, Nyström-based ELLA (Deng et al., 2022), and sample-then-optimize schemes (Antorán et al., 2023).
Fixed-Mean Gaussian Processes (FMGPs). Extending the RKHS perspective (see Section) further, a new family of sparse GPs is introduced, whose posterior mean is fixed to an arbitrary continuous function—e.g. the predictions of a pre-trained high-performance DNN—while the posterior covariance is learned via VI using decoupled inducing points (Cheng & Boots, 2016). FMGP thus converts any pre-trained network into a Bayesian predictor without requiring Jacobians or direct access to internal parameters. Scalability to ImageNet-scale vision models and to molecular-property prediction (QM9) is shown, achieving uncertainty estimates competitive with Hamiltonian Monte Carlo on small problems but at orders-of-magnitude lower cost.
Together, these two methodologies illustrate how LLA’s theoretical appeal can be retained in large-scale, practical settings through judicious kernel approximations and RKHS duality, paving the way for reliable uncertainty quantification in modern deep learning.
Variational Linearized Laplace Approximation
Building on the Gaussian-process (GP) foundations laid out in Section, Variational Linearized Laplace Approximation (VaLLA) is introduced. The method inherits (i) the scalability of variational sparse GPs and (ii) the post-hoc nature of linearized Laplace approximation (LLA), allowing for enhanced uncertainty estimation of a pre-trained deep model employing a sparse approximation to LLA.
As reviewed in Section, Variational sparse GPs approximate the GP posterior using a GP parameterized by \(M\) inducing points \(\mathbf{Z}\), each in \(\mathbb{R}^D\), and associated process values \(\mathbf{u} = f(\mathbf{Z})\) (Titsias, 2009), \begin{equation} P(\mathbf{f}, \mathbf{u} | \mathbf{y}) \approx Q(\mathbf{f}, \mathbf{u}) = P(\mathbf{f} | \mathbf{u})Q(\mathbf{u})\,, \end{equation} where \(Q(\mathbf{u}) = \mathcal{N}(\mathbf{u}|\hat{\bm{m}}, \hat{\bm{S}})\), \(\mathbf{f}=f(\mathbf{X})\), and \(P(\mathbf{f} | \mathbf{u})\) are fixed. The approximate posterior distribution \(Q(\mathbf{u})\) is obtained by minimizing the KL-divergence \(\text{KL}\left(Q(\mathbf{f}, \mathbf{u}) | P(\mathbf{f}, \mathbf{u}| \mathbf{y}) \right)\). In practice, the minimization problem is transformed into the maximization of the lower bound of the log-marginal likelihood. \begin{equation} \log P(\mathbf{y}) \geq \max_{\mathbf{Z}, \hat{\bm{m}}, \hat{\bm{S}}} \int Q(\mathbf{f}, \mathbf{u}) \log \frac{P(\mathbf{y}|\mathbf{f}) P(\mathbf{f}| \mathbf{u})P(\mathbf{u})}{Q(\mathbf{f}, \mathbf{u})} \ \mathrm{d} (\mathbf{f}, \mathbf{u})\,, \end{equation} which has cost \(\mathcal{O}(NM^2+M^3)\) due to the cancellation of the factor \(P(\mathbf{f}|\mathbf{u})\) since \(Q(\mathbf{f}, \mathbf{u}) = P(\mathbf{f} | \mathbf{u})Q(\mathbf{u})\). The following result establishes a connection between SVGPs and a subset of Gaussian measures in the Hilbert space. This is built on the RKHS theory presented in Section.
Theorem 4.1. For any kernel function \(K\) and its corresponding RKHS \(\cal{H}\) —with feature maps defined as \(\phi_{\mathbf{x}} := K(\cdot, \mathbf{x}) \in \cal{H}\)—, any \(\bm{a}\in \mathbb{R}^M\), \(\bm{A} \in \mathbb{R}^{M \times M}\) with \(\bm{A} \succeq 0\) and \(\mathbf{Z} = (\mathbf{z}_1, \dots, \mathbf{z}_M) \in \mathcal{X}^{M}\), let \(\tilde{\mu} \in \mathcal{H}\) and \(\tilde{\Sigma} : \cal{H} \to \cal H\) be defined as \begin{equation} \tilde{\mu} = \sum_{m=1}^{M} a_m \phi_{\mathbf{z}_m}, \qquad \text{ and }\qquad \tilde{\Sigma} = \left(I + \sum_{i=1}^{M}\sum_{j=1}^M\phi_{\mathbf{z}_i} A_{i,j} \phi_{\mathbf{z}_j}^T\right)^{-1}\,, \end{equation} then, \(\tilde \Sigma \in L^{+}(\mathcal H)\), and the Gaussian measure defined by \((\tilde{\mu}, \tilde{\Sigma})\) is equivalent to a SVGP with variational distribution \(Q(\mathbf{u})=\mathcal{N}(\mathbf{u}|\bm{m},\bm{S})\), defined as \begin{equation} \bm{m} = K(\mathbf{Z}, \mathbf{Z}) \bm a, \quad \text{and} \quad \bm{S} =K(\mathbf{Z}, \mathbf{Z}) - K(\mathbf{Z}, \mathbf{Z}) (\bm A^{-1} + K(\mathbf{Z}, \mathbf{Z}))^{-1} K(\mathbf{Z}, \mathbf{Z})\,. \end{equation}
Proof
Write \(\Psi = \bigl[\phi_{\mathbf z_1} \cdots \phi_{\mathbf z_M}\bigr] :\mathbb R^{M}\to\mathcal H\). Because \(\bm A=\bm A^{\top}\) it verifies that \(\Psi \bm A \Psi^{\top}= (\Psi \bm A \Psi^{\top})^{\ast}\), hence \(\big(\tilde\Sigma^{-1}\big)^{\ast}= \tilde\Sigma^{-1}\). For any \(f\in\mathcal H\), \begin{equation} \langle f,\tilde\Sigma^{-1}( f)\rangle_{\mathcal H} = \|f\|_{\mathcal H}^{2} + \bigl\langle \Psi^{\top}f,\bm A\Psi^{\top}f \bigr\rangle_{\mathbb R^{M}} \ge 0\,, \end{equation} because \(\bm A\succeq0\). Thus \(\tilde\Sigma^{-1}\succeq0\) and \(\tilde{\Sigma}^{-1} \in \mathcal{L}^+(\mathcal{H})\). Thus \(\tilde{\Sigma} \in \mathcal{L}^+(\mathcal{H})\). Recall that from the dual formulation of GPs, the mean function of the GP is defined as \(m^{\star}(\mathbf{x}) := \braket{\phi_{\mathbf{x}}, \mu}\). Thus, it verifies that \begin{equation} m^{\star}(\mathbf{x}) = \braket{\phi_{\mathbf{x}}, \tilde{\mu}} = \braket{\phi_{\mathbf{x}},\sum_{m=1}^{M} a_m \phi_{\mathbf{z}_m}} = \sum_{m=1}^{M} a_m \braket{\phi_{\mathbf{x}}, \phi_{\mathbf{z}_m}} = K(\mathbf{x}, \mathbf{Z}) \bm{a}\,. \end{equation} Using that \(\bm{a} = K(\mathbf{Z}, \mathbf{Z})^{-1} \bm{m}\), the mean function from SVGPs (Titsias, 2009) is recovered. The same procedure can be used for the covariance matrix, using that \(\tilde{\Sigma}^{-1} = I + \Psi \bm A \Psi^{\top}\), applying Woodbury’s matrix identity: \begin{equation} \tilde{\Sigma} = I + \Psi\tilde{\bm{A}}\Psi^\top\,, \quad \text{ where } \tilde{\bm{A}} = -(\bm A^{-1} + K(\mathbf Z, \mathbf Z))^{-1}\,. \end{equation} Thus, the kernel \(K^{\star}(\mathbf{x}, \mathbf{x}') := \braket{\phi_{\mathbf{x}},\tilde\Sigma (\phi_{\mathbf{x}'})}\): \begin{align} K^{\star}(\mathbf{x}, \mathbf{x}') &= \braket{\phi_{\mathbf{x}},\tilde{\Sigma} (\phi_{\mathbf{x}'})} = \braket{\phi_{\mathbf{x}}, \phi_{\mathbf{x}'} + \sum_{i=1}^{M}\sum_{j=1}^M\phi_{\mathbf{z}_i} \tilde A_{i,j} \braket{\phi_{\mathbf{z}_j}, \phi_{\mathbf{x}'}}}\\ &= \braket{\phi_{\mathbf{x}}, \phi_{\mathbf{x}'}} + \sum_{i=1}^{M}\sum_{j=1}^M\braket{\phi_{\mathbf{x}}, \phi_{\mathbf{z}_i}} \tilde A_{i,j} \braket{\phi_{\mathbf{z}_j}, \phi_{\mathbf{x}'}}\\ &= K(\mathbf x, \mathbf x') + \sum_{i=1}^{M}\sum_{j=1}^M K(\mathbf x, \mathbf z_i) \tilde A_{i,j} K(\mathbf z_j, \mathbf x')\\ &= K(\mathbf x, \mathbf x') + K(\mathbf x, \mathbf Z) \tilde{\bm{A}} K(\mathbf Z, \mathbf x')\,. \end{align} From the definition of \(\bm{S}\): \begin{equation} K(\mathbf{Z}, \mathbf{Z})^{-1} - K(\mathbf{Z}, \mathbf{Z})^{-1}\bm{S}K(\mathbf{Z}, \mathbf{Z})^{-1} = (\bm A^{-1} + K(\mathbf{Z}, \mathbf{Z}))^{-1} = -\tilde{\bm{A}} \,. \end{equation} Exchanging \(\tilde{\bm{A}}\) with this quantity recovers the covariance function of SVGPs and completes the proof. □
This defines a family of Gaussian measures \(\mathcal{Q}\subset \mathcal{P}_{\mathcal{H}}\) as \begin{equation} \mathcal{Q} := \left\{\mathcal{N}(f|\tilde{\mu}_{\bm a}, \tilde{\Sigma}_{\bm A}): \bm{a} \in \mathbb{R}^M, \bm{A} \in \mathbb{R}^{M \times M}, \bm{A}\succeq0, \mathbf Z \in \mathcal{X}^M \right\}\,, \end{equation} where \(\mathcal{H}\) and \(M\) are omitted from \(\mathcal{Q}'s\) notation for simplicity as more sub-indexes will be used later.
Decoupled Sparse Variational Gaussian Processes
Theorem 4.1 establishes that the variational inference procedure introduced by Titsias (2009) conducts optimization over a Gaussian measure whose mean, \(\tilde{\mu}\), and covariance operator, \(\tilde{\Sigma}\), are both expressed in terms of a single finite set of inducing basis functions \(\{\phi_{\mathbf{z}}\in\mathcal{H}\,: \, \mathbf{z}\in\mathbf{Z}\}\subset\mathcal{H}\), where \(\mathcal{H}\) denotes the RKHS generated by the kernel.
Cheng & Boots (2017) extend this construction by allowing different inducing bases for the mean and the covariance. Let \(M_\alpha, M_\beta \in \mathbb{N}\), and \begin{equation} \mathbf{Z}_{\alpha} =\{\mathbf{z}_{\alpha,1},\dots,\mathbf{z}_{\alpha,M_\alpha}\} \in \mathcal{X}^{M_\alpha}, \qquad \mathbf{Z}_{\beta} =\{\mathbf{z}_{\beta,1},\dots,\mathbf{z}_{\beta,M_\beta}\}\in \mathcal{X}^{M_\beta}\,. \end{equation} Define the Gaussian measure as \begin{equation} \tilde{\mu} = \sum_{m=1}^{M_\alpha} a_m \phi_{\mathbf{z}_{\alpha, m}}, \qquad \text{ and }\qquad \tilde{\Sigma} = \left(I + \sum_{i=1}^{M_\beta}\sum_{j=1}^{M_\beta}\phi_{\mathbf{z}_{\beta, i}} A_{i,j} \phi_{\mathbf{z}_{j, \beta}}^T\right)^{-1}\,, \end{equation} where \(\bm{A} \succeq 0\). This decoupled parameterization is a clear generalization from standard SVGPs and cannot be obtained using the approach of (Titsias, 2009) unless \(\mathbf{Z}_\alpha = \mathbf{Z}_\beta\). The decoupled space of Gaussian measures is defined then as: \begin{equation} \mathcal{Q}^{+} = \left\{\mathcal{N}(f|\tilde{\mu}_{\alpha, \bm a}, \tilde{\Sigma}_{\beta, \bm A}) : \bm{a} \in \mathbb{R}^{M_\alpha},\bm{A} \in \mathbb{R}^{M_\beta \times M_\beta}, \bm{A} \succeq 0, \mathbf Z_\alpha \in \mathcal{X}^{M_\alpha}, \mathbf Z_\beta \in \mathcal{X}^{M_\beta}\right\}\,, \end{equation} where it is verified that \(\mathcal{Q} \subset \mathcal{Q}^{+} \subset \cal P_{\cal{H}}\).
This two‐basis parameterization strictly generalizes the single‐basis formulation of Titsias (2009); there exists no shared inducing set capable of reproducing every pair \((\mathbf{Z}_\alpha,\mathbf{Z}_\beta)\) when \(M_\alpha\neq M_\beta\).
For any \(Q(f)=\mathcal{N}(f|\tilde{\mu},\tilde{\Sigma})\) drawn from the decoupled Gaussian measure family, the evidence lower bound (ELBO) is (Cheng & Boots, 2016) \begin{equation} \label{eq:opt_dual} \mathcal{L}(Q) = \mathbb{E}_{Q}\left[\log P(\mathbf{y}|f)\right] -\mathrm{KL}(Q|P)\,, \end{equation} where \(p(f)\) denotes the prior Gaussian measure (typically \(\mathcal{N}(f| 0, I)\)) and \(P(\mathbf{y}|f)\) the likelihood. Because \(Q\) and \(P\) are Gaussian, the Kullback–Leibler divergence admits a closed‐form expression (Cheng & Boots, 2016): \begin{equation} \label{eq:KL_dual_basis} \mathrm{KL}(Q|P) = \frac12\,\bm{a}^{T}\bm{K}_{\alpha}\bm{a} + \frac12 \log\det (I+\bm{K}_{\beta}\bm{A}) - \frac12 \mathrm{tr}[\bm{K}_{\beta} (\bm{A}^{-1}+\bm{K}_\beta)^{-1}]\,, \end{equation} with \((\bm{K}_{\alpha})_{i,j} = K(\mathbf z_{\alpha, i}, \mathbf{z}_{\alpha, j})\) and \((\bm{K}_{\beta})_{i,j} = K(\mathbf z_{\beta, i}, \mathbf{z}_{\beta, j})\). The expectation term in Equation \(\eqref{eq:opt_dual}\) depends on the choice of likelihood (e.g. Gaussian, Bernoulli) and is typically handled via Gaussian quadrature or the reparameterization trick. Gradients of \(\mathcal{L}(q)\) with respect to \(\bm{a}\), \(\bm{A}\), and the inducing locations can thereafter be obtained in closed form or by automatic differentiation.
The decoupling of inducing sets introduces two notable advantages:
Expressiveness. Allowing \(M_\alpha\neq M_\beta\) enables the mean to be represented with a smaller (or larger) set of points than the covariance, yielding a more parsimonious yet accurate approximation.
Computational trade‐off. Evaluating \(\tilde{\mu}\) scales with \(M_\alpha\), whereas computing the log‐determinant in Equation \(\eqref{eq:KL_dual_basis}\) scales as \(O(M_\beta^{3})\). Choosing \(M_\beta<M_\alpha\) can therefore reduce the dominant cubic cost while retaining a rich mean function.
Consequently, the framework of Cheng & Boots (2017) contains the sparse‐GP variational method of Titsias (2009) as the special case \(\mathbf{Z}_\alpha=\mathbf{Z}_\beta\) and \(M_\alpha=M_\beta\), yet affords superior flexibility for large‐scale or heteroscedastic data sets encountered in contemporary applications.
Linearized Laplace Approximation (LLA)
Consider the task of inferring an unknown target function \(f:\mathbb{R}^D \to \mathbb{R}\) from noisy observations \(\mathbf{y} = (y_1,\dots,y_N)^{\!\top}\) at inputs \(\mathbf{X} = (\mathbf{x}_1,\dots,\mathbf{x}_N)\). In deep learning, a neural network \[g:\mathbb{R}^D \times \mathbb{R}^P \to \mathbb{R},\] with parameters \(\bm{\theta} \in \mathbb{R}^P\), is used to approximate \(f\) such that there exists a parameter configuration \(\hat{\bm{\theta}}\) satisfying \(f(\cdot) \approx g(\cdot,\hat{\bm{\theta}})\). After training, the function \(g(\cdot,\hat{\bm{\theta}})\) acts as a deterministic predictor of the data-generating process. However, standard neural networks do not quantify predictive uncertainty, often yielding overconfident predictions in regions without training data.
From a Bayesian standpoint, uncertainty over the model parameters is introduced by specifying a prior \(P(\bm{\theta})\) and forming the posterior \[P(\bm{\theta}|\mathbf{y}) \propto P(\mathbf{y}|\bm{\theta})\,P(\bm{\theta}),\] where the likelihood is induced by the network evaluations \(\mathbf{g} = (g(\mathbf{x}_1,\bm{\theta}),\ldots,g(\mathbf{x}_N,\bm{\theta}))^{\!\top}\). In regression, the likelihood is typically Gaussian, while in classification problems \(y_i \in \{1,\ldots,C\}\) with \(C\) denoting the number of classes, and the likelihood is categorical with probabilities given by a softmax link function. In this latter case, \[g:\mathbb{R}^D \times \mathbb{R}^P \to \mathbb{R}^C,\] produces \(C\) logits, one per class label.
Exact posterior inference is impossible for modern deep networks because the likelihood is highly non-linear, so one typically introduces a tractable approximate posterior \(Q(\bm\theta)\approx P(\bm\theta|\mathbf y)\). Predictions for a new input \(\mathbf x^\star\) are then obtained by Monte Carlo marginalization, \begin{equation} P(y^\star|\mathbf x^\star,\mathbf y) = \mathbb E_{P(\bm\theta|\mathbf y)} \bigl[P(y^\star|\mathbf x^\star,\bm\theta)\bigr] \approx \frac1S\sum_{s=1}^{S} P\bigl(y^\star|\mathbf x^\star,\bm\theta_{s}\bigr), \qquad \bm\theta_{s}\sim Q(\bm\theta), \end{equation} where the empirical average exposes both aleatoric and epistemic uncertainty (Bishop, 2006).
The Laplace approximation (LA) replaces the true posterior by a Gaussian centered at the maximum–a–posteriori (MAP) solution \(\hat{\bm\theta}=\arg\max_{\bm\theta}\bigl[\log P(\mathbf y|\bm\theta)+ \log P(\bm\theta)\bigr]\): \begin{equation} Q(\bm\theta)= \mathcal N \bigl(\bm\theta|\hat{\bm\theta},\bm\Sigma\bigr), \qquad \bm\Sigma^{-1}=- \nabla_{ \bm\theta\bm\theta}^{2} \Bigl[\log P(\mathbf y|\bm\theta)+\log P(\bm\theta)\Bigr]_{\bm\theta=\hat{\bm\theta}}. \end{equation} Assuming an isotropic Gaussian prior \(P(\bm\theta)=\mathcal N(\bm\theta|\mathbf0,\sigma_{0}^{2}\bm I)\) adds the identity term \(\bm I/\sigma_{0}^{2}\) to the precision, \begin{equation} \bm\Sigma^{-1}=- \nabla_{ \bm\theta\bm\theta}^{2} \log P(\mathbf y|\bm\theta)_{\bm\theta=\hat{\bm\theta}} + \frac{\bm I}{\sigma_{0}^{2}}. \end{equation} For deep networks, the Hessian is intractable and not guaranteed positive-definite. A common remedy is to replace it by the positive semi-definite Generalized Gauss–Newton (GGN) matrix (Immer et al., 2021), \begin{equation} \label{eq:ggn} \bm\Sigma^{-1} \approx \sum_{n=1}^{N} J_{\hat{\bm\theta}}(\mathbf x_{n})^{ \top} \Lambda(\mathbf x_{n},y_{n}) J_{\hat{\bm\theta}}(\mathbf x_{n}) + \frac{\bm I}{\sigma_{0}^{2}}, \end{equation} with \(J_{\hat{\bm\theta}}(\mathbf x):=\nabla_{ \bm\theta}\, g(\mathbf x,\bm\theta)\bigr|_{\bm\theta=\hat{\bm\theta}}\) the Jacobian of the network outputs and \begin{equation} \Lambda(\mathbf x_{n},y_{n}):= -\nabla_{ \mathbf g\mathbf g}^{2} \log P(y_{n}|\mathbf g)\bigr|_{\mathbf g=g(\mathbf x_{n},\hat{\bm\theta})} \end{equation} the Hessian of the negative log-likelihood with respect to the network outputs \(\mathbf g\). Because the GGN is exactly the Hessian of the linearized network \begin{equation} g_{\hat{\bm\theta}}^{\text{lin}}(\mathbf x,\bm\theta)= g(\mathbf x,\hat{\bm\theta})+ J_{\hat{\bm\theta}}(\mathbf x)\,(\bm\theta-\hat{\bm\theta}), \end{equation} the approximate posterior is centered on \(\hat{\bm\theta}\) but its curvature is governed by the linear surrogate, a mismatch that can lead to underfitting (Lawrence, 2001).
The linearized Laplace approximation (LLA) resolves the mismatch by also making predictions through the linear surrogate: \begin{equation} P_{\text{LLA}}(y^\star|\mathbf x^\star,\mathbf y)=\mathbb E_{Q(\bm\theta)}[P(y^\star|g_{\hat{\bm\theta}}^{\text{lin}}(\mathbf x^\star,\bm\theta))]\,. \end{equation} LLA therefore uses a Gaussian posterior over a linear model, for which the GGN is exact. The main computational bottleneck is now the inversion of the precision matrix, whose cost scales cubically with the number of parameters. A dual Gaussian-process formulation trades this for a cubic dependence on the number of training points instead.
A linear model endowed with a Gaussian prior on its weights induces a Gaussian process (GP) over functions (Rasmussen & Williams, 2006). Adopting the isotropic prior \(p(\bm\theta)=\mathcal N(\bm\theta|\mathbf 0,\sigma_{0}^{2}\mathbf I_{P})\), the linearized Bayesian neural network \(g_{\hat{\bm\theta}}^{\mathrm{lin}} \bigl(\mathbf x,\bm\theta\bigr)\) defines a GP whose mean and covariance functions are: \begin{equation} m(\mathbf x) = g_{\hat{\bm\theta}}^{\mathrm{lin}} \bigl(\mathbf x,\mathbf 0\bigr), \qquad \text{ and} \qquad K_{\text{prior}}(\mathbf x,\mathbf x') = \sigma_{0}^{2}\, J_{\hat{\bm\theta}}(\mathbf x)^{ \top} J_{\hat{\bm\theta}}(\mathbf x'). \end{equation} Replacing the prior \(p(\bm\theta)\) by the Laplace posterior \(q(\bm\theta)=\mathcal N(\bm\theta|\hat{\bm\theta},\bm\Sigma)\) simply shifts the GP mean to the non-linear network output while re-scaling the covariance: \begin{equation} \label{eq:posterior_gp} m(\mathbf x) = g(\mathbf x,\hat{\bm\theta}), \qquad K(\mathbf x,\mathbf x') = J_{\hat{\bm\theta}}(\mathbf x)^{ \top}\, \bm\Sigma\, J_{\hat{\bm\theta}}(\mathbf x'). \end{equation} Under the Generalized Gauss–Newton (GGN) approximation \(\bm\Sigma^{-1} =\sum_{n}J_{n}^{\top}\Lambda_{n}J_{n}+\sigma_{0}^{-2}\mathbf I_{P}\) the Woodbury matrix identity yields \begin{equation} \bm\Sigma = \sigma_{0}^{2}\Bigl( \mathbf I_{P} -\mathbf J^{\top} \bigl(\sigma_{0}^{-2}\bm\Lambda_{\mathbf X,\mathbf y}^{-1} +\mathbf J\mathbf J^{\top}\bigr)^{-1} \mathbf J \Bigr), \end{equation} where \(\mathbf J\) stacks all Jacobians \(J_{\hat{\bm\theta}}(\mathbf x_{n})\) row-wise and \(\bm\Lambda_{\mathbf X,\mathbf y}=\operatorname{diag}(\Lambda_{1},\dots,\Lambda_{N})\). Substituting this expression into the covariance formula above leads to a purely function-space representation. The (scaled) Neural Tangent Kernel (NTK) \begin{equation} \kappa(\mathbf x,\mathbf x') = \sigma_{0}^{2}\, J_{\hat{\bm\theta}}(\mathbf x)^{ \top} J_{\hat{\bm\theta}}(\mathbf x'), \qquad \mathbf Q = \bm\Lambda_{\mathbf X,\mathbf y}^{-1} +\kappa(\mathbf X,\mathbf X). \end{equation} The posterior GP covariance simplifies to \begin{equation} K(\mathbf x,\mathbf x') ~=~ \kappa(\mathbf x,\mathbf x') - \kappa(\mathbf x,\mathbf X)\, \mathbf Q^{-1}\, \kappa(\mathbf X,\mathbf x'). \end{equation} Interpreting the Linearized Laplace Approximation (LLA) as a GP shifts the computational bottleneck from parameter space to data space: evaluating \(\mathbf Q^{-1}\) costs \(\mathcal O(N^{3}+N^{2}P)\) operations. For \(C\)-class classification the kernel matrices have dimension \(NC\times NC\); the cubic dependence on both the sample size \(N\) and output dimension \(C\) must therefore be addressed by further approximations.
Using Decoupled SGP and LLA
The decoupled reparameterization of sparse GPs establishes a model where the mean of the approximated posterior distribution is anchored to a pre-trained MAP solution. This method is named variational LLA (VaLLA).
Proposition 4.2. For any kernel \(K:\cal X \times \cal X \to \mathbb{R}\) and its corresponding RKHS \(\cal H\), if \(g(\cdot, \hat{\bm{\theta}}) \in \mathcal{H}\), then \(\forall \epsilon > 0\), \(\exists M_\alpha \in \mathbb{N}^+\), \(\mathbf{Z}_{\alpha}\in \mathcal{X}^{M_\alpha}\), \(\bm{a} \in \mathbb{R}^{M_\alpha}\) such that, for any \(M_\beta \in \mathbb{N}\), \(\mathbf{Z}_{\beta} \in \cal{X}^{M_\beta}\), \(\bm{A} \in \mathbb{R}^{M_\beta \times M_\beta}\) with \(\bm{A} \succeq 0\), the Gaussian measure \(Q(f) = \mathcal{N}(f|\tilde{\mu}_{\alpha, \bm a}, \tilde{\Sigma}_{\beta, \bm{A}}) \in \cal{Q}^+\) has a corresponding GP with mean and covariance functions defined as \begin{equation} m^{\star}(\mathbf{x}) = \tilde{\mu}_{\alpha, \bm a}(\mathbf{x})\,,\quad K^{\star}(\mathbf{x}, \mathbf{x}') =K(\mathbf{x}, \mathbf{x}') - K_{\mathbf{x}, \mathbf{Z}_\beta}(\bm{A}^{-1} + \bm{K}_\beta)^{-1} K_{\mathbf{Z}_\beta, \mathbf{x}'} \,, \end{equation} where \(d_{\mathcal{H}}(g(\cdot, \hat{\bm{\theta}}), \tilde{\mu}_{\alpha, \bm a}) \leq \epsilon\), with \(d_{\mathcal{H}}(\cdot,\cdot)\).
Proof
If \(g(\cdot, \hat{\bm{\theta}}) \in \mathcal{H}\), the reproducing property of the RKHS verifies that \(\forall \epsilon > 0\) there exists \(\mathbf{Z}_{\alpha} \subset \mathcal{X}\) and \(\{\bm{a}_i\}_{i \in \mathbb{N}}\) such that \(\tilde{\mu}_{\alpha, \bm a}= \sum_{i=1}^{M_\alpha} a_i \phi_{\mathbf{z}_i}\) verifies \begin{equation} d_{\mathcal{H}}(g(\cdot, \hat{\bm{\theta}}), \tilde{\mu}_{\alpha, \bm a}) \leq \epsilon\,. \end{equation} As a result, the mean function of the SVGP is \begin{equation} m^{\star}(\mathbf{x}) = \braket{\phi_{\mathbf{x}}, \tilde{\mu}_{\alpha, \bm a}} = \tilde{\mu}_{\alpha, \bm a}(\mathbf{x}) \approx g(\mathbf{x}, \hat{\bm{\theta}})\,. \end{equation} On the other hand, using that the covariance function is \begin{align} K^{\star}(\mathbf{x}, \mathbf{x}' ) &= \braket{\phi_{\mathbf{x}},\tilde{\Sigma} (\phi_{\mathbf{x}'})} = \braket{\phi_{\mathbf{x}},\phi_{\mathbf{x}'}} - \braket{\phi_{\mathbf{x}},\Phi_{\mathbf{Z}_{\beta}}(\bm{A}^{-1} + \Phi_{\mathbf{Z}_{\beta}}^T\Phi_{\mathbf{Z}_{\beta}})^{-1}\Phi_{\mathbf{Z}_{\beta}}^T \phi_{\mathbf{x}'}}\\ &= K(\mathbf{x}, \mathbf{x}') - K_{\mathbf{x}, \mathbf{Z}_\beta}(\bm{A}^{-1} - K_{\mathbf{Z}_\beta})^{-1} K_{\mathbf{Z}_\beta, \mathbf{x}'}\,. \end{align} □
As a result, the predictive covariance function is \begin{equation} \label{eq:pred_valla} K^{\star}(\mathbf{x}, \mathbf{x}') =K(\mathbf{x}, \mathbf{x}') - K_{\mathbf{x}, \mathbf{Z}_\beta}(\bm{A}^{-1} + \bm{K}_\beta)^{-1} K_{\mathbf{Z}_\beta, \mathbf{x}'}\,. \end{equation}
Proposition 4.2 implies that if \(g(\cdot, \hat{\bm{\theta}}) \in \mathcal{H}\), there exist values for \(\bm{a}\) and inducing points for the mean \(\mathbf{Z}_\alpha\) such that \(d_{\mathcal{H}}(g(\cdot, \hat{\bm{\theta}}), \tilde{\mu}_{\alpha, \bm a})\) can be made as small as desired. For a sufficiently small \(\epsilon\), \(m^\star(\cdot) \approx g(\cdot, \hat{\bm{\theta}})\), and \(g(\cdot, \hat{\bm{\theta}})\) can be used for prediction instead of \(m^\star(\mathbf{x})\). Thus, there is no need to optimize \(\bm{a}\) and \(\mathbf{Z}_{\alpha}\) in Equation \(\eqref{eq:opt_dual}\), and the posterior distribution of VaLLA uses \(g(\cdot, \hat{\bm{\theta}})\) as its mean function. The optimal parameters \(\mathbf{Z}_{\beta}\) and \(\bm{A}\) can be found by optimizing Equation \(\eqref{eq:opt_dual}\) with \(\bm{a}\) and \(\mathbf{Z}_{\alpha}\) held constant. From the following proposition, computing the optimal value of \(\mathbf{A}\) has cost \(\mathcal{O}(NM_\beta^2 + M_\beta^3)\).
Proposition 4.3. The value of \(\bm{A}\) that minimizes Equation \(\eqref{eq:opt_dual}\) is \begin{equation} \bm{A} = \frac{1}{\sigma^2} \bm{K}_{\beta}^{-1} \bm{K}_{\bm{Z}_\beta, \bm{X}}\bm{K}_{\bm{X}, \bm{Z}_\beta} \bm{K}_{\beta}^{-1}\,, \end{equation} where \(\sigma^2\) is the noise variance and \(\bm{K}_{\bm{X}, \bm{Z}_\beta}\) is a matrix with the prior covariances between \(f(\mathbf{X})\) and \(f(\mathbf{Z}_\beta)\). If \(\bm{Z}_\beta=\mathbf{X}\), the covariance function of the predictive distribution in Equation \(\eqref{eq:pred_valla}\) is equal to that of the full GP.
Proof
The objective is to maximize the ELBO over the covariance parameter of the decoupled variational family under a Gaussian likelihood \(p(\mathbf y|\mathbf f) = \mathcal N\big(\mathbf y| \mathbf f,\, \sigma^2 \mathbf I\big)\). Let \(\bm{u}_\beta := f(\mathbf Z_\beta)\) with prior \(p(\bm{u}_\beta)=\mathcal N(0,\bm K_\beta)\), where \((\bm K_\beta)_{ij}=K(\mathbf z_{\beta,i},\mathbf z_{\beta,j})\). Consider a variational distribution \(q(\bm{u}_\beta)=\mathcal N(0, \bm{S})\) with some PSD matrix \(S\). Only the covariance part matters for optimizing \(\mathbf A\). Using Gaussian conditioning identities, \begin{equation} \mathrm{Cov}_q(\mathbf f) = \bm K_{\mathbf X, \mathbf X} - \bm K_{\mathbf X, \mathbf Z_\beta}\,\bm K_\beta^{-1}\bm K_{\mathbf Z_\beta, \mathbf X} + \bm K_{\mathbf X, \mathbf Z_\beta}\,\bm K_\beta^{-1} \bm S \bm K_\beta^{-1}\,\bm K_{\mathbf Z_\beta, \mathbf X}. \end{equation} The expected log-likelihood contributes (up to a constant independent of \(\bm S\)): \begin{equation} -\frac{1}{2\sigma^2}\,\mathrm{tr}\left(\bm K_{\mathbf X, \mathbf Z_\beta}\,\bm K_\beta^{-1} \bm S \bm K_\beta^{-1}\,\bm K_{\mathbf Z_\beta, \mathbf X}\right). \end{equation} Define \begin{equation} \bm C := \bm K_\beta^{-1}\bm K_{\mathbf Z_\beta, \mathbf X}\bm K_{\mathbf X, \mathbf Z_\beta}\bm K_\beta^{-1}. \end{equation} Then the term simplifies to \begin{equation} -\frac{1}{2\sigma^2}\,\mathrm{tr}(\bm S\bm C). \end{equation} The KL divergence between \(q(u_\beta)=\mathcal N(0,\bm S)\) and \(p(u_\beta)=\mathcal N(0,\bm K_\beta)\) is \begin{equation} \mathrm{KL}(q|p) = \tfrac12\Big(\mathrm{tr}(\bm K_\beta^{-1}\bm S) - \log\det(\bm K_\beta^{-1}\bm S) - m_\beta\Big)\,, \end{equation} where \(m_\beta = |\mathbf{Z}_\beta|\). Thus the \(\bm S\)–dependent part of the ELBO is \begin{equation} \mathcal L(S) \;=\; -\frac{1}{2\sigma^2}\,\mathrm{tr}(\bm S\bm C) \;-\; \tfrac12\Big(\mathrm{tr}(\bm K_\beta^{-1}\bm S) - \log\det(\bm K_\beta^{-1}\bm S) - m_\beta\Big). \end{equation} Taking derivatives w.r.t. \(\bm S\) and setting to zero: \begin{equation} -\frac{1}{2\sigma^2}\bm C - \tfrac12 \bm K_\beta^{-1} + \tfrac12 \bm S^{-1} = 0. \end{equation} Therefore \begin{equation} \bm S^{-1} = \bm K_\beta^{-1} + \tfrac{1}{\sigma^2}\bm C\,. \end{equation} Applying Woodbury matrix identity: \begin{equation} \bm S = \bm K_\beta - \bm K_\beta(\sigma^2\bm C^{-1} + \bm K_\beta)^{-1} \bm K_\beta\,. \end{equation} In the dual parameterization, \(\bm S\) and \(\bm A\) are related via: \begin{equation} \bm S \;=\; \bm K_\beta - \bm K_\beta ( \bm A^{-1} + \bm K_\beta)^{-1} \bm K_\beta. \end{equation} As a result, it verifies that \begin{equation} \bm A^{-1} = \sigma^2 \bm{C}^{-1} \implies \bm A = \tfrac{1}{\sigma^2}\bm K_\beta^{-1}\bm K_{\mathbf Z_\beta, \mathbf X}\bm K_{\mathbf X, \mathbf Z_\beta}\bm K_\beta^{-1}\,. \end{equation} If \(\mathbf Z_\beta=\mathbf X\), then \(\bm K_{\mathbf Z_\beta,\mathbf X}=\bm K_{\mathbf X,\mathbf X}\) and \(\bm A^{\ast}=\tfrac{1}{\sigma^2}\bm I\), which recovers the full GP posterior covariance. □
Proposition 4.2 assumes that \(g(\cdot, \hat{\bm \theta}) \in \mathcal{H}\). In practice, this may not be the case. Covariance functions such as squared exponential or Matérn are recognized for spanning the entire space of continuous functions. However, whether \(g(\cdot, \hat{\bm \theta}) \in \mathcal{H}\) holds in general remains unknown. From here onward, if \(g(\cdot, \hat{\bm{\theta}}) \notin \mathcal{H}\), then \(\mathcal{H}\) is assumed to be sufficiently expressive to include a close approximation to \(g(\cdot, \hat{\bm{\theta}})\). Consequently, \(g(\cdot, \hat{\bm \theta})\) can be used as the sparse GP posterior mean.
Hessian Approximation in VaLLA
Despite the formulation using GPs function-space duality, VaLLA can also be understood as a Hessian approximation method. Note that the predictive covariances of VaLLA are given in Equation \(\eqref{eq:pred_valla}\). From this equation, a connection with the exact posterior variances given by the full Hessian in Equation \(\eqref{eq:posterior_gp}\) can be made. From there, one can conclude that VaLLA’s inverse negative Hessian approximation is given by: \begin{equation} \sigma^2_0\bm{I}_P - \sigma^2_0\Phi_{\mathbf{Z}_\beta}(\bm{A}^{-1} + \sigma_0^2\Phi_{\mathbf{Z}_\beta}^T\Phi_{\mathbf{Z}_\beta})^{-1}\Phi_{\mathbf{Z}_\beta}^T\sigma_0^2\,, \end{equation} which is equal to \((\bm{I}_P/\sigma_0^2 + \Phi_{\mathbf{Z}_\beta}^T \bm{A} \Phi_{\mathbf{Z}_\beta})^{-1}\) using the Woodbury formula, where \(\bm{A}\) is a free parameter adjusted by VaLLA by minimizing the KL divergence between stochastic processes, as described above. VaLLA’s negative Hessian approximation is the inverse of the previous quantity. Namely, \begin{equation} \bm{I}_P/\sigma_0^2 + \Phi_{\mathbf{Z}_\beta}^T \bm{A} \Phi_{\mathbf{Z}_\beta}\,, \end{equation} which is similar in structure to the GGN approximation in Equation \(\eqref{eq:ggn}\) considered by LLA, where the data-dependent term has been replaced by the inducing points and the matrix \(\bm{A}\).
VaLLA’s predictive mean is anchored to the DNN output, making the maximization of the ELBO in Equation \(\eqref{eq:opt_dual}\) unsuitable for tuning hyper-parameters like the prior variance \(\sigma_0^2\). Specifically, in a regression scenario with Gaussian noise with variance \(\sigma^2\) the first term in the r.h.s. of Equation \(\eqref{eq:opt_dual}\) becomes: \begin{equation} \sum_{i=1}^N -\frac{\log(2\pi\sigma^2)}{2} - \frac{(y_i - g(\mathbf{x}_i, \hat{\bm{\theta}}))^2}{2\sigma^2} - \frac{K(\mathbf{x}_i, \mathbf{x}_i)}{2\sigma^2}\,, \label{eq:first_term} \end{equation} where \((y_i - g(\mathbf{x}_i, \hat{\bm{\theta}}))^2\) is constant. Maximizing Equation \(\eqref{eq:first_term}\) w.r.t. \(\sigma_0^2\) results in the prior covariances, \(\sigma_0^2 J_{\hat{\bm{\theta}}}(\mathbf{x})^\text{T} J_{\hat{\bm{\theta}}}(\mathbf{x}' )\), tending to \(0\). This makes posterior covariances \(K(\mathbf{x}_i, \mathbf{x}_i)\) also tend to \(0\), effectively canceling the last term in Equation \(\eqref{eq:first_term}\). The KL term in Equation \(\eqref{eq:opt_dual}\) is also optimal and \(0\) if \(\sigma_0^2 \rightarrow 0\). The reasoning is that, in sparse GPs, tuning hyper-parameters involves a trade-off between fitting the mean to the training data and reducing the predictive variance of the model. Therefore, in VaLLA’s setting, where the predictive mean is fixed, the optimal predictive variance tends to be zero.
To address these issues, an alternative objective to Equation \(\eqref{eq:opt_dual}\) is proposed, making use of \(\alpha\)-divergences. This new objective facilitates hyper-parameter optimization: \begin{equation} \label{eq:alpha} \max_{Q(f)} \, \sum_{i=1}^N\frac{1}{\alpha}\log \mathbb{E}_Q \left[P(y_i|f)^\alpha\right] - \mathrm{KL}\left(Q|P\right)\,. \end{equation} Here \(\alpha\in(0,1]\) is a parameter. Instead of minimizing \(\text{KL}(q(f)|P(f|\mathbf{y}))\), this objective minimizes the \(\alpha\)-divergence between \(P(f|\mathbf{y})\) and \(Q(f)\) (Li & Gal, 2017). Remarkably, this can be achieved by simply changing the data-dependent term in the objective of Equation \(\eqref{eq:opt_dual}\).
The use of \(\alpha\)-divergences (see Section for more details) for approximate inference has been extensively studied (Bui et al., 2017; Rodrı́guez-Santana & Hernández-Lobato, 2022; Villacampa-Calvo & Hernández-Lobato, 2020), with observations suggesting that values of \(\alpha \approx 0\) result in better predictive mean estimation. Conversely, values of \(\alpha \approx 1\) provide superior predictive distributions in terms of the log-likelihood. In all the conducted experiments, \(\alpha = 1\). In this case, Equation \(\eqref{eq:alpha}\) does not promote \(\sigma_0^2 \rightarrow 0\), unlike Equation \(\eqref{eq:opt_dual}\), as the data-dependent term is the log-likelihood of the training data. An unexpected behavior, however, is that Equation \(\eqref{eq:alpha}\) may lead to overfitting. To alleviate this, the proposed solution is the use of early-stopping with a validation set (see Section 4 for further details). Early stopping is also used in other LLA approximations e.g. ELLA (Deng et al., 2022).
The objective in Equation \(\eqref{eq:alpha}\) supports mini-batch optimization with cost \(\mathcal{O}(M_\beta^3)\). For \(\alpha=1.0\), \begin{equation} N|\mathcal{B}|^{-1} \sum_{b \in \mathcal{B}} \log \mathbb{E}_Q \left[P(y_b|f(\mathbf{x}_b))\right] -\text{KL}\left(Q|P\right)\,. \label{eq:valla_minibatch_objective} \end{equation} Here, \(\mathcal{B}\) denotes a mini-batch, and the expectation can be computed in closed form in regression problems. In classification, an approximation is available via using the softmax approximation of Daxberger et al. (2021b). This sub-linear cost of VaLLA enables its use in very large datasets.
Predictions for test points \((y^\star,\mathbf{x}^\star)\) are computed using Equation \(\eqref{eq:pred_valla}\) with the DNN output \(g(\mathbf{x}^\star,\hat{\bm{\theta}})\) as the mean. \(P(y^\star|\mathbf{x}^\star) \approx \mathbb{E}_Q[P(y^\star|f(\mathbf{x}^\star)]\) is evaluated as in training.
The locations of the inducing points \(\mathbf{Z}_\beta\) are found by optimizing Equation \(\eqref{eq:valla_minibatch_objective}\) with K-means initialization.
Limitations of VaLLA
VaLLA is limited by three factors: (i) Computing the predictive distribution at each training iteration involves inverting \(\mathbf{A}^{-1} + K_{\mathbf{Z}_\beta}\) in Equation \(\eqref{eq:pred_valla}\), with cubic cost in the number of inducing points \(M_{\beta}\). Therefore, VaLLA cannot accommodate a very large number of inducing points. (ii) The objective in Equation \(\eqref{eq:alpha}\) requires a validation set and early stopping for effective optimization of the prior variance \(\sigma_0^2\), thus further increasing training time. (iii) VaLLA requires additional training compared to other LLA approximations. However, in this regard, early stopping can also reduce the training time by cutting down the number of iterations. In the Taxi experiments performed in Section 4.2.5, early stopping is triggered when only \(16.6\%\) of the training data has been seen. (iii) Mini-batch optimization in Equation \(\eqref{eq:valla_minibatch_objective}\) involves evaluating \(K_{\mathbf{x},\mathbf{Z}_\beta}\) \(\forall \mathbf{x} \in \mathcal{B}\) and \(\bm{K}_{\beta}\). Hence, VaLLA requires efficient evaluation of the (scaled) Neural Tangent Kernel, \(\kappa(\cdot, \cdot) = \sigma^2 J_{\hat{\bm{\theta}}}(\cdot)^\text{T} J_{\hat{\bm{\theta}}}(\cdot)\) and its gradients to find \(\mathbf{Z}_\beta\). While there are libraries that use structure in the derivatives for the efficient computation of \(\kappa(\cdot, \cdot)\), these are limited to a few DNN models (Novak et al., 2022). A simple but inefficient approach to evaluate \(\kappa(\cdot, \cdot)\) involves computing and storing all full Jacobians in memory, for each mini-batch instance and inducing point. This is tractable in many problems, but makes VaLLA infeasible for very large problems, e.g., ImageNet. Appendix 4 shows a very efficient layer-by-layer method to obtain \(K_{\mathbf{x},\mathbf{Z}_\beta}\) \(\forall \mathbf{x} \in \mathcal{B}\) and \(\bm{K}_{\beta}\). However, this requires computing each layer’s contribution to the Jacobian at hand, which is difficult for large and complex DNNs.
Related Work
LA for DNNs was originally introduced by MacKay (1992b), applying it to small networks using the full Hessian. MacKay (1992c) also proposed an approximation similar to the generalized Gauss-Newton (GGN). The combination of scalable factorizations or diagonal Hessian approximations (Botev et al., 2017; Martens & Grosse, 2015) with the GGN approximation (Martens, 2020) played a crucial role in the resurgence of LA for modern DNNs (Khan et al., 2019; Ritter et al., 2018). Recent works aim to relax the Gaussian assumption of LLA, adopting a Riemannian-Laplace approximation, where samples naturally fall into weight regions with low negative log-posterior (Bergamin et al., 2023).
To address the underfitting issue associated with LA (Lawrence, 2001), particularly when combined with the GGN approximation, Ritter et al. (2018) proposed a Kronecker factored (KFAC) LLA approximation. This approach outperforms LA with a diagonal Hessian matrix.
The GP interpretation of LLA (Khan et al., 2019) allows using GP approximate methods to speed up the computations. Immer et al. (2021) propose to use a subset of the training dataset as a scalable alternative to the true GP. Lee et al. (2022) propose a Mixture of Experts approach where each expert is trained on a different soft-margin cluster. However, the proposed clustering algorithm, although more efficient than Kernel-K-means, has linear cost w.r.t. the training set size. VaLLA, on the other hand, has sub-linear training time w.r.t. training set size due to mini-batch training. Moreover, it is not clear how to consider neighboring clusters in high dimensional input spaces. The authors only provide code for a 1-dimensional problem. Third, fitting a local GP using the data of the corresponding cluster and its neighbors is expected to overestimate the predictive variance since the model has been trained with a smaller number of training instances (see Figure 4.9). This is particularly the case in datasets with millions of training instances such as Taxi. This problem is also described by Immer et al. (2021).
Deng et al. (2022) proposed a Nyström approximation of the true GP covariance matrix by using \(M\ll N\) points chosen at random from the training set. The method, called ELLA, has cost \(\mathcal{O}(NM^3)\). ELLA also requires computing the costly Jacobian vectors required in VaLLA, but does not need their gradients. Unlike VaLLA, the Nyström approximation needs to visit each instance in the training set. However, as stated by Deng et al. (2022), ELLA suffers from overfitting. An early-stopping strategy, using a validation set, is proposed to alleviate it. In this case, ELLA only considers a subset of the training data. ELLA does not allow for hyperparameter optimization, unlike VaLLA. The prior variance \(\sigma_0^2\) must be tuned using grid search and a validation set, which increases training time significantly.
The recent work of Scannell et al. (2024) proposes a similar approach to VaLLA, where an inducing point sparse approach is used to construct a GP from a pre-trained DNN. However, two main points differentiate this work from the proposed approach: (i) the pre-trained DNN is not kept as the posterior mean of the model, potentially losing prediction performance and also departing from LLA’s post-hoc nature and goal; (ii) instead of using mini-batches to optimize variational parameters, they perform a full iteration over the training data to find optimal variational parameters. Thus, this results in a potentially slower method than VaLLA, which, due to early stopping and stochastic optimization, can avoid iterating over the full dataset.
Samples from a GP posterior can be efficiently computed using stochastic optimization, eluding the explicit inversion of the kernel matrix (Lin et al., 2024). This approach can be extended to LLA to generate samples from the GP posterior, avoiding the \(\mathcal{O}(N^3)\) cost (Antorán et al., 2023). However, this method cannot provide an estimate of the log-marginal likelihood for hyper-parameter optimization. To address this limitation, Antorán et al. (2023) propose using the EM-algorithm, where samples are generated (E-step) and hyper-parameters are optimized afterwards (M-step) iteratively. The EM algorithm significantly increases computational cost, as generating a single sample is as expensive as training the original DNN on the full data. Finally, the method of Antorán et al. (2023) only considers classification problems.
Another GP-based approach for obtaining prediction uncertainty in the context of DNNs is the Spectral-normalized Neural Gaussian Process (SNGP) (Liu et al., 2023), where the last layer of a DNN is replaced by a GP. This approach allows us to either (i) fine-tune a pre-trained DNN model, or (ii) train a full DNN model from scratch. The former method is considered for the conducted experiments. However, replacing the last layer with a GP often reduces the prediction performance of the initial DNN. This is also observed in the results of Liu et al. (2023). As a result, this method also lies outside LLA-based methods’ main objective, which is to preserve the initial DNN predictions.
Experiments
VaLLA is evaluated against the canonical LLA baselines implemented in the Laplace library of Daxberger et al. (2021a). VaLLA is trained with a fixed mini-batch size of \(100\). For the regression, MNIST, and Fashion-MNIST benchmarks, a conventional multilayer perceptron (MLP) is employed; the resulting model checkpoints are archived to guarantee full reproducibility. For CIFAR-10, the same ResNet backbone employed by Deng et al. (2022) is adopted, where their published metrics are quoted for all competing methods to maintain strict architectural parity. The prior variance of every LLA variant—diagonal, Kronecker-factored (KFAC), and last-layer—is tuned by maximizing the approximate log marginal likelihood. Specifically, the log_marginal_likelihood routine in the Laplace library is used to optimize them for \(40\ 000\) Adam steps with a learning rate of \(10^{-3}\).
The complete VaLLA implementation is publicly available at https://github.com/Ludvins/Variational-LLA.
Synthetic Regression
The predictive distribution of VaLLA is compared with that of LLA (which is considered the optimal method), other LLA variants, and ELLA, on the 1-D regression problem of Izmailov et al. (2020). In ELLA and VaLLA, the optimal hyper-parameters from LLA are employed for this experiment. The results in Figure 4.1 illustrate that VaLLA’s predictive distribution closely aligns with that of LLA. In the ablation studies, Figure 4.7 depicts the predictive distributions of VaLLA and ELLA for varying numbers of inducing points and points in the Nyström approximation, respectively. It shows that VaLLA converges to the true posterior faster than ELLA, with VaLLA tending to overestimate the predictive variance while ELLA underestimates it. In Figure 4.6, the effect of tuning the prior variance in VaLLA is observed in another toy 1-D problem, with and without early stopping. Notably, early stopping, using a validation set, prevents overly small predictive variances in VaLLA. Finally, given that VaLLA estimates the prior variance by maximizing Equation \(\eqref{eq:alpha}\), it tends to underestimate LLA’s predictive variance.
| LLA | VaLLA |
| ELLA | Last-Layer LLA |
| MoE LLA | Kronecker LLA |
Airline, Year and Taxi Regression Problems
For regression problems (Year, Airline, and Taxi datasets), a 3-layer fully connected NN with \(200\) units per layer was used. Optimal weights were obtained by minimizing the RMSE over \(20{,}000\) iterations with a batch size of \(100\) and the Adam optimizer (Kingma & Ba, 2015) with learning rate \(10^{-2}\) and weight decay \(10^{-2}\).
Three large regression datasets are used to validate the performance of VaLLA: (i) the Year dataset (UCI) with \(515\ 345\) instances and \(90\) features, using the original train/test splits; (ii) the US flight delay (Airline) dataset (Dutordoir et al., 2020). Following Ortega et al. (2023), the first \(700\ 000\) instances are used for training and the next \(100\ 000\) for testing. Eight features are considered: month, day of the month, day of the week, plane age, air time, distance, arrival time, and departure time; (iii) the Taxi dataset, with data recorded in January 2023 (Salimbeni & Deisenroth, 2017). Nine attributes are considered: time of day, day of week, day of month, month, PULocationID, DOLocationID, distance, and duration. Trips shorter than 10 seconds and longer than 5 hours are filtered out, resulting in \(3\) million instances. The first \(80\%\) is used as training data, the next \(10\%\) as validation data, and the last \(10\%\) as test data. In all experiments, a 3-layer DNN with \(200\) units, tanh activations, and \(\ell_2\) regularization is employed. VaLLA and ELLA use \(100\) inducing points and \(100\) random points, respectively. A total of \(40\ 000\) iterations with mini-batch size \(100\) are performed for VaLLA. However, for the Taxi dataset, which contains nearly \(3\) million instances, early stopping terminates training at \(5\ 000\) iterations for one of the random seed initializations (this value differs across seeds). Thus, \(500\ 000\) points are visited during training for that seed, corresponding to only \(16.6\%\) of the complete dataset.
Table 4.1 presents the averaged results over 5 random seeds. LLA is not considered here due to intractability. Negative log likelihood (NLL), continuous ranked probability score (CRPS) (Gneiting & Raftery, 2007) and a centered quantile metric (CQM), described below, are reported. VaLLA performs best according to NLL and CQM, while it gives worse results in terms of CRPS compared to the other methods.
| Airline | Year | Taxi | |||||||
|---|---|---|---|---|---|---|---|---|---|
| Model | NLL | CRPS | CQM | NLL | CRPS | CQM | NLL | CRPS | CQM |
| MAP | \(5.087\) | \(18.436\) | \(0.158\) | \(3.674\) | \(5.056\) | \(0.164\) | \(3.763\) | \(\color{purple}\mathbf{3.753}\) | \(\color{teal}\mathbf{0.227}\) |
| LLA Diag | \(5.096\) | \(\color{purple}\mathbf{18.317}\) | \(0.144\) | \(3.650\) | \(4.957\) | \(0.122\) | \(3.714\) | \(3.979\) | \(0.270\) |
| LLA KFAC | \(5.097\) | \(\color{purple}\mathbf{18.317}\) | \(0.144\) | \(3.650\) | \(\color{teal}\mathbf{4.955}\) | \(0.121\) | \(3.705\) | \(3.977\) | \(0.270\) |
| LLA\(^{\ast}\) | \(5.097\) | \(\color{teal}\mathbf{18.319}\) | \(0.144\) | \(3.650\) | \(\color{purple}\mathbf{4.954}\) | \(0.120\) | \(3.718\) | \(3.975\) | \(0.270\) |
| LLA\(^{\ast}\) KFAC | \(5.097\) | \(\color{purple}\mathbf{18.317}\) | \(0.144\) | \(3.650\) | \(\color{purple}\mathbf{4.954}\) | \(0.120\) | \(3.718\) | \(3.976\) | \(0.270\) |
| ELLA | \(5.086\) | \(18.437\) | \(0.158\) | \(3.674\) | \(5.056\) | \(0.164\) | \(3.753\) | \(\color{teal}\mathbf{3.754}\) | \(\color{teal}\mathbf{0.227}\) |
| VaLLA 100 | \(\color{purple}\mathbf{4.923}\) | \(18.610\) | \(\color{teal}\mathbf{0.109}\) | \(\color{teal}\mathbf{3.527}\) | \(5.071\) | \(\color{teal}\mathbf{0.084}\) | \(\color{teal}\mathbf{3.287}\) | \(3.968\) | \(\color{purple}\mathbf{0.188}\) |
| VaLLA 200 | \(\color{purple}\mathbf{4.918}\) | \(18.615\) | \(\color{purple}\mathbf{0.107}\) | \(\color{purple}\mathbf{3.493}\) | \(5.026\) | \(\color{purple}\mathbf{0.076}\) | \(\color{purple}\mathbf{3.280}\) | \(3.993\) | \(\color{purple}\mathbf{0.188}\) |
Centered Quantile Metric (CQM).
The Centered Quantile Metric (CQM) is introduced as a calibration measure for regression models whose predictive distributions are Gaussian with a common mean. CQM extends the well-known Expected Calibration Error from classification to this setting by comparing the nominal and empirical coverage of centered quantile intervals.
For a test input \(\mathbf{x}\) with predictive distribution \(\mathcal{N}(\mu(\mathbf{x}),\sigma^{2}(\mathbf{x}))\), define the \(\alpha\)–central interval \begin{equation} I(\mathbf{x},\alpha)=\bigl(\lambda(-\alpha),\lambda(\alpha)\bigr), \qquad \lambda(\alpha)=\Phi_{\mu(\mathbf{x}),\sigma^{2}(\mathbf{x})}^{-1}\left(\frac{1+\alpha}{2}\right), \quad \alpha\in(0,1), \end{equation} where \(\Phi_{\mu,\sigma^{2}}^{-1}\) denotes the inverse CDF of a Gaussian with mean \(\mu\) and variance \(\sigma^{2}\). Let \begin{equation} \gamma(\alpha)=\mathbb{P}_{(\mathbf{x},y)} \bigl[y\in I(\mathbf{x},\alpha)\bigr] \end{equation} be the fraction of test observations whose targets fall inside that interval. A perfectly calibrated model satisfies \(\gamma(\alpha)=\alpha\) for every \(\alpha\). CQM aggregates the deviation from this ideal across all confidence levels: \begin{equation} \operatorname{CQM} =\int_{0}^{1}\bigl|\gamma(\alpha)-\alpha\bigr|\,d\alpha. \label{eq:CQM} \end{equation} All LLA methods retain the maximum-a-posteriori estimate as predictive mean, so differences in \(I(\mathbf{x},\alpha)\) arise solely from their predictive variances. The integral in Equation \(\eqref{eq:CQM}\) is approximated on an eleven-point grid over \(\alpha\), which simultaneously yields a scalar summary and the calibration curve \(\gamma(\alpha)\) (Figure 4.2). Figure 4.2 shows the evolution of \(\mathbb{P}_{(\mathbf{x}^\star, y^\star)}\left[ y^\star \in I(\mathbf{x}^\star, \alpha) \right]\) w.r.t. \(\alpha\) for the best performing models in the regression problems. CQM corresponds to the area between the shown curve and \(y = x\) (black line). This figure allows us to argue that (in general) all methods are over-estimating the predictive variance as they are giving values above the diagonal. That is, for a specific value of \(\alpha \in (0, 1)\), the reported probabilities are higher than \(\alpha\), meaning that, on average, there are more points in \(I(\mathbf{x}, \alpha)\) than they should. That is, the predicted interval is larger than it should be, which can only mean that the variance is over-estimated. From a geometrical perspective, it is clear that CQM is always greater than \(0\) and lower than \(0.5\); independently of the model and dataset used.
In fact, this figure allows us to visually study the level of over/infra-estimation of the prediction uncertainty, for each degree of confidence \(\alpha\). For example, in the Year dataset, VaLLA slightly over-estimates the uncertainty for \(\alpha \in (0, 0.7)\) while it infra-estimates it for larger values of \(\alpha\).
| Year | Airline |
| Taxi | |
Image Classification Problems
MNIST and FMNIST.
For MNIST and FMNIST experiments, a 2-layer fully connected NN was used with \(200\) units in each layer. The optimal weights are obtained by minimizing the NLL using \(20\ 000\) iterations of batch size \(100\) and Adam optimizer (Kingma & Ba, 2015) with learning rate \(10^{-3}\) and weight decay \(10^{-3}\).
A fully connected DNN with \(200\) units in each layer and tanh activations is employed. In VaLLA, \(100\) and \(200\) inducing points are considered, whereas ELLA uses \(2\ 000\) random points. The out-of-distribution (OOD) detection ability of each method is evaluated using the entropy of the predictive distribution as the score. The area under the ROC curve (AUC) is computed for the binary problem that distinguishes between instances from the dataset pairs MNIST/FMNIST and FMNIST/MNIST (Immer et al., 2021). Moreover, on FMNIST, the robustness of the predictive distribution is assessed by rotating the test images by up to \(180^\circ\) and computing the ECE and NLL on the rotated images (Ovadia et al., 2019) (Figure 4.3).
Table 4.2 shows the results on MNIST. VaLLA gives better uncertainty estimates in terms of NLL and the Brier score but performs less effectively in terms of ECE. Remarkably, VaLLA improves prediction accuracy (ACC) due to the approximation of Daxberger et al. (2021b) to compute class probabilities in multi-class problems. In terms of OOD-AUC, VaLLA outperforms the MAP solution but lags behind other methods s.a. Sampled-LLA or LLA with Kronecker approximations. Figure 4.4 illustrates the training times for each method, with VaLLA being faster than ELLA, Sampled-LLA, or Last-Layer LLA.
| Model | ACC | NLL | ECE | BRIER | OOD-AUC |
|---|---|---|---|---|---|
| MAP | \(\color{teal}\mathbf{97.6}\) | \(\color{teal}\mathbf{0.076}\) | \(\color{purple}\mathbf{0.008}\) | \(\color{teal}\mathbf{0.036}\) | \(0.905\) |
| LLA Diag | \(97.4\) | \(0.143\) | \(0.072\) | \(0.053\) | \(0.922\) |
| LLA KFAC | \(97.5\) | \(0.094\) | \(0.029\) | \(0.041\) | \(\color{teal}\mathbf{0.949}\) |
| LLA\(^{\star}\) | \(\color{teal}\mathbf{97.6}\) | \(0.081\) | \(0.015\) | \(0.037\) | \(0.909\) |
| LLA\(^{\star}\) KFAC | \(\color{teal}\mathbf{97.6}\) | \(0.081\) | \(0.015\) | \(0.037\) | \(0.909\) |
| ELLA | \(\color{teal}\mathbf{97.6}\) | \(\color{teal}\mathbf{0.076}\) | \(\color{purple}\mathbf{0.008}\) | \(\color{teal}\mathbf{0.036}\) | \(0.905\) |
| Sampled LLA | \(\color{teal}\mathbf{97.6}\) | \(0.087\) | \(0.026\) | \(0.040\) | \(\color{purple}\mathbf{0.954}\) |
| VaLLA 100 | \(\color{purple}\mathbf{97.7}\) | \(\color{teal}\mathbf{0.076}\) | \(\color{teal}\mathbf{0.010}\) | \(\color{teal}\mathbf{0.036}\) | \(0.916\) |
| VaLLA 200 | \(\color{purple}\mathbf{97.7}\) | \(\color{purple}\mathbf{0.075}\) | \(\color{teal}\mathbf{0.010}\) | \(\color{purple}\mathbf{0.035}\) | \(0.921\) |
| Model | ACC | NLL | ECE | BRIER | OOD-AUC |
|---|---|---|---|---|---|
| MAP | \(86.6\) | \(0.373\) | \(\color{teal}\mathbf{0.009}\) | \(0.193\) | \(0.874\) |
| LLA Diag | \(86.2\) | \(0.397\) | \(0.043\) | \(0.201\) | \(0.914\) |
| LLA KFAC | \(86.5\) | \(0.377\) | \(0.014\) | \(0.194\) | \(\color{teal}\mathbf{0.932}\) |
| LLA\(^{\star}\) | \(86.6\) | \(0.373\) | \(\color{purple}\mathbf{0.008}\) | \(0.193\) | \(0.882\) |
| LLA\(^{\star}\) KFAC | \(86.6\) | \(0.373\) | \(\color{purple}\mathbf{0.008}\) | \(0.193\) | \(0.880\) |
| ELLA | \(86.6\) | \(0.373\) | \(\color{purple}\mathbf{0.008}\) | \(0.193\) | \(0.874\) |
| VaLLA 100 | \(\color{teal}\mathbf{87.4}\) | \(\color{teal}\mathbf{0.335}\) | \(0.011\) | \(\color{teal}\mathbf{0.182}\) | \(0.923\) |
| VaLLA 200 | \(\color{purple}\mathbf{87.6}\) | \(\color{purple}\mathbf{0.332}\) | \(0.013\) | \(\color{purple}\mathbf{0.181}\) | \(\color{purple}\mathbf{0.933}\) |
Finally, Table 4.2 displays the results on FMNIST. Here, VaLLA excels in prediction accuracy and provides the best uncertainty estimates in terms of NLL and the Brier score. Although it does not perform as well in terms of ECE, the differences are small. VaLLA also achieves the best results in OOD-AUC. Figure 4.3 shows VaLLA holds better performance in terms of ECE and NLL as the test images’ corruption increases (rotation level), indicating the greater robustness of VaLLA’s predictive distribution.
CIFAR10 and ResNet.
Various ResNet architectures are used, and the corresponding pre-trained models are those of Deng et al. (2022) (accessible at https://github.com/chenyaofo/pytorch-cifar-models). Table 4.4 reports ACC, NLL, and ECE for each method, including LLA variants, a mean-field VI approach (Deng & Zhu, 2023), fine-tuned SNGP (Liu et al., 2023), and the approach of Immer et al. (2021) that uses the GP interpretation of LLA with a random subset of \(500\) training instances (GP-Subset). For the latter, the prior parameter was set to the weight decay value used when training the MAP solution, \(\sigma_0^2 = 0.04\), as suggested by Immer et al. (2021). The prior was not scaled by the subset size, as scaling resulted in worse performance. For this experiment, VaLLA is trained for \(40\ 000\) iterations or until Early-Stopping triggers. Since the same pre-trained models are used, the results of all other methods are consistent with those reported by Deng et al. (2022). VaLLA with \(M_\beta=100\) outperforms other methods in most cases, consistently ranking as the best or second-best method. Figure 4.5 shows the NLL on the perturbed test set with five increasing levels of \(19\) image corruptions (Deng et al., 2022). Each box plot summarizes the test NLL for each intensity level across all \(19\) corruptions. The results again highlight VaLLA’s robust predictive distribution, achieving lower NLL compared to the other methods.
| ResNet-20 | ResNet-32 | ResNet-44 | ResNet-56 | Rank | |||||||||
| Method | ACC | NLL | ECE | ACC | NLL | ECE | ACC | NLL | ECE | ACC | NLL | ECE | |
| MAP | \(\color{teal} \mathbf{92.6}\) | \(0.282\) | \(0.039\) | \(\color{purple} \mathbf{93.5}\) | \(0.292\) | \(0.041\) | \(\color{purple} \mathbf{94.0}\) | \(0.275\) | \(0.039\) | \(\color{purple} \mathbf{94.4}\) | \(0.252\) | \(0.037\) | \(-\) |
| MF-VI | \(\color{purple} \mathbf{92.7}\) | \(\color{teal}\mathbf{0.231}\) | \(0.016\) | \(\color{purple} \mathbf{93.5}\) | \(0.222\) | \(0.020\) | \(\color{teal} \mathbf{93.9}\) | \(0.206\) | \(0.018\) | \(\color{purple} \mathbf{94.4}\) | \(0.188\) | \(0.016\) | \(-\) |
| SNGP | \(92.4\) | \(0.266\) | \(0.024\) | \(93.2\) | \(0.256\) | \(0.025\) | \(93.8\) | \(0.242\) | \(0.028\) | \(93.8\) | \(0.229\) | \(0.022\) | \(-\) |
| GP - Subset | \(\color{teal} \mathbf{92.6}\) | \(0.555\) | \(0.299\) | \(\color{teal} \mathbf{93.4}\) | \(0.462\) | \(0.247\) | \(93.6\) | \(0.424\) | \(0.225\) | \(\color{purple} \mathbf{94.4}\) | \(0.403\) | \(0.221\) | \(-\) |
| LLA Diag | \(92.2\) | \(0.728\) | \(0.404\) | \(92.7\) | \(0.755\) | \(0.430\) | \(92.8\) | \(0.778\) | \(0.445\) | \(\color{teal}\mathbf{92.9}\) | \(0.843\) | \(0.480\) | \(-\) |
| LLA KFAC | \(92.0\) | \(0.852\) | \(0.467\) | \(91.8\) | \(1.027\) | \(0.547\) | \(91.4\) | \(1.091\) | \(0.566\) | \(89.8\) | \(1.174\) | \(0.579\) | \(-\) |
| LLA\(^{\ast}\) | \(\color{teal} \mathbf{92.6}\) | \(0.269\) | \(0.034\) | \(\color{purple} \mathbf{93.5}\) | \(0.259\) | \(0.033\) | \(\color{purple} \mathbf{94.0}\) | \(0.237\) | \(0.028\) | \(\color{purple} \mathbf{94.4}\) | \(0.213\) | \(0.022\) | \(-\) |
| LLA\(^{\ast}\) KFAC | \(\color{teal} \mathbf{92.6}\) | \(0.271\) | \(0.035\) | \(\color{purple} \mathbf{93.5}\) | \(0.260\) | \(0.033\) | \(\color{purple} \mathbf{94.0}\) | \(0.232\) | \(0.028\) | \(\color{purple} \mathbf{94.4}\) | \(0.202\) | \(0.024\) | \(-\) |
| ELLA | \(92.5\) | \(0.233\) | \(0.009\) | \(\color{purple} \mathbf{93.5}\) | \(\color{teal}\mathbf{0.215}\) | \(\color{teal}\mathbf{0.008}\) | \(\color{teal} \mathbf{93.9}\) | \(0.204\) | \(\color{purple} \mathbf{0.007}\) | \(\color{purple} \mathbf{94.4}\) | \(0.187\) | \(\color{purple} \mathbf{0.007}\) | \(2.37\) |
| Sampled LLA | \(92.5\) | \(\color{teal} \mathbf{0.231}\) | \(\color{purple} \mathbf{0.006}\) | \(\color{purple}\mathbf{93.5}\) | \(0.217\) | \(\color{teal}\mathbf{0.008}\) | \(\color{purple}\mathbf{94.0}\) | \(\color{teal} \mathbf{0.200}\) | \(\color{purple}\mathbf{0.007}\) | \(\color{purple}{\mathbf{94.4}}\) | \(\color{teal} \mathbf{0.185}\) | \(0.015\) | \(\color{teal}\mathbf{2.00}\) |
| VaLLA | \(\color{teal} \mathbf{92.6}\) | \(\color{purple} \mathbf{0.228}\) | \(\color{teal}\mathbf{0.007}\) | \(\color{purple} \mathbf{93.5}\) | \(\color{purple} \mathbf{0.211}\) | \(\color{purple} \mathbf{0.007}\) | \(\color{purple} \mathbf{94.0}\) | \(\color{purple} \mathbf{0.198}\) | \(\color{teal}\mathbf{0.008}\) | \(\color{purple} \mathbf{94.4}\) | \(\color{purple} \mathbf{0.183}\) | \(\color{teal} \mathbf{0.009}\) | \(\color{purple}\mathbf{1.37}\) |
Ablation Studies
To better understand the contribution of each design choice in the proposed approach, the next subsection presents a detailed ablation analysis of its main components.
Efficient Kernel Computation for MLP
In this section, an efficient implementation for computing the Neural Tangent Kernel \(\kappa(\mathbf x, \mathbf x')\) is reviewed. First of all, consider that the computation of the kernel can be reduced to a summation on the number of parameters of the model: \begin{equation} \kappa(\mathbf x, \mathbf x') = \sigma_0^2 J_{\hat{\bm{\theta}}}(\mathbf x)^T J_{\hat{\bm{\theta}}}(\mathbf x') = \sigma_0^2\sum_{\theta_s \in \hat{\bm{\theta}}} \frac{\partial}{\partial \theta_s} g(\mathbf{x}, \hat{\bm{\theta}}) \frac{\partial}{\partial \theta_s} g(\mathbf{x}', \hat{\bm{\theta}}). \end{equation} One of the limitations of computing the kernel is storing \(J_{\hat{\bm{\theta}}}(\mathbf x)\) in memory, which is a 3-dimensional tensor of (batch size, number of classes, number of parameters). Computing the kernel as a sum allows us to simplify the required computations significantly (no longer have to store in memory the Jacobians). Consider now a MLP as \begin{equation} g(\mathbf{x}, \hat{\bm{\theta}}) = h_L \circ a \circ H_{L-1} \circ \cdots \circ a \circ h_1(\mathbf{x})\,, \end{equation} where each function \(a\) is a non-linear activation function and each function \(h\) is a linear function of the form \begin{equation} h_l(\mathbf x) = \bm{W}_l^T \mathbf x + \bm{b}_l\,. \end{equation} With this, \(g\) is supposed to be a fully-connected neural network of \(L\) layers. Each of the partial derivatives of the neural network is \begin{equation} \frac{\partial}{\partial W_{l,j,i}} g(\mathbf{x}, \hat{\bm{\theta}}) \quad \text{and}\quad \frac{\partial}{\partial b_{l,j}} g(\mathbf{x}, \hat{\bm{\theta}}) \quad \forall l=1,\dots,L\,, \end{equation} and the kernel is computed simply by adding the product of these derivatives. Here, \(i\) is a sub-index denoting input \(i\)-th to layer \(l\). Similarly, \(j\) is a sub-index denoting each component of the bias vector parameter at layer \(l\), or similarly, each output of that layer.
In fact, using the structure of the model and the chain rule, the derivative of the \(o^{th}\) output of the network w.r.t. the \(j^{th},i^{th}\) weight parameter of the \(l^{th}\) layer is: \begin{equation} \frac{\partial}{\partial W_{l,j,i}} g_o(\mathbf{x}, \hat{\bm{\theta}}) = \left(\textcolor{teal}{\frac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\right) ^T \left(\textcolor{red}{\frac{\partial}{\partial W_{l,j,i}} h_l}\right)\,, \label{eq:output_derivative} \end{equation} where each of the two vectors in the r.h.s. has length equal to the number of units in the layer \(l\). In fact, \begin{equation} \textcolor{red}{\frac{\partial}{\partial W_{l,j,i}} h_l} = \bm{1}_l \cdot \textcolor{purple}{a}(\textcolor{blue}{h_{l-1}})_i\,, \label{mlp:structured_derivative} \end{equation} where \(\textcolor{purple}{a}(\textcolor{blue}{h_{l-1}})_i\) corresponds to the inputs of the \(l^{th}\) layer. Moreover, \(\frac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})\) can also be computed using the chain rule: \begin{equation} \textcolor{teal}{\frac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})} = \frac{\partial}{\partial h_{l+1}} g_o(\mathbf{x}, \hat{\bm{\theta}})\frac{\partial}{\partial h_{l}} h_{l+1} = \textcolor{teal}{\frac{\partial}{\partial h_{l+1}} g_o(\mathbf{x}, \hat{\bm{\theta}})} \textcolor{orange}{\bm{W}_l}^T \text{diag}(\textcolor{cyan}{a'}(\textcolor{blue}{h_{l}}))\,, \end{equation} which can be computed by back-propagating the derivatives. The same derivations apply to the biases of each layer \(b_{l,j}\). As a result, the derivatives depend only on a back-propagating term \(\textcolor{teal}{\frac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\) for each layer, the parameter values \(\textcolor{orange}{\bm{W}_l}, \textcolor{orange}{\bm{b}_l}\), and the propagated outputs at each layer \(\textcolor{blue}{h_1},\dots,\textcolor{blue}{h_{L-1}}\), evaluated at the non-linear activation \(\textcolor{purple}{a}(\cdot)\) and its derivative \(\textcolor{cyan}{a'}(\cdot)\). Consequently, if the intermediate outputs of each layer (\(\textcolor{blue}{h_1},\dots,\textcolor{blue}{h_{L-1}}\)) are stored during the forward pass, a single backward pass suffices to compute \(\textcolor{teal}{\frac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\) for each layer.
Critically, given each \(\textcolor{teal}{\frac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\), the contribution of each layer to the kernel can be added using Equation \(\eqref{eq:output_derivative}\). In this process, computations can be accelerated by exploiting structure in the derivatives. For example, in Equation \(\eqref{mlp:structured_derivative}\) the derivative has a simple form—a vector of ones scaled by a scalar. Furthermore, there is no dependence on \(j\), the output unit corresponding to the weight \(W_{l,j,i}\). Therefore, for two instances \(\mathbf{x}\) and \(\mathbf{x}'\), the kernel contribution (ignoring the prior variance parameter) corresponding to outputs \(o\) and \(o'\) is: \begin{align*} \tfrac{\partial}{\partial W_{l,j,i}} &g_o(\mathbf{x}, \hat{\bm{\theta}}) \tfrac{\partial}{\partial W_{l,j,i}} g_o'(\mathbf{x}', \hat{\bm{\theta}}) = \left(\textcolor{teal}{\tfrac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\right) ^T \left(\textcolor{red}{\tfrac{\partial}{\partial W_{l,j,i}} h_l}\right) \left(\textcolor{teal}{\tfrac{\partial}{\partial h_l} g_o(\mathbf{x}', \hat{\bm{\theta}})}\right) ^T \left(\textcolor{red}{\tfrac{\partial}{\partial W_{l,j,i}} h_l}\right) \\ &= \left(\textcolor{teal}{\tfrac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\right) ^T \left(\textcolor{red}{\tfrac{\partial}{\partial W_{l,j,i}} h_l}\right) \left(\textcolor{red}{\tfrac{\partial}{\partial W_{l,j,i}} h_l}\right)^T \left(\textcolor{teal}{\tfrac{\partial}{\partial h_l} g_o'(\mathbf{x}', \hat{\bm{\theta}})}\right) \\ &= \left(\textcolor{teal}{\tfrac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\right) ^T \bm{1}_l \cdot \textcolor{purple}{a}(\textcolor{blue}{h_{l-1}(\mathbf{x})})_i \cdot \textcolor{purple}{a}(\textcolor{blue}{h_{l-1}(\mathbf{x}')})_i \bm{1}_l^T \left(\textcolor{teal}{\tfrac{\partial}{\partial h_l} g_o'(\mathbf{x}', \hat{\bm{\theta}})}\right) \\ &= s_{o,\mathbf{x}}^l \textcolor{purple}{a}(\textcolor{blue}{h_{l-1}(\mathbf{x})})_i \cdot \textcolor{purple}{a}(\textcolor{blue}{h_{l-1}(\mathbf{x}')})_i s_{o',\mathbf{x}'}^l \\ &= s_{o,\mathbf{x}}^l \mathbf{s}_{o',\mathbf{x}'}^l \textcolor{purple}{a}(\textcolor{blue}{h_{l-1}(\mathbf{x})})_i \textcolor{purple}{a}(\textcolor{blue}{h_{l-1}(\mathbf{x}')})_i \,, \end{align*} with \(s_{o,\mathbf{x}}^l = \left(\textcolor{teal}{\frac{\partial}{\partial h_l} g_o(\mathbf{x}, \hat{\bm{\theta}})}\right) ^T \bm{1}_l\) a scalar. Similar simplifications occur in the case of e.g., a convolutional layer.
Summing up, by using this method, all the required kernel matrices can be easily and efficiently computed, for a mini-batch of data points and a set of inducing points, with a similar cost as that of letting the mini-batch or the inducing points go through the DNN. A disadvantage is, however, that the described computations will have to be manually coded for each different DNN architecture. This becomes tedious in the case of very big DNNs with complicated layers, as described in Section 4.2.3.2.
Over-fitting and Early Stopping
Standard maximization of the ELBO does not permit optimization of the prior variance. In summary, the optimal prior variance becomes infinite because the mean is fixed at the MAP solution. As discussed, this issue is circumvented by applying \(\alpha\)-divergences, which are well-defined in this learning setup and allow optimization of the prior. However, this objective is not perfect and tends to overfit the prior variance to the training data. The middle column of Figure 4.6 shows the predictive distribution (twice the standard deviation) learned by VaLLA using the black points as training data. The MAP solution is obtained with a two hidden-layer MLP with \(50\) hidden units and tanh activation, optimized to minimize the RMSE of the training data for \(10\,000\) iterations using Adam with learning rate \(10^{-3}\). VaLLA, in contrast, is trained for \(20\,000\) iterations. As shown in the figure, the prior variance is overfit to the data, such that the uncertainty does not increase in the central gap of the data.
| VaLLA Val \(M=5\) | VaLLA No-Val \(M=5\) |
| LLA | VaLLA Val \(M=10\) |
| VaLLA No-Val \(M=10\) | ELLA \(M=10\) |
In this experiment, VaLLA optimizes hyper-parameters jointly with the variational objective. LLA optimizes the prior variance and the likelihood variance by maximizing the marginal log-likelihood, and ELLA uses the optimal hyper-parameters obtained by LLA.
Two simple courses of action are available: (i) return to the original ELBO and select the prior variance via cross-validation; or (ii) perform early stopping with the \(\alpha\)-divergences objective, using a validation set to halt training before the prior variance overfits. The latter approach may not always succeed, as it assumes the existence of a training point at which the prior variance captures the underlying data without overfitting. However, when the prior variance is set to a relatively large value compared to the small, overfitting-prone optimum, this method yields strong performance for VaLLA while avoiding the cost of cross-validation. The left column of Figure 4.6 shows the predictive distribution (twice the standard deviation) learned by VaLLA in this setting, using the black points as training data and the orange points as the validation set. In the experiments, the validation NLL was computed every \(100\) training iterations, and training was stopped when this metric worsened, thereby reducing computational time.
Increasing Inducing Points
Using the optimal covariance in Proposition 4.3, if the set of inducing points equals the training points \(\mathbf{Z} = \mathbf{X}\), the posterior distribution of VaLLA coincides with that of the exact LLA Gaussian Process. This observation suggests that increasing the number of inducing points leads to improved uncertainty estimates. This section examines how closely the predictive distribution of VaLLA matches that of LLA as the number of inducing points in \(\mathbf{Z}\) increases.
Figure 4.7 shows the obtained predictive distribution of VaLLA (first column) and ELLA (second column) for \(M=5\), \(M=10\) and \(M=20\) inducing points/samples. The initial and final locations of the inducing points are also shown for VaLLA. The posterior distribution obtained by LLA is shown in dotted orange. The MAP solution is obtained using a 2 hidden layer MLP with \(50\) hidden units and tanh activation, optimized to minimize the RMSE of the training data for \(12\ 000\) iterations with Adam and learning rate \(10^{-3}\). VaLLA, on the other hand, is trained for \(30\ 000\) iterations. For this experiment, VaLLA and ELLA use the optimal prior variance and likelihood variance obtained by optimizing LLA’s marginal log likelihood. As one may see in the image, it is clear that one of the main differences between the two methods is that VaLLA tends to overestimate the variance, whereas ELLA tends to underestimate it, compared to LLA. Furthermore, the value of \(M\) for which the model is closer to the LLA posterior is lower for VaLLA than for ELLA. As \(M\) increases, VaLLA’s predictive distribution becomes closer and closer to that of LLA.
The initial and final positions of the inducing locations are also shown in the figure. For this experiment, the initial values are computed using K-Means. It can be seen how VaLLA is capable of tuning the inducing locations and moving them from one cluster of points to another as needed. This is one of the main advantages of this method compared to ELLA.
Fixed-Mean Gaussian Processes
A novel family of GPs, Fixed-Mean Gaussian Processes (FMGPs), is presented. This family of function-space distributions is defined using the dual formulation of sparse variational GPs. First, consider a subset of the input space \(\cal Z \subset \cal X\). For any kernel function \(K\), the kernel section is defined as the closure of all linear combinations of kernel evaluations on \(\cal Z\); that is, \begin{equation} K(\mathcal{Z}) := \overline{\left\{\sum_{i=1}^n a_i K(\cdot, \mathbf{z}_i) :\, n \in \mathbb{N}, \ a_i \in \mathbb{R}, \ \mathbf{z}_i \in \mathcal{Z}\right\}}\,. \end{equation}
Definition 4.4. A kernel function is said to be universal if for any compact subset \(\mathcal{Z}\) of the input space \(\mathcal{X}\), the kernel section \(K(\mathcal{Z})\) is dense in \(C(\mathcal{Z})\) with the infinity norm. That is, for any \(f \in C(\mathcal{Z})\) and any \(\epsilon > 0\), there exists \(g_\epsilon \in K(\mathcal{Z})\) such that \(\|g_\epsilon - f\|_\infty \leq \epsilon\).
Given a universal kernel, building on the dual representation of GPS (Section) the decoupled family of measures \(\cal{Q}^+\) is employed to show that the mean function of the corresponding GP can be fixed to any continuous function on compact subsets.
Proposition 4.5. Let \(\mathcal{Z}\subset\mathcal{X}\) be compact and let \(K\colon\mathcal{X}\times\mathcal{X}\to\mathbb{R}\) be a universal kernel in the sense of Definition 4.4. Then, for every function \(f\in C(\mathcal{Z})\) and for every \(\varepsilon>0\) there exist \begin{equation} M_\alpha\in\mathbb{N},\qquad \bigl\{\mathbf{z}_1,\dots,\mathbf{z}_{M_\alpha}\bigr\}\subset\mathcal{Z}, \quad\text{and}\quad \bigl\{a_1,\dots,a_{M_\alpha}\bigr\}\subset\mathbb{R}, \end{equation} such that the finite kernel expansion \begin{equation} m(\mathbf{x})\;:=\;\bigl\langle\phi_{\mathbf{x}},\tilde{\mu}_{\alpha,\bm{a}}\bigr\rangle_{\mathcal{H}_K} \;=\; \sum_{m=1}^{M_\alpha}a_m\,K\!\bigl(\mathbf{x},\mathbf{z}_m\bigr), \qquad\mathbf{x}\in\mathcal{Z}, \end{equation} satisfies the uniform approximation bound \begin{equation} \bigl\lVert f - m\bigr\rVert_{\infty,\mathcal{Z}}\;\le\;\varepsilon. \end{equation}
Proof
Fix an arbitrary function \(f\in C(\mathcal{Z})\) together with \(\varepsilon>0\). Because \(K\) is universal, Definition 4.4 guarantees the existence of a function \(g_{\varepsilon/2}\in K(\mathcal{Z})\) such that \begin{equation} \bigl\lVert f - g_{\varepsilon/2}\bigr\rVert_{\infty,\mathcal{Z}} \;\le\;\frac{\varepsilon}{2}. \end{equation} By construction, the kernel section admits the representation \begin{equation} K(\mathcal{Z}) \;=\; \overline{\Bigl\{ \sum_{i=1}^{n} a_i\,K(\,\cdot\,,\mathbf{z}_i) : n\in\mathbb{N},\;a_i\in\mathbb{R},\; \mathbf{z}_i\in\mathcal{Z} \Bigr\}}^{\lVert\cdot\rVert_{\infty,\mathcal{Z}}}. \end{equation} Consequently, there exists a finite sum \begin{equation} g(\mathbf{x}) \;=\; \sum_{m=1}^{M_\alpha}a_m\,K \bigl(\mathbf{x},\mathbf{z}_m\bigr), \qquad \mathbf{x}\in\mathcal{Z}, \end{equation} such that \begin{equation} \bigl\lVert g_{\varepsilon/2}-g\bigr\rVert_{\infty,\mathcal{Z}} \;\le\;\frac{\varepsilon}{2}. \end{equation} Combining the two preceding inequalities yields \begin{equation} \bigl\lVert f-g\bigr\rVert_{\infty,\mathcal{Z}} \;\le\; \bigl\lVert f-g_{\varepsilon/2}\bigr\rVert_{\infty,\mathcal{Z}} +\bigl\lVert g_{\varepsilon/2}-g\bigr\rVert_{\infty,\mathcal{Z}} \;\le\;\frac{\varepsilon}{2}+\frac{\varepsilon}{2} \;=\;\varepsilon. \end{equation} □
As a result, for any error rate \(\epsilon > 0\), the posterior mean of a decoupled GP can be set to any continuous function on any compact subset of the input space. More precisely, the decoupled formulation of GPs can be used to fix the posterior mean to the output \(g(\cdot)\) of a given pre-trained DNN.
Definition 4.6. For any compact subset of the input space \(\mathcal{Z} \subset \mathcal{X}\), continuous function \(g \in C(\mathcal{Z})\), error \(\epsilon > 0\) and universal kernel \(K: \mathcal{X} \times \mathcal{X} \to \mathbb{R}\), the set of \(g\)-mean Gaussian measures \(\mathcal{Q}^{g}_{\mathcal{Z}, \epsilon} \subset \mathcal{Q}^+\) is defined as \begin{equation} \mathcal{Q}^{g}_{\mathcal{Z}, \epsilon} := \left\{\mathcal{N}(f| \tilde{\mu}_{\alpha, \bm{a}}, \tilde{\Sigma}_{\beta, \bm A}) \ : \ \bm{A} \in \mathbb{R}^{M_\beta \times M_\beta}\,, \bm{A} \succeq 0\,, \mathbf Z_\beta \in \mathcal{X}^{M_\beta} \right\} \,, \end{equation} where \(\tilde{\mu}_{\alpha, \bm{a}}\) verifies \(\big\|g(\mathbf{x}) - \braket{\phi_{\mathbf{x}}, \tilde{\mu}_{\alpha, \bm{a}}}_\mathcal{H}\big\|_{\mathcal{Z}} \leq \epsilon\).
Thus, for any \(Q(f) \in \mathcal{Q}^{g}_{\mathcal{Z}, \epsilon}\), the corresponding GP \(f \sim \mathcal{GP}(m^{\star}, K^\star_{\mathbf{Z}_\beta, \bm{A}})\) verifies that \begin{equation} \big\|g(\mathbf x) - m^\star(\mathbf{x}) \big\|_{\mathcal{Z}} \leq \epsilon\,, \end{equation} \begin{equation} \label{eq:pred} K^{\star}_{\mathbf{Z}_\beta, \bm{A}}(\mathbf{x}, \mathbf{x}') = K(\mathbf{x}, \mathbf{x}') - K(\mathbf{x}, \mathbf{Z}_\beta)(\bm{A} + K(\mathbf{Z}_\beta, \mathbf{Z}_\beta))^{-1}K(\mathbf{Z}_\beta, \mathbf{x}') \,. \end{equation} This set of GPs is referred to as fixed-mean Gaussian processes (FMGPs). By Proposition 4.5, it is clear that for any \(g \in C(\mathcal{Z})\), it is verified that \(\mathcal{Q}^{g}_{\mathcal{Z}, \epsilon} \neq \emptyset\) and its corresponding set of FMGPs exists. Again, VI can be used to find the optimal \(q\) from this parametric family. That is, \begin{align} \argmin_{Q \in \mathcal{Q}^{g}_{\mathcal{Z}, \epsilon}} \ \mathrm{KL}\big(Q(f) \big| P(f|\mathbf{y})\big) = \argmax_{Q \in \mathcal{Q}^+} \ \mathbb{E}_{Q(f)}[\log P(\mathbf{y}|f)] - \mathrm{KL}\big(Q(f) \big| P(f)\big)\,, \label{eq:opt_kl2} \end{align} where the KL term is, setting \(p(f)=\mathcal{N}(f|0, I)\): \begin{equation} \mathrm{KL}\big(Q(f) | P(f)\big) = \frac{1}{2} \log |\bm{I} - \bm{K}_\beta(\bm{A}^{-1} + \bm{K}_\beta)^{-1}| - \frac{1}{2} \text{tr}\left( \bm{K}_\beta\bm{A} \right) + \text{constant}\,. \end{equation} Only \(\mathbf{A}\) and \(\mathbf{Z}_\beta\) require optimization. The method that solves Equation \(\eqref{eq:opt_kl2}\) is referred to as the fixed-mean Gaussian process (FMGP).
Figure 4.8 shows a set representation of the different families of Gaussian measures considered in this section. Observe that \(\mathcal{Q}^+\), in which there are different inducing points for the mean and the covariances, is the largest family. This family includes both \(\mathcal{Q}\), in which there is only a single set of inducing points for the mean and the covariances, and \(\mathcal{Q}^{g}_{\mathcal{Z}, \epsilon}\), in which the posterior mean is fixed to approximate \(g\) in \(\mathcal{Z}\) with error at most \(\epsilon\). Note that there could potentially be some overlap between the Gaussian measures in \(\mathcal{Q}\) and in \(\mathcal{Q}^{g}_{\mathcal{Z}, \epsilon}\).
Application to Post-hoc Bayesian Deep Learning
FMGP enables the conversion of DNN into approximate Bayesian models while maintaining the DNN output as the predictive mean. The process is straightforward and can be summarized as follows:
Given a pre-trained model \(g \in C(\mathcal{X})\), choose a parametric family of kernels that defines an RKHS and a family of Gaussian measures \(\mathcal{P}_\mathcal{H}\), e.g. squared exponential kernels.
Ensure that there exists a compact set \(\mathcal{Z} \subset \mathcal{X}\) where inputs are expected. The parametric family of Gaussian measures \(\mathcal{Q}^{g}_{\mathcal{Z}, \epsilon}\) exists for any \(\epsilon > 0\).
Initialize a measure in this family, e.g. initialize \(\mathbf{Z}_\beta\) using K-Means (Lloyd, 1982) and \(\bm{A}\) as the identity matrix.
Perform VI to optimize the variational measure (\(\bm{A}\) and \(\mathbf{Z}_\beta\)), along with the kernel hyper-parameters \(\Omega\) and the noise variance \(\sigma^2\), using Equation \(\eqref{eq:opt_kl2}\). The predictive variance is computed as in Equation \(\eqref{eq:pred}\), and, if \(\mathcal{Z}\) is large and \(\epsilon\) small, the predictive mean \(m^\star\) approximates the pre-trained model \(g\). Thus, in practice, \(m^\star\) can be replaced by \(g\) in the computations.
Regularization and Loss Function
In standard sparse GPs, tuning hyper-parameters entails balancing the fit of the mean to the training data against reducing the model’s predictive variance. In FMGPs, however, the predictive mean is fixed, eliminating this trade-off. Consequently, the kernel hyper-parameters \(\Omega\) adjust only the predictive variance without affecting the mean. Maximizing the variational-inference ELBO in Equation \(\eqref{eq:opt_kl2}\) with respect to \(\Omega\) can therefore yield undesirable solutions in which the predictive variance collapses to zero. To counteract this, a regularization technique is introduced using an additional variational Gaussian measure. Specifically, an auxiliary Gaussian measure \(Q^\star \in \mathcal{Q}\) is considered that shares \(Q\)’s parameters (\(\bm{A}\), \(\mathbf{Z}_\beta\), and the kernel hyper-parameters) while also incorporating \(\bm{a} \in \mathbb{R}^{M_\beta}\) and \(\mathbf{Z} = \mathbf{Z}_\beta\) as additional parameters for its predictive mean. This yields the following loss function: \begin{equation} \mathcal{L}(\Gamma) = \underbrace{\mathbb{E}_{Q(f)}[\log P(\mathbf{y}|f)] - \mathrm{KL}\big(Q\big| P\big)}_{\text{ELBO}(Q)} + \underbrace{\mathbb{E}_{Q^\star(f)}[\log P(\mathbf{y}|f)] - \mathrm{KL}\big(Q^\star \big| P\big)}_{\text{ELBO}(Q^\star)}\,. \end{equation} with \(\Gamma =\{\bm{a}, \bm{A}, \mathbf{Z}_\beta, \Omega, \sigma^2\}\). This loss function implies that the predictive variances must account for training data under two predictive means: the pre-trained one, \(g(\cdot)\), and the one defined by \(\bm{a}\). Additionally, \(\Omega\) cannot be adjusted solely to fit the variances, as it also affects \(q^\star\)’s predictive mean.
Moreover, the use of \(\alpha\)-divergences (see Section for more details) for approximate inference has been widely explored (Bui et al., 2017; Hernandez-Lobato et al., 2016; Rodrı́guez-Santana & Hernández-Lobato, 2022; Villacampa-Calvo & Hernández-Lobato, 2020), with findings indicating that values of \(\alpha \approx 0\) enhance predictive mean estimation, while \(\alpha \approx 1\) improve predictive distributions, reflected in higher test log-likelihood performance. In consequence, instead of minimizing \(\mathrm{KL}(Q(f)|P(f|\mathbf{y}))\), our objective is changed using a generalized view of VI (Knoblauch et al., 2022) to minimize the \(\alpha\)-divergence between \(P(f|\mathbf{y})\) and \(Q(f)\), in an approximate way, for \(\alpha \approx 1.0\) (Li & Gal, 2017). This can be achieved by changing the data-dependent term of the loss (Section): \begin{align} \label{eq:loss} \mathcal{L}(\Gamma) = \log \mathbb{E}_{Q(f)}[P(\mathbf{y}|f)] - \mathrm{KL}\big(Q\big|P\big) + \log \mathbb{E}_{Q^\star(f)}[P(\mathbf{y}|f)] - \mathrm{KL}\big(Q^\star \big| P\big)\,, \end{align} where now the expectation is inside the logarithm function.
The objective in Equation \(\eqref{eq:loss}\) supports mini-batch optimization with a cost in \(\mathcal{O}(M_\beta^3 + |\mathcal{B}|M_\beta^2)\): \begin{equation} \begin{aligned} \mathcal{L}(\Gamma) &\approx \frac{N}{|\mathcal{B}|} \sum_{b \in \mathcal{B}} \log \mathbb{E}_{Q(f)} \left[P(y_b|f)\right] -\mathrm{KL}\left(Q|P\right) \\ &\quad+\frac{N}{|\mathcal{B}|} \sum_{b \in \mathcal{B}} \log \mathbb{E}_{Q^\star(f)} \left[P(y_b|f)\right] -\mathrm{KL}\left(Q^\star|P\right)\,, \end{aligned} \end{equation} where \(\mathcal{B}\) is a mini-batch of points. The expectation can be computed in closed form in regression. In classification, an approximation is available via the softmax method in (Daxberger et al., 2021b).
Limitations
FMGP is limited by three factors:
Computing the predictive distribution at each training iteration involves inverting \(\bm{A}^{-1} + \bm{K}_{\mathbf{Z}_\beta}\), with cubic cost in the number of inducing points \(M_{\beta}\). Therefore, FMGP cannot accommodate a very large number of inducing points. However, as shown in the experiments, this number can be set to a very low value, such as \(20\), even for classification tasks with a thousand classes.
FMGP requires additional optimization steps compared to other post-hoc approximations, e.g., (Deng et al., 2022). However, these other methods often rely on visiting every training point to compute specific updates. As a result, FMGP training can be faster in large datasets featuring millions of instances such as ImageNet.
The construction of FMGP requires choosing a (parametric) kernel. This is both an advantage, as it may better capture the underlying data patterns in the modeling process, and a disadvantage, as it may be difficult to efficiently use an effective kernel in some tasks, such as image classification.
Related Work
Due to its post-hoc nature, FMGP is highly related to the Linearized Laplace Approximation (LLA) for deep learning. In (MacKay, 1992c), the Laplace Approximation (LA) was introduced by applying it to small DNNs. LA simply approximates the DNN posterior in parameter space with a Gaussian centered at its mode matching the posterior Hessian. LA can be made more scalable by considering a Generalized Gauss-Newton (GGN) approximation of the Hessian at the mode, which is equivalent to linearizing the DNN. When this linearization is used at prediction time, LA becomes a post-hoc method known as LLA (Ritter et al., 2018). Moreover, LLA addresses the underfitting issues associated with LA (Lawrence, 2001). Despite this, the GGN approximate Hessian of LLA is still intractable in real problems and has to be further approximated using, e.g., Kronecker factored (KFAC) approximations. Recently, many approaches have been developed to make LLA more scalable and accurate, trying to match LLA with the GGN approximate Hessian, including Nyströn approximations (Deng et al., 2022), variational approaches (Ortega et al., 2024; Scannell et al., 2024) or sample-based approximations (Antorán et al., 2023).
Among LLA methods, FMGP is most closely related to Variational LLA (VaLLA) (Ortega et al., 2024). VaLLA interprets the intractable LLA approximation as a GP (Immer et al., 2021; Khan et al., 2019), which is possible due to the DNN linearization. Then, it uses a VI sparse GP to approximate the posterior variances. If the Neural Tangent Kernel (NTK) is considered (which is simply given by the DNN Jacobian, i.e. \(\phi_\mathbf{x} = \mathrm{d}g(\mathbf{x})/\mathrm{d}\bm{\theta}\)), VaLLA can be recovered as a specific case of the FMGP formulation under certain hypotheses. The key differences between FMGP and VaLLA are: (i) FMGP does not rely on the NTK, which allows it to make use of more suitable kernels for specific tasks. Furthermore, the NTK demands computing the Jacobian of the neural network at each iteration, which drastically increases prediction costs, making it impractical to apply VaLLA to large problems (e.g. ImageNet). Furthermore, as it will be shown in the experiments, the FMGP kernel flexibility allows both for enhanced predictive distributions as well as more efficiently computed predictions. (ii) FMGP does not rely on a LLA approximation. Therefore, LLA with the GGN Hessian approximation need not be optimal under this framework. (iii) The NTK kernel is not guaranteed to be universal, and VaLLA relies on the hypothesis that the DNN output \(g(\cdot)\) is in \(\mathcal{H}\), which may not be the case. (iv) VaLLA does not consider a regularization technique to avoid overfitting, requiring early stopping and a validation set. In short, VaLLA is constructed by performing a variational approximation to the GP resulting from LLA. By contrast, FMGP uses VI to find the optimal measure within a set of Gaussian measures with a fixed mean.
In (Deng et al., 2022), the authors propose a Nyström approximation of the GGN Hessian approximation of LLA using \(M\ll N\) points chosen at random from the training set. The method, called ELLA, has cost \(\mathcal{O}(NM^3)\). ELLA also requires computing the costly Jacobian vectors required in VaLLA, but does not need their gradients. Unlike VaLLA, the Nyström approximation needs to visit each instance in the training set. However, as stated in (Deng et al., 2022), ELLA suffers from overfitting. Again, an early-stopping strategy using a validation set is proposed to alleviate it. In this case, ELLA only considers a subset of the training data. ELLA does not allow for hyper-parameter optimization, unlike VaLLA. The prior variance \(\sigma_0^2\) must be tuned using grid search and a validation set, increasing the required training time significantly.
Samples from LLA’s corresponding GP posterior can be efficiently computed using stochastic optimization, without inverting the kernel matrix (Antorán et al., 2023; Lin et al., 2024). This approach avoids LLA’s \(\mathcal{O}(N^3)\) cost. However, this method does not provide an estimate of the log-marginal likelihood for hyper-parameter optimization. Thus, in (Antorán et al., 2023) it is proposed to use the EM-algorithm for this task, where samples are generated (E-step) and hyper-parameters are optimized afterwards (M-step) iteratively. This significantly increases training cost, as generating a single sample is as expensive as fitting the original DNN on the full data. A limitation of this approach is that in (Antorán et al., 2023) only classification problems are considered and there is empirical evidence showing that VaLLA is faster and gives better results.
Another GP-based approach for obtaining predictive uncertainty in DNNs is the Spectral-normalized Neural Gaussian Process (SNGP) (Liu et al., 2023), which replaces the last layer of the DNN with a GP. SNGP allows either (i) fine-tuning a pre-trained DNN or (ii) training a full DNN from scratch. The comparison in the experiments focuses on the former. In practice, replacing the last layer with a GP does not preserve the predictive mean of the pre-trained DNN and often leads to a drop in predictive performance; similar behavior was also reported by Liu et al. (2023).
Another simple option to transform a pre-trained DNN model to a Bayesian one is to consider a mean-field VI approximation of the DNN posterior where the means are initialized to the pre-trained optimal solution weights and kept fixed. This is known as mean-field VI fine-tuning (Deng & Zhu, 2023) and, as demonstrated in the conducted experiments, it can achieve good results in terms of both prediction performance and uncertainty estimation. However, this method demands full training of the variance of each weight, which can be very costly and may require several training epochs. Furthermore, this method provides no closed-form predictive distribution. It relies on generating Monte Carlo samples to make predictions. As a result, further approximations must be considered to reduce the training time, such as Flipout Trick (Wen et al., 2018). Even though these techniques successfully reduce the training time, the required Monte Carlo samples significantly increase prediction time.
Experiments
The proposed method, FMGP, is compared with other methods including: last-layer LLA with and without KFAC approximation, ELLA (Deng et al., 2022), VaLLA (Ortega et al., 2024), a mean-field VI fine-tuning approach (Deng & Zhu, 2023) and SNGP (Liu et al., 2023). FMGP and VaLLA use \(100\) inducing points, as in (Ortega et al., 2024). ELLA employs \(2\,000\) random points and \(20\) random features as in (Deng & Zhu, 2023). All the timed experiments are executed on an Nvidia A100 graphics card. Finally, an implementation of FMGP is publicly available at https://github.com/Ludvins/FixedMeanGaussianProcesses.
Synthetic Experiment
| MAP | LLA |
| FMGP | MFVI |
| GP | HMC |
The experiment in Figure 4.9 illustrates the predictive distributions of commonly used Bayesian approaches on the synthetic 1-dimensional dataset from (Izmailov et al., 2020). It compares the predictive distribution of FMGPs against other methods, including the pre-trained DNN with optimized Gaussian noise (MAP), the linearized Laplace approximation (LLA) with prior precision and Gaussian noise optimized to maximize the marginal likelihood estimate, mean-field VI (MFVI) fine-tuning of the pre-trained model with Gaussian noise optimized on the training data, a GP with a squared exponential kernel and hyper-parameters that maximize the marginal likelihood, and Hamiltonian Monte Carlo (HMC) using a uniform prior for the variance of both the Gaussian noise and the Gaussian prior over the DNN’s weights \(\bm{\theta}\).
In this simple problem, HMC’s predictions serve as the gold standard for assessing the predictive variances of other methods. Note, however, that HMC does not scale to large problems. Figure 4.9 shows that MAP and MFVI tend to underestimate the predictive variance, while LLA tends to overestimate it by interpolating between data clusters. On the other hand, FMGP and the GP produce predictive variances comparable to those of HMC, with the GP yielding slightly larger variances.
However, the GP’s predictive mean does not align with the DNN output (given by the predictive mean of the MAP output) and suffers from a prior mean reversion problem, where the GP mean reverts to the prior mean between the second and third point clusters, which is expected to worsen the resulting predictive performance. Moreover, the GP does not scale well to large problems. By contrast, FMGP not only produces predictive variances similar to those of HMC but also retains the predictive mean equal to the DNN’s output, which is expected to result in improved prediction accuracy.
Regression Problems
As part of the experimental evaluation, consider three different large regression datasets:
The Year dataset (Bertin-Mahieux, 2011) with \(515\,345\) instances and \(90\) features. The data is divided as: the first \(400\,000\) instances as train subset and the following \(63\,715\) for validation. The rest of instances are taken for the test set.
The US flight delay (Airline) dataset (Dutordoir et al., 2020). Following (Ortega et al., 2024), use the first \(600\,000\) instances for training, the following \(100\,000\) instances for validation and the next \(100\,000\) for testing. Here, \(8\) features are considered: month, day of the month, day of the week, plane age, air time, distance, arrival time and departure time.
The Taxi dataset, with data recorded on January, 2023 (Ortega et al., 2023). For this dataset, \(9\) attributes are considered: time of day, day of week, day of month, month, PULocationID, DOLocationID, distance and duration; while the predictive variable is the price. Following (Ortega et al., 2024), filter trips shorter than 10 seconds and larger than 5 hours, resulting in \(3\,050\,311\) million instances. The first \(80\%\) is used as train data, the next \(10\%\) as validation data, and the last \(10\%\) as testing data.
In all experiments, a pre-trained 3-layer DNN with 200 units with tanh activations is employed, following (Ortega et al., 2024). ELLA is trained without early-stopping as overfitting is not observed in these regression problems. Hyper-parameters are chosen using a grid search and the validation set. FMGP employs the squared-exponential kernel with hyper-parameters given by kernel amplitude and one length scale per input feature. MAP results are obtained by learning the optimal Gaussian noise using a validation set. A last-layer Kronecker approximation is used for LLA.
Figure 4.10 shows average results for each method over 5 different random seeds. The quality of the predictive distribution is measured in terms of the negative log likelihood (NLL), the continuous ranked probability score (CRPS) (Gneiting & Raftery, 2007) and a centered quantile metric (CQM) (Ortega et al., 2024). Intuitively, CRPS can be understood as a generalization of the mean absolute error to predictive distributions. As seen in Section 4.2.5, CQM measures the difference between the model’s quantiles and the data quantiles under the same predictive mean, which is always the case here for each method. CQM is like a generalization of expected calibration error for regression problems. It is defined as: \begin{align} \text{CQM} & = \int_0^1 \ \Big|\mathbb{P}_{(\mathbf{x}^\star, y^\star)}\left[ y^\star \in I(\mathbf{x}^\star, \alpha) \right] - \alpha\Big| \ d\alpha\,, \end{align} where \(I(\mathbf{x}, \alpha)=(\lambda(-\alpha), \lambda(\alpha))\), \(\lambda(\alpha) = \Phi_{\mu(\mathbf{x}),\sigma^2(\mathbf{x})}^{-1}(\tfrac{1 + \alpha}{2})\) and \(\Phi_{\mu(\mathbf{x}),\sigma^2(\mathbf{x})}\) is the CDF of a Gaussian with mean \(\mu(\mathbf{x})\) and variance \(\sigma^2(\mathbf{x})\), specified by each model’s predictive distribution.
Figure 4.10 shows that FMGP performs best according to all three metrics (the lower the better), where the biggest difference is obtained in terms of CQM. As a result, we can argue that FMGP provides better uncertainty estimates (in terms of NLL) and calibration (both in terms of CRPS and CQM) compared to state-of-the-art LLA variants in regression settings.
CIFAR10 Dataset and ResNet Architectures
Several experiments with various ResNet architectures (He et al., 2016) on the CIFAR10 dataset (Krizhevsky et al., 2009) are performed. To facilitate reproducibility, the considered pre-trained models are publicly available and accessible at https://github.com/chenyaofo/pytorch-cifar-models. The considered models are the following: ResNet20 (\(272\, 474\) parameters), ResNet32 (\(466\, 906\) parameters), ResNet44 (\(661\, 338\) parameters) and ResNet56 (\(855\, 770\) parameters). Following (Deng et al., 2022) and (Ortega et al., 2024), ELLA and VaLLA use as a validation set a data-augmented subset of \(5\, 000\) training points from the train set. This validation set is obtained by performing random image crops of the training images of sizes in \([0.5, 1]\).
In multi-class classification problems, the kernel used in FMGP should model dependencies among the different DNN outputs, one for each class label. Therefore, the following simple kernel in FMGP is employed in that setting: \begin{equation} K((\mathbf{x},c),(\mathbf{x}',c')) = B_{c,c'} \times K_{RBF}(\mathbf{x}, \mathbf{x}') \nonumber \times (\psi(\mathbf{x})^T \psi(\mathbf{x}') + \delta_{\mathbf{x}=\mathbf{x}'})\,, \end{equation} which includes a p.s. matrix \(\bm{B}\in \mathbb{R}^{C \times C}\) to model output dependencies, a squared exponential kernel in the input space, and a linear kernel plus noise in the high-level features \(\psi(\cdot)\), that correspond to the output of the pre-trained model up to the second-to-last layer. The trainable hyper-parameters are the squared exponential amplitude and length scales of the RBF kernel (one per input feature), along with the matrix \(\bm{B}\), parameterized by its Cholesky decomposition. This simple kernel gives good results in the conducted experiments. More sophisticated kernels are possible, potentially leading to even better results. The inducing points are randomly assigned to a class label.
Figure 4.11 shows the negative log-likelihood (NLL), expected calibration error (ECE), and Brier score of each method. Furthermore, the out-of-distribution AUC of each method is also reported in a binary classification problem with the SVHN dataset as the out-of-distribution data (Netzer et al., 2011). In each method, the predictive entropy is used as the threshold for classification between in and out-of-distribution. The training and evaluation times for each method are also reported. Recent work (Mucsányi et al., 2024) shows how different uncertainty quantification metrics tend to cluster and the importance of measuring prediction uncertainty using as many as possible. Accuracy is not shown here as most methods barely change the pre-trained DNN accuracy. Notwithstanding, it is worth mentioning that SNGP tends to lower the accuracy of the model, as shown in (Liu et al., 2023), while MFVI tends to increase it slightly, as noticed in (Deng et al., 2022) and (Ortega et al., 2024).
Figure 4.11 shows that FMGP, MFVI, VaLLA, and ELLA provide the highest performance in terms of NLL and Brier scores (the lower the better). However, in terms of ECE (also the lower the better), SNGP, VaLLA, and FMGP provide better-calibrated uncertainties. As a result, FMGP and VaLLA seem to provide better uncertainty quantification with better-calibrated predictive distributions. However, for out-of-distribution detection, the best AUC is obtained by MFVI, ELLA, VaLLA, and SNGP. Figure 4.12 shows histograms of the entropy of the predictive distribution of each method for each type of test data (in and out-of-distribution). The poor results of FMGP in this task may be due to the kernel choice. More sophisticated kernels may improve FMGP’s results in this setting as well.
Regarding training time, Figure 4.11 shows that last-layer LLA approaches are the fastest to train, with VaLLA being the slowest method. At prediction time, SNGP, last-layer LLA, and FMGP are quite similar to the pre-trained model. By contrast, VaLLA, ELLA, and MFVI take larger prediction times. In VaLLA and ELLA, this is due to the computation of the Jacobians, while in MFVI this is due to Monte-Carlo sampling. Since FMGP is agnostic of the pre-trained model architecture, it only uses the DNN’s predictions. Therefore, all the model outputs are precomputed and used directly when training FMGP and making predictions using this model. As a result, a second bar is shown for FMGP indicating the training and evaluation time when pre-computing the outputs for both training and evaluation sets. In such a setting, the speed-up of FMGP is approximately \(\times1.5\) for training time and \(\times2.2\) for evaluation time.
Regarding predictive robustness, in Figure 4.13 the NLL and ECE of each method on rotated images of the CIFAR10 test set are shown, as in (Ortega et al., 2024). These results indicate that FMGP is the most robust method in terms of NLL, while it lies around the middle ground in terms of ECE. ELLA and VaLLA achieve the best results in this regard.
ImageNet Dataset and Extra ResNet Architectures
Several experiments with more ResNet architectures (He et al., 2016) on the ImageNet 1k dataset (Russakovsky et al., 2015) are performed. This dataset has \(1\,000\) different classes and over 1 million data instances. As pre-trained models, those from TorchVision (Maintainers & Contributors, 2016) available at https://pytorch.org/vision/main/models/resnet.html are considered. Specifically, the considered models are ResNet18 (\(11\,689\,512\) parameters), ResNet34 (\(21\,797\,672\) parameters), ResNet50 (\(25\,557\,032\) parameters), ResNet101 (\(44\,549\,160\) parameters) and ResNet152 (\(60\,192\,808\) parameters). Importantly, due to the size of the DNNs and dataset, many methods became infeasible in these experiments. Specifically, LLA cannot be used even with last-layer approximations due to memory limitations. Furthermore, Monte Carlo sampling for MFVI testing takes longer than 1 day for models larger than ResNet18. For this reason, MFVI is only tested on the ResNet18 architecture. SNGP is not evaluated as it requires a training time of several days on the smallest architecture.
Table 4.5 shows the results obtained for each method on each ResNet architecture. The best method is highlighted in red, and the second-to-best method is highlighted in green. Overall, FMGP obtains the best performance (NLL and ECE) while remaining the second-to-best in terms of computational time, only behind the MAP solution for the bigger models. As an additional detail, ELLA’s validation set is computed using the same data-augmentation strategy proposed in (Deng et al., 2022).
The obtained results for ELLA are slightly different from those reported in (Deng et al., 2022) since ELLA’s performance highly depends on the particular data augmentation performed to create the validation set. Despite using the same hyper-parameters for this step, using the current PyTorch versions leads to different results.
| Model | Method | NLL | ECE | Train Time | Test Time | ||||
|---|---|---|---|---|---|---|---|---|---|
| ResNet18 | MAP | 1.247 | 0.000 | 0.026 | 0.000 | 0.0 | 0.0 | 5.058 | 0.029\(\times 10^2\) |
| ELLA | 1.248 | 0.000 | 0.025 | 0.000 | 7.890 | 0.275\(\times 10^3\) | 8.060 | 0.010\(\times 10^2\) | |
| FMGP | 1.248 | 0.001 | 0.015 | 0.001 | 1.835 | 0.099\(\times 10^4\) | 7.324 | 0.001\(\times 10^2\) | |
| MFVI | 1.242 | 0.001 | 0.040 | 0.000 | 7.602 | 0.032\(\times 10^4\) | 3.773 | 0.308\(\times 10^4\) | |
| ResNet34 | MAP | 1.081 | 0.000 | 0.035 | 0.000 | 0.0 | 0.0 | 5.088 | 0.004\(\times 10^2\) |
| ELLA | 1.082 | 0.000 | 0.034 | 0.000 | 1.201 | 0.373\(\times 10^4\) | 1.087 | 0.018\(\times 10^3\) | |
| FMGP | 1.077 | 0.000 | 0.016 | 0.000 | 1.942 | 0.103\(\times 10^4\) | 8.563 | 0.011\(\times 10^2\) | |
| ResNet50 | MAP | 0.962 | 0.000 | 0.037 | 0.000 | 0.0 | 0.0 | 4.954 | 0.010\(\times 10^2\) |
| ELLA | 0.962 | 0.000 | 0.036 | 0.000 | 2.997 | 1.215\(\times 10^4\) | 1.954 | 0.018\(\times 10^3\) | |
| FMGP | 0.958 | 0.001 | 0.018 | 0.001 | 2.543 | 0.046\(\times 10^4\) | 1.100 | 0.010\(\times 10^3\) | |
| ResNet101 | MAP | 0.912 | 0.000 | 0.049 | 0.000 | 0.0 | 0.0 | 5.059 | 0.001\(\times 10^2\) |
| ELLA | 0.913 | 0.000 | 0.048 | 0.000 | 4.464 | 1.649\(\times 10^4\) | 2.808 | 0.001\(\times 10^3\) | |
| FMGP | 0.900 | 0.000 | 0.030 | 0.001 | 2.654 | 0.064\(\times 10^4\) | 1.134 | 0.001\(\times 10^3\) | |
| ResNet152 | MAP | 0.876 | 0.000 | 0.050 | 0.000 | 0.0 | 0.0 | 6.324 | 0.004\(\times 10^2\) |
| ELLA | 0.877 | 0.000 | 0.048 | 0.000 | 6.820 | 0.526\(\times 10^4\) | 3.877 | 0.007\(\times 10^3\) | |
| FMGP | 0.865 | 0.001 | 0.024 | 0.001 | 2.973 | 0.069\(\times 10^4\) | 1.267 | 0.002\(\times 10^3\) | |
Protein Feature Prediction Dataset
QM9 is a dataset that provides quantum chemical properties (at the DFT level) for a relevant, consistent, and comprehensive chemical space of around \(130\,000\) small organic molecules (Ruddigkeit et al., 2012). In this experiment, a small convolutional neural network with message passing is trained following the Torch-Geometric (Fey & Lenssen, 2019) tutorial available at https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_nn_conv.py. The model is trained to predict the dipole moment target.
In this regression setting, the input space consists of molecular graphs rather than the tabular data typically used in supervised learning. Accordingly, for each evaluated method, the network up to the last two linear layers is treated as a feature embedding of the graphs, and the data are assumed to live in that embedding space. The first \(10\,000\) instances are used for testing, the next \(10\,000\) for validation, and the remaining \(110\,000\) for training. ELLA is trained without early stopping, and hyper-parameters are selected via grid search on the validation set. FMGP employs the squared-exponential kernel with hyper-parameters comprising the amplitude and one length scale per dimension. MAP results are obtained by estimating the Gaussian noise on the validation set.
The results are shown in Table 4.6 for MAP, last-layer Kronecker LLA, ELLA, and FMGP in terms of negative log-likelihood (NLL) and CRPS. Average results across \(5\) repetitions are reported. The best result is highlighted in red and the second-best in green. FMGP achieves the best performance (lower is better) in both NLL and CRPS among the considered methods.
| Method | NLL | CRPS | ||
|---|---|---|---|---|
| MAP | -1.76 | 0.016 | 0.0221 | 0.00 |
| LLA | -1.78 | 0.021 | 0.0218 | 0.00 |
| ELLA | -1.80 | 0.013 | 0.0219 | 0.00 |
| FMGP | -1.85 | 0.017 | 0.0216 | 0.00 |
Conclusions
In this chapter, two complementary post-hoc uncertainty estimation methods were introduced: the Variational Linearized Laplace Approximation (VaLLA) and the Fixed-Mean Gaussian Process (FMGP). Both approaches stem from a generalized sparse GP formulation that enables fixing the predictive mean to any desired function within the RKHS, typically set to the output of a pre-trained DNN. This design allows them to retrofit calibrated predictive distributions onto existing networks without retraining.
VaLLA offers robust uncertainty estimates with computational costs independent of the number of training instances, efficiently scaling to deep models with millions of parameters and large datasets. It outperforms existing sampling-based and Nyström-approximation methods in speed and stability, while maintaining predictive robustness under input perturbations. FMGP extends this idea through a dual variational formulation, providing a simple yet effective framework that relies solely on DNN outputs—without requiring Jacobians or architectural details—thereby achieving minimal evaluation time and broad applicability across regression and classification tasks.
Both methods demonstrate that accurate, scalable uncertainty quantification can be achieved even for large deterministic networks, bridging the gap between practical deep learning and Bayesian modeling. Moreover, VaLLA’s and FMGP’s efficiency and reliability suggest promising applications in Bayesian optimization, where fast and accurate uncertainty estimation can significantly improve the search for optimal solutions.