import math
import random
import numpy as np
from numpy import ndarray
import scipy.optimize as optimize
import scipy.integrate as integrate
import jax
from jax import hessian, jit, value_and_grad
import jax.numpy as jnp
from .utils import fmin_newton
from tqdm import tqdm
jax.config.update("jax_enable_x64", True)
[docs]
class BayesBAR:
"""Bayesian Bennett acceptance ratio method"""
def __init__(
self,
energy: ndarray,
num_conf: ndarray,
sample_size: int = 1000,
method: str = "Newton",
verbose=False,
):
"""
Initialize the BayesBAR class.
Parameters
----------
energy : ndarray
An energy matrix in reduced units. Its size should be 2xN, where N is the total number of samples from the two states.
num_conf : ndarray
Number of configurations in each state. Its size should be (2,).
sample_size : int, optional
The number of samples from the posterior distribution. Defaults to 1000.
method : str, optional
Optimization method for finding the mode. Options are "Newton" or "L-BFGS-B". Defaults to "Newton".
verbose : bool, optional
Whether to print running information. Defaults to False.
"""
assert (
energy.shape[0] == 2
), f"The energy matrix has {energy.shape[0]} rows. It has to have 2 rows."
assert (
len(num_conf) == 2
), f"The size of num_conf is {len(num_conf)}. It has to be 2."
assert (
energy.shape[1] == num_conf[0] + num_conf[1]
), f"The energy matrix has {energy.shape[1]} columns, but the sum of num_conf is {num_conf[0] + num_conf[1]}. The number of columns in the energy matrix must equal the sum of num_conf."
self.energy = jnp.float64(energy)
self.num_conf = jnp.int32(num_conf)
self.n_0, self.n_1 = self.num_conf
self.n = jnp.sum(self.num_conf)
## find the posterior mode which corresponds to the BAR solution
dF_init = jnp.zeros(1, dtype=jnp.float64)
dF_init = jnp.mean(self.energy[1] - self.energy[0]).reshape((-1,))
if method == "Newton":
f = jit(value_and_grad(_compute_loss))
hess = jit(hessian(_compute_loss))
res = fmin_newton(f, hess, dF_init, args=(self.energy, self.num_conf), verbose=verbose)
elif method == "L-BFGS-B":
options = {"disp": verbose, "gtol": 1e-8}
f = jit(value_and_grad(_compute_loss))
res = optimize.minimize(
lambda x: [np.array(r) for r in f(x, self.energy, self.num_conf)],
dF_init,
jac=True,
method="L-BFGS-B",
tol=1e-12,
options=options,
)
else:
raise ValueError(
f"Method {method} is not supported. It must be 'Newton' or 'L-BFGS-B'."
)
self.dF_mode = res["x"]
## compute posterior mean and standard deviation using numerical integration
self.dF_mean, self.dF_std = _compute_posterior_mean_and_std(
self.dF_mode, self.energy, self.num_conf
)
## sampling from the posterior distribution
self.sample_size = sample_size
if self.sample_size > 0:
self.dF_samples = _sample_from_posterior(
self.dF_mode,
self.dF_std,
self.energy,
self.num_conf,
self.sample_size,
)
## compute asymptotic standard deviation
H = hessian(_compute_logp)(self.dF_mode, self.energy, self.num_conf)
_dF_var_asymptotic = -1.0 / H - 1.0 / self.n_0 - 1.0 / self.n_1
self._dF_std_asymptotic = jnp.reshape(jnp.sqrt(_dF_var_asymptotic), ())
## Bennett's uncertainty
du = (
self.energy[1, :]
- self.energy[0, :]
- jnp.log(self.n_1 / self.n_0)
- self.dF_mode
)
f0 = jax.nn.sigmoid(-du[0 : self.n_0])
f1 = jax.nn.sigmoid(du[self.n_0 :])
_dF_var_bennett = (
jnp.mean(f0**2) / (self.n_0 * jnp.mean(f0) ** 2)
+ jnp.mean(f1**2) / (self.n_1 * jnp.mean(f1) ** 2)
- 1.0 / self.n_0
- 1.0 / self.n_1
)
self._dF_std_bennett = jnp.sqrt(_dF_var_bennett)
@property
def DeltaF_mode(self) -> ndarray:
"""The posterior mode of the free energy difference."""
return np.array(jax.device_put(self.dF_mode, jax.devices("cpu")[0]))
@property
def DeltaF_mean(self) -> ndarray:
"""The posterior mean of the free energy difference."""
return np.array(jax.device_put(self.dF_mean, jax.devices("cpu")[0]))
@property
def DeltaF_std(self) -> ndarray:
"""The posterior standard deviation of the free energy difference."""
return np.array(jax.device_put(self.dF_std, jax.devices("cpu")[0]))
@property
def DeltaF_samples(self) -> ndarray:
"""The samples from the posterior distribution of the free energy difference."""
return np.array(jax.device_put(self.dF_samples, jax.devices("cpu")[0]))
@jit
def _compute_logp(dF, energy, num_conf):
n_0, n_1 = num_conf
du = energy[1, :] - energy[0, :] - jnp.log(n_1 / n_0)
logp = n_1 * dF - jnp.logaddexp(jnp.zeros(1), dF - du).sum()
return jnp.reshape(logp, ())
@jit
def _compute_loss(dF, energy, num_conf):
logp = _compute_logp(dF, energy, num_conf)
loss = -logp / num_conf.sum()
return loss
@jit
def _compute_posterior(dF, energy, num_conf, dF_mode):
logp_max = _compute_logp(dF_mode, energy, num_conf)
## The prior is chosen to be the uniform distribution over R,
## so the posterior is equal to the likelihood. The likelihood
## is shifted down by a constant which is the self.logp_max
return jnp.exp(_compute_logp(dF, energy, num_conf) - logp_max)
def _compute_posterior_mean_and_std(dF_mode, energy, num_conf):
## compute the normalization constant Z
def f(dF):
return _compute_posterior(dF, energy, num_conf, dF_mode)
Z, Z_err = integrate.quad(jit(f), -np.inf, np.inf)
## posterior mean
def f(dF):
return dF * _compute_posterior(dF, energy, num_conf, dF_mode)
dF, dF_err = integrate.quad(jit(f), -np.inf, np.inf)
dF_mean = dF / Z
## posterior standard deviation
def f(dF):
return (dF - dF_mean) ** 2 * _compute_posterior(dF, energy, num_conf, dF_mode)
dF_var, dF_var_err = integrate.quad(jit(f), -np.inf, np.inf)
dF_var = dF_var / Z
dF_std = math.sqrt(dF_var)
return dF_mean, dF_std
def _sample_from_posterior(dF_mode, dF_std, energy, num_conf, size):
"""sample from the posterior using slice sampling
(https://www.jstor.org/stable/3448413 )
"""
## dF_std is used as the estimate of the typical size of a slice.
width = dF_std
## the size of a slice will be limited to max_size*width
max_size = 3
## start from the posterior mode
x0 = dF_mode
samples = [x0]
for _ in tqdm(range(size - 1)):
## sample the auxiliarxy random variable
logp = _compute_logp(x0, energy, num_conf)
z = logp - random.expovariate(1.0)
## find the slice interval using the "stepping out" procedure
u = random.uniform(0.0, 1.0)
L = x0 - u * width
R = L + width
v = random.uniform(0.0, 1.0)
J = math.floor(v * max_size)
K = max_size - 1 - J
while J > 0 and z < _compute_logp(L, energy, num_conf):
L = L - width
J = J - 1
while K > 0 and z < _compute_logp(R, energy, num_conf):
R = R + width
K = K - 1
## sampling from the interval using the "shrinkage" procedure
while True:
x1 = random.uniform(L, R)
if z < _compute_logp(x1, energy, num_conf):
samples.append(x1)
x0 = x1
break
if x1 < x0:
L = x1
else:
R = x1
return jnp.array(samples).reshape(-1)