Members of the CCAIM research team will publish ten papers at this year’s International Conference on Machine Learning (ICML), the leading international academic conference in machine learning.
Taking place completely online from 18 – 24 July this year, the ICML is the premier gathering of professionals dedicated to the advancement of the branch of artificial intelligence known as machine learning.
The ICML is globally renowned for presenting and publishing cutting-edge research on all aspects of machine learning used in closely related areas like artificial intelligence, statistics and data science, as well as important application areas such as machine vision, computational biology, speech recognition, and robotics.
The Cambridge Centre for AI in Medicine is at the forefront of research in machine learning and will be presenting ten papers at this year’s conference. From better understanding human decision-making to individualized treatment effects and machine learning interpretability, our team – made up of expert minds from the van der Schaar Lab, Cambridge Machine Learning Group, and University of Cambridge Department of Computer Science and Technology – is working on revolutionary technologies that will shape the future of healthcare.
The papers submitted by our team that have been accepted at ICML 2021 include:
1. Inverse Decision Modeling: Learning Interpretable Representations of Behavior
Daniel Jarrett, Alihan Hüyük, Mihaela van der Schaar
Abstract: Decision analysis deals with modeling and enhancing decision processes. A principal challenge in improving behavior is in obtaining a transparent description of existing behavior in the first place.
In this paper, we develop an expressive, unifying perspective on inverse decision modeling: a framework for learning parameterized representations of sequential decision behavior.
First, we formalize the forward problem (as a normative standard), subsuming common classes of control behavior.
Second, we use this to formalize the inverse problem (as a descriptive model), generalizing existing work on imitation/reward learning—while opening up a much broader class of research problems in behavior representation.
Finally, we instantiate this approach with an example (inverse bounded rational control), illustrating how this structure enables learning (interpretable) representations of (bounded) rationality—while naturally capturing intuitive notions of suboptimal actions, biased beliefs, and imperfect knowledge of environments.
2. Explaining Time Series Predictions With Dynamic Masks
Jonathan Crabbé, Mihaela van der Schaar
Abstract: How can we explain the predictions of a machine learning model?
When the data is structured as a multivariate time series, this question induces additional difficulties such as the necessity for the explanation to embody the time dependency and the large number of inputs.
To address these challenges, we propose dynamic masks (Dynamask). This method produces instance-wise importance scores for each feature at each time step by fitting a perturbation mask to the input sequence. In order to incorporate the time dependency of the data, Dynamask studies the effects of dynamic perturbation operators. In order to tackle the large number of inputs, we propose a scheme to make the feature selection parsimonious (to select no more feature than necessary) and legible (a notion that we detail by making a parallel with information theory).
With synthetic and real-world data, we demonstrate that the dynamic underpinning of Dynamask, together with its parsimony, offer a neat improvement in the identification of feature importance over time. The modularity of Dynamask makes it ideal as a plug-in to increase the transparency of a wide range of machine learning models in areas such as medicine and finance, where time series are abundant.
3. Policy Analysis using Synthetic Controls in Continuous-Time
Alexis Bellot, Mihaela van der Schaar
Abstract: Counterfactual estimation using synthetic controls is one of the most successful recent methodological developments in causal inference. Despite its popularity, the current description only considers time series aligned across units and synthetic controls expressed as linear combinations of observed control units.
We propose a continuous-time alternative that models the latent counterfactual path explicitly using the formalism of controlled differential equations.
This model is directly applicable to the general setting of irregularly-aligned multivariate time series and may be optimized in rich function spaces – thereby substantially improving on some limitations of existing approaches.
4. Learning Queueing Policies for Organ Transplantation Allocation using Interpretable Counterfactual Survival Analysis
Jeroen Berrevoets, Ahmed Alaa, Zhaozhi Qian, James Jordon, Alexander Gimson, and Mihaela van der Schaar
Abstract: Organ transplantation is often the last resort for treating end-stage illnesses, but managing transplant wait-lists is challenging because of organ scarcity and the complexity of assessing donor-recipient compatibility.
In this paper, we develop a data-driven model for (real-time) organ allocation using observational data for transplant outcomes. Our model integrates a queuing-theoretic framework with unsupervised learning to cluster the organs into “organ types”, and then construct priority queues (associated with each organ type) wherein incoming patients are assigned. To reason about organ allocations, the model uses synthetic controls to infer a patient’s survival outcomes under counterfactual allocations to the different organ types– the model is trained end-to-end to optimize the trade-off between patient waiting time and expected survival time. The usage of synthetic controls enable patient-level interpretations of allocation decisions that can be presented and understood by clinicians.
We test our model on multiple data sets, and show that it outperforms other organ-allocation policies in terms of added life-years, and death count. Furthermore, we introduce a novel organ-allocation simulator to accurately test new policies.
5. Directional Graph Networks
Dominique Beaini, Saro Passaro, Vincent Létourneau, William L. Hamilton, Gabriele Corso, and Pietro Liò
Abstract: The lack of anisotropic kernels in graph neural networks (GNNs) strongly limits their expressiveness, contributing to well-known issues such as over-smoothing. To overcome this limitation, we propose the first globally consistent anisotropic kernels for GNNs, allowing for graph convolutions that are defined according to topologicaly-derived directional flows.
First, by defining a vector field in the graph, we develop a method of applying directional derivatives and smoothing by projecting node-specific messages into the field. Then, we propose the use of the Laplacian eigenvectors as such vector field. We show that the method generalizes CNNs on an n-dimensional grid and is provably more discriminative than standard GNNs regarding the Weisfeiler-Lehman 1-WL test. We evaluate our method on different standard benchmarks and see a relative error reduction of 8% on the CIFAR10 graph dataset and 11% to 32% on the molecular ZINC dataset, and a relative increase in precision of 1.6% on the MolPCBA dataset.
An important outcome of this work is that it enables graph networks to embed directions in an unsupervised way, thus allowing a better representation of the anisotropic features in different physical or biological problems.
6. How Framelets Enhance Graph Neural Networks
Xuebin Zheng, Bingxin Zhou, Junbin Gao, Yu Guang Wang, Pietro Lio, Ming Li, Guido Montufar
Abstract: This paper presents a new approach for assembling graph neural networks based on framelet transforms. The latter provides a multi-scale representation for graph-structured data.
With the framelet system, we can decompose the graph feature into low-pass and high-pass frequencies as extracted features for network training, which then defines a framelet-based graph convolution. The framelet decomposition naturally induces a graph pooling strategy by aggregating the graph feature into low-pass and high-pass spectra, which considers both the feature values and geometry of the graph data and conserves the total information. The graph neural networks with the proposed framelet convolution and pooling achieve state-of-the-art performance in many types of node and graph prediction tasks.
Moreover, we propose shrinkage as a new activation for the framelet convolution, which thresholds the high-frequency information at different scales. Compared to ReLU, shrinkage in framelet convolution improves the graph neural network model in terms of denoising and signal compression: noises in both node and structure can be significantly reduced by accurately cutting off the high-pass coefficients from framelet decomposition, and the signal can be compressed to less than half its original size with the prediction performance well preserved.
7. Weisfeiler and Lehman Go Topological: Message Passing Simplicial Networks
Cristian Bodnar, Fabrizio Frasca, Yu Guang Wang, Nina Otter, Guido Montúfar, Pietro Liò, Michael Bronstein
Abstract: The pairwise interaction paradigm of graph machine learning has predominantly governed the modelling of relational systems. However, graphs alone cannot capture the multi-level interactions present in many complex systems and the expressive power of such schemes was proven to be limited.
To overcome these limitations, we propose Message Passing Simplicial Networks (MPSNs), a class of models that perform message passing on simplicial complexes (SCs) – topological objects generalising graphs to higher dimensions. To theoretically analyse the expressivity of our model we introduce a Simplicial Weisfeiler-Lehman (SWL) colouring procedure for distinguishing non-isomorphic SCs. We relate the power of SWL to the problem of distinguishing non-isomorphic graphs and show that SWL and MPSNs are strictly more powerful than the WL test and not less powerful than the 3-WL test. We deepen the analysis by comparing our model with traditional
graph neural networks with ReLU activations in terms of the number of linear regions of the functions they can represent. We empirically support our theoretical claims by showing that MPSNs can distinguish challenging strongly regular graphs for which GNNs fail and, when equipped with orientation equivariant layers, they can improve classification accuracy in oriented SCs compared to a GNN baseline.
Additionally, we implement a library for message passing on simplicial complexes that we envision to release in due course.
8. A Gradient Based Strategy for Hamiltonian Monte Carlo Hyperparameter Optimization
Campbell A., Chen W., Stimper V., Hernández-Lobato J. M. and Zhang Y.
Abstract: Hamiltonian Monte Carlo (HMC) is one of the most successful sampling methods in machine learning. However, its performance is significantly affected by the choice of hyperparameter values. Existing approaches for optimizing the hyperparameters either optimize a proxy for mixing speed or consider the HMC chain as an implicit variational distribution and optimize a tractable lower bound that is too loose to be useful in practice.
Instead, we propose to optimize an objective that quantifies directly the speed of convergence to the target distribution. Our objective can be easily optimized using stochastic gradient descent. We evaluate our proposed method and compare to baselines on a variety of problems including synthetic 2D distributions, performing inference for sparse signal recovery, learning deep latent variable models and sampling molecular configurations from the Boltzmann distribution for molecular configurations of a 22 atom molecule.
We find our method is competitive with or improves upon alternative baselines on all problems we consider.
9. Active Slices for Sliced Stein Discrepancy
Gong W., Zhang K., Li Y. and Hernández-Lobato J. M.
Abstract: Sliced Stein discrepancy (SSD) and its kernelized variants have demonstrated promising successes in goodness-of-fit tests and model learning in high dimensions. Despite the theoretical elegance, their empirical performance depends crucially on the search of the optimal slicing directions to discriminate between two distributions.
Unfortunately, previous gradient-based optimisation approach returns sub-optimal results for the slicing directions: it is computationally expensive, sensitive to initialization, and it lacks theoretical guarantee for convergence. We address these issues in two steps.
First, we show in theory that the requirement of using optimal slicing directions in the kernelized version of SSD can be relaxed, validating the resulting discrepancy with finite random slicing directions.
Second, given that good slicing directions are crucial for practical performance, we propose a fast algorithm for finding good slicing directions based on ideas of active sub-space construction and spectral decomposition. Experiments in goodness-of-fit tests and model learning show that our approach achieves both the best performance and the fastest convergence. Especially, we demonstrate 14-80x speed-up in goodness-of-fit tests when compared with the gradient-based approach.
10. Bayesian Deep Learning via Subnetwork Inference
Daxberger E., Nalisnick E., Allingham J., Antoran J. and Hernández-Lobato J. M.
Abstract: The Bayesian paradigm has the potential to solve core issues of deep neural networks such as poor calibration and data inefficiency. Alas, scaling Bayesian inference to large weight spaces often requires restrictive approximations.
In this work, we show that it suffices to perform inference over a small subset of model weights in order to obtain accurate predictive posteriors. The other weights are kept as point estimates. This subnetwork inference framework enables us to use expressive, otherwise intractable, posterior approximations over such subsets. In particular, we implement subnetwork linearized Laplace: We first obtain a MAP estimate of all weights and then infer a full-covariance Gaussian posterior over a subnetwork. We propose a subnetwork selection strategy that aims to maximally preserve the model’s predictive uncertainty.
Empirically, our approach is effective compared to ensembles and less expressive posterior approximations over full networks.
Like this? Explore more papers →
Want to dig deeper? Visit our blog →