Five members of the Centre’s faculty will publish eighteen papers and three workshops at this year’s conference on Neural Information Processing Systems (NeurIPS 2021), one of the most prestigious international academic conferences in machine learning.
Taking place completely virtually this year, NeurIPS 2021 will run for one week from December 6 through 14. The conference is a multi-track interdisciplinary annual meeting that includes invited talks, demonstrations, symposia, and oral and poster presentations of refereed papers. A professional exposition focusing on machine learning in practice, a series of tutorials, and topical workshops that provide a less formal setting for the exchange of ideas takes place alongside the conference.
The Cambridge Centre for AI in Medicine (CCAIM) is at the forefront of research in machine learning and will be presenting a total of eighteen papers and three workshops at this year’s conference. The papers cover a number of the Centre’s key research areas, including ML interpretability and explainability, causal inference, time series analysis, organ transplantation, biomedical discovery, and adaptive clinical trials, among others.
Of these papers, the van der Schaar Lab (Centre Director Professor Mihaela van der Schaar’s personal lab) has fourteen papers accepted for presentation and the Floto Lab (Centre Co-Director Professor Andres Floto) has one major paper accepted. Faculty members Dr. Angela Wood, Prof. José Miguel Hernández-Lobato, and Dr. Alexander Gimson will be contributing to the conference as well with three papers accepted between them (outlined in detail below).
Prof. van der Schaar will also be delivering three invited talks at the following workshops:
- December 13 at 13:30 GMT: 2021 Workshop on Meta-Learning (MetaLearn 2021) – Introducing quantitative epistemology, a transformational new area of research pioneered by our lab as a strand of machine learning aimed at understanding, supporting, and improving human decision-making.
- December 14 at 21:00 GMT: Self-supervised Learning for Genomics at the 2021 Self-supervised Learning Workshop
- December 14 at 16:20 GMT: Synthetic Data Generation and Assessment: Challenges, Methods, Impact at the Deep Generative Models and Downstream Applications Workshop – This is one of the van der Schaar Lab’s core research pillars, as they believe synthetic data has the potential to revolutionize how we access and interact with datasets in and beyond healthcare.
One of the Centre’s papers – “Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation” – was co-authored with Jim Weatherall of AstraZeneca, one of the Centre’s core industry partners. The paper investigates current benchmarking practices for ML-based conditional average treatment effect (CATE) estimators, with a special focus on empirical evaluation based on the popular semi-synthetic IHDP benchmark. It identifies problems with current practice and highlights that semi-synthetic benchmark datasets, which do not necessarily reflect properties of real data, can systematically favor some algorithms over others – a fact that is rarely acknowledged but is of immense relevance for the interpretation of empirical results.
“[I am] proud to have co-authored this NeurIPS paper, as part of our world leading industry-academia collaboration, the CCAIM. Understanding the effects of medicines at the individual level is key to the future of precision medicine. Metrics and benchmarking for machine learning approaches are an essential part of that journey.”
Jim Weatherall, Vice President of Data Science & AI, AstraZeneca
Further details (abstracts, authors, and links) about the papers that have been accepted at NeurIPS 2021:
1. SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data
Alicia Curth, Changhee Lee, Mihaela van der Schaar
Abstract: We study the problem of inferring heterogeneous treatment effects from time-to-event data. While both the related problems of (i) estimating treatment effects for binary or continuous outcomes and (ii) predicting survival outcomes have been well studied in the recent machine learning literature, their combination — albeit of high practical relevance — has received considerably less attention.
With the ultimate goal of reliably estimating the effects of treatments on instantaneous risk and survival probabilities, we focus on the problem of learning (discrete-time) treatment-specific conditional hazard functions. We find that unique challenges arise in this context due to a variety of covariate shift issues that go beyond a mere combination of well-studied confounding and censoring biases. We theoretically analyse their effects by adapting recent generalization bounds from domain adaptation and treatment effect estimation to our setting and discuss implications for model design. We use the resulting insights to propose a novel deep learning method for treatment-specific hazard estimation based on balancing representations.
We investigate performance across a range of experimental settings and empirically confirm that our method outperforms baselines by addressing covariate shifts from various sources.
2. On Inductive Biases for Heterogeneous Treatment Effect Estimation
(spotlight paper)
Alicia Curth, Mihaela van der Schaar
Abstract: We investigate how to exploit structural similarities of an individual’s potential outcomes (POs) under different treatments to obtain better estimates of conditional average treatment effects in finite samples.
Especially when it is unknown whether a treatment has an effect at all, it is natural to hypothesize that the POs are similar — yet, some existing strategies for treatment effect estimation employ regularization schemes that implicitly encourage heterogeneity even when it does not exist and fail to fully make use of shared structure.
In this paper, we investigate and compare three end-to-end learning strategies to overcome this problem — based on regularization, reparametrization and a flexible multi-task architecture — each encoding inductive bias favoring shared behavior across POs.
To build understanding of their relative strengths, we implement all strategies using neural networks and conduct a wide range of semi-synthetic experiments. We observe that all three approaches can lead to substantial improvements upon numerous baselines and gain insight into performance differences across various experimental settings.
3. Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation
Alicia Curth, David Svensson, Jim Weatherall Mihaela van der Schaar
Abstract: The machine learning (ML) toolbox for estimation of heterogeneous treatment effects from observational data is expanding rapidly, yet many of its algorithms have been evaluated only on a very limited set of semi-synthetic benchmark datasets.
In this paper, we investigate current benchmarking practices for ML-based conditional average treatment effect (CATE) estimators, with special focus on empirical evaluation based on the popular semi-synthetic IHDP benchmark. We identify problems with current practice and highlight that semi-synthetic benchmark datasets, which (unlike real-world benchmarks used elsewhere in ML) do not necessarily reflect properties of real data, can systematically favor some algorithms over others — a fact that is rarely acknowledged but of immense relevance for interpretation of empirical results.
Further, we argue that current evaluation metrics evaluate performance only for a small subset of possible use cases of CATE estimators, and discuss alternative metrics relevant for applications in personalized medicine.
Additionally, we discuss alternatives for current benchmark datasets, and implications of our findings for benchmarking in CATE estimation.
4. Estimating Multi-cause Treatment Effects via Single-cause Perturbation
Zhaozhi Qian, Alicia Curth, Mihaela van der Schaar
Abstract: Most existing methods for conditional average treatment effect estimation are designed to estimate the effect of a single cause — only one variable can be intervened on at one time.
However, many applications involve simultaneous intervention on multiple variables, which leads to multi-cause treatment effect problems. The multi-cause problem is challenging due to severe data scarcity — we only observe the outcome corresponding to the treatment that was actually given but need to infer a large number of potential outcomes under different combinations of the causes.
In this work, we propose Single-cause Perturbation (SCP), a novel two-step procedure to estimate the multi-cause treatment effect. SCP starts by augmenting the observational dataset with the estimated potential outcomes under single-cause interventions. It then performs covariate adjustment on the augmented dataset to obtain the estimator. SCP is agnostic to the exact choice of algorithm in either step.
We show formally that the procedure is valid under standard assumptions in causal inference. We demonstrate the performance gain of SCP on extensive simulation and real data experiments.
5. Integrating Expert ODEs into Neural ODEs: Pharmacology and Disease Progression
Zhaozhi Qian, William R. Zame, Lucas Fleuren, Paul Elbers, Mihaela van der Schaar
Abstract: Modeling a system’s temporal behaviour in reaction to external stimuli is a fundamental problem in many areas. Pure Machine Learning (ML) approaches often fail in the small sample regime and cannot provide actionable insights beyond predictions. A promising modification has been to incorporate expert domain knowledge into ML models.
The application we consider is predicting the progression of disease under medications, where a plethora of domain knowledge is available from pharmacology. Pharmacological models describe the dynamics of carefully-chosen medically meaningful variables in terms of systems of Ordinary Differential Equations (ODEs). However, these models only describe a limited collection of variables, and these variables are often not observable in clinical environments.
To close this gap, we propose the latent hybridisation model (LHM) that integrates a system of expert-designed ODEs with machine-learned Neural ODEs to fully describe the dynamics of the system and to link the expert and latent variables to observable quantities.
We evaluated LHM on synthetic data as well as real-world intensive care data of COVID-19 patients. LHM consistently outperforms previous works, especially when few training samples are available such as at the beginning of the pandemic.
6. SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes
Zhaozhi Qian, Yao Zhang, Ioana Bica, Angela Wood, Mihaela van der Schaar
Abstract: Most of the medical observational studies estimate the causal treatment effects using electronic health records (EHR), where a patient’s covariates and outcomes are both observed longitudinally. However, previous methods focus only on adjusting for the covariates while neglecting the temporal structure in the outcomes.
To bridge the gap, this paper develops a new method, SyncTwin, that learns a patient-specific time-constant representation from the pre-treatment observations. SyncTwin issues counterfactual prediction of a target patient by constructing a synthetic twin that closely matches the target in representation. The reliability of the estimated treatment effect can be assessed by comparing the observed and synthetic pre-treatment outcomes. The medical experts can interpret the estimate by examining the most important contributing individuals to the synthetic twin.
In the real-data experiment, SyncTwin successfully reproduced the findings of a randomized controlled clinical trial using observational data, which demonstrates its usability in the complex real-world EHR.
7. Invariant Causal Imitation Learning for Generalizable Policies
Ioana Bica, Daniel Jarrett, Mihaela van der Schaar
Abstract: Consider learning an imitation policy on the basis of demonstrated behavior from multiple environments, with an eye towards deployment in an unseen environment. Since the observable features from each setting may be different, directly learning individual policies as mappings from features to actions is prone to spurious correlations—and may not generalize well. However, the expert’s policy is often a function of a shared latent structure underlying those observable features that is invariant across settings.
By leveraging data from multiple environments, we propose Invariant Causal Imitation Learning (ICIL), a novel technique in which we learn a feature representation that is invariant across domains, on the basis of which we learn an imitation policy that matches expert behavior. To cope with transition dynamics mismatch, ICIL learns a shared representation of causal features (for all training environments), that is disentangled from the specific representations of noise variables (for each of those environments). Moreover, to ensure that the learned policy matches the observation distribution of the expert’s policy, ICIL estimates the energy of the expert’s observations and uses a regularization term that minimizes the imitator policy’s next state energy.
Experimentally, we compare our methods against several benchmarks in control and healthcare tasks and show its effectiveness in learning imitation policies capable of generalizing to unseen environments.
8. Time-series Generation by Contrastive Imitation
Daniel Jarrett, Ioana Bica, Mihaela van der Schaar
Abstract: Consider learning a generative model for time-series data. The sequential setting poses a unique challenge: Not only should the generator capture the *conditional* dynamics of (stepwise) transitions, but its open-loop rollouts should also preserve the *joint* distribution of (multi-step) trajectories.
On one hand, autoregressive models trained by MLE allow learning and computing explicit transition distributions, but suffer from compounding error during rollouts. On the other hand, adversarial models based on GAN training alleviate such exposure bias, but transitions are implicit and hard to assess.
In this work, we propose a novel framework marrying the best of both worlds: Motivated by a precise moment-matching objective to mitigate compounding error, we optimize a local (but forward-looking) *transition policy*, where the reinforcement signal is provided by a global (but stepwise-decomposable) *energy model* trained by contrastive estimation. In learning, the two components are trained cooperatively, avoiding the instabilities typical of adversarial objectives. Moreover, while the learned policy serves as the generator for sampling, the learned energy naturally serves as a trajectory-level measure for evaluating sample quality. By expressly training a policy to imitate sequential behavior of time-series features in a dataset, our approach embodies “generation by imitation”.
Theoretically, we demonstrate the correctness of our formulation and consistency of our algorithm. Empirically, we evaluate its ability to generate realistic samples using real-world datasets, and verify that it performs at or above the standard of existing benchmarks.
9. The Medkit-Learn(ing) Environment: Medical Decision Modelling through Simulation
Alex Chan, Ioana Bica, Alihan Hüyük, Daniel Jarrett, Mihaela van der Schaar
Abstract: The goal of understanding decision-making behaviours in clinical environments is of paramount importance if we are to bring the strengths of machine learning to ultimately improve patient outcomes.
Mainstream development of algorithms is often geared towards optimal performance in tasks that do not necessarily translate well into the medical regime—due to several factors including the lack of public availability of realistic data, the intrinsically offline nature of the problem, as well as the complexity and variety of human behaviours.
We therefore present a new benchmarking suite designed specifically for medical sequential decision modelling: the Medkit-Learn(ing) Environment, a publicly available Python package providing simple and easy access to high-fidelity synthetic medical data.
While providing a standardised way to compare algorithms in a realistic medical setting, we employ a generating process that disentangles the policy and environment dynamics to allow for a range of customisations, thus enabling systematic evaluation of algorithms’ robustness against specific challenges prevalent in healthcare.
10. Closing the loop in medical decision support by understanding clinical decision-making: A case study on organ transplantation
Yuchao Qin, Fergus Imrie, Alihan Hüyük, Daniel Jarrett, Alexander Gimson, Mihaela van der Schaar
Abstract: Significant effort has been placed on developing decision support tools to improve patient care. However, drivers of real-world clinical decisions in complex medical scenarios are not yet well-understood, resulting in substantial gaps between these tools and practical applications.
In light of this, we highlight that more attention on understanding clinical decision-making is required both to elucidate current clinical practices and to enable effective human-machine interactions. This is imperative in high-stakes scenarios with scarce available resources. Using organ transplantation as a case study, we formalize the desiderata of methods for understanding clinical decision-making.
We show that most existing machine learning methods are insufficient to meet these requirements and propose iTransplant, a novel data-driven framework to learn the factors affecting decisions on organ offers in an instance-wise fashion directly from clinical data, as a possible solution.
Through experiments on real-world liver transplantation data from OPTN, we demonstrate the use of iTransplant to: (1) discover which criteria are most important to clinicians for organ offer acceptance; (2) identify patient-specific organ preferences of clinicians allowing automatic patient stratification; and (3) explore variations in transplantation practices between different transplant centers. Finally, we emphasize that the insights gained by iTransplant can be used to inform the development of future decision support tools.
11. Explaining Latent Representations with a Corpus of Examples (spotlight paper)
Jonathan Crabbé, Zhaozhi Qian, Fergus Imrie, Mihaela van der Schaar
Abstract: Modern machine learning models are complicated. Most of them rely on convoluted latent representations of their input to issue a prediction. To achieve greater transparency than a black-box that connects inputs to predictions, it is necessary to gain a deeper understanding of these latent representations.
To that aim, we propose SimplEx: a user-centred method that provides example-based explanations with reference to a freely selected set of examples, called the corpus. SimplEx uses the corpus to improve the user’s understanding of the latent space with post-hoc explanations answering two questions: (1) Which corpus examples explain the prediction issued for a given test example? (2) What features of these corpus examples are relevant for the model to relate them to the test example? SimplEx provides an answer by reconstructing the test latent representation as a mixture of corpus latent representations.
Further, we propose a novel approach, the integrated Jacobian, that allows SimplEx to make explicit the contribution of each corpus feature in the mixture. Through experiments on tasks ranging from mortality prediction to image classification, we demonstrate that these decompositions are robust and accurate.
With illustrative use cases in medicine, we show that SimplEx empowers the user by highlighting relevant patterns in the corpus that explain model representations. Moreover, we demonstrate how the freedom in choosing the corpus allows the user to have personalized explanations in terms of examples that are meaningful for them.
12. Conformal Time-Series Forecasting
Kamilė Stankevičiūtė, Ahmed Alaa, Mihaela van der Schaar
Abstract: Current approaches for (multi-horizon) time-series forecasting using recurrent neural networks (RNNs) focus on issuing point estimates, which are insufficient for informing decision-making in critical application domains wherein uncertainty estimates are also required.
Existing methods for uncertainty quantification in RNN-based time-series forecasts are limited as they may require significant alterations to the underlying architecture, may be computationally complex, may be difficult to calibrate, may incur high sample complexity, and may not provide theoretical validity guarantees for the issued uncertainty intervals.
In this work, we extend the inductive conformal prediction framework to the time-series forecasting setup, and propose a lightweight uncertainty estimation procedure to address the above limitations. With minimal exchangeability assumptions, our approach provides uncertainty intervals with theoretical guarantees on frequentist coverage for multi-horizon forecast predictor and dataset.
We demonstrate the effectiveness of the conformal forecasting framework by comparing it with existing baselines on a variety of synthetic and real-world datasets.
13. MIRACLE: Causally-Aware Imputation via Learning Missing Data Mechanisms
Trent Kyono, Yao Zhang, Alexis Bellot, Mihaela van der Schaar
Abstract: Missing data is an important problem in machine learning practice. Starting from the premise that imputation methods should preserve the causal structure of the data, we develop a regularization scheme that encourages any baseline imputation method to be causally consistent with the underlying data generating mechanism.
Our proposal is a causally-aware imputation algorithm (MIRACLE). MIRACLE iteratively refines the imputation of a baseline by simultaneously modeling the missingness generating mechanism, encouraging imputation to be consistent with the causal structure of the data.
We conduct extensive experiments on synthetic and a variety of publicly available datasets to show that MIRACLE is able to consistently improve imputation over a variety of benchmark methods across all three missingness scenarios: at random, completely at random, and not at random.
14. DECAF: Generating Fair Synthetic Data Using Causally-Aware Generative Networks
Trent Kyono, Boris van Breugel, Jeroen Berrevoets, Mihaela van der Schaar
Abstract: Machine learning models have been criticized for reflecting unfair biases in the training data. Instead of solving for this by introducing fair learning algorithms directly, we focus on generating fair synthetic data, such that any downstream learner is fair. Generating fair synthetic data from unfair data – while remaining truthful to the underlying data-generating process (DGP) – is non-trivial.
In this paper, we introduce DECAF: a GAN-based fair synthetic data generator for tabular data. With DECAF we embed the DGP explicitly as a structural causal model in the input layers of the generator, allowing each variable to be reconstructed conditioned on its causal parents. This procedure enables inference time debiasing, where biased edges can be strategically removed for satisfying user-defined fairness requirements. The DECAF framework is versatile and compatible with several popular definitions of fairness.
In our experiments, we show that DECAF successfully removes undesired bias and – in contrast to existing methods – is capable of generating high-quality synthetic data. Furthermore, we provide theoretical guarantees on the generator’s convergence and the fairness of downstream models.
15. COVID-19 Sounds: A Large-Scale Audio Dataset for Digital Respiratory Screening
Tong Xia, Dimitris Spathis, J Ch, Andreas Grammenos, Jing Han, Apinan Hasthanasombat, Erika Bondareva, Ting Dang, Andres Floto, Pietro Cicuta, Cecilia Mascolo
Abstract: Audio signals are widely recognised as powerful indicators of overall health status, and there has been increasing interest in leveraging sound for affordable COVID19 screening through machine learning. However, there has also been scepticism regarding the initial efforts, due to perhaps the lack of reproducibility, large datasets and transparency which unfortunately is often an issue with machine learning for health. To facilitate the advancement and openness of audio-based machine learning for respiratory health, we release a dataset consisting of 53,449 audio samples (over 552 hours in total) crowd-sourced from 36,116 participants through our COVID-19 Sounds app1. Given its scale, this dataset is comprehensive in terms of demographics and spectrum of health conditions. It also provides participants’ selfreported COVID-19 testing status with 2,106 samples tested positive.
To the best of our knowledge, COVID-19 Sounds is the largest multi-modal dataset of COVID-19 respiratory sounds: it consists of three modalities including breathing, cough, and voice recordings. Additionally, in this paper, we report on several benchmarks for two principal research tasks: respiratory symptoms prediction and COVID-19 prediction. For these tasks we demonstrate performance with a ROC-AUC of over 0.7, confirming both the promise of machine learning approaches based on these types of datasets as well as the usability of our data for such tasks. We describe a realistic experimental setting that hopes to pave the way to a fair performance evaluation of future models. In addition, we reflect on how the released dataset can help to scale some existing studies and enable new research directions, which inspire and benefit a wide range of future works.
16. Weisfeiler and lehman go cellular: Cw networks
Cristian Bodnar, Fabrizio Frasca, Nina Otter, Yu Guang Wang, Pietro Liò, Guido F Montufar, Michael Bronstein
Abstract: Graph Neural Networks (GNNs) are limited in their expressive power, struggle with long-range interactions and lack a principled way to model higher-order structures. These problems can be attributed to the strong coupling between the computational graph and the input graph structure. The recently proposed Message Passing Simplicial Networks naturally decouple these elements by performing message passing on the clique complex of the graph. Nevertheless, these models can be severely constrained by the rigid combinatorial structure of Simplicial Complexes (SCs).
In this work, we extend recent theoretical results on SCs to regular Cell Complexes, topological objects that flexibly subsume SCs and graphs. We show that this generalisation provides a powerful set of graph” lifting” transformations, each leading to a unique hierarchical message passing procedure. The resulting methods, which we collectively call CW Networks (CWNs), are strictly more powerful than the WL test and not less powerful than the 3-WL test. In particular, we demonstrate the effectiveness of one such scheme, based on rings, when applied to molecular graph problems. The proposed architecture benefits from provably larger expressivity than commonly used GNNs, principled modelling of higher-order signals and from compressing the distances between nodes. We demonstrate that our model achieves state-of-the-art results on a variety of molecular datasets.
17. Functional Variational Inference based on Stochastic Process Generators
Chao Ma, José Miguel Hernández-Lobato
Abstract: Bayesian inference in the space of functions has been an important topic for Bayesian modeling in the past. In this paper, we propose a new solution to this problem called Functional Variational Inference (FVI). In FVI, we minimize a divergence in function space between the variational distribution and the posterior process. This is done by using as functional variational family a new class of flexible distributions called Stochastic Process Generators (SPGs), which are cleverly designed so that the functional ELBO can be estimated efficiently using analytic solutions and mini-batch sampling. FVI can be applied to stochastic process priors when random function samples from those priors are available. Our experiments show that FVI consistently outperforms weight-space and function space VI methods on several tasks, which validates the effectiveness of our approach.
18. Improving black-box optimization in VAE latent space using decoder uncertainty
Pascal Notin, José Miguel Hernández-Lobato, Yarin Gal
Abstract: Optimization in the latent space of variational autoencoders is a promising approach to generate high-dimensional discrete objects that maximize an expensive black-box property (eg, drug-likeness in molecular generation, function approximation with arithmetic expressions). However, existing methods lack robustness as they may decide to explore areas of the latent space for which no data was available during training and where the decoder can be unreliable, leading to the generation of unrealistic or invalid objects. We propose to leverage the epistemic uncertainty of the decoder to guide the optimization process. This is not trivial though, as a naive estimation of uncertainty in the high-dimensional and structured settings we consider would result in high estimator variance.
To solve this problem, we introduce an importance sampling-based estimator that provides more robust estimates of epistemic uncertainty. Our uncertainty-guided optimization approach does not require modifications of the model architecture nor the training process. It produces samples with a better trade-off between black-box objective and validity of the generated samples, sometimes improving both simultaneously. We illustrate these advantages across several experimental settings in digit generation, arithmetic expression approximation and molecule generation for drug design.