Skip to content

JaxGaussianProcesses/jax-decision-making

jax-decision-making

PyPI - Version PyPI - Python Version

jax-decision-making is an early-stage library which aims to provide algorithms for a variety of sequential decision-making problems. It currently provides implementations of several acquisition/utility functions for Bayesian optimisation, including probability of improvement, expected improvement and Thompson sampling. The implementations are built upon the JAX library, enabling automatic differentiation, vectorisation, and just-in-time (JIT) compilation for high performance on CPUs, GPUs, and TPUs. This allows for efficient research, development, and deployment of decision-making agents.

Initially, jax-decision-making was created as a sub-package within the GPJax library, but it has now been separated into its own package. Currently, Gaussian processes (implemented in GPJax) are the primary surrogate model around which the library has been developed. Nonetheless, now that the packages have been decoupled, we are happy to increase support for alternative surrogate models, such as Bayesian neural networks etc. Please feel free to open issues for features you would like to see implemented. These might include:

  • Support for additional acquisition functions/tricks for Bayesian optimisation and experimental design (e.g. trust-regions for high-dimensional problems).
  • Support for alternative surrogate models beyond Gaussian processes (e.g. Bayesian neural networks).

Table of Contents

Installation

pip install jax-decision-making

Contributing

Please refer to the contributing guidelines file for guidelines on how to contribute to the project.

License

jax-decision-making is distributed under the terms of the MIT license.

About

Sequential decision-making in JAX.

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages