publications
* denotes equal contribution
An up-to-date list is available on Google Scholar.
PhD thesis
preprints
- arXivFew-step Cofolding with All-Atom Flow MapsG. Scarpellini, R. Shprints, P. Holderrieth, J. Nam, P. Murugan, R. Gómez-Bombarelli, T. Jaakkola, M. Al-Shedivat, N. M. Boffi, and A. J. BosearXiv preprint, 2026
All-atom generative modeling of 3D biomolecular complexes has emerged as the dominant paradigm for predicting the structure of proteins and protein-ligand systems. Generating structures at the atomic level of fidelity, however, typically requires expensive iterative diffusion rollouts, making both conventional deployment and inference-time search techniques computationally costly. In this paper, we introduce the Denoiser Cofolding All-Atom Flowmap (DeCAF) framework for distilling state-of-the-art all-atom cofolding models into all-atom flow maps that produce high-quality samples in only a few inference steps. We build DeCAF on a denoiser-based formulation of flow maps with endpoint losses that naturally support SE(3) rigid alignment, which we show is critical for training accurate models. We further derive a simple change of variables that lets DeCAF operate in the sigma-space noise schedule of EDM-style architectures, enabling direct distillation from pretrained cofolding diffusion models. Equipped with DeCAF’s flowmap lookahead, we introduce a purpose-built inference-time framework that improves sampling through reward-guided search. Empirically, DeCAF-Boltz statistically improves over Boltz-1x in both accuracy (RMSD) and physical validity scores of protein-ligand poses at strict NFE budgets on the challenging Runs N’ Poses, while also showing a more optimal Pareto frontier across all inference compute budgets on PoseBusters. Distilling the state-of-the-art Pearl cofolding model, DeCAF-Pearl outperforms diffusion-based cofolding models and matches its teacher on success rate while using 5x fewer NFEs.
- arXivPearl: A Foundation Model for Placing Every Atom in the Right LocationGenesis Research TeamarXiv preprint, 2025Blog: Release · OpenBind results
Accurately predicting the three-dimensional structures of protein-ligand complexes remains a fundamental challenge in computational drug discovery that limits the pace and success of therapeutic design. Deep learning methods have recently shown strong potential as structural prediction tools, achieving promising accuracy across diverse biomolecular systems. However, their performance and utility are constrained by scarce experimental data, inefficient architectures, physically invalid poses, and the limited ability to exploit auxiliary information available at inference. To address these issues, we introduce Pearl (Placing Every Atom in the Right Location), a foundation model for protein-ligand cofolding at scale. Pearl addresses these challenges with three key innovations: (1) training recipes that include large-scale synthetic data to overcome data scarcity; (2) architectures that incorporate an SO(3)-equivariant diffusion module to inherently respect 3D rotational symmetries, improving generalization and sample efficiency, and (3) controllable inference, including a generalized multi-chain templating system supporting both protein and non-polymeric components as well as dual unconditional/conditional modes. Pearl establishes a new state-of-the-art performance in protein-ligand cofolding. On the key metric of generating accurate (RMSD < 2 Angstrom) and physically valid poses, Pearl surpasses AlphaFold 3 and other open source baselines on the public Runs N’ Poses and PoseBusters benchmarks, delivering 14.5% and 14.2% improvements, respectively, over the next best model. In the pocket-conditional cofolding regime, Pearl delivers 3.6x improvement on a proprietary set of challenging, real-world drug targets at the more rigorous RMSD < 1 Angstrom threshold. Finally, we demonstrate that model performance correlates directly with synthetic dataset size used in training.
- arXivA Field Guide to Federated OptimizationarXiv preprint, 2021
Federated learning and analytics are a distributed approach for collaboratively learning models (or statistics) from decentralized data, motivated by and designed for privacy protection. The distributed learning process can be formulated as solving federated optimization problems, which emphasize communication efficiency, data heterogeneity, compatibility with privacy and system requirements, and other constraints that are not primary considerations in other problem settings. This paper provides recommendations and guidelines on formulating, designing, evaluating and analyzing federated optimization algorithms through concrete examples and practical implementation, with a focus on conducting effective simulations to infer real-world performance. The goal of this work is not to survey the current literature, but to inspire researchers and practitioners to design federated learning algorithms that can be used in various practical applications.
conference & journal papers
2026
- Triangle Multiplication Is All You Need for Biomolecular Structure RepresentationsJ. Ouyang-Zhang, P. Murugan, D. J. Diaz, G. Scarpellini, R. S. Bowen, N. Gruver, A. Klivans, P. Krähenbühl, A. Faust, and M. Al-ShedivatIn International Conference on Learning Representations (ICLR), 2026
AlphaFold has transformed protein structure prediction, but emerging applications such as virtual ligand screening, proteome-wide folding, and de novo binder design demand predictions at a massive scale, where runtime and memory costs become prohibitive. A major bottleneck lies in the Pairformer backbone of AlphaFold3-style models, which relies on computationally expensive triangular primitives—especially triangle attention—for pairwise reasoning. We introduce Pairmixer, a streamlined alternative that eliminates triangle attention while preserving higher-order geometric reasoning capabilities that are critical for structure prediction. Pairmixer substantially improves computational efficiency, matching state-of-the-art structure predictors across folding and docking benchmarks, delivering up to 4x faster inference on long sequences while reducing training cost by 34%. Its efficiency alleviates the computational burden of downstream applications such as modeling large protein complexes, high-throughput ligand and binder screening, and hallucination-based design. Within BoltzDesign, for example, Pairmixer delivers over 2x faster sampling and scales to sequences 30% longer than the memory limits of Pairformer.
2021
- EMNLP OralKnowledge-Aware Meta-learning for Low-Resource Text ClassificationH. Yao, Y.-X. Wu, M. Al-Shedivat, and E. XingIn Conference on Empirical Methods in Natural Language Processing (EMNLP), 2021
Oral
Meta-learning has achieved great success in leveraging the historical learned knowledge to facilitate the learning process of the new task. However, merely learning the knowledge from the historical tasks, adopted by current meta-learning algorithms, may not generalize well to testing tasks when they are not well-supported by training tasks. This paper studies a low-resource text classification problem and bridges the gap between meta-training and meta-testing tasks by leveraging the external knowledge bases. Specifically, we propose KGML to introduce additional representation for each sentence learned from the extracted sentence-specific knowledge graph. The extensive experiments on three datasets demonstrate the effectiveness of KGML under both supervised adaptation and unsupervised adaptation settings.
- NAACL OralProgressive Generation of Long Text with Pretrained Language ModelsIn Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL), 2021
Oral
Large-scale language models pretrained on massive corpora of text, such as GPT-2, are powerful open-domain text generators. However, as our systematic examination reveals, it is still challenging for such models to generate coherent long passages of text (1000 tokens), especially when the models are fine-tuned to the target domain on a small corpus. To overcome the limitation, we propose a simple but effective method of generating text in a progressive manner, inspired by generating images from low to high resolution. Our method first produces domain-specific content keywords and then progressively refines them into complete passages in multiple stages. The simple design allows our approach to take advantage of pretrained language models at each stage and effectively adapt to any target domain given only a small set of examples. We conduct a comprehensive empirical study with a broad set of evaluation metrics, and show that our approach significantly improves upon the fine-tuned GPT-2 in terms of domain-specific quality and sample efficiency. The coarse-to-fine nature of progressive generation also allows for a higher degree of control over the generated content.
- Federated Learning via Posterior Averaging: A New Perspective and Practical AlgorithmsM. Al-Shedivat, J. Gillenwater, E. Xing, and A. RostamizadehIn International Conference on Learning Representations (ICLR), 2021
Federated learning is typically approached as an optimization problem, where the goal is to minimize a global loss function by distributing computation across client devices that possess local data and specify different parts of the global objective. We present an alternative perspective and formulate federated learning as a posterior inference problem, where the goal is to infer a global posterior distribution by having client devices each infer the posterior of their local data. While exact inference is often intractable, this perspective provides a principled way to search for global optima in federated settings. Further, starting with the analysis of federated quadratic objectives, we develop a computation- and communication-efficient approximate posterior inference algorithm—federated posterior averaging (FedPA). Our algorithm uses MCMC for approximate inference of local posteriors on the clients and efficiently communicates their statistics to the server, where the latter uses them to refine a global estimate of the posterior mode. Finally, we show that FedPA generalizes federated averaging (FedAvg), can similarly benefit from adaptive optimizers, and yields state-of-the-art results on four realistic and challenging benchmarks, converging faster, to better optima.
- AISTATSOn Data Efficiency of Meta-learningIn International Conference on Artificial Intelligence and Statistics (AISTATS), 2021
Meta-learning has enabled learning statistical models that can be quickly adapted to new prediction tasks. Motivated by use-cases in personalized federated learning, we study the often overlooked aspect of the modern meta-learning algorithms—their data efficiency. To shed more light on which methods are more efficient, we use techniques from algorithmic stability to derive bounds on the transfer risk that have important practical implications, indicating how much supervision is needed and how it must be allocated for each method to attain the desired level of generalization. Further, we introduce a new simple framework for evaluating meta-learning methods under a limit on the available supervision, conduct an empirical study of MAML, Reptile, and ProtoNets, and demonstrate the differences in the behavior of these methods on few-shot and federated learning benchmarks. Finally, we propose active meta-learning, which incorporates active data selection into learning-to-learn, leading to better performance of all methods in the limited supervision regime.
2020
- Regularizing Black-box Models for Improved InterpretabilityIn Advances in Neural Information Processing Systems (NeurIPS), 2020
Most of the work on interpretable machine learning has focused on designing either inherently interpretable models, which typically trade-off accuracy for interpretability, or post-hoc explanation systems, whose explanation quality can be unpredictable. Our method, ExpO, is a hybridization of these approaches that regularizes a model for explanation quality at training time. Importantly, these regularizers are differentiable, model agnostic, and require no domain knowledge to define. We demonstrate that post-hoc explanations for ExpO-regularized models have better explanation quality, as measured by the common fidelity and stability metrics. We verify that improving these metrics leads to significantly more useful explanations with a user study on a realistic task.
- Contextual Explanation NetworksM. Al-Shedivat, A. Dubey, and E. P. XingJournal of Machine Learning Research (JMLR), 2020Press: NLP Highlights.
Modern learning algorithms excel at producing accurate but complex models of the data. However, deploying such models in the real-world requires extra care: we must ensure their reliability, robustness, and absence of undesired biases. This motivates the development of models that are equally accurate but can be also easily inspected and assessed beyond their predictive performance. To this end, we introduce contextual explanation networks (CENs)—a class of architectures that learn to predict by generating and utilizing intermediate, simplified probabilistic models. Specifically, CENs generate parameters for intermediate graphical models which are further used for prediction and play the role of explanations. Contrary to the existing post-hoc model-explanation tools, CENs learn to predict and to explain jointly. Our approach offers two major advantages: (i) for each prediction, valid, instance-specific explanations are generated with no computational overhead and (ii) prediction via explanation acts as a regularizer and boosts performance in low-resource settings. We analyze the proposed framework theoretically and experimentally. Our results on image and text classification and survival analysis tasks demonstrate that CENs are not only competitive with the state-of-the-art methods but also offer additional insights behind each prediction, that are valuable for decision support. We also show that while post-hoc methods may produce misleading explanations in certain cases, CENs are always consistent and allow to detect such cases systematically.
2019
- NAACL Full OralConsistency by Agreement in Zero-shot Neural Machine TranslationM. Al-Shedivat and A.P. ParikhIn Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL), 2019
Full Oral
Generalization and reliability of multilingual translation often highly depend on the amount of available parallel data for each language pair of interest. In this paper, we focus on zero-shot generalization—a challenging setup that tests models on translation directions they have not been optimized for at training time. To solve the problem, we (i) reformulate multilingual translation as probabilistic inference, (ii) define the notion of zero-shot consistency and show why standard training often results in models unsuitable for zero-shot tasks, and (iii) introduce a consistent agreement-based training method that encourages the model to produce equivalent translations of parallel sentences in auxiliary languages. We test our multilingual NMT models on multiple public zero-shot translation benchmarks (IWSLT17, UN corpus, Europarl) and show that agreement-based learning often results in 2-3 BLEU zero-shot improvement over strong baselines without any loss in performance on supervised translation directions.
- A Baseline for Any Order Gradient Estimation in Stochastic Computation GraphsIn International Conference on Machine Learning (ICML), 2019
By enabling correct differentiation in stochastic computation graphs (SCGs), the infinitely differentiable Monte-Carlo estimator (DiCE) can generate correct estimates for the higher order gradients that arise in, e.g., multi-agent reinforcement learning and meta-learning. However, the baseline term in DiCE that serves as a control variate for reducing variance applies only to first order gradient estimation, limiting the utility of higherorder gradient estimates. To improve the sample efficiency of DiCE, we propose a new baseline term for higher order gradient estimation. This term may be easily included in the objective, and produces unbiased variance-reduced estimators under (automatic) differentiation, without affecting the estimate of the objective itself or of the first order gradient estimate. It reuses the same baseline function (e.g., the state-value function in reinforcement learning) already used for the first order baseline. We provide theoretical analysis and numerical evaluations of this new baseline, which demonstrate that it can dramatically reduce the variance of DiCE’s second order gradient estimators and also show empirically that it reduces the variance of third and fourth order gradients. This computational tool can be easily used to estimate higher order gradients with unprecedented efficiency and simplicity wherever automatic differentiation is utilised, and it has the potential to unlock applications of higher order gradients in reinforcement learning and meta-learning.
2018
- ICML Full OralLearning Policy Representations in Multiagent SystemsIn International Conference on Machine Learning (ICML), 2018
Full Oral
Modeling agent behavior is central to understanding the emergence of complex phenomena in multiagent systems. Prior work in agent modeling has largely been task-specific and driven by hand-engineering domain-specific prior knowledge. We propose a general learning framework for modeling agent behavior in any multiagent system using only a handful of interaction data. Our framework casts agent modeling as a representation learning problem. Consequently, we construct a novel objective inspired by imitation learning and agent identification and design an algorithm for unsupervised learning of representations of agent policies. We demonstrate empirically the utility of the proposed framework in (i) a challenging high-dimensional competitive environment for continuous control and (ii) a cooperative environment for communication, on supervised predictive tasks, unsupervised clustering, and policy optimization using deep reinforcement learning.
- ICML Full OralDiCE: The Infinitely Differentiable Monte-Carlo EstimatorIn International Conference on Machine Learning (ICML), 2018
Full Oral
The score function estimator is widely used for estimating gradients of stochastic objectives in Stochastic Computation Graphs (SCG), eg. in reinforcement learning and meta-learning. While deriving the first-order gradient estimators by differentiating a surrogate loss (SL) objective is computationally and conceptually simple, using the same approach for higher-order gradients is more challenging. Firstly, analytically deriving and implementing such estimators is laborious and not compliant with automatic differentiation. Secondly, repeatedly applying SL to construct new objectives for each order gradient involves increasingly cumbersome graph manipulations. Lastly, to match the first-order gradient under differentiation, SL treats part of the cost as a fixed sample, which we show leads to missing and wrong terms for higher-order gradient estimators. To address all these shortcomings in a unified way, we introduce DiCE, which provides a single objective that can be differentiated repeatedly, generating correct gradient estimators of any order in SCGs. Unlike SL, DiCE relies on automatic differentiation for performing the requisite graph manipulations. We verify the correctness of DiCE both through a proof and through numerical evaluation of the DiCE gradient estimates. We also use DiCE to propose and evaluate a novel approach for multi-agent learning. Our code is available at https://goo.gl/xkkGxN.
- Learning with Opponent-Learning AwarenessIn International Conference on Autonomous Agents and Multiagent Systems (AAMAS), 2018
Multi-agent settings are quickly gathering importance in machine learning. Beyond a plethora of recent work on deep multi-agent reinforcement learning, hierarchical reinforcement learning, generative adversarial networks and decentralized optimization can all be seen as instances of this setting. However, the presence of multiple learning agents in these settings renders the training problem non-stationary and often leads to unstable training or undesired final results. We present Learning with Opponent-Learning Awareness (LOLA), a method that reasons about the anticipated learning of the other agents. The LOLA learning rule includes an additional term that accounts for the impact of the agent’s policy on the anticipated parameter update of the other agents. We show that the LOLA update rule can be efficiently calculated using an extension of the likelihood ratio policy gradient update, making the method suitable for model-free RL. This method thus scales to large parameter and input spaces and nonlinear function approximators. Preliminary results show that the encounter of two LOLA agents leads to the emergence of tit-for-tat and therefore cooperation in the iterated prisoners’ dilemma (IPD), while independent learning does not. In this domain, LOLA also receives higher payouts compared to a naive learner, and is robust against exploitation by higher order gradient-based methods. Applied to infinitely repeated matching pennies, LOLA agents converge to the Nash equilibrium. In a round robin tournament we show that LOLA agents can successfully shape the learning of a range of multi-agent learning algorithms from literature, resulting in the highest average returns on the IPD. We also apply LOLA to a grid world task with an embedded social dilemma using deep recurrent policies. Again, by considering the learning of the other agent, LOLA agents learn to cooperate out of selfish interests.
- ICLR Best PaperContinuous Adaptation via Meta-Learning in Nonstationary and Competitive EnvironmentsIn International Conference on Learning Representations (ICLR), 2018
Best Paper Award
Ability to continuously learn and adapt from limited experience in nonstationary environments is an important milestone on the path towards general intelligence. In this paper, we cast the problem of continuous adaptation into the learning-to-learn framework. We develop a simple gradient-based meta-learning algorithm suitable for adaptation in dynamically changing and adversarial scenarios. Additionally, we design a new multi-agent competitive environment, RoboSumo, and define iterated adaptation games for testing various aspects of continuous adaptation strategies. We demonstrate that meta-learning enables significantly more efficient adaptation than reactive baselines in the few-shot regime. Our experiments with a population of agents that learn and compete suggest that meta-learners are the fittest.
2017
- Learning Scalable Deep Kernels with Recurrent StructureJournal of Machine Learning Research (JMLR), 2017
Many applications in speech, robotics, finance, and biology deal with sequential data, where ordering matters and recurrent structures are common. However, this structure cannot be easily captured by standard kernel functions. To model such structure, we propose expressive closed-form kernel functions for Gaussian processes. The resulting model, GP-LSTM, fully encapsulates the inductive biases of long short-term memory (LSTM) recurrent networks, while retaining the non-parametric probabilistic advantages of Gaussian processes. We learn the properties of the proposed kernels by optimizing the Gaussian process marginal likelihood using a new provably convergent semi-stochastic procedure and exploit the structure of these kernels for fast and scalable training and prediction. We demonstrate state-of-the-art performance on several benchmarks, and thoroughly investigate a consequential autonomous driving application, where the predictive uncertainties provided by GP-LSTM are uniquely valuable.
2016
- Learning HMMs with Nonparametric Emissions via Decompositions of Continuous MatricesIn Advances in Neural Information Processing Systems (NeurIPS), 2016
Recently, there has been a surge of interest in using spectral methods for estimating latent variable models. However, it is usually assumed that the distribution of the observations conditioned on the latent variables is either discrete or belongs to a parametric family. In this paper, we study the estimation of an m-state hidden Markov model (HMM) with only smoothness assumptions, such as Hölderian conditions, on the emission probabilities. By leveraging some recent advances in continuous linear algebra and numerical analysis, we develop a computationally efficient spectral algorithm for learning nonparametric HMMs. Our technique is based on computing an SVD on nonparametric estimates of density functions by viewing them as continuous matrices. We derive sample complexity bounds via concentration results for nonparametric density estimation and novel perturbation theory results for these continuous matrices. We implement our method using Chebyshev polynomial approximations. Our method is competitive with other baselines on synthetic and real problems and is also very computationally efficient.
- ADIOS: Architectures Deep In Output SpaceM. Cissé, M. Al-Shedivat, and S. BengioIn International Conference on Machine Learning (ICML), 2016
Multi-label classification is a generalization of binary classification where the task consists in predicting sets of labels. With the availability of ever larger datasets, the multi-label setting has become a natural one in many applications, and the interest in solving multi-label problems has grown significantly. As expected, deep learning approaches are now yielding state-of-the-art performance for this class of problems. Unfortunately, they usually do not take into account the often unknown but nevertheless rich relationships between labels. In this paper, we propose to make use of this underlying structure by learning to partition the labels into a Markov Blanket Chain and then applying a novel deep architecture that exploits the partition. Experiments on several popular and large multi-label datasets demonstrate that our approach not only yields significant improvements, but also helps to overcome trade-offs specific to the multi-label classification setting.
- FrontiersStochastic Synapses Enable Efficient Brain-Inspired Learning MachinesFrontiers in Neuroscience, 2016
Recent studies have shown that synaptic unreliability is a robust and sufficient mechanism for inducing the stochasticity observed in cortex. Here, we introduce Synaptic Sampling Machines (S2Ms), a class of neural network models that uses synaptic stochasticity as a means to Monte Carlo sampling and unsupervised learning. Similar to the original formulation of Boltzmann machines, these models can be viewed as a stochastic counterpart of Hopfield networks, but where stochasticity is induced by a random mask over the connections. Synaptic stochasticity plays the dual role of an efficient mechanism for sampling, and a regularizer during learning akin to DropConnect. A local synaptic plasticity rule implementing an event-driven form of contrastive divergence enables the learning of generative models in an on-line fashion. S2Ms perform equally well using discrete-timed artificial units (as in Hopfield networks) or continuous-timed leaky integrate and fire neurons. The learned representations are remarkably sparse and robust to reductions in bit precision and synapse pruning: removal of more than 75% of the weakest connections followed by cursory re-learning causes a negligible performance loss on benchmark classification tasks. The spiking neuron-based S2Ms outperform existing spike-based unsupervised learners, while potentially offering substantial advantages in terms of power and complexity, and are thus promising models for on-line learning in brain-inspired hardware.
2015
- Stochasticity Modeling in MemristorsR. Naous, M. Al-Shedivat, and K.N. SalamaIEEE Transactions on Nanotechnology, 2015
Diverse models have been proposed over the past years to explain the exhibiting behavior of memristors, the fourth fundamental circuit element. The models varied in complexity ranging from a description of physical mechanisms to a more generalized mathematical modeling. Nonetheless, stochasticity, a widespread observed phenomenon, has been immensely overlooked from the modeling perspective. This inherent variability within the operation of the memristor is a vital feature for the integration of this nonlinear device into the stochastic electronics realm of study. In this paper, experimentally observed innate stochasticity is modeled in a circuit compatible format. The model proposed is generic and could be incorporated into variants of threshold-based memristor models in which apparent variations in the output hysteresis convey the switching threshold shift. Further application as a noise injection alternative paves the way for novel approaches in the fields of neuromorphic engineering circuits design. On the other hand, extra caution needs to be paid to variability intolerant digital designs based on nondeterministic memristor logic.
- Memristors Empower Spiking Neurons With StochasticityM. Al-Shedivat, R. Naous, G. Cauwenberghs, and K. N. SalamaIEEE Journal on Emerging and Selected Topics in Circuits and Systems, 2015
Recent theoretical studies have shown that probabilistic spiking can be interpreted as learning and inference in cortical microcircuits. This interpretation creates new opportunities for building neuromorphic systems driven by probabilistic learning algorithms. However, such systems must have two crucial features: 1) the neurons should follow a specific behavioral model, and 2) stochastic spiking should be implemented efficiently for it to be scalable. This paper proposes a memristor-based stochastically spiking neuron that fulfills these requirements. First, the analytical model of the memristor is enhanced so it can capture the behavioral stochasticity consistent with experimentally observed phenomena. The switching behavior of the memristor model is demonstrated to be akin to the firing of the stochastic spike response neuron model, the primary building block for probabilistic algorithms in spiking neural networks. Furthermore, the paper proposes a neural soma circuit that utilizes the intrinsic nondeterminism of memristive switching for efficient spike generation. The simulations and analysis of the behavior of a single stochastic neuron and a winner-take-all network built of such neurons and trained on handwritten digits confirm that the circuit can be used for building probabilistic sampling and pattern adaptation machinery in spiking networks. The findings constitute an important step towards scalable and efficient probabilistic neuromorphic platforms.
- Inherently Stochastic Spiking Neurons for Probabilistic Neural ComputationIn International IEEE/EMBS Conference on Neural Engineering (NER), 2015
Neuromorphic engineering aims to design hardware that efficiently mimics neural circuitry and provides the means for emulating and studying neural systems. In this paper, we propose a new memristor-based neuron circuit that uniquely complements the scope of neuron implementations and follows the stochastic spike response model (SRM), which plays a cornerstone role in spike-based probabilistic algorithms. We demonstrate that the switching of the memristor is akin to the stochastic firing of the SRM. Our analysis and simulations show that the proposed neuron circuit satisfies a neural computability condition that enables probabilistic neural sampling and spike-based Bayesian learning and inference. Our findings constitute an important step towards memristive, scalable and efficient stochastic neuromorphic platforms.
- Learning Non-deterministic Representations with Energy-based EnsemblesM. Al-Shedivat, E. Neftci, and G. CauwenberghsIn International Conference on Learning Representations (ICLR), workshop track, 2015
The goal of a generative model is to capture the distribution underlying the data, typically through latent variables. After training, these variables are often used as a new representation, more effective than the original features in a variety of learning tasks. However, the representations constructed by contemporary generative models are usually point-wise deterministic mappings from the original feature space. Thus, even with representations robust to class-specific transformations, statistically driven models trained on them would not be able to generalize when the labeled data is scarce. Inspired by the stochasticity of the synaptic connections in the brain, we introduce Energy-based Stochastic Ensembles. These ensembles can learn non-deterministic representations, i.e., mappings from the feature space to a family of distributions in the latent space. These mappings are encoded in a distribution over a (possibly infinite) collection of models. By conditionally sampling models from the ensemble, we obtain multiple representations for every input example and effectively augment the data. We propose an algorithm similar to contrastive divergence for training restricted Boltzmann stochastic ensembles. Finally, we demonstrate the concept of the stochastic representations on a synthetic dataset as well as test them in the one-shot learning scenario on MNIST.
2014
- Supervised Transfer Sparse CodingM. Al-Shedivat, J. J.-Y. Wang, M. Alzahrani, J. Z. Huang, and X. GaoIn AAAI conference on Artificial Intelligence, 2014
A combination of the sparse coding and transfer learning techniques was shown to be accurate and robust in classification tasks where training and testing objects have a shared feature space but are sampled from different underlying distributions, i.e., belong to different domains. The key assumption in such case is that in spite of the domain disparity, samples from different domains share some common hidden factors. Previous methods often assumed that all the objects in the target domain are unlabeled, and thus the training set solely comprised objects from the source domain. However, in real world applications, the target domain often has some labeled objects, or one can always manually label a small number of them. In this paper, we explore such possibility and show how a small number of labeled data in the target domain can significantly leverage classification accuracy of the state-of-the-art transfer sparse coding methods. We further propose a unified framework named supervised transfer sparse coding (STSC) which simultaneously optimizes sparse representation, domain transfer and classification. Experimental results on three applications demonstrate that a little manual labeling and then learning the model in a supervised fashion can significantly improve classification accuracy.
technical reports & short papers
- medRxivDiscriminative Subtyping of Lung Cancers from Histopathology Images via Contextual Deep LearningmedRxiv preprint, 2020
When designing individualized treatment protocols for cancer patients, clinicians must synthesize the information from multiple data modalities into a single parsimonious description of the patient’s personal disease. However, such a description of a patient is never observed. In this work, we propose to model these patient descriptions as latent discriminative subtypes—sample representations which can be learned from one data modality and used to contextualize predictions based on another data modality. We apply contextual deep learning to learn these sample-specific discriminative subtypes from lung cancer histopathology imagery. Based on these subtypes, we produce sample-specific transcriptomic models which accurately classify samples as adenocarcinoma, squamous cell carcinoma, or healthy tissue (F1 score of 0.97, outperforming previous state-of-the-art multimodal approaches). Combining these data modalities in a single pipeline not only improves the predictive accuracy, but also gives biological interpretations of the discriminative subtypes and ties the phenotypic patterns present in histopathology images to biological processes.
- arXivLearning from Imperfect AnnotationsarXiv preprint, 2020
Many machine learning systems today are trained on large amounts of human-annotated data. Data annotation tasks that require a high level of competency make data acquisition expensive, while the resulting labels are often subjective, inconsistent, and may contain a variety of human biases. To improve the data quality, practitioners often need to collect multiple annotations per example and aggregate them before training models. Such a multi-stage approach results in redundant annotations and may often produce imperfect "ground truth" that may limit the potential of training accurate machine learning models. We propose a new end-to-end framework that enables us to: (i) merge the aggregation step with model training, thus allowing deep learning systems to learn to predict ground truth estimates directly from the available data, and (ii) model difficulties of examples and learn representations of the annotators that allow us to estimate and take into account their competencies. Our approach is general and has many applications, including training more accurate models on crowdsourced data, ensemble learning, as well as classifier accuracy estimation from unlabeled data. We conduct an extensive experimental evaluation of our method on 5 crowdsourcing datasets of varied difficulty and show accuracy gains of up to 25% over the current state-of-the-art approaches for aggregating annotations, as well as significant reductions in the required annotation redundancy.
- On the Complexity of Exploration in Goal-Driven NavigationIn Relational Representation Learning Workshop, NeurIPS, 2018
Building agents that can explore their environments intelligently is a challenging open problem. In this paper, we make a step towards understanding how a hierarchical design of the agent’s policy can affect its exploration capabilities. First, we design EscapeRoom environments, where the agent must figure out how to navigate to the exit by accomplishing a number of intermediate tasks (subgoals), such as finding keys or opening doors. Our environments are procedurally generated and vary in complexity, which can be controlled by the number of subgoals and relationships between them. Next, we propose to measure the complexity of each environment by constructing dependency graphs between the goals and analytically computing hitting times of a random walk in the graph. We empirically evaluate Proximal Policy Optimization (PPO) with sparse and shaped rewards, a variation of policy sketches, and a hierarchical version of PPO (called HiPPO) akin to h-DQN. We show that analytically estimated hitting time in goal dependency graphs is an informative metric of the environment complexity. We conjecture that the result should hold for environments other than navigation. Finally, we show that solving environments beyond certain level of complexity requires hierarchical approaches.
- Contextual Explanation Networks Enable Integrated Analysis Of Imaging And Genomic DataIn 26th conference on Intelligent Systems for Molecular Biology (ISMB), 2018
A fundamental goal of precision medicine is to use multiple types of data to generate a high-quality understanding of a patient’s disease. For datatypes which have well-understood effects on the outcome of interest, it is straightforward to incorporate this data directly in a probabilistic graphical model. However, complex datatypes can be difficult to incorporate. To overcome this challenge, we present a deep learning framework which incorporates complex covariate data (e.g., histology images) into an interpretable graphical model framework by using Contextual Explanation Networks to encode the covariate data into sample-specific “contexts" which determine interpretable parameters for a graphical model. We apply the framework to a dataset of Kidney Renal Clear Cell Carcinoma patients and find that the use of imaging contexts improves performance of case/control status by logistic regression from a baseline of 95% to over 99% predictive accuracy. Finally, we investigate the learned contexts to uncover molecular subtypes of the disease.
- Evaluating Generalization in Multiagent Systems using Agent-Interaction GraphsIn International Conference on Autonomous Agents and Multiagent Systems (AAMAS), 2018
Learning from interactions between agents is a key component for inference in multiagent systems. Depending on the downstream task, there could be multiple criteria for evaluating the generalization performance of learning. In this work, we propose a novel framework for evaluating generalization in multiagent systems based on agent-interaction graphs. An agent-interaction graph models agents as nodes and interactions as hyper-edges between participating agents. Using this abstract data structure, we define three notions of generalization for principled evaluation of learning in multiagent systems.
- NeurIPS SpotlightThe Intriguing Properties of Model ExplanationsIn Symposium on Interpretable Machine Learning, NeurIPS, 2017
Spotlight
Linear approximations to the decision boundary of a complex model have become one of the most popular tools for interpreting predictions. In this paper, we study such linear explanations produced either post-hoc by a few recent methods or generated along with predictions with contextual explanation networks (CENs). We focus on two questions: (i) whether linear explanations are always consistent or can be misleading, and (ii) when integrated into the prediction process, whether and how explanations affect the performance of the model. Our analysis sheds more light on certain properties of explanations produced by different methods and suggests that learning models that explain and predict jointly is often advantageous.
- NeurIPS SpotlightPersonalized Survival Prediction with Contextual Explanation NetworksIn Machine Learning for Healthcare Workshop, NeurIPS, 2017
Spotlight
Accurate and transparent prediction of cancer survival times on the level of individual patients can inform and improve patient care and treatment practices. In this paper, we design a model that concurrently learns to accurately predict patient-specific survival distributions and to explain its predictions in terms of patient attributes such as clinical tests or assessments. Our model is flexible and based on a recurrent network, can handle various modalities of data including temporal measurements, and yet constructs and uses simple explanations in the form of patient- and time-specific linear regression. For analysis, we use two publicly available datasets and show that our networks outperform a number of baselines in prediction while providing a way to inspect the reasons behind each prediction.
- Scalable GP-LSTMs with Semi-Stochastic GradientsIn Bayesian Deep Learning Workshop, NeurIPS, 2016
Many applications in speech, robotics, finance, and biology deal with sequential data, where ordering matters and recurrent structures are common. However, this structure cannot be easily captured by standard kernel functions. To model such structure, we propose expressive closed-form kernel functions for Gaussian processes. The resulting model, GP-LSTM, fully encapsulates the inductive biases of long short-term memory (LSTM) recurrent networks, while retaining the non-parametric probabilistic advantages of Gaussian processes. We learn the properties of the proposed kernels by optimizing the Gaussian process marginal likelihood using a new provably convergent semi-stochastic gradient procedure and exploit the structure of these kernels for fast and scalable training and prediction.
- Learning Diverse Overcomplete Dictionaries via Determinantal PriorsIn Geometry in Machine Learning Workshop, ICML, 2016
Sparse coding represents signals as sparse linear combinations of atoms drawn from a learned, overcomplete dictionary, and the quality of the resulting representations depends strongly on the dictionary: highly correlated atoms tend to yield redundant and unstable codes. In this work, we encourage diversity among the dictionary atoms by placing a determinantal point process (DPP) prior over them, which assigns higher probability to subsets of atoms that are mutually dissimilar. We discuss learning and inference for overcomplete dictionaries under this determinantal prior and how it promotes more diverse, better-conditioned representations.
- Neural Generative Models with Stochastic Synapses Capture Richer RepresentationsM. Al-Shedivat, E. Neftci, and G. CauwenberghsIn Computational and Systems Neuroscience (Cosyne), 2015
Stochasticity in synaptic vesicle release is one of the major sources of noise in the brain. While the concept of cellular neural noise gave rise to computational models of biological learning such as deep belief networks and algorithms such as spike-sampling, the functional implications of synaptic stochasticity on learning remain unascertained and are often limited to filtering, decorrelation, or regularization. In this work, we approach synaptic stochasticity from the perspective of representation learning, showing that it can improve fast concept learning in situations where labeled data is scarce. We study a two-layer neural network that implements a Boltzmann machine with probabilistic connections, where noisy synaptic strengths lead to a notion of stochastic ensembles of generative models. We demonstrate how such ensembles can be tuned using variational inference and—by analytically marginalizing synaptic noise for Bernoulli and Gaussian cases—use stochastic optimization based on Gibbs sampling to learn the synaptic distributions.
- FiO/LSShaping of Femtosecond Laser Pulses with Plasmonic CrystalsM. Shcherbakov, P. Vabishchevich, V. Zubjuk, M. Al-Shedivat, T. Dolgova, and A. FedyaninIn Frontiers in Optics, 2013
Temporal shaping of femtosecond laser pulses reflected from a one-dimensional plasmonic crystal—a commercially available polymer grating coated with a silver film—is demonstrated experimentally by time-resolved measurements of the intensity correlation function. The shaping is achieved by exciting surface plasmon-polaritons with a lifetime comparable to the 130 fs laser pulse duration. The measurements demonstrate flexible shaping of femtosecond pulses by delaying, advancing, splitting, broadening, compressing, and changing the topological properties of the pulse with plasmonic crystals.
other theses
- M.Sc.
- B.Sc.Фемтосекундная динамика преобразования поляризации света хиральными плазмонными метаматериаламиМаруан Аль-ШедиватМГУ им. М.В. Ломоносова, 2013