Source code for bayesmbar.bayesmbar

from functools import partial
from typing import Literal

import numpy as np
from numpy.typing import NDArray
import jax

import jax.numpy as jnp
from jax import hessian, jit, value_and_grad
from jax import random
import blackjax
from optax import sgd
import optax
from .utils import (
    _solve_mbar,
    fmin_newton,
    _compute_log_likelihood_of_dF,
    _compute_log_likelihood_of_F,
    _compute_loss_likelihood_of_dF,
)

jax.config.update("jax_enable_x64", True)


[docs] class BayesMBAR: """Bayesian Multistate Bennett Acceptance Ratio (BayesMBAR) method""" def __init__( self, energy: np.ndarray, num_conf: np.ndarray, prior: Literal["uniform", "normal"] = "uniform", mean: Literal["constant", "linear", "quadratic"] = "constant", kernel: Literal["SE", "Matern52", "Matern32", "RQ"] = "SE", state_cv: np.ndarray = None, sample_size: int = 1000, warmup_steps: int = 500, optimize_steps: int = 10000, verbose: bool = True, random_seed: int = 0, method: Literal["Newton", "L-BFGS-B"] = "Newton", ) -> None: """ Args: energy (np.ndarray): Energy matrix of shape (m, n), where m is the number of states and n is the number of configurations. num_conf (np.ndarray): Number of configurations in each state. It is a 1D array of length m. prior (str, optional): Prior distribution of dF. It can be either "uniform" or "normal". Defaults to "uniform". mean (str, optional): Mean function of the prior. It can be either "constant", "linear", or "quadratic". Defaults to "constant". kernel (str, optional): Kernel function of the prior. It can be either "SE", "Matern52", "Matern32", or "RQ". Defaults to "SE". state_cv (np.ndarray, optional): State collective variables. It is a 2D array of shape (n, d), where n is the number of configurations and d is the dimension of the collective variables. Defaults to None. sample_size (int, optional): Number of samples drawn from the posterior distribution. Defaults to 1000. warmup_steps (int, optional): Number of warmup steps used to find the step size and mass matrix of the NUTS sampler. Defaults to 500. optimize_steps (int, optional): Number of optimization steps used to learn the hyperparameters when normal priors are used. Defaults to 10000. verbose (bool, optional): Whether to print the progress bar for the optimization and sampling. Defaults to True. random_seed (int, optional): Random seed. Defaults to 0. """ self._energy = jnp.float64(energy) self._num_conf = jnp.int32(num_conf) self._prior = prior self._mean_name = mean self._kernel_name = kernel if state_cv is not None: self._state_cv = state_cv[1:] self._sample_size = sample_size self._warmup_steps = warmup_steps self._optimize_steps = optimize_steps self._verbose = verbose self._method = method self._rng_key = jax.random.PRNGKey(random_seed) self._m = self._energy.shape[0] self._n = self._energy.shape[1] # We first compute the mode estimate based on the likelihood # because it is used in both the uniform and normal priors. # The mode estimate based on the likelihood is the solution to the MBAR equation. print("Solve for the mode of the likelihood") dF_init = jnp.zeros((self._m - 1,)) dF = _solve_mbar( dF_init, self._energy, self._num_conf, self._method, self._verbose ) # f = jit(value_and_grad(_compute_loss_likelihood_of_dF)) # hess = jit(hessian(_compute_loss_likelihood_of_dF)) # res = fmin_newton(f, hess, dF_init, args=(self._energy, self._num_conf)) # dF = res["x"] self._dF_mode_ll = dF # sample dF based on the likelihood. # When the uniform prior is used, the posterior distribution of dF is # the same as the likelihood function. # Thefore so these samples are also samples from the posterior # distribution of dF when the uniform prior is used. print("=====================================================") print("Sample from the likelihood") self._rng_key, subkey = random.split(self._rng_key) def logdensity(dF): return _compute_log_likelihood_of_dF(dF, self._energy, self._num_conf) self._dF_samples_ll = _sample_from_logdensity( subkey, self._dF_mode_ll, logdensity, self._warmup_steps, self._sample_size, self._verbose, ) ## compute the mean, covariance, and precision of dF based on the samples from the likelihood self._dF_mean_ll = jnp.mean(self._dF_samples_ll, axis=0) self._dF_cov_ll = jnp.cov(self._dF_samples_ll.T) L = jnp.linalg.cholesky(self._dF_cov_ll) L_inv = jax.scipy.linalg.solve_triangular(L, jnp.eye(L.shape[0]), lower=True) self._dF_prec_ll = L_inv.T.dot(L_inv) # self._dF_prec_ll = jnp.linalg.inv(self._dF_cov_ll) self._F_mode_ll = _dF_to_F(self._dF_mode_ll, self._num_conf) self._F_samples_ll = _dF_to_F(self._dF_samples_ll, self._num_conf) self._F_mean_ll = jnp.mean(self._F_samples_ll, axis=0) self._F_cov_ll = jnp.cov(self._F_samples_ll.T) ## we are done here if the prior is uniform. ## When normal prior is used, we need to learn the hyperparameters of the prior and then sample dF from the posterior distribution of dF. if self._prior == "normal": _data = { "energy": self._energy, "num_conf": self._num_conf, "dF_mean_ll": self._dF_mean_ll, "dF_prec_ll": self._dF_prec_ll, "state_cv": self._state_cv, } ## mean function of the prior if self._mean_name == "constant": self.mean_order = 0 self.mean = partial(_mean, order=self.mean_order) elif self._mean_name == "linear": self.mean_order = 1 self.mean = partial(_mean, order=self.mean_order) elif self._mean_name == "quadratic": self.mean_order = 2 self.mean = partial(_mean, order=self.mean_order) ## learn the hyperparameters of the prior if self._kernel_name == "SE": self.kernel = _kernel_SE elif self._kernel_name == "Matern52": self.kernel = _kernel_Matern52 elif self._kernel_name == "Matern32": self.kernel = _kernel_Matern32 elif self._kernel_name == "RQ": self.kernel = _kernel_RQ ## initialize the hyperparameters based on the mode of the likelihood params = _init_params( self.mean_order, self._kernel_name, self._dF_mode_ll, self._state_cv, self._num_conf, ) raw_params = _params_to_raw(params) ## optimize the hyperparameters self._rng_key, subkey = random.split(self._rng_key) optimizer = sgd(learning_rate=1e-3, momentum=0.9, nesterov=True) opt_state = optimizer.init(raw_params) @partial(jit, static_argnames=["mean", "kernel"]) def step(key, raw_params, opt_state, mean, kernel, data): loss, grads = _compute_elbo_loss(key, raw_params, mean, kernel, data) update, opt_state = optimizer.update(grads, opt_state) raw_params = optax.apply_updates(raw_params, update) return loss, raw_params, opt_state for i in range(optimize_steps): loss, raw_params, opt_state = step( subkey, raw_params, opt_state, self.mean, self.kernel, _data ) self._rng_key, subkey = random.split(self._rng_key) if i % 100 == 0: params = _params_from_raw(raw_params) print(f"step: {i:>10d}, loss: {loss:10.4f}", _print_params(params)) self._params = _params_from_raw(raw_params) self._dF_mean_prior = self.mean(self._params["mean"], self._state_cv) self._dF_cov_prior = self.kernel( self._params["kernel"], self._state_cv, ) self._dF_prec_prior = jnp.linalg.inv(self._dF_cov_prior) ## solve for the mode of the posterior f = jit(value_and_grad(_compute_loss_joint_likelihood_of_dF)) hess = jit(hessian(_compute_loss_joint_likelihood_of_dF)) res = fmin_newton( f, hess, self._dF_mode_ll, args=( self._energy, self._num_conf, self._dF_mean_prior, self._dF_prec_prior, ), ) self._dF_mode_posterior = res["x"] ## sample dF from the posterior def logdensity(dF): return _compute_log_joint_likelihood_of_dF( dF, self._energy, self._num_conf, self._dF_mean_prior, self._dF_prec_prior, ) self._rng_key, subkey = random.split(self._rng_key) self._dF_samples_posterior = _sample_from_logdensity( subkey, self._dF_mode_posterior, logdensity, self._warmup_steps, self._sample_size, self._verbose, ) self._dF_mean_posterior = jnp.mean(self._dF_samples_posterior, axis=0) self._dF_cov_posterior = jnp.cov(self._dF_samples_posterior.T) self._dF_prec_posterior = jnp.linalg.inv(self._dF_cov_posterior) self._F_mode_posterior = _dF_to_F(self._dF_mode_posterior, self._num_conf) self._F_samples_posterior = _dF_to_F( self._dF_samples_posterior, self._num_conf ) self._F_mean_posterior = jnp.mean(self._F_samples_posterior, axis=0) self._F_cov_posterior = jnp.cov(self._F_samples_posterior.T) @property def F_mode(self) -> NDArray: r"""The posterior mode estimate of the free energies of the states under the constraints that :math:`\sum_{k=1}^{M} N_k * F_k = 0`, where :math:`N_k` and :math:`F_k` are the number of conformations and the free energy of the k-th state, respectively.""" if self._prior == "uniform": F_mode = self._F_mode_ll elif self._prior == "normal": F_mode = self._F_mode_posterior return np.array(jax.device_put(F_mode, jax.devices("cpu")[0])) @property def F_mean(self) -> NDArray: r"""The posterior mean of the free energies of the states under the constraints that :math:`\\sum_{k=1}^{M} N_k * F_k = 0`, where :math:`N_k` and :math:`F_k` are the number of conformations and the free energy of the k-th state, respectively.""" if self._prior == "uniform": F_mean = self._F_mean_ll elif self._prior == "normal": F_mean = self._F_mean_posterior return np.array(jax.device_put(F_mean, jax.devices("cpu")[0])) @property def F_cov(self) -> NDArray: r"""The posterior covariance matrix of the free energies of the states under the constraints that :math:`\\sum_{k=1}^{M} N_k * F_k = 0`, where :math:`N_k` and :math:`F_k` are the number of conformations and the free energy of the k-th state, respectively.""" if self._prior == "uniform": F_cov = self._F_cov_ll elif self._prior == "normal": F_cov = self._F_cov_posterior F_cov = F_cov - jnp.diag(1.0 / self._num_conf) + 1.0 / self._num_conf.sum() ## if the diagnoal elements of F_cov are negetive, set them to 1e-4 condition = jnp.eye(F_cov.shape[0], dtype=bool) & (F_cov <= 0) F_cov = jnp.where(condition, 1e-4, F_cov) return np.array(jax.device_put(F_cov, jax.devices("cpu")[0])) @property def F_std(self) -> NDArray: r"""The posterior standard deviation of the free energies of the states under the constraints that :math:`\\sum_{k=1}^{M} N_k * F_k = 0`, where :math:`N_k` and :math:`F_k` are the number of conformations and the free energy of the k-th state, respectively.""" return np.array(jnp.sqrt(jnp.diag(self.F_cov))) @property def F_samples(self) -> NDArray: """The samples of the free energies of the states from the posterior distribution under the constraints that :math:`\\sum_{k=1}^{M} N_k * F_k = 0`, where :math:`N_k` and :math:`F_k` are the number of conformations and the free energy of the k-th state, respectively.""" if self._prior == "uniform": F_samples = self._F_samples_ll elif self._prior == "normal": F_samples = self._F_samples_posterior return np.array(jax.device_put(F_samples, jax.devices("cpu")[0])) @property def DeltaF_mode(self) -> NDArray: r"""The posterior mode estimate of free energy difference between states. DeltaF_mode[i,j] is the free energy difference between state :math:`j` and state :math:`i`, i.e., DeltaF_mode[i,j] = F_mode[j] - F_mode[i]. """ return self.F_mode[None, :] - self.F_mode[:, None] @property def DeltaF_mean(self) -> NDArray: r"""The posterior mean of free energy difference between states. DeltaF_mean[i,j] is the free energy difference between state :math:`j` and state :math:`i`, i.e., DeltaF_mean[i,j] = F_mean[j] - F_mean[i]. """ return self.F_mean[None, :] - self.F_mean[:, None] @property def DeltaF_std(self) -> NDArray: r"""The posterior standard deviation of free energy difference between states. DeltaF_std[i,j] is the posterior standard deviation of the free energy difference between state :math:`j` and state :math:`i`, """ DeltaF_cov = ( np.diag(self.F_cov)[:, None] + np.diag(self.F_cov)[None, :] - 2 * self.F_cov ) return np.sqrt(DeltaF_cov)
def _dF_to_F(dF, num_conf): if dF.ndim == 1: F = jnp.concatenate([jnp.zeros((1,)), dF]) elif dF.ndim == 2: F = jnp.concatenate([jnp.zeros((dF.shape[0], 1)), dF], axis=1) pi = num_conf / num_conf.sum() F = F - jnp.sum(pi * F, axis=-1, keepdims=True) return F def _compute_loss_joint_likelihood_of_dF(dF, energy, num_conf, mean_prior, prec_prior): """ Compute the loss function of dF based on the joint likelihood. The logarithm of the joint likelihood of dF scales with the number of configurations and the number of states. To make the loss function semi-invariant to the number of configurations and the number of states, we divide the log likelihood by the total number of configurations and the number of states. This helps to set a single tolerance for the optimization algorithm used to compute the MAP estimate. See the doc of _compute_log_joint_likelihood_of_dF for more details on the arguments and the return value. """ loss = -_compute_log_joint_likelihood_of_dF( dF, energy, num_conf, mean_prior, prec_prior ) loss = loss / (energy.shape[1] + energy.shape[0] - 1) return loss def _compute_log_joint_likelihood_of_dF(dF, energy, num_conf, mean_prior, prec_prior): """ Compute the logarithm of the joint likelihood of dF when the prior is a normal distribution. The joint likelihood is defined by the right hand side of Eq. (9) in the reference paper. Arguments: dF (jnp.ndarray): Free energy differences energy (jnp.ndarray): Energy matrix num_conf (jnp.ndarray): Number of configurations in each state mean_prior (jnp.ndarray): Mean of the prior prec_prior (jnp.ndarray): Precision matrix of the prior Returns: jnp.ndarray: Logarithm of the joint likelihood of dF """ logp = -0.5 * jnp.dot(dF - mean_prior, jnp.dot(prec_prior, dF - mean_prior)) logp = logp + _compute_log_likelihood_of_dF(dF, energy, num_conf) return logp @partial(value_and_grad, argnums=1) def _compute_elbo_loss(rng_key, raw_params, mean, kernel, data): energy = data["energy"] num_conf = data["num_conf"] state_cv = data["state_cv"] dF_prec_ll = data["dF_prec_ll"] dF_mean_ll = data["dF_mean_ll"] params = _params_from_raw(raw_params) mean_prior = mean(params["mean"], state_cv) cov_prior = kernel(params["kernel"], state_cv) mu_prop, cov_prop = _compute_proposal_dist( mean_prior, cov_prior, dF_mean_ll, dF_prec_ll ) dFs = random.multivariate_normal(rng_key, mu_prop, cov_prop, shape=(1024,)) Fs = jnp.concatenate([jnp.zeros((dFs.shape[0], 1)), dFs], axis=1) elbo = jax.vmap(_compute_log_likelihood_of_F, in_axes=(0, None, None))( Fs, energy, num_conf ) elbo = jnp.mean(elbo) elbo = elbo - _compute_kl_divergence(mu_prop, cov_prop, mean_prior, cov_prior) return -elbo def _compute_kl_divergence(mu0, cov0, mu1, cov1): L0 = jnp.linalg.cholesky(cov0) L1 = jnp.linalg.cholesky(cov1) M = jax.scipy.linalg.solve_triangular(L1, L0, lower=True) y = jax.scipy.linalg.solve_triangular(L1, mu1 - mu0, lower=True) kl = 0.5 * ( jnp.sum(M**2) + jnp.sum(y**2) - mu0.shape[0] + 2 * jnp.sum(jnp.log(jnp.diag(L1)) - jnp.log(jnp.diag(L0))) ) return kl def _compute_proposal_dist(mean_prior, cov_prior, dF_mean_ll, dF_prec_ll): prec_prior = jnp.linalg.inv(cov_prior) prec = dF_prec_ll + prec_prior cov = jnp.linalg.inv(prec) mu = jnp.dot(cov, jnp.dot(dF_prec_ll, dF_mean_ll) + jnp.dot(prec_prior, mean_prior)) return mu, cov def _print_params(params): res = "beta: " for i in range(params["mean"]["beta"].shape[0]): res += f'{params["mean"]["beta"][i].item():.4f}, ' res += f'scale: {params["kernel"]["scale"].item():.4f}, ' res += "l_scale: " for i in range(params["kernel"]["length_scale"].shape[0]): res += f'{params["kernel"]["length_scale"][i].item():.4f}, ' if "alpha" in params["kernel"].keys(): res += f'alpha: {params["kernel"]["alpha"].item():.4f}, ' res += "dscale: " for i in range(params["kernel"]["dscale"].shape[0]): res += f'{params["kernel"]["dscale"][i].item():.4f}, ' return res def _expand(x, order): xx = [jnp.ones((x.shape[0], 1))] for i in range(order): xx.append(x ** (i + 1)) xx = jnp.concatenate(xx, axis=-1) return xx def _mean(params, x, order): xx = _expand(x, order) return jnp.sum(params["beta"] * xx, axis=-1) def _params_from_raw(raw_params): params = {} params["mean"] = raw_params["mean"] params["kernel"] = _kernel_params_from_raw(raw_params["kernel"]) return params def _params_to_raw(params): raw_params = {} raw_params["mean"] = params["mean"] raw_params["kernel"] = _kernel_params_to_raw(params["kernel"]) return raw_params def _init_mean_params(order, dF, state_cv): x = _expand(state_cv, order) beta = jnp.linalg.lstsq(x, dF, rcond=None)[0] params = {"beta": beta} return params def _init_params(mean_order, kernel_name, dF, state_cv, num_conf): params = {} params["mean"] = _init_mean_params(mean_order, dF, state_cv) params["kernel"] = _init_kernel_params(kernel_name, dF, state_cv, num_conf) return params def _init_kernel_params(kernel_name, dF, state_cv, num_conf): params = {} params["scale"] = jnp.std(dF) params["length_scale"] = (state_cv.max(0) - state_cv.min(0)) / state_cv.shape[0] params["dscale"] = jnp.ones_like(dF) * jnp.std(dF) if kernel_name == "RQ": params["alpha"] = jnp.ones((1,)) * 10 return params def _kernel_RQ(params, x): scale = params["scale"] length_scale = params["length_scale"] dscale = params["dscale"] alpha = params["alpha"] x = x / length_scale ds = _compute_squared_distance(x) return scale**2 * (1 + ds / (2 * alpha)) ** (-alpha) + dscale**2 * jnp.eye( ds.shape[0] ) def _kernel_SE(params, x): scale = params["scale"] length_scale = params["length_scale"] dscale = params["dscale"] x = x / length_scale ds = _compute_squared_distance(x) return scale**2 * jnp.exp(-0.5 * ds) + dscale**2 * jnp.eye(ds.shape[0]) def _kernel_Matern52(params, x): scale = params["scale"] length_scale = params["length_scale"] dscale = params["dscale"] x = x / length_scale ds = _compute_squared_distance(x) d = jnp.sqrt(ds + 1e-18) return scale**2 * (1 + jnp.sqrt(5.0) * d + 5.0 / 3.0 * ds) * jnp.exp( -jnp.sqrt(5.0) * d ) + dscale**2 * jnp.eye(ds.shape[0]) def _kernel_Matern32(params, x): scale = params["scale"] length_scale = params["length_scale"] dscale = params["dscale"] x = x / length_scale ds = _compute_squared_distance(x) d = jnp.sqrt(ds + 1e-18) return scale**2 * (1 + jnp.sqrt(3.0) * d) * jnp.exp( -jnp.sqrt(3.0) * d ) + dscale**2 * jnp.eye(ds.shape[0]) def _kernel_params_from_raw(raw_params): params = {} params["scale"] = jax.nn.softplus(raw_params["raw_scale"]) params["length_scale"] = jax.nn.softplus(raw_params["raw_length_scale"]) params["dscale"] = jax.nn.softplus(raw_params["raw_dscale"]) if "raw_alpha" in raw_params.keys(): params["alpha"] = jax.nn.softplus(raw_params["raw_alpha"]) return params def _kernel_params_to_raw(params): raw_params = {} raw_params["raw_scale"] = jnp.log(jnp.exp(params["scale"]) - 1) raw_params["raw_length_scale"] = jnp.log(jnp.exp(params["length_scale"]) - 1) raw_params["raw_dscale"] = jnp.log(jnp.exp(params["dscale"]) - 1) if "alpha" in params.keys(): raw_params["raw_alpha"] = jnp.log(jnp.exp(params["alpha"]) - 1) return raw_params def _compute_squared_distance(x): x1 = x[:, None, :] x2 = x[None, :, :] return jnp.sum((x1 - x2) ** 2, axis=-1) def _sample_from_logdensity( rng_key, init_dF, logdensity, warmup_steps, num_samples, verbose ): ## warmup to find step size and mass matrix warmup = blackjax.window_adaptation( blackjax.nuts, logdensity, is_mass_matrix_diagonal=False, progress_bar=verbose, ) rng_key, subkey = random.split(rng_key) (state, parameters), _ = warmup.run(subkey, init_dF, num_steps=warmup_steps) print("Sample using the NUTS sampler") ## sample using nuts # ## Use the blackjax.util.run_inference_algorithm function to run the nuts algorithm # ## so that we can have a progress bar. It is a wrap of _sample_loop. # alg = blackjax.nuts(logdensity, **parameters) # _, states, _ = blackjax.util.run_inference_algorithm( # rng_key, state, alg, num_samples, progress_bar=verbose # ) ## sample using nuts rng_key, subkey = random.split(rng_key) kernel = blackjax.nuts(logdensity, **parameters).step states = _sample_loop(subkey, kernel, state, num_samples) return states.position def _sample_loop(rng_key, kernel, init_state, num_samples): @jax.jit def one_step(state, rng_key): state, _ = kernel(rng_key, state) return state, state keys = jax.random.split(rng_key, num_samples) _, states = jax.lax.scan(one_step, init_state, keys) return states